From a04b2e88bdaeea624c0d192b3259bbd3482bc717 Mon Sep 17 00:00:00 2001 From: CREDO23 Date: Thu, 7 May 2026 17:06:17 +0200 Subject: [PATCH] wire orchestrator streaming context path and align event relay outputs --- .../chat/streaming/orchestration/__init__.py | 12 +- .../streaming/orchestration/event_stream.py | 6 +- .../chat/streaming/orchestration/input.py | 4 +- .../streaming/orchestration/orchestrator.py | 142 ++++++++---------- .../chat/streaming/orchestration/output.py | 5 +- .../tasks/chat/streaming/relay/event_relay.py | 4 +- .../test_orchestration_event_stream.py | 16 +- .../test_orchestrator_stream_chat.py | 14 +- 8 files changed, 94 insertions(+), 109 deletions(-) diff --git a/surfsense_backend/app/tasks/chat/streaming/orchestration/__init__.py b/surfsense_backend/app/tasks/chat/streaming/orchestration/__init__.py index 6f683a410..b1a201fd3 100644 --- a/surfsense_backend/app/tasks/chat/streaming/orchestration/__init__.py +++ b/surfsense_backend/app/tasks/chat/streaming/orchestration/__init__.py @@ -1,11 +1,11 @@ """Composable orchestration pieces for chat streaming.""" -from app.tasks.chat.streaming.orchestration.event_stream import stream_agent_events -from app.tasks.chat.streaming.orchestration.input import StreamExecutionInput -from app.tasks.chat.streaming.orchestration.output import StreamOutput +from app.tasks.chat.streaming.orchestration.event_stream import stream_output +from app.tasks.chat.streaming.orchestration.input import StreamingContext +from app.tasks.chat.streaming.orchestration.output import StreamingResult __all__ = [ - "StreamExecutionInput", - "StreamOutput", - "stream_agent_events", + "StreamingContext", + "StreamingResult", + "stream_output", ] diff --git a/surfsense_backend/app/tasks/chat/streaming/orchestration/event_stream.py b/surfsense_backend/app/tasks/chat/streaming/orchestration/event_stream.py index 369883c3a..fc8c13027 100644 --- a/surfsense_backend/app/tasks/chat/streaming/orchestration/event_stream.py +++ b/surfsense_backend/app/tasks/chat/streaming/orchestration/event_stream.py @@ -6,18 +6,18 @@ from collections.abc import AsyncIterator from typing import Any from app.agents.new_chat.feature_flags import get_flags -from app.tasks.chat.streaming.orchestration.output import StreamOutput +from app.tasks.chat.streaming.orchestration.output import StreamingResult from app.tasks.chat.streaming.relay.event_relay import EventRelay from app.tasks.chat.streaming.relay.state import AgentEventRelayState -async def stream_agent_events( +async def stream_output( *, agent: Any, config: dict[str, Any], input_data: Any, streaming_service: Any, - result: StreamOutput, + result: StreamingResult, step_prefix: str = "thinking", initial_step_id: str | None = None, initial_step_title: str = "", diff --git a/surfsense_backend/app/tasks/chat/streaming/orchestration/input.py b/surfsense_backend/app/tasks/chat/streaming/orchestration/input.py index 13d43b612..45a33d435 100644 --- a/surfsense_backend/app/tasks/chat/streaming/orchestration/input.py +++ b/surfsense_backend/app/tasks/chat/streaming/orchestration/input.py @@ -7,8 +7,8 @@ from typing import Any @dataclass(frozen=True) -class StreamExecutionInput: - """Container for dependencies required by ``stream_agent_events``.""" +class StreamingContext: + """Container for dependencies required by ``stream_output``.""" agent: Any config: dict[str, Any] diff --git a/surfsense_backend/app/tasks/chat/streaming/orchestration/orchestrator.py b/surfsense_backend/app/tasks/chat/streaming/orchestration/orchestrator.py index 1e32e7f5a..b40083f42 100644 --- a/surfsense_backend/app/tasks/chat/streaming/orchestration/orchestrator.py +++ b/surfsense_backend/app/tasks/chat/streaming/orchestration/orchestrator.py @@ -1,9 +1,4 @@ """Top-level chat streaming entrypoints. - -For now these orchestrator functions are thin compatibility wrappers around the -current ``stream_new_chat`` / ``stream_resume_chat`` implementations. Routing -calls through this module lets us cut over to the fully modular event relay in -one place later without touching API routes again. """ from __future__ import annotations @@ -14,9 +9,47 @@ from typing import Any, Literal from app.agents.new_chat.filesystem_selection import FilesystemSelection from app.db import ChatVisibility from app.tasks.chat.stream_new_chat import stream_new_chat, stream_resume_chat -from app.tasks.chat.streaming.orchestration.event_stream import stream_agent_events -from app.tasks.chat.streaming.orchestration.input import StreamExecutionInput -from app.tasks.chat.streaming.orchestration.output import StreamOutput +from app.tasks.chat.streaming.orchestration.event_stream import stream_output +from app.tasks.chat.streaming.orchestration.input import StreamingContext +from app.tasks.chat.streaming.orchestration.output import StreamingResult + + +def _build_streaming_result( + *, + chat_id: int, + request_id: str | None, + filesystem_selection: FilesystemSelection | None, + suffix: str, +) -> StreamingResult: + return StreamingResult( + request_id=request_id, + turn_id=f"{chat_id}:{suffix}", + filesystem_mode=(filesystem_selection.mode.value if filesystem_selection else "cloud"), + client_platform=( + filesystem_selection.client_platform.value if filesystem_selection else "web" + ), + ) + + +async def _stream_output_with_streaming_context( + *, + streaming_context: StreamingContext, + result: StreamingResult, +) -> AsyncGenerator[str, None]: + async for frame in stream_output( + agent=streaming_context.agent, + config=streaming_context.config, + input_data=streaming_context.input_data, + streaming_service=streaming_context.streaming_service, + result=result, + step_prefix=streaming_context.step_prefix, + initial_step_id=streaming_context.initial_step_id, + initial_step_title=streaming_context.initial_step_title, + initial_step_items=streaming_context.initial_step_items, + content_builder=streaming_context.content_builder, + runtime_context=streaming_context.runtime_context, + ): + yield frame async def stream_chat( @@ -37,34 +70,19 @@ async def stream_chat( filesystem_selection: FilesystemSelection | None = None, request_id: str | None = None, user_image_data_urls: list[str] | None = None, - orchestration_input: StreamExecutionInput | None = None, + streaming_context: StreamingContext | None = None, ) -> AsyncGenerator[str, None]: """Stream a new chat turn through the current production pipeline.""" - if orchestration_input is not None: - result = StreamOutput( + if streaming_context is not None: + result = _build_streaming_result( + chat_id=chat_id, request_id=request_id, - turn_id=f"{chat_id}:orchestrator", - filesystem_mode=( - filesystem_selection.mode.value if filesystem_selection else "cloud" - ), - client_platform=( - filesystem_selection.client_platform.value - if filesystem_selection - else "web" - ), + filesystem_selection=filesystem_selection, + suffix="orchestrator", ) - async for frame in stream_agent_events( - agent=orchestration_input.agent, - config=orchestration_input.config, - input_data=orchestration_input.input_data, - streaming_service=orchestration_input.streaming_service, + async for frame in _stream_output_with_streaming_context( + streaming_context=streaming_context, result=result, - step_prefix=orchestration_input.step_prefix, - initial_step_id=orchestration_input.initial_step_id, - initial_step_title=orchestration_input.initial_step_title, - initial_step_items=orchestration_input.initial_step_items, - content_builder=orchestration_input.content_builder, - runtime_context=orchestration_input.runtime_context, ): yield frame return @@ -101,34 +119,19 @@ async def stream_resume( filesystem_selection: FilesystemSelection | None = None, request_id: str | None = None, disabled_tools: list[str] | None = None, - orchestration_input: StreamExecutionInput | None = None, + streaming_context: StreamingContext | None = None, ) -> AsyncGenerator[str, None]: """Resume an interrupted chat turn through the current production pipeline.""" - if orchestration_input is not None: - result = StreamOutput( + if streaming_context is not None: + result = _build_streaming_result( + chat_id=chat_id, request_id=request_id, - turn_id=f"{chat_id}:orchestrator-resume", - filesystem_mode=( - filesystem_selection.mode.value if filesystem_selection else "cloud" - ), - client_platform=( - filesystem_selection.client_platform.value - if filesystem_selection - else "web" - ), + filesystem_selection=filesystem_selection, + suffix="orchestrator-resume", ) - async for frame in stream_agent_events( - agent=orchestration_input.agent, - config=orchestration_input.config, - input_data=orchestration_input.input_data, - streaming_service=orchestration_input.streaming_service, + async for frame in _stream_output_with_streaming_context( + streaming_context=streaming_context, result=result, - step_prefix=orchestration_input.step_prefix, - initial_step_id=orchestration_input.initial_step_id, - initial_step_title=orchestration_input.initial_step_title, - initial_step_items=orchestration_input.initial_step_items, - content_builder=orchestration_input.content_builder, - runtime_context=orchestration_input.runtime_context, ): yield frame return @@ -166,34 +169,19 @@ async def stream_regenerate( request_id: str | None = None, user_image_data_urls: list[str] | None = None, flow: Literal["new", "regenerate"] = "regenerate", - orchestration_input: StreamExecutionInput | None = None, + streaming_context: StreamingContext | None = None, ) -> AsyncGenerator[str, None]: """Regenerate an assistant turn through the current production pipeline.""" - if orchestration_input is not None: - result = StreamOutput( + if streaming_context is not None: + result = _build_streaming_result( + chat_id=chat_id, request_id=request_id, - turn_id=f"{chat_id}:orchestrator-regenerate", - filesystem_mode=( - filesystem_selection.mode.value if filesystem_selection else "cloud" - ), - client_platform=( - filesystem_selection.client_platform.value - if filesystem_selection - else "web" - ), + filesystem_selection=filesystem_selection, + suffix="orchestrator-regenerate", ) - async for frame in stream_agent_events( - agent=orchestration_input.agent, - config=orchestration_input.config, - input_data=orchestration_input.input_data, - streaming_service=orchestration_input.streaming_service, + async for frame in _stream_output_with_streaming_context( + streaming_context=streaming_context, result=result, - step_prefix=orchestration_input.step_prefix, - initial_step_id=orchestration_input.initial_step_id, - initial_step_title=orchestration_input.initial_step_title, - initial_step_items=orchestration_input.initial_step_items, - content_builder=orchestration_input.content_builder, - runtime_context=orchestration_input.runtime_context, ): yield frame return diff --git a/surfsense_backend/app/tasks/chat/streaming/orchestration/output.py b/surfsense_backend/app/tasks/chat/streaming/orchestration/output.py index 0c4870ec4..60f8ee6ee 100644 --- a/surfsense_backend/app/tasks/chat/streaming/orchestration/output.py +++ b/surfsense_backend/app/tasks/chat/streaming/orchestration/output.py @@ -7,7 +7,7 @@ from typing import Any @dataclass -class StreamOutput: +class StreamingResult: accumulated_text: str = "" is_interrupted: bool = False interrupt_value: dict[str, Any] | None = None @@ -27,6 +27,3 @@ class StreamOutput: assistant_message_id: int | None = None content_builder: Any | None = field(default=None, repr=False) - -# Backwards-compatible alias while imports migrate. -StreamResult = StreamOutput diff --git a/surfsense_backend/app/tasks/chat/streaming/relay/event_relay.py b/surfsense_backend/app/tasks/chat/streaming/relay/event_relay.py index 072baac72..c8aebd99c 100644 --- a/surfsense_backend/app/tasks/chat/streaming/relay/event_relay.py +++ b/surfsense_backend/app/tasks/chat/streaming/relay/event_relay.py @@ -16,7 +16,7 @@ from app.tasks.chat.streaming.handlers.custom_event_dispatch import ( ) from app.tasks.chat.streaming.handlers.tool_end import iter_tool_end_frames from app.tasks.chat.streaming.handlers.tool_start import iter_tool_start_frames -from app.tasks.chat.streaming.orchestration.output import StreamOutput +from app.tasks.chat.streaming.orchestration.output import StreamingResult from app.tasks.chat.streaming.relay.state import AgentEventRelayState from app.tasks.chat.streaming.relay.thinking_step_completion import ( complete_active_thinking_step, @@ -52,7 +52,7 @@ class EventRelay: events: AsyncIterator[dict[str, Any]], *, state: AgentEventRelayState, - result: StreamOutput, + result: StreamingResult, step_prefix: str = "thinking", content_builder: Any | None = None, config: dict[str, Any] | None = None, diff --git a/surfsense_backend/tests/unit/tasks/chat/streaming/test_orchestration_event_stream.py b/surfsense_backend/tests/unit/tasks/chat/streaming/test_orchestration_event_stream.py index bd154e6a0..b17d82293 100644 --- a/surfsense_backend/tests/unit/tasks/chat/streaming/test_orchestration_event_stream.py +++ b/surfsense_backend/tests/unit/tasks/chat/streaming/test_orchestration_event_stream.py @@ -7,8 +7,8 @@ from typing import Any import pytest -from app.tasks.chat.streaming.orchestration import stream_agent_events -from app.tasks.chat.streaming.orchestration.output import StreamOutput +from app.tasks.chat.streaming.orchestration import stream_output +from app.tasks.chat.streaming.orchestration.output import StreamingResult pytestmark = pytest.mark.unit @@ -56,7 +56,7 @@ async def _collect(stream: Any) -> list[str]: return out -async def test_stream_agent_events_emits_text_lifecycle_and_updates_result() -> None: +async def test_stream_output_emits_text_lifecycle_and_updates_result() -> None: service = _StreamingService() agent = _Agent( [ @@ -64,10 +64,10 @@ async def test_stream_agent_events_emits_text_lifecycle_and_updates_result() -> {"event": "on_chat_model_stream", "data": {"chunk": _Chunk(content=" world")}}, ] ) - result = StreamOutput() + result = StreamingResult() frames = await _collect( - stream_agent_events( + stream_output( agent=agent, config={"configurable": {"thread_id": "t-1"}}, input_data={"messages": []}, @@ -86,7 +86,7 @@ async def test_stream_agent_events_emits_text_lifecycle_and_updates_result() -> assert result.agent_called_update_memory is False -async def test_stream_agent_events_passes_runtime_context_to_agent() -> None: +async def test_stream_output_passes_runtime_context_to_agent() -> None: service = _StreamingService() class _ContextAwareAgent: async def astream_events(self, input_data: Any, **kwargs: Any): @@ -95,10 +95,10 @@ async def test_stream_agent_events_passes_runtime_context_to_agent() -> None: yield {"event": "on_chat_model_stream", "data": {"chunk": _Chunk(text)}} agent = _ContextAwareAgent() - result = StreamOutput() + result = StreamingResult() frames = await _collect( - stream_agent_events( + stream_output( agent=agent, config={"configurable": {"thread_id": "t-2"}}, input_data={"messages": []}, diff --git a/surfsense_backend/tests/unit/tasks/chat/streaming/test_orchestrator_stream_chat.py b/surfsense_backend/tests/unit/tasks/chat/streaming/test_orchestrator_stream_chat.py index d9cd7951f..b84193cb7 100644 --- a/surfsense_backend/tests/unit/tasks/chat/streaming/test_orchestrator_stream_chat.py +++ b/surfsense_backend/tests/unit/tasks/chat/streaming/test_orchestrator_stream_chat.py @@ -7,7 +7,7 @@ from typing import Any import pytest -from app.tasks.chat.streaming.orchestration import StreamExecutionInput +from app.tasks.chat.streaming.orchestration import StreamingContext from app.tasks.chat.streaming.orchestration.orchestrator import ( stream_chat, stream_regenerate, @@ -60,7 +60,7 @@ async def _collect(stream: Any) -> list[str]: return out -async def test_stream_chat_uses_orchestration_input_path() -> None: +async def test_stream_chat_uses_streaming_context_path() -> None: service = _StreamingService() agent = _Agent( [ @@ -73,7 +73,7 @@ async def test_stream_chat_uses_orchestration_input_path() -> None: user_query="ignored-here", search_space_id=1, chat_id=77, - orchestration_input=StreamExecutionInput( + streaming_context=StreamingContext( agent=agent, config={"configurable": {"thread_id": "thread-1"}}, input_data={"messages": []}, @@ -90,7 +90,7 @@ async def test_stream_chat_uses_orchestration_input_path() -> None: ] -async def test_stream_resume_uses_orchestration_input_path() -> None: +async def test_stream_resume_uses_streaming_context_path() -> None: service = _StreamingService() agent = _Agent([{"event": "on_chat_model_stream", "data": {"chunk": _Chunk("r")}}]) @@ -99,7 +99,7 @@ async def test_stream_resume_uses_orchestration_input_path() -> None: chat_id=9, search_space_id=1, decisions=[], - orchestration_input=StreamExecutionInput( + streaming_context=StreamingContext( agent=agent, config={"configurable": {"thread_id": "thread-r"}}, input_data={"messages": []}, @@ -115,7 +115,7 @@ async def test_stream_resume_uses_orchestration_input_path() -> None: ] -async def test_stream_regenerate_uses_orchestration_input_path() -> None: +async def test_stream_regenerate_uses_streaming_context_path() -> None: service = _StreamingService() agent = _Agent([{"event": "on_chat_model_stream", "data": {"chunk": _Chunk("g")}}]) @@ -124,7 +124,7 @@ async def test_stream_regenerate_uses_orchestration_input_path() -> None: user_query="q", search_space_id=1, chat_id=2, - orchestration_input=StreamExecutionInput( + streaming_context=StreamingContext( agent=agent, config={"configurable": {"thread_id": "thread-g"}}, input_data={"messages": []},