Merge pull request #1325 from AnishSarkar22/feat/split-auto-free-premium

feat(chat): thread-level auto model pinning & structured chat errors
This commit is contained in:
Rohan Verma 2026-04-30 15:18:26 -07:00 committed by GitHub
commit f129207c50
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
17 changed files with 2501 additions and 353 deletions

View file

@ -0,0 +1,65 @@
"""138_add_thread_auto_model_pinning_fields
Revision ID: 138
Revises: 137
Create Date: 2026-04-30
Add thread-level fields to persist Auto (Fastest) model pinning metadata:
- pinned_llm_config_id: concrete resolved config id used for this thread
- pinned_auto_mode: auto policy identifier (currently "auto_fastest")
- pinned_at: timestamp when the pin was created/refreshed
"""
from __future__ import annotations
from collections.abc import Sequence
from alembic import op
revision: str = "138"
down_revision: str | None = "137"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
def upgrade() -> None:
op.execute(
"ALTER TABLE new_chat_threads "
"ADD COLUMN IF NOT EXISTS pinned_llm_config_id INTEGER"
)
op.execute(
"ALTER TABLE new_chat_threads "
"ADD COLUMN IF NOT EXISTS pinned_auto_mode VARCHAR(32)"
)
op.execute(
"ALTER TABLE new_chat_threads "
"ADD COLUMN IF NOT EXISTS pinned_at TIMESTAMP WITH TIME ZONE"
)
op.execute(
"CREATE INDEX IF NOT EXISTS ix_new_chat_threads_pinned_llm_config_id "
"ON new_chat_threads (pinned_llm_config_id)"
)
op.execute(
"CREATE INDEX IF NOT EXISTS ix_new_chat_threads_pinned_auto_mode "
"ON new_chat_threads (pinned_auto_mode)"
)
def downgrade() -> None:
op.execute(
"DROP INDEX IF EXISTS ix_new_chat_threads_pinned_auto_mode"
)
op.execute(
"DROP INDEX IF EXISTS ix_new_chat_threads_pinned_llm_config_id"
)
op.execute(
"ALTER TABLE new_chat_threads DROP COLUMN IF EXISTS pinned_at"
)
op.execute(
"ALTER TABLE new_chat_threads DROP COLUMN IF EXISTS pinned_auto_mode"
)
op.execute(
"ALTER TABLE new_chat_threads DROP COLUMN IF EXISTS pinned_llm_config_id"
)

View file

@ -638,6 +638,13 @@ class NewChatThread(BaseModel, TimestampMixin):
default=False, default=False,
server_default="false", server_default="false",
) )
# Auto model pinning metadata:
# - pinned_llm_config_id stores the concrete resolved model config id.
# - pinned_auto_mode indicates which auto policy produced the pin.
# This allows Auto (Fastest) to resolve once per thread and stay stable.
pinned_llm_config_id = Column(Integer, nullable=True, index=True)
pinned_auto_mode = Column(String(32), nullable=True, index=True)
pinned_at = Column(TIMESTAMP(timezone=True), nullable=True)
# Relationships # Relationships
search_space = relationship("SearchSpace", back_populates="new_chat_threads") search_space = relationship("SearchSpace", back_populates="new_chat_threads")

View file

@ -1924,6 +1924,7 @@ async def regenerate_response(
filesystem_selection=filesystem_selection, filesystem_selection=filesystem_selection,
request_id=getattr(http_request.state, "request_id", "unknown"), request_id=getattr(http_request.state, "request_id", "unknown"),
user_image_data_urls=regenerate_image_urls or None, user_image_data_urls=regenerate_image_urls or None,
flow="regenerate",
): ):
yield chunk yield chunk
streaming_completed = True streaming_completed = True

View file

@ -3,7 +3,7 @@ import logging
from fastapi import APIRouter, Depends, HTTPException from fastapi import APIRouter, Depends, HTTPException
from langchain_core.messages import HumanMessage from langchain_core.messages import HumanMessage
from pydantic import BaseModel as PydanticBaseModel from pydantic import BaseModel as PydanticBaseModel
from sqlalchemy import func from sqlalchemy import func, update
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select from sqlalchemy.future import select
@ -15,6 +15,7 @@ from app.agents.new_chat.tools.update_memory import MEMORY_HARD_LIMIT, _save_mem
from app.config import config from app.config import config
from app.db import ( from app.db import (
ImageGenerationConfig, ImageGenerationConfig,
NewChatThread,
NewLLMConfig, NewLLMConfig,
Permission, Permission,
SearchSpace, SearchSpace,
@ -790,9 +791,31 @@ async def update_llm_preferences(
# Update preferences # Update preferences
update_data = preferences.model_dump(exclude_unset=True) update_data = preferences.model_dump(exclude_unset=True)
previous_agent_llm_id = search_space.agent_llm_id
for key, value in update_data.items(): for key, value in update_data.items():
setattr(search_space, key, value) setattr(search_space, key, value)
agent_llm_changed = (
"agent_llm_id" in update_data
and update_data["agent_llm_id"] != previous_agent_llm_id
)
if agent_llm_changed:
await session.execute(
update(NewChatThread)
.where(NewChatThread.search_space_id == search_space_id)
.values(
pinned_llm_config_id=None,
pinned_auto_mode=None,
pinned_at=None,
)
)
logger.info(
"Cleared auto model pins for search_space_id=%s after agent_llm_id change (%s -> %s)",
search_space_id,
previous_agent_llm_id,
update_data["agent_llm_id"],
)
await session.commit() await session.commit()
await session.refresh(search_space) await session.refresh(search_space)

View file

@ -0,0 +1,218 @@
"""Resolve and persist Auto (Fastest) model pins per chat thread.
Auto (Fastest) is represented by ``agent_llm_id == 0``. For chat threads we
resolve that virtual mode to one concrete global LLM config exactly once and
persist the chosen config id on ``new_chat_threads`` so subsequent turns are
stable.
"""
from __future__ import annotations
import hashlib
import logging
from dataclasses import dataclass
from datetime import UTC, datetime
from uuid import UUID
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.config import config
from app.db import NewChatThread
from app.services.token_quota_service import TokenQuotaService
logger = logging.getLogger(__name__)
AUTO_FASTEST_ID = 0
AUTO_FASTEST_MODE = "auto_fastest"
@dataclass
class AutoPinResolution:
resolved_llm_config_id: int
resolved_tier: str
from_existing_pin: bool
def _is_usable_global_config(cfg: dict) -> bool:
return bool(
cfg.get("id") is not None
and cfg.get("model_name")
and cfg.get("provider")
and cfg.get("api_key")
)
def _global_candidates() -> list[dict]:
candidates = [cfg for cfg in config.GLOBAL_LLM_CONFIGS if _is_usable_global_config(cfg)]
return sorted(candidates, key=lambda c: int(c.get("id", 0)))
def _tier_of(cfg: dict) -> str:
return str(cfg.get("billing_tier", "free")).lower()
def _deterministic_pick(candidates: list[dict], thread_id: int) -> dict:
digest = hashlib.sha256(f"{AUTO_FASTEST_MODE}:{thread_id}".encode()).digest()
idx = int.from_bytes(digest[:8], "big") % len(candidates)
return candidates[idx]
def _to_uuid(user_id: str | UUID | None) -> UUID | None:
if user_id is None:
return None
if isinstance(user_id, UUID):
return user_id
try:
return UUID(str(user_id))
except Exception:
return None
async def _is_premium_eligible(session: AsyncSession, user_id: str | UUID | None) -> bool:
parsed = _to_uuid(user_id)
if parsed is None:
return False
usage = await TokenQuotaService.premium_get_usage(session, parsed)
return bool(usage.allowed)
async def resolve_or_get_pinned_llm_config_id(
session: AsyncSession,
*,
thread_id: int,
search_space_id: int,
user_id: str | UUID | None,
selected_llm_config_id: int,
force_repin_free: bool = False,
) -> AutoPinResolution:
"""Resolve Auto (Fastest) to one concrete config id and persist pin metadata.
For non-auto selections, this function clears existing auto pin metadata and
returns the selected id as-is.
"""
thread = (
(
await session.execute(
select(NewChatThread)
.where(NewChatThread.id == thread_id)
.with_for_update(of=NewChatThread)
)
)
.unique()
.scalar_one_or_none()
)
if thread is None:
raise ValueError(f"Thread {thread_id} not found")
if thread.search_space_id != search_space_id:
raise ValueError(
f"Thread {thread_id} does not belong to search space {search_space_id}"
)
# Explicit model selected: clear stale auto pin metadata.
if selected_llm_config_id != AUTO_FASTEST_ID:
if (
thread.pinned_llm_config_id is not None
or thread.pinned_auto_mode is not None
or thread.pinned_at is not None
):
thread.pinned_llm_config_id = None
thread.pinned_auto_mode = None
thread.pinned_at = None
await session.commit()
return AutoPinResolution(
resolved_llm_config_id=selected_llm_config_id,
resolved_tier="explicit",
from_existing_pin=False,
)
candidates = _global_candidates()
if not candidates:
raise ValueError("No usable global LLM configs are available for Auto mode")
candidate_by_id = {int(c["id"]): c for c in candidates}
# Reuse existing valid pin without re-checking current quota (no silent tier switch),
# unless the caller explicitly requests a forced repin to free.
pinned_id = thread.pinned_llm_config_id
if (
not force_repin_free
and
thread.pinned_auto_mode == AUTO_FASTEST_MODE
and pinned_id is not None
and int(pinned_id) in candidate_by_id
):
pinned_cfg = candidate_by_id[int(pinned_id)]
logger.info(
"auto_pin_reused thread_id=%s search_space_id=%s resolved_config_id=%s tier=%s",
thread_id,
search_space_id,
pinned_id,
_tier_of(pinned_cfg),
)
return AutoPinResolution(
resolved_llm_config_id=int(pinned_id),
resolved_tier=_tier_of(pinned_cfg),
from_existing_pin=True,
)
if pinned_id is not None:
logger.info(
"auto_pin_invalid thread_id=%s search_space_id=%s pinned_config_id=%s pinned_auto_mode=%s",
thread_id,
search_space_id,
pinned_id,
thread.pinned_auto_mode,
)
premium_eligible = False if force_repin_free else await _is_premium_eligible(session, user_id)
if premium_eligible:
eligible = candidates
else:
eligible = [c for c in candidates if _tier_of(c) != "premium"]
if not eligible:
raise ValueError(
"Auto mode could not find an eligible LLM config for this user and quota state"
)
selected_cfg = _deterministic_pick(eligible, thread_id)
selected_id = int(selected_cfg["id"])
selected_tier = _tier_of(selected_cfg)
thread.pinned_llm_config_id = selected_id
thread.pinned_auto_mode = AUTO_FASTEST_MODE
thread.pinned_at = datetime.now(UTC)
await session.commit()
if force_repin_free:
logger.info(
"auto_pin_forced_free_repin thread_id=%s search_space_id=%s previous_config_id=%s resolved_config_id=%s",
thread_id,
search_space_id,
pinned_id,
selected_id,
)
if pinned_id is None:
logger.info(
"auto_pin_created thread_id=%s search_space_id=%s resolved_config_id=%s tier=%s premium_eligible=%s",
thread_id,
search_space_id,
selected_id,
selected_tier,
premium_eligible,
)
else:
logger.info(
"auto_pin_repaired thread_id=%s search_space_id=%s previous_config_id=%s resolved_config_id=%s tier=%s premium_eligible=%s",
thread_id,
search_space_id,
pinned_id,
selected_id,
selected_tier,
premium_eligible,
)
return AutoPinResolution(
resolved_llm_config_id=selected_id,
resolved_tier=selected_tier,
from_existing_pin=False,
)

View file

@ -565,20 +565,24 @@ class VercelStreamingService:
# Error Part # Error Part
# ========================================================================= # =========================================================================
def format_error(self, error_text: str) -> str: def format_error(self, error_text: str, error_code: str | None = None) -> str:
""" """
Format an error message. Format an error message.
Args: Args:
error_text: The error message text error_text: The error message text
error_code: Optional machine-readable error code for frontend branching
Returns: Returns:
str: SSE formatted error part str: SSE formatted error part
Example output: Example output:
data: {"type":"error","errorText":"Something went wrong"} data: {"type":"error","errorText":"Something went wrong","errorCode":"SOME_CODE"}
""" """
return self._format_sse({"type": "error", "errorText": error_text}) payload: dict[str, str] = {"type": "error", "errorText": error_text}
if error_code:
payload["errorCode"] = error_code
return self._format_sse(payload)
# ========================================================================= # =========================================================================
# Tool Parts # Tool Parts

View file

@ -19,7 +19,8 @@ import re
import time import time
from collections.abc import AsyncGenerator from collections.abc import AsyncGenerator
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any from functools import partial
from typing import Any, Literal
from uuid import UUID from uuid import UUID
import anyio import anyio
@ -30,6 +31,7 @@ from sqlalchemy.orm import selectinload
from app.agents.new_chat.chat_deepagent import create_surfsense_deep_agent from app.agents.new_chat.chat_deepagent import create_surfsense_deep_agent
from app.agents.new_chat.checkpointer import get_checkpointer from app.agents.new_chat.checkpointer import get_checkpointer
from app.agents.new_chat.errors import BusyError
from app.agents.new_chat.feature_flags import get_flags from app.agents.new_chat.feature_flags import get_flags
from app.agents.new_chat.filesystem_selection import FilesystemMode, FilesystemSelection from app.agents.new_chat.filesystem_selection import FilesystemMode, FilesystemSelection
from app.agents.new_chat.llm_config import ( from app.agents.new_chat.llm_config import (
@ -57,6 +59,7 @@ from app.db import (
shielded_async_session, shielded_async_session,
) )
from app.prompts import TITLE_GENERATION_PROMPT from app.prompts import TITLE_GENERATION_PROMPT
from app.services.auto_model_pin_service import resolve_or_get_pinned_llm_config_id
from app.services.chat_session_state_service import ( from app.services.chat_session_state_service import (
clear_ai_responding, clear_ai_responding,
set_ai_responding, set_ai_responding,
@ -338,6 +341,138 @@ def _log_file_contract(stage: str, result: StreamResult, **extra: Any) -> None:
) )
def _log_chat_stream_error(
*,
flow: Literal["new", "resume", "regenerate"],
error_kind: str,
error_code: str | None,
severity: Literal["info", "warn", "error"],
is_expected: bool,
request_id: str | None,
thread_id: int | None,
search_space_id: int | None,
user_id: str | None,
message: str,
extra: dict[str, Any] | None = None,
) -> None:
payload: dict[str, Any] = {
"event": "chat_stream_error",
"flow": flow,
"error_kind": error_kind,
"error_code": error_code,
"severity": severity,
"is_expected": is_expected,
"request_id": request_id or "unknown",
"thread_id": thread_id,
"search_space_id": search_space_id,
"user_id": user_id,
"message": message,
}
if extra:
payload.update(extra)
logger = logging.getLogger(__name__)
rendered = json.dumps(payload, ensure_ascii=False)
if severity == "error":
logger.error("[chat_stream_error] %s", rendered)
elif severity == "warn":
logger.warning("[chat_stream_error] %s", rendered)
else:
logger.info("[chat_stream_error] %s", rendered)
def _parse_error_payload(message: str) -> dict[str, Any] | None:
candidates = [message]
first_brace_idx = message.find("{")
if first_brace_idx >= 0:
candidates.append(message[first_brace_idx:])
for candidate in candidates:
try:
parsed = json.loads(candidate)
if isinstance(parsed, dict):
return parsed
except Exception:
continue
return None
def _classify_stream_exception(
exc: Exception,
*,
flow_label: str,
) -> tuple[str, str, Literal["info", "warn", "error"], bool, str]:
raw = str(exc)
if isinstance(exc, BusyError) or "Thread is busy with another request" in raw:
return (
"thread_busy",
"THREAD_BUSY",
"warn",
True,
"Another response is still finishing for this thread. Please try again in a moment.",
)
parsed = _parse_error_payload(raw)
provider_error_type = ""
if parsed:
top_type = parsed.get("type")
if isinstance(top_type, str):
provider_error_type = top_type.lower()
nested = parsed.get("error")
if isinstance(nested, dict):
nested_type = nested.get("type")
if isinstance(nested_type, str):
provider_error_type = nested_type.lower()
if provider_error_type == "rate_limit_error":
return (
"rate_limited",
"RATE_LIMITED",
"warn",
True,
"This model is temporarily rate-limited. Please try again in a few seconds or switch models.",
)
return (
"server_error",
"SERVER_ERROR",
"error",
False,
f"Error during {flow_label}: {raw}",
)
def _emit_stream_terminal_error(
*,
streaming_service: VercelStreamingService,
flow: str,
request_id: str | None,
thread_id: int,
search_space_id: int,
user_id: str | None,
message: str,
error_kind: str = "server_error",
error_code: str = "SERVER_ERROR",
severity: Literal["info", "warn", "error"] = "error",
is_expected: bool = False,
extra: dict[str, Any] | None = None,
) -> str:
_log_chat_stream_error(
flow=flow,
error_kind=error_kind,
error_code=error_code,
severity=severity,
is_expected=is_expected,
request_id=request_id,
thread_id=thread_id,
search_space_id=search_space_id,
user_id=user_id,
message=message,
extra=extra,
)
return streaming_service.format_error(message, error_code=error_code)
def _legacy_match_lc_id( def _legacy_match_lc_id(
pending_tool_call_chunks: list[dict[str, Any]], pending_tool_call_chunks: list[dict[str, Any]],
tool_name: str, tool_name: str,
@ -1913,6 +2048,7 @@ async def stream_new_chat(
filesystem_selection: FilesystemSelection | None = None, filesystem_selection: FilesystemSelection | None = None,
request_id: str | None = None, request_id: str | None = None,
user_image_data_urls: list[str] | None = None, user_image_data_urls: list[str] | None = None,
flow: Literal["new", "regenerate"] = "new",
) -> AsyncGenerator[str, None]: ) -> AsyncGenerator[str, None]:
""" """
Stream chat responses from the new SurfSense deep agent. Stream chat responses from the new SurfSense deep agent.
@ -1964,6 +2100,16 @@ async def stream_new_chat(
_premium_reserved = 0 _premium_reserved = 0
_premium_request_id: str | None = None _premium_request_id: str | None = None
_emit_stream_error = partial(
_emit_stream_terminal_error,
streaming_service=streaming_service,
flow=flow,
request_id=request_id,
thread_id=chat_id,
search_space_id=search_space_id,
user_id=user_id,
)
session = async_session_maker() session = async_session_maker()
try: try:
# Mark AI as responding to this user for live collaboration # Mark AI as responding to this user for live collaboration
@ -1971,49 +2117,78 @@ async def stream_new_chat(
await set_ai_responding(session, chat_id, UUID(user_id)) await set_ai_responding(session, chat_id, UUID(user_id))
# Load LLM config - supports both YAML (negative IDs) and database (positive IDs) # Load LLM config - supports both YAML (negative IDs) and database (positive IDs)
agent_config: AgentConfig | None = None agent_config: AgentConfig | None = None
requested_llm_config_id = llm_config_id
async def _load_llm_bundle(
config_id: int,
) -> tuple[Any, AgentConfig | None, str | None]:
if config_id >= 0:
loaded_agent_config = await load_agent_config(
session=session,
config_id=config_id,
search_space_id=search_space_id,
)
if not loaded_agent_config:
return (
None,
None,
f"Failed to load NewLLMConfig with id {config_id}",
)
return (
create_chat_litellm_from_agent_config(loaded_agent_config),
loaded_agent_config,
None,
)
loaded_llm_config = load_global_llm_config_by_id(config_id)
if not loaded_llm_config:
return None, None, f"Failed to load LLM config with id {config_id}"
return (
create_chat_litellm_from_config(loaded_llm_config),
AgentConfig.from_yaml_config(loaded_llm_config),
None,
)
_t0 = time.perf_counter() _t0 = time.perf_counter()
if llm_config_id >= 0: try:
# Positive ID: Load from NewLLMConfig database table llm_config_id = (
agent_config = await load_agent_config( await resolve_or_get_pinned_llm_config_id(
session=session, session,
config_id=llm_config_id, thread_id=chat_id,
search_space_id=search_space_id, search_space_id=search_space_id,
user_id=user_id,
selected_llm_config_id=llm_config_id,
)
).resolved_llm_config_id
except ValueError as pin_error:
yield _emit_stream_error(
message=str(pin_error),
error_kind="server_error",
error_code="SERVER_ERROR",
) )
if not agent_config: yield streaming_service.format_done()
yield streaming_service.format_error( return
f"Failed to load NewLLMConfig with id {llm_config_id}"
)
yield streaming_service.format_done()
return
# Create ChatLiteLLM from AgentConfig llm, agent_config, llm_load_error = await _load_llm_bundle(llm_config_id)
llm = create_chat_litellm_from_agent_config(agent_config) if llm_load_error:
else: yield _emit_stream_error(
# Negative ID: Load from in-memory global configs (includes dynamic OpenRouter models) message=llm_load_error,
llm_config = load_global_llm_config_by_id(llm_config_id) error_kind="server_error",
if not llm_config: error_code="SERVER_ERROR",
yield streaming_service.format_error( )
f"Failed to load LLM config with id {llm_config_id}" yield streaming_service.format_done()
) return
yield streaming_service.format_done()
return
# Create ChatLiteLLM from global config dict
llm = create_chat_litellm_from_config(llm_config)
agent_config = AgentConfig.from_yaml_config(llm_config)
_perf_log.info( _perf_log.info(
"[stream_new_chat] LLM config loaded in %.3fs (config_id=%s)", "[stream_new_chat] LLM config loaded in %.3fs (config_id=%s)",
time.perf_counter() - _t0, time.perf_counter() - _t0,
llm_config_id, llm_config_id,
) )
# Premium quota reservation — applies to explicitly premium configs # Premium quota reservation for pinned premium model only.
# AND Auto mode (which may route to premium models).
_needs_premium_quota = ( _needs_premium_quota = (
agent_config is not None agent_config is not None
and user_id and user_id
and (agent_config.is_premium or agent_config.is_auto_mode) and agent_config.is_premium
) )
if _needs_premium_quota: if _needs_premium_quota:
import uuid as _uuid import uuid as _uuid
@ -2036,19 +2211,79 @@ async def stream_new_chat(
) )
_premium_reserved = reserve_amount _premium_reserved = reserve_amount
if not quota_result.allowed: if not quota_result.allowed:
if agent_config.is_premium: if requested_llm_config_id == 0:
yield streaming_service.format_error( try:
"Premium token quota exceeded. Please purchase more tokens to continue using premium models." llm_config_id = (
await resolve_or_get_pinned_llm_config_id(
session,
thread_id=chat_id,
search_space_id=search_space_id,
user_id=user_id,
selected_llm_config_id=0,
force_repin_free=True,
)
).resolved_llm_config_id
except ValueError as pin_error:
yield _emit_stream_error(
message=str(pin_error),
error_kind="server_error",
error_code="SERVER_ERROR",
)
yield streaming_service.format_done()
return
llm, agent_config, llm_load_error = await _load_llm_bundle(llm_config_id)
if llm_load_error:
yield _emit_stream_error(
message=llm_load_error,
error_kind="server_error",
error_code="SERVER_ERROR",
)
yield streaming_service.format_done()
return
_premium_request_id = None
_premium_reserved = 0
_log_chat_stream_error(
flow=flow,
error_kind="premium_quota_exhausted",
error_code="PREMIUM_QUOTA_EXHAUSTED",
severity="info",
is_expected=True,
request_id=request_id,
thread_id=chat_id,
search_space_id=search_space_id,
user_id=user_id,
message=(
"Premium quota exhausted on pinned model; auto-fallback switched to a free model"
),
extra={
"fallback_config_id": llm_config_id,
"auto_fallback": True,
},
)
else:
yield _emit_stream_error(
message=(
"Buy more tokens to continue with this model, or switch to a free model"
),
error_kind="premium_quota_exhausted",
error_code="PREMIUM_QUOTA_EXHAUSTED",
severity="info",
is_expected=True,
extra={
"resolved_config_id": llm_config_id,
"auto_fallback": False,
},
) )
yield streaming_service.format_done() yield streaming_service.format_done()
return return
# Auto mode: quota exhausted but we can still proceed
# (the router may pick a free model). Reset reservation.
_premium_request_id = None
_premium_reserved = 0
if not llm: if not llm:
yield streaming_service.format_error("Failed to create LLM instance") yield _emit_stream_error(
message="Failed to create LLM instance",
error_kind="server_error",
error_code="SERVER_ERROR",
)
yield streaming_service.format_done() yield streaming_service.format_done()
return return
@ -2499,28 +2734,20 @@ async def stream_new_chat(
) )
# Finalize premium quota with actual tokens. # Finalize premium quota with actual tokens.
# For Auto mode, only count tokens from calls that used premium models.
if _premium_request_id and user_id: if _premium_request_id and user_id:
try: try:
from app.services.token_quota_service import TokenQuotaService from app.services.token_quota_service import TokenQuotaService
if agent_config and agent_config.is_auto_mode:
from app.services.llm_router_service import LLMRouterService
actual_premium_tokens = LLMRouterService.compute_premium_tokens(
accumulator.calls
)
else:
actual_premium_tokens = accumulator.grand_total
async with shielded_async_session() as quota_session: async with shielded_async_session() as quota_session:
await TokenQuotaService.premium_finalize( await TokenQuotaService.premium_finalize(
db_session=quota_session, db_session=quota_session,
user_id=UUID(user_id), user_id=UUID(user_id),
request_id=_premium_request_id, request_id=_premium_request_id,
actual_tokens=actual_premium_tokens, actual_tokens=accumulator.grand_total,
reserved_tokens=_premium_reserved, reserved_tokens=_premium_reserved,
) )
_premium_request_id = None
_premium_reserved = 0
except Exception: except Exception:
logging.getLogger(__name__).warning( logging.getLogger(__name__).warning(
"Failed to finalize premium quota for user %s", "Failed to finalize premium quota for user %s",
@ -2586,12 +2813,25 @@ async def stream_new_chat(
# Handle any errors # Handle any errors
import traceback import traceback
(
error_kind,
error_code,
severity,
is_expected,
user_message,
) = _classify_stream_exception(e, flow_label="chat")
error_message = f"Error during chat: {e!s}" error_message = f"Error during chat: {e!s}"
print(f"[stream_new_chat] {error_message}") print(f"[stream_new_chat] {error_message}")
print(f"[stream_new_chat] Exception type: {type(e).__name__}") print(f"[stream_new_chat] Exception type: {type(e).__name__}")
print(f"[stream_new_chat] Traceback:\n{traceback.format_exc()}") print(f"[stream_new_chat] Traceback:\n{traceback.format_exc()}")
yield streaming_service.format_error(error_message) yield _emit_stream_error(
message=user_message,
error_kind=error_kind,
error_code=error_code,
severity=severity,
is_expected=is_expected,
)
yield streaming_service.format_finish_step() yield streaming_service.format_finish_step()
yield streaming_service.format_finish() yield streaming_service.format_finish()
yield streaming_service.format_done() yield streaming_service.format_done()
@ -2706,36 +2946,83 @@ async def stream_resume_chat(
accumulator = start_turn() accumulator = start_turn()
_emit_stream_error = partial(
_emit_stream_terminal_error,
streaming_service=streaming_service,
flow="resume",
request_id=request_id,
thread_id=chat_id,
search_space_id=search_space_id,
user_id=user_id,
)
session = async_session_maker() session = async_session_maker()
try: try:
if user_id: if user_id:
await set_ai_responding(session, chat_id, UUID(user_id)) await set_ai_responding(session, chat_id, UUID(user_id))
agent_config: AgentConfig | None = None agent_config: AgentConfig | None = None
_t0 = time.perf_counter() requested_llm_config_id = llm_config_id
if llm_config_id >= 0:
agent_config = await load_agent_config( async def _load_llm_bundle(
session=session, config_id: int,
config_id=llm_config_id, ) -> tuple[Any, AgentConfig | None, str | None]:
search_space_id=search_space_id, if config_id >= 0:
loaded_agent_config = await load_agent_config(
session=session,
config_id=config_id,
search_space_id=search_space_id,
)
if not loaded_agent_config:
return (
None,
None,
f"Failed to load NewLLMConfig with id {config_id}",
)
return (
create_chat_litellm_from_agent_config(loaded_agent_config),
loaded_agent_config,
None,
)
loaded_llm_config = load_global_llm_config_by_id(config_id)
if not loaded_llm_config:
return None, None, f"Failed to load LLM config with id {config_id}"
return (
create_chat_litellm_from_config(loaded_llm_config),
AgentConfig.from_yaml_config(loaded_llm_config),
None,
) )
if not agent_config:
yield streaming_service.format_error( _t0 = time.perf_counter()
f"Failed to load NewLLMConfig with id {llm_config_id}" try:
llm_config_id = (
await resolve_or_get_pinned_llm_config_id(
session,
thread_id=chat_id,
search_space_id=search_space_id,
user_id=user_id,
selected_llm_config_id=llm_config_id,
) )
yield streaming_service.format_done() ).resolved_llm_config_id
return except ValueError as pin_error:
llm = create_chat_litellm_from_agent_config(agent_config) yield _emit_stream_error(
else: message=str(pin_error),
llm_config = load_global_llm_config_by_id(llm_config_id) error_kind="server_error",
if not llm_config: error_code="SERVER_ERROR",
yield streaming_service.format_error( )
f"Failed to load LLM config with id {llm_config_id}" yield streaming_service.format_done()
) return
yield streaming_service.format_done()
return llm, agent_config, llm_load_error = await _load_llm_bundle(llm_config_id)
llm = create_chat_litellm_from_config(llm_config) if llm_load_error:
agent_config = AgentConfig.from_yaml_config(llm_config) yield _emit_stream_error(
message=llm_load_error,
error_kind="server_error",
error_code="SERVER_ERROR",
)
yield streaming_service.format_done()
return
_perf_log.info( _perf_log.info(
"[stream_resume] LLM config loaded in %.3fs", time.perf_counter() - _t0 "[stream_resume] LLM config loaded in %.3fs", time.perf_counter() - _t0
) )
@ -2746,7 +3033,7 @@ async def stream_resume_chat(
_resume_needs_premium = ( _resume_needs_premium = (
agent_config is not None agent_config is not None
and user_id and user_id
and (agent_config.is_premium or agent_config.is_auto_mode) and agent_config.is_premium
) )
if _resume_needs_premium: if _resume_needs_premium:
import uuid as _uuid import uuid as _uuid
@ -2769,17 +3056,79 @@ async def stream_resume_chat(
) )
_resume_premium_reserved = reserve_amount _resume_premium_reserved = reserve_amount
if not quota_result.allowed: if not quota_result.allowed:
if agent_config.is_premium: if requested_llm_config_id == 0:
yield streaming_service.format_error( try:
"Premium token quota exceeded. Please purchase more tokens to continue using premium models." llm_config_id = (
await resolve_or_get_pinned_llm_config_id(
session,
thread_id=chat_id,
search_space_id=search_space_id,
user_id=user_id,
selected_llm_config_id=0,
force_repin_free=True,
)
).resolved_llm_config_id
except ValueError as pin_error:
yield _emit_stream_error(
message=str(pin_error),
error_kind="server_error",
error_code="SERVER_ERROR",
)
yield streaming_service.format_done()
return
llm, agent_config, llm_load_error = await _load_llm_bundle(llm_config_id)
if llm_load_error:
yield _emit_stream_error(
message=llm_load_error,
error_kind="server_error",
error_code="SERVER_ERROR",
)
yield streaming_service.format_done()
return
_resume_premium_request_id = None
_resume_premium_reserved = 0
_log_chat_stream_error(
flow="resume",
error_kind="premium_quota_exhausted",
error_code="PREMIUM_QUOTA_EXHAUSTED",
severity="info",
is_expected=True,
request_id=request_id,
thread_id=chat_id,
search_space_id=search_space_id,
user_id=user_id,
message=(
"Premium quota exhausted on pinned model; auto-fallback switched to a free model"
),
extra={
"fallback_config_id": llm_config_id,
"auto_fallback": True,
},
)
else:
yield _emit_stream_error(
message=(
"Buy more tokens to continue with this model, or switch to a free model"
),
error_kind="premium_quota_exhausted",
error_code="PREMIUM_QUOTA_EXHAUSTED",
severity="info",
is_expected=True,
extra={
"resolved_config_id": llm_config_id,
"auto_fallback": False,
},
) )
yield streaming_service.format_done() yield streaming_service.format_done()
return return
_resume_premium_request_id = None
_resume_premium_reserved = 0
if not llm: if not llm:
yield streaming_service.format_error("Failed to create LLM instance") yield _emit_stream_error(
message="Failed to create LLM instance",
error_kind="server_error",
error_code="SERVER_ERROR",
)
yield streaming_service.format_done() yield streaming_service.format_done()
return return
@ -2920,23 +3269,16 @@ async def stream_resume_chat(
try: try:
from app.services.token_quota_service import TokenQuotaService from app.services.token_quota_service import TokenQuotaService
if agent_config and agent_config.is_auto_mode:
from app.services.llm_router_service import LLMRouterService
actual_premium_tokens = LLMRouterService.compute_premium_tokens(
accumulator.calls
)
else:
actual_premium_tokens = accumulator.grand_total
async with shielded_async_session() as quota_session: async with shielded_async_session() as quota_session:
await TokenQuotaService.premium_finalize( await TokenQuotaService.premium_finalize(
db_session=quota_session, db_session=quota_session,
user_id=UUID(user_id), user_id=UUID(user_id),
request_id=_resume_premium_request_id, request_id=_resume_premium_request_id,
actual_tokens=actual_premium_tokens, actual_tokens=accumulator.grand_total,
reserved_tokens=_resume_premium_reserved, reserved_tokens=_resume_premium_reserved,
) )
_resume_premium_request_id = None
_resume_premium_reserved = 0
except Exception: except Exception:
logging.getLogger(__name__).warning( logging.getLogger(__name__).warning(
"Failed to finalize premium quota for user %s (resume)", "Failed to finalize premium quota for user %s (resume)",
@ -2970,10 +3312,23 @@ async def stream_resume_chat(
except Exception as e: except Exception as e:
import traceback import traceback
(
error_kind,
error_code,
severity,
is_expected,
user_message,
) = _classify_stream_exception(e, flow_label="resume")
error_message = f"Error during resume: {e!s}" error_message = f"Error during resume: {e!s}"
print(f"[stream_resume_chat] {error_message}") print(f"[stream_resume_chat] {error_message}")
print(f"[stream_resume_chat] Traceback:\n{traceback.format_exc()}") print(f"[stream_resume_chat] Traceback:\n{traceback.format_exc()}")
yield streaming_service.format_error(error_message) yield _emit_stream_error(
message=user_message,
error_kind=error_kind,
error_code=error_code,
severity=severity,
is_expected=is_expected,
)
yield streaming_service.format_finish_step() yield streaming_service.format_finish_step()
yield streaming_service.format_finish() yield streaming_service.format_finish()
yield streaming_service.format_done() yield streaming_service.format_done()

View file

@ -0,0 +1,329 @@
from __future__ import annotations
from dataclasses import dataclass
from types import SimpleNamespace
import pytest
from app.services.auto_model_pin_service import (
AUTO_FASTEST_MODE,
resolve_or_get_pinned_llm_config_id,
)
pytestmark = pytest.mark.unit
@dataclass
class _FakeQuotaResult:
allowed: bool
class _FakeExecResult:
def __init__(self, thread):
self._thread = thread
def unique(self):
return self
def scalar_one_or_none(self):
return self._thread
class _FakeSession:
def __init__(self, thread):
self.thread = thread
self.commit_count = 0
async def execute(self, _stmt):
return _FakeExecResult(self.thread)
async def commit(self):
self.commit_count += 1
def _thread(
*,
search_space_id: int = 10,
pinned_llm_config_id: int | None = None,
pinned_auto_mode: str | None = None,
):
return SimpleNamespace(
id=1,
search_space_id=search_space_id,
pinned_llm_config_id=pinned_llm_config_id,
pinned_auto_mode=pinned_auto_mode,
pinned_at=None,
)
@pytest.mark.asyncio
async def test_auto_first_turn_pins_one_model(monkeypatch):
from app.config import config
session = _FakeSession(_thread())
monkeypatch.setattr(
config,
"GLOBAL_LLM_CONFIGS",
[
{"id": -2, "provider": "OPENAI", "model_name": "gpt-free", "api_key": "k1"},
{"id": -1, "provider": "OPENAI", "model_name": "gpt-prem", "api_key": "k2", "billing_tier": "premium"},
],
)
async def _allowed(*_args, **_kwargs):
return _FakeQuotaResult(allowed=True)
monkeypatch.setattr(
"app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage",
_allowed,
)
result = await resolve_or_get_pinned_llm_config_id(
session,
thread_id=1,
search_space_id=10,
user_id="00000000-0000-0000-0000-000000000001",
selected_llm_config_id=0,
)
assert result.resolved_llm_config_id in {-1, -2}
assert session.thread.pinned_llm_config_id == result.resolved_llm_config_id
assert session.thread.pinned_auto_mode == AUTO_FASTEST_MODE
assert session.thread.pinned_at is not None
assert session.commit_count == 1
@pytest.mark.asyncio
async def test_next_turn_reuses_existing_pin(monkeypatch):
from app.config import config
session = _FakeSession(
_thread(pinned_llm_config_id=-1, pinned_auto_mode=AUTO_FASTEST_MODE)
)
monkeypatch.setattr(
config,
"GLOBAL_LLM_CONFIGS",
[
{"id": -1, "provider": "OPENAI", "model_name": "gpt-prem", "api_key": "k2", "billing_tier": "premium"},
],
)
async def _must_not_call(*_args, **_kwargs):
raise AssertionError("premium_get_usage should not be called for valid pin reuse")
monkeypatch.setattr(
"app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage",
_must_not_call,
)
result = await resolve_or_get_pinned_llm_config_id(
session,
thread_id=1,
search_space_id=10,
user_id="00000000-0000-0000-0000-000000000001",
selected_llm_config_id=0,
)
assert result.resolved_llm_config_id == -1
assert result.from_existing_pin is True
assert session.commit_count == 0
@pytest.mark.asyncio
async def test_premium_eligible_auto_can_pin_premium(monkeypatch):
from app.config import config
session = _FakeSession(_thread())
monkeypatch.setattr(
config,
"GLOBAL_LLM_CONFIGS",
[
{"id": -1, "provider": "OPENAI", "model_name": "gpt-prem", "api_key": "k2", "billing_tier": "premium"},
],
)
async def _allowed(*_args, **_kwargs):
return _FakeQuotaResult(allowed=True)
monkeypatch.setattr(
"app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage",
_allowed,
)
result = await resolve_or_get_pinned_llm_config_id(
session,
thread_id=1,
search_space_id=10,
user_id="00000000-0000-0000-0000-000000000001",
selected_llm_config_id=0,
)
assert result.resolved_llm_config_id == -1
assert result.resolved_tier == "premium"
@pytest.mark.asyncio
async def test_premium_ineligible_auto_pins_free_only(monkeypatch):
from app.config import config
session = _FakeSession(_thread())
monkeypatch.setattr(
config,
"GLOBAL_LLM_CONFIGS",
[
{"id": -2, "provider": "OPENAI", "model_name": "gpt-free", "api_key": "k1", "billing_tier": "free"},
{"id": -1, "provider": "OPENAI", "model_name": "gpt-prem", "api_key": "k2", "billing_tier": "premium"},
],
)
async def _blocked(*_args, **_kwargs):
return _FakeQuotaResult(allowed=False)
monkeypatch.setattr(
"app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage",
_blocked,
)
result = await resolve_or_get_pinned_llm_config_id(
session,
thread_id=1,
search_space_id=10,
user_id="00000000-0000-0000-0000-000000000001",
selected_llm_config_id=0,
)
assert result.resolved_llm_config_id == -2
assert result.resolved_tier == "free"
@pytest.mark.asyncio
async def test_pinned_premium_stays_premium_after_quota_exhaustion(monkeypatch):
from app.config import config
session = _FakeSession(
_thread(pinned_llm_config_id=-1, pinned_auto_mode=AUTO_FASTEST_MODE)
)
monkeypatch.setattr(
config,
"GLOBAL_LLM_CONFIGS",
[
{"id": -2, "provider": "OPENAI", "model_name": "gpt-free", "api_key": "k1", "billing_tier": "free"},
{"id": -1, "provider": "OPENAI", "model_name": "gpt-prem", "api_key": "k2", "billing_tier": "premium"},
],
)
async def _blocked(*_args, **_kwargs):
return _FakeQuotaResult(allowed=False)
monkeypatch.setattr(
"app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage",
_blocked,
)
result = await resolve_or_get_pinned_llm_config_id(
session,
thread_id=1,
search_space_id=10,
user_id="00000000-0000-0000-0000-000000000001",
selected_llm_config_id=0,
)
assert result.resolved_llm_config_id == -1
assert result.from_existing_pin is True
@pytest.mark.asyncio
async def test_force_repin_free_switches_auto_premium_pin_to_free(monkeypatch):
from app.config import config
session = _FakeSession(
_thread(pinned_llm_config_id=-1, pinned_auto_mode=AUTO_FASTEST_MODE)
)
monkeypatch.setattr(
config,
"GLOBAL_LLM_CONFIGS",
[
{"id": -2, "provider": "OPENAI", "model_name": "gpt-free", "api_key": "k1", "billing_tier": "free"},
{"id": -1, "provider": "OPENAI", "model_name": "gpt-prem", "api_key": "k2", "billing_tier": "premium"},
],
)
async def _blocked(*_args, **_kwargs):
return _FakeQuotaResult(allowed=False)
monkeypatch.setattr(
"app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage",
_blocked,
)
result = await resolve_or_get_pinned_llm_config_id(
session,
thread_id=1,
search_space_id=10,
user_id="00000000-0000-0000-0000-000000000001",
selected_llm_config_id=0,
force_repin_free=True,
)
assert result.resolved_llm_config_id == -2
assert result.resolved_tier == "free"
assert result.from_existing_pin is False
assert session.thread.pinned_llm_config_id == -2
@pytest.mark.asyncio
async def test_explicit_user_model_change_clears_pin(monkeypatch):
from app.config import config
session = _FakeSession(
_thread(pinned_llm_config_id=-2, pinned_auto_mode=AUTO_FASTEST_MODE)
)
monkeypatch.setattr(
config,
"GLOBAL_LLM_CONFIGS",
[
{"id": -2, "provider": "OPENAI", "model_name": "gpt-free", "api_key": "k1"},
],
)
result = await resolve_or_get_pinned_llm_config_id(
session,
thread_id=1,
search_space_id=10,
user_id="00000000-0000-0000-0000-000000000001",
selected_llm_config_id=7,
)
assert result.resolved_llm_config_id == 7
assert session.thread.pinned_llm_config_id is None
assert session.thread.pinned_auto_mode is None
assert session.thread.pinned_at is None
assert session.commit_count == 1
@pytest.mark.asyncio
async def test_invalid_pinned_config_repairs_with_new_pin(monkeypatch):
from app.config import config
session = _FakeSession(
_thread(pinned_llm_config_id=-999, pinned_auto_mode=AUTO_FASTEST_MODE)
)
monkeypatch.setattr(
config,
"GLOBAL_LLM_CONFIGS",
[
{"id": -2, "provider": "OPENAI", "model_name": "gpt-free", "api_key": "k1"},
],
)
async def _allowed(*_args, **_kwargs):
return _FakeQuotaResult(allowed=True)
monkeypatch.setattr(
"app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage",
_allowed,
)
result = await resolve_or_get_pinned_llm_config_id(
session,
thread_id=1,
search_space_id=10,
user_id="00000000-0000-0000-0000-000000000001",
selected_llm_config_id=0,
)
assert result.resolved_llm_config_id == -2
assert session.thread.pinned_llm_config_id == -2
assert session.commit_count == 1

View file

@ -1,9 +1,19 @@
import inspect
import json
import logging
import re
from pathlib import Path
import pytest import pytest
import app.tasks.chat.stream_new_chat as stream_new_chat_module
from app.agents.new_chat.errors import BusyError
from app.tasks.chat.stream_new_chat import ( from app.tasks.chat.stream_new_chat import (
StreamResult, StreamResult,
_classify_stream_exception,
_contract_enforcement_active, _contract_enforcement_active,
_evaluate_file_contract_outcome, _evaluate_file_contract_outcome,
_log_chat_stream_error,
_tool_output_has_error, _tool_output_has_error,
) )
@ -45,3 +55,217 @@ def test_contract_enforcement_local_only():
result.filesystem_mode = "cloud" result.filesystem_mode = "cloud"
assert not _contract_enforcement_active(result) assert not _contract_enforcement_active(result)
def _extract_chat_stream_payload(record_message: str) -> dict:
prefix = "[chat_stream_error] "
assert record_message.startswith(prefix)
return json.loads(record_message[len(prefix) :])
def test_unified_chat_stream_error_log_schema(caplog):
with caplog.at_level(logging.INFO, logger="app.tasks.chat.stream_new_chat"):
_log_chat_stream_error(
flow="new",
error_kind="server_error",
error_code="SERVER_ERROR",
severity="warn",
is_expected=False,
request_id="req-123",
thread_id=101,
search_space_id=202,
user_id="user-1",
message="Error during chat: boom",
)
record = next(r for r in caplog.records if "[chat_stream_error]" in r.message)
payload = _extract_chat_stream_payload(record.message)
required_keys = {
"event",
"flow",
"error_kind",
"error_code",
"severity",
"is_expected",
"request_id",
"thread_id",
"search_space_id",
"user_id",
"message",
}
assert required_keys.issubset(payload.keys())
assert payload["event"] == "chat_stream_error"
assert payload["flow"] == "new"
assert payload["error_code"] == "SERVER_ERROR"
def test_premium_quota_uses_unified_chat_stream_log_shape(caplog):
with caplog.at_level(logging.INFO, logger="app.tasks.chat.stream_new_chat"):
_log_chat_stream_error(
flow="resume",
error_kind="premium_quota_exhausted",
error_code="PREMIUM_QUOTA_EXHAUSTED",
severity="info",
is_expected=True,
request_id="req-premium",
thread_id=303,
search_space_id=404,
user_id="user-2",
message="Buy more tokens to continue with this model, or switch to a free model",
extra={"auto_fallback": False},
)
record = next(r for r in caplog.records if "[chat_stream_error]" in r.message)
payload = _extract_chat_stream_payload(record.message)
assert payload["event"] == "chat_stream_error"
assert payload["error_kind"] == "premium_quota_exhausted"
assert payload["error_code"] == "PREMIUM_QUOTA_EXHAUSTED"
assert payload["flow"] == "resume"
assert payload["is_expected"] is True
assert payload["auto_fallback"] is False
def test_stream_error_emission_keeps_machine_error_codes():
source = inspect.getsource(stream_new_chat_module)
format_error_calls = re.findall(r"format_error\(", source)
emitted_error_codes = set(re.findall(r'error_code="([A-Z_]+)"', source))
# All stream paths should route through one shared terminal error emitter.
assert len(format_error_calls) == 1
assert {
"PREMIUM_QUOTA_EXHAUSTED",
"SERVER_ERROR",
}.issubset(emitted_error_codes)
assert 'flow: Literal["new", "regenerate"] = "new"' in source
assert "_emit_stream_terminal_error" in source
assert "flow=flow" in source
assert 'flow="resume"' in source
def test_stream_exception_classifies_rate_limited():
exc = Exception(
'{"error":{"type":"rate_limit_error","message":"Rate limited. Please try again later."}}'
)
kind, code, severity, is_expected, user_message = _classify_stream_exception(
exc, flow_label="chat"
)
assert kind == "rate_limited"
assert code == "RATE_LIMITED"
assert severity == "warn"
assert is_expected is True
assert "temporarily rate-limited" in user_message
def test_stream_exception_classifies_thread_busy():
exc = BusyError(request_id="thread-123")
kind, code, severity, is_expected, user_message = _classify_stream_exception(
exc, flow_label="chat"
)
assert kind == "thread_busy"
assert code == "THREAD_BUSY"
assert severity == "warn"
assert is_expected is True
assert "still finishing for this thread" in user_message
def test_stream_exception_classifies_thread_busy_from_message():
exc = Exception("Thread is busy with another request")
kind, code, severity, is_expected, user_message = _classify_stream_exception(
exc, flow_label="chat"
)
assert kind == "thread_busy"
assert code == "THREAD_BUSY"
assert severity == "warn"
assert is_expected is True
assert "still finishing for this thread" in user_message
def test_premium_classification_is_error_code_driven():
classifier_path = Path(__file__).resolve().parents[3] / "surfsense_web/lib/chat/chat-error-classifier.ts"
source = classifier_path.read_text(encoding="utf-8")
assert "PREMIUM_KEYWORDS" not in source
assert "RATE_LIMIT_KEYWORDS" not in source
assert "normalized.includes(" not in source
assert 'if (errorCode === "PREMIUM_QUOTA_EXHAUSTED") {' in source
def test_stream_terminal_error_handler_has_pre_accept_soft_rollback_hook():
page_path = (
Path(__file__).resolve().parents[3]
/ "surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx"
)
source = page_path.read_text(encoding="utf-8")
assert "onPreAcceptFailure?: () => Promise<void>;" in source
assert "if (!accepted) {" in source
assert "await onPreAcceptFailure?.();" in source
assert "await onAcceptedStreamError?.();" in source
assert "setMessages((prev) => prev.filter((m) => m.id !== userMsgId));" in source
assert "setMessageDocumentsMap((prev) => {" in source
def test_toast_only_pre_accept_policy_has_no_inline_failed_marker():
user_message_path = (
Path(__file__).resolve().parents[3] / "surfsense_web/components/assistant-ui/user-message.tsx"
)
source = user_message_path.read_text(encoding="utf-8")
assert "Not sent. Edit and retry." not in source
assert "failed_pre_accept" not in source
def test_network_send_failures_use_unified_retry_toast_message():
classifier_path = Path(__file__).resolve().parents[3] / "surfsense_web/lib/chat/chat-error-classifier.ts"
classifier_source = classifier_path.read_text(encoding="utf-8")
page_path = (
Path(__file__).resolve().parents[3]
/ "surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx"
)
page_source = page_path.read_text(encoding="utf-8")
assert '"send_failed_pre_accept"' in classifier_source
assert 'errorCode === "SEND_FAILED_PRE_ACCEPT"' in classifier_source
assert "if (withCode.code) return withCode.code;" in classifier_source
assert 'userMessage: "Message not sent. Please retry."' in classifier_source
assert 'userMessage: "Connection issue. Please try again."' in classifier_source
assert "tagPreAcceptSendFailure(error)" in page_source
assert "const passthroughCodes = new Set([" in page_source
assert '"PREMIUM_QUOTA_EXHAUSTED"' in page_source
assert '"THREAD_BUSY"' in page_source
assert '"AUTH_EXPIRED"' in page_source
assert '"UNAUTHORIZED"' in page_source
assert '"RATE_LIMITED"' in page_source
assert '"NETWORK_ERROR"' in page_source
assert '"STREAM_PARSE_ERROR"' in page_source
assert '"TOOL_EXECUTION_ERROR"' in page_source
assert '"PERSIST_MESSAGE_FAILED"' in page_source
assert '"SERVER_ERROR"' in page_source
assert "passthroughCodes.has(existingCode)" in page_source
assert 'errorCode: "SEND_FAILED_PRE_ACCEPT"' in page_source
assert 'errorCode: "NETWORK_ERROR"' not in page_source
assert "Failed to start chat. Please try again." not in page_source
def test_pre_post_accept_abort_contract_exists_for_new_resume_regenerate_flows():
page_path = (
Path(__file__).resolve().parents[3]
/ "surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx"
)
source = page_path.read_text(encoding="utf-8")
# Each flow tracks accepted boundary and passes it into shared terminal handling.
assert "let newAccepted = false;" in source
assert "let resumeAccepted = false;" in source
assert "let regenerateAccepted = false;" in source
assert "accepted: newAccepted," in source
assert "accepted: resumeAccepted," in source
assert "accepted: regenerateAccepted," in source
# Pre-accept abort in resume/regenerate exits without persistence.
assert "if (!resumeAccepted) return;" in source
assert "if (!regenerateAccepted) return;" in source
# New flow persists only when accepted and not already persisted.
assert "if (newAccepted && !userPersisted) {" in source

View file

@ -0,0 +1,45 @@
import { atom } from "jotai";
export type PremiumAlertState = {
message: string;
};
export const premiumAlertByThreadAtom = atom<Record<number, PremiumAlertState>>({});
export const setPremiumAlertForThreadAtom = atom(
null,
(
get,
set,
payload: {
threadId: number;
message: string;
userId?: string | null;
}
) => {
const storageKey = `surfsense-premium-alert-seen-v1:${payload.userId ?? "anonymous"}`;
if (typeof window !== "undefined") {
const hasSeen = localStorage.getItem(storageKey) === "true";
if (hasSeen) return;
}
const current = get(premiumAlertByThreadAtom);
set(premiumAlertByThreadAtom, {
...current,
[payload.threadId]: { message: payload.message },
});
if (typeof window !== "undefined") {
localStorage.setItem(storageKey, "true");
}
}
);
export const clearPremiumAlertForThreadAtom = atom(null, (get, set, threadId: number) => {
const current = get(premiumAlertByThreadAtom);
if (!(threadId in current)) return;
const next = { ...current };
delete next[threadId];
set(premiumAlertByThreadAtom, next);
});

View file

@ -37,10 +37,13 @@ import {
toggleToolAtom, toggleToolAtom,
} from "@/atoms/agent-tools/agent-tools.atoms"; } from "@/atoms/agent-tools/agent-tools.atoms";
import { chatSessionStateAtom } from "@/atoms/chat/chat-session-state.atom"; import { chatSessionStateAtom } from "@/atoms/chat/chat-session-state.atom";
import { import { currentThreadAtom } from "@/atoms/chat/current-thread.atom";
mentionedDocumentsAtom, import { mentionedDocumentsAtom } from "@/atoms/chat/mentioned-documents.atom";
} from "@/atoms/chat/mentioned-documents.atom";
import { pendingUserImageDataUrlsAtom } from "@/atoms/chat/pending-user-images.atom"; import { pendingUserImageDataUrlsAtom } from "@/atoms/chat/pending-user-images.atom";
import {
clearPremiumAlertForThreadAtom,
premiumAlertByThreadAtom,
} from "@/atoms/chat/premium-alert.atom";
import { connectorDialogOpenAtom } from "@/atoms/connector-dialog/connector-dialog.atoms"; import { connectorDialogOpenAtom } from "@/atoms/connector-dialog/connector-dialog.atoms";
import { connectorsAtom } from "@/atoms/connectors/connector-query.atoms"; import { connectorsAtom } from "@/atoms/connectors/connector-query.atoms";
import { membersAtom } from "@/atoms/members/members-query.atoms"; import { membersAtom } from "@/atoms/members/members-query.atoms";
@ -135,6 +138,9 @@ const ThreadContent: FC = () => {
style={{ paddingBottom: "max(1rem, env(safe-area-inset-bottom))" }} style={{ paddingBottom: "max(1rem, env(safe-area-inset-bottom))" }}
> >
<ThreadScrollToBottom /> <ThreadScrollToBottom />
<AuiIf condition={({ thread }) => !thread.isEmpty}>
<PremiumQuotaPinnedAlert />
</AuiIf>
<AuiIf condition={({ thread }) => !thread.isEmpty}> <AuiIf condition={({ thread }) => !thread.isEmpty}>
<Composer /> <Composer />
</AuiIf> </AuiIf>
@ -144,6 +150,37 @@ const ThreadContent: FC = () => {
); );
}; };
const PremiumQuotaPinnedAlert: FC = () => {
const currentThreadState = useAtomValue(currentThreadAtom);
const alertsByThread = useAtomValue(premiumAlertByThreadAtom);
const clearPremiumAlertForThread = useSetAtom(clearPremiumAlertForThreadAtom);
const currentThreadId = currentThreadState?.id;
if (!currentThreadId) return null;
const alert = alertsByThread[currentThreadId];
if (!alert) return null;
return (
<div className="mx-0 overflow-hidden rounded-2xl border-input bg-muted px-4 py-4 text-foreground select-none">
<div className="flex items-center gap-2">
<AlertCircle className="size-4 shrink-0 text-muted-foreground" />
<div className="min-w-0 flex-1">
<p className="text-sm">{alert.message}</p>
</div>
<button
type="button"
className="inline-flex size-6 items-center justify-center text-muted-foreground transition-colors hover:text-foreground"
aria-label="Dismiss premium quota alert"
onClick={() => clearPremiumAlertForThread(currentThreadId)}
>
<X className="size-4" />
</button>
</div>
</div>
);
};
const ThreadScrollToBottom: FC = () => { const ThreadScrollToBottom: FC = () => {
return ( return (
<ThreadPrimitive.ScrollToBottom asChild> <ThreadPrimitive.ScrollToBottom asChild>

View file

@ -104,7 +104,13 @@ export function AnonymousChat({ model }: AnonymousChatProps) {
setMessages((prev) => prev.filter((m) => m.id !== assistantId)); setMessages((prev) => prev.filter((m) => m.id !== assistantId));
return; return;
} }
throw new Error(`Stream error: ${response.status}`); const body = await response.text().catch(() => "");
const errorCode = response.status === 409 ? "THREAD_BUSY" : "SERVER_ERROR";
const message =
errorCode === "THREAD_BUSY"
? "A previous response is still stopping. Please try again in a moment."
: `Stream error: ${response.status}`;
throw Object.assign(new Error(body || message), { errorCode });
} }
for await (const event of readSSEStream(response)) { for await (const event of readSSEStream(response)) {
@ -115,10 +121,12 @@ export function AnonymousChat({ model }: AnonymousChatProps) {
prev.map((m) => (m.id === assistantId ? { ...m, content: m.content + event.delta } : m)) prev.map((m) => (m.id === assistantId ? { ...m, content: m.content + event.delta } : m))
); );
} else if (event.type === "error") { } else if (event.type === "error") {
const message =
event.errorCode === "THREAD_BUSY"
? "A previous response is still stopping. Please try again in a moment."
: event.errorText;
setMessages((prev) => setMessages((prev) =>
prev.map((m) => prev.map((m) => (m.id === assistantId ? { ...m, content: m.content || message } : m))
m.id === assistantId ? { ...m, content: m.content || event.errorText } : m
)
); );
} else if ("type" in event && event.type === "data-token-usage") { } else if ("type" in event && event.type === "data-token-usage") {
// After streaming completes, refresh quota // After streaming completes, refresh quota

View file

@ -55,6 +55,48 @@ function parseCaptchaError(status: number, body: string): string | null {
return null; return null;
} }
function normalizeFreeChatErrorMessage(error: unknown): string {
if (!(error instanceof Error)) return "An unexpected error occurred";
const code = (error as Error & { errorCode?: string }).errorCode;
if (code === "THREAD_BUSY") {
return "A previous response is still stopping. Please try again in a moment.";
}
return error.message || "An unexpected error occurred";
}
function toFreeChatHttpError(status: number, body: string): Error & { errorCode?: string } {
let errorCode: string | undefined;
let message = body || `Server error: ${status}`;
try {
const parsed = JSON.parse(body) as Record<string, unknown>;
const detail =
typeof parsed.detail === "object" && parsed.detail !== null
? (parsed.detail as Record<string, unknown>)
: null;
errorCode =
(typeof detail?.error_code === "string" ? detail.error_code : undefined) ??
(typeof detail?.errorCode === "string" ? detail.errorCode : undefined) ??
(typeof parsed.error_code === "string" ? parsed.error_code : undefined) ??
(typeof parsed.errorCode === "string" ? parsed.errorCode : undefined);
message =
(typeof detail?.message === "string" ? detail.message : undefined) ??
(typeof parsed.message === "string" ? parsed.message : undefined) ??
(typeof parsed.detail === "string" ? parsed.detail : undefined) ??
message;
} catch {
// non-json response
}
if (!errorCode) {
if (status === 409) errorCode = "THREAD_BUSY";
else if (status === 429) errorCode = "RATE_LIMITED";
else if (status === 401 || status === 403) errorCode = "AUTH_EXPIRED";
else errorCode = "SERVER_ERROR";
}
return Object.assign(new Error(message), { errorCode });
}
export function FreeChatPage() { export function FreeChatPage() {
const anonMode = useAnonymousMode(); const anonMode = useAnonymousMode();
const modelSlug = anonMode.isAnonymous ? anonMode.modelSlug : ""; const modelSlug = anonMode.isAnonymous ? anonMode.modelSlug : "";
@ -124,7 +166,7 @@ export function FreeChatPage() {
const body = await response.text().catch(() => ""); const body = await response.text().catch(() => "");
const captchaCode = parseCaptchaError(response.status, body); const captchaCode = parseCaptchaError(response.status, body);
if (captchaCode) return "captcha"; if (captchaCode) return "captcha";
throw new Error(body || `Server error: ${response.status}`); throw toFreeChatHttpError(response.status, body);
} }
const currentThinkingSteps = new Map<string, ThinkingStepData>(); const currentThinkingSteps = new Map<string, ThinkingStepData>();
@ -244,7 +286,9 @@ export function FreeChatPage() {
break; break;
case "error": case "error":
throw new Error(parsed.errorText || "Server error"); throw Object.assign(new Error(parsed.errorText || "Server error"), {
errorCode: parsed.errorCode,
});
} }
} }
batcher.flush(); batcher.flush();
@ -334,7 +378,7 @@ export function FreeChatPage() {
} catch (error) { } catch (error) {
if (error instanceof Error && error.name === "AbortError") return; if (error instanceof Error && error.name === "AbortError") return;
console.error("[FreeChatPage] Chat error:", error); console.error("[FreeChatPage] Chat error:", error);
const errorText = error instanceof Error ? error.message : "An unexpected error occurred"; const errorText = normalizeFreeChatErrorMessage(error);
setMessages((prev) => setMessages((prev) =>
prev.map((m) => prev.map((m) =>
m.id === assistantMsgId m.id === assistantMsgId
@ -393,7 +437,7 @@ export function FreeChatPage() {
} catch (error) { } catch (error) {
if (error instanceof Error && error.name === "AbortError") return; if (error instanceof Error && error.name === "AbortError") return;
console.error("[FreeChatPage] Retry error:", error); console.error("[FreeChatPage] Retry error:", error);
const errorText = error instanceof Error ? error.message : "An unexpected error occurred"; const errorText = normalizeFreeChatErrorMessage(error);
setMessages((prev) => setMessages((prev) =>
prev.map((m) => prev.map((m) =>
m.id === assistantMsgId m.id === assistantMsgId

View file

@ -0,0 +1,304 @@
export type ChatFlow = "new" | "resume" | "regenerate";
export type ChatErrorKind =
| "premium_quota_exhausted"
| "thread_busy"
| "send_failed_pre_accept"
| "auth_expired"
| "rate_limited"
| "network_offline"
| "stream_interrupted"
| "stream_parse_error"
| "tool_execution_error"
| "persist_message_failed"
| "server_error"
| "unknown";
export type ChatErrorChannel = "pinned_inline" | "toast" | "silent";
export type ChatTelemetryEvent = "chat_blocked" | "chat_error";
export type ChatErrorSeverity = "info" | "warn" | "error";
export interface NormalizedChatError {
kind: ChatErrorKind;
channel: ChatErrorChannel;
severity: ChatErrorSeverity;
telemetryEvent: ChatTelemetryEvent;
isExpected: boolean;
userMessage: string;
assistantMessage?: string;
rawMessage?: string;
errorCode?: string;
details?: Record<string, unknown>;
}
export interface RawChatErrorInput {
error: unknown;
flow: ChatFlow;
context?: {
searchSpaceId?: number;
threadId?: number | null;
};
}
export const PREMIUM_QUOTA_ASSISTANT_MESSAGE =
"I cant continue with the current premium model because your premium tokens are exhausted. Switch to a free model or buy more tokens to continue.";
function getErrorMessage(error: unknown): string {
if (error instanceof Error) return error.message;
if (typeof error === "string") return error;
try {
return JSON.stringify(error);
} catch {
return "Unknown error";
}
}
function getErrorCode(error: unknown, parsedJson: Record<string, unknown> | null): string | undefined {
if (error instanceof Error) {
const withCode = error as Error & { errorCode?: string; code?: string };
if (withCode.errorCode) return withCode.errorCode;
if (withCode.code) return withCode.code;
}
if (typeof error === "object" && error !== null) {
const withCode = error as { errorCode?: unknown };
if (typeof withCode.errorCode === "string" && withCode.errorCode) {
return withCode.errorCode;
}
}
if (parsedJson) {
const topLevelCode = parsedJson.errorCode;
if (typeof topLevelCode === "string" && topLevelCode) {
return topLevelCode;
}
}
return undefined;
}
function parseEmbeddedJson(text: string): Record<string, unknown> | null {
const candidates = [text];
const firstBraceIdx = text.indexOf("{");
if (firstBraceIdx >= 0) {
candidates.push(text.slice(firstBraceIdx));
}
for (const candidate of candidates) {
try {
const parsed = JSON.parse(candidate);
if (typeof parsed === "object" && parsed !== null) {
return parsed as Record<string, unknown>;
}
} catch {
// noop
}
}
return null;
}
function inferProviderErrorType(parsedJson: Record<string, unknown> | null): string | undefined {
if (!parsedJson) return undefined;
const topLevelType = parsedJson.type;
if (typeof topLevelType === "string" && topLevelType) return topLevelType;
const nestedError = parsedJson.error;
if (typeof nestedError === "object" && nestedError !== null) {
const nestedType = (nestedError as Record<string, unknown>).type;
if (typeof nestedType === "string" && nestedType) return nestedType;
}
return undefined;
}
export function classifyChatError(input: RawChatErrorInput): NormalizedChatError {
const { error } = input;
const rawMessage = getErrorMessage(error);
const parsedJson = parseEmbeddedJson(rawMessage);
const errorCode = getErrorCode(error, parsedJson);
const providerErrorType = inferProviderErrorType(parsedJson);
const providerTypeNormalized = providerErrorType?.toLowerCase() ?? "";
const errorName = error instanceof Error ? error.name : undefined;
if (errorName === "AbortError") {
return {
kind: "stream_interrupted",
channel: "silent",
severity: "info",
telemetryEvent: "chat_error",
isExpected: true,
userMessage: "Request canceled.",
rawMessage,
errorCode,
details: { flow: input.flow },
};
}
if (errorCode === "PREMIUM_QUOTA_EXHAUSTED") {
return {
kind: "premium_quota_exhausted",
channel: "pinned_inline",
severity: "info",
telemetryEvent: "chat_blocked",
isExpected: true,
userMessage:
"Buy more tokens to continue with this model, or switch to a free model.",
assistantMessage: PREMIUM_QUOTA_ASSISTANT_MESSAGE,
rawMessage,
errorCode: errorCode ?? "PREMIUM_QUOTA_EXHAUSTED",
details: { flow: input.flow },
};
}
if (
errorCode === "THREAD_BUSY"
) {
return {
kind: "thread_busy",
channel: "toast",
severity: "warn",
telemetryEvent: "chat_blocked",
isExpected: true,
userMessage: "A previous response is still stopping. Please try again in a moment.",
rawMessage,
errorCode: errorCode ?? "THREAD_BUSY",
details: { flow: input.flow },
};
}
if (errorCode === "SEND_FAILED_PRE_ACCEPT") {
return {
kind: "send_failed_pre_accept",
channel: "toast",
severity: "warn",
telemetryEvent: "chat_blocked",
isExpected: true,
userMessage: "Message not sent. Please retry.",
rawMessage,
errorCode: errorCode ?? "SEND_FAILED_PRE_ACCEPT",
details: { flow: input.flow },
};
}
if (
errorCode === "AUTH_EXPIRED" ||
errorCode === "UNAUTHORIZED"
) {
return {
kind: "auth_expired",
channel: "toast",
severity: "warn",
telemetryEvent: "chat_error",
isExpected: true,
userMessage: "Your session expired. Please sign in again.",
rawMessage,
errorCode: errorCode ?? "AUTH_EXPIRED",
details: { flow: input.flow },
};
}
if (
errorCode === "RATE_LIMITED" ||
providerTypeNormalized === "rate_limit_error"
) {
return {
kind: "rate_limited",
channel: "toast",
severity: "warn",
telemetryEvent: "chat_blocked",
isExpected: true,
userMessage:
"This model is temporarily rate-limited. Please try again in a few seconds or switch models.",
rawMessage,
errorCode: errorCode ?? "RATE_LIMITED",
details: { flow: input.flow, providerErrorType },
};
}
if (errorCode === "NETWORK_ERROR") {
return {
kind: "network_offline",
channel: "toast",
severity: "warn",
telemetryEvent: "chat_error",
isExpected: true,
userMessage: "Connection issue. Please try again.",
rawMessage,
errorCode: errorCode ?? "NETWORK_ERROR",
details: { flow: input.flow },
};
}
if (
errorCode === "STREAM_PARSE_ERROR"
) {
return {
kind: "stream_parse_error",
channel: "toast",
severity: "error",
telemetryEvent: "chat_error",
isExpected: false,
userMessage: "We hit a response formatting issue. Please try again.",
rawMessage,
errorCode: errorCode ?? "STREAM_PARSE_ERROR",
details: { flow: input.flow },
};
}
if (
errorCode === "TOOL_EXECUTION_ERROR"
) {
return {
kind: "tool_execution_error",
channel: "toast",
severity: "error",
telemetryEvent: "chat_error",
isExpected: false,
userMessage: "A tool failed while processing your request. Please try again.",
rawMessage,
errorCode: errorCode ?? "TOOL_EXECUTION_ERROR",
details: { flow: input.flow },
};
}
if (
errorCode === "PERSIST_MESSAGE_FAILED"
) {
return {
kind: "persist_message_failed",
channel: "toast",
severity: "error",
telemetryEvent: "chat_error",
isExpected: false,
userMessage: "Response generated, but saving failed. Please retry once.",
rawMessage,
errorCode: errorCode ?? "PERSIST_MESSAGE_FAILED",
details: { flow: input.flow },
};
}
if (
errorCode === "SERVER_ERROR"
) {
return {
kind: "server_error",
channel: "toast",
severity: "error",
telemetryEvent: "chat_error",
isExpected: false,
userMessage: "We couldnt complete this response right now. Please try again.",
rawMessage,
errorCode: errorCode ?? "SERVER_ERROR",
details: { flow: input.flow, providerErrorType },
};
}
return {
kind: "unknown",
channel: "toast",
severity: "error",
telemetryEvent: "chat_error",
isExpected: false,
userMessage: "We couldnt complete this response right now. Please try again.",
rawMessage,
errorCode,
details: { flow: input.flow, providerErrorType },
};
}

View file

@ -546,7 +546,7 @@ export type SSEEvent =
}>; }>;
}; };
} }
| { type: "error"; errorText: string }; | { type: "error"; errorText: string; errorCode?: string };
/** /**
* Async generator that reads an SSE stream and yields parsed JSON objects. * Async generator that reads an SSE stream and yields parsed JSON objects.

View file

@ -1,5 +1,6 @@
import posthog from "posthog-js"; import posthog from "posthog-js";
import { getConnectorTelemetryMeta } from "@/components/assistant-ui/connector-popup/constants/connector-constants"; import { getConnectorTelemetryMeta } from "@/components/assistant-ui/connector-popup/constants/connector-constants";
import type { ChatErrorKind, ChatFlow, ChatErrorSeverity } from "@/lib/chat/chat-error-classifier";
/** /**
* PostHog Analytics Event Definitions * PostHog Analytics Event Definitions
@ -139,6 +140,55 @@ export function trackChatError(searchSpaceId: number, chatId: number, error?: st
}); });
} }
export interface ChatFailureTelemetry {
flow: ChatFlow;
kind: ChatErrorKind;
error_code?: string;
severity: ChatErrorSeverity;
is_expected: boolean;
message?: string;
}
export function trackChatBlocked(
searchSpaceId: number,
chatId: number | null,
payload: ChatFailureTelemetry
) {
safeCapture(
"chat_blocked",
compact({
search_space_id: searchSpaceId,
chat_id: chatId ?? undefined,
flow: payload.flow,
kind: payload.kind,
error_code: payload.error_code,
severity: payload.severity,
is_expected: payload.is_expected,
message: payload.message,
})
);
}
export function trackChatErrorDetailed(
searchSpaceId: number,
chatId: number | null,
payload: ChatFailureTelemetry
) {
safeCapture(
"chat_error",
compact({
search_space_id: searchSpaceId,
chat_id: chatId ?? undefined,
flow: payload.flow,
kind: payload.kind,
error_code: payload.error_code,
severity: payload.severity,
is_expected: payload.is_expected,
message: payload.message,
})
);
}
/** /**
* Track a message sent from the unauthenticated "free" / anonymous chat * Track a message sent from the unauthenticated "free" / anonymous chat
* flow. This is intentionally a separate event from `chat_message_sent` * flow. This is intentionally a separate event from `chat_message_sent`