diff --git a/surfsense_backend/app/agents/new_chat/tools/mcp_client.py b/surfsense_backend/app/agents/new_chat/tools/mcp_client.py index 437f93043..d4dbe2a0c 100644 --- a/surfsense_backend/app/agents/new_chat/tools/mcp_client.py +++ b/surfsense_backend/app/agents/new_chat/tools/mcp_client.py @@ -4,6 +4,7 @@ This module provides a client for communicating with MCP servers via stdio trans It handles server lifecycle management, tool discovery, and tool execution. """ +import asyncio import logging import os from contextlib import asynccontextmanager @@ -14,6 +15,11 @@ from mcp.client.stdio import StdioServerParameters, stdio_client logger = logging.getLogger(__name__) +# Retry configuration +MAX_RETRIES = 3 +RETRY_DELAY = 1.0 # seconds +RETRY_BACKOFF = 2.0 # exponential backoff multiplier + class MCPClient: """Client for communicating with an MCP server.""" @@ -35,44 +41,86 @@ class MCPClient: self.session: ClientSession | None = None @asynccontextmanager - async def connect(self): + async def connect(self, max_retries: int = MAX_RETRIES): """Connect to the MCP server and manage its lifecycle. + Args: + max_retries: Maximum number of connection retry attempts + Yields: ClientSession: Active MCP session for making requests + Raises: + RuntimeError: If all connection attempts fail + """ - try: - # Merge env vars with current environment - server_env = os.environ.copy() - server_env.update(self.env) + last_error = None + delay = RETRY_DELAY - # Create server parameters with env - server_params = StdioServerParameters( - command=self.command, args=self.args, env=server_env - ) + for attempt in range(max_retries): + try: + # Merge env vars with current environment + server_env = os.environ.copy() + server_env.update(self.env) - # Spawn server process and create session - # Note: Cannot combine these context managers because ClientSession - # needs the read/write streams from stdio_client - async with stdio_client(server=server_params) as (read, write): # noqa: SIM117 - async with ClientSession(read, write) as session: - # Initialize the connection - await session.initialize() - self.session = session - logger.info( - "Connected to MCP server: %s %s", - self.command, - " ".join(self.args), + # Create server parameters with env + server_params = StdioServerParameters( + command=self.command, args=self.args, env=server_env + ) + + # Spawn server process and create session + # Note: Cannot combine these context managers because ClientSession + # needs the read/write streams from stdio_client + async with stdio_client(server=server_params) as (read, write): # noqa: SIM117 + async with ClientSession(read, write) as session: + # Initialize the connection + await session.initialize() + self.session = session + + if attempt > 0: + logger.info( + "Connected to MCP server on attempt %d: %s %s", + attempt + 1, + self.command, + " ".join(self.args), + ) + else: + logger.info( + "Connected to MCP server: %s %s", + self.command, + " ".join(self.args), + ) + yield session + return # Success, exit retry loop + + except Exception as e: + last_error = e + if attempt < max_retries - 1: + logger.warning( + "MCP server connection failed (attempt %d/%d): %s. Retrying in %.1fs...", + attempt + 1, + max_retries, + e, + delay, ) - yield session + await asyncio.sleep(delay) + delay *= RETRY_BACKOFF # Exponential backoff + else: + logger.error( + "Failed to connect to MCP server after %d attempts: %s", + max_retries, + e, + exc_info=True, + ) + finally: + self.session = None - except Exception as e: - logger.error("Failed to connect to MCP server: %s", e, exc_info=True) - raise - finally: - self.session = None - logger.info("Disconnected from MCP server: %s", self.command) + # All retries exhausted + error_msg = f"Failed to connect to MCP server '{self.command}' after {max_retries} attempts" + if last_error: + error_msg += f": {last_error}" + logger.error(error_msg) + raise RuntimeError(error_msg) from last_error async def list_tools(self) -> list[dict[str, Any]]: """List all tools available from the MCP server. diff --git a/surfsense_backend/app/agents/new_chat/tools/mcp_tool.py b/surfsense_backend/app/agents/new_chat/tools/mcp_tool.py index 50339fb93..d7c9210af 100644 --- a/surfsense_backend/app/agents/new_chat/tools/mcp_tool.py +++ b/surfsense_backend/app/agents/new_chat/tools/mcp_tool.py @@ -90,16 +90,22 @@ async def _create_mcp_tool_from_definition( input_model = _create_dynamic_input_model_from_schema(tool_name, input_schema) async def mcp_tool_call(**kwargs) -> str: - """Execute the MCP tool call via the client.""" + """Execute the MCP tool call via the client with retry support.""" logger.info(f"MCP tool '{tool_name}' called with params: {kwargs}") try: - # Connect to server and call tool + # Connect to server and call tool (connect has built-in retry logic) async with mcp_client.connect(): result = await mcp_client.call_tool(tool_name, kwargs) return str(result) + except RuntimeError as e: + # Connection failures after all retries + error_msg = f"MCP tool '{tool_name}' connection failed after retries: {e!s}" + logger.error(error_msg) + return f"Error: {error_msg}" except Exception as e: - error_msg = f"MCP tool '{tool_name}' failed: {e!s}" + # Tool execution or other errors + error_msg = f"MCP tool '{tool_name}' execution failed: {e!s}" logger.exception(error_msg) return f"Error: {error_msg}" @@ -146,21 +152,38 @@ async def load_mcp_tools( tools: list[StructuredTool] = [] for connector in result.scalars(): try: - # Extract single server config + # Early validation: Extract and validate connector config config = connector.config or {} server_config = config.get("server_config", {}) - if not server_config: - logger.warning(f"MCP connector {connector.id} missing server_config, skipping") + # Validate server_config exists and is a dict + if not server_config or not isinstance(server_config, dict): + logger.warning( + f"MCP connector {connector.id} (name: '{connector.name}') has invalid or missing server_config, skipping" + ) continue + # Validate required command field command = server_config.get("command") - args = server_config.get("args", []) - env = server_config.get("env", {}) - - if not command: + if not command or not isinstance(command, str): logger.warning( - f"MCP connector {connector.id} missing command, skipping" + f"MCP connector {connector.id} (name: '{connector.name}') missing or invalid command field, skipping" + ) + continue + + # Validate args field (must be list if present) + args = server_config.get("args", []) + if not isinstance(args, list): + logger.warning( + f"MCP connector {connector.id} (name: '{connector.name}') has invalid args field (must be list), skipping" + ) + continue + + # Validate env field (must be dict if present) + env = server_config.get("env", {}) + if not isinstance(env, dict): + logger.warning( + f"MCP connector {connector.id} (name: '{connector.name}') has invalid env field (must be dict), skipping" ) continue