mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-19 18:45:15 +02:00
multi_agent_chat/middleware: tighten heterogeneous slice arithmetic to (2,3) bundles
This commit is contained in:
parent
668b89927b
commit
d69d2cc1fc
1 changed files with 44 additions and 66 deletions
|
|
@ -15,15 +15,19 @@ This module pins:
|
||||||
only its slice in the original order.
|
only its slice in the original order.
|
||||||
2. **Per-decision metadata pass-through** — ``message`` and ``edited_action``
|
2. **Per-decision metadata pass-through** — ``message`` and ``edited_action``
|
||||||
payloads must reach the subagent intact (not just the ``type`` discriminator).
|
payloads must reach the subagent intact (not just the ``type`` discriminator).
|
||||||
3. **Mixed bundle sizes** — two paused subagents with different
|
3. **Off-by-one-sensitive bundle sizes** — both paused subagents have action
|
||||||
``len(action_requests)`` correctly account for the slice boundary
|
counts ``> 1`` (``2`` and ``3``). With those sizes a buggy
|
||||||
(sub-A gets 1 decision, sub-B gets 2 from a flat list of 3).
|
``cursor += 1`` slicer (instead of ``cursor += action_count``) produces a
|
||||||
|
different B-slice from the correct one, so this test catches the most
|
||||||
|
common refactor mistake. A ``(1, 2)`` configuration would silently pass
|
||||||
|
such a bug because ``+= 1`` and ``+= count`` are arithmetically identical
|
||||||
|
when ``count == 1``.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import json
|
import json
|
||||||
from typing import Annotated, Any
|
from typing import Annotated
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from langchain.tools import ToolRuntime
|
from langchain.tools import ToolRuntime
|
||||||
|
|
@ -131,56 +135,33 @@ def _parent_dispatching_two_subagents(
|
||||||
return g.compile(checkpointer=checkpointer)
|
return g.compile(checkpointer=checkpointer)
|
||||||
|
|
||||||
|
|
||||||
def _captured_payloads_by_content_marker(
|
|
||||||
final_state, *, marker_keys: list[str]
|
|
||||||
) -> dict[str, dict[str, Any]]:
|
|
||||||
"""Extract per-subagent resume payloads from the parent's final messages.
|
|
||||||
|
|
||||||
Each subagent emitted ``AIMessage(json.dumps(payload))``. We tag them by
|
|
||||||
looking for a marker in the inner action-request name (``act_{i}``) plus
|
|
||||||
the per-decision content — but that's brittle. Instead we just collect
|
|
||||||
every JSON payload and let the test match by content.
|
|
||||||
"""
|
|
||||||
payloads: list[dict[str, Any]] = []
|
|
||||||
for msg in getattr(final_state, "values", {}).get("messages", []) or []:
|
|
||||||
content = getattr(msg, "content", None)
|
|
||||||
if not isinstance(content, str):
|
|
||||||
continue
|
|
||||||
try:
|
|
||||||
parsed = json.loads(content)
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
continue
|
|
||||||
if isinstance(parsed, dict) and "decisions" in parsed:
|
|
||||||
payloads.append(parsed)
|
|
||||||
|
|
||||||
by_marker: dict[str, dict[str, Any]] = {}
|
|
||||||
for marker in marker_keys:
|
|
||||||
for p in payloads:
|
|
||||||
text = json.dumps(p, sort_keys=True)
|
|
||||||
if marker in text:
|
|
||||||
by_marker[marker] = p
|
|
||||||
break
|
|
||||||
return by_marker
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_heterogeneous_decisions_route_to_correct_subagents_with_metadata_intact():
|
async def test_heterogeneous_decisions_route_to_correct_subagents_with_metadata_intact():
|
||||||
"""Mixed approve/reject/edit decisions across two parallel subagents.
|
"""Mixed approve/reject/edit decisions across two parallel subagents.
|
||||||
|
|
||||||
Setup:
|
Setup chosen so the slicer's cursor arithmetic is sensitive to off-by-one
|
||||||
- Sub-A pauses with a 1-action bundle (``act_0``).
|
refactors:
|
||||||
- Sub-B pauses with a 2-action bundle (``act_0``, ``act_1``).
|
- Sub-A pauses with a 2-action bundle (``act_0``, ``act_1``).
|
||||||
|
- Sub-B pauses with a 3-action bundle (``act_0``, ``act_1``, ``act_2``).
|
||||||
- Parent ends up with 2 pending interrupts (one per subagent).
|
- Parent ends up with 2 pending interrupts (one per subagent).
|
||||||
|
|
||||||
The frontend submits a flat ``[A_approve, B_reject, B_edit]`` list; our
|
With both counts ``> 1``, a buggy ``cursor += 1`` (instead of
|
||||||
slicer must split into ``{tcid_A: [A_approve], tcid_B: [B_reject, B_edit]}``
|
``cursor += action_count``) produces a different B-slice from the correct
|
||||||
and the bridge must forward each subagent's slice intact — including the
|
one, so the assertions catch it. A ``(1, 2)`` configuration would not
|
||||||
``message`` on the reject and the ``edited_action.args`` on the edit.
|
because ``+= 1`` and ``+= count`` are arithmetically identical when
|
||||||
|
``count == 1``.
|
||||||
|
|
||||||
|
The frontend submits a flat
|
||||||
|
``[A_approve, A_reject, B_edit, B_approve, B_reject]`` list with distinct
|
||||||
|
``message`` and ``edited_action`` payloads; our slicer must split into
|
||||||
|
``{tcid_A: [A_approve, A_reject], tcid_B: [B_edit, B_approve, B_reject]}``
|
||||||
|
and the bridge must forward each subagent's slice intact — including all
|
||||||
|
metadata, in original order.
|
||||||
"""
|
"""
|
||||||
checkpointer = InMemorySaver()
|
checkpointer = InMemorySaver()
|
||||||
|
|
||||||
sub_a = _build_capturing_subagent(checkpointer, action_count=1)
|
sub_a = _build_capturing_subagent(checkpointer, action_count=2)
|
||||||
sub_b = _build_capturing_subagent(checkpointer, action_count=2)
|
sub_b = _build_capturing_subagent(checkpointer, action_count=3)
|
||||||
|
|
||||||
task_tool = build_task_tool_with_parent_config(
|
task_tool = build_task_tool_with_parent_config(
|
||||||
[
|
[
|
||||||
|
|
@ -211,23 +192,25 @@ async def test_heterogeneous_decisions_route_to_correct_subagents_with_metadata_
|
||||||
|
|
||||||
pending = collect_pending_tool_calls(paused_state)
|
pending = collect_pending_tool_calls(paused_state)
|
||||||
pending_by_tcid = dict(pending)
|
pending_by_tcid = dict(pending)
|
||||||
assert pending_by_tcid == {"tcid-A": 1, "tcid-B": 2}, (
|
assert pending_by_tcid == {"tcid-A": 2, "tcid-B": 3}, (
|
||||||
f"REGRESSION: action-count accounting wrong; got {pending_by_tcid!r}"
|
f"REGRESSION: action-count accounting wrong; got {pending_by_tcid!r}"
|
||||||
)
|
)
|
||||||
|
|
||||||
a_approve = {"type": "approve"}
|
a_approve = {"type": "approve"}
|
||||||
b_reject = {"type": "reject", "message": "no thanks for B[0]"}
|
a_reject = {"type": "reject", "message": "A[1] looks redundant"}
|
||||||
b_edit = {
|
b_edit = {
|
||||||
"type": "edit",
|
"type": "edit",
|
||||||
"edited_action": {"name": "act_1", "args": {"i": 1, "edited": True}},
|
"edited_action": {"name": "act_0", "args": {"i": 0, "edited": True}},
|
||||||
}
|
}
|
||||||
flat_decisions = [a_approve, b_reject, b_edit]
|
b_approve = {"type": "approve"}
|
||||||
|
b_reject = {"type": "reject", "message": "B[2] needs more context"}
|
||||||
|
flat_decisions = [a_approve, a_reject, b_edit, b_approve, b_reject]
|
||||||
|
|
||||||
by_tool_call_id = slice_decisions_by_tool_call(flat_decisions, pending)
|
by_tool_call_id = slice_decisions_by_tool_call(flat_decisions, pending)
|
||||||
|
|
||||||
assert by_tool_call_id == {
|
assert by_tool_call_id == {
|
||||||
"tcid-A": {"decisions": [a_approve]},
|
"tcid-A": {"decisions": [a_approve, a_reject]},
|
||||||
"tcid-B": {"decisions": [b_reject, b_edit]},
|
"tcid-B": {"decisions": [b_edit, b_approve, b_reject]},
|
||||||
}, f"REGRESSION: slicer mis-routed decisions: {by_tool_call_id!r}"
|
}, f"REGRESSION: slicer mis-routed decisions: {by_tool_call_id!r}"
|
||||||
|
|
||||||
config["configurable"]["surfsense_resume_value"] = by_tool_call_id
|
config["configurable"]["surfsense_resume_value"] = by_tool_call_id
|
||||||
|
|
@ -240,17 +223,7 @@ async def test_heterogeneous_decisions_route_to_correct_subagents_with_metadata_
|
||||||
f"REGRESSION: leftover pending interrupts after resume: {final_state.interrupts!r}"
|
f"REGRESSION: leftover pending interrupts after resume: {final_state.interrupts!r}"
|
||||||
)
|
)
|
||||||
|
|
||||||
captured = _captured_payloads_by_content_marker(
|
payloads: list[dict] = []
|
||||||
final_state,
|
|
||||||
marker_keys=["no thanks for B[0]", '"i": 1, "edited": true'],
|
|
||||||
)
|
|
||||||
payload_b = captured.get("no thanks for B[0]")
|
|
||||||
assert payload_b is not None, "could not locate sub-B's captured payload"
|
|
||||||
assert payload_b == {
|
|
||||||
"decisions": [b_reject, b_edit]
|
|
||||||
}, f"REGRESSION: sub-B received wrong payload: {payload_b!r}"
|
|
||||||
|
|
||||||
payloads = []
|
|
||||||
for msg in final_state.values.get("messages", []) or []:
|
for msg in final_state.values.get("messages", []) or []:
|
||||||
content = getattr(msg, "content", None)
|
content = getattr(msg, "content", None)
|
||||||
if isinstance(content, str):
|
if isinstance(content, str):
|
||||||
|
|
@ -258,11 +231,16 @@ async def test_heterogeneous_decisions_route_to_correct_subagents_with_metadata_
|
||||||
payloads.append(json.loads(content))
|
payloads.append(json.loads(content))
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
pass
|
pass
|
||||||
payload_a = next(
|
|
||||||
(p for p in payloads if p == {"decisions": [a_approve]}), None
|
expected_a = {"decisions": [a_approve, a_reject]}
|
||||||
|
expected_b = {"decisions": [b_edit, b_approve, b_reject]}
|
||||||
|
|
||||||
|
assert expected_a in payloads, (
|
||||||
|
f"REGRESSION: sub-A did not receive its 2-decision slice in original order; "
|
||||||
|
f"payloads seen: {payloads!r}"
|
||||||
)
|
)
|
||||||
assert payload_a is not None, (
|
assert expected_b in payloads, (
|
||||||
f"REGRESSION: sub-A did not receive its single approve in isolation; "
|
f"REGRESSION: sub-B did not receive its 3-decision slice in original order; "
|
||||||
f"payloads seen: {payloads!r}"
|
f"payloads seen: {payloads!r}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue