From f2d9e67ac2fd92854a6baafe24939635d25b9602 Mon Sep 17 00:00:00 2001 From: CREDO23 Date: Wed, 22 Apr 2026 18:57:43 +0200 Subject: [PATCH] add multi-account MCP tool disambiguation with prefix namespacing --- .../app/agents/new_chat/tools/mcp_tool.py | 174 ++++++++++++++---- 1 file changed, 136 insertions(+), 38 deletions(-) diff --git a/surfsense_backend/app/agents/new_chat/tools/mcp_tool.py b/surfsense_backend/app/agents/new_chat/tools/mcp_tool.py index 47ee16f7d..62ef56dd7 100644 --- a/surfsense_backend/app/agents/new_chat/tools/mcp_tool.py +++ b/surfsense_backend/app/agents/new_chat/tools/mcp_tool.py @@ -16,18 +16,21 @@ clicking "Always Allow", which adds the tool name to the connector's import logging import time +from collections import defaultdict from typing import Any from langchain_core.tools import StructuredTool from mcp import ClientSession from mcp.client.streamable_http import streamablehttp_client from pydantic import BaseModel, create_model -from sqlalchemy import select +from sqlalchemy import cast, select +from sqlalchemy.dialects.postgresql import JSONB from sqlalchemy.ext.asyncio import AsyncSession from app.agents.new_chat.tools.hitl import request_approval from app.agents.new_chat.tools.mcp_client import MCPClient from app.db import SearchSourceConnector, SearchSourceConnectorType +from app.services.mcp_oauth.registry import MCP_SERVICES, get_service_by_connector_type logger = logging.getLogger(__name__) @@ -123,7 +126,7 @@ async def _create_mcp_tool_from_definition_stdio( ) if hitl_result.rejected: return "Tool call rejected by user." - call_kwargs = hitl_result.params + call_kwargs = {k: v for k, v in hitl_result.params.items() if v is not None} try: async with mcp_client.connect(): @@ -163,41 +166,57 @@ async def _create_mcp_tool_from_definition_http( connector_name: str = "", connector_id: int | None = None, trusted_tools: list[str] | None = None, + readonly_tools: frozenset[str] | None = None, + tool_name_prefix: str | None = None, ) -> StructuredTool: """Create a LangChain tool from an MCP tool definition (HTTP transport). - All MCP tools are unconditionally wrapped with HITL approval. - ``request_approval()`` is called OUTSIDE the try/except so that - ``GraphInterrupt`` propagates cleanly to LangGraph. + Write tools are wrapped with HITL approval; read-only tools (listed in + ``readonly_tools``) execute immediately without user confirmation. + + When ``tool_name_prefix`` is set (multi-account disambiguation), the + tool exposed to the LLM gets a prefixed name (e.g. ``linear_25_list_issues``) + but the actual MCP ``call_tool`` still uses the original name. """ - tool_name = tool_def.get("name", "unnamed_tool") + original_tool_name = tool_def.get("name", "unnamed_tool") tool_description = tool_def.get("description", "No description provided") input_schema = tool_def.get("input_schema", {"type": "object", "properties": {}}) + is_readonly = readonly_tools is not None and original_tool_name in readonly_tools - logger.info(f"MCP HTTP tool '{tool_name}' input schema: {input_schema}") + exposed_name = ( + f"{tool_name_prefix}_{original_tool_name}" + if tool_name_prefix + else original_tool_name + ) + if tool_name_prefix: + tool_description = f"[Account: {connector_name}] {tool_description}" - input_model = _create_dynamic_input_model_from_schema(tool_name, input_schema) + logger.info(f"MCP HTTP tool '{exposed_name}' input schema: {input_schema}") + + input_model = _create_dynamic_input_model_from_schema(exposed_name, input_schema) async def mcp_http_tool_call(**kwargs) -> str: """Execute the MCP tool call via HTTP transport.""" - logger.info(f"MCP HTTP tool '{tool_name}' called with params: {kwargs}") + logger.info(f"MCP HTTP tool '{exposed_name}' called with params: {kwargs}") - # HITL — OUTSIDE try/except so GraphInterrupt propagates to LangGraph - hitl_result = request_approval( - action_type="mcp_tool_call", - tool_name=tool_name, - params=kwargs, - context={ - "mcp_server": connector_name, - "tool_description": tool_description, - "mcp_transport": "http", - "mcp_connector_id": connector_id, - }, - trusted_tools=trusted_tools, - ) - if hitl_result.rejected: - return "Tool call rejected by user." - call_kwargs = hitl_result.params + if is_readonly: + call_kwargs = {k: v for k, v in kwargs.items() if v is not None} + else: + hitl_result = request_approval( + action_type="mcp_tool_call", + tool_name=exposed_name, + params=kwargs, + context={ + "mcp_server": connector_name, + "tool_description": tool_description, + "mcp_transport": "http", + "mcp_connector_id": connector_id, + }, + trusted_tools=trusted_tools, + ) + if hitl_result.rejected: + return "Tool call rejected by user." + call_kwargs = {k: v for k, v in hitl_result.params.items() if v is not None} try: async with ( @@ -205,7 +224,9 @@ async def _create_mcp_tool_from_definition_http( ClientSession(read, write) as session, ): await session.initialize() - response = await session.call_tool(tool_name, arguments=call_kwargs) + response = await session.call_tool( + original_tool_name, arguments=call_kwargs, + ) result = [] for content in response.content: @@ -218,17 +239,17 @@ async def _create_mcp_tool_from_definition_http( result_str = "\n".join(result) if result else "" logger.info( - f"MCP HTTP tool '{tool_name}' succeeded: {result_str[:200]}" + f"MCP HTTP tool '{exposed_name}' succeeded: {result_str[:200]}" ) return result_str except Exception as e: - error_msg = f"MCP HTTP tool '{tool_name}' execution failed: {e!s}" + error_msg = f"MCP HTTP tool '{exposed_name}' execution failed: {e!s}" logger.exception(error_msg) return f"Error: {error_msg}" tool = StructuredTool( - name=tool_name, + name=exposed_name, description=tool_description, coroutine=mcp_http_tool_call, args_schema=input_model, @@ -236,12 +257,14 @@ async def _create_mcp_tool_from_definition_http( "mcp_input_schema": input_schema, "mcp_transport": "http", "mcp_url": url, - "hitl": True, + "hitl": not is_readonly, "hitl_dedup_key": next(iter(input_schema.get("required", [])), None), + "mcp_original_tool_name": original_tool_name, + "mcp_connector_id": connector_id, }, ) - logger.info(f"Created MCP tool (HTTP): '{tool_name}'") + logger.info(f"Created MCP tool (HTTP): '{exposed_name}'") return tool @@ -309,8 +332,20 @@ async def _load_http_mcp_tools( connector_name: str, server_config: dict[str, Any], trusted_tools: list[str] | None = None, + allowed_tools: list[str] | None = None, + readonly_tools: frozenset[str] | None = None, + tool_name_prefix: str | None = None, ) -> list[StructuredTool]: - """Load tools from an HTTP-based MCP server.""" + """Load tools from an HTTP-based MCP server. + + Args: + allowed_tools: If non-empty, only tools whose names appear in this + list are loaded. Empty/None means load everything (used for + user-managed generic MCP servers). + readonly_tools: Tool names that skip HITL approval (read-only operations). + tool_name_prefix: If set, each tool name is prefixed for multi-account + disambiguation (e.g. ``linear_25``). + """ tools: list[StructuredTool] = [] url = server_config.get("url") @@ -327,6 +362,8 @@ async def _load_http_mcp_tools( ) return tools + allowed_set = set(allowed_tools) if allowed_tools else None + try: async with ( streamablehttp_client(url, headers=headers) as (read, write, _), @@ -347,10 +384,21 @@ async def _load_http_mcp_tools( } ) - logger.info( - f"Discovered {len(tool_definitions)} tools from HTTP MCP server " - f"'{url}' (connector {connector_id})" - ) + total_discovered = len(tool_definitions) + + if allowed_set: + tool_definitions = [ + td for td in tool_definitions if td["name"] in allowed_set + ] + logger.info( + "HTTP MCP server '%s' (connector %d): %d/%d tools after allowlist filter", + url, connector_id, len(tool_definitions), total_discovered, + ) + else: + logger.info( + "Discovered %d tools from HTTP MCP server '%s' (connector %d) — no allowlist, loading all", + total_discovered, url, connector_id, + ) for tool_def in tool_definitions: try: @@ -361,6 +409,8 @@ async def _load_http_mcp_tools( connector_name=connector_name, connector_id=connector_id, trusted_tools=trusted_tools, + readonly_tools=readonly_tools, + tool_name_prefix=tool_name_prefix, ) tools.append(tool) except Exception as e: @@ -532,15 +582,39 @@ async def load_mcp_tools( try: # Find all connectors with MCP server config: generic MCP_CONNECTOR type # and service-specific types (LINEAR_CONNECTOR, etc.) created via MCP OAuth. + # Cast JSON -> JSONB so we can use has_key to filter by the presence of "server_config". result = await session.execute( select(SearchSourceConnector).filter( SearchSourceConnector.search_space_id == search_space_id, - SearchSourceConnector.config.has_key("server_config"), # noqa: W601 + cast(SearchSourceConnector.config, JSONB).has_key("server_config"), # noqa: W601 ), ) + connectors = list(result.scalars()) + + # Group connectors by type to detect multi-account scenarios. + # When >1 connector shares the same type, tool names would collide + # so we prefix them with "{service_key}_{connector_id}_". + type_groups: dict[str, list[SearchSourceConnector]] = defaultdict(list) + for connector in connectors: + ct = ( + connector.connector_type.value + if hasattr(connector.connector_type, "value") + else str(connector.connector_type) + ) + type_groups[ct].append(connector) + + multi_account_types: set[str] = { + ct for ct, group in type_groups.items() if len(group) > 1 + } + if multi_account_types: + logger.info( + "Multi-account detected for connector types: %s", + multi_account_types, + ) + tools: list[StructuredTool] = [] - for connector in result.scalars(): + for connector in connectors: try: cfg = connector.config or {} server_config = cfg.get("server_config", {}) @@ -558,6 +632,27 @@ async def load_mcp_tools( session, connector, cfg, server_config, ) + ct = ( + connector.connector_type.value + if hasattr(connector.connector_type, "value") + else str(connector.connector_type) + ) + + # Resolve the allowlist from the service registry (if any). + svc_cfg = get_service_by_connector_type(ct) + allowed_tools = svc_cfg.allowed_tools if svc_cfg else [] + readonly_tools = svc_cfg.readonly_tools if svc_cfg else frozenset() + + # Build a prefix only when multiple accounts share the same type. + tool_name_prefix: str | None = None + if ct in multi_account_types and svc_cfg: + service_key = next( + (k for k, v in MCP_SERVICES.items() if v is svc_cfg), + None, + ) + if service_key: + tool_name_prefix = f"{service_key}_{connector.id}" + transport = server_config.get("transport", "stdio") if transport in ("streamable-http", "http", "sse"): @@ -566,6 +661,9 @@ async def load_mcp_tools( connector.name, server_config, trusted_tools=trusted_tools, + allowed_tools=allowed_tools, + readonly_tools=readonly_tools, + tool_name_prefix=tool_name_prefix, ) else: connector_tools = await _load_stdio_mcp_tools(