perf(mcp): persist list_tools discovery in connector.config.cached_tools

Skip the ~1-3s MCP initialize + list_tools handshake on every cache miss
by reading tool definitions from the connector row we already load. Lazy
populate on first miss, self-heal on corrupt cache, zero schema migration.
This commit is contained in:
CREDO23 2026-05-20 16:11:07 +02:00
parent db8bffab38
commit c0aa4261ac
3 changed files with 304 additions and 42 deletions

View file

@ -36,6 +36,11 @@ from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.new_chat.middleware.dedup_tool_calls import dedup_key_full_args from app.agents.new_chat.middleware.dedup_tool_calls import dedup_key_full_args
from app.agents.new_chat.tools.hitl import request_approval from app.agents.new_chat.tools.hitl import request_approval
from app.agents.new_chat.tools.mcp_client import MCPClient from app.agents.new_chat.tools.mcp_client import MCPClient
from app.agents.new_chat.tools.mcp_tools_cache import (
CachedMCPTools,
read_cached_tools,
write_cached_tools,
)
from app.db import SearchSourceConnector from app.db import SearchSourceConnector
from app.services.mcp_oauth.registry import MCP_SERVICES, get_service_by_connector_type from app.services.mcp_oauth.registry import MCP_SERVICES, get_service_by_connector_type
from app.utils.perf import get_perf_logger from app.utils.perf import get_perf_logger
@ -516,6 +521,7 @@ async def _load_http_mcp_tools(
is_generic_mcp: bool = False, is_generic_mcp: bool = False,
*, *,
bypass_internal_hitl: bool = False, bypass_internal_hitl: bool = False,
cached_tools: CachedMCPTools | None = None,
) -> list[StructuredTool]: ) -> list[StructuredTool]:
"""Load tools from an HTTP-based MCP server. """Load tools from an HTTP-based MCP server.
@ -526,6 +532,8 @@ async def _load_http_mcp_tools(
readonly_tools: Tool names that skip HITL approval (read-only operations). readonly_tools: Tool names that skip HITL approval (read-only operations).
tool_name_prefix: If set, each tool name is prefixed for multi-account tool_name_prefix: If set, each tool name is prefixed for multi-account
disambiguation (e.g. ``linear_25``). disambiguation (e.g. ``linear_25``).
cached_tools: If provided, skip live discovery and rebuild wrappers
from the persisted definitions.
""" """
tools: list[StructuredTool] = [] tools: list[StructuredTool] = []
@ -549,15 +557,23 @@ async def _load_http_mcp_tools(
allowed_set = set(allowed_tools) if allowed_tools else None allowed_set = set(allowed_tools) if allowed_tools else None
async def _discover(disc_headers: dict[str, str]) -> list[dict[str, Any]]: async def _discover(
"""Connect, initialize, and list tools from the MCP server.""" disc_headers: dict[str, str],
) -> tuple[dict[str, str | None], list[dict[str, Any]]]:
"""Connect, initialize, and list tools — returns (serverInfo, tools)."""
async with ( async with (
streamablehttp_client(url, headers=disc_headers) as (read, write, _), streamablehttp_client(url, headers=disc_headers) as (read, write, _),
ClientSession(read, write) as session, ClientSession(read, write) as session,
): ):
await session.initialize() init_result = await session.initialize()
server_info: dict[str, str | None] = {"name": None, "version": None}
si = getattr(init_result, "serverInfo", None)
if si is not None:
server_info["name"] = getattr(si, "name", None)
server_info["version"] = getattr(si, "version", None)
response = await session.list_tools() response = await session.list_tools()
return [ return server_info, [
{ {
"name": tool.name, "name": tool.name,
"description": tool.description or "", "description": tool.description or "",
@ -568,47 +584,65 @@ async def _load_http_mcp_tools(
for tool in response.tools for tool in response.tools
] ]
try: if cached_tools is not None:
tool_definitions = await _discover(headers) tool_definitions = [
except Exception as first_err: {
if not _is_auth_error(first_err) or connector_id is None: "name": td.name,
logger.exception( "description": td.description,
"Failed to connect to HTTP MCP server at '%s' (connector %d): %s", "input_schema": td.input_schema,
url, }
connector_id, for td in cached_tools.tools
first_err, ]
) else:
return tools
logger.warning(
"HTTP MCP discovery for connector %d got 401 — attempting token refresh",
connector_id,
)
fresh_headers = await _force_refresh_and_get_headers(connector_id)
if fresh_headers is None:
await _mark_connector_auth_expired(connector_id)
logger.error(
"HTTP MCP discovery for connector %d: token refresh failed, marking auth_expired",
connector_id,
)
return tools
try: try:
tool_definitions = await _discover(fresh_headers) server_info, tool_definitions = await _discover(headers)
headers = fresh_headers except Exception as first_err:
logger.info( if not _is_auth_error(first_err) or connector_id is None:
"HTTP MCP discovery for connector %d succeeded after 401 recovery", logger.exception(
"Failed to connect to HTTP MCP server at '%s' (connector %d): %s",
url,
connector_id,
first_err,
)
return tools
logger.warning(
"HTTP MCP discovery for connector %d got 401 — attempting token refresh",
connector_id, connector_id,
) )
except Exception as retry_err: fresh_headers = await _force_refresh_and_get_headers(connector_id)
logger.exception( if fresh_headers is None:
"HTTP MCP discovery for connector %d still failing after refresh: %s",
connector_id,
retry_err,
)
if _is_auth_error(retry_err):
await _mark_connector_auth_expired(connector_id) await _mark_connector_auth_expired(connector_id)
return tools logger.error(
"HTTP MCP discovery for connector %d: token refresh failed, marking auth_expired",
connector_id,
)
return tools
try:
server_info, tool_definitions = await _discover(fresh_headers)
headers = fresh_headers
logger.info(
"HTTP MCP discovery for connector %d succeeded after 401 recovery",
connector_id,
)
except Exception as retry_err:
logger.exception(
"HTTP MCP discovery for connector %d still failing after refresh: %s",
connector_id,
retry_err,
)
if _is_auth_error(retry_err):
await _mark_connector_auth_expired(connector_id)
return tools
await write_cached_tools(
connector_id,
tool_definitions,
server_name=server_info.get("name"),
server_version=server_info.get("version"),
transport=server_config.get("transport", "streamable-http"),
)
total_discovered = len(tool_definitions) total_discovered = len(tool_definitions)
@ -1099,6 +1133,7 @@ async def load_mcp_tools(
"tool_name_prefix": tool_name_prefix, "tool_name_prefix": tool_name_prefix,
"transport": server_config.get("transport", "stdio"), "transport": server_config.get("transport", "stdio"),
"is_generic_mcp": svc_cfg is None, "is_generic_mcp": svc_cfg is None,
"cached_tools": read_cached_tools(connector),
} }
) )
@ -1112,6 +1147,7 @@ async def load_mcp_tools(
async def _discover_one(task: dict[str, Any]) -> list[StructuredTool]: async def _discover_one(task: dict[str, Any]) -> list[StructuredTool]:
discover_start = time.perf_counter() discover_start = time.perf_counter()
transport = task["transport"] transport = task["transport"]
cached_tools = task.get("cached_tools")
try: try:
if transport in ("streamable-http", "http", "sse"): if transport in ("streamable-http", "http", "sse"):
result = await asyncio.wait_for( result = await asyncio.wait_for(
@ -1125,6 +1161,7 @@ async def load_mcp_tools(
tool_name_prefix=task["tool_name_prefix"], tool_name_prefix=task["tool_name_prefix"],
is_generic_mcp=task.get("is_generic_mcp", False), is_generic_mcp=task.get("is_generic_mcp", False),
bypass_internal_hitl=bypass_internal_hitl, bypass_internal_hitl=bypass_internal_hitl,
cached_tools=cached_tools,
), ),
timeout=_MCP_DISCOVERY_TIMEOUT_SECONDS, timeout=_MCP_DISCOVERY_TIMEOUT_SECONDS,
) )
@ -1140,12 +1177,13 @@ async def load_mcp_tools(
timeout=_MCP_DISCOVERY_TIMEOUT_SECONDS, timeout=_MCP_DISCOVERY_TIMEOUT_SECONDS,
) )
_perf_log.info( _perf_log.info(
"[mcp_discover] connector=%s name=%r transport=%s tools=%d elapsed=%.3fs", "[mcp_discover] connector=%s name=%r transport=%s tools=%d elapsed=%.3fs cache=%s",
task["connector_id"], task["connector_id"],
task["connector_name"], task["connector_name"],
transport, transport,
len(result), len(result),
time.perf_counter() - discover_start, time.perf_counter() - discover_start,
"hit" if cached_tools is not None else "miss",
) )
return result return result
except TimeoutError: except TimeoutError:

View file

@ -0,0 +1,94 @@
"""Persist MCP ``list_tools`` results in ``SearchSourceConnector.config.cached_tools``."""
from __future__ import annotations
import logging
from datetime import UTC, datetime
from typing import Any
from pydantic import BaseModel, Field, ValidationError
from sqlalchemy import select
from sqlalchemy.orm.attributes import flag_modified
from app.db import SearchSourceConnector, async_session_maker
logger = logging.getLogger(__name__)
class CachedMCPToolDef(BaseModel):
name: str
description: str = ""
input_schema: dict[str, Any] = Field(default_factory=dict)
class CachedMCPTools(BaseModel):
discovered_at: datetime
server_version: str | None = None
server_name: str | None = None
transport: str | None = None
tools: list[CachedMCPToolDef]
def read_cached_tools(connector: SearchSourceConnector) -> CachedMCPTools | None:
"""Return parsed cached tools or ``None`` if missing / corrupt (caller falls back to live discovery)."""
cfg = connector.config or {}
raw = cfg.get("cached_tools")
if not raw or not isinstance(raw, dict):
return None
try:
return CachedMCPTools.model_validate(raw)
except ValidationError as exc:
logger.warning(
"MCP connector %d has corrupt cached_tools — falling back to live discovery: %s",
connector.id,
exc,
)
return None
async def write_cached_tools(
connector_id: int,
tool_definitions: list[dict[str, Any]],
*,
server_name: str | None = None,
server_version: str | None = None,
transport: str | None = None,
) -> None:
"""Best-effort persist; uses its own session so a write failure cannot poison the caller's transaction."""
payload = CachedMCPTools(
discovered_at=datetime.now(UTC),
server_version=server_version,
server_name=server_name,
transport=transport,
tools=[CachedMCPToolDef.model_validate(td) for td in tool_definitions],
)
try:
async with async_session_maker() as session:
result = await session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == connector_id,
)
)
connector = result.scalars().first()
if connector is None:
return
cfg = dict(connector.config or {})
cfg["cached_tools"] = payload.model_dump(mode="json")
connector.config = cfg
flag_modified(connector, "config")
await session.commit()
logger.info(
"Persisted cached_tools for MCP connector %d (%d tools)",
connector_id,
len(payload.tools),
)
except Exception:
logger.warning(
"Failed to persist cached_tools for MCP connector %d",
connector_id,
exc_info=True,
)

View file

@ -0,0 +1,130 @@
"""Unit tests for ``mcp_tools_cache``."""
from __future__ import annotations
from datetime import UTC, datetime
from types import SimpleNamespace
import pytest
from app.agents.new_chat.tools.mcp_tools_cache import (
CachedMCPToolDef,
CachedMCPTools,
read_cached_tools,
)
pytestmark = pytest.mark.unit
def _make_connector(config: dict | None) -> SimpleNamespace:
return SimpleNamespace(id=42, config=config)
def test_read_returns_none_when_config_is_none() -> None:
assert read_cached_tools(_make_connector(None)) is None
def test_read_returns_none_when_cached_tools_missing() -> None:
assert read_cached_tools(_make_connector({"server_config": {}})) is None
def test_read_returns_none_when_cached_tools_is_not_a_dict() -> None:
assert read_cached_tools(_make_connector({"cached_tools": []})) is None
assert read_cached_tools(_make_connector({"cached_tools": "stale"})) is None
def test_read_parses_minimal_valid_payload() -> None:
parsed = read_cached_tools(
_make_connector(
{
"cached_tools": {
"discovered_at": "2026-05-20T10:00:00+00:00",
"tools": [
{
"name": "list_issues",
"description": "List Linear issues",
"input_schema": {"type": "object"},
}
],
}
}
)
)
assert parsed is not None
assert parsed.server_version is None
assert parsed.server_name is None
assert parsed.transport is None
assert len(parsed.tools) == 1
assert parsed.tools[0].name == "list_issues"
def test_read_parses_full_payload_with_serverinfo() -> None:
parsed = read_cached_tools(
_make_connector(
{
"cached_tools": {
"discovered_at": "2026-05-20T10:00:00+00:00",
"server_version": "1.2.3",
"server_name": "atlassian-mcp",
"transport": "streamable-http",
"tools": [
{"name": "create_issue", "input_schema": {}},
{"name": "list_issues", "input_schema": {}},
],
}
}
)
)
assert parsed is not None
assert parsed.server_version == "1.2.3"
assert parsed.server_name == "atlassian-mcp"
assert parsed.transport == "streamable-http"
assert [t.name for t in parsed.tools] == ["create_issue", "list_issues"]
def test_read_returns_none_for_corrupt_payload(caplog) -> None:
parsed = read_cached_tools(
_make_connector(
{
"cached_tools": {
"discovered_at": "not-a-date",
"tools": "should-be-a-list",
}
}
)
)
assert parsed is None
assert any("corrupt cached_tools" in r.getMessage() for r in caplog.records)
def test_read_returns_none_when_tools_missing() -> None:
parsed = read_cached_tools(
_make_connector(
{"cached_tools": {"discovered_at": "2026-05-20T10:00:00+00:00"}}
)
)
assert parsed is None
def test_tool_def_defaults_description_and_schema() -> None:
td = CachedMCPToolDef.model_validate({"name": "ping"})
assert td.description == ""
assert td.input_schema == {}
def test_model_dump_json_mode_is_round_trippable() -> None:
original = CachedMCPTools(
discovered_at=datetime(2026, 5, 20, 10, 0, 0, tzinfo=UTC),
server_version="1.2.3",
server_name="atlassian-mcp",
transport="streamable-http",
tools=[CachedMCPToolDef(name="list_issues")],
)
payload = original.model_dump(mode="json")
assert payload["discovered_at"] == "2026-05-20T10:00:00Z"
assert payload["tools"][0]["name"] == "list_issues"
reparsed = CachedMCPTools.model_validate(payload)
assert reparsed.discovered_at == original.discovered_at
assert reparsed.tools[0].name == "list_issues"