From 82c7d4a2ab2589f0fa0ef48dde42c92c80721f35 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Mon, 13 Apr 2026 20:14:12 +0530 Subject: [PATCH] refactor: enhance deduplication logic for HITL tool calls Updated the deduplication mechanism in the DedupHITLToolCallsMiddleware to utilize a comprehensive list of native HITL tools. The deduplication keys are now dynamically populated from both hardcoded values and metadata from StructuredTool instances. Additionally, integrated HITL approval into MCP tool creation, ensuring all tools are gated by user approval, with the ability to bypass for trusted tools. --- .../new_chat/middleware/dedup_tool_calls.py | 56 +++++- .../app/agents/new_chat/tools/mcp_tool.py | 180 +++++++++--------- 2 files changed, 137 insertions(+), 99 deletions(-) diff --git a/surfsense_backend/app/agents/new_chat/middleware/dedup_tool_calls.py b/surfsense_backend/app/agents/new_chat/middleware/dedup_tool_calls.py index f5e8f1235..bc6f7fd9e 100644 --- a/surfsense_backend/app/agents/new_chat/middleware/dedup_tool_calls.py +++ b/surfsense_backend/app/agents/new_chat/middleware/dedup_tool_calls.py @@ -20,19 +20,39 @@ from langgraph.runtime import Runtime logger = logging.getLogger(__name__) -_HITL_TOOL_DEDUP_KEYS: dict[str, str] = { - "delete_calendar_event": "event_title_or_id", - "update_calendar_event": "event_title_or_id", - "trash_gmail_email": "email_subject_or_id", +_NATIVE_HITL_TOOL_DEDUP_KEYS: dict[str, str] = { + # Gmail + "send_gmail_email": "subject", + "create_gmail_draft": "subject", "update_gmail_draft": "draft_subject_or_id", + "trash_gmail_email": "email_subject_or_id", + # Google Calendar + "create_calendar_event": "title", + "update_calendar_event": "event_title_or_id", + "delete_calendar_event": "event_title_or_id", + # Google Drive + "create_google_drive_file": "file_name", "delete_google_drive_file": "file_name", + # OneDrive + "create_onedrive_file": "file_name", "delete_onedrive_file": "file_name", - "delete_notion_page": "page_title", + # Dropbox + "create_dropbox_file": "file_name", + "delete_dropbox_file": "file_name", + # Notion + "create_notion_page": "title", "update_notion_page": "page_title", - "delete_linear_issue": "issue_ref", + "delete_notion_page": "page_title", + # Linear + "create_linear_issue": "title", "update_linear_issue": "issue_ref", + "delete_linear_issue": "issue_ref", + # Jira + "create_jira_issue": "summary", "update_jira_issue": "issue_title_or_key", "delete_jira_issue": "issue_title_or_key", + # Confluence + "create_confluence_page": "title", "update_confluence_page": "page_title_or_id", "delete_confluence_page": "page_title_or_id", } @@ -43,22 +63,38 @@ class DedupHITLToolCallsMiddleware(AgentMiddleware): # type: ignore[type-arg] Only the **first** occurrence of each (tool-name, primary-arg-value) pair is kept; subsequent duplicates are silently dropped. + + The dedup map is built from two sources: + + 1. A comprehensive list of native HITL tools (hardcoded above). + 2. Any ``StructuredTool`` instances passed via *agent_tools* whose + ``metadata`` contains ``{"hitl": True, "hitl_dedup_key": "..."}``. + This is how MCP tools automatically get dedup support. """ tools = () + def __init__(self, *, agent_tools: list[Any] | None = None) -> None: + self._dedup_keys: dict[str, str] = dict(_NATIVE_HITL_TOOL_DEDUP_KEYS) + for t in agent_tools or []: + meta = getattr(t, "metadata", None) or {} + if meta.get("hitl") and meta.get("hitl_dedup_key"): + self._dedup_keys[t.name] = meta["hitl_dedup_key"] + def after_model( self, state: AgentState, runtime: Runtime[Any] ) -> dict[str, Any] | None: - return self._dedup(state) + return self._dedup(state, self._dedup_keys) async def aafter_model( self, state: AgentState, runtime: Runtime[Any] ) -> dict[str, Any] | None: - return self._dedup(state) + return self._dedup(state, self._dedup_keys) @staticmethod - def _dedup(state: AgentState) -> dict[str, Any] | None: # type: ignore[type-arg] + def _dedup( + state: AgentState, dedup_keys: dict[str, str] # type: ignore[type-arg] + ) -> dict[str, Any] | None: messages = state.get("messages") if not messages: return None @@ -73,7 +109,7 @@ class DedupHITLToolCallsMiddleware(AgentMiddleware): # type: ignore[type-arg] for tc in tool_calls: name = tc.get("name", "") - dedup_key_arg = _HITL_TOOL_DEDUP_KEYS.get(name) + dedup_key_arg = dedup_keys.get(name) if dedup_key_arg is not None: arg_val = str(tc.get("args", {}).get(dedup_key_arg, "")).lower() key = (name, arg_val) diff --git a/surfsense_backend/app/agents/new_chat/tools/mcp_tool.py b/surfsense_backend/app/agents/new_chat/tools/mcp_tool.py index 2fb7ffb06..9743d049d 100644 --- a/surfsense_backend/app/agents/new_chat/tools/mcp_tool.py +++ b/surfsense_backend/app/agents/new_chat/tools/mcp_tool.py @@ -7,7 +7,11 @@ Supports both transport types: - stdio: Local process-based MCP servers (command, args, env) - streamable-http/http/sse: Remote HTTP-based MCP servers (url, headers) -This implements real MCP protocol support similar to Cursor's implementation. +All MCP tools are unconditionally gated by HITL (Human-in-the-Loop) approval. +Per the MCP spec: "Clients MUST consider tool annotations to be untrusted unless +they come from trusted servers." Users can bypass HITL for specific tools by +clicking "Always Allow", which adds the tool name to the connector's +``config.trusted_tools`` allow-list. """ import logging @@ -21,6 +25,7 @@ from pydantic import BaseModel, create_model from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession +from app.agents.new_chat.tools.hitl import request_approval from app.agents.new_chat.tools.mcp_client import MCPClient from app.db import SearchSourceConnector, SearchSourceConnectorType @@ -49,27 +54,15 @@ def _create_dynamic_input_model_from_schema( tool_name: str, input_schema: dict[str, Any], ) -> type[BaseModel]: - """Create a Pydantic model from MCP tool's JSON schema. - - Args: - tool_name: Name of the tool (used for model class name) - input_schema: JSON schema from MCP server - - Returns: - Pydantic model class for tool input validation - - """ + """Create a Pydantic model from MCP tool's JSON schema.""" properties = input_schema.get("properties", {}) required_fields = input_schema.get("required", []) - # Build Pydantic field definitions field_definitions = {} for param_name, param_schema in properties.items(): param_description = param_schema.get("description", "") is_required = param_name in required_fields - # Use Any type for complex schemas to preserve structure - # This allows the MCP server to do its own validation from typing import Any as AnyType from pydantic import Field @@ -85,7 +78,6 @@ def _create_dynamic_input_model_from_schema( Field(None, description=param_description), ) - # Create dynamic model model_name = f"{tool_name.replace(' ', '').replace('-', '_')}Input" return create_model(model_name, **field_definitions) @@ -93,55 +85,70 @@ def _create_dynamic_input_model_from_schema( async def _create_mcp_tool_from_definition_stdio( tool_def: dict[str, Any], mcp_client: MCPClient, + *, + connector_name: str = "", + connector_id: int | None = None, + trusted_tools: list[str] | None = None, ) -> StructuredTool: """Create a LangChain tool from an MCP tool definition (stdio transport). - Args: - tool_def: Tool definition from MCP server with name, description, input_schema - mcp_client: MCP client instance for calling the tool - - Returns: - LangChain StructuredTool instance - + All MCP tools are unconditionally wrapped with HITL approval. + ``request_approval()`` is called OUTSIDE the try/except so that + ``GraphInterrupt`` propagates cleanly to LangGraph. """ tool_name = tool_def.get("name", "unnamed_tool") tool_description = tool_def.get("description", "No description provided") input_schema = tool_def.get("input_schema", {"type": "object", "properties": {}}) - # Log the actual schema for debugging logger.info(f"MCP tool '{tool_name}' input schema: {input_schema}") - # Create dynamic input model from schema input_model = _create_dynamic_input_model_from_schema(tool_name, input_schema) async def mcp_tool_call(**kwargs) -> str: """Execute the MCP tool call via the client with retry support.""" logger.info(f"MCP tool '{tool_name}' called with params: {kwargs}") + # HITL — OUTSIDE try/except so GraphInterrupt propagates to LangGraph + hitl_result = request_approval( + action_type="mcp_tool_call", + tool_name=tool_name, + params=kwargs, + context={ + "mcp_server": connector_name, + "tool_description": tool_description, + "mcp_transport": "stdio", + "mcp_connector_id": connector_id, + }, + trusted_tools=trusted_tools, + ) + if hitl_result.rejected: + return "Tool call rejected by user." + call_kwargs = hitl_result.params + try: - # Connect to server and call tool (connect has built-in retry logic) async with mcp_client.connect(): - result = await mcp_client.call_tool(tool_name, kwargs) + result = await mcp_client.call_tool(tool_name, call_kwargs) return str(result) except RuntimeError as e: - # Connection failures after all retries error_msg = f"MCP tool '{tool_name}' connection failed after retries: {e!s}" logger.error(error_msg) return f"Error: {error_msg}" except Exception as e: - # Tool execution or other errors error_msg = f"MCP tool '{tool_name}' execution failed: {e!s}" logger.exception(error_msg) return f"Error: {error_msg}" - # Create StructuredTool with response_format to preserve exact schema tool = StructuredTool( name=tool_name, description=tool_description, coroutine=mcp_tool_call, args_schema=input_model, - # Store the original MCP schema as metadata so we can access it later - metadata={"mcp_input_schema": input_schema, "mcp_transport": "stdio"}, + metadata={ + "mcp_input_schema": input_schema, + "mcp_transport": "stdio", + "hitl": True, + "hitl_dedup_key": next(iter(input_schema.get("required", [])), None), + }, ) logger.info(f"Created MCP tool (stdio): '{tool_name}'") @@ -152,43 +159,54 @@ async def _create_mcp_tool_from_definition_http( tool_def: dict[str, Any], url: str, headers: dict[str, str], + *, + connector_name: str = "", + connector_id: int | None = None, + trusted_tools: list[str] | None = None, ) -> StructuredTool: """Create a LangChain tool from an MCP tool definition (HTTP transport). - Args: - tool_def: Tool definition from MCP server with name, description, input_schema - url: URL of the MCP server - headers: HTTP headers for authentication - - Returns: - LangChain StructuredTool instance - + All MCP tools are unconditionally wrapped with HITL approval. + ``request_approval()`` is called OUTSIDE the try/except so that + ``GraphInterrupt`` propagates cleanly to LangGraph. """ tool_name = tool_def.get("name", "unnamed_tool") tool_description = tool_def.get("description", "No description provided") input_schema = tool_def.get("input_schema", {"type": "object", "properties": {}}) - # Log the actual schema for debugging logger.info(f"MCP HTTP tool '{tool_name}' input schema: {input_schema}") - # Create dynamic input model from schema input_model = _create_dynamic_input_model_from_schema(tool_name, input_schema) async def mcp_http_tool_call(**kwargs) -> str: """Execute the MCP tool call via HTTP transport.""" logger.info(f"MCP HTTP tool '{tool_name}' called with params: {kwargs}") + # HITL — OUTSIDE try/except so GraphInterrupt propagates to LangGraph + hitl_result = request_approval( + action_type="mcp_tool_call", + tool_name=tool_name, + params=kwargs, + context={ + "mcp_server": connector_name, + "tool_description": tool_description, + "mcp_transport": "http", + "mcp_connector_id": connector_id, + }, + trusted_tools=trusted_tools, + ) + if hitl_result.rejected: + return "Tool call rejected by user." + call_kwargs = hitl_result.params + try: async with ( streamablehttp_client(url, headers=headers) as (read, write, _), ClientSession(read, write) as session, ): await session.initialize() + response = await session.call_tool(tool_name, arguments=call_kwargs) - # Call the tool - response = await session.call_tool(tool_name, arguments=kwargs) - - # Extract content from response result = [] for content in response.content: if hasattr(content, "text"): @@ -209,7 +227,6 @@ async def _create_mcp_tool_from_definition_http( logger.exception(error_msg) return f"Error: {error_msg}" - # Create StructuredTool tool = StructuredTool( name=tool_name, description=tool_description, @@ -219,6 +236,8 @@ async def _create_mcp_tool_from_definition_http( "mcp_input_schema": input_schema, "mcp_transport": "http", "mcp_url": url, + "hitl": True, + "hitl_dedup_key": next(iter(input_schema.get("required", [])), None), }, ) @@ -230,20 +249,11 @@ async def _load_stdio_mcp_tools( connector_id: int, connector_name: str, server_config: dict[str, Any], + trusted_tools: list[str] | None = None, ) -> list[StructuredTool]: - """Load tools from a stdio-based MCP server. - - Args: - connector_id: Connector ID for logging - connector_name: Connector name for logging - server_config: Server configuration with command, args, env - - Returns: - List of tools from the MCP server - """ + """Load tools from a stdio-based MCP server.""" tools: list[StructuredTool] = [] - # Validate required command field command = server_config.get("command") if not command or not isinstance(command, str): logger.warning( @@ -251,7 +261,6 @@ async def _load_stdio_mcp_tools( ) return tools - # Validate args field (must be list if present) args = server_config.get("args", []) if not isinstance(args, list): logger.warning( @@ -259,7 +268,6 @@ async def _load_stdio_mcp_tools( ) return tools - # Validate env field (must be dict if present) env = server_config.get("env", {}) if not isinstance(env, dict): logger.warning( @@ -267,10 +275,8 @@ async def _load_stdio_mcp_tools( ) return tools - # Create MCP client mcp_client = MCPClient(command, args, env) - # Connect and discover tools async with mcp_client.connect(): tool_definitions = await mcp_client.list_tools() @@ -279,10 +285,15 @@ async def _load_stdio_mcp_tools( f"'{command}' (connector {connector_id})" ) - # Create LangChain tools from definitions for tool_def in tool_definitions: try: - tool = await _create_mcp_tool_from_definition_stdio(tool_def, mcp_client) + tool = await _create_mcp_tool_from_definition_stdio( + tool_def, + mcp_client, + connector_name=connector_name, + connector_id=connector_id, + trusted_tools=trusted_tools, + ) tools.append(tool) except Exception as e: logger.exception( @@ -297,20 +308,11 @@ async def _load_http_mcp_tools( connector_id: int, connector_name: str, server_config: dict[str, Any], + trusted_tools: list[str] | None = None, ) -> list[StructuredTool]: - """Load tools from an HTTP-based MCP server. - - Args: - connector_id: Connector ID for logging - connector_name: Connector name for logging - server_config: Server configuration with url, headers - - Returns: - List of tools from the MCP server - """ + """Load tools from an HTTP-based MCP server.""" tools: list[StructuredTool] = [] - # Validate required url field url = server_config.get("url") if not url or not isinstance(url, str): logger.warning( @@ -318,7 +320,6 @@ async def _load_http_mcp_tools( ) return tools - # Validate headers field (must be dict if present) headers = server_config.get("headers", {}) if not isinstance(headers, dict): logger.warning( @@ -326,7 +327,6 @@ async def _load_http_mcp_tools( ) return tools - # Connect and discover tools via HTTP try: async with ( streamablehttp_client(url, headers=headers) as (read, write, _), @@ -334,7 +334,6 @@ async def _load_http_mcp_tools( ): await session.initialize() - # List available tools response = await session.list_tools() tool_definitions = [] for tool in response.tools: @@ -353,11 +352,15 @@ async def _load_http_mcp_tools( f"'{url}' (connector {connector_id})" ) - # Create LangChain tools from definitions for tool_def in tool_definitions: try: tool = await _create_mcp_tool_from_definition_http( - tool_def, url, headers + tool_def, + url, + headers, + connector_name=connector_name, + connector_id=connector_id, + trusted_tools=trusted_tools, ) tools.append(tool) except Exception as e: @@ -398,14 +401,6 @@ async def load_mcp_tools( Results are cached per search space for up to 5 minutes to avoid re-spawning MCP server processes on every chat message. - - Args: - session: Database session - search_space_id: User's search space ID - - Returns: - List of LangChain StructuredTool instances - """ _evict_expired_mcp_cache() @@ -436,6 +431,7 @@ async def load_mcp_tools( try: config = connector.config or {} server_config = config.get("server_config", {}) + trusted_tools = config.get("trusted_tools", []) if not server_config or not isinstance(server_config, dict): logger.warning( @@ -447,11 +443,17 @@ async def load_mcp_tools( if transport in ("streamable-http", "http", "sse"): connector_tools = await _load_http_mcp_tools( - connector.id, connector.name, server_config + connector.id, + connector.name, + server_config, + trusted_tools=trusted_tools, ) else: connector_tools = await _load_stdio_mcp_tools( - connector.id, connector.name, server_config + connector.id, + connector.name, + server_config, + trusted_tools=trusted_tools, ) tools.extend(connector_tools)