diff --git a/surfsense_backend/app/agents/new_chat/middleware/dedup_tool_calls.py b/surfsense_backend/app/agents/new_chat/middleware/dedup_tool_calls.py index c55347284..a6d2ce310 100644 --- a/surfsense_backend/app/agents/new_chat/middleware/dedup_tool_calls.py +++ b/surfsense_backend/app/agents/new_chat/middleware/dedup_tool_calls.py @@ -21,6 +21,7 @@ A tool with no resolver from either path simply opts out of dedup. from __future__ import annotations +import json import logging from collections.abc import Callable from typing import Any @@ -57,6 +58,19 @@ def wrap_dedup_key_by_arg_name(arg_name: str) -> DedupResolver: return _resolver +def dedup_key_full_args(args: dict[str, Any]) -> str: + """Resolver that collapses calls only when **every** argument is identical. + + Safe default for tools where no single field uniquely identifies a call + (e.g. MCP tools whose first required field is a shared workspace id). + """ + + try: + return json.dumps(args, sort_keys=True, default=str) + except (TypeError, ValueError): + return repr(sorted(args.items())) if isinstance(args, dict) else repr(args) + + # Backwards-compatible alias for code that imported the original # private name. New callers should use :func:`wrap_dedup_key_by_arg_name`. _wrap_string_key = wrap_dedup_key_by_arg_name diff --git a/surfsense_backend/app/agents/new_chat/middleware/permission.py b/surfsense_backend/app/agents/new_chat/middleware/permission.py index 37719e96a..5ea7f1740 100644 --- a/surfsense_backend/app/agents/new_chat/middleware/permission.py +++ b/surfsense_backend/app/agents/new_chat/middleware/permission.py @@ -19,8 +19,9 @@ Operation: the results: ``deny`` > ``ask`` > ``allow``. 3. On ``deny``: replaces the call with a synthetic ``ToolMessage`` containing a :class:`StreamingError`. -4. On ``ask``: raises a SurfSense-style ``interrupt(...)``. The reply - shape is ``{"decision_type": "once|always|reject", "feedback"?: str}``. +4. On ``ask``: raises a SurfSense-style ``interrupt(...)``. Both the legacy + SurfSense shape and LangChain HITL ``{"decisions": [{"type": ...}]}`` + replies are accepted via :func:`_normalize_permission_decision`. - ``once``: proceed. - ``always``: also persist allow rules for ``request.always`` patterns. - ``reject`` w/o feedback: raise :class:`RejectedError`. @@ -81,6 +82,75 @@ def _default_pattern_resolver(name: str) -> PatternResolver: return _resolve +# Translation from the LangChain HITL envelope (what ``stream_resume_chat`` +# sends) to SurfSense's legacy ``decision_type`` shape. ``edit`` keeps the +# original tool args — tools needing argument edits should use +# ``request_approval`` from ``app/agents/new_chat/tools/hitl.py``. +_LC_TYPE_TO_PERMISSION_DECISION: dict[str, str] = { + "approve": "once", + "reject": "reject", + "edit": "once", +} + + +def _normalize_permission_decision(decision: Any) -> dict[str, Any]: + """Coerce any accepted reply shape into ``{"decision_type": ..., "feedback"?}``. + + Falls back to ``reject`` (with a warning) on unrecognized payloads so the + middleware fails closed. + """ + if isinstance(decision, str): + return {"decision_type": decision} + if not isinstance(decision, dict): + logger.warning( + "Unrecognized permission resume value (%s); treating as reject", + type(decision).__name__, + ) + return {"decision_type": "reject"} + + if decision.get("decision_type"): + return decision + + payload: dict[str, Any] = decision + decisions = decision.get("decisions") + if isinstance(decisions, list) and decisions: + first = decisions[0] + if isinstance(first, dict): + payload = first + + raw_type = payload.get("type") or payload.get("decision_type") + if not raw_type: + logger.warning( + "Permission resume missing decision type (keys=%s); treating as reject", + list(payload.keys()), + ) + return {"decision_type": "reject"} + + raw_type = str(raw_type).lower() + mapped = _LC_TYPE_TO_PERMISSION_DECISION.get(raw_type) + if mapped is None: + # Tolerate legacy values arriving without ``decision_type`` wrapping. + if raw_type in {"once", "always", "reject"}: + mapped = raw_type + else: + logger.warning( + "Unknown permission decision type %r; treating as reject", raw_type + ) + mapped = "reject" + + if raw_type == "edit": + logger.warning( + "Permission middleware received an 'edit' decision; original args " + "kept (edits not merged here)." + ) + + out: dict[str, Any] = {"decision_type": mapped} + feedback = payload.get("feedback") or payload.get("message") + if isinstance(feedback, str) and feedback.strip(): + out["feedback"] = feedback + return out + + class PermissionMiddleware(AgentMiddleware): # type: ignore[type-arg] """Allow/deny/ask layer over the agent's tool calls. @@ -214,12 +284,7 @@ class PermissionMiddleware(AgentMiddleware): # type: ignore[type-arg] ot.interrupt_span(interrupt_type="permission_ask"), ): decision = interrupt(payload) - if isinstance(decision, dict): - return decision - # Tolerate a plain string reply ("once", "always", "reject") - if isinstance(decision, str): - return {"decision_type": decision} - return {"decision_type": "reject"} + return _normalize_permission_decision(decision) def _persist_always(self, tool_name: str, patterns: list[str]) -> None: """Promote ``always`` reply into runtime allow rules. @@ -355,4 +420,5 @@ class PermissionMiddleware(AgentMiddleware): # type: ignore[type-arg] __all__ = [ "PatternResolver", "PermissionMiddleware", + "_normalize_permission_decision", ] diff --git a/surfsense_backend/app/agents/new_chat/tools/mcp_tool.py b/surfsense_backend/app/agents/new_chat/tools/mcp_tool.py index 5b96ab374..92a808a5e 100644 --- a/surfsense_backend/app/agents/new_chat/tools/mcp_tool.py +++ b/surfsense_backend/app/agents/new_chat/tools/mcp_tool.py @@ -33,6 +33,7 @@ from sqlalchemy import cast, select from sqlalchemy.dialects.postgresql import JSONB from sqlalchemy.ext.asyncio import AsyncSession +from app.agents.new_chat.middleware.dedup_tool_calls import dedup_key_full_args from app.agents.new_chat.tools.hitl import request_approval from app.agents.new_chat.tools.mcp_client import MCPClient from app.db import SearchSourceConnector @@ -45,7 +46,10 @@ _MCP_CACHE_MAX_SIZE = 50 _MCP_DISCOVERY_TIMEOUT_SECONDS = 30 _TOOL_CALL_MAX_RETRIES = 3 _TOOL_CALL_RETRY_DELAY = 1.5 # seconds, doubles per attempt -_mcp_tools_cache: dict[int, tuple[float, list[StructuredTool]]] = {} +# Keyed by ``(search_space_id, bypass_internal_hitl)`` so single-agent and +# multi-agent paths cannot share tool closures with different HITL wiring. +_MCPCacheKey = tuple[int, bool] +_mcp_tools_cache: dict[_MCPCacheKey, tuple[float, list[StructuredTool]]] = {} def _evict_expired_mcp_cache() -> None: @@ -137,12 +141,13 @@ async def _create_mcp_tool_from_definition_stdio( connector_name: str = "", connector_id: int | None = None, trusted_tools: list[str] | None = None, + bypass_internal_hitl: bool = False, ) -> StructuredTool: """Create a LangChain tool from an MCP tool definition (stdio transport). - All MCP tools are unconditionally wrapped with HITL approval. - ``request_approval()`` is called OUTSIDE the try/except so that - ``GraphInterrupt`` propagates cleanly to LangGraph. + Set ``bypass_internal_hitl=True`` when an outer ``HumanInTheLoopMiddleware`` + already gates the tool, otherwise the body's ``request_approval()`` is the + sole HITL gate (single-agent path). """ tool_name = tool_def.get("name", "unnamed_tool") raw_description = tool_def.get("description", "No description provided") @@ -161,24 +166,29 @@ async def _create_mcp_tool_from_definition_stdio( """Execute the MCP tool call via the client with retry support.""" logger.debug("MCP tool '%s' called", tool_name) - # HITL — OUTSIDE try/except so GraphInterrupt propagates to LangGraph - hitl_result = request_approval( - action_type="mcp_tool_call", - tool_name=tool_name, - params=kwargs, - context={ - "mcp_server": connector_name, - "tool_description": raw_description, - "mcp_transport": "stdio", - "mcp_connector_id": connector_id, - }, - trusted_tools=trusted_tools, - ) - if hitl_result.rejected: - return "Tool call rejected by user." - call_kwargs = _unpack_synthetic_input_data( - {k: v for k, v in hitl_result.params.items() if v is not None} - ) + if bypass_internal_hitl: + call_kwargs = _unpack_synthetic_input_data( + {k: v for k, v in kwargs.items() if v is not None} + ) + else: + # Outside try/except so ``GraphInterrupt`` propagates to LangGraph. + hitl_result = request_approval( + action_type="mcp_tool_call", + tool_name=tool_name, + params=kwargs, + context={ + "mcp_server": connector_name, + "tool_description": raw_description, + "mcp_transport": "stdio", + "mcp_connector_id": connector_id, + }, + trusted_tools=trusted_tools, + ) + if hitl_result.rejected: + return "Tool call rejected by user." + call_kwargs = _unpack_synthetic_input_data( + {k: v for k, v in hitl_result.params.items() if v is not None} + ) last_error: Exception | None = None for attempt in range(_TOOL_CALL_MAX_RETRIES): @@ -221,7 +231,9 @@ async def _create_mcp_tool_from_definition_stdio( "mcp_connector_name": connector_name or None, "mcp_is_generic": True, "hitl": True, - "hitl_dedup_key": next(iter(input_schema.get("required", [])), None), + # Full-args hash: shared identifiers (cloudId, workspaceId, …) + # would otherwise collapse legitimate batches. + "dedup_key": dedup_key_full_args, }, ) @@ -240,11 +252,14 @@ async def _create_mcp_tool_from_definition_http( readonly_tools: frozenset[str] | None = None, tool_name_prefix: str | None = None, is_generic_mcp: bool = False, + bypass_internal_hitl: bool = False, ) -> StructuredTool: """Create a LangChain tool from an MCP tool definition (HTTP transport). Write tools are wrapped with HITL approval; read-only tools (listed in - ``readonly_tools``) execute immediately without user confirmation. + ``readonly_tools``) execute immediately without user confirmation. Set + ``bypass_internal_hitl=True`` when an outer ``HumanInTheLoopMiddleware`` + already gates the tool. When ``tool_name_prefix`` is set (multi-account disambiguation), the tool exposed to the LLM gets a prefixed name (e.g. ``linear_25_list_issues``) @@ -302,7 +317,7 @@ async def _create_mcp_tool_from_definition_http( """Execute the MCP tool call via HTTP transport.""" logger.debug("MCP HTTP tool '%s' called", exposed_name) - if is_readonly: + if is_readonly or bypass_internal_hitl: call_kwargs = _unpack_synthetic_input_data( {k: v for k, v in kwargs.items() if v is not None} ) @@ -385,7 +400,9 @@ async def _create_mcp_tool_from_definition_http( "mcp_connector_name": connector_name or None, "mcp_is_generic": is_generic_mcp, "hitl": not is_readonly, - "hitl_dedup_key": next(iter(input_schema.get("required", [])), None), + # Full-args hash: shared identifiers (cloudId, workspaceId, …) + # would otherwise collapse legitimate batches. + "dedup_key": dedup_key_full_args, "mcp_original_tool_name": original_tool_name, "mcp_connector_id": connector_id, }, @@ -400,6 +417,8 @@ async def _load_stdio_mcp_tools( connector_name: str, server_config: dict[str, Any], trusted_tools: list[str] | None = None, + *, + bypass_internal_hitl: bool = False, ) -> list[StructuredTool]: """Load tools from a stdio-based MCP server.""" tools: list[StructuredTool] = [] @@ -451,6 +470,7 @@ async def _load_stdio_mcp_tools( connector_name=connector_name, connector_id=connector_id, trusted_tools=trusted_tools, + bypass_internal_hitl=bypass_internal_hitl, ) tools.append(tool) except Exception as e: @@ -473,6 +493,8 @@ async def _load_http_mcp_tools( readonly_tools: frozenset[str] | None = None, tool_name_prefix: str | None = None, is_generic_mcp: bool = False, + *, + bypass_internal_hitl: bool = False, ) -> list[StructuredTool]: """Load tools from an HTTP-based MCP server. @@ -598,6 +620,7 @@ async def _load_http_mcp_tools( readonly_tools=readonly_tools, tool_name_prefix=tool_name_prefix, is_generic_mcp=is_generic_mcp, + bypass_internal_hitl=bypass_internal_hitl, ) tools.append(tool) except Exception as e: @@ -905,14 +928,10 @@ async def _mark_connector_auth_expired(connector_id: int) -> None: def invalidate_mcp_tools_cache(search_space_id: int | None = None) -> None: - """Invalidate cached MCP tools. - - Args: - search_space_id: If provided, only invalidate for this search space. - If None, invalidate all cached MCP tools. - """ + """Invalidate cached MCP tools (both ``bypass_internal_hitl`` variants together).""" if search_space_id is not None: - _mcp_tools_cache.pop(search_space_id, None) + for key in [k for k in _mcp_tools_cache if k[0] == search_space_id]: + _mcp_tools_cache.pop(key, None) else: _mcp_tools_cache.clear() @@ -920,27 +939,29 @@ def invalidate_mcp_tools_cache(search_space_id: int | None = None) -> None: async def load_mcp_tools( session: AsyncSession, search_space_id: int, + *, + bypass_internal_hitl: bool = False, ) -> list[StructuredTool]: - """Load all MCP tools from user's active MCP server connectors. + """Load all MCP tools from the user's active MCP server connectors. - This discovers tools dynamically from MCP servers using the protocol. - Supports both stdio (local process) and HTTP (remote server) transports. - - Results are cached per search space for up to 5 minutes to avoid - re-spawning MCP server processes on every chat message. + Results are cached per ``(search_space_id, bypass_internal_hitl)`` for up + to 5 minutes; bypass is keyed because each variant builds a different tool + closure (with vs. without the in-wrapper ``request_approval`` gate). """ _evict_expired_mcp_cache() now = time.monotonic() - cached = _mcp_tools_cache.get(search_space_id) + cache_key: _MCPCacheKey = (search_space_id, bypass_internal_hitl) + cached = _mcp_tools_cache.get(cache_key) if cached is not None: cached_at, cached_tools = cached if now - cached_at < _MCP_CACHE_TTL_SECONDS: logger.info( - "Using cached MCP tools for search space %s (%d tools, age=%.0fs)", + "Using cached MCP tools for search space %s (%d tools, age=%.0fs, bypass_hitl=%s)", search_space_id, len(cached_tools), now - cached_at, + bypass_internal_hitl, ) return list(cached_tools) @@ -1064,6 +1085,7 @@ async def load_mcp_tools( readonly_tools=task["readonly_tools"], tool_name_prefix=task["tool_name_prefix"], is_generic_mcp=task.get("is_generic_mcp", False), + bypass_internal_hitl=bypass_internal_hitl, ), timeout=_MCP_DISCOVERY_TIMEOUT_SECONDS, ) @@ -1074,6 +1096,7 @@ async def load_mcp_tools( task["connector_name"], task["server_config"], trusted_tools=task["trusted_tools"], + bypass_internal_hitl=bypass_internal_hitl, ), timeout=_MCP_DISCOVERY_TIMEOUT_SECONDS, ) @@ -1095,14 +1118,17 @@ async def load_mcp_tools( results = await asyncio.gather(*[_discover_one(t) for t in discovery_tasks]) tools: list[StructuredTool] = [tool for sublist in results for tool in sublist] - _mcp_tools_cache[search_space_id] = (now, tools) + _mcp_tools_cache[cache_key] = (now, tools) if len(_mcp_tools_cache) > _MCP_CACHE_MAX_SIZE: oldest_key = min(_mcp_tools_cache, key=lambda k: _mcp_tools_cache[k][0]) del _mcp_tools_cache[oldest_key] logger.info( - "Loaded %d MCP tools for search space %d", len(tools), search_space_id + "Loaded %d MCP tools for search space %d (bypass_hitl=%s)", + len(tools), + search_space_id, + bypass_internal_hitl, ) return tools diff --git a/surfsense_backend/tests/unit/agents/new_chat/test_dedup_tool_calls.py b/surfsense_backend/tests/unit/agents/new_chat/test_dedup_tool_calls.py index e04f50815..61d9b499f 100644 --- a/surfsense_backend/tests/unit/agents/new_chat/test_dedup_tool_calls.py +++ b/surfsense_backend/tests/unit/agents/new_chat/test_dedup_tool_calls.py @@ -130,6 +130,79 @@ def test_registry_propagates_dedup_key_to_tool_metadata() -> None: assert sample == "plan" +def test_full_args_dedup_keeps_distinct_calls_sharing_a_field() -> None: + """Regression: MCP tools (e.g. ``createJiraIssue``) used to dedup on + the schema's first required field, which is often the workspace / + cloudId — so 3 distinct issues in the same workspace collapsed to 1. + + With :func:`dedup_key_full_args` only fully identical arg dicts dedup. + """ + from app.agents.new_chat.middleware.dedup_tool_calls import dedup_key_full_args + + tool = _make_tool("createJiraIssue", dedup_key=dedup_key_full_args) + mw = DedupHITLToolCallsMiddleware(agent_tools=[tool]) + state = { + "messages": [ + _msg( + { + "name": "createJiraIssue", + "args": { + "cloudId": "ws.atlassian.net", + "projectKey": "PROJ", + "summary": "Fix login bug", + }, + "id": "1", + }, + { + "name": "createJiraIssue", + "args": { + "cloudId": "ws.atlassian.net", + "projectKey": "PROJ", + "summary": "Add dark mode", + }, + "id": "2", + }, + { + "name": "createJiraIssue", + "args": { + "cloudId": "ws.atlassian.net", + "projectKey": "PROJ", + "summary": "Improve perf", + }, + "id": "3", + }, + ) + ] + } + out = mw.after_model(state, _Runtime()) + assert out is None # nothing dropped — all three differ in summary + + +def test_full_args_dedup_drops_only_exact_duplicates() -> None: + from app.agents.new_chat.middleware.dedup_tool_calls import dedup_key_full_args + + tool = _make_tool("createJiraIssue", dedup_key=dedup_key_full_args) + mw = DedupHITLToolCallsMiddleware(agent_tools=[tool]) + args = {"cloudId": "ws.atlassian.net", "summary": "Fix bug"} + state = { + "messages": [ + _msg( + {"name": "createJiraIssue", "args": args, "id": "1"}, + {"name": "createJiraIssue", "args": dict(args), "id": "2"}, + { + "name": "createJiraIssue", + "args": {**args, "summary": "Different"}, + "id": "3", + }, + ) + ] + } + out = mw.after_model(state, _Runtime()) + assert out is not None + new_calls = out["messages"][0].tool_calls + assert {c["id"] for c in new_calls} == {"1", "3"} + + def test_unknown_tool_passes_through() -> None: mw = DedupHITLToolCallsMiddleware(agent_tools=None) state = { diff --git a/surfsense_backend/tests/unit/agents/new_chat/test_permission_middleware.py b/surfsense_backend/tests/unit/agents/new_chat/test_permission_middleware.py index a997c8d61..eda5be150 100644 --- a/surfsense_backend/tests/unit/agents/new_chat/test_permission_middleware.py +++ b/surfsense_backend/tests/unit/agents/new_chat/test_permission_middleware.py @@ -6,7 +6,10 @@ import pytest from langchain_core.messages import AIMessage, ToolMessage from app.agents.new_chat.errors import CorrectedError, RejectedError -from app.agents.new_chat.middleware.permission import PermissionMiddleware +from app.agents.new_chat.middleware.permission import ( + PermissionMiddleware, + _normalize_permission_decision, +) from app.agents.new_chat.permissions import Rule, Ruleset pytestmark = pytest.mark.unit @@ -112,3 +115,151 @@ class TestAsk: # Runtime ruleset got the always-allow rule new_rules = [r for r in mw._runtime_ruleset.rules if r.action == "allow"] assert any(r.permission == "send_email" for r in new_rules) + + +class TestNormalizeDecision: + """Resume shapes ``_normalize_permission_decision`` must accept.""" + + def test_legacy_decision_type_dict_passes_through(self) -> None: + decision = {"decision_type": "once"} + assert _normalize_permission_decision(decision) == {"decision_type": "once"} + + def test_legacy_decision_type_with_feedback_passes_through(self) -> None: + decision = {"decision_type": "reject", "feedback": "no thanks"} + assert _normalize_permission_decision(decision) == decision + + def test_plain_string_wrapped(self) -> None: + assert _normalize_permission_decision("once") == {"decision_type": "once"} + assert _normalize_permission_decision("reject") == {"decision_type": "reject"} + + def test_lc_envelope_approve_maps_to_once(self) -> None: + decision = {"decisions": [{"type": "approve"}]} + assert _normalize_permission_decision(decision) == {"decision_type": "once"} + + def test_lc_envelope_reject_maps_to_reject(self) -> None: + decision = {"decisions": [{"type": "reject"}]} + assert _normalize_permission_decision(decision) == {"decision_type": "reject"} + + def test_lc_envelope_reject_with_message_carries_feedback(self) -> None: + decision = { + "decisions": [{"type": "reject", "message": "wrong recipient"}] + } + out = _normalize_permission_decision(decision) + assert out == {"decision_type": "reject", "feedback": "wrong recipient"} + + def test_lc_envelope_reject_with_feedback_field(self) -> None: + decision = { + "decisions": [{"type": "reject", "feedback": "tighten the subject"}] + } + out = _normalize_permission_decision(decision) + assert out == {"decision_type": "reject", "feedback": "tighten the subject"} + + def test_lc_envelope_edit_maps_to_once(self) -> None: + # Pins the contract: edited args are NOT merged by permission. + decision = { + "decisions": [ + { + "type": "edit", + "edited_action": { + "name": "send_email", + "args": {"subject": "edited"}, + }, + } + ] + } + assert _normalize_permission_decision(decision) == {"decision_type": "once"} + + def test_lc_single_decision_without_envelope(self) -> None: + assert _normalize_permission_decision({"type": "approve"}) == { + "decision_type": "once" + } + + def test_unknown_type_falls_back_to_reject(self) -> None: + decision = {"decisions": [{"type": "totally_unknown"}]} + assert _normalize_permission_decision(decision) == {"decision_type": "reject"} + + def test_missing_type_falls_back_to_reject(self) -> None: + assert _normalize_permission_decision({"decisions": [{}]}) == { + "decision_type": "reject" + } + + def test_non_dict_non_string_falls_back_to_reject(self) -> None: + assert _normalize_permission_decision(None) == {"decision_type": "reject"} + assert _normalize_permission_decision(42) == {"decision_type": "reject"} + + def test_empty_decisions_list_falls_back_to_reject(self) -> None: + # Fail-closed on a malformed reply rather than treat it as approve. + assert _normalize_permission_decision({"decisions": []}) == { + "decision_type": "reject" + } + + +class TestResumeShapesEndToEnd: + """LangChain HITL envelope reaches ``_process`` correctly via ``_raise_interrupt``.""" + + def test_lc_approve_envelope_keeps_call(self) -> None: + mw = PermissionMiddleware(rulesets=[]) + mw._raise_interrupt = lambda **kw: { # type: ignore[assignment] + "decisions": [{"type": "approve"}] + } + state = {"messages": [_msg({"name": "send_email", "args": {}, "id": "1"})]} + original = mw._raise_interrupt + mw._raise_interrupt = lambda **kw: _normalize_permission_decision( # type: ignore[assignment] + original(**kw) + ) + out = mw.after_model(state, _FakeRuntime()) + assert out is None + + def test_lc_reject_envelope_raises(self) -> None: + mw = PermissionMiddleware(rulesets=[]) + original = lambda **kw: {"decisions": [{"type": "reject"}]} # noqa: E731 + mw._raise_interrupt = lambda **kw: _normalize_permission_decision( # type: ignore[assignment] + original(**kw) + ) + state = {"messages": [_msg({"name": "send_email", "args": {}, "id": "1"})]} + with pytest.raises(RejectedError): + mw.after_model(state, _FakeRuntime()) + + def test_lc_reject_with_message_raises_corrected(self) -> None: + mw = PermissionMiddleware(rulesets=[]) + original = lambda **kw: { # noqa: E731 + "decisions": [{"type": "reject", "message": "wrong recipient"}] + } + mw._raise_interrupt = lambda **kw: _normalize_permission_decision( # type: ignore[assignment] + original(**kw) + ) + state = {"messages": [_msg({"name": "send_email", "args": {}, "id": "1"})]} + with pytest.raises(CorrectedError) as excinfo: + mw.after_model(state, _FakeRuntime()) + assert excinfo.value.feedback == "wrong recipient" + + def test_lc_edit_envelope_keeps_call_with_original_args(self) -> None: + # Pins the "edit -> once, args unchanged" contract. + mw = PermissionMiddleware(rulesets=[]) + original = lambda **kw: { # noqa: E731 + "decisions": [ + { + "type": "edit", + "edited_action": { + "name": "send_email", + "args": {"to": "edited@example.com"}, + }, + } + ] + } + mw._raise_interrupt = lambda **kw: _normalize_permission_decision( # type: ignore[assignment] + original(**kw) + ) + state = { + "messages": [ + _msg( + { + "name": "send_email", + "args": {"to": "original@example.com"}, + "id": "1", + } + ) + ] + } + out = mw.after_model(state, _FakeRuntime()) + assert out is None diff --git a/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx b/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx index e5ac61cd9..21fc4cf1a 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx @@ -146,6 +146,31 @@ function markInterruptsCompleted(contentParts: Array<{ type: string; result?: un } } +/** + * Most recent pending tool-call card with this name, so a new HITL interrupt + * does not overwrite an already-approved card with the same tool name. + */ +function findHitlTargetToolCallId( + toolCallIndices: Map, + contentParts: Array<{ + type: string; + toolName?: string; + result?: unknown; + }>, + toolName: string +): string | null { + const entries = Array.from(toolCallIndices.entries()); + for (let i = entries.length - 1; i >= 0; i--) { + const [tcId, idx] = entries[i]; + const part = contentParts[idx]; + if (!part || part.type !== "tool-call" || part.toolName !== toolName) continue; + const result = part.result as Record | undefined | null; + if (result == null) return tcId; + if (result.__interrupt__ === true && !result.__decided__) return tcId; + } + return null; +} + /** * Zod schema for mentioned document info (for type-safe parsing) */ @@ -949,12 +974,13 @@ export default function NewChatPage() { args: Record; }>; for (const action of actionRequests) { - const existingIdx = Array.from(toolCallIndices.entries()).find(([, idx]) => { - const part = contentParts[idx]; - return part?.type === "tool-call" && part.toolName === action.name; - }); - if (existingIdx) { - updateToolCall(contentPartsState, existingIdx[0], { + const targetTcId = findHitlTargetToolCallId( + toolCallIndices, + contentParts, + action.name + ); + if (targetTcId) { + updateToolCall(contentPartsState, targetTcId, { result: { __interrupt__: true, ...interruptData }, }); } else { @@ -1265,6 +1291,7 @@ export default function NewChatPage() { body: JSON.stringify({ search_space_id: searchSpaceId, decisions, + disabled_tools: disabledTools.length > 0 ? disabledTools : undefined, filesystem_mode: selection.filesystem_mode, client_platform: selection.client_platform, local_filesystem_mounts: selection.local_filesystem_mounts, @@ -1388,12 +1415,13 @@ export default function NewChatPage() { args: Record; }>; for (const action of actionRequests) { - const existingIdx = Array.from(toolCallIndices.entries()).find(([, idx]) => { - const part = contentParts[idx]; - return part?.type === "tool-call" && part.toolName === action.name; - }); - if (existingIdx) { - updateToolCall(contentPartsState, existingIdx[0], { + const targetTcId = findHitlTargetToolCallId( + toolCallIndices, + contentParts, + action.name + ); + if (targetTcId) { + updateToolCall(contentPartsState, targetTcId, { result: { __interrupt__: true, ...interruptData, @@ -1514,6 +1542,25 @@ export default function NewChatPage() { const decision = detail.decisions[0]; const decisionType = decision?.type as "approve" | "reject" | "edit"; + // Fan a single click out to N decisions when the backend bundled + // N tool calls into one HITLRequest (one Approve/Reject covers + // the whole batch until per-card decisions land). + const interruptData = pendingInterrupt.interruptData as + | { action_requests?: unknown[] } + | undefined; + const expectedCount = Array.isArray(interruptData?.action_requests) + ? interruptData.action_requests.length + : detail.decisions.length; + const submittedDecisions = + detail.decisions.length >= expectedCount || expectedCount <= 1 + ? detail.decisions + : [ + ...detail.decisions, + ...Array.from({ length: expectedCount - detail.decisions.length }, () => ({ + ...detail.decisions[detail.decisions.length - 1], + })), + ]; + setMessages((prev) => prev.map((m) => { if (m.id !== pendingInterrupt.assistantMsgId) return m; @@ -1554,7 +1601,7 @@ export default function NewChatPage() { return { ...m, content: newContent as unknown as ThreadMessageLike["content"] }; }) ); - handleResume(detail.decisions); + handleResume(submittedDecisions); } }; window.addEventListener("hitl-decision", handler);