mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-04-25 00:36:31 +02:00
fix: robust generic MCP tool routing, retry, and empty-schema handling
This commit is contained in:
parent
1712f454f8
commit
45b72de481
4 changed files with 191 additions and 42 deletions
|
|
@ -314,6 +314,20 @@ async def create_surfsense_deep_agent(
|
||||||
_t0 = time.perf_counter()
|
_t0 = time.perf_counter()
|
||||||
_enabled_tool_names = {t.name for t in tools}
|
_enabled_tool_names = {t.name for t in tools}
|
||||||
_user_disabled_tool_names = set(disabled_tools) if disabled_tools else set()
|
_user_disabled_tool_names = set(disabled_tools) if disabled_tools else set()
|
||||||
|
|
||||||
|
# Collect generic MCP connector info so the system prompt can route queries
|
||||||
|
# to their tools instead of falling back to "not in knowledge base".
|
||||||
|
_mcp_connector_tools: dict[str, list[str]] = {}
|
||||||
|
for t in tools:
|
||||||
|
meta = getattr(t, "metadata", None) or {}
|
||||||
|
if meta.get("mcp_is_generic") and meta.get("mcp_connector_name"):
|
||||||
|
_mcp_connector_tools.setdefault(
|
||||||
|
meta["mcp_connector_name"], [],
|
||||||
|
).append(t.name)
|
||||||
|
|
||||||
|
if _mcp_connector_tools:
|
||||||
|
_perf_log.info("MCP connector tool routing: %s", _mcp_connector_tools)
|
||||||
|
|
||||||
if agent_config is not None:
|
if agent_config is not None:
|
||||||
system_prompt = build_configurable_system_prompt(
|
system_prompt = build_configurable_system_prompt(
|
||||||
custom_system_instructions=agent_config.system_instructions,
|
custom_system_instructions=agent_config.system_instructions,
|
||||||
|
|
@ -322,12 +336,14 @@ async def create_surfsense_deep_agent(
|
||||||
thread_visibility=thread_visibility,
|
thread_visibility=thread_visibility,
|
||||||
enabled_tool_names=_enabled_tool_names,
|
enabled_tool_names=_enabled_tool_names,
|
||||||
disabled_tool_names=_user_disabled_tool_names,
|
disabled_tool_names=_user_disabled_tool_names,
|
||||||
|
mcp_connector_tools=_mcp_connector_tools,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
system_prompt = build_surfsense_system_prompt(
|
system_prompt = build_surfsense_system_prompt(
|
||||||
thread_visibility=thread_visibility,
|
thread_visibility=thread_visibility,
|
||||||
enabled_tool_names=_enabled_tool_names,
|
enabled_tool_names=_enabled_tool_names,
|
||||||
disabled_tool_names=_user_disabled_tool_names,
|
disabled_tool_names=_user_disabled_tool_names,
|
||||||
|
mcp_connector_tools=_mcp_connector_tools,
|
||||||
)
|
)
|
||||||
_perf_log.info(
|
_perf_log.info(
|
||||||
"[create_agent] System prompt built in %.3fs", time.perf_counter() - _t0
|
"[create_agent] System prompt built in %.3fs", time.perf_counter() - _t0
|
||||||
|
|
|
||||||
|
|
@ -815,11 +815,36 @@ Your goal is to provide helpful, informative answers in a clean, readable format
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def _build_mcp_routing_block(
|
||||||
|
mcp_connector_tools: dict[str, list[str]] | None,
|
||||||
|
) -> str:
|
||||||
|
"""Build an additional tool routing block for generic MCP connectors.
|
||||||
|
|
||||||
|
When users add MCP servers (e.g. GitLab, GitHub), the LLM needs to know
|
||||||
|
those tools exist and should be called directly — not searched in the
|
||||||
|
knowledge base.
|
||||||
|
"""
|
||||||
|
if not mcp_connector_tools:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
lines = [
|
||||||
|
"\n<mcp_tool_routing>",
|
||||||
|
"You also have direct tools from these user-connected MCP servers.",
|
||||||
|
"Their data is NEVER in the knowledge base — call their tools directly.",
|
||||||
|
"",
|
||||||
|
]
|
||||||
|
for server_name, tool_names in mcp_connector_tools.items():
|
||||||
|
lines.append(f"- {server_name} → {', '.join(tool_names)}")
|
||||||
|
lines.append("</mcp_tool_routing>\n")
|
||||||
|
return "\n".join(lines)
|
||||||
|
|
||||||
|
|
||||||
def build_surfsense_system_prompt(
|
def build_surfsense_system_prompt(
|
||||||
today: datetime | None = None,
|
today: datetime | None = None,
|
||||||
thread_visibility: ChatVisibility | None = None,
|
thread_visibility: ChatVisibility | None = None,
|
||||||
enabled_tool_names: set[str] | None = None,
|
enabled_tool_names: set[str] | None = None,
|
||||||
disabled_tool_names: set[str] | None = None,
|
disabled_tool_names: set[str] | None = None,
|
||||||
|
mcp_connector_tools: dict[str, list[str]] | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Build the SurfSense system prompt with default settings.
|
Build the SurfSense system prompt with default settings.
|
||||||
|
|
@ -834,6 +859,9 @@ def build_surfsense_system_prompt(
|
||||||
thread_visibility: Optional; when provided, used for conditional prompt (e.g. private vs shared memory wording). Defaults to private behavior when None.
|
thread_visibility: Optional; when provided, used for conditional prompt (e.g. private vs shared memory wording). Defaults to private behavior when None.
|
||||||
enabled_tool_names: Set of tool names actually bound to the agent. When None all tools are included.
|
enabled_tool_names: Set of tool names actually bound to the agent. When None all tools are included.
|
||||||
disabled_tool_names: Set of tool names the user explicitly disabled. Included as a note so the model can inform the user.
|
disabled_tool_names: Set of tool names the user explicitly disabled. Included as a note so the model can inform the user.
|
||||||
|
mcp_connector_tools: Mapping of MCP server display name → list of tool names
|
||||||
|
for generic MCP connectors. Injected into the system prompt so the LLM
|
||||||
|
knows to call these tools directly.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Complete system prompt string
|
Complete system prompt string
|
||||||
|
|
@ -841,6 +869,7 @@ def build_surfsense_system_prompt(
|
||||||
|
|
||||||
visibility = thread_visibility or ChatVisibility.PRIVATE
|
visibility = thread_visibility or ChatVisibility.PRIVATE
|
||||||
system_instructions = _get_system_instructions(visibility, today)
|
system_instructions = _get_system_instructions(visibility, today)
|
||||||
|
system_instructions += _build_mcp_routing_block(mcp_connector_tools)
|
||||||
tools_instructions = _get_tools_instructions(
|
tools_instructions = _get_tools_instructions(
|
||||||
visibility, enabled_tool_names, disabled_tool_names
|
visibility, enabled_tool_names, disabled_tool_names
|
||||||
)
|
)
|
||||||
|
|
@ -856,6 +885,7 @@ def build_configurable_system_prompt(
|
||||||
thread_visibility: ChatVisibility | None = None,
|
thread_visibility: ChatVisibility | None = None,
|
||||||
enabled_tool_names: set[str] | None = None,
|
enabled_tool_names: set[str] | None = None,
|
||||||
disabled_tool_names: set[str] | None = None,
|
disabled_tool_names: set[str] | None = None,
|
||||||
|
mcp_connector_tools: dict[str, list[str]] | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Build a configurable SurfSense system prompt based on NewLLMConfig settings.
|
Build a configurable SurfSense system prompt based on NewLLMConfig settings.
|
||||||
|
|
@ -877,6 +907,9 @@ def build_configurable_system_prompt(
|
||||||
thread_visibility: Optional; when provided, used for conditional prompt (e.g. private vs shared memory wording). Defaults to private behavior when None.
|
thread_visibility: Optional; when provided, used for conditional prompt (e.g. private vs shared memory wording). Defaults to private behavior when None.
|
||||||
enabled_tool_names: Set of tool names actually bound to the agent. When None all tools are included.
|
enabled_tool_names: Set of tool names actually bound to the agent. When None all tools are included.
|
||||||
disabled_tool_names: Set of tool names the user explicitly disabled. Included as a note so the model can inform the user.
|
disabled_tool_names: Set of tool names the user explicitly disabled. Included as a note so the model can inform the user.
|
||||||
|
mcp_connector_tools: Mapping of MCP server display name → list of tool names
|
||||||
|
for generic MCP connectors. Injected into the system prompt so the LLM
|
||||||
|
knows to call these tools directly.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Complete system prompt string
|
Complete system prompt string
|
||||||
|
|
@ -894,6 +927,8 @@ def build_configurable_system_prompt(
|
||||||
else:
|
else:
|
||||||
system_instructions = ""
|
system_instructions = ""
|
||||||
|
|
||||||
|
system_instructions += _build_mcp_routing_block(mcp_connector_tools)
|
||||||
|
|
||||||
# Tools instructions: only include enabled tools, note disabled ones
|
# Tools instructions: only include enabled tools, note disabled ones
|
||||||
tools_instructions = _get_tools_instructions(
|
tools_instructions = _get_tools_instructions(
|
||||||
thread_visibility, enabled_tool_names, disabled_tool_names
|
thread_visibility, enabled_tool_names, disabled_tool_names
|
||||||
|
|
|
||||||
|
|
@ -45,6 +45,18 @@ class MCPClient:
|
||||||
async def connect(self, max_retries: int = MAX_RETRIES):
|
async def connect(self, max_retries: int = MAX_RETRIES):
|
||||||
"""Connect to the MCP server and manage its lifecycle.
|
"""Connect to the MCP server and manage its lifecycle.
|
||||||
|
|
||||||
|
Retries only apply to the **connection** phase (spawning the process,
|
||||||
|
initialising the session). Once the session is yielded to the caller,
|
||||||
|
any exception raised by the caller propagates normally -- the context
|
||||||
|
manager will NOT retry after ``yield``.
|
||||||
|
|
||||||
|
Previous implementation wrapped both connection AND yield inside the
|
||||||
|
retry loop. Because ``@asynccontextmanager`` only allows a single
|
||||||
|
``yield``, a failure after yield caused the generator to attempt a
|
||||||
|
second yield on retry, triggering
|
||||||
|
``RuntimeError("generator didn't stop after athrow()")`` and orphaning
|
||||||
|
the stdio subprocess.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
max_retries: Maximum number of connection retry attempts
|
max_retries: Maximum number of connection retry attempts
|
||||||
|
|
||||||
|
|
@ -57,26 +69,22 @@ class MCPClient:
|
||||||
"""
|
"""
|
||||||
last_error = None
|
last_error = None
|
||||||
delay = RETRY_DELAY
|
delay = RETRY_DELAY
|
||||||
|
connected = False
|
||||||
|
|
||||||
for attempt in range(max_retries):
|
for attempt in range(max_retries):
|
||||||
try:
|
try:
|
||||||
# Merge env vars with current environment
|
|
||||||
server_env = os.environ.copy()
|
server_env = os.environ.copy()
|
||||||
server_env.update(self.env)
|
server_env.update(self.env)
|
||||||
|
|
||||||
# Create server parameters with env
|
|
||||||
server_params = StdioServerParameters(
|
server_params = StdioServerParameters(
|
||||||
command=self.command, args=self.args, env=server_env
|
command=self.command, args=self.args, env=server_env
|
||||||
)
|
)
|
||||||
|
|
||||||
# Spawn server process and create session
|
|
||||||
# Note: Cannot combine these context managers because ClientSession
|
|
||||||
# needs the read/write streams from stdio_client
|
|
||||||
async with stdio_client(server=server_params) as (read, write): # noqa: SIM117
|
async with stdio_client(server=server_params) as (read, write): # noqa: SIM117
|
||||||
async with ClientSession(read, write) as session:
|
async with ClientSession(read, write) as session:
|
||||||
# Initialize the connection
|
|
||||||
await session.initialize()
|
await session.initialize()
|
||||||
self.session = session
|
self.session = session
|
||||||
|
connected = True
|
||||||
|
|
||||||
if attempt > 0:
|
if attempt > 0:
|
||||||
logger.info(
|
logger.info(
|
||||||
|
|
@ -91,10 +99,16 @@ class MCPClient:
|
||||||
self.command,
|
self.command,
|
||||||
" ".join(self.args),
|
" ".join(self.args),
|
||||||
)
|
)
|
||||||
yield session
|
try:
|
||||||
return # Success, exit retry loop
|
yield session
|
||||||
|
finally:
|
||||||
|
self.session = None
|
||||||
|
return
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
self.session = None
|
||||||
|
if connected:
|
||||||
|
raise
|
||||||
last_error = e
|
last_error = e
|
||||||
if attempt < max_retries - 1:
|
if attempt < max_retries - 1:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
|
|
@ -105,7 +119,7 @@ class MCPClient:
|
||||||
delay,
|
delay,
|
||||||
)
|
)
|
||||||
await asyncio.sleep(delay)
|
await asyncio.sleep(delay)
|
||||||
delay *= RETRY_BACKOFF # Exponential backoff
|
delay *= RETRY_BACKOFF
|
||||||
else:
|
else:
|
||||||
logger.error(
|
logger.error(
|
||||||
"Failed to connect to MCP server after %d attempts: %s",
|
"Failed to connect to MCP server after %d attempts: %s",
|
||||||
|
|
@ -113,10 +127,7 @@ class MCPClient:
|
||||||
e,
|
e,
|
||||||
exc_info=True,
|
exc_info=True,
|
||||||
)
|
)
|
||||||
finally:
|
|
||||||
self.session = None
|
|
||||||
|
|
||||||
# All retries exhausted
|
|
||||||
error_msg = f"Failed to connect to MCP server '{self.command}' after {max_retries} attempts"
|
error_msg = f"Failed to connect to MCP server '{self.command}' after {max_retries} attempts"
|
||||||
if last_error:
|
if last_error:
|
||||||
error_msg += f": {last_error}"
|
error_msg += f": {last_error}"
|
||||||
|
|
@ -161,12 +172,18 @@ class MCPClient:
|
||||||
logger.error("Failed to list tools from MCP server: %s", e, exc_info=True)
|
logger.error("Failed to list tools from MCP server: %s", e, exc_info=True)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async def call_tool(self, tool_name: str, arguments: dict[str, Any]) -> Any:
|
async def call_tool(
|
||||||
|
self,
|
||||||
|
tool_name: str,
|
||||||
|
arguments: dict[str, Any],
|
||||||
|
timeout: float = 60.0,
|
||||||
|
) -> Any:
|
||||||
"""Call a tool on the MCP server.
|
"""Call a tool on the MCP server.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
tool_name: Name of the tool to call
|
tool_name: Name of the tool to call
|
||||||
arguments: Arguments to pass to the tool
|
arguments: Arguments to pass to the tool
|
||||||
|
timeout: Maximum seconds to wait for the tool to respond
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tool execution result
|
Tool execution result
|
||||||
|
|
@ -185,10 +202,11 @@ class MCPClient:
|
||||||
"Calling MCP tool '%s' with arguments: %s", tool_name, arguments
|
"Calling MCP tool '%s' with arguments: %s", tool_name, arguments
|
||||||
)
|
)
|
||||||
|
|
||||||
# Call tools/call RPC method
|
response = await asyncio.wait_for(
|
||||||
response = await self.session.call_tool(tool_name, arguments=arguments)
|
self.session.call_tool(tool_name, arguments=arguments),
|
||||||
|
timeout=timeout,
|
||||||
|
)
|
||||||
|
|
||||||
# Extract content from response
|
|
||||||
result = []
|
result = []
|
||||||
for content in response.content:
|
for content in response.content:
|
||||||
if hasattr(content, "text"):
|
if hasattr(content, "text"):
|
||||||
|
|
@ -202,15 +220,17 @@ class MCPClient:
|
||||||
logger.info("MCP tool '%s' succeeded: %s", tool_name, result_str[:200])
|
logger.info("MCP tool '%s' succeeded: %s", tool_name, result_str[:200])
|
||||||
return result_str
|
return result_str
|
||||||
|
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
logger.error(
|
||||||
|
"MCP tool '%s' timed out after %.0fs", tool_name, timeout
|
||||||
|
)
|
||||||
|
return f"Error: MCP tool '{tool_name}' timed out after {timeout:.0f}s"
|
||||||
except RuntimeError as e:
|
except RuntimeError as e:
|
||||||
# Handle validation errors from MCP server responses
|
|
||||||
# Some MCP servers (like server-memory) return extra fields not in their schema
|
|
||||||
if "Invalid structured content" in str(e):
|
if "Invalid structured content" in str(e):
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"MCP server returned data not matching its schema, but continuing: %s",
|
"MCP server returned data not matching its schema, but continuing: %s",
|
||||||
e,
|
e,
|
||||||
)
|
)
|
||||||
# Try to extract result from error message or return a success message
|
|
||||||
return "Operation completed (server returned unexpected format)"
|
return "Operation completed (server returned unexpected format)"
|
||||||
raise
|
raise
|
||||||
except (ValueError, TypeError, AttributeError, KeyError) as e:
|
except (ValueError, TypeError, AttributeError, KeyError) as e:
|
||||||
|
|
|
||||||
|
|
@ -28,7 +28,7 @@ if TYPE_CHECKING:
|
||||||
from langchain_core.tools import StructuredTool
|
from langchain_core.tools import StructuredTool
|
||||||
from mcp import ClientSession
|
from mcp import ClientSession
|
||||||
from mcp.client.streamable_http import streamablehttp_client
|
from mcp.client.streamable_http import streamablehttp_client
|
||||||
from pydantic import BaseModel, Field, create_model
|
from pydantic import BaseModel, ConfigDict, Field, create_model
|
||||||
from sqlalchemy import cast, select
|
from sqlalchemy import cast, select
|
||||||
from sqlalchemy.dialects.postgresql import JSONB
|
from sqlalchemy.dialects.postgresql import JSONB
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
@ -43,6 +43,8 @@ logger = logging.getLogger(__name__)
|
||||||
_MCP_CACHE_TTL_SECONDS = 300 # 5 minutes
|
_MCP_CACHE_TTL_SECONDS = 300 # 5 minutes
|
||||||
_MCP_CACHE_MAX_SIZE = 50
|
_MCP_CACHE_MAX_SIZE = 50
|
||||||
_MCP_DISCOVERY_TIMEOUT_SECONDS = 30
|
_MCP_DISCOVERY_TIMEOUT_SECONDS = 30
|
||||||
|
_TOOL_CALL_MAX_RETRIES = 3
|
||||||
|
_TOOL_CALL_RETRY_DELAY = 1.5 # seconds, doubles per attempt
|
||||||
_mcp_tools_cache: dict[int, tuple[float, list[StructuredTool]]] = {}
|
_mcp_tools_cache: dict[int, tuple[float, list[StructuredTool]]] = {}
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -64,7 +66,18 @@ def _create_dynamic_input_model_from_schema(
|
||||||
tool_name: str,
|
tool_name: str,
|
||||||
input_schema: dict[str, Any],
|
input_schema: dict[str, Any],
|
||||||
) -> type[BaseModel]:
|
) -> type[BaseModel]:
|
||||||
"""Create a Pydantic model from MCP tool's JSON schema."""
|
"""Create a Pydantic model from MCP tool's JSON schema.
|
||||||
|
|
||||||
|
Models always allow extra fields (``extra="allow"``) so that parameters
|
||||||
|
missing from a broken or incomplete JSON schema (e.g. ``zod-to-json-schema``
|
||||||
|
producing an empty ``$schema``-only object) can still be forwarded to the
|
||||||
|
MCP server.
|
||||||
|
|
||||||
|
When the schema declares **no** properties, a synthetic ``input_data``
|
||||||
|
field of type ``dict`` is injected so the LLM has a visible parameter to
|
||||||
|
populate. The caller should unpack ``input_data`` before forwarding to
|
||||||
|
the MCP server (see ``_unpack_synthetic_input_data``).
|
||||||
|
"""
|
||||||
properties = input_schema.get("properties", {})
|
properties = input_schema.get("properties", {})
|
||||||
required_fields = input_schema.get("required", [])
|
required_fields = input_schema.get("required", [])
|
||||||
|
|
||||||
|
|
@ -84,8 +97,35 @@ def _create_dynamic_input_model_from_schema(
|
||||||
Field(None, description=param_description),
|
Field(None, description=param_description),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if not properties:
|
||||||
|
field_definitions["input_data"] = (
|
||||||
|
dict[str, Any] | None,
|
||||||
|
Field(
|
||||||
|
None,
|
||||||
|
description=(
|
||||||
|
"Arguments to pass to this tool as a JSON object. "
|
||||||
|
"Infer sensible key names from the tool name and description "
|
||||||
|
"(e.g. {\"search\": \"my query\"} for a search tool)."
|
||||||
|
),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
model_name = f"{tool_name.replace(' ', '').replace('-', '_')}Input"
|
model_name = f"{tool_name.replace(' ', '').replace('-', '_')}Input"
|
||||||
return create_model(model_name, **field_definitions)
|
model = create_model(model_name, __config__=ConfigDict(extra="allow"), **field_definitions)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def _unpack_synthetic_input_data(kwargs: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
"""Unpack the synthetic ``input_data`` field into top-level kwargs.
|
||||||
|
|
||||||
|
When the MCP tool schema is empty, ``_create_dynamic_input_model_from_schema``
|
||||||
|
adds a catch-all ``input_data: dict`` field. This helper merges that dict
|
||||||
|
back into the top-level kwargs so the MCP server receives flat arguments.
|
||||||
|
"""
|
||||||
|
input_data = kwargs.pop("input_data", None)
|
||||||
|
if isinstance(input_data, dict):
|
||||||
|
kwargs.update(input_data)
|
||||||
|
return kwargs
|
||||||
|
|
||||||
|
|
||||||
async def _create_mcp_tool_from_definition_stdio(
|
async def _create_mcp_tool_from_definition_stdio(
|
||||||
|
|
@ -103,7 +143,12 @@ async def _create_mcp_tool_from_definition_stdio(
|
||||||
``GraphInterrupt`` propagates cleanly to LangGraph.
|
``GraphInterrupt`` propagates cleanly to LangGraph.
|
||||||
"""
|
"""
|
||||||
tool_name = tool_def.get("name", "unnamed_tool")
|
tool_name = tool_def.get("name", "unnamed_tool")
|
||||||
tool_description = tool_def.get("description", "No description provided")
|
raw_description = tool_def.get("description", "No description provided")
|
||||||
|
tool_description = (
|
||||||
|
f"[MCP server: {connector_name}] {raw_description}"
|
||||||
|
if connector_name
|
||||||
|
else raw_description
|
||||||
|
)
|
||||||
input_schema = tool_def.get("input_schema", {"type": "object", "properties": {}})
|
input_schema = tool_def.get("input_schema", {"type": "object", "properties": {}})
|
||||||
|
|
||||||
logger.debug("MCP tool '%s' input schema: %s", tool_name, input_schema)
|
logger.debug("MCP tool '%s' input schema: %s", tool_name, input_schema)
|
||||||
|
|
@ -121,7 +166,7 @@ async def _create_mcp_tool_from_definition_stdio(
|
||||||
params=kwargs,
|
params=kwargs,
|
||||||
context={
|
context={
|
||||||
"mcp_server": connector_name,
|
"mcp_server": connector_name,
|
||||||
"tool_description": tool_description,
|
"tool_description": raw_description,
|
||||||
"mcp_transport": "stdio",
|
"mcp_transport": "stdio",
|
||||||
"mcp_connector_id": connector_id,
|
"mcp_connector_id": connector_id,
|
||||||
},
|
},
|
||||||
|
|
@ -129,18 +174,32 @@ async def _create_mcp_tool_from_definition_stdio(
|
||||||
)
|
)
|
||||||
if hitl_result.rejected:
|
if hitl_result.rejected:
|
||||||
return "Tool call rejected by user."
|
return "Tool call rejected by user."
|
||||||
call_kwargs = {k: v for k, v in hitl_result.params.items() if v is not None}
|
call_kwargs = _unpack_synthetic_input_data(
|
||||||
|
{k: v for k, v in hitl_result.params.items() if v is not None}
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
last_error: Exception | None = None
|
||||||
async with mcp_client.connect():
|
for attempt in range(_TOOL_CALL_MAX_RETRIES):
|
||||||
result = await mcp_client.call_tool(tool_name, call_kwargs)
|
try:
|
||||||
return str(result)
|
async with mcp_client.connect():
|
||||||
except RuntimeError as e:
|
result = await mcp_client.call_tool(tool_name, call_kwargs)
|
||||||
logger.error("MCP tool '%s' connection failed after retries: %s", tool_name, e)
|
return str(result)
|
||||||
return f"Error: MCP tool '{tool_name}' connection failed after retries: {e!s}"
|
except Exception as e:
|
||||||
except Exception as e:
|
last_error = e
|
||||||
logger.exception("MCP tool '%s' execution failed: %s", tool_name, e)
|
if attempt < _TOOL_CALL_MAX_RETRIES - 1:
|
||||||
return f"Error: MCP tool '{tool_name}' execution failed: {e!s}"
|
delay = _TOOL_CALL_RETRY_DELAY * (2 ** attempt)
|
||||||
|
logger.warning(
|
||||||
|
"MCP tool '%s' failed (attempt %d/%d): %s. Retrying in %.1fs...",
|
||||||
|
tool_name, attempt + 1, _TOOL_CALL_MAX_RETRIES, e, delay,
|
||||||
|
)
|
||||||
|
await asyncio.sleep(delay)
|
||||||
|
else:
|
||||||
|
logger.error(
|
||||||
|
"MCP tool '%s' failed after %d attempts: %s",
|
||||||
|
tool_name, _TOOL_CALL_MAX_RETRIES, e, exc_info=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
return f"Error: MCP tool '{tool_name}' failed after {_TOOL_CALL_MAX_RETRIES} attempts: {last_error!s}"
|
||||||
|
|
||||||
tool = StructuredTool(
|
tool = StructuredTool(
|
||||||
name=tool_name,
|
name=tool_name,
|
||||||
|
|
@ -150,6 +209,8 @@ async def _create_mcp_tool_from_definition_stdio(
|
||||||
metadata={
|
metadata={
|
||||||
"mcp_input_schema": input_schema,
|
"mcp_input_schema": input_schema,
|
||||||
"mcp_transport": "stdio",
|
"mcp_transport": "stdio",
|
||||||
|
"mcp_connector_name": connector_name or None,
|
||||||
|
"mcp_is_generic": True,
|
||||||
"hitl": True,
|
"hitl": True,
|
||||||
"hitl_dedup_key": next(iter(input_schema.get("required", [])), None),
|
"hitl_dedup_key": next(iter(input_schema.get("required", [])), None),
|
||||||
},
|
},
|
||||||
|
|
@ -169,6 +230,7 @@ async def _create_mcp_tool_from_definition_http(
|
||||||
trusted_tools: list[str] | None = None,
|
trusted_tools: list[str] | None = None,
|
||||||
readonly_tools: frozenset[str] | None = None,
|
readonly_tools: frozenset[str] | None = None,
|
||||||
tool_name_prefix: str | None = None,
|
tool_name_prefix: str | None = None,
|
||||||
|
is_generic_mcp: bool = False,
|
||||||
) -> StructuredTool:
|
) -> StructuredTool:
|
||||||
"""Create a LangChain tool from an MCP tool definition (HTTP transport).
|
"""Create a LangChain tool from an MCP tool definition (HTTP transport).
|
||||||
|
|
||||||
|
|
@ -180,7 +242,7 @@ async def _create_mcp_tool_from_definition_http(
|
||||||
but the actual MCP ``call_tool`` still uses the original name.
|
but the actual MCP ``call_tool`` still uses the original name.
|
||||||
"""
|
"""
|
||||||
original_tool_name = tool_def.get("name", "unnamed_tool")
|
original_tool_name = tool_def.get("name", "unnamed_tool")
|
||||||
tool_description = tool_def.get("description", "No description provided")
|
raw_description = tool_def.get("description", "No description provided")
|
||||||
input_schema = tool_def.get("input_schema", {"type": "object", "properties": {}})
|
input_schema = tool_def.get("input_schema", {"type": "object", "properties": {}})
|
||||||
is_readonly = readonly_tools is not None and original_tool_name in readonly_tools
|
is_readonly = readonly_tools is not None and original_tool_name in readonly_tools
|
||||||
|
|
||||||
|
|
@ -190,7 +252,11 @@ async def _create_mcp_tool_from_definition_http(
|
||||||
else original_tool_name
|
else original_tool_name
|
||||||
)
|
)
|
||||||
if tool_name_prefix:
|
if tool_name_prefix:
|
||||||
tool_description = f"[Account: {connector_name}] {tool_description}"
|
tool_description = f"[Account: {connector_name}] {raw_description}"
|
||||||
|
elif is_generic_mcp and connector_name:
|
||||||
|
tool_description = f"[MCP server: {connector_name}] {raw_description}"
|
||||||
|
else:
|
||||||
|
tool_description = raw_description
|
||||||
|
|
||||||
logger.debug("MCP HTTP tool '%s' input schema: %s", exposed_name, input_schema)
|
logger.debug("MCP HTTP tool '%s' input schema: %s", exposed_name, input_schema)
|
||||||
|
|
||||||
|
|
@ -199,6 +265,7 @@ async def _create_mcp_tool_from_definition_http(
|
||||||
async def _do_mcp_call(
|
async def _do_mcp_call(
|
||||||
call_headers: dict[str, str],
|
call_headers: dict[str, str],
|
||||||
call_kwargs: dict[str, Any],
|
call_kwargs: dict[str, Any],
|
||||||
|
timeout: float = 60.0,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Execute a single MCP HTTP call with the given headers."""
|
"""Execute a single MCP HTTP call with the given headers."""
|
||||||
async with (
|
async with (
|
||||||
|
|
@ -206,8 +273,9 @@ async def _create_mcp_tool_from_definition_http(
|
||||||
ClientSession(read, write) as session,
|
ClientSession(read, write) as session,
|
||||||
):
|
):
|
||||||
await session.initialize()
|
await session.initialize()
|
||||||
response = await session.call_tool(
|
response = await asyncio.wait_for(
|
||||||
original_tool_name, arguments=call_kwargs,
|
session.call_tool(original_tool_name, arguments=call_kwargs),
|
||||||
|
timeout=timeout,
|
||||||
)
|
)
|
||||||
|
|
||||||
result = []
|
result = []
|
||||||
|
|
@ -226,7 +294,9 @@ async def _create_mcp_tool_from_definition_http(
|
||||||
logger.debug("MCP HTTP tool '%s' called", exposed_name)
|
logger.debug("MCP HTTP tool '%s' called", exposed_name)
|
||||||
|
|
||||||
if is_readonly:
|
if is_readonly:
|
||||||
call_kwargs = {k: v for k, v in kwargs.items() if v is not None}
|
call_kwargs = _unpack_synthetic_input_data(
|
||||||
|
{k: v for k, v in kwargs.items() if v is not None}
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
hitl_result = request_approval(
|
hitl_result = request_approval(
|
||||||
action_type="mcp_tool_call",
|
action_type="mcp_tool_call",
|
||||||
|
|
@ -234,7 +304,7 @@ async def _create_mcp_tool_from_definition_http(
|
||||||
params=kwargs,
|
params=kwargs,
|
||||||
context={
|
context={
|
||||||
"mcp_server": connector_name,
|
"mcp_server": connector_name,
|
||||||
"tool_description": tool_description,
|
"tool_description": raw_description,
|
||||||
"mcp_transport": "http",
|
"mcp_transport": "http",
|
||||||
"mcp_connector_id": connector_id,
|
"mcp_connector_id": connector_id,
|
||||||
},
|
},
|
||||||
|
|
@ -242,7 +312,9 @@ async def _create_mcp_tool_from_definition_http(
|
||||||
)
|
)
|
||||||
if hitl_result.rejected:
|
if hitl_result.rejected:
|
||||||
return "Tool call rejected by user."
|
return "Tool call rejected by user."
|
||||||
call_kwargs = {k: v for k, v in hitl_result.params.items() if v is not None}
|
call_kwargs = _unpack_synthetic_input_data(
|
||||||
|
{k: v for k, v in hitl_result.params.items() if v is not None}
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result_str = await _do_mcp_call(headers, call_kwargs)
|
result_str = await _do_mcp_call(headers, call_kwargs)
|
||||||
|
|
@ -295,6 +367,8 @@ async def _create_mcp_tool_from_definition_http(
|
||||||
"mcp_input_schema": input_schema,
|
"mcp_input_schema": input_schema,
|
||||||
"mcp_transport": "http",
|
"mcp_transport": "http",
|
||||||
"mcp_url": url,
|
"mcp_url": url,
|
||||||
|
"mcp_connector_name": connector_name or None,
|
||||||
|
"mcp_is_generic": is_generic_mcp,
|
||||||
"hitl": not is_readonly,
|
"hitl": not is_readonly,
|
||||||
"hitl_dedup_key": next(iter(input_schema.get("required", [])), None),
|
"hitl_dedup_key": next(iter(input_schema.get("required", [])), None),
|
||||||
"mcp_original_tool_name": original_tool_name,
|
"mcp_original_tool_name": original_tool_name,
|
||||||
|
|
@ -376,6 +450,7 @@ async def _load_http_mcp_tools(
|
||||||
allowed_tools: list[str] | None = None,
|
allowed_tools: list[str] | None = None,
|
||||||
readonly_tools: frozenset[str] | None = None,
|
readonly_tools: frozenset[str] | None = None,
|
||||||
tool_name_prefix: str | None = None,
|
tool_name_prefix: str | None = None,
|
||||||
|
is_generic_mcp: bool = False,
|
||||||
) -> list[StructuredTool]:
|
) -> list[StructuredTool]:
|
||||||
"""Load tools from an HTTP-based MCP server.
|
"""Load tools from an HTTP-based MCP server.
|
||||||
|
|
||||||
|
|
@ -492,6 +567,7 @@ async def _load_http_mcp_tools(
|
||||||
trusted_tools=trusted_tools,
|
trusted_tools=trusted_tools,
|
||||||
readonly_tools=readonly_tools,
|
readonly_tools=readonly_tools,
|
||||||
tool_name_prefix=tool_name_prefix,
|
tool_name_prefix=tool_name_prefix,
|
||||||
|
is_generic_mcp=is_generic_mcp,
|
||||||
)
|
)
|
||||||
tools.append(tool)
|
tools.append(tool)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
@ -928,6 +1004,7 @@ async def load_mcp_tools(
|
||||||
"readonly_tools": readonly_tools,
|
"readonly_tools": readonly_tools,
|
||||||
"tool_name_prefix": tool_name_prefix,
|
"tool_name_prefix": tool_name_prefix,
|
||||||
"transport": server_config.get("transport", "stdio"),
|
"transport": server_config.get("transport", "stdio"),
|
||||||
|
"is_generic_mcp": svc_cfg is None,
|
||||||
})
|
})
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
@ -948,6 +1025,7 @@ async def load_mcp_tools(
|
||||||
allowed_tools=task["allowed_tools"],
|
allowed_tools=task["allowed_tools"],
|
||||||
readonly_tools=task["readonly_tools"],
|
readonly_tools=task["readonly_tools"],
|
||||||
tool_name_prefix=task["tool_name_prefix"],
|
tool_name_prefix=task["tool_name_prefix"],
|
||||||
|
is_generic_mcp=task.get("is_generic_mcp", False),
|
||||||
),
|
),
|
||||||
timeout=_MCP_DISCOVERY_TIMEOUT_SECONDS,
|
timeout=_MCP_DISCOVERY_TIMEOUT_SECONDS,
|
||||||
)
|
)
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue