mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-25 19:15:18 +02:00
refactor(mcp): per-connector cache refresh on lifecycle events
Collapse the invalidate + warmup pair into a single refresh_mcp_tools_cache_for_connector(connector_id, search_space_id) helper and scope live discovery to the one connector that changed instead of the whole search space. - new mcp_tool.discover_single_mcp_connector: load one connector, refresh OAuth if needed, force live MCP discovery so its cached_tools row is rewritten; returned wrappers are discarded since the in-process LRU is rebuilt lazily on the next user query - mcp_tools_cache.refresh_mcp_tools_cache_for_connector: synchronously evicts the per-space LRU (LRU keys cannot scope finer) and schedules the per-connector prefetch via loop.create_task - routes (OAuth callback, MCP POST, MCP PUT) collapse their two back-to-back calls into a single refresh call; DELETE handlers keep using bare invalidate_mcp_tools_cache (nothing to prefetch) No new automated tests: the new functions are I/O glue (DB + network) where mocked unit tests would test implementation rather than behavior. The existing 9 unit tests for the cached_tools data shape are unchanged.
This commit is contained in:
parent
c0aa4261ac
commit
704d1bf18f
4 changed files with 161 additions and 11 deletions
|
|
@ -1007,6 +1007,94 @@ def invalidate_mcp_tools_cache(search_space_id: int | None = None) -> None:
|
|||
_mcp_tools_cache.clear()
|
||||
|
||||
|
||||
async def discover_single_mcp_connector(connector_id: int) -> None:
|
||||
"""Force live MCP discovery for one connector so its ``cached_tools`` row is fresh.
|
||||
|
||||
``_load_http_mcp_tools`` persists ``cached_tools`` as a side effect of any
|
||||
live discovery; passing ``cached_tools=None`` here guarantees we go to the
|
||||
network. The returned wrappers are discarded — the in-process LRU is
|
||||
rebuilt lazily on the next user query. Stdio connectors are not cached and
|
||||
are skipped.
|
||||
"""
|
||||
from app.db import async_session_maker
|
||||
|
||||
started = time.perf_counter()
|
||||
try:
|
||||
async with async_session_maker() as session:
|
||||
connector = await session.get(SearchSourceConnector, connector_id)
|
||||
if connector is None:
|
||||
logger.info(
|
||||
"discover_single_mcp_connector: connector %d not found",
|
||||
connector_id,
|
||||
)
|
||||
return
|
||||
|
||||
cfg = connector.config or {}
|
||||
server_config = cfg.get("server_config", {})
|
||||
if not server_config or not isinstance(server_config, dict):
|
||||
return
|
||||
|
||||
transport = server_config.get("transport", "stdio")
|
||||
if transport not in ("streamable-http", "http", "sse"):
|
||||
return
|
||||
|
||||
if cfg.get("mcp_oauth"):
|
||||
server_config = await _maybe_refresh_mcp_oauth_token(
|
||||
session, connector, cfg, server_config
|
||||
)
|
||||
cfg = connector.config or {}
|
||||
server_config = _inject_oauth_headers(cfg, server_config)
|
||||
if server_config is None:
|
||||
logger.info(
|
||||
"discover_single_mcp_connector: OAuth token unavailable for connector %d",
|
||||
connector_id,
|
||||
)
|
||||
return
|
||||
|
||||
ct = (
|
||||
connector.connector_type.value
|
||||
if hasattr(connector.connector_type, "value")
|
||||
else str(connector.connector_type)
|
||||
)
|
||||
svc_cfg = get_service_by_connector_type(ct)
|
||||
allowed_tools = svc_cfg.allowed_tools if svc_cfg else []
|
||||
readonly_tools = svc_cfg.readonly_tools if svc_cfg else frozenset()
|
||||
|
||||
await asyncio.wait_for(
|
||||
_load_http_mcp_tools(
|
||||
connector.id,
|
||||
connector.name,
|
||||
server_config,
|
||||
trusted_tools=cfg.get("trusted_tools", []),
|
||||
allowed_tools=allowed_tools,
|
||||
readonly_tools=readonly_tools,
|
||||
tool_name_prefix=None,
|
||||
is_generic_mcp=svc_cfg is None,
|
||||
bypass_internal_hitl=True,
|
||||
cached_tools=None,
|
||||
),
|
||||
timeout=_MCP_DISCOVERY_TIMEOUT_SECONDS,
|
||||
)
|
||||
|
||||
_perf_log.info(
|
||||
"[mcp_prefetch] connector=%s elapsed=%.3fs",
|
||||
connector_id,
|
||||
time.perf_counter() - started,
|
||||
)
|
||||
except TimeoutError:
|
||||
logger.warning(
|
||||
"discover_single_mcp_connector: connector %d timed out after %ds",
|
||||
connector_id,
|
||||
_MCP_DISCOVERY_TIMEOUT_SECONDS,
|
||||
)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"discover_single_mcp_connector: failed for connector %d",
|
||||
connector_id,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
|
||||
async def load_mcp_tools(
|
||||
session: AsyncSession,
|
||||
search_space_id: int,
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
|
|
@ -14,6 +15,8 @@ from app.db import SearchSourceConnector, async_session_maker
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_pending_prefetch_tasks: set[asyncio.Task[None]] = set()
|
||||
|
||||
|
||||
class CachedMCPToolDef(BaseModel):
|
||||
name: str
|
||||
|
|
@ -92,3 +95,51 @@ async def write_cached_tools(
|
|||
connector_id,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
|
||||
def refresh_mcp_tools_cache_for_connector(
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
) -> None:
|
||||
"""Maintain the MCP tool cache after a single-connector lifecycle event.
|
||||
|
||||
Synchronously evicts the in-process LRU for the connector's search space
|
||||
(LRU keys are per-space, so eviction cannot be scoped finer), then schedules
|
||||
a background live discovery for this connector alone so its persisted
|
||||
``cached_tools`` row is refreshed before the next user query.
|
||||
|
||||
Idempotent. Eviction is best-effort; prefetch is best-effort and only runs
|
||||
when an event loop is available. Neither path raises.
|
||||
"""
|
||||
try:
|
||||
from app.agents.new_chat.tools.mcp_tool import invalidate_mcp_tools_cache
|
||||
|
||||
invalidate_mcp_tools_cache(search_space_id)
|
||||
except Exception:
|
||||
logger.debug(
|
||||
"MCP in-process cache eviction skipped for space %d",
|
||||
search_space_id,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
return
|
||||
|
||||
task = loop.create_task(_run_connector_prefetch(connector_id))
|
||||
_pending_prefetch_tasks.add(task)
|
||||
task.add_done_callback(_pending_prefetch_tasks.discard)
|
||||
|
||||
|
||||
async def _run_connector_prefetch(connector_id: int) -> None:
|
||||
from app.agents.new_chat.tools.mcp_tool import discover_single_mcp_connector
|
||||
|
||||
try:
|
||||
await discover_single_mcp_connector(connector_id)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"MCP background prefetch failed for connector_id=%d",
|
||||
connector_id,
|
||||
exc_info=True,
|
||||
)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue