diff --git a/surfsense_backend/app/agents/new_chat/tools/mcp_tool.py b/surfsense_backend/app/agents/new_chat/tools/mcp_tool.py index ddd65c7a7..7909657e0 100644 --- a/surfsense_backend/app/agents/new_chat/tools/mcp_tool.py +++ b/surfsense_backend/app/agents/new_chat/tools/mcp_tool.py @@ -16,6 +16,7 @@ clicking "Always Allow", which adds the tool name to the connector's from __future__ import annotations +import asyncio import logging import time from collections import defaultdict @@ -41,6 +42,7 @@ logger = logging.getLogger(__name__) _MCP_CACHE_TTL_SECONDS = 300 # 5 minutes _MCP_CACHE_MAX_SIZE = 50 +_MCP_DISCOVERY_TIMEOUT_SECONDS = 30 _mcp_tools_cache: dict[int, tuple[float, list[StructuredTool]]] = {} @@ -869,7 +871,7 @@ async def load_mcp_tools( multi_account_types, ) - tools: list[StructuredTool] = [] + discovery_tasks: list[dict[str, Any]] = [] for connector in connectors: try: cfg = connector.config or {} @@ -882,14 +884,10 @@ async def load_mcp_tools( ) continue - # For MCP OAuth connectors: refresh if needed, then decrypt the - # access token and inject it into headers at runtime. The DB - # intentionally does NOT store plaintext tokens in server_config. if cfg.get("mcp_oauth"): server_config = await _maybe_refresh_mcp_oauth_token( session, connector, cfg, server_config, ) - # Re-read cfg after potential refresh (connector was reloaded from DB). cfg = connector.config or {} server_config = _inject_oauth_headers(cfg, server_config) if server_config is None: @@ -911,7 +909,6 @@ async def load_mcp_tools( allowed_tools = svc_cfg.allowed_tools if svc_cfg else [] readonly_tools = svc_cfg.readonly_tools if svc_cfg else frozenset() - # Build a prefix only when multiple accounts share the same type. tool_name_prefix: str | None = None if ct in multi_account_types and svc_cfg: service_key = next( @@ -921,34 +918,66 @@ async def load_mcp_tools( if service_key: tool_name_prefix = f"{service_key}_{connector.id}" - transport = server_config.get("transport", "stdio") - - if transport in ("streamable-http", "http", "sse"): - connector_tools = await _load_http_mcp_tools( - connector.id, - connector.name, - server_config, - trusted_tools=trusted_tools, - allowed_tools=allowed_tools, - readonly_tools=readonly_tools, - tool_name_prefix=tool_name_prefix, - ) - else: - connector_tools = await _load_stdio_mcp_tools( - connector.id, - connector.name, - server_config, - trusted_tools=trusted_tools, - ) - - tools.extend(connector_tools) + discovery_tasks.append({ + "connector_id": connector.id, + "connector_name": connector.name, + "server_config": server_config, + "trusted_tools": trusted_tools, + "allowed_tools": allowed_tools, + "readonly_tools": readonly_tools, + "tool_name_prefix": tool_name_prefix, + "transport": server_config.get("transport", "stdio"), + }) except Exception as e: logger.exception( - "Failed to load tools from MCP connector %d: %s", + "Failed to prepare MCP connector %d: %s", connector.id, e, ) + async def _discover_one(task: dict[str, Any]) -> list[StructuredTool]: + try: + if task["transport"] in ("streamable-http", "http", "sse"): + return await asyncio.wait_for( + _load_http_mcp_tools( + task["connector_id"], + task["connector_name"], + task["server_config"], + trusted_tools=task["trusted_tools"], + allowed_tools=task["allowed_tools"], + readonly_tools=task["readonly_tools"], + tool_name_prefix=task["tool_name_prefix"], + ), + timeout=_MCP_DISCOVERY_TIMEOUT_SECONDS, + ) + else: + return await asyncio.wait_for( + _load_stdio_mcp_tools( + task["connector_id"], + task["connector_name"], + task["server_config"], + trusted_tools=task["trusted_tools"], + ), + timeout=_MCP_DISCOVERY_TIMEOUT_SECONDS, + ) + except asyncio.TimeoutError: + logger.error( + "MCP connector %d timed out after %ds during discovery", + task["connector_id"], _MCP_DISCOVERY_TIMEOUT_SECONDS, + ) + return [] + except Exception as e: + logger.exception( + "Failed to load tools from MCP connector %d: %s", + task["connector_id"], e, + ) + return [] + + results = await asyncio.gather(*[_discover_one(t) for t in discovery_tasks]) + tools: list[StructuredTool] = [ + tool for sublist in results for tool in sublist + ] + _mcp_tools_cache[search_space_id] = (now, tools) if len(_mcp_tools_cache) > _MCP_CACHE_MAX_SIZE: