fix: robust generic MCP tool routing, retry, and empty-schema handling

This commit is contained in:
CREDO23 2026-04-23 11:30:58 +02:00
parent 1712f454f8
commit 45b72de481
4 changed files with 191 additions and 42 deletions

View file

@ -314,6 +314,20 @@ async def create_surfsense_deep_agent(
_t0 = time.perf_counter()
_enabled_tool_names = {t.name for t in tools}
_user_disabled_tool_names = set(disabled_tools) if disabled_tools else set()
# Collect generic MCP connector info so the system prompt can route queries
# to their tools instead of falling back to "not in knowledge base".
_mcp_connector_tools: dict[str, list[str]] = {}
for t in tools:
meta = getattr(t, "metadata", None) or {}
if meta.get("mcp_is_generic") and meta.get("mcp_connector_name"):
_mcp_connector_tools.setdefault(
meta["mcp_connector_name"], [],
).append(t.name)
if _mcp_connector_tools:
_perf_log.info("MCP connector tool routing: %s", _mcp_connector_tools)
if agent_config is not None:
system_prompt = build_configurable_system_prompt(
custom_system_instructions=agent_config.system_instructions,
@ -322,12 +336,14 @@ async def create_surfsense_deep_agent(
thread_visibility=thread_visibility,
enabled_tool_names=_enabled_tool_names,
disabled_tool_names=_user_disabled_tool_names,
mcp_connector_tools=_mcp_connector_tools,
)
else:
system_prompt = build_surfsense_system_prompt(
thread_visibility=thread_visibility,
enabled_tool_names=_enabled_tool_names,
disabled_tool_names=_user_disabled_tool_names,
mcp_connector_tools=_mcp_connector_tools,
)
_perf_log.info(
"[create_agent] System prompt built in %.3fs", time.perf_counter() - _t0

View file

@ -815,11 +815,36 @@ Your goal is to provide helpful, informative answers in a clean, readable format
"""
def _build_mcp_routing_block(
mcp_connector_tools: dict[str, list[str]] | None,
) -> str:
"""Build an additional tool routing block for generic MCP connectors.
When users add MCP servers (e.g. GitLab, GitHub), the LLM needs to know
those tools exist and should be called directly not searched in the
knowledge base.
"""
if not mcp_connector_tools:
return ""
lines = [
"\n<mcp_tool_routing>",
"You also have direct tools from these user-connected MCP servers.",
"Their data is NEVER in the knowledge base — call their tools directly.",
"",
]
for server_name, tool_names in mcp_connector_tools.items():
lines.append(f"- {server_name}{', '.join(tool_names)}")
lines.append("</mcp_tool_routing>\n")
return "\n".join(lines)
def build_surfsense_system_prompt(
today: datetime | None = None,
thread_visibility: ChatVisibility | None = None,
enabled_tool_names: set[str] | None = None,
disabled_tool_names: set[str] | None = None,
mcp_connector_tools: dict[str, list[str]] | None = None,
) -> str:
"""
Build the SurfSense system prompt with default settings.
@ -834,6 +859,9 @@ def build_surfsense_system_prompt(
thread_visibility: Optional; when provided, used for conditional prompt (e.g. private vs shared memory wording). Defaults to private behavior when None.
enabled_tool_names: Set of tool names actually bound to the agent. When None all tools are included.
disabled_tool_names: Set of tool names the user explicitly disabled. Included as a note so the model can inform the user.
mcp_connector_tools: Mapping of MCP server display name list of tool names
for generic MCP connectors. Injected into the system prompt so the LLM
knows to call these tools directly.
Returns:
Complete system prompt string
@ -841,6 +869,7 @@ def build_surfsense_system_prompt(
visibility = thread_visibility or ChatVisibility.PRIVATE
system_instructions = _get_system_instructions(visibility, today)
system_instructions += _build_mcp_routing_block(mcp_connector_tools)
tools_instructions = _get_tools_instructions(
visibility, enabled_tool_names, disabled_tool_names
)
@ -856,6 +885,7 @@ def build_configurable_system_prompt(
thread_visibility: ChatVisibility | None = None,
enabled_tool_names: set[str] | None = None,
disabled_tool_names: set[str] | None = None,
mcp_connector_tools: dict[str, list[str]] | None = None,
) -> str:
"""
Build a configurable SurfSense system prompt based on NewLLMConfig settings.
@ -877,6 +907,9 @@ def build_configurable_system_prompt(
thread_visibility: Optional; when provided, used for conditional prompt (e.g. private vs shared memory wording). Defaults to private behavior when None.
enabled_tool_names: Set of tool names actually bound to the agent. When None all tools are included.
disabled_tool_names: Set of tool names the user explicitly disabled. Included as a note so the model can inform the user.
mcp_connector_tools: Mapping of MCP server display name list of tool names
for generic MCP connectors. Injected into the system prompt so the LLM
knows to call these tools directly.
Returns:
Complete system prompt string
@ -894,6 +927,8 @@ def build_configurable_system_prompt(
else:
system_instructions = ""
system_instructions += _build_mcp_routing_block(mcp_connector_tools)
# Tools instructions: only include enabled tools, note disabled ones
tools_instructions = _get_tools_instructions(
thread_visibility, enabled_tool_names, disabled_tool_names

View file

@ -45,6 +45,18 @@ class MCPClient:
async def connect(self, max_retries: int = MAX_RETRIES):
"""Connect to the MCP server and manage its lifecycle.
Retries only apply to the **connection** phase (spawning the process,
initialising the session). Once the session is yielded to the caller,
any exception raised by the caller propagates normally -- the context
manager will NOT retry after ``yield``.
Previous implementation wrapped both connection AND yield inside the
retry loop. Because ``@asynccontextmanager`` only allows a single
``yield``, a failure after yield caused the generator to attempt a
second yield on retry, triggering
``RuntimeError("generator didn't stop after athrow()")`` and orphaning
the stdio subprocess.
Args:
max_retries: Maximum number of connection retry attempts
@ -57,26 +69,22 @@ class MCPClient:
"""
last_error = None
delay = RETRY_DELAY
connected = False
for attempt in range(max_retries):
try:
# Merge env vars with current environment
server_env = os.environ.copy()
server_env.update(self.env)
# Create server parameters with env
server_params = StdioServerParameters(
command=self.command, args=self.args, env=server_env
)
# Spawn server process and create session
# Note: Cannot combine these context managers because ClientSession
# needs the read/write streams from stdio_client
async with stdio_client(server=server_params) as (read, write): # noqa: SIM117
async with ClientSession(read, write) as session:
# Initialize the connection
await session.initialize()
self.session = session
connected = True
if attempt > 0:
logger.info(
@ -91,10 +99,16 @@ class MCPClient:
self.command,
" ".join(self.args),
)
yield session
return # Success, exit retry loop
try:
yield session
finally:
self.session = None
return
except Exception as e:
self.session = None
if connected:
raise
last_error = e
if attempt < max_retries - 1:
logger.warning(
@ -105,7 +119,7 @@ class MCPClient:
delay,
)
await asyncio.sleep(delay)
delay *= RETRY_BACKOFF # Exponential backoff
delay *= RETRY_BACKOFF
else:
logger.error(
"Failed to connect to MCP server after %d attempts: %s",
@ -113,10 +127,7 @@ class MCPClient:
e,
exc_info=True,
)
finally:
self.session = None
# All retries exhausted
error_msg = f"Failed to connect to MCP server '{self.command}' after {max_retries} attempts"
if last_error:
error_msg += f": {last_error}"
@ -161,12 +172,18 @@ class MCPClient:
logger.error("Failed to list tools from MCP server: %s", e, exc_info=True)
raise
async def call_tool(self, tool_name: str, arguments: dict[str, Any]) -> Any:
async def call_tool(
self,
tool_name: str,
arguments: dict[str, Any],
timeout: float = 60.0,
) -> Any:
"""Call a tool on the MCP server.
Args:
tool_name: Name of the tool to call
arguments: Arguments to pass to the tool
timeout: Maximum seconds to wait for the tool to respond
Returns:
Tool execution result
@ -185,10 +202,11 @@ class MCPClient:
"Calling MCP tool '%s' with arguments: %s", tool_name, arguments
)
# Call tools/call RPC method
response = await self.session.call_tool(tool_name, arguments=arguments)
response = await asyncio.wait_for(
self.session.call_tool(tool_name, arguments=arguments),
timeout=timeout,
)
# Extract content from response
result = []
for content in response.content:
if hasattr(content, "text"):
@ -202,15 +220,17 @@ class MCPClient:
logger.info("MCP tool '%s' succeeded: %s", tool_name, result_str[:200])
return result_str
except asyncio.TimeoutError:
logger.error(
"MCP tool '%s' timed out after %.0fs", tool_name, timeout
)
return f"Error: MCP tool '{tool_name}' timed out after {timeout:.0f}s"
except RuntimeError as e:
# Handle validation errors from MCP server responses
# Some MCP servers (like server-memory) return extra fields not in their schema
if "Invalid structured content" in str(e):
logger.warning(
"MCP server returned data not matching its schema, but continuing: %s",
e,
)
# Try to extract result from error message or return a success message
return "Operation completed (server returned unexpected format)"
raise
except (ValueError, TypeError, AttributeError, KeyError) as e:

View file

@ -28,7 +28,7 @@ if TYPE_CHECKING:
from langchain_core.tools import StructuredTool
from mcp import ClientSession
from mcp.client.streamable_http import streamablehttp_client
from pydantic import BaseModel, Field, create_model
from pydantic import BaseModel, ConfigDict, Field, create_model
from sqlalchemy import cast, select
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.ext.asyncio import AsyncSession
@ -43,6 +43,8 @@ logger = logging.getLogger(__name__)
_MCP_CACHE_TTL_SECONDS = 300 # 5 minutes
_MCP_CACHE_MAX_SIZE = 50
_MCP_DISCOVERY_TIMEOUT_SECONDS = 30
_TOOL_CALL_MAX_RETRIES = 3
_TOOL_CALL_RETRY_DELAY = 1.5 # seconds, doubles per attempt
_mcp_tools_cache: dict[int, tuple[float, list[StructuredTool]]] = {}
@ -64,7 +66,18 @@ def _create_dynamic_input_model_from_schema(
tool_name: str,
input_schema: dict[str, Any],
) -> type[BaseModel]:
"""Create a Pydantic model from MCP tool's JSON schema."""
"""Create a Pydantic model from MCP tool's JSON schema.
Models always allow extra fields (``extra="allow"``) so that parameters
missing from a broken or incomplete JSON schema (e.g. ``zod-to-json-schema``
producing an empty ``$schema``-only object) can still be forwarded to the
MCP server.
When the schema declares **no** properties, a synthetic ``input_data``
field of type ``dict`` is injected so the LLM has a visible parameter to
populate. The caller should unpack ``input_data`` before forwarding to
the MCP server (see ``_unpack_synthetic_input_data``).
"""
properties = input_schema.get("properties", {})
required_fields = input_schema.get("required", [])
@ -84,8 +97,35 @@ def _create_dynamic_input_model_from_schema(
Field(None, description=param_description),
)
if not properties:
field_definitions["input_data"] = (
dict[str, Any] | None,
Field(
None,
description=(
"Arguments to pass to this tool as a JSON object. "
"Infer sensible key names from the tool name and description "
"(e.g. {\"search\": \"my query\"} for a search tool)."
),
),
)
model_name = f"{tool_name.replace(' ', '').replace('-', '_')}Input"
return create_model(model_name, **field_definitions)
model = create_model(model_name, __config__=ConfigDict(extra="allow"), **field_definitions)
return model
def _unpack_synthetic_input_data(kwargs: dict[str, Any]) -> dict[str, Any]:
"""Unpack the synthetic ``input_data`` field into top-level kwargs.
When the MCP tool schema is empty, ``_create_dynamic_input_model_from_schema``
adds a catch-all ``input_data: dict`` field. This helper merges that dict
back into the top-level kwargs so the MCP server receives flat arguments.
"""
input_data = kwargs.pop("input_data", None)
if isinstance(input_data, dict):
kwargs.update(input_data)
return kwargs
async def _create_mcp_tool_from_definition_stdio(
@ -103,7 +143,12 @@ async def _create_mcp_tool_from_definition_stdio(
``GraphInterrupt`` propagates cleanly to LangGraph.
"""
tool_name = tool_def.get("name", "unnamed_tool")
tool_description = tool_def.get("description", "No description provided")
raw_description = tool_def.get("description", "No description provided")
tool_description = (
f"[MCP server: {connector_name}] {raw_description}"
if connector_name
else raw_description
)
input_schema = tool_def.get("input_schema", {"type": "object", "properties": {}})
logger.debug("MCP tool '%s' input schema: %s", tool_name, input_schema)
@ -121,7 +166,7 @@ async def _create_mcp_tool_from_definition_stdio(
params=kwargs,
context={
"mcp_server": connector_name,
"tool_description": tool_description,
"tool_description": raw_description,
"mcp_transport": "stdio",
"mcp_connector_id": connector_id,
},
@ -129,18 +174,32 @@ async def _create_mcp_tool_from_definition_stdio(
)
if hitl_result.rejected:
return "Tool call rejected by user."
call_kwargs = {k: v for k, v in hitl_result.params.items() if v is not None}
call_kwargs = _unpack_synthetic_input_data(
{k: v for k, v in hitl_result.params.items() if v is not None}
)
try:
async with mcp_client.connect():
result = await mcp_client.call_tool(tool_name, call_kwargs)
return str(result)
except RuntimeError as e:
logger.error("MCP tool '%s' connection failed after retries: %s", tool_name, e)
return f"Error: MCP tool '{tool_name}' connection failed after retries: {e!s}"
except Exception as e:
logger.exception("MCP tool '%s' execution failed: %s", tool_name, e)
return f"Error: MCP tool '{tool_name}' execution failed: {e!s}"
last_error: Exception | None = None
for attempt in range(_TOOL_CALL_MAX_RETRIES):
try:
async with mcp_client.connect():
result = await mcp_client.call_tool(tool_name, call_kwargs)
return str(result)
except Exception as e:
last_error = e
if attempt < _TOOL_CALL_MAX_RETRIES - 1:
delay = _TOOL_CALL_RETRY_DELAY * (2 ** attempt)
logger.warning(
"MCP tool '%s' failed (attempt %d/%d): %s. Retrying in %.1fs...",
tool_name, attempt + 1, _TOOL_CALL_MAX_RETRIES, e, delay,
)
await asyncio.sleep(delay)
else:
logger.error(
"MCP tool '%s' failed after %d attempts: %s",
tool_name, _TOOL_CALL_MAX_RETRIES, e, exc_info=True,
)
return f"Error: MCP tool '{tool_name}' failed after {_TOOL_CALL_MAX_RETRIES} attempts: {last_error!s}"
tool = StructuredTool(
name=tool_name,
@ -150,6 +209,8 @@ async def _create_mcp_tool_from_definition_stdio(
metadata={
"mcp_input_schema": input_schema,
"mcp_transport": "stdio",
"mcp_connector_name": connector_name or None,
"mcp_is_generic": True,
"hitl": True,
"hitl_dedup_key": next(iter(input_schema.get("required", [])), None),
},
@ -169,6 +230,7 @@ async def _create_mcp_tool_from_definition_http(
trusted_tools: list[str] | None = None,
readonly_tools: frozenset[str] | None = None,
tool_name_prefix: str | None = None,
is_generic_mcp: bool = False,
) -> StructuredTool:
"""Create a LangChain tool from an MCP tool definition (HTTP transport).
@ -180,7 +242,7 @@ async def _create_mcp_tool_from_definition_http(
but the actual MCP ``call_tool`` still uses the original name.
"""
original_tool_name = tool_def.get("name", "unnamed_tool")
tool_description = tool_def.get("description", "No description provided")
raw_description = tool_def.get("description", "No description provided")
input_schema = tool_def.get("input_schema", {"type": "object", "properties": {}})
is_readonly = readonly_tools is not None and original_tool_name in readonly_tools
@ -190,7 +252,11 @@ async def _create_mcp_tool_from_definition_http(
else original_tool_name
)
if tool_name_prefix:
tool_description = f"[Account: {connector_name}] {tool_description}"
tool_description = f"[Account: {connector_name}] {raw_description}"
elif is_generic_mcp and connector_name:
tool_description = f"[MCP server: {connector_name}] {raw_description}"
else:
tool_description = raw_description
logger.debug("MCP HTTP tool '%s' input schema: %s", exposed_name, input_schema)
@ -199,6 +265,7 @@ async def _create_mcp_tool_from_definition_http(
async def _do_mcp_call(
call_headers: dict[str, str],
call_kwargs: dict[str, Any],
timeout: float = 60.0,
) -> str:
"""Execute a single MCP HTTP call with the given headers."""
async with (
@ -206,8 +273,9 @@ async def _create_mcp_tool_from_definition_http(
ClientSession(read, write) as session,
):
await session.initialize()
response = await session.call_tool(
original_tool_name, arguments=call_kwargs,
response = await asyncio.wait_for(
session.call_tool(original_tool_name, arguments=call_kwargs),
timeout=timeout,
)
result = []
@ -226,7 +294,9 @@ async def _create_mcp_tool_from_definition_http(
logger.debug("MCP HTTP tool '%s' called", exposed_name)
if is_readonly:
call_kwargs = {k: v for k, v in kwargs.items() if v is not None}
call_kwargs = _unpack_synthetic_input_data(
{k: v for k, v in kwargs.items() if v is not None}
)
else:
hitl_result = request_approval(
action_type="mcp_tool_call",
@ -234,7 +304,7 @@ async def _create_mcp_tool_from_definition_http(
params=kwargs,
context={
"mcp_server": connector_name,
"tool_description": tool_description,
"tool_description": raw_description,
"mcp_transport": "http",
"mcp_connector_id": connector_id,
},
@ -242,7 +312,9 @@ async def _create_mcp_tool_from_definition_http(
)
if hitl_result.rejected:
return "Tool call rejected by user."
call_kwargs = {k: v for k, v in hitl_result.params.items() if v is not None}
call_kwargs = _unpack_synthetic_input_data(
{k: v for k, v in hitl_result.params.items() if v is not None}
)
try:
result_str = await _do_mcp_call(headers, call_kwargs)
@ -295,6 +367,8 @@ async def _create_mcp_tool_from_definition_http(
"mcp_input_schema": input_schema,
"mcp_transport": "http",
"mcp_url": url,
"mcp_connector_name": connector_name or None,
"mcp_is_generic": is_generic_mcp,
"hitl": not is_readonly,
"hitl_dedup_key": next(iter(input_schema.get("required", [])), None),
"mcp_original_tool_name": original_tool_name,
@ -376,6 +450,7 @@ async def _load_http_mcp_tools(
allowed_tools: list[str] | None = None,
readonly_tools: frozenset[str] | None = None,
tool_name_prefix: str | None = None,
is_generic_mcp: bool = False,
) -> list[StructuredTool]:
"""Load tools from an HTTP-based MCP server.
@ -492,6 +567,7 @@ async def _load_http_mcp_tools(
trusted_tools=trusted_tools,
readonly_tools=readonly_tools,
tool_name_prefix=tool_name_prefix,
is_generic_mcp=is_generic_mcp,
)
tools.append(tool)
except Exception as e:
@ -928,6 +1004,7 @@ async def load_mcp_tools(
"readonly_tools": readonly_tools,
"tool_name_prefix": tool_name_prefix,
"transport": server_config.get("transport", "stdio"),
"is_generic_mcp": svc_cfg is None,
})
except Exception as e:
@ -948,6 +1025,7 @@ async def load_mcp_tools(
allowed_tools=task["allowed_tools"],
readonly_tools=task["readonly_tools"],
tool_name_prefix=task["tool_name_prefix"],
is_generic_mcp=task.get("is_generic_mcp", False),
),
timeout=_MCP_DISCOVERY_TIMEOUT_SECONDS,
)