mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-07 23:02:39 +02:00
add error handling to mcp_tool python files
This commit is contained in:
parent
18ce599c81
commit
f78c2a685e
2 changed files with 110 additions and 39 deletions
|
|
@ -4,6 +4,7 @@ This module provides a client for communicating with MCP servers via stdio trans
|
|||
It handles server lifecycle management, tool discovery, and tool execution.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
from contextlib import asynccontextmanager
|
||||
|
|
@ -14,6 +15,11 @@ from mcp.client.stdio import StdioServerParameters, stdio_client
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Retry configuration
|
||||
MAX_RETRIES = 3
|
||||
RETRY_DELAY = 1.0 # seconds
|
||||
RETRY_BACKOFF = 2.0 # exponential backoff multiplier
|
||||
|
||||
|
||||
class MCPClient:
|
||||
"""Client for communicating with an MCP server."""
|
||||
|
|
@ -35,44 +41,86 @@ class MCPClient:
|
|||
self.session: ClientSession | None = None
|
||||
|
||||
@asynccontextmanager
|
||||
async def connect(self):
|
||||
async def connect(self, max_retries: int = MAX_RETRIES):
|
||||
"""Connect to the MCP server and manage its lifecycle.
|
||||
|
||||
Args:
|
||||
max_retries: Maximum number of connection retry attempts
|
||||
|
||||
Yields:
|
||||
ClientSession: Active MCP session for making requests
|
||||
|
||||
Raises:
|
||||
RuntimeError: If all connection attempts fail
|
||||
|
||||
"""
|
||||
try:
|
||||
# Merge env vars with current environment
|
||||
server_env = os.environ.copy()
|
||||
server_env.update(self.env)
|
||||
last_error = None
|
||||
delay = RETRY_DELAY
|
||||
|
||||
# Create server parameters with env
|
||||
server_params = StdioServerParameters(
|
||||
command=self.command, args=self.args, env=server_env
|
||||
)
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
# Merge env vars with current environment
|
||||
server_env = os.environ.copy()
|
||||
server_env.update(self.env)
|
||||
|
||||
# Spawn server process and create session
|
||||
# Note: Cannot combine these context managers because ClientSession
|
||||
# needs the read/write streams from stdio_client
|
||||
async with stdio_client(server=server_params) as (read, write): # noqa: SIM117
|
||||
async with ClientSession(read, write) as session:
|
||||
# Initialize the connection
|
||||
await session.initialize()
|
||||
self.session = session
|
||||
logger.info(
|
||||
"Connected to MCP server: %s %s",
|
||||
self.command,
|
||||
" ".join(self.args),
|
||||
# Create server parameters with env
|
||||
server_params = StdioServerParameters(
|
||||
command=self.command, args=self.args, env=server_env
|
||||
)
|
||||
|
||||
# Spawn server process and create session
|
||||
# Note: Cannot combine these context managers because ClientSession
|
||||
# needs the read/write streams from stdio_client
|
||||
async with stdio_client(server=server_params) as (read, write): # noqa: SIM117
|
||||
async with ClientSession(read, write) as session:
|
||||
# Initialize the connection
|
||||
await session.initialize()
|
||||
self.session = session
|
||||
|
||||
if attempt > 0:
|
||||
logger.info(
|
||||
"Connected to MCP server on attempt %d: %s %s",
|
||||
attempt + 1,
|
||||
self.command,
|
||||
" ".join(self.args),
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
"Connected to MCP server: %s %s",
|
||||
self.command,
|
||||
" ".join(self.args),
|
||||
)
|
||||
yield session
|
||||
return # Success, exit retry loop
|
||||
|
||||
except Exception as e:
|
||||
last_error = e
|
||||
if attempt < max_retries - 1:
|
||||
logger.warning(
|
||||
"MCP server connection failed (attempt %d/%d): %s. Retrying in %.1fs...",
|
||||
attempt + 1,
|
||||
max_retries,
|
||||
e,
|
||||
delay,
|
||||
)
|
||||
yield session
|
||||
await asyncio.sleep(delay)
|
||||
delay *= RETRY_BACKOFF # Exponential backoff
|
||||
else:
|
||||
logger.error(
|
||||
"Failed to connect to MCP server after %d attempts: %s",
|
||||
max_retries,
|
||||
e,
|
||||
exc_info=True,
|
||||
)
|
||||
finally:
|
||||
self.session = None
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to connect to MCP server: %s", e, exc_info=True)
|
||||
raise
|
||||
finally:
|
||||
self.session = None
|
||||
logger.info("Disconnected from MCP server: %s", self.command)
|
||||
# All retries exhausted
|
||||
error_msg = f"Failed to connect to MCP server '{self.command}' after {max_retries} attempts"
|
||||
if last_error:
|
||||
error_msg += f": {last_error}"
|
||||
logger.error(error_msg)
|
||||
raise RuntimeError(error_msg) from last_error
|
||||
|
||||
async def list_tools(self) -> list[dict[str, Any]]:
|
||||
"""List all tools available from the MCP server.
|
||||
|
|
|
|||
|
|
@ -90,16 +90,22 @@ async def _create_mcp_tool_from_definition(
|
|||
input_model = _create_dynamic_input_model_from_schema(tool_name, input_schema)
|
||||
|
||||
async def mcp_tool_call(**kwargs) -> str:
|
||||
"""Execute the MCP tool call via the client."""
|
||||
"""Execute the MCP tool call via the client with retry support."""
|
||||
logger.info(f"MCP tool '{tool_name}' called with params: {kwargs}")
|
||||
|
||||
try:
|
||||
# Connect to server and call tool
|
||||
# Connect to server and call tool (connect has built-in retry logic)
|
||||
async with mcp_client.connect():
|
||||
result = await mcp_client.call_tool(tool_name, kwargs)
|
||||
return str(result)
|
||||
except RuntimeError as e:
|
||||
# Connection failures after all retries
|
||||
error_msg = f"MCP tool '{tool_name}' connection failed after retries: {e!s}"
|
||||
logger.error(error_msg)
|
||||
return f"Error: {error_msg}"
|
||||
except Exception as e:
|
||||
error_msg = f"MCP tool '{tool_name}' failed: {e!s}"
|
||||
# Tool execution or other errors
|
||||
error_msg = f"MCP tool '{tool_name}' execution failed: {e!s}"
|
||||
logger.exception(error_msg)
|
||||
return f"Error: {error_msg}"
|
||||
|
||||
|
|
@ -146,21 +152,38 @@ async def load_mcp_tools(
|
|||
tools: list[StructuredTool] = []
|
||||
for connector in result.scalars():
|
||||
try:
|
||||
# Extract single server config
|
||||
# Early validation: Extract and validate connector config
|
||||
config = connector.config or {}
|
||||
server_config = config.get("server_config", {})
|
||||
|
||||
if not server_config:
|
||||
logger.warning(f"MCP connector {connector.id} missing server_config, skipping")
|
||||
# Validate server_config exists and is a dict
|
||||
if not server_config or not isinstance(server_config, dict):
|
||||
logger.warning(
|
||||
f"MCP connector {connector.id} (name: '{connector.name}') has invalid or missing server_config, skipping"
|
||||
)
|
||||
continue
|
||||
|
||||
# Validate required command field
|
||||
command = server_config.get("command")
|
||||
args = server_config.get("args", [])
|
||||
env = server_config.get("env", {})
|
||||
|
||||
if not command:
|
||||
if not command or not isinstance(command, str):
|
||||
logger.warning(
|
||||
f"MCP connector {connector.id} missing command, skipping"
|
||||
f"MCP connector {connector.id} (name: '{connector.name}') missing or invalid command field, skipping"
|
||||
)
|
||||
continue
|
||||
|
||||
# Validate args field (must be list if present)
|
||||
args = server_config.get("args", [])
|
||||
if not isinstance(args, list):
|
||||
logger.warning(
|
||||
f"MCP connector {connector.id} (name: '{connector.name}') has invalid args field (must be list), skipping"
|
||||
)
|
||||
continue
|
||||
|
||||
# Validate env field (must be dict if present)
|
||||
env = server_config.get("env", {})
|
||||
if not isinstance(env, dict):
|
||||
logger.warning(
|
||||
f"MCP connector {connector.id} (name: '{connector.name}') has invalid env field (must be dict), skipping"
|
||||
)
|
||||
continue
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue