mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-04-25 00:36:31 +02:00
perf: parallelize MCP connector discovery with per-connector timeout
This commit is contained in:
parent
e3172dc282
commit
0b3551bd06
1 changed files with 57 additions and 28 deletions
|
|
@ -16,6 +16,7 @@ clicking "Always Allow", which adds the tool name to the connector's
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from collections import defaultdict
|
||||
|
|
@ -41,6 +42,7 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
_MCP_CACHE_TTL_SECONDS = 300 # 5 minutes
|
||||
_MCP_CACHE_MAX_SIZE = 50
|
||||
_MCP_DISCOVERY_TIMEOUT_SECONDS = 30
|
||||
_mcp_tools_cache: dict[int, tuple[float, list[StructuredTool]]] = {}
|
||||
|
||||
|
||||
|
|
@ -869,7 +871,7 @@ async def load_mcp_tools(
|
|||
multi_account_types,
|
||||
)
|
||||
|
||||
tools: list[StructuredTool] = []
|
||||
discovery_tasks: list[dict[str, Any]] = []
|
||||
for connector in connectors:
|
||||
try:
|
||||
cfg = connector.config or {}
|
||||
|
|
@ -882,14 +884,10 @@ async def load_mcp_tools(
|
|||
)
|
||||
continue
|
||||
|
||||
# For MCP OAuth connectors: refresh if needed, then decrypt the
|
||||
# access token and inject it into headers at runtime. The DB
|
||||
# intentionally does NOT store plaintext tokens in server_config.
|
||||
if cfg.get("mcp_oauth"):
|
||||
server_config = await _maybe_refresh_mcp_oauth_token(
|
||||
session, connector, cfg, server_config,
|
||||
)
|
||||
# Re-read cfg after potential refresh (connector was reloaded from DB).
|
||||
cfg = connector.config or {}
|
||||
server_config = _inject_oauth_headers(cfg, server_config)
|
||||
if server_config is None:
|
||||
|
|
@ -911,7 +909,6 @@ async def load_mcp_tools(
|
|||
allowed_tools = svc_cfg.allowed_tools if svc_cfg else []
|
||||
readonly_tools = svc_cfg.readonly_tools if svc_cfg else frozenset()
|
||||
|
||||
# Build a prefix only when multiple accounts share the same type.
|
||||
tool_name_prefix: str | None = None
|
||||
if ct in multi_account_types and svc_cfg:
|
||||
service_key = next(
|
||||
|
|
@ -921,34 +918,66 @@ async def load_mcp_tools(
|
|||
if service_key:
|
||||
tool_name_prefix = f"{service_key}_{connector.id}"
|
||||
|
||||
transport = server_config.get("transport", "stdio")
|
||||
|
||||
if transport in ("streamable-http", "http", "sse"):
|
||||
connector_tools = await _load_http_mcp_tools(
|
||||
connector.id,
|
||||
connector.name,
|
||||
server_config,
|
||||
trusted_tools=trusted_tools,
|
||||
allowed_tools=allowed_tools,
|
||||
readonly_tools=readonly_tools,
|
||||
tool_name_prefix=tool_name_prefix,
|
||||
)
|
||||
else:
|
||||
connector_tools = await _load_stdio_mcp_tools(
|
||||
connector.id,
|
||||
connector.name,
|
||||
server_config,
|
||||
trusted_tools=trusted_tools,
|
||||
)
|
||||
|
||||
tools.extend(connector_tools)
|
||||
discovery_tasks.append({
|
||||
"connector_id": connector.id,
|
||||
"connector_name": connector.name,
|
||||
"server_config": server_config,
|
||||
"trusted_tools": trusted_tools,
|
||||
"allowed_tools": allowed_tools,
|
||||
"readonly_tools": readonly_tools,
|
||||
"tool_name_prefix": tool_name_prefix,
|
||||
"transport": server_config.get("transport", "stdio"),
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
"Failed to load tools from MCP connector %d: %s",
|
||||
"Failed to prepare MCP connector %d: %s",
|
||||
connector.id, e,
|
||||
)
|
||||
|
||||
async def _discover_one(task: dict[str, Any]) -> list[StructuredTool]:
|
||||
try:
|
||||
if task["transport"] in ("streamable-http", "http", "sse"):
|
||||
return await asyncio.wait_for(
|
||||
_load_http_mcp_tools(
|
||||
task["connector_id"],
|
||||
task["connector_name"],
|
||||
task["server_config"],
|
||||
trusted_tools=task["trusted_tools"],
|
||||
allowed_tools=task["allowed_tools"],
|
||||
readonly_tools=task["readonly_tools"],
|
||||
tool_name_prefix=task["tool_name_prefix"],
|
||||
),
|
||||
timeout=_MCP_DISCOVERY_TIMEOUT_SECONDS,
|
||||
)
|
||||
else:
|
||||
return await asyncio.wait_for(
|
||||
_load_stdio_mcp_tools(
|
||||
task["connector_id"],
|
||||
task["connector_name"],
|
||||
task["server_config"],
|
||||
trusted_tools=task["trusted_tools"],
|
||||
),
|
||||
timeout=_MCP_DISCOVERY_TIMEOUT_SECONDS,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(
|
||||
"MCP connector %d timed out after %ds during discovery",
|
||||
task["connector_id"], _MCP_DISCOVERY_TIMEOUT_SECONDS,
|
||||
)
|
||||
return []
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
"Failed to load tools from MCP connector %d: %s",
|
||||
task["connector_id"], e,
|
||||
)
|
||||
return []
|
||||
|
||||
results = await asyncio.gather(*[_discover_one(t) for t in discovery_tasks])
|
||||
tools: list[StructuredTool] = [
|
||||
tool for sublist in results for tool in sublist
|
||||
]
|
||||
|
||||
_mcp_tools_cache[search_space_id] = (now, tools)
|
||||
|
||||
if len(_mcp_tools_cache) > _MCP_CACHE_MAX_SIZE:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue