Harden HITL for multi-step tasks: bypass internal MCP gate, full-args dedup, and decision-envelope normalization.

This commit is contained in:
CREDO23 2026-05-04 19:25:27 +02:00
parent 4ac3f0b304
commit 277bd50f37
6 changed files with 442 additions and 65 deletions

View file

@ -130,6 +130,79 @@ def test_registry_propagates_dedup_key_to_tool_metadata() -> None:
assert sample == "plan"
def test_full_args_dedup_keeps_distinct_calls_sharing_a_field() -> None:
"""Regression: MCP tools (e.g. ``createJiraIssue``) used to dedup on
the schema's first required field, which is often the workspace /
cloudId so 3 distinct issues in the same workspace collapsed to 1.
With :func:`dedup_key_full_args` only fully identical arg dicts dedup.
"""
from app.agents.new_chat.middleware.dedup_tool_calls import dedup_key_full_args
tool = _make_tool("createJiraIssue", dedup_key=dedup_key_full_args)
mw = DedupHITLToolCallsMiddleware(agent_tools=[tool])
state = {
"messages": [
_msg(
{
"name": "createJiraIssue",
"args": {
"cloudId": "ws.atlassian.net",
"projectKey": "PROJ",
"summary": "Fix login bug",
},
"id": "1",
},
{
"name": "createJiraIssue",
"args": {
"cloudId": "ws.atlassian.net",
"projectKey": "PROJ",
"summary": "Add dark mode",
},
"id": "2",
},
{
"name": "createJiraIssue",
"args": {
"cloudId": "ws.atlassian.net",
"projectKey": "PROJ",
"summary": "Improve perf",
},
"id": "3",
},
)
]
}
out = mw.after_model(state, _Runtime())
assert out is None # nothing dropped — all three differ in summary
def test_full_args_dedup_drops_only_exact_duplicates() -> None:
from app.agents.new_chat.middleware.dedup_tool_calls import dedup_key_full_args
tool = _make_tool("createJiraIssue", dedup_key=dedup_key_full_args)
mw = DedupHITLToolCallsMiddleware(agent_tools=[tool])
args = {"cloudId": "ws.atlassian.net", "summary": "Fix bug"}
state = {
"messages": [
_msg(
{"name": "createJiraIssue", "args": args, "id": "1"},
{"name": "createJiraIssue", "args": dict(args), "id": "2"},
{
"name": "createJiraIssue",
"args": {**args, "summary": "Different"},
"id": "3",
},
)
]
}
out = mw.after_model(state, _Runtime())
assert out is not None
new_calls = out["messages"][0].tool_calls
assert {c["id"] for c in new_calls} == {"1", "3"}
def test_unknown_tool_passes_through() -> None:
mw = DedupHITLToolCallsMiddleware(agent_tools=None)
state = {

View file

@ -6,7 +6,10 @@ import pytest
from langchain_core.messages import AIMessage, ToolMessage
from app.agents.new_chat.errors import CorrectedError, RejectedError
from app.agents.new_chat.middleware.permission import PermissionMiddleware
from app.agents.new_chat.middleware.permission import (
PermissionMiddleware,
_normalize_permission_decision,
)
from app.agents.new_chat.permissions import Rule, Ruleset
pytestmark = pytest.mark.unit
@ -112,3 +115,151 @@ class TestAsk:
# Runtime ruleset got the always-allow rule
new_rules = [r for r in mw._runtime_ruleset.rules if r.action == "allow"]
assert any(r.permission == "send_email" for r in new_rules)
class TestNormalizeDecision:
"""Resume shapes ``_normalize_permission_decision`` must accept."""
def test_legacy_decision_type_dict_passes_through(self) -> None:
decision = {"decision_type": "once"}
assert _normalize_permission_decision(decision) == {"decision_type": "once"}
def test_legacy_decision_type_with_feedback_passes_through(self) -> None:
decision = {"decision_type": "reject", "feedback": "no thanks"}
assert _normalize_permission_decision(decision) == decision
def test_plain_string_wrapped(self) -> None:
assert _normalize_permission_decision("once") == {"decision_type": "once"}
assert _normalize_permission_decision("reject") == {"decision_type": "reject"}
def test_lc_envelope_approve_maps_to_once(self) -> None:
decision = {"decisions": [{"type": "approve"}]}
assert _normalize_permission_decision(decision) == {"decision_type": "once"}
def test_lc_envelope_reject_maps_to_reject(self) -> None:
decision = {"decisions": [{"type": "reject"}]}
assert _normalize_permission_decision(decision) == {"decision_type": "reject"}
def test_lc_envelope_reject_with_message_carries_feedback(self) -> None:
decision = {
"decisions": [{"type": "reject", "message": "wrong recipient"}]
}
out = _normalize_permission_decision(decision)
assert out == {"decision_type": "reject", "feedback": "wrong recipient"}
def test_lc_envelope_reject_with_feedback_field(self) -> None:
decision = {
"decisions": [{"type": "reject", "feedback": "tighten the subject"}]
}
out = _normalize_permission_decision(decision)
assert out == {"decision_type": "reject", "feedback": "tighten the subject"}
def test_lc_envelope_edit_maps_to_once(self) -> None:
# Pins the contract: edited args are NOT merged by permission.
decision = {
"decisions": [
{
"type": "edit",
"edited_action": {
"name": "send_email",
"args": {"subject": "edited"},
},
}
]
}
assert _normalize_permission_decision(decision) == {"decision_type": "once"}
def test_lc_single_decision_without_envelope(self) -> None:
assert _normalize_permission_decision({"type": "approve"}) == {
"decision_type": "once"
}
def test_unknown_type_falls_back_to_reject(self) -> None:
decision = {"decisions": [{"type": "totally_unknown"}]}
assert _normalize_permission_decision(decision) == {"decision_type": "reject"}
def test_missing_type_falls_back_to_reject(self) -> None:
assert _normalize_permission_decision({"decisions": [{}]}) == {
"decision_type": "reject"
}
def test_non_dict_non_string_falls_back_to_reject(self) -> None:
assert _normalize_permission_decision(None) == {"decision_type": "reject"}
assert _normalize_permission_decision(42) == {"decision_type": "reject"}
def test_empty_decisions_list_falls_back_to_reject(self) -> None:
# Fail-closed on a malformed reply rather than treat it as approve.
assert _normalize_permission_decision({"decisions": []}) == {
"decision_type": "reject"
}
class TestResumeShapesEndToEnd:
"""LangChain HITL envelope reaches ``_process`` correctly via ``_raise_interrupt``."""
def test_lc_approve_envelope_keeps_call(self) -> None:
mw = PermissionMiddleware(rulesets=[])
mw._raise_interrupt = lambda **kw: { # type: ignore[assignment]
"decisions": [{"type": "approve"}]
}
state = {"messages": [_msg({"name": "send_email", "args": {}, "id": "1"})]}
original = mw._raise_interrupt
mw._raise_interrupt = lambda **kw: _normalize_permission_decision( # type: ignore[assignment]
original(**kw)
)
out = mw.after_model(state, _FakeRuntime())
assert out is None
def test_lc_reject_envelope_raises(self) -> None:
mw = PermissionMiddleware(rulesets=[])
original = lambda **kw: {"decisions": [{"type": "reject"}]} # noqa: E731
mw._raise_interrupt = lambda **kw: _normalize_permission_decision( # type: ignore[assignment]
original(**kw)
)
state = {"messages": [_msg({"name": "send_email", "args": {}, "id": "1"})]}
with pytest.raises(RejectedError):
mw.after_model(state, _FakeRuntime())
def test_lc_reject_with_message_raises_corrected(self) -> None:
mw = PermissionMiddleware(rulesets=[])
original = lambda **kw: { # noqa: E731
"decisions": [{"type": "reject", "message": "wrong recipient"}]
}
mw._raise_interrupt = lambda **kw: _normalize_permission_decision( # type: ignore[assignment]
original(**kw)
)
state = {"messages": [_msg({"name": "send_email", "args": {}, "id": "1"})]}
with pytest.raises(CorrectedError) as excinfo:
mw.after_model(state, _FakeRuntime())
assert excinfo.value.feedback == "wrong recipient"
def test_lc_edit_envelope_keeps_call_with_original_args(self) -> None:
# Pins the "edit -> once, args unchanged" contract.
mw = PermissionMiddleware(rulesets=[])
original = lambda **kw: { # noqa: E731
"decisions": [
{
"type": "edit",
"edited_action": {
"name": "send_email",
"args": {"to": "edited@example.com"},
},
}
]
}
mw._raise_interrupt = lambda **kw: _normalize_permission_decision( # type: ignore[assignment]
original(**kw)
)
state = {
"messages": [
_msg(
{
"name": "send_email",
"args": {"to": "original@example.com"},
"id": "1",
}
)
]
}
out = mw.after_model(state, _FakeRuntime())
assert out is None