add error handling to mcp_tool python files

This commit is contained in:
Manoj Aggarwal 2026-01-17 09:18:46 -08:00
parent 18ce599c81
commit f78c2a685e
2 changed files with 110 additions and 39 deletions

View file

@ -4,6 +4,7 @@ This module provides a client for communicating with MCP servers via stdio trans
It handles server lifecycle management, tool discovery, and tool execution. It handles server lifecycle management, tool discovery, and tool execution.
""" """
import asyncio
import logging import logging
import os import os
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
@ -14,6 +15,11 @@ from mcp.client.stdio import StdioServerParameters, stdio_client
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Retry configuration
MAX_RETRIES = 3
RETRY_DELAY = 1.0 # seconds
RETRY_BACKOFF = 2.0 # exponential backoff multiplier
class MCPClient: class MCPClient:
"""Client for communicating with an MCP server.""" """Client for communicating with an MCP server."""
@ -35,44 +41,86 @@ class MCPClient:
self.session: ClientSession | None = None self.session: ClientSession | None = None
@asynccontextmanager @asynccontextmanager
async def connect(self): async def connect(self, max_retries: int = MAX_RETRIES):
"""Connect to the MCP server and manage its lifecycle. """Connect to the MCP server and manage its lifecycle.
Args:
max_retries: Maximum number of connection retry attempts
Yields: Yields:
ClientSession: Active MCP session for making requests ClientSession: Active MCP session for making requests
Raises:
RuntimeError: If all connection attempts fail
""" """
try: last_error = None
# Merge env vars with current environment delay = RETRY_DELAY
server_env = os.environ.copy()
server_env.update(self.env)
# Create server parameters with env for attempt in range(max_retries):
server_params = StdioServerParameters( try:
command=self.command, args=self.args, env=server_env # Merge env vars with current environment
) server_env = os.environ.copy()
server_env.update(self.env)
# Spawn server process and create session # Create server parameters with env
# Note: Cannot combine these context managers because ClientSession server_params = StdioServerParameters(
# needs the read/write streams from stdio_client command=self.command, args=self.args, env=server_env
async with stdio_client(server=server_params) as (read, write): # noqa: SIM117 )
async with ClientSession(read, write) as session:
# Initialize the connection # Spawn server process and create session
await session.initialize() # Note: Cannot combine these context managers because ClientSession
self.session = session # needs the read/write streams from stdio_client
logger.info( async with stdio_client(server=server_params) as (read, write): # noqa: SIM117
"Connected to MCP server: %s %s", async with ClientSession(read, write) as session:
self.command, # Initialize the connection
" ".join(self.args), await session.initialize()
self.session = session
if attempt > 0:
logger.info(
"Connected to MCP server on attempt %d: %s %s",
attempt + 1,
self.command,
" ".join(self.args),
)
else:
logger.info(
"Connected to MCP server: %s %s",
self.command,
" ".join(self.args),
)
yield session
return # Success, exit retry loop
except Exception as e:
last_error = e
if attempt < max_retries - 1:
logger.warning(
"MCP server connection failed (attempt %d/%d): %s. Retrying in %.1fs...",
attempt + 1,
max_retries,
e,
delay,
) )
yield session await asyncio.sleep(delay)
delay *= RETRY_BACKOFF # Exponential backoff
else:
logger.error(
"Failed to connect to MCP server after %d attempts: %s",
max_retries,
e,
exc_info=True,
)
finally:
self.session = None
except Exception as e: # All retries exhausted
logger.error("Failed to connect to MCP server: %s", e, exc_info=True) error_msg = f"Failed to connect to MCP server '{self.command}' after {max_retries} attempts"
raise if last_error:
finally: error_msg += f": {last_error}"
self.session = None logger.error(error_msg)
logger.info("Disconnected from MCP server: %s", self.command) raise RuntimeError(error_msg) from last_error
async def list_tools(self) -> list[dict[str, Any]]: async def list_tools(self) -> list[dict[str, Any]]:
"""List all tools available from the MCP server. """List all tools available from the MCP server.

View file

@ -90,16 +90,22 @@ async def _create_mcp_tool_from_definition(
input_model = _create_dynamic_input_model_from_schema(tool_name, input_schema) input_model = _create_dynamic_input_model_from_schema(tool_name, input_schema)
async def mcp_tool_call(**kwargs) -> str: async def mcp_tool_call(**kwargs) -> str:
"""Execute the MCP tool call via the client.""" """Execute the MCP tool call via the client with retry support."""
logger.info(f"MCP tool '{tool_name}' called with params: {kwargs}") logger.info(f"MCP tool '{tool_name}' called with params: {kwargs}")
try: try:
# Connect to server and call tool # Connect to server and call tool (connect has built-in retry logic)
async with mcp_client.connect(): async with mcp_client.connect():
result = await mcp_client.call_tool(tool_name, kwargs) result = await mcp_client.call_tool(tool_name, kwargs)
return str(result) return str(result)
except RuntimeError as e:
# Connection failures after all retries
error_msg = f"MCP tool '{tool_name}' connection failed after retries: {e!s}"
logger.error(error_msg)
return f"Error: {error_msg}"
except Exception as e: except Exception as e:
error_msg = f"MCP tool '{tool_name}' failed: {e!s}" # Tool execution or other errors
error_msg = f"MCP tool '{tool_name}' execution failed: {e!s}"
logger.exception(error_msg) logger.exception(error_msg)
return f"Error: {error_msg}" return f"Error: {error_msg}"
@ -146,21 +152,38 @@ async def load_mcp_tools(
tools: list[StructuredTool] = [] tools: list[StructuredTool] = []
for connector in result.scalars(): for connector in result.scalars():
try: try:
# Extract single server config # Early validation: Extract and validate connector config
config = connector.config or {} config = connector.config or {}
server_config = config.get("server_config", {}) server_config = config.get("server_config", {})
if not server_config: # Validate server_config exists and is a dict
logger.warning(f"MCP connector {connector.id} missing server_config, skipping") if not server_config or not isinstance(server_config, dict):
logger.warning(
f"MCP connector {connector.id} (name: '{connector.name}') has invalid or missing server_config, skipping"
)
continue continue
# Validate required command field
command = server_config.get("command") command = server_config.get("command")
args = server_config.get("args", []) if not command or not isinstance(command, str):
env = server_config.get("env", {})
if not command:
logger.warning( logger.warning(
f"MCP connector {connector.id} missing command, skipping" f"MCP connector {connector.id} (name: '{connector.name}') missing or invalid command field, skipping"
)
continue
# Validate args field (must be list if present)
args = server_config.get("args", [])
if not isinstance(args, list):
logger.warning(
f"MCP connector {connector.id} (name: '{connector.name}') has invalid args field (must be list), skipping"
)
continue
# Validate env field (must be dict if present)
env = server_config.get("env", {})
if not isinstance(env, dict):
logger.warning(
f"MCP connector {connector.id} (name: '{connector.name}') has invalid env field (must be dict), skipping"
) )
continue continue