mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-04-25 00:36:31 +02:00
add multi-account MCP tool disambiguation with prefix namespacing
This commit is contained in:
parent
9eb54bc4af
commit
f2d9e67ac2
1 changed files with 136 additions and 38 deletions
|
|
@ -16,18 +16,21 @@ clicking "Always Allow", which adds the tool name to the connector's
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
|
from collections import defaultdict
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from langchain_core.tools import StructuredTool
|
from langchain_core.tools import StructuredTool
|
||||||
from mcp import ClientSession
|
from mcp import ClientSession
|
||||||
from mcp.client.streamable_http import streamablehttp_client
|
from mcp.client.streamable_http import streamablehttp_client
|
||||||
from pydantic import BaseModel, create_model
|
from pydantic import BaseModel, create_model
|
||||||
from sqlalchemy import select
|
from sqlalchemy import cast, select
|
||||||
|
from sqlalchemy.dialects.postgresql import JSONB
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.agents.new_chat.tools.hitl import request_approval
|
from app.agents.new_chat.tools.hitl import request_approval
|
||||||
from app.agents.new_chat.tools.mcp_client import MCPClient
|
from app.agents.new_chat.tools.mcp_client import MCPClient
|
||||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||||
|
from app.services.mcp_oauth.registry import MCP_SERVICES, get_service_by_connector_type
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
@ -123,7 +126,7 @@ async def _create_mcp_tool_from_definition_stdio(
|
||||||
)
|
)
|
||||||
if hitl_result.rejected:
|
if hitl_result.rejected:
|
||||||
return "Tool call rejected by user."
|
return "Tool call rejected by user."
|
||||||
call_kwargs = hitl_result.params
|
call_kwargs = {k: v for k, v in hitl_result.params.items() if v is not None}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
async with mcp_client.connect():
|
async with mcp_client.connect():
|
||||||
|
|
@ -163,41 +166,57 @@ async def _create_mcp_tool_from_definition_http(
|
||||||
connector_name: str = "",
|
connector_name: str = "",
|
||||||
connector_id: int | None = None,
|
connector_id: int | None = None,
|
||||||
trusted_tools: list[str] | None = None,
|
trusted_tools: list[str] | None = None,
|
||||||
|
readonly_tools: frozenset[str] | None = None,
|
||||||
|
tool_name_prefix: str | None = None,
|
||||||
) -> StructuredTool:
|
) -> StructuredTool:
|
||||||
"""Create a LangChain tool from an MCP tool definition (HTTP transport).
|
"""Create a LangChain tool from an MCP tool definition (HTTP transport).
|
||||||
|
|
||||||
All MCP tools are unconditionally wrapped with HITL approval.
|
Write tools are wrapped with HITL approval; read-only tools (listed in
|
||||||
``request_approval()`` is called OUTSIDE the try/except so that
|
``readonly_tools``) execute immediately without user confirmation.
|
||||||
``GraphInterrupt`` propagates cleanly to LangGraph.
|
|
||||||
|
When ``tool_name_prefix`` is set (multi-account disambiguation), the
|
||||||
|
tool exposed to the LLM gets a prefixed name (e.g. ``linear_25_list_issues``)
|
||||||
|
but the actual MCP ``call_tool`` still uses the original name.
|
||||||
"""
|
"""
|
||||||
tool_name = tool_def.get("name", "unnamed_tool")
|
original_tool_name = tool_def.get("name", "unnamed_tool")
|
||||||
tool_description = tool_def.get("description", "No description provided")
|
tool_description = tool_def.get("description", "No description provided")
|
||||||
input_schema = tool_def.get("input_schema", {"type": "object", "properties": {}})
|
input_schema = tool_def.get("input_schema", {"type": "object", "properties": {}})
|
||||||
|
is_readonly = readonly_tools is not None and original_tool_name in readonly_tools
|
||||||
|
|
||||||
logger.info(f"MCP HTTP tool '{tool_name}' input schema: {input_schema}")
|
exposed_name = (
|
||||||
|
f"{tool_name_prefix}_{original_tool_name}"
|
||||||
|
if tool_name_prefix
|
||||||
|
else original_tool_name
|
||||||
|
)
|
||||||
|
if tool_name_prefix:
|
||||||
|
tool_description = f"[Account: {connector_name}] {tool_description}"
|
||||||
|
|
||||||
input_model = _create_dynamic_input_model_from_schema(tool_name, input_schema)
|
logger.info(f"MCP HTTP tool '{exposed_name}' input schema: {input_schema}")
|
||||||
|
|
||||||
|
input_model = _create_dynamic_input_model_from_schema(exposed_name, input_schema)
|
||||||
|
|
||||||
async def mcp_http_tool_call(**kwargs) -> str:
|
async def mcp_http_tool_call(**kwargs) -> str:
|
||||||
"""Execute the MCP tool call via HTTP transport."""
|
"""Execute the MCP tool call via HTTP transport."""
|
||||||
logger.info(f"MCP HTTP tool '{tool_name}' called with params: {kwargs}")
|
logger.info(f"MCP HTTP tool '{exposed_name}' called with params: {kwargs}")
|
||||||
|
|
||||||
# HITL — OUTSIDE try/except so GraphInterrupt propagates to LangGraph
|
if is_readonly:
|
||||||
hitl_result = request_approval(
|
call_kwargs = {k: v for k, v in kwargs.items() if v is not None}
|
||||||
action_type="mcp_tool_call",
|
else:
|
||||||
tool_name=tool_name,
|
hitl_result = request_approval(
|
||||||
params=kwargs,
|
action_type="mcp_tool_call",
|
||||||
context={
|
tool_name=exposed_name,
|
||||||
"mcp_server": connector_name,
|
params=kwargs,
|
||||||
"tool_description": tool_description,
|
context={
|
||||||
"mcp_transport": "http",
|
"mcp_server": connector_name,
|
||||||
"mcp_connector_id": connector_id,
|
"tool_description": tool_description,
|
||||||
},
|
"mcp_transport": "http",
|
||||||
trusted_tools=trusted_tools,
|
"mcp_connector_id": connector_id,
|
||||||
)
|
},
|
||||||
if hitl_result.rejected:
|
trusted_tools=trusted_tools,
|
||||||
return "Tool call rejected by user."
|
)
|
||||||
call_kwargs = hitl_result.params
|
if hitl_result.rejected:
|
||||||
|
return "Tool call rejected by user."
|
||||||
|
call_kwargs = {k: v for k, v in hitl_result.params.items() if v is not None}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
async with (
|
async with (
|
||||||
|
|
@ -205,7 +224,9 @@ async def _create_mcp_tool_from_definition_http(
|
||||||
ClientSession(read, write) as session,
|
ClientSession(read, write) as session,
|
||||||
):
|
):
|
||||||
await session.initialize()
|
await session.initialize()
|
||||||
response = await session.call_tool(tool_name, arguments=call_kwargs)
|
response = await session.call_tool(
|
||||||
|
original_tool_name, arguments=call_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
result = []
|
result = []
|
||||||
for content in response.content:
|
for content in response.content:
|
||||||
|
|
@ -218,17 +239,17 @@ async def _create_mcp_tool_from_definition_http(
|
||||||
|
|
||||||
result_str = "\n".join(result) if result else ""
|
result_str = "\n".join(result) if result else ""
|
||||||
logger.info(
|
logger.info(
|
||||||
f"MCP HTTP tool '{tool_name}' succeeded: {result_str[:200]}"
|
f"MCP HTTP tool '{exposed_name}' succeeded: {result_str[:200]}"
|
||||||
)
|
)
|
||||||
return result_str
|
return result_str
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_msg = f"MCP HTTP tool '{tool_name}' execution failed: {e!s}"
|
error_msg = f"MCP HTTP tool '{exposed_name}' execution failed: {e!s}"
|
||||||
logger.exception(error_msg)
|
logger.exception(error_msg)
|
||||||
return f"Error: {error_msg}"
|
return f"Error: {error_msg}"
|
||||||
|
|
||||||
tool = StructuredTool(
|
tool = StructuredTool(
|
||||||
name=tool_name,
|
name=exposed_name,
|
||||||
description=tool_description,
|
description=tool_description,
|
||||||
coroutine=mcp_http_tool_call,
|
coroutine=mcp_http_tool_call,
|
||||||
args_schema=input_model,
|
args_schema=input_model,
|
||||||
|
|
@ -236,12 +257,14 @@ async def _create_mcp_tool_from_definition_http(
|
||||||
"mcp_input_schema": input_schema,
|
"mcp_input_schema": input_schema,
|
||||||
"mcp_transport": "http",
|
"mcp_transport": "http",
|
||||||
"mcp_url": url,
|
"mcp_url": url,
|
||||||
"hitl": True,
|
"hitl": not is_readonly,
|
||||||
"hitl_dedup_key": next(iter(input_schema.get("required", [])), None),
|
"hitl_dedup_key": next(iter(input_schema.get("required", [])), None),
|
||||||
|
"mcp_original_tool_name": original_tool_name,
|
||||||
|
"mcp_connector_id": connector_id,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f"Created MCP tool (HTTP): '{tool_name}'")
|
logger.info(f"Created MCP tool (HTTP): '{exposed_name}'")
|
||||||
return tool
|
return tool
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -309,8 +332,20 @@ async def _load_http_mcp_tools(
|
||||||
connector_name: str,
|
connector_name: str,
|
||||||
server_config: dict[str, Any],
|
server_config: dict[str, Any],
|
||||||
trusted_tools: list[str] | None = None,
|
trusted_tools: list[str] | None = None,
|
||||||
|
allowed_tools: list[str] | None = None,
|
||||||
|
readonly_tools: frozenset[str] | None = None,
|
||||||
|
tool_name_prefix: str | None = None,
|
||||||
) -> list[StructuredTool]:
|
) -> list[StructuredTool]:
|
||||||
"""Load tools from an HTTP-based MCP server."""
|
"""Load tools from an HTTP-based MCP server.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
allowed_tools: If non-empty, only tools whose names appear in this
|
||||||
|
list are loaded. Empty/None means load everything (used for
|
||||||
|
user-managed generic MCP servers).
|
||||||
|
readonly_tools: Tool names that skip HITL approval (read-only operations).
|
||||||
|
tool_name_prefix: If set, each tool name is prefixed for multi-account
|
||||||
|
disambiguation (e.g. ``linear_25``).
|
||||||
|
"""
|
||||||
tools: list[StructuredTool] = []
|
tools: list[StructuredTool] = []
|
||||||
|
|
||||||
url = server_config.get("url")
|
url = server_config.get("url")
|
||||||
|
|
@ -327,6 +362,8 @@ async def _load_http_mcp_tools(
|
||||||
)
|
)
|
||||||
return tools
|
return tools
|
||||||
|
|
||||||
|
allowed_set = set(allowed_tools) if allowed_tools else None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
async with (
|
async with (
|
||||||
streamablehttp_client(url, headers=headers) as (read, write, _),
|
streamablehttp_client(url, headers=headers) as (read, write, _),
|
||||||
|
|
@ -347,10 +384,21 @@ async def _load_http_mcp_tools(
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(
|
total_discovered = len(tool_definitions)
|
||||||
f"Discovered {len(tool_definitions)} tools from HTTP MCP server "
|
|
||||||
f"'{url}' (connector {connector_id})"
|
if allowed_set:
|
||||||
)
|
tool_definitions = [
|
||||||
|
td for td in tool_definitions if td["name"] in allowed_set
|
||||||
|
]
|
||||||
|
logger.info(
|
||||||
|
"HTTP MCP server '%s' (connector %d): %d/%d tools after allowlist filter",
|
||||||
|
url, connector_id, len(tool_definitions), total_discovered,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.info(
|
||||||
|
"Discovered %d tools from HTTP MCP server '%s' (connector %d) — no allowlist, loading all",
|
||||||
|
total_discovered, url, connector_id,
|
||||||
|
)
|
||||||
|
|
||||||
for tool_def in tool_definitions:
|
for tool_def in tool_definitions:
|
||||||
try:
|
try:
|
||||||
|
|
@ -361,6 +409,8 @@ async def _load_http_mcp_tools(
|
||||||
connector_name=connector_name,
|
connector_name=connector_name,
|
||||||
connector_id=connector_id,
|
connector_id=connector_id,
|
||||||
trusted_tools=trusted_tools,
|
trusted_tools=trusted_tools,
|
||||||
|
readonly_tools=readonly_tools,
|
||||||
|
tool_name_prefix=tool_name_prefix,
|
||||||
)
|
)
|
||||||
tools.append(tool)
|
tools.append(tool)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
@ -532,15 +582,39 @@ async def load_mcp_tools(
|
||||||
try:
|
try:
|
||||||
# Find all connectors with MCP server config: generic MCP_CONNECTOR type
|
# Find all connectors with MCP server config: generic MCP_CONNECTOR type
|
||||||
# and service-specific types (LINEAR_CONNECTOR, etc.) created via MCP OAuth.
|
# and service-specific types (LINEAR_CONNECTOR, etc.) created via MCP OAuth.
|
||||||
|
# Cast JSON -> JSONB so we can use has_key to filter by the presence of "server_config".
|
||||||
result = await session.execute(
|
result = await session.execute(
|
||||||
select(SearchSourceConnector).filter(
|
select(SearchSourceConnector).filter(
|
||||||
SearchSourceConnector.search_space_id == search_space_id,
|
SearchSourceConnector.search_space_id == search_space_id,
|
||||||
SearchSourceConnector.config.has_key("server_config"), # noqa: W601
|
cast(SearchSourceConnector.config, JSONB).has_key("server_config"), # noqa: W601
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
connectors = list(result.scalars())
|
||||||
|
|
||||||
|
# Group connectors by type to detect multi-account scenarios.
|
||||||
|
# When >1 connector shares the same type, tool names would collide
|
||||||
|
# so we prefix them with "{service_key}_{connector_id}_".
|
||||||
|
type_groups: dict[str, list[SearchSourceConnector]] = defaultdict(list)
|
||||||
|
for connector in connectors:
|
||||||
|
ct = (
|
||||||
|
connector.connector_type.value
|
||||||
|
if hasattr(connector.connector_type, "value")
|
||||||
|
else str(connector.connector_type)
|
||||||
|
)
|
||||||
|
type_groups[ct].append(connector)
|
||||||
|
|
||||||
|
multi_account_types: set[str] = {
|
||||||
|
ct for ct, group in type_groups.items() if len(group) > 1
|
||||||
|
}
|
||||||
|
if multi_account_types:
|
||||||
|
logger.info(
|
||||||
|
"Multi-account detected for connector types: %s",
|
||||||
|
multi_account_types,
|
||||||
|
)
|
||||||
|
|
||||||
tools: list[StructuredTool] = []
|
tools: list[StructuredTool] = []
|
||||||
for connector in result.scalars():
|
for connector in connectors:
|
||||||
try:
|
try:
|
||||||
cfg = connector.config or {}
|
cfg = connector.config or {}
|
||||||
server_config = cfg.get("server_config", {})
|
server_config = cfg.get("server_config", {})
|
||||||
|
|
@ -558,6 +632,27 @@ async def load_mcp_tools(
|
||||||
session, connector, cfg, server_config,
|
session, connector, cfg, server_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
ct = (
|
||||||
|
connector.connector_type.value
|
||||||
|
if hasattr(connector.connector_type, "value")
|
||||||
|
else str(connector.connector_type)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Resolve the allowlist from the service registry (if any).
|
||||||
|
svc_cfg = get_service_by_connector_type(ct)
|
||||||
|
allowed_tools = svc_cfg.allowed_tools if svc_cfg else []
|
||||||
|
readonly_tools = svc_cfg.readonly_tools if svc_cfg else frozenset()
|
||||||
|
|
||||||
|
# Build a prefix only when multiple accounts share the same type.
|
||||||
|
tool_name_prefix: str | None = None
|
||||||
|
if ct in multi_account_types and svc_cfg:
|
||||||
|
service_key = next(
|
||||||
|
(k for k, v in MCP_SERVICES.items() if v is svc_cfg),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
if service_key:
|
||||||
|
tool_name_prefix = f"{service_key}_{connector.id}"
|
||||||
|
|
||||||
transport = server_config.get("transport", "stdio")
|
transport = server_config.get("transport", "stdio")
|
||||||
|
|
||||||
if transport in ("streamable-http", "http", "sse"):
|
if transport in ("streamable-http", "http", "sse"):
|
||||||
|
|
@ -566,6 +661,9 @@ async def load_mcp_tools(
|
||||||
connector.name,
|
connector.name,
|
||||||
server_config,
|
server_config,
|
||||||
trusted_tools=trusted_tools,
|
trusted_tools=trusted_tools,
|
||||||
|
allowed_tools=allowed_tools,
|
||||||
|
readonly_tools=readonly_tools,
|
||||||
|
tool_name_prefix=tool_name_prefix,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
connector_tools = await _load_stdio_mcp_tools(
|
connector_tools = await _load_stdio_mcp_tools(
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue