diff --git a/apps/rowboat_agents/src/graph/swarm_wrapper.py b/apps/rowboat_agents/src/graph/swarm_wrapper.py index 6f428a49..48b723cc 100644 --- a/apps/rowboat_agents/src/graph/swarm_wrapper.py +++ b/apps/rowboat_agents/src/graph/swarm_wrapper.py @@ -37,7 +37,7 @@ class NewResponse(BaseModel): tokens_used: Optional[dict] = {} error_msg: Optional[str] = "" -async def mock_tool(tool_name: str, args: str, tool_config: str) -> str: +async def mock_tool(tool_name: str, args: str, description: str, mock_instructions: str) -> str: """ Handles tool execution by either using mock instructions or generating a response. @@ -51,9 +51,6 @@ async def mock_tool(tool_name: str, args: str, tool_config: str) -> str: """ print(f"Mock tool called for: {tool_name}") - # For non-mocked tools, generate a realistic response - description = tool_config.get("description", "") - mock_instructions = tool_config.get("mockInstructions", "") messages = [ {"role": "system", "content": f"You are simulating the execution of a tool called '{tool_name}'.Here is the description of the tool: {description}. Here are the instructions for the mock tool: {mock_instructions}. Generate a realistic response as if the tool was actually executed with the given parameters."}, @@ -140,9 +137,12 @@ async def catch_all(ctx: RunContextWrapper[Any], args: str, tool_name: str, tool asyncio.set_event_loop(loop) response_content = None - if tool_config.get("mockTool", False): + if tool_config.get("mockTool", False) or complete_request.get("testProfile", {}).get("mockTools", False): # Call mock_tool to handle the response (it will decide whether to use mock instructions or generate a response) - response_content = await mock_tool(tool_name, args, tool_config) + if complete_request.get("testProfile", {}).get("mockPrompt", ""): + response_content = await mock_tool(tool_name, args, tool_config.get("description", ""), complete_request.get("testProfile", {}).get("mockPrompt", "")) + else: + response_content = await mock_tool(tool_name, args, tool_config.get("description", ""), tool_config.get("mockInstructions", "")) print(response_content) elif tool_config.get("isMcp", False): mcp_server_name = tool_config.get("mcpServerName", "")