mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-06 14:22:47 +02:00
Harden HITL for multi-step tasks: bypass internal MCP gate, full-args dedup, and decision-envelope normalization.
This commit is contained in:
parent
4ac3f0b304
commit
277bd50f37
6 changed files with 442 additions and 65 deletions
|
|
@ -130,6 +130,79 @@ def test_registry_propagates_dedup_key_to_tool_metadata() -> None:
|
|||
assert sample == "plan"
|
||||
|
||||
|
||||
def test_full_args_dedup_keeps_distinct_calls_sharing_a_field() -> None:
|
||||
"""Regression: MCP tools (e.g. ``createJiraIssue``) used to dedup on
|
||||
the schema's first required field, which is often the workspace /
|
||||
cloudId — so 3 distinct issues in the same workspace collapsed to 1.
|
||||
|
||||
With :func:`dedup_key_full_args` only fully identical arg dicts dedup.
|
||||
"""
|
||||
from app.agents.new_chat.middleware.dedup_tool_calls import dedup_key_full_args
|
||||
|
||||
tool = _make_tool("createJiraIssue", dedup_key=dedup_key_full_args)
|
||||
mw = DedupHITLToolCallsMiddleware(agent_tools=[tool])
|
||||
state = {
|
||||
"messages": [
|
||||
_msg(
|
||||
{
|
||||
"name": "createJiraIssue",
|
||||
"args": {
|
||||
"cloudId": "ws.atlassian.net",
|
||||
"projectKey": "PROJ",
|
||||
"summary": "Fix login bug",
|
||||
},
|
||||
"id": "1",
|
||||
},
|
||||
{
|
||||
"name": "createJiraIssue",
|
||||
"args": {
|
||||
"cloudId": "ws.atlassian.net",
|
||||
"projectKey": "PROJ",
|
||||
"summary": "Add dark mode",
|
||||
},
|
||||
"id": "2",
|
||||
},
|
||||
{
|
||||
"name": "createJiraIssue",
|
||||
"args": {
|
||||
"cloudId": "ws.atlassian.net",
|
||||
"projectKey": "PROJ",
|
||||
"summary": "Improve perf",
|
||||
},
|
||||
"id": "3",
|
||||
},
|
||||
)
|
||||
]
|
||||
}
|
||||
out = mw.after_model(state, _Runtime())
|
||||
assert out is None # nothing dropped — all three differ in summary
|
||||
|
||||
|
||||
def test_full_args_dedup_drops_only_exact_duplicates() -> None:
|
||||
from app.agents.new_chat.middleware.dedup_tool_calls import dedup_key_full_args
|
||||
|
||||
tool = _make_tool("createJiraIssue", dedup_key=dedup_key_full_args)
|
||||
mw = DedupHITLToolCallsMiddleware(agent_tools=[tool])
|
||||
args = {"cloudId": "ws.atlassian.net", "summary": "Fix bug"}
|
||||
state = {
|
||||
"messages": [
|
||||
_msg(
|
||||
{"name": "createJiraIssue", "args": args, "id": "1"},
|
||||
{"name": "createJiraIssue", "args": dict(args), "id": "2"},
|
||||
{
|
||||
"name": "createJiraIssue",
|
||||
"args": {**args, "summary": "Different"},
|
||||
"id": "3",
|
||||
},
|
||||
)
|
||||
]
|
||||
}
|
||||
out = mw.after_model(state, _Runtime())
|
||||
assert out is not None
|
||||
new_calls = out["messages"][0].tool_calls
|
||||
assert {c["id"] for c in new_calls} == {"1", "3"}
|
||||
|
||||
|
||||
def test_unknown_tool_passes_through() -> None:
|
||||
mw = DedupHITLToolCallsMiddleware(agent_tools=None)
|
||||
state = {
|
||||
|
|
|
|||
|
|
@ -6,7 +6,10 @@ import pytest
|
|||
from langchain_core.messages import AIMessage, ToolMessage
|
||||
|
||||
from app.agents.new_chat.errors import CorrectedError, RejectedError
|
||||
from app.agents.new_chat.middleware.permission import PermissionMiddleware
|
||||
from app.agents.new_chat.middleware.permission import (
|
||||
PermissionMiddleware,
|
||||
_normalize_permission_decision,
|
||||
)
|
||||
from app.agents.new_chat.permissions import Rule, Ruleset
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
|
@ -112,3 +115,151 @@ class TestAsk:
|
|||
# Runtime ruleset got the always-allow rule
|
||||
new_rules = [r for r in mw._runtime_ruleset.rules if r.action == "allow"]
|
||||
assert any(r.permission == "send_email" for r in new_rules)
|
||||
|
||||
|
||||
class TestNormalizeDecision:
|
||||
"""Resume shapes ``_normalize_permission_decision`` must accept."""
|
||||
|
||||
def test_legacy_decision_type_dict_passes_through(self) -> None:
|
||||
decision = {"decision_type": "once"}
|
||||
assert _normalize_permission_decision(decision) == {"decision_type": "once"}
|
||||
|
||||
def test_legacy_decision_type_with_feedback_passes_through(self) -> None:
|
||||
decision = {"decision_type": "reject", "feedback": "no thanks"}
|
||||
assert _normalize_permission_decision(decision) == decision
|
||||
|
||||
def test_plain_string_wrapped(self) -> None:
|
||||
assert _normalize_permission_decision("once") == {"decision_type": "once"}
|
||||
assert _normalize_permission_decision("reject") == {"decision_type": "reject"}
|
||||
|
||||
def test_lc_envelope_approve_maps_to_once(self) -> None:
|
||||
decision = {"decisions": [{"type": "approve"}]}
|
||||
assert _normalize_permission_decision(decision) == {"decision_type": "once"}
|
||||
|
||||
def test_lc_envelope_reject_maps_to_reject(self) -> None:
|
||||
decision = {"decisions": [{"type": "reject"}]}
|
||||
assert _normalize_permission_decision(decision) == {"decision_type": "reject"}
|
||||
|
||||
def test_lc_envelope_reject_with_message_carries_feedback(self) -> None:
|
||||
decision = {
|
||||
"decisions": [{"type": "reject", "message": "wrong recipient"}]
|
||||
}
|
||||
out = _normalize_permission_decision(decision)
|
||||
assert out == {"decision_type": "reject", "feedback": "wrong recipient"}
|
||||
|
||||
def test_lc_envelope_reject_with_feedback_field(self) -> None:
|
||||
decision = {
|
||||
"decisions": [{"type": "reject", "feedback": "tighten the subject"}]
|
||||
}
|
||||
out = _normalize_permission_decision(decision)
|
||||
assert out == {"decision_type": "reject", "feedback": "tighten the subject"}
|
||||
|
||||
def test_lc_envelope_edit_maps_to_once(self) -> None:
|
||||
# Pins the contract: edited args are NOT merged by permission.
|
||||
decision = {
|
||||
"decisions": [
|
||||
{
|
||||
"type": "edit",
|
||||
"edited_action": {
|
||||
"name": "send_email",
|
||||
"args": {"subject": "edited"},
|
||||
},
|
||||
}
|
||||
]
|
||||
}
|
||||
assert _normalize_permission_decision(decision) == {"decision_type": "once"}
|
||||
|
||||
def test_lc_single_decision_without_envelope(self) -> None:
|
||||
assert _normalize_permission_decision({"type": "approve"}) == {
|
||||
"decision_type": "once"
|
||||
}
|
||||
|
||||
def test_unknown_type_falls_back_to_reject(self) -> None:
|
||||
decision = {"decisions": [{"type": "totally_unknown"}]}
|
||||
assert _normalize_permission_decision(decision) == {"decision_type": "reject"}
|
||||
|
||||
def test_missing_type_falls_back_to_reject(self) -> None:
|
||||
assert _normalize_permission_decision({"decisions": [{}]}) == {
|
||||
"decision_type": "reject"
|
||||
}
|
||||
|
||||
def test_non_dict_non_string_falls_back_to_reject(self) -> None:
|
||||
assert _normalize_permission_decision(None) == {"decision_type": "reject"}
|
||||
assert _normalize_permission_decision(42) == {"decision_type": "reject"}
|
||||
|
||||
def test_empty_decisions_list_falls_back_to_reject(self) -> None:
|
||||
# Fail-closed on a malformed reply rather than treat it as approve.
|
||||
assert _normalize_permission_decision({"decisions": []}) == {
|
||||
"decision_type": "reject"
|
||||
}
|
||||
|
||||
|
||||
class TestResumeShapesEndToEnd:
|
||||
"""LangChain HITL envelope reaches ``_process`` correctly via ``_raise_interrupt``."""
|
||||
|
||||
def test_lc_approve_envelope_keeps_call(self) -> None:
|
||||
mw = PermissionMiddleware(rulesets=[])
|
||||
mw._raise_interrupt = lambda **kw: { # type: ignore[assignment]
|
||||
"decisions": [{"type": "approve"}]
|
||||
}
|
||||
state = {"messages": [_msg({"name": "send_email", "args": {}, "id": "1"})]}
|
||||
original = mw._raise_interrupt
|
||||
mw._raise_interrupt = lambda **kw: _normalize_permission_decision( # type: ignore[assignment]
|
||||
original(**kw)
|
||||
)
|
||||
out = mw.after_model(state, _FakeRuntime())
|
||||
assert out is None
|
||||
|
||||
def test_lc_reject_envelope_raises(self) -> None:
|
||||
mw = PermissionMiddleware(rulesets=[])
|
||||
original = lambda **kw: {"decisions": [{"type": "reject"}]} # noqa: E731
|
||||
mw._raise_interrupt = lambda **kw: _normalize_permission_decision( # type: ignore[assignment]
|
||||
original(**kw)
|
||||
)
|
||||
state = {"messages": [_msg({"name": "send_email", "args": {}, "id": "1"})]}
|
||||
with pytest.raises(RejectedError):
|
||||
mw.after_model(state, _FakeRuntime())
|
||||
|
||||
def test_lc_reject_with_message_raises_corrected(self) -> None:
|
||||
mw = PermissionMiddleware(rulesets=[])
|
||||
original = lambda **kw: { # noqa: E731
|
||||
"decisions": [{"type": "reject", "message": "wrong recipient"}]
|
||||
}
|
||||
mw._raise_interrupt = lambda **kw: _normalize_permission_decision( # type: ignore[assignment]
|
||||
original(**kw)
|
||||
)
|
||||
state = {"messages": [_msg({"name": "send_email", "args": {}, "id": "1"})]}
|
||||
with pytest.raises(CorrectedError) as excinfo:
|
||||
mw.after_model(state, _FakeRuntime())
|
||||
assert excinfo.value.feedback == "wrong recipient"
|
||||
|
||||
def test_lc_edit_envelope_keeps_call_with_original_args(self) -> None:
|
||||
# Pins the "edit -> once, args unchanged" contract.
|
||||
mw = PermissionMiddleware(rulesets=[])
|
||||
original = lambda **kw: { # noqa: E731
|
||||
"decisions": [
|
||||
{
|
||||
"type": "edit",
|
||||
"edited_action": {
|
||||
"name": "send_email",
|
||||
"args": {"to": "edited@example.com"},
|
||||
},
|
||||
}
|
||||
]
|
||||
}
|
||||
mw._raise_interrupt = lambda **kw: _normalize_permission_decision( # type: ignore[assignment]
|
||||
original(**kw)
|
||||
)
|
||||
state = {
|
||||
"messages": [
|
||||
_msg(
|
||||
{
|
||||
"name": "send_email",
|
||||
"args": {"to": "original@example.com"},
|
||||
"id": "1",
|
||||
}
|
||||
)
|
||||
]
|
||||
}
|
||||
out = mw.after_model(state, _FakeRuntime())
|
||||
assert out is None
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue