mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-19 18:45:15 +02:00
multi_agent_chat: real-graph regressions for unified HITL paths + format pass
This commit is contained in:
parent
adb52fb575
commit
0723702320
34 changed files with 920 additions and 90 deletions
|
|
@ -0,0 +1,272 @@
|
|||
"""Real-graph parallel HITL across both approval kinds — the keystone regression.
|
||||
|
||||
Pre-fix bug: the parallel-HITL routing layer (``collect_pending_tool_calls``
|
||||
+ ``slice_decisions_by_tool_call`` + ``build_lg_resume_map``) only
|
||||
recognized middleware-gated approvals (LC HITL shape from
|
||||
``HumanInTheLoopMiddleware``). Self-gated approvals from
|
||||
``request_approval`` and middleware-gated permission asks from
|
||||
``PermissionMiddleware`` both used the SurfSense-specific
|
||||
``{type, action, context}`` shape, so when the orchestrator dispatched
|
||||
two parallel ``task`` calls — one self-gated, one middleware-gated — only
|
||||
one interrupt was visible to the routing layer and resume crashed with
|
||||
``Decision count mismatch``.
|
||||
|
||||
This test fans out two real subagents via ``Send``: one calls
|
||||
``request_approval`` (self-gated), the other calls
|
||||
``request_permission_decision`` (middleware-gated). Both pause; the routing
|
||||
layer must see TWO LC HITL interrupts, slice the decisions by
|
||||
``tool_call_id``, key by ``Interrupt.id``, and resume both branches with
|
||||
their per-slice payload.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Annotated
|
||||
|
||||
import pytest
|
||||
from langchain.tools import ToolRuntime
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
from langgraph.graph import END, START, StateGraph
|
||||
from langgraph.graph.message import add_messages
|
||||
from langgraph.types import Command, Send
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from app.agents.multi_agent_chat.middleware.main_agent.checkpointed_subagent_middleware.resume_routing import (
|
||||
build_lg_resume_map,
|
||||
collect_pending_tool_calls,
|
||||
slice_decisions_by_tool_call,
|
||||
)
|
||||
from app.agents.multi_agent_chat.middleware.main_agent.checkpointed_subagent_middleware.task_tool import (
|
||||
build_task_tool_with_parent_config,
|
||||
)
|
||||
from app.agents.multi_agent_chat.middleware.shared.permissions.ask.request import (
|
||||
request_permission_decision,
|
||||
)
|
||||
from app.agents.multi_agent_chat.subagents.shared.hitl.approvals.self_gated import (
|
||||
request_approval,
|
||||
)
|
||||
from app.agents.new_chat.permissions import Rule
|
||||
|
||||
|
||||
class _SubState(TypedDict, total=False):
|
||||
messages: list
|
||||
|
||||
|
||||
class _DispatchState(TypedDict, total=False):
|
||||
# ``add_messages`` is mandatory: parallel ``Send`` branches both append
|
||||
# to ``messages`` in the same superstep; without a reducer langgraph
|
||||
# raises ``InvalidUpdateError``.
|
||||
messages: Annotated[list, add_messages]
|
||||
tcid: str
|
||||
desc: str
|
||||
subtype: str
|
||||
|
||||
|
||||
def _build_self_gated_subagent(checkpointer: InMemorySaver):
|
||||
"""Subagent that pauses via ``request_approval`` (self-gated path)."""
|
||||
|
||||
def gate_node(_state):
|
||||
result = request_approval(
|
||||
action_type="gmail_email_send",
|
||||
tool_name="send_gmail_email",
|
||||
params={"to": "alice@example.com"},
|
||||
)
|
||||
return {
|
||||
"messages": [
|
||||
AIMessage(
|
||||
content=json.dumps(
|
||||
{
|
||||
"kind": "self_gated",
|
||||
"decision_type": result.decision_type,
|
||||
"params": result.params,
|
||||
"rejected": result.rejected,
|
||||
},
|
||||
sort_keys=True,
|
||||
)
|
||||
)
|
||||
]
|
||||
}
|
||||
|
||||
g = StateGraph(_SubState)
|
||||
g.add_node("gate", gate_node)
|
||||
g.add_edge(START, "gate")
|
||||
g.add_edge("gate", END)
|
||||
return g.compile(checkpointer=checkpointer)
|
||||
|
||||
|
||||
def _build_middleware_gated_subagent(checkpointer: InMemorySaver):
|
||||
"""Subagent that pauses via ``request_permission_decision`` (middleware-gated path)."""
|
||||
|
||||
def perm_node(_state):
|
||||
decision = request_permission_decision(
|
||||
tool_name="rm",
|
||||
args={"path": "/tmp/file"},
|
||||
patterns=["rm/*"],
|
||||
rules=[Rule(permission="rm", pattern="*", action="ask")],
|
||||
emit_interrupt=True,
|
||||
)
|
||||
return {
|
||||
"messages": [
|
||||
AIMessage(
|
||||
content=json.dumps(
|
||||
{"kind": "middleware_gated", "decision": decision},
|
||||
sort_keys=True,
|
||||
)
|
||||
)
|
||||
]
|
||||
}
|
||||
|
||||
g = StateGraph(_SubState)
|
||||
g.add_node("perm", perm_node)
|
||||
g.add_edge(START, "perm")
|
||||
g.add_edge("perm", END)
|
||||
return g.compile(checkpointer=checkpointer)
|
||||
|
||||
|
||||
def _build_mixed_task_tool(checkpointer: InMemorySaver):
|
||||
"""Two subagents, one per approval kind, registered under distinct names."""
|
||||
return build_task_tool_with_parent_config(
|
||||
[
|
||||
{
|
||||
"name": "self-gated-agent",
|
||||
"description": "uses request_approval",
|
||||
"runnable": _build_self_gated_subagent(checkpointer),
|
||||
},
|
||||
{
|
||||
"name": "middleware-gated-agent",
|
||||
"description": "uses request_permission_decision",
|
||||
"runnable": _build_middleware_gated_subagent(checkpointer),
|
||||
},
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def _parent_dispatching_one_of_each(
|
||||
task_tool, *, tcid_self: str, tcid_mw: str, checkpointer
|
||||
):
|
||||
def fanout_edge(_state) -> list[Send]:
|
||||
return [
|
||||
Send(
|
||||
"call_task",
|
||||
{"tcid": tcid_self, "desc": "approve email", "subtype": "self-gated-agent"},
|
||||
),
|
||||
Send(
|
||||
"call_task",
|
||||
{
|
||||
"tcid": tcid_mw,
|
||||
"desc": "approve rm",
|
||||
"subtype": "middleware-gated-agent",
|
||||
},
|
||||
),
|
||||
]
|
||||
|
||||
async def call_task(state: _DispatchState, config: RunnableConfig):
|
||||
rt = ToolRuntime(
|
||||
state=state,
|
||||
config=config,
|
||||
context=None,
|
||||
stream_writer=None,
|
||||
tool_call_id=state["tcid"],
|
||||
store=None,
|
||||
)
|
||||
return await task_tool.coroutine(
|
||||
description=state["desc"], subagent_type=state["subtype"], runtime=rt
|
||||
)
|
||||
|
||||
g = StateGraph(_DispatchState)
|
||||
g.add_node("call_task", call_task)
|
||||
g.add_conditional_edges(START, fanout_edge, ["call_task"])
|
||||
g.add_edge("call_task", END)
|
||||
return g.compile(checkpointer=checkpointer)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_parallel_self_gated_and_middleware_gated_route_and_resume_cleanly():
|
||||
"""Both interrupt kinds must reach the routing layer in LC HITL shape and resume independently."""
|
||||
checkpointer = InMemorySaver()
|
||||
task_tool = _build_mixed_task_tool(checkpointer)
|
||||
tcid_self = "tcid-self-gated"
|
||||
tcid_mw = "tcid-middleware-gated"
|
||||
parent = _parent_dispatching_one_of_each(
|
||||
task_tool,
|
||||
tcid_self=tcid_self,
|
||||
tcid_mw=tcid_mw,
|
||||
checkpointer=checkpointer,
|
||||
)
|
||||
config: dict = {
|
||||
"configurable": {"thread_id": "mixed-parallel"},
|
||||
"recursion_limit": 100,
|
||||
}
|
||||
await parent.ainvoke({"messages": [HumanMessage(content="seed")]}, config)
|
||||
|
||||
paused = await parent.aget_state(config)
|
||||
assert len(paused.interrupts) == 2, (
|
||||
"fixture broken: expected one paused interrupt per approval kind"
|
||||
)
|
||||
|
||||
# Both interrupts must speak the same wire shape — the whole point of
|
||||
# the unification. If either one regresses to the legacy SurfSense shape
|
||||
# ``collect_pending_tool_calls`` would silently skip it and the count
|
||||
# below would be 1.
|
||||
pending = collect_pending_tool_calls(paused)
|
||||
assert dict(pending) == {tcid_self: 1, tcid_mw: 1}, (
|
||||
f"REGRESSION: not all interrupt kinds reached the routing layer; "
|
||||
f"got {pending!r}"
|
||||
)
|
||||
|
||||
# Verify the actual wire payloads carry the LC HITL standard fields
|
||||
# (extra defensive assertion against partial regressions where one
|
||||
# path stamps tool_call_id but reverts the body shape).
|
||||
interrupt_types = {i.value.get("interrupt_type") for i in paused.interrupts}
|
||||
assert interrupt_types == {"gmail_email_send", "permission_ask"}
|
||||
|
||||
# Resume order: same order the SSE stream would emit (interrupts list).
|
||||
decision_self = {"type": "approve"}
|
||||
decision_mw = {"type": "always"}
|
||||
flat_decisions = [
|
||||
# Match `pending` order.
|
||||
decision_self if pending[0][0] == tcid_self else decision_mw,
|
||||
decision_mw if pending[0][0] == tcid_self else decision_self,
|
||||
]
|
||||
by_tool_call_id = slice_decisions_by_tool_call(flat_decisions, pending)
|
||||
lg_resume_map = build_lg_resume_map(paused, by_tool_call_id)
|
||||
assert len(lg_resume_map) == 2
|
||||
|
||||
config["configurable"]["surfsense_resume_value"] = by_tool_call_id
|
||||
await parent.ainvoke(Command(resume=lg_resume_map), config)
|
||||
|
||||
final = await parent.aget_state(config)
|
||||
assert not final.interrupts, (
|
||||
f"expected both branches resumed, but state still has interrupts: "
|
||||
f"{final.interrupts!r}"
|
||||
)
|
||||
|
||||
# Each subagent must have received its own slice — verify by inspecting
|
||||
# the JSON-serialized result messages.
|
||||
payloads: list[dict] = []
|
||||
for msg in final.values.get("messages", []) or []:
|
||||
content = getattr(msg, "content", None)
|
||||
if isinstance(content, str):
|
||||
try:
|
||||
payloads.append(json.loads(content))
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
self_payloads = [p for p in payloads if p.get("kind") == "self_gated"]
|
||||
mw_payloads = [p for p in payloads if p.get("kind") == "middleware_gated"]
|
||||
assert len(self_payloads) == 1, (
|
||||
f"self-gated subagent did not complete; payloads: {payloads!r}"
|
||||
)
|
||||
assert len(mw_payloads) == 1, (
|
||||
f"middleware-gated subagent did not complete; payloads: {payloads!r}"
|
||||
)
|
||||
|
||||
# Self-gated approve → HITLResult(decision_type="approve", rejected=False).
|
||||
assert self_payloads[0]["decision_type"] == "approve"
|
||||
assert self_payloads[0]["rejected"] is False
|
||||
|
||||
# Middleware-gated always → canonical permission shape with always.
|
||||
assert mw_payloads[0]["decision"] == {"decision_type": "always"}
|
||||
|
|
@ -0,0 +1,125 @@
|
|||
"""Regression: ``request_permission_decision`` must emit the unified LC HITL wire shape.
|
||||
|
||||
Same bug class as :mod:`test_lc_hitl_wire` for self-gated approvals: the
|
||||
permission middleware previously fired the SurfSense-specific
|
||||
``{type, action, context}`` shape, which the parallel-HITL routing layer
|
||||
does not recognize. Standardizing on LC HITL keeps every approval kind on
|
||||
one routing path.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
from langgraph.graph import END, START, StateGraph
|
||||
from langgraph.types import Command
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from app.agents.multi_agent_chat.middleware.shared.permissions.ask.request import (
|
||||
request_permission_decision,
|
||||
)
|
||||
from app.agents.new_chat.permissions import Rule
|
||||
|
||||
|
||||
class _State(TypedDict, total=False):
|
||||
messages: list
|
||||
final_decision: dict
|
||||
|
||||
|
||||
def _build_graph_calling_request_permission_decision(checkpointer: InMemorySaver):
|
||||
"""Real graph whose only node delegates to the permission ask primitive."""
|
||||
|
||||
def perm_node(_state):
|
||||
decision = request_permission_decision(
|
||||
tool_name="rm",
|
||||
args={"path": "/tmp/file"},
|
||||
patterns=["rm/*"],
|
||||
rules=[Rule(permission="rm", pattern="*", action="ask")],
|
||||
emit_interrupt=True,
|
||||
)
|
||||
return {"final_decision": decision}
|
||||
|
||||
g = StateGraph(_State)
|
||||
g.add_node("perm", perm_node)
|
||||
g.add_edge(START, "perm")
|
||||
g.add_edge("perm", END)
|
||||
return g.compile(checkpointer=checkpointer)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_permission_ask_payload_uses_lc_hitl_shape():
|
||||
"""The permission middleware now speaks the langchain HITL standard shape."""
|
||||
checkpointer = InMemorySaver()
|
||||
graph = _build_graph_calling_request_permission_decision(checkpointer)
|
||||
config = {"configurable": {"thread_id": "perm-wire"}}
|
||||
await graph.ainvoke({"messages": [HumanMessage(content="seed")]}, config)
|
||||
|
||||
snap = await graph.aget_state(config)
|
||||
assert len(snap.interrupts) == 1
|
||||
value = snap.interrupts[0].value
|
||||
|
||||
assert value.get("action_requests") == [
|
||||
{"name": "rm", "args": {"path": "/tmp/file"}}
|
||||
], f"REGRESSION: permission ask reverted to legacy shape; got {value!r}"
|
||||
review = value.get("review_configs")
|
||||
assert isinstance(review, list) and len(review) == 1
|
||||
# ``always`` must be in the palette so the FE can render the promote button.
|
||||
assert "always" in review[0]["allowed_decisions"]
|
||||
assert value.get("interrupt_type") == "permission_ask"
|
||||
# SurfSense context rides through verbatim for FE explainability.
|
||||
assert value["context"]["patterns"] == ["rm/*"]
|
||||
assert value["context"]["always"] == ["rm/*"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resume_with_approve_envelope_returns_once_decision():
|
||||
"""``approve`` from the LC envelope projects to permission-domain ``once``."""
|
||||
checkpointer = InMemorySaver()
|
||||
graph = _build_graph_calling_request_permission_decision(checkpointer)
|
||||
config = {"configurable": {"thread_id": "perm-once"}}
|
||||
await graph.ainvoke({"messages": [HumanMessage(content="seed")]}, config)
|
||||
|
||||
await graph.ainvoke(
|
||||
Command(resume={"decisions": [{"type": "approve"}]}), config
|
||||
)
|
||||
final = await graph.aget_state(config)
|
||||
assert final.values.get("final_decision") == {"decision_type": "once"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resume_with_always_envelope_projects_to_always():
|
||||
"""``always`` reply must project unchanged so the middleware can promote the rule."""
|
||||
checkpointer = InMemorySaver()
|
||||
graph = _build_graph_calling_request_permission_decision(checkpointer)
|
||||
config = {"configurable": {"thread_id": "perm-always"}}
|
||||
await graph.ainvoke({"messages": [HumanMessage(content="seed")]}, config)
|
||||
|
||||
await graph.ainvoke(
|
||||
Command(resume={"decisions": [{"type": "always"}]}), config
|
||||
)
|
||||
final = await graph.aget_state(config)
|
||||
assert final.values.get("final_decision") == {"decision_type": "always"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resume_with_reject_and_feedback_carries_feedback_through():
|
||||
"""Reject feedback must survive normalization for ``CorrectedError`` to fire downstream."""
|
||||
checkpointer = InMemorySaver()
|
||||
graph = _build_graph_calling_request_permission_decision(checkpointer)
|
||||
config = {"configurable": {"thread_id": "perm-reject"}}
|
||||
await graph.ainvoke({"messages": [HumanMessage(content="seed")]}, config)
|
||||
|
||||
await graph.ainvoke(
|
||||
Command(
|
||||
resume={
|
||||
"decisions": [{"type": "reject", "feedback": "use the trash bin"}]
|
||||
}
|
||||
),
|
||||
config,
|
||||
)
|
||||
final = await graph.aget_state(config)
|
||||
assert final.values.get("final_decision") == {
|
||||
"decision_type": "reject",
|
||||
"feedback": "use the trash bin",
|
||||
}
|
||||
|
|
@ -0,0 +1,169 @@
|
|||
"""Regression: subagent-owned rulesets layer cleanly into ``PermissionMiddleware``.
|
||||
|
||||
The KB unification swap (legacy ``interrupt_on`` map → KB-owned ``Ruleset``
|
||||
threaded through ``build_permission_mw(extra_rulesets=...)``) must produce
|
||||
*exactly one* interrupt per destructive FS call, in LC HITL shape, even
|
||||
when ``enable_permission`` is False — destructive ops always ask.
|
||||
|
||||
We exercise the production factory and a real ``PermissionMiddleware`` on a
|
||||
real ``StateGraph`` so the test catches regressions in factory gating,
|
||||
ruleset layering, and interrupt emission together.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Annotated, Any
|
||||
|
||||
import pytest
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
from langgraph.graph import END, START, StateGraph
|
||||
from langgraph.graph.message import add_messages
|
||||
from langgraph.types import Command
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from app.agents.multi_agent_chat.middleware.shared.permissions import (
|
||||
build_permission_mw,
|
||||
)
|
||||
from app.agents.new_chat.feature_flags import AgentFeatureFlags
|
||||
from app.agents.new_chat.permissions import Rule, Ruleset
|
||||
|
||||
|
||||
def _kb_style_ruleset() -> Ruleset:
|
||||
"""Mirror :data:`knowledge_base.agent.KB_RULESET` without importing it.
|
||||
|
||||
Importing the agent module pulls in deepagents and prompts; this test
|
||||
is about the factory + middleware contract, not KB wiring.
|
||||
"""
|
||||
return Ruleset(
|
||||
origin="knowledge_base",
|
||||
rules=[
|
||||
Rule(permission="rm", pattern="*", action="ask"),
|
||||
Rule(permission="rmdir", pattern="*", action="ask"),
|
||||
Rule(permission="move_file", pattern="*", action="ask"),
|
||||
Rule(permission="edit_file", pattern="*", action="ask"),
|
||||
Rule(permission="write_file", pattern="*", action="ask"),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
class _State(TypedDict, total=False):
|
||||
messages: Annotated[list, add_messages]
|
||||
|
||||
|
||||
def _build_graph_with_permission_middleware(
|
||||
*,
|
||||
flags: AgentFeatureFlags,
|
||||
extra_rulesets: list[Ruleset] | None,
|
||||
checkpointer: InMemorySaver,
|
||||
):
|
||||
"""Compile a one-node graph that emits a tool call for ``rm`` and
|
||||
routes through the production ``PermissionMiddleware``.
|
||||
|
||||
The node returns an ``AIMessage`` with a tool call. The middleware's
|
||||
``after_model`` hook intercepts and (if a rule says ``ask``) raises
|
||||
a ``GraphInterrupt`` carrying the LC HITL payload.
|
||||
"""
|
||||
pm = build_permission_mw(flags=flags, extra_rulesets=extra_rulesets)
|
||||
|
||||
def node(_state: _State) -> dict[str, Any]:
|
||||
msg = AIMessage(
|
||||
content="",
|
||||
tool_calls=[
|
||||
{
|
||||
"name": "rm",
|
||||
"args": {"path": "/tmp/foo"},
|
||||
"id": "call-rm-1",
|
||||
"type": "tool_call",
|
||||
}
|
||||
],
|
||||
)
|
||||
return {"messages": [msg]}
|
||||
|
||||
def after_node(state: _State) -> dict[str, Any] | None:
|
||||
if pm is None:
|
||||
return None
|
||||
# PermissionMiddleware._process ignores runtime; the test never relies
|
||||
# on the runtime context, so passing None keeps the harness lean.
|
||||
return pm._process(state, None) # type: ignore[arg-type]
|
||||
|
||||
g = StateGraph(_State)
|
||||
g.add_node("emit", node)
|
||||
g.add_node("permission", after_node)
|
||||
g.add_edge(START, "emit")
|
||||
g.add_edge("emit", "permission")
|
||||
g.add_edge("permission", END)
|
||||
return g.compile(checkpointer=checkpointer), pm
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_kb_ruleset_raises_one_lc_hitl_ask_for_rm_even_when_permission_flag_off():
|
||||
"""KB ruleset: ``rm`` must ask once even with ``enable_permission=False``.
|
||||
|
||||
This is the keystone of the unification: the legacy ``interrupt_on``
|
||||
map fired regardless of ``enable_permission``, so the migrated rules
|
||||
must too. Otherwise users could opt out of "ask before rm".
|
||||
"""
|
||||
flags = AgentFeatureFlags(enable_permission=False)
|
||||
checkpointer = InMemorySaver()
|
||||
graph, pm = _build_graph_with_permission_middleware(
|
||||
flags=flags,
|
||||
extra_rulesets=[_kb_style_ruleset()],
|
||||
checkpointer=checkpointer,
|
||||
)
|
||||
assert pm is not None, "extras must force the middleware on"
|
||||
|
||||
config = {"configurable": {"thread_id": "kb-cloud-rm"}}
|
||||
await graph.ainvoke({"messages": [HumanMessage(content="seed")]}, config)
|
||||
|
||||
snap = await graph.aget_state(config)
|
||||
assert len(snap.interrupts) == 1, (
|
||||
f"REGRESSION: KB ruleset should raise exactly one interrupt; got "
|
||||
f"{[i.value for i in snap.interrupts]!r}"
|
||||
)
|
||||
payload = snap.interrupts[0].value
|
||||
requests = payload.get("action_requests")
|
||||
assert requests == [{"name": "rm", "args": {"path": "/tmp/foo"}}], (
|
||||
f"interrupt must carry the rm call in LC HITL shape; got {payload!r}"
|
||||
)
|
||||
assert payload.get("interrupt_type") == "permission_ask"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_kb_ruleset_resume_with_approve_lets_rm_through():
|
||||
"""Resume with ``approve`` → call kept; the model continues normally."""
|
||||
flags = AgentFeatureFlags(enable_permission=False)
|
||||
checkpointer = InMemorySaver()
|
||||
graph, _ = _build_graph_with_permission_middleware(
|
||||
flags=flags,
|
||||
extra_rulesets=[_kb_style_ruleset()],
|
||||
checkpointer=checkpointer,
|
||||
)
|
||||
config = {"configurable": {"thread_id": "kb-cloud-rm-approve"}}
|
||||
await graph.ainvoke({"messages": [HumanMessage(content="seed")]}, config)
|
||||
|
||||
await graph.ainvoke(
|
||||
Command(resume={"decisions": [{"type": "approve"}]}), config
|
||||
)
|
||||
final = await graph.aget_state(config)
|
||||
assert final.next == (), "graph must complete after approve"
|
||||
last_ai = next(
|
||||
(m for m in reversed(final.values["messages"]) if isinstance(m, AIMessage)),
|
||||
None,
|
||||
)
|
||||
assert last_ai is not None
|
||||
assert [tc["name"] for tc in last_ai.tool_calls] == ["rm"], (
|
||||
"approved rm call must remain on the AIMessage so the tool can run"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_extras_with_permission_off_skips_middleware_entirely():
|
||||
"""No extras + permission off → factory returns ``None`` (no engine).
|
||||
|
||||
The legacy gating is preserved when no caller asks for rules: nothing
|
||||
runs, nothing pauses.
|
||||
"""
|
||||
flags = AgentFeatureFlags(enable_permission=False)
|
||||
pm = build_permission_mw(flags=flags, extra_rulesets=None)
|
||||
assert pm is None
|
||||
|
|
@ -0,0 +1,132 @@
|
|||
"""Regression: ``request_approval`` must emit the unified LC HITL wire shape.
|
||||
|
||||
Before this fix, self-gated approvals fired the SurfSense-specific
|
||||
``{type, action, context}`` shape which the parallel-HITL routing layer
|
||||
(``collect_pending_tool_calls``) does not recognize. In a parallel HITL
|
||||
scenario where one subagent used self-gated approvals (e.g. Gmail send)
|
||||
and another used middleware-gated approvals (e.g. Linear via
|
||||
``HumanInTheLoopMiddleware``), the routing layer would silently skip the
|
||||
self-gated interrupt and crash on resume with ``Decision count mismatch``.
|
||||
|
||||
This test pins the wire contract by running ``request_approval`` inside a
|
||||
real ``StateGraph`` and asserting the paused parent observes the LC HITL
|
||||
shape (``action_requests``, ``review_configs``, ``interrupt_type``).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
from langgraph.graph import END, START, StateGraph
|
||||
from langgraph.types import Command
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from app.agents.multi_agent_chat.subagents.shared.hitl.approvals.self_gated import (
|
||||
request_approval,
|
||||
)
|
||||
|
||||
|
||||
class _State(TypedDict, total=False):
|
||||
messages: list
|
||||
final_decision_type: str
|
||||
final_params: dict
|
||||
|
||||
|
||||
def _build_graph_calling_request_approval(checkpointer: InMemorySaver):
|
||||
"""A real graph whose only node delegates to ``request_approval``."""
|
||||
|
||||
def gate_node(_state):
|
||||
result = request_approval(
|
||||
action_type="gmail_email_send",
|
||||
tool_name="send_gmail_email",
|
||||
params={"to": "alice@example.com", "subject": "hi"},
|
||||
context={"account": "alice@gmail.com"},
|
||||
)
|
||||
return {
|
||||
"final_decision_type": result.decision_type,
|
||||
"final_params": result.params,
|
||||
}
|
||||
|
||||
g = StateGraph(_State)
|
||||
g.add_node("gate", gate_node)
|
||||
g.add_edge(START, "gate")
|
||||
g.add_edge("gate", END)
|
||||
return g.compile(checkpointer=checkpointer)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_paused_interrupt_uses_lc_hitl_action_requests_shape():
|
||||
"""The paused interrupt must speak the langchain HITL standard shape."""
|
||||
checkpointer = InMemorySaver()
|
||||
graph = _build_graph_calling_request_approval(checkpointer)
|
||||
config = {"configurable": {"thread_id": "self-gated-wire"}}
|
||||
await graph.ainvoke({"messages": [HumanMessage(content="seed")]}, config)
|
||||
|
||||
snap = await graph.aget_state(config)
|
||||
assert len(snap.interrupts) == 1, (
|
||||
f"expected one paused interrupt, got {len(snap.interrupts)}"
|
||||
)
|
||||
value = snap.interrupts[0].value
|
||||
assert isinstance(value, dict)
|
||||
|
||||
# Standard LC HITL fields the routing layer reads.
|
||||
assert value.get("action_requests") == [
|
||||
{
|
||||
"name": "send_gmail_email",
|
||||
"args": {"to": "alice@example.com", "subject": "hi"},
|
||||
}
|
||||
], (
|
||||
"REGRESSION: self-gated approval reverted to legacy SurfSense shape; "
|
||||
f"got {value!r}"
|
||||
)
|
||||
assert value.get("review_configs") == [
|
||||
{
|
||||
"action_name": "send_gmail_email",
|
||||
"allowed_decisions": ["approve", "reject", "edit"],
|
||||
}
|
||||
]
|
||||
assert value.get("interrupt_type") == "gmail_email_send", (
|
||||
"FE card discriminator must travel as ``interrupt_type``."
|
||||
)
|
||||
assert value.get("context") == {"account": "alice@gmail.com"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resume_with_lc_envelope_returns_hitl_result_with_edited_args():
|
||||
"""Edit reply via the LC envelope must round-trip into ``HITLResult.params``."""
|
||||
checkpointer = InMemorySaver()
|
||||
graph = _build_graph_calling_request_approval(checkpointer)
|
||||
config = {"configurable": {"thread_id": "self-gated-resume"}}
|
||||
await graph.ainvoke({"messages": [HumanMessage(content="seed")]}, config)
|
||||
|
||||
edited = {"to": "alice@example.com", "subject": "EDITED"}
|
||||
await graph.ainvoke(
|
||||
Command(
|
||||
resume={
|
||||
"decisions": [
|
||||
{"type": "edit", "edited_action": {"args": {"subject": "EDITED"}}}
|
||||
]
|
||||
}
|
||||
),
|
||||
config,
|
||||
)
|
||||
final = await graph.aget_state(config)
|
||||
assert final.values.get("final_decision_type") == "edit"
|
||||
assert final.values.get("final_params") == edited
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reject_envelope_returns_rejected_hitl_result():
|
||||
"""Reject reply must surface as ``HITLResult.rejected=True`` without invoking the tool."""
|
||||
checkpointer = InMemorySaver()
|
||||
graph = _build_graph_calling_request_approval(checkpointer)
|
||||
config = {"configurable": {"thread_id": "self-gated-reject"}}
|
||||
await graph.ainvoke({"messages": [HumanMessage(content="seed")]}, config)
|
||||
|
||||
await graph.ainvoke(
|
||||
Command(resume={"decisions": [{"type": "reject", "feedback": "no"}]}),
|
||||
config,
|
||||
)
|
||||
final = await graph.aget_state(config)
|
||||
assert final.values.get("final_decision_type") == "reject"
|
||||
|
|
@ -0,0 +1,168 @@
|
|||
"""Unit contract for the unified LC HITL wire format.
|
||||
|
||||
Both the self-gated approval primitive (``request_approval``) and the
|
||||
middleware-gated permission ask (``PermissionMiddleware``) must serialize
|
||||
to the same wire shape so the parallel-HITL routing layer
|
||||
(``collect_pending_tool_calls`` + ``slice_decisions_by_tool_call`` +
|
||||
``build_lg_resume_map``) sees one format.
|
||||
|
||||
These tests pin the shape:
|
||||
|
||||
- Builder always emits ``action_requests`` (1 entry) + ``review_configs``
|
||||
+ ``interrupt_type``; ``context`` rides through verbatim when present.
|
||||
- Parser tolerates the standard LC envelope, bare scalar strings, and
|
||||
unrecognized shapes (failing closed to ``reject``).
|
||||
- Edited args round-trip through both nested (``edited_action.args``) and
|
||||
flat (``args``) shapes without inventing values for the empty case.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from app.agents.multi_agent_chat.subagents.shared.hitl.wire import (
|
||||
LC_DECISION_APPROVE,
|
||||
LC_DECISION_EDIT,
|
||||
LC_DECISION_REJECT,
|
||||
SURFSENSE_DECISION_ALWAYS,
|
||||
build_lc_hitl_payload,
|
||||
parse_lc_envelope,
|
||||
)
|
||||
|
||||
|
||||
class TestBuildLcHitlPayload:
|
||||
def test_minimal_payload_has_one_action_request_and_one_review_config(self):
|
||||
payload = build_lc_hitl_payload(
|
||||
tool_name="send_email",
|
||||
args={"to": "x@y.z"},
|
||||
allowed_decisions=[LC_DECISION_APPROVE, LC_DECISION_REJECT],
|
||||
interrupt_type="gmail_email_send",
|
||||
)
|
||||
assert payload["action_requests"] == [
|
||||
{"name": "send_email", "args": {"to": "x@y.z"}}
|
||||
]
|
||||
assert payload["review_configs"] == [
|
||||
{
|
||||
"action_name": "send_email",
|
||||
"allowed_decisions": [LC_DECISION_APPROVE, LC_DECISION_REJECT],
|
||||
}
|
||||
]
|
||||
assert payload["interrupt_type"] == "gmail_email_send"
|
||||
assert "context" not in payload, "context must be omitted when not provided"
|
||||
|
||||
def test_none_args_normalized_to_empty_dict(self):
|
||||
"""FE expects a stable shape; ``None`` would crash card rendering."""
|
||||
payload = build_lc_hitl_payload(
|
||||
tool_name="ping",
|
||||
args=None, # type: ignore[arg-type]
|
||||
allowed_decisions=[LC_DECISION_APPROVE],
|
||||
interrupt_type="self_gated",
|
||||
)
|
||||
assert payload["action_requests"][0]["args"] == {}
|
||||
|
||||
def test_description_attached_only_when_provided(self):
|
||||
with_desc = build_lc_hitl_payload(
|
||||
tool_name="t",
|
||||
args={},
|
||||
allowed_decisions=[LC_DECISION_APPROVE],
|
||||
interrupt_type="x",
|
||||
description="please review",
|
||||
)
|
||||
without = build_lc_hitl_payload(
|
||||
tool_name="t",
|
||||
args={},
|
||||
allowed_decisions=[LC_DECISION_APPROVE],
|
||||
interrupt_type="x",
|
||||
)
|
||||
assert with_desc["action_requests"][0]["description"] == "please review"
|
||||
assert "description" not in without["action_requests"][0]
|
||||
|
||||
def test_context_passed_through_verbatim(self):
|
||||
ctx = {"patterns": ["rm/*"], "rules": [], "always": ["rm/*"]}
|
||||
payload = build_lc_hitl_payload(
|
||||
tool_name="rm",
|
||||
args={"path": "/tmp"},
|
||||
allowed_decisions=[
|
||||
LC_DECISION_APPROVE,
|
||||
LC_DECISION_REJECT,
|
||||
SURFSENSE_DECISION_ALWAYS,
|
||||
],
|
||||
interrupt_type="permission_ask",
|
||||
context=ctx,
|
||||
)
|
||||
assert payload["context"] == ctx
|
||||
|
||||
def test_allowed_decisions_list_is_copied_not_aliased(self):
|
||||
"""A caller mutating their original list must not corrupt the payload."""
|
||||
decisions = [LC_DECISION_APPROVE]
|
||||
payload = build_lc_hitl_payload(
|
||||
tool_name="t",
|
||||
args={},
|
||||
allowed_decisions=decisions,
|
||||
interrupt_type="x",
|
||||
)
|
||||
decisions.append(LC_DECISION_REJECT)
|
||||
assert payload["review_configs"][0]["allowed_decisions"] == [LC_DECISION_APPROVE]
|
||||
|
||||
|
||||
class TestParseLcEnvelope:
|
||||
def test_standard_lc_envelope_returns_typed_decision(self):
|
||||
parsed = parse_lc_envelope({"decisions": [{"type": "approve"}]})
|
||||
assert parsed.decision_type == "approve"
|
||||
assert parsed.edited_args is None
|
||||
assert parsed.message is None
|
||||
|
||||
def test_bare_scalar_string_passes_through_lowercased(self):
|
||||
assert parse_lc_envelope("ALWAYS").decision_type == "always"
|
||||
assert parse_lc_envelope("once").decision_type == "once"
|
||||
|
||||
def test_non_dict_non_string_collapses_to_reject(self):
|
||||
"""Failing closed: ambiguous input must never proceed."""
|
||||
assert parse_lc_envelope(42).decision_type == "reject"
|
||||
assert parse_lc_envelope(None).decision_type == "reject"
|
||||
assert parse_lc_envelope(["bogus"]).decision_type == "reject"
|
||||
|
||||
def test_missing_decision_type_collapses_to_reject(self):
|
||||
assert parse_lc_envelope({"decisions": [{}]}).decision_type == "reject"
|
||||
assert parse_lc_envelope({"foo": "bar"}).decision_type == "reject"
|
||||
|
||||
def test_edit_extracts_nested_args(self):
|
||||
parsed = parse_lc_envelope(
|
||||
{
|
||||
"decisions": [
|
||||
{
|
||||
"type": LC_DECISION_EDIT,
|
||||
"edited_action": {"args": {"to": "edited@y.z"}},
|
||||
}
|
||||
]
|
||||
}
|
||||
)
|
||||
assert parsed.decision_type == "edit"
|
||||
assert parsed.edited_args == {"to": "edited@y.z"}
|
||||
|
||||
def test_edit_falls_back_to_flat_args(self):
|
||||
parsed = parse_lc_envelope(
|
||||
{"decisions": [{"type": "edit", "args": {"k": "v"}}]}
|
||||
)
|
||||
assert parsed.edited_args == {"k": "v"}
|
||||
|
||||
def test_edit_with_empty_args_yields_none_edited(self):
|
||||
"""Empty edited_args means "no edits" — caller treats as plain approve."""
|
||||
parsed = parse_lc_envelope(
|
||||
{"decisions": [{"type": "edit", "edited_action": {"args": {}}}]}
|
||||
)
|
||||
assert parsed.edited_args is None
|
||||
|
||||
def test_message_picked_from_either_feedback_or_message_field(self):
|
||||
with_feedback = parse_lc_envelope(
|
||||
{"decisions": [{"type": "reject", "feedback": "no thanks"}]}
|
||||
)
|
||||
with_message = parse_lc_envelope(
|
||||
{"decisions": [{"type": "reject", "message": "no thanks"}]}
|
||||
)
|
||||
assert with_feedback.message == "no thanks"
|
||||
assert with_message.message == "no thanks"
|
||||
|
||||
def test_blank_message_treated_as_absent(self):
|
||||
parsed = parse_lc_envelope(
|
||||
{"decisions": [{"type": "reject", "message": " "}]}
|
||||
)
|
||||
assert parsed.message is None
|
||||
Loading…
Add table
Add a link
Reference in a new issue