add multi-account MCP tool disambiguation with prefix namespacing

This commit is contained in:
CREDO23 2026-04-22 18:57:43 +02:00
parent 9eb54bc4af
commit f2d9e67ac2

View file

@ -16,18 +16,21 @@ clicking "Always Allow", which adds the tool name to the connector's
import logging
import time
from collections import defaultdict
from typing import Any
from langchain_core.tools import StructuredTool
from mcp import ClientSession
from mcp.client.streamable_http import streamablehttp_client
from pydantic import BaseModel, create_model
from sqlalchemy import select
from sqlalchemy import cast, select
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.new_chat.tools.hitl import request_approval
from app.agents.new_chat.tools.mcp_client import MCPClient
from app.db import SearchSourceConnector, SearchSourceConnectorType
from app.services.mcp_oauth.registry import MCP_SERVICES, get_service_by_connector_type
logger = logging.getLogger(__name__)
@ -123,7 +126,7 @@ async def _create_mcp_tool_from_definition_stdio(
)
if hitl_result.rejected:
return "Tool call rejected by user."
call_kwargs = hitl_result.params
call_kwargs = {k: v for k, v in hitl_result.params.items() if v is not None}
try:
async with mcp_client.connect():
@ -163,41 +166,57 @@ async def _create_mcp_tool_from_definition_http(
connector_name: str = "",
connector_id: int | None = None,
trusted_tools: list[str] | None = None,
readonly_tools: frozenset[str] | None = None,
tool_name_prefix: str | None = None,
) -> StructuredTool:
"""Create a LangChain tool from an MCP tool definition (HTTP transport).
All MCP tools are unconditionally wrapped with HITL approval.
``request_approval()`` is called OUTSIDE the try/except so that
``GraphInterrupt`` propagates cleanly to LangGraph.
Write tools are wrapped with HITL approval; read-only tools (listed in
``readonly_tools``) execute immediately without user confirmation.
When ``tool_name_prefix`` is set (multi-account disambiguation), the
tool exposed to the LLM gets a prefixed name (e.g. ``linear_25_list_issues``)
but the actual MCP ``call_tool`` still uses the original name.
"""
tool_name = tool_def.get("name", "unnamed_tool")
original_tool_name = tool_def.get("name", "unnamed_tool")
tool_description = tool_def.get("description", "No description provided")
input_schema = tool_def.get("input_schema", {"type": "object", "properties": {}})
is_readonly = readonly_tools is not None and original_tool_name in readonly_tools
logger.info(f"MCP HTTP tool '{tool_name}' input schema: {input_schema}")
exposed_name = (
f"{tool_name_prefix}_{original_tool_name}"
if tool_name_prefix
else original_tool_name
)
if tool_name_prefix:
tool_description = f"[Account: {connector_name}] {tool_description}"
input_model = _create_dynamic_input_model_from_schema(tool_name, input_schema)
logger.info(f"MCP HTTP tool '{exposed_name}' input schema: {input_schema}")
input_model = _create_dynamic_input_model_from_schema(exposed_name, input_schema)
async def mcp_http_tool_call(**kwargs) -> str:
"""Execute the MCP tool call via HTTP transport."""
logger.info(f"MCP HTTP tool '{tool_name}' called with params: {kwargs}")
logger.info(f"MCP HTTP tool '{exposed_name}' called with params: {kwargs}")
# HITL — OUTSIDE try/except so GraphInterrupt propagates to LangGraph
hitl_result = request_approval(
action_type="mcp_tool_call",
tool_name=tool_name,
params=kwargs,
context={
"mcp_server": connector_name,
"tool_description": tool_description,
"mcp_transport": "http",
"mcp_connector_id": connector_id,
},
trusted_tools=trusted_tools,
)
if hitl_result.rejected:
return "Tool call rejected by user."
call_kwargs = hitl_result.params
if is_readonly:
call_kwargs = {k: v for k, v in kwargs.items() if v is not None}
else:
hitl_result = request_approval(
action_type="mcp_tool_call",
tool_name=exposed_name,
params=kwargs,
context={
"mcp_server": connector_name,
"tool_description": tool_description,
"mcp_transport": "http",
"mcp_connector_id": connector_id,
},
trusted_tools=trusted_tools,
)
if hitl_result.rejected:
return "Tool call rejected by user."
call_kwargs = {k: v for k, v in hitl_result.params.items() if v is not None}
try:
async with (
@ -205,7 +224,9 @@ async def _create_mcp_tool_from_definition_http(
ClientSession(read, write) as session,
):
await session.initialize()
response = await session.call_tool(tool_name, arguments=call_kwargs)
response = await session.call_tool(
original_tool_name, arguments=call_kwargs,
)
result = []
for content in response.content:
@ -218,17 +239,17 @@ async def _create_mcp_tool_from_definition_http(
result_str = "\n".join(result) if result else ""
logger.info(
f"MCP HTTP tool '{tool_name}' succeeded: {result_str[:200]}"
f"MCP HTTP tool '{exposed_name}' succeeded: {result_str[:200]}"
)
return result_str
except Exception as e:
error_msg = f"MCP HTTP tool '{tool_name}' execution failed: {e!s}"
error_msg = f"MCP HTTP tool '{exposed_name}' execution failed: {e!s}"
logger.exception(error_msg)
return f"Error: {error_msg}"
tool = StructuredTool(
name=tool_name,
name=exposed_name,
description=tool_description,
coroutine=mcp_http_tool_call,
args_schema=input_model,
@ -236,12 +257,14 @@ async def _create_mcp_tool_from_definition_http(
"mcp_input_schema": input_schema,
"mcp_transport": "http",
"mcp_url": url,
"hitl": True,
"hitl": not is_readonly,
"hitl_dedup_key": next(iter(input_schema.get("required", [])), None),
"mcp_original_tool_name": original_tool_name,
"mcp_connector_id": connector_id,
},
)
logger.info(f"Created MCP tool (HTTP): '{tool_name}'")
logger.info(f"Created MCP tool (HTTP): '{exposed_name}'")
return tool
@ -309,8 +332,20 @@ async def _load_http_mcp_tools(
connector_name: str,
server_config: dict[str, Any],
trusted_tools: list[str] | None = None,
allowed_tools: list[str] | None = None,
readonly_tools: frozenset[str] | None = None,
tool_name_prefix: str | None = None,
) -> list[StructuredTool]:
"""Load tools from an HTTP-based MCP server."""
"""Load tools from an HTTP-based MCP server.
Args:
allowed_tools: If non-empty, only tools whose names appear in this
list are loaded. Empty/None means load everything (used for
user-managed generic MCP servers).
readonly_tools: Tool names that skip HITL approval (read-only operations).
tool_name_prefix: If set, each tool name is prefixed for multi-account
disambiguation (e.g. ``linear_25``).
"""
tools: list[StructuredTool] = []
url = server_config.get("url")
@ -327,6 +362,8 @@ async def _load_http_mcp_tools(
)
return tools
allowed_set = set(allowed_tools) if allowed_tools else None
try:
async with (
streamablehttp_client(url, headers=headers) as (read, write, _),
@ -347,10 +384,21 @@ async def _load_http_mcp_tools(
}
)
logger.info(
f"Discovered {len(tool_definitions)} tools from HTTP MCP server "
f"'{url}' (connector {connector_id})"
)
total_discovered = len(tool_definitions)
if allowed_set:
tool_definitions = [
td for td in tool_definitions if td["name"] in allowed_set
]
logger.info(
"HTTP MCP server '%s' (connector %d): %d/%d tools after allowlist filter",
url, connector_id, len(tool_definitions), total_discovered,
)
else:
logger.info(
"Discovered %d tools from HTTP MCP server '%s' (connector %d) — no allowlist, loading all",
total_discovered, url, connector_id,
)
for tool_def in tool_definitions:
try:
@ -361,6 +409,8 @@ async def _load_http_mcp_tools(
connector_name=connector_name,
connector_id=connector_id,
trusted_tools=trusted_tools,
readonly_tools=readonly_tools,
tool_name_prefix=tool_name_prefix,
)
tools.append(tool)
except Exception as e:
@ -532,15 +582,39 @@ async def load_mcp_tools(
try:
# Find all connectors with MCP server config: generic MCP_CONNECTOR type
# and service-specific types (LINEAR_CONNECTOR, etc.) created via MCP OAuth.
# Cast JSON -> JSONB so we can use has_key to filter by the presence of "server_config".
result = await session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.config.has_key("server_config"), # noqa: W601
cast(SearchSourceConnector.config, JSONB).has_key("server_config"), # noqa: W601
),
)
connectors = list(result.scalars())
# Group connectors by type to detect multi-account scenarios.
# When >1 connector shares the same type, tool names would collide
# so we prefix them with "{service_key}_{connector_id}_".
type_groups: dict[str, list[SearchSourceConnector]] = defaultdict(list)
for connector in connectors:
ct = (
connector.connector_type.value
if hasattr(connector.connector_type, "value")
else str(connector.connector_type)
)
type_groups[ct].append(connector)
multi_account_types: set[str] = {
ct for ct, group in type_groups.items() if len(group) > 1
}
if multi_account_types:
logger.info(
"Multi-account detected for connector types: %s",
multi_account_types,
)
tools: list[StructuredTool] = []
for connector in result.scalars():
for connector in connectors:
try:
cfg = connector.config or {}
server_config = cfg.get("server_config", {})
@ -558,6 +632,27 @@ async def load_mcp_tools(
session, connector, cfg, server_config,
)
ct = (
connector.connector_type.value
if hasattr(connector.connector_type, "value")
else str(connector.connector_type)
)
# Resolve the allowlist from the service registry (if any).
svc_cfg = get_service_by_connector_type(ct)
allowed_tools = svc_cfg.allowed_tools if svc_cfg else []
readonly_tools = svc_cfg.readonly_tools if svc_cfg else frozenset()
# Build a prefix only when multiple accounts share the same type.
tool_name_prefix: str | None = None
if ct in multi_account_types and svc_cfg:
service_key = next(
(k for k, v in MCP_SERVICES.items() if v is svc_cfg),
None,
)
if service_key:
tool_name_prefix = f"{service_key}_{connector.id}"
transport = server_config.get("transport", "stdio")
if transport in ("streamable-http", "http", "sse"):
@ -566,6 +661,9 @@ async def load_mcp_tools(
connector.name,
server_config,
trusted_tools=trusted_tools,
allowed_tools=allowed_tools,
readonly_tools=readonly_tools,
tool_name_prefix=tool_name_prefix,
)
else:
connector_tools = await _load_stdio_mcp_tools(