mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-06-06 20:15:17 +02:00
Relocate the entire new_chat/tools/ package (62 files incl. registry, hitl, MCP cluster, and all connector subpackages: gmail/slack/discord/teams/drive/etc.) to the shared kernel. The package turned out to be a clean cohesive cluster: its only references to non-tools new_chat modules were comments, and its middleware deps were already flipped to shared in slice 5c. Flip 33 live importers (multi-agent, flows, routes, services, anonymous_agent, tests). Re-export shims remain for the frozen single-agent stack: a package __init__ mirroring the public surface (new_chat.__init__ imports it) plus invalid_tool + registry submodule shims (chat_deepagent imports those). Resolves slice 5c's two transient back-edges: shared/middleware/action_log (TYPE_CHECKING ToolDefinition) and tool_call_repair (local INVALID_TOOL_NAME) now point at app.agents.shared.tools.
326 lines
11 KiB
Python
326 lines
11 KiB
Python
"""MCP Client Wrapper.
|
|
|
|
This module provides a client for communicating with MCP servers via stdio and HTTP transports.
|
|
It handles server lifecycle management, tool discovery, and tool execution.
|
|
"""
|
|
|
|
import asyncio
|
|
import logging
|
|
import os
|
|
from contextlib import asynccontextmanager
|
|
from typing import Any
|
|
|
|
from mcp import ClientSession
|
|
from mcp.client.stdio import StdioServerParameters, stdio_client
|
|
from mcp.client.streamable_http import streamablehttp_client
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Retry configuration
|
|
MAX_RETRIES = 3
|
|
RETRY_DELAY = 1.0 # seconds
|
|
RETRY_BACKOFF = 2.0 # exponential backoff multiplier
|
|
|
|
|
|
class MCPClient:
|
|
"""Client for communicating with an MCP server."""
|
|
|
|
def __init__(
|
|
self, command: str, args: list[str], env: dict[str, str] | None = None
|
|
):
|
|
"""Initialize MCP client.
|
|
|
|
Args:
|
|
command: Command to spawn the MCP server (e.g., "uvx", "node")
|
|
args: Arguments for the command (e.g., ["mcp-server-git"])
|
|
env: Optional environment variables for the server process
|
|
|
|
"""
|
|
self.command = command
|
|
self.args = args
|
|
self.env = env or {}
|
|
self.session: ClientSession | None = None
|
|
|
|
@asynccontextmanager
|
|
async def connect(self, max_retries: int = MAX_RETRIES):
|
|
"""Connect to the MCP server and manage its lifecycle.
|
|
|
|
Retries only apply to the **connection** phase (spawning the process,
|
|
initialising the session). Once the session is yielded to the caller,
|
|
any exception raised by the caller propagates normally -- the context
|
|
manager will NOT retry after ``yield``.
|
|
|
|
Previous implementation wrapped both connection AND yield inside the
|
|
retry loop. Because ``@asynccontextmanager`` only allows a single
|
|
``yield``, a failure after yield caused the generator to attempt a
|
|
second yield on retry, triggering
|
|
``RuntimeError("generator didn't stop after athrow()")`` and orphaning
|
|
the stdio subprocess.
|
|
|
|
Args:
|
|
max_retries: Maximum number of connection retry attempts
|
|
|
|
Yields:
|
|
ClientSession: Active MCP session for making requests
|
|
|
|
Raises:
|
|
RuntimeError: If all connection attempts fail
|
|
|
|
"""
|
|
last_error = None
|
|
delay = RETRY_DELAY
|
|
connected = False
|
|
|
|
for attempt in range(max_retries):
|
|
try:
|
|
server_env = os.environ.copy()
|
|
server_env.update(self.env)
|
|
|
|
server_params = StdioServerParameters(
|
|
command=self.command, args=self.args, env=server_env
|
|
)
|
|
|
|
async with stdio_client(server=server_params) as (read, write): # noqa: SIM117
|
|
async with ClientSession(read, write) as session:
|
|
await session.initialize()
|
|
self.session = session
|
|
connected = True
|
|
|
|
if attempt > 0:
|
|
logger.info(
|
|
"Connected to MCP server on attempt %d: %s %s",
|
|
attempt + 1,
|
|
self.command,
|
|
" ".join(self.args),
|
|
)
|
|
else:
|
|
logger.info(
|
|
"Connected to MCP server: %s %s",
|
|
self.command,
|
|
" ".join(self.args),
|
|
)
|
|
try:
|
|
yield session
|
|
finally:
|
|
self.session = None
|
|
return
|
|
|
|
except Exception as e:
|
|
self.session = None
|
|
if connected:
|
|
raise
|
|
last_error = e
|
|
if attempt < max_retries - 1:
|
|
logger.warning(
|
|
"MCP server connection failed (attempt %d/%d): %s. Retrying in %.1fs...",
|
|
attempt + 1,
|
|
max_retries,
|
|
e,
|
|
delay,
|
|
)
|
|
await asyncio.sleep(delay)
|
|
delay *= RETRY_BACKOFF
|
|
else:
|
|
logger.error(
|
|
"Failed to connect to MCP server after %d attempts: %s",
|
|
max_retries,
|
|
e,
|
|
exc_info=True,
|
|
)
|
|
|
|
error_msg = f"Failed to connect to MCP server '{self.command}' after {max_retries} attempts"
|
|
if last_error:
|
|
error_msg += f": {last_error}"
|
|
logger.error(error_msg)
|
|
raise RuntimeError(error_msg) from last_error
|
|
|
|
async def list_tools(self) -> list[dict[str, Any]]:
|
|
"""List all tools available from the MCP server.
|
|
|
|
Returns:
|
|
List of tool definitions with name, description, and input schema
|
|
|
|
Raises:
|
|
RuntimeError: If not connected to server
|
|
|
|
"""
|
|
if not self.session:
|
|
raise RuntimeError(
|
|
"Not connected to MCP server. Use 'async with client.connect():'"
|
|
)
|
|
|
|
try:
|
|
# Call tools/list RPC method
|
|
response = await self.session.list_tools()
|
|
|
|
tools = []
|
|
for tool in response.tools:
|
|
tools.append(
|
|
{
|
|
"name": tool.name,
|
|
"description": tool.description or "",
|
|
"input_schema": tool.inputSchema
|
|
if hasattr(tool, "inputSchema")
|
|
else {},
|
|
}
|
|
)
|
|
|
|
logger.info("Listed %d tools from MCP server", len(tools))
|
|
return tools
|
|
|
|
except Exception as e:
|
|
logger.error("Failed to list tools from MCP server: %s", e, exc_info=True)
|
|
raise
|
|
|
|
async def call_tool(
|
|
self,
|
|
tool_name: str,
|
|
arguments: dict[str, Any],
|
|
timeout: float = 60.0,
|
|
) -> Any:
|
|
"""Call a tool on the MCP server.
|
|
|
|
Args:
|
|
tool_name: Name of the tool to call
|
|
arguments: Arguments to pass to the tool
|
|
timeout: Maximum seconds to wait for the tool to respond
|
|
|
|
Returns:
|
|
Tool execution result
|
|
|
|
Raises:
|
|
RuntimeError: If not connected to server
|
|
|
|
"""
|
|
if not self.session:
|
|
raise RuntimeError(
|
|
"Not connected to MCP server. Use 'async with client.connect():'"
|
|
)
|
|
|
|
try:
|
|
logger.info(
|
|
"Calling MCP tool '%s' with arguments: %s", tool_name, arguments
|
|
)
|
|
|
|
response = await asyncio.wait_for(
|
|
self.session.call_tool(tool_name, arguments=arguments),
|
|
timeout=timeout,
|
|
)
|
|
|
|
result = []
|
|
for content in response.content:
|
|
if hasattr(content, "text"):
|
|
result.append(content.text)
|
|
elif hasattr(content, "data"):
|
|
result.append(str(content.data))
|
|
else:
|
|
result.append(str(content))
|
|
|
|
result_str = "\n".join(result) if result else ""
|
|
logger.info("MCP tool '%s' succeeded: %s", tool_name, result_str[:200])
|
|
return result_str
|
|
|
|
except TimeoutError:
|
|
logger.error("MCP tool '%s' timed out after %.0fs", tool_name, timeout)
|
|
return f"Error: MCP tool '{tool_name}' timed out after {timeout:.0f}s"
|
|
except RuntimeError as e:
|
|
if "Invalid structured content" in str(e):
|
|
logger.warning(
|
|
"MCP server returned data not matching its schema, but continuing: %s",
|
|
e,
|
|
)
|
|
return "Operation completed (server returned unexpected format)"
|
|
raise
|
|
except (ValueError, TypeError, AttributeError, KeyError) as e:
|
|
logger.error(
|
|
"Failed to call MCP tool '%s': %s", tool_name, e, exc_info=True
|
|
)
|
|
return f"Error calling tool: {e!s}"
|
|
|
|
|
|
async def test_mcp_connection(
|
|
command: str, args: list[str], env: dict[str, str] | None = None
|
|
) -> dict[str, Any]:
|
|
"""Test connection to an MCP server via stdio and fetch available tools.
|
|
|
|
Args:
|
|
command: Command to spawn the MCP server
|
|
args: Arguments for the command
|
|
env: Optional environment variables
|
|
|
|
Returns:
|
|
Dict with connection status and available tools
|
|
|
|
"""
|
|
client = MCPClient(command, args, env)
|
|
|
|
try:
|
|
async with client.connect():
|
|
tools = await client.list_tools()
|
|
return {
|
|
"status": "success",
|
|
"message": f"Connected successfully. Found {len(tools)} tools.",
|
|
"tools": tools,
|
|
}
|
|
except (RuntimeError, ConnectionError, TimeoutError, OSError) as e:
|
|
return {
|
|
"status": "error",
|
|
"message": f"Failed to connect: {e!s}",
|
|
"tools": [],
|
|
}
|
|
|
|
|
|
async def test_mcp_http_connection(
|
|
url: str, headers: dict[str, str] | None = None, transport: str = "streamable-http"
|
|
) -> dict[str, Any]:
|
|
"""Test connection to an MCP server via HTTP and fetch available tools.
|
|
|
|
Args:
|
|
url: URL of the MCP server
|
|
headers: Optional HTTP headers for authentication
|
|
transport: Transport type ("streamable-http", "http", or "sse")
|
|
|
|
Returns:
|
|
Dict with connection status and available tools
|
|
|
|
"""
|
|
try:
|
|
logger.info(
|
|
"Testing HTTP MCP connection to: %s (transport: %s)", url, transport
|
|
)
|
|
|
|
# Use streamable HTTP client for all HTTP-based transports
|
|
async with (
|
|
streamablehttp_client(url, headers=headers or {}) as (read, write, _),
|
|
ClientSession(read, write) as session,
|
|
):
|
|
await session.initialize()
|
|
|
|
# List available tools
|
|
response = await session.list_tools()
|
|
tools = []
|
|
for tool in response.tools:
|
|
tools.append(
|
|
{
|
|
"name": tool.name,
|
|
"description": tool.description or "",
|
|
"input_schema": tool.inputSchema
|
|
if hasattr(tool, "inputSchema")
|
|
else {},
|
|
}
|
|
)
|
|
|
|
logger.info("HTTP MCP connection successful. Found %d tools.", len(tools))
|
|
return {
|
|
"status": "success",
|
|
"message": f"Connected successfully. Found {len(tools)} tools.",
|
|
"tools": tools,
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error("Failed to connect to HTTP MCP server: %s", e, exc_info=True)
|
|
return {
|
|
"status": "error",
|
|
"message": f"Failed to connect: {e!s}",
|
|
"tools": [],
|
|
}
|