From e6297e2e40ac9d30aff4610010f97ecd3ed52640 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Fri, 8 May 2026 00:14:43 +0530 Subject: [PATCH] test(backend): add shared MCP runtime E2E fake --- .../tests/e2e/fakes/mcp_runtime.py | 146 ++++++++++++++++++ 1 file changed, 146 insertions(+) create mode 100644 surfsense_backend/tests/e2e/fakes/mcp_runtime.py diff --git a/surfsense_backend/tests/e2e/fakes/mcp_runtime.py b/surfsense_backend/tests/e2e/fakes/mcp_runtime.py new file mode 100644 index 000000000..821b7c839 --- /dev/null +++ b/surfsense_backend/tests/e2e/fakes/mcp_runtime.py @@ -0,0 +1,146 @@ +"""Shared strict MCP streamable-HTTP runtime fake for E2E tests.""" + +from __future__ import annotations + +import inspect +from collections.abc import Awaitable, Callable +from dataclasses import dataclass +from types import SimpleNamespace +from typing import Any +from unittest.mock import patch + +ListToolsFn = Callable[[], Any | Awaitable[Any]] +CallToolFn = Callable[[str, dict[str, Any]], Any | Awaitable[Any]] + + +@dataclass(frozen=True) +class _RuntimeHandler: + expected_bearer: str + list_tools: ListToolsFn + call_tool: CallToolFn + + +_HANDLERS: dict[str, _RuntimeHandler] = {} + + +class _StrictFakeMixin: + _component_name: str = "" + + def __getattr__(self, name: str) -> Any: + raise NotImplementedError( + f"E2E MCP runtime fake missing surface: {self._component_name}.{name!r}. " + "Add it to surfsense_backend/tests/e2e/fakes/mcp_runtime.py." + ) + + +class _FakeEndpoint(_StrictFakeMixin): + _component_name = "streamablehttp_endpoint" + + def __init__(self, url: str, handler: _RuntimeHandler): + self.url = url + self.handler = handler + + +class _FakeStreamableHttpClient(_StrictFakeMixin): + _component_name = "streamablehttp_client" + + def __init__( + self, url: str, *, headers: dict[str, str] | None = None, **kwargs: Any + ): + del kwargs + handler = _HANDLERS.get(url) + if handler is None: + raise NotImplementedError(f"Unexpected MCP streamable-http url={url!r}") + + auth = (headers or {}).get("Authorization") + expected = f"Bearer {handler.expected_bearer}" + if auth != expected: + raise ValueError( + f"Unexpected MCP Authorization header for {url!r}: {auth!r}" + ) + + self.url = url + self.headers = headers or {} + self.handler = handler + + async def __aenter__(self) -> tuple[_FakeEndpoint, _FakeEndpoint, None]: + return _FakeEndpoint(self.url, self.handler), _FakeEndpoint( + self.url, self.handler + ), None + + async def __aexit__(self, exc_type: Any, exc: Any, tb: Any) -> None: + del exc_type, exc, tb + + +class _FakeClientSession(_StrictFakeMixin): + _component_name = "ClientSession" + + def __init__(self, read: _FakeEndpoint, write: _FakeEndpoint): + if read.handler is not write.handler: + raise ValueError("MCP fake received mismatched read/write endpoints.") + self.read = read + self.write = write + self.handler = read.handler + + async def __aenter__(self) -> _FakeClientSession: + return self + + async def __aexit__(self, exc_type: Any, exc: Any, tb: Any) -> None: + del exc_type, exc, tb + + async def initialize(self) -> None: + return None + + async def list_tools(self) -> SimpleNamespace: + result = self.handler.list_tools() + if inspect.isawaitable(result): + result = await result + return result + + async def call_tool( + self, tool_name: str, *, arguments: dict[str, Any] | None = None + ) -> SimpleNamespace: + result = self.handler.call_tool(tool_name, arguments or {}) + if inspect.isawaitable(result): + result = await result + return result + + +def _fake_streamablehttp_client( + url: str, *, headers: dict[str, str] | None = None, **kwargs: Any +) -> _FakeStreamableHttpClient: + return _FakeStreamableHttpClient(url, headers=headers, **kwargs) + + +def register( + *, + url: str, + expected_bearer: str, + list_tools: ListToolsFn, + call_tool: CallToolFn, +) -> None: + """Register a fake streamable-HTTP MCP server by canonical MCP URL.""" + existing = _HANDLERS.get(url) + handler = _RuntimeHandler( + expected_bearer=expected_bearer, + list_tools=list_tools, + call_tool=call_tool, + ) + if existing is not None and existing != handler: + raise ValueError(f"MCP runtime fake handler already registered for {url!r}.") + _HANDLERS[url] = handler + + +def install(active_patches: list[Any]) -> None: + """Patch production MCP streamable-HTTP boundaries exactly once.""" + targets = [ + ( + "app.agents.new_chat.tools.mcp_tool.streamablehttp_client", + _fake_streamablehttp_client, + ), + ("app.agents.new_chat.tools.mcp_tool.ClientSession", _FakeClientSession), + ] + for target, replacement in targets: + p = patch(target, replacement) + p.start() + active_patches.append(p)