perf: parallelize MCP connector discovery with per-connector timeout

This commit is contained in:
CREDO23 2026-04-23 08:40:06 +02:00
parent e3172dc282
commit 0b3551bd06

View file

@ -16,6 +16,7 @@ clicking "Always Allow", which adds the tool name to the connector's
from __future__ import annotations
import asyncio
import logging
import time
from collections import defaultdict
@ -41,6 +42,7 @@ logger = logging.getLogger(__name__)
_MCP_CACHE_TTL_SECONDS = 300 # 5 minutes
_MCP_CACHE_MAX_SIZE = 50
_MCP_DISCOVERY_TIMEOUT_SECONDS = 30
_mcp_tools_cache: dict[int, tuple[float, list[StructuredTool]]] = {}
@ -869,7 +871,7 @@ async def load_mcp_tools(
multi_account_types,
)
tools: list[StructuredTool] = []
discovery_tasks: list[dict[str, Any]] = []
for connector in connectors:
try:
cfg = connector.config or {}
@ -882,14 +884,10 @@ async def load_mcp_tools(
)
continue
# For MCP OAuth connectors: refresh if needed, then decrypt the
# access token and inject it into headers at runtime. The DB
# intentionally does NOT store plaintext tokens in server_config.
if cfg.get("mcp_oauth"):
server_config = await _maybe_refresh_mcp_oauth_token(
session, connector, cfg, server_config,
)
# Re-read cfg after potential refresh (connector was reloaded from DB).
cfg = connector.config or {}
server_config = _inject_oauth_headers(cfg, server_config)
if server_config is None:
@ -911,7 +909,6 @@ async def load_mcp_tools(
allowed_tools = svc_cfg.allowed_tools if svc_cfg else []
readonly_tools = svc_cfg.readonly_tools if svc_cfg else frozenset()
# Build a prefix only when multiple accounts share the same type.
tool_name_prefix: str | None = None
if ct in multi_account_types and svc_cfg:
service_key = next(
@ -921,34 +918,66 @@ async def load_mcp_tools(
if service_key:
tool_name_prefix = f"{service_key}_{connector.id}"
transport = server_config.get("transport", "stdio")
if transport in ("streamable-http", "http", "sse"):
connector_tools = await _load_http_mcp_tools(
connector.id,
connector.name,
server_config,
trusted_tools=trusted_tools,
allowed_tools=allowed_tools,
readonly_tools=readonly_tools,
tool_name_prefix=tool_name_prefix,
)
else:
connector_tools = await _load_stdio_mcp_tools(
connector.id,
connector.name,
server_config,
trusted_tools=trusted_tools,
)
tools.extend(connector_tools)
discovery_tasks.append({
"connector_id": connector.id,
"connector_name": connector.name,
"server_config": server_config,
"trusted_tools": trusted_tools,
"allowed_tools": allowed_tools,
"readonly_tools": readonly_tools,
"tool_name_prefix": tool_name_prefix,
"transport": server_config.get("transport", "stdio"),
})
except Exception as e:
logger.exception(
"Failed to load tools from MCP connector %d: %s",
"Failed to prepare MCP connector %d: %s",
connector.id, e,
)
async def _discover_one(task: dict[str, Any]) -> list[StructuredTool]:
try:
if task["transport"] in ("streamable-http", "http", "sse"):
return await asyncio.wait_for(
_load_http_mcp_tools(
task["connector_id"],
task["connector_name"],
task["server_config"],
trusted_tools=task["trusted_tools"],
allowed_tools=task["allowed_tools"],
readonly_tools=task["readonly_tools"],
tool_name_prefix=task["tool_name_prefix"],
),
timeout=_MCP_DISCOVERY_TIMEOUT_SECONDS,
)
else:
return await asyncio.wait_for(
_load_stdio_mcp_tools(
task["connector_id"],
task["connector_name"],
task["server_config"],
trusted_tools=task["trusted_tools"],
),
timeout=_MCP_DISCOVERY_TIMEOUT_SECONDS,
)
except asyncio.TimeoutError:
logger.error(
"MCP connector %d timed out after %ds during discovery",
task["connector_id"], _MCP_DISCOVERY_TIMEOUT_SECONDS,
)
return []
except Exception as e:
logger.exception(
"Failed to load tools from MCP connector %d: %s",
task["connector_id"], e,
)
return []
results = await asyncio.gather(*[_discover_one(t) for t in discovery_tasks])
tools: list[StructuredTool] = [
tool for sublist in results for tool in sublist
]
_mcp_tools_cache[search_space_id] = (now, tools)
if len(_mcp_tools_cache) > _MCP_CACHE_MAX_SIZE: