diff --git a/surfsense_backend/alembic/versions/138_add_thread_auto_model_pinning_fields.py b/surfsense_backend/alembic/versions/138_add_thread_auto_model_pinning_fields.py new file mode 100644 index 000000000..1ea549975 --- /dev/null +++ b/surfsense_backend/alembic/versions/138_add_thread_auto_model_pinning_fields.py @@ -0,0 +1,65 @@ +"""138_add_thread_auto_model_pinning_fields + +Revision ID: 138 +Revises: 137 +Create Date: 2026-04-30 + +Add thread-level fields to persist Auto (Fastest) model pinning metadata: +- pinned_llm_config_id: concrete resolved config id used for this thread +- pinned_auto_mode: auto policy identifier (currently "auto_fastest") +- pinned_at: timestamp when the pin was created/refreshed +""" + +from __future__ import annotations + +from collections.abc import Sequence + +from alembic import op + +revision: str = "138" +down_revision: str | None = "137" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def upgrade() -> None: + op.execute( + "ALTER TABLE new_chat_threads " + "ADD COLUMN IF NOT EXISTS pinned_llm_config_id INTEGER" + ) + op.execute( + "ALTER TABLE new_chat_threads " + "ADD COLUMN IF NOT EXISTS pinned_auto_mode VARCHAR(32)" + ) + op.execute( + "ALTER TABLE new_chat_threads " + "ADD COLUMN IF NOT EXISTS pinned_at TIMESTAMP WITH TIME ZONE" + ) + + op.execute( + "CREATE INDEX IF NOT EXISTS ix_new_chat_threads_pinned_llm_config_id " + "ON new_chat_threads (pinned_llm_config_id)" + ) + op.execute( + "CREATE INDEX IF NOT EXISTS ix_new_chat_threads_pinned_auto_mode " + "ON new_chat_threads (pinned_auto_mode)" + ) + + +def downgrade() -> None: + op.execute( + "DROP INDEX IF EXISTS ix_new_chat_threads_pinned_auto_mode" + ) + op.execute( + "DROP INDEX IF EXISTS ix_new_chat_threads_pinned_llm_config_id" + ) + + op.execute( + "ALTER TABLE new_chat_threads DROP COLUMN IF EXISTS pinned_at" + ) + op.execute( + "ALTER TABLE new_chat_threads DROP COLUMN IF EXISTS pinned_auto_mode" + ) + op.execute( + "ALTER TABLE new_chat_threads DROP COLUMN IF EXISTS pinned_llm_config_id" + ) diff --git a/surfsense_backend/app/db.py b/surfsense_backend/app/db.py index 91d19fb4f..ca3334f8b 100644 --- a/surfsense_backend/app/db.py +++ b/surfsense_backend/app/db.py @@ -638,6 +638,13 @@ class NewChatThread(BaseModel, TimestampMixin): default=False, server_default="false", ) + # Auto model pinning metadata: + # - pinned_llm_config_id stores the concrete resolved model config id. + # - pinned_auto_mode indicates which auto policy produced the pin. + # This allows Auto (Fastest) to resolve once per thread and stay stable. + pinned_llm_config_id = Column(Integer, nullable=True, index=True) + pinned_auto_mode = Column(String(32), nullable=True, index=True) + pinned_at = Column(TIMESTAMP(timezone=True), nullable=True) # Relationships search_space = relationship("SearchSpace", back_populates="new_chat_threads") diff --git a/surfsense_backend/app/routes/new_chat_routes.py b/surfsense_backend/app/routes/new_chat_routes.py index 26c72bd45..e04cce1b5 100644 --- a/surfsense_backend/app/routes/new_chat_routes.py +++ b/surfsense_backend/app/routes/new_chat_routes.py @@ -1924,6 +1924,7 @@ async def regenerate_response( filesystem_selection=filesystem_selection, request_id=getattr(http_request.state, "request_id", "unknown"), user_image_data_urls=regenerate_image_urls or None, + flow="regenerate", ): yield chunk streaming_completed = True diff --git a/surfsense_backend/app/routes/search_spaces_routes.py b/surfsense_backend/app/routes/search_spaces_routes.py index 828137518..7944e7d66 100644 --- a/surfsense_backend/app/routes/search_spaces_routes.py +++ b/surfsense_backend/app/routes/search_spaces_routes.py @@ -3,7 +3,7 @@ import logging from fastapi import APIRouter, Depends, HTTPException from langchain_core.messages import HumanMessage from pydantic import BaseModel as PydanticBaseModel -from sqlalchemy import func +from sqlalchemy import func, update from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.future import select @@ -15,6 +15,7 @@ from app.agents.new_chat.tools.update_memory import MEMORY_HARD_LIMIT, _save_mem from app.config import config from app.db import ( ImageGenerationConfig, + NewChatThread, NewLLMConfig, Permission, SearchSpace, @@ -790,9 +791,31 @@ async def update_llm_preferences( # Update preferences update_data = preferences.model_dump(exclude_unset=True) + previous_agent_llm_id = search_space.agent_llm_id for key, value in update_data.items(): setattr(search_space, key, value) + agent_llm_changed = ( + "agent_llm_id" in update_data + and update_data["agent_llm_id"] != previous_agent_llm_id + ) + if agent_llm_changed: + await session.execute( + update(NewChatThread) + .where(NewChatThread.search_space_id == search_space_id) + .values( + pinned_llm_config_id=None, + pinned_auto_mode=None, + pinned_at=None, + ) + ) + logger.info( + "Cleared auto model pins for search_space_id=%s after agent_llm_id change (%s -> %s)", + search_space_id, + previous_agent_llm_id, + update_data["agent_llm_id"], + ) + await session.commit() await session.refresh(search_space) diff --git a/surfsense_backend/app/services/auto_model_pin_service.py b/surfsense_backend/app/services/auto_model_pin_service.py new file mode 100644 index 000000000..6bdb60f57 --- /dev/null +++ b/surfsense_backend/app/services/auto_model_pin_service.py @@ -0,0 +1,218 @@ +"""Resolve and persist Auto (Fastest) model pins per chat thread. + +Auto (Fastest) is represented by ``agent_llm_id == 0``. For chat threads we +resolve that virtual mode to one concrete global LLM config exactly once and +persist the chosen config id on ``new_chat_threads`` so subsequent turns are +stable. +""" + +from __future__ import annotations + +import hashlib +import logging +from dataclasses import dataclass +from datetime import UTC, datetime +from uuid import UUID + +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.config import config +from app.db import NewChatThread +from app.services.token_quota_service import TokenQuotaService + +logger = logging.getLogger(__name__) + +AUTO_FASTEST_ID = 0 +AUTO_FASTEST_MODE = "auto_fastest" + + +@dataclass +class AutoPinResolution: + resolved_llm_config_id: int + resolved_tier: str + from_existing_pin: bool + + +def _is_usable_global_config(cfg: dict) -> bool: + return bool( + cfg.get("id") is not None + and cfg.get("model_name") + and cfg.get("provider") + and cfg.get("api_key") + ) + + +def _global_candidates() -> list[dict]: + candidates = [cfg for cfg in config.GLOBAL_LLM_CONFIGS if _is_usable_global_config(cfg)] + return sorted(candidates, key=lambda c: int(c.get("id", 0))) + + +def _tier_of(cfg: dict) -> str: + return str(cfg.get("billing_tier", "free")).lower() + + +def _deterministic_pick(candidates: list[dict], thread_id: int) -> dict: + digest = hashlib.sha256(f"{AUTO_FASTEST_MODE}:{thread_id}".encode()).digest() + idx = int.from_bytes(digest[:8], "big") % len(candidates) + return candidates[idx] + + +def _to_uuid(user_id: str | UUID | None) -> UUID | None: + if user_id is None: + return None + if isinstance(user_id, UUID): + return user_id + try: + return UUID(str(user_id)) + except Exception: + return None + + +async def _is_premium_eligible(session: AsyncSession, user_id: str | UUID | None) -> bool: + parsed = _to_uuid(user_id) + if parsed is None: + return False + usage = await TokenQuotaService.premium_get_usage(session, parsed) + return bool(usage.allowed) + + +async def resolve_or_get_pinned_llm_config_id( + session: AsyncSession, + *, + thread_id: int, + search_space_id: int, + user_id: str | UUID | None, + selected_llm_config_id: int, + force_repin_free: bool = False, +) -> AutoPinResolution: + """Resolve Auto (Fastest) to one concrete config id and persist pin metadata. + + For non-auto selections, this function clears existing auto pin metadata and + returns the selected id as-is. + """ + thread = ( + ( + await session.execute( + select(NewChatThread) + .where(NewChatThread.id == thread_id) + .with_for_update(of=NewChatThread) + ) + ) + .unique() + .scalar_one_or_none() + ) + if thread is None: + raise ValueError(f"Thread {thread_id} not found") + if thread.search_space_id != search_space_id: + raise ValueError( + f"Thread {thread_id} does not belong to search space {search_space_id}" + ) + + # Explicit model selected: clear stale auto pin metadata. + if selected_llm_config_id != AUTO_FASTEST_ID: + if ( + thread.pinned_llm_config_id is not None + or thread.pinned_auto_mode is not None + or thread.pinned_at is not None + ): + thread.pinned_llm_config_id = None + thread.pinned_auto_mode = None + thread.pinned_at = None + await session.commit() + return AutoPinResolution( + resolved_llm_config_id=selected_llm_config_id, + resolved_tier="explicit", + from_existing_pin=False, + ) + + candidates = _global_candidates() + if not candidates: + raise ValueError("No usable global LLM configs are available for Auto mode") + candidate_by_id = {int(c["id"]): c for c in candidates} + + # Reuse existing valid pin without re-checking current quota (no silent tier switch), + # unless the caller explicitly requests a forced repin to free. + pinned_id = thread.pinned_llm_config_id + if ( + not force_repin_free + and + thread.pinned_auto_mode == AUTO_FASTEST_MODE + and pinned_id is not None + and int(pinned_id) in candidate_by_id + ): + pinned_cfg = candidate_by_id[int(pinned_id)] + logger.info( + "auto_pin_reused thread_id=%s search_space_id=%s resolved_config_id=%s tier=%s", + thread_id, + search_space_id, + pinned_id, + _tier_of(pinned_cfg), + ) + return AutoPinResolution( + resolved_llm_config_id=int(pinned_id), + resolved_tier=_tier_of(pinned_cfg), + from_existing_pin=True, + ) + if pinned_id is not None: + logger.info( + "auto_pin_invalid thread_id=%s search_space_id=%s pinned_config_id=%s pinned_auto_mode=%s", + thread_id, + search_space_id, + pinned_id, + thread.pinned_auto_mode, + ) + + premium_eligible = False if force_repin_free else await _is_premium_eligible(session, user_id) + if premium_eligible: + eligible = candidates + else: + eligible = [c for c in candidates if _tier_of(c) != "premium"] + + if not eligible: + raise ValueError( + "Auto mode could not find an eligible LLM config for this user and quota state" + ) + + selected_cfg = _deterministic_pick(eligible, thread_id) + selected_id = int(selected_cfg["id"]) + selected_tier = _tier_of(selected_cfg) + + thread.pinned_llm_config_id = selected_id + thread.pinned_auto_mode = AUTO_FASTEST_MODE + thread.pinned_at = datetime.now(UTC) + await session.commit() + + if force_repin_free: + logger.info( + "auto_pin_forced_free_repin thread_id=%s search_space_id=%s previous_config_id=%s resolved_config_id=%s", + thread_id, + search_space_id, + pinned_id, + selected_id, + ) + + if pinned_id is None: + logger.info( + "auto_pin_created thread_id=%s search_space_id=%s resolved_config_id=%s tier=%s premium_eligible=%s", + thread_id, + search_space_id, + selected_id, + selected_tier, + premium_eligible, + ) + else: + logger.info( + "auto_pin_repaired thread_id=%s search_space_id=%s previous_config_id=%s resolved_config_id=%s tier=%s premium_eligible=%s", + thread_id, + search_space_id, + pinned_id, + selected_id, + selected_tier, + premium_eligible, + ) + return AutoPinResolution( + resolved_llm_config_id=selected_id, + resolved_tier=selected_tier, + from_existing_pin=False, + ) diff --git a/surfsense_backend/app/services/new_streaming_service.py b/surfsense_backend/app/services/new_streaming_service.py index 3531d37af..842481f1c 100644 --- a/surfsense_backend/app/services/new_streaming_service.py +++ b/surfsense_backend/app/services/new_streaming_service.py @@ -565,20 +565,24 @@ class VercelStreamingService: # Error Part # ========================================================================= - def format_error(self, error_text: str) -> str: + def format_error(self, error_text: str, error_code: str | None = None) -> str: """ Format an error message. Args: error_text: The error message text + error_code: Optional machine-readable error code for frontend branching Returns: str: SSE formatted error part Example output: - data: {"type":"error","errorText":"Something went wrong"} + data: {"type":"error","errorText":"Something went wrong","errorCode":"SOME_CODE"} """ - return self._format_sse({"type": "error", "errorText": error_text}) + payload: dict[str, str] = {"type": "error", "errorText": error_text} + if error_code: + payload["errorCode"] = error_code + return self._format_sse(payload) # ========================================================================= # Tool Parts diff --git a/surfsense_backend/app/tasks/chat/stream_new_chat.py b/surfsense_backend/app/tasks/chat/stream_new_chat.py index c94945bb1..2afa851b5 100644 --- a/surfsense_backend/app/tasks/chat/stream_new_chat.py +++ b/surfsense_backend/app/tasks/chat/stream_new_chat.py @@ -19,7 +19,8 @@ import re import time from collections.abc import AsyncGenerator from dataclasses import dataclass, field -from typing import Any +from functools import partial +from typing import Any, Literal from uuid import UUID import anyio @@ -30,6 +31,7 @@ from sqlalchemy.orm import selectinload from app.agents.new_chat.chat_deepagent import create_surfsense_deep_agent from app.agents.new_chat.checkpointer import get_checkpointer +from app.agents.new_chat.errors import BusyError from app.agents.new_chat.feature_flags import get_flags from app.agents.new_chat.filesystem_selection import FilesystemMode, FilesystemSelection from app.agents.new_chat.llm_config import ( @@ -57,6 +59,7 @@ from app.db import ( shielded_async_session, ) from app.prompts import TITLE_GENERATION_PROMPT +from app.services.auto_model_pin_service import resolve_or_get_pinned_llm_config_id from app.services.chat_session_state_service import ( clear_ai_responding, set_ai_responding, @@ -338,6 +341,138 @@ def _log_file_contract(stage: str, result: StreamResult, **extra: Any) -> None: ) +def _log_chat_stream_error( + *, + flow: Literal["new", "resume", "regenerate"], + error_kind: str, + error_code: str | None, + severity: Literal["info", "warn", "error"], + is_expected: bool, + request_id: str | None, + thread_id: int | None, + search_space_id: int | None, + user_id: str | None, + message: str, + extra: dict[str, Any] | None = None, +) -> None: + payload: dict[str, Any] = { + "event": "chat_stream_error", + "flow": flow, + "error_kind": error_kind, + "error_code": error_code, + "severity": severity, + "is_expected": is_expected, + "request_id": request_id or "unknown", + "thread_id": thread_id, + "search_space_id": search_space_id, + "user_id": user_id, + "message": message, + } + if extra: + payload.update(extra) + + logger = logging.getLogger(__name__) + rendered = json.dumps(payload, ensure_ascii=False) + if severity == "error": + logger.error("[chat_stream_error] %s", rendered) + elif severity == "warn": + logger.warning("[chat_stream_error] %s", rendered) + else: + logger.info("[chat_stream_error] %s", rendered) + + +def _parse_error_payload(message: str) -> dict[str, Any] | None: + candidates = [message] + first_brace_idx = message.find("{") + if first_brace_idx >= 0: + candidates.append(message[first_brace_idx:]) + + for candidate in candidates: + try: + parsed = json.loads(candidate) + if isinstance(parsed, dict): + return parsed + except Exception: + continue + return None + + +def _classify_stream_exception( + exc: Exception, + *, + flow_label: str, +) -> tuple[str, str, Literal["info", "warn", "error"], bool, str]: + raw = str(exc) + if isinstance(exc, BusyError) or "Thread is busy with another request" in raw: + return ( + "thread_busy", + "THREAD_BUSY", + "warn", + True, + "Another response is still finishing for this thread. Please try again in a moment.", + ) + + parsed = _parse_error_payload(raw) + provider_error_type = "" + if parsed: + top_type = parsed.get("type") + if isinstance(top_type, str): + provider_error_type = top_type.lower() + nested = parsed.get("error") + if isinstance(nested, dict): + nested_type = nested.get("type") + if isinstance(nested_type, str): + provider_error_type = nested_type.lower() + + if provider_error_type == "rate_limit_error": + return ( + "rate_limited", + "RATE_LIMITED", + "warn", + True, + "This model is temporarily rate-limited. Please try again in a few seconds or switch models.", + ) + + return ( + "server_error", + "SERVER_ERROR", + "error", + False, + f"Error during {flow_label}: {raw}", + ) + + +def _emit_stream_terminal_error( + *, + streaming_service: VercelStreamingService, + flow: str, + request_id: str | None, + thread_id: int, + search_space_id: int, + user_id: str | None, + message: str, + error_kind: str = "server_error", + error_code: str = "SERVER_ERROR", + severity: Literal["info", "warn", "error"] = "error", + is_expected: bool = False, + extra: dict[str, Any] | None = None, +) -> str: + _log_chat_stream_error( + flow=flow, + error_kind=error_kind, + error_code=error_code, + severity=severity, + is_expected=is_expected, + request_id=request_id, + thread_id=thread_id, + search_space_id=search_space_id, + user_id=user_id, + message=message, + extra=extra, + ) + return streaming_service.format_error(message, error_code=error_code) + + def _legacy_match_lc_id( pending_tool_call_chunks: list[dict[str, Any]], tool_name: str, @@ -1913,6 +2048,7 @@ async def stream_new_chat( filesystem_selection: FilesystemSelection | None = None, request_id: str | None = None, user_image_data_urls: list[str] | None = None, + flow: Literal["new", "regenerate"] = "new", ) -> AsyncGenerator[str, None]: """ Stream chat responses from the new SurfSense deep agent. @@ -1964,6 +2100,16 @@ async def stream_new_chat( _premium_reserved = 0 _premium_request_id: str | None = None + _emit_stream_error = partial( + _emit_stream_terminal_error, + streaming_service=streaming_service, + flow=flow, + request_id=request_id, + thread_id=chat_id, + search_space_id=search_space_id, + user_id=user_id, + ) + session = async_session_maker() try: # Mark AI as responding to this user for live collaboration @@ -1971,49 +2117,78 @@ async def stream_new_chat( await set_ai_responding(session, chat_id, UUID(user_id)) # Load LLM config - supports both YAML (negative IDs) and database (positive IDs) agent_config: AgentConfig | None = None + requested_llm_config_id = llm_config_id + + async def _load_llm_bundle( + config_id: int, + ) -> tuple[Any, AgentConfig | None, str | None]: + if config_id >= 0: + loaded_agent_config = await load_agent_config( + session=session, + config_id=config_id, + search_space_id=search_space_id, + ) + if not loaded_agent_config: + return ( + None, + None, + f"Failed to load NewLLMConfig with id {config_id}", + ) + return ( + create_chat_litellm_from_agent_config(loaded_agent_config), + loaded_agent_config, + None, + ) + + loaded_llm_config = load_global_llm_config_by_id(config_id) + if not loaded_llm_config: + return None, None, f"Failed to load LLM config with id {config_id}" + return ( + create_chat_litellm_from_config(loaded_llm_config), + AgentConfig.from_yaml_config(loaded_llm_config), + None, + ) _t0 = time.perf_counter() - if llm_config_id >= 0: - # Positive ID: Load from NewLLMConfig database table - agent_config = await load_agent_config( - session=session, - config_id=llm_config_id, - search_space_id=search_space_id, + try: + llm_config_id = ( + await resolve_or_get_pinned_llm_config_id( + session, + thread_id=chat_id, + search_space_id=search_space_id, + user_id=user_id, + selected_llm_config_id=llm_config_id, + ) + ).resolved_llm_config_id + except ValueError as pin_error: + yield _emit_stream_error( + message=str(pin_error), + error_kind="server_error", + error_code="SERVER_ERROR", ) - if not agent_config: - yield streaming_service.format_error( - f"Failed to load NewLLMConfig with id {llm_config_id}" - ) - yield streaming_service.format_done() - return + yield streaming_service.format_done() + return - # Create ChatLiteLLM from AgentConfig - llm = create_chat_litellm_from_agent_config(agent_config) - else: - # Negative ID: Load from in-memory global configs (includes dynamic OpenRouter models) - llm_config = load_global_llm_config_by_id(llm_config_id) - if not llm_config: - yield streaming_service.format_error( - f"Failed to load LLM config with id {llm_config_id}" - ) - yield streaming_service.format_done() - return - - # Create ChatLiteLLM from global config dict - llm = create_chat_litellm_from_config(llm_config) - agent_config = AgentConfig.from_yaml_config(llm_config) + llm, agent_config, llm_load_error = await _load_llm_bundle(llm_config_id) + if llm_load_error: + yield _emit_stream_error( + message=llm_load_error, + error_kind="server_error", + error_code="SERVER_ERROR", + ) + yield streaming_service.format_done() + return _perf_log.info( "[stream_new_chat] LLM config loaded in %.3fs (config_id=%s)", time.perf_counter() - _t0, llm_config_id, ) - # Premium quota reservation — applies to explicitly premium configs - # AND Auto mode (which may route to premium models). + # Premium quota reservation for pinned premium model only. _needs_premium_quota = ( agent_config is not None and user_id - and (agent_config.is_premium or agent_config.is_auto_mode) + and agent_config.is_premium ) if _needs_premium_quota: import uuid as _uuid @@ -2036,19 +2211,79 @@ async def stream_new_chat( ) _premium_reserved = reserve_amount if not quota_result.allowed: - if agent_config.is_premium: - yield streaming_service.format_error( - "Premium token quota exceeded. Please purchase more tokens to continue using premium models." + if requested_llm_config_id == 0: + try: + llm_config_id = ( + await resolve_or_get_pinned_llm_config_id( + session, + thread_id=chat_id, + search_space_id=search_space_id, + user_id=user_id, + selected_llm_config_id=0, + force_repin_free=True, + ) + ).resolved_llm_config_id + except ValueError as pin_error: + yield _emit_stream_error( + message=str(pin_error), + error_kind="server_error", + error_code="SERVER_ERROR", + ) + yield streaming_service.format_done() + return + + llm, agent_config, llm_load_error = await _load_llm_bundle(llm_config_id) + if llm_load_error: + yield _emit_stream_error( + message=llm_load_error, + error_kind="server_error", + error_code="SERVER_ERROR", + ) + yield streaming_service.format_done() + return + _premium_request_id = None + _premium_reserved = 0 + _log_chat_stream_error( + flow=flow, + error_kind="premium_quota_exhausted", + error_code="PREMIUM_QUOTA_EXHAUSTED", + severity="info", + is_expected=True, + request_id=request_id, + thread_id=chat_id, + search_space_id=search_space_id, + user_id=user_id, + message=( + "Premium quota exhausted on pinned model; auto-fallback switched to a free model" + ), + extra={ + "fallback_config_id": llm_config_id, + "auto_fallback": True, + }, + ) + else: + yield _emit_stream_error( + message=( + "Buy more tokens to continue with this model, or switch to a free model" + ), + error_kind="premium_quota_exhausted", + error_code="PREMIUM_QUOTA_EXHAUSTED", + severity="info", + is_expected=True, + extra={ + "resolved_config_id": llm_config_id, + "auto_fallback": False, + }, ) yield streaming_service.format_done() return - # Auto mode: quota exhausted but we can still proceed - # (the router may pick a free model). Reset reservation. - _premium_request_id = None - _premium_reserved = 0 if not llm: - yield streaming_service.format_error("Failed to create LLM instance") + yield _emit_stream_error( + message="Failed to create LLM instance", + error_kind="server_error", + error_code="SERVER_ERROR", + ) yield streaming_service.format_done() return @@ -2499,28 +2734,20 @@ async def stream_new_chat( ) # Finalize premium quota with actual tokens. - # For Auto mode, only count tokens from calls that used premium models. if _premium_request_id and user_id: try: from app.services.token_quota_service import TokenQuotaService - if agent_config and agent_config.is_auto_mode: - from app.services.llm_router_service import LLMRouterService - - actual_premium_tokens = LLMRouterService.compute_premium_tokens( - accumulator.calls - ) - else: - actual_premium_tokens = accumulator.grand_total - async with shielded_async_session() as quota_session: await TokenQuotaService.premium_finalize( db_session=quota_session, user_id=UUID(user_id), request_id=_premium_request_id, - actual_tokens=actual_premium_tokens, + actual_tokens=accumulator.grand_total, reserved_tokens=_premium_reserved, ) + _premium_request_id = None + _premium_reserved = 0 except Exception: logging.getLogger(__name__).warning( "Failed to finalize premium quota for user %s", @@ -2586,12 +2813,25 @@ async def stream_new_chat( # Handle any errors import traceback + ( + error_kind, + error_code, + severity, + is_expected, + user_message, + ) = _classify_stream_exception(e, flow_label="chat") error_message = f"Error during chat: {e!s}" print(f"[stream_new_chat] {error_message}") print(f"[stream_new_chat] Exception type: {type(e).__name__}") print(f"[stream_new_chat] Traceback:\n{traceback.format_exc()}") - yield streaming_service.format_error(error_message) + yield _emit_stream_error( + message=user_message, + error_kind=error_kind, + error_code=error_code, + severity=severity, + is_expected=is_expected, + ) yield streaming_service.format_finish_step() yield streaming_service.format_finish() yield streaming_service.format_done() @@ -2706,36 +2946,83 @@ async def stream_resume_chat( accumulator = start_turn() + _emit_stream_error = partial( + _emit_stream_terminal_error, + streaming_service=streaming_service, + flow="resume", + request_id=request_id, + thread_id=chat_id, + search_space_id=search_space_id, + user_id=user_id, + ) + session = async_session_maker() try: if user_id: await set_ai_responding(session, chat_id, UUID(user_id)) agent_config: AgentConfig | None = None - _t0 = time.perf_counter() - if llm_config_id >= 0: - agent_config = await load_agent_config( - session=session, - config_id=llm_config_id, - search_space_id=search_space_id, + requested_llm_config_id = llm_config_id + + async def _load_llm_bundle( + config_id: int, + ) -> tuple[Any, AgentConfig | None, str | None]: + if config_id >= 0: + loaded_agent_config = await load_agent_config( + session=session, + config_id=config_id, + search_space_id=search_space_id, + ) + if not loaded_agent_config: + return ( + None, + None, + f"Failed to load NewLLMConfig with id {config_id}", + ) + return ( + create_chat_litellm_from_agent_config(loaded_agent_config), + loaded_agent_config, + None, + ) + + loaded_llm_config = load_global_llm_config_by_id(config_id) + if not loaded_llm_config: + return None, None, f"Failed to load LLM config with id {config_id}" + return ( + create_chat_litellm_from_config(loaded_llm_config), + AgentConfig.from_yaml_config(loaded_llm_config), + None, ) - if not agent_config: - yield streaming_service.format_error( - f"Failed to load NewLLMConfig with id {llm_config_id}" + + _t0 = time.perf_counter() + try: + llm_config_id = ( + await resolve_or_get_pinned_llm_config_id( + session, + thread_id=chat_id, + search_space_id=search_space_id, + user_id=user_id, + selected_llm_config_id=llm_config_id, ) - yield streaming_service.format_done() - return - llm = create_chat_litellm_from_agent_config(agent_config) - else: - llm_config = load_global_llm_config_by_id(llm_config_id) - if not llm_config: - yield streaming_service.format_error( - f"Failed to load LLM config with id {llm_config_id}" - ) - yield streaming_service.format_done() - return - llm = create_chat_litellm_from_config(llm_config) - agent_config = AgentConfig.from_yaml_config(llm_config) + ).resolved_llm_config_id + except ValueError as pin_error: + yield _emit_stream_error( + message=str(pin_error), + error_kind="server_error", + error_code="SERVER_ERROR", + ) + yield streaming_service.format_done() + return + + llm, agent_config, llm_load_error = await _load_llm_bundle(llm_config_id) + if llm_load_error: + yield _emit_stream_error( + message=llm_load_error, + error_kind="server_error", + error_code="SERVER_ERROR", + ) + yield streaming_service.format_done() + return _perf_log.info( "[stream_resume] LLM config loaded in %.3fs", time.perf_counter() - _t0 ) @@ -2746,7 +3033,7 @@ async def stream_resume_chat( _resume_needs_premium = ( agent_config is not None and user_id - and (agent_config.is_premium or agent_config.is_auto_mode) + and agent_config.is_premium ) if _resume_needs_premium: import uuid as _uuid @@ -2769,17 +3056,79 @@ async def stream_resume_chat( ) _resume_premium_reserved = reserve_amount if not quota_result.allowed: - if agent_config.is_premium: - yield streaming_service.format_error( - "Premium token quota exceeded. Please purchase more tokens to continue using premium models." + if requested_llm_config_id == 0: + try: + llm_config_id = ( + await resolve_or_get_pinned_llm_config_id( + session, + thread_id=chat_id, + search_space_id=search_space_id, + user_id=user_id, + selected_llm_config_id=0, + force_repin_free=True, + ) + ).resolved_llm_config_id + except ValueError as pin_error: + yield _emit_stream_error( + message=str(pin_error), + error_kind="server_error", + error_code="SERVER_ERROR", + ) + yield streaming_service.format_done() + return + + llm, agent_config, llm_load_error = await _load_llm_bundle(llm_config_id) + if llm_load_error: + yield _emit_stream_error( + message=llm_load_error, + error_kind="server_error", + error_code="SERVER_ERROR", + ) + yield streaming_service.format_done() + return + _resume_premium_request_id = None + _resume_premium_reserved = 0 + _log_chat_stream_error( + flow="resume", + error_kind="premium_quota_exhausted", + error_code="PREMIUM_QUOTA_EXHAUSTED", + severity="info", + is_expected=True, + request_id=request_id, + thread_id=chat_id, + search_space_id=search_space_id, + user_id=user_id, + message=( + "Premium quota exhausted on pinned model; auto-fallback switched to a free model" + ), + extra={ + "fallback_config_id": llm_config_id, + "auto_fallback": True, + }, + ) + else: + yield _emit_stream_error( + message=( + "Buy more tokens to continue with this model, or switch to a free model" + ), + error_kind="premium_quota_exhausted", + error_code="PREMIUM_QUOTA_EXHAUSTED", + severity="info", + is_expected=True, + extra={ + "resolved_config_id": llm_config_id, + "auto_fallback": False, + }, ) yield streaming_service.format_done() return - _resume_premium_request_id = None - _resume_premium_reserved = 0 if not llm: - yield streaming_service.format_error("Failed to create LLM instance") + yield _emit_stream_error( + message="Failed to create LLM instance", + error_kind="server_error", + error_code="SERVER_ERROR", + ) yield streaming_service.format_done() return @@ -2920,23 +3269,16 @@ async def stream_resume_chat( try: from app.services.token_quota_service import TokenQuotaService - if agent_config and agent_config.is_auto_mode: - from app.services.llm_router_service import LLMRouterService - - actual_premium_tokens = LLMRouterService.compute_premium_tokens( - accumulator.calls - ) - else: - actual_premium_tokens = accumulator.grand_total - async with shielded_async_session() as quota_session: await TokenQuotaService.premium_finalize( db_session=quota_session, user_id=UUID(user_id), request_id=_resume_premium_request_id, - actual_tokens=actual_premium_tokens, + actual_tokens=accumulator.grand_total, reserved_tokens=_resume_premium_reserved, ) + _resume_premium_request_id = None + _resume_premium_reserved = 0 except Exception: logging.getLogger(__name__).warning( "Failed to finalize premium quota for user %s (resume)", @@ -2970,10 +3312,23 @@ async def stream_resume_chat( except Exception as e: import traceback + ( + error_kind, + error_code, + severity, + is_expected, + user_message, + ) = _classify_stream_exception(e, flow_label="resume") error_message = f"Error during resume: {e!s}" print(f"[stream_resume_chat] {error_message}") print(f"[stream_resume_chat] Traceback:\n{traceback.format_exc()}") - yield streaming_service.format_error(error_message) + yield _emit_stream_error( + message=user_message, + error_kind=error_kind, + error_code=error_code, + severity=severity, + is_expected=is_expected, + ) yield streaming_service.format_finish_step() yield streaming_service.format_finish() yield streaming_service.format_done() diff --git a/surfsense_backend/tests/unit/services/test_auto_model_pin_service.py b/surfsense_backend/tests/unit/services/test_auto_model_pin_service.py new file mode 100644 index 000000000..f08e50ba2 --- /dev/null +++ b/surfsense_backend/tests/unit/services/test_auto_model_pin_service.py @@ -0,0 +1,329 @@ +from __future__ import annotations + +from dataclasses import dataclass +from types import SimpleNamespace + +import pytest + +from app.services.auto_model_pin_service import ( + AUTO_FASTEST_MODE, + resolve_or_get_pinned_llm_config_id, +) + +pytestmark = pytest.mark.unit + + +@dataclass +class _FakeQuotaResult: + allowed: bool + + +class _FakeExecResult: + def __init__(self, thread): + self._thread = thread + + def unique(self): + return self + + def scalar_one_or_none(self): + return self._thread + + +class _FakeSession: + def __init__(self, thread): + self.thread = thread + self.commit_count = 0 + + async def execute(self, _stmt): + return _FakeExecResult(self.thread) + + async def commit(self): + self.commit_count += 1 + + +def _thread( + *, + search_space_id: int = 10, + pinned_llm_config_id: int | None = None, + pinned_auto_mode: str | None = None, +): + return SimpleNamespace( + id=1, + search_space_id=search_space_id, + pinned_llm_config_id=pinned_llm_config_id, + pinned_auto_mode=pinned_auto_mode, + pinned_at=None, + ) + + +@pytest.mark.asyncio +async def test_auto_first_turn_pins_one_model(monkeypatch): + from app.config import config + + session = _FakeSession(_thread()) + monkeypatch.setattr( + config, + "GLOBAL_LLM_CONFIGS", + [ + {"id": -2, "provider": "OPENAI", "model_name": "gpt-free", "api_key": "k1"}, + {"id": -1, "provider": "OPENAI", "model_name": "gpt-prem", "api_key": "k2", "billing_tier": "premium"}, + ], + ) + + async def _allowed(*_args, **_kwargs): + return _FakeQuotaResult(allowed=True) + + monkeypatch.setattr( + "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage", + _allowed, + ) + + result = await resolve_or_get_pinned_llm_config_id( + session, + thread_id=1, + search_space_id=10, + user_id="00000000-0000-0000-0000-000000000001", + selected_llm_config_id=0, + ) + assert result.resolved_llm_config_id in {-1, -2} + assert session.thread.pinned_llm_config_id == result.resolved_llm_config_id + assert session.thread.pinned_auto_mode == AUTO_FASTEST_MODE + assert session.thread.pinned_at is not None + assert session.commit_count == 1 + + +@pytest.mark.asyncio +async def test_next_turn_reuses_existing_pin(monkeypatch): + from app.config import config + + session = _FakeSession( + _thread(pinned_llm_config_id=-1, pinned_auto_mode=AUTO_FASTEST_MODE) + ) + monkeypatch.setattr( + config, + "GLOBAL_LLM_CONFIGS", + [ + {"id": -1, "provider": "OPENAI", "model_name": "gpt-prem", "api_key": "k2", "billing_tier": "premium"}, + ], + ) + + async def _must_not_call(*_args, **_kwargs): + raise AssertionError("premium_get_usage should not be called for valid pin reuse") + + monkeypatch.setattr( + "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage", + _must_not_call, + ) + + result = await resolve_or_get_pinned_llm_config_id( + session, + thread_id=1, + search_space_id=10, + user_id="00000000-0000-0000-0000-000000000001", + selected_llm_config_id=0, + ) + assert result.resolved_llm_config_id == -1 + assert result.from_existing_pin is True + assert session.commit_count == 0 + + +@pytest.mark.asyncio +async def test_premium_eligible_auto_can_pin_premium(monkeypatch): + from app.config import config + + session = _FakeSession(_thread()) + monkeypatch.setattr( + config, + "GLOBAL_LLM_CONFIGS", + [ + {"id": -1, "provider": "OPENAI", "model_name": "gpt-prem", "api_key": "k2", "billing_tier": "premium"}, + ], + ) + + async def _allowed(*_args, **_kwargs): + return _FakeQuotaResult(allowed=True) + + monkeypatch.setattr( + "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage", + _allowed, + ) + + result = await resolve_or_get_pinned_llm_config_id( + session, + thread_id=1, + search_space_id=10, + user_id="00000000-0000-0000-0000-000000000001", + selected_llm_config_id=0, + ) + assert result.resolved_llm_config_id == -1 + assert result.resolved_tier == "premium" + + +@pytest.mark.asyncio +async def test_premium_ineligible_auto_pins_free_only(monkeypatch): + from app.config import config + + session = _FakeSession(_thread()) + monkeypatch.setattr( + config, + "GLOBAL_LLM_CONFIGS", + [ + {"id": -2, "provider": "OPENAI", "model_name": "gpt-free", "api_key": "k1", "billing_tier": "free"}, + {"id": -1, "provider": "OPENAI", "model_name": "gpt-prem", "api_key": "k2", "billing_tier": "premium"}, + ], + ) + + async def _blocked(*_args, **_kwargs): + return _FakeQuotaResult(allowed=False) + + monkeypatch.setattr( + "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage", + _blocked, + ) + + result = await resolve_or_get_pinned_llm_config_id( + session, + thread_id=1, + search_space_id=10, + user_id="00000000-0000-0000-0000-000000000001", + selected_llm_config_id=0, + ) + assert result.resolved_llm_config_id == -2 + assert result.resolved_tier == "free" + + +@pytest.mark.asyncio +async def test_pinned_premium_stays_premium_after_quota_exhaustion(monkeypatch): + from app.config import config + + session = _FakeSession( + _thread(pinned_llm_config_id=-1, pinned_auto_mode=AUTO_FASTEST_MODE) + ) + monkeypatch.setattr( + config, + "GLOBAL_LLM_CONFIGS", + [ + {"id": -2, "provider": "OPENAI", "model_name": "gpt-free", "api_key": "k1", "billing_tier": "free"}, + {"id": -1, "provider": "OPENAI", "model_name": "gpt-prem", "api_key": "k2", "billing_tier": "premium"}, + ], + ) + + async def _blocked(*_args, **_kwargs): + return _FakeQuotaResult(allowed=False) + + monkeypatch.setattr( + "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage", + _blocked, + ) + + result = await resolve_or_get_pinned_llm_config_id( + session, + thread_id=1, + search_space_id=10, + user_id="00000000-0000-0000-0000-000000000001", + selected_llm_config_id=0, + ) + assert result.resolved_llm_config_id == -1 + assert result.from_existing_pin is True + + +@pytest.mark.asyncio +async def test_force_repin_free_switches_auto_premium_pin_to_free(monkeypatch): + from app.config import config + + session = _FakeSession( + _thread(pinned_llm_config_id=-1, pinned_auto_mode=AUTO_FASTEST_MODE) + ) + monkeypatch.setattr( + config, + "GLOBAL_LLM_CONFIGS", + [ + {"id": -2, "provider": "OPENAI", "model_name": "gpt-free", "api_key": "k1", "billing_tier": "free"}, + {"id": -1, "provider": "OPENAI", "model_name": "gpt-prem", "api_key": "k2", "billing_tier": "premium"}, + ], + ) + + async def _blocked(*_args, **_kwargs): + return _FakeQuotaResult(allowed=False) + + monkeypatch.setattr( + "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage", + _blocked, + ) + + result = await resolve_or_get_pinned_llm_config_id( + session, + thread_id=1, + search_space_id=10, + user_id="00000000-0000-0000-0000-000000000001", + selected_llm_config_id=0, + force_repin_free=True, + ) + assert result.resolved_llm_config_id == -2 + assert result.resolved_tier == "free" + assert result.from_existing_pin is False + assert session.thread.pinned_llm_config_id == -2 + + +@pytest.mark.asyncio +async def test_explicit_user_model_change_clears_pin(monkeypatch): + from app.config import config + + session = _FakeSession( + _thread(pinned_llm_config_id=-2, pinned_auto_mode=AUTO_FASTEST_MODE) + ) + monkeypatch.setattr( + config, + "GLOBAL_LLM_CONFIGS", + [ + {"id": -2, "provider": "OPENAI", "model_name": "gpt-free", "api_key": "k1"}, + ], + ) + + result = await resolve_or_get_pinned_llm_config_id( + session, + thread_id=1, + search_space_id=10, + user_id="00000000-0000-0000-0000-000000000001", + selected_llm_config_id=7, + ) + assert result.resolved_llm_config_id == 7 + assert session.thread.pinned_llm_config_id is None + assert session.thread.pinned_auto_mode is None + assert session.thread.pinned_at is None + assert session.commit_count == 1 + + +@pytest.mark.asyncio +async def test_invalid_pinned_config_repairs_with_new_pin(monkeypatch): + from app.config import config + + session = _FakeSession( + _thread(pinned_llm_config_id=-999, pinned_auto_mode=AUTO_FASTEST_MODE) + ) + monkeypatch.setattr( + config, + "GLOBAL_LLM_CONFIGS", + [ + {"id": -2, "provider": "OPENAI", "model_name": "gpt-free", "api_key": "k1"}, + ], + ) + + async def _allowed(*_args, **_kwargs): + return _FakeQuotaResult(allowed=True) + + monkeypatch.setattr( + "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage", + _allowed, + ) + + result = await resolve_or_get_pinned_llm_config_id( + session, + thread_id=1, + search_space_id=10, + user_id="00000000-0000-0000-0000-000000000001", + selected_llm_config_id=0, + ) + assert result.resolved_llm_config_id == -2 + assert session.thread.pinned_llm_config_id == -2 + assert session.commit_count == 1 diff --git a/surfsense_backend/tests/unit/test_stream_new_chat_contract.py b/surfsense_backend/tests/unit/test_stream_new_chat_contract.py index 034aa484c..86ea7edd1 100644 --- a/surfsense_backend/tests/unit/test_stream_new_chat_contract.py +++ b/surfsense_backend/tests/unit/test_stream_new_chat_contract.py @@ -1,9 +1,19 @@ +import inspect +import json +import logging +import re +from pathlib import Path + import pytest +import app.tasks.chat.stream_new_chat as stream_new_chat_module +from app.agents.new_chat.errors import BusyError from app.tasks.chat.stream_new_chat import ( StreamResult, + _classify_stream_exception, _contract_enforcement_active, _evaluate_file_contract_outcome, + _log_chat_stream_error, _tool_output_has_error, ) @@ -45,3 +55,217 @@ def test_contract_enforcement_local_only(): result.filesystem_mode = "cloud" assert not _contract_enforcement_active(result) + + +def _extract_chat_stream_payload(record_message: str) -> dict: + prefix = "[chat_stream_error] " + assert record_message.startswith(prefix) + return json.loads(record_message[len(prefix) :]) + + +def test_unified_chat_stream_error_log_schema(caplog): + with caplog.at_level(logging.INFO, logger="app.tasks.chat.stream_new_chat"): + _log_chat_stream_error( + flow="new", + error_kind="server_error", + error_code="SERVER_ERROR", + severity="warn", + is_expected=False, + request_id="req-123", + thread_id=101, + search_space_id=202, + user_id="user-1", + message="Error during chat: boom", + ) + + record = next(r for r in caplog.records if "[chat_stream_error]" in r.message) + payload = _extract_chat_stream_payload(record.message) + + required_keys = { + "event", + "flow", + "error_kind", + "error_code", + "severity", + "is_expected", + "request_id", + "thread_id", + "search_space_id", + "user_id", + "message", + } + assert required_keys.issubset(payload.keys()) + assert payload["event"] == "chat_stream_error" + assert payload["flow"] == "new" + assert payload["error_code"] == "SERVER_ERROR" + + +def test_premium_quota_uses_unified_chat_stream_log_shape(caplog): + with caplog.at_level(logging.INFO, logger="app.tasks.chat.stream_new_chat"): + _log_chat_stream_error( + flow="resume", + error_kind="premium_quota_exhausted", + error_code="PREMIUM_QUOTA_EXHAUSTED", + severity="info", + is_expected=True, + request_id="req-premium", + thread_id=303, + search_space_id=404, + user_id="user-2", + message="Buy more tokens to continue with this model, or switch to a free model", + extra={"auto_fallback": False}, + ) + + record = next(r for r in caplog.records if "[chat_stream_error]" in r.message) + payload = _extract_chat_stream_payload(record.message) + assert payload["event"] == "chat_stream_error" + assert payload["error_kind"] == "premium_quota_exhausted" + assert payload["error_code"] == "PREMIUM_QUOTA_EXHAUSTED" + assert payload["flow"] == "resume" + assert payload["is_expected"] is True + assert payload["auto_fallback"] is False + + +def test_stream_error_emission_keeps_machine_error_codes(): + source = inspect.getsource(stream_new_chat_module) + format_error_calls = re.findall(r"format_error\(", source) + emitted_error_codes = set(re.findall(r'error_code="([A-Z_]+)"', source)) + + # All stream paths should route through one shared terminal error emitter. + assert len(format_error_calls) == 1 + assert { + "PREMIUM_QUOTA_EXHAUSTED", + "SERVER_ERROR", + }.issubset(emitted_error_codes) + assert 'flow: Literal["new", "regenerate"] = "new"' in source + assert "_emit_stream_terminal_error" in source + assert "flow=flow" in source + assert 'flow="resume"' in source + + +def test_stream_exception_classifies_rate_limited(): + exc = Exception( + '{"error":{"type":"rate_limit_error","message":"Rate limited. Please try again later."}}' + ) + kind, code, severity, is_expected, user_message = _classify_stream_exception( + exc, flow_label="chat" + ) + assert kind == "rate_limited" + assert code == "RATE_LIMITED" + assert severity == "warn" + assert is_expected is True + assert "temporarily rate-limited" in user_message + + +def test_stream_exception_classifies_thread_busy(): + exc = BusyError(request_id="thread-123") + kind, code, severity, is_expected, user_message = _classify_stream_exception( + exc, flow_label="chat" + ) + assert kind == "thread_busy" + assert code == "THREAD_BUSY" + assert severity == "warn" + assert is_expected is True + assert "still finishing for this thread" in user_message + + +def test_stream_exception_classifies_thread_busy_from_message(): + exc = Exception("Thread is busy with another request") + kind, code, severity, is_expected, user_message = _classify_stream_exception( + exc, flow_label="chat" + ) + assert kind == "thread_busy" + assert code == "THREAD_BUSY" + assert severity == "warn" + assert is_expected is True + assert "still finishing for this thread" in user_message + + +def test_premium_classification_is_error_code_driven(): + classifier_path = Path(__file__).resolve().parents[3] / "surfsense_web/lib/chat/chat-error-classifier.ts" + source = classifier_path.read_text(encoding="utf-8") + + assert "PREMIUM_KEYWORDS" not in source + assert "RATE_LIMIT_KEYWORDS" not in source + assert "normalized.includes(" not in source + assert 'if (errorCode === "PREMIUM_QUOTA_EXHAUSTED") {' in source + + +def test_stream_terminal_error_handler_has_pre_accept_soft_rollback_hook(): + page_path = ( + Path(__file__).resolve().parents[3] + / "surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx" + ) + source = page_path.read_text(encoding="utf-8") + + assert "onPreAcceptFailure?: () => Promise;" in source + assert "if (!accepted) {" in source + assert "await onPreAcceptFailure?.();" in source + assert "await onAcceptedStreamError?.();" in source + assert "setMessages((prev) => prev.filter((m) => m.id !== userMsgId));" in source + assert "setMessageDocumentsMap((prev) => {" in source + + +def test_toast_only_pre_accept_policy_has_no_inline_failed_marker(): + user_message_path = ( + Path(__file__).resolve().parents[3] / "surfsense_web/components/assistant-ui/user-message.tsx" + ) + source = user_message_path.read_text(encoding="utf-8") + + assert "Not sent. Edit and retry." not in source + assert "failed_pre_accept" not in source + + +def test_network_send_failures_use_unified_retry_toast_message(): + classifier_path = Path(__file__).resolve().parents[3] / "surfsense_web/lib/chat/chat-error-classifier.ts" + classifier_source = classifier_path.read_text(encoding="utf-8") + page_path = ( + Path(__file__).resolve().parents[3] + / "surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx" + ) + page_source = page_path.read_text(encoding="utf-8") + + assert '"send_failed_pre_accept"' in classifier_source + assert 'errorCode === "SEND_FAILED_PRE_ACCEPT"' in classifier_source + assert "if (withCode.code) return withCode.code;" in classifier_source + assert 'userMessage: "Message not sent. Please retry."' in classifier_source + assert 'userMessage: "Connection issue. Please try again."' in classifier_source + assert "tagPreAcceptSendFailure(error)" in page_source + assert "const passthroughCodes = new Set([" in page_source + assert '"PREMIUM_QUOTA_EXHAUSTED"' in page_source + assert '"THREAD_BUSY"' in page_source + assert '"AUTH_EXPIRED"' in page_source + assert '"UNAUTHORIZED"' in page_source + assert '"RATE_LIMITED"' in page_source + assert '"NETWORK_ERROR"' in page_source + assert '"STREAM_PARSE_ERROR"' in page_source + assert '"TOOL_EXECUTION_ERROR"' in page_source + assert '"PERSIST_MESSAGE_FAILED"' in page_source + assert '"SERVER_ERROR"' in page_source + assert "passthroughCodes.has(existingCode)" in page_source + assert 'errorCode: "SEND_FAILED_PRE_ACCEPT"' in page_source + assert 'errorCode: "NETWORK_ERROR"' not in page_source + assert "Failed to start chat. Please try again." not in page_source + + +def test_pre_post_accept_abort_contract_exists_for_new_resume_regenerate_flows(): + page_path = ( + Path(__file__).resolve().parents[3] + / "surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx" + ) + source = page_path.read_text(encoding="utf-8") + + # Each flow tracks accepted boundary and passes it into shared terminal handling. + assert "let newAccepted = false;" in source + assert "let resumeAccepted = false;" in source + assert "let regenerateAccepted = false;" in source + assert "accepted: newAccepted," in source + assert "accepted: resumeAccepted," in source + assert "accepted: regenerateAccepted," in source + + # Pre-accept abort in resume/regenerate exits without persistence. + assert "if (!resumeAccepted) return;" in source + assert "if (!regenerateAccepted) return;" in source + + # New flow persists only when accepted and not already persisted. + assert "if (newAccepted && !userPersisted) {" in source diff --git a/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx b/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx index e5ac61cd9..fe625f169 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx @@ -19,6 +19,7 @@ import { currentThreadAtom, setTargetCommentIdAtom, } from "@/atoms/chat/current-thread.atom"; +import { setPremiumAlertForThreadAtom } from "@/atoms/chat/premium-alert.atom"; import { type MentionedDocumentInfo, mentionedDocumentIdsAtom, @@ -59,6 +60,10 @@ import { useMessagesSync } from "@/hooks/use-messages-sync"; import { getAgentFilesystemSelection } from "@/lib/agent-filesystem"; import { documentsApiService } from "@/lib/apis/documents-api.service"; import { getBearerToken } from "@/lib/auth-utils"; +import { + classifyChatError, + type ChatFlow, +} from "@/lib/chat/chat-error-classifier"; import { convertToThreadMessage } from "@/lib/chat/message-utils"; import { isPodcastGenerating, @@ -77,6 +82,7 @@ import { endReasoning, FrameBatchedUpdater, readSSEStream, + type SSEEvent, type ThinkingStepData, type ToolUIGate, updateThinkingSteps, @@ -99,7 +105,8 @@ import { import { NotFoundError } from "@/lib/error"; import { trackChatCreated, - trackChatError, + trackChatBlocked, + trackChatErrorDetailed, trackChatMessageSent, trackChatResponseReceived, } from "@/lib/posthog/events"; @@ -146,6 +153,105 @@ function markInterruptsCompleted(contentParts: Array<{ type: string; result?: un } } +function toStreamTerminalError( + event: Extract +): Error & { errorCode?: string } { + return Object.assign(new Error(event.errorText || "Server error"), { + errorCode: event.errorCode, + }); +} + +async function toHttpResponseError(response: Response): Promise { + const statusDefaultCode = + response.status === 409 + ? "THREAD_BUSY" + : response.status === 429 + ? "RATE_LIMITED" + : response.status === 401 || response.status === 403 + ? "AUTH_EXPIRED" + : "SERVER_ERROR"; + + let rawBody = ""; + try { + rawBody = await response.text(); + } catch { + // noop + } + + let parsedBody: Record | null = null; + if (rawBody) { + try { + const parsed = JSON.parse(rawBody); + if (typeof parsed === "object" && parsed !== null) { + parsedBody = parsed as Record; + } + } catch { + // noop + } + } + + const detail = parsedBody?.detail; + const detailObject = + typeof detail === "object" && detail !== null ? (detail as Record) : null; + const detailMessage = typeof detail === "string" ? detail : undefined; + const topLevelMessage = + typeof parsedBody?.message === "string" ? (parsedBody.message as string) : undefined; + const detailNestedMessage = + typeof detailObject?.message === "string" ? (detailObject.message as string) : undefined; + + const topLevelCode = + typeof parsedBody?.errorCode === "string" + ? parsedBody.errorCode + : typeof parsedBody?.error_code === "string" + ? parsedBody.error_code + : undefined; + const detailCode = + typeof detailObject?.errorCode === "string" + ? detailObject.errorCode + : typeof detailObject?.error_code === "string" + ? detailObject.error_code + : undefined; + + const errorCode = detailCode ?? topLevelCode ?? statusDefaultCode; + const message = + detailNestedMessage ?? + detailMessage ?? + topLevelMessage ?? + `Backend error: ${response.status}`; + + return Object.assign(new Error(message), { errorCode }); +} + +function tagPreAcceptSendFailure(error: unknown): unknown { + if (error instanceof Error) { + const withCode = error as Error & { errorCode?: string; code?: string }; + const existingCode = withCode.errorCode ?? withCode.code; + const passthroughCodes = new Set([ + "PREMIUM_QUOTA_EXHAUSTED", + "THREAD_BUSY", + "AUTH_EXPIRED", + "UNAUTHORIZED", + "RATE_LIMITED", + "NETWORK_ERROR", + "STREAM_PARSE_ERROR", + "TOOL_EXECUTION_ERROR", + "PERSIST_MESSAGE_FAILED", + "SERVER_ERROR", + ]); + if ( + existingCode && + passthroughCodes.has(existingCode) + ) { + return Object.assign(error, { errorCode: existingCode }); + } + return Object.assign(error, { errorCode: "SEND_FAILED_PRE_ACCEPT" }); + } + + return Object.assign(new Error("Failed to send message before stream acceptance"), { + errorCode: "SEND_FAILED_PRE_ACCEPT", + }); +} + /** * Zod schema for mentioned document info (for type-safe parsing) */ @@ -226,6 +332,164 @@ export default function NewChatPage() { interruptData: Record; } | null>(null); const toolsWithUI = TOOLS_WITH_UI_ALL; + const setMessageDocumentsMap = useSetAtom(messageDocumentsMapAtom); + + const persistAssistantErrorMessage = useCallback( + async ({ + threadId, + assistantMsgId, + text, + }: { + threadId: number | null; + assistantMsgId: string; + text: string; + }) => { + setMessages((prev) => + prev.map((m) => + m.id === assistantMsgId + ? { + ...m, + content: [{ type: "text", text }], + } + : m + ) + ); + + if (!threadId) return; + + // Persist only temporary assistant placeholders to avoid duplicate rows + // when the message already has a database-backed ID. + if (!assistantMsgId.startsWith("msg-assistant-")) return; + + try { + const savedMessage = await appendMessage(threadId, { + role: "assistant", + content: [{ type: "text", text }], + }); + const newMsgId = `msg-${savedMessage.id}`; + tokenUsageStore.rename(assistantMsgId, newMsgId); + setMessages((prev) => + prev.map((m) => (m.id === assistantMsgId ? { ...m, id: newMsgId } : m)) + ); + } catch (persistErr) { + console.error("Failed to persist assistant error message:", persistErr); + } + }, + [tokenUsageStore] + ); + + const persistUserTurn = useCallback( + async ({ + threadId, + userMsgId, + content, + mentionedDocs, + turnId, + logContext, + }: { + threadId: number | null; + userMsgId: string; + content: unknown; + mentionedDocs?: MentionedDocumentInfo[]; + turnId?: string | null; + logContext: string; + }) => { + if (!threadId) return null; + try { + const normalizedContent = Array.isArray(content) + ? ([...content] as unknown[]) + : [content]; + const hasMentionedDocumentsPart = normalizedContent.some((part) => + MentionedDocumentsPartSchema.safeParse(part).success + ); + if (mentionedDocs && mentionedDocs.length > 0 && !hasMentionedDocumentsPart) { + normalizedContent.push({ + type: "mentioned-documents", + documents: mentionedDocs, + }); + } + + const savedUserMessage = await appendMessage(threadId, { + role: "user", + content: normalizedContent as AppendMessage["content"], + turn_id: turnId, + }); + const newUserMsgId = `msg-${savedUserMessage.id}`; + setMessages((prev) => + prev.map((m) => + m.id === userMsgId + ? mergeChatTurnIdIntoMessage( + { ...m, id: newUserMsgId }, + savedUserMessage.turn_id + ) + : m + ) + ); + if (mentionedDocs && mentionedDocs.length > 0) { + setMessageDocumentsMap((prev) => { + const { [userMsgId]: _, ...rest } = prev; + return { + ...rest, + [newUserMsgId]: mentionedDocs, + }; + }); + } + return newUserMsgId; + } catch (err) { + console.error(`Failed to persist ${logContext} user message:`, err); + return null; + } + }, + [setMessageDocumentsMap] + ); + + const persistAssistantTurn = useCallback( + async ({ + threadId, + assistantMsgId, + content, + tokenUsage, + turnId, + logContext, + onRemapped, + }: { + threadId: number | null; + assistantMsgId: string; + content: unknown; + tokenUsage?: Record; + turnId?: string | null; + logContext: string; + onRemapped?: (newMsgId: string) => void; + }) => { + if (!threadId) return null; + try { + const savedMessage = await appendMessage(threadId, { + role: "assistant", + content: content as AppendMessage["content"], + token_usage: tokenUsage, + turn_id: turnId, + }); + const newMsgId = `msg-${savedMessage.id}`; + tokenUsageStore.rename(assistantMsgId, newMsgId); + setMessages((prev) => + prev.map((m) => + m.id === assistantMsgId + ? mergeChatTurnIdIntoMessage( + { ...m, id: newMsgId }, + savedMessage.turn_id + ) + : m + ) + ); + onRemapped?.(newMsgId); + return newMsgId; + } catch (err) { + console.error(`Failed to persist ${logContext} assistant message:`, err); + return null; + } + }, + [tokenUsageStore] + ); // Get disabled tools from the tool toggle UI const disabledTools = useAtomValue(disabledToolsAtom); @@ -233,9 +497,10 @@ export default function NewChatPage() { // Get mentioned document IDs from the composer. const mentionedDocumentIds = useAtomValue(mentionedDocumentIdsAtom); const mentionedDocuments = useAtomValue(mentionedDocumentsAtom); + const messageDocumentsMap = useAtomValue(messageDocumentsMapAtom); const setMentionedDocuments = useSetAtom(mentionedDocumentsAtom); - const setMessageDocumentsMap = useSetAtom(messageDocumentsMapAtom); const setCurrentThreadState = useSetAtom(currentThreadAtom); + const setPremiumAlertForThread = useSetAtom(setPremiumAlertForThreadAtom); const setTargetCommentId = useSetAtom(setTargetCommentIdAtom); const clearTargetCommentId = useSetAtom(clearTargetCommentIdAtom); const closeReportPanel = useSetAtom(closeReportPanelAtom); @@ -350,6 +615,122 @@ export default function NewChatPage() { return Number.isNaN(parsed) ? 0 : parsed; }, [params.chat_id]); + const handleChatFailure = useCallback( + async ({ + error, + flow, + threadId, + assistantMsgId, + }: { + error: unknown; + flow: ChatFlow; + threadId: number | null; + assistantMsgId: string; + }) => { + const normalized = classifyChatError({ + error, + flow, + context: { + searchSpaceId, + threadId, + }, + }); + + const logger = + normalized.severity === "error" + ? console.error + : normalized.severity === "warn" + ? console.warn + : console.info; + logger(`[NewChatPage] ${flow} ${normalized.kind}:`, error); + + const telemetryPayload = { + flow, + kind: normalized.kind, + error_code: normalized.errorCode, + severity: normalized.severity, + is_expected: normalized.isExpected, + message: normalized.userMessage, + }; + if (normalized.telemetryEvent === "chat_blocked") { + trackChatBlocked(searchSpaceId, threadId, telemetryPayload); + } else { + trackChatErrorDetailed(searchSpaceId, threadId, telemetryPayload); + } + + if (normalized.channel === "silent") { + return; + } + + if (normalized.channel === "pinned_inline") { + if (threadId) { + setPremiumAlertForThread({ + threadId, + message: normalized.userMessage, + userId: currentUser?.id ?? null, + }); + } + if (normalized.assistantMessage) { + await persistAssistantErrorMessage({ + threadId, + assistantMsgId, + text: normalized.assistantMessage, + }); + } + return; + } + + toast.error(normalized.userMessage); + }, + [ + currentUser?.id, + persistAssistantErrorMessage, + searchSpaceId, + setPremiumAlertForThread, + ] + ); + + const handleStreamTerminalError = useCallback( + async ({ + error, + flow, + threadId, + assistantMsgId, + accepted, + onAbort, + onPreAcceptFailure, + onAcceptedStreamError, + }: { + error: unknown; + flow: ChatFlow; + threadId: number | null; + assistantMsgId: string; + accepted: boolean; + onAbort?: () => Promise; + onPreAcceptFailure?: () => Promise; + onAcceptedStreamError?: () => Promise; + }) => { + if (error instanceof Error && error.name === "AbortError") { + await onAbort?.(); + return; + } + + if (!accepted) { + await onPreAcceptFailure?.(); + } else { + await onAcceptedStreamError?.(); + } + + await handleChatFailure({ + error: !accepted ? tagPreAcceptSendFailure(error) : error, + flow, + threadId, + assistantMsgId: accepted ? assistantMsgId : "no-persist-assistant", + }); + }, + [handleChatFailure] + ); + // Initialize thread and load messages // For new chats (no urlChatId), we use lazy creation - thread is created on first message const initializeThread = useCallback(async () => { @@ -576,7 +957,12 @@ export default function NewChatPage() { ); } catch (error) { console.error("[NewChatPage] Failed to create thread:", error); - toast.error("Failed to start chat. Please try again."); + await handleChatFailure({ + error: tagPreAcceptSendFailure(error), + flow: "new", + threadId: currentThreadId, + assistantMsgId: "no-persist-assistant", + }); return; } } @@ -661,27 +1047,6 @@ export default function NewChatPage() { }); } - appendMessage(currentThreadId, { - role: "user", - content: persistContent, - }) - .then((savedMessage) => { - const newUserMsgId = `msg-${savedMessage.id}`; - setMessages((prev) => - prev.map((m) => (m.id === userMsgId ? { ...m, id: newUserMsgId } : m)) - ); - setMessageDocumentsMap((prev) => { - const docs = prev[userMsgId]; - if (!docs) return prev; - const { [userMsgId]: _, ...rest } = prev; - return { ...rest, [newUserMsgId]: docs }; - }); - if (isNewThread) { - queryClient.invalidateQueries({ queryKey: ["threads", String(searchSpaceId)] }); - } - }) - .catch((err) => console.error("Failed to persist user message:", err)); - // Start streaming response setIsRunning(true); const controller = new AbortController(); @@ -701,20 +1066,11 @@ export default function NewChatPage() { const { contentParts, toolCallIndices } = contentPartsState; let wasInterrupted = false; let tokenUsageData: Record | null = null; + let newAccepted = false; + let userPersisted = false; // Captured from ``data-turn-info`` at stream start. let streamedChatTurnId: string | null = null; - // Add placeholder assistant message - setMessages((prev) => [ - ...prev, - { - id: assistantMsgId, - role: "assistant", - content: [{ type: "text", text: "" }], - createdAt: new Date(), - }, - ]); - try { const backendUrl = process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL || "http://localhost:8000"; const selection = await getAgentFilesystemSelection(searchSpaceId); @@ -774,8 +1130,18 @@ export default function NewChatPage() { }); if (!response.ok) { - throw new Error(`Backend error: ${response.status}`); + throw await toHttpResponseError(response); } + newAccepted = true; + setMessages((prev) => [ + ...prev, + { + id: assistantMsgId, + role: "assistant", + content: [{ type: "text", text: "" }], + createdAt: new Date(), + }, + ]); const flushMessages = () => { setMessages((prev) => @@ -1015,7 +1381,7 @@ export default function NewChatPage() { break; case "error": - throw new Error(parsed.errorText || "Server error"); + throw toStreamTerminalError(parsed); } } @@ -1024,99 +1390,107 @@ export default function NewChatPage() { // Skip persistence for interrupted messages -- handleResume will persist the final version const finalContent = buildContentForPersistence(contentPartsState, toolsWithUI); if (contentParts.length > 0 && !wasInterrupted) { - try { - const savedMessage = await appendMessage(currentThreadId, { - role: "assistant", - content: finalContent, - token_usage: tokenUsageData ?? undefined, - turn_id: streamedChatTurnId, + if (!userPersisted) { + const persistedUserMsgId = await persistUserTurn({ + threadId: currentThreadId, + userMsgId, + content: persistContent, + mentionedDocs: allMentionedDocs, + turnId: streamedChatTurnId, + logContext: "new chat", }); - - // Update message ID from temporary to database ID so comments work immediately - const newMsgId = `msg-${savedMessage.id}`; - tokenUsageStore.rename(assistantMsgId, newMsgId); - setMessages((prev) => - prev.map((m) => - m.id === assistantMsgId - ? mergeChatTurnIdIntoMessage({ ...m, id: newMsgId }, savedMessage.turn_id) - : m - ) - ); - - // Update pending interrupt with the new persisted message ID - setPendingInterrupt((prev) => - prev && prev.assistantMsgId === assistantMsgId - ? { ...prev, assistantMsgId: newMsgId } - : prev - ); - } catch (err) { - console.error("Failed to persist assistant message:", err); + userPersisted = Boolean(persistedUserMsgId); + if (userPersisted && isNewThread) { + queryClient.invalidateQueries({ queryKey: ["threads", String(searchSpaceId)] }); + } } + await persistAssistantTurn({ + threadId: currentThreadId, + assistantMsgId, + content: finalContent, + tokenUsage: tokenUsageData ?? undefined, + turnId: streamedChatTurnId, + logContext: "new chat", + onRemapped: (newMsgId) => { + setPendingInterrupt((prev) => + prev && prev.assistantMsgId === assistantMsgId + ? { ...prev, assistantMsgId: newMsgId } + : prev + ); + }, + }); + // Track successful response trackChatResponseReceived(searchSpaceId, currentThreadId); } } catch (error) { batcher.dispose(); - if (error instanceof Error && error.name === "AbortError") { - // Request was cancelled by user - persist partial response if any content was received - const hasContent = contentParts.some( - (part) => - (part.type === "text" && part.text.length > 0) || - (part.type === "reasoning" && part.text.length > 0) || - (part.type === "tool-call" && - (toolsWithUI === "all" || toolsWithUI.has(part.toolName))) - ); - if (hasContent && currentThreadId) { - const partialContent = buildContentForPersistence(contentPartsState, toolsWithUI); - try { - const savedMessage = await appendMessage(currentThreadId, { - role: "assistant", - content: partialContent, - turn_id: streamedChatTurnId, + await handleStreamTerminalError({ + error, + flow: "new", + threadId: currentThreadId, + assistantMsgId, + accepted: newAccepted, + onAbort: async () => { + if (newAccepted && !userPersisted) { + const persistedUserMsgId = await persistUserTurn({ + threadId: currentThreadId, + userMsgId, + content: persistContent, + mentionedDocs: allMentionedDocs, + turnId: streamedChatTurnId, + logContext: "new chat (aborted)", }); - - // Update message ID from temporary to database ID - const newMsgId = `msg-${savedMessage.id}`; - setMessages((prev) => - prev.map((m) => - m.id === assistantMsgId - ? mergeChatTurnIdIntoMessage({ ...m, id: newMsgId }, savedMessage.turn_id) - : m - ) - ); - } catch (err) { - console.error("Failed to persist partial assistant message:", err); + userPersisted = Boolean(persistedUserMsgId); + if (userPersisted && isNewThread) { + queryClient.invalidateQueries({ queryKey: ["threads", String(searchSpaceId)] }); + } } - } - return; - } - console.error("[NewChatPage] Chat error:", error); - // Track chat error - trackChatError( - searchSpaceId, - currentThreadId, - error instanceof Error ? error.message : "Unknown error" - ); - - toast.error("Failed to get response. Please try again."); - // Update assistant message with error - setMessages((prev) => - prev.map((m) => - m.id === assistantMsgId - ? { - ...m, - content: [ - { - type: "text", - text: "Sorry, there was an error. Please try again.", - }, - ], - } - : m - ) - ); + const hasContent = contentParts.some( + (part) => + (part.type === "text" && part.text.length > 0) || + (part.type === "reasoning" && part.text.length > 0) || + (part.type === "tool-call" && + (toolsWithUI === "all" || toolsWithUI.has(part.toolName))) + ); + if (hasContent && currentThreadId) { + const partialContent = buildContentForPersistence(contentPartsState, toolsWithUI); + await persistAssistantTurn({ + threadId: currentThreadId, + assistantMsgId, + content: partialContent, + turnId: streamedChatTurnId, + logContext: "partial new chat", + }); + } + }, + onAcceptedStreamError: async () => { + if (!userPersisted) { + const persistedUserMsgId = await persistUserTurn({ + threadId: currentThreadId, + userMsgId, + content: persistContent, + mentionedDocs: allMentionedDocs, + turnId: streamedChatTurnId, + logContext: "new chat (stream error)", + }); + userPersisted = Boolean(persistedUserMsgId); + if (userPersisted && isNewThread) { + queryClient.invalidateQueries({ queryKey: ["threads", String(searchSpaceId)] }); + } + } + }, + onPreAcceptFailure: async () => { + setMessages((prev) => prev.filter((m) => m.id !== userMsgId)); + setMessageDocumentsMap((prev) => { + if (!(userMsgId in prev)) return prev; + const { [userMsgId]: _removed, ...rest } = prev; + return rest; + }); + }, + }); } finally { setIsRunning(false); abortControllerRef.current = null; @@ -1138,7 +1512,10 @@ export default function NewChatPage() { tokenUsageStore, pendingUserImageUrls, setPendingUserImageUrls, - toolsWithUI, + handleStreamTerminalError, + handleChatFailure, + persistAssistantTurn, + persistUserTurn, ] ); @@ -1176,6 +1553,7 @@ export default function NewChatPage() { }; const { contentParts, toolCallIndices } = contentPartsState; let tokenUsageData: Record | null = null; + let resumeAccepted = false; // Captured from ``data-turn-info`` at stream start. let streamedChatTurnId: string | null = null; @@ -1273,8 +1651,9 @@ export default function NewChatPage() { }); if (!response.ok) { - throw new Error(`Backend error: ${response.status}`); + throw await toHttpResponseError(response); } + resumeAccepted = true; const flushMessages = () => { setMessages((prev) => @@ -1458,7 +1837,7 @@ export default function NewChatPage() { break; case "error": - throw new Error(parsed.errorText || "Server error"); + throw toStreamTerminalError(parsed); } } @@ -1466,39 +1845,56 @@ export default function NewChatPage() { const finalContent = buildContentForPersistence(contentPartsState, toolsWithUI); if (contentParts.length > 0) { - try { - const savedMessage = await appendMessage(resumeThreadId, { - role: "assistant", - content: finalContent, - token_usage: tokenUsageData ?? undefined, - turn_id: streamedChatTurnId, - }); - const newMsgId = `msg-${savedMessage.id}`; - tokenUsageStore.rename(assistantMsgId, newMsgId); - setMessages((prev) => - prev.map((m) => - m.id === assistantMsgId - ? mergeChatTurnIdIntoMessage({ ...m, id: newMsgId }, savedMessage.turn_id) - : m - ) - ); - } catch (err) { - console.error("Failed to persist resumed assistant message:", err); - } + await persistAssistantTurn({ + threadId: resumeThreadId, + assistantMsgId, + content: finalContent, + tokenUsage: tokenUsageData ?? undefined, + turnId: streamedChatTurnId, + logContext: "resumed chat", + }); } } catch (error) { batcher.dispose(); - if (error instanceof Error && error.name === "AbortError") { - return; - } - console.error("[NewChatPage] Resume error:", error); - toast.error("Failed to resume. Please try again."); + await handleStreamTerminalError({ + error, + flow: "resume", + threadId: resumeThreadId, + assistantMsgId, + accepted: resumeAccepted, + onAbort: async () => { + if (!resumeAccepted) return; + const hasContent = contentParts.some( + (part) => + (part.type === "text" && part.text.length > 0) || + (part.type === "reasoning" && part.text.length > 0) || + (part.type === "tool-call" && + (toolsWithUI === "all" || toolsWithUI.has(part.toolName))) + ); + if (!hasContent) return; + const partialContent = buildContentForPersistence(contentPartsState, toolsWithUI); + await persistAssistantTurn({ + threadId: resumeThreadId, + assistantMsgId, + content: partialContent, + turnId: streamedChatTurnId, + logContext: "partial resumed chat", + }); + }, + }); } finally { setIsRunning(false); abortControllerRef.current = null; } }, - [pendingInterrupt, messages, searchSpaceId, tokenUsageStore, toolsWithUI] + [ + pendingInterrupt, + messages, + searchSpaceId, + tokenUsageStore, + handleStreamTerminalError, + persistAssistantTurn, + ] ); useEffect(() => { @@ -1580,6 +1976,7 @@ export default function NewChatPage() { editExtras?: { userMessageContent: ThreadMessageLike["content"]; userImages: NewChatUserImagePayload[]; + sourceUserMessageId?: string; }, editFromPosition?: { /** Message id (numeric, parsed from ``msg-``) to rewind to. */ @@ -1611,11 +2008,13 @@ export default function NewChatPage() { let userQueryToDisplay: string | undefined; let originalUserMessageContent: ThreadMessageLike["content"] | null = null; let originalUserMessageMetadata: ThreadMessageLike["metadata"] | undefined; + let sourceUserMessageId: string | undefined = editExtras?.sourceUserMessageId; if (!isEdit) { // Reload mode - find and preserve the last user message content const lastUserMessage = [...messages].reverse().find((m) => m.role === "user"); if (lastUserMessage) { + sourceUserMessageId = lastUserMessage.id; originalUserMessageContent = lastUserMessage.content; originalUserMessageMetadata = lastUserMessage.metadata; // Extract text for the API request @@ -1630,26 +2029,6 @@ export default function NewChatPage() { userQueryToDisplay = newUserQuery; } - // Remove downstream messages from the UI immediately. The - // backend will also delete them from the database. - // - // When an explicit ``fromMessageId`` is passed, slice from - // that message forward; otherwise fall back to the legacy - // "drop the last 2" behaviour. - setMessages((prev) => { - if (editFromPosition?.fromMessageId != null) { - const targetId = `msg-${editFromPosition.fromMessageId}`; - const sliceIndex = prev.findIndex((m) => m.id === targetId); - if (sliceIndex >= 0) { - return prev.slice(0, sliceIndex); - } - } - if (prev.length >= 2) { - return prev.slice(0, -2); - } - return prev; - }); - // Start streaming setIsRunning(true); const controller = new AbortController(); @@ -1669,6 +2048,8 @@ export default function NewChatPage() { const { contentParts, toolCallIndices } = contentPartsState; const batcher = new FrameBatchedUpdater(); let tokenUsageData: Record | null = null; + let regenerateAccepted = false; + let userPersisted = false; // Captured from ``data-turn-info`` at stream start; stamped // onto persisted messages so future edits can locate the // right LangGraph checkpoint. @@ -1685,19 +2066,13 @@ export default function NewChatPage() { createdAt: new Date(), metadata: isEdit ? undefined : originalUserMessageMetadata, }; - setMessages((prev) => [...prev, userMessage]); - - // Add placeholder assistant message - setMessages((prev) => [ - ...prev, - { - id: assistantMsgId, - role: "assistant", - content: [{ type: "text", text: "" }], - createdAt: new Date(), - }, - ]); - + const userContentToPersist = isEdit + ? (editExtras?.userMessageContent ?? [{ type: "text", text: newUserQuery ?? "" }]) + : originalUserMessageContent || [{ type: "text", text: userQueryToDisplay || "" }]; + const sourceMentionedDocs = + sourceUserMessageId && messageDocumentsMap[sourceUserMessageId] + ? messageDocumentsMap[sourceUserMessageId] + : []; try { const selection = await getAgentFilesystemSelection(searchSpaceId); const requestBody: Record = { @@ -1732,7 +2107,43 @@ export default function NewChatPage() { }); if (!response.ok) { - throw new Error(`Backend error: ${response.status}`); + throw await toHttpResponseError(response); + } + regenerateAccepted = true; + + // Only switch UI to regenerated placeholder messages after the backend accepts + // regenerate. This avoids local message loss when regenerate fails early (e.g. 400). + // + // When an explicit ``editFromPosition.fromMessageId`` is passed, slice from + // that message forward so edit-from-arbitrary-position drops every downstream + // message; otherwise fall back to the legacy "drop the last 2" behaviour. + setMessages((prev) => { + let base = prev; + if (editFromPosition?.fromMessageId != null) { + const targetId = `msg-${editFromPosition.fromMessageId}`; + const sliceIndex = prev.findIndex((m) => m.id === targetId); + if (sliceIndex >= 0) { + base = prev.slice(0, sliceIndex); + } + } else if (prev.length >= 2) { + base = prev.slice(0, -2); + } + return [ + ...base, + userMessage, + { + id: assistantMsgId, + role: "assistant", + content: [{ type: "text", text: "" }], + createdAt: new Date(), + }, + ]; + }); + if (sourceMentionedDocs.length > 0) { + setMessageDocumentsMap((prev) => ({ + ...prev, + [userMsgId]: sourceMentionedDocs, + })); } const flushMessages = () => { @@ -1922,7 +2333,7 @@ export default function NewChatPage() { break; case "error": - throw new Error(parsed.errorText || "Server error"); + throw toStreamTerminalError(parsed); } } @@ -1931,79 +2342,97 @@ export default function NewChatPage() { // Persist messages after streaming completes const finalContent = buildContentForPersistence(contentPartsState, toolsWithUI); if (contentParts.length > 0) { - try { - // Persist user message (for both edit and reload modes, since backend deleted it) - const userContentToPersist = isEdit - ? (editExtras?.userMessageContent ?? [{ type: "text", text: newUserQuery ?? "" }]) - : originalUserMessageContent || [{ type: "text", text: userQueryToDisplay || "" }]; + const persistedUserMsgId = await persistUserTurn({ + threadId, + userMsgId, + content: userContentToPersist, + mentionedDocs: sourceMentionedDocs, + turnId: streamedChatTurnId, + logContext: "regenerated", + }); + userPersisted = Boolean(persistedUserMsgId); - const savedUserMessage = await appendMessage(threadId, { - role: "user", - content: userContentToPersist, - turn_id: streamedChatTurnId, - }); + await persistAssistantTurn({ + threadId, + assistantMsgId, + content: finalContent, + tokenUsage: tokenUsageData ?? undefined, + turnId: streamedChatTurnId, + logContext: "regenerated", + }); - // Update user message ID to database ID - const newUserMsgId = `msg-${savedUserMessage.id}`; - setMessages((prev) => - prev.map((m) => - m.id === userMsgId - ? mergeChatTurnIdIntoMessage({ ...m, id: newUserMsgId }, savedUserMessage.turn_id) - : m - ) - ); - - // Persist assistant message - const savedMessage = await appendMessage(threadId, { - role: "assistant", - content: finalContent, - token_usage: tokenUsageData ?? undefined, - turn_id: streamedChatTurnId, - }); - - const newMsgId = `msg-${savedMessage.id}`; - tokenUsageStore.rename(assistantMsgId, newMsgId); - setMessages((prev) => - prev.map((m) => - m.id === assistantMsgId - ? mergeChatTurnIdIntoMessage({ ...m, id: newMsgId }, savedMessage.turn_id) - : m - ) - ); - - trackChatResponseReceived(searchSpaceId, threadId); - } catch (err) { - console.error("Failed to persist regenerated message:", err); - } + trackChatResponseReceived(searchSpaceId, threadId); } } catch (error) { - if (error instanceof Error && error.name === "AbortError") { - return; - } batcher.dispose(); - console.error("[NewChatPage] Regeneration error:", error); - trackChatError( - searchSpaceId, + await handleStreamTerminalError({ + error, + flow: "regenerate", threadId, - error instanceof Error ? error.message : "Unknown error" - ); - toast.error("Failed to regenerate response. Please try again."); - setMessages((prev) => - prev.map((m) => - m.id === assistantMsgId - ? { - ...m, - content: [{ type: "text", text: "Sorry, there was an error. Please try again." }], - } - : m - ) - ); + assistantMsgId, + accepted: regenerateAccepted, + onAbort: async () => { + if (!regenerateAccepted) return; + if (!userPersisted) { + const persistedUserMsgId = await persistUserTurn({ + threadId, + userMsgId, + content: userContentToPersist, + mentionedDocs: sourceMentionedDocs, + turnId: streamedChatTurnId, + logContext: "regenerated (aborted)", + }); + userPersisted = Boolean(persistedUserMsgId); + } + const hasContent = contentParts.some( + (part) => + (part.type === "text" && part.text.length > 0) || + (part.type === "reasoning" && part.text.length > 0) || + (part.type === "tool-call" && + (toolsWithUI === "all" || toolsWithUI.has(part.toolName))) + ); + if (!hasContent) return; + const partialContent = buildContentForPersistence(contentPartsState, toolsWithUI); + await persistAssistantTurn({ + threadId, + assistantMsgId, + content: partialContent, + tokenUsage: tokenUsageData ?? undefined, + turnId: streamedChatTurnId, + logContext: "partial regenerated chat", + }); + }, + onAcceptedStreamError: async () => { + if (!userPersisted) { + const persistedUserMsgId = await persistUserTurn({ + threadId, + userMsgId, + content: userContentToPersist, + mentionedDocs: sourceMentionedDocs, + turnId: streamedChatTurnId, + logContext: "regenerated (stream error)", + }); + userPersisted = Boolean(persistedUserMsgId); + } + }, + }); } finally { setIsRunning(false); abortControllerRef.current = null; } }, - [threadId, searchSpaceId, messages, disabledTools, tokenUsageStore, toolsWithUI] + [ + threadId, + searchSpaceId, + messages, + disabledTools, + messageDocumentsMap, + setMessageDocumentsMap, + tokenUsageStore, + handleStreamTerminalError, + persistAssistantTurn, + persistUserTurn, + ] ); // Handle editing a message - truncates history and regenerates with new query. @@ -2037,7 +2466,11 @@ export default function NewChatPage() { if (fromMessageId == null) { // No source id (or non-DB id) — fall back to today's // last-2 behaviour. The user gets the legacy edit flow. - await handleRegenerate(queryForApi, { userMessageContent, userImages }); + await handleRegenerate(queryForApi, { + userMessageContent, + userImages, + sourceUserMessageId: sourceId, + }); return; } @@ -2086,7 +2519,7 @@ export default function NewChatPage() { // Nothing to revert — submit silently. await handleRegenerate( queryForApi, - { userMessageContent, userImages }, + { userMessageContent, userImages, sourceUserMessageId: sourceId }, { fromMessageId, revertActions: false } ); return; @@ -2115,6 +2548,7 @@ export default function NewChatPage() { { userMessageContent: pending.userMessageContent, userImages: pending.userImages, + sourceUserMessageId: `msg-${pending.fromMessageId}`, }, { fromMessageId: pending.fromMessageId, diff --git a/surfsense_web/atoms/chat/premium-alert.atom.ts b/surfsense_web/atoms/chat/premium-alert.atom.ts new file mode 100644 index 000000000..1c837dd65 --- /dev/null +++ b/surfsense_web/atoms/chat/premium-alert.atom.ts @@ -0,0 +1,45 @@ +import { atom } from "jotai"; + +export type PremiumAlertState = { + message: string; +}; + +export const premiumAlertByThreadAtom = atom>({}); + +export const setPremiumAlertForThreadAtom = atom( + null, + ( + get, + set, + payload: { + threadId: number; + message: string; + userId?: string | null; + } + ) => { + const storageKey = `surfsense-premium-alert-seen-v1:${payload.userId ?? "anonymous"}`; + + if (typeof window !== "undefined") { + const hasSeen = localStorage.getItem(storageKey) === "true"; + if (hasSeen) return; + } + + const current = get(premiumAlertByThreadAtom); + set(premiumAlertByThreadAtom, { + ...current, + [payload.threadId]: { message: payload.message }, + }); + + if (typeof window !== "undefined") { + localStorage.setItem(storageKey, "true"); + } + } +); + +export const clearPremiumAlertForThreadAtom = atom(null, (get, set, threadId: number) => { + const current = get(premiumAlertByThreadAtom); + if (!(threadId in current)) return; + const next = { ...current }; + delete next[threadId]; + set(premiumAlertByThreadAtom, next); +}); diff --git a/surfsense_web/components/assistant-ui/thread.tsx b/surfsense_web/components/assistant-ui/thread.tsx index e58783c87..3e27e7adb 100644 --- a/surfsense_web/components/assistant-ui/thread.tsx +++ b/surfsense_web/components/assistant-ui/thread.tsx @@ -37,10 +37,13 @@ import { toggleToolAtom, } from "@/atoms/agent-tools/agent-tools.atoms"; import { chatSessionStateAtom } from "@/atoms/chat/chat-session-state.atom"; -import { - mentionedDocumentsAtom, -} from "@/atoms/chat/mentioned-documents.atom"; +import { currentThreadAtom } from "@/atoms/chat/current-thread.atom"; +import { mentionedDocumentsAtom } from "@/atoms/chat/mentioned-documents.atom"; import { pendingUserImageDataUrlsAtom } from "@/atoms/chat/pending-user-images.atom"; +import { + clearPremiumAlertForThreadAtom, + premiumAlertByThreadAtom, +} from "@/atoms/chat/premium-alert.atom"; import { connectorDialogOpenAtom } from "@/atoms/connector-dialog/connector-dialog.atoms"; import { connectorsAtom } from "@/atoms/connectors/connector-query.atoms"; import { membersAtom } from "@/atoms/members/members-query.atoms"; @@ -135,6 +138,9 @@ const ThreadContent: FC = () => { style={{ paddingBottom: "max(1rem, env(safe-area-inset-bottom))" }} > + !thread.isEmpty}> + + !thread.isEmpty}> @@ -144,6 +150,37 @@ const ThreadContent: FC = () => { ); }; +const PremiumQuotaPinnedAlert: FC = () => { + const currentThreadState = useAtomValue(currentThreadAtom); + const alertsByThread = useAtomValue(premiumAlertByThreadAtom); + const clearPremiumAlertForThread = useSetAtom(clearPremiumAlertForThreadAtom); + + const currentThreadId = currentThreadState?.id; + if (!currentThreadId) return null; + + const alert = alertsByThread[currentThreadId]; + if (!alert) return null; + + return ( +
+
+ +
+

{alert.message}

+
+ +
+
+ ); +}; + const ThreadScrollToBottom: FC = () => { return ( diff --git a/surfsense_web/components/free-chat/anonymous-chat.tsx b/surfsense_web/components/free-chat/anonymous-chat.tsx index b286c5316..3de2ca434 100644 --- a/surfsense_web/components/free-chat/anonymous-chat.tsx +++ b/surfsense_web/components/free-chat/anonymous-chat.tsx @@ -104,7 +104,13 @@ export function AnonymousChat({ model }: AnonymousChatProps) { setMessages((prev) => prev.filter((m) => m.id !== assistantId)); return; } - throw new Error(`Stream error: ${response.status}`); + const body = await response.text().catch(() => ""); + const errorCode = response.status === 409 ? "THREAD_BUSY" : "SERVER_ERROR"; + const message = + errorCode === "THREAD_BUSY" + ? "A previous response is still stopping. Please try again in a moment." + : `Stream error: ${response.status}`; + throw Object.assign(new Error(body || message), { errorCode }); } for await (const event of readSSEStream(response)) { @@ -115,10 +121,12 @@ export function AnonymousChat({ model }: AnonymousChatProps) { prev.map((m) => (m.id === assistantId ? { ...m, content: m.content + event.delta } : m)) ); } else if (event.type === "error") { + const message = + event.errorCode === "THREAD_BUSY" + ? "A previous response is still stopping. Please try again in a moment." + : event.errorText; setMessages((prev) => - prev.map((m) => - m.id === assistantId ? { ...m, content: m.content || event.errorText } : m - ) + prev.map((m) => (m.id === assistantId ? { ...m, content: m.content || message } : m)) ); } else if ("type" in event && event.type === "data-token-usage") { // After streaming completes, refresh quota diff --git a/surfsense_web/components/free-chat/free-chat-page.tsx b/surfsense_web/components/free-chat/free-chat-page.tsx index 05db99407..080d9a2b6 100644 --- a/surfsense_web/components/free-chat/free-chat-page.tsx +++ b/surfsense_web/components/free-chat/free-chat-page.tsx @@ -55,6 +55,48 @@ function parseCaptchaError(status: number, body: string): string | null { return null; } +function normalizeFreeChatErrorMessage(error: unknown): string { + if (!(error instanceof Error)) return "An unexpected error occurred"; + const code = (error as Error & { errorCode?: string }).errorCode; + if (code === "THREAD_BUSY") { + return "A previous response is still stopping. Please try again in a moment."; + } + return error.message || "An unexpected error occurred"; +} + +function toFreeChatHttpError(status: number, body: string): Error & { errorCode?: string } { + let errorCode: string | undefined; + let message = body || `Server error: ${status}`; + try { + const parsed = JSON.parse(body) as Record; + const detail = + typeof parsed.detail === "object" && parsed.detail !== null + ? (parsed.detail as Record) + : null; + errorCode = + (typeof detail?.error_code === "string" ? detail.error_code : undefined) ?? + (typeof detail?.errorCode === "string" ? detail.errorCode : undefined) ?? + (typeof parsed.error_code === "string" ? parsed.error_code : undefined) ?? + (typeof parsed.errorCode === "string" ? parsed.errorCode : undefined); + message = + (typeof detail?.message === "string" ? detail.message : undefined) ?? + (typeof parsed.message === "string" ? parsed.message : undefined) ?? + (typeof parsed.detail === "string" ? parsed.detail : undefined) ?? + message; + } catch { + // non-json response + } + + if (!errorCode) { + if (status === 409) errorCode = "THREAD_BUSY"; + else if (status === 429) errorCode = "RATE_LIMITED"; + else if (status === 401 || status === 403) errorCode = "AUTH_EXPIRED"; + else errorCode = "SERVER_ERROR"; + } + + return Object.assign(new Error(message), { errorCode }); +} + export function FreeChatPage() { const anonMode = useAnonymousMode(); const modelSlug = anonMode.isAnonymous ? anonMode.modelSlug : ""; @@ -124,7 +166,7 @@ export function FreeChatPage() { const body = await response.text().catch(() => ""); const captchaCode = parseCaptchaError(response.status, body); if (captchaCode) return "captcha"; - throw new Error(body || `Server error: ${response.status}`); + throw toFreeChatHttpError(response.status, body); } const currentThinkingSteps = new Map(); @@ -244,7 +286,9 @@ export function FreeChatPage() { break; case "error": - throw new Error(parsed.errorText || "Server error"); + throw Object.assign(new Error(parsed.errorText || "Server error"), { + errorCode: parsed.errorCode, + }); } } batcher.flush(); @@ -334,7 +378,7 @@ export function FreeChatPage() { } catch (error) { if (error instanceof Error && error.name === "AbortError") return; console.error("[FreeChatPage] Chat error:", error); - const errorText = error instanceof Error ? error.message : "An unexpected error occurred"; + const errorText = normalizeFreeChatErrorMessage(error); setMessages((prev) => prev.map((m) => m.id === assistantMsgId @@ -393,7 +437,7 @@ export function FreeChatPage() { } catch (error) { if (error instanceof Error && error.name === "AbortError") return; console.error("[FreeChatPage] Retry error:", error); - const errorText = error instanceof Error ? error.message : "An unexpected error occurred"; + const errorText = normalizeFreeChatErrorMessage(error); setMessages((prev) => prev.map((m) => m.id === assistantMsgId diff --git a/surfsense_web/lib/chat/chat-error-classifier.ts b/surfsense_web/lib/chat/chat-error-classifier.ts new file mode 100644 index 000000000..57341a4c3 --- /dev/null +++ b/surfsense_web/lib/chat/chat-error-classifier.ts @@ -0,0 +1,304 @@ +export type ChatFlow = "new" | "resume" | "regenerate"; + +export type ChatErrorKind = + | "premium_quota_exhausted" + | "thread_busy" + | "send_failed_pre_accept" + | "auth_expired" + | "rate_limited" + | "network_offline" + | "stream_interrupted" + | "stream_parse_error" + | "tool_execution_error" + | "persist_message_failed" + | "server_error" + | "unknown"; + +export type ChatErrorChannel = "pinned_inline" | "toast" | "silent"; +export type ChatTelemetryEvent = "chat_blocked" | "chat_error"; +export type ChatErrorSeverity = "info" | "warn" | "error"; + +export interface NormalizedChatError { + kind: ChatErrorKind; + channel: ChatErrorChannel; + severity: ChatErrorSeverity; + telemetryEvent: ChatTelemetryEvent; + isExpected: boolean; + userMessage: string; + assistantMessage?: string; + rawMessage?: string; + errorCode?: string; + details?: Record; +} + +export interface RawChatErrorInput { + error: unknown; + flow: ChatFlow; + context?: { + searchSpaceId?: number; + threadId?: number | null; + }; +} + +export const PREMIUM_QUOTA_ASSISTANT_MESSAGE = + "I can’t continue with the current premium model because your premium tokens are exhausted. Switch to a free model or buy more tokens to continue."; + +function getErrorMessage(error: unknown): string { + if (error instanceof Error) return error.message; + if (typeof error === "string") return error; + try { + return JSON.stringify(error); + } catch { + return "Unknown error"; + } +} + +function getErrorCode(error: unknown, parsedJson: Record | null): string | undefined { + if (error instanceof Error) { + const withCode = error as Error & { errorCode?: string; code?: string }; + if (withCode.errorCode) return withCode.errorCode; + if (withCode.code) return withCode.code; + } + + if (typeof error === "object" && error !== null) { + const withCode = error as { errorCode?: unknown }; + if (typeof withCode.errorCode === "string" && withCode.errorCode) { + return withCode.errorCode; + } + } + + if (parsedJson) { + const topLevelCode = parsedJson.errorCode; + if (typeof topLevelCode === "string" && topLevelCode) { + return topLevelCode; + } + } + + return undefined; +} + +function parseEmbeddedJson(text: string): Record | null { + const candidates = [text]; + const firstBraceIdx = text.indexOf("{"); + if (firstBraceIdx >= 0) { + candidates.push(text.slice(firstBraceIdx)); + } + for (const candidate of candidates) { + try { + const parsed = JSON.parse(candidate); + if (typeof parsed === "object" && parsed !== null) { + return parsed as Record; + } + } catch { + // noop + } + } + return null; +} + +function inferProviderErrorType(parsedJson: Record | null): string | undefined { + if (!parsedJson) return undefined; + const topLevelType = parsedJson.type; + if (typeof topLevelType === "string" && topLevelType) return topLevelType; + const nestedError = parsedJson.error; + if (typeof nestedError === "object" && nestedError !== null) { + const nestedType = (nestedError as Record).type; + if (typeof nestedType === "string" && nestedType) return nestedType; + } + return undefined; +} + +export function classifyChatError(input: RawChatErrorInput): NormalizedChatError { + const { error } = input; + const rawMessage = getErrorMessage(error); + const parsedJson = parseEmbeddedJson(rawMessage); + const errorCode = getErrorCode(error, parsedJson); + const providerErrorType = inferProviderErrorType(parsedJson); + const providerTypeNormalized = providerErrorType?.toLowerCase() ?? ""; + const errorName = error instanceof Error ? error.name : undefined; + + if (errorName === "AbortError") { + return { + kind: "stream_interrupted", + channel: "silent", + severity: "info", + telemetryEvent: "chat_error", + isExpected: true, + userMessage: "Request canceled.", + rawMessage, + errorCode, + details: { flow: input.flow }, + }; + } + + if (errorCode === "PREMIUM_QUOTA_EXHAUSTED") { + return { + kind: "premium_quota_exhausted", + channel: "pinned_inline", + severity: "info", + telemetryEvent: "chat_blocked", + isExpected: true, + userMessage: + "Buy more tokens to continue with this model, or switch to a free model.", + assistantMessage: PREMIUM_QUOTA_ASSISTANT_MESSAGE, + rawMessage, + errorCode: errorCode ?? "PREMIUM_QUOTA_EXHAUSTED", + details: { flow: input.flow }, + }; + } + + if ( + errorCode === "THREAD_BUSY" + ) { + return { + kind: "thread_busy", + channel: "toast", + severity: "warn", + telemetryEvent: "chat_blocked", + isExpected: true, + userMessage: "A previous response is still stopping. Please try again in a moment.", + rawMessage, + errorCode: errorCode ?? "THREAD_BUSY", + details: { flow: input.flow }, + }; + } + + if (errorCode === "SEND_FAILED_PRE_ACCEPT") { + return { + kind: "send_failed_pre_accept", + channel: "toast", + severity: "warn", + telemetryEvent: "chat_blocked", + isExpected: true, + userMessage: "Message not sent. Please retry.", + rawMessage, + errorCode: errorCode ?? "SEND_FAILED_PRE_ACCEPT", + details: { flow: input.flow }, + }; + } + + if ( + errorCode === "AUTH_EXPIRED" || + errorCode === "UNAUTHORIZED" + ) { + return { + kind: "auth_expired", + channel: "toast", + severity: "warn", + telemetryEvent: "chat_error", + isExpected: true, + userMessage: "Your session expired. Please sign in again.", + rawMessage, + errorCode: errorCode ?? "AUTH_EXPIRED", + details: { flow: input.flow }, + }; + } + + if ( + errorCode === "RATE_LIMITED" || + providerTypeNormalized === "rate_limit_error" + ) { + return { + kind: "rate_limited", + channel: "toast", + severity: "warn", + telemetryEvent: "chat_blocked", + isExpected: true, + userMessage: + "This model is temporarily rate-limited. Please try again in a few seconds or switch models.", + rawMessage, + errorCode: errorCode ?? "RATE_LIMITED", + details: { flow: input.flow, providerErrorType }, + }; + } + + if (errorCode === "NETWORK_ERROR") { + return { + kind: "network_offline", + channel: "toast", + severity: "warn", + telemetryEvent: "chat_error", + isExpected: true, + userMessage: "Connection issue. Please try again.", + rawMessage, + errorCode: errorCode ?? "NETWORK_ERROR", + details: { flow: input.flow }, + }; + } + + if ( + errorCode === "STREAM_PARSE_ERROR" + ) { + return { + kind: "stream_parse_error", + channel: "toast", + severity: "error", + telemetryEvent: "chat_error", + isExpected: false, + userMessage: "We hit a response formatting issue. Please try again.", + rawMessage, + errorCode: errorCode ?? "STREAM_PARSE_ERROR", + details: { flow: input.flow }, + }; + } + + if ( + errorCode === "TOOL_EXECUTION_ERROR" + ) { + return { + kind: "tool_execution_error", + channel: "toast", + severity: "error", + telemetryEvent: "chat_error", + isExpected: false, + userMessage: "A tool failed while processing your request. Please try again.", + rawMessage, + errorCode: errorCode ?? "TOOL_EXECUTION_ERROR", + details: { flow: input.flow }, + }; + } + + if ( + errorCode === "PERSIST_MESSAGE_FAILED" + ) { + return { + kind: "persist_message_failed", + channel: "toast", + severity: "error", + telemetryEvent: "chat_error", + isExpected: false, + userMessage: "Response generated, but saving failed. Please retry once.", + rawMessage, + errorCode: errorCode ?? "PERSIST_MESSAGE_FAILED", + details: { flow: input.flow }, + }; + } + + if ( + errorCode === "SERVER_ERROR" + ) { + return { + kind: "server_error", + channel: "toast", + severity: "error", + telemetryEvent: "chat_error", + isExpected: false, + userMessage: "We couldn’t complete this response right now. Please try again.", + rawMessage, + errorCode: errorCode ?? "SERVER_ERROR", + details: { flow: input.flow, providerErrorType }, + }; + } + + return { + kind: "unknown", + channel: "toast", + severity: "error", + telemetryEvent: "chat_error", + isExpected: false, + userMessage: "We couldn’t complete this response right now. Please try again.", + rawMessage, + errorCode, + details: { flow: input.flow, providerErrorType }, + }; +} diff --git a/surfsense_web/lib/chat/streaming-state.ts b/surfsense_web/lib/chat/streaming-state.ts index 54faf7e7c..445bbe83d 100644 --- a/surfsense_web/lib/chat/streaming-state.ts +++ b/surfsense_web/lib/chat/streaming-state.ts @@ -546,7 +546,7 @@ export type SSEEvent = }>; }; } - | { type: "error"; errorText: string }; + | { type: "error"; errorText: string; errorCode?: string }; /** * Async generator that reads an SSE stream and yields parsed JSON objects. diff --git a/surfsense_web/lib/posthog/events.ts b/surfsense_web/lib/posthog/events.ts index 34ed3044d..30e58215a 100644 --- a/surfsense_web/lib/posthog/events.ts +++ b/surfsense_web/lib/posthog/events.ts @@ -1,5 +1,6 @@ import posthog from "posthog-js"; import { getConnectorTelemetryMeta } from "@/components/assistant-ui/connector-popup/constants/connector-constants"; +import type { ChatErrorKind, ChatFlow, ChatErrorSeverity } from "@/lib/chat/chat-error-classifier"; /** * PostHog Analytics Event Definitions @@ -139,6 +140,55 @@ export function trackChatError(searchSpaceId: number, chatId: number, error?: st }); } +export interface ChatFailureTelemetry { + flow: ChatFlow; + kind: ChatErrorKind; + error_code?: string; + severity: ChatErrorSeverity; + is_expected: boolean; + message?: string; +} + +export function trackChatBlocked( + searchSpaceId: number, + chatId: number | null, + payload: ChatFailureTelemetry +) { + safeCapture( + "chat_blocked", + compact({ + search_space_id: searchSpaceId, + chat_id: chatId ?? undefined, + flow: payload.flow, + kind: payload.kind, + error_code: payload.error_code, + severity: payload.severity, + is_expected: payload.is_expected, + message: payload.message, + }) + ); +} + +export function trackChatErrorDetailed( + searchSpaceId: number, + chatId: number | null, + payload: ChatFailureTelemetry +) { + safeCapture( + "chat_error", + compact({ + search_space_id: searchSpaceId, + chat_id: chatId ?? undefined, + flow: payload.flow, + kind: payload.kind, + error_code: payload.error_code, + severity: payload.severity, + is_expected: payload.is_expected, + message: payload.message, + }) + ); +} + /** * Track a message sent from the unauthenticated "free" / anonymous chat * flow. This is intentionally a separate event from `chat_message_sent`