chat/stream_resume: key Command(resume=...) by Interrupt.id for parallel HITL

This commit is contained in:
CREDO23 2026-05-13 20:59:57 +02:00
parent c06dd6e8ba
commit 0fd87ccb7f
3 changed files with 285 additions and 2 deletions

View file

@ -11,8 +11,11 @@ this module to:
``GraphInterrupt`` bubbles up through ``[a]task``.
2. Slice the flat ``decisions`` list against that ordered pending list to
produce the dict shape expected by ``consume_surfsense_resume``.
3. Re-key those slices by ``Interrupt.id`` (langgraph's primitive) for use as
the parent-level ``Command(resume={interrupt_id: payload})`` input the
only shape langgraph accepts when multiple interrupts are pending.
Both helpers are pure: callers own the state and the input decisions; we
All helpers are pure: callers own the state and the input decisions; we
return new structures and never mutate.
"""
@ -135,3 +138,48 @@ def collect_pending_tool_calls(state: Any) -> list[tuple[str, int]]:
)
return pending
def build_lg_resume_map(
state: Any, by_tool_call_id: dict[str, dict[str, Any]]
) -> dict[str, dict[str, Any]]:
"""Map ``Interrupt.id → resume_payload`` for langgraph's multi-interrupt resume.
``stream_resume_chat`` builds ``by_tool_call_id`` via
:func:`slice_decisions_by_tool_call`. Langgraph's ``Command(resume=...)``
requires ``Interrupt.id`` keys (not our ``tool_call_id`` stamps) when the
parent state has multiple pending interrupts. This pure helper re-keys the
slice without mutating it, and skips entries that can't be paired (no
stamp, no slice) so contract drift surfaces as a count mismatch at the
call site instead of a silent mis-route.
The two key spaces serve two different consumers:
- ``surfsense_resume_value`` (keyed by ``tool_call_id``): read by the
subagent bridge inside ``task_tool``.
- ``Command(resume=...)`` (keyed by ``Interrupt.id``): read by langgraph's
pregel to wake each pending interrupt site.
Args:
state: A langgraph ``StateSnapshot`` (or any object with an
``interrupts`` iterable).
by_tool_call_id: Output of :func:`slice_decisions_by_tool_call`.
Returns:
Dict ready to be passed as ``Command(resume=<this>)``.
"""
out: dict[str, dict[str, Any]] = {}
for interrupt_obj in getattr(state, "interrupts", ()) or ():
value = getattr(interrupt_obj, "value", None)
if not isinstance(value, dict):
continue
tool_call_id = value.get("tool_call_id")
if not isinstance(tool_call_id, str):
continue
interrupt_id = getattr(interrupt_obj, "id", None)
if not isinstance(interrupt_id, str):
continue
payload = by_tool_call_id.get(tool_call_id)
if payload is None:
continue
out[interrupt_id] = payload
return out