mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-06-06 20:15:17 +02:00
Relocate the entire new_chat/tools/ package (62 files incl. registry, hitl, MCP cluster, and all connector subpackages: gmail/slack/discord/teams/drive/etc.) to the shared kernel. The package turned out to be a clean cohesive cluster: its only references to non-tools new_chat modules were comments, and its middleware deps were already flipped to shared in slice 5c. Flip 33 live importers (multi-agent, flows, routes, services, anonymous_agent, tests). Re-export shims remain for the frozen single-agent stack: a package __init__ mirroring the public surface (new_chat.__init__ imports it) plus invalid_tool + registry submodule shims (chat_deepagent imports those). Resolves slice 5c's two transient back-edges: shared/middleware/action_log (TYPE_CHECKING ToolDefinition) and tool_call_repair (local INVALID_TOOL_NAME) now point at app.agents.shared.tools.
1332 lines
48 KiB
Python
1332 lines
48 KiB
Python
"""MCP Tool Factory.
|
|
|
|
This module creates LangChain tools from MCP servers using the Model Context Protocol.
|
|
Tools are dynamically discovered from MCP servers - no manual configuration needed.
|
|
|
|
Supports both transport types:
|
|
- stdio: Local process-based MCP servers (command, args, env)
|
|
- streamable-http/http/sse: Remote HTTP-based MCP servers (url, headers)
|
|
|
|
All MCP tools are unconditionally gated by HITL (Human-in-the-Loop) approval.
|
|
Per the MCP spec: "Clients MUST consider tool annotations to be untrusted unless
|
|
they come from trusted servers." Users can bypass HITL for specific tools by
|
|
clicking "Always Allow", which adds the tool name to the connector's
|
|
``config.trusted_tools`` allow-list.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import logging
|
|
import time
|
|
from collections import defaultdict
|
|
from typing import TYPE_CHECKING, Any
|
|
|
|
if TYPE_CHECKING:
|
|
from app.utils.oauth_security import TokenEncryption
|
|
|
|
from langchain_core.tools import StructuredTool
|
|
from mcp import ClientSession
|
|
from mcp.client.streamable_http import streamablehttp_client
|
|
from pydantic import BaseModel, ConfigDict, Field, create_model
|
|
from sqlalchemy import cast, select
|
|
from sqlalchemy.dialects.postgresql import JSONB
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.agents.shared.middleware.dedup_tool_calls import dedup_key_full_args
|
|
from app.agents.shared.tools.hitl import request_approval
|
|
from app.agents.shared.tools.mcp_client import MCPClient
|
|
from app.agents.shared.tools.mcp_tools_cache import (
|
|
CachedMCPTools,
|
|
read_cached_tools,
|
|
write_cached_tools,
|
|
)
|
|
from app.db import SearchSourceConnector
|
|
from app.services.mcp_oauth.registry import MCP_SERVICES, get_service_by_connector_type
|
|
from app.utils.perf import get_perf_logger
|
|
|
|
_perf_log = get_perf_logger()
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
_MCP_CACHE_TTL_SECONDS = 300 # 5 minutes
|
|
_MCP_CACHE_MAX_SIZE = 50
|
|
_MCP_DISCOVERY_TIMEOUT_SECONDS = 30
|
|
_TOOL_CALL_MAX_RETRIES = 3
|
|
_TOOL_CALL_RETRY_DELAY = 1.5 # seconds, doubles per attempt
|
|
# Keyed by ``(search_space_id, bypass_internal_hitl)`` so single-agent and
|
|
# multi-agent paths cannot share tool closures with different HITL wiring.
|
|
_MCPCacheKey = tuple[int, bool]
|
|
_mcp_tools_cache: dict[_MCPCacheKey, tuple[float, list[StructuredTool]]] = {}
|
|
|
|
|
|
def _evict_expired_mcp_cache() -> None:
|
|
"""Remove expired entries from the MCP tools cache to prevent unbounded growth."""
|
|
now = time.monotonic()
|
|
expired = [
|
|
k
|
|
for k, (ts, _) in _mcp_tools_cache.items()
|
|
if now - ts >= _MCP_CACHE_TTL_SECONDS
|
|
]
|
|
for k in expired:
|
|
del _mcp_tools_cache[k]
|
|
if expired:
|
|
logger.debug("Evicted %d expired MCP cache entries", len(expired))
|
|
|
|
|
|
def _create_dynamic_input_model_from_schema(
|
|
tool_name: str,
|
|
input_schema: dict[str, Any],
|
|
) -> type[BaseModel]:
|
|
"""Create a Pydantic model from MCP tool's JSON schema.
|
|
|
|
Models always allow extra fields (``extra="allow"``) so that parameters
|
|
missing from a broken or incomplete JSON schema (e.g. ``zod-to-json-schema``
|
|
producing an empty ``$schema``-only object) can still be forwarded to the
|
|
MCP server.
|
|
|
|
When the schema declares **no** properties, a synthetic ``input_data``
|
|
field of type ``dict`` is injected so the LLM has a visible parameter to
|
|
populate. The caller should unpack ``input_data`` before forwarding to
|
|
the MCP server (see ``_unpack_synthetic_input_data``).
|
|
"""
|
|
properties = input_schema.get("properties", {})
|
|
required_fields = input_schema.get("required", [])
|
|
|
|
field_definitions = {}
|
|
for param_name, param_schema in properties.items():
|
|
param_description = param_schema.get("description", "")
|
|
is_required = param_name in required_fields
|
|
|
|
if is_required:
|
|
field_definitions[param_name] = (
|
|
Any,
|
|
Field(..., description=param_description),
|
|
)
|
|
else:
|
|
field_definitions[param_name] = (
|
|
Any | None,
|
|
Field(None, description=param_description),
|
|
)
|
|
|
|
if not properties:
|
|
field_definitions["input_data"] = (
|
|
dict[str, Any] | None,
|
|
Field(
|
|
None,
|
|
description=(
|
|
"Arguments to pass to this tool as a JSON object. "
|
|
"Infer sensible key names from the tool name and description "
|
|
'(e.g. {"search": "my query"} for a search tool).'
|
|
),
|
|
),
|
|
)
|
|
|
|
model_name = f"{tool_name.replace(' ', '').replace('-', '_')}Input"
|
|
model = create_model(
|
|
model_name, __config__=ConfigDict(extra="allow"), **field_definitions
|
|
)
|
|
return model
|
|
|
|
|
|
def _unpack_synthetic_input_data(kwargs: dict[str, Any]) -> dict[str, Any]:
|
|
"""Unpack the synthetic ``input_data`` field into top-level kwargs.
|
|
|
|
When the MCP tool schema is empty, ``_create_dynamic_input_model_from_schema``
|
|
adds a catch-all ``input_data: dict`` field. This helper merges that dict
|
|
back into the top-level kwargs so the MCP server receives flat arguments.
|
|
"""
|
|
input_data = kwargs.pop("input_data", None)
|
|
if isinstance(input_data, dict):
|
|
kwargs.update(input_data)
|
|
return kwargs
|
|
|
|
|
|
async def _create_mcp_tool_from_definition_stdio(
|
|
tool_def: dict[str, Any],
|
|
mcp_client: MCPClient,
|
|
*,
|
|
connector_name: str = "",
|
|
connector_id: int | None = None,
|
|
trusted_tools: list[str] | None = None,
|
|
bypass_internal_hitl: bool = False,
|
|
) -> StructuredTool:
|
|
"""Create a LangChain tool from an MCP tool definition (stdio transport).
|
|
|
|
Set ``bypass_internal_hitl=True`` when an outer ``HumanInTheLoopMiddleware``
|
|
already gates the tool, otherwise the body's ``request_approval()`` is the
|
|
sole HITL gate (single-agent path).
|
|
"""
|
|
tool_name = tool_def.get("name", "unnamed_tool")
|
|
raw_description = tool_def.get("description", "No description provided")
|
|
tool_description = (
|
|
f"[MCP server: {connector_name}] {raw_description}"
|
|
if connector_name
|
|
else raw_description
|
|
)
|
|
input_schema = tool_def.get("input_schema", {"type": "object", "properties": {}})
|
|
|
|
logger.debug("MCP tool '%s' input schema: %s", tool_name, input_schema)
|
|
|
|
input_model = _create_dynamic_input_model_from_schema(tool_name, input_schema)
|
|
|
|
async def mcp_tool_call(**kwargs) -> str:
|
|
"""Execute the MCP tool call via the client with retry support."""
|
|
logger.debug("MCP tool '%s' called", tool_name)
|
|
|
|
if bypass_internal_hitl:
|
|
call_kwargs = _unpack_synthetic_input_data(
|
|
{k: v for k, v in kwargs.items() if v is not None}
|
|
)
|
|
else:
|
|
# Outside try/except so ``GraphInterrupt`` propagates to LangGraph.
|
|
hitl_result = request_approval(
|
|
action_type="mcp_tool_call",
|
|
tool_name=tool_name,
|
|
params=kwargs,
|
|
context={
|
|
"mcp_server": connector_name,
|
|
"tool_description": raw_description,
|
|
"mcp_transport": "stdio",
|
|
"mcp_connector_id": connector_id,
|
|
},
|
|
trusted_tools=trusted_tools,
|
|
)
|
|
if hitl_result.rejected:
|
|
return "Tool call rejected by user."
|
|
call_kwargs = _unpack_synthetic_input_data(
|
|
{k: v for k, v in hitl_result.params.items() if v is not None}
|
|
)
|
|
|
|
last_error: Exception | None = None
|
|
for attempt in range(_TOOL_CALL_MAX_RETRIES):
|
|
try:
|
|
async with mcp_client.connect():
|
|
result = await mcp_client.call_tool(tool_name, call_kwargs)
|
|
return str(result)
|
|
except Exception as e:
|
|
last_error = e
|
|
if attempt < _TOOL_CALL_MAX_RETRIES - 1:
|
|
delay = _TOOL_CALL_RETRY_DELAY * (2**attempt)
|
|
logger.warning(
|
|
"MCP tool '%s' failed (attempt %d/%d): %s. Retrying in %.1fs...",
|
|
tool_name,
|
|
attempt + 1,
|
|
_TOOL_CALL_MAX_RETRIES,
|
|
e,
|
|
delay,
|
|
)
|
|
await asyncio.sleep(delay)
|
|
else:
|
|
logger.error(
|
|
"MCP tool '%s' failed after %d attempts: %s",
|
|
tool_name,
|
|
_TOOL_CALL_MAX_RETRIES,
|
|
e,
|
|
exc_info=True,
|
|
)
|
|
|
|
return f"Error: MCP tool '{tool_name}' failed after {_TOOL_CALL_MAX_RETRIES} attempts: {last_error!s}"
|
|
|
|
tool = StructuredTool(
|
|
name=tool_name,
|
|
description=tool_description,
|
|
coroutine=mcp_tool_call,
|
|
args_schema=input_model,
|
|
metadata={
|
|
"mcp_input_schema": input_schema,
|
|
"mcp_transport": "stdio",
|
|
"mcp_connector_name": connector_name or None,
|
|
"mcp_connector_id": connector_id,
|
|
"mcp_is_generic": True,
|
|
"hitl": True,
|
|
# Full-args hash: shared identifiers (cloudId, workspaceId, …)
|
|
# would otherwise collapse legitimate batches.
|
|
"dedup_key": dedup_key_full_args,
|
|
},
|
|
)
|
|
|
|
logger.debug("Created MCP tool (stdio): '%s'", tool_name)
|
|
return tool
|
|
|
|
|
|
async def _create_mcp_tool_from_definition_http(
|
|
tool_def: dict[str, Any],
|
|
url: str,
|
|
headers: dict[str, str],
|
|
*,
|
|
connector_name: str = "",
|
|
connector_id: int | None = None,
|
|
trusted_tools: list[str] | None = None,
|
|
readonly_tools: frozenset[str] | None = None,
|
|
tool_name_prefix: str | None = None,
|
|
is_generic_mcp: bool = False,
|
|
bypass_internal_hitl: bool = False,
|
|
) -> StructuredTool:
|
|
"""Create a LangChain tool from an MCP tool definition (HTTP transport).
|
|
|
|
Write tools are wrapped with HITL approval; read-only tools (listed in
|
|
``readonly_tools``) execute immediately without user confirmation. Set
|
|
``bypass_internal_hitl=True`` when an outer ``HumanInTheLoopMiddleware``
|
|
already gates the tool.
|
|
|
|
When ``tool_name_prefix`` is set (multi-account disambiguation), the
|
|
tool exposed to the LLM gets a prefixed name (e.g. ``linear_25_list_issues``)
|
|
but the actual MCP ``call_tool`` still uses the original name.
|
|
"""
|
|
original_tool_name = tool_def.get("name", "unnamed_tool")
|
|
raw_description = tool_def.get("description", "No description provided")
|
|
input_schema = tool_def.get("input_schema", {"type": "object", "properties": {}})
|
|
is_readonly = readonly_tools is not None and original_tool_name in readonly_tools
|
|
|
|
exposed_name = (
|
|
f"{tool_name_prefix}_{original_tool_name}"
|
|
if tool_name_prefix
|
|
else original_tool_name
|
|
)
|
|
if tool_name_prefix:
|
|
tool_description = f"[Account: {connector_name}] {raw_description}"
|
|
elif is_generic_mcp and connector_name:
|
|
tool_description = f"[MCP server: {connector_name}] {raw_description}"
|
|
else:
|
|
tool_description = raw_description
|
|
|
|
logger.debug("MCP HTTP tool '%s' input schema: %s", exposed_name, input_schema)
|
|
|
|
input_model = _create_dynamic_input_model_from_schema(exposed_name, input_schema)
|
|
|
|
async def _do_mcp_call(
|
|
call_headers: dict[str, str],
|
|
call_kwargs: dict[str, Any],
|
|
timeout: float = 60.0,
|
|
) -> str:
|
|
"""Execute a single MCP HTTP call with the given headers."""
|
|
call_start = time.perf_counter()
|
|
async with (
|
|
streamablehttp_client(url, headers=call_headers) as (read, write, _),
|
|
ClientSession(read, write) as session,
|
|
):
|
|
init_start = time.perf_counter()
|
|
await session.initialize()
|
|
init_elapsed = time.perf_counter() - init_start
|
|
|
|
tool_start = time.perf_counter()
|
|
response = await asyncio.wait_for(
|
|
session.call_tool(original_tool_name, arguments=call_kwargs),
|
|
timeout=timeout,
|
|
)
|
|
tool_elapsed = time.perf_counter() - tool_start
|
|
|
|
result = []
|
|
for content in response.content:
|
|
if hasattr(content, "text"):
|
|
result.append(content.text)
|
|
elif hasattr(content, "data"):
|
|
result.append(str(content.data))
|
|
else:
|
|
result.append(str(content))
|
|
|
|
payload = "\n".join(result) if result else ""
|
|
|
|
_perf_log.info(
|
|
"[mcp_http_call] connector=%s tool=%s init=%.3fs call=%.3fs total=%.3fs out_chars=%d",
|
|
connector_id,
|
|
original_tool_name,
|
|
init_elapsed,
|
|
tool_elapsed,
|
|
time.perf_counter() - call_start,
|
|
len(payload),
|
|
)
|
|
return payload
|
|
|
|
async def mcp_http_tool_call(**kwargs) -> str:
|
|
"""Execute the MCP tool call via HTTP transport."""
|
|
logger.debug("MCP HTTP tool '%s' called", exposed_name)
|
|
|
|
if is_readonly or bypass_internal_hitl:
|
|
call_kwargs = _unpack_synthetic_input_data(
|
|
{k: v for k, v in kwargs.items() if v is not None}
|
|
)
|
|
else:
|
|
hitl_result = request_approval(
|
|
action_type="mcp_tool_call",
|
|
tool_name=exposed_name,
|
|
params=kwargs,
|
|
context={
|
|
"mcp_server": connector_name,
|
|
"tool_description": raw_description,
|
|
"mcp_transport": "http",
|
|
"mcp_connector_id": connector_id,
|
|
},
|
|
trusted_tools=trusted_tools,
|
|
)
|
|
if hitl_result.rejected:
|
|
return "Tool call rejected by user."
|
|
call_kwargs = _unpack_synthetic_input_data(
|
|
{k: v for k, v in hitl_result.params.items() if v is not None}
|
|
)
|
|
|
|
try:
|
|
result_str = await _do_mcp_call(headers, call_kwargs)
|
|
logger.debug(
|
|
"MCP HTTP tool '%s' succeeded (len=%d)", exposed_name, len(result_str)
|
|
)
|
|
return result_str
|
|
|
|
except Exception as first_err:
|
|
if not _is_auth_error(first_err) or connector_id is None:
|
|
logger.exception(
|
|
"MCP HTTP tool '%s' execution failed: %s", exposed_name, first_err
|
|
)
|
|
return f"Error: MCP HTTP tool '{exposed_name}' execution failed: {first_err!s}"
|
|
|
|
logger.warning(
|
|
"MCP HTTP tool '%s' got 401 — attempting token refresh for connector %s",
|
|
exposed_name,
|
|
connector_id,
|
|
)
|
|
fresh_headers = await _force_refresh_and_get_headers(connector_id)
|
|
if fresh_headers is None:
|
|
await _mark_connector_auth_expired(connector_id)
|
|
return (
|
|
f"Error: MCP tool '{exposed_name}' authentication expired. "
|
|
"Please re-authenticate the connector in your settings."
|
|
)
|
|
|
|
try:
|
|
result_str = await _do_mcp_call(fresh_headers, call_kwargs)
|
|
logger.info(
|
|
"MCP HTTP tool '%s' succeeded after 401 recovery",
|
|
exposed_name,
|
|
)
|
|
return result_str
|
|
except Exception as retry_err:
|
|
logger.exception(
|
|
"MCP HTTP tool '%s' still failing after token refresh: %s",
|
|
exposed_name,
|
|
retry_err,
|
|
)
|
|
if _is_auth_error(retry_err):
|
|
await _mark_connector_auth_expired(connector_id)
|
|
return (
|
|
f"Error: MCP tool '{exposed_name}' authentication expired. "
|
|
"Please re-authenticate the connector in your settings."
|
|
)
|
|
return f"Error: MCP HTTP tool '{exposed_name}' execution failed: {retry_err!s}"
|
|
|
|
tool = StructuredTool(
|
|
name=exposed_name,
|
|
description=tool_description,
|
|
coroutine=mcp_http_tool_call,
|
|
args_schema=input_model,
|
|
metadata={
|
|
"mcp_input_schema": input_schema,
|
|
"mcp_transport": "http",
|
|
"mcp_url": url,
|
|
"mcp_connector_name": connector_name or None,
|
|
"mcp_is_generic": is_generic_mcp,
|
|
"hitl": not is_readonly,
|
|
# Full-args hash: shared identifiers (cloudId, workspaceId, …)
|
|
# would otherwise collapse legitimate batches.
|
|
"dedup_key": dedup_key_full_args,
|
|
"mcp_original_tool_name": original_tool_name,
|
|
"mcp_connector_id": connector_id,
|
|
},
|
|
)
|
|
|
|
logger.debug("Created MCP tool (HTTP): '%s'", exposed_name)
|
|
return tool
|
|
|
|
|
|
async def _load_stdio_mcp_tools(
|
|
connector_id: int,
|
|
connector_name: str,
|
|
server_config: dict[str, Any],
|
|
trusted_tools: list[str] | None = None,
|
|
*,
|
|
bypass_internal_hitl: bool = False,
|
|
) -> list[StructuredTool]:
|
|
"""Load tools from a stdio-based MCP server."""
|
|
tools: list[StructuredTool] = []
|
|
|
|
command = server_config.get("command")
|
|
if not command or not isinstance(command, str):
|
|
logger.warning(
|
|
"MCP connector %d (name: '%s') missing or invalid command field, skipping",
|
|
connector_id,
|
|
connector_name,
|
|
)
|
|
return tools
|
|
|
|
args = server_config.get("args", [])
|
|
if not isinstance(args, list):
|
|
logger.warning(
|
|
"MCP connector %d (name: '%s') has invalid args field (must be list), skipping",
|
|
connector_id,
|
|
connector_name,
|
|
)
|
|
return tools
|
|
|
|
env = server_config.get("env", {})
|
|
if not isinstance(env, dict):
|
|
logger.warning(
|
|
"MCP connector %d (name: '%s') has invalid env field (must be dict), skipping",
|
|
connector_id,
|
|
connector_name,
|
|
)
|
|
return tools
|
|
|
|
mcp_client = MCPClient(command, args, env)
|
|
|
|
async with mcp_client.connect():
|
|
tool_definitions = await mcp_client.list_tools()
|
|
|
|
logger.info(
|
|
"Discovered %d tools from stdio MCP server '%s' (connector %d)",
|
|
len(tool_definitions),
|
|
command,
|
|
connector_id,
|
|
)
|
|
|
|
for tool_def in tool_definitions:
|
|
try:
|
|
tool = await _create_mcp_tool_from_definition_stdio(
|
|
tool_def,
|
|
mcp_client,
|
|
connector_name=connector_name,
|
|
connector_id=connector_id,
|
|
trusted_tools=trusted_tools,
|
|
bypass_internal_hitl=bypass_internal_hitl,
|
|
)
|
|
tools.append(tool)
|
|
except Exception as e:
|
|
logger.exception(
|
|
"Failed to create tool '%s' from connector %d: %s",
|
|
tool_def.get("name"),
|
|
connector_id,
|
|
e,
|
|
)
|
|
|
|
return tools
|
|
|
|
|
|
async def _load_http_mcp_tools(
|
|
connector_id: int,
|
|
connector_name: str,
|
|
server_config: dict[str, Any],
|
|
trusted_tools: list[str] | None = None,
|
|
allowed_tools: list[str] | None = None,
|
|
readonly_tools: frozenset[str] | None = None,
|
|
tool_name_prefix: str | None = None,
|
|
is_generic_mcp: bool = False,
|
|
*,
|
|
bypass_internal_hitl: bool = False,
|
|
cached_tools: CachedMCPTools | None = None,
|
|
) -> list[StructuredTool]:
|
|
"""Load tools from an HTTP-based MCP server.
|
|
|
|
Args:
|
|
allowed_tools: If non-empty, only tools whose names appear in this
|
|
list are loaded. Empty/None means load everything (used for
|
|
user-managed generic MCP servers).
|
|
readonly_tools: Tool names that skip HITL approval (read-only operations).
|
|
tool_name_prefix: If set, each tool name is prefixed for multi-account
|
|
disambiguation (e.g. ``linear_25``).
|
|
cached_tools: If provided, skip live discovery and rebuild wrappers
|
|
from the persisted definitions.
|
|
"""
|
|
tools: list[StructuredTool] = []
|
|
|
|
url = server_config.get("url")
|
|
if not url or not isinstance(url, str):
|
|
logger.warning(
|
|
"MCP connector %d (name: '%s') missing or invalid url field, skipping",
|
|
connector_id,
|
|
connector_name,
|
|
)
|
|
return tools
|
|
|
|
headers = server_config.get("headers", {})
|
|
if not isinstance(headers, dict):
|
|
logger.warning(
|
|
"MCP connector %d (name: '%s') has invalid headers field (must be dict), skipping",
|
|
connector_id,
|
|
connector_name,
|
|
)
|
|
return tools
|
|
|
|
allowed_set = set(allowed_tools) if allowed_tools else None
|
|
|
|
async def _discover(
|
|
disc_headers: dict[str, str],
|
|
) -> tuple[dict[str, str | None], list[dict[str, Any]]]:
|
|
"""Connect, initialize, and list tools — returns (serverInfo, tools)."""
|
|
async with (
|
|
streamablehttp_client(url, headers=disc_headers) as (read, write, _),
|
|
ClientSession(read, write) as session,
|
|
):
|
|
init_result = await session.initialize()
|
|
server_info: dict[str, str | None] = {"name": None, "version": None}
|
|
si = getattr(init_result, "serverInfo", None)
|
|
if si is not None:
|
|
server_info["name"] = getattr(si, "name", None)
|
|
server_info["version"] = getattr(si, "version", None)
|
|
|
|
response = await session.list_tools()
|
|
return server_info, [
|
|
{
|
|
"name": tool.name,
|
|
"description": tool.description or "",
|
|
"input_schema": tool.inputSchema
|
|
if hasattr(tool, "inputSchema")
|
|
else {},
|
|
}
|
|
for tool in response.tools
|
|
]
|
|
|
|
if cached_tools is not None:
|
|
tool_definitions = [
|
|
{
|
|
"name": td.name,
|
|
"description": td.description,
|
|
"input_schema": td.input_schema,
|
|
}
|
|
for td in cached_tools.tools
|
|
]
|
|
else:
|
|
try:
|
|
server_info, tool_definitions = await _discover(headers)
|
|
except Exception as first_err:
|
|
if not _is_auth_error(first_err) or connector_id is None:
|
|
logger.exception(
|
|
"Failed to connect to HTTP MCP server at '%s' (connector %d): %s",
|
|
url,
|
|
connector_id,
|
|
first_err,
|
|
)
|
|
return tools
|
|
|
|
logger.warning(
|
|
"HTTP MCP discovery for connector %d got 401 — attempting token refresh",
|
|
connector_id,
|
|
)
|
|
fresh_headers = await _force_refresh_and_get_headers(connector_id)
|
|
if fresh_headers is None:
|
|
await _mark_connector_auth_expired(connector_id)
|
|
logger.error(
|
|
"HTTP MCP discovery for connector %d: token refresh failed, marking auth_expired",
|
|
connector_id,
|
|
)
|
|
return tools
|
|
|
|
try:
|
|
server_info, tool_definitions = await _discover(fresh_headers)
|
|
headers = fresh_headers
|
|
logger.info(
|
|
"HTTP MCP discovery for connector %d succeeded after 401 recovery",
|
|
connector_id,
|
|
)
|
|
except Exception as retry_err:
|
|
logger.exception(
|
|
"HTTP MCP discovery for connector %d still failing after refresh: %s",
|
|
connector_id,
|
|
retry_err,
|
|
)
|
|
if _is_auth_error(retry_err):
|
|
await _mark_connector_auth_expired(connector_id)
|
|
return tools
|
|
|
|
await write_cached_tools(
|
|
connector_id,
|
|
tool_definitions,
|
|
server_name=server_info.get("name"),
|
|
server_version=server_info.get("version"),
|
|
transport=server_config.get("transport", "streamable-http"),
|
|
)
|
|
|
|
total_discovered = len(tool_definitions)
|
|
|
|
if allowed_set:
|
|
tool_definitions = [td for td in tool_definitions if td["name"] in allowed_set]
|
|
logger.info(
|
|
"HTTP MCP server '%s' (connector %d): %d/%d tools after allowlist filter",
|
|
url,
|
|
connector_id,
|
|
len(tool_definitions),
|
|
total_discovered,
|
|
)
|
|
else:
|
|
logger.info(
|
|
"Discovered %d tools from HTTP MCP server '%s' (connector %d) — no allowlist, loading all",
|
|
total_discovered,
|
|
url,
|
|
connector_id,
|
|
)
|
|
|
|
for tool_def in tool_definitions:
|
|
try:
|
|
tool = await _create_mcp_tool_from_definition_http(
|
|
tool_def,
|
|
url,
|
|
headers,
|
|
connector_name=connector_name,
|
|
connector_id=connector_id,
|
|
trusted_tools=trusted_tools,
|
|
readonly_tools=readonly_tools,
|
|
tool_name_prefix=tool_name_prefix,
|
|
is_generic_mcp=is_generic_mcp,
|
|
bypass_internal_hitl=bypass_internal_hitl,
|
|
)
|
|
tools.append(tool)
|
|
except Exception as e:
|
|
logger.exception(
|
|
"Failed to create HTTP tool '%s' from connector %d: %s",
|
|
tool_def.get("name"),
|
|
connector_id,
|
|
e,
|
|
)
|
|
|
|
return tools
|
|
|
|
|
|
_TOKEN_REFRESH_BUFFER_SECONDS = 300 # refresh 5 min before expiry
|
|
|
|
_token_enc: TokenEncryption | None = None
|
|
|
|
|
|
def _get_token_enc() -> TokenEncryption:
|
|
global _token_enc
|
|
if _token_enc is None:
|
|
from app.config import config as app_config
|
|
from app.utils.oauth_security import TokenEncryption
|
|
|
|
_token_enc = TokenEncryption(app_config.SECRET_KEY)
|
|
return _token_enc
|
|
|
|
|
|
def _inject_oauth_headers(
|
|
cfg: dict[str, Any],
|
|
server_config: dict[str, Any],
|
|
) -> dict[str, Any] | None:
|
|
"""Decrypt the MCP OAuth access token and inject it into server_config headers.
|
|
|
|
The DB never stores plaintext tokens in ``server_config.headers``. This
|
|
function decrypts ``mcp_oauth.access_token`` at runtime and returns a
|
|
*copy* of ``server_config`` with the Authorization header set.
|
|
"""
|
|
mcp_oauth = cfg.get("mcp_oauth", {})
|
|
encrypted_token = mcp_oauth.get("access_token")
|
|
if not encrypted_token:
|
|
return server_config
|
|
|
|
try:
|
|
access_token = _get_token_enc().decrypt_token(encrypted_token)
|
|
|
|
result = dict(server_config)
|
|
result["headers"] = {
|
|
**server_config.get("headers", {}),
|
|
"Authorization": f"Bearer {access_token}",
|
|
}
|
|
return result
|
|
except Exception:
|
|
logger.error(
|
|
"Failed to decrypt MCP OAuth token — connector will be skipped",
|
|
exc_info=True,
|
|
)
|
|
return None
|
|
|
|
|
|
async def _refresh_connector_token(
|
|
session: AsyncSession,
|
|
connector: SearchSourceConnector,
|
|
) -> str | None:
|
|
"""Refresh the OAuth token for an MCP connector and persist the result.
|
|
|
|
This is the shared core used by both proactive (pre-expiry) and reactive
|
|
(401 recovery) refresh paths. It handles:
|
|
- Decrypting the current refresh token / client secret
|
|
- Calling the token endpoint
|
|
- Encrypting and persisting the new tokens
|
|
- Clearing ``auth_expired`` if it was set
|
|
- Invalidating the MCP tools cache
|
|
|
|
Returns the **plaintext** new access token on success, or ``None`` on
|
|
failure (no refresh token, IdP error, etc.).
|
|
"""
|
|
from datetime import UTC, datetime, timedelta
|
|
|
|
from sqlalchemy.orm.attributes import flag_modified
|
|
|
|
from app.services.mcp_oauth.discovery import refresh_access_token
|
|
|
|
cfg = connector.config or {}
|
|
mcp_oauth = cfg.get("mcp_oauth", {})
|
|
|
|
refresh_token = mcp_oauth.get("refresh_token")
|
|
if not refresh_token:
|
|
logger.warning(
|
|
"MCP connector %s: no refresh_token available",
|
|
connector.id,
|
|
)
|
|
return None
|
|
|
|
enc = _get_token_enc()
|
|
decrypted_refresh = enc.decrypt_token(refresh_token)
|
|
decrypted_secret = (
|
|
enc.decrypt_token(mcp_oauth["client_secret"])
|
|
if mcp_oauth.get("client_secret")
|
|
else ""
|
|
)
|
|
|
|
token_json = await refresh_access_token(
|
|
token_endpoint=mcp_oauth["token_endpoint"],
|
|
refresh_token=decrypted_refresh,
|
|
client_id=mcp_oauth["client_id"],
|
|
client_secret=decrypted_secret,
|
|
)
|
|
|
|
new_access = token_json.get("access_token")
|
|
if not new_access:
|
|
logger.warning(
|
|
"MCP connector %s: token refresh returned no access_token",
|
|
connector.id,
|
|
)
|
|
return None
|
|
|
|
new_expires_at = None
|
|
if token_json.get("expires_in"):
|
|
new_expires_at = datetime.now(UTC) + timedelta(
|
|
seconds=int(token_json["expires_in"])
|
|
)
|
|
|
|
updated_oauth = dict(mcp_oauth)
|
|
updated_oauth["access_token"] = enc.encrypt_token(new_access)
|
|
if token_json.get("refresh_token"):
|
|
updated_oauth["refresh_token"] = enc.encrypt_token(token_json["refresh_token"])
|
|
updated_oauth["expires_at"] = new_expires_at.isoformat() if new_expires_at else None
|
|
|
|
updated_cfg = {**cfg, "mcp_oauth": updated_oauth}
|
|
updated_cfg.pop("auth_expired", None)
|
|
connector.config = updated_cfg
|
|
flag_modified(connector, "config")
|
|
await session.commit()
|
|
await session.refresh(connector)
|
|
|
|
invalidate_mcp_tools_cache(connector.search_space_id)
|
|
|
|
return new_access
|
|
|
|
|
|
async def _maybe_refresh_mcp_oauth_token(
|
|
session: AsyncSession,
|
|
connector: SearchSourceConnector,
|
|
cfg: dict[str, Any],
|
|
server_config: dict[str, Any],
|
|
) -> dict[str, Any]:
|
|
"""Refresh the access token for an MCP OAuth connector if it is about to expire.
|
|
|
|
Returns the (possibly updated) ``server_config``.
|
|
"""
|
|
from datetime import UTC, datetime, timedelta
|
|
|
|
mcp_oauth = cfg.get("mcp_oauth", {})
|
|
expires_at_str = mcp_oauth.get("expires_at")
|
|
if not expires_at_str:
|
|
return server_config
|
|
|
|
try:
|
|
expires_at = datetime.fromisoformat(expires_at_str)
|
|
if expires_at.tzinfo is None:
|
|
expires_at = expires_at.replace(tzinfo=UTC)
|
|
|
|
if datetime.now(UTC) < expires_at - timedelta(
|
|
seconds=_TOKEN_REFRESH_BUFFER_SECONDS
|
|
):
|
|
return server_config
|
|
except (ValueError, TypeError):
|
|
return server_config
|
|
|
|
refresh_start = time.perf_counter()
|
|
try:
|
|
new_access = await _refresh_connector_token(session, connector)
|
|
if not new_access:
|
|
_perf_log.info(
|
|
"[mcp_oauth_refresh] connector=%s elapsed=%.3fs outcome=no_token",
|
|
connector.id,
|
|
time.perf_counter() - refresh_start,
|
|
)
|
|
return server_config
|
|
|
|
logger.info(
|
|
"Proactively refreshed MCP OAuth token for connector %s", connector.id
|
|
)
|
|
_perf_log.info(
|
|
"[mcp_oauth_refresh] connector=%s elapsed=%.3fs outcome=refreshed",
|
|
connector.id,
|
|
time.perf_counter() - refresh_start,
|
|
)
|
|
|
|
refreshed_config = dict(server_config)
|
|
refreshed_config["headers"] = {
|
|
**server_config.get("headers", {}),
|
|
"Authorization": f"Bearer {new_access}",
|
|
}
|
|
return refreshed_config
|
|
|
|
except Exception:
|
|
_perf_log.info(
|
|
"[mcp_oauth_refresh] connector=%s elapsed=%.3fs outcome=failed",
|
|
connector.id,
|
|
time.perf_counter() - refresh_start,
|
|
)
|
|
logger.warning(
|
|
"Failed to refresh MCP OAuth token for connector %s",
|
|
connector.id,
|
|
exc_info=True,
|
|
)
|
|
return server_config
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Reactive 401 handling helpers
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def _is_auth_error(exc: Exception) -> bool:
|
|
"""Check if an exception indicates an HTTP 401 authentication failure."""
|
|
try:
|
|
import httpx
|
|
|
|
if isinstance(exc, httpx.HTTPStatusError):
|
|
return exc.response.status_code == 401
|
|
except ImportError:
|
|
pass
|
|
err_str = str(exc).lower()
|
|
return "401" in err_str or "unauthorized" in err_str
|
|
|
|
|
|
async def _force_refresh_and_get_headers(
|
|
connector_id: int,
|
|
) -> dict[str, str] | None:
|
|
"""Force-refresh OAuth token for a connector and return fresh HTTP headers.
|
|
|
|
Opens a **new** DB session so this can be called from inside tool closures
|
|
that don't have access to the original session.
|
|
|
|
Returns ``None`` when the connector is not OAuth-backed, has no
|
|
refresh token, or the refresh itself fails.
|
|
"""
|
|
from app.db import async_session_maker
|
|
|
|
try:
|
|
async with async_session_maker() as session:
|
|
result = await session.execute(
|
|
select(SearchSourceConnector).filter(
|
|
SearchSourceConnector.id == connector_id,
|
|
)
|
|
)
|
|
connector = result.scalars().first()
|
|
if not connector:
|
|
return None
|
|
|
|
cfg = connector.config or {}
|
|
if not cfg.get("mcp_oauth"):
|
|
return None
|
|
|
|
server_config = cfg.get("server_config", {})
|
|
|
|
new_access = await _refresh_connector_token(session, connector)
|
|
if not new_access:
|
|
return None
|
|
|
|
logger.info(
|
|
"Force-refreshed MCP OAuth token for connector %s (401 recovery)",
|
|
connector_id,
|
|
)
|
|
return {
|
|
**server_config.get("headers", {}),
|
|
"Authorization": f"Bearer {new_access}",
|
|
}
|
|
|
|
except Exception:
|
|
logger.warning(
|
|
"Failed to force-refresh MCP OAuth token for connector %s",
|
|
connector_id,
|
|
exc_info=True,
|
|
)
|
|
return None
|
|
|
|
|
|
async def _mark_connector_auth_expired(connector_id: int) -> None:
|
|
"""Set ``config.auth_expired = True`` so the frontend shows re-auth UI."""
|
|
from app.db import async_session_maker
|
|
|
|
try:
|
|
async with async_session_maker() as session:
|
|
result = await session.execute(
|
|
select(SearchSourceConnector).filter(
|
|
SearchSourceConnector.id == connector_id,
|
|
)
|
|
)
|
|
connector = result.scalars().first()
|
|
if not connector:
|
|
return
|
|
|
|
cfg = dict(connector.config or {})
|
|
if cfg.get("auth_expired"):
|
|
return
|
|
|
|
cfg["auth_expired"] = True
|
|
connector.config = cfg
|
|
|
|
from sqlalchemy.orm.attributes import flag_modified
|
|
|
|
flag_modified(connector, "config")
|
|
await session.commit()
|
|
|
|
logger.info(
|
|
"Marked MCP connector %s as auth_expired after unrecoverable 401",
|
|
connector_id,
|
|
)
|
|
invalidate_mcp_tools_cache(connector.search_space_id)
|
|
|
|
except Exception:
|
|
logger.warning(
|
|
"Failed to mark connector %s as auth_expired",
|
|
connector_id,
|
|
exc_info=True,
|
|
)
|
|
|
|
|
|
def invalidate_mcp_tools_cache(search_space_id: int | None = None) -> None:
|
|
"""Invalidate cached MCP tools (both ``bypass_internal_hitl`` variants together)."""
|
|
if search_space_id is not None:
|
|
for key in [k for k in _mcp_tools_cache if k[0] == search_space_id]:
|
|
_mcp_tools_cache.pop(key, None)
|
|
else:
|
|
_mcp_tools_cache.clear()
|
|
|
|
|
|
async def discover_single_mcp_connector(connector_id: int) -> None:
|
|
"""Force live MCP discovery for one connector so its ``cached_tools`` row is fresh.
|
|
|
|
``_load_http_mcp_tools`` persists ``cached_tools`` as a side effect of any
|
|
live discovery; passing ``cached_tools=None`` here guarantees we go to the
|
|
network. The returned wrappers are discarded — the in-process LRU is
|
|
rebuilt lazily on the next user query. Stdio connectors are not cached and
|
|
are skipped.
|
|
"""
|
|
from app.db import async_session_maker
|
|
|
|
started = time.perf_counter()
|
|
try:
|
|
async with async_session_maker() as session:
|
|
connector = await session.get(SearchSourceConnector, connector_id)
|
|
if connector is None:
|
|
logger.info(
|
|
"discover_single_mcp_connector: connector %d not found",
|
|
connector_id,
|
|
)
|
|
return
|
|
|
|
cfg = connector.config or {}
|
|
server_config = cfg.get("server_config", {})
|
|
if not server_config or not isinstance(server_config, dict):
|
|
return
|
|
|
|
transport = server_config.get("transport", "stdio")
|
|
if transport not in ("streamable-http", "http", "sse"):
|
|
return
|
|
|
|
if cfg.get("mcp_oauth"):
|
|
server_config = await _maybe_refresh_mcp_oauth_token(
|
|
session, connector, cfg, server_config
|
|
)
|
|
cfg = connector.config or {}
|
|
server_config = _inject_oauth_headers(cfg, server_config)
|
|
if server_config is None:
|
|
logger.info(
|
|
"discover_single_mcp_connector: OAuth token unavailable for connector %d",
|
|
connector_id,
|
|
)
|
|
return
|
|
|
|
ct = (
|
|
connector.connector_type.value
|
|
if hasattr(connector.connector_type, "value")
|
|
else str(connector.connector_type)
|
|
)
|
|
svc_cfg = get_service_by_connector_type(ct)
|
|
allowed_tools = svc_cfg.allowed_tools if svc_cfg else []
|
|
readonly_tools = svc_cfg.readonly_tools if svc_cfg else frozenset()
|
|
|
|
await asyncio.wait_for(
|
|
_load_http_mcp_tools(
|
|
connector.id,
|
|
connector.name,
|
|
server_config,
|
|
trusted_tools=cfg.get("trusted_tools", []),
|
|
allowed_tools=allowed_tools,
|
|
readonly_tools=readonly_tools,
|
|
tool_name_prefix=None,
|
|
is_generic_mcp=svc_cfg is None,
|
|
bypass_internal_hitl=True,
|
|
cached_tools=None,
|
|
),
|
|
timeout=_MCP_DISCOVERY_TIMEOUT_SECONDS,
|
|
)
|
|
|
|
_perf_log.info(
|
|
"[mcp_prefetch] connector=%s elapsed=%.3fs",
|
|
connector_id,
|
|
time.perf_counter() - started,
|
|
)
|
|
except TimeoutError:
|
|
logger.warning(
|
|
"discover_single_mcp_connector: connector %d timed out after %ds",
|
|
connector_id,
|
|
_MCP_DISCOVERY_TIMEOUT_SECONDS,
|
|
)
|
|
except Exception:
|
|
logger.warning(
|
|
"discover_single_mcp_connector: failed for connector %d",
|
|
connector_id,
|
|
exc_info=True,
|
|
)
|
|
|
|
|
|
async def load_mcp_tools(
|
|
session: AsyncSession,
|
|
search_space_id: int,
|
|
*,
|
|
bypass_internal_hitl: bool = False,
|
|
) -> list[StructuredTool]:
|
|
"""Load all MCP tools from the user's active MCP server connectors.
|
|
|
|
Results are cached per ``(search_space_id, bypass_internal_hitl)`` for up
|
|
to 5 minutes; bypass is keyed because each variant builds a different tool
|
|
closure (with vs. without the in-wrapper ``request_approval`` gate).
|
|
"""
|
|
_evict_expired_mcp_cache()
|
|
|
|
now = time.monotonic()
|
|
cache_key: _MCPCacheKey = (search_space_id, bypass_internal_hitl)
|
|
cached = _mcp_tools_cache.get(cache_key)
|
|
if cached is not None:
|
|
cached_at, cached_tools = cached
|
|
if now - cached_at < _MCP_CACHE_TTL_SECONDS:
|
|
logger.info(
|
|
"Using cached MCP tools for search space %s (%d tools, age=%.0fs, bypass_hitl=%s)",
|
|
search_space_id,
|
|
len(cached_tools),
|
|
now - cached_at,
|
|
bypass_internal_hitl,
|
|
)
|
|
return list(cached_tools)
|
|
|
|
try:
|
|
# Find all connectors with MCP server config: generic MCP_CONNECTOR type
|
|
# and service-specific types (LINEAR_CONNECTOR, etc.) created via MCP OAuth.
|
|
# Cast JSON -> JSONB so we can use has_key to filter by the presence of "server_config".
|
|
result = await session.execute(
|
|
select(SearchSourceConnector).filter(
|
|
SearchSourceConnector.search_space_id == search_space_id,
|
|
cast(SearchSourceConnector.config, JSONB).has_key("server_config"),
|
|
),
|
|
)
|
|
|
|
connectors = list(result.scalars())
|
|
|
|
# Group connectors by type to detect multi-account scenarios.
|
|
# When >1 connector shares the same type, tool names would collide
|
|
# so we prefix them with "{service_key}_{connector_id}_".
|
|
type_groups: dict[str, list[SearchSourceConnector]] = defaultdict(list)
|
|
for connector in connectors:
|
|
ct = (
|
|
connector.connector_type.value
|
|
if hasattr(connector.connector_type, "value")
|
|
else str(connector.connector_type)
|
|
)
|
|
type_groups[ct].append(connector)
|
|
|
|
multi_account_types: set[str] = {
|
|
ct for ct, group in type_groups.items() if len(group) > 1
|
|
}
|
|
if multi_account_types:
|
|
logger.info(
|
|
"Multi-account detected for connector types: %s",
|
|
multi_account_types,
|
|
)
|
|
|
|
discovery_tasks: list[dict[str, Any]] = []
|
|
for connector in connectors:
|
|
try:
|
|
cfg = connector.config or {}
|
|
server_config = cfg.get("server_config", {})
|
|
|
|
if not server_config or not isinstance(server_config, dict):
|
|
logger.warning(
|
|
"MCP connector %d (name: '%s') has invalid or missing server_config, skipping",
|
|
connector.id,
|
|
connector.name,
|
|
)
|
|
continue
|
|
|
|
if cfg.get("mcp_oauth"):
|
|
server_config = await _maybe_refresh_mcp_oauth_token(
|
|
session,
|
|
connector,
|
|
cfg,
|
|
server_config,
|
|
)
|
|
cfg = connector.config or {}
|
|
server_config = _inject_oauth_headers(cfg, server_config)
|
|
if server_config is None:
|
|
logger.warning(
|
|
"Skipping MCP connector %d — OAuth token decryption failed",
|
|
connector.id,
|
|
)
|
|
await _mark_connector_auth_expired(connector.id)
|
|
continue
|
|
|
|
trusted_tools = cfg.get("trusted_tools", [])
|
|
|
|
ct = (
|
|
connector.connector_type.value
|
|
if hasattr(connector.connector_type, "value")
|
|
else str(connector.connector_type)
|
|
)
|
|
|
|
svc_cfg = get_service_by_connector_type(ct)
|
|
allowed_tools = svc_cfg.allowed_tools if svc_cfg else []
|
|
readonly_tools = svc_cfg.readonly_tools if svc_cfg else frozenset()
|
|
|
|
tool_name_prefix: str | None = None
|
|
if ct in multi_account_types and svc_cfg:
|
|
service_key = next(
|
|
(k for k, v in MCP_SERVICES.items() if v is svc_cfg),
|
|
None,
|
|
)
|
|
if service_key:
|
|
tool_name_prefix = f"{service_key}_{connector.id}"
|
|
|
|
discovery_tasks.append(
|
|
{
|
|
"connector_id": connector.id,
|
|
"connector_name": connector.name,
|
|
"server_config": server_config,
|
|
"trusted_tools": trusted_tools,
|
|
"allowed_tools": allowed_tools,
|
|
"readonly_tools": readonly_tools,
|
|
"tool_name_prefix": tool_name_prefix,
|
|
"transport": server_config.get("transport", "stdio"),
|
|
"is_generic_mcp": svc_cfg is None,
|
|
"cached_tools": read_cached_tools(connector),
|
|
}
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.exception(
|
|
"Failed to prepare MCP connector %d: %s",
|
|
connector.id,
|
|
e,
|
|
)
|
|
|
|
async def _discover_one(task: dict[str, Any]) -> list[StructuredTool]:
|
|
discover_start = time.perf_counter()
|
|
transport = task["transport"]
|
|
cached_tools = task.get("cached_tools")
|
|
try:
|
|
if transport in ("streamable-http", "http", "sse"):
|
|
result = await asyncio.wait_for(
|
|
_load_http_mcp_tools(
|
|
task["connector_id"],
|
|
task["connector_name"],
|
|
task["server_config"],
|
|
trusted_tools=task["trusted_tools"],
|
|
allowed_tools=task["allowed_tools"],
|
|
readonly_tools=task["readonly_tools"],
|
|
tool_name_prefix=task["tool_name_prefix"],
|
|
is_generic_mcp=task.get("is_generic_mcp", False),
|
|
bypass_internal_hitl=bypass_internal_hitl,
|
|
cached_tools=cached_tools,
|
|
),
|
|
timeout=_MCP_DISCOVERY_TIMEOUT_SECONDS,
|
|
)
|
|
else:
|
|
result = await asyncio.wait_for(
|
|
_load_stdio_mcp_tools(
|
|
task["connector_id"],
|
|
task["connector_name"],
|
|
task["server_config"],
|
|
trusted_tools=task["trusted_tools"],
|
|
bypass_internal_hitl=bypass_internal_hitl,
|
|
),
|
|
timeout=_MCP_DISCOVERY_TIMEOUT_SECONDS,
|
|
)
|
|
_perf_log.info(
|
|
"[mcp_discover] connector=%s name=%r transport=%s tools=%d elapsed=%.3fs cache=%s",
|
|
task["connector_id"],
|
|
task["connector_name"],
|
|
transport,
|
|
len(result),
|
|
time.perf_counter() - discover_start,
|
|
"hit" if cached_tools is not None else "miss",
|
|
)
|
|
return result
|
|
except TimeoutError:
|
|
_perf_log.info(
|
|
"[mcp_discover] connector=%s name=%r transport=%s elapsed=%.3fs outcome=timeout",
|
|
task["connector_id"],
|
|
task["connector_name"],
|
|
transport,
|
|
time.perf_counter() - discover_start,
|
|
)
|
|
logger.error(
|
|
"MCP connector %d timed out after %ds during discovery",
|
|
task["connector_id"],
|
|
_MCP_DISCOVERY_TIMEOUT_SECONDS,
|
|
)
|
|
return []
|
|
except Exception as e:
|
|
_perf_log.info(
|
|
"[mcp_discover] connector=%s name=%r transport=%s elapsed=%.3fs outcome=error",
|
|
task["connector_id"],
|
|
task["connector_name"],
|
|
transport,
|
|
time.perf_counter() - discover_start,
|
|
)
|
|
logger.exception(
|
|
"Failed to load tools from MCP connector %d: %s",
|
|
task["connector_id"],
|
|
e,
|
|
)
|
|
return []
|
|
|
|
gather_start = time.perf_counter()
|
|
results = await asyncio.gather(*[_discover_one(t) for t in discovery_tasks])
|
|
_perf_log.info(
|
|
"[mcp_discover] gather_wall=%.3fs connectors=%d total_tools=%d",
|
|
time.perf_counter() - gather_start,
|
|
len(discovery_tasks),
|
|
sum(len(r) for r in results),
|
|
)
|
|
tools: list[StructuredTool] = [tool for sublist in results for tool in sublist]
|
|
|
|
_mcp_tools_cache[cache_key] = (now, tools)
|
|
|
|
if len(_mcp_tools_cache) > _MCP_CACHE_MAX_SIZE:
|
|
oldest_key = min(_mcp_tools_cache, key=lambda k: _mcp_tools_cache[k][0])
|
|
del _mcp_tools_cache[oldest_key]
|
|
|
|
logger.info(
|
|
"Loaded %d MCP tools for search space %d (bypass_hitl=%s)",
|
|
len(tools),
|
|
search_space_id,
|
|
bypass_internal_hitl,
|
|
)
|
|
return tools
|
|
|
|
except Exception as e:
|
|
logger.exception("Failed to load MCP tools: %s", e)
|
|
return []
|