From 0fd87ccb7f5415a1b3a7280a131c65e4f92a5317 Mon Sep 17 00:00:00 2001 From: CREDO23 Date: Wed, 13 May 2026 20:59:57 +0200 Subject: [PATCH] chat/stream_resume: key Command(resume=...) by Interrupt.id for parallel HITL --- .../resume_routing.py | 50 +++- .../app/tasks/chat/stream_new_chat.py | 7 +- .../test_parallel_resume_command_keying.py | 230 ++++++++++++++++++ 3 files changed, 285 insertions(+), 2 deletions(-) create mode 100644 surfsense_backend/tests/unit/agents/multi_agent_chat/middleware/checkpointed_subagent_middleware/test_parallel_resume_command_keying.py diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/checkpointed_subagent_middleware/resume_routing.py b/surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/checkpointed_subagent_middleware/resume_routing.py index 4e75b3c36..6fa79c764 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/checkpointed_subagent_middleware/resume_routing.py +++ b/surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/checkpointed_subagent_middleware/resume_routing.py @@ -11,8 +11,11 @@ this module to: ``GraphInterrupt`` bubbles up through ``[a]task``. 2. Slice the flat ``decisions`` list against that ordered pending list to produce the dict shape expected by ``consume_surfsense_resume``. +3. Re-key those slices by ``Interrupt.id`` (langgraph's primitive) for use as + the parent-level ``Command(resume={interrupt_id: payload})`` input — the + only shape langgraph accepts when multiple interrupts are pending. -Both helpers are pure: callers own the state and the input decisions; we +All helpers are pure: callers own the state and the input decisions; we return new structures and never mutate. """ @@ -135,3 +138,48 @@ def collect_pending_tool_calls(state: Any) -> list[tuple[str, int]]: ) return pending + + +def build_lg_resume_map( + state: Any, by_tool_call_id: dict[str, dict[str, Any]] +) -> dict[str, dict[str, Any]]: + """Map ``Interrupt.id → resume_payload`` for langgraph's multi-interrupt resume. + + ``stream_resume_chat`` builds ``by_tool_call_id`` via + :func:`slice_decisions_by_tool_call`. Langgraph's ``Command(resume=...)`` + requires ``Interrupt.id`` keys (not our ``tool_call_id`` stamps) when the + parent state has multiple pending interrupts. This pure helper re-keys the + slice without mutating it, and skips entries that can't be paired (no + stamp, no slice) so contract drift surfaces as a count mismatch at the + call site instead of a silent mis-route. + + The two key spaces serve two different consumers: + - ``surfsense_resume_value`` (keyed by ``tool_call_id``): read by the + subagent bridge inside ``task_tool``. + - ``Command(resume=...)`` (keyed by ``Interrupt.id``): read by langgraph's + pregel to wake each pending interrupt site. + + Args: + state: A langgraph ``StateSnapshot`` (or any object with an + ``interrupts`` iterable). + by_tool_call_id: Output of :func:`slice_decisions_by_tool_call`. + + Returns: + Dict ready to be passed as ``Command(resume=)``. + """ + out: dict[str, dict[str, Any]] = {} + for interrupt_obj in getattr(state, "interrupts", ()) or (): + value = getattr(interrupt_obj, "value", None) + if not isinstance(value, dict): + continue + tool_call_id = value.get("tool_call_id") + if not isinstance(tool_call_id, str): + continue + interrupt_id = getattr(interrupt_obj, "id", None) + if not isinstance(interrupt_id, str): + continue + payload = by_tool_call_id.get(tool_call_id) + if payload is None: + continue + out[interrupt_id] = payload + return out diff --git a/surfsense_backend/app/tasks/chat/stream_new_chat.py b/surfsense_backend/app/tasks/chat/stream_new_chat.py index eb30b994e..a87d9f2d1 100644 --- a/surfsense_backend/app/tasks/chat/stream_new_chat.py +++ b/surfsense_backend/app/tasks/chat/stream_new_chat.py @@ -2829,6 +2829,7 @@ async def stream_resume_chat( from langgraph.types import Command from app.agents.multi_agent_chat.middleware.main_agent.checkpointed_subagent_middleware.resume_routing import ( + build_lg_resume_map, collect_pending_tool_calls, slice_decisions_by_tool_call, ) @@ -2847,6 +2848,10 @@ async def stream_resume_chat( len(pending), ) routed_resume_value = slice_decisions_by_tool_call(decisions, pending) + # Langgraph rejects scalar ``Command(resume=...)`` when multiple + # interrupts are pending (parallel HITL); the mapped form works + # for the single-pause case too, so we always use it. + lg_resume_map = build_lg_resume_map(parent_state, routed_resume_value) config = { "configurable": { @@ -2938,7 +2943,7 @@ async def stream_resume_chat( async for sse in _stream_agent_events( agent=agent, config=config, - input_data=Command(resume={"decisions": decisions}), + input_data=Command(resume=lg_resume_map), streaming_service=streaming_service, result=stream_result, step_prefix="thinking-resume", diff --git a/surfsense_backend/tests/unit/agents/multi_agent_chat/middleware/checkpointed_subagent_middleware/test_parallel_resume_command_keying.py b/surfsense_backend/tests/unit/agents/multi_agent_chat/middleware/checkpointed_subagent_middleware/test_parallel_resume_command_keying.py new file mode 100644 index 000000000..125e0744a --- /dev/null +++ b/surfsense_backend/tests/unit/agents/multi_agent_chat/middleware/checkpointed_subagent_middleware/test_parallel_resume_command_keying.py @@ -0,0 +1,230 @@ +"""Real-graph contract: parallel resume must key ``Command(resume=...)`` by ``Interrupt.id``. + +When the parent state has multiple pending interrupts, langgraph rejects a +scalar ``Command(resume=v)`` with:: + + RuntimeError: When there are multiple pending interrupts, you must specify + the interrupt id when resuming. + +The fix is to map each ``Interrupt.id`` from ``state.interrupts`` to the +per-subagent slice — orthogonal to our ``tool_call_id``-keyed +``surfsense_resume_value`` side-channel (different consumer: langgraph's +pregel vs. our subagent bridge). + +This test reproduces the production failure with a real two-task parallel +``Send`` parent graph, exercises the full resume cycle, and asserts both +subagents complete cleanly. +""" + +from __future__ import annotations + +from typing import Annotated + +import pytest +from langchain.tools import ToolRuntime +from langchain_core.messages import AIMessage, HumanMessage +from langchain_core.runnables import RunnableConfig +from langgraph.checkpoint.memory import InMemorySaver +from langgraph.graph import END, START, StateGraph +from langgraph.graph.message import add_messages +from langgraph.types import Command, Send, interrupt +from typing_extensions import TypedDict + +from app.agents.multi_agent_chat.middleware.main_agent.checkpointed_subagent_middleware.resume_routing import ( + build_lg_resume_map, + collect_pending_tool_calls, + slice_decisions_by_tool_call, +) +from app.agents.multi_agent_chat.middleware.main_agent.checkpointed_subagent_middleware.task_tool import ( + build_task_tool_with_parent_config, +) + + +class _SubState(TypedDict, total=False): + messages: list + + +class _DispatchState(TypedDict, total=False): + # ``add_messages`` reducer matches production agent state shape and is + # required when two parallel ``Send`` branches both write to ``messages`` + # in the same superstep (post-resume both subagents return their own + # ``{"messages": [...]}``). Without a reducer langgraph raises + # ``InvalidUpdateError: At key 'messages': Can receive only one value``. + messages: Annotated[list, add_messages] + tcid: str + desc: str + + +def _build_pausing_subagent(checkpointer: InMemorySaver): + def approve_node(_state): + decision = interrupt( + { + "action_requests": [ + {"name": "do_thing", "args": {"x": 1}, "description": ""} + ], + "review_configs": [{}], + } + ) + return {"messages": [AIMessage(content=f"got:{decision}")]} + + g = StateGraph(_SubState) + g.add_node("approve", approve_node) + g.add_edge(START, "approve") + g.add_edge("approve", END) + return g.compile(checkpointer=checkpointer) + + +def _parent_graph_dispatching_two_tasks_via_send( + task_tool, *, tool_call_id_a: str, tool_call_id_b: str, checkpointer +): + def fanout_edge(_state) -> list[Send]: + return [ + Send("call_task", {"tcid": tool_call_id_a, "desc": "approve A"}), + Send("call_task", {"tcid": tool_call_id_b, "desc": "approve B"}), + ] + + async def call_task(state: _DispatchState, config: RunnableConfig): + rt = ToolRuntime( + state=state, + config=config, + context=None, + stream_writer=None, + tool_call_id=state["tcid"], + store=None, + ) + return await task_tool.coroutine( + description=state["desc"], subagent_type="approver", runtime=rt + ) + + g = StateGraph(_DispatchState) + g.add_node("call_task", call_task) + g.add_conditional_edges(START, fanout_edge, ["call_task"]) + g.add_edge("call_task", END) + return g.compile(checkpointer=checkpointer) + + +@pytest.mark.asyncio +async def test_parallel_resume_with_command_resume_scalar_raises_lg_runtime_error(): + """Confirm the production failure mode: scalar resume on multi-pending state explodes. + + This is a contract pin: if langgraph relaxes the requirement in a future + release, this test starts passing and we know we can simplify + ``stream_resume_chat``. Until then, the keyed form is mandatory. + """ + checkpointer = InMemorySaver() + subagent = _build_pausing_subagent(checkpointer) + task_tool = build_task_tool_with_parent_config( + [{"name": "approver", "description": "approves", "runnable": subagent}] + ) + parent = _parent_graph_dispatching_two_tasks_via_send( + task_tool, + tool_call_id_a="parent-tcid-A", + tool_call_id_b="parent-tcid-B", + checkpointer=checkpointer, + ) + config: dict = { + "configurable": {"thread_id": "parallel-resume-scalar"}, + "recursion_limit": 100, + } + await parent.ainvoke({"messages": [HumanMessage(content="seed")]}, config) + + with pytest.raises(RuntimeError, match="multiple pending interrupts"): + await parent.ainvoke(Command(resume={"decisions": ["A"]}), config) + + +@pytest.mark.asyncio +async def test_parallel_resume_with_per_interrupt_id_keying_completes_both_subagents(): + """Production-shape resume: builds the langgraph-keyed map and resumes both subagents. + + Mirrors what ``stream_resume_chat`` does: collects pending interrupts, + slices the flat decisions list by ``tool_call_id``, builds the + ``Interrupt.id``-keyed map for ``Command(resume=...)``, and resumes. + The expected post-condition is that both subagents pop their own + decision (via the ``surfsense_resume_value`` side-channel) and run to + completion — no RuntimeError, no leaked pending interrupts. + """ + checkpointer = InMemorySaver() + subagent = _build_pausing_subagent(checkpointer) + task_tool = build_task_tool_with_parent_config( + [{"name": "approver", "description": "approves", "runnable": subagent}] + ) + tcid_a = "parent-tcid-A" + tcid_b = "parent-tcid-B" + parent = _parent_graph_dispatching_two_tasks_via_send( + task_tool, + tool_call_id_a=tcid_a, + tool_call_id_b=tcid_b, + checkpointer=checkpointer, + ) + config: dict = { + "configurable": {"thread_id": "parallel-resume-keyed"}, + "recursion_limit": 100, + } + await parent.ainvoke({"messages": [HumanMessage(content="seed")]}, config) + + paused_state = await parent.aget_state(config) + assert len(paused_state.interrupts) == 2, "fixture broken: expected 2 paused subagents" + + pending = collect_pending_tool_calls(paused_state) + flat_decisions = [{"type": "approve"}, {"type": "approve"}] + by_tool_call_id = slice_decisions_by_tool_call(flat_decisions, pending) + lg_resume_map = build_lg_resume_map(paused_state, by_tool_call_id) + + assert len(lg_resume_map) == 2, ( + f"expected one entry per pending interrupt id, got {lg_resume_map!r}" + ) + assert all(isinstance(k, str) for k in lg_resume_map), ( + f"keys must be Interrupt.id strings, got {[type(k).__name__ for k in lg_resume_map]}" + ) + + # Wire the side-channel exactly like ``stream_resume_chat`` does. + config["configurable"]["surfsense_resume_value"] = by_tool_call_id + + await parent.ainvoke(Command(resume=lg_resume_map), config) + + final_state = await parent.aget_state(config) + assert not final_state.interrupts, ( + f"expected no leftover pending interrupts after resume, got " + f"{final_state.interrupts!r}" + ) + + +def test_build_lg_resume_map_returns_empty_when_no_interrupts_carry_stamps(): + """Unstamped interrupts can't be routed; we don't fabricate keys for them. + + If a regression lets an unstamped interrupt reach the parent state, the + empty map propagates to the call site and surfaces as a clear count + mismatch instead of a silent mis-route. + """ + from types import SimpleNamespace + + fake_interrupt = SimpleNamespace(id="i-foreign", value={"action_requests": [{}]}) + state = SimpleNamespace(interrupts=(fake_interrupt,)) + + assert build_lg_resume_map(state, {"some-tcid": {"decisions": ["x"]}}) == {} + + +def test_build_lg_resume_map_skips_interrupts_without_corresponding_slice(): + """Skip rather than silently mis-route when the slice and interrupts disagree. + + Only emit a resume entry when both an interrupt id and a tool_call_id + slice are present; a mismatch indicates upstream contract drift and + should not be papered over. + """ + from types import SimpleNamespace + + state = SimpleNamespace( + interrupts=( + SimpleNamespace( + id="i-A", + value={"action_requests": [{}], "tool_call_id": "tcid-A"}, + ), + SimpleNamespace( + id="i-B", + value={"action_requests": [{}], "tool_call_id": "tcid-B"}, + ), + ) + ) + + out = build_lg_resume_map(state, {"tcid-A": {"decisions": ["only-A"]}}) + assert out == {"i-A": {"decisions": ["only-A"]}}