mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-08 15:22:39 +02:00
add error handling to mcp_tool python files
This commit is contained in:
parent
18ce599c81
commit
f78c2a685e
2 changed files with 110 additions and 39 deletions
|
|
@ -4,6 +4,7 @@ This module provides a client for communicating with MCP servers via stdio trans
|
||||||
It handles server lifecycle management, tool discovery, and tool execution.
|
It handles server lifecycle management, tool discovery, and tool execution.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
|
|
@ -14,6 +15,11 @@ from mcp.client.stdio import StdioServerParameters, stdio_client
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Retry configuration
|
||||||
|
MAX_RETRIES = 3
|
||||||
|
RETRY_DELAY = 1.0 # seconds
|
||||||
|
RETRY_BACKOFF = 2.0 # exponential backoff multiplier
|
||||||
|
|
||||||
|
|
||||||
class MCPClient:
|
class MCPClient:
|
||||||
"""Client for communicating with an MCP server."""
|
"""Client for communicating with an MCP server."""
|
||||||
|
|
@ -35,44 +41,86 @@ class MCPClient:
|
||||||
self.session: ClientSession | None = None
|
self.session: ClientSession | None = None
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def connect(self):
|
async def connect(self, max_retries: int = MAX_RETRIES):
|
||||||
"""Connect to the MCP server and manage its lifecycle.
|
"""Connect to the MCP server and manage its lifecycle.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
max_retries: Maximum number of connection retry attempts
|
||||||
|
|
||||||
Yields:
|
Yields:
|
||||||
ClientSession: Active MCP session for making requests
|
ClientSession: Active MCP session for making requests
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: If all connection attempts fail
|
||||||
|
|
||||||
"""
|
"""
|
||||||
try:
|
last_error = None
|
||||||
# Merge env vars with current environment
|
delay = RETRY_DELAY
|
||||||
server_env = os.environ.copy()
|
|
||||||
server_env.update(self.env)
|
|
||||||
|
|
||||||
# Create server parameters with env
|
for attempt in range(max_retries):
|
||||||
server_params = StdioServerParameters(
|
try:
|
||||||
command=self.command, args=self.args, env=server_env
|
# Merge env vars with current environment
|
||||||
)
|
server_env = os.environ.copy()
|
||||||
|
server_env.update(self.env)
|
||||||
|
|
||||||
# Spawn server process and create session
|
# Create server parameters with env
|
||||||
# Note: Cannot combine these context managers because ClientSession
|
server_params = StdioServerParameters(
|
||||||
# needs the read/write streams from stdio_client
|
command=self.command, args=self.args, env=server_env
|
||||||
async with stdio_client(server=server_params) as (read, write): # noqa: SIM117
|
)
|
||||||
async with ClientSession(read, write) as session:
|
|
||||||
# Initialize the connection
|
# Spawn server process and create session
|
||||||
await session.initialize()
|
# Note: Cannot combine these context managers because ClientSession
|
||||||
self.session = session
|
# needs the read/write streams from stdio_client
|
||||||
logger.info(
|
async with stdio_client(server=server_params) as (read, write): # noqa: SIM117
|
||||||
"Connected to MCP server: %s %s",
|
async with ClientSession(read, write) as session:
|
||||||
self.command,
|
# Initialize the connection
|
||||||
" ".join(self.args),
|
await session.initialize()
|
||||||
|
self.session = session
|
||||||
|
|
||||||
|
if attempt > 0:
|
||||||
|
logger.info(
|
||||||
|
"Connected to MCP server on attempt %d: %s %s",
|
||||||
|
attempt + 1,
|
||||||
|
self.command,
|
||||||
|
" ".join(self.args),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.info(
|
||||||
|
"Connected to MCP server: %s %s",
|
||||||
|
self.command,
|
||||||
|
" ".join(self.args),
|
||||||
|
)
|
||||||
|
yield session
|
||||||
|
return # Success, exit retry loop
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
last_error = e
|
||||||
|
if attempt < max_retries - 1:
|
||||||
|
logger.warning(
|
||||||
|
"MCP server connection failed (attempt %d/%d): %s. Retrying in %.1fs...",
|
||||||
|
attempt + 1,
|
||||||
|
max_retries,
|
||||||
|
e,
|
||||||
|
delay,
|
||||||
)
|
)
|
||||||
yield session
|
await asyncio.sleep(delay)
|
||||||
|
delay *= RETRY_BACKOFF # Exponential backoff
|
||||||
|
else:
|
||||||
|
logger.error(
|
||||||
|
"Failed to connect to MCP server after %d attempts: %s",
|
||||||
|
max_retries,
|
||||||
|
e,
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
self.session = None
|
||||||
|
|
||||||
except Exception as e:
|
# All retries exhausted
|
||||||
logger.error("Failed to connect to MCP server: %s", e, exc_info=True)
|
error_msg = f"Failed to connect to MCP server '{self.command}' after {max_retries} attempts"
|
||||||
raise
|
if last_error:
|
||||||
finally:
|
error_msg += f": {last_error}"
|
||||||
self.session = None
|
logger.error(error_msg)
|
||||||
logger.info("Disconnected from MCP server: %s", self.command)
|
raise RuntimeError(error_msg) from last_error
|
||||||
|
|
||||||
async def list_tools(self) -> list[dict[str, Any]]:
|
async def list_tools(self) -> list[dict[str, Any]]:
|
||||||
"""List all tools available from the MCP server.
|
"""List all tools available from the MCP server.
|
||||||
|
|
|
||||||
|
|
@ -90,16 +90,22 @@ async def _create_mcp_tool_from_definition(
|
||||||
input_model = _create_dynamic_input_model_from_schema(tool_name, input_schema)
|
input_model = _create_dynamic_input_model_from_schema(tool_name, input_schema)
|
||||||
|
|
||||||
async def mcp_tool_call(**kwargs) -> str:
|
async def mcp_tool_call(**kwargs) -> str:
|
||||||
"""Execute the MCP tool call via the client."""
|
"""Execute the MCP tool call via the client with retry support."""
|
||||||
logger.info(f"MCP tool '{tool_name}' called with params: {kwargs}")
|
logger.info(f"MCP tool '{tool_name}' called with params: {kwargs}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Connect to server and call tool
|
# Connect to server and call tool (connect has built-in retry logic)
|
||||||
async with mcp_client.connect():
|
async with mcp_client.connect():
|
||||||
result = await mcp_client.call_tool(tool_name, kwargs)
|
result = await mcp_client.call_tool(tool_name, kwargs)
|
||||||
return str(result)
|
return str(result)
|
||||||
|
except RuntimeError as e:
|
||||||
|
# Connection failures after all retries
|
||||||
|
error_msg = f"MCP tool '{tool_name}' connection failed after retries: {e!s}"
|
||||||
|
logger.error(error_msg)
|
||||||
|
return f"Error: {error_msg}"
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_msg = f"MCP tool '{tool_name}' failed: {e!s}"
|
# Tool execution or other errors
|
||||||
|
error_msg = f"MCP tool '{tool_name}' execution failed: {e!s}"
|
||||||
logger.exception(error_msg)
|
logger.exception(error_msg)
|
||||||
return f"Error: {error_msg}"
|
return f"Error: {error_msg}"
|
||||||
|
|
||||||
|
|
@ -146,21 +152,38 @@ async def load_mcp_tools(
|
||||||
tools: list[StructuredTool] = []
|
tools: list[StructuredTool] = []
|
||||||
for connector in result.scalars():
|
for connector in result.scalars():
|
||||||
try:
|
try:
|
||||||
# Extract single server config
|
# Early validation: Extract and validate connector config
|
||||||
config = connector.config or {}
|
config = connector.config or {}
|
||||||
server_config = config.get("server_config", {})
|
server_config = config.get("server_config", {})
|
||||||
|
|
||||||
if not server_config:
|
# Validate server_config exists and is a dict
|
||||||
logger.warning(f"MCP connector {connector.id} missing server_config, skipping")
|
if not server_config or not isinstance(server_config, dict):
|
||||||
|
logger.warning(
|
||||||
|
f"MCP connector {connector.id} (name: '{connector.name}') has invalid or missing server_config, skipping"
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
# Validate required command field
|
||||||
command = server_config.get("command")
|
command = server_config.get("command")
|
||||||
args = server_config.get("args", [])
|
if not command or not isinstance(command, str):
|
||||||
env = server_config.get("env", {})
|
|
||||||
|
|
||||||
if not command:
|
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"MCP connector {connector.id} missing command, skipping"
|
f"MCP connector {connector.id} (name: '{connector.name}') missing or invalid command field, skipping"
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Validate args field (must be list if present)
|
||||||
|
args = server_config.get("args", [])
|
||||||
|
if not isinstance(args, list):
|
||||||
|
logger.warning(
|
||||||
|
f"MCP connector {connector.id} (name: '{connector.name}') has invalid args field (must be list), skipping"
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Validate env field (must be dict if present)
|
||||||
|
env = server_config.get("env", {})
|
||||||
|
if not isinstance(env, dict):
|
||||||
|
logger.warning(
|
||||||
|
f"MCP connector {connector.id} (name: '{connector.name}') has invalid env field (must be dict), skipping"
|
||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue