diff --git a/surfsense_backend/app/agents/new_chat/chat_deepagent.py b/surfsense_backend/app/agents/new_chat/chat_deepagent.py index 4b204ffa9..89aa13620 100644 --- a/surfsense_backend/app/agents/new_chat/chat_deepagent.py +++ b/surfsense_backend/app/agents/new_chat/chat_deepagent.py @@ -314,6 +314,20 @@ async def create_surfsense_deep_agent( _t0 = time.perf_counter() _enabled_tool_names = {t.name for t in tools} _user_disabled_tool_names = set(disabled_tools) if disabled_tools else set() + + # Collect generic MCP connector info so the system prompt can route queries + # to their tools instead of falling back to "not in knowledge base". + _mcp_connector_tools: dict[str, list[str]] = {} + for t in tools: + meta = getattr(t, "metadata", None) or {} + if meta.get("mcp_is_generic") and meta.get("mcp_connector_name"): + _mcp_connector_tools.setdefault( + meta["mcp_connector_name"], [], + ).append(t.name) + + if _mcp_connector_tools: + _perf_log.info("MCP connector tool routing: %s", _mcp_connector_tools) + if agent_config is not None: system_prompt = build_configurable_system_prompt( custom_system_instructions=agent_config.system_instructions, @@ -322,12 +336,14 @@ async def create_surfsense_deep_agent( thread_visibility=thread_visibility, enabled_tool_names=_enabled_tool_names, disabled_tool_names=_user_disabled_tool_names, + mcp_connector_tools=_mcp_connector_tools, ) else: system_prompt = build_surfsense_system_prompt( thread_visibility=thread_visibility, enabled_tool_names=_enabled_tool_names, disabled_tool_names=_user_disabled_tool_names, + mcp_connector_tools=_mcp_connector_tools, ) _perf_log.info( "[create_agent] System prompt built in %.3fs", time.perf_counter() - _t0 diff --git a/surfsense_backend/app/agents/new_chat/system_prompt.py b/surfsense_backend/app/agents/new_chat/system_prompt.py index 3182735d9..e77132182 100644 --- a/surfsense_backend/app/agents/new_chat/system_prompt.py +++ b/surfsense_backend/app/agents/new_chat/system_prompt.py @@ -815,11 +815,36 @@ Your goal is to provide helpful, informative answers in a clean, readable format """ +def _build_mcp_routing_block( + mcp_connector_tools: dict[str, list[str]] | None, +) -> str: + """Build an additional tool routing block for generic MCP connectors. + + When users add MCP servers (e.g. GitLab, GitHub), the LLM needs to know + those tools exist and should be called directly — not searched in the + knowledge base. + """ + if not mcp_connector_tools: + return "" + + lines = [ + "\n", + "You also have direct tools from these user-connected MCP servers.", + "Their data is NEVER in the knowledge base — call their tools directly.", + "", + ] + for server_name, tool_names in mcp_connector_tools.items(): + lines.append(f"- {server_name} → {', '.join(tool_names)}") + lines.append("\n") + return "\n".join(lines) + + def build_surfsense_system_prompt( today: datetime | None = None, thread_visibility: ChatVisibility | None = None, enabled_tool_names: set[str] | None = None, disabled_tool_names: set[str] | None = None, + mcp_connector_tools: dict[str, list[str]] | None = None, ) -> str: """ Build the SurfSense system prompt with default settings. @@ -834,6 +859,9 @@ def build_surfsense_system_prompt( thread_visibility: Optional; when provided, used for conditional prompt (e.g. private vs shared memory wording). Defaults to private behavior when None. enabled_tool_names: Set of tool names actually bound to the agent. When None all tools are included. disabled_tool_names: Set of tool names the user explicitly disabled. Included as a note so the model can inform the user. + mcp_connector_tools: Mapping of MCP server display name → list of tool names + for generic MCP connectors. Injected into the system prompt so the LLM + knows to call these tools directly. Returns: Complete system prompt string @@ -841,6 +869,7 @@ def build_surfsense_system_prompt( visibility = thread_visibility or ChatVisibility.PRIVATE system_instructions = _get_system_instructions(visibility, today) + system_instructions += _build_mcp_routing_block(mcp_connector_tools) tools_instructions = _get_tools_instructions( visibility, enabled_tool_names, disabled_tool_names ) @@ -856,6 +885,7 @@ def build_configurable_system_prompt( thread_visibility: ChatVisibility | None = None, enabled_tool_names: set[str] | None = None, disabled_tool_names: set[str] | None = None, + mcp_connector_tools: dict[str, list[str]] | None = None, ) -> str: """ Build a configurable SurfSense system prompt based on NewLLMConfig settings. @@ -877,6 +907,9 @@ def build_configurable_system_prompt( thread_visibility: Optional; when provided, used for conditional prompt (e.g. private vs shared memory wording). Defaults to private behavior when None. enabled_tool_names: Set of tool names actually bound to the agent. When None all tools are included. disabled_tool_names: Set of tool names the user explicitly disabled. Included as a note so the model can inform the user. + mcp_connector_tools: Mapping of MCP server display name → list of tool names + for generic MCP connectors. Injected into the system prompt so the LLM + knows to call these tools directly. Returns: Complete system prompt string @@ -894,6 +927,8 @@ def build_configurable_system_prompt( else: system_instructions = "" + system_instructions += _build_mcp_routing_block(mcp_connector_tools) + # Tools instructions: only include enabled tools, note disabled ones tools_instructions = _get_tools_instructions( thread_visibility, enabled_tool_names, disabled_tool_names diff --git a/surfsense_backend/app/agents/new_chat/tools/mcp_client.py b/surfsense_backend/app/agents/new_chat/tools/mcp_client.py index 44c48344c..b46ddbcc5 100644 --- a/surfsense_backend/app/agents/new_chat/tools/mcp_client.py +++ b/surfsense_backend/app/agents/new_chat/tools/mcp_client.py @@ -45,6 +45,18 @@ class MCPClient: async def connect(self, max_retries: int = MAX_RETRIES): """Connect to the MCP server and manage its lifecycle. + Retries only apply to the **connection** phase (spawning the process, + initialising the session). Once the session is yielded to the caller, + any exception raised by the caller propagates normally -- the context + manager will NOT retry after ``yield``. + + Previous implementation wrapped both connection AND yield inside the + retry loop. Because ``@asynccontextmanager`` only allows a single + ``yield``, a failure after yield caused the generator to attempt a + second yield on retry, triggering + ``RuntimeError("generator didn't stop after athrow()")`` and orphaning + the stdio subprocess. + Args: max_retries: Maximum number of connection retry attempts @@ -57,26 +69,22 @@ class MCPClient: """ last_error = None delay = RETRY_DELAY + connected = False for attempt in range(max_retries): try: - # Merge env vars with current environment server_env = os.environ.copy() server_env.update(self.env) - # Create server parameters with env server_params = StdioServerParameters( command=self.command, args=self.args, env=server_env ) - # Spawn server process and create session - # Note: Cannot combine these context managers because ClientSession - # needs the read/write streams from stdio_client async with stdio_client(server=server_params) as (read, write): # noqa: SIM117 async with ClientSession(read, write) as session: - # Initialize the connection await session.initialize() self.session = session + connected = True if attempt > 0: logger.info( @@ -91,10 +99,16 @@ class MCPClient: self.command, " ".join(self.args), ) - yield session - return # Success, exit retry loop + try: + yield session + finally: + self.session = None + return except Exception as e: + self.session = None + if connected: + raise last_error = e if attempt < max_retries - 1: logger.warning( @@ -105,7 +119,7 @@ class MCPClient: delay, ) await asyncio.sleep(delay) - delay *= RETRY_BACKOFF # Exponential backoff + delay *= RETRY_BACKOFF else: logger.error( "Failed to connect to MCP server after %d attempts: %s", @@ -113,10 +127,7 @@ class MCPClient: e, exc_info=True, ) - finally: - self.session = None - # All retries exhausted error_msg = f"Failed to connect to MCP server '{self.command}' after {max_retries} attempts" if last_error: error_msg += f": {last_error}" @@ -161,12 +172,18 @@ class MCPClient: logger.error("Failed to list tools from MCP server: %s", e, exc_info=True) raise - async def call_tool(self, tool_name: str, arguments: dict[str, Any]) -> Any: + async def call_tool( + self, + tool_name: str, + arguments: dict[str, Any], + timeout: float = 60.0, + ) -> Any: """Call a tool on the MCP server. Args: tool_name: Name of the tool to call arguments: Arguments to pass to the tool + timeout: Maximum seconds to wait for the tool to respond Returns: Tool execution result @@ -185,10 +202,11 @@ class MCPClient: "Calling MCP tool '%s' with arguments: %s", tool_name, arguments ) - # Call tools/call RPC method - response = await self.session.call_tool(tool_name, arguments=arguments) + response = await asyncio.wait_for( + self.session.call_tool(tool_name, arguments=arguments), + timeout=timeout, + ) - # Extract content from response result = [] for content in response.content: if hasattr(content, "text"): @@ -202,15 +220,17 @@ class MCPClient: logger.info("MCP tool '%s' succeeded: %s", tool_name, result_str[:200]) return result_str + except asyncio.TimeoutError: + logger.error( + "MCP tool '%s' timed out after %.0fs", tool_name, timeout + ) + return f"Error: MCP tool '{tool_name}' timed out after {timeout:.0f}s" except RuntimeError as e: - # Handle validation errors from MCP server responses - # Some MCP servers (like server-memory) return extra fields not in their schema if "Invalid structured content" in str(e): logger.warning( "MCP server returned data not matching its schema, but continuing: %s", e, ) - # Try to extract result from error message or return a success message return "Operation completed (server returned unexpected format)" raise except (ValueError, TypeError, AttributeError, KeyError) as e: diff --git a/surfsense_backend/app/agents/new_chat/tools/mcp_tool.py b/surfsense_backend/app/agents/new_chat/tools/mcp_tool.py index b0dcd72b6..dfee24516 100644 --- a/surfsense_backend/app/agents/new_chat/tools/mcp_tool.py +++ b/surfsense_backend/app/agents/new_chat/tools/mcp_tool.py @@ -28,7 +28,7 @@ if TYPE_CHECKING: from langchain_core.tools import StructuredTool from mcp import ClientSession from mcp.client.streamable_http import streamablehttp_client -from pydantic import BaseModel, Field, create_model +from pydantic import BaseModel, ConfigDict, Field, create_model from sqlalchemy import cast, select from sqlalchemy.dialects.postgresql import JSONB from sqlalchemy.ext.asyncio import AsyncSession @@ -43,6 +43,8 @@ logger = logging.getLogger(__name__) _MCP_CACHE_TTL_SECONDS = 300 # 5 minutes _MCP_CACHE_MAX_SIZE = 50 _MCP_DISCOVERY_TIMEOUT_SECONDS = 30 +_TOOL_CALL_MAX_RETRIES = 3 +_TOOL_CALL_RETRY_DELAY = 1.5 # seconds, doubles per attempt _mcp_tools_cache: dict[int, tuple[float, list[StructuredTool]]] = {} @@ -64,7 +66,18 @@ def _create_dynamic_input_model_from_schema( tool_name: str, input_schema: dict[str, Any], ) -> type[BaseModel]: - """Create a Pydantic model from MCP tool's JSON schema.""" + """Create a Pydantic model from MCP tool's JSON schema. + + Models always allow extra fields (``extra="allow"``) so that parameters + missing from a broken or incomplete JSON schema (e.g. ``zod-to-json-schema`` + producing an empty ``$schema``-only object) can still be forwarded to the + MCP server. + + When the schema declares **no** properties, a synthetic ``input_data`` + field of type ``dict`` is injected so the LLM has a visible parameter to + populate. The caller should unpack ``input_data`` before forwarding to + the MCP server (see ``_unpack_synthetic_input_data``). + """ properties = input_schema.get("properties", {}) required_fields = input_schema.get("required", []) @@ -84,8 +97,35 @@ def _create_dynamic_input_model_from_schema( Field(None, description=param_description), ) + if not properties: + field_definitions["input_data"] = ( + dict[str, Any] | None, + Field( + None, + description=( + "Arguments to pass to this tool as a JSON object. " + "Infer sensible key names from the tool name and description " + "(e.g. {\"search\": \"my query\"} for a search tool)." + ), + ), + ) + model_name = f"{tool_name.replace(' ', '').replace('-', '_')}Input" - return create_model(model_name, **field_definitions) + model = create_model(model_name, __config__=ConfigDict(extra="allow"), **field_definitions) + return model + + +def _unpack_synthetic_input_data(kwargs: dict[str, Any]) -> dict[str, Any]: + """Unpack the synthetic ``input_data`` field into top-level kwargs. + + When the MCP tool schema is empty, ``_create_dynamic_input_model_from_schema`` + adds a catch-all ``input_data: dict`` field. This helper merges that dict + back into the top-level kwargs so the MCP server receives flat arguments. + """ + input_data = kwargs.pop("input_data", None) + if isinstance(input_data, dict): + kwargs.update(input_data) + return kwargs async def _create_mcp_tool_from_definition_stdio( @@ -103,7 +143,12 @@ async def _create_mcp_tool_from_definition_stdio( ``GraphInterrupt`` propagates cleanly to LangGraph. """ tool_name = tool_def.get("name", "unnamed_tool") - tool_description = tool_def.get("description", "No description provided") + raw_description = tool_def.get("description", "No description provided") + tool_description = ( + f"[MCP server: {connector_name}] {raw_description}" + if connector_name + else raw_description + ) input_schema = tool_def.get("input_schema", {"type": "object", "properties": {}}) logger.debug("MCP tool '%s' input schema: %s", tool_name, input_schema) @@ -121,7 +166,7 @@ async def _create_mcp_tool_from_definition_stdio( params=kwargs, context={ "mcp_server": connector_name, - "tool_description": tool_description, + "tool_description": raw_description, "mcp_transport": "stdio", "mcp_connector_id": connector_id, }, @@ -129,18 +174,32 @@ async def _create_mcp_tool_from_definition_stdio( ) if hitl_result.rejected: return "Tool call rejected by user." - call_kwargs = {k: v for k, v in hitl_result.params.items() if v is not None} + call_kwargs = _unpack_synthetic_input_data( + {k: v for k, v in hitl_result.params.items() if v is not None} + ) - try: - async with mcp_client.connect(): - result = await mcp_client.call_tool(tool_name, call_kwargs) - return str(result) - except RuntimeError as e: - logger.error("MCP tool '%s' connection failed after retries: %s", tool_name, e) - return f"Error: MCP tool '{tool_name}' connection failed after retries: {e!s}" - except Exception as e: - logger.exception("MCP tool '%s' execution failed: %s", tool_name, e) - return f"Error: MCP tool '{tool_name}' execution failed: {e!s}" + last_error: Exception | None = None + for attempt in range(_TOOL_CALL_MAX_RETRIES): + try: + async with mcp_client.connect(): + result = await mcp_client.call_tool(tool_name, call_kwargs) + return str(result) + except Exception as e: + last_error = e + if attempt < _TOOL_CALL_MAX_RETRIES - 1: + delay = _TOOL_CALL_RETRY_DELAY * (2 ** attempt) + logger.warning( + "MCP tool '%s' failed (attempt %d/%d): %s. Retrying in %.1fs...", + tool_name, attempt + 1, _TOOL_CALL_MAX_RETRIES, e, delay, + ) + await asyncio.sleep(delay) + else: + logger.error( + "MCP tool '%s' failed after %d attempts: %s", + tool_name, _TOOL_CALL_MAX_RETRIES, e, exc_info=True, + ) + + return f"Error: MCP tool '{tool_name}' failed after {_TOOL_CALL_MAX_RETRIES} attempts: {last_error!s}" tool = StructuredTool( name=tool_name, @@ -150,6 +209,8 @@ async def _create_mcp_tool_from_definition_stdio( metadata={ "mcp_input_schema": input_schema, "mcp_transport": "stdio", + "mcp_connector_name": connector_name or None, + "mcp_is_generic": True, "hitl": True, "hitl_dedup_key": next(iter(input_schema.get("required", [])), None), }, @@ -169,6 +230,7 @@ async def _create_mcp_tool_from_definition_http( trusted_tools: list[str] | None = None, readonly_tools: frozenset[str] | None = None, tool_name_prefix: str | None = None, + is_generic_mcp: bool = False, ) -> StructuredTool: """Create a LangChain tool from an MCP tool definition (HTTP transport). @@ -180,7 +242,7 @@ async def _create_mcp_tool_from_definition_http( but the actual MCP ``call_tool`` still uses the original name. """ original_tool_name = tool_def.get("name", "unnamed_tool") - tool_description = tool_def.get("description", "No description provided") + raw_description = tool_def.get("description", "No description provided") input_schema = tool_def.get("input_schema", {"type": "object", "properties": {}}) is_readonly = readonly_tools is not None and original_tool_name in readonly_tools @@ -190,7 +252,11 @@ async def _create_mcp_tool_from_definition_http( else original_tool_name ) if tool_name_prefix: - tool_description = f"[Account: {connector_name}] {tool_description}" + tool_description = f"[Account: {connector_name}] {raw_description}" + elif is_generic_mcp and connector_name: + tool_description = f"[MCP server: {connector_name}] {raw_description}" + else: + tool_description = raw_description logger.debug("MCP HTTP tool '%s' input schema: %s", exposed_name, input_schema) @@ -199,6 +265,7 @@ async def _create_mcp_tool_from_definition_http( async def _do_mcp_call( call_headers: dict[str, str], call_kwargs: dict[str, Any], + timeout: float = 60.0, ) -> str: """Execute a single MCP HTTP call with the given headers.""" async with ( @@ -206,8 +273,9 @@ async def _create_mcp_tool_from_definition_http( ClientSession(read, write) as session, ): await session.initialize() - response = await session.call_tool( - original_tool_name, arguments=call_kwargs, + response = await asyncio.wait_for( + session.call_tool(original_tool_name, arguments=call_kwargs), + timeout=timeout, ) result = [] @@ -226,7 +294,9 @@ async def _create_mcp_tool_from_definition_http( logger.debug("MCP HTTP tool '%s' called", exposed_name) if is_readonly: - call_kwargs = {k: v for k, v in kwargs.items() if v is not None} + call_kwargs = _unpack_synthetic_input_data( + {k: v for k, v in kwargs.items() if v is not None} + ) else: hitl_result = request_approval( action_type="mcp_tool_call", @@ -234,7 +304,7 @@ async def _create_mcp_tool_from_definition_http( params=kwargs, context={ "mcp_server": connector_name, - "tool_description": tool_description, + "tool_description": raw_description, "mcp_transport": "http", "mcp_connector_id": connector_id, }, @@ -242,7 +312,9 @@ async def _create_mcp_tool_from_definition_http( ) if hitl_result.rejected: return "Tool call rejected by user." - call_kwargs = {k: v for k, v in hitl_result.params.items() if v is not None} + call_kwargs = _unpack_synthetic_input_data( + {k: v for k, v in hitl_result.params.items() if v is not None} + ) try: result_str = await _do_mcp_call(headers, call_kwargs) @@ -295,6 +367,8 @@ async def _create_mcp_tool_from_definition_http( "mcp_input_schema": input_schema, "mcp_transport": "http", "mcp_url": url, + "mcp_connector_name": connector_name or None, + "mcp_is_generic": is_generic_mcp, "hitl": not is_readonly, "hitl_dedup_key": next(iter(input_schema.get("required", [])), None), "mcp_original_tool_name": original_tool_name, @@ -376,6 +450,7 @@ async def _load_http_mcp_tools( allowed_tools: list[str] | None = None, readonly_tools: frozenset[str] | None = None, tool_name_prefix: str | None = None, + is_generic_mcp: bool = False, ) -> list[StructuredTool]: """Load tools from an HTTP-based MCP server. @@ -492,6 +567,7 @@ async def _load_http_mcp_tools( trusted_tools=trusted_tools, readonly_tools=readonly_tools, tool_name_prefix=tool_name_prefix, + is_generic_mcp=is_generic_mcp, ) tools.append(tool) except Exception as e: @@ -928,6 +1004,7 @@ async def load_mcp_tools( "readonly_tools": readonly_tools, "tool_name_prefix": tool_name_prefix, "transport": server_config.get("transport", "stdio"), + "is_generic_mcp": svc_cfg is None, }) except Exception as e: @@ -948,6 +1025,7 @@ async def load_mcp_tools( allowed_tools=task["allowed_tools"], readonly_tools=task["readonly_tools"], tool_name_prefix=task["tool_name_prefix"], + is_generic_mcp=task.get("is_generic_mcp", False), ), timeout=_MCP_DISCOVERY_TIMEOUT_SECONDS, )