fix: robust generic MCP tool routing, retry, and empty-schema handling

This commit is contained in:
CREDO23 2026-04-23 11:30:58 +02:00
parent 1712f454f8
commit 45b72de481
4 changed files with 191 additions and 42 deletions

View file

@ -45,6 +45,18 @@ class MCPClient:
async def connect(self, max_retries: int = MAX_RETRIES):
"""Connect to the MCP server and manage its lifecycle.
Retries only apply to the **connection** phase (spawning the process,
initialising the session). Once the session is yielded to the caller,
any exception raised by the caller propagates normally -- the context
manager will NOT retry after ``yield``.
Previous implementation wrapped both connection AND yield inside the
retry loop. Because ``@asynccontextmanager`` only allows a single
``yield``, a failure after yield caused the generator to attempt a
second yield on retry, triggering
``RuntimeError("generator didn't stop after athrow()")`` and orphaning
the stdio subprocess.
Args:
max_retries: Maximum number of connection retry attempts
@ -57,26 +69,22 @@ class MCPClient:
"""
last_error = None
delay = RETRY_DELAY
connected = False
for attempt in range(max_retries):
try:
# Merge env vars with current environment
server_env = os.environ.copy()
server_env.update(self.env)
# Create server parameters with env
server_params = StdioServerParameters(
command=self.command, args=self.args, env=server_env
)
# Spawn server process and create session
# Note: Cannot combine these context managers because ClientSession
# needs the read/write streams from stdio_client
async with stdio_client(server=server_params) as (read, write): # noqa: SIM117
async with ClientSession(read, write) as session:
# Initialize the connection
await session.initialize()
self.session = session
connected = True
if attempt > 0:
logger.info(
@ -91,10 +99,16 @@ class MCPClient:
self.command,
" ".join(self.args),
)
yield session
return # Success, exit retry loop
try:
yield session
finally:
self.session = None
return
except Exception as e:
self.session = None
if connected:
raise
last_error = e
if attempt < max_retries - 1:
logger.warning(
@ -105,7 +119,7 @@ class MCPClient:
delay,
)
await asyncio.sleep(delay)
delay *= RETRY_BACKOFF # Exponential backoff
delay *= RETRY_BACKOFF
else:
logger.error(
"Failed to connect to MCP server after %d attempts: %s",
@ -113,10 +127,7 @@ class MCPClient:
e,
exc_info=True,
)
finally:
self.session = None
# All retries exhausted
error_msg = f"Failed to connect to MCP server '{self.command}' after {max_retries} attempts"
if last_error:
error_msg += f": {last_error}"
@ -161,12 +172,18 @@ class MCPClient:
logger.error("Failed to list tools from MCP server: %s", e, exc_info=True)
raise
async def call_tool(self, tool_name: str, arguments: dict[str, Any]) -> Any:
async def call_tool(
self,
tool_name: str,
arguments: dict[str, Any],
timeout: float = 60.0,
) -> Any:
"""Call a tool on the MCP server.
Args:
tool_name: Name of the tool to call
arguments: Arguments to pass to the tool
timeout: Maximum seconds to wait for the tool to respond
Returns:
Tool execution result
@ -185,10 +202,11 @@ class MCPClient:
"Calling MCP tool '%s' with arguments: %s", tool_name, arguments
)
# Call tools/call RPC method
response = await self.session.call_tool(tool_name, arguments=arguments)
response = await asyncio.wait_for(
self.session.call_tool(tool_name, arguments=arguments),
timeout=timeout,
)
# Extract content from response
result = []
for content in response.content:
if hasattr(content, "text"):
@ -202,15 +220,17 @@ class MCPClient:
logger.info("MCP tool '%s' succeeded: %s", tool_name, result_str[:200])
return result_str
except asyncio.TimeoutError:
logger.error(
"MCP tool '%s' timed out after %.0fs", tool_name, timeout
)
return f"Error: MCP tool '{tool_name}' timed out after {timeout:.0f}s"
except RuntimeError as e:
# Handle validation errors from MCP server responses
# Some MCP servers (like server-memory) return extra fields not in their schema
if "Invalid structured content" in str(e):
logger.warning(
"MCP server returned data not matching its schema, but continuing: %s",
e,
)
# Try to extract result from error message or return a success message
return "Operation completed (server returned unexpected format)"
raise
except (ValueError, TypeError, AttributeError, KeyError) as e:

View file

@ -28,7 +28,7 @@ if TYPE_CHECKING:
from langchain_core.tools import StructuredTool
from mcp import ClientSession
from mcp.client.streamable_http import streamablehttp_client
from pydantic import BaseModel, Field, create_model
from pydantic import BaseModel, ConfigDict, Field, create_model
from sqlalchemy import cast, select
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.ext.asyncio import AsyncSession
@ -43,6 +43,8 @@ logger = logging.getLogger(__name__)
_MCP_CACHE_TTL_SECONDS = 300 # 5 minutes
_MCP_CACHE_MAX_SIZE = 50
_MCP_DISCOVERY_TIMEOUT_SECONDS = 30
_TOOL_CALL_MAX_RETRIES = 3
_TOOL_CALL_RETRY_DELAY = 1.5 # seconds, doubles per attempt
_mcp_tools_cache: dict[int, tuple[float, list[StructuredTool]]] = {}
@ -64,7 +66,18 @@ def _create_dynamic_input_model_from_schema(
tool_name: str,
input_schema: dict[str, Any],
) -> type[BaseModel]:
"""Create a Pydantic model from MCP tool's JSON schema."""
"""Create a Pydantic model from MCP tool's JSON schema.
Models always allow extra fields (``extra="allow"``) so that parameters
missing from a broken or incomplete JSON schema (e.g. ``zod-to-json-schema``
producing an empty ``$schema``-only object) can still be forwarded to the
MCP server.
When the schema declares **no** properties, a synthetic ``input_data``
field of type ``dict`` is injected so the LLM has a visible parameter to
populate. The caller should unpack ``input_data`` before forwarding to
the MCP server (see ``_unpack_synthetic_input_data``).
"""
properties = input_schema.get("properties", {})
required_fields = input_schema.get("required", [])
@ -84,8 +97,35 @@ def _create_dynamic_input_model_from_schema(
Field(None, description=param_description),
)
if not properties:
field_definitions["input_data"] = (
dict[str, Any] | None,
Field(
None,
description=(
"Arguments to pass to this tool as a JSON object. "
"Infer sensible key names from the tool name and description "
"(e.g. {\"search\": \"my query\"} for a search tool)."
),
),
)
model_name = f"{tool_name.replace(' ', '').replace('-', '_')}Input"
return create_model(model_name, **field_definitions)
model = create_model(model_name, __config__=ConfigDict(extra="allow"), **field_definitions)
return model
def _unpack_synthetic_input_data(kwargs: dict[str, Any]) -> dict[str, Any]:
"""Unpack the synthetic ``input_data`` field into top-level kwargs.
When the MCP tool schema is empty, ``_create_dynamic_input_model_from_schema``
adds a catch-all ``input_data: dict`` field. This helper merges that dict
back into the top-level kwargs so the MCP server receives flat arguments.
"""
input_data = kwargs.pop("input_data", None)
if isinstance(input_data, dict):
kwargs.update(input_data)
return kwargs
async def _create_mcp_tool_from_definition_stdio(
@ -103,7 +143,12 @@ async def _create_mcp_tool_from_definition_stdio(
``GraphInterrupt`` propagates cleanly to LangGraph.
"""
tool_name = tool_def.get("name", "unnamed_tool")
tool_description = tool_def.get("description", "No description provided")
raw_description = tool_def.get("description", "No description provided")
tool_description = (
f"[MCP server: {connector_name}] {raw_description}"
if connector_name
else raw_description
)
input_schema = tool_def.get("input_schema", {"type": "object", "properties": {}})
logger.debug("MCP tool '%s' input schema: %s", tool_name, input_schema)
@ -121,7 +166,7 @@ async def _create_mcp_tool_from_definition_stdio(
params=kwargs,
context={
"mcp_server": connector_name,
"tool_description": tool_description,
"tool_description": raw_description,
"mcp_transport": "stdio",
"mcp_connector_id": connector_id,
},
@ -129,18 +174,32 @@ async def _create_mcp_tool_from_definition_stdio(
)
if hitl_result.rejected:
return "Tool call rejected by user."
call_kwargs = {k: v for k, v in hitl_result.params.items() if v is not None}
call_kwargs = _unpack_synthetic_input_data(
{k: v for k, v in hitl_result.params.items() if v is not None}
)
try:
async with mcp_client.connect():
result = await mcp_client.call_tool(tool_name, call_kwargs)
return str(result)
except RuntimeError as e:
logger.error("MCP tool '%s' connection failed after retries: %s", tool_name, e)
return f"Error: MCP tool '{tool_name}' connection failed after retries: {e!s}"
except Exception as e:
logger.exception("MCP tool '%s' execution failed: %s", tool_name, e)
return f"Error: MCP tool '{tool_name}' execution failed: {e!s}"
last_error: Exception | None = None
for attempt in range(_TOOL_CALL_MAX_RETRIES):
try:
async with mcp_client.connect():
result = await mcp_client.call_tool(tool_name, call_kwargs)
return str(result)
except Exception as e:
last_error = e
if attempt < _TOOL_CALL_MAX_RETRIES - 1:
delay = _TOOL_CALL_RETRY_DELAY * (2 ** attempt)
logger.warning(
"MCP tool '%s' failed (attempt %d/%d): %s. Retrying in %.1fs...",
tool_name, attempt + 1, _TOOL_CALL_MAX_RETRIES, e, delay,
)
await asyncio.sleep(delay)
else:
logger.error(
"MCP tool '%s' failed after %d attempts: %s",
tool_name, _TOOL_CALL_MAX_RETRIES, e, exc_info=True,
)
return f"Error: MCP tool '{tool_name}' failed after {_TOOL_CALL_MAX_RETRIES} attempts: {last_error!s}"
tool = StructuredTool(
name=tool_name,
@ -150,6 +209,8 @@ async def _create_mcp_tool_from_definition_stdio(
metadata={
"mcp_input_schema": input_schema,
"mcp_transport": "stdio",
"mcp_connector_name": connector_name or None,
"mcp_is_generic": True,
"hitl": True,
"hitl_dedup_key": next(iter(input_schema.get("required", [])), None),
},
@ -169,6 +230,7 @@ async def _create_mcp_tool_from_definition_http(
trusted_tools: list[str] | None = None,
readonly_tools: frozenset[str] | None = None,
tool_name_prefix: str | None = None,
is_generic_mcp: bool = False,
) -> StructuredTool:
"""Create a LangChain tool from an MCP tool definition (HTTP transport).
@ -180,7 +242,7 @@ async def _create_mcp_tool_from_definition_http(
but the actual MCP ``call_tool`` still uses the original name.
"""
original_tool_name = tool_def.get("name", "unnamed_tool")
tool_description = tool_def.get("description", "No description provided")
raw_description = tool_def.get("description", "No description provided")
input_schema = tool_def.get("input_schema", {"type": "object", "properties": {}})
is_readonly = readonly_tools is not None and original_tool_name in readonly_tools
@ -190,7 +252,11 @@ async def _create_mcp_tool_from_definition_http(
else original_tool_name
)
if tool_name_prefix:
tool_description = f"[Account: {connector_name}] {tool_description}"
tool_description = f"[Account: {connector_name}] {raw_description}"
elif is_generic_mcp and connector_name:
tool_description = f"[MCP server: {connector_name}] {raw_description}"
else:
tool_description = raw_description
logger.debug("MCP HTTP tool '%s' input schema: %s", exposed_name, input_schema)
@ -199,6 +265,7 @@ async def _create_mcp_tool_from_definition_http(
async def _do_mcp_call(
call_headers: dict[str, str],
call_kwargs: dict[str, Any],
timeout: float = 60.0,
) -> str:
"""Execute a single MCP HTTP call with the given headers."""
async with (
@ -206,8 +273,9 @@ async def _create_mcp_tool_from_definition_http(
ClientSession(read, write) as session,
):
await session.initialize()
response = await session.call_tool(
original_tool_name, arguments=call_kwargs,
response = await asyncio.wait_for(
session.call_tool(original_tool_name, arguments=call_kwargs),
timeout=timeout,
)
result = []
@ -226,7 +294,9 @@ async def _create_mcp_tool_from_definition_http(
logger.debug("MCP HTTP tool '%s' called", exposed_name)
if is_readonly:
call_kwargs = {k: v for k, v in kwargs.items() if v is not None}
call_kwargs = _unpack_synthetic_input_data(
{k: v for k, v in kwargs.items() if v is not None}
)
else:
hitl_result = request_approval(
action_type="mcp_tool_call",
@ -234,7 +304,7 @@ async def _create_mcp_tool_from_definition_http(
params=kwargs,
context={
"mcp_server": connector_name,
"tool_description": tool_description,
"tool_description": raw_description,
"mcp_transport": "http",
"mcp_connector_id": connector_id,
},
@ -242,7 +312,9 @@ async def _create_mcp_tool_from_definition_http(
)
if hitl_result.rejected:
return "Tool call rejected by user."
call_kwargs = {k: v for k, v in hitl_result.params.items() if v is not None}
call_kwargs = _unpack_synthetic_input_data(
{k: v for k, v in hitl_result.params.items() if v is not None}
)
try:
result_str = await _do_mcp_call(headers, call_kwargs)
@ -295,6 +367,8 @@ async def _create_mcp_tool_from_definition_http(
"mcp_input_schema": input_schema,
"mcp_transport": "http",
"mcp_url": url,
"mcp_connector_name": connector_name or None,
"mcp_is_generic": is_generic_mcp,
"hitl": not is_readonly,
"hitl_dedup_key": next(iter(input_schema.get("required", [])), None),
"mcp_original_tool_name": original_tool_name,
@ -376,6 +450,7 @@ async def _load_http_mcp_tools(
allowed_tools: list[str] | None = None,
readonly_tools: frozenset[str] | None = None,
tool_name_prefix: str | None = None,
is_generic_mcp: bool = False,
) -> list[StructuredTool]:
"""Load tools from an HTTP-based MCP server.
@ -492,6 +567,7 @@ async def _load_http_mcp_tools(
trusted_tools=trusted_tools,
readonly_tools=readonly_tools,
tool_name_prefix=tool_name_prefix,
is_generic_mcp=is_generic_mcp,
)
tools.append(tool)
except Exception as e:
@ -928,6 +1004,7 @@ async def load_mcp_tools(
"readonly_tools": readonly_tools,
"tool_name_prefix": tool_name_prefix,
"transport": server_config.get("transport", "stdio"),
"is_generic_mcp": svc_cfg is None,
})
except Exception as e:
@ -948,6 +1025,7 @@ async def load_mcp_tools(
allowed_tools=task["allowed_tools"],
readonly_tools=task["readonly_tools"],
tool_name_prefix=task["tool_name_prefix"],
is_generic_mcp=task.get("is_generic_mcp", False),
),
timeout=_MCP_DISCOVERY_TIMEOUT_SECONDS,
)