Harden HITL for multi-step tasks: bypass internal MCP gate, full-args dedup, and decision-envelope normalization.

This commit is contained in:
CREDO23 2026-05-04 19:25:27 +02:00
parent 4ac3f0b304
commit 277bd50f37
6 changed files with 442 additions and 65 deletions

View file

@ -21,6 +21,7 @@ A tool with no resolver from either path simply opts out of dedup.
from __future__ import annotations
import json
import logging
from collections.abc import Callable
from typing import Any
@ -57,6 +58,19 @@ def wrap_dedup_key_by_arg_name(arg_name: str) -> DedupResolver:
return _resolver
def dedup_key_full_args(args: dict[str, Any]) -> str:
"""Resolver that collapses calls only when **every** argument is identical.
Safe default for tools where no single field uniquely identifies a call
(e.g. MCP tools whose first required field is a shared workspace id).
"""
try:
return json.dumps(args, sort_keys=True, default=str)
except (TypeError, ValueError):
return repr(sorted(args.items())) if isinstance(args, dict) else repr(args)
# Backwards-compatible alias for code that imported the original
# private name. New callers should use :func:`wrap_dedup_key_by_arg_name`.
_wrap_string_key = wrap_dedup_key_by_arg_name

View file

@ -19,8 +19,9 @@ Operation:
the results: ``deny`` > ``ask`` > ``allow``.
3. On ``deny``: replaces the call with a synthetic ``ToolMessage``
containing a :class:`StreamingError`.
4. On ``ask``: raises a SurfSense-style ``interrupt(...)``. The reply
shape is ``{"decision_type": "once|always|reject", "feedback"?: str}``.
4. On ``ask``: raises a SurfSense-style ``interrupt(...)``. Both the legacy
SurfSense shape and LangChain HITL ``{"decisions": [{"type": ...}]}``
replies are accepted via :func:`_normalize_permission_decision`.
- ``once``: proceed.
- ``always``: also persist allow rules for ``request.always`` patterns.
- ``reject`` w/o feedback: raise :class:`RejectedError`.
@ -81,6 +82,75 @@ def _default_pattern_resolver(name: str) -> PatternResolver:
return _resolve
# Translation from the LangChain HITL envelope (what ``stream_resume_chat``
# sends) to SurfSense's legacy ``decision_type`` shape. ``edit`` keeps the
# original tool args — tools needing argument edits should use
# ``request_approval`` from ``app/agents/new_chat/tools/hitl.py``.
_LC_TYPE_TO_PERMISSION_DECISION: dict[str, str] = {
"approve": "once",
"reject": "reject",
"edit": "once",
}
def _normalize_permission_decision(decision: Any) -> dict[str, Any]:
"""Coerce any accepted reply shape into ``{"decision_type": ..., "feedback"?}``.
Falls back to ``reject`` (with a warning) on unrecognized payloads so the
middleware fails closed.
"""
if isinstance(decision, str):
return {"decision_type": decision}
if not isinstance(decision, dict):
logger.warning(
"Unrecognized permission resume value (%s); treating as reject",
type(decision).__name__,
)
return {"decision_type": "reject"}
if decision.get("decision_type"):
return decision
payload: dict[str, Any] = decision
decisions = decision.get("decisions")
if isinstance(decisions, list) and decisions:
first = decisions[0]
if isinstance(first, dict):
payload = first
raw_type = payload.get("type") or payload.get("decision_type")
if not raw_type:
logger.warning(
"Permission resume missing decision type (keys=%s); treating as reject",
list(payload.keys()),
)
return {"decision_type": "reject"}
raw_type = str(raw_type).lower()
mapped = _LC_TYPE_TO_PERMISSION_DECISION.get(raw_type)
if mapped is None:
# Tolerate legacy values arriving without ``decision_type`` wrapping.
if raw_type in {"once", "always", "reject"}:
mapped = raw_type
else:
logger.warning(
"Unknown permission decision type %r; treating as reject", raw_type
)
mapped = "reject"
if raw_type == "edit":
logger.warning(
"Permission middleware received an 'edit' decision; original args "
"kept (edits not merged here)."
)
out: dict[str, Any] = {"decision_type": mapped}
feedback = payload.get("feedback") or payload.get("message")
if isinstance(feedback, str) and feedback.strip():
out["feedback"] = feedback
return out
class PermissionMiddleware(AgentMiddleware): # type: ignore[type-arg]
"""Allow/deny/ask layer over the agent's tool calls.
@ -214,12 +284,7 @@ class PermissionMiddleware(AgentMiddleware): # type: ignore[type-arg]
ot.interrupt_span(interrupt_type="permission_ask"),
):
decision = interrupt(payload)
if isinstance(decision, dict):
return decision
# Tolerate a plain string reply ("once", "always", "reject")
if isinstance(decision, str):
return {"decision_type": decision}
return {"decision_type": "reject"}
return _normalize_permission_decision(decision)
def _persist_always(self, tool_name: str, patterns: list[str]) -> None:
"""Promote ``always`` reply into runtime allow rules.
@ -355,4 +420,5 @@ class PermissionMiddleware(AgentMiddleware): # type: ignore[type-arg]
__all__ = [
"PatternResolver",
"PermissionMiddleware",
"_normalize_permission_decision",
]

View file

@ -33,6 +33,7 @@ from sqlalchemy import cast, select
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.new_chat.middleware.dedup_tool_calls import dedup_key_full_args
from app.agents.new_chat.tools.hitl import request_approval
from app.agents.new_chat.tools.mcp_client import MCPClient
from app.db import SearchSourceConnector
@ -45,7 +46,10 @@ _MCP_CACHE_MAX_SIZE = 50
_MCP_DISCOVERY_TIMEOUT_SECONDS = 30
_TOOL_CALL_MAX_RETRIES = 3
_TOOL_CALL_RETRY_DELAY = 1.5 # seconds, doubles per attempt
_mcp_tools_cache: dict[int, tuple[float, list[StructuredTool]]] = {}
# Keyed by ``(search_space_id, bypass_internal_hitl)`` so single-agent and
# multi-agent paths cannot share tool closures with different HITL wiring.
_MCPCacheKey = tuple[int, bool]
_mcp_tools_cache: dict[_MCPCacheKey, tuple[float, list[StructuredTool]]] = {}
def _evict_expired_mcp_cache() -> None:
@ -137,12 +141,13 @@ async def _create_mcp_tool_from_definition_stdio(
connector_name: str = "",
connector_id: int | None = None,
trusted_tools: list[str] | None = None,
bypass_internal_hitl: bool = False,
) -> StructuredTool:
"""Create a LangChain tool from an MCP tool definition (stdio transport).
All MCP tools are unconditionally wrapped with HITL approval.
``request_approval()`` is called OUTSIDE the try/except so that
``GraphInterrupt`` propagates cleanly to LangGraph.
Set ``bypass_internal_hitl=True`` when an outer ``HumanInTheLoopMiddleware``
already gates the tool, otherwise the body's ``request_approval()`` is the
sole HITL gate (single-agent path).
"""
tool_name = tool_def.get("name", "unnamed_tool")
raw_description = tool_def.get("description", "No description provided")
@ -161,24 +166,29 @@ async def _create_mcp_tool_from_definition_stdio(
"""Execute the MCP tool call via the client with retry support."""
logger.debug("MCP tool '%s' called", tool_name)
# HITL — OUTSIDE try/except so GraphInterrupt propagates to LangGraph
hitl_result = request_approval(
action_type="mcp_tool_call",
tool_name=tool_name,
params=kwargs,
context={
"mcp_server": connector_name,
"tool_description": raw_description,
"mcp_transport": "stdio",
"mcp_connector_id": connector_id,
},
trusted_tools=trusted_tools,
)
if hitl_result.rejected:
return "Tool call rejected by user."
call_kwargs = _unpack_synthetic_input_data(
{k: v for k, v in hitl_result.params.items() if v is not None}
)
if bypass_internal_hitl:
call_kwargs = _unpack_synthetic_input_data(
{k: v for k, v in kwargs.items() if v is not None}
)
else:
# Outside try/except so ``GraphInterrupt`` propagates to LangGraph.
hitl_result = request_approval(
action_type="mcp_tool_call",
tool_name=tool_name,
params=kwargs,
context={
"mcp_server": connector_name,
"tool_description": raw_description,
"mcp_transport": "stdio",
"mcp_connector_id": connector_id,
},
trusted_tools=trusted_tools,
)
if hitl_result.rejected:
return "Tool call rejected by user."
call_kwargs = _unpack_synthetic_input_data(
{k: v for k, v in hitl_result.params.items() if v is not None}
)
last_error: Exception | None = None
for attempt in range(_TOOL_CALL_MAX_RETRIES):
@ -221,7 +231,9 @@ async def _create_mcp_tool_from_definition_stdio(
"mcp_connector_name": connector_name or None,
"mcp_is_generic": True,
"hitl": True,
"hitl_dedup_key": next(iter(input_schema.get("required", [])), None),
# Full-args hash: shared identifiers (cloudId, workspaceId, …)
# would otherwise collapse legitimate batches.
"dedup_key": dedup_key_full_args,
},
)
@ -240,11 +252,14 @@ async def _create_mcp_tool_from_definition_http(
readonly_tools: frozenset[str] | None = None,
tool_name_prefix: str | None = None,
is_generic_mcp: bool = False,
bypass_internal_hitl: bool = False,
) -> StructuredTool:
"""Create a LangChain tool from an MCP tool definition (HTTP transport).
Write tools are wrapped with HITL approval; read-only tools (listed in
``readonly_tools``) execute immediately without user confirmation.
``readonly_tools``) execute immediately without user confirmation. Set
``bypass_internal_hitl=True`` when an outer ``HumanInTheLoopMiddleware``
already gates the tool.
When ``tool_name_prefix`` is set (multi-account disambiguation), the
tool exposed to the LLM gets a prefixed name (e.g. ``linear_25_list_issues``)
@ -302,7 +317,7 @@ async def _create_mcp_tool_from_definition_http(
"""Execute the MCP tool call via HTTP transport."""
logger.debug("MCP HTTP tool '%s' called", exposed_name)
if is_readonly:
if is_readonly or bypass_internal_hitl:
call_kwargs = _unpack_synthetic_input_data(
{k: v for k, v in kwargs.items() if v is not None}
)
@ -385,7 +400,9 @@ async def _create_mcp_tool_from_definition_http(
"mcp_connector_name": connector_name or None,
"mcp_is_generic": is_generic_mcp,
"hitl": not is_readonly,
"hitl_dedup_key": next(iter(input_schema.get("required", [])), None),
# Full-args hash: shared identifiers (cloudId, workspaceId, …)
# would otherwise collapse legitimate batches.
"dedup_key": dedup_key_full_args,
"mcp_original_tool_name": original_tool_name,
"mcp_connector_id": connector_id,
},
@ -400,6 +417,8 @@ async def _load_stdio_mcp_tools(
connector_name: str,
server_config: dict[str, Any],
trusted_tools: list[str] | None = None,
*,
bypass_internal_hitl: bool = False,
) -> list[StructuredTool]:
"""Load tools from a stdio-based MCP server."""
tools: list[StructuredTool] = []
@ -451,6 +470,7 @@ async def _load_stdio_mcp_tools(
connector_name=connector_name,
connector_id=connector_id,
trusted_tools=trusted_tools,
bypass_internal_hitl=bypass_internal_hitl,
)
tools.append(tool)
except Exception as e:
@ -473,6 +493,8 @@ async def _load_http_mcp_tools(
readonly_tools: frozenset[str] | None = None,
tool_name_prefix: str | None = None,
is_generic_mcp: bool = False,
*,
bypass_internal_hitl: bool = False,
) -> list[StructuredTool]:
"""Load tools from an HTTP-based MCP server.
@ -598,6 +620,7 @@ async def _load_http_mcp_tools(
readonly_tools=readonly_tools,
tool_name_prefix=tool_name_prefix,
is_generic_mcp=is_generic_mcp,
bypass_internal_hitl=bypass_internal_hitl,
)
tools.append(tool)
except Exception as e:
@ -905,14 +928,10 @@ async def _mark_connector_auth_expired(connector_id: int) -> None:
def invalidate_mcp_tools_cache(search_space_id: int | None = None) -> None:
"""Invalidate cached MCP tools.
Args:
search_space_id: If provided, only invalidate for this search space.
If None, invalidate all cached MCP tools.
"""
"""Invalidate cached MCP tools (both ``bypass_internal_hitl`` variants together)."""
if search_space_id is not None:
_mcp_tools_cache.pop(search_space_id, None)
for key in [k for k in _mcp_tools_cache if k[0] == search_space_id]:
_mcp_tools_cache.pop(key, None)
else:
_mcp_tools_cache.clear()
@ -920,27 +939,29 @@ def invalidate_mcp_tools_cache(search_space_id: int | None = None) -> None:
async def load_mcp_tools(
session: AsyncSession,
search_space_id: int,
*,
bypass_internal_hitl: bool = False,
) -> list[StructuredTool]:
"""Load all MCP tools from user's active MCP server connectors.
"""Load all MCP tools from the user's active MCP server connectors.
This discovers tools dynamically from MCP servers using the protocol.
Supports both stdio (local process) and HTTP (remote server) transports.
Results are cached per search space for up to 5 minutes to avoid
re-spawning MCP server processes on every chat message.
Results are cached per ``(search_space_id, bypass_internal_hitl)`` for up
to 5 minutes; bypass is keyed because each variant builds a different tool
closure (with vs. without the in-wrapper ``request_approval`` gate).
"""
_evict_expired_mcp_cache()
now = time.monotonic()
cached = _mcp_tools_cache.get(search_space_id)
cache_key: _MCPCacheKey = (search_space_id, bypass_internal_hitl)
cached = _mcp_tools_cache.get(cache_key)
if cached is not None:
cached_at, cached_tools = cached
if now - cached_at < _MCP_CACHE_TTL_SECONDS:
logger.info(
"Using cached MCP tools for search space %s (%d tools, age=%.0fs)",
"Using cached MCP tools for search space %s (%d tools, age=%.0fs, bypass_hitl=%s)",
search_space_id,
len(cached_tools),
now - cached_at,
bypass_internal_hitl,
)
return list(cached_tools)
@ -1064,6 +1085,7 @@ async def load_mcp_tools(
readonly_tools=task["readonly_tools"],
tool_name_prefix=task["tool_name_prefix"],
is_generic_mcp=task.get("is_generic_mcp", False),
bypass_internal_hitl=bypass_internal_hitl,
),
timeout=_MCP_DISCOVERY_TIMEOUT_SECONDS,
)
@ -1074,6 +1096,7 @@ async def load_mcp_tools(
task["connector_name"],
task["server_config"],
trusted_tools=task["trusted_tools"],
bypass_internal_hitl=bypass_internal_hitl,
),
timeout=_MCP_DISCOVERY_TIMEOUT_SECONDS,
)
@ -1095,14 +1118,17 @@ async def load_mcp_tools(
results = await asyncio.gather(*[_discover_one(t) for t in discovery_tasks])
tools: list[StructuredTool] = [tool for sublist in results for tool in sublist]
_mcp_tools_cache[search_space_id] = (now, tools)
_mcp_tools_cache[cache_key] = (now, tools)
if len(_mcp_tools_cache) > _MCP_CACHE_MAX_SIZE:
oldest_key = min(_mcp_tools_cache, key=lambda k: _mcp_tools_cache[k][0])
del _mcp_tools_cache[oldest_key]
logger.info(
"Loaded %d MCP tools for search space %d", len(tools), search_space_id
"Loaded %d MCP tools for search space %d (bypass_hitl=%s)",
len(tools),
search_space_id,
bypass_internal_hitl,
)
return tools

View file

@ -130,6 +130,79 @@ def test_registry_propagates_dedup_key_to_tool_metadata() -> None:
assert sample == "plan"
def test_full_args_dedup_keeps_distinct_calls_sharing_a_field() -> None:
"""Regression: MCP tools (e.g. ``createJiraIssue``) used to dedup on
the schema's first required field, which is often the workspace /
cloudId so 3 distinct issues in the same workspace collapsed to 1.
With :func:`dedup_key_full_args` only fully identical arg dicts dedup.
"""
from app.agents.new_chat.middleware.dedup_tool_calls import dedup_key_full_args
tool = _make_tool("createJiraIssue", dedup_key=dedup_key_full_args)
mw = DedupHITLToolCallsMiddleware(agent_tools=[tool])
state = {
"messages": [
_msg(
{
"name": "createJiraIssue",
"args": {
"cloudId": "ws.atlassian.net",
"projectKey": "PROJ",
"summary": "Fix login bug",
},
"id": "1",
},
{
"name": "createJiraIssue",
"args": {
"cloudId": "ws.atlassian.net",
"projectKey": "PROJ",
"summary": "Add dark mode",
},
"id": "2",
},
{
"name": "createJiraIssue",
"args": {
"cloudId": "ws.atlassian.net",
"projectKey": "PROJ",
"summary": "Improve perf",
},
"id": "3",
},
)
]
}
out = mw.after_model(state, _Runtime())
assert out is None # nothing dropped — all three differ in summary
def test_full_args_dedup_drops_only_exact_duplicates() -> None:
from app.agents.new_chat.middleware.dedup_tool_calls import dedup_key_full_args
tool = _make_tool("createJiraIssue", dedup_key=dedup_key_full_args)
mw = DedupHITLToolCallsMiddleware(agent_tools=[tool])
args = {"cloudId": "ws.atlassian.net", "summary": "Fix bug"}
state = {
"messages": [
_msg(
{"name": "createJiraIssue", "args": args, "id": "1"},
{"name": "createJiraIssue", "args": dict(args), "id": "2"},
{
"name": "createJiraIssue",
"args": {**args, "summary": "Different"},
"id": "3",
},
)
]
}
out = mw.after_model(state, _Runtime())
assert out is not None
new_calls = out["messages"][0].tool_calls
assert {c["id"] for c in new_calls} == {"1", "3"}
def test_unknown_tool_passes_through() -> None:
mw = DedupHITLToolCallsMiddleware(agent_tools=None)
state = {

View file

@ -6,7 +6,10 @@ import pytest
from langchain_core.messages import AIMessage, ToolMessage
from app.agents.new_chat.errors import CorrectedError, RejectedError
from app.agents.new_chat.middleware.permission import PermissionMiddleware
from app.agents.new_chat.middleware.permission import (
PermissionMiddleware,
_normalize_permission_decision,
)
from app.agents.new_chat.permissions import Rule, Ruleset
pytestmark = pytest.mark.unit
@ -112,3 +115,151 @@ class TestAsk:
# Runtime ruleset got the always-allow rule
new_rules = [r for r in mw._runtime_ruleset.rules if r.action == "allow"]
assert any(r.permission == "send_email" for r in new_rules)
class TestNormalizeDecision:
"""Resume shapes ``_normalize_permission_decision`` must accept."""
def test_legacy_decision_type_dict_passes_through(self) -> None:
decision = {"decision_type": "once"}
assert _normalize_permission_decision(decision) == {"decision_type": "once"}
def test_legacy_decision_type_with_feedback_passes_through(self) -> None:
decision = {"decision_type": "reject", "feedback": "no thanks"}
assert _normalize_permission_decision(decision) == decision
def test_plain_string_wrapped(self) -> None:
assert _normalize_permission_decision("once") == {"decision_type": "once"}
assert _normalize_permission_decision("reject") == {"decision_type": "reject"}
def test_lc_envelope_approve_maps_to_once(self) -> None:
decision = {"decisions": [{"type": "approve"}]}
assert _normalize_permission_decision(decision) == {"decision_type": "once"}
def test_lc_envelope_reject_maps_to_reject(self) -> None:
decision = {"decisions": [{"type": "reject"}]}
assert _normalize_permission_decision(decision) == {"decision_type": "reject"}
def test_lc_envelope_reject_with_message_carries_feedback(self) -> None:
decision = {
"decisions": [{"type": "reject", "message": "wrong recipient"}]
}
out = _normalize_permission_decision(decision)
assert out == {"decision_type": "reject", "feedback": "wrong recipient"}
def test_lc_envelope_reject_with_feedback_field(self) -> None:
decision = {
"decisions": [{"type": "reject", "feedback": "tighten the subject"}]
}
out = _normalize_permission_decision(decision)
assert out == {"decision_type": "reject", "feedback": "tighten the subject"}
def test_lc_envelope_edit_maps_to_once(self) -> None:
# Pins the contract: edited args are NOT merged by permission.
decision = {
"decisions": [
{
"type": "edit",
"edited_action": {
"name": "send_email",
"args": {"subject": "edited"},
},
}
]
}
assert _normalize_permission_decision(decision) == {"decision_type": "once"}
def test_lc_single_decision_without_envelope(self) -> None:
assert _normalize_permission_decision({"type": "approve"}) == {
"decision_type": "once"
}
def test_unknown_type_falls_back_to_reject(self) -> None:
decision = {"decisions": [{"type": "totally_unknown"}]}
assert _normalize_permission_decision(decision) == {"decision_type": "reject"}
def test_missing_type_falls_back_to_reject(self) -> None:
assert _normalize_permission_decision({"decisions": [{}]}) == {
"decision_type": "reject"
}
def test_non_dict_non_string_falls_back_to_reject(self) -> None:
assert _normalize_permission_decision(None) == {"decision_type": "reject"}
assert _normalize_permission_decision(42) == {"decision_type": "reject"}
def test_empty_decisions_list_falls_back_to_reject(self) -> None:
# Fail-closed on a malformed reply rather than treat it as approve.
assert _normalize_permission_decision({"decisions": []}) == {
"decision_type": "reject"
}
class TestResumeShapesEndToEnd:
"""LangChain HITL envelope reaches ``_process`` correctly via ``_raise_interrupt``."""
def test_lc_approve_envelope_keeps_call(self) -> None:
mw = PermissionMiddleware(rulesets=[])
mw._raise_interrupt = lambda **kw: { # type: ignore[assignment]
"decisions": [{"type": "approve"}]
}
state = {"messages": [_msg({"name": "send_email", "args": {}, "id": "1"})]}
original = mw._raise_interrupt
mw._raise_interrupt = lambda **kw: _normalize_permission_decision( # type: ignore[assignment]
original(**kw)
)
out = mw.after_model(state, _FakeRuntime())
assert out is None
def test_lc_reject_envelope_raises(self) -> None:
mw = PermissionMiddleware(rulesets=[])
original = lambda **kw: {"decisions": [{"type": "reject"}]} # noqa: E731
mw._raise_interrupt = lambda **kw: _normalize_permission_decision( # type: ignore[assignment]
original(**kw)
)
state = {"messages": [_msg({"name": "send_email", "args": {}, "id": "1"})]}
with pytest.raises(RejectedError):
mw.after_model(state, _FakeRuntime())
def test_lc_reject_with_message_raises_corrected(self) -> None:
mw = PermissionMiddleware(rulesets=[])
original = lambda **kw: { # noqa: E731
"decisions": [{"type": "reject", "message": "wrong recipient"}]
}
mw._raise_interrupt = lambda **kw: _normalize_permission_decision( # type: ignore[assignment]
original(**kw)
)
state = {"messages": [_msg({"name": "send_email", "args": {}, "id": "1"})]}
with pytest.raises(CorrectedError) as excinfo:
mw.after_model(state, _FakeRuntime())
assert excinfo.value.feedback == "wrong recipient"
def test_lc_edit_envelope_keeps_call_with_original_args(self) -> None:
# Pins the "edit -> once, args unchanged" contract.
mw = PermissionMiddleware(rulesets=[])
original = lambda **kw: { # noqa: E731
"decisions": [
{
"type": "edit",
"edited_action": {
"name": "send_email",
"args": {"to": "edited@example.com"},
},
}
]
}
mw._raise_interrupt = lambda **kw: _normalize_permission_decision( # type: ignore[assignment]
original(**kw)
)
state = {
"messages": [
_msg(
{
"name": "send_email",
"args": {"to": "original@example.com"},
"id": "1",
}
)
]
}
out = mw.after_model(state, _FakeRuntime())
assert out is None

View file

@ -146,6 +146,31 @@ function markInterruptsCompleted(contentParts: Array<{ type: string; result?: un
}
}
/**
* Most recent pending tool-call card with this name, so a new HITL interrupt
* does not overwrite an already-approved card with the same tool name.
*/
function findHitlTargetToolCallId(
toolCallIndices: Map<string, number>,
contentParts: Array<{
type: string;
toolName?: string;
result?: unknown;
}>,
toolName: string
): string | null {
const entries = Array.from(toolCallIndices.entries());
for (let i = entries.length - 1; i >= 0; i--) {
const [tcId, idx] = entries[i];
const part = contentParts[idx];
if (!part || part.type !== "tool-call" || part.toolName !== toolName) continue;
const result = part.result as Record<string, unknown> | undefined | null;
if (result == null) return tcId;
if (result.__interrupt__ === true && !result.__decided__) return tcId;
}
return null;
}
/**
* Zod schema for mentioned document info (for type-safe parsing)
*/
@ -949,12 +974,13 @@ export default function NewChatPage() {
args: Record<string, unknown>;
}>;
for (const action of actionRequests) {
const existingIdx = Array.from(toolCallIndices.entries()).find(([, idx]) => {
const part = contentParts[idx];
return part?.type === "tool-call" && part.toolName === action.name;
});
if (existingIdx) {
updateToolCall(contentPartsState, existingIdx[0], {
const targetTcId = findHitlTargetToolCallId(
toolCallIndices,
contentParts,
action.name
);
if (targetTcId) {
updateToolCall(contentPartsState, targetTcId, {
result: { __interrupt__: true, ...interruptData },
});
} else {
@ -1265,6 +1291,7 @@ export default function NewChatPage() {
body: JSON.stringify({
search_space_id: searchSpaceId,
decisions,
disabled_tools: disabledTools.length > 0 ? disabledTools : undefined,
filesystem_mode: selection.filesystem_mode,
client_platform: selection.client_platform,
local_filesystem_mounts: selection.local_filesystem_mounts,
@ -1388,12 +1415,13 @@ export default function NewChatPage() {
args: Record<string, unknown>;
}>;
for (const action of actionRequests) {
const existingIdx = Array.from(toolCallIndices.entries()).find(([, idx]) => {
const part = contentParts[idx];
return part?.type === "tool-call" && part.toolName === action.name;
});
if (existingIdx) {
updateToolCall(contentPartsState, existingIdx[0], {
const targetTcId = findHitlTargetToolCallId(
toolCallIndices,
contentParts,
action.name
);
if (targetTcId) {
updateToolCall(contentPartsState, targetTcId, {
result: {
__interrupt__: true,
...interruptData,
@ -1514,6 +1542,25 @@ export default function NewChatPage() {
const decision = detail.decisions[0];
const decisionType = decision?.type as "approve" | "reject" | "edit";
// Fan a single click out to N decisions when the backend bundled
// N tool calls into one HITLRequest (one Approve/Reject covers
// the whole batch until per-card decisions land).
const interruptData = pendingInterrupt.interruptData as
| { action_requests?: unknown[] }
| undefined;
const expectedCount = Array.isArray(interruptData?.action_requests)
? interruptData.action_requests.length
: detail.decisions.length;
const submittedDecisions =
detail.decisions.length >= expectedCount || expectedCount <= 1
? detail.decisions
: [
...detail.decisions,
...Array.from({ length: expectedCount - detail.decisions.length }, () => ({
...detail.decisions[detail.decisions.length - 1],
})),
];
setMessages((prev) =>
prev.map((m) => {
if (m.id !== pendingInterrupt.assistantMsgId) return m;
@ -1554,7 +1601,7 @@ export default function NewChatPage() {
return { ...m, content: newContent as unknown as ThreadMessageLike["content"] };
})
);
handleResume(detail.decisions);
handleResume(submittedDecisions);
}
};
window.addEventListener("hitl-decision", handler);