mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-19 18:45:15 +02:00
chat/stream_resume: key Command(resume=...) by Interrupt.id for parallel HITL
This commit is contained in:
parent
c06dd6e8ba
commit
0fd87ccb7f
3 changed files with 285 additions and 2 deletions
|
|
@ -11,8 +11,11 @@ this module to:
|
||||||
``GraphInterrupt`` bubbles up through ``[a]task``.
|
``GraphInterrupt`` bubbles up through ``[a]task``.
|
||||||
2. Slice the flat ``decisions`` list against that ordered pending list to
|
2. Slice the flat ``decisions`` list against that ordered pending list to
|
||||||
produce the dict shape expected by ``consume_surfsense_resume``.
|
produce the dict shape expected by ``consume_surfsense_resume``.
|
||||||
|
3. Re-key those slices by ``Interrupt.id`` (langgraph's primitive) for use as
|
||||||
|
the parent-level ``Command(resume={interrupt_id: payload})`` input — the
|
||||||
|
only shape langgraph accepts when multiple interrupts are pending.
|
||||||
|
|
||||||
Both helpers are pure: callers own the state and the input decisions; we
|
All helpers are pure: callers own the state and the input decisions; we
|
||||||
return new structures and never mutate.
|
return new structures and never mutate.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
@ -135,3 +138,48 @@ def collect_pending_tool_calls(state: Any) -> list[tuple[str, int]]:
|
||||||
)
|
)
|
||||||
|
|
||||||
return pending
|
return pending
|
||||||
|
|
||||||
|
|
||||||
|
def build_lg_resume_map(
|
||||||
|
state: Any, by_tool_call_id: dict[str, dict[str, Any]]
|
||||||
|
) -> dict[str, dict[str, Any]]:
|
||||||
|
"""Map ``Interrupt.id → resume_payload`` for langgraph's multi-interrupt resume.
|
||||||
|
|
||||||
|
``stream_resume_chat`` builds ``by_tool_call_id`` via
|
||||||
|
:func:`slice_decisions_by_tool_call`. Langgraph's ``Command(resume=...)``
|
||||||
|
requires ``Interrupt.id`` keys (not our ``tool_call_id`` stamps) when the
|
||||||
|
parent state has multiple pending interrupts. This pure helper re-keys the
|
||||||
|
slice without mutating it, and skips entries that can't be paired (no
|
||||||
|
stamp, no slice) so contract drift surfaces as a count mismatch at the
|
||||||
|
call site instead of a silent mis-route.
|
||||||
|
|
||||||
|
The two key spaces serve two different consumers:
|
||||||
|
- ``surfsense_resume_value`` (keyed by ``tool_call_id``): read by the
|
||||||
|
subagent bridge inside ``task_tool``.
|
||||||
|
- ``Command(resume=...)`` (keyed by ``Interrupt.id``): read by langgraph's
|
||||||
|
pregel to wake each pending interrupt site.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state: A langgraph ``StateSnapshot`` (or any object with an
|
||||||
|
``interrupts`` iterable).
|
||||||
|
by_tool_call_id: Output of :func:`slice_decisions_by_tool_call`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict ready to be passed as ``Command(resume=<this>)``.
|
||||||
|
"""
|
||||||
|
out: dict[str, dict[str, Any]] = {}
|
||||||
|
for interrupt_obj in getattr(state, "interrupts", ()) or ():
|
||||||
|
value = getattr(interrupt_obj, "value", None)
|
||||||
|
if not isinstance(value, dict):
|
||||||
|
continue
|
||||||
|
tool_call_id = value.get("tool_call_id")
|
||||||
|
if not isinstance(tool_call_id, str):
|
||||||
|
continue
|
||||||
|
interrupt_id = getattr(interrupt_obj, "id", None)
|
||||||
|
if not isinstance(interrupt_id, str):
|
||||||
|
continue
|
||||||
|
payload = by_tool_call_id.get(tool_call_id)
|
||||||
|
if payload is None:
|
||||||
|
continue
|
||||||
|
out[interrupt_id] = payload
|
||||||
|
return out
|
||||||
|
|
|
||||||
|
|
@ -2829,6 +2829,7 @@ async def stream_resume_chat(
|
||||||
from langgraph.types import Command
|
from langgraph.types import Command
|
||||||
|
|
||||||
from app.agents.multi_agent_chat.middleware.main_agent.checkpointed_subagent_middleware.resume_routing import (
|
from app.agents.multi_agent_chat.middleware.main_agent.checkpointed_subagent_middleware.resume_routing import (
|
||||||
|
build_lg_resume_map,
|
||||||
collect_pending_tool_calls,
|
collect_pending_tool_calls,
|
||||||
slice_decisions_by_tool_call,
|
slice_decisions_by_tool_call,
|
||||||
)
|
)
|
||||||
|
|
@ -2847,6 +2848,10 @@ async def stream_resume_chat(
|
||||||
len(pending),
|
len(pending),
|
||||||
)
|
)
|
||||||
routed_resume_value = slice_decisions_by_tool_call(decisions, pending)
|
routed_resume_value = slice_decisions_by_tool_call(decisions, pending)
|
||||||
|
# Langgraph rejects scalar ``Command(resume=...)`` when multiple
|
||||||
|
# interrupts are pending (parallel HITL); the mapped form works
|
||||||
|
# for the single-pause case too, so we always use it.
|
||||||
|
lg_resume_map = build_lg_resume_map(parent_state, routed_resume_value)
|
||||||
|
|
||||||
config = {
|
config = {
|
||||||
"configurable": {
|
"configurable": {
|
||||||
|
|
@ -2938,7 +2943,7 @@ async def stream_resume_chat(
|
||||||
async for sse in _stream_agent_events(
|
async for sse in _stream_agent_events(
|
||||||
agent=agent,
|
agent=agent,
|
||||||
config=config,
|
config=config,
|
||||||
input_data=Command(resume={"decisions": decisions}),
|
input_data=Command(resume=lg_resume_map),
|
||||||
streaming_service=streaming_service,
|
streaming_service=streaming_service,
|
||||||
result=stream_result,
|
result=stream_result,
|
||||||
step_prefix="thinking-resume",
|
step_prefix="thinking-resume",
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,230 @@
|
||||||
|
"""Real-graph contract: parallel resume must key ``Command(resume=...)`` by ``Interrupt.id``.
|
||||||
|
|
||||||
|
When the parent state has multiple pending interrupts, langgraph rejects a
|
||||||
|
scalar ``Command(resume=v)`` with::
|
||||||
|
|
||||||
|
RuntimeError: When there are multiple pending interrupts, you must specify
|
||||||
|
the interrupt id when resuming.
|
||||||
|
|
||||||
|
The fix is to map each ``Interrupt.id`` from ``state.interrupts`` to the
|
||||||
|
per-subagent slice — orthogonal to our ``tool_call_id``-keyed
|
||||||
|
``surfsense_resume_value`` side-channel (different consumer: langgraph's
|
||||||
|
pregel vs. our subagent bridge).
|
||||||
|
|
||||||
|
This test reproduces the production failure with a real two-task parallel
|
||||||
|
``Send`` parent graph, exercises the full resume cycle, and asserts both
|
||||||
|
subagents complete cleanly.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Annotated
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from langchain.tools import ToolRuntime
|
||||||
|
from langchain_core.messages import AIMessage, HumanMessage
|
||||||
|
from langchain_core.runnables import RunnableConfig
|
||||||
|
from langgraph.checkpoint.memory import InMemorySaver
|
||||||
|
from langgraph.graph import END, START, StateGraph
|
||||||
|
from langgraph.graph.message import add_messages
|
||||||
|
from langgraph.types import Command, Send, interrupt
|
||||||
|
from typing_extensions import TypedDict
|
||||||
|
|
||||||
|
from app.agents.multi_agent_chat.middleware.main_agent.checkpointed_subagent_middleware.resume_routing import (
|
||||||
|
build_lg_resume_map,
|
||||||
|
collect_pending_tool_calls,
|
||||||
|
slice_decisions_by_tool_call,
|
||||||
|
)
|
||||||
|
from app.agents.multi_agent_chat.middleware.main_agent.checkpointed_subagent_middleware.task_tool import (
|
||||||
|
build_task_tool_with_parent_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class _SubState(TypedDict, total=False):
|
||||||
|
messages: list
|
||||||
|
|
||||||
|
|
||||||
|
class _DispatchState(TypedDict, total=False):
|
||||||
|
# ``add_messages`` reducer matches production agent state shape and is
|
||||||
|
# required when two parallel ``Send`` branches both write to ``messages``
|
||||||
|
# in the same superstep (post-resume both subagents return their own
|
||||||
|
# ``{"messages": [...]}``). Without a reducer langgraph raises
|
||||||
|
# ``InvalidUpdateError: At key 'messages': Can receive only one value``.
|
||||||
|
messages: Annotated[list, add_messages]
|
||||||
|
tcid: str
|
||||||
|
desc: str
|
||||||
|
|
||||||
|
|
||||||
|
def _build_pausing_subagent(checkpointer: InMemorySaver):
|
||||||
|
def approve_node(_state):
|
||||||
|
decision = interrupt(
|
||||||
|
{
|
||||||
|
"action_requests": [
|
||||||
|
{"name": "do_thing", "args": {"x": 1}, "description": ""}
|
||||||
|
],
|
||||||
|
"review_configs": [{}],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return {"messages": [AIMessage(content=f"got:{decision}")]}
|
||||||
|
|
||||||
|
g = StateGraph(_SubState)
|
||||||
|
g.add_node("approve", approve_node)
|
||||||
|
g.add_edge(START, "approve")
|
||||||
|
g.add_edge("approve", END)
|
||||||
|
return g.compile(checkpointer=checkpointer)
|
||||||
|
|
||||||
|
|
||||||
|
def _parent_graph_dispatching_two_tasks_via_send(
|
||||||
|
task_tool, *, tool_call_id_a: str, tool_call_id_b: str, checkpointer
|
||||||
|
):
|
||||||
|
def fanout_edge(_state) -> list[Send]:
|
||||||
|
return [
|
||||||
|
Send("call_task", {"tcid": tool_call_id_a, "desc": "approve A"}),
|
||||||
|
Send("call_task", {"tcid": tool_call_id_b, "desc": "approve B"}),
|
||||||
|
]
|
||||||
|
|
||||||
|
async def call_task(state: _DispatchState, config: RunnableConfig):
|
||||||
|
rt = ToolRuntime(
|
||||||
|
state=state,
|
||||||
|
config=config,
|
||||||
|
context=None,
|
||||||
|
stream_writer=None,
|
||||||
|
tool_call_id=state["tcid"],
|
||||||
|
store=None,
|
||||||
|
)
|
||||||
|
return await task_tool.coroutine(
|
||||||
|
description=state["desc"], subagent_type="approver", runtime=rt
|
||||||
|
)
|
||||||
|
|
||||||
|
g = StateGraph(_DispatchState)
|
||||||
|
g.add_node("call_task", call_task)
|
||||||
|
g.add_conditional_edges(START, fanout_edge, ["call_task"])
|
||||||
|
g.add_edge("call_task", END)
|
||||||
|
return g.compile(checkpointer=checkpointer)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_parallel_resume_with_command_resume_scalar_raises_lg_runtime_error():
|
||||||
|
"""Confirm the production failure mode: scalar resume on multi-pending state explodes.
|
||||||
|
|
||||||
|
This is a contract pin: if langgraph relaxes the requirement in a future
|
||||||
|
release, this test starts passing and we know we can simplify
|
||||||
|
``stream_resume_chat``. Until then, the keyed form is mandatory.
|
||||||
|
"""
|
||||||
|
checkpointer = InMemorySaver()
|
||||||
|
subagent = _build_pausing_subagent(checkpointer)
|
||||||
|
task_tool = build_task_tool_with_parent_config(
|
||||||
|
[{"name": "approver", "description": "approves", "runnable": subagent}]
|
||||||
|
)
|
||||||
|
parent = _parent_graph_dispatching_two_tasks_via_send(
|
||||||
|
task_tool,
|
||||||
|
tool_call_id_a="parent-tcid-A",
|
||||||
|
tool_call_id_b="parent-tcid-B",
|
||||||
|
checkpointer=checkpointer,
|
||||||
|
)
|
||||||
|
config: dict = {
|
||||||
|
"configurable": {"thread_id": "parallel-resume-scalar"},
|
||||||
|
"recursion_limit": 100,
|
||||||
|
}
|
||||||
|
await parent.ainvoke({"messages": [HumanMessage(content="seed")]}, config)
|
||||||
|
|
||||||
|
with pytest.raises(RuntimeError, match="multiple pending interrupts"):
|
||||||
|
await parent.ainvoke(Command(resume={"decisions": ["A"]}), config)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_parallel_resume_with_per_interrupt_id_keying_completes_both_subagents():
|
||||||
|
"""Production-shape resume: builds the langgraph-keyed map and resumes both subagents.
|
||||||
|
|
||||||
|
Mirrors what ``stream_resume_chat`` does: collects pending interrupts,
|
||||||
|
slices the flat decisions list by ``tool_call_id``, builds the
|
||||||
|
``Interrupt.id``-keyed map for ``Command(resume=...)``, and resumes.
|
||||||
|
The expected post-condition is that both subagents pop their own
|
||||||
|
decision (via the ``surfsense_resume_value`` side-channel) and run to
|
||||||
|
completion — no RuntimeError, no leaked pending interrupts.
|
||||||
|
"""
|
||||||
|
checkpointer = InMemorySaver()
|
||||||
|
subagent = _build_pausing_subagent(checkpointer)
|
||||||
|
task_tool = build_task_tool_with_parent_config(
|
||||||
|
[{"name": "approver", "description": "approves", "runnable": subagent}]
|
||||||
|
)
|
||||||
|
tcid_a = "parent-tcid-A"
|
||||||
|
tcid_b = "parent-tcid-B"
|
||||||
|
parent = _parent_graph_dispatching_two_tasks_via_send(
|
||||||
|
task_tool,
|
||||||
|
tool_call_id_a=tcid_a,
|
||||||
|
tool_call_id_b=tcid_b,
|
||||||
|
checkpointer=checkpointer,
|
||||||
|
)
|
||||||
|
config: dict = {
|
||||||
|
"configurable": {"thread_id": "parallel-resume-keyed"},
|
||||||
|
"recursion_limit": 100,
|
||||||
|
}
|
||||||
|
await parent.ainvoke({"messages": [HumanMessage(content="seed")]}, config)
|
||||||
|
|
||||||
|
paused_state = await parent.aget_state(config)
|
||||||
|
assert len(paused_state.interrupts) == 2, "fixture broken: expected 2 paused subagents"
|
||||||
|
|
||||||
|
pending = collect_pending_tool_calls(paused_state)
|
||||||
|
flat_decisions = [{"type": "approve"}, {"type": "approve"}]
|
||||||
|
by_tool_call_id = slice_decisions_by_tool_call(flat_decisions, pending)
|
||||||
|
lg_resume_map = build_lg_resume_map(paused_state, by_tool_call_id)
|
||||||
|
|
||||||
|
assert len(lg_resume_map) == 2, (
|
||||||
|
f"expected one entry per pending interrupt id, got {lg_resume_map!r}"
|
||||||
|
)
|
||||||
|
assert all(isinstance(k, str) for k in lg_resume_map), (
|
||||||
|
f"keys must be Interrupt.id strings, got {[type(k).__name__ for k in lg_resume_map]}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Wire the side-channel exactly like ``stream_resume_chat`` does.
|
||||||
|
config["configurable"]["surfsense_resume_value"] = by_tool_call_id
|
||||||
|
|
||||||
|
await parent.ainvoke(Command(resume=lg_resume_map), config)
|
||||||
|
|
||||||
|
final_state = await parent.aget_state(config)
|
||||||
|
assert not final_state.interrupts, (
|
||||||
|
f"expected no leftover pending interrupts after resume, got "
|
||||||
|
f"{final_state.interrupts!r}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_lg_resume_map_returns_empty_when_no_interrupts_carry_stamps():
|
||||||
|
"""Unstamped interrupts can't be routed; we don't fabricate keys for them.
|
||||||
|
|
||||||
|
If a regression lets an unstamped interrupt reach the parent state, the
|
||||||
|
empty map propagates to the call site and surfaces as a clear count
|
||||||
|
mismatch instead of a silent mis-route.
|
||||||
|
"""
|
||||||
|
from types import SimpleNamespace
|
||||||
|
|
||||||
|
fake_interrupt = SimpleNamespace(id="i-foreign", value={"action_requests": [{}]})
|
||||||
|
state = SimpleNamespace(interrupts=(fake_interrupt,))
|
||||||
|
|
||||||
|
assert build_lg_resume_map(state, {"some-tcid": {"decisions": ["x"]}}) == {}
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_lg_resume_map_skips_interrupts_without_corresponding_slice():
|
||||||
|
"""Skip rather than silently mis-route when the slice and interrupts disagree.
|
||||||
|
|
||||||
|
Only emit a resume entry when both an interrupt id and a tool_call_id
|
||||||
|
slice are present; a mismatch indicates upstream contract drift and
|
||||||
|
should not be papered over.
|
||||||
|
"""
|
||||||
|
from types import SimpleNamespace
|
||||||
|
|
||||||
|
state = SimpleNamespace(
|
||||||
|
interrupts=(
|
||||||
|
SimpleNamespace(
|
||||||
|
id="i-A",
|
||||||
|
value={"action_requests": [{}], "tool_call_id": "tcid-A"},
|
||||||
|
),
|
||||||
|
SimpleNamespace(
|
||||||
|
id="i-B",
|
||||||
|
value={"action_requests": [{}], "tool_call_id": "tcid-B"},
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
out = build_lg_resume_map(state, {"tcid-A": {"decisions": ["only-A"]}})
|
||||||
|
assert out == {"i-A": {"decisions": ["only-A"]}}
|
||||||
Loading…
Add table
Add a link
Reference in a new issue