mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-04-25 00:36:31 +02:00
add multi-account MCP tool disambiguation with prefix namespacing
This commit is contained in:
parent
9eb54bc4af
commit
f2d9e67ac2
1 changed files with 136 additions and 38 deletions
|
|
@ -16,18 +16,21 @@ clicking "Always Allow", which adds the tool name to the connector's
|
|||
|
||||
import logging
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.tools import StructuredTool
|
||||
from mcp import ClientSession
|
||||
from mcp.client.streamable_http import streamablehttp_client
|
||||
from pydantic import BaseModel, create_model
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import cast, select
|
||||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.agents.new_chat.tools.hitl import request_approval
|
||||
from app.agents.new_chat.tools.mcp_client import MCPClient
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||
from app.services.mcp_oauth.registry import MCP_SERVICES, get_service_by_connector_type
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -123,7 +126,7 @@ async def _create_mcp_tool_from_definition_stdio(
|
|||
)
|
||||
if hitl_result.rejected:
|
||||
return "Tool call rejected by user."
|
||||
call_kwargs = hitl_result.params
|
||||
call_kwargs = {k: v for k, v in hitl_result.params.items() if v is not None}
|
||||
|
||||
try:
|
||||
async with mcp_client.connect():
|
||||
|
|
@ -163,41 +166,57 @@ async def _create_mcp_tool_from_definition_http(
|
|||
connector_name: str = "",
|
||||
connector_id: int | None = None,
|
||||
trusted_tools: list[str] | None = None,
|
||||
readonly_tools: frozenset[str] | None = None,
|
||||
tool_name_prefix: str | None = None,
|
||||
) -> StructuredTool:
|
||||
"""Create a LangChain tool from an MCP tool definition (HTTP transport).
|
||||
|
||||
All MCP tools are unconditionally wrapped with HITL approval.
|
||||
``request_approval()`` is called OUTSIDE the try/except so that
|
||||
``GraphInterrupt`` propagates cleanly to LangGraph.
|
||||
Write tools are wrapped with HITL approval; read-only tools (listed in
|
||||
``readonly_tools``) execute immediately without user confirmation.
|
||||
|
||||
When ``tool_name_prefix`` is set (multi-account disambiguation), the
|
||||
tool exposed to the LLM gets a prefixed name (e.g. ``linear_25_list_issues``)
|
||||
but the actual MCP ``call_tool`` still uses the original name.
|
||||
"""
|
||||
tool_name = tool_def.get("name", "unnamed_tool")
|
||||
original_tool_name = tool_def.get("name", "unnamed_tool")
|
||||
tool_description = tool_def.get("description", "No description provided")
|
||||
input_schema = tool_def.get("input_schema", {"type": "object", "properties": {}})
|
||||
is_readonly = readonly_tools is not None and original_tool_name in readonly_tools
|
||||
|
||||
logger.info(f"MCP HTTP tool '{tool_name}' input schema: {input_schema}")
|
||||
exposed_name = (
|
||||
f"{tool_name_prefix}_{original_tool_name}"
|
||||
if tool_name_prefix
|
||||
else original_tool_name
|
||||
)
|
||||
if tool_name_prefix:
|
||||
tool_description = f"[Account: {connector_name}] {tool_description}"
|
||||
|
||||
input_model = _create_dynamic_input_model_from_schema(tool_name, input_schema)
|
||||
logger.info(f"MCP HTTP tool '{exposed_name}' input schema: {input_schema}")
|
||||
|
||||
input_model = _create_dynamic_input_model_from_schema(exposed_name, input_schema)
|
||||
|
||||
async def mcp_http_tool_call(**kwargs) -> str:
|
||||
"""Execute the MCP tool call via HTTP transport."""
|
||||
logger.info(f"MCP HTTP tool '{tool_name}' called with params: {kwargs}")
|
||||
logger.info(f"MCP HTTP tool '{exposed_name}' called with params: {kwargs}")
|
||||
|
||||
# HITL — OUTSIDE try/except so GraphInterrupt propagates to LangGraph
|
||||
hitl_result = request_approval(
|
||||
action_type="mcp_tool_call",
|
||||
tool_name=tool_name,
|
||||
params=kwargs,
|
||||
context={
|
||||
"mcp_server": connector_name,
|
||||
"tool_description": tool_description,
|
||||
"mcp_transport": "http",
|
||||
"mcp_connector_id": connector_id,
|
||||
},
|
||||
trusted_tools=trusted_tools,
|
||||
)
|
||||
if hitl_result.rejected:
|
||||
return "Tool call rejected by user."
|
||||
call_kwargs = hitl_result.params
|
||||
if is_readonly:
|
||||
call_kwargs = {k: v for k, v in kwargs.items() if v is not None}
|
||||
else:
|
||||
hitl_result = request_approval(
|
||||
action_type="mcp_tool_call",
|
||||
tool_name=exposed_name,
|
||||
params=kwargs,
|
||||
context={
|
||||
"mcp_server": connector_name,
|
||||
"tool_description": tool_description,
|
||||
"mcp_transport": "http",
|
||||
"mcp_connector_id": connector_id,
|
||||
},
|
||||
trusted_tools=trusted_tools,
|
||||
)
|
||||
if hitl_result.rejected:
|
||||
return "Tool call rejected by user."
|
||||
call_kwargs = {k: v for k, v in hitl_result.params.items() if v is not None}
|
||||
|
||||
try:
|
||||
async with (
|
||||
|
|
@ -205,7 +224,9 @@ async def _create_mcp_tool_from_definition_http(
|
|||
ClientSession(read, write) as session,
|
||||
):
|
||||
await session.initialize()
|
||||
response = await session.call_tool(tool_name, arguments=call_kwargs)
|
||||
response = await session.call_tool(
|
||||
original_tool_name, arguments=call_kwargs,
|
||||
)
|
||||
|
||||
result = []
|
||||
for content in response.content:
|
||||
|
|
@ -218,17 +239,17 @@ async def _create_mcp_tool_from_definition_http(
|
|||
|
||||
result_str = "\n".join(result) if result else ""
|
||||
logger.info(
|
||||
f"MCP HTTP tool '{tool_name}' succeeded: {result_str[:200]}"
|
||||
f"MCP HTTP tool '{exposed_name}' succeeded: {result_str[:200]}"
|
||||
)
|
||||
return result_str
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"MCP HTTP tool '{tool_name}' execution failed: {e!s}"
|
||||
error_msg = f"MCP HTTP tool '{exposed_name}' execution failed: {e!s}"
|
||||
logger.exception(error_msg)
|
||||
return f"Error: {error_msg}"
|
||||
|
||||
tool = StructuredTool(
|
||||
name=tool_name,
|
||||
name=exposed_name,
|
||||
description=tool_description,
|
||||
coroutine=mcp_http_tool_call,
|
||||
args_schema=input_model,
|
||||
|
|
@ -236,12 +257,14 @@ async def _create_mcp_tool_from_definition_http(
|
|||
"mcp_input_schema": input_schema,
|
||||
"mcp_transport": "http",
|
||||
"mcp_url": url,
|
||||
"hitl": True,
|
||||
"hitl": not is_readonly,
|
||||
"hitl_dedup_key": next(iter(input_schema.get("required", [])), None),
|
||||
"mcp_original_tool_name": original_tool_name,
|
||||
"mcp_connector_id": connector_id,
|
||||
},
|
||||
)
|
||||
|
||||
logger.info(f"Created MCP tool (HTTP): '{tool_name}'")
|
||||
logger.info(f"Created MCP tool (HTTP): '{exposed_name}'")
|
||||
return tool
|
||||
|
||||
|
||||
|
|
@ -309,8 +332,20 @@ async def _load_http_mcp_tools(
|
|||
connector_name: str,
|
||||
server_config: dict[str, Any],
|
||||
trusted_tools: list[str] | None = None,
|
||||
allowed_tools: list[str] | None = None,
|
||||
readonly_tools: frozenset[str] | None = None,
|
||||
tool_name_prefix: str | None = None,
|
||||
) -> list[StructuredTool]:
|
||||
"""Load tools from an HTTP-based MCP server."""
|
||||
"""Load tools from an HTTP-based MCP server.
|
||||
|
||||
Args:
|
||||
allowed_tools: If non-empty, only tools whose names appear in this
|
||||
list are loaded. Empty/None means load everything (used for
|
||||
user-managed generic MCP servers).
|
||||
readonly_tools: Tool names that skip HITL approval (read-only operations).
|
||||
tool_name_prefix: If set, each tool name is prefixed for multi-account
|
||||
disambiguation (e.g. ``linear_25``).
|
||||
"""
|
||||
tools: list[StructuredTool] = []
|
||||
|
||||
url = server_config.get("url")
|
||||
|
|
@ -327,6 +362,8 @@ async def _load_http_mcp_tools(
|
|||
)
|
||||
return tools
|
||||
|
||||
allowed_set = set(allowed_tools) if allowed_tools else None
|
||||
|
||||
try:
|
||||
async with (
|
||||
streamablehttp_client(url, headers=headers) as (read, write, _),
|
||||
|
|
@ -347,10 +384,21 @@ async def _load_http_mcp_tools(
|
|||
}
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Discovered {len(tool_definitions)} tools from HTTP MCP server "
|
||||
f"'{url}' (connector {connector_id})"
|
||||
)
|
||||
total_discovered = len(tool_definitions)
|
||||
|
||||
if allowed_set:
|
||||
tool_definitions = [
|
||||
td for td in tool_definitions if td["name"] in allowed_set
|
||||
]
|
||||
logger.info(
|
||||
"HTTP MCP server '%s' (connector %d): %d/%d tools after allowlist filter",
|
||||
url, connector_id, len(tool_definitions), total_discovered,
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
"Discovered %d tools from HTTP MCP server '%s' (connector %d) — no allowlist, loading all",
|
||||
total_discovered, url, connector_id,
|
||||
)
|
||||
|
||||
for tool_def in tool_definitions:
|
||||
try:
|
||||
|
|
@ -361,6 +409,8 @@ async def _load_http_mcp_tools(
|
|||
connector_name=connector_name,
|
||||
connector_id=connector_id,
|
||||
trusted_tools=trusted_tools,
|
||||
readonly_tools=readonly_tools,
|
||||
tool_name_prefix=tool_name_prefix,
|
||||
)
|
||||
tools.append(tool)
|
||||
except Exception as e:
|
||||
|
|
@ -532,15 +582,39 @@ async def load_mcp_tools(
|
|||
try:
|
||||
# Find all connectors with MCP server config: generic MCP_CONNECTOR type
|
||||
# and service-specific types (LINEAR_CONNECTOR, etc.) created via MCP OAuth.
|
||||
# Cast JSON -> JSONB so we can use has_key to filter by the presence of "server_config".
|
||||
result = await session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.config.has_key("server_config"), # noqa: W601
|
||||
cast(SearchSourceConnector.config, JSONB).has_key("server_config"), # noqa: W601
|
||||
),
|
||||
)
|
||||
|
||||
connectors = list(result.scalars())
|
||||
|
||||
# Group connectors by type to detect multi-account scenarios.
|
||||
# When >1 connector shares the same type, tool names would collide
|
||||
# so we prefix them with "{service_key}_{connector_id}_".
|
||||
type_groups: dict[str, list[SearchSourceConnector]] = defaultdict(list)
|
||||
for connector in connectors:
|
||||
ct = (
|
||||
connector.connector_type.value
|
||||
if hasattr(connector.connector_type, "value")
|
||||
else str(connector.connector_type)
|
||||
)
|
||||
type_groups[ct].append(connector)
|
||||
|
||||
multi_account_types: set[str] = {
|
||||
ct for ct, group in type_groups.items() if len(group) > 1
|
||||
}
|
||||
if multi_account_types:
|
||||
logger.info(
|
||||
"Multi-account detected for connector types: %s",
|
||||
multi_account_types,
|
||||
)
|
||||
|
||||
tools: list[StructuredTool] = []
|
||||
for connector in result.scalars():
|
||||
for connector in connectors:
|
||||
try:
|
||||
cfg = connector.config or {}
|
||||
server_config = cfg.get("server_config", {})
|
||||
|
|
@ -558,6 +632,27 @@ async def load_mcp_tools(
|
|||
session, connector, cfg, server_config,
|
||||
)
|
||||
|
||||
ct = (
|
||||
connector.connector_type.value
|
||||
if hasattr(connector.connector_type, "value")
|
||||
else str(connector.connector_type)
|
||||
)
|
||||
|
||||
# Resolve the allowlist from the service registry (if any).
|
||||
svc_cfg = get_service_by_connector_type(ct)
|
||||
allowed_tools = svc_cfg.allowed_tools if svc_cfg else []
|
||||
readonly_tools = svc_cfg.readonly_tools if svc_cfg else frozenset()
|
||||
|
||||
# Build a prefix only when multiple accounts share the same type.
|
||||
tool_name_prefix: str | None = None
|
||||
if ct in multi_account_types and svc_cfg:
|
||||
service_key = next(
|
||||
(k for k, v in MCP_SERVICES.items() if v is svc_cfg),
|
||||
None,
|
||||
)
|
||||
if service_key:
|
||||
tool_name_prefix = f"{service_key}_{connector.id}"
|
||||
|
||||
transport = server_config.get("transport", "stdio")
|
||||
|
||||
if transport in ("streamable-http", "http", "sse"):
|
||||
|
|
@ -566,6 +661,9 @@ async def load_mcp_tools(
|
|||
connector.name,
|
||||
server_config,
|
||||
trusted_tools=trusted_tools,
|
||||
allowed_tools=allowed_tools,
|
||||
readonly_tools=readonly_tools,
|
||||
tool_name_prefix=tool_name_prefix,
|
||||
)
|
||||
else:
|
||||
connector_tools = await _load_stdio_mcp_tools(
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue