From 5b45f78a168dcd6024103dee3c176514a756da52 Mon Sep 17 00:00:00 2001 From: CREDO23 Date: Thu, 4 Jun 2026 14:35:45 +0200 Subject: [PATCH] refactor(chat): delete legacy stream_new_chat monolith (cutover complete) The flows orchestrators (new_chat/resume_chat) are now the sole live path after the byte-for-byte differential proof, so the monolith and its monolith-vs-flows parity scaffolding are removed. - Repoint the last live importer (anonymous_chat_routes) to streaming.agent.event_loop.stream_agent_events + shared.stream_result.StreamResult (drop-in; the keyword-only fallback-commit params default to inert for anon). - Repoint e2e launcher patch targets to flows.shared.llm_bundle. - Repoint helper unit tests (chunk_parts, thinking-step ids, tool-input streaming) to their flows homes to preserve coverage. - Delete the monolith, the contract test, and the parity tests (parallel_refactor, stage_1, stage_2, orchestrator_frame) whose sole purpose was comparing against the now-removed monolith. Full suite green (2622 passed, 1 skipped); the two excluded live-app dirs (document_upload, composio) have a pre-existing, env-gated registration 404 unrelated to this change. --- .../app/routes/anonymous_chat_routes.py | 5 +- .../app/tasks/chat/stream_new_chat.py | 3050 ----------------- surfsense_backend/tests/e2e/run_backend.py | 4 +- surfsense_backend/tests/e2e/run_celery.py | 4 +- .../test_orchestrator_frame_parity.py | 457 --- .../test_parallel_refactor_parity.py | 584 ---- .../chat/streaming/test_stage_1_parity.py | 240 -- .../chat/streaming/test_stage_2_parity.py | 241 -- .../tasks/chat/test_extract_chunk_parts.py | 6 +- .../chat/test_thinking_step_id_uniqueness.py | 12 +- .../tasks/chat/test_tool_input_streaming.py | 12 +- .../unit/test_stream_new_chat_contract.py | 438 --- 12 files changed, 25 insertions(+), 5028 deletions(-) delete mode 100644 surfsense_backend/app/tasks/chat/stream_new_chat.py delete mode 100644 surfsense_backend/tests/unit/tasks/chat/streaming/test_orchestrator_frame_parity.py delete mode 100644 surfsense_backend/tests/unit/tasks/chat/streaming/test_parallel_refactor_parity.py delete mode 100644 surfsense_backend/tests/unit/tasks/chat/streaming/test_stage_1_parity.py delete mode 100644 surfsense_backend/tests/unit/tasks/chat/streaming/test_stage_2_parity.py delete mode 100644 surfsense_backend/tests/unit/test_stream_new_chat_contract.py diff --git a/surfsense_backend/app/routes/anonymous_chat_routes.py b/surfsense_backend/app/routes/anonymous_chat_routes.py index 1a283ef29..bf71a0348 100644 --- a/surfsense_backend/app/routes/anonymous_chat_routes.py +++ b/surfsense_backend/app/routes/anonymous_chat_routes.py @@ -356,7 +356,8 @@ async def stream_anonymous_chat( from app.db import shielded_async_session from app.services.new_streaming_service import VercelStreamingService from app.services.token_tracking_service import start_turn - from app.tasks.chat.stream_new_chat import StreamResult, _stream_agent_events + from app.tasks.chat.streaming.agent.event_loop import stream_agent_events + from app.tasks.chat.streaming.shared.stream_result import StreamResult accumulator = start_turn() streaming_service = VercelStreamingService() @@ -419,7 +420,7 @@ async def stream_anonymous_chat( stream_result = StreamResult() - async for sse in _stream_agent_events( + async for sse in stream_agent_events( agent=agent, config=langgraph_config, input_data=input_state, diff --git a/surfsense_backend/app/tasks/chat/stream_new_chat.py b/surfsense_backend/app/tasks/chat/stream_new_chat.py deleted file mode 100644 index cec13204f..000000000 --- a/surfsense_backend/app/tasks/chat/stream_new_chat.py +++ /dev/null @@ -1,3050 +0,0 @@ -""" -Streaming task for the new SurfSense deep agent chat. - -This module streams responses from the deep agent using the Vercel AI SDK -Data Stream Protocol (SSE format). - -Supports loading LLM configurations from: -- YAML files (negative IDs for global configs) -- NewLLMConfig database table (positive IDs for user-created configs with prompt settings) -""" - -import asyncio -import contextlib -import gc -import json -import logging -import sys -import time -from collections.abc import AsyncGenerator -from dataclasses import dataclass, field -from functools import partial -from typing import Any, Literal -from uuid import UUID - -import anyio -from langchain_core.messages import HumanMessage -from sqlalchemy.future import select - -from app.agents.multi_agent_chat import create_multi_agent_chat_deep_agent -from app.agents.shared.checkpointer import get_checkpointer -from app.agents.shared.context import SurfSenseContextSchema -from app.agents.shared.errors import BusyError -from app.agents.shared.filesystem_selection import FilesystemMode, FilesystemSelection -from app.agents.shared.llm_config import ( - AgentConfig, - create_chat_litellm_from_agent_config, - create_chat_litellm_from_config, - load_agent_config, - load_global_llm_config_by_id, -) -from app.agents.shared.mention_resolver import resolve_mentions, substitute_in_text -from app.agents.shared.middleware.busy_mutex import ( - end_turn, - get_cancel_state, - is_cancel_requested, -) -from app.agents.shared.middleware.kb_persistence import ( - commit_staged_filesystem_state, -) -from app.db import ( - ChatVisibility, - NewChatMessage, - NewChatThread, - Report, - SearchSourceConnectorType, - async_session_maker, - shielded_async_session, -) -from app.observability import metrics as ot_metrics, otel as ot -from app.prompts import TITLE_GENERATION_PROMPT -from app.services.auto_model_pin_service import ( - mark_runtime_cooldown, - resolve_or_get_pinned_llm_config_id, -) -from app.services.chat_session_state_service import ( - clear_ai_responding, - set_ai_responding, -) -from app.services.connector_service import ConnectorService -from app.services.new_streaming_service import VercelStreamingService -from app.tasks.chat.streaming.graph_stream.event_stream import stream_output -from app.tasks.chat.streaming.helpers.interrupt_inspector import ( - all_interrupt_values, -) -from app.utils.content_utils import bootstrap_history_from_db -from app.utils.perf import get_perf_logger, log_system_snapshot, trim_native_heap -from app.utils.user_message_multimodal import build_human_message_content - -_background_tasks: set[asyncio.Task] = set() -_perf_log = get_perf_logger() -logger = logging.getLogger(__name__) - -TURN_CANCELLING_INITIAL_DELAY_MS = 200 -TURN_CANCELLING_BACKOFF_FACTOR = 2 -TURN_CANCELLING_MAX_DELAY_MS = 1500 - - -def _resume_step_prefix(turn_id: str) -> str: - """Build the per-turn ``step_prefix`` for a resume invocation. - - Each ``_stream_agent_events`` call constructs a fresh - :class:`AgentEventRelayState` with ``thinking_step_counter=0``, so two - consecutive resume turns would otherwise both emit ``thinking-resume-1``, - ``-2`` etc. The frontend rehydrates ``currentThinkingSteps`` from the - immediate prior assistant message at the start of every resume — if the - new stream's IDs collide with the seeded ones, React renders sibling - Timeline rows with the same key. Salting with ``turn_id`` guarantees - disjoint IDs across resumes within one thread. - """ - return f"thinking-resume-{turn_id}" - - -def _compute_turn_cancelling_retry_delay(attempt: int) -> int: - if attempt < 1: - attempt = 1 - delay = TURN_CANCELLING_INITIAL_DELAY_MS * ( - TURN_CANCELLING_BACKOFF_FACTOR ** (attempt - 1) - ) - return min(delay, TURN_CANCELLING_MAX_DELAY_MS) - - -def _extract_chunk_parts(chunk: Any) -> dict[str, Any]: - """Decompose an ``AIMessageChunk`` into typed text/reasoning/tool-call parts. - - Returns a dict with three keys: - - * ``text`` — concatenated string content (empty string if the chunk - contributes none). - * ``reasoning`` — concatenated reasoning content (empty string if the - chunk contributes none). - * ``tool_call_chunks`` — flat list of LangChain ``tool_call_chunk`` - dicts surfaced from either the typed-block list or the - ``tool_call_chunks`` attribute. - - Background - ---------- - ``AIMessageChunk.content`` can be: - - * a ``str`` (most providers), or - * a ``list`` of typed blocks ``{type: 'text' | 'reasoning' | - 'tool_call_chunk' | 'tool_use' | ..., text/content/...}`` for - Anthropic, Bedrock, and several reasoning configurations. - - Reasoning may also live under - ``chunk.additional_kwargs['reasoning_content']`` (some providers - surface it that way instead of as a typed block). Tool-call chunks - may live under ``chunk.tool_call_chunks`` even when ``content`` is a - plain string. - - Earlier versions only handled the ``isinstance(content, str)`` branch - and silently dropped reasoning blocks + tool-call chunks emitted by - LangChain ``AIMessageChunk``s. - """ - out: dict[str, Any] = {"text": "", "reasoning": "", "tool_call_chunks": []} - if chunk is None: - return out - - content = getattr(chunk, "content", None) - if isinstance(content, str): - if content: - out["text"] = content - elif isinstance(content, list): - text_parts: list[str] = [] - reasoning_parts: list[str] = [] - for block in content: - if not isinstance(block, dict): - continue - block_type = block.get("type") - if block_type == "text": - value = block.get("text") or block.get("content") or "" - if isinstance(value, str) and value: - text_parts.append(value) - elif block_type == "reasoning": - value = ( - block.get("reasoning") - or block.get("text") - or block.get("content") - or "" - ) - if isinstance(value, str) and value: - reasoning_parts.append(value) - elif block_type in ("tool_call_chunk", "tool_use"): - out["tool_call_chunks"].append(block) - if text_parts: - out["text"] = "".join(text_parts) - if reasoning_parts: - out["reasoning"] = "".join(reasoning_parts) - - additional = getattr(chunk, "additional_kwargs", None) or {} - if isinstance(additional, dict): - extra_reasoning = additional.get("reasoning_content") - if isinstance(extra_reasoning, str) and extra_reasoning: - existing = out["reasoning"] - out["reasoning"] = ( - (existing + extra_reasoning) if existing else extra_reasoning - ) - - extra_tool_chunks = getattr(chunk, "tool_call_chunks", None) - if isinstance(extra_tool_chunks, list): - for tcc in extra_tool_chunks: - if isinstance(tcc, dict): - out["tool_call_chunks"].append(tcc) - - return out - - -def extract_todos_from_deepagents(command_output) -> dict: - """ - Extract todos from deepagents' TodoListMiddleware Command output. - - deepagents returns a Command object with: - - Command.update['todos'] = [{'content': '...', 'status': '...'}] - - Returns the todos directly (no transformation needed - UI matches deepagents format). - """ - todos_data = [] - if hasattr(command_output, "update"): - # It's a Command object from deepagents - update = command_output.update - todos_data = update.get("todos", []) - elif isinstance(command_output, dict): - # Already a dict - check if it has todos directly or in update - if "todos" in command_output: - todos_data = command_output.get("todos", []) - elif "update" in command_output and isinstance(command_output["update"], dict): - todos_data = command_output["update"].get("todos", []) - - return {"todos": todos_data} - - -@dataclass -class StreamResult: - accumulated_text: str = "" - is_interrupted: bool = False - sandbox_files: list[str] = field(default_factory=list) - request_id: str | None = None - turn_id: str = "" - filesystem_mode: str = "cloud" - client_platform: str = "web" - intent_detected: str = "chat_only" - intent_confidence: float = 0.0 - write_attempted: bool = False - write_succeeded: bool = False - verification_succeeded: bool = False - commit_gate_passed: bool = True - commit_gate_reason: str = "" - # Pre-allocated assistant ``new_chat_messages.id`` for this turn, - # captured by ``persist_assistant_shell`` right after the user row is - # persisted. ``None`` for the legacy / anonymous code paths that don't - # opt in to server-side ``ContentPart[]`` projection. - assistant_message_id: int | None = None - # In-memory mirror of the FE's assistant-ui ``ContentPartsState``, - # populated by the lifecycle methods called from ``_stream_agent_events`` - # at each ``streaming_service.format_*`` yield site. Snapshot in the - # streaming ``finally`` to produce the rich JSONB persisted by - # ``finalize_assistant_turn``. ``repr=False`` keeps the - # log-on-error path (``StreamResult`` is logged in some error - # branches) from dumping a potentially-large parts list. - content_builder: Any | None = field(default=None, repr=False) - - -def _safe_float(value: Any, default: float = 0.0) -> float: - try: - return float(value) - except (TypeError, ValueError): - return default - - -def _tool_output_to_text(tool_output: Any) -> str: - if isinstance(tool_output, dict): - if isinstance(tool_output.get("result"), str): - return tool_output["result"] - if isinstance(tool_output.get("error"), str): - return tool_output["error"] - return json.dumps(tool_output, ensure_ascii=False) - return str(tool_output) - - -def _tool_output_has_error(tool_output: Any) -> bool: - if isinstance(tool_output, dict): - if tool_output.get("error"): - return True - result = tool_output.get("result") - return bool( - isinstance(result, str) and result.strip().lower().startswith("error:") - ) - if isinstance(tool_output, str): - return tool_output.strip().lower().startswith("error:") - return False - - -def _extract_resolved_file_path( - *, tool_name: str, tool_output: Any, tool_input: Any | None = None -) -> str | None: - if isinstance(tool_output, dict): - path_value = tool_output.get("path") - if isinstance(path_value, str) and path_value.strip(): - return path_value.strip() - if tool_name in ("write_file", "edit_file") and isinstance(tool_input, dict): - file_path = tool_input.get("file_path") - if isinstance(file_path, str) and file_path.strip(): - return file_path.strip() - return None - - -def _contract_enforcement_active(result: StreamResult) -> bool: - # Keep policy deterministic with no env-driven progression modes: - # enforce the file-operation contract only in desktop local-folder mode. - return result.filesystem_mode == "desktop_local_folder" - - -def _evaluate_file_contract_outcome(result: StreamResult) -> tuple[bool, str]: - if result.intent_detected != "file_write": - return True, "" - if not result.write_attempted: - return False, "no_write_attempt" - if not result.write_succeeded: - return False, "write_failed" - if not result.verification_succeeded: - return False, "verification_failed" - return True, "" - - -def _log_file_contract(stage: str, result: StreamResult, **extra: Any) -> None: - payload: dict[str, Any] = { - "stage": stage, - "request_id": result.request_id or "unknown", - "turn_id": result.turn_id or "unknown", - "chat_id": result.turn_id.split(":", 1)[0] - if ":" in result.turn_id - else "unknown", - "filesystem_mode": result.filesystem_mode, - "client_platform": result.client_platform, - "intent_detected": result.intent_detected, - "intent_confidence": result.intent_confidence, - "write_attempted": result.write_attempted, - "write_succeeded": result.write_succeeded, - "verification_succeeded": result.verification_succeeded, - "commit_gate_passed": result.commit_gate_passed, - "commit_gate_reason": result.commit_gate_reason or None, - } - payload.update(extra) - _perf_log.info( - "[file_operation_contract] %s", json.dumps(payload, ensure_ascii=False) - ) - - -def _log_chat_stream_error( - *, - flow: Literal["new", "resume", "regenerate"], - error_kind: str, - error_code: str | None, - severity: Literal["info", "warn", "error"], - is_expected: bool, - request_id: str | None, - thread_id: int | None, - search_space_id: int | None, - user_id: str | None, - message: str, - extra: dict[str, Any] | None = None, -) -> None: - payload: dict[str, Any] = { - "event": "chat_stream_error", - "flow": flow, - "error_kind": error_kind, - "error_code": error_code, - "severity": severity, - "is_expected": is_expected, - "request_id": request_id or "unknown", - "thread_id": thread_id, - "search_space_id": search_space_id, - "user_id": user_id, - "message": message, - } - if extra: - payload.update(extra) - - logger = logging.getLogger(__name__) - rendered = json.dumps(payload, ensure_ascii=False) - if severity == "error": - logger.error("[chat_stream_error] %s", rendered) - elif severity == "warn": - logger.warning("[chat_stream_error] %s", rendered) - else: - logger.info("[chat_stream_error] %s", rendered) - - -def _parse_error_payload(message: str) -> dict[str, Any] | None: - candidates = [message] - first_brace_idx = message.find("{") - if first_brace_idx >= 0: - candidates.append(message[first_brace_idx:]) - - for candidate in candidates: - try: - parsed = json.loads(candidate) - if isinstance(parsed, dict): - return parsed - except Exception: - continue - return None - - -def _extract_provider_error_code(parsed: dict[str, Any] | None) -> int | None: - if not isinstance(parsed, dict): - return None - candidates: list[Any] = [parsed.get("code")] - nested = parsed.get("error") - if isinstance(nested, dict): - candidates.append(nested.get("code")) - for value in candidates: - try: - if value is None: - continue - return int(value) - except Exception: - continue - return None - - -def _is_provider_rate_limited(exc: BaseException) -> bool: - """Best-effort detection for provider-side runtime throttling. - - Covers LiteLLM/OpenRouter shapes like: - - class name contains ``RateLimit`` - - nested payload ``{"error": {"code": 429}}`` - - nested payload ``{"error": {"type": "rate_limit_error"}}`` - """ - raw = str(exc) - lowered = raw.lower() - if "ratelimit" in type(exc).__name__.lower(): - return True - parsed = _parse_error_payload(raw) - provider_code = _extract_provider_error_code(parsed) - if provider_code == 429: - return True - - provider_error_type = "" - if parsed: - top_type = parsed.get("type") - if isinstance(top_type, str): - provider_error_type = top_type.lower() - nested = parsed.get("error") - if isinstance(nested, dict): - nested_type = nested.get("type") - if isinstance(nested_type, str): - provider_error_type = nested_type.lower() - if provider_error_type == "rate_limit_error": - return True - - return ( - "rate limited" in lowered - or "rate-limited" in lowered - or "temporarily rate-limited upstream" in lowered - ) - - -async def _build_main_agent_for_thread( - agent_factory: Any, - *, - llm: Any, - search_space_id: int, - db_session: Any, - connector_service: ConnectorService, - checkpointer: Any, - user_id: str | None, - thread_id: int | None, - agent_config: AgentConfig | None, - firecrawl_api_key: str | None, - thread_visibility: ChatVisibility | None, - filesystem_selection: FilesystemSelection | None, - disabled_tools: list[str] | None = None, - mentioned_document_ids: list[int] | None = None, -) -> Any: - """Single (re)build path so the agent factory cannot drift across the - initial build and mid-stream 429 recovery for one ``thread_id``: a - graph swap mid-turn would corrupt checkpointer state.""" - return await agent_factory( - llm=llm, - search_space_id=search_space_id, - db_session=db_session, - connector_service=connector_service, - checkpointer=checkpointer, - user_id=user_id, - thread_id=thread_id, - agent_config=agent_config, - firecrawl_api_key=firecrawl_api_key, - thread_visibility=thread_visibility, - filesystem_selection=filesystem_selection, - disabled_tools=disabled_tools, - mentioned_document_ids=mentioned_document_ids, - ) - - -def _classify_stream_exception( - exc: Exception, - *, - flow_label: str, -) -> tuple[ - str, str, Literal["info", "warn", "error"], bool, str, dict[str, Any] | None -]: - raw = str(exc) - if isinstance(exc, BusyError) or "Thread is busy with another request" in raw: - busy_thread_id = str(exc.request_id) if isinstance(exc, BusyError) else None - if busy_thread_id and is_cancel_requested(busy_thread_id): - cancel_state = get_cancel_state(busy_thread_id) - attempt = cancel_state[0] if cancel_state else 1 - retry_after_ms = _compute_turn_cancelling_retry_delay(attempt) - retry_after_at = int(time.time() * 1000) + retry_after_ms - return ( - "thread_busy", - "TURN_CANCELLING", - "info", - True, - "A previous response is still stopping. Please try again in a moment.", - { - "retry_after_ms": retry_after_ms, - "retry_after_at": retry_after_at, - }, - ) - return ( - "thread_busy", - "THREAD_BUSY", - "warn", - True, - "Another response is still finishing for this thread. Please try again in a moment.", - None, - ) - - if _is_provider_rate_limited(exc): - return ( - "rate_limited", - "RATE_LIMITED", - "warn", - True, - "This model is temporarily rate-limited. Please try again in a few seconds or switch models.", - None, - ) - - return ( - "server_error", - "SERVER_ERROR", - "error", - False, - f"Error during {flow_label}: {raw}", - None, - ) - - -def _emit_stream_terminal_error( - *, - streaming_service: VercelStreamingService, - flow: str, - request_id: str | None, - thread_id: int, - search_space_id: int, - user_id: str | None, - message: str, - error_kind: str = "server_error", - error_code: str = "SERVER_ERROR", - severity: Literal["info", "warn", "error"] = "error", - is_expected: bool = False, - extra: dict[str, Any] | None = None, -) -> str: - _log_chat_stream_error( - flow=flow, - error_kind=error_kind, - error_code=error_code, - severity=severity, - is_expected=is_expected, - request_id=request_id, - thread_id=thread_id, - search_space_id=search_space_id, - user_id=user_id, - message=message, - extra=extra, - ) - return streaming_service.format_error(message, error_code=error_code, extra=extra) - - -def _legacy_match_lc_id( - pending_tool_call_chunks: list[dict[str, Any]], - tool_name: str, - run_id: str, - lc_tool_call_id_by_run: dict[str, str], -) -> str | None: - """Best-effort match a buffered ``tool_call_chunk`` to a tool name. - - Pure extract of the in-line match used at ``on_tool_start`` when the - chunk path didn't register an index for this call. Pops the next - id-bearing chunk whose ``name`` - matches ``tool_name`` (or any id-bearing chunk as a fallback) and - returns its id. Mutates ``pending_tool_call_chunks`` and - ``lc_tool_call_id_by_run`` in place. - """ - matched_idx: int | None = None - for idx, tcc in enumerate(pending_tool_call_chunks): - if tcc.get("name") == tool_name and tcc.get("id"): - matched_idx = idx - break - if matched_idx is None: - for idx, tcc in enumerate(pending_tool_call_chunks): - if tcc.get("id"): - matched_idx = idx - break - if matched_idx is None: - return None - matched = pending_tool_call_chunks.pop(matched_idx) - candidate = matched.get("id") - if isinstance(candidate, str) and candidate: - if run_id: - lc_tool_call_id_by_run[run_id] = candidate - return candidate - return None - - -async def _stream_agent_events( - agent: Any, - config: dict[str, Any], - input_data: Any, - streaming_service: VercelStreamingService, - result: StreamResult, - step_prefix: str = "thinking", - initial_step_id: str | None = None, - initial_step_title: str = "", - initial_step_items: list[str] | None = None, - *, - fallback_commit_search_space_id: int | None = None, - fallback_commit_created_by_id: str | None = None, - fallback_commit_filesystem_mode: FilesystemMode = FilesystemMode.CLOUD, - fallback_commit_thread_id: int | None = None, - runtime_context: Any = None, - content_builder: Any | None = None, -) -> AsyncGenerator[str, None]: - """Shared async generator that streams and formats astream_events from the agent. - - Yields SSE-formatted strings. After exhausting, inspect the ``result`` - object for accumulated_text and interrupt state. - - Args: - agent: The compiled LangGraph agent. - config: LangGraph config dict (must include configurable.thread_id). - input_data: The input to pass to agent.astream_events (dict or Command). - streaming_service: VercelStreamingService instance for formatting events. - result: Mutable StreamResult populated with accumulated_text / interrupt info. - step_prefix: Prefix for thinking step IDs (e.g. "thinking" or "thinking-resume"). - initial_step_id: If set, the helper inherits an already-active thinking step. - initial_step_title: Title of the inherited thinking step. - initial_step_items: Items of the inherited thinking step. - content_builder: Optional ``AssistantContentBuilder``. When set, every - ``streaming_service.format_*`` yield site also drives the matching - builder lifecycle method (``on_text_*``, ``on_reasoning_*``, - ``on_tool_*``, ``on_thinking_step``, ``on_step_separator``) so the - in-memory ``ContentPart[]`` projection stays in lockstep with what - the FE renders live. Pure in-memory accumulation — no DB I/O — - consumed by the streaming ``finally`` to produce the rich JSONB - persisted via ``finalize_assistant_turn``. ``None`` (the default) - is used by the anonymous / legacy code paths and is a no-op. - - Yields: - SSE-formatted strings for each event. - """ - async for sse in stream_output( - agent=agent, - config=config, - input_data=input_data, - streaming_service=streaming_service, - result=result, - step_prefix=step_prefix, - initial_step_id=initial_step_id, - initial_step_title=initial_step_title, - initial_step_items=initial_step_items, - content_builder=content_builder, - runtime_context=runtime_context, - ): - yield sse - - accumulated_text = result.accumulated_text - - state = await agent.aget_state(config) - state_values = getattr(state, "values", {}) or {} - - # Safety net: if astream_events was cancelled before - # KnowledgeBasePersistenceMiddleware.aafter_agent ran, any staged work - # (dirty_paths / staged_dirs / pending_moves / pending_deletes / - # pending_dir_deletes) will still be in the checkpointed state. Run - # the SAME shared commit helper here so the turn's writes don't get - # lost on client disconnect, then push the delta back into the graph - # using `as_node=...` so reducers fire as if the after_agent hook - # produced it. - if ( - fallback_commit_filesystem_mode == FilesystemMode.CLOUD - and fallback_commit_search_space_id is not None - and ( - (state_values.get("dirty_paths") or []) - or (state_values.get("staged_dirs") or []) - or (state_values.get("pending_moves") or []) - or (state_values.get("pending_deletes") or []) - or (state_values.get("pending_dir_deletes") or []) - ) - ): - try: - delta = await commit_staged_filesystem_state( - state_values, - search_space_id=fallback_commit_search_space_id, - created_by_id=fallback_commit_created_by_id, - filesystem_mode=fallback_commit_filesystem_mode, - thread_id=fallback_commit_thread_id, - dispatch_events=False, - ) - if delta: - await agent.aupdate_state( - config, - delta, - as_node="KnowledgeBasePersistenceMiddleware.after_agent", - ) - except Exception as exc: - _perf_log.warning("[stream_new_chat] safety-net commit failed: %s", exc) - - contract_state = state_values.get("file_operation_contract") or {} - contract_turn_id = contract_state.get("turn_id") - current_turn_id = config.get("configurable", {}).get("turn_id", "") - intent_value = contract_state.get("intent") - if ( - isinstance(intent_value, str) - and intent_value in ("chat_only", "file_write", "file_read") - and contract_turn_id == current_turn_id - ): - result.intent_detected = intent_value - if ( - isinstance(intent_value, str) - and intent_value - in ( - "chat_only", - "file_write", - "file_read", - ) - and contract_turn_id != current_turn_id - ): - # Ignore stale intent contracts from previous turns/checkpoints. - result.intent_detected = "chat_only" - result.intent_confidence = ( - _safe_float(contract_state.get("confidence"), default=0.0) - if contract_turn_id == current_turn_id - else 0.0 - ) - - if result.intent_detected == "file_write": - result.commit_gate_passed, result.commit_gate_reason = ( - _evaluate_file_contract_outcome(result) - ) - if not result.commit_gate_passed and _contract_enforcement_active(result): - gate_notice = ( - "I could not complete the requested file write because no successful " - "write_file/edit_file operation was confirmed." - ) - gate_text_id = streaming_service.generate_text_id() - yield streaming_service.format_text_start(gate_text_id) - if content_builder is not None: - content_builder.on_text_start(gate_text_id) - yield streaming_service.format_text_delta(gate_text_id, gate_notice) - if content_builder is not None: - content_builder.on_text_delta(gate_text_id, gate_notice) - yield streaming_service.format_text_end(gate_text_id) - if content_builder is not None: - content_builder.on_text_end(gate_text_id) - yield streaming_service.format_terminal_info(gate_notice, "error") - accumulated_text = gate_notice - else: - result.commit_gate_passed = True - result.commit_gate_reason = "" - - result.accumulated_text = accumulated_text - _log_file_contract("turn_outcome", result) - - pending_values = all_interrupt_values(state) - if pending_values: - result.is_interrupted = True - # One frame per paused subagent so each parallel HITL renders its own - # approval card on the wire. Order matches ``state.interrupts``, which - # the resume slicer in ``checkpointed_subagent_middleware.resume_routing`` - # consumes in the same order — keeping emit and resume in lock-step. - for interrupt_value in pending_values: - yield streaming_service.format_interrupt_request(interrupt_value) - - -async def stream_new_chat( - user_query: str, - search_space_id: int, - chat_id: int, - user_id: str | None = None, - llm_config_id: int = -1, - mentioned_document_ids: list[int] | None = None, - mentioned_folder_ids: list[int] | None = None, - mentioned_connector_ids: list[int] | None = None, - mentioned_connectors: list[dict[str, Any]] | None = None, - mentioned_documents: list[dict[str, Any]] | None = None, - checkpoint_id: str | None = None, - needs_history_bootstrap: bool = False, - thread_visibility: ChatVisibility | None = None, - current_user_display_name: str | None = None, - disabled_tools: list[str] | None = None, - filesystem_selection: FilesystemSelection | None = None, - request_id: str | None = None, - user_image_data_urls: list[str] | None = None, - flow: Literal["new", "regenerate"] = "new", -) -> AsyncGenerator[str, None]: - """ - Stream chat responses from the new SurfSense deep agent. - - This uses the Vercel AI SDK Data Stream Protocol (SSE format) for streaming. - The chat_id is used as LangGraph's thread_id for memory/checkpointing. - - The function creates and manages its own database session to guarantee proper - cleanup even when Starlette's middleware cancels the task on client disconnect. - - Args: - user_query: The user's query - search_space_id: The search space ID - chat_id: The chat ID (used as LangGraph thread_id for memory) - user_id: The current user's UUID string (for memory tools and session state) - llm_config_id: The LLM configuration ID (default: -1 for first global config) - needs_history_bootstrap: If True, load message history from DB (for cloned chats) - mentioned_document_ids: Optional list of document IDs mentioned with @ in the chat - mentioned_folder_ids: Optional list of knowledge-base folder IDs mentioned with @ (cloud mode) - checkpoint_id: Optional checkpoint ID to rewind/fork from (for edit/reload operations) - - Yields: - str: SSE formatted response strings - """ - streaming_service = VercelStreamingService() - stream_result = StreamResult() - _t_total = time.perf_counter() - fs_mode = filesystem_selection.mode.value if filesystem_selection else "cloud" - fs_platform = ( - filesystem_selection.client_platform.value if filesystem_selection else "web" - ) - stream_result.request_id = request_id - stream_result.turn_id = f"{chat_id}:{int(time.time() * 1000)}" - stream_result.filesystem_mode = fs_mode - stream_result.client_platform = fs_platform - chat_agent_mode = "unknown" - chat_outcome = "success" - chat_error_category: str | None = None - chat_span_cm = ot.chat_request_span( - chat_id=chat_id, - search_space_id=search_space_id, - flow=flow, - request_id=request_id, - turn_id=stream_result.turn_id, - filesystem_mode=fs_mode, - client_platform=fs_platform, - agent_mode=chat_agent_mode, - ) - chat_span = chat_span_cm.__enter__() - _log_file_contract("turn_start", stream_result) - _perf_log.info( - "[stream_new_chat] filesystem_mode=%s client_platform=%s", - fs_mode, - fs_platform, - ) - log_system_snapshot("stream_new_chat_START") - - from app.services.token_tracking_service import start_turn - - accumulator = start_turn() - - # Premium credit (USD micro-units) tracking state. Stores the - # amount reserved up front so we can release it on cancellation - # and finalize-debit the actual provider cost reported by LiteLLM. - _premium_reserved_micros = 0 - _premium_request_id: str | None = None - - # ``BusyError`` fires before the lock is acquired; the ``finally`` must - # not release the in-flight caller's lock. - _busy_error_raised = False - - _emit_stream_error = partial( - _emit_stream_terminal_error, - streaming_service=streaming_service, - flow=flow, - request_id=request_id, - thread_id=chat_id, - search_space_id=search_space_id, - user_id=user_id, - ) - - session = async_session_maker() - try: - # Mark AI as responding to this user for live collaboration - if user_id: - await set_ai_responding(session, chat_id, UUID(user_id)) - # Load LLM config - supports both YAML (negative IDs) and database (positive IDs) - agent_config: AgentConfig | None = None - requested_llm_config_id = llm_config_id - - async def _load_llm_bundle( - config_id: int, - ) -> tuple[Any, AgentConfig | None, str | None]: - if config_id >= 0: - loaded_agent_config = await load_agent_config( - session=session, - config_id=config_id, - search_space_id=search_space_id, - ) - if not loaded_agent_config: - return ( - None, - None, - f"Failed to load NewLLMConfig with id {config_id}", - ) - return ( - create_chat_litellm_from_agent_config(loaded_agent_config), - loaded_agent_config, - None, - ) - - loaded_llm_config = load_global_llm_config_by_id(config_id) - if not loaded_llm_config: - return None, None, f"Failed to load LLM config with id {config_id}" - return ( - create_chat_litellm_from_config(loaded_llm_config), - AgentConfig.from_yaml_config(loaded_llm_config), - None, - ) - - _t0 = time.perf_counter() - # Image-bearing turns force the Auto-pin resolver to filter the - # candidate pool to vision-capable cfgs (and force-repin a - # text-only existing pin). For explicit selections this flag is - # a no-op — the resolver returns the user's chosen id unchanged. - _requires_image_input = bool(user_image_data_urls) - try: - llm_config_id = ( - await resolve_or_get_pinned_llm_config_id( - session, - thread_id=chat_id, - search_space_id=search_space_id, - user_id=user_id, - selected_llm_config_id=llm_config_id, - requires_image_input=_requires_image_input, - ) - ).resolved_llm_config_id - ot.add_event( - "model.pin.resolved", - { - "pin.requested_id": requested_llm_config_id, - "pin.resolved_id": llm_config_id, - "pin.requires_image_input": _requires_image_input, - }, - ) - except ValueError as pin_error: - # Auto-pin's "no vision-capable cfg" path raises a ValueError - # whose message we map to the friendly image-input SSE error - # so the user sees the same message regardless of whether - # the gate fired in Auto-mode or in the agent_config check - # below. - error_code = ( - "MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT" - if _requires_image_input and "vision-capable" in str(pin_error) - else "SERVER_ERROR" - ) - error_kind = ( - "user_error" - if error_code == "MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT" - else "server_error" - ) - if error_code == "MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT": - ot.add_event( - "quota.denied", - { - "quota.code": error_code, - }, - ) - yield _emit_stream_error( - message=str(pin_error), - error_kind=error_kind, - error_code=error_code, - ) - yield streaming_service.format_done() - return - - llm, agent_config, llm_load_error = await _load_llm_bundle(llm_config_id) - if llm_load_error: - yield _emit_stream_error( - message=llm_load_error, - error_kind="server_error", - error_code="SERVER_ERROR", - ) - yield streaming_service.format_done() - return - _perf_log.info( - "[stream_new_chat] LLM config loaded in %.3fs (config_id=%s)", - time.perf_counter() - _t0, - llm_config_id, - ) - - # Capability safety net: a turn carrying user-uploaded images - # cannot be routed to a chat config that LiteLLM's authoritative - # model map *explicitly* marks as text-only (``supports_vision`` - # set to False). The check is intentionally narrow — it only - # fires when LiteLLM is *certain* the model can't accept image - # input. Unknown / unmapped / vision-capable models pass - # through. Without this guard a known-text-only model would 404 - # at the provider with ``"No endpoints found that support image - # input"``, surfacing as an opaque ``SERVER_ERROR`` SSE chunk; - # failing here lets us return a friendly message that tells the - # user what to change. - if user_image_data_urls and agent_config is not None: - from app.services.provider_capabilities import ( - is_known_text_only_chat_model, - ) - - agent_litellm_params = agent_config.litellm_params or {} - agent_base_model = ( - agent_litellm_params.get("base_model") - if isinstance(agent_litellm_params, dict) - else None - ) - if is_known_text_only_chat_model( - provider=agent_config.provider, - model_name=agent_config.model_name, - base_model=agent_base_model, - custom_provider=agent_config.custom_provider, - ): - model_label = ( - agent_config.config_name or agent_config.model_name or "model" - ) - ot.add_event( - "quota.denied", - { - "quota.code": "MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT", - }, - ) - yield _emit_stream_error( - message=( - f"The selected model ({model_label}) does not support " - "image input. Switch to a vision-capable model " - "(e.g. GPT-4o, Claude, Gemini) or remove the image " - "attachment and try again." - ), - error_kind="user_error", - error_code="MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT", - ) - yield streaming_service.format_done() - return - - # Premium quota reservation for pinned premium model only. - _needs_premium_quota = ( - agent_config is not None and user_id and agent_config.is_premium - ) - if _needs_premium_quota: - import uuid as _uuid - - from app.services.token_quota_service import ( - TokenQuotaService, - estimate_call_reserve_micros, - ) - - _premium_request_id = _uuid.uuid4().hex[:16] - _agent_litellm_params = agent_config.litellm_params or {} - _agent_base_model = ( - _agent_litellm_params.get("base_model") or agent_config.model_name or "" - ) - reserve_amount_micros = estimate_call_reserve_micros( - base_model=_agent_base_model, - quota_reserve_tokens=agent_config.quota_reserve_tokens, - ) - async with shielded_async_session() as quota_session: - quota_result = await TokenQuotaService.premium_reserve( - db_session=quota_session, - user_id=UUID(user_id), - request_id=_premium_request_id, - reserve_micros=reserve_amount_micros, - ) - _premium_reserved_micros = reserve_amount_micros - if not quota_result.allowed: - ot.add_event( - "quota.denied", - { - "quota.code": "PREMIUM_QUOTA_EXHAUSTED", - }, - ) - if requested_llm_config_id == 0: - try: - llm_config_id = ( - await resolve_or_get_pinned_llm_config_id( - session, - thread_id=chat_id, - search_space_id=search_space_id, - user_id=user_id, - selected_llm_config_id=0, - force_repin_free=True, - requires_image_input=_requires_image_input, - ) - ).resolved_llm_config_id - ot.add_event( - "model.repin", - { - "repin.reason": "premium_quota_exhausted", - "repin.to_config_id": llm_config_id, - }, - ) - except ValueError as pin_error: - yield _emit_stream_error( - message=str(pin_error), - error_kind="server_error", - error_code="SERVER_ERROR", - ) - yield streaming_service.format_done() - return - - llm, agent_config, llm_load_error = await _load_llm_bundle( - llm_config_id - ) - if llm_load_error: - yield _emit_stream_error( - message=llm_load_error, - error_kind="server_error", - error_code="SERVER_ERROR", - ) - yield streaming_service.format_done() - return - _premium_request_id = None - _premium_reserved_micros = 0 - _log_chat_stream_error( - flow=flow, - error_kind="premium_quota_exhausted", - error_code="PREMIUM_QUOTA_EXHAUSTED", - severity="info", - is_expected=True, - request_id=request_id, - thread_id=chat_id, - search_space_id=search_space_id, - user_id=user_id, - message=( - "Premium quota exhausted on pinned model; auto-fallback switched to a free model" - ), - extra={ - "fallback_config_id": llm_config_id, - "auto_fallback": True, - }, - ) - else: - yield _emit_stream_error( - message=( - "Buy more tokens to continue with this model, or switch to a free model" - ), - error_kind="premium_quota_exhausted", - error_code="PREMIUM_QUOTA_EXHAUSTED", - severity="info", - is_expected=True, - extra={ - "resolved_config_id": llm_config_id, - "auto_fallback": False, - }, - ) - yield streaming_service.format_done() - return - - if not llm: - yield _emit_stream_error( - message="Failed to create LLM instance", - error_kind="server_error", - error_code="SERVER_ERROR", - ) - yield streaming_service.format_done() - return - - # Create connector service - _t0 = time.perf_counter() - connector_service = ConnectorService(session, search_space_id=search_space_id) - - firecrawl_api_key = None - webcrawler_connector = await connector_service.get_connector_by_type( - SearchSourceConnectorType.WEBCRAWLER_CONNECTOR, search_space_id - ) - if webcrawler_connector and webcrawler_connector.config: - firecrawl_api_key = webcrawler_connector.config.get("FIRECRAWL_API_KEY") - _perf_log.info( - "[stream_new_chat] Connector service + firecrawl key in %.3fs", - time.perf_counter() - _t0, - ) - - # Get the PostgreSQL checkpointer for persistent conversation memory - _t0 = time.perf_counter() - checkpointer = await get_checkpointer() - _perf_log.info( - "[stream_new_chat] Checkpointer ready in %.3fs", time.perf_counter() - _t0 - ) - - visibility = thread_visibility or ChatVisibility.PRIVATE - chat_agent_mode = "multi" - with contextlib.suppress(Exception): - chat_span.set_attribute("agent.mode", chat_agent_mode) - - _t0 = time.perf_counter() - agent_factory = create_multi_agent_chat_deep_agent - # Build the agent inline. Provider 429s surface through the - # in-stream recovery loop below (``_is_provider_rate_limited``), - # which repins the thread to an eligible alternative config and - # rebuilds the agent before the user sees any output. - agent = await _build_main_agent_for_thread( - agent_factory, - llm=llm, - search_space_id=search_space_id, - db_session=session, - connector_service=connector_service, - checkpointer=checkpointer, - user_id=user_id, - thread_id=chat_id, - agent_config=agent_config, - firecrawl_api_key=firecrawl_api_key, - thread_visibility=visibility, - filesystem_selection=filesystem_selection, - disabled_tools=disabled_tools, - mentioned_document_ids=mentioned_document_ids, - ) - _perf_log.info( - "[stream_new_chat] Agent created in %.3fs", time.perf_counter() - _t0 - ) - - # Build input with message history - langchain_messages = [] - - _t0 = time.perf_counter() - # Bootstrap history for cloned chats (no LangGraph checkpoint exists yet) - if needs_history_bootstrap: - langchain_messages = await bootstrap_history_from_db( - session, chat_id, thread_visibility=visibility - ) - - thread_result = await session.execute( - select(NewChatThread).filter(NewChatThread.id == chat_id) - ) - thread = thread_result.scalars().first() - if thread: - thread.needs_history_bootstrap = False - await session.commit() - - # Mentioned KB documents are now handled by KnowledgeBaseSearchMiddleware - # which merges them into the scoped filesystem with full document - # structure. Only report context is inlined here. - - # Fetch the most recent report(s) in this thread so the LLM can - # easily find report_id for versioning decisions, instead of - # having to dig through conversation history. - recent_reports_result = await session.execute( - select(Report) - .filter( - Report.thread_id == chat_id, - Report.content.isnot(None), # exclude failed reports - ) - .order_by(Report.id.desc()) - .limit(3) - ) - recent_reports = list(recent_reports_result.scalars().all()) - - # Resolve @-mention chips to canonical virtual paths and rewrite - # the user-typed text so the LLM sees ``\`/documents/...\``` instead - # of bare ``@title``. The substitution lands in ``agent_user_query`` - # ONLY — the original ``user_query`` (with ``@title`` tokens) flows - # untouched into ``persist_user_turn`` below so chip rendering on - # reload still works (``UserTextPart`` → ``parseMentionSegments`` - # matches ``@title``, not ``\`/documents/...\```). It also feeds - # the human-readable surfaces — SSE "Processing X" status, auto - # thread title, memory seed — which all want what the user typed. - # See ``persistence._build_user_content``. - # - # Cloud mode only: local-folder mode keeps the legacy - # ``@title`` text path; mention support there is a follow-up - # task because the path scheme (mount-rooted) and the picker - # UI both need separate work. - agent_user_query = user_query - accepted_folder_ids: list[int] = [] - if fs_mode == FilesystemMode.CLOUD.value and ( - mentioned_document_ids or mentioned_folder_ids or mentioned_documents - ): - from app.schemas.new_chat import ( - MentionedDocumentInfo as _MentionedDocumentInfo, - ) - - chip_objs: list[_MentionedDocumentInfo] | None = None - if mentioned_documents: - chip_objs = [] - for raw in mentioned_documents: - if isinstance(raw, _MentionedDocumentInfo): - chip_objs.append(raw) - continue - try: - chip_objs.append(_MentionedDocumentInfo.model_validate(raw)) - except Exception: - logger.debug( - "stream_new_chat: dropping malformed mention chip %r", - raw, - ) - - resolved = await resolve_mentions( - session, - search_space_id=search_space_id, - mentioned_documents=chip_objs, - mentioned_document_ids=mentioned_document_ids, - mentioned_folder_ids=mentioned_folder_ids, - ) - agent_user_query = substitute_in_text(user_query, resolved.token_to_path) - accepted_folder_ids = resolved.mentioned_folder_ids - - # Format the user query with context (reports only). - # Uses ``agent_user_query`` so the LLM sees backtick-wrapped paths - # instead of bare ``@title`` tokens. - final_query = agent_user_query - context_parts = [] - - if mentioned_connectors: - connector_lines = [] - for connector in mentioned_connectors: - if not isinstance(connector, dict): - continue - connector_id = connector.get("id") - connector_type = connector.get("connector_type") or connector.get( - "document_type" - ) - account_name = connector.get("account_name") or connector.get("title") - if connector_id is None or connector_type is None: - continue - connector_lines.append( - f' - connector_id={connector_id}, connector_type="{connector_type}", ' - f'account_name="{account_name or ""}"' - ) - if connector_lines: - context_parts.append( - "\n" - "The user selected these exact connector accounts with @. " - "These entries are selection metadata, not retrieved connector content. " - "When a connector-backed tool needs an account, use the matching " - "connector_id from this list if the tool supports connector_id:\n" - + "\n".join(connector_lines) - + "\n" - ) - - # Surface report IDs prominently so the LLM doesn't have to - # retrieve them from old tool responses in conversation history. - if recent_reports: - report_lines = [] - for r in recent_reports: - report_lines.append( - f' - report_id={r.id}, title="{r.title}", ' - f'style="{r.report_style or "detailed"}"' - ) - reports_listing = "\n".join(report_lines) - context_parts.append( - "\n" - "Previously generated reports in this conversation:\n" - f"{reports_listing}\n\n" - "If the user wants to MODIFY, REVISE, UPDATE, or ADD to one of " - "these reports, set parent_report_id to the relevant report_id above.\n" - "If the user wants a completely NEW report on a different topic, " - "leave parent_report_id unset.\n" - "" - ) - - if context_parts: - context = "\n\n".join(context_parts) - final_query = f"{context}\n\n{agent_user_query}" - - if visibility == ChatVisibility.SEARCH_SPACE and current_user_display_name: - final_query = f"**[{current_user_display_name}]:** {final_query}" - - # if messages: - # # Convert frontend messages to LangChain format - # for msg in messages: - # if msg.role == "user": - # langchain_messages.append(HumanMessage(content=msg.content)) - # elif msg.role == "assistant": - # langchain_messages.append(AIMessage(content=msg.content)) - # else: - human_content = build_human_message_content( - final_query, list(user_image_data_urls or ()) - ) - langchain_messages.append(HumanMessage(content=human_content)) - - input_state = { - # Lets not pass this message atm because we are using the checkpointer to manage the conversation history - # We will use this to simulate group chat functionality in the future - "messages": langchain_messages, - "search_space_id": search_space_id, - "request_id": request_id or "unknown", - "turn_id": stream_result.turn_id, - } - - _perf_log.info( - "[stream_new_chat] History bootstrap + doc/report queries in %.3fs", - time.perf_counter() - _t0, - ) - - # All pre-streaming DB reads are done. Commit to release the - # transaction and its ACCESS SHARE locks so we don't block DDL - # (e.g. migrations) for the entire duration of LLM streaming. - # Tools that need DB access during streaming will start their own - # short-lived transactions (or use isolated sessions). - await session.commit() - - # Detach heavy ORM objects (documents with chunks, reports, etc.) - # from the session identity map now that we've extracted the data - # we need. This prevents them from accumulating in memory for the - # entire duration of LLM streaming (which can be several minutes). - session.expunge_all() - - _perf_log.info( - "[stream_new_chat] Total pre-stream setup in %.3fs (chat_id=%s)", - time.perf_counter() - _t_total, - chat_id, - ) - - # Configure LangGraph with thread_id for memory - # If checkpoint_id is provided, fork from that checkpoint (for edit/reload) - configurable = {"thread_id": str(chat_id)} - configurable["request_id"] = request_id or "unknown" - configurable["turn_id"] = stream_result.turn_id - if checkpoint_id: - configurable["checkpoint_id"] = checkpoint_id - - config = { - "configurable": configurable, - # Effectively uncapped, matching the agent-level - # ``with_config`` default in ``chat_deepagent.create_agent`` - # and the unbounded ``while(true)`` loop used by OpenCode's - # ``session/processor.ts``. Real circuit-breakers live in - # middleware: ``DoomLoopMiddleware`` (sliding-window tool - # signature check), plus ``enable_tool_call_limit`` / - # ``enable_model_call_limit`` when those flags are set. The - # original LangGraph default of 25 (and our previous 80 - # bump) hit users on legitimate multi-tool plans. - "recursion_limit": 10_000, - } - - # Start the message stream - yield streaming_service.format_message_start() - yield streaming_service.format_start_step() - - # Surface the per-turn correlation id at the very start of the - # stream so the frontend can stamp it onto the in-flight - # assistant message and replay it via ``appendMessage`` - # for durable storage. Tool/action-log events DO carry it later, - # but pure-text turns never produce action-log events; this - # event guarantees the frontend learns the turn id regardless. - yield streaming_service.format_data( - "turn-info", - {"chat_turn_id": stream_result.turn_id}, - ) - yield streaming_service.format_data("turn-status", {"status": "busy"}) - - # Persist the user-side row for this turn before any expensive - # work runs. Closes the "ghost-thread" abuse vector - # (authenticated client hits POST /new_chat then never calls - # /messages — empty new_chat_messages, free LLM completion). - # Idempotent against the unique index in migration 141 so the - # legacy frontend appendMessage call is a no-op on the second - # writer. Hard failure aborts the turn so we never produce a - # title or assistant row that isn't anchored to a persisted - # user message. - from app.tasks.chat.content_builder import AssistantContentBuilder - from app.tasks.chat.persistence import ( - persist_assistant_shell, - persist_user_turn, - ) - - user_message_id = await persist_user_turn( - chat_id=chat_id, - user_id=user_id, - turn_id=stream_result.turn_id, - user_query=user_query, - user_image_data_urls=user_image_data_urls, - mentioned_documents=mentioned_documents, - ) - if user_message_id is None: - yield _emit_stream_error( - message=( - "We couldn't save your message. Please try again in a moment." - ), - error_kind="server_error", - error_code="MESSAGE_PERSIST_FAILED", - ) - yield streaming_service.format_data("turn-status", {"status": "idle"}) - yield streaming_service.format_finish_step() - yield streaming_service.format_finish() - yield streaming_service.format_done() - return - - # Emit canonical user message id BEFORE any LLM streaming so the - # FE can rename its optimistic ``msg-user-XXX`` placeholder to - # ``msg-{user_message_id}`` and unlock features gated on a real - # DB id (comments, edit-from-this-message). See B4 in - # ``sse-based_message_id_handshake`` plan. - yield streaming_service.format_data( - "user-message-id", - {"message_id": user_message_id, "turn_id": stream_result.turn_id}, - ) - - # Pre-write the assistant row for this turn so we have a stable - # ``message_id`` to anchor mid-stream metadata (token_usage, - # future agent_action_log.message_id correlation) and a - # write-once UPDATE target at finalize time. Idempotent against - # the (thread_id, turn_id, ASSISTANT) partial unique index from - # migration 141 — if the legacy frontend appendMessage races - # this, we recover the existing row's id. - assistant_message_id = await persist_assistant_shell( - chat_id=chat_id, - user_id=user_id, - turn_id=stream_result.turn_id, - ) - if assistant_message_id is None: - # Genuine DB failure — abort the turn rather than stream - # into a void. The user row is already persisted so the - # legacy "ghost-thread" gate isn't reopened. - yield _emit_stream_error( - message=( - "We couldn't initialize the assistant message. Please try again." - ), - error_kind="server_error", - error_code="MESSAGE_PERSIST_FAILED", - ) - yield streaming_service.format_data("turn-status", {"status": "idle"}) - yield streaming_service.format_finish_step() - yield streaming_service.format_finish() - yield streaming_service.format_done() - return - - # Emit canonical assistant message id BEFORE any LLM streaming - # so the FE can rename its optimistic ``msg-assistant-XXX`` - # placeholder to ``msg-{assistant_message_id}`` and bind - # ``tokenUsageStore`` / ``pendingInterrupt`` to the real id - # immediately. See B4 in ``sse-based_message_id_handshake`` - # plan. - yield streaming_service.format_data( - "assistant-message-id", - {"message_id": assistant_message_id, "turn_id": stream_result.turn_id}, - ) - - stream_result.assistant_message_id = assistant_message_id - stream_result.content_builder = AssistantContentBuilder() - - # Initial thinking step - analyzing the request - initial_title = "Understanding your request" - action_verb = "Processing" - - processing_parts = [] - if user_query.strip(): - query_text = user_query[:80] + ("..." if len(user_query) > 80 else "") - processing_parts.append(query_text) - elif user_image_data_urls: - processing_parts.append(f"[{len(user_image_data_urls)} image(s)]") - else: - processing_parts.append("(message)") - - initial_items = [f"{action_verb}: {' '.join(processing_parts)}"] - initial_step_id = "thinking-1" - - # Drive the builder for this initial thinking step too — the - # ``_emit_thinking_step`` helper lives inside ``_stream_agent_events`` - # so it isn't in scope here, but the FE folds this step into - # the same singleton ``data-thinking-steps`` part as everything - # the agent stream emits later. Mirror that fold server-side. - if stream_result.content_builder is not None: - stream_result.content_builder.on_thinking_step( - initial_step_id, initial_title, "in_progress", initial_items - ) - yield streaming_service.format_thinking_step( - step_id=initial_step_id, - title=initial_title, - status="in_progress", - items=initial_items, - ) - - # These ORM objects can be large. They're only needed to build context - # strings already copied into final_query / langchain_messages — - # release them before streaming. - del recent_reports - del langchain_messages, final_query - - # Check if this is the first assistant response so we can generate - # a title in parallel with the agent stream (better UX than waiting - # until after the full response). - # Use a LIMIT 1 EXISTS-style probe rather than COUNT(*) because - # this is now a hot path executed on every turn, and COUNT scales - # with thread length (server-side persistence can grow rows - # quickly under power users). - # - # IMPORTANT: ``persist_assistant_shell`` above (line ~3112) already - # inserted THIS turn's assistant row. We must therefore exclude - # it from the probe — otherwise the gate fires on every turn - # except the very first, and title generation never runs for new - # threads. Excluding by primary key (``id != assistant_message_id``) - # is bulletproof regardless of ``turn_id`` shape (legacy NULLs, - # resume turns, etc.). - first_assistant_probe = await session.execute( - select(NewChatMessage.id) - .filter( - NewChatMessage.thread_id == chat_id, - NewChatMessage.role == "assistant", - NewChatMessage.id != assistant_message_id, - ) - .limit(1) - ) - is_first_response = first_assistant_probe.scalars().first() is None - - title_task: asyncio.Task[tuple[str | None, dict | None]] | None = None - # Gate title generation on a persisted user message so a stream - # that fails before persistence (we abort above) can never leave - # behind a thread with a generated title and no anchoring rows. - if is_first_response and user_message_id is not None: - - async def _generate_title() -> tuple[str | None, dict | None]: - """Generate a short title via litellm.acompletion. - - Returns (title, usage_dict). Usage is extracted directly from - the response object because litellm fires its async callback - via fire-and-forget ``create_task``, so the - ``TokenTrackingCallback`` would run too late. We also blank - the accumulator in this child-task context so the late callback - doesn't double-count. - """ - try: - from litellm import acompletion - - from app.services.llm_router_service import LLMRouterService - from app.services.provider_api_base import resolve_api_base - from app.services.token_tracking_service import _turn_accumulator - - _turn_accumulator.set(None) - - title_seed = user_query.strip() or ( - f"[{len(user_image_data_urls or [])} image(s)]" - if user_image_data_urls - else "" - ) - prompt = TITLE_GENERATION_PROMPT.replace( - "{user_query}", title_seed[:500] or "(message)" - ) - messages = [{"role": "user", "content": prompt}] - - if getattr(llm, "model", None) == "auto": - router = LLMRouterService.get_router() - response = await router.acompletion( - model="auto", messages=messages - ) - else: - # Apply the same ``api_base`` cascade chat / vision / - # image-gen call sites use so we never inherit - # ``litellm.api_base`` (commonly set by - # ``AZURE_OPENAI_ENDPOINT``) when the chat config - # itself ships an empty ``api_base``. Without this - # the title-gen on an OpenRouter chat config would - # 404 against the inherited Azure endpoint — see - # ``provider_api_base`` docstring for the same - # bug repro on the image-gen / vision paths. - raw_model = getattr(llm, "model", "") or "" - provider_prefix = ( - raw_model.split("/", 1)[0] if "/" in raw_model else None - ) - provider_value = ( - agent_config.provider if agent_config is not None else None - ) - title_api_base = resolve_api_base( - provider=provider_value, - provider_prefix=provider_prefix, - config_api_base=getattr(llm, "api_base", None), - ) - response = await acompletion( - model=raw_model, - messages=messages, - api_key=getattr(llm, "api_key", None), - api_base=title_api_base, - ) - - usage_info = None - usage = getattr(response, "usage", None) - if usage: - raw_model = getattr(llm, "model", "") or "" - model_name = ( - raw_model.split("/", 1)[-1] - if "/" in raw_model - else (raw_model or response.model or "unknown") - ) - usage_info = { - "model": model_name, - "prompt_tokens": getattr(usage, "prompt_tokens", 0) or 0, - "completion_tokens": getattr(usage, "completion_tokens", 0) - or 0, - "total_tokens": getattr(usage, "total_tokens", 0) or 0, - } - - raw_title = response.choices[0].message.content.strip() - if raw_title and len(raw_title) <= 100: - return raw_title.strip("\"'"), usage_info - return None, usage_info - except Exception: - logging.getLogger(__name__).exception( - "[TitleGen] _generate_title failed" - ) - return None, None - - title_task = asyncio.create_task(_generate_title()) - - title_emitted = False - - # Build the per-invocation runtime context (Phase 1.5). - # ``mentioned_document_ids`` is read by ``KnowledgePriorityMiddleware`` - # via ``runtime.context.mentioned_document_ids`` instead of its - # ``__init__`` closure — that way the same compiled-agent instance - # can serve multiple turns with different mention lists. - runtime_context = SurfSenseContextSchema( - search_space_id=search_space_id, - mentioned_document_ids=list(mentioned_document_ids or []), - mentioned_folder_ids=list( - accepted_folder_ids or mentioned_folder_ids or [] - ), - mentioned_connector_ids=list(mentioned_connector_ids or []), - mentioned_connectors=list(mentioned_connectors or []), - request_id=request_id, - turn_id=stream_result.turn_id, - ) - - _t_stream_start = time.perf_counter() - _first_event_logged = False - runtime_rate_limit_recovered = False - while True: - try: - async for sse in _stream_agent_events( - agent=agent, - config=config, - input_data=input_state, - streaming_service=streaming_service, - result=stream_result, - step_prefix="thinking", - initial_step_id=initial_step_id, - initial_step_title=initial_title, - initial_step_items=initial_items, - fallback_commit_search_space_id=search_space_id, - fallback_commit_created_by_id=user_id, - fallback_commit_filesystem_mode=( - filesystem_selection.mode - if filesystem_selection - else FilesystemMode.CLOUD - ), - fallback_commit_thread_id=chat_id, - runtime_context=runtime_context, - content_builder=stream_result.content_builder, - ): - if not _first_event_logged: - _perf_log.info( - "[stream_new_chat] First agent event in %.3fs (time since stream start), " - "%.3fs (total since request start) (chat_id=%s)", - time.perf_counter() - _t_stream_start, - time.perf_counter() - _t_total, - chat_id, - ) - _first_event_logged = True - yield sse - - # Inject title update mid-stream as soon as the background - # task finishes. - if ( - title_task is not None - and title_task.done() - and not title_emitted - ): - generated_title, title_usage = title_task.result() - if title_usage: - accumulator.add(**title_usage) - if generated_title: - async with shielded_async_session() as title_session: - title_thread_result = await title_session.execute( - select(NewChatThread).filter( - NewChatThread.id == chat_id - ) - ) - title_thread = title_thread_result.scalars().first() - if title_thread: - title_thread.title = generated_title - await title_session.commit() - yield streaming_service.format_thread_title_update( - chat_id, generated_title - ) - title_emitted = True - break - except Exception as stream_exc: - can_runtime_recover = ( - not runtime_rate_limit_recovered - and requested_llm_config_id == 0 - and llm_config_id < 0 - and not _first_event_logged - and _is_provider_rate_limited(stream_exc) - ) - if not can_runtime_recover: - raise - - runtime_rate_limit_recovered = True - previous_config_id = llm_config_id - # The failed attempt may still hold the per-thread busy mutex - # (middleware teardown can lag behind raised provider errors). - # Force release before we retry within the same request. - end_turn(str(chat_id)) - mark_runtime_cooldown( - previous_config_id, - reason="provider_rate_limited", - ) - - llm_config_id = ( - await resolve_or_get_pinned_llm_config_id( - session, - thread_id=chat_id, - search_space_id=search_space_id, - user_id=user_id, - selected_llm_config_id=0, - exclude_config_ids={previous_config_id}, - requires_image_input=_requires_image_input, - ) - ).resolved_llm_config_id - - llm, agent_config, llm_load_error = await _load_llm_bundle( - llm_config_id - ) - if llm_load_error: - raise stream_exc - - # Title generation uses the initial llm object. After a runtime - # repin we keep the stream focused on response recovery and skip - # title generation for this turn. - if title_task is not None and not title_task.done(): - title_task.cancel() - title_task = None - - _t0 = time.perf_counter() - agent = await _build_main_agent_for_thread( - agent_factory, - llm=llm, - search_space_id=search_space_id, - db_session=session, - connector_service=connector_service, - checkpointer=checkpointer, - user_id=user_id, - thread_id=chat_id, - agent_config=agent_config, - firecrawl_api_key=firecrawl_api_key, - thread_visibility=visibility, - filesystem_selection=filesystem_selection, - disabled_tools=disabled_tools, - mentioned_document_ids=mentioned_document_ids, - ) - _perf_log.info( - "[stream_new_chat] Runtime rate-limit recovery repinned " - "config_id=%s -> %s and rebuilt agent in %.3fs", - previous_config_id, - llm_config_id, - time.perf_counter() - _t0, - ) - ot.add_event( - "chat.rate_limit.recovered", - { - "recovery.reason": "provider_rate_limited", - "recovery.previous_config_id": previous_config_id, - "recovery.fallback_config_id": llm_config_id, - }, - ) - _log_chat_stream_error( - flow=flow, - error_kind="rate_limited", - error_code="RATE_LIMITED", - severity="info", - is_expected=True, - request_id=request_id, - thread_id=chat_id, - search_space_id=search_space_id, - user_id=user_id, - message=( - "Auto-pinned model hit runtime rate limit; switched to " - "another eligible model and retried." - ), - extra={ - "auto_runtime_recover": True, - "previous_config_id": previous_config_id, - "fallback_config_id": llm_config_id, - }, - ) - continue - - _perf_log.info( - "[stream_new_chat] Agent stream completed in %.3fs (chat_id=%s)", - time.perf_counter() - _t_stream_start, - chat_id, - ) - log_system_snapshot("stream_new_chat_END") - - if stream_result.is_interrupted: - ot.add_event( - "chat.interrupted", - { - "chat.flow": flow, - }, - ) - if title_task is not None and not title_task.done(): - title_task.cancel() - - usage_summary = accumulator.per_message_summary() - _perf_log.info( - "[token_usage] interrupted new_chat: calls=%d total=%d cost_micros=%d summary=%s", - len(accumulator.calls), - accumulator.grand_total, - accumulator.total_cost_micros, - usage_summary, - ) - if usage_summary: - yield streaming_service.format_data( - "token-usage", - { - "usage": usage_summary, - "prompt_tokens": accumulator.total_prompt_tokens, - "completion_tokens": accumulator.total_completion_tokens, - "total_tokens": accumulator.grand_total, - "cost_micros": accumulator.total_cost_micros, - "call_details": accumulator.serialized_calls(), - }, - ) - - yield streaming_service.format_finish_step() - yield streaming_service.format_finish() - yield streaming_service.format_done() - return - - # If the title task didn't finish during streaming, await it now - if title_task is not None and not title_emitted: - generated_title, title_usage = await title_task - if title_usage: - accumulator.add(**title_usage) - if generated_title: - async with shielded_async_session() as title_session: - title_thread_result = await title_session.execute( - select(NewChatThread).filter(NewChatThread.id == chat_id) - ) - title_thread = title_thread_result.scalars().first() - if title_thread: - title_thread.title = generated_title - await title_session.commit() - yield streaming_service.format_thread_title_update( - chat_id, generated_title - ) - - # Finalize premium credit debit with the actual provider cost - # reported by LiteLLM, summed across every call in the turn. - # Mirrors the pre-cost behaviour of "premium turn → all calls - # count" so free sub-agent calls during a premium turn still - # contribute to the bill (they're $0 in practice anyway). - if _premium_request_id and user_id: - try: - from app.services.token_quota_service import TokenQuotaService - - async with shielded_async_session() as quota_session: - await TokenQuotaService.premium_finalize( - db_session=quota_session, - user_id=UUID(user_id), - request_id=_premium_request_id, - actual_micros=accumulator.total_cost_micros, - reserved_micros=_premium_reserved_micros, - ) - _premium_request_id = None - _premium_reserved_micros = 0 - except Exception: - logging.getLogger(__name__).warning( - "Failed to finalize premium quota for user %s", - user_id, - exc_info=True, - ) - - usage_summary = accumulator.per_message_summary() - _perf_log.info( - "[token_usage] normal new_chat: calls=%d total=%d cost_micros=%d summary=%s", - len(accumulator.calls), - accumulator.grand_total, - accumulator.total_cost_micros, - usage_summary, - ) - if usage_summary: - yield streaming_service.format_data( - "token-usage", - { - "usage": usage_summary, - "prompt_tokens": accumulator.total_prompt_tokens, - "completion_tokens": accumulator.total_completion_tokens, - "total_tokens": accumulator.grand_total, - "cost_micros": accumulator.total_cost_micros, - "call_details": accumulator.serialized_calls(), - }, - ) - - # Finish the step and message - yield streaming_service.format_data("turn-status", {"status": "idle"}) - yield streaming_service.format_finish_step() - yield streaming_service.format_finish() - yield streaming_service.format_done() - - except Exception as e: - # Handle any errors - import traceback - - # ``BusyError`` fires before the agent acquires the lock; the - # cleanup path must skip lock release to avoid freeing the - # in-flight caller's lock. Classification is handled below. - if isinstance(e, BusyError): - _busy_error_raised = True - - ( - error_kind, - error_code, - severity, - is_expected, - user_message, - error_extra, - ) = _classify_stream_exception(e, flow_label="chat") - chat_outcome = error_code or error_kind or "error" - chat_error_category = ot_metrics.categorize_exception(e) - with contextlib.suppress(Exception): - chat_span.set_attribute("chat.outcome", chat_outcome) - chat_span.set_attribute("error.category", chat_error_category) - ot.record_error(chat_span, e) - error_message = f"Error during chat: {e!s}" - print(f"[stream_new_chat] {error_message}") - print(f"[stream_new_chat] Exception type: {type(e).__name__}") - print(f"[stream_new_chat] Traceback:\n{traceback.format_exc()}") - if error_code == "TURN_CANCELLING": - status_payload: dict[str, Any] = {"status": "cancelling"} - if error_extra: - status_payload.update(error_extra) - yield streaming_service.format_data("turn-status", status_payload) - else: - yield streaming_service.format_data("turn-status", {"status": "busy"}) - - yield _emit_stream_error( - message=user_message, - error_kind=error_kind, - error_code=error_code, - severity=severity, - is_expected=is_expected, - extra=error_extra, - ) - yield streaming_service.format_data("turn-status", {"status": "idle"}) - yield streaming_service.format_finish_step() - yield streaming_service.format_finish() - yield streaming_service.format_done() - - finally: - # Shield the ENTIRE async cleanup from anyio cancel-scope - # cancellation. Starlette's BaseHTTPMiddleware uses anyio task - # groups; on client disconnect, it cancels the scope with - # level-triggered cancellation — every unshielded `await` inside - # the cancelled scope raises CancelledError immediately. Without - # this shield the very first `await` (session.rollback) would - # raise CancelledError, `except Exception` wouldn't catch it - # (CancelledError is a BaseException), and the rest of the - # finally block — including session.close() — would never run. - with anyio.CancelScope(shield=True): - # Authoritative fallback cleanup for lock/cancel state. Middleware - # teardown can be skipped on some client-abort paths. - end_turn(str(chat_id)) - - # Release premium reservation if not finalized - if _premium_request_id and _premium_reserved_micros > 0 and user_id: - try: - from app.services.token_quota_service import TokenQuotaService - - async with shielded_async_session() as quota_session: - await TokenQuotaService.premium_release( - db_session=quota_session, - user_id=UUID(user_id), - reserved_micros=_premium_reserved_micros, - ) - _premium_reserved_micros = 0 - except Exception: - logging.getLogger(__name__).warning( - "Failed to release premium quota for user %s", user_id - ) - - try: - await session.rollback() - await clear_ai_responding(session, chat_id) - except Exception: - try: - async with shielded_async_session() as fresh_session: - await clear_ai_responding(fresh_session, chat_id) - except Exception: - logging.getLogger(__name__).warning( - "Failed to clear AI responding state for thread %s", chat_id - ) - - with contextlib.suppress(Exception): - session.expunge_all() - - with contextlib.suppress(Exception): - await session.close() - - # Server-side assistant-message + token_usage finalization. - # Runs after the main session has been closed (uses its own - # shielded session) so we don't fight the same DB connection. - # Idempotent against the legacy frontend appendMessage: - # * the assistant row was already INSERTed by - # ``persist_assistant_shell`` above, so this just UPDATEs - # it with the rich ContentPart[] from the builder. - # * token_usage uses INSERT ... ON CONFLICT DO NOTHING - # against migration 142's partial unique index, so a - # racing append_message recovery branch can never - # double-write. - # ``mark_interrupted`` closes any open text/reasoning blocks - # and flips running tool-calls (no result) to state=aborted - # so the persisted JSONB reflects a coherent end-state even - # on client disconnect. - # Never raises (best-effort, logs only). - if ( - stream_result - and stream_result.turn_id - and stream_result.assistant_message_id - ): - from app.tasks.chat.persistence import finalize_assistant_turn - - builder_stats: dict[str, int] | None = None - if stream_result.content_builder is not None: - stream_result.content_builder.mark_interrupted() - # Snapshot stats BEFORE deepcopy in ``snapshot()`` so - # the perf log records the actual finalised payload - # (post-mark_interrupted), not the live-mutating - # builder state. - builder_stats = stream_result.content_builder.stats() - content_payload = stream_result.content_builder.snapshot() - else: - # Defensive fallback — we always set the builder - # alongside ``assistant_message_id`` above, so this - # branch only fires if a future refactor ever - # decouples them. Persist whatever accumulated - # text we captured so the row at least renders. - content_payload = [ - { - "type": "text", - "text": stream_result.accumulated_text or "", - } - ] - - if builder_stats is not None: - _perf_log.info( - "[stream_new_chat] finalize_payload chat_id=%s " - "message_id=%s parts=%d bytes=%d text=%d " - "reasoning=%d tool_calls=%d " - "tool_calls_completed=%d tool_calls_aborted=%d " - "thinking_step_parts=%d step_separators=%d", - chat_id, - stream_result.assistant_message_id, - builder_stats["parts"], - builder_stats["bytes"], - builder_stats["text"], - builder_stats["reasoning"], - builder_stats["tool_calls"], - builder_stats["tool_calls_completed"], - builder_stats["tool_calls_aborted"], - builder_stats["thinking_step_parts"], - builder_stats["step_separators"], - ) - - await finalize_assistant_turn( - message_id=stream_result.assistant_message_id, - chat_id=chat_id, - search_space_id=search_space_id, - user_id=user_id, - turn_id=stream_result.turn_id, - content=content_payload, - accumulator=accumulator, - ) - - # Persist any sandbox-produced files to local storage so they - # remain downloadable after the Daytona sandbox auto-deletes. - if stream_result and stream_result.sandbox_files: - with contextlib.suppress(Exception): - from app.agents.shared.sandbox import ( - is_sandbox_enabled, - persist_and_delete_sandbox, - ) - - if is_sandbox_enabled(): - with anyio.CancelScope(shield=True): - await persist_and_delete_sandbox( - chat_id, stream_result.sandbox_files - ) - - # ``aafter_agent`` doesn't fire on ``interrupt()`` or early bailout. - # Skip on ``BusyError`` (caller never acquired the lock). - if not _busy_error_raised: - with contextlib.suppress(Exception): - end_turn(str(chat_id)) - _perf_log.info( - "[stream_new_chat] end_turn cleanup (chat_id=%s)", - chat_id, - ) - - # Break circular refs held by the agent graph, tools, and LLM - # wrappers so the GC can reclaim them in a single pass. - agent = llm = connector_service = None - input_state = stream_result = None - session = None - - collected = gc.collect(0) + gc.collect(1) + gc.collect(2) - if collected: - _perf_log.info( - "[stream_new_chat] gc.collect() reclaimed %d objects (chat_id=%s)", - collected, - chat_id, - ) - trim_native_heap() - log_system_snapshot("stream_new_chat_END") - with contextlib.suppress(Exception): - chat_span.set_attribute("chat.outcome", chat_outcome) - ot_metrics.record_chat_request_duration( - (time.perf_counter() - _t_total) * 1000, - flow=flow, - outcome=chat_outcome, - agent_mode=chat_agent_mode, - ) - ot_metrics.record_chat_request_outcome( - flow=flow, - outcome=chat_outcome, - agent_mode=chat_agent_mode, - error_category=chat_error_category, - ) - chat_span_cm.__exit__(*sys.exc_info()) - - -async def stream_resume_chat( - chat_id: int, - search_space_id: int, - decisions: list[dict], - user_id: str | None = None, - llm_config_id: int = -1, - thread_visibility: ChatVisibility | None = None, - filesystem_selection: FilesystemSelection | None = None, - request_id: str | None = None, - disabled_tools: list[str] | None = None, -) -> AsyncGenerator[str, None]: - streaming_service = VercelStreamingService() - stream_result = StreamResult() - _t_total = time.perf_counter() - fs_mode = filesystem_selection.mode.value if filesystem_selection else "cloud" - fs_platform = ( - filesystem_selection.client_platform.value if filesystem_selection else "web" - ) - stream_result.request_id = request_id - stream_result.turn_id = f"{chat_id}:{int(time.time() * 1000)}" - stream_result.filesystem_mode = fs_mode - stream_result.client_platform = fs_platform - chat_agent_mode = "unknown" - chat_outcome = "success" - chat_error_category: str | None = None - chat_span_cm = ot.chat_request_span( - chat_id=chat_id, - search_space_id=search_space_id, - flow="resume", - request_id=request_id, - turn_id=stream_result.turn_id, - filesystem_mode=fs_mode, - client_platform=fs_platform, - agent_mode=chat_agent_mode, - ) - chat_span = chat_span_cm.__enter__() - _log_file_contract("turn_start", stream_result) - _perf_log.info( - "[stream_resume] filesystem_mode=%s client_platform=%s", - fs_mode, - fs_platform, - ) - from app.services.token_tracking_service import start_turn - - accumulator = start_turn() - - # Skip the finally release on ``BusyError`` (caller never acquired the lock). - _busy_error_raised = False - - _emit_stream_error = partial( - _emit_stream_terminal_error, - streaming_service=streaming_service, - flow="resume", - request_id=request_id, - thread_id=chat_id, - search_space_id=search_space_id, - user_id=user_id, - ) - - session = async_session_maker() - try: - if user_id: - await set_ai_responding(session, chat_id, UUID(user_id)) - - agent_config: AgentConfig | None = None - requested_llm_config_id = llm_config_id - - async def _load_llm_bundle( - config_id: int, - ) -> tuple[Any, AgentConfig | None, str | None]: - if config_id >= 0: - loaded_agent_config = await load_agent_config( - session=session, - config_id=config_id, - search_space_id=search_space_id, - ) - if not loaded_agent_config: - return ( - None, - None, - f"Failed to load NewLLMConfig with id {config_id}", - ) - return ( - create_chat_litellm_from_agent_config(loaded_agent_config), - loaded_agent_config, - None, - ) - - loaded_llm_config = load_global_llm_config_by_id(config_id) - if not loaded_llm_config: - return None, None, f"Failed to load LLM config with id {config_id}" - return ( - create_chat_litellm_from_config(loaded_llm_config), - AgentConfig.from_yaml_config(loaded_llm_config), - None, - ) - - _t0 = time.perf_counter() - try: - llm_config_id = ( - await resolve_or_get_pinned_llm_config_id( - session, - thread_id=chat_id, - search_space_id=search_space_id, - user_id=user_id, - selected_llm_config_id=llm_config_id, - ) - ).resolved_llm_config_id - ot.add_event( - "model.pin.resolved", - { - "pin.requested_id": requested_llm_config_id, - "pin.resolved_id": llm_config_id, - "pin.requires_image_input": False, - }, - ) - except ValueError as pin_error: - yield _emit_stream_error( - message=str(pin_error), - error_kind="server_error", - error_code="SERVER_ERROR", - ) - yield streaming_service.format_done() - return - - llm, agent_config, llm_load_error = await _load_llm_bundle(llm_config_id) - if llm_load_error: - yield _emit_stream_error( - message=llm_load_error, - error_kind="server_error", - error_code="SERVER_ERROR", - ) - yield streaming_service.format_done() - return - _perf_log.info( - "[stream_resume] LLM config loaded in %.3fs", time.perf_counter() - _t0 - ) - - # Premium credit reservation (same logic as stream_new_chat). - _resume_premium_reserved_micros = 0 - _resume_premium_request_id: str | None = None - _resume_needs_premium = ( - agent_config is not None and user_id and agent_config.is_premium - ) - if _resume_needs_premium: - import uuid as _uuid - - from app.services.token_quota_service import ( - TokenQuotaService, - estimate_call_reserve_micros, - ) - - _resume_premium_request_id = _uuid.uuid4().hex[:16] - _resume_litellm_params = agent_config.litellm_params or {} - _resume_base_model = ( - _resume_litellm_params.get("base_model") - or agent_config.model_name - or "" - ) - reserve_amount_micros = estimate_call_reserve_micros( - base_model=_resume_base_model, - quota_reserve_tokens=agent_config.quota_reserve_tokens, - ) - async with shielded_async_session() as quota_session: - quota_result = await TokenQuotaService.premium_reserve( - db_session=quota_session, - user_id=UUID(user_id), - request_id=_resume_premium_request_id, - reserve_micros=reserve_amount_micros, - ) - _resume_premium_reserved_micros = reserve_amount_micros - if not quota_result.allowed: - ot.add_event( - "quota.denied", - { - "quota.code": "PREMIUM_QUOTA_EXHAUSTED", - }, - ) - if requested_llm_config_id == 0: - try: - llm_config_id = ( - await resolve_or_get_pinned_llm_config_id( - session, - thread_id=chat_id, - search_space_id=search_space_id, - user_id=user_id, - selected_llm_config_id=0, - force_repin_free=True, - ) - ).resolved_llm_config_id - ot.add_event( - "model.repin", - { - "repin.reason": "premium_quota_exhausted", - "repin.to_config_id": llm_config_id, - }, - ) - except ValueError as pin_error: - yield _emit_stream_error( - message=str(pin_error), - error_kind="server_error", - error_code="SERVER_ERROR", - ) - yield streaming_service.format_done() - return - - llm, agent_config, llm_load_error = await _load_llm_bundle( - llm_config_id - ) - if llm_load_error: - yield _emit_stream_error( - message=llm_load_error, - error_kind="server_error", - error_code="SERVER_ERROR", - ) - yield streaming_service.format_done() - return - _resume_premium_request_id = None - _resume_premium_reserved_micros = 0 - _log_chat_stream_error( - flow="resume", - error_kind="premium_quota_exhausted", - error_code="PREMIUM_QUOTA_EXHAUSTED", - severity="info", - is_expected=True, - request_id=request_id, - thread_id=chat_id, - search_space_id=search_space_id, - user_id=user_id, - message=( - "Premium quota exhausted on pinned model; auto-fallback switched to a free model" - ), - extra={ - "fallback_config_id": llm_config_id, - "auto_fallback": True, - }, - ) - else: - yield _emit_stream_error( - message=( - "Buy more tokens to continue with this model, or switch to a free model" - ), - error_kind="premium_quota_exhausted", - error_code="PREMIUM_QUOTA_EXHAUSTED", - severity="info", - is_expected=True, - extra={ - "resolved_config_id": llm_config_id, - "auto_fallback": False, - }, - ) - yield streaming_service.format_done() - return - - if not llm: - yield _emit_stream_error( - message="Failed to create LLM instance", - error_kind="server_error", - error_code="SERVER_ERROR", - ) - yield streaming_service.format_done() - return - - _t0 = time.perf_counter() - connector_service = ConnectorService(session, search_space_id=search_space_id) - - firecrawl_api_key = None - webcrawler_connector = await connector_service.get_connector_by_type( - SearchSourceConnectorType.WEBCRAWLER_CONNECTOR, search_space_id - ) - if webcrawler_connector and webcrawler_connector.config: - firecrawl_api_key = webcrawler_connector.config.get("FIRECRAWL_API_KEY") - _perf_log.info( - "[stream_resume] Connector service + firecrawl key in %.3fs", - time.perf_counter() - _t0, - ) - - _t0 = time.perf_counter() - checkpointer = await get_checkpointer() - _perf_log.info( - "[stream_resume] Checkpointer ready in %.3fs", time.perf_counter() - _t0 - ) - - visibility = thread_visibility or ChatVisibility.PRIVATE - chat_agent_mode = "multi" - with contextlib.suppress(Exception): - chat_span.set_attribute("agent.mode", chat_agent_mode) - _t0 = time.perf_counter() - agent_factory = create_multi_agent_chat_deep_agent - # Build the agent inline. Provider 429s are handled by the - # in-stream recovery loop, which repins to an eligible - # alternative config and rebuilds the agent before the user sees - # any output. - agent = await _build_main_agent_for_thread( - agent_factory, - llm=llm, - search_space_id=search_space_id, - db_session=session, - connector_service=connector_service, - checkpointer=checkpointer, - user_id=user_id, - thread_id=chat_id, - agent_config=agent_config, - firecrawl_api_key=firecrawl_api_key, - thread_visibility=visibility, - filesystem_selection=filesystem_selection, - disabled_tools=disabled_tools, - ) - _perf_log.info( - "[stream_resume] Agent created in %.3fs", time.perf_counter() - _t0 - ) - - # Release the transaction before streaming (same rationale as stream_new_chat). - await session.commit() - session.expunge_all() - - _perf_log.info( - "[stream_resume] Total pre-stream setup in %.3fs (chat_id=%s)", - time.perf_counter() - _t_total, - chat_id, - ) - - from langgraph.types import Command - - from app.agents.multi_agent_chat.middleware.main_agent.checkpointed_subagent_middleware.resume_routing import ( - build_lg_resume_map, - collect_pending_tool_calls, - slice_decisions_by_tool_call, - ) - - # Each pending interrupt is stamped with its originating ``tool_call_id`` - # (see ``checkpointed_subagent_middleware.propagation``) so we can route - # a flat ``decisions`` list back to the right paused subagent. - parent_state = await agent.aget_state( - {"configurable": {"thread_id": str(chat_id)}} - ) - pending = collect_pending_tool_calls(parent_state) - _perf_log.info( - "[hitl_route] resume_entry chat_id=%s decisions=%d pending_subagents=%d", - chat_id, - len(decisions), - len(pending), - ) - routed_resume_value = slice_decisions_by_tool_call(decisions, pending) - # Langgraph rejects scalar ``Command(resume=...)`` when multiple - # interrupts are pending (parallel HITL); the mapped form works - # for the single-pause case too, so we always use it. - lg_resume_map = build_lg_resume_map(parent_state, routed_resume_value) - - config = { - "configurable": { - "thread_id": str(chat_id), - "request_id": request_id or "unknown", - "turn_id": stream_result.turn_id, - # Per-``tool_call_id`` resume slices read by - # ``SurfSenseCheckpointedSubAgentMiddleware``. Parallel - # siblings each pop their own entry, so they never race. - "surfsense_resume_value": routed_resume_value, - }, - # See ``stream_new_chat`` above for rationale: effectively - # uncapped to mirror the agent default and OpenCode's - # session loop. Doom-loop / call-limit middleware enforce - # the real ceiling. - "recursion_limit": 10_000, - } - - yield streaming_service.format_message_start() - yield streaming_service.format_start_step() - # Same rationale as ``stream_new_chat``: emit the turn id so - # resumed streams can be persisted with their correlation id - # intact. - yield streaming_service.format_data( - "turn-info", - {"chat_turn_id": stream_result.turn_id}, - ) - yield streaming_service.format_data("turn-status", {"status": "busy"}) - - # Pre-write a fresh assistant row for this resume turn. The - # original (interrupted) ``stream_new_chat`` invocation already - # persisted its own assistant row anchored to a different - # ``turn_id``; resume allocates a new ``turn_id`` (above) so we - # need a separate row keyed on the same ``(thread_id, turn_id, - # ASSISTANT)`` invariant. Idempotent against migration 141's - # partial unique index — recovers existing id on retry. - from app.tasks.chat.content_builder import AssistantContentBuilder - from app.tasks.chat.persistence import persist_assistant_shell - - assistant_message_id = await persist_assistant_shell( - chat_id=chat_id, - user_id=user_id, - turn_id=stream_result.turn_id, - ) - if assistant_message_id is None: - yield _emit_stream_error( - message=( - "We couldn't initialize the assistant message. Please try again." - ), - error_kind="server_error", - error_code="MESSAGE_PERSIST_FAILED", - ) - yield streaming_service.format_data("turn-status", {"status": "idle"}) - yield streaming_service.format_finish_step() - yield streaming_service.format_finish() - yield streaming_service.format_done() - return - - # Emit canonical assistant message id BEFORE any LLM streaming - # so the FE can rename ``pendingInterrupt.assistantMsgId`` to - # ``msg-{assistant_message_id}`` immediately. Resume does NOT - # emit ``data-user-message-id`` because the user row is from - # the original interrupted turn (different ``turn_id``) and is - # never re-persisted here. See B5 in the - # ``sse-based_message_id_handshake`` plan. - yield streaming_service.format_data( - "assistant-message-id", - {"message_id": assistant_message_id, "turn_id": stream_result.turn_id}, - ) - - stream_result.assistant_message_id = assistant_message_id - stream_result.content_builder = AssistantContentBuilder() - - # Resume path doesn't carry new ``mentioned_document_ids`` — - # those are seeded in the original turn. We still pass a - # context so future middleware extensions (Phase 2) can rely on - # ``runtime.context`` always being populated. - runtime_context = SurfSenseContextSchema( - search_space_id=search_space_id, - request_id=request_id, - turn_id=stream_result.turn_id, - ) - - _t_stream_start = time.perf_counter() - _first_event_logged = False - runtime_rate_limit_recovered = False - while True: - try: - async for sse in _stream_agent_events( - agent=agent, - config=config, - input_data=Command(resume=lg_resume_map), - streaming_service=streaming_service, - result=stream_result, - step_prefix=_resume_step_prefix(stream_result.turn_id), - fallback_commit_search_space_id=search_space_id, - fallback_commit_created_by_id=user_id, - fallback_commit_filesystem_mode=( - filesystem_selection.mode - if filesystem_selection - else FilesystemMode.CLOUD - ), - fallback_commit_thread_id=chat_id, - runtime_context=runtime_context, - content_builder=stream_result.content_builder, - ): - if not _first_event_logged: - _perf_log.info( - "[stream_resume] First agent event in %.3fs (stream), %.3fs (total) (chat_id=%s)", - time.perf_counter() - _t_stream_start, - time.perf_counter() - _t_total, - chat_id, - ) - _first_event_logged = True - yield sse - break - except Exception as stream_exc: - can_runtime_recover = ( - not runtime_rate_limit_recovered - and requested_llm_config_id == 0 - and llm_config_id < 0 - and not _first_event_logged - and _is_provider_rate_limited(stream_exc) - ) - if not can_runtime_recover: - raise - - runtime_rate_limit_recovered = True - previous_config_id = llm_config_id - # Ensure the same-request recovery retry does not trip the - # BusyMutex lock retained by the failed attempt. - end_turn(str(chat_id)) - mark_runtime_cooldown( - previous_config_id, - reason="provider_rate_limited", - ) - llm_config_id = ( - await resolve_or_get_pinned_llm_config_id( - session, - thread_id=chat_id, - search_space_id=search_space_id, - user_id=user_id, - selected_llm_config_id=0, - exclude_config_ids={previous_config_id}, - ) - ).resolved_llm_config_id - - llm, agent_config, llm_load_error = await _load_llm_bundle( - llm_config_id - ) - if llm_load_error: - raise stream_exc - - _t0 = time.perf_counter() - agent = await _build_main_agent_for_thread( - agent_factory, - llm=llm, - search_space_id=search_space_id, - db_session=session, - connector_service=connector_service, - checkpointer=checkpointer, - user_id=user_id, - thread_id=chat_id, - agent_config=agent_config, - firecrawl_api_key=firecrawl_api_key, - thread_visibility=visibility, - filesystem_selection=filesystem_selection, - disabled_tools=disabled_tools, - ) - _perf_log.info( - "[stream_resume] Runtime rate-limit recovery repinned " - "config_id=%s -> %s and rebuilt agent in %.3fs", - previous_config_id, - llm_config_id, - time.perf_counter() - _t0, - ) - ot.add_event( - "chat.rate_limit.recovered", - { - "recovery.reason": "provider_rate_limited", - "recovery.previous_config_id": previous_config_id, - "recovery.fallback_config_id": llm_config_id, - }, - ) - _log_chat_stream_error( - flow="resume", - error_kind="rate_limited", - error_code="RATE_LIMITED", - severity="info", - is_expected=True, - request_id=request_id, - thread_id=chat_id, - search_space_id=search_space_id, - user_id=user_id, - message=( - "Auto-pinned model hit runtime rate limit; switched to " - "another eligible model and retried." - ), - extra={ - "auto_runtime_recover": True, - "previous_config_id": previous_config_id, - "fallback_config_id": llm_config_id, - }, - ) - continue - _perf_log.info( - "[stream_resume] Agent stream completed in %.3fs (chat_id=%s)", - time.perf_counter() - _t_stream_start, - chat_id, - ) - if stream_result.is_interrupted: - ot.add_event( - "chat.interrupted", - { - "chat.flow": "resume", - }, - ) - usage_summary = accumulator.per_message_summary() - _perf_log.info( - "[token_usage] interrupted resume_chat: calls=%d total=%d cost_micros=%d summary=%s", - len(accumulator.calls), - accumulator.grand_total, - accumulator.total_cost_micros, - usage_summary, - ) - if usage_summary: - yield streaming_service.format_data( - "token-usage", - { - "usage": usage_summary, - "prompt_tokens": accumulator.total_prompt_tokens, - "completion_tokens": accumulator.total_completion_tokens, - "total_tokens": accumulator.grand_total, - "cost_micros": accumulator.total_cost_micros, - "call_details": accumulator.serialized_calls(), - }, - ) - - yield streaming_service.format_finish_step() - yield streaming_service.format_finish() - yield streaming_service.format_done() - return - - # Finalize premium credit debit for resume path with the actual - # provider cost reported by LiteLLM (sum of cost across all - # calls in the turn). - if _resume_premium_request_id and user_id: - try: - from app.services.token_quota_service import TokenQuotaService - - async with shielded_async_session() as quota_session: - await TokenQuotaService.premium_finalize( - db_session=quota_session, - user_id=UUID(user_id), - request_id=_resume_premium_request_id, - actual_micros=accumulator.total_cost_micros, - reserved_micros=_resume_premium_reserved_micros, - ) - _resume_premium_request_id = None - _resume_premium_reserved_micros = 0 - except Exception: - logging.getLogger(__name__).warning( - "Failed to finalize premium quota for user %s (resume)", - user_id, - exc_info=True, - ) - - usage_summary = accumulator.per_message_summary() - _perf_log.info( - "[token_usage] normal resume_chat: calls=%d total=%d cost_micros=%d summary=%s", - len(accumulator.calls), - accumulator.grand_total, - accumulator.total_cost_micros, - usage_summary, - ) - if usage_summary: - yield streaming_service.format_data( - "token-usage", - { - "usage": usage_summary, - "prompt_tokens": accumulator.total_prompt_tokens, - "completion_tokens": accumulator.total_completion_tokens, - "total_tokens": accumulator.grand_total, - "cost_micros": accumulator.total_cost_micros, - "call_details": accumulator.serialized_calls(), - }, - ) - - yield streaming_service.format_data("turn-status", {"status": "idle"}) - yield streaming_service.format_finish_step() - yield streaming_service.format_finish() - yield streaming_service.format_done() - - except Exception as e: - import traceback - - # ``BusyError`` fires before the agent acquires the lock; the - # cleanup path must skip lock release to avoid freeing the - # in-flight caller's lock. Classification is handled below. - if isinstance(e, BusyError): - _busy_error_raised = True - - ( - error_kind, - error_code, - severity, - is_expected, - user_message, - error_extra, - ) = _classify_stream_exception(e, flow_label="resume") - chat_outcome = error_code or error_kind or "error" - chat_error_category = ot_metrics.categorize_exception(e) - with contextlib.suppress(Exception): - chat_span.set_attribute("chat.outcome", chat_outcome) - chat_span.set_attribute("error.category", chat_error_category) - ot.record_error(chat_span, e) - error_message = f"Error during resume: {e!s}" - print(f"[stream_resume_chat] {error_message}") - print(f"[stream_resume_chat] Traceback:\n{traceback.format_exc()}") - if error_code == "TURN_CANCELLING": - status_payload: dict[str, Any] = {"status": "cancelling"} - if error_extra: - status_payload.update(error_extra) - yield streaming_service.format_data("turn-status", status_payload) - else: - yield streaming_service.format_data("turn-status", {"status": "busy"}) - yield _emit_stream_error( - message=user_message, - error_kind=error_kind, - error_code=error_code, - severity=severity, - is_expected=is_expected, - extra=error_extra, - ) - yield streaming_service.format_data("turn-status", {"status": "idle"}) - yield streaming_service.format_finish_step() - yield streaming_service.format_finish() - yield streaming_service.format_done() - - finally: - with anyio.CancelScope(shield=True): - # Authoritative fallback cleanup for lock/cancel state. Middleware - # teardown can be skipped on some client-abort paths. - end_turn(str(chat_id)) - - # Release premium reservation if not finalized - if ( - _resume_premium_request_id - and _resume_premium_reserved_micros > 0 - and user_id - ): - try: - from app.services.token_quota_service import TokenQuotaService - - async with shielded_async_session() as quota_session: - await TokenQuotaService.premium_release( - db_session=quota_session, - user_id=UUID(user_id), - reserved_micros=_resume_premium_reserved_micros, - ) - _resume_premium_reserved_micros = 0 - except Exception: - logging.getLogger(__name__).warning( - "Failed to release premium quota for user %s (resume)", user_id - ) - - try: - await session.rollback() - await clear_ai_responding(session, chat_id) - except Exception: - try: - async with shielded_async_session() as fresh_session: - await clear_ai_responding(fresh_session, chat_id) - except Exception: - logging.getLogger(__name__).warning( - "Failed to clear AI responding state for thread %s", chat_id - ) - - with contextlib.suppress(Exception): - session.expunge_all() - - with contextlib.suppress(Exception): - await session.close() - - # Server-side assistant-message + token_usage finalization for - # the resume flow. The original user message was persisted by - # the original (interrupted) ``stream_new_chat`` invocation; - # the resume's own ``persist_assistant_shell`` write lives at - # the new ``turn_id`` above. This finalize updates that row - # with the rich ContentPart[] from the builder and writes - # token_usage idempotently via migration 142's partial - # unique index. Best-effort, never raises. - if ( - stream_result - and stream_result.turn_id - and stream_result.assistant_message_id - ): - from app.tasks.chat.persistence import finalize_assistant_turn - - builder_stats: dict[str, int] | None = None - if stream_result.content_builder is not None: - stream_result.content_builder.mark_interrupted() - builder_stats = stream_result.content_builder.stats() - content_payload = stream_result.content_builder.snapshot() - else: - content_payload = [ - { - "type": "text", - "text": stream_result.accumulated_text or "", - } - ] - - if builder_stats is not None: - _perf_log.info( - "[stream_resume] finalize_payload chat_id=%s " - "message_id=%s parts=%d bytes=%d text=%d " - "reasoning=%d tool_calls=%d " - "tool_calls_completed=%d tool_calls_aborted=%d " - "thinking_step_parts=%d step_separators=%d", - chat_id, - stream_result.assistant_message_id, - builder_stats["parts"], - builder_stats["bytes"], - builder_stats["text"], - builder_stats["reasoning"], - builder_stats["tool_calls"], - builder_stats["tool_calls_completed"], - builder_stats["tool_calls_aborted"], - builder_stats["thinking_step_parts"], - builder_stats["step_separators"], - ) - - await finalize_assistant_turn( - message_id=stream_result.assistant_message_id, - chat_id=chat_id, - search_space_id=search_space_id, - user_id=user_id, - turn_id=stream_result.turn_id, - content=content_payload, - accumulator=accumulator, - ) - - # Release the lock from the original interrupted turn or any - # re-interrupt/bailout. Skip on ``BusyError`` (lock not held here). - if not _busy_error_raised: - with contextlib.suppress(Exception): - end_turn(str(chat_id)) - _perf_log.info( - "[stream_resume] end_turn cleanup (chat_id=%s)", - chat_id, - ) - - agent = llm = connector_service = None - stream_result = None - session = None - - collected = gc.collect(0) + gc.collect(1) + gc.collect(2) - if collected: - _perf_log.info( - "[stream_resume] gc.collect() reclaimed %d objects (chat_id=%s)", - collected, - chat_id, - ) - trim_native_heap() - log_system_snapshot("stream_resume_chat_END") - with contextlib.suppress(Exception): - chat_span.set_attribute("chat.outcome", chat_outcome) - ot_metrics.record_chat_request_duration( - (time.perf_counter() - _t_total) * 1000, - flow="resume", - outcome=chat_outcome, - agent_mode=chat_agent_mode, - ) - ot_metrics.record_chat_request_outcome( - flow="resume", - outcome=chat_outcome, - agent_mode=chat_agent_mode, - error_category=chat_error_category, - ) - chat_span_cm.__exit__(*sys.exc_info()) diff --git a/surfsense_backend/tests/e2e/run_backend.py b/surfsense_backend/tests/e2e/run_backend.py index 2567cc7a4..c05783790 100644 --- a/surfsense_backend/tests/e2e/run_backend.py +++ b/surfsense_backend/tests/e2e/run_backend.py @@ -247,11 +247,11 @@ def _patch_llm_bindings() -> None: fake_create_chat_litellm_from_config, ), ( - "app.tasks.chat.stream_new_chat.create_chat_litellm_from_agent_config", + "app.tasks.chat.streaming.flows.shared.llm_bundle.create_chat_litellm_from_agent_config", fake_create_chat_litellm_from_agent_config, ), ( - "app.tasks.chat.stream_new_chat.create_chat_litellm_from_config", + "app.tasks.chat.streaming.flows.shared.llm_bundle.create_chat_litellm_from_config", fake_create_chat_litellm_from_config, ), ] diff --git a/surfsense_backend/tests/e2e/run_celery.py b/surfsense_backend/tests/e2e/run_celery.py index 9e7576a51..1a77bf45a 100644 --- a/surfsense_backend/tests/e2e/run_celery.py +++ b/surfsense_backend/tests/e2e/run_celery.py @@ -220,11 +220,11 @@ def _patch_llm_bindings() -> None: fake_create_chat_litellm_from_config, ), ( - "app.tasks.chat.stream_new_chat.create_chat_litellm_from_agent_config", + "app.tasks.chat.streaming.flows.shared.llm_bundle.create_chat_litellm_from_agent_config", fake_create_chat_litellm_from_agent_config, ), ( - "app.tasks.chat.stream_new_chat.create_chat_litellm_from_config", + "app.tasks.chat.streaming.flows.shared.llm_bundle.create_chat_litellm_from_config", fake_create_chat_litellm_from_config, ), ] diff --git a/surfsense_backend/tests/unit/tasks/chat/streaming/test_orchestrator_frame_parity.py b/surfsense_backend/tests/unit/tasks/chat/streaming/test_orchestrator_frame_parity.py deleted file mode 100644 index fc8280012..000000000 --- a/surfsense_backend/tests/unit/tasks/chat/streaming/test_orchestrator_frame_parity.py +++ /dev/null @@ -1,457 +0,0 @@ -"""Byte-for-byte frame parity: legacy monolith vs refactored flows orchestrators. - -The agent-content portion of the stream (`text-*`, tool cards, thinking-step -updates) flows through **shared** code in both implementations -(`stream_output` -> `EventRelay.relay` -> handlers), so it cannot diverge. The -only independently-written part is the *orchestrator glue*: the initial frames, -persistence-handshake frames, error/terminal branches, and final frames. - -This module drives BOTH ``stream_new_chat`` implementations (legacy -``app.tasks.chat.stream_new_chat`` and the refactored -``app.tasks.chat.streaming.flows``) through the deterministic glue paths and -asserts the emitted SSE frame sequences are **byte-for-byte identical**. These -are the paths where divergence could hide; the agent-streaming portion is shared -and is covered separately. - -Determinism is enforced by: - * freezing ``time.time`` (so ``turn_id = f"{chat_id}:{ms}"`` is stable), - * a deterministic ``uuid`` sequence for the streaming-service id generators, - * stubbing every DB/LLM/agent seam (LLM resolution, persistence, connector, - checkpointer, session) to fixed values. - -Cutover gate: when these are green, the live callers can be flipped to the -flows orchestrators. -""" - -from __future__ import annotations - -from types import SimpleNamespace -from typing import Any -from unittest.mock import AsyncMock, MagicMock - -import pytest - -import app.services.new_streaming_service as _nss -from app.tasks.chat.stream_new_chat import ( - stream_new_chat as old_stream_new_chat, - stream_resume_chat as old_stream_resume_chat, -) -from app.tasks.chat.streaming.flows import ( - stream_new_chat as new_stream_new_chat, - stream_resume_chat as new_stream_resume_chat, -) - -pytestmark = pytest.mark.unit - -_FIXED_EPOCH = 1_700_000_000.0 # -> turn_id ":1700000000000" - - -# --------------------------------------------------------------------------- # -# Deterministic uuid for the streaming-service id generators -# --------------------------------------------------------------------------- # - - -class _SeqUUID: - """Drop-in for the ``uuid`` module used by ``new_streaming_service``. - - Only ``uuid4().hex`` is consumed by the id generators. We hand out a - monotonic, zero-padded hex so two runs that emit the same number of ids in - the same order produce identical bytes. - """ - - def __init__(self) -> None: - self._n = 0 - - def reset(self) -> None: - self._n = 0 - - def uuid4(self) -> SimpleNamespace: - self._n += 1 - return SimpleNamespace(hex=f"{self._n:032x}") - - -_SEQ = _SeqUUID() - - -# --------------------------------------------------------------------------- # -# Fake session: the orchestrator owns ``async_session_maker()``; for the glue -# paths every real consumer is stubbed, so a no-op session suffices. -# --------------------------------------------------------------------------- # - - -class _FakeResult: - """Empty-everything SQLAlchemy ``Result`` stand-in for pre-stream reads.""" - - def scalars(self) -> "_FakeResult": - return self - - def first(self) -> None: - return None - - def all(self) -> list[Any]: - return [] - - def one_or_none(self) -> None: - return None - - def scalar_one_or_none(self) -> None: - return None - - def scalar(self) -> None: - return None - - def fetchall(self) -> list[Any]: - return [] - - def __iter__(self): - return iter(()) - - -class _FakeSession: - async def commit(self) -> None: # pragma: no cover - trivial - return None - - async def rollback(self) -> None: # pragma: no cover - trivial - return None - - async def close(self) -> None: # pragma: no cover - trivial - return None - - def expunge_all(self) -> None: # pragma: no cover - trivial - return None - - def add(self, *a: Any, **k: Any) -> None: # pragma: no cover - trivial - return None - - async def flush(self, *a: Any, **k: Any) -> None: # pragma: no cover - return None - - async def execute(self, *a: Any, **k: Any) -> _FakeResult: - return _FakeResult() - - -class _FakeConnectorService: - def __init__(self, *a: Any, **k: Any) -> None: - pass - - async def get_connector_by_type(self, *a: Any, **k: Any) -> None: - return None - - -def _patch(monkeypatch: pytest.MonkeyPatch, target: str, value: Any) -> None: - """``setattr`` that tolerates a missing attr (binding may be local-import).""" - monkeypatch.setattr(target, value, raising=False) - - -def _apply_common( - monkeypatch: pytest.MonkeyPatch, - *, - pin_raises: ValueError | None = None, - resolved_id: int = -1, - llm_load_ok: bool = True, - persist_user_id: int | None = 101, - persist_assistant_id: int | None = 102, -) -> None: - """Patch every glue seam in BOTH implementations to deterministic values.""" - # Time -> stable turn_id and any retry_after_at. - monkeypatch.setattr("time.time", lambda: _FIXED_EPOCH) - - # Deterministic streaming-service ids. - monkeypatch.setattr(_nss, "uuid", _SEQ) - - fake_model = MagicMock(name="scripted_llm") - - # --- session --- - for tgt in ( - "app.tasks.chat.stream_new_chat.async_session_maker", - "app.tasks.chat.streaming.flows.new_chat.orchestrator.async_session_maker", - "app.tasks.chat.streaming.flows.resume_chat.orchestrator.async_session_maker", - ): - _patch(monkeypatch, tgt, _FakeSession) - - # --- connector service --- - for tgt in ( - "app.tasks.chat.stream_new_chat.ConnectorService", - "app.tasks.chat.streaming.flows.shared.pre_stream_setup.ConnectorService", - ): - _patch(monkeypatch, tgt, _FakeConnectorService) - - # --- checkpointer --- - for tgt in ( - "app.tasks.chat.stream_new_chat.get_checkpointer", - "app.tasks.chat.streaming.flows.shared.pre_stream_setup.get_checkpointer", - ): - _patch(monkeypatch, tgt, AsyncMock(return_value=MagicMock(name="checkpointer"))) - - # --- agent factory (built but never streamed on glue paths) --- - # Resume routing awaits ``agent.aget_state`` before persist, so the fake - # agent exposes async state accessors returning an empty (no-interrupt) - # snapshot. ``astream_events`` is never reached on glue paths. - fake_agent = MagicMock(name="agent") - fake_agent.aget_state = AsyncMock( - return_value=SimpleNamespace(values={}, tasks=[], interrupts=[], next=()) - ) - fake_agent.aupdate_state = AsyncMock(return_value=None) - agent_factory = AsyncMock(return_value=fake_agent) - for tgt in ( - "app.tasks.chat.stream_new_chat.create_multi_agent_chat_deep_agent", - "app.tasks.chat.streaming.flows.new_chat.orchestrator.create_multi_agent_chat_deep_agent", - "app.tasks.chat.streaming.flows.resume_chat.orchestrator.create_multi_agent_chat_deep_agent", - ): - _patch(monkeypatch, tgt, agent_factory) - - # --- LLM resolution (auto-pin) --- - if pin_raises is not None: - async def _resolver(*a: Any, **k: Any): - raise pin_raises - else: - async def _resolver(*a: Any, **k: Any): - return SimpleNamespace(resolved_llm_config_id=resolved_id) - - _patch(monkeypatch, "app.services.auto_model_pin_service.resolve_or_get_pinned_llm_config_id", _resolver) - _patch(monkeypatch, "app.tasks.chat.stream_new_chat.resolve_or_get_pinned_llm_config_id", _resolver) - _patch( - monkeypatch, - "app.tasks.chat.streaming.flows.new_chat.auto_pin.resolve_or_get_pinned_llm_config_id", - _resolver, - ) - - # --- LLM bundle --- - sentinel_cfg = object() if llm_load_ok else None - _patch(monkeypatch, "app.tasks.chat.stream_new_chat.load_global_llm_config_by_id", lambda cid: sentinel_cfg) - _patch( - monkeypatch, - "app.tasks.chat.streaming.flows.shared.llm_bundle.load_global_llm_config_by_id", - lambda cid: sentinel_cfg, - ) - _patch(monkeypatch, "app.tasks.chat.stream_new_chat.create_chat_litellm_from_config", lambda cfg: fake_model) - _patch( - monkeypatch, - "app.tasks.chat.streaming.flows.shared.llm_bundle.create_chat_litellm_from_config", - lambda cfg: fake_model, - ) - # agent_config := None keeps premium + capability gates inert and identical. - from app.agents.shared.llm_config import AgentConfig - - monkeypatch.setattr(AgentConfig, "from_yaml_config", staticmethod(lambda cfg: None)) - - # --- persistence --- - async def _persist_user(*a: Any, **k: Any): - return persist_user_id - - async def _persist_assistant(*a: Any, **k: Any): - return persist_assistant_id - - async def _finalize(*a: Any, **k: Any): - return None - - for mod in ( - "app.tasks.chat.persistence", - "app.tasks.chat.streaming.flows.new_chat.persistence_spawn", - ): - _patch(monkeypatch, f"{mod}.persist_user_turn", _persist_user) - _patch(monkeypatch, f"{mod}.persist_assistant_shell", _persist_assistant) - # Resume binds ``persist_assistant_shell`` in its own assistant_shell module. - _patch( - monkeypatch, - "app.tasks.chat.streaming.flows.resume_chat.assistant_shell.persist_assistant_shell", - _persist_assistant, - ) - _patch(monkeypatch, "app.tasks.chat.persistence.finalize_assistant_turn", _finalize) - - # --- collaboration flags --- - async def _noop(*a: Any, **k: Any): - return None - - for tgt in ( - "app.tasks.chat.stream_new_chat.set_ai_responding", - "app.tasks.chat.stream_new_chat.clear_ai_responding", - "app.tasks.chat.streaming.flows.new_chat.persistence_spawn.set_ai_responding", - "app.services.chat_session_state_service.set_ai_responding", - "app.services.chat_session_state_service.clear_ai_responding", - ): - _patch(monkeypatch, tgt, _noop) - - -async def _collect(genfunc: Any, **kwargs: Any) -> list[str]: - frames: list[str] = [] - async for frame in genfunc(**kwargs): - frames.append(frame) - return frames - - -async def _run_both(kwargs: dict[str, Any]) -> tuple[list[str], list[str]]: - """Drive both NEW-chat implementations on identical inputs.""" - _SEQ.reset() - old = await _collect(old_stream_new_chat, **kwargs) - _SEQ.reset() - new = await _collect(new_stream_new_chat, **kwargs) - return old, new - - -async def _run_both_resume(kwargs: dict[str, Any]) -> tuple[list[str], list[str]]: - """Drive both RESUME-chat implementations on identical inputs.""" - _SEQ.reset() - old = await _collect(old_stream_resume_chat, **kwargs) - _SEQ.reset() - new = await _collect(new_stream_resume_chat, **kwargs) - return old, new - - -def _assert_parity(old: list[str], new: list[str]) -> None: - """Byte-for-byte equality with a readable first-divergence message.""" - for i, (a, b) in enumerate(zip(old, new, strict=False)): - assert a == b, f"frame[{i}] differs:\n old={a!r}\n new={b!r}" - assert len(old) == len(new), ( - f"frame count differs: old={len(old)} new={len(new)}\n" - f" old tail={old[len(new):]!r}\n new tail={new[len(old):]!r}" - ) - assert old[-1].strip() == "data: [DONE]" - - -# --------------------------------------------------------------------------- # -# NEW-chat scenarios -# --------------------------------------------------------------------------- # - -_NEW_KW = dict(user_query="hi", search_space_id=1, chat_id=42, user_id=None) - - -@pytest.mark.asyncio -async def test_auto_pin_failure_parity(monkeypatch: pytest.MonkeyPatch) -> None: - """Auto-pin raises -> identical ``[error, DONE]`` from both.""" - _apply_common(monkeypatch, pin_raises=ValueError("no eligible config")) - old, new = await _run_both(dict(_NEW_KW)) - _assert_parity(old, new) - assert len(old) == 2 - assert '"errorCode": "SERVER_ERROR"' in old[0] - - -@pytest.mark.asyncio -async def test_llm_load_failure_parity(monkeypatch: pytest.MonkeyPatch) -> None: - """LLM bundle load fails -> identical ``[error, DONE]`` from both.""" - _apply_common(monkeypatch, llm_load_ok=False) - old, new = await _run_both(dict(_NEW_KW)) - _assert_parity(old, new) - assert len(old) == 2 - assert '"errorCode": "SERVER_ERROR"' in old[0] - - -@pytest.mark.asyncio -async def test_persist_user_failure_parity(monkeypatch: pytest.MonkeyPatch) -> None: - """User-turn persist returns None. - - Exercises the full initial-frame ordering (start, start-step, turn-info, - turn-status busy), the MESSAGE_PERSIST_FAILED error, and final frames. - """ - _apply_common(monkeypatch, persist_user_id=None) - old, new = await _run_both(dict(_NEW_KW)) - _assert_parity(old, new) - assert '"type": "start"' in old[0] - assert '"chat_turn_id": "42:1700000000000"' in old[2] - assert any('"errorCode": "MESSAGE_PERSIST_FAILED"' in f for f in old) - assert any('"type": "finish"' in f for f in old) - - -@pytest.mark.asyncio -async def test_persist_assistant_failure_parity(monkeypatch: pytest.MonkeyPatch) -> None: - """Assistant-shell persist returns None. - - Adds the ``data-user-message-id`` handshake frame ahead of the error. - """ - _apply_common(monkeypatch, persist_user_id=101, persist_assistant_id=None) - old, new = await _run_both(dict(_NEW_KW)) - _assert_parity(old, new) - assert any('"data-user-message-id"' in f and '"message_id": 101' in f for f in old) - assert any('"errorCode": "MESSAGE_PERSIST_FAILED"' in f for f in old) - - -@pytest.mark.asyncio -async def test_prestream_exception_parity(monkeypatch: pytest.MonkeyPatch) -> None: - """A pre-stream failure routes both through the top-level ``except`` path. - - Resolver returns a non-int so ``turn_id`` math / downstream use raises after - the span opens but before initial frames: both must emit the identical - ``busy -> error -> idle -> finish-step -> finish -> DONE`` terminal sequence. - """ - - async def _bad_resolver(*a: Any, **k: Any): - raise RuntimeError("boom in pre-stream") - - _apply_common(monkeypatch) - # Override the resolver with a non-ValueError so the classified early-error - # branches don't catch it -> top-level except path. - for tgt in ( - "app.services.auto_model_pin_service.resolve_or_get_pinned_llm_config_id", - "app.tasks.chat.stream_new_chat.resolve_or_get_pinned_llm_config_id", - "app.tasks.chat.streaming.flows.new_chat.auto_pin.resolve_or_get_pinned_llm_config_id", - ): - _patch(monkeypatch, tgt, _bad_resolver) - old, new = await _run_both(dict(_NEW_KW)) - _assert_parity(old, new) - assert any('"type": "error"' in f for f in old) - - -# --------------------------------------------------------------------------- # -# RESUME-chat scenarios (no title-generation path -> fully deterministic) -# --------------------------------------------------------------------------- # - -_RESUME_KW = dict(chat_id=42, search_space_id=1, decisions=[], user_id=None) - - -async def _collect_resume_old() -> list[str]: - _SEQ.reset() - return await _collect(old_stream_resume_chat, **dict(_RESUME_KW)) - - -# NOTE: KNOWN, INTENTIONAL DIVERGENCE (flows fixes a latent monolith bug). -# -# In ``stream_resume_chat`` the monolith defines ``_resume_premium_request_id`` -# (line ~2363) AFTER the auto-pin / LLM-load early-return points (~2346 / ~2356). -# Its ``finally`` block (line ~2918) reads that variable, so a resume turn whose -# auto-pin raises or whose LLM bundle fails to load crashes with -# ``UnboundLocalError`` instead of emitting a clean terminal-error frame. The -# refactored flows orchestrator does NOT have this bug — it emits the proper -# ``[error, DONE]`` sequence. We assert the divergence explicitly so the cutover -# is a documented behavior IMPROVEMENT rather than a silent change. - - -@pytest.mark.asyncio -async def test_resume_auto_pin_failure_flows_fixes_monolith_crash( - monkeypatch: pytest.MonkeyPatch, -) -> None: - _apply_common(monkeypatch, pin_raises=ValueError("no eligible config")) - # Monolith: latent UnboundLocalError in the finally clause. - with pytest.raises(UnboundLocalError, match="_resume_premium_request_id"): - await _collect_resume_old() - # Flows: clean terminal error. - _SEQ.reset() - new = await _collect(new_stream_resume_chat, **dict(_RESUME_KW)) - assert len(new) == 2 - assert new[-1].strip() == "data: [DONE]" - assert '"type": "error"' in new[0] - - -@pytest.mark.asyncio -async def test_resume_llm_load_failure_flows_fixes_monolith_crash( - monkeypatch: pytest.MonkeyPatch, -) -> None: - _apply_common(monkeypatch, llm_load_ok=False) - with pytest.raises(UnboundLocalError, match="_resume_premium_request_id"): - await _collect_resume_old() - _SEQ.reset() - new = await _collect(new_stream_resume_chat, **dict(_RESUME_KW)) - assert len(new) == 2 - assert new[-1].strip() == "data: [DONE]" - assert '"type": "error"' in new[0] - - -@pytest.mark.asyncio -async def test_resume_persist_assistant_failure_parity( - monkeypatch: pytest.MonkeyPatch, -) -> None: - """Resume emits NO user-message-id frame; only the assistant handshake path.""" - _apply_common(monkeypatch, persist_assistant_id=None) - old, new = await _run_both_resume(dict(_RESUME_KW)) - _assert_parity(old, new) - assert not any('"data-user-message-id"' in f for f in old) - assert any('"chat_turn_id": "42:1700000000000"' in f for f in old) diff --git a/surfsense_backend/tests/unit/tasks/chat/streaming/test_parallel_refactor_parity.py b/surfsense_backend/tests/unit/tasks/chat/streaming/test_parallel_refactor_parity.py deleted file mode 100644 index 3a9a834f9..000000000 --- a/surfsense_backend/tests/unit/tasks/chat/streaming/test_parallel_refactor_parity.py +++ /dev/null @@ -1,584 +0,0 @@ -"""Parity gate for the parallel refactor of ``stream_new_chat.py``. - -The new tree under ``app.tasks.chat.streaming.flows`` is built side-by-side with -the legacy monolithic ``app.tasks.chat.stream_new_chat`` so we can cut over -atomically. This file pins externally-observable behaviour at module -boundaries so a divergence between the two trees fails loudly *before* the -cutover. - -What we verify: - - 1. **Signature parity** — ``stream_new_chat`` / ``stream_resume_chat`` from - the new tree have the same call signature as the originals. - 2. **Helper extraction parity** — the SRP modules in ``flows/`` produce the - same outputs as the inline code in the legacy file for representative - inputs (initial thinking step, image-capability gate, runtime context, - SSE frame sequences, token-usage frame shape, persistence guards). - 3. **Wrapper delegation** — wrappers like ``load_llm_bundle`` / - ``can_recover_provider_rate_limit`` exist and are addressable. - -Delete this file along with ``stream_new_chat.py`` once the cutover is done -(see the parent refactor plan). -""" - -from __future__ import annotations - -import asyncio -import inspect -from typing import Any -from unittest.mock import AsyncMock, patch - -import pytest - -from app.agents.shared.context import SurfSenseContextSchema -from app.services.new_streaming_service import VercelStreamingService -from app.tasks.chat.stream_new_chat import ( - stream_new_chat as old_stream_new_chat, - stream_resume_chat as old_stream_resume_chat, -) -from app.tasks.chat.streaming.flows import ( - stream_new_chat as new_stream_new_chat, - stream_resume_chat as new_stream_resume_chat, -) -from app.tasks.chat.streaming.flows.new_chat.initial_thinking_step import ( - build_initial_thinking_step, -) -from app.tasks.chat.streaming.flows.new_chat.llm_capability import ( - check_image_input_capability, -) -from app.tasks.chat.streaming.flows.new_chat.persistence_spawn import ( - await_persist_task, - spawn_persist_assistant_shell_task, - spawn_persist_user_task, - spawn_set_ai_responding_bg, -) -from app.tasks.chat.streaming.flows.new_chat.runtime_context import ( - build_new_chat_runtime_context, -) -from app.tasks.chat.streaming.flows.resume_chat.runtime_context import ( - build_resume_chat_runtime_context, -) -from app.tasks.chat.streaming.flows.shared.finalize_emit import iter_token_usage_frame -from app.tasks.chat.streaming.flows.shared.first_frames import ( - iter_final_frames, - iter_initial_frames, -) -from app.tasks.chat.streaming.flows.shared.llm_bundle import load_llm_bundle -from app.tasks.chat.streaming.flows.shared.premium_quota import ( - PremiumReservation, - needs_premium_quota, -) -from app.tasks.chat.streaming.flows.shared.rate_limit_recovery import ( - can_recover_provider_rate_limit, -) - -pytestmark = pytest.mark.unit - - -# --------------------------------------------------------------------- signature - - -def _normalize_annotation(ann: Any) -> str: - """Compare-friendly form for an annotation. - - The legacy ``stream_new_chat.py`` does NOT use ``from __future__ import - annotations``, so its annotations are evaluated at import time and come - back as type objects / typing generics. The new tree DOES use it, so its - annotations are PEP-563 strings. - - Both reprs describe the same types — strip the module prefixes / typing - namespace + the ```` wrapper so we compare the canonical - declared form. - """ - if ann is inspect.Signature.empty: - return "" - raw = ann if isinstance(ann, str) else repr(ann) - cleaned = ( - raw.replace("typing.", "") - .replace("collections.abc.", "") - .replace("app.db.", "") - .replace("app.agents.shared.filesystem_selection.", "") - .replace("app.agents.shared.context.", "") - ) - # Unwrap ```` → ``int`` (legacy-side type objects). - if cleaned.startswith(""): - cleaned = cleaned[len("")] - return cleaned - - -def _normalize_sig(sig: inspect.Signature) -> list[tuple[str, Any, str]]: - return [ - (p.name, p.default, _normalize_annotation(p.annotation)) - for p in sig.parameters.values() - ] - - -def test_stream_new_chat_signature_matches_legacy() -> None: - old = inspect.signature(old_stream_new_chat) - new = inspect.signature(new_stream_new_chat) - assert _normalize_sig(new) == _normalize_sig(old) - assert _normalize_annotation(new.return_annotation) == _normalize_annotation( - old.return_annotation - ) - - -def test_stream_resume_chat_signature_matches_legacy() -> None: - old = inspect.signature(old_stream_resume_chat) - new = inspect.signature(new_stream_resume_chat) - assert _normalize_sig(new) == _normalize_sig(old) - assert _normalize_annotation(new.return_annotation) == _normalize_annotation( - old.return_annotation - ) - - -def test_orchestrators_are_async_generator_functions() -> None: - assert inspect.isasyncgenfunction(new_stream_new_chat) - assert inspect.isasyncgenfunction(new_stream_resume_chat) - - -# ------------------------------------------------------------ initial thinking - - -@pytest.mark.parametrize( - "user_query, image_urls, expected_title, expected_action", - [ - ("hello world", None, "Understanding your request", "Processing"), - ( - "", - ["data:image/png;base64,AAA"], - "Understanding your request", - "Processing", - ), - ("", None, "Understanding your request", "Processing"), - ], -) -def test_initial_thinking_step_branches( - user_query: str, - image_urls: list[str] | None, - expected_title: str, - expected_action: str, -) -> None: - step = build_initial_thinking_step( - user_query=user_query, - user_image_data_urls=image_urls, - ) - assert step.step_id == "thinking-1" - assert step.title == expected_title - assert len(step.items) == 1 - assert step.items[0].startswith(f"{expected_action}: ") - - -def test_initial_thinking_step_truncates_long_query() -> None: - long_query = "x" * 200 - step = build_initial_thinking_step( - user_query=long_query, - user_image_data_urls=None, - ) - # 80-char truncation + ellipsis, sandwiched after "Processing: ". - assert "..." in step.items[0] - item = step.items[0] - payload = item[len("Processing: ") :] - assert payload.startswith("x" * 80) and payload.endswith("...") - - -# ------------------------------------------------------------ capability gate - - -def test_image_capability_passes_without_images() -> None: - assert ( - check_image_input_capability(user_image_data_urls=None, agent_config=None) - is None - ) - - -def test_image_capability_passes_when_capability_unknown() -> None: - """Unknown / unmapped models are not blocked — only models LiteLLM has - *explicitly* marked text-only trip the gate.""" - - class _AgentConfig: - provider = "openrouter" - model_name = "unknown-mystery-model" - custom_provider = None - config_name = "Unknown" - litellm_params: dict[str, Any] = {} - - with patch( - "app.services.provider_capabilities.is_known_text_only_chat_model", - return_value=False, - ): - assert ( - check_image_input_capability( - user_image_data_urls=["data:image/png;base64,AAA"], - agent_config=_AgentConfig(), # type: ignore[arg-type] - ) - is None - ) - - -def test_image_capability_blocks_known_text_only_models() -> None: - class _AgentConfig: - provider = "openai" - model_name = "gpt-3.5-turbo" - custom_provider = None - config_name = "GPT-3.5" - litellm_params: dict[str, Any] = {"base_model": "gpt-3.5-turbo"} - - with patch( - "app.services.provider_capabilities.is_known_text_only_chat_model", - return_value=True, - ): - result = check_image_input_capability( - user_image_data_urls=["data:image/png;base64,AAA"], - agent_config=_AgentConfig(), # type: ignore[arg-type] - ) - assert result is not None - message, error_code = result - assert error_code == "MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT" - assert "GPT-3.5" in message - - -# ---------------------------------------------------------------- runtime ctx - - -def test_new_chat_runtime_context_prefers_accepted_folder_ids() -> None: - """Post-resolve accepted folder ids win over the raw requested ids.""" - ctx = build_new_chat_runtime_context( - search_space_id=7, - mentioned_document_ids=[1, 2], - accepted_folder_ids=[10], - mentioned_folder_ids=[20, 30], - mentioned_connector_ids=None, - mentioned_connectors=None, - request_id="req", - turn_id="t1", - ) - assert isinstance(ctx, SurfSenseContextSchema) - assert ctx.search_space_id == 7 - assert list(ctx.mentioned_document_ids) == [1, 2] - assert list(ctx.mentioned_folder_ids) == [10] - assert ctx.request_id == "req" - assert ctx.turn_id == "t1" - - -def test_new_chat_runtime_context_falls_back_to_mentioned_folder_ids() -> None: - """With no accepted ids, the raw requested folder ids flow through.""" - ctx = build_new_chat_runtime_context( - search_space_id=7, - mentioned_document_ids=None, - accepted_folder_ids=[], - mentioned_folder_ids=[20, 30], - mentioned_connector_ids=None, - mentioned_connectors=None, - request_id=None, - turn_id="t2", - ) - assert list(ctx.mentioned_folder_ids) == [20, 30] - - -def test_new_chat_runtime_context_propagates_connector_mentions() -> None: - """@-selected connector ids/accounts ride onto the runtime context schema. - - Parity with the legacy ``stream_new_chat`` runtime context, which set both - ``mentioned_connector_ids`` and ``mentioned_connectors`` on the schema. - """ - connectors = [{"id": 5, "connector_type": "SLACK_CONNECTOR", "title": "acme"}] - ctx = build_new_chat_runtime_context( - search_space_id=7, - mentioned_document_ids=None, - accepted_folder_ids=[], - mentioned_folder_ids=None, - mentioned_connector_ids=[5], - mentioned_connectors=connectors, - request_id=None, - turn_id="t3", - ) - assert list(ctx.mentioned_connector_ids) == [5] - assert list(ctx.mentioned_connectors) == connectors - - -def test_resume_chat_runtime_context_empty_mention_lists() -> None: - ctx = build_resume_chat_runtime_context( - search_space_id=42, request_id="req-r", turn_id="t-r" - ) - assert ctx.search_space_id == 42 - assert ctx.request_id == "req-r" - assert ctx.turn_id == "t-r" - - -# ---------------------------------------------------------------- SSE frames - - -def test_iter_initial_frames_emits_canonical_sequence() -> None: - svc = VercelStreamingService() - frames = list(iter_initial_frames(svc, turn_id="42:1700000000000")) - # Exactly 4 frames: message_start, start_step, turn-info (turn_id), turn-status (busy). - assert len(frames) == 4 - assert "42:1700000000000" in frames[2] - assert '"status":"busy"' in frames[3] or '"status": "busy"' in frames[3] - - -def test_iter_final_frames_emits_idle_then_finish_done() -> None: - svc = VercelStreamingService() - frames = list(iter_final_frames(svc)) - assert len(frames) == 4 - assert '"status":"idle"' in frames[0] or '"status": "idle"' in frames[0] - - -# ----------------------------------------------------------- token usage frame - - -class _FakeAccumulator: - """Minimal stand-in covering only the fields ``iter_token_usage_frame`` reads.""" - - def __init__(self, summary: Any = None) -> None: - self._summary = summary - self.calls = [1, 2, 3] - self.grand_total = 100 - self.total_cost_micros = 50_000 - self.total_prompt_tokens = 60 - self.total_completion_tokens = 40 - - def per_message_summary(self) -> Any: - return self._summary - - def serialized_calls(self) -> list[Any]: - return list(self.calls) - - -def test_token_usage_frame_skipped_when_no_summary() -> None: - svc = VercelStreamingService() - frames = list( - iter_token_usage_frame( - svc, - accumulator=_FakeAccumulator(summary=None), # type: ignore[arg-type] - log_label="parity-empty", - ) - ) - assert frames == [] - - -def test_token_usage_frame_emitted_when_summary_present() -> None: - svc = VercelStreamingService() - frames = list( - iter_token_usage_frame( - svc, - accumulator=_FakeAccumulator(summary=[{"m": "x", "t": 100}]), # type: ignore[arg-type] - log_label="parity-populated", - ) - ) - assert len(frames) == 1 - # Field shape on the wire is fixed by the FE; assert each surfaces. - payload = frames[0] - for key in ( - '"prompt_tokens":60', - '"completion_tokens":40', - '"total_tokens":100', - '"cost_micros":50000', - ): - assert key in payload.replace(" ", "") - - -# ------------------------------------------------------------------ llm_bundle - - -def test_load_llm_bundle_routes_negative_id_to_yaml_loader() -> None: - async def _run() -> tuple[Any, Any, str | None]: - with ( - patch( - "app.tasks.chat.streaming.flows.shared.llm_bundle.load_global_llm_config_by_id", - return_value=None, - ), - ): - return await load_llm_bundle( - session=AsyncMock(), # type: ignore[arg-type] - config_id=-1, - search_space_id=7, - ) - - llm, agent_config, error = asyncio.run(_run()) - assert llm is None - assert agent_config is None - assert error is not None and "id -1" in error - - -def test_load_llm_bundle_routes_nonnegative_id_to_db_loader() -> None: - async def _run() -> tuple[Any, Any, str | None]: - with ( - patch( - "app.tasks.chat.streaming.flows.shared.llm_bundle.load_agent_config", - new=AsyncMock(return_value=None), - ), - ): - return await load_llm_bundle( - session=AsyncMock(), # type: ignore[arg-type] - config_id=12, - search_space_id=7, - ) - - llm, agent_config, error = asyncio.run(_run()) - assert llm is None - assert agent_config is None - assert error is not None and "id 12" in error - - -# ----------------------------------------------------------------- premium quota - - -def test_needs_premium_quota_requires_user_and_premium_flag() -> None: - class _AgentConfig: - is_premium = True - - class _NonPremium: - is_premium = False - - assert needs_premium_quota(_AgentConfig(), "user-1") is True # type: ignore[arg-type] - assert needs_premium_quota(_AgentConfig(), None) is False # type: ignore[arg-type] - assert needs_premium_quota(_NonPremium(), "user-1") is False # type: ignore[arg-type] - assert needs_premium_quota(None, "user-1") is False - - -def test_premium_reservation_dataclass_shape() -> None: - # Sanity: the dataclass exists and carries the fields the orchestrator uses. - r = PremiumReservation(request_id="abc", reserved_micros=100, allowed=True) - assert r.request_id == "abc" - assert r.reserved_micros == 100 - assert r.allowed is True - - -# ----------------------------------------------------------- rate-limit guard - - -@pytest.mark.parametrize( - "first_event_seen, recovered, requested_id, current_id, expected", - [ - (False, False, 0, -1, True), - # Already recovered: no second pass. - (False, True, 0, -1, False), - # User explicitly picked a config: don't silently switch. - (False, False, 5, -1, False), - # Already on a database-backed (positive) id. - (False, False, 0, 7, False), - # User has already seen output: silent rebuild not possible. - (True, False, 0, -1, False), - ], -) -def test_can_recover_provider_rate_limit_truth_table( - first_event_seen: bool, - recovered: bool, - requested_id: int, - current_id: int, - expected: bool, -) -> None: - # Use a known rate-limit-shaped exception so the helper's last condition - # is satisfied; the guard only short-circuits to False when one of the - # *other* preconditions fails. - exc = Exception('{"error":{"type":"rate_limit_error","message":"slow"}}') - assert ( - can_recover_provider_rate_limit( - exc, - first_event_seen=first_event_seen, - runtime_rate_limit_recovered=recovered, - requested_llm_config_id=requested_id, - current_llm_config_id=current_id, - ) - is expected - ) - - -def test_can_recover_provider_rate_limit_rejects_non_rate_limit_exception() -> None: - assert ( - can_recover_provider_rate_limit( - ValueError("not a rate limit"), - first_event_seen=False, - runtime_rate_limit_recovered=False, - requested_llm_config_id=0, - current_llm_config_id=-1, - ) - is False - ) - - -# --------------------------------------------------------- persistence spawn - - -def test_spawn_set_ai_responding_bg_noop_without_user_id() -> None: - async def _run() -> set[asyncio.Task]: - background: set[asyncio.Task] = set() - spawn_set_ai_responding_bg(chat_id=1, user_id=None, background_tasks=background) - return background - - bg = asyncio.run(_run()) - assert bg == set() - - -def test_spawn_persist_user_task_registers_and_self_unregisters() -> None: - async def _run() -> tuple[int, int]: - background: set[asyncio.Task] = set() - with patch( - "app.tasks.chat.streaming.flows.new_chat.persistence_spawn.persist_user_turn", - new=AsyncMock(return_value=99), - ): - task = spawn_persist_user_task( - chat_id=1, - user_id="u", - turn_id="t", - user_query="hi", - user_image_data_urls=None, - mentioned_documents=None, - background_tasks=background, - ) - size_before_await = len(background) - result = await asyncio.shield(task) - # Give the done-callback one event-loop tick to run. - await asyncio.sleep(0) - return size_before_await, result # type: ignore[return-value] - - size_before, result = asyncio.run(_run()) - assert size_before == 1 - assert result == 99 - - -def test_spawn_persist_assistant_shell_task_registers() -> None: - async def _run() -> int | None: - background: set[asyncio.Task] = set() - with patch( - "app.tasks.chat.streaming.flows.new_chat.persistence_spawn.persist_assistant_shell", - new=AsyncMock(return_value=42), - ): - task = spawn_persist_assistant_shell_task( - chat_id=1, - user_id="u", - turn_id="t", - background_tasks=background, - ) - return await asyncio.shield(task) - - assert asyncio.run(_run()) == 42 - - -def test_await_persist_task_returns_none_on_failure() -> None: - async def _run() -> int | None: - async def _boom() -> int: - raise RuntimeError("DB down") - - task = asyncio.create_task(_boom()) - return await await_persist_task( - task, - chat_id=1, - turn_id="t", - log_label="parity-failure", - ) - - assert asyncio.run(_run()) is None - - -def test_await_persist_task_returns_none_for_none_input() -> None: - async def _run() -> int | None: - return await await_persist_task( - None, - chat_id=1, - turn_id="t", - log_label="parity-none", - ) - - assert asyncio.run(_run()) is None diff --git a/surfsense_backend/tests/unit/tasks/chat/streaming/test_stage_1_parity.py b/surfsense_backend/tests/unit/tasks/chat/streaming/test_stage_1_parity.py deleted file mode 100644 index a4bd1d56c..000000000 --- a/surfsense_backend/tests/unit/tasks/chat/streaming/test_stage_1_parity.py +++ /dev/null @@ -1,240 +0,0 @@ -"""Pin Stage 1 extractions as faithful copies of the old helpers. - -Extractions under ``app.tasks.chat.streaming`` are compared to -``app.tasks.chat.stream_new_chat`` helpers. -For each Stage 1 extraction we assert the new function returns the same -output as the old one for a representative input set. The moment the -two diverge - intentionally or otherwise - this file fails loudly so -the divergence is reviewed rather than shipped silently. -""" - -from __future__ import annotations - -import logging -from dataclasses import dataclass, field -from typing import Any - -import pytest - -from app.agents.shared.errors import BusyError -from app.agents.shared.middleware.busy_mutex import request_cancel, reset_cancel -from app.tasks.chat.stream_new_chat import ( - _classify_stream_exception as old_classify, - _emit_stream_terminal_error as old_emit_terminal_error, - _extract_chunk_parts as old_extract_chunk_parts, - _extract_resolved_file_path as old_extract_resolved_file_path, - _tool_output_has_error as old_tool_output_has_error, - _tool_output_to_text as old_tool_output_to_text, -) -from app.tasks.chat.streaming.errors.classifier import ( - classify_stream_exception as new_classify, -) -from app.tasks.chat.streaming.errors.emitter import ( - emit_stream_terminal_error as new_emit_terminal_error, -) -from app.tasks.chat.streaming.helpers.chunk_parts import ( - extract_chunk_parts as new_extract_chunk_parts, -) -from app.tasks.chat.streaming.helpers.tool_output import ( - extract_resolved_file_path as new_extract_resolved_file_path, - tool_output_has_error as new_tool_output_has_error, - tool_output_to_text as new_tool_output_to_text, -) - -pytestmark = pytest.mark.unit - - -# ---------------------------------------------------------------- chunk parts - - -@dataclass -class _Chunk: - content: Any = "" - additional_kwargs: dict[str, Any] = field(default_factory=dict) - tool_call_chunks: list[dict[str, Any]] = field(default_factory=list) - - -_CHUNK_CASES: list[Any] = [ - None, - _Chunk(content=""), - _Chunk(content="hello"), - _Chunk(content=42), # invalid type, defensively coerced to empty - _Chunk( - content=[ - {"type": "text", "text": "Hello "}, - {"type": "text", "text": "world"}, - ] - ), - _Chunk( - content=[ - {"type": "reasoning", "reasoning": "hmm "}, - {"type": "reasoning", "text": "still"}, - {"type": "text", "text": "answer"}, - ] - ), - _Chunk( - content=[ - {"type": "tool_call_chunk", "id": "c1", "name": "x", "args": "{"}, - {"type": "tool_use", "id": "c2", "name": "y"}, - {"type": "image_url", "url": "ignored"}, - ] - ), - _Chunk( - content="visible", - additional_kwargs={"reasoning_content": "private"}, - ), - _Chunk( - tool_call_chunks=[ - {"id": None, "name": None, "args": '{"a":1}', "index": 0}, - {"id": "c", "name": "n", "args": "}", "index": 0}, - ] - ), - _Chunk( - content=[{"type": "tool_call_chunk", "id": "from-block", "name": "x"}], - tool_call_chunks=[{"id": "from-attr", "name": "y"}], - ), -] - - -@pytest.mark.parametrize("chunk", _CHUNK_CASES) -def test_extract_chunk_parts_matches_old_implementation(chunk: Any) -> None: - assert new_extract_chunk_parts(chunk) == old_extract_chunk_parts(chunk) - - -# ----------------------------------------------------------- error classifier - - -def _classify_cases() -> list[Exception]: - """Inputs that the FE depends on being mapped to specific error codes.""" - return [ - Exception("totally generic error"), - Exception('{"error":{"type":"rate_limit_error","message":"slow down"}}'), - Exception( - 'OpenrouterException - {"error":{"message":"Provider returned error",' - '"code":429}}' - ), - BusyError(request_id="thread-busy-parity"), - Exception("Thread is busy with another request"), - ] - - -@pytest.mark.parametrize("exc", _classify_cases()) -def test_classify_stream_exception_matches_old_implementation( - exc: Exception, -) -> None: - new = new_classify(exc, flow_label="parity-test") - old = old_classify(exc, flow_label="parity-test") - # Strip the wall-clock retry timestamp before comparing — both - # implementations call ``time.time()`` independently and the call - # order is enough to differ by 1 ms in practice. Every other field - # in the tuple must match exactly. - new_extra = dict(new[5]) if isinstance(new[5], dict) else new[5] - old_extra = dict(old[5]) if isinstance(old[5], dict) else old[5] - if isinstance(new_extra, dict) and isinstance(old_extra, dict): - new_extra.pop("retry_after_at", None) - old_extra.pop("retry_after_at", None) - assert new[:5] == old[:5] - assert new_extra == old_extra - - -def test_classify_turn_cancelling_branch_parity() -> None: - """The TURN_CANCELLING branch reads cancel state for the busy thread id; - both implementations must agree on retry-window semantics, not just the - plain THREAD_BUSY code.""" - thread_id = "parity-cancelling-thread" - reset_cancel(thread_id) - request_cancel(thread_id) - exc = BusyError(request_id=thread_id) - new = new_classify(exc, flow_label="parity-test") - old = old_classify(exc, flow_label="parity-test") - assert new[0] == old[0] == "thread_busy" - assert new[1] == old[1] == "TURN_CANCELLING" - assert isinstance(new[5], dict) and isinstance(old[5], dict) - assert new[5]["retry_after_ms"] == old[5]["retry_after_ms"] - - -# ------------------------------------------------------------ terminal emitter - - -class _FakeStreamingService: - """Duck-types ``format_error`` for both old and new emitters.""" - - def __init__(self) -> None: - self.calls: list[dict[str, Any]] = [] - - def format_error( - self, message: str, *, error_code: str, extra: dict[str, Any] | None = None - ) -> str: - self.calls.append( - {"message": message, "error_code": error_code, "extra": extra} - ) - return f'data: {{"type":"error","errorText":"{message}"}}\n\n' - - -def test_emit_stream_terminal_error_matches_old_output_and_logs(caplog) -> None: - """The new emitter must produce the same SSE frame and log the same - structured payload as the old one for the same arguments.""" - args: dict[str, Any] = { - "flow": "new", - "request_id": "req-parity", - "thread_id": 7, - "search_space_id": 9, - "user_id": "user-parity", - "message": "boom", - "error_kind": "server_error", - "error_code": "SERVER_ERROR", - "severity": "error", - "is_expected": False, - "extra": {"foo": "bar"}, - } - - new_svc = _FakeStreamingService() - old_svc = _FakeStreamingService() - - with caplog.at_level(logging.ERROR): - new_frame = new_emit_terminal_error(streaming_service=new_svc, **args) - old_frame = old_emit_terminal_error(streaming_service=old_svc, **args) - - assert new_frame == old_frame - assert new_svc.calls == old_svc.calls - chat_error_records = [ - r for r in caplog.records if "[chat_stream_error]" in r.message - ] - # One log line per emit call (two emits -> two records). - assert len(chat_error_records) == 2 - - -# ---------------------------------------------------------------- tool output - - -def test_tool_output_helpers_match_old_implementation() -> None: - samples: list[Any] = [ - {"result": "ok"}, - {"error": "bad"}, - {"result": "Error: x"}, - "Error: plain", - "fine", - {"nested": {"a": 1}}, - ] - for s in samples: - assert new_tool_output_to_text(s) == old_tool_output_to_text(s) - assert new_tool_output_has_error(s) == old_tool_output_has_error(s) - - assert new_extract_resolved_file_path( - tool_name="write_file", - tool_output={"path": " /tmp/x "}, - tool_input=None, - ) == old_extract_resolved_file_path( - tool_name="write_file", - tool_output={"path": " /tmp/x "}, - tool_input=None, - ) - assert new_extract_resolved_file_path( - tool_name="write_file", - tool_output={}, - tool_input={"file_path": " /fallback "}, - ) == old_extract_resolved_file_path( - tool_name="write_file", - tool_output={}, - tool_input={"file_path": " /fallback "}, - ) diff --git a/surfsense_backend/tests/unit/tasks/chat/streaming/test_stage_2_parity.py b/surfsense_backend/tests/unit/tasks/chat/streaming/test_stage_2_parity.py deleted file mode 100644 index 3ee1ab622..000000000 --- a/surfsense_backend/tests/unit/tasks/chat/streaming/test_stage_2_parity.py +++ /dev/null @@ -1,241 +0,0 @@ -"""Parity tests for Stage 2 extractions (tool matching, thinking step, custom events).""" - -from __future__ import annotations - -from typing import Any -from unittest.mock import MagicMock - -import pytest - -from app.tasks.chat.stream_new_chat import _legacy_match_lc_id as old_legacy_match -from app.tasks.chat.streaming.handlers.custom_events import ( - handle_action_log, - handle_action_log_updated, - handle_document_created, - handle_report_progress, -) -from app.tasks.chat.streaming.helpers.tool_call_matching import ( - match_buffered_langchain_tool_call_id as new_legacy_match, -) -from app.tasks.chat.streaming.relay.state import AgentEventRelayState -from app.tasks.chat.streaming.relay.thinking_step_completion import ( - complete_active_thinking_step, -) -from app.tasks.chat.streaming.relay.thinking_step_sse import emit_thinking_step_frame - -pytestmark = pytest.mark.unit - - -def _copy_chunk_buffer(raw: list[dict[str, Any]]) -> list[dict[str, Any]]: - return [dict(x) for x in raw] - - -def test_legacy_tool_call_match_matches_old_implementation() -> None: - cases: list[tuple[list[dict[str, Any]], str, str, dict[str, str]]] = [ - ( - [ - {"name": "write_file", "id": "lc-a"}, - {"name": "other", "id": "lc-b"}, - ], - "write_file", - "run-1", - {}, - ), - ( - [{"name": "x", "id": None}, {"name": "y", "id": "lc-fallback"}], - "write_file", - "run-2", - {}, - ), - ([{"name": "no_id"}], "write_file", "run-3", {}), - ] - for chunks_template, tool_name, run_id, lc_map_seed in cases: - old_chunks = _copy_chunk_buffer(chunks_template) - new_chunks = _copy_chunk_buffer(chunks_template) - old_map = dict(lc_map_seed) - new_map = dict(lc_map_seed) - old_out = old_legacy_match(old_chunks, tool_name, run_id, old_map) - new_out = new_legacy_match(new_chunks, tool_name, run_id, new_map) - assert new_out == old_out - assert new_chunks == old_chunks - assert new_map == old_map - - -def test_emit_thinking_step_frame_invokes_builder_before_service() -> None: - order: list[str] = [] - builder = MagicMock() - - def on_ts(*args: Any, **kwargs: Any) -> None: - order.append("builder") - - builder.on_thinking_step.side_effect = on_ts - - svc = MagicMock() - - def fmt(**kwargs: Any) -> str: - order.append("service") - return "frame" - - svc.format_thinking_step.side_effect = fmt - - out = emit_thinking_step_frame( - streaming_service=svc, - content_builder=builder, - step_id="thinking-1", - title="Working", - status="in_progress", - items=["a"], - ) - assert out == "frame" - assert order == ["builder", "service"] - builder.on_thinking_step.assert_called_once() - svc.format_thinking_step.assert_called_once() - - -def test_emit_thinking_step_frame_skips_builder_when_none() -> None: - svc = MagicMock(return_value="x") - svc.format_thinking_step.return_value = "frame" - assert ( - emit_thinking_step_frame( - streaming_service=svc, - content_builder=None, - step_id="s", - title="t", - ) - == "frame" - ) - svc.format_thinking_step.assert_called_once() - - -def test_complete_active_thinking_step_mirrors_closure_semantics() -> None: - svc = MagicMock() - svc.format_thinking_step.return_value = "done-frame" - completed: set[str] = set() - relay_state = AgentEventRelayState.for_invocation() - - frame, new_id = complete_active_thinking_step( - state=relay_state, - streaming_service=svc, - content_builder=None, - last_active_step_id="thinking-1", - last_active_step_title="T", - last_active_step_items=["x"], - completed_step_ids=completed, - ) - assert frame == "done-frame" - assert new_id is None - assert "thinking-1" in completed - - frame2, id2 = complete_active_thinking_step( - state=relay_state, - streaming_service=svc, - content_builder=None, - last_active_step_id="thinking-1", - last_active_step_title="T", - last_active_step_items=[], - completed_step_ids=completed, - ) - assert frame2 is None - assert id2 == "thinking-1" - - -def test_agent_event_relay_state_factory_matches_counter_rule() -> None: - s0 = AgentEventRelayState.for_invocation() - assert s0.thinking_step_counter == 0 - assert s0.last_active_step_id is None - - s1 = AgentEventRelayState.for_invocation( - initial_step_id="thinking-resume-1", - initial_step_title="Inherited", - initial_step_items=["Topic: X"], - ) - assert s1.thinking_step_counter == 1 - assert s1.last_active_step_id == "thinking-resume-1" - assert s1.next_thinking_step_id("thinking") == "thinking-2" - - -@pytest.mark.parametrize( - ("phase", "message", "start_items", "expected_tail"), - [ - ( - "revising_section", - "progress line", - ["Topic: Foo", "Modifying bar", "stale..."], - ["Topic: Foo", "Modifying bar", "progress line"], - ), - ( - "other", - "phase msg", - ["Topic: Foo", "old line"], - ["Topic: Foo", "phase msg"], - ), - ], -) -def test_report_progress_items_match_reference( - phase: str, - message: str, - start_items: list[str], - expected_tail: list[str], -) -> None: - svc = MagicMock() - svc.format_thinking_step.return_value = "sse" - - items = list(start_items) - frame, new_items = handle_report_progress( - {"message": message, "phase": phase}, - last_active_step_id="step-1", - last_active_step_title="Report", - last_active_step_items=items, - streaming_service=svc, - content_builder=None, - ) - assert frame == "sse" - assert new_items == expected_tail - kwargs = svc.format_thinking_step.call_args.kwargs - assert kwargs["items"] == expected_tail - - -def test_report_progress_noop_when_missing_message_or_step() -> None: - svc = MagicMock() - items = ["Topic: A"] - f1, i1 = handle_report_progress( - {"message": "", "phase": "x"}, - last_active_step_id="s", - last_active_step_title="t", - last_active_step_items=items, - streaming_service=svc, - content_builder=None, - ) - assert f1 is None and i1 is items - - f2, i2 = handle_report_progress( - {"message": "m", "phase": "x"}, - last_active_step_id=None, - last_active_step_title="t", - last_active_step_items=items, - streaming_service=svc, - content_builder=None, - ) - assert f2 is None and i2 is items - - -def test_document_action_handlers_match_format_data_guards() -> None: - svc = MagicMock() - svc.format_data.return_value = "data-frame" - - assert handle_document_created({}, streaming_service=svc) is None - assert handle_document_created({"id": 0}, streaming_service=svc) is None - handle_document_created({"id": 42, "title": "x"}, streaming_service=svc) - svc.format_data.assert_called_with( - "documents-updated", {"action": "created", "document": {"id": 42, "title": "x"}} - ) - - svc.reset_mock() - assert handle_action_log({"id": None}, streaming_service=svc) is None - handle_action_log({"id": 1}, streaming_service=svc) - svc.format_data.assert_called_once_with("action-log", {"id": 1}) - - svc.reset_mock() - assert handle_action_log_updated({"id": None}, streaming_service=svc) is None - handle_action_log_updated({"id": 2}, streaming_service=svc) - svc.format_data.assert_called_once_with("action-log-updated", {"id": 2}) diff --git a/surfsense_backend/tests/unit/tasks/chat/test_extract_chunk_parts.py b/surfsense_backend/tests/unit/tasks/chat/test_extract_chunk_parts.py index 1263a5fe1..0f154f1dc 100644 --- a/surfsense_backend/tests/unit/tasks/chat/test_extract_chunk_parts.py +++ b/surfsense_backend/tests/unit/tasks/chat/test_extract_chunk_parts.py @@ -1,4 +1,4 @@ -"""Unit tests for ``stream_new_chat._extract_chunk_parts``. +"""Unit tests for ``streaming.helpers.chunk_parts.extract_chunk_parts``. Earlier versions only handled ``isinstance(chunk.content, str)`` and silently dropped every other shape (Anthropic typed-block lists, @@ -14,7 +14,9 @@ from typing import Any import pytest -from app.tasks.chat.stream_new_chat import _extract_chunk_parts +from app.tasks.chat.streaming.helpers.chunk_parts import ( + extract_chunk_parts as _extract_chunk_parts, +) @dataclass diff --git a/surfsense_backend/tests/unit/tasks/chat/test_thinking_step_id_uniqueness.py b/surfsense_backend/tests/unit/tasks/chat/test_thinking_step_id_uniqueness.py index 50f7b8070..7f8285e98 100644 --- a/surfsense_backend/tests/unit/tasks/chat/test_thinking_step_id_uniqueness.py +++ b/surfsense_backend/tests/unit/tasks/chat/test_thinking_step_id_uniqueness.py @@ -7,7 +7,7 @@ constructs a fresh :class:`AgentEventRelayState` with ``thinking_step_counter=0``), React renders sibling timeline rows with the same key — the warning the user reported in production. -The contract this module pins: each ``_stream_agent_events`` invocation must +The contract this module pins: each ``stream_agent_events`` invocation must receive a ``step_prefix`` that is unique within the thread (we salt with the per-turn ``turn_id``), so the resulting step IDs across consecutive turns are always disjoint. @@ -23,10 +23,12 @@ from typing import Any import pytest from app.services.new_streaming_service import VercelStreamingService -from app.tasks.chat.stream_new_chat import ( - StreamResult, - _resume_step_prefix, - _stream_agent_events, +from app.tasks.chat.streaming.agent.event_loop import ( + stream_agent_events as _stream_agent_events, +) +from app.tasks.chat.streaming.shared.stream_result import StreamResult +from app.tasks.chat.streaming.shared.utils import ( + resume_step_prefix as _resume_step_prefix, ) pytestmark = pytest.mark.unit diff --git a/surfsense_backend/tests/unit/tasks/chat/test_tool_input_streaming.py b/surfsense_backend/tests/unit/tasks/chat/test_tool_input_streaming.py index ada32d168..f3f28eb1c 100644 --- a/surfsense_backend/tests/unit/tasks/chat/test_tool_input_streaming.py +++ b/surfsense_backend/tests/unit/tasks/chat/test_tool_input_streaming.py @@ -1,6 +1,6 @@ """Unit tests for live tool-call argument streaming. -Pins the wire format that ``_stream_agent_events`` emits: +Pins the wire format that ``stream_agent_events`` emits: ``tool-input-start`` → ``tool-input-delta``... → ``tool-input-available`` → ``tool-output-available``, keyed consistently with LangChain ``tool_call.id`` when the model streams indexed chunks. @@ -20,11 +20,13 @@ from typing import Any import pytest from app.services.new_streaming_service import VercelStreamingService -from app.tasks.chat.stream_new_chat import ( - StreamResult, - _legacy_match_lc_id, - _stream_agent_events, +from app.tasks.chat.streaming.agent.event_loop import ( + stream_agent_events as _stream_agent_events, ) +from app.tasks.chat.streaming.helpers.tool_call_matching import ( + match_buffered_langchain_tool_call_id as _legacy_match_lc_id, +) +from app.tasks.chat.streaming.shared.stream_result import StreamResult pytestmark = pytest.mark.unit diff --git a/surfsense_backend/tests/unit/test_stream_new_chat_contract.py b/surfsense_backend/tests/unit/test_stream_new_chat_contract.py deleted file mode 100644 index 8ff576e2d..000000000 --- a/surfsense_backend/tests/unit/test_stream_new_chat_contract.py +++ /dev/null @@ -1,438 +0,0 @@ -import inspect -import json -import logging -import re -from pathlib import Path - -import pytest - -import app.tasks.chat.stream_new_chat as stream_new_chat_module -from app.agents.shared.errors import BusyError -from app.agents.shared.middleware.busy_mutex import request_cancel, reset_cancel -from app.tasks.chat.stream_new_chat import ( - StreamResult, - _classify_stream_exception, - _contract_enforcement_active, - _evaluate_file_contract_outcome, - _extract_resolved_file_path, - _log_chat_stream_error, - _tool_output_has_error, -) - -pytestmark = pytest.mark.unit - - -def test_tool_output_error_detection(): - assert _tool_output_has_error("Error: failed to write file") - assert _tool_output_has_error({"error": "boom"}) - assert _tool_output_has_error({"result": "Error: disk is full"}) - assert not _tool_output_has_error({"result": "Updated file /notes.md"}) - - -def test_extract_resolved_file_path_prefers_structured_path(): - assert ( - _extract_resolved_file_path( - tool_name="write_file", - tool_output={"status": "completed", "path": "/docs/note.md"}, - tool_input=None, - ) - == "/docs/note.md" - ) - - -def test_extract_resolved_file_path_falls_back_to_tool_input(): - assert ( - _extract_resolved_file_path( - tool_name="edit_file", - tool_output={"status": "completed", "result": "updated"}, - tool_input={"file_path": "/docs/edited.md"}, - ) - == "/docs/edited.md" - ) - - -def test_extract_resolved_file_path_does_not_parse_result_text(): - assert ( - _extract_resolved_file_path( - tool_name="write_file", - tool_output={"result": "Updated file /docs/from-text.md"}, - tool_input=None, - ) - is None - ) - - -def test_file_write_contract_outcome_reasons(): - result = StreamResult(intent_detected="file_write") - passed, reason = _evaluate_file_contract_outcome(result) - assert not passed - assert reason == "no_write_attempt" - - result.write_attempted = True - passed, reason = _evaluate_file_contract_outcome(result) - assert not passed - assert reason == "write_failed" - - result.write_succeeded = True - passed, reason = _evaluate_file_contract_outcome(result) - assert not passed - assert reason == "verification_failed" - - result.verification_succeeded = True - passed, reason = _evaluate_file_contract_outcome(result) - assert passed - assert reason == "" - - -def test_contract_enforcement_local_only(): - result = StreamResult(filesystem_mode="desktop_local_folder") - assert _contract_enforcement_active(result) - - result.filesystem_mode = "cloud" - assert not _contract_enforcement_active(result) - - -def _extract_chat_stream_payload(record_message: str) -> dict: - prefix = "[chat_stream_error] " - assert record_message.startswith(prefix) - return json.loads(record_message[len(prefix) :]) - - -def test_unified_chat_stream_error_log_schema(caplog): - with caplog.at_level(logging.INFO, logger="app.tasks.chat.stream_new_chat"): - _log_chat_stream_error( - flow="new", - error_kind="server_error", - error_code="SERVER_ERROR", - severity="warn", - is_expected=False, - request_id="req-123", - thread_id=101, - search_space_id=202, - user_id="user-1", - message="Error during chat: boom", - ) - - record = next(r for r in caplog.records if "[chat_stream_error]" in r.message) - payload = _extract_chat_stream_payload(record.message) - - required_keys = { - "event", - "flow", - "error_kind", - "error_code", - "severity", - "is_expected", - "request_id", - "thread_id", - "search_space_id", - "user_id", - "message", - } - assert required_keys.issubset(payload.keys()) - assert payload["event"] == "chat_stream_error" - assert payload["flow"] == "new" - assert payload["error_code"] == "SERVER_ERROR" - - -def test_premium_quota_uses_unified_chat_stream_log_shape(caplog): - with caplog.at_level(logging.INFO, logger="app.tasks.chat.stream_new_chat"): - _log_chat_stream_error( - flow="resume", - error_kind="premium_quota_exhausted", - error_code="PREMIUM_QUOTA_EXHAUSTED", - severity="info", - is_expected=True, - request_id="req-premium", - thread_id=303, - search_space_id=404, - user_id="user-2", - message="Buy more tokens to continue with this model, or switch to a free model", - extra={"auto_fallback": False}, - ) - - record = next(r for r in caplog.records if "[chat_stream_error]" in r.message) - payload = _extract_chat_stream_payload(record.message) - assert payload["event"] == "chat_stream_error" - assert payload["error_kind"] == "premium_quota_exhausted" - assert payload["error_code"] == "PREMIUM_QUOTA_EXHAUSTED" - assert payload["flow"] == "resume" - assert payload["is_expected"] is True - assert payload["auto_fallback"] is False - - -def test_stream_error_emission_keeps_machine_error_codes(): - source = inspect.getsource(stream_new_chat_module) - format_error_calls = re.findall(r"format_error\(", source) - emitted_error_codes = set(re.findall(r'error_code="([A-Z_]+)"', source)) - - # All stream paths should route through one shared terminal error emitter. - assert len(format_error_calls) == 1 - assert { - "PREMIUM_QUOTA_EXHAUSTED", - "SERVER_ERROR", - }.issubset(emitted_error_codes) - assert 'flow: Literal["new", "regenerate"] = "new"' in source - assert "_emit_stream_terminal_error" in source - assert "flow=flow" in source - assert 'flow="resume"' in source - - -def test_stream_exception_classifies_rate_limited(): - exc = Exception( - '{"error":{"type":"rate_limit_error","message":"Rate limited. Please try again later."}}' - ) - kind, code, severity, is_expected, user_message, extra = _classify_stream_exception( - exc, flow_label="chat" - ) - assert kind == "rate_limited" - assert code == "RATE_LIMITED" - assert severity == "warn" - assert is_expected is True - assert "temporarily rate-limited" in user_message - assert extra is None - - -def test_stream_exception_classifies_openrouter_429_payload(): - exc = Exception( - 'OpenrouterException - {"error":{"message":"Provider returned error","code":429,' - '"metadata":{"raw":"foo is temporarily rate-limited upstream"}}}' - ) - kind, code, severity, is_expected, user_message, extra = _classify_stream_exception( - exc, flow_label="chat" - ) - assert kind == "rate_limited" - assert code == "RATE_LIMITED" - assert severity == "warn" - assert is_expected is True - assert "temporarily rate-limited" in user_message - assert extra is None - - -def test_stream_exception_classifies_thread_busy(): - exc = BusyError(request_id="thread-123") - kind, code, severity, is_expected, user_message, extra = _classify_stream_exception( - exc, flow_label="chat" - ) - assert kind == "thread_busy" - assert code == "THREAD_BUSY" - assert severity == "warn" - assert is_expected is True - assert "still finishing for this thread" in user_message - assert extra is None - - -def test_stream_exception_classifies_thread_busy_from_message(): - exc = Exception("Thread is busy with another request") - kind, code, severity, is_expected, user_message, extra = _classify_stream_exception( - exc, flow_label="chat" - ) - assert kind == "thread_busy" - assert code == "THREAD_BUSY" - assert severity == "warn" - assert is_expected is True - assert "still finishing for this thread" in user_message - assert extra is None - - -def test_stream_exception_classifies_turn_cancelling_when_cancel_requested(): - thread_id = "thread-cancelling-1" - reset_cancel(thread_id) - request_cancel(thread_id) - exc = BusyError(request_id=thread_id) - kind, code, severity, is_expected, user_message, extra = _classify_stream_exception( - exc, flow_label="chat" - ) - assert kind == "thread_busy" - assert code == "TURN_CANCELLING" - assert severity == "info" - assert is_expected is True - assert "stopping" in user_message - assert isinstance(extra, dict) - assert "retry_after_ms" in extra - - -def test_premium_classification_is_error_code_driven(): - classifier_path = ( - Path(__file__).resolve().parents[3] - / "surfsense_web/lib/chat/chat-error-classifier.ts" - ) - source = classifier_path.read_text(encoding="utf-8") - - assert "PREMIUM_KEYWORDS" not in source - assert "RATE_LIMIT_KEYWORDS" not in source - assert "normalized.includes(" not in source - assert 'if (errorCode === "PREMIUM_QUOTA_EXHAUSTED") {' in source - - -def test_stream_terminal_error_handler_has_pre_accept_soft_rollback_hook(): - page_path = ( - Path(__file__).resolve().parents[3] - / "surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx" - ) - source = page_path.read_text(encoding="utf-8") - - assert "onPreAcceptFailure?: () => Promise;" in source - assert "if (!accepted) {" in source - assert "await onPreAcceptFailure?.();" in source - assert "await onAcceptedStreamError?.();" in source - assert "setMessages((prev) => prev.filter((m) => m.id !== userMsgId));" in source - assert "setMessageDocumentsMap((prev) => {" in source - - -def test_toast_only_pre_accept_policy_has_no_inline_failed_marker(): - user_message_path = ( - Path(__file__).resolve().parents[3] - / "surfsense_web/components/assistant-ui/user-message.tsx" - ) - source = user_message_path.read_text(encoding="utf-8") - - assert "Not sent. Edit and retry." not in source - assert "failed_pre_accept" not in source - - -def test_network_send_failures_use_unified_retry_toast_message(): - classifier_path = ( - Path(__file__).resolve().parents[3] - / "surfsense_web/lib/chat/chat-error-classifier.ts" - ) - classifier_source = classifier_path.read_text(encoding="utf-8") - request_errors_path = ( - Path(__file__).resolve().parents[3] - / "surfsense_web/lib/chat/chat-request-errors.ts" - ) - request_errors_source = request_errors_path.read_text(encoding="utf-8") - - assert '"send_failed_pre_accept"' in classifier_source - assert 'errorCode === "SEND_FAILED_PRE_ACCEPT"' in classifier_source - assert 'errorCode === "TURN_CANCELLING"' in classifier_source - assert "if (withCode.code) return withCode.code;" in classifier_source - assert 'userMessage: "Message not sent. Please retry."' in classifier_source - assert 'userMessage: "Connection issue. Please try again."' in classifier_source - assert "const passthroughCodes = new Set([" in request_errors_source - assert '"PREMIUM_QUOTA_EXHAUSTED"' in request_errors_source - assert '"THREAD_BUSY"' in request_errors_source - assert '"TURN_CANCELLING"' in request_errors_source - assert '"AUTH_EXPIRED"' in request_errors_source - assert '"UNAUTHORIZED"' in request_errors_source - assert '"RATE_LIMITED"' in request_errors_source - assert '"NETWORK_ERROR"' in request_errors_source - assert '"STREAM_PARSE_ERROR"' in request_errors_source - assert '"TOOL_EXECUTION_ERROR"' in request_errors_source - assert '"PERSIST_MESSAGE_FAILED"' in request_errors_source - assert '"SERVER_ERROR"' in request_errors_source - assert "passthroughCodes.has(existingCode)" in request_errors_source - assert 'errorCode: "SEND_FAILED_PRE_ACCEPT"' in request_errors_source - assert 'errorCode: "NETWORK_ERROR"' not in request_errors_source - assert "Failed to start chat. Please try again." not in classifier_source - - -def test_pre_post_accept_abort_contract_exists_for_new_resume_regenerate_flows(): - page_path = ( - Path(__file__).resolve().parents[3] - / "surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx" - ) - source = page_path.read_text(encoding="utf-8") - - # Each flow tracks accepted boundary and passes it into shared terminal handling. - # The acceptance boundary is still meaningful post-refactor: it gates - # local-state cleanup (onPreAcceptFailure path) and lets the shared - # terminal handler distinguish pre-accept aborts from in-stream errors. - assert "let newAccepted = false;" in source - assert "let resumeAccepted = false;" in source - assert "let regenerateAccepted = false;" in source - assert "accepted: newAccepted," in source - assert "accepted: resumeAccepted," in source - assert "accepted: regenerateAccepted," in source - - # NOTE: The FE-side persistence guards previously asserted here - # ("if (!resumeAccepted) return;", "if (!regenerateAccepted) return;", - # "if (newAccepted && !userPersisted) {") have been intentionally - # removed by the SSE-based message-id handshake refactor. Persistence - # is now server-authoritative: persist_user_turn / persist_assistant_shell - # run inside stream_new_chat / stream_resume_chat unconditionally and - # the FE consumes data-user-message-id / data-assistant-message-id - # SSE events to learn the canonical primary keys. There is therefore - # no FE call-site to guard, and the shared terminal handler relies - # purely on the `accepted` field above (forwarded to onAbort / - # onAcceptedStreamError) to drive UI cleanup. See - # tests/integration/chat/test_message_id_sse.py for the new - # cross-tier ID coherence guarantees. - - # The TURN_CANCELLING / THREAD_BUSY retry plumbing is independent - # of the persistence refactor and must still exist on every - # start-stream fetch. - assert "const fetchWithTurnCancellingRetry = useCallback(" in source - assert "computeFallbackTurnCancellingRetryDelay" in source - assert 'withMeta.errorCode === "TURN_CANCELLING"' in source - assert 'withMeta.errorCode === "THREAD_BUSY"' in source - assert "await fetchWithTurnCancellingRetry(() =>" in source - - -def test_cancel_active_turn_route_contract_exists(): - routes_path = ( - Path(__file__).resolve().parents[3] - / "surfsense_backend/app/routes/new_chat_routes.py" - ) - source = routes_path.read_text(encoding="utf-8") - - assert '@router.post(\n "/threads/{thread_id}/cancel-active-turn",' in source - assert "response_model=CancelActiveTurnResponse" in source - assert 'status="cancelling",' in source - assert 'error_code="TURN_CANCELLING",' in source - assert "retry_after_ms=retry_after_ms if retry_after_ms > 0 else None," in source - assert "retry_after_at=" in source - assert 'status="idle",' in source - assert 'error_code="NO_ACTIVE_TURN",' in source - - -def test_turn_status_route_contract_exists(): - routes_path = ( - Path(__file__).resolve().parents[3] - / "surfsense_backend/app/routes/new_chat_routes.py" - ) - source = routes_path.read_text(encoding="utf-8") - - assert '@router.get(\n "/threads/{thread_id}/turn-status",' in source - assert "response_model=TurnStatusResponse" in source - assert "_build_turn_status_payload(thread_id)" in source - assert "Permission.CHATS_READ.value" in source - assert "_raise_if_thread_busy_for_start(" in source - - -def test_turn_cancelling_retry_policy_contract_exists(): - routes_path = ( - Path(__file__).resolve().parents[3] - / "surfsense_backend/app/routes/new_chat_routes.py" - ) - source = routes_path.read_text(encoding="utf-8") - - assert "TURN_CANCELLING_INITIAL_DELAY_MS = 200" in source - assert "TURN_CANCELLING_BACKOFF_FACTOR = 2" in source - assert "TURN_CANCELLING_MAX_DELAY_MS = 1500" in source - assert "def _compute_turn_cancelling_retry_delay(" in source - assert "retry-after-ms" in source - assert '"Retry-After"' in source - assert '"errorCode": "TURN_CANCELLING"' in source - - -def test_turn_status_sse_contract_exists(): - stream_source = ( - Path(__file__).resolve().parents[3] - / "surfsense_backend/app/tasks/chat/stream_new_chat.py" - ).read_text(encoding="utf-8") - state_source = ( - Path(__file__).resolve().parents[3] - / "surfsense_web/lib/chat/streaming-state.ts" - ).read_text(encoding="utf-8") - pipeline_source = ( - Path(__file__).resolve().parents[3] - / "surfsense_web/lib/chat/stream-pipeline.ts" - ).read_text(encoding="utf-8") - - assert '"turn-status"' in stream_source - assert '"status": "busy"' in stream_source - assert '"status": "idle"' in stream_source - assert 'type: "data-turn-status"' in state_source - assert 'case "data-turn-status":' in pipeline_source - assert "end_turn(str(chat_id))" in stream_source