mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-25 19:15:18 +02:00
perf(mcp): persist list_tools discovery in connector.config.cached_tools
Skip the ~1-3s MCP initialize + list_tools handshake on every cache miss by reading tool definitions from the connector row we already load. Lazy populate on first miss, self-heal on corrupt cache, zero schema migration.
This commit is contained in:
parent
db8bffab38
commit
c0aa4261ac
3 changed files with 304 additions and 42 deletions
|
|
@ -36,6 +36,11 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
from app.agents.new_chat.middleware.dedup_tool_calls import dedup_key_full_args
|
from app.agents.new_chat.middleware.dedup_tool_calls import dedup_key_full_args
|
||||||
from app.agents.new_chat.tools.hitl import request_approval
|
from app.agents.new_chat.tools.hitl import request_approval
|
||||||
from app.agents.new_chat.tools.mcp_client import MCPClient
|
from app.agents.new_chat.tools.mcp_client import MCPClient
|
||||||
|
from app.agents.new_chat.tools.mcp_tools_cache import (
|
||||||
|
CachedMCPTools,
|
||||||
|
read_cached_tools,
|
||||||
|
write_cached_tools,
|
||||||
|
)
|
||||||
from app.db import SearchSourceConnector
|
from app.db import SearchSourceConnector
|
||||||
from app.services.mcp_oauth.registry import MCP_SERVICES, get_service_by_connector_type
|
from app.services.mcp_oauth.registry import MCP_SERVICES, get_service_by_connector_type
|
||||||
from app.utils.perf import get_perf_logger
|
from app.utils.perf import get_perf_logger
|
||||||
|
|
@ -516,6 +521,7 @@ async def _load_http_mcp_tools(
|
||||||
is_generic_mcp: bool = False,
|
is_generic_mcp: bool = False,
|
||||||
*,
|
*,
|
||||||
bypass_internal_hitl: bool = False,
|
bypass_internal_hitl: bool = False,
|
||||||
|
cached_tools: CachedMCPTools | None = None,
|
||||||
) -> list[StructuredTool]:
|
) -> list[StructuredTool]:
|
||||||
"""Load tools from an HTTP-based MCP server.
|
"""Load tools from an HTTP-based MCP server.
|
||||||
|
|
||||||
|
|
@ -526,6 +532,8 @@ async def _load_http_mcp_tools(
|
||||||
readonly_tools: Tool names that skip HITL approval (read-only operations).
|
readonly_tools: Tool names that skip HITL approval (read-only operations).
|
||||||
tool_name_prefix: If set, each tool name is prefixed for multi-account
|
tool_name_prefix: If set, each tool name is prefixed for multi-account
|
||||||
disambiguation (e.g. ``linear_25``).
|
disambiguation (e.g. ``linear_25``).
|
||||||
|
cached_tools: If provided, skip live discovery and rebuild wrappers
|
||||||
|
from the persisted definitions.
|
||||||
"""
|
"""
|
||||||
tools: list[StructuredTool] = []
|
tools: list[StructuredTool] = []
|
||||||
|
|
||||||
|
|
@ -549,15 +557,23 @@ async def _load_http_mcp_tools(
|
||||||
|
|
||||||
allowed_set = set(allowed_tools) if allowed_tools else None
|
allowed_set = set(allowed_tools) if allowed_tools else None
|
||||||
|
|
||||||
async def _discover(disc_headers: dict[str, str]) -> list[dict[str, Any]]:
|
async def _discover(
|
||||||
"""Connect, initialize, and list tools from the MCP server."""
|
disc_headers: dict[str, str],
|
||||||
|
) -> tuple[dict[str, str | None], list[dict[str, Any]]]:
|
||||||
|
"""Connect, initialize, and list tools — returns (serverInfo, tools)."""
|
||||||
async with (
|
async with (
|
||||||
streamablehttp_client(url, headers=disc_headers) as (read, write, _),
|
streamablehttp_client(url, headers=disc_headers) as (read, write, _),
|
||||||
ClientSession(read, write) as session,
|
ClientSession(read, write) as session,
|
||||||
):
|
):
|
||||||
await session.initialize()
|
init_result = await session.initialize()
|
||||||
|
server_info: dict[str, str | None] = {"name": None, "version": None}
|
||||||
|
si = getattr(init_result, "serverInfo", None)
|
||||||
|
if si is not None:
|
||||||
|
server_info["name"] = getattr(si, "name", None)
|
||||||
|
server_info["version"] = getattr(si, "version", None)
|
||||||
|
|
||||||
response = await session.list_tools()
|
response = await session.list_tools()
|
||||||
return [
|
return server_info, [
|
||||||
{
|
{
|
||||||
"name": tool.name,
|
"name": tool.name,
|
||||||
"description": tool.description or "",
|
"description": tool.description or "",
|
||||||
|
|
@ -568,47 +584,65 @@ async def _load_http_mcp_tools(
|
||||||
for tool in response.tools
|
for tool in response.tools
|
||||||
]
|
]
|
||||||
|
|
||||||
try:
|
if cached_tools is not None:
|
||||||
tool_definitions = await _discover(headers)
|
tool_definitions = [
|
||||||
except Exception as first_err:
|
{
|
||||||
if not _is_auth_error(first_err) or connector_id is None:
|
"name": td.name,
|
||||||
logger.exception(
|
"description": td.description,
|
||||||
"Failed to connect to HTTP MCP server at '%s' (connector %d): %s",
|
"input_schema": td.input_schema,
|
||||||
url,
|
}
|
||||||
connector_id,
|
for td in cached_tools.tools
|
||||||
first_err,
|
]
|
||||||
)
|
else:
|
||||||
return tools
|
|
||||||
|
|
||||||
logger.warning(
|
|
||||||
"HTTP MCP discovery for connector %d got 401 — attempting token refresh",
|
|
||||||
connector_id,
|
|
||||||
)
|
|
||||||
fresh_headers = await _force_refresh_and_get_headers(connector_id)
|
|
||||||
if fresh_headers is None:
|
|
||||||
await _mark_connector_auth_expired(connector_id)
|
|
||||||
logger.error(
|
|
||||||
"HTTP MCP discovery for connector %d: token refresh failed, marking auth_expired",
|
|
||||||
connector_id,
|
|
||||||
)
|
|
||||||
return tools
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
tool_definitions = await _discover(fresh_headers)
|
server_info, tool_definitions = await _discover(headers)
|
||||||
headers = fresh_headers
|
except Exception as first_err:
|
||||||
logger.info(
|
if not _is_auth_error(first_err) or connector_id is None:
|
||||||
"HTTP MCP discovery for connector %d succeeded after 401 recovery",
|
logger.exception(
|
||||||
|
"Failed to connect to HTTP MCP server at '%s' (connector %d): %s",
|
||||||
|
url,
|
||||||
|
connector_id,
|
||||||
|
first_err,
|
||||||
|
)
|
||||||
|
return tools
|
||||||
|
|
||||||
|
logger.warning(
|
||||||
|
"HTTP MCP discovery for connector %d got 401 — attempting token refresh",
|
||||||
connector_id,
|
connector_id,
|
||||||
)
|
)
|
||||||
except Exception as retry_err:
|
fresh_headers = await _force_refresh_and_get_headers(connector_id)
|
||||||
logger.exception(
|
if fresh_headers is None:
|
||||||
"HTTP MCP discovery for connector %d still failing after refresh: %s",
|
|
||||||
connector_id,
|
|
||||||
retry_err,
|
|
||||||
)
|
|
||||||
if _is_auth_error(retry_err):
|
|
||||||
await _mark_connector_auth_expired(connector_id)
|
await _mark_connector_auth_expired(connector_id)
|
||||||
return tools
|
logger.error(
|
||||||
|
"HTTP MCP discovery for connector %d: token refresh failed, marking auth_expired",
|
||||||
|
connector_id,
|
||||||
|
)
|
||||||
|
return tools
|
||||||
|
|
||||||
|
try:
|
||||||
|
server_info, tool_definitions = await _discover(fresh_headers)
|
||||||
|
headers = fresh_headers
|
||||||
|
logger.info(
|
||||||
|
"HTTP MCP discovery for connector %d succeeded after 401 recovery",
|
||||||
|
connector_id,
|
||||||
|
)
|
||||||
|
except Exception as retry_err:
|
||||||
|
logger.exception(
|
||||||
|
"HTTP MCP discovery for connector %d still failing after refresh: %s",
|
||||||
|
connector_id,
|
||||||
|
retry_err,
|
||||||
|
)
|
||||||
|
if _is_auth_error(retry_err):
|
||||||
|
await _mark_connector_auth_expired(connector_id)
|
||||||
|
return tools
|
||||||
|
|
||||||
|
await write_cached_tools(
|
||||||
|
connector_id,
|
||||||
|
tool_definitions,
|
||||||
|
server_name=server_info.get("name"),
|
||||||
|
server_version=server_info.get("version"),
|
||||||
|
transport=server_config.get("transport", "streamable-http"),
|
||||||
|
)
|
||||||
|
|
||||||
total_discovered = len(tool_definitions)
|
total_discovered = len(tool_definitions)
|
||||||
|
|
||||||
|
|
@ -1099,6 +1133,7 @@ async def load_mcp_tools(
|
||||||
"tool_name_prefix": tool_name_prefix,
|
"tool_name_prefix": tool_name_prefix,
|
||||||
"transport": server_config.get("transport", "stdio"),
|
"transport": server_config.get("transport", "stdio"),
|
||||||
"is_generic_mcp": svc_cfg is None,
|
"is_generic_mcp": svc_cfg is None,
|
||||||
|
"cached_tools": read_cached_tools(connector),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -1112,6 +1147,7 @@ async def load_mcp_tools(
|
||||||
async def _discover_one(task: dict[str, Any]) -> list[StructuredTool]:
|
async def _discover_one(task: dict[str, Any]) -> list[StructuredTool]:
|
||||||
discover_start = time.perf_counter()
|
discover_start = time.perf_counter()
|
||||||
transport = task["transport"]
|
transport = task["transport"]
|
||||||
|
cached_tools = task.get("cached_tools")
|
||||||
try:
|
try:
|
||||||
if transport in ("streamable-http", "http", "sse"):
|
if transport in ("streamable-http", "http", "sse"):
|
||||||
result = await asyncio.wait_for(
|
result = await asyncio.wait_for(
|
||||||
|
|
@ -1125,6 +1161,7 @@ async def load_mcp_tools(
|
||||||
tool_name_prefix=task["tool_name_prefix"],
|
tool_name_prefix=task["tool_name_prefix"],
|
||||||
is_generic_mcp=task.get("is_generic_mcp", False),
|
is_generic_mcp=task.get("is_generic_mcp", False),
|
||||||
bypass_internal_hitl=bypass_internal_hitl,
|
bypass_internal_hitl=bypass_internal_hitl,
|
||||||
|
cached_tools=cached_tools,
|
||||||
),
|
),
|
||||||
timeout=_MCP_DISCOVERY_TIMEOUT_SECONDS,
|
timeout=_MCP_DISCOVERY_TIMEOUT_SECONDS,
|
||||||
)
|
)
|
||||||
|
|
@ -1140,12 +1177,13 @@ async def load_mcp_tools(
|
||||||
timeout=_MCP_DISCOVERY_TIMEOUT_SECONDS,
|
timeout=_MCP_DISCOVERY_TIMEOUT_SECONDS,
|
||||||
)
|
)
|
||||||
_perf_log.info(
|
_perf_log.info(
|
||||||
"[mcp_discover] connector=%s name=%r transport=%s tools=%d elapsed=%.3fs",
|
"[mcp_discover] connector=%s name=%r transport=%s tools=%d elapsed=%.3fs cache=%s",
|
||||||
task["connector_id"],
|
task["connector_id"],
|
||||||
task["connector_name"],
|
task["connector_name"],
|
||||||
transport,
|
transport,
|
||||||
len(result),
|
len(result),
|
||||||
time.perf_counter() - discover_start,
|
time.perf_counter() - discover_start,
|
||||||
|
"hit" if cached_tools is not None else "miss",
|
||||||
)
|
)
|
||||||
return result
|
return result
|
||||||
except TimeoutError:
|
except TimeoutError:
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,94 @@
|
||||||
|
"""Persist MCP ``list_tools`` results in ``SearchSourceConnector.config.cached_tools``."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from datetime import UTC, datetime
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field, ValidationError
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.orm.attributes import flag_modified
|
||||||
|
|
||||||
|
from app.db import SearchSourceConnector, async_session_maker
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class CachedMCPToolDef(BaseModel):
|
||||||
|
name: str
|
||||||
|
description: str = ""
|
||||||
|
input_schema: dict[str, Any] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
class CachedMCPTools(BaseModel):
|
||||||
|
discovered_at: datetime
|
||||||
|
server_version: str | None = None
|
||||||
|
server_name: str | None = None
|
||||||
|
transport: str | None = None
|
||||||
|
tools: list[CachedMCPToolDef]
|
||||||
|
|
||||||
|
|
||||||
|
def read_cached_tools(connector: SearchSourceConnector) -> CachedMCPTools | None:
|
||||||
|
"""Return parsed cached tools or ``None`` if missing / corrupt (caller falls back to live discovery)."""
|
||||||
|
cfg = connector.config or {}
|
||||||
|
raw = cfg.get("cached_tools")
|
||||||
|
if not raw or not isinstance(raw, dict):
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
return CachedMCPTools.model_validate(raw)
|
||||||
|
except ValidationError as exc:
|
||||||
|
logger.warning(
|
||||||
|
"MCP connector %d has corrupt cached_tools — falling back to live discovery: %s",
|
||||||
|
connector.id,
|
||||||
|
exc,
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
async def write_cached_tools(
|
||||||
|
connector_id: int,
|
||||||
|
tool_definitions: list[dict[str, Any]],
|
||||||
|
*,
|
||||||
|
server_name: str | None = None,
|
||||||
|
server_version: str | None = None,
|
||||||
|
transport: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Best-effort persist; uses its own session so a write failure cannot poison the caller's transaction."""
|
||||||
|
payload = CachedMCPTools(
|
||||||
|
discovered_at=datetime.now(UTC),
|
||||||
|
server_version=server_version,
|
||||||
|
server_name=server_name,
|
||||||
|
transport=transport,
|
||||||
|
tools=[CachedMCPToolDef.model_validate(td) for td in tool_definitions],
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with async_session_maker() as session:
|
||||||
|
result = await session.execute(
|
||||||
|
select(SearchSourceConnector).filter(
|
||||||
|
SearchSourceConnector.id == connector_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
connector = result.scalars().first()
|
||||||
|
if connector is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
cfg = dict(connector.config or {})
|
||||||
|
cfg["cached_tools"] = payload.model_dump(mode="json")
|
||||||
|
connector.config = cfg
|
||||||
|
flag_modified(connector, "config")
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Persisted cached_tools for MCP connector %d (%d tools)",
|
||||||
|
connector_id,
|
||||||
|
len(payload.tools),
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logger.warning(
|
||||||
|
"Failed to persist cached_tools for MCP connector %d",
|
||||||
|
connector_id,
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
|
@ -0,0 +1,130 @@
|
||||||
|
"""Unit tests for ``mcp_tools_cache``."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from datetime import UTC, datetime
|
||||||
|
from types import SimpleNamespace
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.agents.new_chat.tools.mcp_tools_cache import (
|
||||||
|
CachedMCPToolDef,
|
||||||
|
CachedMCPTools,
|
||||||
|
read_cached_tools,
|
||||||
|
)
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.unit
|
||||||
|
|
||||||
|
|
||||||
|
def _make_connector(config: dict | None) -> SimpleNamespace:
|
||||||
|
return SimpleNamespace(id=42, config=config)
|
||||||
|
|
||||||
|
|
||||||
|
def test_read_returns_none_when_config_is_none() -> None:
|
||||||
|
assert read_cached_tools(_make_connector(None)) is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_read_returns_none_when_cached_tools_missing() -> None:
|
||||||
|
assert read_cached_tools(_make_connector({"server_config": {}})) is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_read_returns_none_when_cached_tools_is_not_a_dict() -> None:
|
||||||
|
assert read_cached_tools(_make_connector({"cached_tools": []})) is None
|
||||||
|
assert read_cached_tools(_make_connector({"cached_tools": "stale"})) is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_read_parses_minimal_valid_payload() -> None:
|
||||||
|
parsed = read_cached_tools(
|
||||||
|
_make_connector(
|
||||||
|
{
|
||||||
|
"cached_tools": {
|
||||||
|
"discovered_at": "2026-05-20T10:00:00+00:00",
|
||||||
|
"tools": [
|
||||||
|
{
|
||||||
|
"name": "list_issues",
|
||||||
|
"description": "List Linear issues",
|
||||||
|
"input_schema": {"type": "object"},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
assert parsed is not None
|
||||||
|
assert parsed.server_version is None
|
||||||
|
assert parsed.server_name is None
|
||||||
|
assert parsed.transport is None
|
||||||
|
assert len(parsed.tools) == 1
|
||||||
|
assert parsed.tools[0].name == "list_issues"
|
||||||
|
|
||||||
|
|
||||||
|
def test_read_parses_full_payload_with_serverinfo() -> None:
|
||||||
|
parsed = read_cached_tools(
|
||||||
|
_make_connector(
|
||||||
|
{
|
||||||
|
"cached_tools": {
|
||||||
|
"discovered_at": "2026-05-20T10:00:00+00:00",
|
||||||
|
"server_version": "1.2.3",
|
||||||
|
"server_name": "atlassian-mcp",
|
||||||
|
"transport": "streamable-http",
|
||||||
|
"tools": [
|
||||||
|
{"name": "create_issue", "input_schema": {}},
|
||||||
|
{"name": "list_issues", "input_schema": {}},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
assert parsed is not None
|
||||||
|
assert parsed.server_version == "1.2.3"
|
||||||
|
assert parsed.server_name == "atlassian-mcp"
|
||||||
|
assert parsed.transport == "streamable-http"
|
||||||
|
assert [t.name for t in parsed.tools] == ["create_issue", "list_issues"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_read_returns_none_for_corrupt_payload(caplog) -> None:
|
||||||
|
parsed = read_cached_tools(
|
||||||
|
_make_connector(
|
||||||
|
{
|
||||||
|
"cached_tools": {
|
||||||
|
"discovered_at": "not-a-date",
|
||||||
|
"tools": "should-be-a-list",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
assert parsed is None
|
||||||
|
assert any("corrupt cached_tools" in r.getMessage() for r in caplog.records)
|
||||||
|
|
||||||
|
|
||||||
|
def test_read_returns_none_when_tools_missing() -> None:
|
||||||
|
parsed = read_cached_tools(
|
||||||
|
_make_connector(
|
||||||
|
{"cached_tools": {"discovered_at": "2026-05-20T10:00:00+00:00"}}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
assert parsed is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_tool_def_defaults_description_and_schema() -> None:
|
||||||
|
td = CachedMCPToolDef.model_validate({"name": "ping"})
|
||||||
|
assert td.description == ""
|
||||||
|
assert td.input_schema == {}
|
||||||
|
|
||||||
|
|
||||||
|
def test_model_dump_json_mode_is_round_trippable() -> None:
|
||||||
|
original = CachedMCPTools(
|
||||||
|
discovered_at=datetime(2026, 5, 20, 10, 0, 0, tzinfo=UTC),
|
||||||
|
server_version="1.2.3",
|
||||||
|
server_name="atlassian-mcp",
|
||||||
|
transport="streamable-http",
|
||||||
|
tools=[CachedMCPToolDef(name="list_issues")],
|
||||||
|
)
|
||||||
|
payload = original.model_dump(mode="json")
|
||||||
|
|
||||||
|
assert payload["discovered_at"] == "2026-05-20T10:00:00Z"
|
||||||
|
assert payload["tools"][0]["name"] == "list_issues"
|
||||||
|
|
||||||
|
reparsed = CachedMCPTools.model_validate(payload)
|
||||||
|
assert reparsed.discovered_at == original.discovered_at
|
||||||
|
assert reparsed.tools[0].name == "list_issues"
|
||||||
Loading…
Add table
Add a link
Reference in a new issue