diff --git a/surfsense_backend/app/agents/new_chat/middleware/dedup_tool_calls.py b/surfsense_backend/app/agents/new_chat/middleware/dedup_tool_calls.py index f5e8f1235..bc6f7fd9e 100644 --- a/surfsense_backend/app/agents/new_chat/middleware/dedup_tool_calls.py +++ b/surfsense_backend/app/agents/new_chat/middleware/dedup_tool_calls.py @@ -20,19 +20,39 @@ from langgraph.runtime import Runtime logger = logging.getLogger(__name__) -_HITL_TOOL_DEDUP_KEYS: dict[str, str] = { - "delete_calendar_event": "event_title_or_id", - "update_calendar_event": "event_title_or_id", - "trash_gmail_email": "email_subject_or_id", +_NATIVE_HITL_TOOL_DEDUP_KEYS: dict[str, str] = { + # Gmail + "send_gmail_email": "subject", + "create_gmail_draft": "subject", "update_gmail_draft": "draft_subject_or_id", + "trash_gmail_email": "email_subject_or_id", + # Google Calendar + "create_calendar_event": "title", + "update_calendar_event": "event_title_or_id", + "delete_calendar_event": "event_title_or_id", + # Google Drive + "create_google_drive_file": "file_name", "delete_google_drive_file": "file_name", + # OneDrive + "create_onedrive_file": "file_name", "delete_onedrive_file": "file_name", - "delete_notion_page": "page_title", + # Dropbox + "create_dropbox_file": "file_name", + "delete_dropbox_file": "file_name", + # Notion + "create_notion_page": "title", "update_notion_page": "page_title", - "delete_linear_issue": "issue_ref", + "delete_notion_page": "page_title", + # Linear + "create_linear_issue": "title", "update_linear_issue": "issue_ref", + "delete_linear_issue": "issue_ref", + # Jira + "create_jira_issue": "summary", "update_jira_issue": "issue_title_or_key", "delete_jira_issue": "issue_title_or_key", + # Confluence + "create_confluence_page": "title", "update_confluence_page": "page_title_or_id", "delete_confluence_page": "page_title_or_id", } @@ -43,22 +63,38 @@ class DedupHITLToolCallsMiddleware(AgentMiddleware): # type: ignore[type-arg] Only the **first** occurrence of each (tool-name, primary-arg-value) pair is kept; subsequent duplicates are silently dropped. + + The dedup map is built from two sources: + + 1. A comprehensive list of native HITL tools (hardcoded above). + 2. Any ``StructuredTool`` instances passed via *agent_tools* whose + ``metadata`` contains ``{"hitl": True, "hitl_dedup_key": "..."}``. + This is how MCP tools automatically get dedup support. """ tools = () + def __init__(self, *, agent_tools: list[Any] | None = None) -> None: + self._dedup_keys: dict[str, str] = dict(_NATIVE_HITL_TOOL_DEDUP_KEYS) + for t in agent_tools or []: + meta = getattr(t, "metadata", None) or {} + if meta.get("hitl") and meta.get("hitl_dedup_key"): + self._dedup_keys[t.name] = meta["hitl_dedup_key"] + def after_model( self, state: AgentState, runtime: Runtime[Any] ) -> dict[str, Any] | None: - return self._dedup(state) + return self._dedup(state, self._dedup_keys) async def aafter_model( self, state: AgentState, runtime: Runtime[Any] ) -> dict[str, Any] | None: - return self._dedup(state) + return self._dedup(state, self._dedup_keys) @staticmethod - def _dedup(state: AgentState) -> dict[str, Any] | None: # type: ignore[type-arg] + def _dedup( + state: AgentState, dedup_keys: dict[str, str] # type: ignore[type-arg] + ) -> dict[str, Any] | None: messages = state.get("messages") if not messages: return None @@ -73,7 +109,7 @@ class DedupHITLToolCallsMiddleware(AgentMiddleware): # type: ignore[type-arg] for tc in tool_calls: name = tc.get("name", "") - dedup_key_arg = _HITL_TOOL_DEDUP_KEYS.get(name) + dedup_key_arg = dedup_keys.get(name) if dedup_key_arg is not None: arg_val = str(tc.get("args", {}).get(dedup_key_arg, "")).lower() key = (name, arg_val) diff --git a/surfsense_backend/app/agents/new_chat/tools/mcp_tool.py b/surfsense_backend/app/agents/new_chat/tools/mcp_tool.py index 2fb7ffb06..9743d049d 100644 --- a/surfsense_backend/app/agents/new_chat/tools/mcp_tool.py +++ b/surfsense_backend/app/agents/new_chat/tools/mcp_tool.py @@ -7,7 +7,11 @@ Supports both transport types: - stdio: Local process-based MCP servers (command, args, env) - streamable-http/http/sse: Remote HTTP-based MCP servers (url, headers) -This implements real MCP protocol support similar to Cursor's implementation. +All MCP tools are unconditionally gated by HITL (Human-in-the-Loop) approval. +Per the MCP spec: "Clients MUST consider tool annotations to be untrusted unless +they come from trusted servers." Users can bypass HITL for specific tools by +clicking "Always Allow", which adds the tool name to the connector's +``config.trusted_tools`` allow-list. """ import logging @@ -21,6 +25,7 @@ from pydantic import BaseModel, create_model from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession +from app.agents.new_chat.tools.hitl import request_approval from app.agents.new_chat.tools.mcp_client import MCPClient from app.db import SearchSourceConnector, SearchSourceConnectorType @@ -49,27 +54,15 @@ def _create_dynamic_input_model_from_schema( tool_name: str, input_schema: dict[str, Any], ) -> type[BaseModel]: - """Create a Pydantic model from MCP tool's JSON schema. - - Args: - tool_name: Name of the tool (used for model class name) - input_schema: JSON schema from MCP server - - Returns: - Pydantic model class for tool input validation - - """ + """Create a Pydantic model from MCP tool's JSON schema.""" properties = input_schema.get("properties", {}) required_fields = input_schema.get("required", []) - # Build Pydantic field definitions field_definitions = {} for param_name, param_schema in properties.items(): param_description = param_schema.get("description", "") is_required = param_name in required_fields - # Use Any type for complex schemas to preserve structure - # This allows the MCP server to do its own validation from typing import Any as AnyType from pydantic import Field @@ -85,7 +78,6 @@ def _create_dynamic_input_model_from_schema( Field(None, description=param_description), ) - # Create dynamic model model_name = f"{tool_name.replace(' ', '').replace('-', '_')}Input" return create_model(model_name, **field_definitions) @@ -93,55 +85,70 @@ def _create_dynamic_input_model_from_schema( async def _create_mcp_tool_from_definition_stdio( tool_def: dict[str, Any], mcp_client: MCPClient, + *, + connector_name: str = "", + connector_id: int | None = None, + trusted_tools: list[str] | None = None, ) -> StructuredTool: """Create a LangChain tool from an MCP tool definition (stdio transport). - Args: - tool_def: Tool definition from MCP server with name, description, input_schema - mcp_client: MCP client instance for calling the tool - - Returns: - LangChain StructuredTool instance - + All MCP tools are unconditionally wrapped with HITL approval. + ``request_approval()`` is called OUTSIDE the try/except so that + ``GraphInterrupt`` propagates cleanly to LangGraph. """ tool_name = tool_def.get("name", "unnamed_tool") tool_description = tool_def.get("description", "No description provided") input_schema = tool_def.get("input_schema", {"type": "object", "properties": {}}) - # Log the actual schema for debugging logger.info(f"MCP tool '{tool_name}' input schema: {input_schema}") - # Create dynamic input model from schema input_model = _create_dynamic_input_model_from_schema(tool_name, input_schema) async def mcp_tool_call(**kwargs) -> str: """Execute the MCP tool call via the client with retry support.""" logger.info(f"MCP tool '{tool_name}' called with params: {kwargs}") + # HITL — OUTSIDE try/except so GraphInterrupt propagates to LangGraph + hitl_result = request_approval( + action_type="mcp_tool_call", + tool_name=tool_name, + params=kwargs, + context={ + "mcp_server": connector_name, + "tool_description": tool_description, + "mcp_transport": "stdio", + "mcp_connector_id": connector_id, + }, + trusted_tools=trusted_tools, + ) + if hitl_result.rejected: + return "Tool call rejected by user." + call_kwargs = hitl_result.params + try: - # Connect to server and call tool (connect has built-in retry logic) async with mcp_client.connect(): - result = await mcp_client.call_tool(tool_name, kwargs) + result = await mcp_client.call_tool(tool_name, call_kwargs) return str(result) except RuntimeError as e: - # Connection failures after all retries error_msg = f"MCP tool '{tool_name}' connection failed after retries: {e!s}" logger.error(error_msg) return f"Error: {error_msg}" except Exception as e: - # Tool execution or other errors error_msg = f"MCP tool '{tool_name}' execution failed: {e!s}" logger.exception(error_msg) return f"Error: {error_msg}" - # Create StructuredTool with response_format to preserve exact schema tool = StructuredTool( name=tool_name, description=tool_description, coroutine=mcp_tool_call, args_schema=input_model, - # Store the original MCP schema as metadata so we can access it later - metadata={"mcp_input_schema": input_schema, "mcp_transport": "stdio"}, + metadata={ + "mcp_input_schema": input_schema, + "mcp_transport": "stdio", + "hitl": True, + "hitl_dedup_key": next(iter(input_schema.get("required", [])), None), + }, ) logger.info(f"Created MCP tool (stdio): '{tool_name}'") @@ -152,43 +159,54 @@ async def _create_mcp_tool_from_definition_http( tool_def: dict[str, Any], url: str, headers: dict[str, str], + *, + connector_name: str = "", + connector_id: int | None = None, + trusted_tools: list[str] | None = None, ) -> StructuredTool: """Create a LangChain tool from an MCP tool definition (HTTP transport). - Args: - tool_def: Tool definition from MCP server with name, description, input_schema - url: URL of the MCP server - headers: HTTP headers for authentication - - Returns: - LangChain StructuredTool instance - + All MCP tools are unconditionally wrapped with HITL approval. + ``request_approval()`` is called OUTSIDE the try/except so that + ``GraphInterrupt`` propagates cleanly to LangGraph. """ tool_name = tool_def.get("name", "unnamed_tool") tool_description = tool_def.get("description", "No description provided") input_schema = tool_def.get("input_schema", {"type": "object", "properties": {}}) - # Log the actual schema for debugging logger.info(f"MCP HTTP tool '{tool_name}' input schema: {input_schema}") - # Create dynamic input model from schema input_model = _create_dynamic_input_model_from_schema(tool_name, input_schema) async def mcp_http_tool_call(**kwargs) -> str: """Execute the MCP tool call via HTTP transport.""" logger.info(f"MCP HTTP tool '{tool_name}' called with params: {kwargs}") + # HITL — OUTSIDE try/except so GraphInterrupt propagates to LangGraph + hitl_result = request_approval( + action_type="mcp_tool_call", + tool_name=tool_name, + params=kwargs, + context={ + "mcp_server": connector_name, + "tool_description": tool_description, + "mcp_transport": "http", + "mcp_connector_id": connector_id, + }, + trusted_tools=trusted_tools, + ) + if hitl_result.rejected: + return "Tool call rejected by user." + call_kwargs = hitl_result.params + try: async with ( streamablehttp_client(url, headers=headers) as (read, write, _), ClientSession(read, write) as session, ): await session.initialize() + response = await session.call_tool(tool_name, arguments=call_kwargs) - # Call the tool - response = await session.call_tool(tool_name, arguments=kwargs) - - # Extract content from response result = [] for content in response.content: if hasattr(content, "text"): @@ -209,7 +227,6 @@ async def _create_mcp_tool_from_definition_http( logger.exception(error_msg) return f"Error: {error_msg}" - # Create StructuredTool tool = StructuredTool( name=tool_name, description=tool_description, @@ -219,6 +236,8 @@ async def _create_mcp_tool_from_definition_http( "mcp_input_schema": input_schema, "mcp_transport": "http", "mcp_url": url, + "hitl": True, + "hitl_dedup_key": next(iter(input_schema.get("required", [])), None), }, ) @@ -230,20 +249,11 @@ async def _load_stdio_mcp_tools( connector_id: int, connector_name: str, server_config: dict[str, Any], + trusted_tools: list[str] | None = None, ) -> list[StructuredTool]: - """Load tools from a stdio-based MCP server. - - Args: - connector_id: Connector ID for logging - connector_name: Connector name for logging - server_config: Server configuration with command, args, env - - Returns: - List of tools from the MCP server - """ + """Load tools from a stdio-based MCP server.""" tools: list[StructuredTool] = [] - # Validate required command field command = server_config.get("command") if not command or not isinstance(command, str): logger.warning( @@ -251,7 +261,6 @@ async def _load_stdio_mcp_tools( ) return tools - # Validate args field (must be list if present) args = server_config.get("args", []) if not isinstance(args, list): logger.warning( @@ -259,7 +268,6 @@ async def _load_stdio_mcp_tools( ) return tools - # Validate env field (must be dict if present) env = server_config.get("env", {}) if not isinstance(env, dict): logger.warning( @@ -267,10 +275,8 @@ async def _load_stdio_mcp_tools( ) return tools - # Create MCP client mcp_client = MCPClient(command, args, env) - # Connect and discover tools async with mcp_client.connect(): tool_definitions = await mcp_client.list_tools() @@ -279,10 +285,15 @@ async def _load_stdio_mcp_tools( f"'{command}' (connector {connector_id})" ) - # Create LangChain tools from definitions for tool_def in tool_definitions: try: - tool = await _create_mcp_tool_from_definition_stdio(tool_def, mcp_client) + tool = await _create_mcp_tool_from_definition_stdio( + tool_def, + mcp_client, + connector_name=connector_name, + connector_id=connector_id, + trusted_tools=trusted_tools, + ) tools.append(tool) except Exception as e: logger.exception( @@ -297,20 +308,11 @@ async def _load_http_mcp_tools( connector_id: int, connector_name: str, server_config: dict[str, Any], + trusted_tools: list[str] | None = None, ) -> list[StructuredTool]: - """Load tools from an HTTP-based MCP server. - - Args: - connector_id: Connector ID for logging - connector_name: Connector name for logging - server_config: Server configuration with url, headers - - Returns: - List of tools from the MCP server - """ + """Load tools from an HTTP-based MCP server.""" tools: list[StructuredTool] = [] - # Validate required url field url = server_config.get("url") if not url or not isinstance(url, str): logger.warning( @@ -318,7 +320,6 @@ async def _load_http_mcp_tools( ) return tools - # Validate headers field (must be dict if present) headers = server_config.get("headers", {}) if not isinstance(headers, dict): logger.warning( @@ -326,7 +327,6 @@ async def _load_http_mcp_tools( ) return tools - # Connect and discover tools via HTTP try: async with ( streamablehttp_client(url, headers=headers) as (read, write, _), @@ -334,7 +334,6 @@ async def _load_http_mcp_tools( ): await session.initialize() - # List available tools response = await session.list_tools() tool_definitions = [] for tool in response.tools: @@ -353,11 +352,15 @@ async def _load_http_mcp_tools( f"'{url}' (connector {connector_id})" ) - # Create LangChain tools from definitions for tool_def in tool_definitions: try: tool = await _create_mcp_tool_from_definition_http( - tool_def, url, headers + tool_def, + url, + headers, + connector_name=connector_name, + connector_id=connector_id, + trusted_tools=trusted_tools, ) tools.append(tool) except Exception as e: @@ -398,14 +401,6 @@ async def load_mcp_tools( Results are cached per search space for up to 5 minutes to avoid re-spawning MCP server processes on every chat message. - - Args: - session: Database session - search_space_id: User's search space ID - - Returns: - List of LangChain StructuredTool instances - """ _evict_expired_mcp_cache() @@ -436,6 +431,7 @@ async def load_mcp_tools( try: config = connector.config or {} server_config = config.get("server_config", {}) + trusted_tools = config.get("trusted_tools", []) if not server_config or not isinstance(server_config, dict): logger.warning( @@ -447,11 +443,17 @@ async def load_mcp_tools( if transport in ("streamable-http", "http", "sse"): connector_tools = await _load_http_mcp_tools( - connector.id, connector.name, server_config + connector.id, + connector.name, + server_config, + trusted_tools=trusted_tools, ) else: connector_tools = await _load_stdio_mcp_tools( - connector.id, connector.name, server_config + connector.id, + connector.name, + server_config, + trusted_tools=trusted_tools, ) tools.extend(connector_tools)