perf(mcp): persist list_tools discovery in connector.config.cached_tools

Skip the ~1-3s MCP initialize + list_tools handshake on every cache miss
by reading tool definitions from the connector row we already load. Lazy
populate on first miss, self-heal on corrupt cache, zero schema migration.
This commit is contained in:
CREDO23 2026-05-20 16:11:07 +02:00
parent db8bffab38
commit c0aa4261ac
3 changed files with 304 additions and 42 deletions

View file

@ -36,6 +36,11 @@ from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.new_chat.middleware.dedup_tool_calls import dedup_key_full_args
from app.agents.new_chat.tools.hitl import request_approval
from app.agents.new_chat.tools.mcp_client import MCPClient
from app.agents.new_chat.tools.mcp_tools_cache import (
CachedMCPTools,
read_cached_tools,
write_cached_tools,
)
from app.db import SearchSourceConnector
from app.services.mcp_oauth.registry import MCP_SERVICES, get_service_by_connector_type
from app.utils.perf import get_perf_logger
@ -516,6 +521,7 @@ async def _load_http_mcp_tools(
is_generic_mcp: bool = False,
*,
bypass_internal_hitl: bool = False,
cached_tools: CachedMCPTools | None = None,
) -> list[StructuredTool]:
"""Load tools from an HTTP-based MCP server.
@ -526,6 +532,8 @@ async def _load_http_mcp_tools(
readonly_tools: Tool names that skip HITL approval (read-only operations).
tool_name_prefix: If set, each tool name is prefixed for multi-account
disambiguation (e.g. ``linear_25``).
cached_tools: If provided, skip live discovery and rebuild wrappers
from the persisted definitions.
"""
tools: list[StructuredTool] = []
@ -549,15 +557,23 @@ async def _load_http_mcp_tools(
allowed_set = set(allowed_tools) if allowed_tools else None
async def _discover(disc_headers: dict[str, str]) -> list[dict[str, Any]]:
"""Connect, initialize, and list tools from the MCP server."""
async def _discover(
disc_headers: dict[str, str],
) -> tuple[dict[str, str | None], list[dict[str, Any]]]:
"""Connect, initialize, and list tools — returns (serverInfo, tools)."""
async with (
streamablehttp_client(url, headers=disc_headers) as (read, write, _),
ClientSession(read, write) as session,
):
await session.initialize()
init_result = await session.initialize()
server_info: dict[str, str | None] = {"name": None, "version": None}
si = getattr(init_result, "serverInfo", None)
if si is not None:
server_info["name"] = getattr(si, "name", None)
server_info["version"] = getattr(si, "version", None)
response = await session.list_tools()
return [
return server_info, [
{
"name": tool.name,
"description": tool.description or "",
@ -568,47 +584,65 @@ async def _load_http_mcp_tools(
for tool in response.tools
]
try:
tool_definitions = await _discover(headers)
except Exception as first_err:
if not _is_auth_error(first_err) or connector_id is None:
logger.exception(
"Failed to connect to HTTP MCP server at '%s' (connector %d): %s",
url,
connector_id,
first_err,
)
return tools
logger.warning(
"HTTP MCP discovery for connector %d got 401 — attempting token refresh",
connector_id,
)
fresh_headers = await _force_refresh_and_get_headers(connector_id)
if fresh_headers is None:
await _mark_connector_auth_expired(connector_id)
logger.error(
"HTTP MCP discovery for connector %d: token refresh failed, marking auth_expired",
connector_id,
)
return tools
if cached_tools is not None:
tool_definitions = [
{
"name": td.name,
"description": td.description,
"input_schema": td.input_schema,
}
for td in cached_tools.tools
]
else:
try:
tool_definitions = await _discover(fresh_headers)
headers = fresh_headers
logger.info(
"HTTP MCP discovery for connector %d succeeded after 401 recovery",
server_info, tool_definitions = await _discover(headers)
except Exception as first_err:
if not _is_auth_error(first_err) or connector_id is None:
logger.exception(
"Failed to connect to HTTP MCP server at '%s' (connector %d): %s",
url,
connector_id,
first_err,
)
return tools
logger.warning(
"HTTP MCP discovery for connector %d got 401 — attempting token refresh",
connector_id,
)
except Exception as retry_err:
logger.exception(
"HTTP MCP discovery for connector %d still failing after refresh: %s",
connector_id,
retry_err,
)
if _is_auth_error(retry_err):
fresh_headers = await _force_refresh_and_get_headers(connector_id)
if fresh_headers is None:
await _mark_connector_auth_expired(connector_id)
return tools
logger.error(
"HTTP MCP discovery for connector %d: token refresh failed, marking auth_expired",
connector_id,
)
return tools
try:
server_info, tool_definitions = await _discover(fresh_headers)
headers = fresh_headers
logger.info(
"HTTP MCP discovery for connector %d succeeded after 401 recovery",
connector_id,
)
except Exception as retry_err:
logger.exception(
"HTTP MCP discovery for connector %d still failing after refresh: %s",
connector_id,
retry_err,
)
if _is_auth_error(retry_err):
await _mark_connector_auth_expired(connector_id)
return tools
await write_cached_tools(
connector_id,
tool_definitions,
server_name=server_info.get("name"),
server_version=server_info.get("version"),
transport=server_config.get("transport", "streamable-http"),
)
total_discovered = len(tool_definitions)
@ -1099,6 +1133,7 @@ async def load_mcp_tools(
"tool_name_prefix": tool_name_prefix,
"transport": server_config.get("transport", "stdio"),
"is_generic_mcp": svc_cfg is None,
"cached_tools": read_cached_tools(connector),
}
)
@ -1112,6 +1147,7 @@ async def load_mcp_tools(
async def _discover_one(task: dict[str, Any]) -> list[StructuredTool]:
discover_start = time.perf_counter()
transport = task["transport"]
cached_tools = task.get("cached_tools")
try:
if transport in ("streamable-http", "http", "sse"):
result = await asyncio.wait_for(
@ -1125,6 +1161,7 @@ async def load_mcp_tools(
tool_name_prefix=task["tool_name_prefix"],
is_generic_mcp=task.get("is_generic_mcp", False),
bypass_internal_hitl=bypass_internal_hitl,
cached_tools=cached_tools,
),
timeout=_MCP_DISCOVERY_TIMEOUT_SECONDS,
)
@ -1140,12 +1177,13 @@ async def load_mcp_tools(
timeout=_MCP_DISCOVERY_TIMEOUT_SECONDS,
)
_perf_log.info(
"[mcp_discover] connector=%s name=%r transport=%s tools=%d elapsed=%.3fs",
"[mcp_discover] connector=%s name=%r transport=%s tools=%d elapsed=%.3fs cache=%s",
task["connector_id"],
task["connector_name"],
transport,
len(result),
time.perf_counter() - discover_start,
"hit" if cached_tools is not None else "miss",
)
return result
except TimeoutError:

View file

@ -0,0 +1,94 @@
"""Persist MCP ``list_tools`` results in ``SearchSourceConnector.config.cached_tools``."""
from __future__ import annotations
import logging
from datetime import UTC, datetime
from typing import Any
from pydantic import BaseModel, Field, ValidationError
from sqlalchemy import select
from sqlalchemy.orm.attributes import flag_modified
from app.db import SearchSourceConnector, async_session_maker
logger = logging.getLogger(__name__)
class CachedMCPToolDef(BaseModel):
name: str
description: str = ""
input_schema: dict[str, Any] = Field(default_factory=dict)
class CachedMCPTools(BaseModel):
discovered_at: datetime
server_version: str | None = None
server_name: str | None = None
transport: str | None = None
tools: list[CachedMCPToolDef]
def read_cached_tools(connector: SearchSourceConnector) -> CachedMCPTools | None:
"""Return parsed cached tools or ``None`` if missing / corrupt (caller falls back to live discovery)."""
cfg = connector.config or {}
raw = cfg.get("cached_tools")
if not raw or not isinstance(raw, dict):
return None
try:
return CachedMCPTools.model_validate(raw)
except ValidationError as exc:
logger.warning(
"MCP connector %d has corrupt cached_tools — falling back to live discovery: %s",
connector.id,
exc,
)
return None
async def write_cached_tools(
connector_id: int,
tool_definitions: list[dict[str, Any]],
*,
server_name: str | None = None,
server_version: str | None = None,
transport: str | None = None,
) -> None:
"""Best-effort persist; uses its own session so a write failure cannot poison the caller's transaction."""
payload = CachedMCPTools(
discovered_at=datetime.now(UTC),
server_version=server_version,
server_name=server_name,
transport=transport,
tools=[CachedMCPToolDef.model_validate(td) for td in tool_definitions],
)
try:
async with async_session_maker() as session:
result = await session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == connector_id,
)
)
connector = result.scalars().first()
if connector is None:
return
cfg = dict(connector.config or {})
cfg["cached_tools"] = payload.model_dump(mode="json")
connector.config = cfg
flag_modified(connector, "config")
await session.commit()
logger.info(
"Persisted cached_tools for MCP connector %d (%d tools)",
connector_id,
len(payload.tools),
)
except Exception:
logger.warning(
"Failed to persist cached_tools for MCP connector %d",
connector_id,
exc_info=True,
)