mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-17 18:35:19 +02:00
multi_agent_chat/middleware: tighten parallel-keying test with heterogeneous bundles and per-slice assertions
This commit is contained in:
parent
d69d2cc1fc
commit
a36b15b834
1 changed files with 92 additions and 21 deletions
|
|
@ -13,11 +13,17 @@ pregel vs. our subagent bridge).
|
||||||
|
|
||||||
This test reproduces the production failure with a real two-task parallel
|
This test reproduces the production failure with a real two-task parallel
|
||||||
``Send`` parent graph, exercises the full resume cycle, and asserts both
|
``Send`` parent graph, exercises the full resume cycle, and asserts both
|
||||||
subagents complete cleanly.
|
subagents complete cleanly with their per-subagent slice intact.
|
||||||
|
|
||||||
|
Bundle sizes are chosen heterogeneous (``2`` and ``3``) so the assertions
|
||||||
|
also catch slicer arithmetic regressions (e.g., ``cursor += 1`` instead of
|
||||||
|
``cursor += action_count``). A symmetric ``(1, 1)`` configuration would
|
||||||
|
silently pass such a bug because the slices would coincide.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
from typing import Annotated
|
from typing import Annotated
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
@ -53,19 +59,31 @@ class _DispatchState(TypedDict, total=False):
|
||||||
messages: Annotated[list, add_messages]
|
messages: Annotated[list, add_messages]
|
||||||
tcid: str
|
tcid: str
|
||||||
desc: str
|
desc: str
|
||||||
|
subtype: str
|
||||||
|
|
||||||
|
|
||||||
def _build_pausing_subagent(checkpointer: InMemorySaver):
|
def _build_pausing_subagent(checkpointer: InMemorySaver, *, action_count: int):
|
||||||
|
"""Subagent that pauses with an ``action_count``-action HITL bundle.
|
||||||
|
|
||||||
|
On resume it captures the decision payload as a JSON-serialized
|
||||||
|
``AIMessage`` content so the test can inspect exactly which slice
|
||||||
|
reached this subagent — the strongest assertion against slicer
|
||||||
|
routing regressions.
|
||||||
|
"""
|
||||||
|
|
||||||
def approve_node(_state):
|
def approve_node(_state):
|
||||||
decision = interrupt(
|
decision = interrupt(
|
||||||
{
|
{
|
||||||
"action_requests": [
|
"action_requests": [
|
||||||
{"name": "do_thing", "args": {"x": 1}, "description": ""}
|
{"name": f"act_{i}", "args": {"i": i}, "description": ""}
|
||||||
|
for i in range(action_count)
|
||||||
],
|
],
|
||||||
"review_configs": [{}],
|
"review_configs": [{} for _ in range(action_count)],
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
return {"messages": [AIMessage(content=f"got:{decision}")]}
|
return {
|
||||||
|
"messages": [AIMessage(content=json.dumps(decision, sort_keys=True))]
|
||||||
|
}
|
||||||
|
|
||||||
g = StateGraph(_SubState)
|
g = StateGraph(_SubState)
|
||||||
g.add_node("approve", approve_node)
|
g.add_node("approve", approve_node)
|
||||||
|
|
@ -79,8 +97,14 @@ def _parent_graph_dispatching_two_tasks_via_send(
|
||||||
):
|
):
|
||||||
def fanout_edge(_state) -> list[Send]:
|
def fanout_edge(_state) -> list[Send]:
|
||||||
return [
|
return [
|
||||||
Send("call_task", {"tcid": tool_call_id_a, "desc": "approve A"}),
|
Send(
|
||||||
Send("call_task", {"tcid": tool_call_id_b, "desc": "approve B"}),
|
"call_task",
|
||||||
|
{"tcid": tool_call_id_a, "desc": "approve A", "subtype": "agent-a"},
|
||||||
|
),
|
||||||
|
Send(
|
||||||
|
"call_task",
|
||||||
|
{"tcid": tool_call_id_b, "desc": "approve B", "subtype": "agent-b"},
|
||||||
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
async def call_task(state: _DispatchState, config: RunnableConfig):
|
async def call_task(state: _DispatchState, config: RunnableConfig):
|
||||||
|
|
@ -93,7 +117,7 @@ def _parent_graph_dispatching_two_tasks_via_send(
|
||||||
store=None,
|
store=None,
|
||||||
)
|
)
|
||||||
return await task_tool.coroutine(
|
return await task_tool.coroutine(
|
||||||
description=state["desc"], subagent_type="approver", runtime=rt
|
description=state["desc"], subagent_type=state["subtype"], runtime=rt
|
||||||
)
|
)
|
||||||
|
|
||||||
g = StateGraph(_DispatchState)
|
g = StateGraph(_DispatchState)
|
||||||
|
|
@ -103,6 +127,22 @@ def _parent_graph_dispatching_two_tasks_via_send(
|
||||||
return g.compile(checkpointer=checkpointer)
|
return g.compile(checkpointer=checkpointer)
|
||||||
|
|
||||||
|
|
||||||
|
def _build_two_subagents_task_tool(checkpointer: InMemorySaver):
|
||||||
|
"""Register two subagents under distinct names with heterogeneous bundle sizes.
|
||||||
|
|
||||||
|
Sub-A: 2-action bundle. Sub-B: 3-action bundle. Both ``> 1`` so the slice
|
||||||
|
arithmetic is sensitive to off-by-one mistakes.
|
||||||
|
"""
|
||||||
|
sub_a = _build_pausing_subagent(checkpointer, action_count=2)
|
||||||
|
sub_b = _build_pausing_subagent(checkpointer, action_count=3)
|
||||||
|
return build_task_tool_with_parent_config(
|
||||||
|
[
|
||||||
|
{"name": "agent-a", "description": "first", "runnable": sub_a},
|
||||||
|
{"name": "agent-b", "description": "second", "runnable": sub_b},
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_parallel_resume_with_command_resume_scalar_raises_lg_runtime_error():
|
async def test_parallel_resume_with_command_resume_scalar_raises_lg_runtime_error():
|
||||||
"""Confirm the production failure mode: scalar resume on multi-pending state explodes.
|
"""Confirm the production failure mode: scalar resume on multi-pending state explodes.
|
||||||
|
|
@ -112,10 +152,7 @@ async def test_parallel_resume_with_command_resume_scalar_raises_lg_runtime_erro
|
||||||
``stream_resume_chat``. Until then, the keyed form is mandatory.
|
``stream_resume_chat``. Until then, the keyed form is mandatory.
|
||||||
"""
|
"""
|
||||||
checkpointer = InMemorySaver()
|
checkpointer = InMemorySaver()
|
||||||
subagent = _build_pausing_subagent(checkpointer)
|
task_tool = _build_two_subagents_task_tool(checkpointer)
|
||||||
task_tool = build_task_tool_with_parent_config(
|
|
||||||
[{"name": "approver", "description": "approves", "runnable": subagent}]
|
|
||||||
)
|
|
||||||
parent = _parent_graph_dispatching_two_tasks_via_send(
|
parent = _parent_graph_dispatching_two_tasks_via_send(
|
||||||
task_tool,
|
task_tool,
|
||||||
tool_call_id_a="parent-tcid-A",
|
tool_call_id_a="parent-tcid-A",
|
||||||
|
|
@ -139,15 +176,17 @@ async def test_parallel_resume_with_per_interrupt_id_keying_completes_both_subag
|
||||||
Mirrors what ``stream_resume_chat`` does: collects pending interrupts,
|
Mirrors what ``stream_resume_chat`` does: collects pending interrupts,
|
||||||
slices the flat decisions list by ``tool_call_id``, builds the
|
slices the flat decisions list by ``tool_call_id``, builds the
|
||||||
``Interrupt.id``-keyed map for ``Command(resume=...)``, and resumes.
|
``Interrupt.id``-keyed map for ``Command(resume=...)``, and resumes.
|
||||||
The expected post-condition is that both subagents pop their own
|
|
||||||
decision (via the ``surfsense_resume_value`` side-channel) and run to
|
Post-conditions checked:
|
||||||
completion — no RuntimeError, no leaked pending interrupts.
|
1. The langgraph-keyed map has exactly one entry per pending interrupt
|
||||||
|
id (``str`` keys, count matches).
|
||||||
|
2. Both subagents complete with no leftover pending interrupts.
|
||||||
|
3. **Each subagent receives its exact slice in the original order** —
|
||||||
|
this catches slicer arithmetic regressions (e.g., ``cursor += 1``)
|
||||||
|
that wouldn't surface by checking only "no leftover pending".
|
||||||
"""
|
"""
|
||||||
checkpointer = InMemorySaver()
|
checkpointer = InMemorySaver()
|
||||||
subagent = _build_pausing_subagent(checkpointer)
|
task_tool = _build_two_subagents_task_tool(checkpointer)
|
||||||
task_tool = build_task_tool_with_parent_config(
|
|
||||||
[{"name": "approver", "description": "approves", "runnable": subagent}]
|
|
||||||
)
|
|
||||||
tcid_a = "parent-tcid-A"
|
tcid_a = "parent-tcid-A"
|
||||||
tcid_b = "parent-tcid-B"
|
tcid_b = "parent-tcid-B"
|
||||||
parent = _parent_graph_dispatching_two_tasks_via_send(
|
parent = _parent_graph_dispatching_two_tasks_via_send(
|
||||||
|
|
@ -166,7 +205,20 @@ async def test_parallel_resume_with_per_interrupt_id_keying_completes_both_subag
|
||||||
assert len(paused_state.interrupts) == 2, "fixture broken: expected 2 paused subagents"
|
assert len(paused_state.interrupts) == 2, "fixture broken: expected 2 paused subagents"
|
||||||
|
|
||||||
pending = collect_pending_tool_calls(paused_state)
|
pending = collect_pending_tool_calls(paused_state)
|
||||||
flat_decisions = [{"type": "approve"}, {"type": "approve"}]
|
assert dict(pending) == {tcid_a: 2, tcid_b: 3}, (
|
||||||
|
f"fixture broken: heterogeneous bundle sizes not detected; got {pending!r}"
|
||||||
|
)
|
||||||
|
|
||||||
|
a_d0 = {"type": "approve"}
|
||||||
|
a_d1 = {"type": "reject", "message": "A[1] is redundant"}
|
||||||
|
b_d0 = {
|
||||||
|
"type": "edit",
|
||||||
|
"edited_action": {"name": "act_0", "args": {"i": 0, "edited": True}},
|
||||||
|
}
|
||||||
|
b_d1 = {"type": "approve"}
|
||||||
|
b_d2 = {"type": "reject", "message": "B[2] needs more context"}
|
||||||
|
flat_decisions = [a_d0, a_d1, b_d0, b_d1, b_d2]
|
||||||
|
|
||||||
by_tool_call_id = slice_decisions_by_tool_call(flat_decisions, pending)
|
by_tool_call_id = slice_decisions_by_tool_call(flat_decisions, pending)
|
||||||
lg_resume_map = build_lg_resume_map(paused_state, by_tool_call_id)
|
lg_resume_map = build_lg_resume_map(paused_state, by_tool_call_id)
|
||||||
|
|
||||||
|
|
@ -177,7 +229,6 @@ async def test_parallel_resume_with_per_interrupt_id_keying_completes_both_subag
|
||||||
f"keys must be Interrupt.id strings, got {[type(k).__name__ for k in lg_resume_map]}"
|
f"keys must be Interrupt.id strings, got {[type(k).__name__ for k in lg_resume_map]}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Wire the side-channel exactly like ``stream_resume_chat`` does.
|
|
||||||
config["configurable"]["surfsense_resume_value"] = by_tool_call_id
|
config["configurable"]["surfsense_resume_value"] = by_tool_call_id
|
||||||
|
|
||||||
await parent.ainvoke(Command(resume=lg_resume_map), config)
|
await parent.ainvoke(Command(resume=lg_resume_map), config)
|
||||||
|
|
@ -188,6 +239,26 @@ async def test_parallel_resume_with_per_interrupt_id_keying_completes_both_subag
|
||||||
f"{final_state.interrupts!r}"
|
f"{final_state.interrupts!r}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
payloads: list[dict] = []
|
||||||
|
for msg in final_state.values.get("messages", []) or []:
|
||||||
|
content = getattr(msg, "content", None)
|
||||||
|
if isinstance(content, str):
|
||||||
|
try:
|
||||||
|
payloads.append(json.loads(content))
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
expected_a = {"decisions": [a_d0, a_d1]}
|
||||||
|
expected_b = {"decisions": [b_d0, b_d1, b_d2]}
|
||||||
|
assert expected_a in payloads, (
|
||||||
|
f"REGRESSION: sub-A did not receive its 2-decision slice in order; "
|
||||||
|
f"payloads seen: {payloads!r}"
|
||||||
|
)
|
||||||
|
assert expected_b in payloads, (
|
||||||
|
f"REGRESSION: sub-B did not receive its 3-decision slice in order; "
|
||||||
|
f"payloads seen: {payloads!r}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_build_lg_resume_map_returns_empty_when_no_interrupts_carry_stamps():
|
def test_build_lg_resume_map_returns_empty_when_no_interrupts_carry_stamps():
|
||||||
"""Unstamped interrupts can't be routed; we don't fabricate keys for them.
|
"""Unstamped interrupts can't be routed; we don't fabricate keys for them.
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue