diff --git a/surfsense_backend/scripts/probe_mcp_session_lifetime.py b/surfsense_backend/scripts/probe_mcp_session_lifetime.py new file mode 100644 index 000000000..66be5bc14 --- /dev/null +++ b/surfsense_backend/scripts/probe_mcp_session_lifetime.py @@ -0,0 +1,563 @@ +"""Probe MCP server session lifetime / staleness behavior — read-only. + +Goal +---- +Empirically answer two questions for our actual third-party MCP servers +(Atlassian, Linear, Slack, ClickUp, Airtable, ...): + +1. How expensive is the initial ``initialize`` handshake (``init=`` cost)? +2. How long can a ``ClientSession`` sit idle and still survive a + subsequent ``list_tools()`` call? + +This script informs the design choice between + +* per-call sessions (current, ~1s init tax per call), +* per-turn session reuse (LangChain-style, holds a session for the + duration of a chat turn), +* a long-lived session pool (IBM-style, sessions reused across turns). + +The probe is read-only: it only ever calls ``session.list_tools()``, +which is the safest MCP method. No tool calls against user data are +performed. + +Usage +----- +Run from the repo root or from ``surfsense_backend/``:: + + uv run python -m scripts.probe_mcp_session_lifetime + uv run python -m scripts.probe_mcp_session_lifetime --quick + uv run python -m scripts.probe_mcp_session_lifetime --connectors 7,19,20 + uv run python -m scripts.probe_mcp_session_lifetime --intervals 5,30,60,300 + +Output +------ +* Live progress to stderr (``[connector=7 t=+30s] OK 0.142s``). +* Final per-connector table to stdout. +* Raw results JSON to ``./mcp_session_probe_.json``. + +The default test reaches 1800s of idle (~30 min). Use ``--quick`` to +stop at 60s for fast iteration. All connectors probe concurrently so +total wall-clock time equals the longest interval, not the sum. +""" + +from __future__ import annotations + +import argparse +import asyncio +import json +import logging +import os +import sys +import time +from dataclasses import asdict, dataclass, field +from datetime import datetime +from typing import Any + +_HERE = os.path.dirname(os.path.abspath(__file__)) +_BACKEND_ROOT = os.path.dirname(_HERE) +if _BACKEND_ROOT not in sys.path: + sys.path.insert(0, _BACKEND_ROOT) + +import httpx # noqa: E402 +from mcp import ClientSession # noqa: E402 +from mcp.client.streamable_http import streamable_http_client # noqa: E402 +from sqlalchemy import cast, select # noqa: E402 +from sqlalchemy.dialects.postgresql import JSONB # noqa: E402 + +from app.agents.new_chat.tools.mcp_tool import ( # noqa: E402 + _inject_oauth_headers, + _maybe_refresh_mcp_oauth_token, +) +from app.db import SearchSourceConnector, async_session_maker # noqa: E402 + +logging.basicConfig( + level=logging.WARNING, + format="%(asctime)s | %(levelname)-8s | %(name)s | %(message)s", + stream=sys.stderr, +) +logging.getLogger("httpx").setLevel(logging.ERROR) +logging.getLogger("mcp").setLevel(logging.ERROR) +logger = logging.getLogger("mcp_probe") +logger.setLevel(logging.INFO) + + +DEFAULT_INTERVALS_SECONDS = [5, 30, 60, 300, 900, 1800] +QUICK_INTERVALS_SECONDS = [5, 30, 60] +PER_CALL_TIMEOUT_SECONDS = 60.0 + + +@dataclass +class CheckpointResult: + """One ``list_tools()`` call against a long-lived session.""" + + idle_seconds_target: int + elapsed_since_open_seconds: float + elapsed_since_last_call_seconds: float + success: bool + latency_seconds: float | None + tools_returned: int | None + error_type: str | None + error_message: str | None + + +@dataclass +class ConnectorProbeResult: + """Per-connector aggregated probe outcome.""" + + connector_id: int + connector_name: str + connector_type: str + url: str + init_latency_seconds: float | None + first_call_latency_seconds: float | None + checkpoints: list[CheckpointResult] = field(default_factory=list) + fatal_error: str | None = None + + +# --------------------------------------------------------------------------- +# Connector loading + auth +# --------------------------------------------------------------------------- + + +async def _fetch_connectors( + connector_ids: list[int] | None, +) -> list[SearchSourceConnector]: + """Pull every MCP-shaped connector (or only the requested IDs).""" + async with async_session_maker() as session: + stmt = select(SearchSourceConnector).filter( + cast(SearchSourceConnector.config, JSONB).has_key("server_config"), + ) + if connector_ids: + stmt = stmt.filter(SearchSourceConnector.id.in_(connector_ids)) + result = await session.execute(stmt) + connectors = list(result.scalars()) + + if connector_ids: + found_ids = {c.id for c in connectors} + missing = [cid for cid in connector_ids if cid not in found_ids] + if missing: + logger.warning("Requested connector IDs not found: %s", missing) + return connectors + + +async def _resolve_authed_server_config( + connector: SearchSourceConnector, +) -> dict[str, Any] | None: + """Refresh OAuth (if needed) and return a server_config with auth headers. + + Returns ``None`` if the connector cannot be probed (missing url, + decrypt failure, no refresh token, etc.). + """ + cfg = connector.config or {} + server_config = cfg.get("server_config", {}) + if not isinstance(server_config, dict): + return None + + if cfg.get("mcp_oauth"): + async with async_session_maker() as session: + attached = await session.get(SearchSourceConnector, connector.id) + if attached is None: + return None + refreshed = await _maybe_refresh_mcp_oauth_token( + session, + attached, + attached.config or {}, + server_config, + ) + attached_cfg = attached.config or {} + server_config = _inject_oauth_headers(attached_cfg, refreshed) + if server_config is None: + return None + return server_config + + +# --------------------------------------------------------------------------- +# The actual probe +# --------------------------------------------------------------------------- + + +def _classify_error(exc: BaseException) -> tuple[str, str]: + """Return ``(short_label, human_message)`` for a failed call.""" + name = type(exc).__name__ + msg = str(exc) or repr(exc) + if isinstance(exc, asyncio.TimeoutError): + return "timeout", f"call exceeded {PER_CALL_TIMEOUT_SECONDS}s" + if "404" in msg or "Not Found" in msg or "session" in msg.lower(): + return "session_expired", msg + if "401" in msg or "Unauthorized" in msg: + return "auth_401", msg + if "ClosedResourceError" in name or "Closed" in name: + return "stream_closed", msg + if "Connection" in name or "ConnectError" in name: + return "connection_error", msg + return name, msg + + +async def _probe_one_connector( + connector: SearchSourceConnector, + intervals: list[int], +) -> ConnectorProbeResult: + """Open a single long-lived session, call ``list_tools`` at each interval.""" + connector_type = ( + connector.connector_type.value + if hasattr(connector.connector_type, "value") + else str(connector.connector_type) + ) + server_config = await _resolve_authed_server_config(connector) + if server_config is None: + return ConnectorProbeResult( + connector_id=connector.id, + connector_name=connector.name, + connector_type=connector_type, + url="(unresolved)", + init_latency_seconds=None, + first_call_latency_seconds=None, + fatal_error="failed_to_resolve_server_config", + ) + + url = server_config.get("url") + headers = server_config.get("headers", {}) + if not url: + return ConnectorProbeResult( + connector_id=connector.id, + connector_name=connector.name, + connector_type=connector_type, + url="(missing)", + init_latency_seconds=None, + first_call_latency_seconds=None, + fatal_error="missing_url", + ) + + transport = server_config.get("transport", "streamable-http") + if transport not in ("streamable-http", "http", "sse"): + return ConnectorProbeResult( + connector_id=connector.id, + connector_name=connector.name, + connector_type=connector_type, + url=url, + init_latency_seconds=None, + first_call_latency_seconds=None, + fatal_error=f"unsupported_transport:{transport}", + ) + + result = ConnectorProbeResult( + connector_id=connector.id, + connector_name=connector.name, + connector_type=connector_type, + url=url, + init_latency_seconds=None, + first_call_latency_seconds=None, + ) + + open_started = time.perf_counter() + last_call_at: float | None = None + + # Manually drive the context-manager protocol so the session lives + # across our sleep intervals. ``streamable_http_client`` spawns a + # background task for the SSE receive loop; ``ClientSession`` spawns + # another for request multiplexing. We must close them in reverse order. + http_client = httpx.AsyncClient(headers=headers, timeout=PER_CALL_TIMEOUT_SECONDS) + transport_cm = None + session_cm = None + session = None + try: + transport_cm = streamable_http_client(url, http_client=http_client) + read, write, _ = await transport_cm.__aenter__() + session_cm = ClientSession(read, write) + session = await session_cm.__aenter__() + + init_start = time.perf_counter() + await asyncio.wait_for(session.initialize(), timeout=PER_CALL_TIMEOUT_SECONDS) + result.init_latency_seconds = time.perf_counter() - init_start + logger.info( + "[connector=%s name=%r] init=%.3fs", + connector.id, + connector.name, + result.init_latency_seconds, + ) + + first_call_start = time.perf_counter() + first_response = await asyncio.wait_for( + session.list_tools(), timeout=PER_CALL_TIMEOUT_SECONDS + ) + result.first_call_latency_seconds = time.perf_counter() - first_call_start + last_call_at = time.perf_counter() + logger.info( + "[connector=%s name=%r] first_call=%.3fs tools=%d", + connector.id, + connector.name, + result.first_call_latency_seconds, + len(first_response.tools), + ) + + for interval in intervals: + target_elapsed = open_started + ( + result.init_latency_seconds + result.first_call_latency_seconds + interval + ) + sleep_for = max(0.0, target_elapsed - time.perf_counter()) + await asyncio.sleep(sleep_for) + + call_start = time.perf_counter() + elapsed_since_open = call_start - open_started + elapsed_since_last = call_start - (last_call_at or call_start) + try: + response = await asyncio.wait_for( + session.list_tools(), timeout=PER_CALL_TIMEOUT_SECONDS + ) + latency = time.perf_counter() - call_start + last_call_at = time.perf_counter() + checkpoint = CheckpointResult( + idle_seconds_target=interval, + elapsed_since_open_seconds=round(elapsed_since_open, 3), + elapsed_since_last_call_seconds=round(elapsed_since_last, 3), + success=True, + latency_seconds=round(latency, 3), + tools_returned=len(response.tools), + error_type=None, + error_message=None, + ) + logger.info( + "[connector=%s t=+%ds] OK %.3fs (tools=%d)", + connector.id, + interval, + latency, + len(response.tools), + ) + result.checkpoints.append(checkpoint) + except Exception as exc: # noqa: BLE001 + label, msg = _classify_error(exc) + latency_at_failure = time.perf_counter() - call_start + checkpoint = CheckpointResult( + idle_seconds_target=interval, + elapsed_since_open_seconds=round(elapsed_since_open, 3), + elapsed_since_last_call_seconds=round(elapsed_since_last, 3), + success=False, + latency_seconds=round(latency_at_failure, 3), + tools_returned=None, + error_type=label, + error_message=msg[:300], + ) + logger.warning( + "[connector=%s t=+%ds] FAILED %s after %.3fs: %s", + connector.id, + interval, + label, + latency_at_failure, + msg[:200], + ) + result.checkpoints.append(checkpoint) + # Session is presumed dead — further checkpoints would all + # fail the same way and just waste wall time. + break + + except Exception as exc: # noqa: BLE001 + label, msg = _classify_error(exc) + result.fatal_error = f"{label}: {msg[:200]}" + logger.exception( + "[connector=%s] fatal during open/init: %s", + connector.id, + exc, + ) + finally: + if session_cm is not None: + try: + await session_cm.__aexit__(None, None, None) + except Exception: + pass + if transport_cm is not None: + try: + await transport_cm.__aexit__(None, None, None) + except Exception: + pass + try: + await http_client.aclose() + except Exception: + pass + + return result + + +# --------------------------------------------------------------------------- +# Reporting +# --------------------------------------------------------------------------- + + +def _render_table(results: list[ConnectorProbeResult]) -> str: + """Pretty-print a per-connector summary suitable for the terminal.""" + lines: list[str] = [] + lines.append("=" * 100) + lines.append("MCP Session Lifetime Probe Results") + lines.append("=" * 100) + + for result in results: + lines.append("") + lines.append( + f"Connector {result.connector_id} | {result.connector_type} | " + f"{result.connector_name!r}" + ) + lines.append(f" url: {result.url}") + if result.fatal_error: + lines.append(f" FATAL: {result.fatal_error}") + continue + lines.append( + f" init handshake: " + f"{result.init_latency_seconds:.3f}s" + if result.init_latency_seconds is not None + else " init handshake: (failed)" + ) + lines.append( + f" first list_tools (cold): " + f"{result.first_call_latency_seconds:.3f}s" + if result.first_call_latency_seconds is not None + else " first list_tools: (failed)" + ) + if not result.checkpoints: + lines.append(" (no idle checkpoints recorded)") + continue + lines.append( + f" {'idle_s':>8} | {'since_last':>10} | {'outcome':>16} | " + f"{'latency':>9} | {'tools':>5}" + ) + for cp in result.checkpoints: + outcome = "OK" if cp.success else (cp.error_type or "FAIL") + latency = f"{cp.latency_seconds:.3f}s" if cp.latency_seconds is not None else "-" + tools = str(cp.tools_returned) if cp.tools_returned is not None else "-" + lines.append( + f" {cp.idle_seconds_target:>8} | " + f"{cp.elapsed_since_last_call_seconds:>10.1f} | " + f"{outcome:>16} | " + f"{latency:>9} | " + f"{tools:>5}" + ) + + lines.append("") + lines.append("=" * 100) + lines.append("Summary") + lines.append("=" * 100) + survived: dict[int, list[int]] = {} + for result in results: + for cp in result.checkpoints: + if cp.success: + survived.setdefault(cp.idle_seconds_target, []).append( + result.connector_id + ) + if survived: + for interval in sorted(survived): + ids = sorted(survived[interval]) + lines.append( + f" Idle {interval:>5}s: {len(ids)}/{len(results)} connectors " + f"survived ({ids})" + ) + else: + lines.append(" (no successful checkpoints)") + return "\n".join(lines) + + +# --------------------------------------------------------------------------- +# Entry point +# --------------------------------------------------------------------------- + + +def _parse_int_list(value: str) -> list[int]: + return [int(x) for x in value.split(",") if x.strip()] + + +def _parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Probe MCP server session lifetime (read-only)", + ) + parser.add_argument( + "--connectors", + type=_parse_int_list, + default=None, + help="Comma-separated connector IDs to probe. Default: all MCP connectors.", + ) + parser.add_argument( + "--intervals", + type=_parse_int_list, + default=None, + help="Comma-separated idle intervals in seconds. " + f"Default: {DEFAULT_INTERVALS_SECONDS}", + ) + parser.add_argument( + "--quick", + action="store_true", + help=f"Short run (intervals={QUICK_INTERVALS_SECONDS}) for fast iteration.", + ) + parser.add_argument( + "--output", + type=str, + default=None, + help="Optional path for the raw JSON results.", + ) + return parser.parse_args() + + +async def _async_main() -> int: + args = _parse_args() + if args.intervals is not None: + intervals = args.intervals + elif args.quick: + intervals = QUICK_INTERVALS_SECONDS + else: + intervals = DEFAULT_INTERVALS_SECONDS + + longest = max(intervals) if intervals else 0 + logger.info( + "Probing intervals=%s (longest=%ds, ~%dmin total wall time)", + intervals, + longest, + (longest + 30) // 60, + ) + + connectors = await _fetch_connectors(args.connectors) + if not connectors: + logger.error("No MCP connectors found to probe.") + return 2 + logger.info( + "Probing %d connector(s): %s", + len(connectors), + [f"{c.id}:{c.name}" for c in connectors], + ) + + started_at = time.time() + results = await asyncio.gather( + *[_probe_one_connector(c, intervals) for c in connectors], + return_exceptions=False, + ) + elapsed = time.time() - started_at + logger.info("All probes complete in %.1fs", elapsed) + + table = _render_table(results) + print(table) + + output_path = ( + args.output + or f"mcp_session_probe_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json" + ) + with open(output_path, "w", encoding="utf-8") as fp: + json.dump( + { + "started_at": datetime.fromtimestamp(started_at).isoformat(), + "elapsed_seconds": round(elapsed, 1), + "intervals_tested": intervals, + "results": [asdict(r) for r in results], + }, + fp, + indent=2, + ) + logger.info("Raw results saved to %s", output_path) + return 0 + + +def main() -> None: + try: + exit_code = asyncio.run(_async_main()) + except KeyboardInterrupt: + logger.warning("Interrupted by user") + exit_code = 130 + sys.exit(exit_code) + + +if __name__ == "__main__": + main()