multi_agent_chat/middleware: per-call thread_id, tcid-keyed resume, decisions slicer

This commit is contained in:
CREDO23 2026-05-13 19:56:51 +02:00
parent 246dae40a8
commit fc2c5b6445
5 changed files with 513 additions and 30 deletions

View file

@ -0,0 +1,154 @@
"""Slicing helper that routes a flat decisions list to per-tool-call payloads.
The frontend submits ``decisions: list[ResumeDecision]`` in the same order the
SSE stream emitted approval cards. When multiple parallel subagents are paused,
the backend slices that flat list into per-``tool_call_id`` payloads so each
``atask`` reads only its own decisions through ``consume_surfsense_resume``.
The extractor reads ``state.interrupts[i].value["tool_call_id"]`` which is
populated by ``propagation.wrap_with_tool_call_id`` inside ``task_tool``'s
``except GraphInterrupt`` chokepoint whenever a subagent interrupt bubbles up
through ``[a]task`` to build the ordered ``pending`` list the slicer needs.
"""
from __future__ import annotations
from types import SimpleNamespace
import pytest
from app.agents.multi_agent_chat.middleware.main_agent.checkpointed_subagent_middleware.resume_routing import (
collect_pending_tool_calls,
slice_decisions_by_tool_call,
)
class TestSliceDecisionsByToolCall:
def test_splits_flat_decisions_across_two_pending_tool_calls(self):
decisions = [
{"type": "approve"},
{"type": "edit", "edited_action": {"name": "edited-b1"}},
{"type": "reject"},
{"type": "approve"},
{"type": "approve"},
]
pending = [
("tcid-A", 3),
("tcid-B", 2),
]
routed = slice_decisions_by_tool_call(decisions, pending)
assert routed == {
"tcid-A": {"decisions": decisions[0:3]},
"tcid-B": {"decisions": decisions[3:5]},
}
def test_raises_when_decision_count_less_than_total_actions(self):
decisions = [{"type": "approve"}, {"type": "approve"}]
pending = [("tcid-A", 3), ("tcid-B", 2)]
with pytest.raises(ValueError, match=r"5 actions.*2 decisions"):
slice_decisions_by_tool_call(decisions, pending)
def test_raises_when_decision_count_greater_than_total_actions(self):
decisions = [{"type": "approve"}] * 6
pending = [("tcid-A", 3), ("tcid-B", 2)]
with pytest.raises(ValueError, match=r"5 actions.*6 decisions"):
slice_decisions_by_tool_call(decisions, pending)
def test_handles_single_pending_tool_call(self):
decisions = [{"type": "approve"}, {"type": "reject"}]
pending = [("tcid-only", 2)]
routed = slice_decisions_by_tool_call(decisions, pending)
assert routed == {"tcid-only": {"decisions": decisions}}
def test_returns_empty_dict_for_no_pending(self):
routed = slice_decisions_by_tool_call([], [])
assert routed == {}
def _interrupt_with(tool_call_id: str, action_count: int):
return SimpleNamespace(
id=f"i-{tool_call_id}",
value={
"action_requests": [{"name": "n", "args": {}}] * action_count,
"review_configs": [{}] * action_count,
"tool_call_id": tool_call_id,
},
)
class TestCollectPendingToolCalls:
def test_single_pending_returns_one_pair(self):
state = SimpleNamespace(interrupts=(_interrupt_with("tcid-only", 3),))
assert collect_pending_tool_calls(state) == [("tcid-only", 3)]
def test_multiple_pending_preserves_state_order(self):
"""Order must match what the SSE stream emitted (= state.interrupts order)."""
state = SimpleNamespace(
interrupts=(
_interrupt_with("tcid-A", 2),
_interrupt_with("tcid-B", 3),
)
)
assert collect_pending_tool_calls(state) == [("tcid-A", 2), ("tcid-B", 3)]
def test_empty_when_no_interrupts(self):
state = SimpleNamespace(interrupts=())
assert collect_pending_tool_calls(state) == []
def test_skips_interrupts_without_tool_call_id(self):
"""Defensive: interrupts not produced by our propagation layer are ignored.
``stream_resume_chat`` only owns the ``task``-routing slice; non-task
interrupts (e.g. parent-side HITL middleware on a different tool) are
not the slicer's responsibility.
"""
state = SimpleNamespace(
interrupts=(
_interrupt_with("tcid-A", 2),
SimpleNamespace(id="i-foreign", value={"action_requests": [{}]}),
_interrupt_with("tcid-B", 1),
)
)
assert collect_pending_tool_calls(state) == [("tcid-A", 2), ("tcid-B", 1)]
def test_handles_scalar_value_interrupt(self):
"""Subagents using ``interrupt("approve?")`` style propagate as ``{"value": ..., "tool_call_id": ...}``.
These have no ``action_requests`` count them as a single action so
the frontend submits exactly one decision per such interrupt.
"""
state = SimpleNamespace(
interrupts=(
SimpleNamespace(
id="i-A",
value={"value": "approve?", "tool_call_id": "tcid-A"},
),
)
)
assert collect_pending_tool_calls(state) == [("tcid-A", 1)]
def test_raises_when_interrupt_value_missing_action_count_keys(self):
"""An interrupt with ``tool_call_id`` but no usable count signals a contract bug."""
state = SimpleNamespace(
interrupts=(
SimpleNamespace(
id="i-A",
value={"tool_call_id": "tcid-A", "weird_shape": True},
),
)
)
with pytest.raises(ValueError, match="action_requests"):
collect_pending_tool_calls(state)

View file

@ -1,4 +1,4 @@
"""Resume side-channel must be read exactly once per turn."""
"""Resume side-channel is keyed per ``tool_call_id`` so parallel siblings can resume independently."""
from __future__ import annotations
@ -10,33 +10,61 @@ from app.agents.multi_agent_chat.middleware.main_agent.checkpointed_subagent_mid
)
def _runtime_with_config(config: dict) -> ToolRuntime:
def _runtime_with_config(
config: dict, *, tool_call_id: str = "tcid-test"
) -> ToolRuntime:
return ToolRuntime(
state=None,
context=None,
config=config,
stream_writer=None,
tool_call_id="tcid-test",
tool_call_id=tool_call_id,
store=None,
)
class TestConsumeSurfsenseResume:
def test_pops_value_on_first_call(self):
def test_pops_only_entry_matching_runtime_tool_call_id(self):
configurable = {
"surfsense_resume_value": {
"tcid-A": {"decisions": ["approve"]},
"tcid-B": {"decisions": ["reject"]},
}
}
runtime = _runtime_with_config(
{"configurable": {"surfsense_resume_value": {"decisions": ["approve"]}}}
{"configurable": configurable}, tool_call_id="tcid-A"
)
assert consume_surfsense_resume(runtime) == {"decisions": ["approve"]}
def test_second_call_returns_none(self):
configurable: dict = {"surfsense_resume_value": {"decisions": ["approve"]}}
runtime = _runtime_with_config({"configurable": configurable})
def test_popping_one_entry_leaves_siblings_untouched(self):
configurable = {
"surfsense_resume_value": {
"tcid-A": {"decisions": ["approve"]},
"tcid-B": {"decisions": ["reject"]},
}
}
runtime_a = _runtime_with_config(
{"configurable": configurable}, tool_call_id="tcid-A"
)
consume_surfsense_resume(runtime)
consume_surfsense_resume(runtime_a)
assert configurable["surfsense_resume_value"] == {
"tcid-B": {"decisions": ["reject"]}
}
def test_returns_none_when_no_entry_for_this_tool_call(self):
runtime = _runtime_with_config(
{
"configurable": {
"surfsense_resume_value": {"tcid-other": {"decisions": []}}
}
},
tool_call_id="tcid-A",
)
assert consume_surfsense_resume(runtime) is None
assert "surfsense_resume_value" not in configurable
def test_returns_none_when_no_payload_queued(self):
runtime = _runtime_with_config({"configurable": {}})
@ -48,22 +76,57 @@ class TestConsumeSurfsenseResume:
assert consume_surfsense_resume(runtime) is None
def test_drops_empty_dict_after_last_entry_consumed(self):
configurable = {
"surfsense_resume_value": {"tcid-A": {"decisions": ["approve"]}}
}
runtime = _runtime_with_config(
{"configurable": configurable}, tool_call_id="tcid-A"
)
consume_surfsense_resume(runtime)
assert "surfsense_resume_value" not in configurable
class TestHasSurfsenseResume:
def test_true_when_payload_queued(self):
def test_true_when_entry_for_this_tool_call_present(self):
runtime = _runtime_with_config(
{"configurable": {"surfsense_resume_value": "approve"}}
{
"configurable": {
"surfsense_resume_value": {"tcid-A": {"decisions": ["approve"]}}
}
},
tool_call_id="tcid-A",
)
assert has_surfsense_resume(runtime) is True
def test_false_when_entry_for_other_tool_call_only(self):
runtime = _runtime_with_config(
{
"configurable": {
"surfsense_resume_value": {"tcid-other": {"decisions": []}}
}
},
tool_call_id="tcid-A",
)
assert has_surfsense_resume(runtime) is False
def test_does_not_consume_payload(self):
configurable = {"surfsense_resume_value": "approve"}
runtime = _runtime_with_config({"configurable": configurable})
configurable = {
"surfsense_resume_value": {"tcid-A": {"decisions": ["approve"]}}
}
runtime = _runtime_with_config(
{"configurable": configurable}, tool_call_id="tcid-A"
)
has_surfsense_resume(runtime)
assert configurable == {"surfsense_resume_value": "approve"}
assert configurable["surfsense_resume_value"] == {
"tcid-A": {"decisions": ["approve"]}
}
def test_false_when_payload_absent(self):
runtime = _runtime_with_config({"configurable": {}})

View file

@ -0,0 +1,94 @@
"""Per-call ``thread_id`` derivation for nested subagent invocations.
Parallel ``task`` (and ``ask_knowledge_base``) calls must land in disjoint
checkpoint slots so their nested pregel runs do not stomp on each other or on
the parent's checkpoint state. The slot key is derived from the runtime's
``tool_call_id`` so the same call across the resume cycle keeps reading from
the same snapshot.
Note: we namespace via ``thread_id`` rather than ``checkpoint_ns`` because
langgraph's ``aget_state`` interprets a non-empty ``checkpoint_ns`` as a
subgraph path and raises ``ValueError("Subgraph X not found")``. ``thread_id``
is the primary checkpoint key and is free-form, so it's the right primitive.
"""
from __future__ import annotations
from langchain.tools import ToolRuntime
from app.agents.multi_agent_chat.middleware.main_agent.checkpointed_subagent_middleware.config import (
subagent_invoke_config,
)
def _runtime(*, tool_call_id: str, config: dict | None = None) -> ToolRuntime:
return ToolRuntime(
state=None,
context=None,
config=config or {},
stream_writer=None,
tool_call_id=tool_call_id,
store=None,
)
class TestSubagentInvokeThreadId:
def test_sets_per_call_thread_id_under_parent(self):
runtime = _runtime(
tool_call_id="tcid-A",
config={"configurable": {"thread_id": "t1"}},
)
sub_config = subagent_invoke_config(runtime)
assert sub_config["configurable"]["thread_id"] == "t1::task:tcid-A"
def test_per_call_thread_id_nests_under_already_namespaced_parent(self):
"""A subagent that itself spawns a subagent must keep nesting cleanly."""
runtime = _runtime(
tool_call_id="tcid-inner",
config={
"configurable": {
"thread_id": "t1::task:tcid-outer",
}
},
)
sub_config = subagent_invoke_config(runtime)
assert (
sub_config["configurable"]["thread_id"]
== "t1::task:tcid-outer::task:tcid-inner"
)
def test_different_tool_call_ids_produce_different_thread_ids(self):
config = {"configurable": {"thread_id": "t1"}}
rt_a = _runtime(tool_call_id="tcid-A", config=config)
rt_b = _runtime(tool_call_id="tcid-B", config=config)
tid_a = subagent_invoke_config(rt_a)["configurable"]["thread_id"]
tid_b = subagent_invoke_config(rt_b)["configurable"]["thread_id"]
assert tid_a != tid_b
def test_same_tool_call_id_produces_same_thread_id_across_repeated_calls(self):
"""Resume bridge needs to find the snapshot it primed earlier."""
config = {"configurable": {"thread_id": "t1"}}
rt_first = _runtime(tool_call_id="tcid-A", config=config)
rt_second = _runtime(tool_call_id="tcid-A", config=config)
tid_first = subagent_invoke_config(rt_first)["configurable"]["thread_id"]
tid_second = subagent_invoke_config(rt_second)["configurable"]["thread_id"]
assert tid_first == tid_second
def test_does_not_mutate_caller_config(self):
"""Repeated calls must not accumulate suffixes onto the parent's config."""
original_thread_id = "t1"
config = {"configurable": {"thread_id": original_thread_id}}
runtime = _runtime(tool_call_id="tcid-A", config=config)
subagent_invoke_config(runtime)
subagent_invoke_config(runtime)
assert config["configurable"]["thread_id"] == original_thread_id