add error handling to mcp_tool python files

This commit is contained in:
Manoj Aggarwal 2026-01-17 09:18:46 -08:00
parent 18ce599c81
commit f78c2a685e
2 changed files with 110 additions and 39 deletions

View file

@ -4,6 +4,7 @@ This module provides a client for communicating with MCP servers via stdio trans
It handles server lifecycle management, tool discovery, and tool execution.
"""
import asyncio
import logging
import os
from contextlib import asynccontextmanager
@ -14,6 +15,11 @@ from mcp.client.stdio import StdioServerParameters, stdio_client
logger = logging.getLogger(__name__)
# Retry configuration
MAX_RETRIES = 3
RETRY_DELAY = 1.0 # seconds
RETRY_BACKOFF = 2.0 # exponential backoff multiplier
class MCPClient:
"""Client for communicating with an MCP server."""
@ -35,44 +41,86 @@ class MCPClient:
self.session: ClientSession | None = None
@asynccontextmanager
async def connect(self):
async def connect(self, max_retries: int = MAX_RETRIES):
"""Connect to the MCP server and manage its lifecycle.
Args:
max_retries: Maximum number of connection retry attempts
Yields:
ClientSession: Active MCP session for making requests
Raises:
RuntimeError: If all connection attempts fail
"""
try:
# Merge env vars with current environment
server_env = os.environ.copy()
server_env.update(self.env)
last_error = None
delay = RETRY_DELAY
# Create server parameters with env
server_params = StdioServerParameters(
command=self.command, args=self.args, env=server_env
)
for attempt in range(max_retries):
try:
# Merge env vars with current environment
server_env = os.environ.copy()
server_env.update(self.env)
# Spawn server process and create session
# Note: Cannot combine these context managers because ClientSession
# needs the read/write streams from stdio_client
async with stdio_client(server=server_params) as (read, write): # noqa: SIM117
async with ClientSession(read, write) as session:
# Initialize the connection
await session.initialize()
self.session = session
logger.info(
"Connected to MCP server: %s %s",
self.command,
" ".join(self.args),
# Create server parameters with env
server_params = StdioServerParameters(
command=self.command, args=self.args, env=server_env
)
# Spawn server process and create session
# Note: Cannot combine these context managers because ClientSession
# needs the read/write streams from stdio_client
async with stdio_client(server=server_params) as (read, write): # noqa: SIM117
async with ClientSession(read, write) as session:
# Initialize the connection
await session.initialize()
self.session = session
if attempt > 0:
logger.info(
"Connected to MCP server on attempt %d: %s %s",
attempt + 1,
self.command,
" ".join(self.args),
)
else:
logger.info(
"Connected to MCP server: %s %s",
self.command,
" ".join(self.args),
)
yield session
return # Success, exit retry loop
except Exception as e:
last_error = e
if attempt < max_retries - 1:
logger.warning(
"MCP server connection failed (attempt %d/%d): %s. Retrying in %.1fs...",
attempt + 1,
max_retries,
e,
delay,
)
yield session
await asyncio.sleep(delay)
delay *= RETRY_BACKOFF # Exponential backoff
else:
logger.error(
"Failed to connect to MCP server after %d attempts: %s",
max_retries,
e,
exc_info=True,
)
finally:
self.session = None
except Exception as e:
logger.error("Failed to connect to MCP server: %s", e, exc_info=True)
raise
finally:
self.session = None
logger.info("Disconnected from MCP server: %s", self.command)
# All retries exhausted
error_msg = f"Failed to connect to MCP server '{self.command}' after {max_retries} attempts"
if last_error:
error_msg += f": {last_error}"
logger.error(error_msg)
raise RuntimeError(error_msg) from last_error
async def list_tools(self) -> list[dict[str, Any]]:
"""List all tools available from the MCP server.

View file

@ -90,16 +90,22 @@ async def _create_mcp_tool_from_definition(
input_model = _create_dynamic_input_model_from_schema(tool_name, input_schema)
async def mcp_tool_call(**kwargs) -> str:
"""Execute the MCP tool call via the client."""
"""Execute the MCP tool call via the client with retry support."""
logger.info(f"MCP tool '{tool_name}' called with params: {kwargs}")
try:
# Connect to server and call tool
# Connect to server and call tool (connect has built-in retry logic)
async with mcp_client.connect():
result = await mcp_client.call_tool(tool_name, kwargs)
return str(result)
except RuntimeError as e:
# Connection failures after all retries
error_msg = f"MCP tool '{tool_name}' connection failed after retries: {e!s}"
logger.error(error_msg)
return f"Error: {error_msg}"
except Exception as e:
error_msg = f"MCP tool '{tool_name}' failed: {e!s}"
# Tool execution or other errors
error_msg = f"MCP tool '{tool_name}' execution failed: {e!s}"
logger.exception(error_msg)
return f"Error: {error_msg}"
@ -146,21 +152,38 @@ async def load_mcp_tools(
tools: list[StructuredTool] = []
for connector in result.scalars():
try:
# Extract single server config
# Early validation: Extract and validate connector config
config = connector.config or {}
server_config = config.get("server_config", {})
if not server_config:
logger.warning(f"MCP connector {connector.id} missing server_config, skipping")
# Validate server_config exists and is a dict
if not server_config or not isinstance(server_config, dict):
logger.warning(
f"MCP connector {connector.id} (name: '{connector.name}') has invalid or missing server_config, skipping"
)
continue
# Validate required command field
command = server_config.get("command")
args = server_config.get("args", [])
env = server_config.get("env", {})
if not command:
if not command or not isinstance(command, str):
logger.warning(
f"MCP connector {connector.id} missing command, skipping"
f"MCP connector {connector.id} (name: '{connector.name}') missing or invalid command field, skipping"
)
continue
# Validate args field (must be list if present)
args = server_config.get("args", [])
if not isinstance(args, list):
logger.warning(
f"MCP connector {connector.id} (name: '{connector.name}') has invalid args field (must be list), skipping"
)
continue
# Validate env field (must be dict if present)
env = server_config.get("env", {})
if not isinstance(env, dict):
logger.warning(
f"MCP connector {connector.id} (name: '{connector.name}') has invalid env field (must be dict), skipping"
)
continue