Harden HITL for multi-step tasks: bypass internal MCP gate, full-args dedup, and decision-envelope normalization.

This commit is contained in:
CREDO23 2026-05-04 19:25:27 +02:00
parent 4ac3f0b304
commit 277bd50f37
6 changed files with 442 additions and 65 deletions

View file

@ -21,6 +21,7 @@ A tool with no resolver from either path simply opts out of dedup.
from __future__ import annotations from __future__ import annotations
import json
import logging import logging
from collections.abc import Callable from collections.abc import Callable
from typing import Any from typing import Any
@ -57,6 +58,19 @@ def wrap_dedup_key_by_arg_name(arg_name: str) -> DedupResolver:
return _resolver return _resolver
def dedup_key_full_args(args: dict[str, Any]) -> str:
"""Resolver that collapses calls only when **every** argument is identical.
Safe default for tools where no single field uniquely identifies a call
(e.g. MCP tools whose first required field is a shared workspace id).
"""
try:
return json.dumps(args, sort_keys=True, default=str)
except (TypeError, ValueError):
return repr(sorted(args.items())) if isinstance(args, dict) else repr(args)
# Backwards-compatible alias for code that imported the original # Backwards-compatible alias for code that imported the original
# private name. New callers should use :func:`wrap_dedup_key_by_arg_name`. # private name. New callers should use :func:`wrap_dedup_key_by_arg_name`.
_wrap_string_key = wrap_dedup_key_by_arg_name _wrap_string_key = wrap_dedup_key_by_arg_name

View file

@ -19,8 +19,9 @@ Operation:
the results: ``deny`` > ``ask`` > ``allow``. the results: ``deny`` > ``ask`` > ``allow``.
3. On ``deny``: replaces the call with a synthetic ``ToolMessage`` 3. On ``deny``: replaces the call with a synthetic ``ToolMessage``
containing a :class:`StreamingError`. containing a :class:`StreamingError`.
4. On ``ask``: raises a SurfSense-style ``interrupt(...)``. The reply 4. On ``ask``: raises a SurfSense-style ``interrupt(...)``. Both the legacy
shape is ``{"decision_type": "once|always|reject", "feedback"?: str}``. SurfSense shape and LangChain HITL ``{"decisions": [{"type": ...}]}``
replies are accepted via :func:`_normalize_permission_decision`.
- ``once``: proceed. - ``once``: proceed.
- ``always``: also persist allow rules for ``request.always`` patterns. - ``always``: also persist allow rules for ``request.always`` patterns.
- ``reject`` w/o feedback: raise :class:`RejectedError`. - ``reject`` w/o feedback: raise :class:`RejectedError`.
@ -81,6 +82,75 @@ def _default_pattern_resolver(name: str) -> PatternResolver:
return _resolve return _resolve
# Translation from the LangChain HITL envelope (what ``stream_resume_chat``
# sends) to SurfSense's legacy ``decision_type`` shape. ``edit`` keeps the
# original tool args — tools needing argument edits should use
# ``request_approval`` from ``app/agents/new_chat/tools/hitl.py``.
_LC_TYPE_TO_PERMISSION_DECISION: dict[str, str] = {
"approve": "once",
"reject": "reject",
"edit": "once",
}
def _normalize_permission_decision(decision: Any) -> dict[str, Any]:
"""Coerce any accepted reply shape into ``{"decision_type": ..., "feedback"?}``.
Falls back to ``reject`` (with a warning) on unrecognized payloads so the
middleware fails closed.
"""
if isinstance(decision, str):
return {"decision_type": decision}
if not isinstance(decision, dict):
logger.warning(
"Unrecognized permission resume value (%s); treating as reject",
type(decision).__name__,
)
return {"decision_type": "reject"}
if decision.get("decision_type"):
return decision
payload: dict[str, Any] = decision
decisions = decision.get("decisions")
if isinstance(decisions, list) and decisions:
first = decisions[0]
if isinstance(first, dict):
payload = first
raw_type = payload.get("type") or payload.get("decision_type")
if not raw_type:
logger.warning(
"Permission resume missing decision type (keys=%s); treating as reject",
list(payload.keys()),
)
return {"decision_type": "reject"}
raw_type = str(raw_type).lower()
mapped = _LC_TYPE_TO_PERMISSION_DECISION.get(raw_type)
if mapped is None:
# Tolerate legacy values arriving without ``decision_type`` wrapping.
if raw_type in {"once", "always", "reject"}:
mapped = raw_type
else:
logger.warning(
"Unknown permission decision type %r; treating as reject", raw_type
)
mapped = "reject"
if raw_type == "edit":
logger.warning(
"Permission middleware received an 'edit' decision; original args "
"kept (edits not merged here)."
)
out: dict[str, Any] = {"decision_type": mapped}
feedback = payload.get("feedback") or payload.get("message")
if isinstance(feedback, str) and feedback.strip():
out["feedback"] = feedback
return out
class PermissionMiddleware(AgentMiddleware): # type: ignore[type-arg] class PermissionMiddleware(AgentMiddleware): # type: ignore[type-arg]
"""Allow/deny/ask layer over the agent's tool calls. """Allow/deny/ask layer over the agent's tool calls.
@ -214,12 +284,7 @@ class PermissionMiddleware(AgentMiddleware): # type: ignore[type-arg]
ot.interrupt_span(interrupt_type="permission_ask"), ot.interrupt_span(interrupt_type="permission_ask"),
): ):
decision = interrupt(payload) decision = interrupt(payload)
if isinstance(decision, dict): return _normalize_permission_decision(decision)
return decision
# Tolerate a plain string reply ("once", "always", "reject")
if isinstance(decision, str):
return {"decision_type": decision}
return {"decision_type": "reject"}
def _persist_always(self, tool_name: str, patterns: list[str]) -> None: def _persist_always(self, tool_name: str, patterns: list[str]) -> None:
"""Promote ``always`` reply into runtime allow rules. """Promote ``always`` reply into runtime allow rules.
@ -355,4 +420,5 @@ class PermissionMiddleware(AgentMiddleware): # type: ignore[type-arg]
__all__ = [ __all__ = [
"PatternResolver", "PatternResolver",
"PermissionMiddleware", "PermissionMiddleware",
"_normalize_permission_decision",
] ]

View file

@ -33,6 +33,7 @@ from sqlalchemy import cast, select
from sqlalchemy.dialects.postgresql import JSONB from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.new_chat.middleware.dedup_tool_calls import dedup_key_full_args
from app.agents.new_chat.tools.hitl import request_approval from app.agents.new_chat.tools.hitl import request_approval
from app.agents.new_chat.tools.mcp_client import MCPClient from app.agents.new_chat.tools.mcp_client import MCPClient
from app.db import SearchSourceConnector from app.db import SearchSourceConnector
@ -45,7 +46,10 @@ _MCP_CACHE_MAX_SIZE = 50
_MCP_DISCOVERY_TIMEOUT_SECONDS = 30 _MCP_DISCOVERY_TIMEOUT_SECONDS = 30
_TOOL_CALL_MAX_RETRIES = 3 _TOOL_CALL_MAX_RETRIES = 3
_TOOL_CALL_RETRY_DELAY = 1.5 # seconds, doubles per attempt _TOOL_CALL_RETRY_DELAY = 1.5 # seconds, doubles per attempt
_mcp_tools_cache: dict[int, tuple[float, list[StructuredTool]]] = {} # Keyed by ``(search_space_id, bypass_internal_hitl)`` so single-agent and
# multi-agent paths cannot share tool closures with different HITL wiring.
_MCPCacheKey = tuple[int, bool]
_mcp_tools_cache: dict[_MCPCacheKey, tuple[float, list[StructuredTool]]] = {}
def _evict_expired_mcp_cache() -> None: def _evict_expired_mcp_cache() -> None:
@ -137,12 +141,13 @@ async def _create_mcp_tool_from_definition_stdio(
connector_name: str = "", connector_name: str = "",
connector_id: int | None = None, connector_id: int | None = None,
trusted_tools: list[str] | None = None, trusted_tools: list[str] | None = None,
bypass_internal_hitl: bool = False,
) -> StructuredTool: ) -> StructuredTool:
"""Create a LangChain tool from an MCP tool definition (stdio transport). """Create a LangChain tool from an MCP tool definition (stdio transport).
All MCP tools are unconditionally wrapped with HITL approval. Set ``bypass_internal_hitl=True`` when an outer ``HumanInTheLoopMiddleware``
``request_approval()`` is called OUTSIDE the try/except so that already gates the tool, otherwise the body's ``request_approval()`` is the
``GraphInterrupt`` propagates cleanly to LangGraph. sole HITL gate (single-agent path).
""" """
tool_name = tool_def.get("name", "unnamed_tool") tool_name = tool_def.get("name", "unnamed_tool")
raw_description = tool_def.get("description", "No description provided") raw_description = tool_def.get("description", "No description provided")
@ -161,24 +166,29 @@ async def _create_mcp_tool_from_definition_stdio(
"""Execute the MCP tool call via the client with retry support.""" """Execute the MCP tool call via the client with retry support."""
logger.debug("MCP tool '%s' called", tool_name) logger.debug("MCP tool '%s' called", tool_name)
# HITL — OUTSIDE try/except so GraphInterrupt propagates to LangGraph if bypass_internal_hitl:
hitl_result = request_approval( call_kwargs = _unpack_synthetic_input_data(
action_type="mcp_tool_call", {k: v for k, v in kwargs.items() if v is not None}
tool_name=tool_name, )
params=kwargs, else:
context={ # Outside try/except so ``GraphInterrupt`` propagates to LangGraph.
"mcp_server": connector_name, hitl_result = request_approval(
"tool_description": raw_description, action_type="mcp_tool_call",
"mcp_transport": "stdio", tool_name=tool_name,
"mcp_connector_id": connector_id, params=kwargs,
}, context={
trusted_tools=trusted_tools, "mcp_server": connector_name,
) "tool_description": raw_description,
if hitl_result.rejected: "mcp_transport": "stdio",
return "Tool call rejected by user." "mcp_connector_id": connector_id,
call_kwargs = _unpack_synthetic_input_data( },
{k: v for k, v in hitl_result.params.items() if v is not None} trusted_tools=trusted_tools,
) )
if hitl_result.rejected:
return "Tool call rejected by user."
call_kwargs = _unpack_synthetic_input_data(
{k: v for k, v in hitl_result.params.items() if v is not None}
)
last_error: Exception | None = None last_error: Exception | None = None
for attempt in range(_TOOL_CALL_MAX_RETRIES): for attempt in range(_TOOL_CALL_MAX_RETRIES):
@ -221,7 +231,9 @@ async def _create_mcp_tool_from_definition_stdio(
"mcp_connector_name": connector_name or None, "mcp_connector_name": connector_name or None,
"mcp_is_generic": True, "mcp_is_generic": True,
"hitl": True, "hitl": True,
"hitl_dedup_key": next(iter(input_schema.get("required", [])), None), # Full-args hash: shared identifiers (cloudId, workspaceId, …)
# would otherwise collapse legitimate batches.
"dedup_key": dedup_key_full_args,
}, },
) )
@ -240,11 +252,14 @@ async def _create_mcp_tool_from_definition_http(
readonly_tools: frozenset[str] | None = None, readonly_tools: frozenset[str] | None = None,
tool_name_prefix: str | None = None, tool_name_prefix: str | None = None,
is_generic_mcp: bool = False, is_generic_mcp: bool = False,
bypass_internal_hitl: bool = False,
) -> StructuredTool: ) -> StructuredTool:
"""Create a LangChain tool from an MCP tool definition (HTTP transport). """Create a LangChain tool from an MCP tool definition (HTTP transport).
Write tools are wrapped with HITL approval; read-only tools (listed in Write tools are wrapped with HITL approval; read-only tools (listed in
``readonly_tools``) execute immediately without user confirmation. ``readonly_tools``) execute immediately without user confirmation. Set
``bypass_internal_hitl=True`` when an outer ``HumanInTheLoopMiddleware``
already gates the tool.
When ``tool_name_prefix`` is set (multi-account disambiguation), the When ``tool_name_prefix`` is set (multi-account disambiguation), the
tool exposed to the LLM gets a prefixed name (e.g. ``linear_25_list_issues``) tool exposed to the LLM gets a prefixed name (e.g. ``linear_25_list_issues``)
@ -302,7 +317,7 @@ async def _create_mcp_tool_from_definition_http(
"""Execute the MCP tool call via HTTP transport.""" """Execute the MCP tool call via HTTP transport."""
logger.debug("MCP HTTP tool '%s' called", exposed_name) logger.debug("MCP HTTP tool '%s' called", exposed_name)
if is_readonly: if is_readonly or bypass_internal_hitl:
call_kwargs = _unpack_synthetic_input_data( call_kwargs = _unpack_synthetic_input_data(
{k: v for k, v in kwargs.items() if v is not None} {k: v for k, v in kwargs.items() if v is not None}
) )
@ -385,7 +400,9 @@ async def _create_mcp_tool_from_definition_http(
"mcp_connector_name": connector_name or None, "mcp_connector_name": connector_name or None,
"mcp_is_generic": is_generic_mcp, "mcp_is_generic": is_generic_mcp,
"hitl": not is_readonly, "hitl": not is_readonly,
"hitl_dedup_key": next(iter(input_schema.get("required", [])), None), # Full-args hash: shared identifiers (cloudId, workspaceId, …)
# would otherwise collapse legitimate batches.
"dedup_key": dedup_key_full_args,
"mcp_original_tool_name": original_tool_name, "mcp_original_tool_name": original_tool_name,
"mcp_connector_id": connector_id, "mcp_connector_id": connector_id,
}, },
@ -400,6 +417,8 @@ async def _load_stdio_mcp_tools(
connector_name: str, connector_name: str,
server_config: dict[str, Any], server_config: dict[str, Any],
trusted_tools: list[str] | None = None, trusted_tools: list[str] | None = None,
*,
bypass_internal_hitl: bool = False,
) -> list[StructuredTool]: ) -> list[StructuredTool]:
"""Load tools from a stdio-based MCP server.""" """Load tools from a stdio-based MCP server."""
tools: list[StructuredTool] = [] tools: list[StructuredTool] = []
@ -451,6 +470,7 @@ async def _load_stdio_mcp_tools(
connector_name=connector_name, connector_name=connector_name,
connector_id=connector_id, connector_id=connector_id,
trusted_tools=trusted_tools, trusted_tools=trusted_tools,
bypass_internal_hitl=bypass_internal_hitl,
) )
tools.append(tool) tools.append(tool)
except Exception as e: except Exception as e:
@ -473,6 +493,8 @@ async def _load_http_mcp_tools(
readonly_tools: frozenset[str] | None = None, readonly_tools: frozenset[str] | None = None,
tool_name_prefix: str | None = None, tool_name_prefix: str | None = None,
is_generic_mcp: bool = False, is_generic_mcp: bool = False,
*,
bypass_internal_hitl: bool = False,
) -> list[StructuredTool]: ) -> list[StructuredTool]:
"""Load tools from an HTTP-based MCP server. """Load tools from an HTTP-based MCP server.
@ -598,6 +620,7 @@ async def _load_http_mcp_tools(
readonly_tools=readonly_tools, readonly_tools=readonly_tools,
tool_name_prefix=tool_name_prefix, tool_name_prefix=tool_name_prefix,
is_generic_mcp=is_generic_mcp, is_generic_mcp=is_generic_mcp,
bypass_internal_hitl=bypass_internal_hitl,
) )
tools.append(tool) tools.append(tool)
except Exception as e: except Exception as e:
@ -905,14 +928,10 @@ async def _mark_connector_auth_expired(connector_id: int) -> None:
def invalidate_mcp_tools_cache(search_space_id: int | None = None) -> None: def invalidate_mcp_tools_cache(search_space_id: int | None = None) -> None:
"""Invalidate cached MCP tools. """Invalidate cached MCP tools (both ``bypass_internal_hitl`` variants together)."""
Args:
search_space_id: If provided, only invalidate for this search space.
If None, invalidate all cached MCP tools.
"""
if search_space_id is not None: if search_space_id is not None:
_mcp_tools_cache.pop(search_space_id, None) for key in [k for k in _mcp_tools_cache if k[0] == search_space_id]:
_mcp_tools_cache.pop(key, None)
else: else:
_mcp_tools_cache.clear() _mcp_tools_cache.clear()
@ -920,27 +939,29 @@ def invalidate_mcp_tools_cache(search_space_id: int | None = None) -> None:
async def load_mcp_tools( async def load_mcp_tools(
session: AsyncSession, session: AsyncSession,
search_space_id: int, search_space_id: int,
*,
bypass_internal_hitl: bool = False,
) -> list[StructuredTool]: ) -> list[StructuredTool]:
"""Load all MCP tools from user's active MCP server connectors. """Load all MCP tools from the user's active MCP server connectors.
This discovers tools dynamically from MCP servers using the protocol. Results are cached per ``(search_space_id, bypass_internal_hitl)`` for up
Supports both stdio (local process) and HTTP (remote server) transports. to 5 minutes; bypass is keyed because each variant builds a different tool
closure (with vs. without the in-wrapper ``request_approval`` gate).
Results are cached per search space for up to 5 minutes to avoid
re-spawning MCP server processes on every chat message.
""" """
_evict_expired_mcp_cache() _evict_expired_mcp_cache()
now = time.monotonic() now = time.monotonic()
cached = _mcp_tools_cache.get(search_space_id) cache_key: _MCPCacheKey = (search_space_id, bypass_internal_hitl)
cached = _mcp_tools_cache.get(cache_key)
if cached is not None: if cached is not None:
cached_at, cached_tools = cached cached_at, cached_tools = cached
if now - cached_at < _MCP_CACHE_TTL_SECONDS: if now - cached_at < _MCP_CACHE_TTL_SECONDS:
logger.info( logger.info(
"Using cached MCP tools for search space %s (%d tools, age=%.0fs)", "Using cached MCP tools for search space %s (%d tools, age=%.0fs, bypass_hitl=%s)",
search_space_id, search_space_id,
len(cached_tools), len(cached_tools),
now - cached_at, now - cached_at,
bypass_internal_hitl,
) )
return list(cached_tools) return list(cached_tools)
@ -1064,6 +1085,7 @@ async def load_mcp_tools(
readonly_tools=task["readonly_tools"], readonly_tools=task["readonly_tools"],
tool_name_prefix=task["tool_name_prefix"], tool_name_prefix=task["tool_name_prefix"],
is_generic_mcp=task.get("is_generic_mcp", False), is_generic_mcp=task.get("is_generic_mcp", False),
bypass_internal_hitl=bypass_internal_hitl,
), ),
timeout=_MCP_DISCOVERY_TIMEOUT_SECONDS, timeout=_MCP_DISCOVERY_TIMEOUT_SECONDS,
) )
@ -1074,6 +1096,7 @@ async def load_mcp_tools(
task["connector_name"], task["connector_name"],
task["server_config"], task["server_config"],
trusted_tools=task["trusted_tools"], trusted_tools=task["trusted_tools"],
bypass_internal_hitl=bypass_internal_hitl,
), ),
timeout=_MCP_DISCOVERY_TIMEOUT_SECONDS, timeout=_MCP_DISCOVERY_TIMEOUT_SECONDS,
) )
@ -1095,14 +1118,17 @@ async def load_mcp_tools(
results = await asyncio.gather(*[_discover_one(t) for t in discovery_tasks]) results = await asyncio.gather(*[_discover_one(t) for t in discovery_tasks])
tools: list[StructuredTool] = [tool for sublist in results for tool in sublist] tools: list[StructuredTool] = [tool for sublist in results for tool in sublist]
_mcp_tools_cache[search_space_id] = (now, tools) _mcp_tools_cache[cache_key] = (now, tools)
if len(_mcp_tools_cache) > _MCP_CACHE_MAX_SIZE: if len(_mcp_tools_cache) > _MCP_CACHE_MAX_SIZE:
oldest_key = min(_mcp_tools_cache, key=lambda k: _mcp_tools_cache[k][0]) oldest_key = min(_mcp_tools_cache, key=lambda k: _mcp_tools_cache[k][0])
del _mcp_tools_cache[oldest_key] del _mcp_tools_cache[oldest_key]
logger.info( logger.info(
"Loaded %d MCP tools for search space %d", len(tools), search_space_id "Loaded %d MCP tools for search space %d (bypass_hitl=%s)",
len(tools),
search_space_id,
bypass_internal_hitl,
) )
return tools return tools

View file

@ -130,6 +130,79 @@ def test_registry_propagates_dedup_key_to_tool_metadata() -> None:
assert sample == "plan" assert sample == "plan"
def test_full_args_dedup_keeps_distinct_calls_sharing_a_field() -> None:
"""Regression: MCP tools (e.g. ``createJiraIssue``) used to dedup on
the schema's first required field, which is often the workspace /
cloudId so 3 distinct issues in the same workspace collapsed to 1.
With :func:`dedup_key_full_args` only fully identical arg dicts dedup.
"""
from app.agents.new_chat.middleware.dedup_tool_calls import dedup_key_full_args
tool = _make_tool("createJiraIssue", dedup_key=dedup_key_full_args)
mw = DedupHITLToolCallsMiddleware(agent_tools=[tool])
state = {
"messages": [
_msg(
{
"name": "createJiraIssue",
"args": {
"cloudId": "ws.atlassian.net",
"projectKey": "PROJ",
"summary": "Fix login bug",
},
"id": "1",
},
{
"name": "createJiraIssue",
"args": {
"cloudId": "ws.atlassian.net",
"projectKey": "PROJ",
"summary": "Add dark mode",
},
"id": "2",
},
{
"name": "createJiraIssue",
"args": {
"cloudId": "ws.atlassian.net",
"projectKey": "PROJ",
"summary": "Improve perf",
},
"id": "3",
},
)
]
}
out = mw.after_model(state, _Runtime())
assert out is None # nothing dropped — all three differ in summary
def test_full_args_dedup_drops_only_exact_duplicates() -> None:
from app.agents.new_chat.middleware.dedup_tool_calls import dedup_key_full_args
tool = _make_tool("createJiraIssue", dedup_key=dedup_key_full_args)
mw = DedupHITLToolCallsMiddleware(agent_tools=[tool])
args = {"cloudId": "ws.atlassian.net", "summary": "Fix bug"}
state = {
"messages": [
_msg(
{"name": "createJiraIssue", "args": args, "id": "1"},
{"name": "createJiraIssue", "args": dict(args), "id": "2"},
{
"name": "createJiraIssue",
"args": {**args, "summary": "Different"},
"id": "3",
},
)
]
}
out = mw.after_model(state, _Runtime())
assert out is not None
new_calls = out["messages"][0].tool_calls
assert {c["id"] for c in new_calls} == {"1", "3"}
def test_unknown_tool_passes_through() -> None: def test_unknown_tool_passes_through() -> None:
mw = DedupHITLToolCallsMiddleware(agent_tools=None) mw = DedupHITLToolCallsMiddleware(agent_tools=None)
state = { state = {

View file

@ -6,7 +6,10 @@ import pytest
from langchain_core.messages import AIMessage, ToolMessage from langchain_core.messages import AIMessage, ToolMessage
from app.agents.new_chat.errors import CorrectedError, RejectedError from app.agents.new_chat.errors import CorrectedError, RejectedError
from app.agents.new_chat.middleware.permission import PermissionMiddleware from app.agents.new_chat.middleware.permission import (
PermissionMiddleware,
_normalize_permission_decision,
)
from app.agents.new_chat.permissions import Rule, Ruleset from app.agents.new_chat.permissions import Rule, Ruleset
pytestmark = pytest.mark.unit pytestmark = pytest.mark.unit
@ -112,3 +115,151 @@ class TestAsk:
# Runtime ruleset got the always-allow rule # Runtime ruleset got the always-allow rule
new_rules = [r for r in mw._runtime_ruleset.rules if r.action == "allow"] new_rules = [r for r in mw._runtime_ruleset.rules if r.action == "allow"]
assert any(r.permission == "send_email" for r in new_rules) assert any(r.permission == "send_email" for r in new_rules)
class TestNormalizeDecision:
"""Resume shapes ``_normalize_permission_decision`` must accept."""
def test_legacy_decision_type_dict_passes_through(self) -> None:
decision = {"decision_type": "once"}
assert _normalize_permission_decision(decision) == {"decision_type": "once"}
def test_legacy_decision_type_with_feedback_passes_through(self) -> None:
decision = {"decision_type": "reject", "feedback": "no thanks"}
assert _normalize_permission_decision(decision) == decision
def test_plain_string_wrapped(self) -> None:
assert _normalize_permission_decision("once") == {"decision_type": "once"}
assert _normalize_permission_decision("reject") == {"decision_type": "reject"}
def test_lc_envelope_approve_maps_to_once(self) -> None:
decision = {"decisions": [{"type": "approve"}]}
assert _normalize_permission_decision(decision) == {"decision_type": "once"}
def test_lc_envelope_reject_maps_to_reject(self) -> None:
decision = {"decisions": [{"type": "reject"}]}
assert _normalize_permission_decision(decision) == {"decision_type": "reject"}
def test_lc_envelope_reject_with_message_carries_feedback(self) -> None:
decision = {
"decisions": [{"type": "reject", "message": "wrong recipient"}]
}
out = _normalize_permission_decision(decision)
assert out == {"decision_type": "reject", "feedback": "wrong recipient"}
def test_lc_envelope_reject_with_feedback_field(self) -> None:
decision = {
"decisions": [{"type": "reject", "feedback": "tighten the subject"}]
}
out = _normalize_permission_decision(decision)
assert out == {"decision_type": "reject", "feedback": "tighten the subject"}
def test_lc_envelope_edit_maps_to_once(self) -> None:
# Pins the contract: edited args are NOT merged by permission.
decision = {
"decisions": [
{
"type": "edit",
"edited_action": {
"name": "send_email",
"args": {"subject": "edited"},
},
}
]
}
assert _normalize_permission_decision(decision) == {"decision_type": "once"}
def test_lc_single_decision_without_envelope(self) -> None:
assert _normalize_permission_decision({"type": "approve"}) == {
"decision_type": "once"
}
def test_unknown_type_falls_back_to_reject(self) -> None:
decision = {"decisions": [{"type": "totally_unknown"}]}
assert _normalize_permission_decision(decision) == {"decision_type": "reject"}
def test_missing_type_falls_back_to_reject(self) -> None:
assert _normalize_permission_decision({"decisions": [{}]}) == {
"decision_type": "reject"
}
def test_non_dict_non_string_falls_back_to_reject(self) -> None:
assert _normalize_permission_decision(None) == {"decision_type": "reject"}
assert _normalize_permission_decision(42) == {"decision_type": "reject"}
def test_empty_decisions_list_falls_back_to_reject(self) -> None:
# Fail-closed on a malformed reply rather than treat it as approve.
assert _normalize_permission_decision({"decisions": []}) == {
"decision_type": "reject"
}
class TestResumeShapesEndToEnd:
"""LangChain HITL envelope reaches ``_process`` correctly via ``_raise_interrupt``."""
def test_lc_approve_envelope_keeps_call(self) -> None:
mw = PermissionMiddleware(rulesets=[])
mw._raise_interrupt = lambda **kw: { # type: ignore[assignment]
"decisions": [{"type": "approve"}]
}
state = {"messages": [_msg({"name": "send_email", "args": {}, "id": "1"})]}
original = mw._raise_interrupt
mw._raise_interrupt = lambda **kw: _normalize_permission_decision( # type: ignore[assignment]
original(**kw)
)
out = mw.after_model(state, _FakeRuntime())
assert out is None
def test_lc_reject_envelope_raises(self) -> None:
mw = PermissionMiddleware(rulesets=[])
original = lambda **kw: {"decisions": [{"type": "reject"}]} # noqa: E731
mw._raise_interrupt = lambda **kw: _normalize_permission_decision( # type: ignore[assignment]
original(**kw)
)
state = {"messages": [_msg({"name": "send_email", "args": {}, "id": "1"})]}
with pytest.raises(RejectedError):
mw.after_model(state, _FakeRuntime())
def test_lc_reject_with_message_raises_corrected(self) -> None:
mw = PermissionMiddleware(rulesets=[])
original = lambda **kw: { # noqa: E731
"decisions": [{"type": "reject", "message": "wrong recipient"}]
}
mw._raise_interrupt = lambda **kw: _normalize_permission_decision( # type: ignore[assignment]
original(**kw)
)
state = {"messages": [_msg({"name": "send_email", "args": {}, "id": "1"})]}
with pytest.raises(CorrectedError) as excinfo:
mw.after_model(state, _FakeRuntime())
assert excinfo.value.feedback == "wrong recipient"
def test_lc_edit_envelope_keeps_call_with_original_args(self) -> None:
# Pins the "edit -> once, args unchanged" contract.
mw = PermissionMiddleware(rulesets=[])
original = lambda **kw: { # noqa: E731
"decisions": [
{
"type": "edit",
"edited_action": {
"name": "send_email",
"args": {"to": "edited@example.com"},
},
}
]
}
mw._raise_interrupt = lambda **kw: _normalize_permission_decision( # type: ignore[assignment]
original(**kw)
)
state = {
"messages": [
_msg(
{
"name": "send_email",
"args": {"to": "original@example.com"},
"id": "1",
}
)
]
}
out = mw.after_model(state, _FakeRuntime())
assert out is None

View file

@ -146,6 +146,31 @@ function markInterruptsCompleted(contentParts: Array<{ type: string; result?: un
} }
} }
/**
* Most recent pending tool-call card with this name, so a new HITL interrupt
* does not overwrite an already-approved card with the same tool name.
*/
function findHitlTargetToolCallId(
toolCallIndices: Map<string, number>,
contentParts: Array<{
type: string;
toolName?: string;
result?: unknown;
}>,
toolName: string
): string | null {
const entries = Array.from(toolCallIndices.entries());
for (let i = entries.length - 1; i >= 0; i--) {
const [tcId, idx] = entries[i];
const part = contentParts[idx];
if (!part || part.type !== "tool-call" || part.toolName !== toolName) continue;
const result = part.result as Record<string, unknown> | undefined | null;
if (result == null) return tcId;
if (result.__interrupt__ === true && !result.__decided__) return tcId;
}
return null;
}
/** /**
* Zod schema for mentioned document info (for type-safe parsing) * Zod schema for mentioned document info (for type-safe parsing)
*/ */
@ -949,12 +974,13 @@ export default function NewChatPage() {
args: Record<string, unknown>; args: Record<string, unknown>;
}>; }>;
for (const action of actionRequests) { for (const action of actionRequests) {
const existingIdx = Array.from(toolCallIndices.entries()).find(([, idx]) => { const targetTcId = findHitlTargetToolCallId(
const part = contentParts[idx]; toolCallIndices,
return part?.type === "tool-call" && part.toolName === action.name; contentParts,
}); action.name
if (existingIdx) { );
updateToolCall(contentPartsState, existingIdx[0], { if (targetTcId) {
updateToolCall(contentPartsState, targetTcId, {
result: { __interrupt__: true, ...interruptData }, result: { __interrupt__: true, ...interruptData },
}); });
} else { } else {
@ -1265,6 +1291,7 @@ export default function NewChatPage() {
body: JSON.stringify({ body: JSON.stringify({
search_space_id: searchSpaceId, search_space_id: searchSpaceId,
decisions, decisions,
disabled_tools: disabledTools.length > 0 ? disabledTools : undefined,
filesystem_mode: selection.filesystem_mode, filesystem_mode: selection.filesystem_mode,
client_platform: selection.client_platform, client_platform: selection.client_platform,
local_filesystem_mounts: selection.local_filesystem_mounts, local_filesystem_mounts: selection.local_filesystem_mounts,
@ -1388,12 +1415,13 @@ export default function NewChatPage() {
args: Record<string, unknown>; args: Record<string, unknown>;
}>; }>;
for (const action of actionRequests) { for (const action of actionRequests) {
const existingIdx = Array.from(toolCallIndices.entries()).find(([, idx]) => { const targetTcId = findHitlTargetToolCallId(
const part = contentParts[idx]; toolCallIndices,
return part?.type === "tool-call" && part.toolName === action.name; contentParts,
}); action.name
if (existingIdx) { );
updateToolCall(contentPartsState, existingIdx[0], { if (targetTcId) {
updateToolCall(contentPartsState, targetTcId, {
result: { result: {
__interrupt__: true, __interrupt__: true,
...interruptData, ...interruptData,
@ -1514,6 +1542,25 @@ export default function NewChatPage() {
const decision = detail.decisions[0]; const decision = detail.decisions[0];
const decisionType = decision?.type as "approve" | "reject" | "edit"; const decisionType = decision?.type as "approve" | "reject" | "edit";
// Fan a single click out to N decisions when the backend bundled
// N tool calls into one HITLRequest (one Approve/Reject covers
// the whole batch until per-card decisions land).
const interruptData = pendingInterrupt.interruptData as
| { action_requests?: unknown[] }
| undefined;
const expectedCount = Array.isArray(interruptData?.action_requests)
? interruptData.action_requests.length
: detail.decisions.length;
const submittedDecisions =
detail.decisions.length >= expectedCount || expectedCount <= 1
? detail.decisions
: [
...detail.decisions,
...Array.from({ length: expectedCount - detail.decisions.length }, () => ({
...detail.decisions[detail.decisions.length - 1],
})),
];
setMessages((prev) => setMessages((prev) =>
prev.map((m) => { prev.map((m) => {
if (m.id !== pendingInterrupt.assistantMsgId) return m; if (m.id !== pendingInterrupt.assistantMsgId) return m;
@ -1554,7 +1601,7 @@ export default function NewChatPage() {
return { ...m, content: newContent as unknown as ThreadMessageLike["content"] }; return { ...m, content: newContent as unknown as ThreadMessageLike["content"] };
}) })
); );
handleResume(detail.decisions); handleResume(submittedDecisions);
} }
}; };
window.addEventListener("hitl-decision", handler); window.addEventListener("hitl-decision", handler);