chat/stream_new_chat: emit one SSE frame per pending interrupt

This commit is contained in:
CREDO23 2026-05-13 20:59:48 +02:00
parent 583ac83735
commit c06dd6e8ba
5 changed files with 259 additions and 118 deletions

View file

@ -76,6 +76,9 @@ from app.services.chat_session_state_service import (
from app.services.connector_service import ConnectorService
from app.services.new_streaming_service import VercelStreamingService
from app.tasks.chat.streaming.graph_stream.event_stream import stream_output
from app.tasks.chat.streaming.helpers.interrupt_inspector import (
all_interrupt_values,
)
from app.utils.content_utils import bootstrap_history_from_db
from app.utils.perf import get_perf_logger, log_system_snapshot, trim_native_heap
from app.utils.user_message_multimodal import build_human_message_content
@ -98,47 +101,6 @@ def _compute_turn_cancelling_retry_delay(attempt: int) -> int:
return min(delay, TURN_CANCELLING_MAX_DELAY_MS)
def _first_interrupt_value(state: Any) -> dict[str, Any] | None:
"""Return the first LangGraph interrupt payload across all snapshot tasks."""
def _extract_interrupt_value(candidate: Any) -> dict[str, Any] | None:
if isinstance(candidate, dict):
value = candidate.get("value", candidate)
return value if isinstance(value, dict) else None
value = getattr(candidate, "value", None)
if isinstance(value, dict):
return value
if isinstance(candidate, (list, tuple)):
for item in candidate:
extracted = _extract_interrupt_value(item)
if extracted is not None:
return extracted
return None
for task in getattr(state, "tasks", ()) or ():
try:
interrupts = getattr(task, "interrupts", ()) or ()
except (AttributeError, IndexError, TypeError):
interrupts = ()
if not interrupts:
extracted = _extract_interrupt_value(task)
if extracted is not None:
return extracted
continue
for interrupt_item in interrupts:
extracted = _extract_interrupt_value(interrupt_item)
if extracted is not None:
return extracted
try:
state_interrupts = getattr(state, "interrupts", ()) or ()
except (AttributeError, IndexError, TypeError):
state_interrupts = ()
extracted = _extract_interrupt_value(state_interrupts)
if extracted is not None:
return extracted
return None
def _extract_chunk_parts(chunk: Any) -> dict[str, Any]:
"""Decompose an ``AIMessageChunk`` into typed text/reasoning/tool-call parts.
@ -301,7 +263,6 @@ def extract_todos_from_deepagents(command_output) -> dict:
class StreamResult:
accumulated_text: str = ""
is_interrupted: bool = False
interrupt_value: dict[str, Any] | None = None
sandbox_files: list[str] = field(default_factory=list)
agent_called_update_memory: bool = False
request_id: str | None = None
@ -915,11 +876,15 @@ async def _stream_agent_events(
result.accumulated_text = accumulated_text
_log_file_contract("turn_outcome", result)
interrupt_value = _first_interrupt_value(state)
if interrupt_value is not None:
pending_values = all_interrupt_values(state)
if pending_values:
result.is_interrupted = True
result.interrupt_value = interrupt_value
yield streaming_service.format_interrupt_request(result.interrupt_value)
# One frame per paused subagent so each parallel HITL renders its own
# approval card on the wire. Order matches ``state.interrupts``, which
# the resume slicer in ``checkpointed_subagent_middleware.resume_routing``
# consumes in the same order — keeping emit and resume in lock-step.
for interrupt_value in pending_values:
yield streaming_service.format_interrupt_request(interrupt_value)
async def stream_new_chat(

View file

@ -10,7 +10,6 @@ from typing import Any
class StreamingResult:
accumulated_text: str = ""
is_interrupted: bool = False
interrupt_value: dict[str, Any] | None = None
sandbox_files: list[str] = field(default_factory=list)
agent_called_update_memory: bool = False
request_id: str | None = None

View file

@ -1,12 +1,30 @@
"""Read the first interrupt payload from a LangGraph state snapshot."""
"""Read every pending interrupt payload from a LangGraph state snapshot.
The chat-stream emit loop yields one ``data-interrupt-request`` SSE frame per
pending interrupt so parallel HITL across siblings stays addressable on the
wire (the resume slicer in ``checkpointed_subagent_middleware.resume_routing``
correlates each frame back to the right paused subagent via the stamped
``tool_call_id``). This helper produces that flat, ordered list.
"""
from __future__ import annotations
from typing import Any
def first_interrupt_value(state: Any) -> dict[str, Any] | None:
"""Return the first interrupt payload across all snapshot tasks."""
def all_interrupt_values(state: Any) -> list[dict[str, Any]]:
"""Return every interrupt payload across the snapshot, in traversal order.
Walks ``state.tasks[*].interrupts`` first (langgraph's per-task buckets,
which carry one interrupt per paused subagent) and falls back to
``state.interrupts`` when the per-task lists are empty. Order matches the
snapshot's iteration order so the emit-time order on the SSE stream agrees
with ``collect_pending_tool_calls`` consumption order on resume.
Defensive against malformed snapshots: tasks/interrupts that raise on
attribute access are skipped silently. Non-dict values are skipped the
chat-stream contract requires structured interrupt payloads.
"""
def _extract(candidate: Any) -> dict[str, Any] | None:
if isinstance(candidate, dict):
@ -15,33 +33,32 @@ def first_interrupt_value(state: Any) -> dict[str, Any] | None:
value = getattr(candidate, "value", None)
if isinstance(value, dict):
return value
if isinstance(candidate, list | tuple):
for item in candidate:
extracted = _extract(item)
if extracted is not None:
return extracted
return None
values: list[dict[str, Any]] = []
saw_task_interrupt = False
for task in getattr(state, "tasks", ()) or ():
try:
interrupts = getattr(task, "interrupts", ()) or ()
except (AttributeError, IndexError, TypeError):
interrupts = ()
if not interrupts:
extracted = _extract(task)
if extracted is not None:
return extracted
continue
for interrupt_item in interrupts:
extracted = _extract(interrupt_item)
if extracted is not None:
return extracted
if interrupts:
saw_task_interrupt = True
for interrupt_item in interrupts:
extracted = _extract(interrupt_item)
if extracted is not None:
values.append(extracted)
if saw_task_interrupt:
return values
try:
state_interrupts = getattr(state, "interrupts", ()) or ()
except (AttributeError, IndexError, TypeError):
state_interrupts = ()
extracted = _extract(state_interrupts)
if extracted is not None:
return extracted
return None
for interrupt_item in state_interrupts:
extracted = _extract(interrupt_item)
if extracted is not None:
values.append(extracted)
return values

View file

@ -0,0 +1,210 @@
"""Real-graph contract: ``all_interrupt_values`` surfaces every pending interrupt.
The chat-stream emit loop must yield one ``data-interrupt-request`` SSE frame
per paused subagent, in the same order ``state.interrupts`` reports them
that's also the order the resume slicer consumes decisions. These tests pin
that contract against a **real** paused parent graph built via
:class:`~langgraph.types.Send` fan-out (no synthetic state mocks).
"""
from __future__ import annotations
import pytest
from langchain.tools import ToolRuntime
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.runnables import RunnableConfig
from langgraph.checkpoint.memory import InMemorySaver
from langgraph.graph import END, START, StateGraph
from langgraph.types import Send, interrupt
from typing_extensions import TypedDict
from app.agents.multi_agent_chat.middleware.main_agent.checkpointed_subagent_middleware.task_tool import (
build_task_tool_with_parent_config,
)
from app.tasks.chat.streaming.helpers.interrupt_inspector import (
all_interrupt_values,
)
class _SubState(TypedDict, total=False):
messages: list
class _DispatchState(TypedDict, total=False):
messages: list
tcid: str
desc: str
def _build_pausing_subagent(checkpointer: InMemorySaver):
def approve_node(_state):
decision = interrupt(
{
"action_requests": [
{"name": "do_thing", "args": {"x": 1}, "description": ""}
],
"review_configs": [{}],
}
)
return {"messages": [AIMessage(content=f"got:{decision}")]}
g = StateGraph(_SubState)
g.add_node("approve", approve_node)
g.add_edge(START, "approve")
g.add_edge("approve", END)
return g.compile(checkpointer=checkpointer)
def _parent_graph_dispatching_two_tasks_via_send(
task_tool, *, tool_call_id_a: str, tool_call_id_b: str, checkpointer
):
def fanout_edge(_state) -> list[Send]:
return [
Send("call_task", {"tcid": tool_call_id_a, "desc": "approve A"}),
Send("call_task", {"tcid": tool_call_id_b, "desc": "approve B"}),
]
async def call_task(state: _DispatchState, config: RunnableConfig):
rt = ToolRuntime(
state=state,
config=config,
context=None,
stream_writer=None,
tool_call_id=state["tcid"],
store=None,
)
return await task_tool.coroutine(
description=state["desc"], subagent_type="approver", runtime=rt
)
g = StateGraph(_DispatchState)
g.add_node("call_task", call_task)
g.add_conditional_edges(START, fanout_edge, ["call_task"])
g.add_edge("call_task", END)
return g.compile(checkpointer=checkpointer)
@pytest.mark.asyncio
async def test_returns_every_pending_interrupt_for_two_paused_subagents():
"""Two parallel subagents -> ``all_interrupt_values`` returns 2 dicts."""
checkpointer = InMemorySaver()
subagent = _build_pausing_subagent(checkpointer)
task_tool = build_task_tool_with_parent_config(
[{"name": "approver", "description": "approves", "runnable": subagent}]
)
parent = _parent_graph_dispatching_two_tasks_via_send(
task_tool,
tool_call_id_a="parent-tcid-A",
tool_call_id_b="parent-tcid-B",
checkpointer=checkpointer,
)
parent_config = {
"configurable": {"thread_id": "all-iv-thread"},
"recursion_limit": 100,
}
await parent.ainvoke({"messages": [HumanMessage(content="seed")]}, parent_config)
state = await parent.aget_state(parent_config)
values = all_interrupt_values(state)
assert isinstance(values, list)
assert len(values) == 2, (
f"REGRESSION: expected one value per pending subagent, got "
f"{len(values)}: {values!r}"
)
stamps = [v.get("tool_call_id") for v in values]
assert sorted(stamps) == ["parent-tcid-A", "parent-tcid-B"]
for v in values:
assert isinstance(v.get("action_requests"), list)
assert len(v["action_requests"]) == 1
@pytest.mark.asyncio
async def test_preserves_state_interrupts_traversal_order():
"""Order returned by inspector must match ``state.interrupts`` order.
The resume slicer consumes decisions left-to-right against
``collect_pending_tool_calls(state)`` which walks ``state.interrupts``
in iteration order so the inspector (which drives the *emit* order)
must agree with that traversal or the slice and the wire fall out of sync.
"""
checkpointer = InMemorySaver()
subagent = _build_pausing_subagent(checkpointer)
task_tool = build_task_tool_with_parent_config(
[{"name": "approver", "description": "approves", "runnable": subagent}]
)
parent = _parent_graph_dispatching_two_tasks_via_send(
task_tool,
tool_call_id_a="parent-tcid-A",
tool_call_id_b="parent-tcid-B",
checkpointer=checkpointer,
)
parent_config = {
"configurable": {"thread_id": "order-thread"},
"recursion_limit": 100,
}
await parent.ainvoke({"messages": [HumanMessage(content="seed")]}, parent_config)
state = await parent.aget_state(parent_config)
inspector_order = [v["tool_call_id"] for v in all_interrupt_values(state)]
state_order = [
i.value["tool_call_id"]
for i in state.interrupts
if isinstance(getattr(i, "value", None), dict)
and "tool_call_id" in i.value
]
assert inspector_order == state_order, (
f"inspector order {inspector_order!r} diverged from state.interrupts "
f"order {state_order!r}; the resume slicer would mis-route decisions."
)
@pytest.mark.asyncio
async def test_returns_empty_list_when_nothing_paused():
"""A graph that completes normally produces no interrupts to surface."""
def done_node(_state):
return {"messages": [AIMessage(content="done")]}
g = StateGraph(_SubState)
g.add_node("done", done_node)
g.add_edge(START, "done")
g.add_edge("done", END)
graph = g.compile(checkpointer=InMemorySaver())
config = {"configurable": {"thread_id": "no-pause-thread"}}
await graph.ainvoke({"messages": [HumanMessage(content="seed")]}, config)
state = await graph.aget_state(config)
assert all_interrupt_values(state) == []
@pytest.mark.asyncio
async def test_single_paused_subagent_returns_a_list_of_one():
"""Single-pause case must still return a list (not unwrap to a dict)."""
def approve_node(_state):
decision = interrupt(
{
"action_requests": [{"name": "x", "args": {}, "description": ""}],
"review_configs": [{}],
"tool_call_id": "lonely-tcid",
}
)
return {"messages": [AIMessage(content=f"got:{decision}")]}
g = StateGraph(_SubState)
g.add_node("approve", approve_node)
g.add_edge(START, "approve")
g.add_edge("approve", END)
graph = g.compile(checkpointer=InMemorySaver())
config = {"configurable": {"thread_id": "single-thread"}}
await graph.ainvoke({"messages": [HumanMessage(content="seed")]}, config)
state = await graph.aget_state(config)
values = all_interrupt_values(state)
assert isinstance(values, list)
assert len(values) == 1
assert values[0].get("tool_call_id") == "lonely-tcid"

View file

@ -23,7 +23,6 @@ from app.tasks.chat.stream_new_chat import (
_emit_stream_terminal_error as old_emit_terminal_error,
_extract_chunk_parts as old_extract_chunk_parts,
_extract_resolved_file_path as old_extract_resolved_file_path,
_first_interrupt_value as old_first_interrupt_value,
_tool_output_has_error as old_tool_output_has_error,
_tool_output_to_text as old_tool_output_to_text,
)
@ -36,9 +35,6 @@ from app.tasks.chat.streaming.errors.emitter import (
from app.tasks.chat.streaming.helpers.chunk_parts import (
extract_chunk_parts as new_extract_chunk_parts,
)
from app.tasks.chat.streaming.helpers.interrupt_inspector import (
first_interrupt_value as new_first_interrupt_value,
)
from app.tasks.chat.streaming.helpers.tool_output import (
extract_resolved_file_path as new_extract_resolved_file_path,
tool_output_has_error as new_tool_output_has_error,
@ -105,52 +101,6 @@ def test_extract_chunk_parts_matches_old_implementation(chunk: Any) -> None:
assert new_extract_chunk_parts(chunk) == old_extract_chunk_parts(chunk)
# ---------------------------------------------------------- interrupt inspector
@dataclass
class _Interrupt:
value: dict[str, Any]
@dataclass
class _Task:
interrupts: tuple[Any, ...] = ()
@dataclass
class _State:
tasks: tuple[Any, ...] = ()
interrupts: tuple[Any, ...] = ()
_INTERRUPT_CASES: list[Any] = [
_State(),
_State(tasks=(_Task(interrupts=(_Interrupt(value={"name": "send"}),)),)),
# Multiple tasks: must return the FIRST one in iteration order.
_State(
tasks=(
_Task(interrupts=(_Interrupt(value={"name": "first"}),)),
_Task(interrupts=(_Interrupt(value={"name": "second"}),)),
)
),
# Empty task interrupts -> falls back to root state.interrupts.
_State(
tasks=(_Task(interrupts=()),),
interrupts=(_Interrupt(value={"name": "root"}),),
),
# Interrupts as plain dicts (not wrapper objects).
_State(interrupts=({"value": {"name": "dict_root"}},)),
# A defective task whose `.interrupts` raises - must be tolerated.
_State(tasks=(object(),)),
]
@pytest.mark.parametrize("state", _INTERRUPT_CASES)
def test_first_interrupt_value_matches_old_implementation(state: Any) -> None:
assert new_first_interrupt_value(state) == old_first_interrupt_value(state)
# ----------------------------------------------------------- error classifier