diff --git a/surfsense_backend/app/tasks/chat/stream_new_chat.py b/surfsense_backend/app/tasks/chat/stream_new_chat.py index 1d5e1aa1a..eb30b994e 100644 --- a/surfsense_backend/app/tasks/chat/stream_new_chat.py +++ b/surfsense_backend/app/tasks/chat/stream_new_chat.py @@ -76,6 +76,9 @@ from app.services.chat_session_state_service import ( from app.services.connector_service import ConnectorService from app.services.new_streaming_service import VercelStreamingService from app.tasks.chat.streaming.graph_stream.event_stream import stream_output +from app.tasks.chat.streaming.helpers.interrupt_inspector import ( + all_interrupt_values, +) from app.utils.content_utils import bootstrap_history_from_db from app.utils.perf import get_perf_logger, log_system_snapshot, trim_native_heap from app.utils.user_message_multimodal import build_human_message_content @@ -98,47 +101,6 @@ def _compute_turn_cancelling_retry_delay(attempt: int) -> int: return min(delay, TURN_CANCELLING_MAX_DELAY_MS) -def _first_interrupt_value(state: Any) -> dict[str, Any] | None: - """Return the first LangGraph interrupt payload across all snapshot tasks.""" - - def _extract_interrupt_value(candidate: Any) -> dict[str, Any] | None: - if isinstance(candidate, dict): - value = candidate.get("value", candidate) - return value if isinstance(value, dict) else None - value = getattr(candidate, "value", None) - if isinstance(value, dict): - return value - if isinstance(candidate, (list, tuple)): - for item in candidate: - extracted = _extract_interrupt_value(item) - if extracted is not None: - return extracted - return None - - for task in getattr(state, "tasks", ()) or (): - try: - interrupts = getattr(task, "interrupts", ()) or () - except (AttributeError, IndexError, TypeError): - interrupts = () - if not interrupts: - extracted = _extract_interrupt_value(task) - if extracted is not None: - return extracted - continue - for interrupt_item in interrupts: - extracted = _extract_interrupt_value(interrupt_item) - if extracted is not None: - return extracted - try: - state_interrupts = getattr(state, "interrupts", ()) or () - except (AttributeError, IndexError, TypeError): - state_interrupts = () - extracted = _extract_interrupt_value(state_interrupts) - if extracted is not None: - return extracted - return None - - def _extract_chunk_parts(chunk: Any) -> dict[str, Any]: """Decompose an ``AIMessageChunk`` into typed text/reasoning/tool-call parts. @@ -301,7 +263,6 @@ def extract_todos_from_deepagents(command_output) -> dict: class StreamResult: accumulated_text: str = "" is_interrupted: bool = False - interrupt_value: dict[str, Any] | None = None sandbox_files: list[str] = field(default_factory=list) agent_called_update_memory: bool = False request_id: str | None = None @@ -915,11 +876,15 @@ async def _stream_agent_events( result.accumulated_text = accumulated_text _log_file_contract("turn_outcome", result) - interrupt_value = _first_interrupt_value(state) - if interrupt_value is not None: + pending_values = all_interrupt_values(state) + if pending_values: result.is_interrupted = True - result.interrupt_value = interrupt_value - yield streaming_service.format_interrupt_request(result.interrupt_value) + # One frame per paused subagent so each parallel HITL renders its own + # approval card on the wire. Order matches ``state.interrupts``, which + # the resume slicer in ``checkpointed_subagent_middleware.resume_routing`` + # consumes in the same order — keeping emit and resume in lock-step. + for interrupt_value in pending_values: + yield streaming_service.format_interrupt_request(interrupt_value) async def stream_new_chat( diff --git a/surfsense_backend/app/tasks/chat/streaming/graph_stream/result.py b/surfsense_backend/app/tasks/chat/streaming/graph_stream/result.py index 40404e9d0..391f14f24 100644 --- a/surfsense_backend/app/tasks/chat/streaming/graph_stream/result.py +++ b/surfsense_backend/app/tasks/chat/streaming/graph_stream/result.py @@ -10,7 +10,6 @@ from typing import Any class StreamingResult: accumulated_text: str = "" is_interrupted: bool = False - interrupt_value: dict[str, Any] | None = None sandbox_files: list[str] = field(default_factory=list) agent_called_update_memory: bool = False request_id: str | None = None diff --git a/surfsense_backend/app/tasks/chat/streaming/helpers/interrupt_inspector.py b/surfsense_backend/app/tasks/chat/streaming/helpers/interrupt_inspector.py index dca099b3f..f4b00431c 100644 --- a/surfsense_backend/app/tasks/chat/streaming/helpers/interrupt_inspector.py +++ b/surfsense_backend/app/tasks/chat/streaming/helpers/interrupt_inspector.py @@ -1,12 +1,30 @@ -"""Read the first interrupt payload from a LangGraph state snapshot.""" +"""Read every pending interrupt payload from a LangGraph state snapshot. + +The chat-stream emit loop yields one ``data-interrupt-request`` SSE frame per +pending interrupt so parallel HITL across siblings stays addressable on the +wire (the resume slicer in ``checkpointed_subagent_middleware.resume_routing`` +correlates each frame back to the right paused subagent via the stamped +``tool_call_id``). This helper produces that flat, ordered list. +""" from __future__ import annotations from typing import Any -def first_interrupt_value(state: Any) -> dict[str, Any] | None: - """Return the first interrupt payload across all snapshot tasks.""" +def all_interrupt_values(state: Any) -> list[dict[str, Any]]: + """Return every interrupt payload across the snapshot, in traversal order. + + Walks ``state.tasks[*].interrupts`` first (langgraph's per-task buckets, + which carry one interrupt per paused subagent) and falls back to + ``state.interrupts`` when the per-task lists are empty. Order matches the + snapshot's iteration order so the emit-time order on the SSE stream agrees + with ``collect_pending_tool_calls`` consumption order on resume. + + Defensive against malformed snapshots: tasks/interrupts that raise on + attribute access are skipped silently. Non-dict values are skipped — the + chat-stream contract requires structured interrupt payloads. + """ def _extract(candidate: Any) -> dict[str, Any] | None: if isinstance(candidate, dict): @@ -15,33 +33,32 @@ def first_interrupt_value(state: Any) -> dict[str, Any] | None: value = getattr(candidate, "value", None) if isinstance(value, dict): return value - if isinstance(candidate, list | tuple): - for item in candidate: - extracted = _extract(item) - if extracted is not None: - return extracted return None + values: list[dict[str, Any]] = [] + saw_task_interrupt = False + for task in getattr(state, "tasks", ()) or (): try: interrupts = getattr(task, "interrupts", ()) or () except (AttributeError, IndexError, TypeError): interrupts = () - if not interrupts: - extracted = _extract(task) - if extracted is not None: - return extracted - continue - for interrupt_item in interrupts: - extracted = _extract(interrupt_item) - if extracted is not None: - return extracted + if interrupts: + saw_task_interrupt = True + for interrupt_item in interrupts: + extracted = _extract(interrupt_item) + if extracted is not None: + values.append(extracted) + + if saw_task_interrupt: + return values try: state_interrupts = getattr(state, "interrupts", ()) or () except (AttributeError, IndexError, TypeError): state_interrupts = () - extracted = _extract(state_interrupts) - if extracted is not None: - return extracted - return None + for interrupt_item in state_interrupts: + extracted = _extract(interrupt_item) + if extracted is not None: + values.append(extracted) + return values diff --git a/surfsense_backend/tests/unit/tasks/chat/streaming/test_interrupt_inspector_all.py b/surfsense_backend/tests/unit/tasks/chat/streaming/test_interrupt_inspector_all.py new file mode 100644 index 000000000..348e49a4a --- /dev/null +++ b/surfsense_backend/tests/unit/tasks/chat/streaming/test_interrupt_inspector_all.py @@ -0,0 +1,210 @@ +"""Real-graph contract: ``all_interrupt_values`` surfaces every pending interrupt. + +The chat-stream emit loop must yield one ``data-interrupt-request`` SSE frame +per paused subagent, in the same order ``state.interrupts`` reports them — +that's also the order the resume slicer consumes decisions. These tests pin +that contract against a **real** paused parent graph built via +:class:`~langgraph.types.Send` fan-out (no synthetic state mocks). +""" + +from __future__ import annotations + +import pytest +from langchain.tools import ToolRuntime +from langchain_core.messages import AIMessage, HumanMessage +from langchain_core.runnables import RunnableConfig +from langgraph.checkpoint.memory import InMemorySaver +from langgraph.graph import END, START, StateGraph +from langgraph.types import Send, interrupt +from typing_extensions import TypedDict + +from app.agents.multi_agent_chat.middleware.main_agent.checkpointed_subagent_middleware.task_tool import ( + build_task_tool_with_parent_config, +) +from app.tasks.chat.streaming.helpers.interrupt_inspector import ( + all_interrupt_values, +) + + +class _SubState(TypedDict, total=False): + messages: list + + +class _DispatchState(TypedDict, total=False): + messages: list + tcid: str + desc: str + + +def _build_pausing_subagent(checkpointer: InMemorySaver): + def approve_node(_state): + decision = interrupt( + { + "action_requests": [ + {"name": "do_thing", "args": {"x": 1}, "description": ""} + ], + "review_configs": [{}], + } + ) + return {"messages": [AIMessage(content=f"got:{decision}")]} + + g = StateGraph(_SubState) + g.add_node("approve", approve_node) + g.add_edge(START, "approve") + g.add_edge("approve", END) + return g.compile(checkpointer=checkpointer) + + +def _parent_graph_dispatching_two_tasks_via_send( + task_tool, *, tool_call_id_a: str, tool_call_id_b: str, checkpointer +): + def fanout_edge(_state) -> list[Send]: + return [ + Send("call_task", {"tcid": tool_call_id_a, "desc": "approve A"}), + Send("call_task", {"tcid": tool_call_id_b, "desc": "approve B"}), + ] + + async def call_task(state: _DispatchState, config: RunnableConfig): + rt = ToolRuntime( + state=state, + config=config, + context=None, + stream_writer=None, + tool_call_id=state["tcid"], + store=None, + ) + return await task_tool.coroutine( + description=state["desc"], subagent_type="approver", runtime=rt + ) + + g = StateGraph(_DispatchState) + g.add_node("call_task", call_task) + g.add_conditional_edges(START, fanout_edge, ["call_task"]) + g.add_edge("call_task", END) + return g.compile(checkpointer=checkpointer) + + +@pytest.mark.asyncio +async def test_returns_every_pending_interrupt_for_two_paused_subagents(): + """Two parallel subagents -> ``all_interrupt_values`` returns 2 dicts.""" + checkpointer = InMemorySaver() + subagent = _build_pausing_subagent(checkpointer) + task_tool = build_task_tool_with_parent_config( + [{"name": "approver", "description": "approves", "runnable": subagent}] + ) + parent = _parent_graph_dispatching_two_tasks_via_send( + task_tool, + tool_call_id_a="parent-tcid-A", + tool_call_id_b="parent-tcid-B", + checkpointer=checkpointer, + ) + + parent_config = { + "configurable": {"thread_id": "all-iv-thread"}, + "recursion_limit": 100, + } + await parent.ainvoke({"messages": [HumanMessage(content="seed")]}, parent_config) + state = await parent.aget_state(parent_config) + + values = all_interrupt_values(state) + + assert isinstance(values, list) + assert len(values) == 2, ( + f"REGRESSION: expected one value per pending subagent, got " + f"{len(values)}: {values!r}" + ) + stamps = [v.get("tool_call_id") for v in values] + assert sorted(stamps) == ["parent-tcid-A", "parent-tcid-B"] + for v in values: + assert isinstance(v.get("action_requests"), list) + assert len(v["action_requests"]) == 1 + + +@pytest.mark.asyncio +async def test_preserves_state_interrupts_traversal_order(): + """Order returned by inspector must match ``state.interrupts`` order. + + The resume slicer consumes decisions left-to-right against + ``collect_pending_tool_calls(state)`` which walks ``state.interrupts`` + in iteration order — so the inspector (which drives the *emit* order) + must agree with that traversal or the slice and the wire fall out of sync. + """ + checkpointer = InMemorySaver() + subagent = _build_pausing_subagent(checkpointer) + task_tool = build_task_tool_with_parent_config( + [{"name": "approver", "description": "approves", "runnable": subagent}] + ) + parent = _parent_graph_dispatching_two_tasks_via_send( + task_tool, + tool_call_id_a="parent-tcid-A", + tool_call_id_b="parent-tcid-B", + checkpointer=checkpointer, + ) + parent_config = { + "configurable": {"thread_id": "order-thread"}, + "recursion_limit": 100, + } + await parent.ainvoke({"messages": [HumanMessage(content="seed")]}, parent_config) + state = await parent.aget_state(parent_config) + + inspector_order = [v["tool_call_id"] for v in all_interrupt_values(state)] + state_order = [ + i.value["tool_call_id"] + for i in state.interrupts + if isinstance(getattr(i, "value", None), dict) + and "tool_call_id" in i.value + ] + + assert inspector_order == state_order, ( + f"inspector order {inspector_order!r} diverged from state.interrupts " + f"order {state_order!r}; the resume slicer would mis-route decisions." + ) + + +@pytest.mark.asyncio +async def test_returns_empty_list_when_nothing_paused(): + """A graph that completes normally produces no interrupts to surface.""" + + def done_node(_state): + return {"messages": [AIMessage(content="done")]} + + g = StateGraph(_SubState) + g.add_node("done", done_node) + g.add_edge(START, "done") + g.add_edge("done", END) + graph = g.compile(checkpointer=InMemorySaver()) + config = {"configurable": {"thread_id": "no-pause-thread"}} + await graph.ainvoke({"messages": [HumanMessage(content="seed")]}, config) + state = await graph.aget_state(config) + + assert all_interrupt_values(state) == [] + + +@pytest.mark.asyncio +async def test_single_paused_subagent_returns_a_list_of_one(): + """Single-pause case must still return a list (not unwrap to a dict).""" + + def approve_node(_state): + decision = interrupt( + { + "action_requests": [{"name": "x", "args": {}, "description": ""}], + "review_configs": [{}], + "tool_call_id": "lonely-tcid", + } + ) + return {"messages": [AIMessage(content=f"got:{decision}")]} + + g = StateGraph(_SubState) + g.add_node("approve", approve_node) + g.add_edge(START, "approve") + g.add_edge("approve", END) + graph = g.compile(checkpointer=InMemorySaver()) + config = {"configurable": {"thread_id": "single-thread"}} + await graph.ainvoke({"messages": [HumanMessage(content="seed")]}, config) + state = await graph.aget_state(config) + + values = all_interrupt_values(state) + + assert isinstance(values, list) + assert len(values) == 1 + assert values[0].get("tool_call_id") == "lonely-tcid" diff --git a/surfsense_backend/tests/unit/tasks/chat/streaming/test_stage_1_parity.py b/surfsense_backend/tests/unit/tasks/chat/streaming/test_stage_1_parity.py index d598de492..8fde773e3 100644 --- a/surfsense_backend/tests/unit/tasks/chat/streaming/test_stage_1_parity.py +++ b/surfsense_backend/tests/unit/tasks/chat/streaming/test_stage_1_parity.py @@ -23,7 +23,6 @@ from app.tasks.chat.stream_new_chat import ( _emit_stream_terminal_error as old_emit_terminal_error, _extract_chunk_parts as old_extract_chunk_parts, _extract_resolved_file_path as old_extract_resolved_file_path, - _first_interrupt_value as old_first_interrupt_value, _tool_output_has_error as old_tool_output_has_error, _tool_output_to_text as old_tool_output_to_text, ) @@ -36,9 +35,6 @@ from app.tasks.chat.streaming.errors.emitter import ( from app.tasks.chat.streaming.helpers.chunk_parts import ( extract_chunk_parts as new_extract_chunk_parts, ) -from app.tasks.chat.streaming.helpers.interrupt_inspector import ( - first_interrupt_value as new_first_interrupt_value, -) from app.tasks.chat.streaming.helpers.tool_output import ( extract_resolved_file_path as new_extract_resolved_file_path, tool_output_has_error as new_tool_output_has_error, @@ -105,52 +101,6 @@ def test_extract_chunk_parts_matches_old_implementation(chunk: Any) -> None: assert new_extract_chunk_parts(chunk) == old_extract_chunk_parts(chunk) -# ---------------------------------------------------------- interrupt inspector - - -@dataclass -class _Interrupt: - value: dict[str, Any] - - -@dataclass -class _Task: - interrupts: tuple[Any, ...] = () - - -@dataclass -class _State: - tasks: tuple[Any, ...] = () - interrupts: tuple[Any, ...] = () - - -_INTERRUPT_CASES: list[Any] = [ - _State(), - _State(tasks=(_Task(interrupts=(_Interrupt(value={"name": "send"}),)),)), - # Multiple tasks: must return the FIRST one in iteration order. - _State( - tasks=( - _Task(interrupts=(_Interrupt(value={"name": "first"}),)), - _Task(interrupts=(_Interrupt(value={"name": "second"}),)), - ) - ), - # Empty task interrupts -> falls back to root state.interrupts. - _State( - tasks=(_Task(interrupts=()),), - interrupts=(_Interrupt(value={"name": "root"}),), - ), - # Interrupts as plain dicts (not wrapper objects). - _State(interrupts=({"value": {"name": "dict_root"}},)), - # A defective task whose `.interrupts` raises - must be tolerated. - _State(tasks=(object(),)), -] - - -@pytest.mark.parametrize("state", _INTERRUPT_CASES) -def test_first_interrupt_value_matches_old_implementation(state: Any) -> None: - assert new_first_interrupt_value(state) == old_first_interrupt_value(state) - - # ----------------------------------------------------------- error classifier