Harden HITL for multi-step tasks: bypass internal MCP gate, full-args dedup, and decision-envelope normalization.

This commit is contained in:
CREDO23 2026-05-04 19:25:27 +02:00
parent 4ac3f0b304
commit 277bd50f37
6 changed files with 442 additions and 65 deletions

View file

@ -21,6 +21,7 @@ A tool with no resolver from either path simply opts out of dedup.
from __future__ import annotations
import json
import logging
from collections.abc import Callable
from typing import Any
@ -57,6 +58,19 @@ def wrap_dedup_key_by_arg_name(arg_name: str) -> DedupResolver:
return _resolver
def dedup_key_full_args(args: dict[str, Any]) -> str:
"""Resolver that collapses calls only when **every** argument is identical.
Safe default for tools where no single field uniquely identifies a call
(e.g. MCP tools whose first required field is a shared workspace id).
"""
try:
return json.dumps(args, sort_keys=True, default=str)
except (TypeError, ValueError):
return repr(sorted(args.items())) if isinstance(args, dict) else repr(args)
# Backwards-compatible alias for code that imported the original
# private name. New callers should use :func:`wrap_dedup_key_by_arg_name`.
_wrap_string_key = wrap_dedup_key_by_arg_name

View file

@ -19,8 +19,9 @@ Operation:
the results: ``deny`` > ``ask`` > ``allow``.
3. On ``deny``: replaces the call with a synthetic ``ToolMessage``
containing a :class:`StreamingError`.
4. On ``ask``: raises a SurfSense-style ``interrupt(...)``. The reply
shape is ``{"decision_type": "once|always|reject", "feedback"?: str}``.
4. On ``ask``: raises a SurfSense-style ``interrupt(...)``. Both the legacy
SurfSense shape and LangChain HITL ``{"decisions": [{"type": ...}]}``
replies are accepted via :func:`_normalize_permission_decision`.
- ``once``: proceed.
- ``always``: also persist allow rules for ``request.always`` patterns.
- ``reject`` w/o feedback: raise :class:`RejectedError`.
@ -81,6 +82,75 @@ def _default_pattern_resolver(name: str) -> PatternResolver:
return _resolve
# Translation from the LangChain HITL envelope (what ``stream_resume_chat``
# sends) to SurfSense's legacy ``decision_type`` shape. ``edit`` keeps the
# original tool args — tools needing argument edits should use
# ``request_approval`` from ``app/agents/new_chat/tools/hitl.py``.
_LC_TYPE_TO_PERMISSION_DECISION: dict[str, str] = {
"approve": "once",
"reject": "reject",
"edit": "once",
}
def _normalize_permission_decision(decision: Any) -> dict[str, Any]:
"""Coerce any accepted reply shape into ``{"decision_type": ..., "feedback"?}``.
Falls back to ``reject`` (with a warning) on unrecognized payloads so the
middleware fails closed.
"""
if isinstance(decision, str):
return {"decision_type": decision}
if not isinstance(decision, dict):
logger.warning(
"Unrecognized permission resume value (%s); treating as reject",
type(decision).__name__,
)
return {"decision_type": "reject"}
if decision.get("decision_type"):
return decision
payload: dict[str, Any] = decision
decisions = decision.get("decisions")
if isinstance(decisions, list) and decisions:
first = decisions[0]
if isinstance(first, dict):
payload = first
raw_type = payload.get("type") or payload.get("decision_type")
if not raw_type:
logger.warning(
"Permission resume missing decision type (keys=%s); treating as reject",
list(payload.keys()),
)
return {"decision_type": "reject"}
raw_type = str(raw_type).lower()
mapped = _LC_TYPE_TO_PERMISSION_DECISION.get(raw_type)
if mapped is None:
# Tolerate legacy values arriving without ``decision_type`` wrapping.
if raw_type in {"once", "always", "reject"}:
mapped = raw_type
else:
logger.warning(
"Unknown permission decision type %r; treating as reject", raw_type
)
mapped = "reject"
if raw_type == "edit":
logger.warning(
"Permission middleware received an 'edit' decision; original args "
"kept (edits not merged here)."
)
out: dict[str, Any] = {"decision_type": mapped}
feedback = payload.get("feedback") or payload.get("message")
if isinstance(feedback, str) and feedback.strip():
out["feedback"] = feedback
return out
class PermissionMiddleware(AgentMiddleware): # type: ignore[type-arg]
"""Allow/deny/ask layer over the agent's tool calls.
@ -214,12 +284,7 @@ class PermissionMiddleware(AgentMiddleware): # type: ignore[type-arg]
ot.interrupt_span(interrupt_type="permission_ask"),
):
decision = interrupt(payload)
if isinstance(decision, dict):
return decision
# Tolerate a plain string reply ("once", "always", "reject")
if isinstance(decision, str):
return {"decision_type": decision}
return {"decision_type": "reject"}
return _normalize_permission_decision(decision)
def _persist_always(self, tool_name: str, patterns: list[str]) -> None:
"""Promote ``always`` reply into runtime allow rules.
@ -355,4 +420,5 @@ class PermissionMiddleware(AgentMiddleware): # type: ignore[type-arg]
__all__ = [
"PatternResolver",
"PermissionMiddleware",
"_normalize_permission_decision",
]