diff --git a/surfsense_backend/app/agents/new_chat/tools/mcp_tool.py b/surfsense_backend/app/agents/new_chat/tools/mcp_tool.py index b3c26f331..3d4679fb8 100644 --- a/surfsense_backend/app/agents/new_chat/tools/mcp_tool.py +++ b/surfsense_backend/app/agents/new_chat/tools/mcp_tool.py @@ -36,6 +36,11 @@ from sqlalchemy.ext.asyncio import AsyncSession from app.agents.new_chat.middleware.dedup_tool_calls import dedup_key_full_args from app.agents.new_chat.tools.hitl import request_approval from app.agents.new_chat.tools.mcp_client import MCPClient +from app.agents.new_chat.tools.mcp_tools_cache import ( + CachedMCPTools, + read_cached_tools, + write_cached_tools, +) from app.db import SearchSourceConnector from app.services.mcp_oauth.registry import MCP_SERVICES, get_service_by_connector_type from app.utils.perf import get_perf_logger @@ -516,6 +521,7 @@ async def _load_http_mcp_tools( is_generic_mcp: bool = False, *, bypass_internal_hitl: bool = False, + cached_tools: CachedMCPTools | None = None, ) -> list[StructuredTool]: """Load tools from an HTTP-based MCP server. @@ -526,6 +532,8 @@ async def _load_http_mcp_tools( readonly_tools: Tool names that skip HITL approval (read-only operations). tool_name_prefix: If set, each tool name is prefixed for multi-account disambiguation (e.g. ``linear_25``). + cached_tools: If provided, skip live discovery and rebuild wrappers + from the persisted definitions. """ tools: list[StructuredTool] = [] @@ -549,15 +557,23 @@ async def _load_http_mcp_tools( allowed_set = set(allowed_tools) if allowed_tools else None - async def _discover(disc_headers: dict[str, str]) -> list[dict[str, Any]]: - """Connect, initialize, and list tools from the MCP server.""" + async def _discover( + disc_headers: dict[str, str], + ) -> tuple[dict[str, str | None], list[dict[str, Any]]]: + """Connect, initialize, and list tools — returns (serverInfo, tools).""" async with ( streamablehttp_client(url, headers=disc_headers) as (read, write, _), ClientSession(read, write) as session, ): - await session.initialize() + init_result = await session.initialize() + server_info: dict[str, str | None] = {"name": None, "version": None} + si = getattr(init_result, "serverInfo", None) + if si is not None: + server_info["name"] = getattr(si, "name", None) + server_info["version"] = getattr(si, "version", None) + response = await session.list_tools() - return [ + return server_info, [ { "name": tool.name, "description": tool.description or "", @@ -568,47 +584,65 @@ async def _load_http_mcp_tools( for tool in response.tools ] - try: - tool_definitions = await _discover(headers) - except Exception as first_err: - if not _is_auth_error(first_err) or connector_id is None: - logger.exception( - "Failed to connect to HTTP MCP server at '%s' (connector %d): %s", - url, - connector_id, - first_err, - ) - return tools - - logger.warning( - "HTTP MCP discovery for connector %d got 401 — attempting token refresh", - connector_id, - ) - fresh_headers = await _force_refresh_and_get_headers(connector_id) - if fresh_headers is None: - await _mark_connector_auth_expired(connector_id) - logger.error( - "HTTP MCP discovery for connector %d: token refresh failed, marking auth_expired", - connector_id, - ) - return tools - + if cached_tools is not None: + tool_definitions = [ + { + "name": td.name, + "description": td.description, + "input_schema": td.input_schema, + } + for td in cached_tools.tools + ] + else: try: - tool_definitions = await _discover(fresh_headers) - headers = fresh_headers - logger.info( - "HTTP MCP discovery for connector %d succeeded after 401 recovery", + server_info, tool_definitions = await _discover(headers) + except Exception as first_err: + if not _is_auth_error(first_err) or connector_id is None: + logger.exception( + "Failed to connect to HTTP MCP server at '%s' (connector %d): %s", + url, + connector_id, + first_err, + ) + return tools + + logger.warning( + "HTTP MCP discovery for connector %d got 401 — attempting token refresh", connector_id, ) - except Exception as retry_err: - logger.exception( - "HTTP MCP discovery for connector %d still failing after refresh: %s", - connector_id, - retry_err, - ) - if _is_auth_error(retry_err): + fresh_headers = await _force_refresh_and_get_headers(connector_id) + if fresh_headers is None: await _mark_connector_auth_expired(connector_id) - return tools + logger.error( + "HTTP MCP discovery for connector %d: token refresh failed, marking auth_expired", + connector_id, + ) + return tools + + try: + server_info, tool_definitions = await _discover(fresh_headers) + headers = fresh_headers + logger.info( + "HTTP MCP discovery for connector %d succeeded after 401 recovery", + connector_id, + ) + except Exception as retry_err: + logger.exception( + "HTTP MCP discovery for connector %d still failing after refresh: %s", + connector_id, + retry_err, + ) + if _is_auth_error(retry_err): + await _mark_connector_auth_expired(connector_id) + return tools + + await write_cached_tools( + connector_id, + tool_definitions, + server_name=server_info.get("name"), + server_version=server_info.get("version"), + transport=server_config.get("transport", "streamable-http"), + ) total_discovered = len(tool_definitions) @@ -1099,6 +1133,7 @@ async def load_mcp_tools( "tool_name_prefix": tool_name_prefix, "transport": server_config.get("transport", "stdio"), "is_generic_mcp": svc_cfg is None, + "cached_tools": read_cached_tools(connector), } ) @@ -1112,6 +1147,7 @@ async def load_mcp_tools( async def _discover_one(task: dict[str, Any]) -> list[StructuredTool]: discover_start = time.perf_counter() transport = task["transport"] + cached_tools = task.get("cached_tools") try: if transport in ("streamable-http", "http", "sse"): result = await asyncio.wait_for( @@ -1125,6 +1161,7 @@ async def load_mcp_tools( tool_name_prefix=task["tool_name_prefix"], is_generic_mcp=task.get("is_generic_mcp", False), bypass_internal_hitl=bypass_internal_hitl, + cached_tools=cached_tools, ), timeout=_MCP_DISCOVERY_TIMEOUT_SECONDS, ) @@ -1140,12 +1177,13 @@ async def load_mcp_tools( timeout=_MCP_DISCOVERY_TIMEOUT_SECONDS, ) _perf_log.info( - "[mcp_discover] connector=%s name=%r transport=%s tools=%d elapsed=%.3fs", + "[mcp_discover] connector=%s name=%r transport=%s tools=%d elapsed=%.3fs cache=%s", task["connector_id"], task["connector_name"], transport, len(result), time.perf_counter() - discover_start, + "hit" if cached_tools is not None else "miss", ) return result except TimeoutError: diff --git a/surfsense_backend/app/agents/new_chat/tools/mcp_tools_cache.py b/surfsense_backend/app/agents/new_chat/tools/mcp_tools_cache.py new file mode 100644 index 000000000..3c79ed1d3 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/tools/mcp_tools_cache.py @@ -0,0 +1,94 @@ +"""Persist MCP ``list_tools`` results in ``SearchSourceConnector.config.cached_tools``.""" + +from __future__ import annotations + +import logging +from datetime import UTC, datetime +from typing import Any + +from pydantic import BaseModel, Field, ValidationError +from sqlalchemy import select +from sqlalchemy.orm.attributes import flag_modified + +from app.db import SearchSourceConnector, async_session_maker + +logger = logging.getLogger(__name__) + + +class CachedMCPToolDef(BaseModel): + name: str + description: str = "" + input_schema: dict[str, Any] = Field(default_factory=dict) + + +class CachedMCPTools(BaseModel): + discovered_at: datetime + server_version: str | None = None + server_name: str | None = None + transport: str | None = None + tools: list[CachedMCPToolDef] + + +def read_cached_tools(connector: SearchSourceConnector) -> CachedMCPTools | None: + """Return parsed cached tools or ``None`` if missing / corrupt (caller falls back to live discovery).""" + cfg = connector.config or {} + raw = cfg.get("cached_tools") + if not raw or not isinstance(raw, dict): + return None + + try: + return CachedMCPTools.model_validate(raw) + except ValidationError as exc: + logger.warning( + "MCP connector %d has corrupt cached_tools — falling back to live discovery: %s", + connector.id, + exc, + ) + return None + + +async def write_cached_tools( + connector_id: int, + tool_definitions: list[dict[str, Any]], + *, + server_name: str | None = None, + server_version: str | None = None, + transport: str | None = None, +) -> None: + """Best-effort persist; uses its own session so a write failure cannot poison the caller's transaction.""" + payload = CachedMCPTools( + discovered_at=datetime.now(UTC), + server_version=server_version, + server_name=server_name, + transport=transport, + tools=[CachedMCPToolDef.model_validate(td) for td in tool_definitions], + ) + + try: + async with async_session_maker() as session: + result = await session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.id == connector_id, + ) + ) + connector = result.scalars().first() + if connector is None: + return + + cfg = dict(connector.config or {}) + cfg["cached_tools"] = payload.model_dump(mode="json") + connector.config = cfg + flag_modified(connector, "config") + await session.commit() + + logger.info( + "Persisted cached_tools for MCP connector %d (%d tools)", + connector_id, + len(payload.tools), + ) + except Exception: + logger.warning( + "Failed to persist cached_tools for MCP connector %d", + connector_id, + exc_info=True, + ) diff --git a/surfsense_backend/tests/unit/agents/new_chat/tools/test_mcp_tools_cache.py b/surfsense_backend/tests/unit/agents/new_chat/tools/test_mcp_tools_cache.py new file mode 100644 index 000000000..bae97ba9f --- /dev/null +++ b/surfsense_backend/tests/unit/agents/new_chat/tools/test_mcp_tools_cache.py @@ -0,0 +1,130 @@ +"""Unit tests for ``mcp_tools_cache``.""" + +from __future__ import annotations + +from datetime import UTC, datetime +from types import SimpleNamespace + +import pytest + +from app.agents.new_chat.tools.mcp_tools_cache import ( + CachedMCPToolDef, + CachedMCPTools, + read_cached_tools, +) + +pytestmark = pytest.mark.unit + + +def _make_connector(config: dict | None) -> SimpleNamespace: + return SimpleNamespace(id=42, config=config) + + +def test_read_returns_none_when_config_is_none() -> None: + assert read_cached_tools(_make_connector(None)) is None + + +def test_read_returns_none_when_cached_tools_missing() -> None: + assert read_cached_tools(_make_connector({"server_config": {}})) is None + + +def test_read_returns_none_when_cached_tools_is_not_a_dict() -> None: + assert read_cached_tools(_make_connector({"cached_tools": []})) is None + assert read_cached_tools(_make_connector({"cached_tools": "stale"})) is None + + +def test_read_parses_minimal_valid_payload() -> None: + parsed = read_cached_tools( + _make_connector( + { + "cached_tools": { + "discovered_at": "2026-05-20T10:00:00+00:00", + "tools": [ + { + "name": "list_issues", + "description": "List Linear issues", + "input_schema": {"type": "object"}, + } + ], + } + } + ) + ) + assert parsed is not None + assert parsed.server_version is None + assert parsed.server_name is None + assert parsed.transport is None + assert len(parsed.tools) == 1 + assert parsed.tools[0].name == "list_issues" + + +def test_read_parses_full_payload_with_serverinfo() -> None: + parsed = read_cached_tools( + _make_connector( + { + "cached_tools": { + "discovered_at": "2026-05-20T10:00:00+00:00", + "server_version": "1.2.3", + "server_name": "atlassian-mcp", + "transport": "streamable-http", + "tools": [ + {"name": "create_issue", "input_schema": {}}, + {"name": "list_issues", "input_schema": {}}, + ], + } + } + ) + ) + assert parsed is not None + assert parsed.server_version == "1.2.3" + assert parsed.server_name == "atlassian-mcp" + assert parsed.transport == "streamable-http" + assert [t.name for t in parsed.tools] == ["create_issue", "list_issues"] + + +def test_read_returns_none_for_corrupt_payload(caplog) -> None: + parsed = read_cached_tools( + _make_connector( + { + "cached_tools": { + "discovered_at": "not-a-date", + "tools": "should-be-a-list", + } + } + ) + ) + assert parsed is None + assert any("corrupt cached_tools" in r.getMessage() for r in caplog.records) + + +def test_read_returns_none_when_tools_missing() -> None: + parsed = read_cached_tools( + _make_connector( + {"cached_tools": {"discovered_at": "2026-05-20T10:00:00+00:00"}} + ) + ) + assert parsed is None + + +def test_tool_def_defaults_description_and_schema() -> None: + td = CachedMCPToolDef.model_validate({"name": "ping"}) + assert td.description == "" + assert td.input_schema == {} + + +def test_model_dump_json_mode_is_round_trippable() -> None: + original = CachedMCPTools( + discovered_at=datetime(2026, 5, 20, 10, 0, 0, tzinfo=UTC), + server_version="1.2.3", + server_name="atlassian-mcp", + transport="streamable-http", + tools=[CachedMCPToolDef(name="list_issues")], + ) + payload = original.model_dump(mode="json") + + assert payload["discovered_at"] == "2026-05-20T10:00:00Z" + assert payload["tools"][0]["name"] == "list_issues" + + reparsed = CachedMCPTools.model_validate(payload) + assert reparsed.discovered_at == original.discovered_at + assert reparsed.tools[0].name == "list_issues"