multi_agent_chat/middleware: per-call thread_id, tcid-keyed resume, decisions slicer

This commit is contained in:
CREDO23 2026-05-13 19:56:51 +02:00
parent 246dae40a8
commit fc2c5b6445
5 changed files with 513 additions and 30 deletions

View file

@ -21,7 +21,17 @@ _LANGGRAPH_SCRATCHPAD_KEY = "__pregel_scratchpad"
def subagent_invoke_config(runtime: ToolRuntime) -> dict[str, Any]:
"""RunnableConfig for the nested invoke; raises ``recursion_limit`` to the parent's budget."""
"""RunnableConfig for the nested invoke; raises ``recursion_limit`` and isolates ``thread_id``.
Each parallel subagent invocation lands in its own checkpoint slot keyed
by an extended ``thread_id`` of the form ``{parent_thread}::task:{tool_call_id}``.
The same call across the resume cycle keeps reading from the same snapshot
(``tool_call_id`` is stable per LLM-emitted call).
We namespace via ``thread_id`` rather than ``checkpoint_ns`` because
langgraph's ``aget_state`` interprets a non-empty ``checkpoint_ns`` as a
subgraph path and raises ``ValueError("Subgraph X not found")``.
"""
merged: dict[str, Any] = dict(runtime.config) if runtime.config else {}
current_limit = merged.get("recursion_limit")
try:
@ -30,43 +40,68 @@ def subagent_invoke_config(runtime: ToolRuntime) -> dict[str, Any]:
current_int = 0
if current_int < DEFAULT_SUBAGENT_RECURSION_LIMIT:
merged["recursion_limit"] = DEFAULT_SUBAGENT_RECURSION_LIMIT
configurable: dict[str, Any] = dict(merged.get("configurable") or {})
parent_thread_id = configurable.get("thread_id")
per_call_suffix = f"task:{runtime.tool_call_id}"
configurable["thread_id"] = (
f"{parent_thread_id}::{per_call_suffix}"
if parent_thread_id
else per_call_suffix
)
merged["configurable"] = configurable
return merged
def consume_surfsense_resume(runtime: ToolRuntime) -> Any:
"""Pop the resume payload; siblings share ``configurable`` by reference."""
"""Pop the resume payload for *this* call's ``tool_call_id``.
The configurable holds ``surfsense_resume_value: dict[tool_call_id, payload]``
so parallel sibling subagents (each with their own ``tool_call_id``) read
only their own decision and never race on a shared scalar.
"""
cfg = runtime.config or {}
configurable = cfg.get("configurable") if isinstance(cfg, dict) else None
if not isinstance(configurable, dict):
return None
return configurable.pop("surfsense_resume_value", None)
by_tcid = configurable.get("surfsense_resume_value")
if not isinstance(by_tcid, dict):
return None
payload = by_tcid.pop(runtime.tool_call_id, None)
if not by_tcid:
configurable.pop("surfsense_resume_value", None)
return payload
def has_surfsense_resume(runtime: ToolRuntime) -> bool:
"""True iff a resume payload is queued on this runtime (non-destructive)."""
"""True iff a resume payload for this call's ``tool_call_id`` is queued (non-destructive)."""
cfg = runtime.config or {}
configurable = cfg.get("configurable") if isinstance(cfg, dict) else None
if not isinstance(configurable, dict):
return False
return "surfsense_resume_value" in configurable
by_tcid = configurable.get("surfsense_resume_value")
if not isinstance(by_tcid, dict):
return False
return runtime.tool_call_id in by_tcid
def drain_parent_null_resume(runtime: ToolRuntime) -> None:
"""Consume the parent's lingering ``NULL_TASK_ID/RESUME`` write before delegating.
``stream_resume_chat`` wakes the main agent with
``Command(resume={"decisions": [...]})`` so the propagated
``_lg_interrupt(...)`` can return. langgraph stores that payload as the
parent task's ``null_resume`` pending write, which only gets consumed
*after* ``subagent.[a]invoke`` returns (when the post-call propagation
re-fires). While the subagent is mid-execution, any *new* ``interrupt()``
inside it (e.g. a follow-up tool call after a mixed approve/reject) walks
``subagent_scratchpad parent_scratchpad.get_null_resume`` and picks up
the parent's still-live decisions — mismatching against a different number
of hanging tool calls and crashing ``HumanInTheLoopMiddleware``.
``Command(resume={tool_call_id: {"decisions": [...]}})`` so the previously
propagated parent-level interrupt can return. langgraph stores that
payload as the parent task's ``null_resume`` pending write. The ``task``
tool then forwards this turn's slice into the subagent via its own
``Command(resume=...)``. While the subagent is mid-execution, any *new*
``interrupt()`` inside it (e.g. a follow-up tool call after a mixed
approve/reject) walks ``subagent_scratchpad parent_scratchpad.get_null_resume``
and picks up the parent's still-live decisions — mismatching against a
different number of hanging tool calls and crashing
``HumanInTheLoopMiddleware``.
Draining the write here closes that cross-graph leak so subagent
interrupts pause cleanly and re-propagate as a fresh approval card.
interrupts pause cleanly and bubble back up as a fresh approval card.
"""
cfg = runtime.config or {}
configurable = cfg.get("configurable") if isinstance(cfg, dict) else None

View file

@ -0,0 +1,137 @@
"""Route a flat ``decisions`` list to per-``tool_call_id`` resume payloads.
The frontend submits decisions in the same order the SSE stream emitted
approval cards. When multiple parallel subagents are paused, the backend uses
this module to:
1. Read ``state.interrupts`` from the parent's paused snapshot, extracting
``[(tool_call_id, action_count), ...]`` from each interrupt's value.
The ``tool_call_id`` is stamped on by ``propagation.wrap_with_tool_call_id``
inside ``task_tool``'s catch-and-stamp block when a subagent's
``GraphInterrupt`` bubbles up through ``[a]task``.
2. Slice the flat ``decisions`` list against that ordered pending list to
produce the dict shape expected by ``consume_surfsense_resume``.
Both helpers are pure: callers own the state and the input decisions; we
return new structures and never mutate.
"""
from __future__ import annotations
import logging
from collections.abc import Iterable
from typing import Any
logger = logging.getLogger(__name__)
def slice_decisions_by_tool_call(
decisions: list[dict[str, Any]],
pending: Iterable[tuple[str, int]],
) -> dict[str, dict[str, Any]]:
"""Slice ``decisions`` into ``{tool_call_id: {"decisions": <slice>}}``.
Args:
decisions: Flat list of decisions in the order the SSE stream rendered
them.
pending: Ordered ``(tool_call_id, action_count)`` pairs in the same
order. The slicer consumes ``decisions`` left-to-right.
Returns:
Per-``tool_call_id`` payload dict ready to be written to
``configurable["surfsense_resume_value"]``.
Raises:
ValueError: When the total expected action count differs from the
number of decisions provided. We fail loud rather than silently
dropping or padding so a frontend/backend contract drift surfaces
immediately.
"""
pending_list = list(pending)
expected = sum(count for _, count in pending_list)
if expected != len(decisions):
raise ValueError(
f"Decision count mismatch: pending tool calls expect "
f"{expected} actions but received {len(decisions)} decisions."
)
routed: dict[str, dict[str, Any]] = {}
cursor = 0
for tool_call_id, action_count in pending_list:
routed[tool_call_id] = {
"decisions": decisions[cursor : cursor + action_count]
}
cursor += action_count
return routed
def collect_pending_tool_calls(state: Any) -> list[tuple[str, int]]:
"""Extract ``[(tool_call_id, action_count), ...]`` from a paused parent state.
Reads ``state.interrupts`` (the bundle langgraph aggregated from each
paused subagent's propagated interrupt). Each interrupt value carries the
``tool_call_id`` that the parent's ``task`` tool was processing — see
``propagation.wrap_with_tool_call_id`` and ``task_tool``'s
``except GraphInterrupt`` chokepoint.
Order is preserved from ``state.interrupts``, which is the order the SSE
stream emitted approval cards. The frontend submits decisions in that
same order, so the slicer can consume them left-to-right.
Interrupts without a ``tool_call_id`` are skipped they were not
produced by our task-routing layer (e.g. parent-side HITL middleware on
a different tool); ``stream_resume_chat`` is not responsible for routing
those.
Args:
state: A langgraph ``StateSnapshot`` (or any object with an
``interrupts`` attribute).
Returns:
Ordered list of ``(tool_call_id, action_count)``. ``action_count`` is
``len(value["action_requests"])`` for HITL-bundle values, or ``1`` for
scalar-style ``interrupt("...")`` values that were wrapped as
``{"value": ..., "tool_call_id": ...}``.
Raises:
ValueError: When an interrupt value carries a ``tool_call_id`` but
the action count cannot be determined (contract bug every
propagated value should be either a HITL bundle or a wrapped
scalar).
"""
pending: list[tuple[str, int]] = []
for idx, interrupt_obj in enumerate(getattr(state, "interrupts", ()) or ()):
value = getattr(interrupt_obj, "value", None)
if not isinstance(value, dict):
logger.warning(
"[hitl_route] interrupt[%d] skipped: value not a dict (type=%s)",
idx,
type(value).__name__,
)
continue
tool_call_id = value.get("tool_call_id")
if not isinstance(tool_call_id, str):
# Should not happen post-stamping; flag loudly if a regression
# ever lets an unstamped value reach the parent state.
logger.warning(
"[hitl_route] interrupt[%d] skipped: no tool_call_id stamp (keys=%s)",
idx,
sorted(value.keys()),
)
continue
action_requests = value.get("action_requests")
if isinstance(action_requests, list):
pending.append((tool_call_id, len(action_requests)))
continue
if "value" in value:
pending.append((tool_call_id, 1))
continue
raise ValueError(
f"Interrupt for tool_call_id={tool_call_id!r} has no "
"``action_requests`` list and is not a wrapped scalar value; "
"cannot determine action count for resume routing."
)
return pending