add error handling to mcp_tool python files

This commit is contained in:
Manoj Aggarwal 2026-01-17 09:18:46 -08:00
parent 18ce599c81
commit f78c2a685e
2 changed files with 110 additions and 39 deletions

View file

@ -4,6 +4,7 @@ This module provides a client for communicating with MCP servers via stdio trans
It handles server lifecycle management, tool discovery, and tool execution. It handles server lifecycle management, tool discovery, and tool execution.
""" """
import asyncio
import logging import logging
import os import os
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
@ -14,6 +15,11 @@ from mcp.client.stdio import StdioServerParameters, stdio_client
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Retry configuration
MAX_RETRIES = 3
RETRY_DELAY = 1.0 # seconds
RETRY_BACKOFF = 2.0 # exponential backoff multiplier
class MCPClient: class MCPClient:
"""Client for communicating with an MCP server.""" """Client for communicating with an MCP server."""
@ -35,13 +41,23 @@ class MCPClient:
self.session: ClientSession | None = None self.session: ClientSession | None = None
@asynccontextmanager @asynccontextmanager
async def connect(self): async def connect(self, max_retries: int = MAX_RETRIES):
"""Connect to the MCP server and manage its lifecycle. """Connect to the MCP server and manage its lifecycle.
Args:
max_retries: Maximum number of connection retry attempts
Yields: Yields:
ClientSession: Active MCP session for making requests ClientSession: Active MCP session for making requests
Raises:
RuntimeError: If all connection attempts fail
""" """
last_error = None
delay = RETRY_DELAY
for attempt in range(max_retries):
try: try:
# Merge env vars with current environment # Merge env vars with current environment
server_env = os.environ.copy() server_env = os.environ.copy()
@ -60,19 +76,51 @@ class MCPClient:
# Initialize the connection # Initialize the connection
await session.initialize() await session.initialize()
self.session = session self.session = session
if attempt > 0:
logger.info(
"Connected to MCP server on attempt %d: %s %s",
attempt + 1,
self.command,
" ".join(self.args),
)
else:
logger.info( logger.info(
"Connected to MCP server: %s %s", "Connected to MCP server: %s %s",
self.command, self.command,
" ".join(self.args), " ".join(self.args),
) )
yield session yield session
return # Success, exit retry loop
except Exception as e: except Exception as e:
logger.error("Failed to connect to MCP server: %s", e, exc_info=True) last_error = e
raise if attempt < max_retries - 1:
logger.warning(
"MCP server connection failed (attempt %d/%d): %s. Retrying in %.1fs...",
attempt + 1,
max_retries,
e,
delay,
)
await asyncio.sleep(delay)
delay *= RETRY_BACKOFF # Exponential backoff
else:
logger.error(
"Failed to connect to MCP server after %d attempts: %s",
max_retries,
e,
exc_info=True,
)
finally: finally:
self.session = None self.session = None
logger.info("Disconnected from MCP server: %s", self.command)
# All retries exhausted
error_msg = f"Failed to connect to MCP server '{self.command}' after {max_retries} attempts"
if last_error:
error_msg += f": {last_error}"
logger.error(error_msg)
raise RuntimeError(error_msg) from last_error
async def list_tools(self) -> list[dict[str, Any]]: async def list_tools(self) -> list[dict[str, Any]]:
"""List all tools available from the MCP server. """List all tools available from the MCP server.

View file

@ -90,16 +90,22 @@ async def _create_mcp_tool_from_definition(
input_model = _create_dynamic_input_model_from_schema(tool_name, input_schema) input_model = _create_dynamic_input_model_from_schema(tool_name, input_schema)
async def mcp_tool_call(**kwargs) -> str: async def mcp_tool_call(**kwargs) -> str:
"""Execute the MCP tool call via the client.""" """Execute the MCP tool call via the client with retry support."""
logger.info(f"MCP tool '{tool_name}' called with params: {kwargs}") logger.info(f"MCP tool '{tool_name}' called with params: {kwargs}")
try: try:
# Connect to server and call tool # Connect to server and call tool (connect has built-in retry logic)
async with mcp_client.connect(): async with mcp_client.connect():
result = await mcp_client.call_tool(tool_name, kwargs) result = await mcp_client.call_tool(tool_name, kwargs)
return str(result) return str(result)
except RuntimeError as e:
# Connection failures after all retries
error_msg = f"MCP tool '{tool_name}' connection failed after retries: {e!s}"
logger.error(error_msg)
return f"Error: {error_msg}"
except Exception as e: except Exception as e:
error_msg = f"MCP tool '{tool_name}' failed: {e!s}" # Tool execution or other errors
error_msg = f"MCP tool '{tool_name}' execution failed: {e!s}"
logger.exception(error_msg) logger.exception(error_msg)
return f"Error: {error_msg}" return f"Error: {error_msg}"
@ -146,21 +152,38 @@ async def load_mcp_tools(
tools: list[StructuredTool] = [] tools: list[StructuredTool] = []
for connector in result.scalars(): for connector in result.scalars():
try: try:
# Extract single server config # Early validation: Extract and validate connector config
config = connector.config or {} config = connector.config or {}
server_config = config.get("server_config", {}) server_config = config.get("server_config", {})
if not server_config: # Validate server_config exists and is a dict
logger.warning(f"MCP connector {connector.id} missing server_config, skipping") if not server_config or not isinstance(server_config, dict):
logger.warning(
f"MCP connector {connector.id} (name: '{connector.name}') has invalid or missing server_config, skipping"
)
continue continue
# Validate required command field
command = server_config.get("command") command = server_config.get("command")
args = server_config.get("args", []) if not command or not isinstance(command, str):
env = server_config.get("env", {})
if not command:
logger.warning( logger.warning(
f"MCP connector {connector.id} missing command, skipping" f"MCP connector {connector.id} (name: '{connector.name}') missing or invalid command field, skipping"
)
continue
# Validate args field (must be list if present)
args = server_config.get("args", [])
if not isinstance(args, list):
logger.warning(
f"MCP connector {connector.id} (name: '{connector.name}') has invalid args field (must be list), skipping"
)
continue
# Validate env field (must be dict if present)
env = server_config.get("env", {})
if not isinstance(env, dict):
logger.warning(
f"MCP connector {connector.id} (name: '{connector.name}') has invalid env field (must be dict), skipping"
) )
continue continue