mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-25 19:15:18 +02:00
perf(mcp): persist list_tools discovery in connector.config.cached_tools
Skip the ~1-3s MCP initialize + list_tools handshake on every cache miss by reading tool definitions from the connector row we already load. Lazy populate on first miss, self-heal on corrupt cache, zero schema migration.
This commit is contained in:
parent
db8bffab38
commit
c0aa4261ac
3 changed files with 304 additions and 42 deletions
|
|
@ -36,6 +36,11 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|||
from app.agents.new_chat.middleware.dedup_tool_calls import dedup_key_full_args
|
||||
from app.agents.new_chat.tools.hitl import request_approval
|
||||
from app.agents.new_chat.tools.mcp_client import MCPClient
|
||||
from app.agents.new_chat.tools.mcp_tools_cache import (
|
||||
CachedMCPTools,
|
||||
read_cached_tools,
|
||||
write_cached_tools,
|
||||
)
|
||||
from app.db import SearchSourceConnector
|
||||
from app.services.mcp_oauth.registry import MCP_SERVICES, get_service_by_connector_type
|
||||
from app.utils.perf import get_perf_logger
|
||||
|
|
@ -516,6 +521,7 @@ async def _load_http_mcp_tools(
|
|||
is_generic_mcp: bool = False,
|
||||
*,
|
||||
bypass_internal_hitl: bool = False,
|
||||
cached_tools: CachedMCPTools | None = None,
|
||||
) -> list[StructuredTool]:
|
||||
"""Load tools from an HTTP-based MCP server.
|
||||
|
||||
|
|
@ -526,6 +532,8 @@ async def _load_http_mcp_tools(
|
|||
readonly_tools: Tool names that skip HITL approval (read-only operations).
|
||||
tool_name_prefix: If set, each tool name is prefixed for multi-account
|
||||
disambiguation (e.g. ``linear_25``).
|
||||
cached_tools: If provided, skip live discovery and rebuild wrappers
|
||||
from the persisted definitions.
|
||||
"""
|
||||
tools: list[StructuredTool] = []
|
||||
|
||||
|
|
@ -549,15 +557,23 @@ async def _load_http_mcp_tools(
|
|||
|
||||
allowed_set = set(allowed_tools) if allowed_tools else None
|
||||
|
||||
async def _discover(disc_headers: dict[str, str]) -> list[dict[str, Any]]:
|
||||
"""Connect, initialize, and list tools from the MCP server."""
|
||||
async def _discover(
|
||||
disc_headers: dict[str, str],
|
||||
) -> tuple[dict[str, str | None], list[dict[str, Any]]]:
|
||||
"""Connect, initialize, and list tools — returns (serverInfo, tools)."""
|
||||
async with (
|
||||
streamablehttp_client(url, headers=disc_headers) as (read, write, _),
|
||||
ClientSession(read, write) as session,
|
||||
):
|
||||
await session.initialize()
|
||||
init_result = await session.initialize()
|
||||
server_info: dict[str, str | None] = {"name": None, "version": None}
|
||||
si = getattr(init_result, "serverInfo", None)
|
||||
if si is not None:
|
||||
server_info["name"] = getattr(si, "name", None)
|
||||
server_info["version"] = getattr(si, "version", None)
|
||||
|
||||
response = await session.list_tools()
|
||||
return [
|
||||
return server_info, [
|
||||
{
|
||||
"name": tool.name,
|
||||
"description": tool.description or "",
|
||||
|
|
@ -568,47 +584,65 @@ async def _load_http_mcp_tools(
|
|||
for tool in response.tools
|
||||
]
|
||||
|
||||
try:
|
||||
tool_definitions = await _discover(headers)
|
||||
except Exception as first_err:
|
||||
if not _is_auth_error(first_err) or connector_id is None:
|
||||
logger.exception(
|
||||
"Failed to connect to HTTP MCP server at '%s' (connector %d): %s",
|
||||
url,
|
||||
connector_id,
|
||||
first_err,
|
||||
)
|
||||
return tools
|
||||
|
||||
logger.warning(
|
||||
"HTTP MCP discovery for connector %d got 401 — attempting token refresh",
|
||||
connector_id,
|
||||
)
|
||||
fresh_headers = await _force_refresh_and_get_headers(connector_id)
|
||||
if fresh_headers is None:
|
||||
await _mark_connector_auth_expired(connector_id)
|
||||
logger.error(
|
||||
"HTTP MCP discovery for connector %d: token refresh failed, marking auth_expired",
|
||||
connector_id,
|
||||
)
|
||||
return tools
|
||||
|
||||
if cached_tools is not None:
|
||||
tool_definitions = [
|
||||
{
|
||||
"name": td.name,
|
||||
"description": td.description,
|
||||
"input_schema": td.input_schema,
|
||||
}
|
||||
for td in cached_tools.tools
|
||||
]
|
||||
else:
|
||||
try:
|
||||
tool_definitions = await _discover(fresh_headers)
|
||||
headers = fresh_headers
|
||||
logger.info(
|
||||
"HTTP MCP discovery for connector %d succeeded after 401 recovery",
|
||||
server_info, tool_definitions = await _discover(headers)
|
||||
except Exception as first_err:
|
||||
if not _is_auth_error(first_err) or connector_id is None:
|
||||
logger.exception(
|
||||
"Failed to connect to HTTP MCP server at '%s' (connector %d): %s",
|
||||
url,
|
||||
connector_id,
|
||||
first_err,
|
||||
)
|
||||
return tools
|
||||
|
||||
logger.warning(
|
||||
"HTTP MCP discovery for connector %d got 401 — attempting token refresh",
|
||||
connector_id,
|
||||
)
|
||||
except Exception as retry_err:
|
||||
logger.exception(
|
||||
"HTTP MCP discovery for connector %d still failing after refresh: %s",
|
||||
connector_id,
|
||||
retry_err,
|
||||
)
|
||||
if _is_auth_error(retry_err):
|
||||
fresh_headers = await _force_refresh_and_get_headers(connector_id)
|
||||
if fresh_headers is None:
|
||||
await _mark_connector_auth_expired(connector_id)
|
||||
return tools
|
||||
logger.error(
|
||||
"HTTP MCP discovery for connector %d: token refresh failed, marking auth_expired",
|
||||
connector_id,
|
||||
)
|
||||
return tools
|
||||
|
||||
try:
|
||||
server_info, tool_definitions = await _discover(fresh_headers)
|
||||
headers = fresh_headers
|
||||
logger.info(
|
||||
"HTTP MCP discovery for connector %d succeeded after 401 recovery",
|
||||
connector_id,
|
||||
)
|
||||
except Exception as retry_err:
|
||||
logger.exception(
|
||||
"HTTP MCP discovery for connector %d still failing after refresh: %s",
|
||||
connector_id,
|
||||
retry_err,
|
||||
)
|
||||
if _is_auth_error(retry_err):
|
||||
await _mark_connector_auth_expired(connector_id)
|
||||
return tools
|
||||
|
||||
await write_cached_tools(
|
||||
connector_id,
|
||||
tool_definitions,
|
||||
server_name=server_info.get("name"),
|
||||
server_version=server_info.get("version"),
|
||||
transport=server_config.get("transport", "streamable-http"),
|
||||
)
|
||||
|
||||
total_discovered = len(tool_definitions)
|
||||
|
||||
|
|
@ -1099,6 +1133,7 @@ async def load_mcp_tools(
|
|||
"tool_name_prefix": tool_name_prefix,
|
||||
"transport": server_config.get("transport", "stdio"),
|
||||
"is_generic_mcp": svc_cfg is None,
|
||||
"cached_tools": read_cached_tools(connector),
|
||||
}
|
||||
)
|
||||
|
||||
|
|
@ -1112,6 +1147,7 @@ async def load_mcp_tools(
|
|||
async def _discover_one(task: dict[str, Any]) -> list[StructuredTool]:
|
||||
discover_start = time.perf_counter()
|
||||
transport = task["transport"]
|
||||
cached_tools = task.get("cached_tools")
|
||||
try:
|
||||
if transport in ("streamable-http", "http", "sse"):
|
||||
result = await asyncio.wait_for(
|
||||
|
|
@ -1125,6 +1161,7 @@ async def load_mcp_tools(
|
|||
tool_name_prefix=task["tool_name_prefix"],
|
||||
is_generic_mcp=task.get("is_generic_mcp", False),
|
||||
bypass_internal_hitl=bypass_internal_hitl,
|
||||
cached_tools=cached_tools,
|
||||
),
|
||||
timeout=_MCP_DISCOVERY_TIMEOUT_SECONDS,
|
||||
)
|
||||
|
|
@ -1140,12 +1177,13 @@ async def load_mcp_tools(
|
|||
timeout=_MCP_DISCOVERY_TIMEOUT_SECONDS,
|
||||
)
|
||||
_perf_log.info(
|
||||
"[mcp_discover] connector=%s name=%r transport=%s tools=%d elapsed=%.3fs",
|
||||
"[mcp_discover] connector=%s name=%r transport=%s tools=%d elapsed=%.3fs cache=%s",
|
||||
task["connector_id"],
|
||||
task["connector_name"],
|
||||
transport,
|
||||
len(result),
|
||||
time.perf_counter() - discover_start,
|
||||
"hit" if cached_tools is not None else "miss",
|
||||
)
|
||||
return result
|
||||
except TimeoutError:
|
||||
|
|
|
|||
|
|
@ -0,0 +1,94 @@
|
|||
"""Persist MCP ``list_tools`` results in ``SearchSourceConnector.config.cached_tools``."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field, ValidationError
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
from app.db import SearchSourceConnector, async_session_maker
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CachedMCPToolDef(BaseModel):
|
||||
name: str
|
||||
description: str = ""
|
||||
input_schema: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class CachedMCPTools(BaseModel):
|
||||
discovered_at: datetime
|
||||
server_version: str | None = None
|
||||
server_name: str | None = None
|
||||
transport: str | None = None
|
||||
tools: list[CachedMCPToolDef]
|
||||
|
||||
|
||||
def read_cached_tools(connector: SearchSourceConnector) -> CachedMCPTools | None:
|
||||
"""Return parsed cached tools or ``None`` if missing / corrupt (caller falls back to live discovery)."""
|
||||
cfg = connector.config or {}
|
||||
raw = cfg.get("cached_tools")
|
||||
if not raw or not isinstance(raw, dict):
|
||||
return None
|
||||
|
||||
try:
|
||||
return CachedMCPTools.model_validate(raw)
|
||||
except ValidationError as exc:
|
||||
logger.warning(
|
||||
"MCP connector %d has corrupt cached_tools — falling back to live discovery: %s",
|
||||
connector.id,
|
||||
exc,
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
async def write_cached_tools(
|
||||
connector_id: int,
|
||||
tool_definitions: list[dict[str, Any]],
|
||||
*,
|
||||
server_name: str | None = None,
|
||||
server_version: str | None = None,
|
||||
transport: str | None = None,
|
||||
) -> None:
|
||||
"""Best-effort persist; uses its own session so a write failure cannot poison the caller's transaction."""
|
||||
payload = CachedMCPTools(
|
||||
discovered_at=datetime.now(UTC),
|
||||
server_version=server_version,
|
||||
server_name=server_name,
|
||||
transport=transport,
|
||||
tools=[CachedMCPToolDef.model_validate(td) for td in tool_definitions],
|
||||
)
|
||||
|
||||
try:
|
||||
async with async_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == connector_id,
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if connector is None:
|
||||
return
|
||||
|
||||
cfg = dict(connector.config or {})
|
||||
cfg["cached_tools"] = payload.model_dump(mode="json")
|
||||
connector.config = cfg
|
||||
flag_modified(connector, "config")
|
||||
await session.commit()
|
||||
|
||||
logger.info(
|
||||
"Persisted cached_tools for MCP connector %d (%d tools)",
|
||||
connector_id,
|
||||
len(payload.tools),
|
||||
)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to persist cached_tools for MCP connector %d",
|
||||
connector_id,
|
||||
exc_info=True,
|
||||
)
|
||||
Loading…
Add table
Add a link
Reference in a new issue