add multi-account MCP tool disambiguation with prefix namespacing

This commit is contained in:
CREDO23 2026-04-22 18:57:43 +02:00
parent 9eb54bc4af
commit f2d9e67ac2

View file

@ -16,18 +16,21 @@ clicking "Always Allow", which adds the tool name to the connector's
import logging import logging
import time import time
from collections import defaultdict
from typing import Any from typing import Any
from langchain_core.tools import StructuredTool from langchain_core.tools import StructuredTool
from mcp import ClientSession from mcp import ClientSession
from mcp.client.streamable_http import streamablehttp_client from mcp.client.streamable_http import streamablehttp_client
from pydantic import BaseModel, create_model from pydantic import BaseModel, create_model
from sqlalchemy import select from sqlalchemy import cast, select
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.new_chat.tools.hitl import request_approval from app.agents.new_chat.tools.hitl import request_approval
from app.agents.new_chat.tools.mcp_client import MCPClient from app.agents.new_chat.tools.mcp_client import MCPClient
from app.db import SearchSourceConnector, SearchSourceConnectorType from app.db import SearchSourceConnector, SearchSourceConnectorType
from app.services.mcp_oauth.registry import MCP_SERVICES, get_service_by_connector_type
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -123,7 +126,7 @@ async def _create_mcp_tool_from_definition_stdio(
) )
if hitl_result.rejected: if hitl_result.rejected:
return "Tool call rejected by user." return "Tool call rejected by user."
call_kwargs = hitl_result.params call_kwargs = {k: v for k, v in hitl_result.params.items() if v is not None}
try: try:
async with mcp_client.connect(): async with mcp_client.connect():
@ -163,41 +166,57 @@ async def _create_mcp_tool_from_definition_http(
connector_name: str = "", connector_name: str = "",
connector_id: int | None = None, connector_id: int | None = None,
trusted_tools: list[str] | None = None, trusted_tools: list[str] | None = None,
readonly_tools: frozenset[str] | None = None,
tool_name_prefix: str | None = None,
) -> StructuredTool: ) -> StructuredTool:
"""Create a LangChain tool from an MCP tool definition (HTTP transport). """Create a LangChain tool from an MCP tool definition (HTTP transport).
All MCP tools are unconditionally wrapped with HITL approval. Write tools are wrapped with HITL approval; read-only tools (listed in
``request_approval()`` is called OUTSIDE the try/except so that ``readonly_tools``) execute immediately without user confirmation.
``GraphInterrupt`` propagates cleanly to LangGraph.
When ``tool_name_prefix`` is set (multi-account disambiguation), the
tool exposed to the LLM gets a prefixed name (e.g. ``linear_25_list_issues``)
but the actual MCP ``call_tool`` still uses the original name.
""" """
tool_name = tool_def.get("name", "unnamed_tool") original_tool_name = tool_def.get("name", "unnamed_tool")
tool_description = tool_def.get("description", "No description provided") tool_description = tool_def.get("description", "No description provided")
input_schema = tool_def.get("input_schema", {"type": "object", "properties": {}}) input_schema = tool_def.get("input_schema", {"type": "object", "properties": {}})
is_readonly = readonly_tools is not None and original_tool_name in readonly_tools
logger.info(f"MCP HTTP tool '{tool_name}' input schema: {input_schema}") exposed_name = (
f"{tool_name_prefix}_{original_tool_name}"
if tool_name_prefix
else original_tool_name
)
if tool_name_prefix:
tool_description = f"[Account: {connector_name}] {tool_description}"
input_model = _create_dynamic_input_model_from_schema(tool_name, input_schema) logger.info(f"MCP HTTP tool '{exposed_name}' input schema: {input_schema}")
input_model = _create_dynamic_input_model_from_schema(exposed_name, input_schema)
async def mcp_http_tool_call(**kwargs) -> str: async def mcp_http_tool_call(**kwargs) -> str:
"""Execute the MCP tool call via HTTP transport.""" """Execute the MCP tool call via HTTP transport."""
logger.info(f"MCP HTTP tool '{tool_name}' called with params: {kwargs}") logger.info(f"MCP HTTP tool '{exposed_name}' called with params: {kwargs}")
# HITL — OUTSIDE try/except so GraphInterrupt propagates to LangGraph if is_readonly:
hitl_result = request_approval( call_kwargs = {k: v for k, v in kwargs.items() if v is not None}
action_type="mcp_tool_call", else:
tool_name=tool_name, hitl_result = request_approval(
params=kwargs, action_type="mcp_tool_call",
context={ tool_name=exposed_name,
"mcp_server": connector_name, params=kwargs,
"tool_description": tool_description, context={
"mcp_transport": "http", "mcp_server": connector_name,
"mcp_connector_id": connector_id, "tool_description": tool_description,
}, "mcp_transport": "http",
trusted_tools=trusted_tools, "mcp_connector_id": connector_id,
) },
if hitl_result.rejected: trusted_tools=trusted_tools,
return "Tool call rejected by user." )
call_kwargs = hitl_result.params if hitl_result.rejected:
return "Tool call rejected by user."
call_kwargs = {k: v for k, v in hitl_result.params.items() if v is not None}
try: try:
async with ( async with (
@ -205,7 +224,9 @@ async def _create_mcp_tool_from_definition_http(
ClientSession(read, write) as session, ClientSession(read, write) as session,
): ):
await session.initialize() await session.initialize()
response = await session.call_tool(tool_name, arguments=call_kwargs) response = await session.call_tool(
original_tool_name, arguments=call_kwargs,
)
result = [] result = []
for content in response.content: for content in response.content:
@ -218,17 +239,17 @@ async def _create_mcp_tool_from_definition_http(
result_str = "\n".join(result) if result else "" result_str = "\n".join(result) if result else ""
logger.info( logger.info(
f"MCP HTTP tool '{tool_name}' succeeded: {result_str[:200]}" f"MCP HTTP tool '{exposed_name}' succeeded: {result_str[:200]}"
) )
return result_str return result_str
except Exception as e: except Exception as e:
error_msg = f"MCP HTTP tool '{tool_name}' execution failed: {e!s}" error_msg = f"MCP HTTP tool '{exposed_name}' execution failed: {e!s}"
logger.exception(error_msg) logger.exception(error_msg)
return f"Error: {error_msg}" return f"Error: {error_msg}"
tool = StructuredTool( tool = StructuredTool(
name=tool_name, name=exposed_name,
description=tool_description, description=tool_description,
coroutine=mcp_http_tool_call, coroutine=mcp_http_tool_call,
args_schema=input_model, args_schema=input_model,
@ -236,12 +257,14 @@ async def _create_mcp_tool_from_definition_http(
"mcp_input_schema": input_schema, "mcp_input_schema": input_schema,
"mcp_transport": "http", "mcp_transport": "http",
"mcp_url": url, "mcp_url": url,
"hitl": True, "hitl": not is_readonly,
"hitl_dedup_key": next(iter(input_schema.get("required", [])), None), "hitl_dedup_key": next(iter(input_schema.get("required", [])), None),
"mcp_original_tool_name": original_tool_name,
"mcp_connector_id": connector_id,
}, },
) )
logger.info(f"Created MCP tool (HTTP): '{tool_name}'") logger.info(f"Created MCP tool (HTTP): '{exposed_name}'")
return tool return tool
@ -309,8 +332,20 @@ async def _load_http_mcp_tools(
connector_name: str, connector_name: str,
server_config: dict[str, Any], server_config: dict[str, Any],
trusted_tools: list[str] | None = None, trusted_tools: list[str] | None = None,
allowed_tools: list[str] | None = None,
readonly_tools: frozenset[str] | None = None,
tool_name_prefix: str | None = None,
) -> list[StructuredTool]: ) -> list[StructuredTool]:
"""Load tools from an HTTP-based MCP server.""" """Load tools from an HTTP-based MCP server.
Args:
allowed_tools: If non-empty, only tools whose names appear in this
list are loaded. Empty/None means load everything (used for
user-managed generic MCP servers).
readonly_tools: Tool names that skip HITL approval (read-only operations).
tool_name_prefix: If set, each tool name is prefixed for multi-account
disambiguation (e.g. ``linear_25``).
"""
tools: list[StructuredTool] = [] tools: list[StructuredTool] = []
url = server_config.get("url") url = server_config.get("url")
@ -327,6 +362,8 @@ async def _load_http_mcp_tools(
) )
return tools return tools
allowed_set = set(allowed_tools) if allowed_tools else None
try: try:
async with ( async with (
streamablehttp_client(url, headers=headers) as (read, write, _), streamablehttp_client(url, headers=headers) as (read, write, _),
@ -347,10 +384,21 @@ async def _load_http_mcp_tools(
} }
) )
logger.info( total_discovered = len(tool_definitions)
f"Discovered {len(tool_definitions)} tools from HTTP MCP server "
f"'{url}' (connector {connector_id})" if allowed_set:
) tool_definitions = [
td for td in tool_definitions if td["name"] in allowed_set
]
logger.info(
"HTTP MCP server '%s' (connector %d): %d/%d tools after allowlist filter",
url, connector_id, len(tool_definitions), total_discovered,
)
else:
logger.info(
"Discovered %d tools from HTTP MCP server '%s' (connector %d) — no allowlist, loading all",
total_discovered, url, connector_id,
)
for tool_def in tool_definitions: for tool_def in tool_definitions:
try: try:
@ -361,6 +409,8 @@ async def _load_http_mcp_tools(
connector_name=connector_name, connector_name=connector_name,
connector_id=connector_id, connector_id=connector_id,
trusted_tools=trusted_tools, trusted_tools=trusted_tools,
readonly_tools=readonly_tools,
tool_name_prefix=tool_name_prefix,
) )
tools.append(tool) tools.append(tool)
except Exception as e: except Exception as e:
@ -532,15 +582,39 @@ async def load_mcp_tools(
try: try:
# Find all connectors with MCP server config: generic MCP_CONNECTOR type # Find all connectors with MCP server config: generic MCP_CONNECTOR type
# and service-specific types (LINEAR_CONNECTOR, etc.) created via MCP OAuth. # and service-specific types (LINEAR_CONNECTOR, etc.) created via MCP OAuth.
# Cast JSON -> JSONB so we can use has_key to filter by the presence of "server_config".
result = await session.execute( result = await session.execute(
select(SearchSourceConnector).filter( select(SearchSourceConnector).filter(
SearchSourceConnector.search_space_id == search_space_id, SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.config.has_key("server_config"), # noqa: W601 cast(SearchSourceConnector.config, JSONB).has_key("server_config"), # noqa: W601
), ),
) )
connectors = list(result.scalars())
# Group connectors by type to detect multi-account scenarios.
# When >1 connector shares the same type, tool names would collide
# so we prefix them with "{service_key}_{connector_id}_".
type_groups: dict[str, list[SearchSourceConnector]] = defaultdict(list)
for connector in connectors:
ct = (
connector.connector_type.value
if hasattr(connector.connector_type, "value")
else str(connector.connector_type)
)
type_groups[ct].append(connector)
multi_account_types: set[str] = {
ct for ct, group in type_groups.items() if len(group) > 1
}
if multi_account_types:
logger.info(
"Multi-account detected for connector types: %s",
multi_account_types,
)
tools: list[StructuredTool] = [] tools: list[StructuredTool] = []
for connector in result.scalars(): for connector in connectors:
try: try:
cfg = connector.config or {} cfg = connector.config or {}
server_config = cfg.get("server_config", {}) server_config = cfg.get("server_config", {})
@ -558,6 +632,27 @@ async def load_mcp_tools(
session, connector, cfg, server_config, session, connector, cfg, server_config,
) )
ct = (
connector.connector_type.value
if hasattr(connector.connector_type, "value")
else str(connector.connector_type)
)
# Resolve the allowlist from the service registry (if any).
svc_cfg = get_service_by_connector_type(ct)
allowed_tools = svc_cfg.allowed_tools if svc_cfg else []
readonly_tools = svc_cfg.readonly_tools if svc_cfg else frozenset()
# Build a prefix only when multiple accounts share the same type.
tool_name_prefix: str | None = None
if ct in multi_account_types and svc_cfg:
service_key = next(
(k for k, v in MCP_SERVICES.items() if v is svc_cfg),
None,
)
if service_key:
tool_name_prefix = f"{service_key}_{connector.id}"
transport = server_config.get("transport", "stdio") transport = server_config.get("transport", "stdio")
if transport in ("streamable-http", "http", "sse"): if transport in ("streamable-http", "http", "sse"):
@ -566,6 +661,9 @@ async def load_mcp_tools(
connector.name, connector.name,
server_config, server_config,
trusted_tools=trusted_tools, trusted_tools=trusted_tools,
allowed_tools=allowed_tools,
readonly_tools=readonly_tools,
tool_name_prefix=tool_name_prefix,
) )
else: else:
connector_tools = await _load_stdio_mcp_tools( connector_tools = await _load_stdio_mcp_tools(