multi_agent_chat/middleware: per-call thread_id, tcid-keyed resume, decisions slicer

This commit is contained in:
CREDO23 2026-05-13 19:56:51 +02:00
parent 246dae40a8
commit fc2c5b6445
5 changed files with 513 additions and 30 deletions

View file

@ -21,7 +21,17 @@ _LANGGRAPH_SCRATCHPAD_KEY = "__pregel_scratchpad"
def subagent_invoke_config(runtime: ToolRuntime) -> dict[str, Any]:
"""RunnableConfig for the nested invoke; raises ``recursion_limit`` to the parent's budget."""
"""RunnableConfig for the nested invoke; raises ``recursion_limit`` and isolates ``thread_id``.
Each parallel subagent invocation lands in its own checkpoint slot keyed
by an extended ``thread_id`` of the form ``{parent_thread}::task:{tool_call_id}``.
The same call across the resume cycle keeps reading from the same snapshot
(``tool_call_id`` is stable per LLM-emitted call).
We namespace via ``thread_id`` rather than ``checkpoint_ns`` because
langgraph's ``aget_state`` interprets a non-empty ``checkpoint_ns`` as a
subgraph path and raises ``ValueError("Subgraph X not found")``.
"""
merged: dict[str, Any] = dict(runtime.config) if runtime.config else {}
current_limit = merged.get("recursion_limit")
try:
@ -30,43 +40,68 @@ def subagent_invoke_config(runtime: ToolRuntime) -> dict[str, Any]:
current_int = 0
if current_int < DEFAULT_SUBAGENT_RECURSION_LIMIT:
merged["recursion_limit"] = DEFAULT_SUBAGENT_RECURSION_LIMIT
configurable: dict[str, Any] = dict(merged.get("configurable") or {})
parent_thread_id = configurable.get("thread_id")
per_call_suffix = f"task:{runtime.tool_call_id}"
configurable["thread_id"] = (
f"{parent_thread_id}::{per_call_suffix}"
if parent_thread_id
else per_call_suffix
)
merged["configurable"] = configurable
return merged
def consume_surfsense_resume(runtime: ToolRuntime) -> Any:
"""Pop the resume payload; siblings share ``configurable`` by reference."""
"""Pop the resume payload for *this* call's ``tool_call_id``.
The configurable holds ``surfsense_resume_value: dict[tool_call_id, payload]``
so parallel sibling subagents (each with their own ``tool_call_id``) read
only their own decision and never race on a shared scalar.
"""
cfg = runtime.config or {}
configurable = cfg.get("configurable") if isinstance(cfg, dict) else None
if not isinstance(configurable, dict):
return None
return configurable.pop("surfsense_resume_value", None)
by_tcid = configurable.get("surfsense_resume_value")
if not isinstance(by_tcid, dict):
return None
payload = by_tcid.pop(runtime.tool_call_id, None)
if not by_tcid:
configurable.pop("surfsense_resume_value", None)
return payload
def has_surfsense_resume(runtime: ToolRuntime) -> bool:
"""True iff a resume payload is queued on this runtime (non-destructive)."""
"""True iff a resume payload for this call's ``tool_call_id`` is queued (non-destructive)."""
cfg = runtime.config or {}
configurable = cfg.get("configurable") if isinstance(cfg, dict) else None
if not isinstance(configurable, dict):
return False
return "surfsense_resume_value" in configurable
by_tcid = configurable.get("surfsense_resume_value")
if not isinstance(by_tcid, dict):
return False
return runtime.tool_call_id in by_tcid
def drain_parent_null_resume(runtime: ToolRuntime) -> None:
"""Consume the parent's lingering ``NULL_TASK_ID/RESUME`` write before delegating.
``stream_resume_chat`` wakes the main agent with
``Command(resume={"decisions": [...]})`` so the propagated
``_lg_interrupt(...)`` can return. langgraph stores that payload as the
parent task's ``null_resume`` pending write, which only gets consumed
*after* ``subagent.[a]invoke`` returns (when the post-call propagation
re-fires). While the subagent is mid-execution, any *new* ``interrupt()``
inside it (e.g. a follow-up tool call after a mixed approve/reject) walks
``subagent_scratchpad parent_scratchpad.get_null_resume`` and picks up
the parent's still-live decisions — mismatching against a different number
of hanging tool calls and crashing ``HumanInTheLoopMiddleware``.
``Command(resume={tool_call_id: {"decisions": [...]}})`` so the previously
propagated parent-level interrupt can return. langgraph stores that
payload as the parent task's ``null_resume`` pending write. The ``task``
tool then forwards this turn's slice into the subagent via its own
``Command(resume=...)``. While the subagent is mid-execution, any *new*
``interrupt()`` inside it (e.g. a follow-up tool call after a mixed
approve/reject) walks ``subagent_scratchpad parent_scratchpad.get_null_resume``
and picks up the parent's still-live decisions — mismatching against a
different number of hanging tool calls and crashing
``HumanInTheLoopMiddleware``.
Draining the write here closes that cross-graph leak so subagent
interrupts pause cleanly and re-propagate as a fresh approval card.
interrupts pause cleanly and bubble back up as a fresh approval card.
"""
cfg = runtime.config or {}
configurable = cfg.get("configurable") if isinstance(cfg, dict) else None

View file

@ -0,0 +1,137 @@
"""Route a flat ``decisions`` list to per-``tool_call_id`` resume payloads.
The frontend submits decisions in the same order the SSE stream emitted
approval cards. When multiple parallel subagents are paused, the backend uses
this module to:
1. Read ``state.interrupts`` from the parent's paused snapshot, extracting
``[(tool_call_id, action_count), ...]`` from each interrupt's value.
The ``tool_call_id`` is stamped on by ``propagation.wrap_with_tool_call_id``
inside ``task_tool``'s catch-and-stamp block when a subagent's
``GraphInterrupt`` bubbles up through ``[a]task``.
2. Slice the flat ``decisions`` list against that ordered pending list to
produce the dict shape expected by ``consume_surfsense_resume``.
Both helpers are pure: callers own the state and the input decisions; we
return new structures and never mutate.
"""
from __future__ import annotations
import logging
from collections.abc import Iterable
from typing import Any
logger = logging.getLogger(__name__)
def slice_decisions_by_tool_call(
decisions: list[dict[str, Any]],
pending: Iterable[tuple[str, int]],
) -> dict[str, dict[str, Any]]:
"""Slice ``decisions`` into ``{tool_call_id: {"decisions": <slice>}}``.
Args:
decisions: Flat list of decisions in the order the SSE stream rendered
them.
pending: Ordered ``(tool_call_id, action_count)`` pairs in the same
order. The slicer consumes ``decisions`` left-to-right.
Returns:
Per-``tool_call_id`` payload dict ready to be written to
``configurable["surfsense_resume_value"]``.
Raises:
ValueError: When the total expected action count differs from the
number of decisions provided. We fail loud rather than silently
dropping or padding so a frontend/backend contract drift surfaces
immediately.
"""
pending_list = list(pending)
expected = sum(count for _, count in pending_list)
if expected != len(decisions):
raise ValueError(
f"Decision count mismatch: pending tool calls expect "
f"{expected} actions but received {len(decisions)} decisions."
)
routed: dict[str, dict[str, Any]] = {}
cursor = 0
for tool_call_id, action_count in pending_list:
routed[tool_call_id] = {
"decisions": decisions[cursor : cursor + action_count]
}
cursor += action_count
return routed
def collect_pending_tool_calls(state: Any) -> list[tuple[str, int]]:
"""Extract ``[(tool_call_id, action_count), ...]`` from a paused parent state.
Reads ``state.interrupts`` (the bundle langgraph aggregated from each
paused subagent's propagated interrupt). Each interrupt value carries the
``tool_call_id`` that the parent's ``task`` tool was processing — see
``propagation.wrap_with_tool_call_id`` and ``task_tool``'s
``except GraphInterrupt`` chokepoint.
Order is preserved from ``state.interrupts``, which is the order the SSE
stream emitted approval cards. The frontend submits decisions in that
same order, so the slicer can consume them left-to-right.
Interrupts without a ``tool_call_id`` are skipped they were not
produced by our task-routing layer (e.g. parent-side HITL middleware on
a different tool); ``stream_resume_chat`` is not responsible for routing
those.
Args:
state: A langgraph ``StateSnapshot`` (or any object with an
``interrupts`` attribute).
Returns:
Ordered list of ``(tool_call_id, action_count)``. ``action_count`` is
``len(value["action_requests"])`` for HITL-bundle values, or ``1`` for
scalar-style ``interrupt("...")`` values that were wrapped as
``{"value": ..., "tool_call_id": ...}``.
Raises:
ValueError: When an interrupt value carries a ``tool_call_id`` but
the action count cannot be determined (contract bug every
propagated value should be either a HITL bundle or a wrapped
scalar).
"""
pending: list[tuple[str, int]] = []
for idx, interrupt_obj in enumerate(getattr(state, "interrupts", ()) or ()):
value = getattr(interrupt_obj, "value", None)
if not isinstance(value, dict):
logger.warning(
"[hitl_route] interrupt[%d] skipped: value not a dict (type=%s)",
idx,
type(value).__name__,
)
continue
tool_call_id = value.get("tool_call_id")
if not isinstance(tool_call_id, str):
# Should not happen post-stamping; flag loudly if a regression
# ever lets an unstamped value reach the parent state.
logger.warning(
"[hitl_route] interrupt[%d] skipped: no tool_call_id stamp (keys=%s)",
idx,
sorted(value.keys()),
)
continue
action_requests = value.get("action_requests")
if isinstance(action_requests, list):
pending.append((tool_call_id, len(action_requests)))
continue
if "value" in value:
pending.append((tool_call_id, 1))
continue
raise ValueError(
f"Interrupt for tool_call_id={tool_call_id!r} has no "
"``action_requests`` list and is not a wrapped scalar value; "
"cannot determine action count for resume routing."
)
return pending

View file

@ -0,0 +1,154 @@
"""Slicing helper that routes a flat decisions list to per-tool-call payloads.
The frontend submits ``decisions: list[ResumeDecision]`` in the same order the
SSE stream emitted approval cards. When multiple parallel subagents are paused,
the backend slices that flat list into per-``tool_call_id`` payloads so each
``atask`` reads only its own decisions through ``consume_surfsense_resume``.
The extractor reads ``state.interrupts[i].value["tool_call_id"]`` which is
populated by ``propagation.wrap_with_tool_call_id`` inside ``task_tool``'s
``except GraphInterrupt`` chokepoint whenever a subagent interrupt bubbles up
through ``[a]task`` to build the ordered ``pending`` list the slicer needs.
"""
from __future__ import annotations
from types import SimpleNamespace
import pytest
from app.agents.multi_agent_chat.middleware.main_agent.checkpointed_subagent_middleware.resume_routing import (
collect_pending_tool_calls,
slice_decisions_by_tool_call,
)
class TestSliceDecisionsByToolCall:
def test_splits_flat_decisions_across_two_pending_tool_calls(self):
decisions = [
{"type": "approve"},
{"type": "edit", "edited_action": {"name": "edited-b1"}},
{"type": "reject"},
{"type": "approve"},
{"type": "approve"},
]
pending = [
("tcid-A", 3),
("tcid-B", 2),
]
routed = slice_decisions_by_tool_call(decisions, pending)
assert routed == {
"tcid-A": {"decisions": decisions[0:3]},
"tcid-B": {"decisions": decisions[3:5]},
}
def test_raises_when_decision_count_less_than_total_actions(self):
decisions = [{"type": "approve"}, {"type": "approve"}]
pending = [("tcid-A", 3), ("tcid-B", 2)]
with pytest.raises(ValueError, match=r"5 actions.*2 decisions"):
slice_decisions_by_tool_call(decisions, pending)
def test_raises_when_decision_count_greater_than_total_actions(self):
decisions = [{"type": "approve"}] * 6
pending = [("tcid-A", 3), ("tcid-B", 2)]
with pytest.raises(ValueError, match=r"5 actions.*6 decisions"):
slice_decisions_by_tool_call(decisions, pending)
def test_handles_single_pending_tool_call(self):
decisions = [{"type": "approve"}, {"type": "reject"}]
pending = [("tcid-only", 2)]
routed = slice_decisions_by_tool_call(decisions, pending)
assert routed == {"tcid-only": {"decisions": decisions}}
def test_returns_empty_dict_for_no_pending(self):
routed = slice_decisions_by_tool_call([], [])
assert routed == {}
def _interrupt_with(tool_call_id: str, action_count: int):
return SimpleNamespace(
id=f"i-{tool_call_id}",
value={
"action_requests": [{"name": "n", "args": {}}] * action_count,
"review_configs": [{}] * action_count,
"tool_call_id": tool_call_id,
},
)
class TestCollectPendingToolCalls:
def test_single_pending_returns_one_pair(self):
state = SimpleNamespace(interrupts=(_interrupt_with("tcid-only", 3),))
assert collect_pending_tool_calls(state) == [("tcid-only", 3)]
def test_multiple_pending_preserves_state_order(self):
"""Order must match what the SSE stream emitted (= state.interrupts order)."""
state = SimpleNamespace(
interrupts=(
_interrupt_with("tcid-A", 2),
_interrupt_with("tcid-B", 3),
)
)
assert collect_pending_tool_calls(state) == [("tcid-A", 2), ("tcid-B", 3)]
def test_empty_when_no_interrupts(self):
state = SimpleNamespace(interrupts=())
assert collect_pending_tool_calls(state) == []
def test_skips_interrupts_without_tool_call_id(self):
"""Defensive: interrupts not produced by our propagation layer are ignored.
``stream_resume_chat`` only owns the ``task``-routing slice; non-task
interrupts (e.g. parent-side HITL middleware on a different tool) are
not the slicer's responsibility.
"""
state = SimpleNamespace(
interrupts=(
_interrupt_with("tcid-A", 2),
SimpleNamespace(id="i-foreign", value={"action_requests": [{}]}),
_interrupt_with("tcid-B", 1),
)
)
assert collect_pending_tool_calls(state) == [("tcid-A", 2), ("tcid-B", 1)]
def test_handles_scalar_value_interrupt(self):
"""Subagents using ``interrupt("approve?")`` style propagate as ``{"value": ..., "tool_call_id": ...}``.
These have no ``action_requests`` count them as a single action so
the frontend submits exactly one decision per such interrupt.
"""
state = SimpleNamespace(
interrupts=(
SimpleNamespace(
id="i-A",
value={"value": "approve?", "tool_call_id": "tcid-A"},
),
)
)
assert collect_pending_tool_calls(state) == [("tcid-A", 1)]
def test_raises_when_interrupt_value_missing_action_count_keys(self):
"""An interrupt with ``tool_call_id`` but no usable count signals a contract bug."""
state = SimpleNamespace(
interrupts=(
SimpleNamespace(
id="i-A",
value={"tool_call_id": "tcid-A", "weird_shape": True},
),
)
)
with pytest.raises(ValueError, match="action_requests"):
collect_pending_tool_calls(state)

View file

@ -1,4 +1,4 @@
"""Resume side-channel must be read exactly once per turn."""
"""Resume side-channel is keyed per ``tool_call_id`` so parallel siblings can resume independently."""
from __future__ import annotations
@ -10,33 +10,61 @@ from app.agents.multi_agent_chat.middleware.main_agent.checkpointed_subagent_mid
)
def _runtime_with_config(config: dict) -> ToolRuntime:
def _runtime_with_config(
config: dict, *, tool_call_id: str = "tcid-test"
) -> ToolRuntime:
return ToolRuntime(
state=None,
context=None,
config=config,
stream_writer=None,
tool_call_id="tcid-test",
tool_call_id=tool_call_id,
store=None,
)
class TestConsumeSurfsenseResume:
def test_pops_value_on_first_call(self):
def test_pops_only_entry_matching_runtime_tool_call_id(self):
configurable = {
"surfsense_resume_value": {
"tcid-A": {"decisions": ["approve"]},
"tcid-B": {"decisions": ["reject"]},
}
}
runtime = _runtime_with_config(
{"configurable": {"surfsense_resume_value": {"decisions": ["approve"]}}}
{"configurable": configurable}, tool_call_id="tcid-A"
)
assert consume_surfsense_resume(runtime) == {"decisions": ["approve"]}
def test_second_call_returns_none(self):
configurable: dict = {"surfsense_resume_value": {"decisions": ["approve"]}}
runtime = _runtime_with_config({"configurable": configurable})
def test_popping_one_entry_leaves_siblings_untouched(self):
configurable = {
"surfsense_resume_value": {
"tcid-A": {"decisions": ["approve"]},
"tcid-B": {"decisions": ["reject"]},
}
}
runtime_a = _runtime_with_config(
{"configurable": configurable}, tool_call_id="tcid-A"
)
consume_surfsense_resume(runtime)
consume_surfsense_resume(runtime_a)
assert configurable["surfsense_resume_value"] == {
"tcid-B": {"decisions": ["reject"]}
}
def test_returns_none_when_no_entry_for_this_tool_call(self):
runtime = _runtime_with_config(
{
"configurable": {
"surfsense_resume_value": {"tcid-other": {"decisions": []}}
}
},
tool_call_id="tcid-A",
)
assert consume_surfsense_resume(runtime) is None
assert "surfsense_resume_value" not in configurable
def test_returns_none_when_no_payload_queued(self):
runtime = _runtime_with_config({"configurable": {}})
@ -48,22 +76,57 @@ class TestConsumeSurfsenseResume:
assert consume_surfsense_resume(runtime) is None
def test_drops_empty_dict_after_last_entry_consumed(self):
configurable = {
"surfsense_resume_value": {"tcid-A": {"decisions": ["approve"]}}
}
runtime = _runtime_with_config(
{"configurable": configurable}, tool_call_id="tcid-A"
)
consume_surfsense_resume(runtime)
assert "surfsense_resume_value" not in configurable
class TestHasSurfsenseResume:
def test_true_when_payload_queued(self):
def test_true_when_entry_for_this_tool_call_present(self):
runtime = _runtime_with_config(
{"configurable": {"surfsense_resume_value": "approve"}}
{
"configurable": {
"surfsense_resume_value": {"tcid-A": {"decisions": ["approve"]}}
}
},
tool_call_id="tcid-A",
)
assert has_surfsense_resume(runtime) is True
def test_false_when_entry_for_other_tool_call_only(self):
runtime = _runtime_with_config(
{
"configurable": {
"surfsense_resume_value": {"tcid-other": {"decisions": []}}
}
},
tool_call_id="tcid-A",
)
assert has_surfsense_resume(runtime) is False
def test_does_not_consume_payload(self):
configurable = {"surfsense_resume_value": "approve"}
runtime = _runtime_with_config({"configurable": configurable})
configurable = {
"surfsense_resume_value": {"tcid-A": {"decisions": ["approve"]}}
}
runtime = _runtime_with_config(
{"configurable": configurable}, tool_call_id="tcid-A"
)
has_surfsense_resume(runtime)
assert configurable == {"surfsense_resume_value": "approve"}
assert configurable["surfsense_resume_value"] == {
"tcid-A": {"decisions": ["approve"]}
}
def test_false_when_payload_absent(self):
runtime = _runtime_with_config({"configurable": {}})

View file

@ -0,0 +1,94 @@
"""Per-call ``thread_id`` derivation for nested subagent invocations.
Parallel ``task`` (and ``ask_knowledge_base``) calls must land in disjoint
checkpoint slots so their nested pregel runs do not stomp on each other or on
the parent's checkpoint state. The slot key is derived from the runtime's
``tool_call_id`` so the same call across the resume cycle keeps reading from
the same snapshot.
Note: we namespace via ``thread_id`` rather than ``checkpoint_ns`` because
langgraph's ``aget_state`` interprets a non-empty ``checkpoint_ns`` as a
subgraph path and raises ``ValueError("Subgraph X not found")``. ``thread_id``
is the primary checkpoint key and is free-form, so it's the right primitive.
"""
from __future__ import annotations
from langchain.tools import ToolRuntime
from app.agents.multi_agent_chat.middleware.main_agent.checkpointed_subagent_middleware.config import (
subagent_invoke_config,
)
def _runtime(*, tool_call_id: str, config: dict | None = None) -> ToolRuntime:
return ToolRuntime(
state=None,
context=None,
config=config or {},
stream_writer=None,
tool_call_id=tool_call_id,
store=None,
)
class TestSubagentInvokeThreadId:
def test_sets_per_call_thread_id_under_parent(self):
runtime = _runtime(
tool_call_id="tcid-A",
config={"configurable": {"thread_id": "t1"}},
)
sub_config = subagent_invoke_config(runtime)
assert sub_config["configurable"]["thread_id"] == "t1::task:tcid-A"
def test_per_call_thread_id_nests_under_already_namespaced_parent(self):
"""A subagent that itself spawns a subagent must keep nesting cleanly."""
runtime = _runtime(
tool_call_id="tcid-inner",
config={
"configurable": {
"thread_id": "t1::task:tcid-outer",
}
},
)
sub_config = subagent_invoke_config(runtime)
assert (
sub_config["configurable"]["thread_id"]
== "t1::task:tcid-outer::task:tcid-inner"
)
def test_different_tool_call_ids_produce_different_thread_ids(self):
config = {"configurable": {"thread_id": "t1"}}
rt_a = _runtime(tool_call_id="tcid-A", config=config)
rt_b = _runtime(tool_call_id="tcid-B", config=config)
tid_a = subagent_invoke_config(rt_a)["configurable"]["thread_id"]
tid_b = subagent_invoke_config(rt_b)["configurable"]["thread_id"]
assert tid_a != tid_b
def test_same_tool_call_id_produces_same_thread_id_across_repeated_calls(self):
"""Resume bridge needs to find the snapshot it primed earlier."""
config = {"configurable": {"thread_id": "t1"}}
rt_first = _runtime(tool_call_id="tcid-A", config=config)
rt_second = _runtime(tool_call_id="tcid-A", config=config)
tid_first = subagent_invoke_config(rt_first)["configurable"]["thread_id"]
tid_second = subagent_invoke_config(rt_second)["configurable"]["thread_id"]
assert tid_first == tid_second
def test_does_not_mutate_caller_config(self):
"""Repeated calls must not accumulate suffixes onto the parent's config."""
original_thread_id = "t1"
config = {"configurable": {"thread_id": original_thread_id}}
runtime = _runtime(tool_call_id="tcid-A", config=config)
subagent_invoke_config(runtime)
subagent_invoke_config(runtime)
assert config["configurable"]["thread_id"] == original_thread_id