mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-17 18:35:19 +02:00
multi_agent_chat/middleware: real-graph regression test for all-reject parallel routing
This commit is contained in:
parent
ca57b2106e
commit
8e10f38f32
1 changed files with 216 additions and 0 deletions
|
|
@ -0,0 +1,216 @@
|
|||
"""Real-graph contract: all-reject decisions route correctly across parallel subagents.
|
||||
|
||||
Heterogeneous routing is covered by ``test_parallel_heterogeneous_decisions``.
|
||||
This module pins the narrower edge case where **every** card on **every**
|
||||
paused subagent is rejected.
|
||||
|
||||
Why a separate pin:
|
||||
|
||||
1. **No approval-bias in the slicer.** A future "if no approvals, short-circuit
|
||||
resume" optimization would be tempting (skips a langgraph round-trip) and
|
||||
would also silently break this scenario. Pin it.
|
||||
2. **``message`` metadata pass-through across a run of rejects.** The reject
|
||||
``message`` is the user-visible reason ("looks suspicious", "duplicate",
|
||||
etc.). Losing it would silently swallow user intent — the worst UX
|
||||
failure mode for HITL. Heterogeneous covers one reject; here we verify a
|
||||
sequence of rejects survives the slicer + bridge with distinct messages
|
||||
intact and in order.
|
||||
3. **All branches complete with no leftover pending.** Even when nothing was
|
||||
approved, the parent must drain every paused subagent so the SSE stream
|
||||
can close cleanly. A bug that left one ``Interrupt.id`` un-keyed would
|
||||
strand the conversation in "pending" forever.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Annotated
|
||||
|
||||
import pytest
|
||||
from langchain.tools import ToolRuntime
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
from langgraph.graph import END, START, StateGraph
|
||||
from langgraph.graph.message import add_messages
|
||||
from langgraph.types import Command, Send, interrupt
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from app.agents.multi_agent_chat.middleware.main_agent.checkpointed_subagent_middleware.resume_routing import (
|
||||
build_lg_resume_map,
|
||||
collect_pending_tool_calls,
|
||||
slice_decisions_by_tool_call,
|
||||
)
|
||||
from app.agents.multi_agent_chat.middleware.main_agent.checkpointed_subagent_middleware.task_tool import (
|
||||
build_task_tool_with_parent_config,
|
||||
)
|
||||
|
||||
|
||||
class _SubState(TypedDict, total=False):
|
||||
messages: list
|
||||
|
||||
|
||||
class _DispatchState(TypedDict, total=False):
|
||||
messages: Annotated[list, add_messages]
|
||||
tcid: str
|
||||
desc: str
|
||||
subtype: str
|
||||
|
||||
|
||||
def _build_recording_subagent(checkpointer: InMemorySaver, *, action_count: int):
|
||||
"""Subagent that pauses with ``action_count`` actions and records its resume payload.
|
||||
|
||||
The recorded ``AIMessage`` content is the JSON-serialized payload, so the
|
||||
test can match each subagent's slice by content.
|
||||
"""
|
||||
|
||||
def hitl_node(_state):
|
||||
decision_payload = interrupt(
|
||||
{
|
||||
"action_requests": [
|
||||
{"name": f"act_{i}", "args": {"i": i}, "description": ""}
|
||||
for i in range(action_count)
|
||||
],
|
||||
"review_configs": [
|
||||
{
|
||||
"action_name": f"act_{i}",
|
||||
"allowed_decisions": ["approve", "reject", "edit"],
|
||||
}
|
||||
for i in range(action_count)
|
||||
],
|
||||
}
|
||||
)
|
||||
return {
|
||||
"messages": [
|
||||
AIMessage(content=json.dumps(decision_payload, sort_keys=True))
|
||||
]
|
||||
}
|
||||
|
||||
g = StateGraph(_SubState)
|
||||
g.add_node("hitl", hitl_node)
|
||||
g.add_edge(START, "hitl")
|
||||
g.add_edge("hitl", END)
|
||||
return g.compile(checkpointer=checkpointer)
|
||||
|
||||
|
||||
def _parent_two_branches(task_tool, *, dispatches, checkpointer):
|
||||
def fanout(_state) -> list[Send]:
|
||||
return [Send("call_task", d) for d in dispatches]
|
||||
|
||||
async def call_task(state: _DispatchState, config: RunnableConfig):
|
||||
rt = ToolRuntime(
|
||||
state=state,
|
||||
config=config,
|
||||
context=None,
|
||||
stream_writer=None,
|
||||
tool_call_id=state["tcid"],
|
||||
store=None,
|
||||
)
|
||||
return await task_tool.coroutine(
|
||||
description=state["desc"], subagent_type=state["subtype"], runtime=rt
|
||||
)
|
||||
|
||||
g = StateGraph(_DispatchState)
|
||||
g.add_node("call_task", call_task)
|
||||
g.add_conditional_edges(START, fanout, ["call_task"])
|
||||
g.add_edge("call_task", END)
|
||||
return g.compile(checkpointer=checkpointer)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_all_reject_decisions_route_to_each_subagent_with_messages_intact():
|
||||
"""All cards rejected across two parallel subagents — order and messages preserved.
|
||||
|
||||
Setup mirrors a real "user reviews two parallel ticket creations and
|
||||
rejects everything with distinct reasons":
|
||||
|
||||
- Sub-A pauses with 2 actions.
|
||||
- Sub-B pauses with 1 action.
|
||||
- Flat decisions: 3 rejects, each with a unique ``message``.
|
||||
|
||||
Asserts each subagent receives only its slice, in original order,
|
||||
with every ``message`` intact and no ``edited_action`` fields fabricated.
|
||||
"""
|
||||
checkpointer = InMemorySaver()
|
||||
|
||||
sub_a = _build_recording_subagent(checkpointer, action_count=2)
|
||||
sub_b = _build_recording_subagent(checkpointer, action_count=1)
|
||||
|
||||
task_tool = build_task_tool_with_parent_config(
|
||||
[
|
||||
{"name": "agent-a", "description": "first", "runnable": sub_a},
|
||||
{"name": "agent-b", "description": "second", "runnable": sub_b},
|
||||
]
|
||||
)
|
||||
|
||||
parent = _parent_two_branches(
|
||||
task_tool,
|
||||
dispatches=[
|
||||
{"tcid": "tcid-A", "subtype": "agent-a", "desc": "do A"},
|
||||
{"tcid": "tcid-B", "subtype": "agent-b", "desc": "do B"},
|
||||
],
|
||||
checkpointer=checkpointer,
|
||||
)
|
||||
|
||||
config: dict = {
|
||||
"configurable": {"thread_id": "all-reject-thread"},
|
||||
"recursion_limit": 100,
|
||||
}
|
||||
await parent.ainvoke({"messages": [HumanMessage(content="seed")]}, config)
|
||||
|
||||
paused_state = await parent.aget_state(config)
|
||||
assert len(paused_state.interrupts) == 2, (
|
||||
f"fixture broken: expected 2 paused subagents, got {len(paused_state.interrupts)}"
|
||||
)
|
||||
|
||||
a_reject_0 = {"type": "reject", "message": "A[0] looks suspicious"}
|
||||
a_reject_1 = {"type": "reject", "message": "A[1] duplicates A[0]"}
|
||||
b_reject_0 = {"type": "reject", "message": "B[0] needs more context"}
|
||||
flat_decisions = [a_reject_0, a_reject_1, b_reject_0]
|
||||
|
||||
pending = collect_pending_tool_calls(paused_state)
|
||||
by_tool_call_id = slice_decisions_by_tool_call(flat_decisions, pending)
|
||||
|
||||
assert by_tool_call_id == {
|
||||
"tcid-A": {"decisions": [a_reject_0, a_reject_1]},
|
||||
"tcid-B": {"decisions": [b_reject_0]},
|
||||
}, f"REGRESSION: slicer mis-routed all-reject decisions: {by_tool_call_id!r}"
|
||||
|
||||
config["configurable"]["surfsense_resume_value"] = by_tool_call_id
|
||||
lg_resume_map = build_lg_resume_map(paused_state, by_tool_call_id)
|
||||
|
||||
await parent.ainvoke(Command(resume=lg_resume_map), config)
|
||||
|
||||
final_state = await parent.aget_state(config)
|
||||
assert not final_state.interrupts, (
|
||||
f"REGRESSION: leftover pending interrupts after all-reject resume: "
|
||||
f"{final_state.interrupts!r}"
|
||||
)
|
||||
|
||||
payloads: list[dict] = []
|
||||
for msg in final_state.values.get("messages", []) or []:
|
||||
content = getattr(msg, "content", None)
|
||||
if isinstance(content, str):
|
||||
try:
|
||||
payloads.append(json.loads(content))
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
expected_a = {"decisions": [a_reject_0, a_reject_1]}
|
||||
expected_b = {"decisions": [b_reject_0]}
|
||||
|
||||
assert expected_a in payloads, (
|
||||
f"REGRESSION: sub-A did not receive its 2-reject slice in order; "
|
||||
f"payloads seen: {payloads!r}"
|
||||
)
|
||||
assert expected_b in payloads, (
|
||||
f"REGRESSION: sub-B did not receive its single reject; "
|
||||
f"payloads seen: {payloads!r}"
|
||||
)
|
||||
|
||||
for p in payloads:
|
||||
for d in p.get("decisions", []):
|
||||
assert "edited_action" not in d, (
|
||||
f"REGRESSION: spurious ``edited_action`` on a reject — "
|
||||
f"slicer/bridge mutated payload: {d!r}"
|
||||
)
|
||||
Loading…
Add table
Add a link
Reference in a new issue