mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-27 19:25:15 +02:00
refactor(mcp): per-connector cache refresh on lifecycle events
Collapse the invalidate + warmup pair into a single refresh_mcp_tools_cache_for_connector(connector_id, search_space_id) helper and scope live discovery to the one connector that changed instead of the whole search space. - new mcp_tool.discover_single_mcp_connector: load one connector, refresh OAuth if needed, force live MCP discovery so its cached_tools row is rewritten; returned wrappers are discarded since the in-process LRU is rebuilt lazily on the next user query - mcp_tools_cache.refresh_mcp_tools_cache_for_connector: synchronously evicts the per-space LRU (LRU keys cannot scope finer) and schedules the per-connector prefetch via loop.create_task - routes (OAuth callback, MCP POST, MCP PUT) collapse their two back-to-back calls into a single refresh call; DELETE handlers keep using bare invalidate_mcp_tools_cache (nothing to prefetch) No new automated tests: the new functions are I/O glue (DB + network) where mocked unit tests would test implementation rather than behavior. The existing 9 unit tests for the cached_tools data shape are unchanged.
This commit is contained in:
parent
c0aa4261ac
commit
704d1bf18f
4 changed files with 161 additions and 11 deletions
|
|
@ -1007,6 +1007,94 @@ def invalidate_mcp_tools_cache(search_space_id: int | None = None) -> None:
|
||||||
_mcp_tools_cache.clear()
|
_mcp_tools_cache.clear()
|
||||||
|
|
||||||
|
|
||||||
|
async def discover_single_mcp_connector(connector_id: int) -> None:
|
||||||
|
"""Force live MCP discovery for one connector so its ``cached_tools`` row is fresh.
|
||||||
|
|
||||||
|
``_load_http_mcp_tools`` persists ``cached_tools`` as a side effect of any
|
||||||
|
live discovery; passing ``cached_tools=None`` here guarantees we go to the
|
||||||
|
network. The returned wrappers are discarded — the in-process LRU is
|
||||||
|
rebuilt lazily on the next user query. Stdio connectors are not cached and
|
||||||
|
are skipped.
|
||||||
|
"""
|
||||||
|
from app.db import async_session_maker
|
||||||
|
|
||||||
|
started = time.perf_counter()
|
||||||
|
try:
|
||||||
|
async with async_session_maker() as session:
|
||||||
|
connector = await session.get(SearchSourceConnector, connector_id)
|
||||||
|
if connector is None:
|
||||||
|
logger.info(
|
||||||
|
"discover_single_mcp_connector: connector %d not found",
|
||||||
|
connector_id,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
cfg = connector.config or {}
|
||||||
|
server_config = cfg.get("server_config", {})
|
||||||
|
if not server_config or not isinstance(server_config, dict):
|
||||||
|
return
|
||||||
|
|
||||||
|
transport = server_config.get("transport", "stdio")
|
||||||
|
if transport not in ("streamable-http", "http", "sse"):
|
||||||
|
return
|
||||||
|
|
||||||
|
if cfg.get("mcp_oauth"):
|
||||||
|
server_config = await _maybe_refresh_mcp_oauth_token(
|
||||||
|
session, connector, cfg, server_config
|
||||||
|
)
|
||||||
|
cfg = connector.config or {}
|
||||||
|
server_config = _inject_oauth_headers(cfg, server_config)
|
||||||
|
if server_config is None:
|
||||||
|
logger.info(
|
||||||
|
"discover_single_mcp_connector: OAuth token unavailable for connector %d",
|
||||||
|
connector_id,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
ct = (
|
||||||
|
connector.connector_type.value
|
||||||
|
if hasattr(connector.connector_type, "value")
|
||||||
|
else str(connector.connector_type)
|
||||||
|
)
|
||||||
|
svc_cfg = get_service_by_connector_type(ct)
|
||||||
|
allowed_tools = svc_cfg.allowed_tools if svc_cfg else []
|
||||||
|
readonly_tools = svc_cfg.readonly_tools if svc_cfg else frozenset()
|
||||||
|
|
||||||
|
await asyncio.wait_for(
|
||||||
|
_load_http_mcp_tools(
|
||||||
|
connector.id,
|
||||||
|
connector.name,
|
||||||
|
server_config,
|
||||||
|
trusted_tools=cfg.get("trusted_tools", []),
|
||||||
|
allowed_tools=allowed_tools,
|
||||||
|
readonly_tools=readonly_tools,
|
||||||
|
tool_name_prefix=None,
|
||||||
|
is_generic_mcp=svc_cfg is None,
|
||||||
|
bypass_internal_hitl=True,
|
||||||
|
cached_tools=None,
|
||||||
|
),
|
||||||
|
timeout=_MCP_DISCOVERY_TIMEOUT_SECONDS,
|
||||||
|
)
|
||||||
|
|
||||||
|
_perf_log.info(
|
||||||
|
"[mcp_prefetch] connector=%s elapsed=%.3fs",
|
||||||
|
connector_id,
|
||||||
|
time.perf_counter() - started,
|
||||||
|
)
|
||||||
|
except TimeoutError:
|
||||||
|
logger.warning(
|
||||||
|
"discover_single_mcp_connector: connector %d timed out after %ds",
|
||||||
|
connector_id,
|
||||||
|
_MCP_DISCOVERY_TIMEOUT_SECONDS,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logger.warning(
|
||||||
|
"discover_single_mcp_connector: failed for connector %d",
|
||||||
|
connector_id,
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def load_mcp_tools(
|
async def load_mcp_tools(
|
||||||
session: AsyncSession,
|
session: AsyncSession,
|
||||||
search_space_id: int,
|
search_space_id: int,
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,7 @@
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
from datetime import UTC, datetime
|
from datetime import UTC, datetime
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
@ -14,6 +15,8 @@ from app.db import SearchSourceConnector, async_session_maker
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_pending_prefetch_tasks: set[asyncio.Task[None]] = set()
|
||||||
|
|
||||||
|
|
||||||
class CachedMCPToolDef(BaseModel):
|
class CachedMCPToolDef(BaseModel):
|
||||||
name: str
|
name: str
|
||||||
|
|
@ -92,3 +95,51 @@ async def write_cached_tools(
|
||||||
connector_id,
|
connector_id,
|
||||||
exc_info=True,
|
exc_info=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def refresh_mcp_tools_cache_for_connector(
|
||||||
|
connector_id: int,
|
||||||
|
search_space_id: int,
|
||||||
|
) -> None:
|
||||||
|
"""Maintain the MCP tool cache after a single-connector lifecycle event.
|
||||||
|
|
||||||
|
Synchronously evicts the in-process LRU for the connector's search space
|
||||||
|
(LRU keys are per-space, so eviction cannot be scoped finer), then schedules
|
||||||
|
a background live discovery for this connector alone so its persisted
|
||||||
|
``cached_tools`` row is refreshed before the next user query.
|
||||||
|
|
||||||
|
Idempotent. Eviction is best-effort; prefetch is best-effort and only runs
|
||||||
|
when an event loop is available. Neither path raises.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from app.agents.new_chat.tools.mcp_tool import invalidate_mcp_tools_cache
|
||||||
|
|
||||||
|
invalidate_mcp_tools_cache(search_space_id)
|
||||||
|
except Exception:
|
||||||
|
logger.debug(
|
||||||
|
"MCP in-process cache eviction skipped for space %d",
|
||||||
|
search_space_id,
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
except RuntimeError:
|
||||||
|
return
|
||||||
|
|
||||||
|
task = loop.create_task(_run_connector_prefetch(connector_id))
|
||||||
|
_pending_prefetch_tasks.add(task)
|
||||||
|
task.add_done_callback(_pending_prefetch_tasks.discard)
|
||||||
|
|
||||||
|
|
||||||
|
async def _run_connector_prefetch(connector_id: int) -> None:
|
||||||
|
from app.agents.new_chat.tools.mcp_tool import discover_single_mcp_connector
|
||||||
|
|
||||||
|
try:
|
||||||
|
await discover_single_mcp_connector(connector_id)
|
||||||
|
except Exception:
|
||||||
|
logger.warning(
|
||||||
|
"MCP background prefetch failed for connector_id=%d",
|
||||||
|
connector_id,
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
|
|
||||||
|
|
@ -428,7 +428,7 @@ async def mcp_oauth_callback(
|
||||||
await session.commit()
|
await session.commit()
|
||||||
await session.refresh(db_connector)
|
await session.refresh(db_connector)
|
||||||
|
|
||||||
_invalidate_cache(space_id)
|
_refresh_mcp_cache(db_connector.id, space_id)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"Re-authenticated %s MCP connector %s for user %s",
|
"Re-authenticated %s MCP connector %s for user %s",
|
||||||
|
|
@ -481,7 +481,7 @@ async def mcp_oauth_callback(
|
||||||
detail="A connector for this service already exists.",
|
detail="A connector for this service already exists.",
|
||||||
) from e
|
) from e
|
||||||
|
|
||||||
_invalidate_cache(space_id)
|
_refresh_mcp_cache(new_connector.id, space_id)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"Created %s MCP connector %s for user %s in space %s",
|
"Created %s MCP connector %s for user %s in space %s",
|
||||||
|
|
@ -658,10 +658,17 @@ async def reauth_mcp_service(
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
def _invalidate_cache(space_id: int) -> None:
|
def _refresh_mcp_cache(connector_id: int, space_id: int) -> None:
|
||||||
try:
|
"""Evict the in-process MCP tool LRU and schedule background prefetch.
|
||||||
from app.agents.new_chat.tools.mcp_tool import invalidate_mcp_tools_cache
|
|
||||||
|
|
||||||
invalidate_mcp_tools_cache(space_id)
|
Wraps :func:`refresh_mcp_tools_cache_for_connector` so any failure is
|
||||||
|
isolated from the OAuth response flow.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from app.agents.new_chat.tools.mcp_tools_cache import (
|
||||||
|
refresh_mcp_tools_cache_for_connector,
|
||||||
|
)
|
||||||
|
|
||||||
|
refresh_mcp_tools_cache_for_connector(connector_id, space_id)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.debug("MCP cache invalidation skipped", exc_info=True)
|
logger.debug("MCP cache refresh skipped", exc_info=True)
|
||||||
|
|
|
||||||
|
|
@ -2650,9 +2650,11 @@ async def create_mcp_connector(
|
||||||
f"for user {user.id} in search space {search_space_id}"
|
f"for user {user.id} in search space {search_space_id}"
|
||||||
)
|
)
|
||||||
|
|
||||||
from app.agents.new_chat.tools.mcp_tool import invalidate_mcp_tools_cache
|
from app.agents.new_chat.tools.mcp_tools_cache import (
|
||||||
|
refresh_mcp_tools_cache_for_connector,
|
||||||
|
)
|
||||||
|
|
||||||
invalidate_mcp_tools_cache(search_space_id)
|
refresh_mcp_tools_cache_for_connector(db_connector.id, search_space_id)
|
||||||
|
|
||||||
connector_read = SearchSourceConnectorRead.model_validate(db_connector)
|
connector_read = SearchSourceConnectorRead.model_validate(db_connector)
|
||||||
return MCPConnectorRead.from_connector(connector_read)
|
return MCPConnectorRead.from_connector(connector_read)
|
||||||
|
|
@ -2828,9 +2830,11 @@ async def update_mcp_connector(
|
||||||
|
|
||||||
logger.info(f"Updated MCP connector {connector_id}")
|
logger.info(f"Updated MCP connector {connector_id}")
|
||||||
|
|
||||||
from app.agents.new_chat.tools.mcp_tool import invalidate_mcp_tools_cache
|
from app.agents.new_chat.tools.mcp_tools_cache import (
|
||||||
|
refresh_mcp_tools_cache_for_connector,
|
||||||
|
)
|
||||||
|
|
||||||
invalidate_mcp_tools_cache(connector.search_space_id)
|
refresh_mcp_tools_cache_for_connector(connector.id, connector.search_space_id)
|
||||||
|
|
||||||
connector_read = SearchSourceConnectorRead.model_validate(connector)
|
connector_read = SearchSourceConnectorRead.model_validate(connector)
|
||||||
return MCPConnectorRead.from_connector(connector_read)
|
return MCPConnectorRead.from_connector(connector_read)
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue