diff --git a/surfsense_backend/app/routes/anonymous_chat_routes.py b/surfsense_backend/app/routes/anonymous_chat_routes.py
index 1a283ef29..bf71a0348 100644
--- a/surfsense_backend/app/routes/anonymous_chat_routes.py
+++ b/surfsense_backend/app/routes/anonymous_chat_routes.py
@@ -356,7 +356,8 @@ async def stream_anonymous_chat(
from app.db import shielded_async_session
from app.services.new_streaming_service import VercelStreamingService
from app.services.token_tracking_service import start_turn
- from app.tasks.chat.stream_new_chat import StreamResult, _stream_agent_events
+ from app.tasks.chat.streaming.agent.event_loop import stream_agent_events
+ from app.tasks.chat.streaming.shared.stream_result import StreamResult
accumulator = start_turn()
streaming_service = VercelStreamingService()
@@ -419,7 +420,7 @@ async def stream_anonymous_chat(
stream_result = StreamResult()
- async for sse in _stream_agent_events(
+ async for sse in stream_agent_events(
agent=agent,
config=langgraph_config,
input_data=input_state,
diff --git a/surfsense_backend/app/tasks/chat/stream_new_chat.py b/surfsense_backend/app/tasks/chat/stream_new_chat.py
deleted file mode 100644
index cec13204f..000000000
--- a/surfsense_backend/app/tasks/chat/stream_new_chat.py
+++ /dev/null
@@ -1,3050 +0,0 @@
-"""
-Streaming task for the new SurfSense deep agent chat.
-
-This module streams responses from the deep agent using the Vercel AI SDK
-Data Stream Protocol (SSE format).
-
-Supports loading LLM configurations from:
-- YAML files (negative IDs for global configs)
-- NewLLMConfig database table (positive IDs for user-created configs with prompt settings)
-"""
-
-import asyncio
-import contextlib
-import gc
-import json
-import logging
-import sys
-import time
-from collections.abc import AsyncGenerator
-from dataclasses import dataclass, field
-from functools import partial
-from typing import Any, Literal
-from uuid import UUID
-
-import anyio
-from langchain_core.messages import HumanMessage
-from sqlalchemy.future import select
-
-from app.agents.multi_agent_chat import create_multi_agent_chat_deep_agent
-from app.agents.shared.checkpointer import get_checkpointer
-from app.agents.shared.context import SurfSenseContextSchema
-from app.agents.shared.errors import BusyError
-from app.agents.shared.filesystem_selection import FilesystemMode, FilesystemSelection
-from app.agents.shared.llm_config import (
- AgentConfig,
- create_chat_litellm_from_agent_config,
- create_chat_litellm_from_config,
- load_agent_config,
- load_global_llm_config_by_id,
-)
-from app.agents.shared.mention_resolver import resolve_mentions, substitute_in_text
-from app.agents.shared.middleware.busy_mutex import (
- end_turn,
- get_cancel_state,
- is_cancel_requested,
-)
-from app.agents.shared.middleware.kb_persistence import (
- commit_staged_filesystem_state,
-)
-from app.db import (
- ChatVisibility,
- NewChatMessage,
- NewChatThread,
- Report,
- SearchSourceConnectorType,
- async_session_maker,
- shielded_async_session,
-)
-from app.observability import metrics as ot_metrics, otel as ot
-from app.prompts import TITLE_GENERATION_PROMPT
-from app.services.auto_model_pin_service import (
- mark_runtime_cooldown,
- resolve_or_get_pinned_llm_config_id,
-)
-from app.services.chat_session_state_service import (
- clear_ai_responding,
- set_ai_responding,
-)
-from app.services.connector_service import ConnectorService
-from app.services.new_streaming_service import VercelStreamingService
-from app.tasks.chat.streaming.graph_stream.event_stream import stream_output
-from app.tasks.chat.streaming.helpers.interrupt_inspector import (
- all_interrupt_values,
-)
-from app.utils.content_utils import bootstrap_history_from_db
-from app.utils.perf import get_perf_logger, log_system_snapshot, trim_native_heap
-from app.utils.user_message_multimodal import build_human_message_content
-
-_background_tasks: set[asyncio.Task] = set()
-_perf_log = get_perf_logger()
-logger = logging.getLogger(__name__)
-
-TURN_CANCELLING_INITIAL_DELAY_MS = 200
-TURN_CANCELLING_BACKOFF_FACTOR = 2
-TURN_CANCELLING_MAX_DELAY_MS = 1500
-
-
-def _resume_step_prefix(turn_id: str) -> str:
- """Build the per-turn ``step_prefix`` for a resume invocation.
-
- Each ``_stream_agent_events`` call constructs a fresh
- :class:`AgentEventRelayState` with ``thinking_step_counter=0``, so two
- consecutive resume turns would otherwise both emit ``thinking-resume-1``,
- ``-2`` etc. The frontend rehydrates ``currentThinkingSteps`` from the
- immediate prior assistant message at the start of every resume — if the
- new stream's IDs collide with the seeded ones, React renders sibling
- Timeline rows with the same key. Salting with ``turn_id`` guarantees
- disjoint IDs across resumes within one thread.
- """
- return f"thinking-resume-{turn_id}"
-
-
-def _compute_turn_cancelling_retry_delay(attempt: int) -> int:
- if attempt < 1:
- attempt = 1
- delay = TURN_CANCELLING_INITIAL_DELAY_MS * (
- TURN_CANCELLING_BACKOFF_FACTOR ** (attempt - 1)
- )
- return min(delay, TURN_CANCELLING_MAX_DELAY_MS)
-
-
-def _extract_chunk_parts(chunk: Any) -> dict[str, Any]:
- """Decompose an ``AIMessageChunk`` into typed text/reasoning/tool-call parts.
-
- Returns a dict with three keys:
-
- * ``text`` — concatenated string content (empty string if the chunk
- contributes none).
- * ``reasoning`` — concatenated reasoning content (empty string if the
- chunk contributes none).
- * ``tool_call_chunks`` — flat list of LangChain ``tool_call_chunk``
- dicts surfaced from either the typed-block list or the
- ``tool_call_chunks`` attribute.
-
- Background
- ----------
- ``AIMessageChunk.content`` can be:
-
- * a ``str`` (most providers), or
- * a ``list`` of typed blocks ``{type: 'text' | 'reasoning' |
- 'tool_call_chunk' | 'tool_use' | ..., text/content/...}`` for
- Anthropic, Bedrock, and several reasoning configurations.
-
- Reasoning may also live under
- ``chunk.additional_kwargs['reasoning_content']`` (some providers
- surface it that way instead of as a typed block). Tool-call chunks
- may live under ``chunk.tool_call_chunks`` even when ``content`` is a
- plain string.
-
- Earlier versions only handled the ``isinstance(content, str)`` branch
- and silently dropped reasoning blocks + tool-call chunks emitted by
- LangChain ``AIMessageChunk``s.
- """
- out: dict[str, Any] = {"text": "", "reasoning": "", "tool_call_chunks": []}
- if chunk is None:
- return out
-
- content = getattr(chunk, "content", None)
- if isinstance(content, str):
- if content:
- out["text"] = content
- elif isinstance(content, list):
- text_parts: list[str] = []
- reasoning_parts: list[str] = []
- for block in content:
- if not isinstance(block, dict):
- continue
- block_type = block.get("type")
- if block_type == "text":
- value = block.get("text") or block.get("content") or ""
- if isinstance(value, str) and value:
- text_parts.append(value)
- elif block_type == "reasoning":
- value = (
- block.get("reasoning")
- or block.get("text")
- or block.get("content")
- or ""
- )
- if isinstance(value, str) and value:
- reasoning_parts.append(value)
- elif block_type in ("tool_call_chunk", "tool_use"):
- out["tool_call_chunks"].append(block)
- if text_parts:
- out["text"] = "".join(text_parts)
- if reasoning_parts:
- out["reasoning"] = "".join(reasoning_parts)
-
- additional = getattr(chunk, "additional_kwargs", None) or {}
- if isinstance(additional, dict):
- extra_reasoning = additional.get("reasoning_content")
- if isinstance(extra_reasoning, str) and extra_reasoning:
- existing = out["reasoning"]
- out["reasoning"] = (
- (existing + extra_reasoning) if existing else extra_reasoning
- )
-
- extra_tool_chunks = getattr(chunk, "tool_call_chunks", None)
- if isinstance(extra_tool_chunks, list):
- for tcc in extra_tool_chunks:
- if isinstance(tcc, dict):
- out["tool_call_chunks"].append(tcc)
-
- return out
-
-
-def extract_todos_from_deepagents(command_output) -> dict:
- """
- Extract todos from deepagents' TodoListMiddleware Command output.
-
- deepagents returns a Command object with:
- - Command.update['todos'] = [{'content': '...', 'status': '...'}]
-
- Returns the todos directly (no transformation needed - UI matches deepagents format).
- """
- todos_data = []
- if hasattr(command_output, "update"):
- # It's a Command object from deepagents
- update = command_output.update
- todos_data = update.get("todos", [])
- elif isinstance(command_output, dict):
- # Already a dict - check if it has todos directly or in update
- if "todos" in command_output:
- todos_data = command_output.get("todos", [])
- elif "update" in command_output and isinstance(command_output["update"], dict):
- todos_data = command_output["update"].get("todos", [])
-
- return {"todos": todos_data}
-
-
-@dataclass
-class StreamResult:
- accumulated_text: str = ""
- is_interrupted: bool = False
- sandbox_files: list[str] = field(default_factory=list)
- request_id: str | None = None
- turn_id: str = ""
- filesystem_mode: str = "cloud"
- client_platform: str = "web"
- intent_detected: str = "chat_only"
- intent_confidence: float = 0.0
- write_attempted: bool = False
- write_succeeded: bool = False
- verification_succeeded: bool = False
- commit_gate_passed: bool = True
- commit_gate_reason: str = ""
- # Pre-allocated assistant ``new_chat_messages.id`` for this turn,
- # captured by ``persist_assistant_shell`` right after the user row is
- # persisted. ``None`` for the legacy / anonymous code paths that don't
- # opt in to server-side ``ContentPart[]`` projection.
- assistant_message_id: int | None = None
- # In-memory mirror of the FE's assistant-ui ``ContentPartsState``,
- # populated by the lifecycle methods called from ``_stream_agent_events``
- # at each ``streaming_service.format_*`` yield site. Snapshot in the
- # streaming ``finally`` to produce the rich JSONB persisted by
- # ``finalize_assistant_turn``. ``repr=False`` keeps the
- # log-on-error path (``StreamResult`` is logged in some error
- # branches) from dumping a potentially-large parts list.
- content_builder: Any | None = field(default=None, repr=False)
-
-
-def _safe_float(value: Any, default: float = 0.0) -> float:
- try:
- return float(value)
- except (TypeError, ValueError):
- return default
-
-
-def _tool_output_to_text(tool_output: Any) -> str:
- if isinstance(tool_output, dict):
- if isinstance(tool_output.get("result"), str):
- return tool_output["result"]
- if isinstance(tool_output.get("error"), str):
- return tool_output["error"]
- return json.dumps(tool_output, ensure_ascii=False)
- return str(tool_output)
-
-
-def _tool_output_has_error(tool_output: Any) -> bool:
- if isinstance(tool_output, dict):
- if tool_output.get("error"):
- return True
- result = tool_output.get("result")
- return bool(
- isinstance(result, str) and result.strip().lower().startswith("error:")
- )
- if isinstance(tool_output, str):
- return tool_output.strip().lower().startswith("error:")
- return False
-
-
-def _extract_resolved_file_path(
- *, tool_name: str, tool_output: Any, tool_input: Any | None = None
-) -> str | None:
- if isinstance(tool_output, dict):
- path_value = tool_output.get("path")
- if isinstance(path_value, str) and path_value.strip():
- return path_value.strip()
- if tool_name in ("write_file", "edit_file") and isinstance(tool_input, dict):
- file_path = tool_input.get("file_path")
- if isinstance(file_path, str) and file_path.strip():
- return file_path.strip()
- return None
-
-
-def _contract_enforcement_active(result: StreamResult) -> bool:
- # Keep policy deterministic with no env-driven progression modes:
- # enforce the file-operation contract only in desktop local-folder mode.
- return result.filesystem_mode == "desktop_local_folder"
-
-
-def _evaluate_file_contract_outcome(result: StreamResult) -> tuple[bool, str]:
- if result.intent_detected != "file_write":
- return True, ""
- if not result.write_attempted:
- return False, "no_write_attempt"
- if not result.write_succeeded:
- return False, "write_failed"
- if not result.verification_succeeded:
- return False, "verification_failed"
- return True, ""
-
-
-def _log_file_contract(stage: str, result: StreamResult, **extra: Any) -> None:
- payload: dict[str, Any] = {
- "stage": stage,
- "request_id": result.request_id or "unknown",
- "turn_id": result.turn_id or "unknown",
- "chat_id": result.turn_id.split(":", 1)[0]
- if ":" in result.turn_id
- else "unknown",
- "filesystem_mode": result.filesystem_mode,
- "client_platform": result.client_platform,
- "intent_detected": result.intent_detected,
- "intent_confidence": result.intent_confidence,
- "write_attempted": result.write_attempted,
- "write_succeeded": result.write_succeeded,
- "verification_succeeded": result.verification_succeeded,
- "commit_gate_passed": result.commit_gate_passed,
- "commit_gate_reason": result.commit_gate_reason or None,
- }
- payload.update(extra)
- _perf_log.info(
- "[file_operation_contract] %s", json.dumps(payload, ensure_ascii=False)
- )
-
-
-def _log_chat_stream_error(
- *,
- flow: Literal["new", "resume", "regenerate"],
- error_kind: str,
- error_code: str | None,
- severity: Literal["info", "warn", "error"],
- is_expected: bool,
- request_id: str | None,
- thread_id: int | None,
- search_space_id: int | None,
- user_id: str | None,
- message: str,
- extra: dict[str, Any] | None = None,
-) -> None:
- payload: dict[str, Any] = {
- "event": "chat_stream_error",
- "flow": flow,
- "error_kind": error_kind,
- "error_code": error_code,
- "severity": severity,
- "is_expected": is_expected,
- "request_id": request_id or "unknown",
- "thread_id": thread_id,
- "search_space_id": search_space_id,
- "user_id": user_id,
- "message": message,
- }
- if extra:
- payload.update(extra)
-
- logger = logging.getLogger(__name__)
- rendered = json.dumps(payload, ensure_ascii=False)
- if severity == "error":
- logger.error("[chat_stream_error] %s", rendered)
- elif severity == "warn":
- logger.warning("[chat_stream_error] %s", rendered)
- else:
- logger.info("[chat_stream_error] %s", rendered)
-
-
-def _parse_error_payload(message: str) -> dict[str, Any] | None:
- candidates = [message]
- first_brace_idx = message.find("{")
- if first_brace_idx >= 0:
- candidates.append(message[first_brace_idx:])
-
- for candidate in candidates:
- try:
- parsed = json.loads(candidate)
- if isinstance(parsed, dict):
- return parsed
- except Exception:
- continue
- return None
-
-
-def _extract_provider_error_code(parsed: dict[str, Any] | None) -> int | None:
- if not isinstance(parsed, dict):
- return None
- candidates: list[Any] = [parsed.get("code")]
- nested = parsed.get("error")
- if isinstance(nested, dict):
- candidates.append(nested.get("code"))
- for value in candidates:
- try:
- if value is None:
- continue
- return int(value)
- except Exception:
- continue
- return None
-
-
-def _is_provider_rate_limited(exc: BaseException) -> bool:
- """Best-effort detection for provider-side runtime throttling.
-
- Covers LiteLLM/OpenRouter shapes like:
- - class name contains ``RateLimit``
- - nested payload ``{"error": {"code": 429}}``
- - nested payload ``{"error": {"type": "rate_limit_error"}}``
- """
- raw = str(exc)
- lowered = raw.lower()
- if "ratelimit" in type(exc).__name__.lower():
- return True
- parsed = _parse_error_payload(raw)
- provider_code = _extract_provider_error_code(parsed)
- if provider_code == 429:
- return True
-
- provider_error_type = ""
- if parsed:
- top_type = parsed.get("type")
- if isinstance(top_type, str):
- provider_error_type = top_type.lower()
- nested = parsed.get("error")
- if isinstance(nested, dict):
- nested_type = nested.get("type")
- if isinstance(nested_type, str):
- provider_error_type = nested_type.lower()
- if provider_error_type == "rate_limit_error":
- return True
-
- return (
- "rate limited" in lowered
- or "rate-limited" in lowered
- or "temporarily rate-limited upstream" in lowered
- )
-
-
-async def _build_main_agent_for_thread(
- agent_factory: Any,
- *,
- llm: Any,
- search_space_id: int,
- db_session: Any,
- connector_service: ConnectorService,
- checkpointer: Any,
- user_id: str | None,
- thread_id: int | None,
- agent_config: AgentConfig | None,
- firecrawl_api_key: str | None,
- thread_visibility: ChatVisibility | None,
- filesystem_selection: FilesystemSelection | None,
- disabled_tools: list[str] | None = None,
- mentioned_document_ids: list[int] | None = None,
-) -> Any:
- """Single (re)build path so the agent factory cannot drift across the
- initial build and mid-stream 429 recovery for one ``thread_id``: a
- graph swap mid-turn would corrupt checkpointer state."""
- return await agent_factory(
- llm=llm,
- search_space_id=search_space_id,
- db_session=db_session,
- connector_service=connector_service,
- checkpointer=checkpointer,
- user_id=user_id,
- thread_id=thread_id,
- agent_config=agent_config,
- firecrawl_api_key=firecrawl_api_key,
- thread_visibility=thread_visibility,
- filesystem_selection=filesystem_selection,
- disabled_tools=disabled_tools,
- mentioned_document_ids=mentioned_document_ids,
- )
-
-
-def _classify_stream_exception(
- exc: Exception,
- *,
- flow_label: str,
-) -> tuple[
- str, str, Literal["info", "warn", "error"], bool, str, dict[str, Any] | None
-]:
- raw = str(exc)
- if isinstance(exc, BusyError) or "Thread is busy with another request" in raw:
- busy_thread_id = str(exc.request_id) if isinstance(exc, BusyError) else None
- if busy_thread_id and is_cancel_requested(busy_thread_id):
- cancel_state = get_cancel_state(busy_thread_id)
- attempt = cancel_state[0] if cancel_state else 1
- retry_after_ms = _compute_turn_cancelling_retry_delay(attempt)
- retry_after_at = int(time.time() * 1000) + retry_after_ms
- return (
- "thread_busy",
- "TURN_CANCELLING",
- "info",
- True,
- "A previous response is still stopping. Please try again in a moment.",
- {
- "retry_after_ms": retry_after_ms,
- "retry_after_at": retry_after_at,
- },
- )
- return (
- "thread_busy",
- "THREAD_BUSY",
- "warn",
- True,
- "Another response is still finishing for this thread. Please try again in a moment.",
- None,
- )
-
- if _is_provider_rate_limited(exc):
- return (
- "rate_limited",
- "RATE_LIMITED",
- "warn",
- True,
- "This model is temporarily rate-limited. Please try again in a few seconds or switch models.",
- None,
- )
-
- return (
- "server_error",
- "SERVER_ERROR",
- "error",
- False,
- f"Error during {flow_label}: {raw}",
- None,
- )
-
-
-def _emit_stream_terminal_error(
- *,
- streaming_service: VercelStreamingService,
- flow: str,
- request_id: str | None,
- thread_id: int,
- search_space_id: int,
- user_id: str | None,
- message: str,
- error_kind: str = "server_error",
- error_code: str = "SERVER_ERROR",
- severity: Literal["info", "warn", "error"] = "error",
- is_expected: bool = False,
- extra: dict[str, Any] | None = None,
-) -> str:
- _log_chat_stream_error(
- flow=flow,
- error_kind=error_kind,
- error_code=error_code,
- severity=severity,
- is_expected=is_expected,
- request_id=request_id,
- thread_id=thread_id,
- search_space_id=search_space_id,
- user_id=user_id,
- message=message,
- extra=extra,
- )
- return streaming_service.format_error(message, error_code=error_code, extra=extra)
-
-
-def _legacy_match_lc_id(
- pending_tool_call_chunks: list[dict[str, Any]],
- tool_name: str,
- run_id: str,
- lc_tool_call_id_by_run: dict[str, str],
-) -> str | None:
- """Best-effort match a buffered ``tool_call_chunk`` to a tool name.
-
- Pure extract of the in-line match used at ``on_tool_start`` when the
- chunk path didn't register an index for this call. Pops the next
- id-bearing chunk whose ``name``
- matches ``tool_name`` (or any id-bearing chunk as a fallback) and
- returns its id. Mutates ``pending_tool_call_chunks`` and
- ``lc_tool_call_id_by_run`` in place.
- """
- matched_idx: int | None = None
- for idx, tcc in enumerate(pending_tool_call_chunks):
- if tcc.get("name") == tool_name and tcc.get("id"):
- matched_idx = idx
- break
- if matched_idx is None:
- for idx, tcc in enumerate(pending_tool_call_chunks):
- if tcc.get("id"):
- matched_idx = idx
- break
- if matched_idx is None:
- return None
- matched = pending_tool_call_chunks.pop(matched_idx)
- candidate = matched.get("id")
- if isinstance(candidate, str) and candidate:
- if run_id:
- lc_tool_call_id_by_run[run_id] = candidate
- return candidate
- return None
-
-
-async def _stream_agent_events(
- agent: Any,
- config: dict[str, Any],
- input_data: Any,
- streaming_service: VercelStreamingService,
- result: StreamResult,
- step_prefix: str = "thinking",
- initial_step_id: str | None = None,
- initial_step_title: str = "",
- initial_step_items: list[str] | None = None,
- *,
- fallback_commit_search_space_id: int | None = None,
- fallback_commit_created_by_id: str | None = None,
- fallback_commit_filesystem_mode: FilesystemMode = FilesystemMode.CLOUD,
- fallback_commit_thread_id: int | None = None,
- runtime_context: Any = None,
- content_builder: Any | None = None,
-) -> AsyncGenerator[str, None]:
- """Shared async generator that streams and formats astream_events from the agent.
-
- Yields SSE-formatted strings. After exhausting, inspect the ``result``
- object for accumulated_text and interrupt state.
-
- Args:
- agent: The compiled LangGraph agent.
- config: LangGraph config dict (must include configurable.thread_id).
- input_data: The input to pass to agent.astream_events (dict or Command).
- streaming_service: VercelStreamingService instance for formatting events.
- result: Mutable StreamResult populated with accumulated_text / interrupt info.
- step_prefix: Prefix for thinking step IDs (e.g. "thinking" or "thinking-resume").
- initial_step_id: If set, the helper inherits an already-active thinking step.
- initial_step_title: Title of the inherited thinking step.
- initial_step_items: Items of the inherited thinking step.
- content_builder: Optional ``AssistantContentBuilder``. When set, every
- ``streaming_service.format_*`` yield site also drives the matching
- builder lifecycle method (``on_text_*``, ``on_reasoning_*``,
- ``on_tool_*``, ``on_thinking_step``, ``on_step_separator``) so the
- in-memory ``ContentPart[]`` projection stays in lockstep with what
- the FE renders live. Pure in-memory accumulation — no DB I/O —
- consumed by the streaming ``finally`` to produce the rich JSONB
- persisted via ``finalize_assistant_turn``. ``None`` (the default)
- is used by the anonymous / legacy code paths and is a no-op.
-
- Yields:
- SSE-formatted strings for each event.
- """
- async for sse in stream_output(
- agent=agent,
- config=config,
- input_data=input_data,
- streaming_service=streaming_service,
- result=result,
- step_prefix=step_prefix,
- initial_step_id=initial_step_id,
- initial_step_title=initial_step_title,
- initial_step_items=initial_step_items,
- content_builder=content_builder,
- runtime_context=runtime_context,
- ):
- yield sse
-
- accumulated_text = result.accumulated_text
-
- state = await agent.aget_state(config)
- state_values = getattr(state, "values", {}) or {}
-
- # Safety net: if astream_events was cancelled before
- # KnowledgeBasePersistenceMiddleware.aafter_agent ran, any staged work
- # (dirty_paths / staged_dirs / pending_moves / pending_deletes /
- # pending_dir_deletes) will still be in the checkpointed state. Run
- # the SAME shared commit helper here so the turn's writes don't get
- # lost on client disconnect, then push the delta back into the graph
- # using `as_node=...` so reducers fire as if the after_agent hook
- # produced it.
- if (
- fallback_commit_filesystem_mode == FilesystemMode.CLOUD
- and fallback_commit_search_space_id is not None
- and (
- (state_values.get("dirty_paths") or [])
- or (state_values.get("staged_dirs") or [])
- or (state_values.get("pending_moves") or [])
- or (state_values.get("pending_deletes") or [])
- or (state_values.get("pending_dir_deletes") or [])
- )
- ):
- try:
- delta = await commit_staged_filesystem_state(
- state_values,
- search_space_id=fallback_commit_search_space_id,
- created_by_id=fallback_commit_created_by_id,
- filesystem_mode=fallback_commit_filesystem_mode,
- thread_id=fallback_commit_thread_id,
- dispatch_events=False,
- )
- if delta:
- await agent.aupdate_state(
- config,
- delta,
- as_node="KnowledgeBasePersistenceMiddleware.after_agent",
- )
- except Exception as exc:
- _perf_log.warning("[stream_new_chat] safety-net commit failed: %s", exc)
-
- contract_state = state_values.get("file_operation_contract") or {}
- contract_turn_id = contract_state.get("turn_id")
- current_turn_id = config.get("configurable", {}).get("turn_id", "")
- intent_value = contract_state.get("intent")
- if (
- isinstance(intent_value, str)
- and intent_value in ("chat_only", "file_write", "file_read")
- and contract_turn_id == current_turn_id
- ):
- result.intent_detected = intent_value
- if (
- isinstance(intent_value, str)
- and intent_value
- in (
- "chat_only",
- "file_write",
- "file_read",
- )
- and contract_turn_id != current_turn_id
- ):
- # Ignore stale intent contracts from previous turns/checkpoints.
- result.intent_detected = "chat_only"
- result.intent_confidence = (
- _safe_float(contract_state.get("confidence"), default=0.0)
- if contract_turn_id == current_turn_id
- else 0.0
- )
-
- if result.intent_detected == "file_write":
- result.commit_gate_passed, result.commit_gate_reason = (
- _evaluate_file_contract_outcome(result)
- )
- if not result.commit_gate_passed and _contract_enforcement_active(result):
- gate_notice = (
- "I could not complete the requested file write because no successful "
- "write_file/edit_file operation was confirmed."
- )
- gate_text_id = streaming_service.generate_text_id()
- yield streaming_service.format_text_start(gate_text_id)
- if content_builder is not None:
- content_builder.on_text_start(gate_text_id)
- yield streaming_service.format_text_delta(gate_text_id, gate_notice)
- if content_builder is not None:
- content_builder.on_text_delta(gate_text_id, gate_notice)
- yield streaming_service.format_text_end(gate_text_id)
- if content_builder is not None:
- content_builder.on_text_end(gate_text_id)
- yield streaming_service.format_terminal_info(gate_notice, "error")
- accumulated_text = gate_notice
- else:
- result.commit_gate_passed = True
- result.commit_gate_reason = ""
-
- result.accumulated_text = accumulated_text
- _log_file_contract("turn_outcome", result)
-
- pending_values = all_interrupt_values(state)
- if pending_values:
- result.is_interrupted = True
- # One frame per paused subagent so each parallel HITL renders its own
- # approval card on the wire. Order matches ``state.interrupts``, which
- # the resume slicer in ``checkpointed_subagent_middleware.resume_routing``
- # consumes in the same order — keeping emit and resume in lock-step.
- for interrupt_value in pending_values:
- yield streaming_service.format_interrupt_request(interrupt_value)
-
-
-async def stream_new_chat(
- user_query: str,
- search_space_id: int,
- chat_id: int,
- user_id: str | None = None,
- llm_config_id: int = -1,
- mentioned_document_ids: list[int] | None = None,
- mentioned_folder_ids: list[int] | None = None,
- mentioned_connector_ids: list[int] | None = None,
- mentioned_connectors: list[dict[str, Any]] | None = None,
- mentioned_documents: list[dict[str, Any]] | None = None,
- checkpoint_id: str | None = None,
- needs_history_bootstrap: bool = False,
- thread_visibility: ChatVisibility | None = None,
- current_user_display_name: str | None = None,
- disabled_tools: list[str] | None = None,
- filesystem_selection: FilesystemSelection | None = None,
- request_id: str | None = None,
- user_image_data_urls: list[str] | None = None,
- flow: Literal["new", "regenerate"] = "new",
-) -> AsyncGenerator[str, None]:
- """
- Stream chat responses from the new SurfSense deep agent.
-
- This uses the Vercel AI SDK Data Stream Protocol (SSE format) for streaming.
- The chat_id is used as LangGraph's thread_id for memory/checkpointing.
-
- The function creates and manages its own database session to guarantee proper
- cleanup even when Starlette's middleware cancels the task on client disconnect.
-
- Args:
- user_query: The user's query
- search_space_id: The search space ID
- chat_id: The chat ID (used as LangGraph thread_id for memory)
- user_id: The current user's UUID string (for memory tools and session state)
- llm_config_id: The LLM configuration ID (default: -1 for first global config)
- needs_history_bootstrap: If True, load message history from DB (for cloned chats)
- mentioned_document_ids: Optional list of document IDs mentioned with @ in the chat
- mentioned_folder_ids: Optional list of knowledge-base folder IDs mentioned with @ (cloud mode)
- checkpoint_id: Optional checkpoint ID to rewind/fork from (for edit/reload operations)
-
- Yields:
- str: SSE formatted response strings
- """
- streaming_service = VercelStreamingService()
- stream_result = StreamResult()
- _t_total = time.perf_counter()
- fs_mode = filesystem_selection.mode.value if filesystem_selection else "cloud"
- fs_platform = (
- filesystem_selection.client_platform.value if filesystem_selection else "web"
- )
- stream_result.request_id = request_id
- stream_result.turn_id = f"{chat_id}:{int(time.time() * 1000)}"
- stream_result.filesystem_mode = fs_mode
- stream_result.client_platform = fs_platform
- chat_agent_mode = "unknown"
- chat_outcome = "success"
- chat_error_category: str | None = None
- chat_span_cm = ot.chat_request_span(
- chat_id=chat_id,
- search_space_id=search_space_id,
- flow=flow,
- request_id=request_id,
- turn_id=stream_result.turn_id,
- filesystem_mode=fs_mode,
- client_platform=fs_platform,
- agent_mode=chat_agent_mode,
- )
- chat_span = chat_span_cm.__enter__()
- _log_file_contract("turn_start", stream_result)
- _perf_log.info(
- "[stream_new_chat] filesystem_mode=%s client_platform=%s",
- fs_mode,
- fs_platform,
- )
- log_system_snapshot("stream_new_chat_START")
-
- from app.services.token_tracking_service import start_turn
-
- accumulator = start_turn()
-
- # Premium credit (USD micro-units) tracking state. Stores the
- # amount reserved up front so we can release it on cancellation
- # and finalize-debit the actual provider cost reported by LiteLLM.
- _premium_reserved_micros = 0
- _premium_request_id: str | None = None
-
- # ``BusyError`` fires before the lock is acquired; the ``finally`` must
- # not release the in-flight caller's lock.
- _busy_error_raised = False
-
- _emit_stream_error = partial(
- _emit_stream_terminal_error,
- streaming_service=streaming_service,
- flow=flow,
- request_id=request_id,
- thread_id=chat_id,
- search_space_id=search_space_id,
- user_id=user_id,
- )
-
- session = async_session_maker()
- try:
- # Mark AI as responding to this user for live collaboration
- if user_id:
- await set_ai_responding(session, chat_id, UUID(user_id))
- # Load LLM config - supports both YAML (negative IDs) and database (positive IDs)
- agent_config: AgentConfig | None = None
- requested_llm_config_id = llm_config_id
-
- async def _load_llm_bundle(
- config_id: int,
- ) -> tuple[Any, AgentConfig | None, str | None]:
- if config_id >= 0:
- loaded_agent_config = await load_agent_config(
- session=session,
- config_id=config_id,
- search_space_id=search_space_id,
- )
- if not loaded_agent_config:
- return (
- None,
- None,
- f"Failed to load NewLLMConfig with id {config_id}",
- )
- return (
- create_chat_litellm_from_agent_config(loaded_agent_config),
- loaded_agent_config,
- None,
- )
-
- loaded_llm_config = load_global_llm_config_by_id(config_id)
- if not loaded_llm_config:
- return None, None, f"Failed to load LLM config with id {config_id}"
- return (
- create_chat_litellm_from_config(loaded_llm_config),
- AgentConfig.from_yaml_config(loaded_llm_config),
- None,
- )
-
- _t0 = time.perf_counter()
- # Image-bearing turns force the Auto-pin resolver to filter the
- # candidate pool to vision-capable cfgs (and force-repin a
- # text-only existing pin). For explicit selections this flag is
- # a no-op — the resolver returns the user's chosen id unchanged.
- _requires_image_input = bool(user_image_data_urls)
- try:
- llm_config_id = (
- await resolve_or_get_pinned_llm_config_id(
- session,
- thread_id=chat_id,
- search_space_id=search_space_id,
- user_id=user_id,
- selected_llm_config_id=llm_config_id,
- requires_image_input=_requires_image_input,
- )
- ).resolved_llm_config_id
- ot.add_event(
- "model.pin.resolved",
- {
- "pin.requested_id": requested_llm_config_id,
- "pin.resolved_id": llm_config_id,
- "pin.requires_image_input": _requires_image_input,
- },
- )
- except ValueError as pin_error:
- # Auto-pin's "no vision-capable cfg" path raises a ValueError
- # whose message we map to the friendly image-input SSE error
- # so the user sees the same message regardless of whether
- # the gate fired in Auto-mode or in the agent_config check
- # below.
- error_code = (
- "MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT"
- if _requires_image_input and "vision-capable" in str(pin_error)
- else "SERVER_ERROR"
- )
- error_kind = (
- "user_error"
- if error_code == "MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT"
- else "server_error"
- )
- if error_code == "MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT":
- ot.add_event(
- "quota.denied",
- {
- "quota.code": error_code,
- },
- )
- yield _emit_stream_error(
- message=str(pin_error),
- error_kind=error_kind,
- error_code=error_code,
- )
- yield streaming_service.format_done()
- return
-
- llm, agent_config, llm_load_error = await _load_llm_bundle(llm_config_id)
- if llm_load_error:
- yield _emit_stream_error(
- message=llm_load_error,
- error_kind="server_error",
- error_code="SERVER_ERROR",
- )
- yield streaming_service.format_done()
- return
- _perf_log.info(
- "[stream_new_chat] LLM config loaded in %.3fs (config_id=%s)",
- time.perf_counter() - _t0,
- llm_config_id,
- )
-
- # Capability safety net: a turn carrying user-uploaded images
- # cannot be routed to a chat config that LiteLLM's authoritative
- # model map *explicitly* marks as text-only (``supports_vision``
- # set to False). The check is intentionally narrow — it only
- # fires when LiteLLM is *certain* the model can't accept image
- # input. Unknown / unmapped / vision-capable models pass
- # through. Without this guard a known-text-only model would 404
- # at the provider with ``"No endpoints found that support image
- # input"``, surfacing as an opaque ``SERVER_ERROR`` SSE chunk;
- # failing here lets us return a friendly message that tells the
- # user what to change.
- if user_image_data_urls and agent_config is not None:
- from app.services.provider_capabilities import (
- is_known_text_only_chat_model,
- )
-
- agent_litellm_params = agent_config.litellm_params or {}
- agent_base_model = (
- agent_litellm_params.get("base_model")
- if isinstance(agent_litellm_params, dict)
- else None
- )
- if is_known_text_only_chat_model(
- provider=agent_config.provider,
- model_name=agent_config.model_name,
- base_model=agent_base_model,
- custom_provider=agent_config.custom_provider,
- ):
- model_label = (
- agent_config.config_name or agent_config.model_name or "model"
- )
- ot.add_event(
- "quota.denied",
- {
- "quota.code": "MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT",
- },
- )
- yield _emit_stream_error(
- message=(
- f"The selected model ({model_label}) does not support "
- "image input. Switch to a vision-capable model "
- "(e.g. GPT-4o, Claude, Gemini) or remove the image "
- "attachment and try again."
- ),
- error_kind="user_error",
- error_code="MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT",
- )
- yield streaming_service.format_done()
- return
-
- # Premium quota reservation for pinned premium model only.
- _needs_premium_quota = (
- agent_config is not None and user_id and agent_config.is_premium
- )
- if _needs_premium_quota:
- import uuid as _uuid
-
- from app.services.token_quota_service import (
- TokenQuotaService,
- estimate_call_reserve_micros,
- )
-
- _premium_request_id = _uuid.uuid4().hex[:16]
- _agent_litellm_params = agent_config.litellm_params or {}
- _agent_base_model = (
- _agent_litellm_params.get("base_model") or agent_config.model_name or ""
- )
- reserve_amount_micros = estimate_call_reserve_micros(
- base_model=_agent_base_model,
- quota_reserve_tokens=agent_config.quota_reserve_tokens,
- )
- async with shielded_async_session() as quota_session:
- quota_result = await TokenQuotaService.premium_reserve(
- db_session=quota_session,
- user_id=UUID(user_id),
- request_id=_premium_request_id,
- reserve_micros=reserve_amount_micros,
- )
- _premium_reserved_micros = reserve_amount_micros
- if not quota_result.allowed:
- ot.add_event(
- "quota.denied",
- {
- "quota.code": "PREMIUM_QUOTA_EXHAUSTED",
- },
- )
- if requested_llm_config_id == 0:
- try:
- llm_config_id = (
- await resolve_or_get_pinned_llm_config_id(
- session,
- thread_id=chat_id,
- search_space_id=search_space_id,
- user_id=user_id,
- selected_llm_config_id=0,
- force_repin_free=True,
- requires_image_input=_requires_image_input,
- )
- ).resolved_llm_config_id
- ot.add_event(
- "model.repin",
- {
- "repin.reason": "premium_quota_exhausted",
- "repin.to_config_id": llm_config_id,
- },
- )
- except ValueError as pin_error:
- yield _emit_stream_error(
- message=str(pin_error),
- error_kind="server_error",
- error_code="SERVER_ERROR",
- )
- yield streaming_service.format_done()
- return
-
- llm, agent_config, llm_load_error = await _load_llm_bundle(
- llm_config_id
- )
- if llm_load_error:
- yield _emit_stream_error(
- message=llm_load_error,
- error_kind="server_error",
- error_code="SERVER_ERROR",
- )
- yield streaming_service.format_done()
- return
- _premium_request_id = None
- _premium_reserved_micros = 0
- _log_chat_stream_error(
- flow=flow,
- error_kind="premium_quota_exhausted",
- error_code="PREMIUM_QUOTA_EXHAUSTED",
- severity="info",
- is_expected=True,
- request_id=request_id,
- thread_id=chat_id,
- search_space_id=search_space_id,
- user_id=user_id,
- message=(
- "Premium quota exhausted on pinned model; auto-fallback switched to a free model"
- ),
- extra={
- "fallback_config_id": llm_config_id,
- "auto_fallback": True,
- },
- )
- else:
- yield _emit_stream_error(
- message=(
- "Buy more tokens to continue with this model, or switch to a free model"
- ),
- error_kind="premium_quota_exhausted",
- error_code="PREMIUM_QUOTA_EXHAUSTED",
- severity="info",
- is_expected=True,
- extra={
- "resolved_config_id": llm_config_id,
- "auto_fallback": False,
- },
- )
- yield streaming_service.format_done()
- return
-
- if not llm:
- yield _emit_stream_error(
- message="Failed to create LLM instance",
- error_kind="server_error",
- error_code="SERVER_ERROR",
- )
- yield streaming_service.format_done()
- return
-
- # Create connector service
- _t0 = time.perf_counter()
- connector_service = ConnectorService(session, search_space_id=search_space_id)
-
- firecrawl_api_key = None
- webcrawler_connector = await connector_service.get_connector_by_type(
- SearchSourceConnectorType.WEBCRAWLER_CONNECTOR, search_space_id
- )
- if webcrawler_connector and webcrawler_connector.config:
- firecrawl_api_key = webcrawler_connector.config.get("FIRECRAWL_API_KEY")
- _perf_log.info(
- "[stream_new_chat] Connector service + firecrawl key in %.3fs",
- time.perf_counter() - _t0,
- )
-
- # Get the PostgreSQL checkpointer for persistent conversation memory
- _t0 = time.perf_counter()
- checkpointer = await get_checkpointer()
- _perf_log.info(
- "[stream_new_chat] Checkpointer ready in %.3fs", time.perf_counter() - _t0
- )
-
- visibility = thread_visibility or ChatVisibility.PRIVATE
- chat_agent_mode = "multi"
- with contextlib.suppress(Exception):
- chat_span.set_attribute("agent.mode", chat_agent_mode)
-
- _t0 = time.perf_counter()
- agent_factory = create_multi_agent_chat_deep_agent
- # Build the agent inline. Provider 429s surface through the
- # in-stream recovery loop below (``_is_provider_rate_limited``),
- # which repins the thread to an eligible alternative config and
- # rebuilds the agent before the user sees any output.
- agent = await _build_main_agent_for_thread(
- agent_factory,
- llm=llm,
- search_space_id=search_space_id,
- db_session=session,
- connector_service=connector_service,
- checkpointer=checkpointer,
- user_id=user_id,
- thread_id=chat_id,
- agent_config=agent_config,
- firecrawl_api_key=firecrawl_api_key,
- thread_visibility=visibility,
- filesystem_selection=filesystem_selection,
- disabled_tools=disabled_tools,
- mentioned_document_ids=mentioned_document_ids,
- )
- _perf_log.info(
- "[stream_new_chat] Agent created in %.3fs", time.perf_counter() - _t0
- )
-
- # Build input with message history
- langchain_messages = []
-
- _t0 = time.perf_counter()
- # Bootstrap history for cloned chats (no LangGraph checkpoint exists yet)
- if needs_history_bootstrap:
- langchain_messages = await bootstrap_history_from_db(
- session, chat_id, thread_visibility=visibility
- )
-
- thread_result = await session.execute(
- select(NewChatThread).filter(NewChatThread.id == chat_id)
- )
- thread = thread_result.scalars().first()
- if thread:
- thread.needs_history_bootstrap = False
- await session.commit()
-
- # Mentioned KB documents are now handled by KnowledgeBaseSearchMiddleware
- # which merges them into the scoped filesystem with full document
- # structure. Only report context is inlined here.
-
- # Fetch the most recent report(s) in this thread so the LLM can
- # easily find report_id for versioning decisions, instead of
- # having to dig through conversation history.
- recent_reports_result = await session.execute(
- select(Report)
- .filter(
- Report.thread_id == chat_id,
- Report.content.isnot(None), # exclude failed reports
- )
- .order_by(Report.id.desc())
- .limit(3)
- )
- recent_reports = list(recent_reports_result.scalars().all())
-
- # Resolve @-mention chips to canonical virtual paths and rewrite
- # the user-typed text so the LLM sees ``\`/documents/...\``` instead
- # of bare ``@title``. The substitution lands in ``agent_user_query``
- # ONLY — the original ``user_query`` (with ``@title`` tokens) flows
- # untouched into ``persist_user_turn`` below so chip rendering on
- # reload still works (``UserTextPart`` → ``parseMentionSegments``
- # matches ``@title``, not ``\`/documents/...\```). It also feeds
- # the human-readable surfaces — SSE "Processing X" status, auto
- # thread title, memory seed — which all want what the user typed.
- # See ``persistence._build_user_content``.
- #
- # Cloud mode only: local-folder mode keeps the legacy
- # ``@title`` text path; mention support there is a follow-up
- # task because the path scheme (mount-rooted) and the picker
- # UI both need separate work.
- agent_user_query = user_query
- accepted_folder_ids: list[int] = []
- if fs_mode == FilesystemMode.CLOUD.value and (
- mentioned_document_ids or mentioned_folder_ids or mentioned_documents
- ):
- from app.schemas.new_chat import (
- MentionedDocumentInfo as _MentionedDocumentInfo,
- )
-
- chip_objs: list[_MentionedDocumentInfo] | None = None
- if mentioned_documents:
- chip_objs = []
- for raw in mentioned_documents:
- if isinstance(raw, _MentionedDocumentInfo):
- chip_objs.append(raw)
- continue
- try:
- chip_objs.append(_MentionedDocumentInfo.model_validate(raw))
- except Exception:
- logger.debug(
- "stream_new_chat: dropping malformed mention chip %r",
- raw,
- )
-
- resolved = await resolve_mentions(
- session,
- search_space_id=search_space_id,
- mentioned_documents=chip_objs,
- mentioned_document_ids=mentioned_document_ids,
- mentioned_folder_ids=mentioned_folder_ids,
- )
- agent_user_query = substitute_in_text(user_query, resolved.token_to_path)
- accepted_folder_ids = resolved.mentioned_folder_ids
-
- # Format the user query with context (reports only).
- # Uses ``agent_user_query`` so the LLM sees backtick-wrapped paths
- # instead of bare ``@title`` tokens.
- final_query = agent_user_query
- context_parts = []
-
- if mentioned_connectors:
- connector_lines = []
- for connector in mentioned_connectors:
- if not isinstance(connector, dict):
- continue
- connector_id = connector.get("id")
- connector_type = connector.get("connector_type") or connector.get(
- "document_type"
- )
- account_name = connector.get("account_name") or connector.get("title")
- if connector_id is None or connector_type is None:
- continue
- connector_lines.append(
- f' - connector_id={connector_id}, connector_type="{connector_type}", '
- f'account_name="{account_name or ""}"'
- )
- if connector_lines:
- context_parts.append(
- "\n"
- "The user selected these exact connector accounts with @. "
- "These entries are selection metadata, not retrieved connector content. "
- "When a connector-backed tool needs an account, use the matching "
- "connector_id from this list if the tool supports connector_id:\n"
- + "\n".join(connector_lines)
- + "\n"
- )
-
- # Surface report IDs prominently so the LLM doesn't have to
- # retrieve them from old tool responses in conversation history.
- if recent_reports:
- report_lines = []
- for r in recent_reports:
- report_lines.append(
- f' - report_id={r.id}, title="{r.title}", '
- f'style="{r.report_style or "detailed"}"'
- )
- reports_listing = "\n".join(report_lines)
- context_parts.append(
- "\n"
- "Previously generated reports in this conversation:\n"
- f"{reports_listing}\n\n"
- "If the user wants to MODIFY, REVISE, UPDATE, or ADD to one of "
- "these reports, set parent_report_id to the relevant report_id above.\n"
- "If the user wants a completely NEW report on a different topic, "
- "leave parent_report_id unset.\n"
- ""
- )
-
- if context_parts:
- context = "\n\n".join(context_parts)
- final_query = f"{context}\n\n{agent_user_query}"
-
- if visibility == ChatVisibility.SEARCH_SPACE and current_user_display_name:
- final_query = f"**[{current_user_display_name}]:** {final_query}"
-
- # if messages:
- # # Convert frontend messages to LangChain format
- # for msg in messages:
- # if msg.role == "user":
- # langchain_messages.append(HumanMessage(content=msg.content))
- # elif msg.role == "assistant":
- # langchain_messages.append(AIMessage(content=msg.content))
- # else:
- human_content = build_human_message_content(
- final_query, list(user_image_data_urls or ())
- )
- langchain_messages.append(HumanMessage(content=human_content))
-
- input_state = {
- # Lets not pass this message atm because we are using the checkpointer to manage the conversation history
- # We will use this to simulate group chat functionality in the future
- "messages": langchain_messages,
- "search_space_id": search_space_id,
- "request_id": request_id or "unknown",
- "turn_id": stream_result.turn_id,
- }
-
- _perf_log.info(
- "[stream_new_chat] History bootstrap + doc/report queries in %.3fs",
- time.perf_counter() - _t0,
- )
-
- # All pre-streaming DB reads are done. Commit to release the
- # transaction and its ACCESS SHARE locks so we don't block DDL
- # (e.g. migrations) for the entire duration of LLM streaming.
- # Tools that need DB access during streaming will start their own
- # short-lived transactions (or use isolated sessions).
- await session.commit()
-
- # Detach heavy ORM objects (documents with chunks, reports, etc.)
- # from the session identity map now that we've extracted the data
- # we need. This prevents them from accumulating in memory for the
- # entire duration of LLM streaming (which can be several minutes).
- session.expunge_all()
-
- _perf_log.info(
- "[stream_new_chat] Total pre-stream setup in %.3fs (chat_id=%s)",
- time.perf_counter() - _t_total,
- chat_id,
- )
-
- # Configure LangGraph with thread_id for memory
- # If checkpoint_id is provided, fork from that checkpoint (for edit/reload)
- configurable = {"thread_id": str(chat_id)}
- configurable["request_id"] = request_id or "unknown"
- configurable["turn_id"] = stream_result.turn_id
- if checkpoint_id:
- configurable["checkpoint_id"] = checkpoint_id
-
- config = {
- "configurable": configurable,
- # Effectively uncapped, matching the agent-level
- # ``with_config`` default in ``chat_deepagent.create_agent``
- # and the unbounded ``while(true)`` loop used by OpenCode's
- # ``session/processor.ts``. Real circuit-breakers live in
- # middleware: ``DoomLoopMiddleware`` (sliding-window tool
- # signature check), plus ``enable_tool_call_limit`` /
- # ``enable_model_call_limit`` when those flags are set. The
- # original LangGraph default of 25 (and our previous 80
- # bump) hit users on legitimate multi-tool plans.
- "recursion_limit": 10_000,
- }
-
- # Start the message stream
- yield streaming_service.format_message_start()
- yield streaming_service.format_start_step()
-
- # Surface the per-turn correlation id at the very start of the
- # stream so the frontend can stamp it onto the in-flight
- # assistant message and replay it via ``appendMessage``
- # for durable storage. Tool/action-log events DO carry it later,
- # but pure-text turns never produce action-log events; this
- # event guarantees the frontend learns the turn id regardless.
- yield streaming_service.format_data(
- "turn-info",
- {"chat_turn_id": stream_result.turn_id},
- )
- yield streaming_service.format_data("turn-status", {"status": "busy"})
-
- # Persist the user-side row for this turn before any expensive
- # work runs. Closes the "ghost-thread" abuse vector
- # (authenticated client hits POST /new_chat then never calls
- # /messages — empty new_chat_messages, free LLM completion).
- # Idempotent against the unique index in migration 141 so the
- # legacy frontend appendMessage call is a no-op on the second
- # writer. Hard failure aborts the turn so we never produce a
- # title or assistant row that isn't anchored to a persisted
- # user message.
- from app.tasks.chat.content_builder import AssistantContentBuilder
- from app.tasks.chat.persistence import (
- persist_assistant_shell,
- persist_user_turn,
- )
-
- user_message_id = await persist_user_turn(
- chat_id=chat_id,
- user_id=user_id,
- turn_id=stream_result.turn_id,
- user_query=user_query,
- user_image_data_urls=user_image_data_urls,
- mentioned_documents=mentioned_documents,
- )
- if user_message_id is None:
- yield _emit_stream_error(
- message=(
- "We couldn't save your message. Please try again in a moment."
- ),
- error_kind="server_error",
- error_code="MESSAGE_PERSIST_FAILED",
- )
- yield streaming_service.format_data("turn-status", {"status": "idle"})
- yield streaming_service.format_finish_step()
- yield streaming_service.format_finish()
- yield streaming_service.format_done()
- return
-
- # Emit canonical user message id BEFORE any LLM streaming so the
- # FE can rename its optimistic ``msg-user-XXX`` placeholder to
- # ``msg-{user_message_id}`` and unlock features gated on a real
- # DB id (comments, edit-from-this-message). See B4 in
- # ``sse-based_message_id_handshake`` plan.
- yield streaming_service.format_data(
- "user-message-id",
- {"message_id": user_message_id, "turn_id": stream_result.turn_id},
- )
-
- # Pre-write the assistant row for this turn so we have a stable
- # ``message_id`` to anchor mid-stream metadata (token_usage,
- # future agent_action_log.message_id correlation) and a
- # write-once UPDATE target at finalize time. Idempotent against
- # the (thread_id, turn_id, ASSISTANT) partial unique index from
- # migration 141 — if the legacy frontend appendMessage races
- # this, we recover the existing row's id.
- assistant_message_id = await persist_assistant_shell(
- chat_id=chat_id,
- user_id=user_id,
- turn_id=stream_result.turn_id,
- )
- if assistant_message_id is None:
- # Genuine DB failure — abort the turn rather than stream
- # into a void. The user row is already persisted so the
- # legacy "ghost-thread" gate isn't reopened.
- yield _emit_stream_error(
- message=(
- "We couldn't initialize the assistant message. Please try again."
- ),
- error_kind="server_error",
- error_code="MESSAGE_PERSIST_FAILED",
- )
- yield streaming_service.format_data("turn-status", {"status": "idle"})
- yield streaming_service.format_finish_step()
- yield streaming_service.format_finish()
- yield streaming_service.format_done()
- return
-
- # Emit canonical assistant message id BEFORE any LLM streaming
- # so the FE can rename its optimistic ``msg-assistant-XXX``
- # placeholder to ``msg-{assistant_message_id}`` and bind
- # ``tokenUsageStore`` / ``pendingInterrupt`` to the real id
- # immediately. See B4 in ``sse-based_message_id_handshake``
- # plan.
- yield streaming_service.format_data(
- "assistant-message-id",
- {"message_id": assistant_message_id, "turn_id": stream_result.turn_id},
- )
-
- stream_result.assistant_message_id = assistant_message_id
- stream_result.content_builder = AssistantContentBuilder()
-
- # Initial thinking step - analyzing the request
- initial_title = "Understanding your request"
- action_verb = "Processing"
-
- processing_parts = []
- if user_query.strip():
- query_text = user_query[:80] + ("..." if len(user_query) > 80 else "")
- processing_parts.append(query_text)
- elif user_image_data_urls:
- processing_parts.append(f"[{len(user_image_data_urls)} image(s)]")
- else:
- processing_parts.append("(message)")
-
- initial_items = [f"{action_verb}: {' '.join(processing_parts)}"]
- initial_step_id = "thinking-1"
-
- # Drive the builder for this initial thinking step too — the
- # ``_emit_thinking_step`` helper lives inside ``_stream_agent_events``
- # so it isn't in scope here, but the FE folds this step into
- # the same singleton ``data-thinking-steps`` part as everything
- # the agent stream emits later. Mirror that fold server-side.
- if stream_result.content_builder is not None:
- stream_result.content_builder.on_thinking_step(
- initial_step_id, initial_title, "in_progress", initial_items
- )
- yield streaming_service.format_thinking_step(
- step_id=initial_step_id,
- title=initial_title,
- status="in_progress",
- items=initial_items,
- )
-
- # These ORM objects can be large. They're only needed to build context
- # strings already copied into final_query / langchain_messages —
- # release them before streaming.
- del recent_reports
- del langchain_messages, final_query
-
- # Check if this is the first assistant response so we can generate
- # a title in parallel with the agent stream (better UX than waiting
- # until after the full response).
- # Use a LIMIT 1 EXISTS-style probe rather than COUNT(*) because
- # this is now a hot path executed on every turn, and COUNT scales
- # with thread length (server-side persistence can grow rows
- # quickly under power users).
- #
- # IMPORTANT: ``persist_assistant_shell`` above (line ~3112) already
- # inserted THIS turn's assistant row. We must therefore exclude
- # it from the probe — otherwise the gate fires on every turn
- # except the very first, and title generation never runs for new
- # threads. Excluding by primary key (``id != assistant_message_id``)
- # is bulletproof regardless of ``turn_id`` shape (legacy NULLs,
- # resume turns, etc.).
- first_assistant_probe = await session.execute(
- select(NewChatMessage.id)
- .filter(
- NewChatMessage.thread_id == chat_id,
- NewChatMessage.role == "assistant",
- NewChatMessage.id != assistant_message_id,
- )
- .limit(1)
- )
- is_first_response = first_assistant_probe.scalars().first() is None
-
- title_task: asyncio.Task[tuple[str | None, dict | None]] | None = None
- # Gate title generation on a persisted user message so a stream
- # that fails before persistence (we abort above) can never leave
- # behind a thread with a generated title and no anchoring rows.
- if is_first_response and user_message_id is not None:
-
- async def _generate_title() -> tuple[str | None, dict | None]:
- """Generate a short title via litellm.acompletion.
-
- Returns (title, usage_dict). Usage is extracted directly from
- the response object because litellm fires its async callback
- via fire-and-forget ``create_task``, so the
- ``TokenTrackingCallback`` would run too late. We also blank
- the accumulator in this child-task context so the late callback
- doesn't double-count.
- """
- try:
- from litellm import acompletion
-
- from app.services.llm_router_service import LLMRouterService
- from app.services.provider_api_base import resolve_api_base
- from app.services.token_tracking_service import _turn_accumulator
-
- _turn_accumulator.set(None)
-
- title_seed = user_query.strip() or (
- f"[{len(user_image_data_urls or [])} image(s)]"
- if user_image_data_urls
- else ""
- )
- prompt = TITLE_GENERATION_PROMPT.replace(
- "{user_query}", title_seed[:500] or "(message)"
- )
- messages = [{"role": "user", "content": prompt}]
-
- if getattr(llm, "model", None) == "auto":
- router = LLMRouterService.get_router()
- response = await router.acompletion(
- model="auto", messages=messages
- )
- else:
- # Apply the same ``api_base`` cascade chat / vision /
- # image-gen call sites use so we never inherit
- # ``litellm.api_base`` (commonly set by
- # ``AZURE_OPENAI_ENDPOINT``) when the chat config
- # itself ships an empty ``api_base``. Without this
- # the title-gen on an OpenRouter chat config would
- # 404 against the inherited Azure endpoint — see
- # ``provider_api_base`` docstring for the same
- # bug repro on the image-gen / vision paths.
- raw_model = getattr(llm, "model", "") or ""
- provider_prefix = (
- raw_model.split("/", 1)[0] if "/" in raw_model else None
- )
- provider_value = (
- agent_config.provider if agent_config is not None else None
- )
- title_api_base = resolve_api_base(
- provider=provider_value,
- provider_prefix=provider_prefix,
- config_api_base=getattr(llm, "api_base", None),
- )
- response = await acompletion(
- model=raw_model,
- messages=messages,
- api_key=getattr(llm, "api_key", None),
- api_base=title_api_base,
- )
-
- usage_info = None
- usage = getattr(response, "usage", None)
- if usage:
- raw_model = getattr(llm, "model", "") or ""
- model_name = (
- raw_model.split("/", 1)[-1]
- if "/" in raw_model
- else (raw_model or response.model or "unknown")
- )
- usage_info = {
- "model": model_name,
- "prompt_tokens": getattr(usage, "prompt_tokens", 0) or 0,
- "completion_tokens": getattr(usage, "completion_tokens", 0)
- or 0,
- "total_tokens": getattr(usage, "total_tokens", 0) or 0,
- }
-
- raw_title = response.choices[0].message.content.strip()
- if raw_title and len(raw_title) <= 100:
- return raw_title.strip("\"'"), usage_info
- return None, usage_info
- except Exception:
- logging.getLogger(__name__).exception(
- "[TitleGen] _generate_title failed"
- )
- return None, None
-
- title_task = asyncio.create_task(_generate_title())
-
- title_emitted = False
-
- # Build the per-invocation runtime context (Phase 1.5).
- # ``mentioned_document_ids`` is read by ``KnowledgePriorityMiddleware``
- # via ``runtime.context.mentioned_document_ids`` instead of its
- # ``__init__`` closure — that way the same compiled-agent instance
- # can serve multiple turns with different mention lists.
- runtime_context = SurfSenseContextSchema(
- search_space_id=search_space_id,
- mentioned_document_ids=list(mentioned_document_ids or []),
- mentioned_folder_ids=list(
- accepted_folder_ids or mentioned_folder_ids or []
- ),
- mentioned_connector_ids=list(mentioned_connector_ids or []),
- mentioned_connectors=list(mentioned_connectors or []),
- request_id=request_id,
- turn_id=stream_result.turn_id,
- )
-
- _t_stream_start = time.perf_counter()
- _first_event_logged = False
- runtime_rate_limit_recovered = False
- while True:
- try:
- async for sse in _stream_agent_events(
- agent=agent,
- config=config,
- input_data=input_state,
- streaming_service=streaming_service,
- result=stream_result,
- step_prefix="thinking",
- initial_step_id=initial_step_id,
- initial_step_title=initial_title,
- initial_step_items=initial_items,
- fallback_commit_search_space_id=search_space_id,
- fallback_commit_created_by_id=user_id,
- fallback_commit_filesystem_mode=(
- filesystem_selection.mode
- if filesystem_selection
- else FilesystemMode.CLOUD
- ),
- fallback_commit_thread_id=chat_id,
- runtime_context=runtime_context,
- content_builder=stream_result.content_builder,
- ):
- if not _first_event_logged:
- _perf_log.info(
- "[stream_new_chat] First agent event in %.3fs (time since stream start), "
- "%.3fs (total since request start) (chat_id=%s)",
- time.perf_counter() - _t_stream_start,
- time.perf_counter() - _t_total,
- chat_id,
- )
- _first_event_logged = True
- yield sse
-
- # Inject title update mid-stream as soon as the background
- # task finishes.
- if (
- title_task is not None
- and title_task.done()
- and not title_emitted
- ):
- generated_title, title_usage = title_task.result()
- if title_usage:
- accumulator.add(**title_usage)
- if generated_title:
- async with shielded_async_session() as title_session:
- title_thread_result = await title_session.execute(
- select(NewChatThread).filter(
- NewChatThread.id == chat_id
- )
- )
- title_thread = title_thread_result.scalars().first()
- if title_thread:
- title_thread.title = generated_title
- await title_session.commit()
- yield streaming_service.format_thread_title_update(
- chat_id, generated_title
- )
- title_emitted = True
- break
- except Exception as stream_exc:
- can_runtime_recover = (
- not runtime_rate_limit_recovered
- and requested_llm_config_id == 0
- and llm_config_id < 0
- and not _first_event_logged
- and _is_provider_rate_limited(stream_exc)
- )
- if not can_runtime_recover:
- raise
-
- runtime_rate_limit_recovered = True
- previous_config_id = llm_config_id
- # The failed attempt may still hold the per-thread busy mutex
- # (middleware teardown can lag behind raised provider errors).
- # Force release before we retry within the same request.
- end_turn(str(chat_id))
- mark_runtime_cooldown(
- previous_config_id,
- reason="provider_rate_limited",
- )
-
- llm_config_id = (
- await resolve_or_get_pinned_llm_config_id(
- session,
- thread_id=chat_id,
- search_space_id=search_space_id,
- user_id=user_id,
- selected_llm_config_id=0,
- exclude_config_ids={previous_config_id},
- requires_image_input=_requires_image_input,
- )
- ).resolved_llm_config_id
-
- llm, agent_config, llm_load_error = await _load_llm_bundle(
- llm_config_id
- )
- if llm_load_error:
- raise stream_exc
-
- # Title generation uses the initial llm object. After a runtime
- # repin we keep the stream focused on response recovery and skip
- # title generation for this turn.
- if title_task is not None and not title_task.done():
- title_task.cancel()
- title_task = None
-
- _t0 = time.perf_counter()
- agent = await _build_main_agent_for_thread(
- agent_factory,
- llm=llm,
- search_space_id=search_space_id,
- db_session=session,
- connector_service=connector_service,
- checkpointer=checkpointer,
- user_id=user_id,
- thread_id=chat_id,
- agent_config=agent_config,
- firecrawl_api_key=firecrawl_api_key,
- thread_visibility=visibility,
- filesystem_selection=filesystem_selection,
- disabled_tools=disabled_tools,
- mentioned_document_ids=mentioned_document_ids,
- )
- _perf_log.info(
- "[stream_new_chat] Runtime rate-limit recovery repinned "
- "config_id=%s -> %s and rebuilt agent in %.3fs",
- previous_config_id,
- llm_config_id,
- time.perf_counter() - _t0,
- )
- ot.add_event(
- "chat.rate_limit.recovered",
- {
- "recovery.reason": "provider_rate_limited",
- "recovery.previous_config_id": previous_config_id,
- "recovery.fallback_config_id": llm_config_id,
- },
- )
- _log_chat_stream_error(
- flow=flow,
- error_kind="rate_limited",
- error_code="RATE_LIMITED",
- severity="info",
- is_expected=True,
- request_id=request_id,
- thread_id=chat_id,
- search_space_id=search_space_id,
- user_id=user_id,
- message=(
- "Auto-pinned model hit runtime rate limit; switched to "
- "another eligible model and retried."
- ),
- extra={
- "auto_runtime_recover": True,
- "previous_config_id": previous_config_id,
- "fallback_config_id": llm_config_id,
- },
- )
- continue
-
- _perf_log.info(
- "[stream_new_chat] Agent stream completed in %.3fs (chat_id=%s)",
- time.perf_counter() - _t_stream_start,
- chat_id,
- )
- log_system_snapshot("stream_new_chat_END")
-
- if stream_result.is_interrupted:
- ot.add_event(
- "chat.interrupted",
- {
- "chat.flow": flow,
- },
- )
- if title_task is not None and not title_task.done():
- title_task.cancel()
-
- usage_summary = accumulator.per_message_summary()
- _perf_log.info(
- "[token_usage] interrupted new_chat: calls=%d total=%d cost_micros=%d summary=%s",
- len(accumulator.calls),
- accumulator.grand_total,
- accumulator.total_cost_micros,
- usage_summary,
- )
- if usage_summary:
- yield streaming_service.format_data(
- "token-usage",
- {
- "usage": usage_summary,
- "prompt_tokens": accumulator.total_prompt_tokens,
- "completion_tokens": accumulator.total_completion_tokens,
- "total_tokens": accumulator.grand_total,
- "cost_micros": accumulator.total_cost_micros,
- "call_details": accumulator.serialized_calls(),
- },
- )
-
- yield streaming_service.format_finish_step()
- yield streaming_service.format_finish()
- yield streaming_service.format_done()
- return
-
- # If the title task didn't finish during streaming, await it now
- if title_task is not None and not title_emitted:
- generated_title, title_usage = await title_task
- if title_usage:
- accumulator.add(**title_usage)
- if generated_title:
- async with shielded_async_session() as title_session:
- title_thread_result = await title_session.execute(
- select(NewChatThread).filter(NewChatThread.id == chat_id)
- )
- title_thread = title_thread_result.scalars().first()
- if title_thread:
- title_thread.title = generated_title
- await title_session.commit()
- yield streaming_service.format_thread_title_update(
- chat_id, generated_title
- )
-
- # Finalize premium credit debit with the actual provider cost
- # reported by LiteLLM, summed across every call in the turn.
- # Mirrors the pre-cost behaviour of "premium turn → all calls
- # count" so free sub-agent calls during a premium turn still
- # contribute to the bill (they're $0 in practice anyway).
- if _premium_request_id and user_id:
- try:
- from app.services.token_quota_service import TokenQuotaService
-
- async with shielded_async_session() as quota_session:
- await TokenQuotaService.premium_finalize(
- db_session=quota_session,
- user_id=UUID(user_id),
- request_id=_premium_request_id,
- actual_micros=accumulator.total_cost_micros,
- reserved_micros=_premium_reserved_micros,
- )
- _premium_request_id = None
- _premium_reserved_micros = 0
- except Exception:
- logging.getLogger(__name__).warning(
- "Failed to finalize premium quota for user %s",
- user_id,
- exc_info=True,
- )
-
- usage_summary = accumulator.per_message_summary()
- _perf_log.info(
- "[token_usage] normal new_chat: calls=%d total=%d cost_micros=%d summary=%s",
- len(accumulator.calls),
- accumulator.grand_total,
- accumulator.total_cost_micros,
- usage_summary,
- )
- if usage_summary:
- yield streaming_service.format_data(
- "token-usage",
- {
- "usage": usage_summary,
- "prompt_tokens": accumulator.total_prompt_tokens,
- "completion_tokens": accumulator.total_completion_tokens,
- "total_tokens": accumulator.grand_total,
- "cost_micros": accumulator.total_cost_micros,
- "call_details": accumulator.serialized_calls(),
- },
- )
-
- # Finish the step and message
- yield streaming_service.format_data("turn-status", {"status": "idle"})
- yield streaming_service.format_finish_step()
- yield streaming_service.format_finish()
- yield streaming_service.format_done()
-
- except Exception as e:
- # Handle any errors
- import traceback
-
- # ``BusyError`` fires before the agent acquires the lock; the
- # cleanup path must skip lock release to avoid freeing the
- # in-flight caller's lock. Classification is handled below.
- if isinstance(e, BusyError):
- _busy_error_raised = True
-
- (
- error_kind,
- error_code,
- severity,
- is_expected,
- user_message,
- error_extra,
- ) = _classify_stream_exception(e, flow_label="chat")
- chat_outcome = error_code or error_kind or "error"
- chat_error_category = ot_metrics.categorize_exception(e)
- with contextlib.suppress(Exception):
- chat_span.set_attribute("chat.outcome", chat_outcome)
- chat_span.set_attribute("error.category", chat_error_category)
- ot.record_error(chat_span, e)
- error_message = f"Error during chat: {e!s}"
- print(f"[stream_new_chat] {error_message}")
- print(f"[stream_new_chat] Exception type: {type(e).__name__}")
- print(f"[stream_new_chat] Traceback:\n{traceback.format_exc()}")
- if error_code == "TURN_CANCELLING":
- status_payload: dict[str, Any] = {"status": "cancelling"}
- if error_extra:
- status_payload.update(error_extra)
- yield streaming_service.format_data("turn-status", status_payload)
- else:
- yield streaming_service.format_data("turn-status", {"status": "busy"})
-
- yield _emit_stream_error(
- message=user_message,
- error_kind=error_kind,
- error_code=error_code,
- severity=severity,
- is_expected=is_expected,
- extra=error_extra,
- )
- yield streaming_service.format_data("turn-status", {"status": "idle"})
- yield streaming_service.format_finish_step()
- yield streaming_service.format_finish()
- yield streaming_service.format_done()
-
- finally:
- # Shield the ENTIRE async cleanup from anyio cancel-scope
- # cancellation. Starlette's BaseHTTPMiddleware uses anyio task
- # groups; on client disconnect, it cancels the scope with
- # level-triggered cancellation — every unshielded `await` inside
- # the cancelled scope raises CancelledError immediately. Without
- # this shield the very first `await` (session.rollback) would
- # raise CancelledError, `except Exception` wouldn't catch it
- # (CancelledError is a BaseException), and the rest of the
- # finally block — including session.close() — would never run.
- with anyio.CancelScope(shield=True):
- # Authoritative fallback cleanup for lock/cancel state. Middleware
- # teardown can be skipped on some client-abort paths.
- end_turn(str(chat_id))
-
- # Release premium reservation if not finalized
- if _premium_request_id and _premium_reserved_micros > 0 and user_id:
- try:
- from app.services.token_quota_service import TokenQuotaService
-
- async with shielded_async_session() as quota_session:
- await TokenQuotaService.premium_release(
- db_session=quota_session,
- user_id=UUID(user_id),
- reserved_micros=_premium_reserved_micros,
- )
- _premium_reserved_micros = 0
- except Exception:
- logging.getLogger(__name__).warning(
- "Failed to release premium quota for user %s", user_id
- )
-
- try:
- await session.rollback()
- await clear_ai_responding(session, chat_id)
- except Exception:
- try:
- async with shielded_async_session() as fresh_session:
- await clear_ai_responding(fresh_session, chat_id)
- except Exception:
- logging.getLogger(__name__).warning(
- "Failed to clear AI responding state for thread %s", chat_id
- )
-
- with contextlib.suppress(Exception):
- session.expunge_all()
-
- with contextlib.suppress(Exception):
- await session.close()
-
- # Server-side assistant-message + token_usage finalization.
- # Runs after the main session has been closed (uses its own
- # shielded session) so we don't fight the same DB connection.
- # Idempotent against the legacy frontend appendMessage:
- # * the assistant row was already INSERTed by
- # ``persist_assistant_shell`` above, so this just UPDATEs
- # it with the rich ContentPart[] from the builder.
- # * token_usage uses INSERT ... ON CONFLICT DO NOTHING
- # against migration 142's partial unique index, so a
- # racing append_message recovery branch can never
- # double-write.
- # ``mark_interrupted`` closes any open text/reasoning blocks
- # and flips running tool-calls (no result) to state=aborted
- # so the persisted JSONB reflects a coherent end-state even
- # on client disconnect.
- # Never raises (best-effort, logs only).
- if (
- stream_result
- and stream_result.turn_id
- and stream_result.assistant_message_id
- ):
- from app.tasks.chat.persistence import finalize_assistant_turn
-
- builder_stats: dict[str, int] | None = None
- if stream_result.content_builder is not None:
- stream_result.content_builder.mark_interrupted()
- # Snapshot stats BEFORE deepcopy in ``snapshot()`` so
- # the perf log records the actual finalised payload
- # (post-mark_interrupted), not the live-mutating
- # builder state.
- builder_stats = stream_result.content_builder.stats()
- content_payload = stream_result.content_builder.snapshot()
- else:
- # Defensive fallback — we always set the builder
- # alongside ``assistant_message_id`` above, so this
- # branch only fires if a future refactor ever
- # decouples them. Persist whatever accumulated
- # text we captured so the row at least renders.
- content_payload = [
- {
- "type": "text",
- "text": stream_result.accumulated_text or "",
- }
- ]
-
- if builder_stats is not None:
- _perf_log.info(
- "[stream_new_chat] finalize_payload chat_id=%s "
- "message_id=%s parts=%d bytes=%d text=%d "
- "reasoning=%d tool_calls=%d "
- "tool_calls_completed=%d tool_calls_aborted=%d "
- "thinking_step_parts=%d step_separators=%d",
- chat_id,
- stream_result.assistant_message_id,
- builder_stats["parts"],
- builder_stats["bytes"],
- builder_stats["text"],
- builder_stats["reasoning"],
- builder_stats["tool_calls"],
- builder_stats["tool_calls_completed"],
- builder_stats["tool_calls_aborted"],
- builder_stats["thinking_step_parts"],
- builder_stats["step_separators"],
- )
-
- await finalize_assistant_turn(
- message_id=stream_result.assistant_message_id,
- chat_id=chat_id,
- search_space_id=search_space_id,
- user_id=user_id,
- turn_id=stream_result.turn_id,
- content=content_payload,
- accumulator=accumulator,
- )
-
- # Persist any sandbox-produced files to local storage so they
- # remain downloadable after the Daytona sandbox auto-deletes.
- if stream_result and stream_result.sandbox_files:
- with contextlib.suppress(Exception):
- from app.agents.shared.sandbox import (
- is_sandbox_enabled,
- persist_and_delete_sandbox,
- )
-
- if is_sandbox_enabled():
- with anyio.CancelScope(shield=True):
- await persist_and_delete_sandbox(
- chat_id, stream_result.sandbox_files
- )
-
- # ``aafter_agent`` doesn't fire on ``interrupt()`` or early bailout.
- # Skip on ``BusyError`` (caller never acquired the lock).
- if not _busy_error_raised:
- with contextlib.suppress(Exception):
- end_turn(str(chat_id))
- _perf_log.info(
- "[stream_new_chat] end_turn cleanup (chat_id=%s)",
- chat_id,
- )
-
- # Break circular refs held by the agent graph, tools, and LLM
- # wrappers so the GC can reclaim them in a single pass.
- agent = llm = connector_service = None
- input_state = stream_result = None
- session = None
-
- collected = gc.collect(0) + gc.collect(1) + gc.collect(2)
- if collected:
- _perf_log.info(
- "[stream_new_chat] gc.collect() reclaimed %d objects (chat_id=%s)",
- collected,
- chat_id,
- )
- trim_native_heap()
- log_system_snapshot("stream_new_chat_END")
- with contextlib.suppress(Exception):
- chat_span.set_attribute("chat.outcome", chat_outcome)
- ot_metrics.record_chat_request_duration(
- (time.perf_counter() - _t_total) * 1000,
- flow=flow,
- outcome=chat_outcome,
- agent_mode=chat_agent_mode,
- )
- ot_metrics.record_chat_request_outcome(
- flow=flow,
- outcome=chat_outcome,
- agent_mode=chat_agent_mode,
- error_category=chat_error_category,
- )
- chat_span_cm.__exit__(*sys.exc_info())
-
-
-async def stream_resume_chat(
- chat_id: int,
- search_space_id: int,
- decisions: list[dict],
- user_id: str | None = None,
- llm_config_id: int = -1,
- thread_visibility: ChatVisibility | None = None,
- filesystem_selection: FilesystemSelection | None = None,
- request_id: str | None = None,
- disabled_tools: list[str] | None = None,
-) -> AsyncGenerator[str, None]:
- streaming_service = VercelStreamingService()
- stream_result = StreamResult()
- _t_total = time.perf_counter()
- fs_mode = filesystem_selection.mode.value if filesystem_selection else "cloud"
- fs_platform = (
- filesystem_selection.client_platform.value if filesystem_selection else "web"
- )
- stream_result.request_id = request_id
- stream_result.turn_id = f"{chat_id}:{int(time.time() * 1000)}"
- stream_result.filesystem_mode = fs_mode
- stream_result.client_platform = fs_platform
- chat_agent_mode = "unknown"
- chat_outcome = "success"
- chat_error_category: str | None = None
- chat_span_cm = ot.chat_request_span(
- chat_id=chat_id,
- search_space_id=search_space_id,
- flow="resume",
- request_id=request_id,
- turn_id=stream_result.turn_id,
- filesystem_mode=fs_mode,
- client_platform=fs_platform,
- agent_mode=chat_agent_mode,
- )
- chat_span = chat_span_cm.__enter__()
- _log_file_contract("turn_start", stream_result)
- _perf_log.info(
- "[stream_resume] filesystem_mode=%s client_platform=%s",
- fs_mode,
- fs_platform,
- )
- from app.services.token_tracking_service import start_turn
-
- accumulator = start_turn()
-
- # Skip the finally release on ``BusyError`` (caller never acquired the lock).
- _busy_error_raised = False
-
- _emit_stream_error = partial(
- _emit_stream_terminal_error,
- streaming_service=streaming_service,
- flow="resume",
- request_id=request_id,
- thread_id=chat_id,
- search_space_id=search_space_id,
- user_id=user_id,
- )
-
- session = async_session_maker()
- try:
- if user_id:
- await set_ai_responding(session, chat_id, UUID(user_id))
-
- agent_config: AgentConfig | None = None
- requested_llm_config_id = llm_config_id
-
- async def _load_llm_bundle(
- config_id: int,
- ) -> tuple[Any, AgentConfig | None, str | None]:
- if config_id >= 0:
- loaded_agent_config = await load_agent_config(
- session=session,
- config_id=config_id,
- search_space_id=search_space_id,
- )
- if not loaded_agent_config:
- return (
- None,
- None,
- f"Failed to load NewLLMConfig with id {config_id}",
- )
- return (
- create_chat_litellm_from_agent_config(loaded_agent_config),
- loaded_agent_config,
- None,
- )
-
- loaded_llm_config = load_global_llm_config_by_id(config_id)
- if not loaded_llm_config:
- return None, None, f"Failed to load LLM config with id {config_id}"
- return (
- create_chat_litellm_from_config(loaded_llm_config),
- AgentConfig.from_yaml_config(loaded_llm_config),
- None,
- )
-
- _t0 = time.perf_counter()
- try:
- llm_config_id = (
- await resolve_or_get_pinned_llm_config_id(
- session,
- thread_id=chat_id,
- search_space_id=search_space_id,
- user_id=user_id,
- selected_llm_config_id=llm_config_id,
- )
- ).resolved_llm_config_id
- ot.add_event(
- "model.pin.resolved",
- {
- "pin.requested_id": requested_llm_config_id,
- "pin.resolved_id": llm_config_id,
- "pin.requires_image_input": False,
- },
- )
- except ValueError as pin_error:
- yield _emit_stream_error(
- message=str(pin_error),
- error_kind="server_error",
- error_code="SERVER_ERROR",
- )
- yield streaming_service.format_done()
- return
-
- llm, agent_config, llm_load_error = await _load_llm_bundle(llm_config_id)
- if llm_load_error:
- yield _emit_stream_error(
- message=llm_load_error,
- error_kind="server_error",
- error_code="SERVER_ERROR",
- )
- yield streaming_service.format_done()
- return
- _perf_log.info(
- "[stream_resume] LLM config loaded in %.3fs", time.perf_counter() - _t0
- )
-
- # Premium credit reservation (same logic as stream_new_chat).
- _resume_premium_reserved_micros = 0
- _resume_premium_request_id: str | None = None
- _resume_needs_premium = (
- agent_config is not None and user_id and agent_config.is_premium
- )
- if _resume_needs_premium:
- import uuid as _uuid
-
- from app.services.token_quota_service import (
- TokenQuotaService,
- estimate_call_reserve_micros,
- )
-
- _resume_premium_request_id = _uuid.uuid4().hex[:16]
- _resume_litellm_params = agent_config.litellm_params or {}
- _resume_base_model = (
- _resume_litellm_params.get("base_model")
- or agent_config.model_name
- or ""
- )
- reserve_amount_micros = estimate_call_reserve_micros(
- base_model=_resume_base_model,
- quota_reserve_tokens=agent_config.quota_reserve_tokens,
- )
- async with shielded_async_session() as quota_session:
- quota_result = await TokenQuotaService.premium_reserve(
- db_session=quota_session,
- user_id=UUID(user_id),
- request_id=_resume_premium_request_id,
- reserve_micros=reserve_amount_micros,
- )
- _resume_premium_reserved_micros = reserve_amount_micros
- if not quota_result.allowed:
- ot.add_event(
- "quota.denied",
- {
- "quota.code": "PREMIUM_QUOTA_EXHAUSTED",
- },
- )
- if requested_llm_config_id == 0:
- try:
- llm_config_id = (
- await resolve_or_get_pinned_llm_config_id(
- session,
- thread_id=chat_id,
- search_space_id=search_space_id,
- user_id=user_id,
- selected_llm_config_id=0,
- force_repin_free=True,
- )
- ).resolved_llm_config_id
- ot.add_event(
- "model.repin",
- {
- "repin.reason": "premium_quota_exhausted",
- "repin.to_config_id": llm_config_id,
- },
- )
- except ValueError as pin_error:
- yield _emit_stream_error(
- message=str(pin_error),
- error_kind="server_error",
- error_code="SERVER_ERROR",
- )
- yield streaming_service.format_done()
- return
-
- llm, agent_config, llm_load_error = await _load_llm_bundle(
- llm_config_id
- )
- if llm_load_error:
- yield _emit_stream_error(
- message=llm_load_error,
- error_kind="server_error",
- error_code="SERVER_ERROR",
- )
- yield streaming_service.format_done()
- return
- _resume_premium_request_id = None
- _resume_premium_reserved_micros = 0
- _log_chat_stream_error(
- flow="resume",
- error_kind="premium_quota_exhausted",
- error_code="PREMIUM_QUOTA_EXHAUSTED",
- severity="info",
- is_expected=True,
- request_id=request_id,
- thread_id=chat_id,
- search_space_id=search_space_id,
- user_id=user_id,
- message=(
- "Premium quota exhausted on pinned model; auto-fallback switched to a free model"
- ),
- extra={
- "fallback_config_id": llm_config_id,
- "auto_fallback": True,
- },
- )
- else:
- yield _emit_stream_error(
- message=(
- "Buy more tokens to continue with this model, or switch to a free model"
- ),
- error_kind="premium_quota_exhausted",
- error_code="PREMIUM_QUOTA_EXHAUSTED",
- severity="info",
- is_expected=True,
- extra={
- "resolved_config_id": llm_config_id,
- "auto_fallback": False,
- },
- )
- yield streaming_service.format_done()
- return
-
- if not llm:
- yield _emit_stream_error(
- message="Failed to create LLM instance",
- error_kind="server_error",
- error_code="SERVER_ERROR",
- )
- yield streaming_service.format_done()
- return
-
- _t0 = time.perf_counter()
- connector_service = ConnectorService(session, search_space_id=search_space_id)
-
- firecrawl_api_key = None
- webcrawler_connector = await connector_service.get_connector_by_type(
- SearchSourceConnectorType.WEBCRAWLER_CONNECTOR, search_space_id
- )
- if webcrawler_connector and webcrawler_connector.config:
- firecrawl_api_key = webcrawler_connector.config.get("FIRECRAWL_API_KEY")
- _perf_log.info(
- "[stream_resume] Connector service + firecrawl key in %.3fs",
- time.perf_counter() - _t0,
- )
-
- _t0 = time.perf_counter()
- checkpointer = await get_checkpointer()
- _perf_log.info(
- "[stream_resume] Checkpointer ready in %.3fs", time.perf_counter() - _t0
- )
-
- visibility = thread_visibility or ChatVisibility.PRIVATE
- chat_agent_mode = "multi"
- with contextlib.suppress(Exception):
- chat_span.set_attribute("agent.mode", chat_agent_mode)
- _t0 = time.perf_counter()
- agent_factory = create_multi_agent_chat_deep_agent
- # Build the agent inline. Provider 429s are handled by the
- # in-stream recovery loop, which repins to an eligible
- # alternative config and rebuilds the agent before the user sees
- # any output.
- agent = await _build_main_agent_for_thread(
- agent_factory,
- llm=llm,
- search_space_id=search_space_id,
- db_session=session,
- connector_service=connector_service,
- checkpointer=checkpointer,
- user_id=user_id,
- thread_id=chat_id,
- agent_config=agent_config,
- firecrawl_api_key=firecrawl_api_key,
- thread_visibility=visibility,
- filesystem_selection=filesystem_selection,
- disabled_tools=disabled_tools,
- )
- _perf_log.info(
- "[stream_resume] Agent created in %.3fs", time.perf_counter() - _t0
- )
-
- # Release the transaction before streaming (same rationale as stream_new_chat).
- await session.commit()
- session.expunge_all()
-
- _perf_log.info(
- "[stream_resume] Total pre-stream setup in %.3fs (chat_id=%s)",
- time.perf_counter() - _t_total,
- chat_id,
- )
-
- from langgraph.types import Command
-
- from app.agents.multi_agent_chat.middleware.main_agent.checkpointed_subagent_middleware.resume_routing import (
- build_lg_resume_map,
- collect_pending_tool_calls,
- slice_decisions_by_tool_call,
- )
-
- # Each pending interrupt is stamped with its originating ``tool_call_id``
- # (see ``checkpointed_subagent_middleware.propagation``) so we can route
- # a flat ``decisions`` list back to the right paused subagent.
- parent_state = await agent.aget_state(
- {"configurable": {"thread_id": str(chat_id)}}
- )
- pending = collect_pending_tool_calls(parent_state)
- _perf_log.info(
- "[hitl_route] resume_entry chat_id=%s decisions=%d pending_subagents=%d",
- chat_id,
- len(decisions),
- len(pending),
- )
- routed_resume_value = slice_decisions_by_tool_call(decisions, pending)
- # Langgraph rejects scalar ``Command(resume=...)`` when multiple
- # interrupts are pending (parallel HITL); the mapped form works
- # for the single-pause case too, so we always use it.
- lg_resume_map = build_lg_resume_map(parent_state, routed_resume_value)
-
- config = {
- "configurable": {
- "thread_id": str(chat_id),
- "request_id": request_id or "unknown",
- "turn_id": stream_result.turn_id,
- # Per-``tool_call_id`` resume slices read by
- # ``SurfSenseCheckpointedSubAgentMiddleware``. Parallel
- # siblings each pop their own entry, so they never race.
- "surfsense_resume_value": routed_resume_value,
- },
- # See ``stream_new_chat`` above for rationale: effectively
- # uncapped to mirror the agent default and OpenCode's
- # session loop. Doom-loop / call-limit middleware enforce
- # the real ceiling.
- "recursion_limit": 10_000,
- }
-
- yield streaming_service.format_message_start()
- yield streaming_service.format_start_step()
- # Same rationale as ``stream_new_chat``: emit the turn id so
- # resumed streams can be persisted with their correlation id
- # intact.
- yield streaming_service.format_data(
- "turn-info",
- {"chat_turn_id": stream_result.turn_id},
- )
- yield streaming_service.format_data("turn-status", {"status": "busy"})
-
- # Pre-write a fresh assistant row for this resume turn. The
- # original (interrupted) ``stream_new_chat`` invocation already
- # persisted its own assistant row anchored to a different
- # ``turn_id``; resume allocates a new ``turn_id`` (above) so we
- # need a separate row keyed on the same ``(thread_id, turn_id,
- # ASSISTANT)`` invariant. Idempotent against migration 141's
- # partial unique index — recovers existing id on retry.
- from app.tasks.chat.content_builder import AssistantContentBuilder
- from app.tasks.chat.persistence import persist_assistant_shell
-
- assistant_message_id = await persist_assistant_shell(
- chat_id=chat_id,
- user_id=user_id,
- turn_id=stream_result.turn_id,
- )
- if assistant_message_id is None:
- yield _emit_stream_error(
- message=(
- "We couldn't initialize the assistant message. Please try again."
- ),
- error_kind="server_error",
- error_code="MESSAGE_PERSIST_FAILED",
- )
- yield streaming_service.format_data("turn-status", {"status": "idle"})
- yield streaming_service.format_finish_step()
- yield streaming_service.format_finish()
- yield streaming_service.format_done()
- return
-
- # Emit canonical assistant message id BEFORE any LLM streaming
- # so the FE can rename ``pendingInterrupt.assistantMsgId`` to
- # ``msg-{assistant_message_id}`` immediately. Resume does NOT
- # emit ``data-user-message-id`` because the user row is from
- # the original interrupted turn (different ``turn_id``) and is
- # never re-persisted here. See B5 in the
- # ``sse-based_message_id_handshake`` plan.
- yield streaming_service.format_data(
- "assistant-message-id",
- {"message_id": assistant_message_id, "turn_id": stream_result.turn_id},
- )
-
- stream_result.assistant_message_id = assistant_message_id
- stream_result.content_builder = AssistantContentBuilder()
-
- # Resume path doesn't carry new ``mentioned_document_ids`` —
- # those are seeded in the original turn. We still pass a
- # context so future middleware extensions (Phase 2) can rely on
- # ``runtime.context`` always being populated.
- runtime_context = SurfSenseContextSchema(
- search_space_id=search_space_id,
- request_id=request_id,
- turn_id=stream_result.turn_id,
- )
-
- _t_stream_start = time.perf_counter()
- _first_event_logged = False
- runtime_rate_limit_recovered = False
- while True:
- try:
- async for sse in _stream_agent_events(
- agent=agent,
- config=config,
- input_data=Command(resume=lg_resume_map),
- streaming_service=streaming_service,
- result=stream_result,
- step_prefix=_resume_step_prefix(stream_result.turn_id),
- fallback_commit_search_space_id=search_space_id,
- fallback_commit_created_by_id=user_id,
- fallback_commit_filesystem_mode=(
- filesystem_selection.mode
- if filesystem_selection
- else FilesystemMode.CLOUD
- ),
- fallback_commit_thread_id=chat_id,
- runtime_context=runtime_context,
- content_builder=stream_result.content_builder,
- ):
- if not _first_event_logged:
- _perf_log.info(
- "[stream_resume] First agent event in %.3fs (stream), %.3fs (total) (chat_id=%s)",
- time.perf_counter() - _t_stream_start,
- time.perf_counter() - _t_total,
- chat_id,
- )
- _first_event_logged = True
- yield sse
- break
- except Exception as stream_exc:
- can_runtime_recover = (
- not runtime_rate_limit_recovered
- and requested_llm_config_id == 0
- and llm_config_id < 0
- and not _first_event_logged
- and _is_provider_rate_limited(stream_exc)
- )
- if not can_runtime_recover:
- raise
-
- runtime_rate_limit_recovered = True
- previous_config_id = llm_config_id
- # Ensure the same-request recovery retry does not trip the
- # BusyMutex lock retained by the failed attempt.
- end_turn(str(chat_id))
- mark_runtime_cooldown(
- previous_config_id,
- reason="provider_rate_limited",
- )
- llm_config_id = (
- await resolve_or_get_pinned_llm_config_id(
- session,
- thread_id=chat_id,
- search_space_id=search_space_id,
- user_id=user_id,
- selected_llm_config_id=0,
- exclude_config_ids={previous_config_id},
- )
- ).resolved_llm_config_id
-
- llm, agent_config, llm_load_error = await _load_llm_bundle(
- llm_config_id
- )
- if llm_load_error:
- raise stream_exc
-
- _t0 = time.perf_counter()
- agent = await _build_main_agent_for_thread(
- agent_factory,
- llm=llm,
- search_space_id=search_space_id,
- db_session=session,
- connector_service=connector_service,
- checkpointer=checkpointer,
- user_id=user_id,
- thread_id=chat_id,
- agent_config=agent_config,
- firecrawl_api_key=firecrawl_api_key,
- thread_visibility=visibility,
- filesystem_selection=filesystem_selection,
- disabled_tools=disabled_tools,
- )
- _perf_log.info(
- "[stream_resume] Runtime rate-limit recovery repinned "
- "config_id=%s -> %s and rebuilt agent in %.3fs",
- previous_config_id,
- llm_config_id,
- time.perf_counter() - _t0,
- )
- ot.add_event(
- "chat.rate_limit.recovered",
- {
- "recovery.reason": "provider_rate_limited",
- "recovery.previous_config_id": previous_config_id,
- "recovery.fallback_config_id": llm_config_id,
- },
- )
- _log_chat_stream_error(
- flow="resume",
- error_kind="rate_limited",
- error_code="RATE_LIMITED",
- severity="info",
- is_expected=True,
- request_id=request_id,
- thread_id=chat_id,
- search_space_id=search_space_id,
- user_id=user_id,
- message=(
- "Auto-pinned model hit runtime rate limit; switched to "
- "another eligible model and retried."
- ),
- extra={
- "auto_runtime_recover": True,
- "previous_config_id": previous_config_id,
- "fallback_config_id": llm_config_id,
- },
- )
- continue
- _perf_log.info(
- "[stream_resume] Agent stream completed in %.3fs (chat_id=%s)",
- time.perf_counter() - _t_stream_start,
- chat_id,
- )
- if stream_result.is_interrupted:
- ot.add_event(
- "chat.interrupted",
- {
- "chat.flow": "resume",
- },
- )
- usage_summary = accumulator.per_message_summary()
- _perf_log.info(
- "[token_usage] interrupted resume_chat: calls=%d total=%d cost_micros=%d summary=%s",
- len(accumulator.calls),
- accumulator.grand_total,
- accumulator.total_cost_micros,
- usage_summary,
- )
- if usage_summary:
- yield streaming_service.format_data(
- "token-usage",
- {
- "usage": usage_summary,
- "prompt_tokens": accumulator.total_prompt_tokens,
- "completion_tokens": accumulator.total_completion_tokens,
- "total_tokens": accumulator.grand_total,
- "cost_micros": accumulator.total_cost_micros,
- "call_details": accumulator.serialized_calls(),
- },
- )
-
- yield streaming_service.format_finish_step()
- yield streaming_service.format_finish()
- yield streaming_service.format_done()
- return
-
- # Finalize premium credit debit for resume path with the actual
- # provider cost reported by LiteLLM (sum of cost across all
- # calls in the turn).
- if _resume_premium_request_id and user_id:
- try:
- from app.services.token_quota_service import TokenQuotaService
-
- async with shielded_async_session() as quota_session:
- await TokenQuotaService.premium_finalize(
- db_session=quota_session,
- user_id=UUID(user_id),
- request_id=_resume_premium_request_id,
- actual_micros=accumulator.total_cost_micros,
- reserved_micros=_resume_premium_reserved_micros,
- )
- _resume_premium_request_id = None
- _resume_premium_reserved_micros = 0
- except Exception:
- logging.getLogger(__name__).warning(
- "Failed to finalize premium quota for user %s (resume)",
- user_id,
- exc_info=True,
- )
-
- usage_summary = accumulator.per_message_summary()
- _perf_log.info(
- "[token_usage] normal resume_chat: calls=%d total=%d cost_micros=%d summary=%s",
- len(accumulator.calls),
- accumulator.grand_total,
- accumulator.total_cost_micros,
- usage_summary,
- )
- if usage_summary:
- yield streaming_service.format_data(
- "token-usage",
- {
- "usage": usage_summary,
- "prompt_tokens": accumulator.total_prompt_tokens,
- "completion_tokens": accumulator.total_completion_tokens,
- "total_tokens": accumulator.grand_total,
- "cost_micros": accumulator.total_cost_micros,
- "call_details": accumulator.serialized_calls(),
- },
- )
-
- yield streaming_service.format_data("turn-status", {"status": "idle"})
- yield streaming_service.format_finish_step()
- yield streaming_service.format_finish()
- yield streaming_service.format_done()
-
- except Exception as e:
- import traceback
-
- # ``BusyError`` fires before the agent acquires the lock; the
- # cleanup path must skip lock release to avoid freeing the
- # in-flight caller's lock. Classification is handled below.
- if isinstance(e, BusyError):
- _busy_error_raised = True
-
- (
- error_kind,
- error_code,
- severity,
- is_expected,
- user_message,
- error_extra,
- ) = _classify_stream_exception(e, flow_label="resume")
- chat_outcome = error_code or error_kind or "error"
- chat_error_category = ot_metrics.categorize_exception(e)
- with contextlib.suppress(Exception):
- chat_span.set_attribute("chat.outcome", chat_outcome)
- chat_span.set_attribute("error.category", chat_error_category)
- ot.record_error(chat_span, e)
- error_message = f"Error during resume: {e!s}"
- print(f"[stream_resume_chat] {error_message}")
- print(f"[stream_resume_chat] Traceback:\n{traceback.format_exc()}")
- if error_code == "TURN_CANCELLING":
- status_payload: dict[str, Any] = {"status": "cancelling"}
- if error_extra:
- status_payload.update(error_extra)
- yield streaming_service.format_data("turn-status", status_payload)
- else:
- yield streaming_service.format_data("turn-status", {"status": "busy"})
- yield _emit_stream_error(
- message=user_message,
- error_kind=error_kind,
- error_code=error_code,
- severity=severity,
- is_expected=is_expected,
- extra=error_extra,
- )
- yield streaming_service.format_data("turn-status", {"status": "idle"})
- yield streaming_service.format_finish_step()
- yield streaming_service.format_finish()
- yield streaming_service.format_done()
-
- finally:
- with anyio.CancelScope(shield=True):
- # Authoritative fallback cleanup for lock/cancel state. Middleware
- # teardown can be skipped on some client-abort paths.
- end_turn(str(chat_id))
-
- # Release premium reservation if not finalized
- if (
- _resume_premium_request_id
- and _resume_premium_reserved_micros > 0
- and user_id
- ):
- try:
- from app.services.token_quota_service import TokenQuotaService
-
- async with shielded_async_session() as quota_session:
- await TokenQuotaService.premium_release(
- db_session=quota_session,
- user_id=UUID(user_id),
- reserved_micros=_resume_premium_reserved_micros,
- )
- _resume_premium_reserved_micros = 0
- except Exception:
- logging.getLogger(__name__).warning(
- "Failed to release premium quota for user %s (resume)", user_id
- )
-
- try:
- await session.rollback()
- await clear_ai_responding(session, chat_id)
- except Exception:
- try:
- async with shielded_async_session() as fresh_session:
- await clear_ai_responding(fresh_session, chat_id)
- except Exception:
- logging.getLogger(__name__).warning(
- "Failed to clear AI responding state for thread %s", chat_id
- )
-
- with contextlib.suppress(Exception):
- session.expunge_all()
-
- with contextlib.suppress(Exception):
- await session.close()
-
- # Server-side assistant-message + token_usage finalization for
- # the resume flow. The original user message was persisted by
- # the original (interrupted) ``stream_new_chat`` invocation;
- # the resume's own ``persist_assistant_shell`` write lives at
- # the new ``turn_id`` above. This finalize updates that row
- # with the rich ContentPart[] from the builder and writes
- # token_usage idempotently via migration 142's partial
- # unique index. Best-effort, never raises.
- if (
- stream_result
- and stream_result.turn_id
- and stream_result.assistant_message_id
- ):
- from app.tasks.chat.persistence import finalize_assistant_turn
-
- builder_stats: dict[str, int] | None = None
- if stream_result.content_builder is not None:
- stream_result.content_builder.mark_interrupted()
- builder_stats = stream_result.content_builder.stats()
- content_payload = stream_result.content_builder.snapshot()
- else:
- content_payload = [
- {
- "type": "text",
- "text": stream_result.accumulated_text or "",
- }
- ]
-
- if builder_stats is not None:
- _perf_log.info(
- "[stream_resume] finalize_payload chat_id=%s "
- "message_id=%s parts=%d bytes=%d text=%d "
- "reasoning=%d tool_calls=%d "
- "tool_calls_completed=%d tool_calls_aborted=%d "
- "thinking_step_parts=%d step_separators=%d",
- chat_id,
- stream_result.assistant_message_id,
- builder_stats["parts"],
- builder_stats["bytes"],
- builder_stats["text"],
- builder_stats["reasoning"],
- builder_stats["tool_calls"],
- builder_stats["tool_calls_completed"],
- builder_stats["tool_calls_aborted"],
- builder_stats["thinking_step_parts"],
- builder_stats["step_separators"],
- )
-
- await finalize_assistant_turn(
- message_id=stream_result.assistant_message_id,
- chat_id=chat_id,
- search_space_id=search_space_id,
- user_id=user_id,
- turn_id=stream_result.turn_id,
- content=content_payload,
- accumulator=accumulator,
- )
-
- # Release the lock from the original interrupted turn or any
- # re-interrupt/bailout. Skip on ``BusyError`` (lock not held here).
- if not _busy_error_raised:
- with contextlib.suppress(Exception):
- end_turn(str(chat_id))
- _perf_log.info(
- "[stream_resume] end_turn cleanup (chat_id=%s)",
- chat_id,
- )
-
- agent = llm = connector_service = None
- stream_result = None
- session = None
-
- collected = gc.collect(0) + gc.collect(1) + gc.collect(2)
- if collected:
- _perf_log.info(
- "[stream_resume] gc.collect() reclaimed %d objects (chat_id=%s)",
- collected,
- chat_id,
- )
- trim_native_heap()
- log_system_snapshot("stream_resume_chat_END")
- with contextlib.suppress(Exception):
- chat_span.set_attribute("chat.outcome", chat_outcome)
- ot_metrics.record_chat_request_duration(
- (time.perf_counter() - _t_total) * 1000,
- flow="resume",
- outcome=chat_outcome,
- agent_mode=chat_agent_mode,
- )
- ot_metrics.record_chat_request_outcome(
- flow="resume",
- outcome=chat_outcome,
- agent_mode=chat_agent_mode,
- error_category=chat_error_category,
- )
- chat_span_cm.__exit__(*sys.exc_info())
diff --git a/surfsense_backend/tests/e2e/run_backend.py b/surfsense_backend/tests/e2e/run_backend.py
index 2567cc7a4..c05783790 100644
--- a/surfsense_backend/tests/e2e/run_backend.py
+++ b/surfsense_backend/tests/e2e/run_backend.py
@@ -247,11 +247,11 @@ def _patch_llm_bindings() -> None:
fake_create_chat_litellm_from_config,
),
(
- "app.tasks.chat.stream_new_chat.create_chat_litellm_from_agent_config",
+ "app.tasks.chat.streaming.flows.shared.llm_bundle.create_chat_litellm_from_agent_config",
fake_create_chat_litellm_from_agent_config,
),
(
- "app.tasks.chat.stream_new_chat.create_chat_litellm_from_config",
+ "app.tasks.chat.streaming.flows.shared.llm_bundle.create_chat_litellm_from_config",
fake_create_chat_litellm_from_config,
),
]
diff --git a/surfsense_backend/tests/e2e/run_celery.py b/surfsense_backend/tests/e2e/run_celery.py
index 9e7576a51..1a77bf45a 100644
--- a/surfsense_backend/tests/e2e/run_celery.py
+++ b/surfsense_backend/tests/e2e/run_celery.py
@@ -220,11 +220,11 @@ def _patch_llm_bindings() -> None:
fake_create_chat_litellm_from_config,
),
(
- "app.tasks.chat.stream_new_chat.create_chat_litellm_from_agent_config",
+ "app.tasks.chat.streaming.flows.shared.llm_bundle.create_chat_litellm_from_agent_config",
fake_create_chat_litellm_from_agent_config,
),
(
- "app.tasks.chat.stream_new_chat.create_chat_litellm_from_config",
+ "app.tasks.chat.streaming.flows.shared.llm_bundle.create_chat_litellm_from_config",
fake_create_chat_litellm_from_config,
),
]
diff --git a/surfsense_backend/tests/unit/tasks/chat/streaming/test_orchestrator_frame_parity.py b/surfsense_backend/tests/unit/tasks/chat/streaming/test_orchestrator_frame_parity.py
deleted file mode 100644
index fc8280012..000000000
--- a/surfsense_backend/tests/unit/tasks/chat/streaming/test_orchestrator_frame_parity.py
+++ /dev/null
@@ -1,457 +0,0 @@
-"""Byte-for-byte frame parity: legacy monolith vs refactored flows orchestrators.
-
-The agent-content portion of the stream (`text-*`, tool cards, thinking-step
-updates) flows through **shared** code in both implementations
-(`stream_output` -> `EventRelay.relay` -> handlers), so it cannot diverge. The
-only independently-written part is the *orchestrator glue*: the initial frames,
-persistence-handshake frames, error/terminal branches, and final frames.
-
-This module drives BOTH ``stream_new_chat`` implementations (legacy
-``app.tasks.chat.stream_new_chat`` and the refactored
-``app.tasks.chat.streaming.flows``) through the deterministic glue paths and
-asserts the emitted SSE frame sequences are **byte-for-byte identical**. These
-are the paths where divergence could hide; the agent-streaming portion is shared
-and is covered separately.
-
-Determinism is enforced by:
- * freezing ``time.time`` (so ``turn_id = f"{chat_id}:{ms}"`` is stable),
- * a deterministic ``uuid`` sequence for the streaming-service id generators,
- * stubbing every DB/LLM/agent seam (LLM resolution, persistence, connector,
- checkpointer, session) to fixed values.
-
-Cutover gate: when these are green, the live callers can be flipped to the
-flows orchestrators.
-"""
-
-from __future__ import annotations
-
-from types import SimpleNamespace
-from typing import Any
-from unittest.mock import AsyncMock, MagicMock
-
-import pytest
-
-import app.services.new_streaming_service as _nss
-from app.tasks.chat.stream_new_chat import (
- stream_new_chat as old_stream_new_chat,
- stream_resume_chat as old_stream_resume_chat,
-)
-from app.tasks.chat.streaming.flows import (
- stream_new_chat as new_stream_new_chat,
- stream_resume_chat as new_stream_resume_chat,
-)
-
-pytestmark = pytest.mark.unit
-
-_FIXED_EPOCH = 1_700_000_000.0 # -> turn_id ":1700000000000"
-
-
-# --------------------------------------------------------------------------- #
-# Deterministic uuid for the streaming-service id generators
-# --------------------------------------------------------------------------- #
-
-
-class _SeqUUID:
- """Drop-in for the ``uuid`` module used by ``new_streaming_service``.
-
- Only ``uuid4().hex`` is consumed by the id generators. We hand out a
- monotonic, zero-padded hex so two runs that emit the same number of ids in
- the same order produce identical bytes.
- """
-
- def __init__(self) -> None:
- self._n = 0
-
- def reset(self) -> None:
- self._n = 0
-
- def uuid4(self) -> SimpleNamespace:
- self._n += 1
- return SimpleNamespace(hex=f"{self._n:032x}")
-
-
-_SEQ = _SeqUUID()
-
-
-# --------------------------------------------------------------------------- #
-# Fake session: the orchestrator owns ``async_session_maker()``; for the glue
-# paths every real consumer is stubbed, so a no-op session suffices.
-# --------------------------------------------------------------------------- #
-
-
-class _FakeResult:
- """Empty-everything SQLAlchemy ``Result`` stand-in for pre-stream reads."""
-
- def scalars(self) -> "_FakeResult":
- return self
-
- def first(self) -> None:
- return None
-
- def all(self) -> list[Any]:
- return []
-
- def one_or_none(self) -> None:
- return None
-
- def scalar_one_or_none(self) -> None:
- return None
-
- def scalar(self) -> None:
- return None
-
- def fetchall(self) -> list[Any]:
- return []
-
- def __iter__(self):
- return iter(())
-
-
-class _FakeSession:
- async def commit(self) -> None: # pragma: no cover - trivial
- return None
-
- async def rollback(self) -> None: # pragma: no cover - trivial
- return None
-
- async def close(self) -> None: # pragma: no cover - trivial
- return None
-
- def expunge_all(self) -> None: # pragma: no cover - trivial
- return None
-
- def add(self, *a: Any, **k: Any) -> None: # pragma: no cover - trivial
- return None
-
- async def flush(self, *a: Any, **k: Any) -> None: # pragma: no cover
- return None
-
- async def execute(self, *a: Any, **k: Any) -> _FakeResult:
- return _FakeResult()
-
-
-class _FakeConnectorService:
- def __init__(self, *a: Any, **k: Any) -> None:
- pass
-
- async def get_connector_by_type(self, *a: Any, **k: Any) -> None:
- return None
-
-
-def _patch(monkeypatch: pytest.MonkeyPatch, target: str, value: Any) -> None:
- """``setattr`` that tolerates a missing attr (binding may be local-import)."""
- monkeypatch.setattr(target, value, raising=False)
-
-
-def _apply_common(
- monkeypatch: pytest.MonkeyPatch,
- *,
- pin_raises: ValueError | None = None,
- resolved_id: int = -1,
- llm_load_ok: bool = True,
- persist_user_id: int | None = 101,
- persist_assistant_id: int | None = 102,
-) -> None:
- """Patch every glue seam in BOTH implementations to deterministic values."""
- # Time -> stable turn_id and any retry_after_at.
- monkeypatch.setattr("time.time", lambda: _FIXED_EPOCH)
-
- # Deterministic streaming-service ids.
- monkeypatch.setattr(_nss, "uuid", _SEQ)
-
- fake_model = MagicMock(name="scripted_llm")
-
- # --- session ---
- for tgt in (
- "app.tasks.chat.stream_new_chat.async_session_maker",
- "app.tasks.chat.streaming.flows.new_chat.orchestrator.async_session_maker",
- "app.tasks.chat.streaming.flows.resume_chat.orchestrator.async_session_maker",
- ):
- _patch(monkeypatch, tgt, _FakeSession)
-
- # --- connector service ---
- for tgt in (
- "app.tasks.chat.stream_new_chat.ConnectorService",
- "app.tasks.chat.streaming.flows.shared.pre_stream_setup.ConnectorService",
- ):
- _patch(monkeypatch, tgt, _FakeConnectorService)
-
- # --- checkpointer ---
- for tgt in (
- "app.tasks.chat.stream_new_chat.get_checkpointer",
- "app.tasks.chat.streaming.flows.shared.pre_stream_setup.get_checkpointer",
- ):
- _patch(monkeypatch, tgt, AsyncMock(return_value=MagicMock(name="checkpointer")))
-
- # --- agent factory (built but never streamed on glue paths) ---
- # Resume routing awaits ``agent.aget_state`` before persist, so the fake
- # agent exposes async state accessors returning an empty (no-interrupt)
- # snapshot. ``astream_events`` is never reached on glue paths.
- fake_agent = MagicMock(name="agent")
- fake_agent.aget_state = AsyncMock(
- return_value=SimpleNamespace(values={}, tasks=[], interrupts=[], next=())
- )
- fake_agent.aupdate_state = AsyncMock(return_value=None)
- agent_factory = AsyncMock(return_value=fake_agent)
- for tgt in (
- "app.tasks.chat.stream_new_chat.create_multi_agent_chat_deep_agent",
- "app.tasks.chat.streaming.flows.new_chat.orchestrator.create_multi_agent_chat_deep_agent",
- "app.tasks.chat.streaming.flows.resume_chat.orchestrator.create_multi_agent_chat_deep_agent",
- ):
- _patch(monkeypatch, tgt, agent_factory)
-
- # --- LLM resolution (auto-pin) ---
- if pin_raises is not None:
- async def _resolver(*a: Any, **k: Any):
- raise pin_raises
- else:
- async def _resolver(*a: Any, **k: Any):
- return SimpleNamespace(resolved_llm_config_id=resolved_id)
-
- _patch(monkeypatch, "app.services.auto_model_pin_service.resolve_or_get_pinned_llm_config_id", _resolver)
- _patch(monkeypatch, "app.tasks.chat.stream_new_chat.resolve_or_get_pinned_llm_config_id", _resolver)
- _patch(
- monkeypatch,
- "app.tasks.chat.streaming.flows.new_chat.auto_pin.resolve_or_get_pinned_llm_config_id",
- _resolver,
- )
-
- # --- LLM bundle ---
- sentinel_cfg = object() if llm_load_ok else None
- _patch(monkeypatch, "app.tasks.chat.stream_new_chat.load_global_llm_config_by_id", lambda cid: sentinel_cfg)
- _patch(
- monkeypatch,
- "app.tasks.chat.streaming.flows.shared.llm_bundle.load_global_llm_config_by_id",
- lambda cid: sentinel_cfg,
- )
- _patch(monkeypatch, "app.tasks.chat.stream_new_chat.create_chat_litellm_from_config", lambda cfg: fake_model)
- _patch(
- monkeypatch,
- "app.tasks.chat.streaming.flows.shared.llm_bundle.create_chat_litellm_from_config",
- lambda cfg: fake_model,
- )
- # agent_config := None keeps premium + capability gates inert and identical.
- from app.agents.shared.llm_config import AgentConfig
-
- monkeypatch.setattr(AgentConfig, "from_yaml_config", staticmethod(lambda cfg: None))
-
- # --- persistence ---
- async def _persist_user(*a: Any, **k: Any):
- return persist_user_id
-
- async def _persist_assistant(*a: Any, **k: Any):
- return persist_assistant_id
-
- async def _finalize(*a: Any, **k: Any):
- return None
-
- for mod in (
- "app.tasks.chat.persistence",
- "app.tasks.chat.streaming.flows.new_chat.persistence_spawn",
- ):
- _patch(monkeypatch, f"{mod}.persist_user_turn", _persist_user)
- _patch(monkeypatch, f"{mod}.persist_assistant_shell", _persist_assistant)
- # Resume binds ``persist_assistant_shell`` in its own assistant_shell module.
- _patch(
- monkeypatch,
- "app.tasks.chat.streaming.flows.resume_chat.assistant_shell.persist_assistant_shell",
- _persist_assistant,
- )
- _patch(monkeypatch, "app.tasks.chat.persistence.finalize_assistant_turn", _finalize)
-
- # --- collaboration flags ---
- async def _noop(*a: Any, **k: Any):
- return None
-
- for tgt in (
- "app.tasks.chat.stream_new_chat.set_ai_responding",
- "app.tasks.chat.stream_new_chat.clear_ai_responding",
- "app.tasks.chat.streaming.flows.new_chat.persistence_spawn.set_ai_responding",
- "app.services.chat_session_state_service.set_ai_responding",
- "app.services.chat_session_state_service.clear_ai_responding",
- ):
- _patch(monkeypatch, tgt, _noop)
-
-
-async def _collect(genfunc: Any, **kwargs: Any) -> list[str]:
- frames: list[str] = []
- async for frame in genfunc(**kwargs):
- frames.append(frame)
- return frames
-
-
-async def _run_both(kwargs: dict[str, Any]) -> tuple[list[str], list[str]]:
- """Drive both NEW-chat implementations on identical inputs."""
- _SEQ.reset()
- old = await _collect(old_stream_new_chat, **kwargs)
- _SEQ.reset()
- new = await _collect(new_stream_new_chat, **kwargs)
- return old, new
-
-
-async def _run_both_resume(kwargs: dict[str, Any]) -> tuple[list[str], list[str]]:
- """Drive both RESUME-chat implementations on identical inputs."""
- _SEQ.reset()
- old = await _collect(old_stream_resume_chat, **kwargs)
- _SEQ.reset()
- new = await _collect(new_stream_resume_chat, **kwargs)
- return old, new
-
-
-def _assert_parity(old: list[str], new: list[str]) -> None:
- """Byte-for-byte equality with a readable first-divergence message."""
- for i, (a, b) in enumerate(zip(old, new, strict=False)):
- assert a == b, f"frame[{i}] differs:\n old={a!r}\n new={b!r}"
- assert len(old) == len(new), (
- f"frame count differs: old={len(old)} new={len(new)}\n"
- f" old tail={old[len(new):]!r}\n new tail={new[len(old):]!r}"
- )
- assert old[-1].strip() == "data: [DONE]"
-
-
-# --------------------------------------------------------------------------- #
-# NEW-chat scenarios
-# --------------------------------------------------------------------------- #
-
-_NEW_KW = dict(user_query="hi", search_space_id=1, chat_id=42, user_id=None)
-
-
-@pytest.mark.asyncio
-async def test_auto_pin_failure_parity(monkeypatch: pytest.MonkeyPatch) -> None:
- """Auto-pin raises -> identical ``[error, DONE]`` from both."""
- _apply_common(monkeypatch, pin_raises=ValueError("no eligible config"))
- old, new = await _run_both(dict(_NEW_KW))
- _assert_parity(old, new)
- assert len(old) == 2
- assert '"errorCode": "SERVER_ERROR"' in old[0]
-
-
-@pytest.mark.asyncio
-async def test_llm_load_failure_parity(monkeypatch: pytest.MonkeyPatch) -> None:
- """LLM bundle load fails -> identical ``[error, DONE]`` from both."""
- _apply_common(monkeypatch, llm_load_ok=False)
- old, new = await _run_both(dict(_NEW_KW))
- _assert_parity(old, new)
- assert len(old) == 2
- assert '"errorCode": "SERVER_ERROR"' in old[0]
-
-
-@pytest.mark.asyncio
-async def test_persist_user_failure_parity(monkeypatch: pytest.MonkeyPatch) -> None:
- """User-turn persist returns None.
-
- Exercises the full initial-frame ordering (start, start-step, turn-info,
- turn-status busy), the MESSAGE_PERSIST_FAILED error, and final frames.
- """
- _apply_common(monkeypatch, persist_user_id=None)
- old, new = await _run_both(dict(_NEW_KW))
- _assert_parity(old, new)
- assert '"type": "start"' in old[0]
- assert '"chat_turn_id": "42:1700000000000"' in old[2]
- assert any('"errorCode": "MESSAGE_PERSIST_FAILED"' in f for f in old)
- assert any('"type": "finish"' in f for f in old)
-
-
-@pytest.mark.asyncio
-async def test_persist_assistant_failure_parity(monkeypatch: pytest.MonkeyPatch) -> None:
- """Assistant-shell persist returns None.
-
- Adds the ``data-user-message-id`` handshake frame ahead of the error.
- """
- _apply_common(monkeypatch, persist_user_id=101, persist_assistant_id=None)
- old, new = await _run_both(dict(_NEW_KW))
- _assert_parity(old, new)
- assert any('"data-user-message-id"' in f and '"message_id": 101' in f for f in old)
- assert any('"errorCode": "MESSAGE_PERSIST_FAILED"' in f for f in old)
-
-
-@pytest.mark.asyncio
-async def test_prestream_exception_parity(monkeypatch: pytest.MonkeyPatch) -> None:
- """A pre-stream failure routes both through the top-level ``except`` path.
-
- Resolver returns a non-int so ``turn_id`` math / downstream use raises after
- the span opens but before initial frames: both must emit the identical
- ``busy -> error -> idle -> finish-step -> finish -> DONE`` terminal sequence.
- """
-
- async def _bad_resolver(*a: Any, **k: Any):
- raise RuntimeError("boom in pre-stream")
-
- _apply_common(monkeypatch)
- # Override the resolver with a non-ValueError so the classified early-error
- # branches don't catch it -> top-level except path.
- for tgt in (
- "app.services.auto_model_pin_service.resolve_or_get_pinned_llm_config_id",
- "app.tasks.chat.stream_new_chat.resolve_or_get_pinned_llm_config_id",
- "app.tasks.chat.streaming.flows.new_chat.auto_pin.resolve_or_get_pinned_llm_config_id",
- ):
- _patch(monkeypatch, tgt, _bad_resolver)
- old, new = await _run_both(dict(_NEW_KW))
- _assert_parity(old, new)
- assert any('"type": "error"' in f for f in old)
-
-
-# --------------------------------------------------------------------------- #
-# RESUME-chat scenarios (no title-generation path -> fully deterministic)
-# --------------------------------------------------------------------------- #
-
-_RESUME_KW = dict(chat_id=42, search_space_id=1, decisions=[], user_id=None)
-
-
-async def _collect_resume_old() -> list[str]:
- _SEQ.reset()
- return await _collect(old_stream_resume_chat, **dict(_RESUME_KW))
-
-
-# NOTE: KNOWN, INTENTIONAL DIVERGENCE (flows fixes a latent monolith bug).
-#
-# In ``stream_resume_chat`` the monolith defines ``_resume_premium_request_id``
-# (line ~2363) AFTER the auto-pin / LLM-load early-return points (~2346 / ~2356).
-# Its ``finally`` block (line ~2918) reads that variable, so a resume turn whose
-# auto-pin raises or whose LLM bundle fails to load crashes with
-# ``UnboundLocalError`` instead of emitting a clean terminal-error frame. The
-# refactored flows orchestrator does NOT have this bug — it emits the proper
-# ``[error, DONE]`` sequence. We assert the divergence explicitly so the cutover
-# is a documented behavior IMPROVEMENT rather than a silent change.
-
-
-@pytest.mark.asyncio
-async def test_resume_auto_pin_failure_flows_fixes_monolith_crash(
- monkeypatch: pytest.MonkeyPatch,
-) -> None:
- _apply_common(monkeypatch, pin_raises=ValueError("no eligible config"))
- # Monolith: latent UnboundLocalError in the finally clause.
- with pytest.raises(UnboundLocalError, match="_resume_premium_request_id"):
- await _collect_resume_old()
- # Flows: clean terminal error.
- _SEQ.reset()
- new = await _collect(new_stream_resume_chat, **dict(_RESUME_KW))
- assert len(new) == 2
- assert new[-1].strip() == "data: [DONE]"
- assert '"type": "error"' in new[0]
-
-
-@pytest.mark.asyncio
-async def test_resume_llm_load_failure_flows_fixes_monolith_crash(
- monkeypatch: pytest.MonkeyPatch,
-) -> None:
- _apply_common(monkeypatch, llm_load_ok=False)
- with pytest.raises(UnboundLocalError, match="_resume_premium_request_id"):
- await _collect_resume_old()
- _SEQ.reset()
- new = await _collect(new_stream_resume_chat, **dict(_RESUME_KW))
- assert len(new) == 2
- assert new[-1].strip() == "data: [DONE]"
- assert '"type": "error"' in new[0]
-
-
-@pytest.mark.asyncio
-async def test_resume_persist_assistant_failure_parity(
- monkeypatch: pytest.MonkeyPatch,
-) -> None:
- """Resume emits NO user-message-id frame; only the assistant handshake path."""
- _apply_common(monkeypatch, persist_assistant_id=None)
- old, new = await _run_both_resume(dict(_RESUME_KW))
- _assert_parity(old, new)
- assert not any('"data-user-message-id"' in f for f in old)
- assert any('"chat_turn_id": "42:1700000000000"' in f for f in old)
diff --git a/surfsense_backend/tests/unit/tasks/chat/streaming/test_parallel_refactor_parity.py b/surfsense_backend/tests/unit/tasks/chat/streaming/test_parallel_refactor_parity.py
deleted file mode 100644
index 3a9a834f9..000000000
--- a/surfsense_backend/tests/unit/tasks/chat/streaming/test_parallel_refactor_parity.py
+++ /dev/null
@@ -1,584 +0,0 @@
-"""Parity gate for the parallel refactor of ``stream_new_chat.py``.
-
-The new tree under ``app.tasks.chat.streaming.flows`` is built side-by-side with
-the legacy monolithic ``app.tasks.chat.stream_new_chat`` so we can cut over
-atomically. This file pins externally-observable behaviour at module
-boundaries so a divergence between the two trees fails loudly *before* the
-cutover.
-
-What we verify:
-
- 1. **Signature parity** — ``stream_new_chat`` / ``stream_resume_chat`` from
- the new tree have the same call signature as the originals.
- 2. **Helper extraction parity** — the SRP modules in ``flows/`` produce the
- same outputs as the inline code in the legacy file for representative
- inputs (initial thinking step, image-capability gate, runtime context,
- SSE frame sequences, token-usage frame shape, persistence guards).
- 3. **Wrapper delegation** — wrappers like ``load_llm_bundle`` /
- ``can_recover_provider_rate_limit`` exist and are addressable.
-
-Delete this file along with ``stream_new_chat.py`` once the cutover is done
-(see the parent refactor plan).
-"""
-
-from __future__ import annotations
-
-import asyncio
-import inspect
-from typing import Any
-from unittest.mock import AsyncMock, patch
-
-import pytest
-
-from app.agents.shared.context import SurfSenseContextSchema
-from app.services.new_streaming_service import VercelStreamingService
-from app.tasks.chat.stream_new_chat import (
- stream_new_chat as old_stream_new_chat,
- stream_resume_chat as old_stream_resume_chat,
-)
-from app.tasks.chat.streaming.flows import (
- stream_new_chat as new_stream_new_chat,
- stream_resume_chat as new_stream_resume_chat,
-)
-from app.tasks.chat.streaming.flows.new_chat.initial_thinking_step import (
- build_initial_thinking_step,
-)
-from app.tasks.chat.streaming.flows.new_chat.llm_capability import (
- check_image_input_capability,
-)
-from app.tasks.chat.streaming.flows.new_chat.persistence_spawn import (
- await_persist_task,
- spawn_persist_assistant_shell_task,
- spawn_persist_user_task,
- spawn_set_ai_responding_bg,
-)
-from app.tasks.chat.streaming.flows.new_chat.runtime_context import (
- build_new_chat_runtime_context,
-)
-from app.tasks.chat.streaming.flows.resume_chat.runtime_context import (
- build_resume_chat_runtime_context,
-)
-from app.tasks.chat.streaming.flows.shared.finalize_emit import iter_token_usage_frame
-from app.tasks.chat.streaming.flows.shared.first_frames import (
- iter_final_frames,
- iter_initial_frames,
-)
-from app.tasks.chat.streaming.flows.shared.llm_bundle import load_llm_bundle
-from app.tasks.chat.streaming.flows.shared.premium_quota import (
- PremiumReservation,
- needs_premium_quota,
-)
-from app.tasks.chat.streaming.flows.shared.rate_limit_recovery import (
- can_recover_provider_rate_limit,
-)
-
-pytestmark = pytest.mark.unit
-
-
-# --------------------------------------------------------------------- signature
-
-
-def _normalize_annotation(ann: Any) -> str:
- """Compare-friendly form for an annotation.
-
- The legacy ``stream_new_chat.py`` does NOT use ``from __future__ import
- annotations``, so its annotations are evaluated at import time and come
- back as type objects / typing generics. The new tree DOES use it, so its
- annotations are PEP-563 strings.
-
- Both reprs describe the same types — strip the module prefixes / typing
- namespace + the ```` wrapper so we compare the canonical
- declared form.
- """
- if ann is inspect.Signature.empty:
- return ""
- raw = ann if isinstance(ann, str) else repr(ann)
- cleaned = (
- raw.replace("typing.", "")
- .replace("collections.abc.", "")
- .replace("app.db.", "")
- .replace("app.agents.shared.filesystem_selection.", "")
- .replace("app.agents.shared.context.", "")
- )
- # Unwrap ```` → ``int`` (legacy-side type objects).
- if cleaned.startswith(""):
- cleaned = cleaned[len("")]
- return cleaned
-
-
-def _normalize_sig(sig: inspect.Signature) -> list[tuple[str, Any, str]]:
- return [
- (p.name, p.default, _normalize_annotation(p.annotation))
- for p in sig.parameters.values()
- ]
-
-
-def test_stream_new_chat_signature_matches_legacy() -> None:
- old = inspect.signature(old_stream_new_chat)
- new = inspect.signature(new_stream_new_chat)
- assert _normalize_sig(new) == _normalize_sig(old)
- assert _normalize_annotation(new.return_annotation) == _normalize_annotation(
- old.return_annotation
- )
-
-
-def test_stream_resume_chat_signature_matches_legacy() -> None:
- old = inspect.signature(old_stream_resume_chat)
- new = inspect.signature(new_stream_resume_chat)
- assert _normalize_sig(new) == _normalize_sig(old)
- assert _normalize_annotation(new.return_annotation) == _normalize_annotation(
- old.return_annotation
- )
-
-
-def test_orchestrators_are_async_generator_functions() -> None:
- assert inspect.isasyncgenfunction(new_stream_new_chat)
- assert inspect.isasyncgenfunction(new_stream_resume_chat)
-
-
-# ------------------------------------------------------------ initial thinking
-
-
-@pytest.mark.parametrize(
- "user_query, image_urls, expected_title, expected_action",
- [
- ("hello world", None, "Understanding your request", "Processing"),
- (
- "",
- ["data:image/png;base64,AAA"],
- "Understanding your request",
- "Processing",
- ),
- ("", None, "Understanding your request", "Processing"),
- ],
-)
-def test_initial_thinking_step_branches(
- user_query: str,
- image_urls: list[str] | None,
- expected_title: str,
- expected_action: str,
-) -> None:
- step = build_initial_thinking_step(
- user_query=user_query,
- user_image_data_urls=image_urls,
- )
- assert step.step_id == "thinking-1"
- assert step.title == expected_title
- assert len(step.items) == 1
- assert step.items[0].startswith(f"{expected_action}: ")
-
-
-def test_initial_thinking_step_truncates_long_query() -> None:
- long_query = "x" * 200
- step = build_initial_thinking_step(
- user_query=long_query,
- user_image_data_urls=None,
- )
- # 80-char truncation + ellipsis, sandwiched after "Processing: ".
- assert "..." in step.items[0]
- item = step.items[0]
- payload = item[len("Processing: ") :]
- assert payload.startswith("x" * 80) and payload.endswith("...")
-
-
-# ------------------------------------------------------------ capability gate
-
-
-def test_image_capability_passes_without_images() -> None:
- assert (
- check_image_input_capability(user_image_data_urls=None, agent_config=None)
- is None
- )
-
-
-def test_image_capability_passes_when_capability_unknown() -> None:
- """Unknown / unmapped models are not blocked — only models LiteLLM has
- *explicitly* marked text-only trip the gate."""
-
- class _AgentConfig:
- provider = "openrouter"
- model_name = "unknown-mystery-model"
- custom_provider = None
- config_name = "Unknown"
- litellm_params: dict[str, Any] = {}
-
- with patch(
- "app.services.provider_capabilities.is_known_text_only_chat_model",
- return_value=False,
- ):
- assert (
- check_image_input_capability(
- user_image_data_urls=["data:image/png;base64,AAA"],
- agent_config=_AgentConfig(), # type: ignore[arg-type]
- )
- is None
- )
-
-
-def test_image_capability_blocks_known_text_only_models() -> None:
- class _AgentConfig:
- provider = "openai"
- model_name = "gpt-3.5-turbo"
- custom_provider = None
- config_name = "GPT-3.5"
- litellm_params: dict[str, Any] = {"base_model": "gpt-3.5-turbo"}
-
- with patch(
- "app.services.provider_capabilities.is_known_text_only_chat_model",
- return_value=True,
- ):
- result = check_image_input_capability(
- user_image_data_urls=["data:image/png;base64,AAA"],
- agent_config=_AgentConfig(), # type: ignore[arg-type]
- )
- assert result is not None
- message, error_code = result
- assert error_code == "MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT"
- assert "GPT-3.5" in message
-
-
-# ---------------------------------------------------------------- runtime ctx
-
-
-def test_new_chat_runtime_context_prefers_accepted_folder_ids() -> None:
- """Post-resolve accepted folder ids win over the raw requested ids."""
- ctx = build_new_chat_runtime_context(
- search_space_id=7,
- mentioned_document_ids=[1, 2],
- accepted_folder_ids=[10],
- mentioned_folder_ids=[20, 30],
- mentioned_connector_ids=None,
- mentioned_connectors=None,
- request_id="req",
- turn_id="t1",
- )
- assert isinstance(ctx, SurfSenseContextSchema)
- assert ctx.search_space_id == 7
- assert list(ctx.mentioned_document_ids) == [1, 2]
- assert list(ctx.mentioned_folder_ids) == [10]
- assert ctx.request_id == "req"
- assert ctx.turn_id == "t1"
-
-
-def test_new_chat_runtime_context_falls_back_to_mentioned_folder_ids() -> None:
- """With no accepted ids, the raw requested folder ids flow through."""
- ctx = build_new_chat_runtime_context(
- search_space_id=7,
- mentioned_document_ids=None,
- accepted_folder_ids=[],
- mentioned_folder_ids=[20, 30],
- mentioned_connector_ids=None,
- mentioned_connectors=None,
- request_id=None,
- turn_id="t2",
- )
- assert list(ctx.mentioned_folder_ids) == [20, 30]
-
-
-def test_new_chat_runtime_context_propagates_connector_mentions() -> None:
- """@-selected connector ids/accounts ride onto the runtime context schema.
-
- Parity with the legacy ``stream_new_chat`` runtime context, which set both
- ``mentioned_connector_ids`` and ``mentioned_connectors`` on the schema.
- """
- connectors = [{"id": 5, "connector_type": "SLACK_CONNECTOR", "title": "acme"}]
- ctx = build_new_chat_runtime_context(
- search_space_id=7,
- mentioned_document_ids=None,
- accepted_folder_ids=[],
- mentioned_folder_ids=None,
- mentioned_connector_ids=[5],
- mentioned_connectors=connectors,
- request_id=None,
- turn_id="t3",
- )
- assert list(ctx.mentioned_connector_ids) == [5]
- assert list(ctx.mentioned_connectors) == connectors
-
-
-def test_resume_chat_runtime_context_empty_mention_lists() -> None:
- ctx = build_resume_chat_runtime_context(
- search_space_id=42, request_id="req-r", turn_id="t-r"
- )
- assert ctx.search_space_id == 42
- assert ctx.request_id == "req-r"
- assert ctx.turn_id == "t-r"
-
-
-# ---------------------------------------------------------------- SSE frames
-
-
-def test_iter_initial_frames_emits_canonical_sequence() -> None:
- svc = VercelStreamingService()
- frames = list(iter_initial_frames(svc, turn_id="42:1700000000000"))
- # Exactly 4 frames: message_start, start_step, turn-info (turn_id), turn-status (busy).
- assert len(frames) == 4
- assert "42:1700000000000" in frames[2]
- assert '"status":"busy"' in frames[3] or '"status": "busy"' in frames[3]
-
-
-def test_iter_final_frames_emits_idle_then_finish_done() -> None:
- svc = VercelStreamingService()
- frames = list(iter_final_frames(svc))
- assert len(frames) == 4
- assert '"status":"idle"' in frames[0] or '"status": "idle"' in frames[0]
-
-
-# ----------------------------------------------------------- token usage frame
-
-
-class _FakeAccumulator:
- """Minimal stand-in covering only the fields ``iter_token_usage_frame`` reads."""
-
- def __init__(self, summary: Any = None) -> None:
- self._summary = summary
- self.calls = [1, 2, 3]
- self.grand_total = 100
- self.total_cost_micros = 50_000
- self.total_prompt_tokens = 60
- self.total_completion_tokens = 40
-
- def per_message_summary(self) -> Any:
- return self._summary
-
- def serialized_calls(self) -> list[Any]:
- return list(self.calls)
-
-
-def test_token_usage_frame_skipped_when_no_summary() -> None:
- svc = VercelStreamingService()
- frames = list(
- iter_token_usage_frame(
- svc,
- accumulator=_FakeAccumulator(summary=None), # type: ignore[arg-type]
- log_label="parity-empty",
- )
- )
- assert frames == []
-
-
-def test_token_usage_frame_emitted_when_summary_present() -> None:
- svc = VercelStreamingService()
- frames = list(
- iter_token_usage_frame(
- svc,
- accumulator=_FakeAccumulator(summary=[{"m": "x", "t": 100}]), # type: ignore[arg-type]
- log_label="parity-populated",
- )
- )
- assert len(frames) == 1
- # Field shape on the wire is fixed by the FE; assert each surfaces.
- payload = frames[0]
- for key in (
- '"prompt_tokens":60',
- '"completion_tokens":40',
- '"total_tokens":100',
- '"cost_micros":50000',
- ):
- assert key in payload.replace(" ", "")
-
-
-# ------------------------------------------------------------------ llm_bundle
-
-
-def test_load_llm_bundle_routes_negative_id_to_yaml_loader() -> None:
- async def _run() -> tuple[Any, Any, str | None]:
- with (
- patch(
- "app.tasks.chat.streaming.flows.shared.llm_bundle.load_global_llm_config_by_id",
- return_value=None,
- ),
- ):
- return await load_llm_bundle(
- session=AsyncMock(), # type: ignore[arg-type]
- config_id=-1,
- search_space_id=7,
- )
-
- llm, agent_config, error = asyncio.run(_run())
- assert llm is None
- assert agent_config is None
- assert error is not None and "id -1" in error
-
-
-def test_load_llm_bundle_routes_nonnegative_id_to_db_loader() -> None:
- async def _run() -> tuple[Any, Any, str | None]:
- with (
- patch(
- "app.tasks.chat.streaming.flows.shared.llm_bundle.load_agent_config",
- new=AsyncMock(return_value=None),
- ),
- ):
- return await load_llm_bundle(
- session=AsyncMock(), # type: ignore[arg-type]
- config_id=12,
- search_space_id=7,
- )
-
- llm, agent_config, error = asyncio.run(_run())
- assert llm is None
- assert agent_config is None
- assert error is not None and "id 12" in error
-
-
-# ----------------------------------------------------------------- premium quota
-
-
-def test_needs_premium_quota_requires_user_and_premium_flag() -> None:
- class _AgentConfig:
- is_premium = True
-
- class _NonPremium:
- is_premium = False
-
- assert needs_premium_quota(_AgentConfig(), "user-1") is True # type: ignore[arg-type]
- assert needs_premium_quota(_AgentConfig(), None) is False # type: ignore[arg-type]
- assert needs_premium_quota(_NonPremium(), "user-1") is False # type: ignore[arg-type]
- assert needs_premium_quota(None, "user-1") is False
-
-
-def test_premium_reservation_dataclass_shape() -> None:
- # Sanity: the dataclass exists and carries the fields the orchestrator uses.
- r = PremiumReservation(request_id="abc", reserved_micros=100, allowed=True)
- assert r.request_id == "abc"
- assert r.reserved_micros == 100
- assert r.allowed is True
-
-
-# ----------------------------------------------------------- rate-limit guard
-
-
-@pytest.mark.parametrize(
- "first_event_seen, recovered, requested_id, current_id, expected",
- [
- (False, False, 0, -1, True),
- # Already recovered: no second pass.
- (False, True, 0, -1, False),
- # User explicitly picked a config: don't silently switch.
- (False, False, 5, -1, False),
- # Already on a database-backed (positive) id.
- (False, False, 0, 7, False),
- # User has already seen output: silent rebuild not possible.
- (True, False, 0, -1, False),
- ],
-)
-def test_can_recover_provider_rate_limit_truth_table(
- first_event_seen: bool,
- recovered: bool,
- requested_id: int,
- current_id: int,
- expected: bool,
-) -> None:
- # Use a known rate-limit-shaped exception so the helper's last condition
- # is satisfied; the guard only short-circuits to False when one of the
- # *other* preconditions fails.
- exc = Exception('{"error":{"type":"rate_limit_error","message":"slow"}}')
- assert (
- can_recover_provider_rate_limit(
- exc,
- first_event_seen=first_event_seen,
- runtime_rate_limit_recovered=recovered,
- requested_llm_config_id=requested_id,
- current_llm_config_id=current_id,
- )
- is expected
- )
-
-
-def test_can_recover_provider_rate_limit_rejects_non_rate_limit_exception() -> None:
- assert (
- can_recover_provider_rate_limit(
- ValueError("not a rate limit"),
- first_event_seen=False,
- runtime_rate_limit_recovered=False,
- requested_llm_config_id=0,
- current_llm_config_id=-1,
- )
- is False
- )
-
-
-# --------------------------------------------------------- persistence spawn
-
-
-def test_spawn_set_ai_responding_bg_noop_without_user_id() -> None:
- async def _run() -> set[asyncio.Task]:
- background: set[asyncio.Task] = set()
- spawn_set_ai_responding_bg(chat_id=1, user_id=None, background_tasks=background)
- return background
-
- bg = asyncio.run(_run())
- assert bg == set()
-
-
-def test_spawn_persist_user_task_registers_and_self_unregisters() -> None:
- async def _run() -> tuple[int, int]:
- background: set[asyncio.Task] = set()
- with patch(
- "app.tasks.chat.streaming.flows.new_chat.persistence_spawn.persist_user_turn",
- new=AsyncMock(return_value=99),
- ):
- task = spawn_persist_user_task(
- chat_id=1,
- user_id="u",
- turn_id="t",
- user_query="hi",
- user_image_data_urls=None,
- mentioned_documents=None,
- background_tasks=background,
- )
- size_before_await = len(background)
- result = await asyncio.shield(task)
- # Give the done-callback one event-loop tick to run.
- await asyncio.sleep(0)
- return size_before_await, result # type: ignore[return-value]
-
- size_before, result = asyncio.run(_run())
- assert size_before == 1
- assert result == 99
-
-
-def test_spawn_persist_assistant_shell_task_registers() -> None:
- async def _run() -> int | None:
- background: set[asyncio.Task] = set()
- with patch(
- "app.tasks.chat.streaming.flows.new_chat.persistence_spawn.persist_assistant_shell",
- new=AsyncMock(return_value=42),
- ):
- task = spawn_persist_assistant_shell_task(
- chat_id=1,
- user_id="u",
- turn_id="t",
- background_tasks=background,
- )
- return await asyncio.shield(task)
-
- assert asyncio.run(_run()) == 42
-
-
-def test_await_persist_task_returns_none_on_failure() -> None:
- async def _run() -> int | None:
- async def _boom() -> int:
- raise RuntimeError("DB down")
-
- task = asyncio.create_task(_boom())
- return await await_persist_task(
- task,
- chat_id=1,
- turn_id="t",
- log_label="parity-failure",
- )
-
- assert asyncio.run(_run()) is None
-
-
-def test_await_persist_task_returns_none_for_none_input() -> None:
- async def _run() -> int | None:
- return await await_persist_task(
- None,
- chat_id=1,
- turn_id="t",
- log_label="parity-none",
- )
-
- assert asyncio.run(_run()) is None
diff --git a/surfsense_backend/tests/unit/tasks/chat/streaming/test_stage_1_parity.py b/surfsense_backend/tests/unit/tasks/chat/streaming/test_stage_1_parity.py
deleted file mode 100644
index a4bd1d56c..000000000
--- a/surfsense_backend/tests/unit/tasks/chat/streaming/test_stage_1_parity.py
+++ /dev/null
@@ -1,240 +0,0 @@
-"""Pin Stage 1 extractions as faithful copies of the old helpers.
-
-Extractions under ``app.tasks.chat.streaming`` are compared to
-``app.tasks.chat.stream_new_chat`` helpers.
-For each Stage 1 extraction we assert the new function returns the same
-output as the old one for a representative input set. The moment the
-two diverge - intentionally or otherwise - this file fails loudly so
-the divergence is reviewed rather than shipped silently.
-"""
-
-from __future__ import annotations
-
-import logging
-from dataclasses import dataclass, field
-from typing import Any
-
-import pytest
-
-from app.agents.shared.errors import BusyError
-from app.agents.shared.middleware.busy_mutex import request_cancel, reset_cancel
-from app.tasks.chat.stream_new_chat import (
- _classify_stream_exception as old_classify,
- _emit_stream_terminal_error as old_emit_terminal_error,
- _extract_chunk_parts as old_extract_chunk_parts,
- _extract_resolved_file_path as old_extract_resolved_file_path,
- _tool_output_has_error as old_tool_output_has_error,
- _tool_output_to_text as old_tool_output_to_text,
-)
-from app.tasks.chat.streaming.errors.classifier import (
- classify_stream_exception as new_classify,
-)
-from app.tasks.chat.streaming.errors.emitter import (
- emit_stream_terminal_error as new_emit_terminal_error,
-)
-from app.tasks.chat.streaming.helpers.chunk_parts import (
- extract_chunk_parts as new_extract_chunk_parts,
-)
-from app.tasks.chat.streaming.helpers.tool_output import (
- extract_resolved_file_path as new_extract_resolved_file_path,
- tool_output_has_error as new_tool_output_has_error,
- tool_output_to_text as new_tool_output_to_text,
-)
-
-pytestmark = pytest.mark.unit
-
-
-# ---------------------------------------------------------------- chunk parts
-
-
-@dataclass
-class _Chunk:
- content: Any = ""
- additional_kwargs: dict[str, Any] = field(default_factory=dict)
- tool_call_chunks: list[dict[str, Any]] = field(default_factory=list)
-
-
-_CHUNK_CASES: list[Any] = [
- None,
- _Chunk(content=""),
- _Chunk(content="hello"),
- _Chunk(content=42), # invalid type, defensively coerced to empty
- _Chunk(
- content=[
- {"type": "text", "text": "Hello "},
- {"type": "text", "text": "world"},
- ]
- ),
- _Chunk(
- content=[
- {"type": "reasoning", "reasoning": "hmm "},
- {"type": "reasoning", "text": "still"},
- {"type": "text", "text": "answer"},
- ]
- ),
- _Chunk(
- content=[
- {"type": "tool_call_chunk", "id": "c1", "name": "x", "args": "{"},
- {"type": "tool_use", "id": "c2", "name": "y"},
- {"type": "image_url", "url": "ignored"},
- ]
- ),
- _Chunk(
- content="visible",
- additional_kwargs={"reasoning_content": "private"},
- ),
- _Chunk(
- tool_call_chunks=[
- {"id": None, "name": None, "args": '{"a":1}', "index": 0},
- {"id": "c", "name": "n", "args": "}", "index": 0},
- ]
- ),
- _Chunk(
- content=[{"type": "tool_call_chunk", "id": "from-block", "name": "x"}],
- tool_call_chunks=[{"id": "from-attr", "name": "y"}],
- ),
-]
-
-
-@pytest.mark.parametrize("chunk", _CHUNK_CASES)
-def test_extract_chunk_parts_matches_old_implementation(chunk: Any) -> None:
- assert new_extract_chunk_parts(chunk) == old_extract_chunk_parts(chunk)
-
-
-# ----------------------------------------------------------- error classifier
-
-
-def _classify_cases() -> list[Exception]:
- """Inputs that the FE depends on being mapped to specific error codes."""
- return [
- Exception("totally generic error"),
- Exception('{"error":{"type":"rate_limit_error","message":"slow down"}}'),
- Exception(
- 'OpenrouterException - {"error":{"message":"Provider returned error",'
- '"code":429}}'
- ),
- BusyError(request_id="thread-busy-parity"),
- Exception("Thread is busy with another request"),
- ]
-
-
-@pytest.mark.parametrize("exc", _classify_cases())
-def test_classify_stream_exception_matches_old_implementation(
- exc: Exception,
-) -> None:
- new = new_classify(exc, flow_label="parity-test")
- old = old_classify(exc, flow_label="parity-test")
- # Strip the wall-clock retry timestamp before comparing — both
- # implementations call ``time.time()`` independently and the call
- # order is enough to differ by 1 ms in practice. Every other field
- # in the tuple must match exactly.
- new_extra = dict(new[5]) if isinstance(new[5], dict) else new[5]
- old_extra = dict(old[5]) if isinstance(old[5], dict) else old[5]
- if isinstance(new_extra, dict) and isinstance(old_extra, dict):
- new_extra.pop("retry_after_at", None)
- old_extra.pop("retry_after_at", None)
- assert new[:5] == old[:5]
- assert new_extra == old_extra
-
-
-def test_classify_turn_cancelling_branch_parity() -> None:
- """The TURN_CANCELLING branch reads cancel state for the busy thread id;
- both implementations must agree on retry-window semantics, not just the
- plain THREAD_BUSY code."""
- thread_id = "parity-cancelling-thread"
- reset_cancel(thread_id)
- request_cancel(thread_id)
- exc = BusyError(request_id=thread_id)
- new = new_classify(exc, flow_label="parity-test")
- old = old_classify(exc, flow_label="parity-test")
- assert new[0] == old[0] == "thread_busy"
- assert new[1] == old[1] == "TURN_CANCELLING"
- assert isinstance(new[5], dict) and isinstance(old[5], dict)
- assert new[5]["retry_after_ms"] == old[5]["retry_after_ms"]
-
-
-# ------------------------------------------------------------ terminal emitter
-
-
-class _FakeStreamingService:
- """Duck-types ``format_error`` for both old and new emitters."""
-
- def __init__(self) -> None:
- self.calls: list[dict[str, Any]] = []
-
- def format_error(
- self, message: str, *, error_code: str, extra: dict[str, Any] | None = None
- ) -> str:
- self.calls.append(
- {"message": message, "error_code": error_code, "extra": extra}
- )
- return f'data: {{"type":"error","errorText":"{message}"}}\n\n'
-
-
-def test_emit_stream_terminal_error_matches_old_output_and_logs(caplog) -> None:
- """The new emitter must produce the same SSE frame and log the same
- structured payload as the old one for the same arguments."""
- args: dict[str, Any] = {
- "flow": "new",
- "request_id": "req-parity",
- "thread_id": 7,
- "search_space_id": 9,
- "user_id": "user-parity",
- "message": "boom",
- "error_kind": "server_error",
- "error_code": "SERVER_ERROR",
- "severity": "error",
- "is_expected": False,
- "extra": {"foo": "bar"},
- }
-
- new_svc = _FakeStreamingService()
- old_svc = _FakeStreamingService()
-
- with caplog.at_level(logging.ERROR):
- new_frame = new_emit_terminal_error(streaming_service=new_svc, **args)
- old_frame = old_emit_terminal_error(streaming_service=old_svc, **args)
-
- assert new_frame == old_frame
- assert new_svc.calls == old_svc.calls
- chat_error_records = [
- r for r in caplog.records if "[chat_stream_error]" in r.message
- ]
- # One log line per emit call (two emits -> two records).
- assert len(chat_error_records) == 2
-
-
-# ---------------------------------------------------------------- tool output
-
-
-def test_tool_output_helpers_match_old_implementation() -> None:
- samples: list[Any] = [
- {"result": "ok"},
- {"error": "bad"},
- {"result": "Error: x"},
- "Error: plain",
- "fine",
- {"nested": {"a": 1}},
- ]
- for s in samples:
- assert new_tool_output_to_text(s) == old_tool_output_to_text(s)
- assert new_tool_output_has_error(s) == old_tool_output_has_error(s)
-
- assert new_extract_resolved_file_path(
- tool_name="write_file",
- tool_output={"path": " /tmp/x "},
- tool_input=None,
- ) == old_extract_resolved_file_path(
- tool_name="write_file",
- tool_output={"path": " /tmp/x "},
- tool_input=None,
- )
- assert new_extract_resolved_file_path(
- tool_name="write_file",
- tool_output={},
- tool_input={"file_path": " /fallback "},
- ) == old_extract_resolved_file_path(
- tool_name="write_file",
- tool_output={},
- tool_input={"file_path": " /fallback "},
- )
diff --git a/surfsense_backend/tests/unit/tasks/chat/streaming/test_stage_2_parity.py b/surfsense_backend/tests/unit/tasks/chat/streaming/test_stage_2_parity.py
deleted file mode 100644
index 3ee1ab622..000000000
--- a/surfsense_backend/tests/unit/tasks/chat/streaming/test_stage_2_parity.py
+++ /dev/null
@@ -1,241 +0,0 @@
-"""Parity tests for Stage 2 extractions (tool matching, thinking step, custom events)."""
-
-from __future__ import annotations
-
-from typing import Any
-from unittest.mock import MagicMock
-
-import pytest
-
-from app.tasks.chat.stream_new_chat import _legacy_match_lc_id as old_legacy_match
-from app.tasks.chat.streaming.handlers.custom_events import (
- handle_action_log,
- handle_action_log_updated,
- handle_document_created,
- handle_report_progress,
-)
-from app.tasks.chat.streaming.helpers.tool_call_matching import (
- match_buffered_langchain_tool_call_id as new_legacy_match,
-)
-from app.tasks.chat.streaming.relay.state import AgentEventRelayState
-from app.tasks.chat.streaming.relay.thinking_step_completion import (
- complete_active_thinking_step,
-)
-from app.tasks.chat.streaming.relay.thinking_step_sse import emit_thinking_step_frame
-
-pytestmark = pytest.mark.unit
-
-
-def _copy_chunk_buffer(raw: list[dict[str, Any]]) -> list[dict[str, Any]]:
- return [dict(x) for x in raw]
-
-
-def test_legacy_tool_call_match_matches_old_implementation() -> None:
- cases: list[tuple[list[dict[str, Any]], str, str, dict[str, str]]] = [
- (
- [
- {"name": "write_file", "id": "lc-a"},
- {"name": "other", "id": "lc-b"},
- ],
- "write_file",
- "run-1",
- {},
- ),
- (
- [{"name": "x", "id": None}, {"name": "y", "id": "lc-fallback"}],
- "write_file",
- "run-2",
- {},
- ),
- ([{"name": "no_id"}], "write_file", "run-3", {}),
- ]
- for chunks_template, tool_name, run_id, lc_map_seed in cases:
- old_chunks = _copy_chunk_buffer(chunks_template)
- new_chunks = _copy_chunk_buffer(chunks_template)
- old_map = dict(lc_map_seed)
- new_map = dict(lc_map_seed)
- old_out = old_legacy_match(old_chunks, tool_name, run_id, old_map)
- new_out = new_legacy_match(new_chunks, tool_name, run_id, new_map)
- assert new_out == old_out
- assert new_chunks == old_chunks
- assert new_map == old_map
-
-
-def test_emit_thinking_step_frame_invokes_builder_before_service() -> None:
- order: list[str] = []
- builder = MagicMock()
-
- def on_ts(*args: Any, **kwargs: Any) -> None:
- order.append("builder")
-
- builder.on_thinking_step.side_effect = on_ts
-
- svc = MagicMock()
-
- def fmt(**kwargs: Any) -> str:
- order.append("service")
- return "frame"
-
- svc.format_thinking_step.side_effect = fmt
-
- out = emit_thinking_step_frame(
- streaming_service=svc,
- content_builder=builder,
- step_id="thinking-1",
- title="Working",
- status="in_progress",
- items=["a"],
- )
- assert out == "frame"
- assert order == ["builder", "service"]
- builder.on_thinking_step.assert_called_once()
- svc.format_thinking_step.assert_called_once()
-
-
-def test_emit_thinking_step_frame_skips_builder_when_none() -> None:
- svc = MagicMock(return_value="x")
- svc.format_thinking_step.return_value = "frame"
- assert (
- emit_thinking_step_frame(
- streaming_service=svc,
- content_builder=None,
- step_id="s",
- title="t",
- )
- == "frame"
- )
- svc.format_thinking_step.assert_called_once()
-
-
-def test_complete_active_thinking_step_mirrors_closure_semantics() -> None:
- svc = MagicMock()
- svc.format_thinking_step.return_value = "done-frame"
- completed: set[str] = set()
- relay_state = AgentEventRelayState.for_invocation()
-
- frame, new_id = complete_active_thinking_step(
- state=relay_state,
- streaming_service=svc,
- content_builder=None,
- last_active_step_id="thinking-1",
- last_active_step_title="T",
- last_active_step_items=["x"],
- completed_step_ids=completed,
- )
- assert frame == "done-frame"
- assert new_id is None
- assert "thinking-1" in completed
-
- frame2, id2 = complete_active_thinking_step(
- state=relay_state,
- streaming_service=svc,
- content_builder=None,
- last_active_step_id="thinking-1",
- last_active_step_title="T",
- last_active_step_items=[],
- completed_step_ids=completed,
- )
- assert frame2 is None
- assert id2 == "thinking-1"
-
-
-def test_agent_event_relay_state_factory_matches_counter_rule() -> None:
- s0 = AgentEventRelayState.for_invocation()
- assert s0.thinking_step_counter == 0
- assert s0.last_active_step_id is None
-
- s1 = AgentEventRelayState.for_invocation(
- initial_step_id="thinking-resume-1",
- initial_step_title="Inherited",
- initial_step_items=["Topic: X"],
- )
- assert s1.thinking_step_counter == 1
- assert s1.last_active_step_id == "thinking-resume-1"
- assert s1.next_thinking_step_id("thinking") == "thinking-2"
-
-
-@pytest.mark.parametrize(
- ("phase", "message", "start_items", "expected_tail"),
- [
- (
- "revising_section",
- "progress line",
- ["Topic: Foo", "Modifying bar", "stale..."],
- ["Topic: Foo", "Modifying bar", "progress line"],
- ),
- (
- "other",
- "phase msg",
- ["Topic: Foo", "old line"],
- ["Topic: Foo", "phase msg"],
- ),
- ],
-)
-def test_report_progress_items_match_reference(
- phase: str,
- message: str,
- start_items: list[str],
- expected_tail: list[str],
-) -> None:
- svc = MagicMock()
- svc.format_thinking_step.return_value = "sse"
-
- items = list(start_items)
- frame, new_items = handle_report_progress(
- {"message": message, "phase": phase},
- last_active_step_id="step-1",
- last_active_step_title="Report",
- last_active_step_items=items,
- streaming_service=svc,
- content_builder=None,
- )
- assert frame == "sse"
- assert new_items == expected_tail
- kwargs = svc.format_thinking_step.call_args.kwargs
- assert kwargs["items"] == expected_tail
-
-
-def test_report_progress_noop_when_missing_message_or_step() -> None:
- svc = MagicMock()
- items = ["Topic: A"]
- f1, i1 = handle_report_progress(
- {"message": "", "phase": "x"},
- last_active_step_id="s",
- last_active_step_title="t",
- last_active_step_items=items,
- streaming_service=svc,
- content_builder=None,
- )
- assert f1 is None and i1 is items
-
- f2, i2 = handle_report_progress(
- {"message": "m", "phase": "x"},
- last_active_step_id=None,
- last_active_step_title="t",
- last_active_step_items=items,
- streaming_service=svc,
- content_builder=None,
- )
- assert f2 is None and i2 is items
-
-
-def test_document_action_handlers_match_format_data_guards() -> None:
- svc = MagicMock()
- svc.format_data.return_value = "data-frame"
-
- assert handle_document_created({}, streaming_service=svc) is None
- assert handle_document_created({"id": 0}, streaming_service=svc) is None
- handle_document_created({"id": 42, "title": "x"}, streaming_service=svc)
- svc.format_data.assert_called_with(
- "documents-updated", {"action": "created", "document": {"id": 42, "title": "x"}}
- )
-
- svc.reset_mock()
- assert handle_action_log({"id": None}, streaming_service=svc) is None
- handle_action_log({"id": 1}, streaming_service=svc)
- svc.format_data.assert_called_once_with("action-log", {"id": 1})
-
- svc.reset_mock()
- assert handle_action_log_updated({"id": None}, streaming_service=svc) is None
- handle_action_log_updated({"id": 2}, streaming_service=svc)
- svc.format_data.assert_called_once_with("action-log-updated", {"id": 2})
diff --git a/surfsense_backend/tests/unit/tasks/chat/test_extract_chunk_parts.py b/surfsense_backend/tests/unit/tasks/chat/test_extract_chunk_parts.py
index 1263a5fe1..0f154f1dc 100644
--- a/surfsense_backend/tests/unit/tasks/chat/test_extract_chunk_parts.py
+++ b/surfsense_backend/tests/unit/tasks/chat/test_extract_chunk_parts.py
@@ -1,4 +1,4 @@
-"""Unit tests for ``stream_new_chat._extract_chunk_parts``.
+"""Unit tests for ``streaming.helpers.chunk_parts.extract_chunk_parts``.
Earlier versions only handled ``isinstance(chunk.content, str)`` and
silently dropped every other shape (Anthropic typed-block lists,
@@ -14,7 +14,9 @@ from typing import Any
import pytest
-from app.tasks.chat.stream_new_chat import _extract_chunk_parts
+from app.tasks.chat.streaming.helpers.chunk_parts import (
+ extract_chunk_parts as _extract_chunk_parts,
+)
@dataclass
diff --git a/surfsense_backend/tests/unit/tasks/chat/test_thinking_step_id_uniqueness.py b/surfsense_backend/tests/unit/tasks/chat/test_thinking_step_id_uniqueness.py
index 50f7b8070..7f8285e98 100644
--- a/surfsense_backend/tests/unit/tasks/chat/test_thinking_step_id_uniqueness.py
+++ b/surfsense_backend/tests/unit/tasks/chat/test_thinking_step_id_uniqueness.py
@@ -7,7 +7,7 @@ constructs a fresh :class:`AgentEventRelayState` with
``thinking_step_counter=0``), React renders sibling timeline rows with the
same key — the warning the user reported in production.
-The contract this module pins: each ``_stream_agent_events`` invocation must
+The contract this module pins: each ``stream_agent_events`` invocation must
receive a ``step_prefix`` that is unique within the thread (we salt with the
per-turn ``turn_id``), so the resulting step IDs across consecutive turns
are always disjoint.
@@ -23,10 +23,12 @@ from typing import Any
import pytest
from app.services.new_streaming_service import VercelStreamingService
-from app.tasks.chat.stream_new_chat import (
- StreamResult,
- _resume_step_prefix,
- _stream_agent_events,
+from app.tasks.chat.streaming.agent.event_loop import (
+ stream_agent_events as _stream_agent_events,
+)
+from app.tasks.chat.streaming.shared.stream_result import StreamResult
+from app.tasks.chat.streaming.shared.utils import (
+ resume_step_prefix as _resume_step_prefix,
)
pytestmark = pytest.mark.unit
diff --git a/surfsense_backend/tests/unit/tasks/chat/test_tool_input_streaming.py b/surfsense_backend/tests/unit/tasks/chat/test_tool_input_streaming.py
index ada32d168..f3f28eb1c 100644
--- a/surfsense_backend/tests/unit/tasks/chat/test_tool_input_streaming.py
+++ b/surfsense_backend/tests/unit/tasks/chat/test_tool_input_streaming.py
@@ -1,6 +1,6 @@
"""Unit tests for live tool-call argument streaming.
-Pins the wire format that ``_stream_agent_events`` emits:
+Pins the wire format that ``stream_agent_events`` emits:
``tool-input-start`` → ``tool-input-delta``... → ``tool-input-available`` →
``tool-output-available``, keyed consistently with LangChain ``tool_call.id``
when the model streams indexed chunks.
@@ -20,11 +20,13 @@ from typing import Any
import pytest
from app.services.new_streaming_service import VercelStreamingService
-from app.tasks.chat.stream_new_chat import (
- StreamResult,
- _legacy_match_lc_id,
- _stream_agent_events,
+from app.tasks.chat.streaming.agent.event_loop import (
+ stream_agent_events as _stream_agent_events,
)
+from app.tasks.chat.streaming.helpers.tool_call_matching import (
+ match_buffered_langchain_tool_call_id as _legacy_match_lc_id,
+)
+from app.tasks.chat.streaming.shared.stream_result import StreamResult
pytestmark = pytest.mark.unit
diff --git a/surfsense_backend/tests/unit/test_stream_new_chat_contract.py b/surfsense_backend/tests/unit/test_stream_new_chat_contract.py
deleted file mode 100644
index 8ff576e2d..000000000
--- a/surfsense_backend/tests/unit/test_stream_new_chat_contract.py
+++ /dev/null
@@ -1,438 +0,0 @@
-import inspect
-import json
-import logging
-import re
-from pathlib import Path
-
-import pytest
-
-import app.tasks.chat.stream_new_chat as stream_new_chat_module
-from app.agents.shared.errors import BusyError
-from app.agents.shared.middleware.busy_mutex import request_cancel, reset_cancel
-from app.tasks.chat.stream_new_chat import (
- StreamResult,
- _classify_stream_exception,
- _contract_enforcement_active,
- _evaluate_file_contract_outcome,
- _extract_resolved_file_path,
- _log_chat_stream_error,
- _tool_output_has_error,
-)
-
-pytestmark = pytest.mark.unit
-
-
-def test_tool_output_error_detection():
- assert _tool_output_has_error("Error: failed to write file")
- assert _tool_output_has_error({"error": "boom"})
- assert _tool_output_has_error({"result": "Error: disk is full"})
- assert not _tool_output_has_error({"result": "Updated file /notes.md"})
-
-
-def test_extract_resolved_file_path_prefers_structured_path():
- assert (
- _extract_resolved_file_path(
- tool_name="write_file",
- tool_output={"status": "completed", "path": "/docs/note.md"},
- tool_input=None,
- )
- == "/docs/note.md"
- )
-
-
-def test_extract_resolved_file_path_falls_back_to_tool_input():
- assert (
- _extract_resolved_file_path(
- tool_name="edit_file",
- tool_output={"status": "completed", "result": "updated"},
- tool_input={"file_path": "/docs/edited.md"},
- )
- == "/docs/edited.md"
- )
-
-
-def test_extract_resolved_file_path_does_not_parse_result_text():
- assert (
- _extract_resolved_file_path(
- tool_name="write_file",
- tool_output={"result": "Updated file /docs/from-text.md"},
- tool_input=None,
- )
- is None
- )
-
-
-def test_file_write_contract_outcome_reasons():
- result = StreamResult(intent_detected="file_write")
- passed, reason = _evaluate_file_contract_outcome(result)
- assert not passed
- assert reason == "no_write_attempt"
-
- result.write_attempted = True
- passed, reason = _evaluate_file_contract_outcome(result)
- assert not passed
- assert reason == "write_failed"
-
- result.write_succeeded = True
- passed, reason = _evaluate_file_contract_outcome(result)
- assert not passed
- assert reason == "verification_failed"
-
- result.verification_succeeded = True
- passed, reason = _evaluate_file_contract_outcome(result)
- assert passed
- assert reason == ""
-
-
-def test_contract_enforcement_local_only():
- result = StreamResult(filesystem_mode="desktop_local_folder")
- assert _contract_enforcement_active(result)
-
- result.filesystem_mode = "cloud"
- assert not _contract_enforcement_active(result)
-
-
-def _extract_chat_stream_payload(record_message: str) -> dict:
- prefix = "[chat_stream_error] "
- assert record_message.startswith(prefix)
- return json.loads(record_message[len(prefix) :])
-
-
-def test_unified_chat_stream_error_log_schema(caplog):
- with caplog.at_level(logging.INFO, logger="app.tasks.chat.stream_new_chat"):
- _log_chat_stream_error(
- flow="new",
- error_kind="server_error",
- error_code="SERVER_ERROR",
- severity="warn",
- is_expected=False,
- request_id="req-123",
- thread_id=101,
- search_space_id=202,
- user_id="user-1",
- message="Error during chat: boom",
- )
-
- record = next(r for r in caplog.records if "[chat_stream_error]" in r.message)
- payload = _extract_chat_stream_payload(record.message)
-
- required_keys = {
- "event",
- "flow",
- "error_kind",
- "error_code",
- "severity",
- "is_expected",
- "request_id",
- "thread_id",
- "search_space_id",
- "user_id",
- "message",
- }
- assert required_keys.issubset(payload.keys())
- assert payload["event"] == "chat_stream_error"
- assert payload["flow"] == "new"
- assert payload["error_code"] == "SERVER_ERROR"
-
-
-def test_premium_quota_uses_unified_chat_stream_log_shape(caplog):
- with caplog.at_level(logging.INFO, logger="app.tasks.chat.stream_new_chat"):
- _log_chat_stream_error(
- flow="resume",
- error_kind="premium_quota_exhausted",
- error_code="PREMIUM_QUOTA_EXHAUSTED",
- severity="info",
- is_expected=True,
- request_id="req-premium",
- thread_id=303,
- search_space_id=404,
- user_id="user-2",
- message="Buy more tokens to continue with this model, or switch to a free model",
- extra={"auto_fallback": False},
- )
-
- record = next(r for r in caplog.records if "[chat_stream_error]" in r.message)
- payload = _extract_chat_stream_payload(record.message)
- assert payload["event"] == "chat_stream_error"
- assert payload["error_kind"] == "premium_quota_exhausted"
- assert payload["error_code"] == "PREMIUM_QUOTA_EXHAUSTED"
- assert payload["flow"] == "resume"
- assert payload["is_expected"] is True
- assert payload["auto_fallback"] is False
-
-
-def test_stream_error_emission_keeps_machine_error_codes():
- source = inspect.getsource(stream_new_chat_module)
- format_error_calls = re.findall(r"format_error\(", source)
- emitted_error_codes = set(re.findall(r'error_code="([A-Z_]+)"', source))
-
- # All stream paths should route through one shared terminal error emitter.
- assert len(format_error_calls) == 1
- assert {
- "PREMIUM_QUOTA_EXHAUSTED",
- "SERVER_ERROR",
- }.issubset(emitted_error_codes)
- assert 'flow: Literal["new", "regenerate"] = "new"' in source
- assert "_emit_stream_terminal_error" in source
- assert "flow=flow" in source
- assert 'flow="resume"' in source
-
-
-def test_stream_exception_classifies_rate_limited():
- exc = Exception(
- '{"error":{"type":"rate_limit_error","message":"Rate limited. Please try again later."}}'
- )
- kind, code, severity, is_expected, user_message, extra = _classify_stream_exception(
- exc, flow_label="chat"
- )
- assert kind == "rate_limited"
- assert code == "RATE_LIMITED"
- assert severity == "warn"
- assert is_expected is True
- assert "temporarily rate-limited" in user_message
- assert extra is None
-
-
-def test_stream_exception_classifies_openrouter_429_payload():
- exc = Exception(
- 'OpenrouterException - {"error":{"message":"Provider returned error","code":429,'
- '"metadata":{"raw":"foo is temporarily rate-limited upstream"}}}'
- )
- kind, code, severity, is_expected, user_message, extra = _classify_stream_exception(
- exc, flow_label="chat"
- )
- assert kind == "rate_limited"
- assert code == "RATE_LIMITED"
- assert severity == "warn"
- assert is_expected is True
- assert "temporarily rate-limited" in user_message
- assert extra is None
-
-
-def test_stream_exception_classifies_thread_busy():
- exc = BusyError(request_id="thread-123")
- kind, code, severity, is_expected, user_message, extra = _classify_stream_exception(
- exc, flow_label="chat"
- )
- assert kind == "thread_busy"
- assert code == "THREAD_BUSY"
- assert severity == "warn"
- assert is_expected is True
- assert "still finishing for this thread" in user_message
- assert extra is None
-
-
-def test_stream_exception_classifies_thread_busy_from_message():
- exc = Exception("Thread is busy with another request")
- kind, code, severity, is_expected, user_message, extra = _classify_stream_exception(
- exc, flow_label="chat"
- )
- assert kind == "thread_busy"
- assert code == "THREAD_BUSY"
- assert severity == "warn"
- assert is_expected is True
- assert "still finishing for this thread" in user_message
- assert extra is None
-
-
-def test_stream_exception_classifies_turn_cancelling_when_cancel_requested():
- thread_id = "thread-cancelling-1"
- reset_cancel(thread_id)
- request_cancel(thread_id)
- exc = BusyError(request_id=thread_id)
- kind, code, severity, is_expected, user_message, extra = _classify_stream_exception(
- exc, flow_label="chat"
- )
- assert kind == "thread_busy"
- assert code == "TURN_CANCELLING"
- assert severity == "info"
- assert is_expected is True
- assert "stopping" in user_message
- assert isinstance(extra, dict)
- assert "retry_after_ms" in extra
-
-
-def test_premium_classification_is_error_code_driven():
- classifier_path = (
- Path(__file__).resolve().parents[3]
- / "surfsense_web/lib/chat/chat-error-classifier.ts"
- )
- source = classifier_path.read_text(encoding="utf-8")
-
- assert "PREMIUM_KEYWORDS" not in source
- assert "RATE_LIMIT_KEYWORDS" not in source
- assert "normalized.includes(" not in source
- assert 'if (errorCode === "PREMIUM_QUOTA_EXHAUSTED") {' in source
-
-
-def test_stream_terminal_error_handler_has_pre_accept_soft_rollback_hook():
- page_path = (
- Path(__file__).resolve().parents[3]
- / "surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx"
- )
- source = page_path.read_text(encoding="utf-8")
-
- assert "onPreAcceptFailure?: () => Promise;" in source
- assert "if (!accepted) {" in source
- assert "await onPreAcceptFailure?.();" in source
- assert "await onAcceptedStreamError?.();" in source
- assert "setMessages((prev) => prev.filter((m) => m.id !== userMsgId));" in source
- assert "setMessageDocumentsMap((prev) => {" in source
-
-
-def test_toast_only_pre_accept_policy_has_no_inline_failed_marker():
- user_message_path = (
- Path(__file__).resolve().parents[3]
- / "surfsense_web/components/assistant-ui/user-message.tsx"
- )
- source = user_message_path.read_text(encoding="utf-8")
-
- assert "Not sent. Edit and retry." not in source
- assert "failed_pre_accept" not in source
-
-
-def test_network_send_failures_use_unified_retry_toast_message():
- classifier_path = (
- Path(__file__).resolve().parents[3]
- / "surfsense_web/lib/chat/chat-error-classifier.ts"
- )
- classifier_source = classifier_path.read_text(encoding="utf-8")
- request_errors_path = (
- Path(__file__).resolve().parents[3]
- / "surfsense_web/lib/chat/chat-request-errors.ts"
- )
- request_errors_source = request_errors_path.read_text(encoding="utf-8")
-
- assert '"send_failed_pre_accept"' in classifier_source
- assert 'errorCode === "SEND_FAILED_PRE_ACCEPT"' in classifier_source
- assert 'errorCode === "TURN_CANCELLING"' in classifier_source
- assert "if (withCode.code) return withCode.code;" in classifier_source
- assert 'userMessage: "Message not sent. Please retry."' in classifier_source
- assert 'userMessage: "Connection issue. Please try again."' in classifier_source
- assert "const passthroughCodes = new Set([" in request_errors_source
- assert '"PREMIUM_QUOTA_EXHAUSTED"' in request_errors_source
- assert '"THREAD_BUSY"' in request_errors_source
- assert '"TURN_CANCELLING"' in request_errors_source
- assert '"AUTH_EXPIRED"' in request_errors_source
- assert '"UNAUTHORIZED"' in request_errors_source
- assert '"RATE_LIMITED"' in request_errors_source
- assert '"NETWORK_ERROR"' in request_errors_source
- assert '"STREAM_PARSE_ERROR"' in request_errors_source
- assert '"TOOL_EXECUTION_ERROR"' in request_errors_source
- assert '"PERSIST_MESSAGE_FAILED"' in request_errors_source
- assert '"SERVER_ERROR"' in request_errors_source
- assert "passthroughCodes.has(existingCode)" in request_errors_source
- assert 'errorCode: "SEND_FAILED_PRE_ACCEPT"' in request_errors_source
- assert 'errorCode: "NETWORK_ERROR"' not in request_errors_source
- assert "Failed to start chat. Please try again." not in classifier_source
-
-
-def test_pre_post_accept_abort_contract_exists_for_new_resume_regenerate_flows():
- page_path = (
- Path(__file__).resolve().parents[3]
- / "surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx"
- )
- source = page_path.read_text(encoding="utf-8")
-
- # Each flow tracks accepted boundary and passes it into shared terminal handling.
- # The acceptance boundary is still meaningful post-refactor: it gates
- # local-state cleanup (onPreAcceptFailure path) and lets the shared
- # terminal handler distinguish pre-accept aborts from in-stream errors.
- assert "let newAccepted = false;" in source
- assert "let resumeAccepted = false;" in source
- assert "let regenerateAccepted = false;" in source
- assert "accepted: newAccepted," in source
- assert "accepted: resumeAccepted," in source
- assert "accepted: regenerateAccepted," in source
-
- # NOTE: The FE-side persistence guards previously asserted here
- # ("if (!resumeAccepted) return;", "if (!regenerateAccepted) return;",
- # "if (newAccepted && !userPersisted) {") have been intentionally
- # removed by the SSE-based message-id handshake refactor. Persistence
- # is now server-authoritative: persist_user_turn / persist_assistant_shell
- # run inside stream_new_chat / stream_resume_chat unconditionally and
- # the FE consumes data-user-message-id / data-assistant-message-id
- # SSE events to learn the canonical primary keys. There is therefore
- # no FE call-site to guard, and the shared terminal handler relies
- # purely on the `accepted` field above (forwarded to onAbort /
- # onAcceptedStreamError) to drive UI cleanup. See
- # tests/integration/chat/test_message_id_sse.py for the new
- # cross-tier ID coherence guarantees.
-
- # The TURN_CANCELLING / THREAD_BUSY retry plumbing is independent
- # of the persistence refactor and must still exist on every
- # start-stream fetch.
- assert "const fetchWithTurnCancellingRetry = useCallback(" in source
- assert "computeFallbackTurnCancellingRetryDelay" in source
- assert 'withMeta.errorCode === "TURN_CANCELLING"' in source
- assert 'withMeta.errorCode === "THREAD_BUSY"' in source
- assert "await fetchWithTurnCancellingRetry(() =>" in source
-
-
-def test_cancel_active_turn_route_contract_exists():
- routes_path = (
- Path(__file__).resolve().parents[3]
- / "surfsense_backend/app/routes/new_chat_routes.py"
- )
- source = routes_path.read_text(encoding="utf-8")
-
- assert '@router.post(\n "/threads/{thread_id}/cancel-active-turn",' in source
- assert "response_model=CancelActiveTurnResponse" in source
- assert 'status="cancelling",' in source
- assert 'error_code="TURN_CANCELLING",' in source
- assert "retry_after_ms=retry_after_ms if retry_after_ms > 0 else None," in source
- assert "retry_after_at=" in source
- assert 'status="idle",' in source
- assert 'error_code="NO_ACTIVE_TURN",' in source
-
-
-def test_turn_status_route_contract_exists():
- routes_path = (
- Path(__file__).resolve().parents[3]
- / "surfsense_backend/app/routes/new_chat_routes.py"
- )
- source = routes_path.read_text(encoding="utf-8")
-
- assert '@router.get(\n "/threads/{thread_id}/turn-status",' in source
- assert "response_model=TurnStatusResponse" in source
- assert "_build_turn_status_payload(thread_id)" in source
- assert "Permission.CHATS_READ.value" in source
- assert "_raise_if_thread_busy_for_start(" in source
-
-
-def test_turn_cancelling_retry_policy_contract_exists():
- routes_path = (
- Path(__file__).resolve().parents[3]
- / "surfsense_backend/app/routes/new_chat_routes.py"
- )
- source = routes_path.read_text(encoding="utf-8")
-
- assert "TURN_CANCELLING_INITIAL_DELAY_MS = 200" in source
- assert "TURN_CANCELLING_BACKOFF_FACTOR = 2" in source
- assert "TURN_CANCELLING_MAX_DELAY_MS = 1500" in source
- assert "def _compute_turn_cancelling_retry_delay(" in source
- assert "retry-after-ms" in source
- assert '"Retry-After"' in source
- assert '"errorCode": "TURN_CANCELLING"' in source
-
-
-def test_turn_status_sse_contract_exists():
- stream_source = (
- Path(__file__).resolve().parents[3]
- / "surfsense_backend/app/tasks/chat/stream_new_chat.py"
- ).read_text(encoding="utf-8")
- state_source = (
- Path(__file__).resolve().parents[3]
- / "surfsense_web/lib/chat/streaming-state.ts"
- ).read_text(encoding="utf-8")
- pipeline_source = (
- Path(__file__).resolve().parents[3]
- / "surfsense_web/lib/chat/stream-pipeline.ts"
- ).read_text(encoding="utf-8")
-
- assert '"turn-status"' in stream_source
- assert '"status": "busy"' in stream_source
- assert '"status": "idle"' in stream_source
- assert 'type: "data-turn-status"' in state_source
- assert 'case "data-turn-status":' in pipeline_source
- assert "end_turn(str(chat_id))" in stream_source