mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-04 21:32:39 +02:00
refactor: enhance deduplication logic for HITL tool calls
Updated the deduplication mechanism in the DedupHITLToolCallsMiddleware to utilize a comprehensive list of native HITL tools. The deduplication keys are now dynamically populated from both hardcoded values and metadata from StructuredTool instances. Additionally, integrated HITL approval into MCP tool creation, ensuring all tools are gated by user approval, with the ability to bypass for trusted tools.
This commit is contained in:
parent
0c4fd30cce
commit
82c7d4a2ab
2 changed files with 137 additions and 99 deletions
|
|
@ -7,7 +7,11 @@ Supports both transport types:
|
|||
- stdio: Local process-based MCP servers (command, args, env)
|
||||
- streamable-http/http/sse: Remote HTTP-based MCP servers (url, headers)
|
||||
|
||||
This implements real MCP protocol support similar to Cursor's implementation.
|
||||
All MCP tools are unconditionally gated by HITL (Human-in-the-Loop) approval.
|
||||
Per the MCP spec: "Clients MUST consider tool annotations to be untrusted unless
|
||||
they come from trusted servers." Users can bypass HITL for specific tools by
|
||||
clicking "Always Allow", which adds the tool name to the connector's
|
||||
``config.trusted_tools`` allow-list.
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
|
@ -21,6 +25,7 @@ from pydantic import BaseModel, create_model
|
|||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.agents.new_chat.tools.hitl import request_approval
|
||||
from app.agents.new_chat.tools.mcp_client import MCPClient
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||
|
||||
|
|
@ -49,27 +54,15 @@ def _create_dynamic_input_model_from_schema(
|
|||
tool_name: str,
|
||||
input_schema: dict[str, Any],
|
||||
) -> type[BaseModel]:
|
||||
"""Create a Pydantic model from MCP tool's JSON schema.
|
||||
|
||||
Args:
|
||||
tool_name: Name of the tool (used for model class name)
|
||||
input_schema: JSON schema from MCP server
|
||||
|
||||
Returns:
|
||||
Pydantic model class for tool input validation
|
||||
|
||||
"""
|
||||
"""Create a Pydantic model from MCP tool's JSON schema."""
|
||||
properties = input_schema.get("properties", {})
|
||||
required_fields = input_schema.get("required", [])
|
||||
|
||||
# Build Pydantic field definitions
|
||||
field_definitions = {}
|
||||
for param_name, param_schema in properties.items():
|
||||
param_description = param_schema.get("description", "")
|
||||
is_required = param_name in required_fields
|
||||
|
||||
# Use Any type for complex schemas to preserve structure
|
||||
# This allows the MCP server to do its own validation
|
||||
from typing import Any as AnyType
|
||||
|
||||
from pydantic import Field
|
||||
|
|
@ -85,7 +78,6 @@ def _create_dynamic_input_model_from_schema(
|
|||
Field(None, description=param_description),
|
||||
)
|
||||
|
||||
# Create dynamic model
|
||||
model_name = f"{tool_name.replace(' ', '').replace('-', '_')}Input"
|
||||
return create_model(model_name, **field_definitions)
|
||||
|
||||
|
|
@ -93,55 +85,70 @@ def _create_dynamic_input_model_from_schema(
|
|||
async def _create_mcp_tool_from_definition_stdio(
|
||||
tool_def: dict[str, Any],
|
||||
mcp_client: MCPClient,
|
||||
*,
|
||||
connector_name: str = "",
|
||||
connector_id: int | None = None,
|
||||
trusted_tools: list[str] | None = None,
|
||||
) -> StructuredTool:
|
||||
"""Create a LangChain tool from an MCP tool definition (stdio transport).
|
||||
|
||||
Args:
|
||||
tool_def: Tool definition from MCP server with name, description, input_schema
|
||||
mcp_client: MCP client instance for calling the tool
|
||||
|
||||
Returns:
|
||||
LangChain StructuredTool instance
|
||||
|
||||
All MCP tools are unconditionally wrapped with HITL approval.
|
||||
``request_approval()`` is called OUTSIDE the try/except so that
|
||||
``GraphInterrupt`` propagates cleanly to LangGraph.
|
||||
"""
|
||||
tool_name = tool_def.get("name", "unnamed_tool")
|
||||
tool_description = tool_def.get("description", "No description provided")
|
||||
input_schema = tool_def.get("input_schema", {"type": "object", "properties": {}})
|
||||
|
||||
# Log the actual schema for debugging
|
||||
logger.info(f"MCP tool '{tool_name}' input schema: {input_schema}")
|
||||
|
||||
# Create dynamic input model from schema
|
||||
input_model = _create_dynamic_input_model_from_schema(tool_name, input_schema)
|
||||
|
||||
async def mcp_tool_call(**kwargs) -> str:
|
||||
"""Execute the MCP tool call via the client with retry support."""
|
||||
logger.info(f"MCP tool '{tool_name}' called with params: {kwargs}")
|
||||
|
||||
# HITL — OUTSIDE try/except so GraphInterrupt propagates to LangGraph
|
||||
hitl_result = request_approval(
|
||||
action_type="mcp_tool_call",
|
||||
tool_name=tool_name,
|
||||
params=kwargs,
|
||||
context={
|
||||
"mcp_server": connector_name,
|
||||
"tool_description": tool_description,
|
||||
"mcp_transport": "stdio",
|
||||
"mcp_connector_id": connector_id,
|
||||
},
|
||||
trusted_tools=trusted_tools,
|
||||
)
|
||||
if hitl_result.rejected:
|
||||
return "Tool call rejected by user."
|
||||
call_kwargs = hitl_result.params
|
||||
|
||||
try:
|
||||
# Connect to server and call tool (connect has built-in retry logic)
|
||||
async with mcp_client.connect():
|
||||
result = await mcp_client.call_tool(tool_name, kwargs)
|
||||
result = await mcp_client.call_tool(tool_name, call_kwargs)
|
||||
return str(result)
|
||||
except RuntimeError as e:
|
||||
# Connection failures after all retries
|
||||
error_msg = f"MCP tool '{tool_name}' connection failed after retries: {e!s}"
|
||||
logger.error(error_msg)
|
||||
return f"Error: {error_msg}"
|
||||
except Exception as e:
|
||||
# Tool execution or other errors
|
||||
error_msg = f"MCP tool '{tool_name}' execution failed: {e!s}"
|
||||
logger.exception(error_msg)
|
||||
return f"Error: {error_msg}"
|
||||
|
||||
# Create StructuredTool with response_format to preserve exact schema
|
||||
tool = StructuredTool(
|
||||
name=tool_name,
|
||||
description=tool_description,
|
||||
coroutine=mcp_tool_call,
|
||||
args_schema=input_model,
|
||||
# Store the original MCP schema as metadata so we can access it later
|
||||
metadata={"mcp_input_schema": input_schema, "mcp_transport": "stdio"},
|
||||
metadata={
|
||||
"mcp_input_schema": input_schema,
|
||||
"mcp_transport": "stdio",
|
||||
"hitl": True,
|
||||
"hitl_dedup_key": next(iter(input_schema.get("required", [])), None),
|
||||
},
|
||||
)
|
||||
|
||||
logger.info(f"Created MCP tool (stdio): '{tool_name}'")
|
||||
|
|
@ -152,43 +159,54 @@ async def _create_mcp_tool_from_definition_http(
|
|||
tool_def: dict[str, Any],
|
||||
url: str,
|
||||
headers: dict[str, str],
|
||||
*,
|
||||
connector_name: str = "",
|
||||
connector_id: int | None = None,
|
||||
trusted_tools: list[str] | None = None,
|
||||
) -> StructuredTool:
|
||||
"""Create a LangChain tool from an MCP tool definition (HTTP transport).
|
||||
|
||||
Args:
|
||||
tool_def: Tool definition from MCP server with name, description, input_schema
|
||||
url: URL of the MCP server
|
||||
headers: HTTP headers for authentication
|
||||
|
||||
Returns:
|
||||
LangChain StructuredTool instance
|
||||
|
||||
All MCP tools are unconditionally wrapped with HITL approval.
|
||||
``request_approval()`` is called OUTSIDE the try/except so that
|
||||
``GraphInterrupt`` propagates cleanly to LangGraph.
|
||||
"""
|
||||
tool_name = tool_def.get("name", "unnamed_tool")
|
||||
tool_description = tool_def.get("description", "No description provided")
|
||||
input_schema = tool_def.get("input_schema", {"type": "object", "properties": {}})
|
||||
|
||||
# Log the actual schema for debugging
|
||||
logger.info(f"MCP HTTP tool '{tool_name}' input schema: {input_schema}")
|
||||
|
||||
# Create dynamic input model from schema
|
||||
input_model = _create_dynamic_input_model_from_schema(tool_name, input_schema)
|
||||
|
||||
async def mcp_http_tool_call(**kwargs) -> str:
|
||||
"""Execute the MCP tool call via HTTP transport."""
|
||||
logger.info(f"MCP HTTP tool '{tool_name}' called with params: {kwargs}")
|
||||
|
||||
# HITL — OUTSIDE try/except so GraphInterrupt propagates to LangGraph
|
||||
hitl_result = request_approval(
|
||||
action_type="mcp_tool_call",
|
||||
tool_name=tool_name,
|
||||
params=kwargs,
|
||||
context={
|
||||
"mcp_server": connector_name,
|
||||
"tool_description": tool_description,
|
||||
"mcp_transport": "http",
|
||||
"mcp_connector_id": connector_id,
|
||||
},
|
||||
trusted_tools=trusted_tools,
|
||||
)
|
||||
if hitl_result.rejected:
|
||||
return "Tool call rejected by user."
|
||||
call_kwargs = hitl_result.params
|
||||
|
||||
try:
|
||||
async with (
|
||||
streamablehttp_client(url, headers=headers) as (read, write, _),
|
||||
ClientSession(read, write) as session,
|
||||
):
|
||||
await session.initialize()
|
||||
response = await session.call_tool(tool_name, arguments=call_kwargs)
|
||||
|
||||
# Call the tool
|
||||
response = await session.call_tool(tool_name, arguments=kwargs)
|
||||
|
||||
# Extract content from response
|
||||
result = []
|
||||
for content in response.content:
|
||||
if hasattr(content, "text"):
|
||||
|
|
@ -209,7 +227,6 @@ async def _create_mcp_tool_from_definition_http(
|
|||
logger.exception(error_msg)
|
||||
return f"Error: {error_msg}"
|
||||
|
||||
# Create StructuredTool
|
||||
tool = StructuredTool(
|
||||
name=tool_name,
|
||||
description=tool_description,
|
||||
|
|
@ -219,6 +236,8 @@ async def _create_mcp_tool_from_definition_http(
|
|||
"mcp_input_schema": input_schema,
|
||||
"mcp_transport": "http",
|
||||
"mcp_url": url,
|
||||
"hitl": True,
|
||||
"hitl_dedup_key": next(iter(input_schema.get("required", [])), None),
|
||||
},
|
||||
)
|
||||
|
||||
|
|
@ -230,20 +249,11 @@ async def _load_stdio_mcp_tools(
|
|||
connector_id: int,
|
||||
connector_name: str,
|
||||
server_config: dict[str, Any],
|
||||
trusted_tools: list[str] | None = None,
|
||||
) -> list[StructuredTool]:
|
||||
"""Load tools from a stdio-based MCP server.
|
||||
|
||||
Args:
|
||||
connector_id: Connector ID for logging
|
||||
connector_name: Connector name for logging
|
||||
server_config: Server configuration with command, args, env
|
||||
|
||||
Returns:
|
||||
List of tools from the MCP server
|
||||
"""
|
||||
"""Load tools from a stdio-based MCP server."""
|
||||
tools: list[StructuredTool] = []
|
||||
|
||||
# Validate required command field
|
||||
command = server_config.get("command")
|
||||
if not command or not isinstance(command, str):
|
||||
logger.warning(
|
||||
|
|
@ -251,7 +261,6 @@ async def _load_stdio_mcp_tools(
|
|||
)
|
||||
return tools
|
||||
|
||||
# Validate args field (must be list if present)
|
||||
args = server_config.get("args", [])
|
||||
if not isinstance(args, list):
|
||||
logger.warning(
|
||||
|
|
@ -259,7 +268,6 @@ async def _load_stdio_mcp_tools(
|
|||
)
|
||||
return tools
|
||||
|
||||
# Validate env field (must be dict if present)
|
||||
env = server_config.get("env", {})
|
||||
if not isinstance(env, dict):
|
||||
logger.warning(
|
||||
|
|
@ -267,10 +275,8 @@ async def _load_stdio_mcp_tools(
|
|||
)
|
||||
return tools
|
||||
|
||||
# Create MCP client
|
||||
mcp_client = MCPClient(command, args, env)
|
||||
|
||||
# Connect and discover tools
|
||||
async with mcp_client.connect():
|
||||
tool_definitions = await mcp_client.list_tools()
|
||||
|
||||
|
|
@ -279,10 +285,15 @@ async def _load_stdio_mcp_tools(
|
|||
f"'{command}' (connector {connector_id})"
|
||||
)
|
||||
|
||||
# Create LangChain tools from definitions
|
||||
for tool_def in tool_definitions:
|
||||
try:
|
||||
tool = await _create_mcp_tool_from_definition_stdio(tool_def, mcp_client)
|
||||
tool = await _create_mcp_tool_from_definition_stdio(
|
||||
tool_def,
|
||||
mcp_client,
|
||||
connector_name=connector_name,
|
||||
connector_id=connector_id,
|
||||
trusted_tools=trusted_tools,
|
||||
)
|
||||
tools.append(tool)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
|
|
@ -297,20 +308,11 @@ async def _load_http_mcp_tools(
|
|||
connector_id: int,
|
||||
connector_name: str,
|
||||
server_config: dict[str, Any],
|
||||
trusted_tools: list[str] | None = None,
|
||||
) -> list[StructuredTool]:
|
||||
"""Load tools from an HTTP-based MCP server.
|
||||
|
||||
Args:
|
||||
connector_id: Connector ID for logging
|
||||
connector_name: Connector name for logging
|
||||
server_config: Server configuration with url, headers
|
||||
|
||||
Returns:
|
||||
List of tools from the MCP server
|
||||
"""
|
||||
"""Load tools from an HTTP-based MCP server."""
|
||||
tools: list[StructuredTool] = []
|
||||
|
||||
# Validate required url field
|
||||
url = server_config.get("url")
|
||||
if not url or not isinstance(url, str):
|
||||
logger.warning(
|
||||
|
|
@ -318,7 +320,6 @@ async def _load_http_mcp_tools(
|
|||
)
|
||||
return tools
|
||||
|
||||
# Validate headers field (must be dict if present)
|
||||
headers = server_config.get("headers", {})
|
||||
if not isinstance(headers, dict):
|
||||
logger.warning(
|
||||
|
|
@ -326,7 +327,6 @@ async def _load_http_mcp_tools(
|
|||
)
|
||||
return tools
|
||||
|
||||
# Connect and discover tools via HTTP
|
||||
try:
|
||||
async with (
|
||||
streamablehttp_client(url, headers=headers) as (read, write, _),
|
||||
|
|
@ -334,7 +334,6 @@ async def _load_http_mcp_tools(
|
|||
):
|
||||
await session.initialize()
|
||||
|
||||
# List available tools
|
||||
response = await session.list_tools()
|
||||
tool_definitions = []
|
||||
for tool in response.tools:
|
||||
|
|
@ -353,11 +352,15 @@ async def _load_http_mcp_tools(
|
|||
f"'{url}' (connector {connector_id})"
|
||||
)
|
||||
|
||||
# Create LangChain tools from definitions
|
||||
for tool_def in tool_definitions:
|
||||
try:
|
||||
tool = await _create_mcp_tool_from_definition_http(
|
||||
tool_def, url, headers
|
||||
tool_def,
|
||||
url,
|
||||
headers,
|
||||
connector_name=connector_name,
|
||||
connector_id=connector_id,
|
||||
trusted_tools=trusted_tools,
|
||||
)
|
||||
tools.append(tool)
|
||||
except Exception as e:
|
||||
|
|
@ -398,14 +401,6 @@ async def load_mcp_tools(
|
|||
|
||||
Results are cached per search space for up to 5 minutes to avoid
|
||||
re-spawning MCP server processes on every chat message.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
search_space_id: User's search space ID
|
||||
|
||||
Returns:
|
||||
List of LangChain StructuredTool instances
|
||||
|
||||
"""
|
||||
_evict_expired_mcp_cache()
|
||||
|
||||
|
|
@ -436,6 +431,7 @@ async def load_mcp_tools(
|
|||
try:
|
||||
config = connector.config or {}
|
||||
server_config = config.get("server_config", {})
|
||||
trusted_tools = config.get("trusted_tools", [])
|
||||
|
||||
if not server_config or not isinstance(server_config, dict):
|
||||
logger.warning(
|
||||
|
|
@ -447,11 +443,17 @@ async def load_mcp_tools(
|
|||
|
||||
if transport in ("streamable-http", "http", "sse"):
|
||||
connector_tools = await _load_http_mcp_tools(
|
||||
connector.id, connector.name, server_config
|
||||
connector.id,
|
||||
connector.name,
|
||||
server_config,
|
||||
trusted_tools=trusted_tools,
|
||||
)
|
||||
else:
|
||||
connector_tools = await _load_stdio_mcp_tools(
|
||||
connector.id, connector.name, server_config
|
||||
connector.id,
|
||||
connector.name,
|
||||
server_config,
|
||||
trusted_tools=trusted_tools,
|
||||
)
|
||||
|
||||
tools.extend(connector_tools)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue