mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-12 01:02:39 +02:00
Add StreamingService and interrupt correlation for chat streams.
This commit is contained in:
parent
fc429d8702
commit
fef7621d96
3 changed files with 518 additions and 0 deletions
20
surfsense_backend/app/services/streaming/__init__.py
Normal file
20
surfsense_backend/app/services/streaming/__init__.py
Normal file
|
|
@ -0,0 +1,20 @@
|
||||||
|
"""Single-responsibility split of the streaming SSE protocol.
|
||||||
|
|
||||||
|
Layout:
|
||||||
|
* ``envelope/`` - SSE wire framing + ID generators
|
||||||
|
* ``emitter/`` - identity of the agent that emitted an event + runtime registry
|
||||||
|
* ``events/`` - one module per SSE event family
|
||||||
|
* ``service.py`` - composition root used by the orchestrator
|
||||||
|
* ``interrupt_correlation.py`` - id-aware lookup over LangGraph state
|
||||||
|
|
||||||
|
Naming on the wire:
|
||||||
|
* AI SDK protocol fields keep their existing camelCase
|
||||||
|
(``toolCallId``, ``messageId``, ``inputTextDelta``, ``langchainToolCallId``).
|
||||||
|
* Every SurfSense-added field uses ``snake_case``, including the
|
||||||
|
top-level ``emitted_by`` envelope and all inner ``data`` payloads.
|
||||||
|
|
||||||
|
Production keeps using ``app.services.new_streaming_service`` and
|
||||||
|
``app.tasks.chat.stream_new_chat`` until the cutover phase.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
@ -0,0 +1,84 @@
|
||||||
|
"""Id-aware lookup of pending LangGraph interrupts (replaces first-wins)."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class PendingInterrupt:
|
||||||
|
interrupt_id: str | None
|
||||||
|
value: dict[str, Any]
|
||||||
|
source_task_id: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def list_pending_interrupts(state: Any) -> list[PendingInterrupt]:
|
||||||
|
out: list[PendingInterrupt] = []
|
||||||
|
|
||||||
|
for task in getattr(state, "tasks", None) or ():
|
||||||
|
task_id = _safe_str(getattr(task, "id", None))
|
||||||
|
for it in getattr(task, "interrupts", None) or ():
|
||||||
|
value = _coerce_interrupt_value(it)
|
||||||
|
if value is None:
|
||||||
|
continue
|
||||||
|
interrupt_id = _safe_str(getattr(it, "id", None))
|
||||||
|
out.append(
|
||||||
|
PendingInterrupt(
|
||||||
|
interrupt_id=interrupt_id,
|
||||||
|
value=value,
|
||||||
|
source_task_id=task_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
for it in getattr(state, "interrupts", None) or ():
|
||||||
|
value = _coerce_interrupt_value(it)
|
||||||
|
if value is None:
|
||||||
|
continue
|
||||||
|
interrupt_id = _safe_str(getattr(it, "id", None))
|
||||||
|
out.append(PendingInterrupt(interrupt_id=interrupt_id, value=value))
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def get_pending_interrupt_by_id(
|
||||||
|
state: Any, interrupt_id: str
|
||||||
|
) -> PendingInterrupt | None:
|
||||||
|
for pending in list_pending_interrupts(state):
|
||||||
|
if pending.interrupt_id == interrupt_id:
|
||||||
|
return pending
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def get_pending_interrupt_for_tool_call(
|
||||||
|
state: Any, tool_call_id: str
|
||||||
|
) -> PendingInterrupt | None:
|
||||||
|
for pending in list_pending_interrupts(state):
|
||||||
|
actions = pending.value.get("action_requests")
|
||||||
|
if not isinstance(actions, list):
|
||||||
|
continue
|
||||||
|
for action in actions:
|
||||||
|
if not isinstance(action, dict):
|
||||||
|
continue
|
||||||
|
if action.get("tool_call_id") == tool_call_id:
|
||||||
|
return pending
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def first_pending_interrupt(state: Any) -> PendingInterrupt | None:
|
||||||
|
"""Explicit opt-in to legacy first-wins; prefer the id-aware helpers above."""
|
||||||
|
pending = list_pending_interrupts(state)
|
||||||
|
return pending[0] if pending else None
|
||||||
|
|
||||||
|
|
||||||
|
def _coerce_interrupt_value(item: Any) -> dict[str, Any] | None:
|
||||||
|
if isinstance(item, dict):
|
||||||
|
return item if item else None
|
||||||
|
value = getattr(item, "value", None)
|
||||||
|
if isinstance(value, dict):
|
||||||
|
return value if value else None
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _safe_str(value: Any) -> str | None:
|
||||||
|
return value if isinstance(value, str) and value else None
|
||||||
414
surfsense_backend/app/services/streaming/service.py
Normal file
414
surfsense_backend/app/services/streaming/service.py
Normal file
|
|
@ -0,0 +1,414 @@
|
||||||
|
"""Composition root: bundles every formatter + a per-invocation emitter registry."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import Iterable
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from . import envelope
|
||||||
|
from .emitter import Emitter, EmitterRegistry
|
||||||
|
from .events import (
|
||||||
|
action_log,
|
||||||
|
data,
|
||||||
|
error,
|
||||||
|
interrupt,
|
||||||
|
lifecycle,
|
||||||
|
reasoning,
|
||||||
|
source,
|
||||||
|
subagent_lifecycle,
|
||||||
|
text,
|
||||||
|
tool,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class StreamingService:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self._message_id: str | None = None
|
||||||
|
self.emitter_registry = EmitterRegistry()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def message_id(self) -> str | None:
|
||||||
|
return self._message_id
|
||||||
|
|
||||||
|
def begin_message(self, message_id: str | None = None) -> str:
|
||||||
|
self._message_id = message_id or envelope.generate_message_id()
|
||||||
|
return self._message_id
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def generate_text_id() -> str:
|
||||||
|
return envelope.generate_text_id()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def generate_reasoning_id() -> str:
|
||||||
|
return envelope.generate_reasoning_id()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def generate_tool_call_id() -> str:
|
||||||
|
return envelope.generate_tool_call_id()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def generate_subagent_run_id() -> str:
|
||||||
|
return envelope.generate_subagent_run_id()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_response_headers() -> dict[str, str]:
|
||||||
|
return envelope.get_response_headers()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def format_done() -> str:
|
||||||
|
return envelope.format_done()
|
||||||
|
|
||||||
|
def resolve_emitter(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
run_id: str | None,
|
||||||
|
parent_ids: Iterable[str] | None,
|
||||||
|
) -> Emitter:
|
||||||
|
return self.emitter_registry.resolve(run_id=run_id, parent_ids=parent_ids)
|
||||||
|
|
||||||
|
def format_message_start(
|
||||||
|
self,
|
||||||
|
message_id: str | None = None,
|
||||||
|
*,
|
||||||
|
emitter: Emitter | None = None,
|
||||||
|
) -> str:
|
||||||
|
chosen = self.begin_message(message_id)
|
||||||
|
return lifecycle.format_message_start(chosen, emitter=emitter)
|
||||||
|
|
||||||
|
def format_message_finish(self, *, emitter: Emitter | None = None) -> str:
|
||||||
|
return lifecycle.format_message_finish(emitter=emitter)
|
||||||
|
|
||||||
|
def format_step_start(self, *, emitter: Emitter | None = None) -> str:
|
||||||
|
return lifecycle.format_step_start(emitter=emitter)
|
||||||
|
|
||||||
|
def format_step_finish(self, *, emitter: Emitter | None = None) -> str:
|
||||||
|
return lifecycle.format_step_finish(emitter=emitter)
|
||||||
|
|
||||||
|
def format_text_start(
|
||||||
|
self, text_id: str, *, emitter: Emitter | None = None
|
||||||
|
) -> str:
|
||||||
|
return text.format_text_start(text_id, emitter=emitter)
|
||||||
|
|
||||||
|
def format_text_delta(
|
||||||
|
self, text_id: str, delta: str, *, emitter: Emitter | None = None
|
||||||
|
) -> str:
|
||||||
|
return text.format_text_delta(text_id, delta, emitter=emitter)
|
||||||
|
|
||||||
|
def format_text_end(
|
||||||
|
self, text_id: str, *, emitter: Emitter | None = None
|
||||||
|
) -> str:
|
||||||
|
return text.format_text_end(text_id, emitter=emitter)
|
||||||
|
|
||||||
|
def format_reasoning_start(
|
||||||
|
self, reasoning_id: str, *, emitter: Emitter | None = None
|
||||||
|
) -> str:
|
||||||
|
return reasoning.format_reasoning_start(reasoning_id, emitter=emitter)
|
||||||
|
|
||||||
|
def format_reasoning_delta(
|
||||||
|
self,
|
||||||
|
reasoning_id: str,
|
||||||
|
delta: str,
|
||||||
|
*,
|
||||||
|
emitter: Emitter | None = None,
|
||||||
|
) -> str:
|
||||||
|
return reasoning.format_reasoning_delta(reasoning_id, delta, emitter=emitter)
|
||||||
|
|
||||||
|
def format_reasoning_end(
|
||||||
|
self, reasoning_id: str, *, emitter: Emitter | None = None
|
||||||
|
) -> str:
|
||||||
|
return reasoning.format_reasoning_end(reasoning_id, emitter=emitter)
|
||||||
|
|
||||||
|
def format_tool_input_start(
|
||||||
|
self,
|
||||||
|
tool_call_id: str,
|
||||||
|
tool_name: str,
|
||||||
|
*,
|
||||||
|
langchain_tool_call_id: str | None = None,
|
||||||
|
emitter: Emitter | None = None,
|
||||||
|
) -> str:
|
||||||
|
return tool.format_tool_input_start(
|
||||||
|
tool_call_id,
|
||||||
|
tool_name,
|
||||||
|
langchain_tool_call_id=langchain_tool_call_id,
|
||||||
|
emitter=emitter,
|
||||||
|
)
|
||||||
|
|
||||||
|
def format_tool_input_delta(
|
||||||
|
self,
|
||||||
|
tool_call_id: str,
|
||||||
|
input_text_delta: str,
|
||||||
|
*,
|
||||||
|
emitter: Emitter | None = None,
|
||||||
|
) -> str:
|
||||||
|
return tool.format_tool_input_delta(
|
||||||
|
tool_call_id, input_text_delta, emitter=emitter
|
||||||
|
)
|
||||||
|
|
||||||
|
def format_tool_input_available(
|
||||||
|
self,
|
||||||
|
tool_call_id: str,
|
||||||
|
tool_name: str,
|
||||||
|
input_data: dict[str, Any],
|
||||||
|
*,
|
||||||
|
langchain_tool_call_id: str | None = None,
|
||||||
|
emitter: Emitter | None = None,
|
||||||
|
) -> str:
|
||||||
|
return tool.format_tool_input_available(
|
||||||
|
tool_call_id,
|
||||||
|
tool_name,
|
||||||
|
input_data,
|
||||||
|
langchain_tool_call_id=langchain_tool_call_id,
|
||||||
|
emitter=emitter,
|
||||||
|
)
|
||||||
|
|
||||||
|
def format_tool_output_available(
|
||||||
|
self,
|
||||||
|
tool_call_id: str,
|
||||||
|
output: Any,
|
||||||
|
*,
|
||||||
|
langchain_tool_call_id: str | None = None,
|
||||||
|
emitter: Emitter | None = None,
|
||||||
|
) -> str:
|
||||||
|
return tool.format_tool_output_available(
|
||||||
|
tool_call_id,
|
||||||
|
output,
|
||||||
|
langchain_tool_call_id=langchain_tool_call_id,
|
||||||
|
emitter=emitter,
|
||||||
|
)
|
||||||
|
|
||||||
|
def format_source_url(
|
||||||
|
self,
|
||||||
|
url: str,
|
||||||
|
*,
|
||||||
|
source_id: str | None = None,
|
||||||
|
title: str | None = None,
|
||||||
|
emitter: Emitter | None = None,
|
||||||
|
) -> str:
|
||||||
|
return source.format_source_url(
|
||||||
|
url, source_id=source_id, title=title, emitter=emitter
|
||||||
|
)
|
||||||
|
|
||||||
|
def format_source_document(
|
||||||
|
self,
|
||||||
|
source_id: str,
|
||||||
|
*,
|
||||||
|
media_type: str = "file",
|
||||||
|
title: str | None = None,
|
||||||
|
description: str | None = None,
|
||||||
|
emitter: Emitter | None = None,
|
||||||
|
) -> str:
|
||||||
|
return source.format_source_document(
|
||||||
|
source_id,
|
||||||
|
media_type=media_type,
|
||||||
|
title=title,
|
||||||
|
description=description,
|
||||||
|
emitter=emitter,
|
||||||
|
)
|
||||||
|
|
||||||
|
def format_file(
|
||||||
|
self, url: str, media_type: str, *, emitter: Emitter | None = None
|
||||||
|
) -> str:
|
||||||
|
return source.format_file(url, media_type, emitter=emitter)
|
||||||
|
|
||||||
|
def format_data(
|
||||||
|
self, data_type: str, payload: Any, *, emitter: Emitter | None = None
|
||||||
|
) -> str:
|
||||||
|
return data.format_data(data_type, payload, emitter=emitter)
|
||||||
|
|
||||||
|
def format_terminal_info(
|
||||||
|
self,
|
||||||
|
text_value: str,
|
||||||
|
*,
|
||||||
|
message_type: str = "info",
|
||||||
|
emitter: Emitter | None = None,
|
||||||
|
) -> str:
|
||||||
|
return data.format_terminal_info(
|
||||||
|
text_value, message_type=message_type, emitter=emitter
|
||||||
|
)
|
||||||
|
|
||||||
|
def format_further_questions(
|
||||||
|
self,
|
||||||
|
questions: list[str],
|
||||||
|
*,
|
||||||
|
emitter: Emitter | None = None,
|
||||||
|
) -> str:
|
||||||
|
return data.format_further_questions(questions, emitter=emitter)
|
||||||
|
|
||||||
|
def format_thinking_step(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
step_id: str,
|
||||||
|
title: str,
|
||||||
|
status: str = "in_progress",
|
||||||
|
items: list[str] | None = None,
|
||||||
|
emitter: Emitter | None = None,
|
||||||
|
) -> str:
|
||||||
|
return data.format_thinking_step(
|
||||||
|
step_id=step_id,
|
||||||
|
title=title,
|
||||||
|
status=status,
|
||||||
|
items=items,
|
||||||
|
emitter=emitter,
|
||||||
|
)
|
||||||
|
|
||||||
|
def format_thread_title_update(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
thread_id: int,
|
||||||
|
title: str,
|
||||||
|
emitter: Emitter | None = None,
|
||||||
|
) -> str:
|
||||||
|
return data.format_thread_title_update(
|
||||||
|
thread_id=thread_id, title=title, emitter=emitter
|
||||||
|
)
|
||||||
|
|
||||||
|
def format_turn_info(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
chat_turn_id: str,
|
||||||
|
emitter: Emitter | None = None,
|
||||||
|
) -> str:
|
||||||
|
return data.format_turn_info(chat_turn_id=chat_turn_id, emitter=emitter)
|
||||||
|
|
||||||
|
def format_turn_status(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
status: str,
|
||||||
|
emitter: Emitter | None = None,
|
||||||
|
) -> str:
|
||||||
|
return data.format_turn_status(status=status, emitter=emitter)
|
||||||
|
|
||||||
|
def format_user_message_id(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
message_id: str,
|
||||||
|
turn_id: str,
|
||||||
|
emitter: Emitter | None = None,
|
||||||
|
) -> str:
|
||||||
|
return data.format_user_message_id(
|
||||||
|
message_id=message_id, turn_id=turn_id, emitter=emitter
|
||||||
|
)
|
||||||
|
|
||||||
|
def format_assistant_message_id(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
message_id: str,
|
||||||
|
turn_id: str,
|
||||||
|
emitter: Emitter | None = None,
|
||||||
|
) -> str:
|
||||||
|
return data.format_assistant_message_id(
|
||||||
|
message_id=message_id, turn_id=turn_id, emitter=emitter
|
||||||
|
)
|
||||||
|
|
||||||
|
def format_error(
|
||||||
|
self,
|
||||||
|
error_text: str,
|
||||||
|
*,
|
||||||
|
error_code: str | None = None,
|
||||||
|
extra: dict[str, Any] | None = None,
|
||||||
|
emitter: Emitter | None = None,
|
||||||
|
) -> str:
|
||||||
|
return error.format_error(
|
||||||
|
error_text,
|
||||||
|
error_code=error_code,
|
||||||
|
extra=extra,
|
||||||
|
emitter=emitter,
|
||||||
|
)
|
||||||
|
|
||||||
|
def format_interrupt_request(
|
||||||
|
self,
|
||||||
|
interrupt_value: dict[str, Any],
|
||||||
|
*,
|
||||||
|
interrupt_id: str | None = None,
|
||||||
|
pending_interrupt_count: int | None = None,
|
||||||
|
chat_turn_id: str | None = None,
|
||||||
|
emitter: Emitter | None = None,
|
||||||
|
) -> str:
|
||||||
|
return interrupt.format_interrupt_request(
|
||||||
|
interrupt_value,
|
||||||
|
interrupt_id=interrupt_id,
|
||||||
|
pending_interrupt_count=pending_interrupt_count,
|
||||||
|
chat_turn_id=chat_turn_id,
|
||||||
|
emitter=emitter,
|
||||||
|
)
|
||||||
|
|
||||||
|
def format_subagent_start(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
subagent_run_id: str,
|
||||||
|
subagent_type: str,
|
||||||
|
parent_tool_call_id: str,
|
||||||
|
chat_turn_id: str | None = None,
|
||||||
|
description: str | None = None,
|
||||||
|
started_at: str | None = None,
|
||||||
|
emitter: Emitter | None = None,
|
||||||
|
) -> str:
|
||||||
|
return subagent_lifecycle.format_subagent_start(
|
||||||
|
subagent_run_id=subagent_run_id,
|
||||||
|
subagent_type=subagent_type,
|
||||||
|
parent_tool_call_id=parent_tool_call_id,
|
||||||
|
chat_turn_id=chat_turn_id,
|
||||||
|
description=description,
|
||||||
|
started_at=started_at,
|
||||||
|
emitter=emitter,
|
||||||
|
)
|
||||||
|
|
||||||
|
def format_subagent_finish(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
subagent_run_id: str,
|
||||||
|
subagent_type: str,
|
||||||
|
parent_tool_call_id: str,
|
||||||
|
status: str = "completed",
|
||||||
|
ended_at: str | None = None,
|
||||||
|
duration_ms: int | None = None,
|
||||||
|
emitter: Emitter | None = None,
|
||||||
|
) -> str:
|
||||||
|
return subagent_lifecycle.format_subagent_finish(
|
||||||
|
subagent_run_id=subagent_run_id,
|
||||||
|
subagent_type=subagent_type,
|
||||||
|
parent_tool_call_id=parent_tool_call_id,
|
||||||
|
status=status,
|
||||||
|
ended_at=ended_at,
|
||||||
|
duration_ms=duration_ms,
|
||||||
|
emitter=emitter,
|
||||||
|
)
|
||||||
|
|
||||||
|
def format_subagent_error(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
subagent_run_id: str,
|
||||||
|
subagent_type: str,
|
||||||
|
parent_tool_call_id: str,
|
||||||
|
error_text: str,
|
||||||
|
error_type: str | None = None,
|
||||||
|
ended_at: str | None = None,
|
||||||
|
duration_ms: int | None = None,
|
||||||
|
emitter: Emitter | None = None,
|
||||||
|
) -> str:
|
||||||
|
return subagent_lifecycle.format_subagent_error(
|
||||||
|
subagent_run_id=subagent_run_id,
|
||||||
|
subagent_type=subagent_type,
|
||||||
|
parent_tool_call_id=parent_tool_call_id,
|
||||||
|
error_text=error_text,
|
||||||
|
error_type=error_type,
|
||||||
|
ended_at=ended_at,
|
||||||
|
duration_ms=duration_ms,
|
||||||
|
emitter=emitter,
|
||||||
|
)
|
||||||
|
|
||||||
|
def format_action_log(
|
||||||
|
self,
|
||||||
|
payload: dict[str, Any],
|
||||||
|
*,
|
||||||
|
emitter: Emitter | None = None,
|
||||||
|
) -> str:
|
||||||
|
return action_log.format_action_log(payload, emitter=emitter)
|
||||||
|
|
||||||
|
def format_action_log_updated(
|
||||||
|
self,
|
||||||
|
payload: dict[str, Any],
|
||||||
|
*,
|
||||||
|
emitter: Emitter | None = None,
|
||||||
|
) -> str:
|
||||||
|
return action_log.format_action_log_updated(payload, emitter=emitter)
|
||||||
Loading…
Add table
Add a link
Reference in a new issue