diff --git a/surfsense_backend/tests/unit/agents/multi_agent_chat/middleware/checkpointed_subagent_middleware/test_parallel_resume_command_keying.py b/surfsense_backend/tests/unit/agents/multi_agent_chat/middleware/checkpointed_subagent_middleware/test_parallel_resume_command_keying.py index 125e0744a..458a2539b 100644 --- a/surfsense_backend/tests/unit/agents/multi_agent_chat/middleware/checkpointed_subagent_middleware/test_parallel_resume_command_keying.py +++ b/surfsense_backend/tests/unit/agents/multi_agent_chat/middleware/checkpointed_subagent_middleware/test_parallel_resume_command_keying.py @@ -13,11 +13,17 @@ pregel vs. our subagent bridge). This test reproduces the production failure with a real two-task parallel ``Send`` parent graph, exercises the full resume cycle, and asserts both -subagents complete cleanly. +subagents complete cleanly with their per-subagent slice intact. + +Bundle sizes are chosen heterogeneous (``2`` and ``3``) so the assertions +also catch slicer arithmetic regressions (e.g., ``cursor += 1`` instead of +``cursor += action_count``). A symmetric ``(1, 1)`` configuration would +silently pass such a bug because the slices would coincide. """ from __future__ import annotations +import json from typing import Annotated import pytest @@ -53,19 +59,31 @@ class _DispatchState(TypedDict, total=False): messages: Annotated[list, add_messages] tcid: str desc: str + subtype: str -def _build_pausing_subagent(checkpointer: InMemorySaver): +def _build_pausing_subagent(checkpointer: InMemorySaver, *, action_count: int): + """Subagent that pauses with an ``action_count``-action HITL bundle. + + On resume it captures the decision payload as a JSON-serialized + ``AIMessage`` content so the test can inspect exactly which slice + reached this subagent — the strongest assertion against slicer + routing regressions. + """ + def approve_node(_state): decision = interrupt( { "action_requests": [ - {"name": "do_thing", "args": {"x": 1}, "description": ""} + {"name": f"act_{i}", "args": {"i": i}, "description": ""} + for i in range(action_count) ], - "review_configs": [{}], + "review_configs": [{} for _ in range(action_count)], } ) - return {"messages": [AIMessage(content=f"got:{decision}")]} + return { + "messages": [AIMessage(content=json.dumps(decision, sort_keys=True))] + } g = StateGraph(_SubState) g.add_node("approve", approve_node) @@ -79,8 +97,14 @@ def _parent_graph_dispatching_two_tasks_via_send( ): def fanout_edge(_state) -> list[Send]: return [ - Send("call_task", {"tcid": tool_call_id_a, "desc": "approve A"}), - Send("call_task", {"tcid": tool_call_id_b, "desc": "approve B"}), + Send( + "call_task", + {"tcid": tool_call_id_a, "desc": "approve A", "subtype": "agent-a"}, + ), + Send( + "call_task", + {"tcid": tool_call_id_b, "desc": "approve B", "subtype": "agent-b"}, + ), ] async def call_task(state: _DispatchState, config: RunnableConfig): @@ -93,7 +117,7 @@ def _parent_graph_dispatching_two_tasks_via_send( store=None, ) return await task_tool.coroutine( - description=state["desc"], subagent_type="approver", runtime=rt + description=state["desc"], subagent_type=state["subtype"], runtime=rt ) g = StateGraph(_DispatchState) @@ -103,6 +127,22 @@ def _parent_graph_dispatching_two_tasks_via_send( return g.compile(checkpointer=checkpointer) +def _build_two_subagents_task_tool(checkpointer: InMemorySaver): + """Register two subagents under distinct names with heterogeneous bundle sizes. + + Sub-A: 2-action bundle. Sub-B: 3-action bundle. Both ``> 1`` so the slice + arithmetic is sensitive to off-by-one mistakes. + """ + sub_a = _build_pausing_subagent(checkpointer, action_count=2) + sub_b = _build_pausing_subagent(checkpointer, action_count=3) + return build_task_tool_with_parent_config( + [ + {"name": "agent-a", "description": "first", "runnable": sub_a}, + {"name": "agent-b", "description": "second", "runnable": sub_b}, + ] + ) + + @pytest.mark.asyncio async def test_parallel_resume_with_command_resume_scalar_raises_lg_runtime_error(): """Confirm the production failure mode: scalar resume on multi-pending state explodes. @@ -112,10 +152,7 @@ async def test_parallel_resume_with_command_resume_scalar_raises_lg_runtime_erro ``stream_resume_chat``. Until then, the keyed form is mandatory. """ checkpointer = InMemorySaver() - subagent = _build_pausing_subagent(checkpointer) - task_tool = build_task_tool_with_parent_config( - [{"name": "approver", "description": "approves", "runnable": subagent}] - ) + task_tool = _build_two_subagents_task_tool(checkpointer) parent = _parent_graph_dispatching_two_tasks_via_send( task_tool, tool_call_id_a="parent-tcid-A", @@ -139,15 +176,17 @@ async def test_parallel_resume_with_per_interrupt_id_keying_completes_both_subag Mirrors what ``stream_resume_chat`` does: collects pending interrupts, slices the flat decisions list by ``tool_call_id``, builds the ``Interrupt.id``-keyed map for ``Command(resume=...)``, and resumes. - The expected post-condition is that both subagents pop their own - decision (via the ``surfsense_resume_value`` side-channel) and run to - completion — no RuntimeError, no leaked pending interrupts. + + Post-conditions checked: + 1. The langgraph-keyed map has exactly one entry per pending interrupt + id (``str`` keys, count matches). + 2. Both subagents complete with no leftover pending interrupts. + 3. **Each subagent receives its exact slice in the original order** — + this catches slicer arithmetic regressions (e.g., ``cursor += 1``) + that wouldn't surface by checking only "no leftover pending". """ checkpointer = InMemorySaver() - subagent = _build_pausing_subagent(checkpointer) - task_tool = build_task_tool_with_parent_config( - [{"name": "approver", "description": "approves", "runnable": subagent}] - ) + task_tool = _build_two_subagents_task_tool(checkpointer) tcid_a = "parent-tcid-A" tcid_b = "parent-tcid-B" parent = _parent_graph_dispatching_two_tasks_via_send( @@ -166,7 +205,20 @@ async def test_parallel_resume_with_per_interrupt_id_keying_completes_both_subag assert len(paused_state.interrupts) == 2, "fixture broken: expected 2 paused subagents" pending = collect_pending_tool_calls(paused_state) - flat_decisions = [{"type": "approve"}, {"type": "approve"}] + assert dict(pending) == {tcid_a: 2, tcid_b: 3}, ( + f"fixture broken: heterogeneous bundle sizes not detected; got {pending!r}" + ) + + a_d0 = {"type": "approve"} + a_d1 = {"type": "reject", "message": "A[1] is redundant"} + b_d0 = { + "type": "edit", + "edited_action": {"name": "act_0", "args": {"i": 0, "edited": True}}, + } + b_d1 = {"type": "approve"} + b_d2 = {"type": "reject", "message": "B[2] needs more context"} + flat_decisions = [a_d0, a_d1, b_d0, b_d1, b_d2] + by_tool_call_id = slice_decisions_by_tool_call(flat_decisions, pending) lg_resume_map = build_lg_resume_map(paused_state, by_tool_call_id) @@ -177,7 +229,6 @@ async def test_parallel_resume_with_per_interrupt_id_keying_completes_both_subag f"keys must be Interrupt.id strings, got {[type(k).__name__ for k in lg_resume_map]}" ) - # Wire the side-channel exactly like ``stream_resume_chat`` does. config["configurable"]["surfsense_resume_value"] = by_tool_call_id await parent.ainvoke(Command(resume=lg_resume_map), config) @@ -188,6 +239,26 @@ async def test_parallel_resume_with_per_interrupt_id_keying_completes_both_subag f"{final_state.interrupts!r}" ) + payloads: list[dict] = [] + for msg in final_state.values.get("messages", []) or []: + content = getattr(msg, "content", None) + if isinstance(content, str): + try: + payloads.append(json.loads(content)) + except json.JSONDecodeError: + pass + + expected_a = {"decisions": [a_d0, a_d1]} + expected_b = {"decisions": [b_d0, b_d1, b_d2]} + assert expected_a in payloads, ( + f"REGRESSION: sub-A did not receive its 2-decision slice in order; " + f"payloads seen: {payloads!r}" + ) + assert expected_b in payloads, ( + f"REGRESSION: sub-B did not receive its 3-decision slice in order; " + f"payloads seen: {payloads!r}" + ) + def test_build_lg_resume_map_returns_empty_when_no_interrupts_carry_stamps(): """Unstamped interrupts can't be routed; we don't fabricate keys for them.