mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-06 06:12:40 +02:00
Harden HITL for multi-step tasks: bypass internal MCP gate, full-args dedup, and decision-envelope normalization.
This commit is contained in:
parent
4ac3f0b304
commit
277bd50f37
6 changed files with 442 additions and 65 deletions
|
|
@ -21,6 +21,7 @@ A tool with no resolver from either path simply opts out of dedup.
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
|
@ -57,6 +58,19 @@ def wrap_dedup_key_by_arg_name(arg_name: str) -> DedupResolver:
|
|||
return _resolver
|
||||
|
||||
|
||||
def dedup_key_full_args(args: dict[str, Any]) -> str:
|
||||
"""Resolver that collapses calls only when **every** argument is identical.
|
||||
|
||||
Safe default for tools where no single field uniquely identifies a call
|
||||
(e.g. MCP tools whose first required field is a shared workspace id).
|
||||
"""
|
||||
|
||||
try:
|
||||
return json.dumps(args, sort_keys=True, default=str)
|
||||
except (TypeError, ValueError):
|
||||
return repr(sorted(args.items())) if isinstance(args, dict) else repr(args)
|
||||
|
||||
|
||||
# Backwards-compatible alias for code that imported the original
|
||||
# private name. New callers should use :func:`wrap_dedup_key_by_arg_name`.
|
||||
_wrap_string_key = wrap_dedup_key_by_arg_name
|
||||
|
|
|
|||
|
|
@ -19,8 +19,9 @@ Operation:
|
|||
the results: ``deny`` > ``ask`` > ``allow``.
|
||||
3. On ``deny``: replaces the call with a synthetic ``ToolMessage``
|
||||
containing a :class:`StreamingError`.
|
||||
4. On ``ask``: raises a SurfSense-style ``interrupt(...)``. The reply
|
||||
shape is ``{"decision_type": "once|always|reject", "feedback"?: str}``.
|
||||
4. On ``ask``: raises a SurfSense-style ``interrupt(...)``. Both the legacy
|
||||
SurfSense shape and LangChain HITL ``{"decisions": [{"type": ...}]}``
|
||||
replies are accepted via :func:`_normalize_permission_decision`.
|
||||
- ``once``: proceed.
|
||||
- ``always``: also persist allow rules for ``request.always`` patterns.
|
||||
- ``reject`` w/o feedback: raise :class:`RejectedError`.
|
||||
|
|
@ -81,6 +82,75 @@ def _default_pattern_resolver(name: str) -> PatternResolver:
|
|||
return _resolve
|
||||
|
||||
|
||||
# Translation from the LangChain HITL envelope (what ``stream_resume_chat``
|
||||
# sends) to SurfSense's legacy ``decision_type`` shape. ``edit`` keeps the
|
||||
# original tool args — tools needing argument edits should use
|
||||
# ``request_approval`` from ``app/agents/new_chat/tools/hitl.py``.
|
||||
_LC_TYPE_TO_PERMISSION_DECISION: dict[str, str] = {
|
||||
"approve": "once",
|
||||
"reject": "reject",
|
||||
"edit": "once",
|
||||
}
|
||||
|
||||
|
||||
def _normalize_permission_decision(decision: Any) -> dict[str, Any]:
|
||||
"""Coerce any accepted reply shape into ``{"decision_type": ..., "feedback"?}``.
|
||||
|
||||
Falls back to ``reject`` (with a warning) on unrecognized payloads so the
|
||||
middleware fails closed.
|
||||
"""
|
||||
if isinstance(decision, str):
|
||||
return {"decision_type": decision}
|
||||
if not isinstance(decision, dict):
|
||||
logger.warning(
|
||||
"Unrecognized permission resume value (%s); treating as reject",
|
||||
type(decision).__name__,
|
||||
)
|
||||
return {"decision_type": "reject"}
|
||||
|
||||
if decision.get("decision_type"):
|
||||
return decision
|
||||
|
||||
payload: dict[str, Any] = decision
|
||||
decisions = decision.get("decisions")
|
||||
if isinstance(decisions, list) and decisions:
|
||||
first = decisions[0]
|
||||
if isinstance(first, dict):
|
||||
payload = first
|
||||
|
||||
raw_type = payload.get("type") or payload.get("decision_type")
|
||||
if not raw_type:
|
||||
logger.warning(
|
||||
"Permission resume missing decision type (keys=%s); treating as reject",
|
||||
list(payload.keys()),
|
||||
)
|
||||
return {"decision_type": "reject"}
|
||||
|
||||
raw_type = str(raw_type).lower()
|
||||
mapped = _LC_TYPE_TO_PERMISSION_DECISION.get(raw_type)
|
||||
if mapped is None:
|
||||
# Tolerate legacy values arriving without ``decision_type`` wrapping.
|
||||
if raw_type in {"once", "always", "reject"}:
|
||||
mapped = raw_type
|
||||
else:
|
||||
logger.warning(
|
||||
"Unknown permission decision type %r; treating as reject", raw_type
|
||||
)
|
||||
mapped = "reject"
|
||||
|
||||
if raw_type == "edit":
|
||||
logger.warning(
|
||||
"Permission middleware received an 'edit' decision; original args "
|
||||
"kept (edits not merged here)."
|
||||
)
|
||||
|
||||
out: dict[str, Any] = {"decision_type": mapped}
|
||||
feedback = payload.get("feedback") or payload.get("message")
|
||||
if isinstance(feedback, str) and feedback.strip():
|
||||
out["feedback"] = feedback
|
||||
return out
|
||||
|
||||
|
||||
class PermissionMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
||||
"""Allow/deny/ask layer over the agent's tool calls.
|
||||
|
||||
|
|
@ -214,12 +284,7 @@ class PermissionMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
|||
ot.interrupt_span(interrupt_type="permission_ask"),
|
||||
):
|
||||
decision = interrupt(payload)
|
||||
if isinstance(decision, dict):
|
||||
return decision
|
||||
# Tolerate a plain string reply ("once", "always", "reject")
|
||||
if isinstance(decision, str):
|
||||
return {"decision_type": decision}
|
||||
return {"decision_type": "reject"}
|
||||
return _normalize_permission_decision(decision)
|
||||
|
||||
def _persist_always(self, tool_name: str, patterns: list[str]) -> None:
|
||||
"""Promote ``always`` reply into runtime allow rules.
|
||||
|
|
@ -355,4 +420,5 @@ class PermissionMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
|||
__all__ = [
|
||||
"PatternResolver",
|
||||
"PermissionMiddleware",
|
||||
"_normalize_permission_decision",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -33,6 +33,7 @@ from sqlalchemy import cast, select
|
|||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.agents.new_chat.middleware.dedup_tool_calls import dedup_key_full_args
|
||||
from app.agents.new_chat.tools.hitl import request_approval
|
||||
from app.agents.new_chat.tools.mcp_client import MCPClient
|
||||
from app.db import SearchSourceConnector
|
||||
|
|
@ -45,7 +46,10 @@ _MCP_CACHE_MAX_SIZE = 50
|
|||
_MCP_DISCOVERY_TIMEOUT_SECONDS = 30
|
||||
_TOOL_CALL_MAX_RETRIES = 3
|
||||
_TOOL_CALL_RETRY_DELAY = 1.5 # seconds, doubles per attempt
|
||||
_mcp_tools_cache: dict[int, tuple[float, list[StructuredTool]]] = {}
|
||||
# Keyed by ``(search_space_id, bypass_internal_hitl)`` so single-agent and
|
||||
# multi-agent paths cannot share tool closures with different HITL wiring.
|
||||
_MCPCacheKey = tuple[int, bool]
|
||||
_mcp_tools_cache: dict[_MCPCacheKey, tuple[float, list[StructuredTool]]] = {}
|
||||
|
||||
|
||||
def _evict_expired_mcp_cache() -> None:
|
||||
|
|
@ -137,12 +141,13 @@ async def _create_mcp_tool_from_definition_stdio(
|
|||
connector_name: str = "",
|
||||
connector_id: int | None = None,
|
||||
trusted_tools: list[str] | None = None,
|
||||
bypass_internal_hitl: bool = False,
|
||||
) -> StructuredTool:
|
||||
"""Create a LangChain tool from an MCP tool definition (stdio transport).
|
||||
|
||||
All MCP tools are unconditionally wrapped with HITL approval.
|
||||
``request_approval()`` is called OUTSIDE the try/except so that
|
||||
``GraphInterrupt`` propagates cleanly to LangGraph.
|
||||
Set ``bypass_internal_hitl=True`` when an outer ``HumanInTheLoopMiddleware``
|
||||
already gates the tool, otherwise the body's ``request_approval()`` is the
|
||||
sole HITL gate (single-agent path).
|
||||
"""
|
||||
tool_name = tool_def.get("name", "unnamed_tool")
|
||||
raw_description = tool_def.get("description", "No description provided")
|
||||
|
|
@ -161,24 +166,29 @@ async def _create_mcp_tool_from_definition_stdio(
|
|||
"""Execute the MCP tool call via the client with retry support."""
|
||||
logger.debug("MCP tool '%s' called", tool_name)
|
||||
|
||||
# HITL — OUTSIDE try/except so GraphInterrupt propagates to LangGraph
|
||||
hitl_result = request_approval(
|
||||
action_type="mcp_tool_call",
|
||||
tool_name=tool_name,
|
||||
params=kwargs,
|
||||
context={
|
||||
"mcp_server": connector_name,
|
||||
"tool_description": raw_description,
|
||||
"mcp_transport": "stdio",
|
||||
"mcp_connector_id": connector_id,
|
||||
},
|
||||
trusted_tools=trusted_tools,
|
||||
)
|
||||
if hitl_result.rejected:
|
||||
return "Tool call rejected by user."
|
||||
call_kwargs = _unpack_synthetic_input_data(
|
||||
{k: v for k, v in hitl_result.params.items() if v is not None}
|
||||
)
|
||||
if bypass_internal_hitl:
|
||||
call_kwargs = _unpack_synthetic_input_data(
|
||||
{k: v for k, v in kwargs.items() if v is not None}
|
||||
)
|
||||
else:
|
||||
# Outside try/except so ``GraphInterrupt`` propagates to LangGraph.
|
||||
hitl_result = request_approval(
|
||||
action_type="mcp_tool_call",
|
||||
tool_name=tool_name,
|
||||
params=kwargs,
|
||||
context={
|
||||
"mcp_server": connector_name,
|
||||
"tool_description": raw_description,
|
||||
"mcp_transport": "stdio",
|
||||
"mcp_connector_id": connector_id,
|
||||
},
|
||||
trusted_tools=trusted_tools,
|
||||
)
|
||||
if hitl_result.rejected:
|
||||
return "Tool call rejected by user."
|
||||
call_kwargs = _unpack_synthetic_input_data(
|
||||
{k: v for k, v in hitl_result.params.items() if v is not None}
|
||||
)
|
||||
|
||||
last_error: Exception | None = None
|
||||
for attempt in range(_TOOL_CALL_MAX_RETRIES):
|
||||
|
|
@ -221,7 +231,9 @@ async def _create_mcp_tool_from_definition_stdio(
|
|||
"mcp_connector_name": connector_name or None,
|
||||
"mcp_is_generic": True,
|
||||
"hitl": True,
|
||||
"hitl_dedup_key": next(iter(input_schema.get("required", [])), None),
|
||||
# Full-args hash: shared identifiers (cloudId, workspaceId, …)
|
||||
# would otherwise collapse legitimate batches.
|
||||
"dedup_key": dedup_key_full_args,
|
||||
},
|
||||
)
|
||||
|
||||
|
|
@ -240,11 +252,14 @@ async def _create_mcp_tool_from_definition_http(
|
|||
readonly_tools: frozenset[str] | None = None,
|
||||
tool_name_prefix: str | None = None,
|
||||
is_generic_mcp: bool = False,
|
||||
bypass_internal_hitl: bool = False,
|
||||
) -> StructuredTool:
|
||||
"""Create a LangChain tool from an MCP tool definition (HTTP transport).
|
||||
|
||||
Write tools are wrapped with HITL approval; read-only tools (listed in
|
||||
``readonly_tools``) execute immediately without user confirmation.
|
||||
``readonly_tools``) execute immediately without user confirmation. Set
|
||||
``bypass_internal_hitl=True`` when an outer ``HumanInTheLoopMiddleware``
|
||||
already gates the tool.
|
||||
|
||||
When ``tool_name_prefix`` is set (multi-account disambiguation), the
|
||||
tool exposed to the LLM gets a prefixed name (e.g. ``linear_25_list_issues``)
|
||||
|
|
@ -302,7 +317,7 @@ async def _create_mcp_tool_from_definition_http(
|
|||
"""Execute the MCP tool call via HTTP transport."""
|
||||
logger.debug("MCP HTTP tool '%s' called", exposed_name)
|
||||
|
||||
if is_readonly:
|
||||
if is_readonly or bypass_internal_hitl:
|
||||
call_kwargs = _unpack_synthetic_input_data(
|
||||
{k: v for k, v in kwargs.items() if v is not None}
|
||||
)
|
||||
|
|
@ -385,7 +400,9 @@ async def _create_mcp_tool_from_definition_http(
|
|||
"mcp_connector_name": connector_name or None,
|
||||
"mcp_is_generic": is_generic_mcp,
|
||||
"hitl": not is_readonly,
|
||||
"hitl_dedup_key": next(iter(input_schema.get("required", [])), None),
|
||||
# Full-args hash: shared identifiers (cloudId, workspaceId, …)
|
||||
# would otherwise collapse legitimate batches.
|
||||
"dedup_key": dedup_key_full_args,
|
||||
"mcp_original_tool_name": original_tool_name,
|
||||
"mcp_connector_id": connector_id,
|
||||
},
|
||||
|
|
@ -400,6 +417,8 @@ async def _load_stdio_mcp_tools(
|
|||
connector_name: str,
|
||||
server_config: dict[str, Any],
|
||||
trusted_tools: list[str] | None = None,
|
||||
*,
|
||||
bypass_internal_hitl: bool = False,
|
||||
) -> list[StructuredTool]:
|
||||
"""Load tools from a stdio-based MCP server."""
|
||||
tools: list[StructuredTool] = []
|
||||
|
|
@ -451,6 +470,7 @@ async def _load_stdio_mcp_tools(
|
|||
connector_name=connector_name,
|
||||
connector_id=connector_id,
|
||||
trusted_tools=trusted_tools,
|
||||
bypass_internal_hitl=bypass_internal_hitl,
|
||||
)
|
||||
tools.append(tool)
|
||||
except Exception as e:
|
||||
|
|
@ -473,6 +493,8 @@ async def _load_http_mcp_tools(
|
|||
readonly_tools: frozenset[str] | None = None,
|
||||
tool_name_prefix: str | None = None,
|
||||
is_generic_mcp: bool = False,
|
||||
*,
|
||||
bypass_internal_hitl: bool = False,
|
||||
) -> list[StructuredTool]:
|
||||
"""Load tools from an HTTP-based MCP server.
|
||||
|
||||
|
|
@ -598,6 +620,7 @@ async def _load_http_mcp_tools(
|
|||
readonly_tools=readonly_tools,
|
||||
tool_name_prefix=tool_name_prefix,
|
||||
is_generic_mcp=is_generic_mcp,
|
||||
bypass_internal_hitl=bypass_internal_hitl,
|
||||
)
|
||||
tools.append(tool)
|
||||
except Exception as e:
|
||||
|
|
@ -905,14 +928,10 @@ async def _mark_connector_auth_expired(connector_id: int) -> None:
|
|||
|
||||
|
||||
def invalidate_mcp_tools_cache(search_space_id: int | None = None) -> None:
|
||||
"""Invalidate cached MCP tools.
|
||||
|
||||
Args:
|
||||
search_space_id: If provided, only invalidate for this search space.
|
||||
If None, invalidate all cached MCP tools.
|
||||
"""
|
||||
"""Invalidate cached MCP tools (both ``bypass_internal_hitl`` variants together)."""
|
||||
if search_space_id is not None:
|
||||
_mcp_tools_cache.pop(search_space_id, None)
|
||||
for key in [k for k in _mcp_tools_cache if k[0] == search_space_id]:
|
||||
_mcp_tools_cache.pop(key, None)
|
||||
else:
|
||||
_mcp_tools_cache.clear()
|
||||
|
||||
|
|
@ -920,27 +939,29 @@ def invalidate_mcp_tools_cache(search_space_id: int | None = None) -> None:
|
|||
async def load_mcp_tools(
|
||||
session: AsyncSession,
|
||||
search_space_id: int,
|
||||
*,
|
||||
bypass_internal_hitl: bool = False,
|
||||
) -> list[StructuredTool]:
|
||||
"""Load all MCP tools from user's active MCP server connectors.
|
||||
"""Load all MCP tools from the user's active MCP server connectors.
|
||||
|
||||
This discovers tools dynamically from MCP servers using the protocol.
|
||||
Supports both stdio (local process) and HTTP (remote server) transports.
|
||||
|
||||
Results are cached per search space for up to 5 minutes to avoid
|
||||
re-spawning MCP server processes on every chat message.
|
||||
Results are cached per ``(search_space_id, bypass_internal_hitl)`` for up
|
||||
to 5 minutes; bypass is keyed because each variant builds a different tool
|
||||
closure (with vs. without the in-wrapper ``request_approval`` gate).
|
||||
"""
|
||||
_evict_expired_mcp_cache()
|
||||
|
||||
now = time.monotonic()
|
||||
cached = _mcp_tools_cache.get(search_space_id)
|
||||
cache_key: _MCPCacheKey = (search_space_id, bypass_internal_hitl)
|
||||
cached = _mcp_tools_cache.get(cache_key)
|
||||
if cached is not None:
|
||||
cached_at, cached_tools = cached
|
||||
if now - cached_at < _MCP_CACHE_TTL_SECONDS:
|
||||
logger.info(
|
||||
"Using cached MCP tools for search space %s (%d tools, age=%.0fs)",
|
||||
"Using cached MCP tools for search space %s (%d tools, age=%.0fs, bypass_hitl=%s)",
|
||||
search_space_id,
|
||||
len(cached_tools),
|
||||
now - cached_at,
|
||||
bypass_internal_hitl,
|
||||
)
|
||||
return list(cached_tools)
|
||||
|
||||
|
|
@ -1064,6 +1085,7 @@ async def load_mcp_tools(
|
|||
readonly_tools=task["readonly_tools"],
|
||||
tool_name_prefix=task["tool_name_prefix"],
|
||||
is_generic_mcp=task.get("is_generic_mcp", False),
|
||||
bypass_internal_hitl=bypass_internal_hitl,
|
||||
),
|
||||
timeout=_MCP_DISCOVERY_TIMEOUT_SECONDS,
|
||||
)
|
||||
|
|
@ -1074,6 +1096,7 @@ async def load_mcp_tools(
|
|||
task["connector_name"],
|
||||
task["server_config"],
|
||||
trusted_tools=task["trusted_tools"],
|
||||
bypass_internal_hitl=bypass_internal_hitl,
|
||||
),
|
||||
timeout=_MCP_DISCOVERY_TIMEOUT_SECONDS,
|
||||
)
|
||||
|
|
@ -1095,14 +1118,17 @@ async def load_mcp_tools(
|
|||
results = await asyncio.gather(*[_discover_one(t) for t in discovery_tasks])
|
||||
tools: list[StructuredTool] = [tool for sublist in results for tool in sublist]
|
||||
|
||||
_mcp_tools_cache[search_space_id] = (now, tools)
|
||||
_mcp_tools_cache[cache_key] = (now, tools)
|
||||
|
||||
if len(_mcp_tools_cache) > _MCP_CACHE_MAX_SIZE:
|
||||
oldest_key = min(_mcp_tools_cache, key=lambda k: _mcp_tools_cache[k][0])
|
||||
del _mcp_tools_cache[oldest_key]
|
||||
|
||||
logger.info(
|
||||
"Loaded %d MCP tools for search space %d", len(tools), search_space_id
|
||||
"Loaded %d MCP tools for search space %d (bypass_hitl=%s)",
|
||||
len(tools),
|
||||
search_space_id,
|
||||
bypass_internal_hitl,
|
||||
)
|
||||
return tools
|
||||
|
||||
|
|
|
|||
|
|
@ -130,6 +130,79 @@ def test_registry_propagates_dedup_key_to_tool_metadata() -> None:
|
|||
assert sample == "plan"
|
||||
|
||||
|
||||
def test_full_args_dedup_keeps_distinct_calls_sharing_a_field() -> None:
|
||||
"""Regression: MCP tools (e.g. ``createJiraIssue``) used to dedup on
|
||||
the schema's first required field, which is often the workspace /
|
||||
cloudId — so 3 distinct issues in the same workspace collapsed to 1.
|
||||
|
||||
With :func:`dedup_key_full_args` only fully identical arg dicts dedup.
|
||||
"""
|
||||
from app.agents.new_chat.middleware.dedup_tool_calls import dedup_key_full_args
|
||||
|
||||
tool = _make_tool("createJiraIssue", dedup_key=dedup_key_full_args)
|
||||
mw = DedupHITLToolCallsMiddleware(agent_tools=[tool])
|
||||
state = {
|
||||
"messages": [
|
||||
_msg(
|
||||
{
|
||||
"name": "createJiraIssue",
|
||||
"args": {
|
||||
"cloudId": "ws.atlassian.net",
|
||||
"projectKey": "PROJ",
|
||||
"summary": "Fix login bug",
|
||||
},
|
||||
"id": "1",
|
||||
},
|
||||
{
|
||||
"name": "createJiraIssue",
|
||||
"args": {
|
||||
"cloudId": "ws.atlassian.net",
|
||||
"projectKey": "PROJ",
|
||||
"summary": "Add dark mode",
|
||||
},
|
||||
"id": "2",
|
||||
},
|
||||
{
|
||||
"name": "createJiraIssue",
|
||||
"args": {
|
||||
"cloudId": "ws.atlassian.net",
|
||||
"projectKey": "PROJ",
|
||||
"summary": "Improve perf",
|
||||
},
|
||||
"id": "3",
|
||||
},
|
||||
)
|
||||
]
|
||||
}
|
||||
out = mw.after_model(state, _Runtime())
|
||||
assert out is None # nothing dropped — all three differ in summary
|
||||
|
||||
|
||||
def test_full_args_dedup_drops_only_exact_duplicates() -> None:
|
||||
from app.agents.new_chat.middleware.dedup_tool_calls import dedup_key_full_args
|
||||
|
||||
tool = _make_tool("createJiraIssue", dedup_key=dedup_key_full_args)
|
||||
mw = DedupHITLToolCallsMiddleware(agent_tools=[tool])
|
||||
args = {"cloudId": "ws.atlassian.net", "summary": "Fix bug"}
|
||||
state = {
|
||||
"messages": [
|
||||
_msg(
|
||||
{"name": "createJiraIssue", "args": args, "id": "1"},
|
||||
{"name": "createJiraIssue", "args": dict(args), "id": "2"},
|
||||
{
|
||||
"name": "createJiraIssue",
|
||||
"args": {**args, "summary": "Different"},
|
||||
"id": "3",
|
||||
},
|
||||
)
|
||||
]
|
||||
}
|
||||
out = mw.after_model(state, _Runtime())
|
||||
assert out is not None
|
||||
new_calls = out["messages"][0].tool_calls
|
||||
assert {c["id"] for c in new_calls} == {"1", "3"}
|
||||
|
||||
|
||||
def test_unknown_tool_passes_through() -> None:
|
||||
mw = DedupHITLToolCallsMiddleware(agent_tools=None)
|
||||
state = {
|
||||
|
|
|
|||
|
|
@ -6,7 +6,10 @@ import pytest
|
|||
from langchain_core.messages import AIMessage, ToolMessage
|
||||
|
||||
from app.agents.new_chat.errors import CorrectedError, RejectedError
|
||||
from app.agents.new_chat.middleware.permission import PermissionMiddleware
|
||||
from app.agents.new_chat.middleware.permission import (
|
||||
PermissionMiddleware,
|
||||
_normalize_permission_decision,
|
||||
)
|
||||
from app.agents.new_chat.permissions import Rule, Ruleset
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
|
@ -112,3 +115,151 @@ class TestAsk:
|
|||
# Runtime ruleset got the always-allow rule
|
||||
new_rules = [r for r in mw._runtime_ruleset.rules if r.action == "allow"]
|
||||
assert any(r.permission == "send_email" for r in new_rules)
|
||||
|
||||
|
||||
class TestNormalizeDecision:
|
||||
"""Resume shapes ``_normalize_permission_decision`` must accept."""
|
||||
|
||||
def test_legacy_decision_type_dict_passes_through(self) -> None:
|
||||
decision = {"decision_type": "once"}
|
||||
assert _normalize_permission_decision(decision) == {"decision_type": "once"}
|
||||
|
||||
def test_legacy_decision_type_with_feedback_passes_through(self) -> None:
|
||||
decision = {"decision_type": "reject", "feedback": "no thanks"}
|
||||
assert _normalize_permission_decision(decision) == decision
|
||||
|
||||
def test_plain_string_wrapped(self) -> None:
|
||||
assert _normalize_permission_decision("once") == {"decision_type": "once"}
|
||||
assert _normalize_permission_decision("reject") == {"decision_type": "reject"}
|
||||
|
||||
def test_lc_envelope_approve_maps_to_once(self) -> None:
|
||||
decision = {"decisions": [{"type": "approve"}]}
|
||||
assert _normalize_permission_decision(decision) == {"decision_type": "once"}
|
||||
|
||||
def test_lc_envelope_reject_maps_to_reject(self) -> None:
|
||||
decision = {"decisions": [{"type": "reject"}]}
|
||||
assert _normalize_permission_decision(decision) == {"decision_type": "reject"}
|
||||
|
||||
def test_lc_envelope_reject_with_message_carries_feedback(self) -> None:
|
||||
decision = {
|
||||
"decisions": [{"type": "reject", "message": "wrong recipient"}]
|
||||
}
|
||||
out = _normalize_permission_decision(decision)
|
||||
assert out == {"decision_type": "reject", "feedback": "wrong recipient"}
|
||||
|
||||
def test_lc_envelope_reject_with_feedback_field(self) -> None:
|
||||
decision = {
|
||||
"decisions": [{"type": "reject", "feedback": "tighten the subject"}]
|
||||
}
|
||||
out = _normalize_permission_decision(decision)
|
||||
assert out == {"decision_type": "reject", "feedback": "tighten the subject"}
|
||||
|
||||
def test_lc_envelope_edit_maps_to_once(self) -> None:
|
||||
# Pins the contract: edited args are NOT merged by permission.
|
||||
decision = {
|
||||
"decisions": [
|
||||
{
|
||||
"type": "edit",
|
||||
"edited_action": {
|
||||
"name": "send_email",
|
||||
"args": {"subject": "edited"},
|
||||
},
|
||||
}
|
||||
]
|
||||
}
|
||||
assert _normalize_permission_decision(decision) == {"decision_type": "once"}
|
||||
|
||||
def test_lc_single_decision_without_envelope(self) -> None:
|
||||
assert _normalize_permission_decision({"type": "approve"}) == {
|
||||
"decision_type": "once"
|
||||
}
|
||||
|
||||
def test_unknown_type_falls_back_to_reject(self) -> None:
|
||||
decision = {"decisions": [{"type": "totally_unknown"}]}
|
||||
assert _normalize_permission_decision(decision) == {"decision_type": "reject"}
|
||||
|
||||
def test_missing_type_falls_back_to_reject(self) -> None:
|
||||
assert _normalize_permission_decision({"decisions": [{}]}) == {
|
||||
"decision_type": "reject"
|
||||
}
|
||||
|
||||
def test_non_dict_non_string_falls_back_to_reject(self) -> None:
|
||||
assert _normalize_permission_decision(None) == {"decision_type": "reject"}
|
||||
assert _normalize_permission_decision(42) == {"decision_type": "reject"}
|
||||
|
||||
def test_empty_decisions_list_falls_back_to_reject(self) -> None:
|
||||
# Fail-closed on a malformed reply rather than treat it as approve.
|
||||
assert _normalize_permission_decision({"decisions": []}) == {
|
||||
"decision_type": "reject"
|
||||
}
|
||||
|
||||
|
||||
class TestResumeShapesEndToEnd:
|
||||
"""LangChain HITL envelope reaches ``_process`` correctly via ``_raise_interrupt``."""
|
||||
|
||||
def test_lc_approve_envelope_keeps_call(self) -> None:
|
||||
mw = PermissionMiddleware(rulesets=[])
|
||||
mw._raise_interrupt = lambda **kw: { # type: ignore[assignment]
|
||||
"decisions": [{"type": "approve"}]
|
||||
}
|
||||
state = {"messages": [_msg({"name": "send_email", "args": {}, "id": "1"})]}
|
||||
original = mw._raise_interrupt
|
||||
mw._raise_interrupt = lambda **kw: _normalize_permission_decision( # type: ignore[assignment]
|
||||
original(**kw)
|
||||
)
|
||||
out = mw.after_model(state, _FakeRuntime())
|
||||
assert out is None
|
||||
|
||||
def test_lc_reject_envelope_raises(self) -> None:
|
||||
mw = PermissionMiddleware(rulesets=[])
|
||||
original = lambda **kw: {"decisions": [{"type": "reject"}]} # noqa: E731
|
||||
mw._raise_interrupt = lambda **kw: _normalize_permission_decision( # type: ignore[assignment]
|
||||
original(**kw)
|
||||
)
|
||||
state = {"messages": [_msg({"name": "send_email", "args": {}, "id": "1"})]}
|
||||
with pytest.raises(RejectedError):
|
||||
mw.after_model(state, _FakeRuntime())
|
||||
|
||||
def test_lc_reject_with_message_raises_corrected(self) -> None:
|
||||
mw = PermissionMiddleware(rulesets=[])
|
||||
original = lambda **kw: { # noqa: E731
|
||||
"decisions": [{"type": "reject", "message": "wrong recipient"}]
|
||||
}
|
||||
mw._raise_interrupt = lambda **kw: _normalize_permission_decision( # type: ignore[assignment]
|
||||
original(**kw)
|
||||
)
|
||||
state = {"messages": [_msg({"name": "send_email", "args": {}, "id": "1"})]}
|
||||
with pytest.raises(CorrectedError) as excinfo:
|
||||
mw.after_model(state, _FakeRuntime())
|
||||
assert excinfo.value.feedback == "wrong recipient"
|
||||
|
||||
def test_lc_edit_envelope_keeps_call_with_original_args(self) -> None:
|
||||
# Pins the "edit -> once, args unchanged" contract.
|
||||
mw = PermissionMiddleware(rulesets=[])
|
||||
original = lambda **kw: { # noqa: E731
|
||||
"decisions": [
|
||||
{
|
||||
"type": "edit",
|
||||
"edited_action": {
|
||||
"name": "send_email",
|
||||
"args": {"to": "edited@example.com"},
|
||||
},
|
||||
}
|
||||
]
|
||||
}
|
||||
mw._raise_interrupt = lambda **kw: _normalize_permission_decision( # type: ignore[assignment]
|
||||
original(**kw)
|
||||
)
|
||||
state = {
|
||||
"messages": [
|
||||
_msg(
|
||||
{
|
||||
"name": "send_email",
|
||||
"args": {"to": "original@example.com"},
|
||||
"id": "1",
|
||||
}
|
||||
)
|
||||
]
|
||||
}
|
||||
out = mw.after_model(state, _FakeRuntime())
|
||||
assert out is None
|
||||
|
|
|
|||
|
|
@ -146,6 +146,31 @@ function markInterruptsCompleted(contentParts: Array<{ type: string; result?: un
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Most recent pending tool-call card with this name, so a new HITL interrupt
|
||||
* does not overwrite an already-approved card with the same tool name.
|
||||
*/
|
||||
function findHitlTargetToolCallId(
|
||||
toolCallIndices: Map<string, number>,
|
||||
contentParts: Array<{
|
||||
type: string;
|
||||
toolName?: string;
|
||||
result?: unknown;
|
||||
}>,
|
||||
toolName: string
|
||||
): string | null {
|
||||
const entries = Array.from(toolCallIndices.entries());
|
||||
for (let i = entries.length - 1; i >= 0; i--) {
|
||||
const [tcId, idx] = entries[i];
|
||||
const part = contentParts[idx];
|
||||
if (!part || part.type !== "tool-call" || part.toolName !== toolName) continue;
|
||||
const result = part.result as Record<string, unknown> | undefined | null;
|
||||
if (result == null) return tcId;
|
||||
if (result.__interrupt__ === true && !result.__decided__) return tcId;
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Zod schema for mentioned document info (for type-safe parsing)
|
||||
*/
|
||||
|
|
@ -949,12 +974,13 @@ export default function NewChatPage() {
|
|||
args: Record<string, unknown>;
|
||||
}>;
|
||||
for (const action of actionRequests) {
|
||||
const existingIdx = Array.from(toolCallIndices.entries()).find(([, idx]) => {
|
||||
const part = contentParts[idx];
|
||||
return part?.type === "tool-call" && part.toolName === action.name;
|
||||
});
|
||||
if (existingIdx) {
|
||||
updateToolCall(contentPartsState, existingIdx[0], {
|
||||
const targetTcId = findHitlTargetToolCallId(
|
||||
toolCallIndices,
|
||||
contentParts,
|
||||
action.name
|
||||
);
|
||||
if (targetTcId) {
|
||||
updateToolCall(contentPartsState, targetTcId, {
|
||||
result: { __interrupt__: true, ...interruptData },
|
||||
});
|
||||
} else {
|
||||
|
|
@ -1265,6 +1291,7 @@ export default function NewChatPage() {
|
|||
body: JSON.stringify({
|
||||
search_space_id: searchSpaceId,
|
||||
decisions,
|
||||
disabled_tools: disabledTools.length > 0 ? disabledTools : undefined,
|
||||
filesystem_mode: selection.filesystem_mode,
|
||||
client_platform: selection.client_platform,
|
||||
local_filesystem_mounts: selection.local_filesystem_mounts,
|
||||
|
|
@ -1388,12 +1415,13 @@ export default function NewChatPage() {
|
|||
args: Record<string, unknown>;
|
||||
}>;
|
||||
for (const action of actionRequests) {
|
||||
const existingIdx = Array.from(toolCallIndices.entries()).find(([, idx]) => {
|
||||
const part = contentParts[idx];
|
||||
return part?.type === "tool-call" && part.toolName === action.name;
|
||||
});
|
||||
if (existingIdx) {
|
||||
updateToolCall(contentPartsState, existingIdx[0], {
|
||||
const targetTcId = findHitlTargetToolCallId(
|
||||
toolCallIndices,
|
||||
contentParts,
|
||||
action.name
|
||||
);
|
||||
if (targetTcId) {
|
||||
updateToolCall(contentPartsState, targetTcId, {
|
||||
result: {
|
||||
__interrupt__: true,
|
||||
...interruptData,
|
||||
|
|
@ -1514,6 +1542,25 @@ export default function NewChatPage() {
|
|||
const decision = detail.decisions[0];
|
||||
const decisionType = decision?.type as "approve" | "reject" | "edit";
|
||||
|
||||
// Fan a single click out to N decisions when the backend bundled
|
||||
// N tool calls into one HITLRequest (one Approve/Reject covers
|
||||
// the whole batch until per-card decisions land).
|
||||
const interruptData = pendingInterrupt.interruptData as
|
||||
| { action_requests?: unknown[] }
|
||||
| undefined;
|
||||
const expectedCount = Array.isArray(interruptData?.action_requests)
|
||||
? interruptData.action_requests.length
|
||||
: detail.decisions.length;
|
||||
const submittedDecisions =
|
||||
detail.decisions.length >= expectedCount || expectedCount <= 1
|
||||
? detail.decisions
|
||||
: [
|
||||
...detail.decisions,
|
||||
...Array.from({ length: expectedCount - detail.decisions.length }, () => ({
|
||||
...detail.decisions[detail.decisions.length - 1],
|
||||
})),
|
||||
];
|
||||
|
||||
setMessages((prev) =>
|
||||
prev.map((m) => {
|
||||
if (m.id !== pendingInterrupt.assistantMsgId) return m;
|
||||
|
|
@ -1554,7 +1601,7 @@ export default function NewChatPage() {
|
|||
return { ...m, content: newContent as unknown as ThreadMessageLike["content"] };
|
||||
})
|
||||
);
|
||||
handleResume(detail.decisions);
|
||||
handleResume(submittedDecisions);
|
||||
}
|
||||
};
|
||||
window.addEventListener("hitl-decision", handler);
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue