diff --git a/surfsense_backend/app/config/__init__.py b/surfsense_backend/app/config/__init__.py index f6f0c7f62..543524456 100644 --- a/surfsense_backend/app/config/__init__.py +++ b/surfsense_backend/app/config/__init__.py @@ -490,6 +490,12 @@ class Config: ENABLE_DESKTOP_LOCAL_FILESYSTEM = ( os.getenv("ENABLE_DESKTOP_LOCAL_FILESYSTEM", "FALSE").upper() == "TRUE" ) + # Streaming entrypoint switch. Keep this at the route layer so orchestrator + # code stays free of legacy fallback branching. + ENABLE_CHAT_STREAM_ORCHESTRATOR = ( + os.getenv("SURFSENSE_ENABLE_CHAT_STREAM_ORCHESTRATOR", "TRUE").upper() + == "TRUE" + ) @classmethod def is_self_hosted(cls) -> bool: diff --git a/surfsense_backend/app/routes/new_chat_routes.py b/surfsense_backend/app/routes/new_chat_routes.py index ad96654f5..7f035daef 100644 --- a/surfsense_backend/app/routes/new_chat_routes.py +++ b/surfsense_backend/app/routes/new_chat_routes.py @@ -71,7 +71,15 @@ from app.schemas.new_chat import ( TokenUsageSummary, TurnStatusResponse, ) -from app.tasks.chat.stream_new_chat import stream_new_chat, stream_resume_chat +from app.tasks.chat.stream_new_chat import ( + stream_new_chat as legacy_stream_new_chat, + stream_resume_chat as legacy_stream_resume_chat, +) +from app.tasks.chat.streaming.orchestrator import ( + stream_chat, + stream_regenerate, + stream_resume, +) from app.users import current_active_user from app.utils.perf import get_perf_logger from app.utils.rbac import check_permission @@ -90,6 +98,10 @@ TURN_CANCELLING_MAX_DELAY_MS = 1500 router = APIRouter() +def _use_streaming_orchestrator() -> bool: + return config.ENABLE_CHAT_STREAM_ORCHESTRATOR + + def _resolve_filesystem_selection( *, mode: str, @@ -1770,7 +1782,11 @@ async def handle_new_chat( ) return StreamingResponse( - stream_new_chat( + ( + stream_chat + if _use_streaming_orchestrator() + else legacy_stream_new_chat + )( user_query=request.user_query, search_space_id=request.search_space_id, chat_id=request.chat_id, @@ -2255,7 +2271,12 @@ async def regenerate_response( else None ) try: - async for chunk in stream_new_chat( + regenerate_fn = ( + stream_regenerate + if _use_streaming_orchestrator() + else legacy_stream_new_chat + ) + async for chunk in regenerate_fn( user_query=str(user_query_to_use), search_space_id=request.search_space_id, chat_id=thread_id, @@ -2387,7 +2408,11 @@ async def resume_chat( await session.close() return StreamingResponse( - stream_resume_chat( + ( + stream_resume + if _use_streaming_orchestrator() + else legacy_stream_resume_chat + )( chat_id=thread_id, search_space_id=request.search_space_id, decisions=decisions, diff --git a/surfsense_backend/app/tasks/chat/streaming/orchestrator.py b/surfsense_backend/app/tasks/chat/streaming/orchestrator.py index 1b8558bc6..e912dd632 100644 --- a/surfsense_backend/app/tasks/chat/streaming/orchestrator.py +++ b/surfsense_backend/app/tasks/chat/streaming/orchestrator.py @@ -1,48 +1,127 @@ -"""Top-level chat streaming entrypoints (stubs until wired).""" +"""Top-level chat streaming entrypoints. + +For now these orchestrator functions are thin compatibility wrappers around the +current ``stream_new_chat`` / ``stream_resume_chat`` implementations. Routing +calls through this module lets us cut over to the fully modular event relay in +one place later without touching API routes again. +""" from __future__ import annotations from collections.abc import AsyncGenerator -from typing import Any +from typing import Any, Literal + +from app.agents.new_chat.filesystem_selection import FilesystemSelection +from app.db import ChatVisibility +from app.tasks.chat.stream_new_chat import stream_new_chat, stream_resume_chat async def stream_chat( *, - request: Any, - user: Any, - db_session: Any, -) -> AsyncGenerator[str, None]: # pragma: no cover - orchestrator port in progress - del request, user, db_session - raise NotImplementedError( - "stream_chat: orchestrator not wired yet" - ) - if False: # pragma: no cover - yield "" + user_query: str, + search_space_id: int, + chat_id: int, + user_id: str | None = None, + llm_config_id: int = -1, + mentioned_document_ids: list[int] | None = None, + mentioned_surfsense_doc_ids: list[int] | None = None, + mentioned_documents: list[dict[str, Any]] | None = None, + checkpoint_id: str | None = None, + needs_history_bootstrap: bool = False, + thread_visibility: ChatVisibility | None = None, + current_user_display_name: str | None = None, + disabled_tools: list[str] | None = None, + filesystem_selection: FilesystemSelection | None = None, + request_id: str | None = None, + user_image_data_urls: list[str] | None = None, +) -> AsyncGenerator[str, None]: + """Stream a new chat turn through the current production pipeline.""" + async for chunk in stream_new_chat( + user_query=user_query, + search_space_id=search_space_id, + chat_id=chat_id, + user_id=user_id, + llm_config_id=llm_config_id, + mentioned_document_ids=mentioned_document_ids, + mentioned_surfsense_doc_ids=mentioned_surfsense_doc_ids, + mentioned_documents=mentioned_documents, + checkpoint_id=checkpoint_id, + needs_history_bootstrap=needs_history_bootstrap, + thread_visibility=thread_visibility, + current_user_display_name=current_user_display_name, + disabled_tools=disabled_tools, + filesystem_selection=filesystem_selection, + request_id=request_id, + user_image_data_urls=user_image_data_urls, + ): + yield chunk async def stream_resume( *, - request: Any, - user: Any, - db_session: Any, -) -> AsyncGenerator[str, None]: # pragma: no cover - orchestrator port in progress - del request, user, db_session - raise NotImplementedError( - "stream_resume: orchestrator not wired yet" - ) - if False: # pragma: no cover - yield "" + chat_id: int, + search_space_id: int, + decisions: list[dict], + user_id: str | None = None, + llm_config_id: int = -1, + thread_visibility: ChatVisibility | None = None, + filesystem_selection: FilesystemSelection | None = None, + request_id: str | None = None, + disabled_tools: list[str] | None = None, +) -> AsyncGenerator[str, None]: + """Resume an interrupted chat turn through the current production pipeline.""" + async for chunk in stream_resume_chat( + chat_id=chat_id, + search_space_id=search_space_id, + decisions=decisions, + user_id=user_id, + llm_config_id=llm_config_id, + thread_visibility=thread_visibility, + filesystem_selection=filesystem_selection, + request_id=request_id, + disabled_tools=disabled_tools, + ): + yield chunk async def stream_regenerate( *, - request: Any, - user: Any, - db_session: Any, -) -> AsyncGenerator[str, None]: # pragma: no cover - orchestrator port in progress - del request, user, db_session - raise NotImplementedError( - "stream_regenerate: orchestrator not wired yet" - ) - if False: # pragma: no cover - yield "" + user_query: str, + search_space_id: int, + chat_id: int, + user_id: str | None = None, + llm_config_id: int = -1, + mentioned_document_ids: list[int] | None = None, + mentioned_surfsense_doc_ids: list[int] | None = None, + mentioned_documents: list[dict[str, Any]] | None = None, + checkpoint_id: str | None = None, + needs_history_bootstrap: bool = False, + thread_visibility: ChatVisibility | None = None, + current_user_display_name: str | None = None, + disabled_tools: list[str] | None = None, + filesystem_selection: FilesystemSelection | None = None, + request_id: str | None = None, + user_image_data_urls: list[str] | None = None, + flow: Literal["new", "regenerate"] = "regenerate", +) -> AsyncGenerator[str, None]: + """Regenerate an assistant turn through the current production pipeline.""" + async for chunk in stream_new_chat( + user_query=user_query, + search_space_id=search_space_id, + chat_id=chat_id, + user_id=user_id, + llm_config_id=llm_config_id, + mentioned_document_ids=mentioned_document_ids, + mentioned_surfsense_doc_ids=mentioned_surfsense_doc_ids, + mentioned_documents=mentioned_documents, + checkpoint_id=checkpoint_id, + needs_history_bootstrap=needs_history_bootstrap, + thread_visibility=thread_visibility, + current_user_display_name=current_user_display_name, + disabled_tools=disabled_tools, + filesystem_selection=filesystem_selection, + request_id=request_id, + user_image_data_urls=user_image_data_urls, + flow=flow, + ): + yield chunk