diff --git a/apps/agents/src/graph/swarm_wrapper.py b/apps/agents/src/graph/swarm_wrapper.py index 161c598a..9f113ccd 100644 --- a/apps/agents/src/graph/swarm_wrapper.py +++ b/apps/agents/src/graph/swarm_wrapper.py @@ -1,6 +1,6 @@ import logging import json - +import aiohttp # Import helper functions needed for get_agents from .helpers.access import ( get_tool_config_by_name, @@ -17,6 +17,9 @@ from typing import Any # Create a dedicated logger for swarm wrapper #logger = logging.getLogger("swarm_wrapper") #logger.setLevel(logging.INFO) +import asyncio +from mcp import ClientSession +from mcp.client.sse import sse_client from pydantic import BaseModel from typing import List, Optional, Dict @@ -27,22 +30,117 @@ class NewResponse(BaseModel): tokens_used: Optional[dict] = {} error_msg: Optional[str] = "" +async def mock_tool(tool_name: str, args: str, mock_instructions: str) -> str: + """ + Handles tool execution by either using mock instructions or generating a response. + + Args: + tool_name: The name of the tool + args: The arguments passed to the tool + tool_config: The configuration of the tool + + Returns: + The response from the tool + """ + print(f"Mock tool called for: {tool_name}") + + # For non-mocked tools, generate a realistic response + description = mock_instructions + + messages = [ + {"role": "system", "content": f"You are simulating the execution of a tool called '{tool_name}'. Here are the mock instructions: {description}. Generate a realistic response as if the tool was actually executed with the given parameters."}, + {"role": "user", "content": f"Generate a realistic response for the tool '{tool_name}' with these parameters: {args}. The response should be concise and focused on what the tool would actually return."} + ] + + print(f"Generating simulated response for tool: {tool_name}") + response_content = generate_openai_output(messages, output_type='text', model="gpt-4o") + return response_content + +async def call_webhook(tool_name: str, args: str) -> str: + """ + Calls the webhook with the given tool name and arguments. + + Args: + tool_name (str): The name of the tool to call. + args (str): The arguments for the tool as a JSON string. + + Returns: + str: The response from the webhook, or an error message if the call fails. + """ + webhook_url = "http://localhost:4020/tool_call" + content_dict = { + "toolCall": { + "function": { + "name": tool_name, + "arguments": args # Assumes args is a valid JSON string + } + } + } + request_body = { + "content": json.dumps(content_dict) + } + try: + async with aiohttp.ClientSession() as session: + async with session.post(webhook_url, json=request_body) as response: + if response.status == 200: + response_json = await response.json() + return response_json.get("result", "") + else: + error_msg = await response.text() + print(f"Webhook error: {error_msg}") + return f"Error: {error_msg}" + except Exception as e: + print(f"Exception in call_webhook: {str(e)}") + return f"Error: Failed to call webhook - {str(e)}" + +async def call_mcp(tool_name: str, args: str, mcp_server_name: str, mcp_servers: dict) -> str: + """ + Calls the MCP with the given tool name and arguments. + """ + server_url = "http://localhost:8000/sse" #mcp_servers.get(tool_name, None) + print(args) + async with sse_client(url=server_url) as streams: + # Create a client session using the SSE streams + async with ClientSession(*streams) as session: + # Initialize the session (perform handshake with the server) + await session.initialize() + # Call the tool on the server and await the response + response = await session.call_tool(tool_name, arguments=json.loads(args)) + + # Print the response received from the server + print("Server response:", response) + + return response + async def catch_all(ctx: RunContextWrapper[Any], args: str, tool_name: str, tool_config: dict) -> str: + """ + Handles all tool calls by dispatching to the mock_tool function. + + Args: + ctx: The run context wrapper + args: The arguments passed to the tool + tool_name: The name of the tool being called + tool_config: The configuration of the tool + + Returns: + The response from the tool + """ print(f"Catch all called for tool: {tool_name}") print(f"Args: {args}") print(f"Tool config: {tool_config}") - #if tool_config.get("mock", False): - #& return tool_config.get("mockInstructions", "No mock instructions provided") - description = tool_config.get("description", "") - - messages = [ - {"role": "system", "content": f"You are simulating the execution of a tool called '{tool_name}'. The tool has this description: {description}. Generate a realistic response as if the tool was actually executed with the given parameters."}, - {"role": "user", "content": f"Generate a realistic response for the tool '{tool_name}' with these parameters: {args}. The response should be concise and focused on what the tool would actually return."} - - ] - response_content = generate_openai_output(messages, output_type='text', model="gpt-4o") - print(response_content) - return(response_content) + response_content = None + # Check if this tool should be mocked + if tool_config.get("mockTool", False): + # Call mock_tool to handle the response (it will decide whether to use mock instructions or generate a response) + response_content = await mock_tool(tool_name, args, tool_config) + print(response_content) + elif tool_config.get("isMcp", False): + response_content = await call_mcp(tool_name, args, tool_config.get("mcpServerName", ""), {}) + print(response_content) + else: + response_content = await call_webhook(tool_name, args) + print(response_content) + return response_content def get_agents(agent_configs, tool_configs): """ diff --git a/apps/agents/tests/sample_requests/default_example.json b/apps/agents/tests/sample_requests/default_example.json index 7dd1c4ec..8fc308c6 100644 --- a/apps/agents/tests/sample_requests/default_example.json +++ b/apps/agents/tests/sample_requests/default_example.json @@ -106,7 +106,9 @@ "required": [ "order_id" ] - } + }, + "mockTool": true, + "mockInstructions": "Return a mock response for Door Dash order details." }, { "name": "get_delivery_status",