mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-27 19:25:15 +02:00
563 lines
19 KiB
Python
563 lines
19 KiB
Python
"""Probe MCP server session lifetime / staleness behavior — read-only.
|
|
|
|
Goal
|
|
----
|
|
Empirically answer two questions for our actual third-party MCP servers
|
|
(Atlassian, Linear, Slack, ClickUp, Airtable, ...):
|
|
|
|
1. How expensive is the initial ``initialize`` handshake (``init=`` cost)?
|
|
2. How long can a ``ClientSession`` sit idle and still survive a
|
|
subsequent ``list_tools()`` call?
|
|
|
|
This script informs the design choice between
|
|
|
|
* per-call sessions (current, ~1s init tax per call),
|
|
* per-turn session reuse (LangChain-style, holds a session for the
|
|
duration of a chat turn),
|
|
* a long-lived session pool (IBM-style, sessions reused across turns).
|
|
|
|
The probe is read-only: it only ever calls ``session.list_tools()``,
|
|
which is the safest MCP method. No tool calls against user data are
|
|
performed.
|
|
|
|
Usage
|
|
-----
|
|
Run from the repo root or from ``surfsense_backend/``::
|
|
|
|
uv run python -m scripts.probe_mcp_session_lifetime
|
|
uv run python -m scripts.probe_mcp_session_lifetime --quick
|
|
uv run python -m scripts.probe_mcp_session_lifetime --connectors 7,19,20
|
|
uv run python -m scripts.probe_mcp_session_lifetime --intervals 5,30,60,300
|
|
|
|
Output
|
|
------
|
|
* Live progress to stderr (``[connector=7 t=+30s] OK 0.142s``).
|
|
* Final per-connector table to stdout.
|
|
* Raw results JSON to ``./mcp_session_probe_<timestamp>.json``.
|
|
|
|
The default test reaches 1800s of idle (~30 min). Use ``--quick`` to
|
|
stop at 60s for fast iteration. All connectors probe concurrently so
|
|
total wall-clock time equals the longest interval, not the sum.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
import os
|
|
import sys
|
|
import time
|
|
from dataclasses import asdict, dataclass, field
|
|
from datetime import datetime
|
|
from typing import Any
|
|
|
|
_HERE = os.path.dirname(os.path.abspath(__file__))
|
|
_BACKEND_ROOT = os.path.dirname(_HERE)
|
|
if _BACKEND_ROOT not in sys.path:
|
|
sys.path.insert(0, _BACKEND_ROOT)
|
|
|
|
import httpx # noqa: E402
|
|
from mcp import ClientSession # noqa: E402
|
|
from mcp.client.streamable_http import streamable_http_client # noqa: E402
|
|
from sqlalchemy import cast, select # noqa: E402
|
|
from sqlalchemy.dialects.postgresql import JSONB # noqa: E402
|
|
|
|
from app.agents.new_chat.tools.mcp_tool import ( # noqa: E402
|
|
_inject_oauth_headers,
|
|
_maybe_refresh_mcp_oauth_token,
|
|
)
|
|
from app.db import SearchSourceConnector, async_session_maker # noqa: E402
|
|
|
|
logging.basicConfig(
|
|
level=logging.WARNING,
|
|
format="%(asctime)s | %(levelname)-8s | %(name)s | %(message)s",
|
|
stream=sys.stderr,
|
|
)
|
|
logging.getLogger("httpx").setLevel(logging.ERROR)
|
|
logging.getLogger("mcp").setLevel(logging.ERROR)
|
|
logger = logging.getLogger("mcp_probe")
|
|
logger.setLevel(logging.INFO)
|
|
|
|
|
|
DEFAULT_INTERVALS_SECONDS = [5, 30, 60, 300, 900, 1800]
|
|
QUICK_INTERVALS_SECONDS = [5, 30, 60]
|
|
PER_CALL_TIMEOUT_SECONDS = 60.0
|
|
|
|
|
|
@dataclass
|
|
class CheckpointResult:
|
|
"""One ``list_tools()`` call against a long-lived session."""
|
|
|
|
idle_seconds_target: int
|
|
elapsed_since_open_seconds: float
|
|
elapsed_since_last_call_seconds: float
|
|
success: bool
|
|
latency_seconds: float | None
|
|
tools_returned: int | None
|
|
error_type: str | None
|
|
error_message: str | None
|
|
|
|
|
|
@dataclass
|
|
class ConnectorProbeResult:
|
|
"""Per-connector aggregated probe outcome."""
|
|
|
|
connector_id: int
|
|
connector_name: str
|
|
connector_type: str
|
|
url: str
|
|
init_latency_seconds: float | None
|
|
first_call_latency_seconds: float | None
|
|
checkpoints: list[CheckpointResult] = field(default_factory=list)
|
|
fatal_error: str | None = None
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Connector loading + auth
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
async def _fetch_connectors(
|
|
connector_ids: list[int] | None,
|
|
) -> list[SearchSourceConnector]:
|
|
"""Pull every MCP-shaped connector (or only the requested IDs)."""
|
|
async with async_session_maker() as session:
|
|
stmt = select(SearchSourceConnector).filter(
|
|
cast(SearchSourceConnector.config, JSONB).has_key("server_config"),
|
|
)
|
|
if connector_ids:
|
|
stmt = stmt.filter(SearchSourceConnector.id.in_(connector_ids))
|
|
result = await session.execute(stmt)
|
|
connectors = list(result.scalars())
|
|
|
|
if connector_ids:
|
|
found_ids = {c.id for c in connectors}
|
|
missing = [cid for cid in connector_ids if cid not in found_ids]
|
|
if missing:
|
|
logger.warning("Requested connector IDs not found: %s", missing)
|
|
return connectors
|
|
|
|
|
|
async def _resolve_authed_server_config(
|
|
connector: SearchSourceConnector,
|
|
) -> dict[str, Any] | None:
|
|
"""Refresh OAuth (if needed) and return a server_config with auth headers.
|
|
|
|
Returns ``None`` if the connector cannot be probed (missing url,
|
|
decrypt failure, no refresh token, etc.).
|
|
"""
|
|
cfg = connector.config or {}
|
|
server_config = cfg.get("server_config", {})
|
|
if not isinstance(server_config, dict):
|
|
return None
|
|
|
|
if cfg.get("mcp_oauth"):
|
|
async with async_session_maker() as session:
|
|
attached = await session.get(SearchSourceConnector, connector.id)
|
|
if attached is None:
|
|
return None
|
|
refreshed = await _maybe_refresh_mcp_oauth_token(
|
|
session,
|
|
attached,
|
|
attached.config or {},
|
|
server_config,
|
|
)
|
|
attached_cfg = attached.config or {}
|
|
server_config = _inject_oauth_headers(attached_cfg, refreshed)
|
|
if server_config is None:
|
|
return None
|
|
return server_config
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# The actual probe
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def _classify_error(exc: BaseException) -> tuple[str, str]:
|
|
"""Return ``(short_label, human_message)`` for a failed call."""
|
|
name = type(exc).__name__
|
|
msg = str(exc) or repr(exc)
|
|
if isinstance(exc, asyncio.TimeoutError):
|
|
return "timeout", f"call exceeded {PER_CALL_TIMEOUT_SECONDS}s"
|
|
if "404" in msg or "Not Found" in msg or "session" in msg.lower():
|
|
return "session_expired", msg
|
|
if "401" in msg or "Unauthorized" in msg:
|
|
return "auth_401", msg
|
|
if "ClosedResourceError" in name or "Closed" in name:
|
|
return "stream_closed", msg
|
|
if "Connection" in name or "ConnectError" in name:
|
|
return "connection_error", msg
|
|
return name, msg
|
|
|
|
|
|
async def _probe_one_connector(
|
|
connector: SearchSourceConnector,
|
|
intervals: list[int],
|
|
) -> ConnectorProbeResult:
|
|
"""Open a single long-lived session, call ``list_tools`` at each interval."""
|
|
connector_type = (
|
|
connector.connector_type.value
|
|
if hasattr(connector.connector_type, "value")
|
|
else str(connector.connector_type)
|
|
)
|
|
server_config = await _resolve_authed_server_config(connector)
|
|
if server_config is None:
|
|
return ConnectorProbeResult(
|
|
connector_id=connector.id,
|
|
connector_name=connector.name,
|
|
connector_type=connector_type,
|
|
url="(unresolved)",
|
|
init_latency_seconds=None,
|
|
first_call_latency_seconds=None,
|
|
fatal_error="failed_to_resolve_server_config",
|
|
)
|
|
|
|
url = server_config.get("url")
|
|
headers = server_config.get("headers", {})
|
|
if not url:
|
|
return ConnectorProbeResult(
|
|
connector_id=connector.id,
|
|
connector_name=connector.name,
|
|
connector_type=connector_type,
|
|
url="(missing)",
|
|
init_latency_seconds=None,
|
|
first_call_latency_seconds=None,
|
|
fatal_error="missing_url",
|
|
)
|
|
|
|
transport = server_config.get("transport", "streamable-http")
|
|
if transport not in ("streamable-http", "http", "sse"):
|
|
return ConnectorProbeResult(
|
|
connector_id=connector.id,
|
|
connector_name=connector.name,
|
|
connector_type=connector_type,
|
|
url=url,
|
|
init_latency_seconds=None,
|
|
first_call_latency_seconds=None,
|
|
fatal_error=f"unsupported_transport:{transport}",
|
|
)
|
|
|
|
result = ConnectorProbeResult(
|
|
connector_id=connector.id,
|
|
connector_name=connector.name,
|
|
connector_type=connector_type,
|
|
url=url,
|
|
init_latency_seconds=None,
|
|
first_call_latency_seconds=None,
|
|
)
|
|
|
|
open_started = time.perf_counter()
|
|
last_call_at: float | None = None
|
|
|
|
# Manually drive the context-manager protocol so the session lives
|
|
# across our sleep intervals. ``streamable_http_client`` spawns a
|
|
# background task for the SSE receive loop; ``ClientSession`` spawns
|
|
# another for request multiplexing. We must close them in reverse order.
|
|
http_client = httpx.AsyncClient(headers=headers, timeout=PER_CALL_TIMEOUT_SECONDS)
|
|
transport_cm = None
|
|
session_cm = None
|
|
session = None
|
|
try:
|
|
transport_cm = streamable_http_client(url, http_client=http_client)
|
|
read, write, _ = await transport_cm.__aenter__()
|
|
session_cm = ClientSession(read, write)
|
|
session = await session_cm.__aenter__()
|
|
|
|
init_start = time.perf_counter()
|
|
await asyncio.wait_for(session.initialize(), timeout=PER_CALL_TIMEOUT_SECONDS)
|
|
result.init_latency_seconds = time.perf_counter() - init_start
|
|
logger.info(
|
|
"[connector=%s name=%r] init=%.3fs",
|
|
connector.id,
|
|
connector.name,
|
|
result.init_latency_seconds,
|
|
)
|
|
|
|
first_call_start = time.perf_counter()
|
|
first_response = await asyncio.wait_for(
|
|
session.list_tools(), timeout=PER_CALL_TIMEOUT_SECONDS
|
|
)
|
|
result.first_call_latency_seconds = time.perf_counter() - first_call_start
|
|
last_call_at = time.perf_counter()
|
|
logger.info(
|
|
"[connector=%s name=%r] first_call=%.3fs tools=%d",
|
|
connector.id,
|
|
connector.name,
|
|
result.first_call_latency_seconds,
|
|
len(first_response.tools),
|
|
)
|
|
|
|
for interval in intervals:
|
|
target_elapsed = open_started + (
|
|
result.init_latency_seconds + result.first_call_latency_seconds + interval
|
|
)
|
|
sleep_for = max(0.0, target_elapsed - time.perf_counter())
|
|
await asyncio.sleep(sleep_for)
|
|
|
|
call_start = time.perf_counter()
|
|
elapsed_since_open = call_start - open_started
|
|
elapsed_since_last = call_start - (last_call_at or call_start)
|
|
try:
|
|
response = await asyncio.wait_for(
|
|
session.list_tools(), timeout=PER_CALL_TIMEOUT_SECONDS
|
|
)
|
|
latency = time.perf_counter() - call_start
|
|
last_call_at = time.perf_counter()
|
|
checkpoint = CheckpointResult(
|
|
idle_seconds_target=interval,
|
|
elapsed_since_open_seconds=round(elapsed_since_open, 3),
|
|
elapsed_since_last_call_seconds=round(elapsed_since_last, 3),
|
|
success=True,
|
|
latency_seconds=round(latency, 3),
|
|
tools_returned=len(response.tools),
|
|
error_type=None,
|
|
error_message=None,
|
|
)
|
|
logger.info(
|
|
"[connector=%s t=+%ds] OK %.3fs (tools=%d)",
|
|
connector.id,
|
|
interval,
|
|
latency,
|
|
len(response.tools),
|
|
)
|
|
result.checkpoints.append(checkpoint)
|
|
except Exception as exc: # noqa: BLE001
|
|
label, msg = _classify_error(exc)
|
|
latency_at_failure = time.perf_counter() - call_start
|
|
checkpoint = CheckpointResult(
|
|
idle_seconds_target=interval,
|
|
elapsed_since_open_seconds=round(elapsed_since_open, 3),
|
|
elapsed_since_last_call_seconds=round(elapsed_since_last, 3),
|
|
success=False,
|
|
latency_seconds=round(latency_at_failure, 3),
|
|
tools_returned=None,
|
|
error_type=label,
|
|
error_message=msg[:300],
|
|
)
|
|
logger.warning(
|
|
"[connector=%s t=+%ds] FAILED %s after %.3fs: %s",
|
|
connector.id,
|
|
interval,
|
|
label,
|
|
latency_at_failure,
|
|
msg[:200],
|
|
)
|
|
result.checkpoints.append(checkpoint)
|
|
# Session is presumed dead — further checkpoints would all
|
|
# fail the same way and just waste wall time.
|
|
break
|
|
|
|
except Exception as exc: # noqa: BLE001
|
|
label, msg = _classify_error(exc)
|
|
result.fatal_error = f"{label}: {msg[:200]}"
|
|
logger.exception(
|
|
"[connector=%s] fatal during open/init: %s",
|
|
connector.id,
|
|
exc,
|
|
)
|
|
finally:
|
|
if session_cm is not None:
|
|
try:
|
|
await session_cm.__aexit__(None, None, None)
|
|
except Exception:
|
|
pass
|
|
if transport_cm is not None:
|
|
try:
|
|
await transport_cm.__aexit__(None, None, None)
|
|
except Exception:
|
|
pass
|
|
try:
|
|
await http_client.aclose()
|
|
except Exception:
|
|
pass
|
|
|
|
return result
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Reporting
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def _render_table(results: list[ConnectorProbeResult]) -> str:
|
|
"""Pretty-print a per-connector summary suitable for the terminal."""
|
|
lines: list[str] = []
|
|
lines.append("=" * 100)
|
|
lines.append("MCP Session Lifetime Probe Results")
|
|
lines.append("=" * 100)
|
|
|
|
for result in results:
|
|
lines.append("")
|
|
lines.append(
|
|
f"Connector {result.connector_id} | {result.connector_type} | "
|
|
f"{result.connector_name!r}"
|
|
)
|
|
lines.append(f" url: {result.url}")
|
|
if result.fatal_error:
|
|
lines.append(f" FATAL: {result.fatal_error}")
|
|
continue
|
|
lines.append(
|
|
f" init handshake: "
|
|
f"{result.init_latency_seconds:.3f}s"
|
|
if result.init_latency_seconds is not None
|
|
else " init handshake: (failed)"
|
|
)
|
|
lines.append(
|
|
f" first list_tools (cold): "
|
|
f"{result.first_call_latency_seconds:.3f}s"
|
|
if result.first_call_latency_seconds is not None
|
|
else " first list_tools: (failed)"
|
|
)
|
|
if not result.checkpoints:
|
|
lines.append(" (no idle checkpoints recorded)")
|
|
continue
|
|
lines.append(
|
|
f" {'idle_s':>8} | {'since_last':>10} | {'outcome':>16} | "
|
|
f"{'latency':>9} | {'tools':>5}"
|
|
)
|
|
for cp in result.checkpoints:
|
|
outcome = "OK" if cp.success else (cp.error_type or "FAIL")
|
|
latency = f"{cp.latency_seconds:.3f}s" if cp.latency_seconds is not None else "-"
|
|
tools = str(cp.tools_returned) if cp.tools_returned is not None else "-"
|
|
lines.append(
|
|
f" {cp.idle_seconds_target:>8} | "
|
|
f"{cp.elapsed_since_last_call_seconds:>10.1f} | "
|
|
f"{outcome:>16} | "
|
|
f"{latency:>9} | "
|
|
f"{tools:>5}"
|
|
)
|
|
|
|
lines.append("")
|
|
lines.append("=" * 100)
|
|
lines.append("Summary")
|
|
lines.append("=" * 100)
|
|
survived: dict[int, list[int]] = {}
|
|
for result in results:
|
|
for cp in result.checkpoints:
|
|
if cp.success:
|
|
survived.setdefault(cp.idle_seconds_target, []).append(
|
|
result.connector_id
|
|
)
|
|
if survived:
|
|
for interval in sorted(survived):
|
|
ids = sorted(survived[interval])
|
|
lines.append(
|
|
f" Idle {interval:>5}s: {len(ids)}/{len(results)} connectors "
|
|
f"survived ({ids})"
|
|
)
|
|
else:
|
|
lines.append(" (no successful checkpoints)")
|
|
return "\n".join(lines)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Entry point
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def _parse_int_list(value: str) -> list[int]:
|
|
return [int(x) for x in value.split(",") if x.strip()]
|
|
|
|
|
|
def _parse_args() -> argparse.Namespace:
|
|
parser = argparse.ArgumentParser(
|
|
description="Probe MCP server session lifetime (read-only)",
|
|
)
|
|
parser.add_argument(
|
|
"--connectors",
|
|
type=_parse_int_list,
|
|
default=None,
|
|
help="Comma-separated connector IDs to probe. Default: all MCP connectors.",
|
|
)
|
|
parser.add_argument(
|
|
"--intervals",
|
|
type=_parse_int_list,
|
|
default=None,
|
|
help="Comma-separated idle intervals in seconds. "
|
|
f"Default: {DEFAULT_INTERVALS_SECONDS}",
|
|
)
|
|
parser.add_argument(
|
|
"--quick",
|
|
action="store_true",
|
|
help=f"Short run (intervals={QUICK_INTERVALS_SECONDS}) for fast iteration.",
|
|
)
|
|
parser.add_argument(
|
|
"--output",
|
|
type=str,
|
|
default=None,
|
|
help="Optional path for the raw JSON results.",
|
|
)
|
|
return parser.parse_args()
|
|
|
|
|
|
async def _async_main() -> int:
|
|
args = _parse_args()
|
|
if args.intervals is not None:
|
|
intervals = args.intervals
|
|
elif args.quick:
|
|
intervals = QUICK_INTERVALS_SECONDS
|
|
else:
|
|
intervals = DEFAULT_INTERVALS_SECONDS
|
|
|
|
longest = max(intervals) if intervals else 0
|
|
logger.info(
|
|
"Probing intervals=%s (longest=%ds, ~%dmin total wall time)",
|
|
intervals,
|
|
longest,
|
|
(longest + 30) // 60,
|
|
)
|
|
|
|
connectors = await _fetch_connectors(args.connectors)
|
|
if not connectors:
|
|
logger.error("No MCP connectors found to probe.")
|
|
return 2
|
|
logger.info(
|
|
"Probing %d connector(s): %s",
|
|
len(connectors),
|
|
[f"{c.id}:{c.name}" for c in connectors],
|
|
)
|
|
|
|
started_at = time.time()
|
|
results = await asyncio.gather(
|
|
*[_probe_one_connector(c, intervals) for c in connectors],
|
|
return_exceptions=False,
|
|
)
|
|
elapsed = time.time() - started_at
|
|
logger.info("All probes complete in %.1fs", elapsed)
|
|
|
|
table = _render_table(results)
|
|
print(table)
|
|
|
|
output_path = (
|
|
args.output
|
|
or f"mcp_session_probe_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
|
|
)
|
|
with open(output_path, "w", encoding="utf-8") as fp:
|
|
json.dump(
|
|
{
|
|
"started_at": datetime.fromtimestamp(started_at).isoformat(),
|
|
"elapsed_seconds": round(elapsed, 1),
|
|
"intervals_tested": intervals,
|
|
"results": [asdict(r) for r in results],
|
|
},
|
|
fp,
|
|
indent=2,
|
|
)
|
|
logger.info("Raw results saved to %s", output_path)
|
|
return 0
|
|
|
|
|
|
def main() -> None:
|
|
try:
|
|
exit_code = asyncio.run(_async_main())
|
|
except KeyboardInterrupt:
|
|
logger.warning("Interrupted by user")
|
|
exit_code = 130
|
|
sys.exit(exit_code)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|