diff --git a/surfsense_backend/app/tasks/chat/streaming/orchestration/orchestrator.py b/surfsense_backend/app/tasks/chat/streaming/orchestration/orchestrator.py index ac7abc6f4..1e32e7f5a 100644 --- a/surfsense_backend/app/tasks/chat/streaming/orchestration/orchestrator.py +++ b/surfsense_backend/app/tasks/chat/streaming/orchestration/orchestrator.py @@ -101,8 +101,38 @@ async def stream_resume( filesystem_selection: FilesystemSelection | None = None, request_id: str | None = None, disabled_tools: list[str] | None = None, + orchestration_input: StreamExecutionInput | None = None, ) -> AsyncGenerator[str, None]: """Resume an interrupted chat turn through the current production pipeline.""" + if orchestration_input is not None: + result = StreamOutput( + request_id=request_id, + turn_id=f"{chat_id}:orchestrator-resume", + filesystem_mode=( + filesystem_selection.mode.value if filesystem_selection else "cloud" + ), + client_platform=( + filesystem_selection.client_platform.value + if filesystem_selection + else "web" + ), + ) + async for frame in stream_agent_events( + agent=orchestration_input.agent, + config=orchestration_input.config, + input_data=orchestration_input.input_data, + streaming_service=orchestration_input.streaming_service, + result=result, + step_prefix=orchestration_input.step_prefix, + initial_step_id=orchestration_input.initial_step_id, + initial_step_title=orchestration_input.initial_step_title, + initial_step_items=orchestration_input.initial_step_items, + content_builder=orchestration_input.content_builder, + runtime_context=orchestration_input.runtime_context, + ): + yield frame + return + async for chunk in stream_resume_chat( chat_id=chat_id, search_space_id=search_space_id, @@ -136,8 +166,38 @@ async def stream_regenerate( request_id: str | None = None, user_image_data_urls: list[str] | None = None, flow: Literal["new", "regenerate"] = "regenerate", + orchestration_input: StreamExecutionInput | None = None, ) -> AsyncGenerator[str, None]: """Regenerate an assistant turn through the current production pipeline.""" + if orchestration_input is not None: + result = StreamOutput( + request_id=request_id, + turn_id=f"{chat_id}:orchestrator-regenerate", + filesystem_mode=( + filesystem_selection.mode.value if filesystem_selection else "cloud" + ), + client_platform=( + filesystem_selection.client_platform.value + if filesystem_selection + else "web" + ), + ) + async for frame in stream_agent_events( + agent=orchestration_input.agent, + config=orchestration_input.config, + input_data=orchestration_input.input_data, + streaming_service=orchestration_input.streaming_service, + result=result, + step_prefix=orchestration_input.step_prefix, + initial_step_id=orchestration_input.initial_step_id, + initial_step_title=orchestration_input.initial_step_title, + initial_step_items=orchestration_input.initial_step_items, + content_builder=orchestration_input.content_builder, + runtime_context=orchestration_input.runtime_context, + ): + yield frame + return + async for chunk in stream_new_chat( user_query=user_query, search_space_id=search_space_id, diff --git a/surfsense_backend/tests/unit/tasks/chat/streaming/test_orchestration_event_stream.py b/surfsense_backend/tests/unit/tasks/chat/streaming/test_orchestration_event_stream.py index e0a1877a8..bd154e6a0 100644 --- a/surfsense_backend/tests/unit/tasks/chat/streaming/test_orchestration_event_stream.py +++ b/surfsense_backend/tests/unit/tasks/chat/streaming/test_orchestration_event_stream.py @@ -84,15 +84,20 @@ async def test_stream_agent_events_emits_text_lifecycle_and_updates_result() -> ] assert result.accumulated_text == "Hello world" assert result.agent_called_update_memory is False - assert agent.calls[0][1]["version"] == "v2" async def test_stream_agent_events_passes_runtime_context_to_agent() -> None: service = _StreamingService() - agent = _Agent([{"event": "on_chat_model_stream", "data": {"chunk": _Chunk("x")}}]) + class _ContextAwareAgent: + async def astream_events(self, input_data: Any, **kwargs: Any): + del input_data + text = "ctx-ok" if kwargs.get("context") else "ctx-missing" + yield {"event": "on_chat_model_stream", "data": {"chunk": _Chunk(text)}} + + agent = _ContextAwareAgent() result = StreamOutput() - _ = await _collect( + frames = await _collect( stream_agent_events( agent=agent, config={"configurable": {"thread_id": "t-2"}}, @@ -103,5 +108,8 @@ async def test_stream_agent_events_passes_runtime_context_to_agent() -> None: ) ) - assert agent.calls - assert agent.calls[0][1]["context"] == {"mentioned_document_ids": [1, 2]} + assert frames == [ + "text_start:text-1", + "text_delta:text-1:ctx-ok", + "text_end:text-1", + ] diff --git a/surfsense_backend/tests/unit/tasks/chat/streaming/test_orchestrator_stream_chat.py b/surfsense_backend/tests/unit/tasks/chat/streaming/test_orchestrator_stream_chat.py index cf54fdab0..d9cd7951f 100644 --- a/surfsense_backend/tests/unit/tasks/chat/streaming/test_orchestrator_stream_chat.py +++ b/surfsense_backend/tests/unit/tasks/chat/streaming/test_orchestrator_stream_chat.py @@ -8,7 +8,11 @@ from typing import Any import pytest from app.tasks.chat.streaming.orchestration import StreamExecutionInput -from app.tasks.chat.streaming.orchestration.orchestrator import stream_chat +from app.tasks.chat.streaming.orchestration.orchestrator import ( + stream_chat, + stream_regenerate, + stream_resume, +) pytestmark = pytest.mark.unit @@ -84,5 +88,53 @@ async def test_stream_chat_uses_orchestration_input_path() -> None: "text_delta:text-1:!", "text_end:text-1", ] - assert agent.calls - assert agent.calls[0][1]["version"] == "v2" + + +async def test_stream_resume_uses_orchestration_input_path() -> None: + service = _StreamingService() + agent = _Agent([{"event": "on_chat_model_stream", "data": {"chunk": _Chunk("r")}}]) + + frames = await _collect( + stream_resume( + chat_id=9, + search_space_id=1, + decisions=[], + orchestration_input=StreamExecutionInput( + agent=agent, + config={"configurable": {"thread_id": "thread-r"}}, + input_data={"messages": []}, + streaming_service=service, + ), + ) + ) + + assert frames == [ + "text_start:text-1", + "text_delta:text-1:r", + "text_end:text-1", + ] + + +async def test_stream_regenerate_uses_orchestration_input_path() -> None: + service = _StreamingService() + agent = _Agent([{"event": "on_chat_model_stream", "data": {"chunk": _Chunk("g")}}]) + + frames = await _collect( + stream_regenerate( + user_query="q", + search_space_id=1, + chat_id=2, + orchestration_input=StreamExecutionInput( + agent=agent, + config={"configurable": {"thread_id": "thread-g"}}, + input_data={"messages": []}, + streaming_service=service, + ), + ) + ) + + assert frames == [ + "text_start:text-1", + "text_delta:text-1:g", + "text_end:text-1", + ]