SurfSense/surfsense_backend/app/agents/shared/tools/mcp_client.py
CREDO23 aab95b9130 refactor(agents): move tools package to app/agents/shared (slice 6)
Relocate the entire new_chat/tools/ package (62 files incl. registry, hitl, MCP
cluster, and all connector subpackages: gmail/slack/discord/teams/drive/etc.)
to the shared kernel. The package turned out to be a clean cohesive cluster:
its only references to non-tools new_chat modules were comments, and its
middleware deps were already flipped to shared in slice 5c.

Flip 33 live importers (multi-agent, flows, routes, services, anonymous_agent,
tests). Re-export shims remain for the frozen single-agent stack: a package
__init__ mirroring the public surface (new_chat.__init__ imports it) plus
invalid_tool + registry submodule shims (chat_deepagent imports those).

Resolves slice 5c's two transient back-edges: shared/middleware/action_log
(TYPE_CHECKING ToolDefinition) and tool_call_repair (local INVALID_TOOL_NAME)
now point at app.agents.shared.tools.
2026-06-04 13:11:56 +02:00

326 lines
11 KiB
Python

"""MCP Client Wrapper.
This module provides a client for communicating with MCP servers via stdio and HTTP transports.
It handles server lifecycle management, tool discovery, and tool execution.
"""
import asyncio
import logging
import os
from contextlib import asynccontextmanager
from typing import Any
from mcp import ClientSession
from mcp.client.stdio import StdioServerParameters, stdio_client
from mcp.client.streamable_http import streamablehttp_client
logger = logging.getLogger(__name__)
# Retry configuration
MAX_RETRIES = 3
RETRY_DELAY = 1.0 # seconds
RETRY_BACKOFF = 2.0 # exponential backoff multiplier
class MCPClient:
"""Client for communicating with an MCP server."""
def __init__(
self, command: str, args: list[str], env: dict[str, str] | None = None
):
"""Initialize MCP client.
Args:
command: Command to spawn the MCP server (e.g., "uvx", "node")
args: Arguments for the command (e.g., ["mcp-server-git"])
env: Optional environment variables for the server process
"""
self.command = command
self.args = args
self.env = env or {}
self.session: ClientSession | None = None
@asynccontextmanager
async def connect(self, max_retries: int = MAX_RETRIES):
"""Connect to the MCP server and manage its lifecycle.
Retries only apply to the **connection** phase (spawning the process,
initialising the session). Once the session is yielded to the caller,
any exception raised by the caller propagates normally -- the context
manager will NOT retry after ``yield``.
Previous implementation wrapped both connection AND yield inside the
retry loop. Because ``@asynccontextmanager`` only allows a single
``yield``, a failure after yield caused the generator to attempt a
second yield on retry, triggering
``RuntimeError("generator didn't stop after athrow()")`` and orphaning
the stdio subprocess.
Args:
max_retries: Maximum number of connection retry attempts
Yields:
ClientSession: Active MCP session for making requests
Raises:
RuntimeError: If all connection attempts fail
"""
last_error = None
delay = RETRY_DELAY
connected = False
for attempt in range(max_retries):
try:
server_env = os.environ.copy()
server_env.update(self.env)
server_params = StdioServerParameters(
command=self.command, args=self.args, env=server_env
)
async with stdio_client(server=server_params) as (read, write): # noqa: SIM117
async with ClientSession(read, write) as session:
await session.initialize()
self.session = session
connected = True
if attempt > 0:
logger.info(
"Connected to MCP server on attempt %d: %s %s",
attempt + 1,
self.command,
" ".join(self.args),
)
else:
logger.info(
"Connected to MCP server: %s %s",
self.command,
" ".join(self.args),
)
try:
yield session
finally:
self.session = None
return
except Exception as e:
self.session = None
if connected:
raise
last_error = e
if attempt < max_retries - 1:
logger.warning(
"MCP server connection failed (attempt %d/%d): %s. Retrying in %.1fs...",
attempt + 1,
max_retries,
e,
delay,
)
await asyncio.sleep(delay)
delay *= RETRY_BACKOFF
else:
logger.error(
"Failed to connect to MCP server after %d attempts: %s",
max_retries,
e,
exc_info=True,
)
error_msg = f"Failed to connect to MCP server '{self.command}' after {max_retries} attempts"
if last_error:
error_msg += f": {last_error}"
logger.error(error_msg)
raise RuntimeError(error_msg) from last_error
async def list_tools(self) -> list[dict[str, Any]]:
"""List all tools available from the MCP server.
Returns:
List of tool definitions with name, description, and input schema
Raises:
RuntimeError: If not connected to server
"""
if not self.session:
raise RuntimeError(
"Not connected to MCP server. Use 'async with client.connect():'"
)
try:
# Call tools/list RPC method
response = await self.session.list_tools()
tools = []
for tool in response.tools:
tools.append(
{
"name": tool.name,
"description": tool.description or "",
"input_schema": tool.inputSchema
if hasattr(tool, "inputSchema")
else {},
}
)
logger.info("Listed %d tools from MCP server", len(tools))
return tools
except Exception as e:
logger.error("Failed to list tools from MCP server: %s", e, exc_info=True)
raise
async def call_tool(
self,
tool_name: str,
arguments: dict[str, Any],
timeout: float = 60.0,
) -> Any:
"""Call a tool on the MCP server.
Args:
tool_name: Name of the tool to call
arguments: Arguments to pass to the tool
timeout: Maximum seconds to wait for the tool to respond
Returns:
Tool execution result
Raises:
RuntimeError: If not connected to server
"""
if not self.session:
raise RuntimeError(
"Not connected to MCP server. Use 'async with client.connect():'"
)
try:
logger.info(
"Calling MCP tool '%s' with arguments: %s", tool_name, arguments
)
response = await asyncio.wait_for(
self.session.call_tool(tool_name, arguments=arguments),
timeout=timeout,
)
result = []
for content in response.content:
if hasattr(content, "text"):
result.append(content.text)
elif hasattr(content, "data"):
result.append(str(content.data))
else:
result.append(str(content))
result_str = "\n".join(result) if result else ""
logger.info("MCP tool '%s' succeeded: %s", tool_name, result_str[:200])
return result_str
except TimeoutError:
logger.error("MCP tool '%s' timed out after %.0fs", tool_name, timeout)
return f"Error: MCP tool '{tool_name}' timed out after {timeout:.0f}s"
except RuntimeError as e:
if "Invalid structured content" in str(e):
logger.warning(
"MCP server returned data not matching its schema, but continuing: %s",
e,
)
return "Operation completed (server returned unexpected format)"
raise
except (ValueError, TypeError, AttributeError, KeyError) as e:
logger.error(
"Failed to call MCP tool '%s': %s", tool_name, e, exc_info=True
)
return f"Error calling tool: {e!s}"
async def test_mcp_connection(
command: str, args: list[str], env: dict[str, str] | None = None
) -> dict[str, Any]:
"""Test connection to an MCP server via stdio and fetch available tools.
Args:
command: Command to spawn the MCP server
args: Arguments for the command
env: Optional environment variables
Returns:
Dict with connection status and available tools
"""
client = MCPClient(command, args, env)
try:
async with client.connect():
tools = await client.list_tools()
return {
"status": "success",
"message": f"Connected successfully. Found {len(tools)} tools.",
"tools": tools,
}
except (RuntimeError, ConnectionError, TimeoutError, OSError) as e:
return {
"status": "error",
"message": f"Failed to connect: {e!s}",
"tools": [],
}
async def test_mcp_http_connection(
url: str, headers: dict[str, str] | None = None, transport: str = "streamable-http"
) -> dict[str, Any]:
"""Test connection to an MCP server via HTTP and fetch available tools.
Args:
url: URL of the MCP server
headers: Optional HTTP headers for authentication
transport: Transport type ("streamable-http", "http", or "sse")
Returns:
Dict with connection status and available tools
"""
try:
logger.info(
"Testing HTTP MCP connection to: %s (transport: %s)", url, transport
)
# Use streamable HTTP client for all HTTP-based transports
async with (
streamablehttp_client(url, headers=headers or {}) as (read, write, _),
ClientSession(read, write) as session,
):
await session.initialize()
# List available tools
response = await session.list_tools()
tools = []
for tool in response.tools:
tools.append(
{
"name": tool.name,
"description": tool.description or "",
"input_schema": tool.inputSchema
if hasattr(tool, "inputSchema")
else {},
}
)
logger.info("HTTP MCP connection successful. Found %d tools.", len(tools))
return {
"status": "success",
"message": f"Connected successfully. Found {len(tools)} tools.",
"tools": tools,
}
except Exception as e:
logger.error("Failed to connect to HTTP MCP server: %s", e, exc_info=True)
return {
"status": "error",
"message": f"Failed to connect: {e!s}",
"tools": [],
}