refactor: enhance deduplication logic for HITL tool calls

Updated the deduplication mechanism in the DedupHITLToolCallsMiddleware to utilize a comprehensive list of native HITL tools. The deduplication keys are now dynamically populated from both hardcoded values and metadata from StructuredTool instances. Additionally, integrated HITL approval into MCP tool creation, ensuring all tools are gated by user approval, with the ability to bypass for trusted tools.
This commit is contained in:
Anish Sarkar 2026-04-13 20:14:12 +05:30
parent 0c4fd30cce
commit 82c7d4a2ab
2 changed files with 137 additions and 99 deletions

View file

@ -7,7 +7,11 @@ Supports both transport types:
- stdio: Local process-based MCP servers (command, args, env)
- streamable-http/http/sse: Remote HTTP-based MCP servers (url, headers)
This implements real MCP protocol support similar to Cursor's implementation.
All MCP tools are unconditionally gated by HITL (Human-in-the-Loop) approval.
Per the MCP spec: "Clients MUST consider tool annotations to be untrusted unless
they come from trusted servers." Users can bypass HITL for specific tools by
clicking "Always Allow", which adds the tool name to the connector's
``config.trusted_tools`` allow-list.
"""
import logging
@ -21,6 +25,7 @@ from pydantic import BaseModel, create_model
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.new_chat.tools.hitl import request_approval
from app.agents.new_chat.tools.mcp_client import MCPClient
from app.db import SearchSourceConnector, SearchSourceConnectorType
@ -49,27 +54,15 @@ def _create_dynamic_input_model_from_schema(
tool_name: str,
input_schema: dict[str, Any],
) -> type[BaseModel]:
"""Create a Pydantic model from MCP tool's JSON schema.
Args:
tool_name: Name of the tool (used for model class name)
input_schema: JSON schema from MCP server
Returns:
Pydantic model class for tool input validation
"""
"""Create a Pydantic model from MCP tool's JSON schema."""
properties = input_schema.get("properties", {})
required_fields = input_schema.get("required", [])
# Build Pydantic field definitions
field_definitions = {}
for param_name, param_schema in properties.items():
param_description = param_schema.get("description", "")
is_required = param_name in required_fields
# Use Any type for complex schemas to preserve structure
# This allows the MCP server to do its own validation
from typing import Any as AnyType
from pydantic import Field
@ -85,7 +78,6 @@ def _create_dynamic_input_model_from_schema(
Field(None, description=param_description),
)
# Create dynamic model
model_name = f"{tool_name.replace(' ', '').replace('-', '_')}Input"
return create_model(model_name, **field_definitions)
@ -93,55 +85,70 @@ def _create_dynamic_input_model_from_schema(
async def _create_mcp_tool_from_definition_stdio(
tool_def: dict[str, Any],
mcp_client: MCPClient,
*,
connector_name: str = "",
connector_id: int | None = None,
trusted_tools: list[str] | None = None,
) -> StructuredTool:
"""Create a LangChain tool from an MCP tool definition (stdio transport).
Args:
tool_def: Tool definition from MCP server with name, description, input_schema
mcp_client: MCP client instance for calling the tool
Returns:
LangChain StructuredTool instance
All MCP tools are unconditionally wrapped with HITL approval.
``request_approval()`` is called OUTSIDE the try/except so that
``GraphInterrupt`` propagates cleanly to LangGraph.
"""
tool_name = tool_def.get("name", "unnamed_tool")
tool_description = tool_def.get("description", "No description provided")
input_schema = tool_def.get("input_schema", {"type": "object", "properties": {}})
# Log the actual schema for debugging
logger.info(f"MCP tool '{tool_name}' input schema: {input_schema}")
# Create dynamic input model from schema
input_model = _create_dynamic_input_model_from_schema(tool_name, input_schema)
async def mcp_tool_call(**kwargs) -> str:
"""Execute the MCP tool call via the client with retry support."""
logger.info(f"MCP tool '{tool_name}' called with params: {kwargs}")
# HITL — OUTSIDE try/except so GraphInterrupt propagates to LangGraph
hitl_result = request_approval(
action_type="mcp_tool_call",
tool_name=tool_name,
params=kwargs,
context={
"mcp_server": connector_name,
"tool_description": tool_description,
"mcp_transport": "stdio",
"mcp_connector_id": connector_id,
},
trusted_tools=trusted_tools,
)
if hitl_result.rejected:
return "Tool call rejected by user."
call_kwargs = hitl_result.params
try:
# Connect to server and call tool (connect has built-in retry logic)
async with mcp_client.connect():
result = await mcp_client.call_tool(tool_name, kwargs)
result = await mcp_client.call_tool(tool_name, call_kwargs)
return str(result)
except RuntimeError as e:
# Connection failures after all retries
error_msg = f"MCP tool '{tool_name}' connection failed after retries: {e!s}"
logger.error(error_msg)
return f"Error: {error_msg}"
except Exception as e:
# Tool execution or other errors
error_msg = f"MCP tool '{tool_name}' execution failed: {e!s}"
logger.exception(error_msg)
return f"Error: {error_msg}"
# Create StructuredTool with response_format to preserve exact schema
tool = StructuredTool(
name=tool_name,
description=tool_description,
coroutine=mcp_tool_call,
args_schema=input_model,
# Store the original MCP schema as metadata so we can access it later
metadata={"mcp_input_schema": input_schema, "mcp_transport": "stdio"},
metadata={
"mcp_input_schema": input_schema,
"mcp_transport": "stdio",
"hitl": True,
"hitl_dedup_key": next(iter(input_schema.get("required", [])), None),
},
)
logger.info(f"Created MCP tool (stdio): '{tool_name}'")
@ -152,43 +159,54 @@ async def _create_mcp_tool_from_definition_http(
tool_def: dict[str, Any],
url: str,
headers: dict[str, str],
*,
connector_name: str = "",
connector_id: int | None = None,
trusted_tools: list[str] | None = None,
) -> StructuredTool:
"""Create a LangChain tool from an MCP tool definition (HTTP transport).
Args:
tool_def: Tool definition from MCP server with name, description, input_schema
url: URL of the MCP server
headers: HTTP headers for authentication
Returns:
LangChain StructuredTool instance
All MCP tools are unconditionally wrapped with HITL approval.
``request_approval()`` is called OUTSIDE the try/except so that
``GraphInterrupt`` propagates cleanly to LangGraph.
"""
tool_name = tool_def.get("name", "unnamed_tool")
tool_description = tool_def.get("description", "No description provided")
input_schema = tool_def.get("input_schema", {"type": "object", "properties": {}})
# Log the actual schema for debugging
logger.info(f"MCP HTTP tool '{tool_name}' input schema: {input_schema}")
# Create dynamic input model from schema
input_model = _create_dynamic_input_model_from_schema(tool_name, input_schema)
async def mcp_http_tool_call(**kwargs) -> str:
"""Execute the MCP tool call via HTTP transport."""
logger.info(f"MCP HTTP tool '{tool_name}' called with params: {kwargs}")
# HITL — OUTSIDE try/except so GraphInterrupt propagates to LangGraph
hitl_result = request_approval(
action_type="mcp_tool_call",
tool_name=tool_name,
params=kwargs,
context={
"mcp_server": connector_name,
"tool_description": tool_description,
"mcp_transport": "http",
"mcp_connector_id": connector_id,
},
trusted_tools=trusted_tools,
)
if hitl_result.rejected:
return "Tool call rejected by user."
call_kwargs = hitl_result.params
try:
async with (
streamablehttp_client(url, headers=headers) as (read, write, _),
ClientSession(read, write) as session,
):
await session.initialize()
response = await session.call_tool(tool_name, arguments=call_kwargs)
# Call the tool
response = await session.call_tool(tool_name, arguments=kwargs)
# Extract content from response
result = []
for content in response.content:
if hasattr(content, "text"):
@ -209,7 +227,6 @@ async def _create_mcp_tool_from_definition_http(
logger.exception(error_msg)
return f"Error: {error_msg}"
# Create StructuredTool
tool = StructuredTool(
name=tool_name,
description=tool_description,
@ -219,6 +236,8 @@ async def _create_mcp_tool_from_definition_http(
"mcp_input_schema": input_schema,
"mcp_transport": "http",
"mcp_url": url,
"hitl": True,
"hitl_dedup_key": next(iter(input_schema.get("required", [])), None),
},
)
@ -230,20 +249,11 @@ async def _load_stdio_mcp_tools(
connector_id: int,
connector_name: str,
server_config: dict[str, Any],
trusted_tools: list[str] | None = None,
) -> list[StructuredTool]:
"""Load tools from a stdio-based MCP server.
Args:
connector_id: Connector ID for logging
connector_name: Connector name for logging
server_config: Server configuration with command, args, env
Returns:
List of tools from the MCP server
"""
"""Load tools from a stdio-based MCP server."""
tools: list[StructuredTool] = []
# Validate required command field
command = server_config.get("command")
if not command or not isinstance(command, str):
logger.warning(
@ -251,7 +261,6 @@ async def _load_stdio_mcp_tools(
)
return tools
# Validate args field (must be list if present)
args = server_config.get("args", [])
if not isinstance(args, list):
logger.warning(
@ -259,7 +268,6 @@ async def _load_stdio_mcp_tools(
)
return tools
# Validate env field (must be dict if present)
env = server_config.get("env", {})
if not isinstance(env, dict):
logger.warning(
@ -267,10 +275,8 @@ async def _load_stdio_mcp_tools(
)
return tools
# Create MCP client
mcp_client = MCPClient(command, args, env)
# Connect and discover tools
async with mcp_client.connect():
tool_definitions = await mcp_client.list_tools()
@ -279,10 +285,15 @@ async def _load_stdio_mcp_tools(
f"'{command}' (connector {connector_id})"
)
# Create LangChain tools from definitions
for tool_def in tool_definitions:
try:
tool = await _create_mcp_tool_from_definition_stdio(tool_def, mcp_client)
tool = await _create_mcp_tool_from_definition_stdio(
tool_def,
mcp_client,
connector_name=connector_name,
connector_id=connector_id,
trusted_tools=trusted_tools,
)
tools.append(tool)
except Exception as e:
logger.exception(
@ -297,20 +308,11 @@ async def _load_http_mcp_tools(
connector_id: int,
connector_name: str,
server_config: dict[str, Any],
trusted_tools: list[str] | None = None,
) -> list[StructuredTool]:
"""Load tools from an HTTP-based MCP server.
Args:
connector_id: Connector ID for logging
connector_name: Connector name for logging
server_config: Server configuration with url, headers
Returns:
List of tools from the MCP server
"""
"""Load tools from an HTTP-based MCP server."""
tools: list[StructuredTool] = []
# Validate required url field
url = server_config.get("url")
if not url or not isinstance(url, str):
logger.warning(
@ -318,7 +320,6 @@ async def _load_http_mcp_tools(
)
return tools
# Validate headers field (must be dict if present)
headers = server_config.get("headers", {})
if not isinstance(headers, dict):
logger.warning(
@ -326,7 +327,6 @@ async def _load_http_mcp_tools(
)
return tools
# Connect and discover tools via HTTP
try:
async with (
streamablehttp_client(url, headers=headers) as (read, write, _),
@ -334,7 +334,6 @@ async def _load_http_mcp_tools(
):
await session.initialize()
# List available tools
response = await session.list_tools()
tool_definitions = []
for tool in response.tools:
@ -353,11 +352,15 @@ async def _load_http_mcp_tools(
f"'{url}' (connector {connector_id})"
)
# Create LangChain tools from definitions
for tool_def in tool_definitions:
try:
tool = await _create_mcp_tool_from_definition_http(
tool_def, url, headers
tool_def,
url,
headers,
connector_name=connector_name,
connector_id=connector_id,
trusted_tools=trusted_tools,
)
tools.append(tool)
except Exception as e:
@ -398,14 +401,6 @@ async def load_mcp_tools(
Results are cached per search space for up to 5 minutes to avoid
re-spawning MCP server processes on every chat message.
Args:
session: Database session
search_space_id: User's search space ID
Returns:
List of LangChain StructuredTool instances
"""
_evict_expired_mcp_cache()
@ -436,6 +431,7 @@ async def load_mcp_tools(
try:
config = connector.config or {}
server_config = config.get("server_config", {})
trusted_tools = config.get("trusted_tools", [])
if not server_config or not isinstance(server_config, dict):
logger.warning(
@ -447,11 +443,17 @@ async def load_mcp_tools(
if transport in ("streamable-http", "http", "sse"):
connector_tools = await _load_http_mcp_tools(
connector.id, connector.name, server_config
connector.id,
connector.name,
server_config,
trusted_tools=trusted_tools,
)
else:
connector_tools = await _load_stdio_mcp_tools(
connector.id, connector.name, server_config
connector.id,
connector.name,
server_config,
trusted_tools=trusted_tools,
)
tools.extend(connector_tools)