chat/stream_new_chat: emit one SSE frame per pending interrupt

This commit is contained in:
CREDO23 2026-05-13 20:59:48 +02:00
parent 583ac83735
commit c06dd6e8ba
5 changed files with 259 additions and 118 deletions

View file

@ -76,6 +76,9 @@ from app.services.chat_session_state_service import (
from app.services.connector_service import ConnectorService
from app.services.new_streaming_service import VercelStreamingService
from app.tasks.chat.streaming.graph_stream.event_stream import stream_output
from app.tasks.chat.streaming.helpers.interrupt_inspector import (
all_interrupt_values,
)
from app.utils.content_utils import bootstrap_history_from_db
from app.utils.perf import get_perf_logger, log_system_snapshot, trim_native_heap
from app.utils.user_message_multimodal import build_human_message_content
@ -98,47 +101,6 @@ def _compute_turn_cancelling_retry_delay(attempt: int) -> int:
return min(delay, TURN_CANCELLING_MAX_DELAY_MS)
def _first_interrupt_value(state: Any) -> dict[str, Any] | None:
"""Return the first LangGraph interrupt payload across all snapshot tasks."""
def _extract_interrupt_value(candidate: Any) -> dict[str, Any] | None:
if isinstance(candidate, dict):
value = candidate.get("value", candidate)
return value if isinstance(value, dict) else None
value = getattr(candidate, "value", None)
if isinstance(value, dict):
return value
if isinstance(candidate, (list, tuple)):
for item in candidate:
extracted = _extract_interrupt_value(item)
if extracted is not None:
return extracted
return None
for task in getattr(state, "tasks", ()) or ():
try:
interrupts = getattr(task, "interrupts", ()) or ()
except (AttributeError, IndexError, TypeError):
interrupts = ()
if not interrupts:
extracted = _extract_interrupt_value(task)
if extracted is not None:
return extracted
continue
for interrupt_item in interrupts:
extracted = _extract_interrupt_value(interrupt_item)
if extracted is not None:
return extracted
try:
state_interrupts = getattr(state, "interrupts", ()) or ()
except (AttributeError, IndexError, TypeError):
state_interrupts = ()
extracted = _extract_interrupt_value(state_interrupts)
if extracted is not None:
return extracted
return None
def _extract_chunk_parts(chunk: Any) -> dict[str, Any]:
"""Decompose an ``AIMessageChunk`` into typed text/reasoning/tool-call parts.
@ -301,7 +263,6 @@ def extract_todos_from_deepagents(command_output) -> dict:
class StreamResult:
accumulated_text: str = ""
is_interrupted: bool = False
interrupt_value: dict[str, Any] | None = None
sandbox_files: list[str] = field(default_factory=list)
agent_called_update_memory: bool = False
request_id: str | None = None
@ -915,11 +876,15 @@ async def _stream_agent_events(
result.accumulated_text = accumulated_text
_log_file_contract("turn_outcome", result)
interrupt_value = _first_interrupt_value(state)
if interrupt_value is not None:
pending_values = all_interrupt_values(state)
if pending_values:
result.is_interrupted = True
result.interrupt_value = interrupt_value
yield streaming_service.format_interrupt_request(result.interrupt_value)
# One frame per paused subagent so each parallel HITL renders its own
# approval card on the wire. Order matches ``state.interrupts``, which
# the resume slicer in ``checkpointed_subagent_middleware.resume_routing``
# consumes in the same order — keeping emit and resume in lock-step.
for interrupt_value in pending_values:
yield streaming_service.format_interrupt_request(interrupt_value)
async def stream_new_chat(