multi_agent_chat/middleware: tighten heterogeneous slice arithmetic to (2,3) bundles

This commit is contained in:
CREDO23 2026-05-14 10:05:04 +02:00
parent 668b89927b
commit d69d2cc1fc

View file

@ -15,15 +15,19 @@ This module pins:
only its slice in the original order.
2. **Per-decision metadata pass-through** ``message`` and ``edited_action``
payloads must reach the subagent intact (not just the ``type`` discriminator).
3. **Mixed bundle sizes** two paused subagents with different
``len(action_requests)`` correctly account for the slice boundary
(sub-A gets 1 decision, sub-B gets 2 from a flat list of 3).
3. **Off-by-one-sensitive bundle sizes** both paused subagents have action
counts ``> 1`` (``2`` and ``3``). With those sizes a buggy
``cursor += 1`` slicer (instead of ``cursor += action_count``) produces a
different B-slice from the correct one, so this test catches the most
common refactor mistake. A ``(1, 2)`` configuration would silently pass
such a bug because ``+= 1`` and ``+= count`` are arithmetically identical
when ``count == 1``.
"""
from __future__ import annotations
import json
from typing import Annotated, Any
from typing import Annotated
import pytest
from langchain.tools import ToolRuntime
@ -131,56 +135,33 @@ def _parent_dispatching_two_subagents(
return g.compile(checkpointer=checkpointer)
def _captured_payloads_by_content_marker(
final_state, *, marker_keys: list[str]
) -> dict[str, dict[str, Any]]:
"""Extract per-subagent resume payloads from the parent's final messages.
Each subagent emitted ``AIMessage(json.dumps(payload))``. We tag them by
looking for a marker in the inner action-request name (``act_{i}``) plus
the per-decision content but that's brittle. Instead we just collect
every JSON payload and let the test match by content.
"""
payloads: list[dict[str, Any]] = []
for msg in getattr(final_state, "values", {}).get("messages", []) or []:
content = getattr(msg, "content", None)
if not isinstance(content, str):
continue
try:
parsed = json.loads(content)
except json.JSONDecodeError:
continue
if isinstance(parsed, dict) and "decisions" in parsed:
payloads.append(parsed)
by_marker: dict[str, dict[str, Any]] = {}
for marker in marker_keys:
for p in payloads:
text = json.dumps(p, sort_keys=True)
if marker in text:
by_marker[marker] = p
break
return by_marker
@pytest.mark.asyncio
async def test_heterogeneous_decisions_route_to_correct_subagents_with_metadata_intact():
"""Mixed approve/reject/edit decisions across two parallel subagents.
Setup:
- Sub-A pauses with a 1-action bundle (``act_0``).
- Sub-B pauses with a 2-action bundle (``act_0``, ``act_1``).
Setup chosen so the slicer's cursor arithmetic is sensitive to off-by-one
refactors:
- Sub-A pauses with a 2-action bundle (``act_0``, ``act_1``).
- Sub-B pauses with a 3-action bundle (``act_0``, ``act_1``, ``act_2``).
- Parent ends up with 2 pending interrupts (one per subagent).
The frontend submits a flat ``[A_approve, B_reject, B_edit]`` list; our
slicer must split into ``{tcid_A: [A_approve], tcid_B: [B_reject, B_edit]}``
and the bridge must forward each subagent's slice intact — including the
``message`` on the reject and the ``edited_action.args`` on the edit.
With both counts ``> 1``, a buggy ``cursor += 1`` (instead of
``cursor += action_count``) produces a different B-slice from the correct
one, so the assertions catch it. A ``(1, 2)`` configuration would not
because ``+= 1`` and ``+= count`` are arithmetically identical when
``count == 1``.
The frontend submits a flat
``[A_approve, A_reject, B_edit, B_approve, B_reject]`` list with distinct
``message`` and ``edited_action`` payloads; our slicer must split into
``{tcid_A: [A_approve, A_reject], tcid_B: [B_edit, B_approve, B_reject]}``
and the bridge must forward each subagent's slice intact — including all
metadata, in original order.
"""
checkpointer = InMemorySaver()
sub_a = _build_capturing_subagent(checkpointer, action_count=1)
sub_b = _build_capturing_subagent(checkpointer, action_count=2)
sub_a = _build_capturing_subagent(checkpointer, action_count=2)
sub_b = _build_capturing_subagent(checkpointer, action_count=3)
task_tool = build_task_tool_with_parent_config(
[
@ -211,23 +192,25 @@ async def test_heterogeneous_decisions_route_to_correct_subagents_with_metadata_
pending = collect_pending_tool_calls(paused_state)
pending_by_tcid = dict(pending)
assert pending_by_tcid == {"tcid-A": 1, "tcid-B": 2}, (
assert pending_by_tcid == {"tcid-A": 2, "tcid-B": 3}, (
f"REGRESSION: action-count accounting wrong; got {pending_by_tcid!r}"
)
a_approve = {"type": "approve"}
b_reject = {"type": "reject", "message": "no thanks for B[0]"}
a_reject = {"type": "reject", "message": "A[1] looks redundant"}
b_edit = {
"type": "edit",
"edited_action": {"name": "act_1", "args": {"i": 1, "edited": True}},
"edited_action": {"name": "act_0", "args": {"i": 0, "edited": True}},
}
flat_decisions = [a_approve, b_reject, b_edit]
b_approve = {"type": "approve"}
b_reject = {"type": "reject", "message": "B[2] needs more context"}
flat_decisions = [a_approve, a_reject, b_edit, b_approve, b_reject]
by_tool_call_id = slice_decisions_by_tool_call(flat_decisions, pending)
assert by_tool_call_id == {
"tcid-A": {"decisions": [a_approve]},
"tcid-B": {"decisions": [b_reject, b_edit]},
"tcid-A": {"decisions": [a_approve, a_reject]},
"tcid-B": {"decisions": [b_edit, b_approve, b_reject]},
}, f"REGRESSION: slicer mis-routed decisions: {by_tool_call_id!r}"
config["configurable"]["surfsense_resume_value"] = by_tool_call_id
@ -240,17 +223,7 @@ async def test_heterogeneous_decisions_route_to_correct_subagents_with_metadata_
f"REGRESSION: leftover pending interrupts after resume: {final_state.interrupts!r}"
)
captured = _captured_payloads_by_content_marker(
final_state,
marker_keys=["no thanks for B[0]", '"i": 1, "edited": true'],
)
payload_b = captured.get("no thanks for B[0]")
assert payload_b is not None, "could not locate sub-B's captured payload"
assert payload_b == {
"decisions": [b_reject, b_edit]
}, f"REGRESSION: sub-B received wrong payload: {payload_b!r}"
payloads = []
payloads: list[dict] = []
for msg in final_state.values.get("messages", []) or []:
content = getattr(msg, "content", None)
if isinstance(content, str):
@ -258,11 +231,16 @@ async def test_heterogeneous_decisions_route_to_correct_subagents_with_metadata_
payloads.append(json.loads(content))
except json.JSONDecodeError:
pass
payload_a = next(
(p for p in payloads if p == {"decisions": [a_approve]}), None
expected_a = {"decisions": [a_approve, a_reject]}
expected_b = {"decisions": [b_edit, b_approve, b_reject]}
assert expected_a in payloads, (
f"REGRESSION: sub-A did not receive its 2-decision slice in original order; "
f"payloads seen: {payloads!r}"
)
assert payload_a is not None, (
f"REGRESSION: sub-A did not receive its single approve in isolation; "
assert expected_b in payloads, (
f"REGRESSION: sub-B did not receive its 3-decision slice in original order; "
f"payloads seen: {payloads!r}"
)