mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-31 19:45:15 +02:00
Merge pull request #1325 from AnishSarkar22/feat/split-auto-free-premium
feat(chat): thread-level auto model pinning & structured chat errors
This commit is contained in:
commit
f129207c50
17 changed files with 2501 additions and 353 deletions
|
|
@ -0,0 +1,65 @@
|
|||
"""138_add_thread_auto_model_pinning_fields
|
||||
|
||||
Revision ID: 138
|
||||
Revises: 137
|
||||
Create Date: 2026-04-30
|
||||
|
||||
Add thread-level fields to persist Auto (Fastest) model pinning metadata:
|
||||
- pinned_llm_config_id: concrete resolved config id used for this thread
|
||||
- pinned_auto_mode: auto policy identifier (currently "auto_fastest")
|
||||
- pinned_at: timestamp when the pin was created/refreshed
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
from alembic import op
|
||||
|
||||
revision: str = "138"
|
||||
down_revision: str | None = "137"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.execute(
|
||||
"ALTER TABLE new_chat_threads "
|
||||
"ADD COLUMN IF NOT EXISTS pinned_llm_config_id INTEGER"
|
||||
)
|
||||
op.execute(
|
||||
"ALTER TABLE new_chat_threads "
|
||||
"ADD COLUMN IF NOT EXISTS pinned_auto_mode VARCHAR(32)"
|
||||
)
|
||||
op.execute(
|
||||
"ALTER TABLE new_chat_threads "
|
||||
"ADD COLUMN IF NOT EXISTS pinned_at TIMESTAMP WITH TIME ZONE"
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"CREATE INDEX IF NOT EXISTS ix_new_chat_threads_pinned_llm_config_id "
|
||||
"ON new_chat_threads (pinned_llm_config_id)"
|
||||
)
|
||||
op.execute(
|
||||
"CREATE INDEX IF NOT EXISTS ix_new_chat_threads_pinned_auto_mode "
|
||||
"ON new_chat_threads (pinned_auto_mode)"
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.execute(
|
||||
"DROP INDEX IF EXISTS ix_new_chat_threads_pinned_auto_mode"
|
||||
)
|
||||
op.execute(
|
||||
"DROP INDEX IF EXISTS ix_new_chat_threads_pinned_llm_config_id"
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"ALTER TABLE new_chat_threads DROP COLUMN IF EXISTS pinned_at"
|
||||
)
|
||||
op.execute(
|
||||
"ALTER TABLE new_chat_threads DROP COLUMN IF EXISTS pinned_auto_mode"
|
||||
)
|
||||
op.execute(
|
||||
"ALTER TABLE new_chat_threads DROP COLUMN IF EXISTS pinned_llm_config_id"
|
||||
)
|
||||
|
|
@ -638,6 +638,13 @@ class NewChatThread(BaseModel, TimestampMixin):
|
|||
default=False,
|
||||
server_default="false",
|
||||
)
|
||||
# Auto model pinning metadata:
|
||||
# - pinned_llm_config_id stores the concrete resolved model config id.
|
||||
# - pinned_auto_mode indicates which auto policy produced the pin.
|
||||
# This allows Auto (Fastest) to resolve once per thread and stay stable.
|
||||
pinned_llm_config_id = Column(Integer, nullable=True, index=True)
|
||||
pinned_auto_mode = Column(String(32), nullable=True, index=True)
|
||||
pinned_at = Column(TIMESTAMP(timezone=True), nullable=True)
|
||||
|
||||
# Relationships
|
||||
search_space = relationship("SearchSpace", back_populates="new_chat_threads")
|
||||
|
|
|
|||
|
|
@ -1924,6 +1924,7 @@ async def regenerate_response(
|
|||
filesystem_selection=filesystem_selection,
|
||||
request_id=getattr(http_request.state, "request_id", "unknown"),
|
||||
user_image_data_urls=regenerate_image_urls or None,
|
||||
flow="regenerate",
|
||||
):
|
||||
yield chunk
|
||||
streaming_completed = True
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ import logging
|
|||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from langchain_core.messages import HumanMessage
|
||||
from pydantic import BaseModel as PydanticBaseModel
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import func, update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
|
||||
|
|
@ -15,6 +15,7 @@ from app.agents.new_chat.tools.update_memory import MEMORY_HARD_LIMIT, _save_mem
|
|||
from app.config import config
|
||||
from app.db import (
|
||||
ImageGenerationConfig,
|
||||
NewChatThread,
|
||||
NewLLMConfig,
|
||||
Permission,
|
||||
SearchSpace,
|
||||
|
|
@ -790,9 +791,31 @@ async def update_llm_preferences(
|
|||
|
||||
# Update preferences
|
||||
update_data = preferences.model_dump(exclude_unset=True)
|
||||
previous_agent_llm_id = search_space.agent_llm_id
|
||||
for key, value in update_data.items():
|
||||
setattr(search_space, key, value)
|
||||
|
||||
agent_llm_changed = (
|
||||
"agent_llm_id" in update_data
|
||||
and update_data["agent_llm_id"] != previous_agent_llm_id
|
||||
)
|
||||
if agent_llm_changed:
|
||||
await session.execute(
|
||||
update(NewChatThread)
|
||||
.where(NewChatThread.search_space_id == search_space_id)
|
||||
.values(
|
||||
pinned_llm_config_id=None,
|
||||
pinned_auto_mode=None,
|
||||
pinned_at=None,
|
||||
)
|
||||
)
|
||||
logger.info(
|
||||
"Cleared auto model pins for search_space_id=%s after agent_llm_id change (%s -> %s)",
|
||||
search_space_id,
|
||||
previous_agent_llm_id,
|
||||
update_data["agent_llm_id"],
|
||||
)
|
||||
|
||||
await session.commit()
|
||||
await session.refresh(search_space)
|
||||
|
||||
|
|
|
|||
218
surfsense_backend/app/services/auto_model_pin_service.py
Normal file
218
surfsense_backend/app/services/auto_model_pin_service.py
Normal file
|
|
@ -0,0 +1,218 @@
|
|||
"""Resolve and persist Auto (Fastest) model pins per chat thread.
|
||||
|
||||
Auto (Fastest) is represented by ``agent_llm_id == 0``. For chat threads we
|
||||
resolve that virtual mode to one concrete global LLM config exactly once and
|
||||
persist the chosen config id on ``new_chat_threads`` so subsequent turns are
|
||||
stable.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.config import config
|
||||
from app.db import NewChatThread
|
||||
from app.services.token_quota_service import TokenQuotaService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
AUTO_FASTEST_ID = 0
|
||||
AUTO_FASTEST_MODE = "auto_fastest"
|
||||
|
||||
|
||||
@dataclass
|
||||
class AutoPinResolution:
|
||||
resolved_llm_config_id: int
|
||||
resolved_tier: str
|
||||
from_existing_pin: bool
|
||||
|
||||
|
||||
def _is_usable_global_config(cfg: dict) -> bool:
|
||||
return bool(
|
||||
cfg.get("id") is not None
|
||||
and cfg.get("model_name")
|
||||
and cfg.get("provider")
|
||||
and cfg.get("api_key")
|
||||
)
|
||||
|
||||
|
||||
def _global_candidates() -> list[dict]:
|
||||
candidates = [cfg for cfg in config.GLOBAL_LLM_CONFIGS if _is_usable_global_config(cfg)]
|
||||
return sorted(candidates, key=lambda c: int(c.get("id", 0)))
|
||||
|
||||
|
||||
def _tier_of(cfg: dict) -> str:
|
||||
return str(cfg.get("billing_tier", "free")).lower()
|
||||
|
||||
|
||||
def _deterministic_pick(candidates: list[dict], thread_id: int) -> dict:
|
||||
digest = hashlib.sha256(f"{AUTO_FASTEST_MODE}:{thread_id}".encode()).digest()
|
||||
idx = int.from_bytes(digest[:8], "big") % len(candidates)
|
||||
return candidates[idx]
|
||||
|
||||
|
||||
def _to_uuid(user_id: str | UUID | None) -> UUID | None:
|
||||
if user_id is None:
|
||||
return None
|
||||
if isinstance(user_id, UUID):
|
||||
return user_id
|
||||
try:
|
||||
return UUID(str(user_id))
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
async def _is_premium_eligible(session: AsyncSession, user_id: str | UUID | None) -> bool:
|
||||
parsed = _to_uuid(user_id)
|
||||
if parsed is None:
|
||||
return False
|
||||
usage = await TokenQuotaService.premium_get_usage(session, parsed)
|
||||
return bool(usage.allowed)
|
||||
|
||||
|
||||
async def resolve_or_get_pinned_llm_config_id(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
thread_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str | UUID | None,
|
||||
selected_llm_config_id: int,
|
||||
force_repin_free: bool = False,
|
||||
) -> AutoPinResolution:
|
||||
"""Resolve Auto (Fastest) to one concrete config id and persist pin metadata.
|
||||
|
||||
For non-auto selections, this function clears existing auto pin metadata and
|
||||
returns the selected id as-is.
|
||||
"""
|
||||
thread = (
|
||||
(
|
||||
await session.execute(
|
||||
select(NewChatThread)
|
||||
.where(NewChatThread.id == thread_id)
|
||||
.with_for_update(of=NewChatThread)
|
||||
)
|
||||
)
|
||||
.unique()
|
||||
.scalar_one_or_none()
|
||||
)
|
||||
if thread is None:
|
||||
raise ValueError(f"Thread {thread_id} not found")
|
||||
if thread.search_space_id != search_space_id:
|
||||
raise ValueError(
|
||||
f"Thread {thread_id} does not belong to search space {search_space_id}"
|
||||
)
|
||||
|
||||
# Explicit model selected: clear stale auto pin metadata.
|
||||
if selected_llm_config_id != AUTO_FASTEST_ID:
|
||||
if (
|
||||
thread.pinned_llm_config_id is not None
|
||||
or thread.pinned_auto_mode is not None
|
||||
or thread.pinned_at is not None
|
||||
):
|
||||
thread.pinned_llm_config_id = None
|
||||
thread.pinned_auto_mode = None
|
||||
thread.pinned_at = None
|
||||
await session.commit()
|
||||
return AutoPinResolution(
|
||||
resolved_llm_config_id=selected_llm_config_id,
|
||||
resolved_tier="explicit",
|
||||
from_existing_pin=False,
|
||||
)
|
||||
|
||||
candidates = _global_candidates()
|
||||
if not candidates:
|
||||
raise ValueError("No usable global LLM configs are available for Auto mode")
|
||||
candidate_by_id = {int(c["id"]): c for c in candidates}
|
||||
|
||||
# Reuse existing valid pin without re-checking current quota (no silent tier switch),
|
||||
# unless the caller explicitly requests a forced repin to free.
|
||||
pinned_id = thread.pinned_llm_config_id
|
||||
if (
|
||||
not force_repin_free
|
||||
and
|
||||
thread.pinned_auto_mode == AUTO_FASTEST_MODE
|
||||
and pinned_id is not None
|
||||
and int(pinned_id) in candidate_by_id
|
||||
):
|
||||
pinned_cfg = candidate_by_id[int(pinned_id)]
|
||||
logger.info(
|
||||
"auto_pin_reused thread_id=%s search_space_id=%s resolved_config_id=%s tier=%s",
|
||||
thread_id,
|
||||
search_space_id,
|
||||
pinned_id,
|
||||
_tier_of(pinned_cfg),
|
||||
)
|
||||
return AutoPinResolution(
|
||||
resolved_llm_config_id=int(pinned_id),
|
||||
resolved_tier=_tier_of(pinned_cfg),
|
||||
from_existing_pin=True,
|
||||
)
|
||||
if pinned_id is not None:
|
||||
logger.info(
|
||||
"auto_pin_invalid thread_id=%s search_space_id=%s pinned_config_id=%s pinned_auto_mode=%s",
|
||||
thread_id,
|
||||
search_space_id,
|
||||
pinned_id,
|
||||
thread.pinned_auto_mode,
|
||||
)
|
||||
|
||||
premium_eligible = False if force_repin_free else await _is_premium_eligible(session, user_id)
|
||||
if premium_eligible:
|
||||
eligible = candidates
|
||||
else:
|
||||
eligible = [c for c in candidates if _tier_of(c) != "premium"]
|
||||
|
||||
if not eligible:
|
||||
raise ValueError(
|
||||
"Auto mode could not find an eligible LLM config for this user and quota state"
|
||||
)
|
||||
|
||||
selected_cfg = _deterministic_pick(eligible, thread_id)
|
||||
selected_id = int(selected_cfg["id"])
|
||||
selected_tier = _tier_of(selected_cfg)
|
||||
|
||||
thread.pinned_llm_config_id = selected_id
|
||||
thread.pinned_auto_mode = AUTO_FASTEST_MODE
|
||||
thread.pinned_at = datetime.now(UTC)
|
||||
await session.commit()
|
||||
|
||||
if force_repin_free:
|
||||
logger.info(
|
||||
"auto_pin_forced_free_repin thread_id=%s search_space_id=%s previous_config_id=%s resolved_config_id=%s",
|
||||
thread_id,
|
||||
search_space_id,
|
||||
pinned_id,
|
||||
selected_id,
|
||||
)
|
||||
|
||||
if pinned_id is None:
|
||||
logger.info(
|
||||
"auto_pin_created thread_id=%s search_space_id=%s resolved_config_id=%s tier=%s premium_eligible=%s",
|
||||
thread_id,
|
||||
search_space_id,
|
||||
selected_id,
|
||||
selected_tier,
|
||||
premium_eligible,
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
"auto_pin_repaired thread_id=%s search_space_id=%s previous_config_id=%s resolved_config_id=%s tier=%s premium_eligible=%s",
|
||||
thread_id,
|
||||
search_space_id,
|
||||
pinned_id,
|
||||
selected_id,
|
||||
selected_tier,
|
||||
premium_eligible,
|
||||
)
|
||||
return AutoPinResolution(
|
||||
resolved_llm_config_id=selected_id,
|
||||
resolved_tier=selected_tier,
|
||||
from_existing_pin=False,
|
||||
)
|
||||
|
|
@ -565,20 +565,24 @@ class VercelStreamingService:
|
|||
# Error Part
|
||||
# =========================================================================
|
||||
|
||||
def format_error(self, error_text: str) -> str:
|
||||
def format_error(self, error_text: str, error_code: str | None = None) -> str:
|
||||
"""
|
||||
Format an error message.
|
||||
|
||||
Args:
|
||||
error_text: The error message text
|
||||
error_code: Optional machine-readable error code for frontend branching
|
||||
|
||||
Returns:
|
||||
str: SSE formatted error part
|
||||
|
||||
Example output:
|
||||
data: {"type":"error","errorText":"Something went wrong"}
|
||||
data: {"type":"error","errorText":"Something went wrong","errorCode":"SOME_CODE"}
|
||||
"""
|
||||
return self._format_sse({"type": "error", "errorText": error_text})
|
||||
payload: dict[str, str] = {"type": "error", "errorText": error_text}
|
||||
if error_code:
|
||||
payload["errorCode"] = error_code
|
||||
return self._format_sse(payload)
|
||||
|
||||
# =========================================================================
|
||||
# Tool Parts
|
||||
|
|
|
|||
|
|
@ -19,7 +19,8 @@ import re
|
|||
import time
|
||||
from collections.abc import AsyncGenerator
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
from functools import partial
|
||||
from typing import Any, Literal
|
||||
from uuid import UUID
|
||||
|
||||
import anyio
|
||||
|
|
@ -30,6 +31,7 @@ from sqlalchemy.orm import selectinload
|
|||
|
||||
from app.agents.new_chat.chat_deepagent import create_surfsense_deep_agent
|
||||
from app.agents.new_chat.checkpointer import get_checkpointer
|
||||
from app.agents.new_chat.errors import BusyError
|
||||
from app.agents.new_chat.feature_flags import get_flags
|
||||
from app.agents.new_chat.filesystem_selection import FilesystemMode, FilesystemSelection
|
||||
from app.agents.new_chat.llm_config import (
|
||||
|
|
@ -57,6 +59,7 @@ from app.db import (
|
|||
shielded_async_session,
|
||||
)
|
||||
from app.prompts import TITLE_GENERATION_PROMPT
|
||||
from app.services.auto_model_pin_service import resolve_or_get_pinned_llm_config_id
|
||||
from app.services.chat_session_state_service import (
|
||||
clear_ai_responding,
|
||||
set_ai_responding,
|
||||
|
|
@ -338,6 +341,138 @@ def _log_file_contract(stage: str, result: StreamResult, **extra: Any) -> None:
|
|||
)
|
||||
|
||||
|
||||
def _log_chat_stream_error(
|
||||
*,
|
||||
flow: Literal["new", "resume", "regenerate"],
|
||||
error_kind: str,
|
||||
error_code: str | None,
|
||||
severity: Literal["info", "warn", "error"],
|
||||
is_expected: bool,
|
||||
request_id: str | None,
|
||||
thread_id: int | None,
|
||||
search_space_id: int | None,
|
||||
user_id: str | None,
|
||||
message: str,
|
||||
extra: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
payload: dict[str, Any] = {
|
||||
"event": "chat_stream_error",
|
||||
"flow": flow,
|
||||
"error_kind": error_kind,
|
||||
"error_code": error_code,
|
||||
"severity": severity,
|
||||
"is_expected": is_expected,
|
||||
"request_id": request_id or "unknown",
|
||||
"thread_id": thread_id,
|
||||
"search_space_id": search_space_id,
|
||||
"user_id": user_id,
|
||||
"message": message,
|
||||
}
|
||||
if extra:
|
||||
payload.update(extra)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
rendered = json.dumps(payload, ensure_ascii=False)
|
||||
if severity == "error":
|
||||
logger.error("[chat_stream_error] %s", rendered)
|
||||
elif severity == "warn":
|
||||
logger.warning("[chat_stream_error] %s", rendered)
|
||||
else:
|
||||
logger.info("[chat_stream_error] %s", rendered)
|
||||
|
||||
|
||||
def _parse_error_payload(message: str) -> dict[str, Any] | None:
|
||||
candidates = [message]
|
||||
first_brace_idx = message.find("{")
|
||||
if first_brace_idx >= 0:
|
||||
candidates.append(message[first_brace_idx:])
|
||||
|
||||
for candidate in candidates:
|
||||
try:
|
||||
parsed = json.loads(candidate)
|
||||
if isinstance(parsed, dict):
|
||||
return parsed
|
||||
except Exception:
|
||||
continue
|
||||
return None
|
||||
|
||||
|
||||
def _classify_stream_exception(
|
||||
exc: Exception,
|
||||
*,
|
||||
flow_label: str,
|
||||
) -> tuple[str, str, Literal["info", "warn", "error"], bool, str]:
|
||||
raw = str(exc)
|
||||
if isinstance(exc, BusyError) or "Thread is busy with another request" in raw:
|
||||
return (
|
||||
"thread_busy",
|
||||
"THREAD_BUSY",
|
||||
"warn",
|
||||
True,
|
||||
"Another response is still finishing for this thread. Please try again in a moment.",
|
||||
)
|
||||
|
||||
parsed = _parse_error_payload(raw)
|
||||
provider_error_type = ""
|
||||
if parsed:
|
||||
top_type = parsed.get("type")
|
||||
if isinstance(top_type, str):
|
||||
provider_error_type = top_type.lower()
|
||||
nested = parsed.get("error")
|
||||
if isinstance(nested, dict):
|
||||
nested_type = nested.get("type")
|
||||
if isinstance(nested_type, str):
|
||||
provider_error_type = nested_type.lower()
|
||||
|
||||
if provider_error_type == "rate_limit_error":
|
||||
return (
|
||||
"rate_limited",
|
||||
"RATE_LIMITED",
|
||||
"warn",
|
||||
True,
|
||||
"This model is temporarily rate-limited. Please try again in a few seconds or switch models.",
|
||||
)
|
||||
|
||||
return (
|
||||
"server_error",
|
||||
"SERVER_ERROR",
|
||||
"error",
|
||||
False,
|
||||
f"Error during {flow_label}: {raw}",
|
||||
)
|
||||
|
||||
|
||||
def _emit_stream_terminal_error(
|
||||
*,
|
||||
streaming_service: VercelStreamingService,
|
||||
flow: str,
|
||||
request_id: str | None,
|
||||
thread_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str | None,
|
||||
message: str,
|
||||
error_kind: str = "server_error",
|
||||
error_code: str = "SERVER_ERROR",
|
||||
severity: Literal["info", "warn", "error"] = "error",
|
||||
is_expected: bool = False,
|
||||
extra: dict[str, Any] | None = None,
|
||||
) -> str:
|
||||
_log_chat_stream_error(
|
||||
flow=flow,
|
||||
error_kind=error_kind,
|
||||
error_code=error_code,
|
||||
severity=severity,
|
||||
is_expected=is_expected,
|
||||
request_id=request_id,
|
||||
thread_id=thread_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
message=message,
|
||||
extra=extra,
|
||||
)
|
||||
return streaming_service.format_error(message, error_code=error_code)
|
||||
|
||||
|
||||
def _legacy_match_lc_id(
|
||||
pending_tool_call_chunks: list[dict[str, Any]],
|
||||
tool_name: str,
|
||||
|
|
@ -1913,6 +2048,7 @@ async def stream_new_chat(
|
|||
filesystem_selection: FilesystemSelection | None = None,
|
||||
request_id: str | None = None,
|
||||
user_image_data_urls: list[str] | None = None,
|
||||
flow: Literal["new", "regenerate"] = "new",
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""
|
||||
Stream chat responses from the new SurfSense deep agent.
|
||||
|
|
@ -1964,6 +2100,16 @@ async def stream_new_chat(
|
|||
_premium_reserved = 0
|
||||
_premium_request_id: str | None = None
|
||||
|
||||
_emit_stream_error = partial(
|
||||
_emit_stream_terminal_error,
|
||||
streaming_service=streaming_service,
|
||||
flow=flow,
|
||||
request_id=request_id,
|
||||
thread_id=chat_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
session = async_session_maker()
|
||||
try:
|
||||
# Mark AI as responding to this user for live collaboration
|
||||
|
|
@ -1971,49 +2117,78 @@ async def stream_new_chat(
|
|||
await set_ai_responding(session, chat_id, UUID(user_id))
|
||||
# Load LLM config - supports both YAML (negative IDs) and database (positive IDs)
|
||||
agent_config: AgentConfig | None = None
|
||||
requested_llm_config_id = llm_config_id
|
||||
|
||||
async def _load_llm_bundle(
|
||||
config_id: int,
|
||||
) -> tuple[Any, AgentConfig | None, str | None]:
|
||||
if config_id >= 0:
|
||||
loaded_agent_config = await load_agent_config(
|
||||
session=session,
|
||||
config_id=config_id,
|
||||
search_space_id=search_space_id,
|
||||
)
|
||||
if not loaded_agent_config:
|
||||
return (
|
||||
None,
|
||||
None,
|
||||
f"Failed to load NewLLMConfig with id {config_id}",
|
||||
)
|
||||
return (
|
||||
create_chat_litellm_from_agent_config(loaded_agent_config),
|
||||
loaded_agent_config,
|
||||
None,
|
||||
)
|
||||
|
||||
loaded_llm_config = load_global_llm_config_by_id(config_id)
|
||||
if not loaded_llm_config:
|
||||
return None, None, f"Failed to load LLM config with id {config_id}"
|
||||
return (
|
||||
create_chat_litellm_from_config(loaded_llm_config),
|
||||
AgentConfig.from_yaml_config(loaded_llm_config),
|
||||
None,
|
||||
)
|
||||
|
||||
_t0 = time.perf_counter()
|
||||
if llm_config_id >= 0:
|
||||
# Positive ID: Load from NewLLMConfig database table
|
||||
agent_config = await load_agent_config(
|
||||
session=session,
|
||||
config_id=llm_config_id,
|
||||
search_space_id=search_space_id,
|
||||
try:
|
||||
llm_config_id = (
|
||||
await resolve_or_get_pinned_llm_config_id(
|
||||
session,
|
||||
thread_id=chat_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
selected_llm_config_id=llm_config_id,
|
||||
)
|
||||
).resolved_llm_config_id
|
||||
except ValueError as pin_error:
|
||||
yield _emit_stream_error(
|
||||
message=str(pin_error),
|
||||
error_kind="server_error",
|
||||
error_code="SERVER_ERROR",
|
||||
)
|
||||
if not agent_config:
|
||||
yield streaming_service.format_error(
|
||||
f"Failed to load NewLLMConfig with id {llm_config_id}"
|
||||
)
|
||||
yield streaming_service.format_done()
|
||||
return
|
||||
yield streaming_service.format_done()
|
||||
return
|
||||
|
||||
# Create ChatLiteLLM from AgentConfig
|
||||
llm = create_chat_litellm_from_agent_config(agent_config)
|
||||
else:
|
||||
# Negative ID: Load from in-memory global configs (includes dynamic OpenRouter models)
|
||||
llm_config = load_global_llm_config_by_id(llm_config_id)
|
||||
if not llm_config:
|
||||
yield streaming_service.format_error(
|
||||
f"Failed to load LLM config with id {llm_config_id}"
|
||||
)
|
||||
yield streaming_service.format_done()
|
||||
return
|
||||
|
||||
# Create ChatLiteLLM from global config dict
|
||||
llm = create_chat_litellm_from_config(llm_config)
|
||||
agent_config = AgentConfig.from_yaml_config(llm_config)
|
||||
llm, agent_config, llm_load_error = await _load_llm_bundle(llm_config_id)
|
||||
if llm_load_error:
|
||||
yield _emit_stream_error(
|
||||
message=llm_load_error,
|
||||
error_kind="server_error",
|
||||
error_code="SERVER_ERROR",
|
||||
)
|
||||
yield streaming_service.format_done()
|
||||
return
|
||||
_perf_log.info(
|
||||
"[stream_new_chat] LLM config loaded in %.3fs (config_id=%s)",
|
||||
time.perf_counter() - _t0,
|
||||
llm_config_id,
|
||||
)
|
||||
|
||||
# Premium quota reservation — applies to explicitly premium configs
|
||||
# AND Auto mode (which may route to premium models).
|
||||
# Premium quota reservation for pinned premium model only.
|
||||
_needs_premium_quota = (
|
||||
agent_config is not None
|
||||
and user_id
|
||||
and (agent_config.is_premium or agent_config.is_auto_mode)
|
||||
and agent_config.is_premium
|
||||
)
|
||||
if _needs_premium_quota:
|
||||
import uuid as _uuid
|
||||
|
|
@ -2036,19 +2211,79 @@ async def stream_new_chat(
|
|||
)
|
||||
_premium_reserved = reserve_amount
|
||||
if not quota_result.allowed:
|
||||
if agent_config.is_premium:
|
||||
yield streaming_service.format_error(
|
||||
"Premium token quota exceeded. Please purchase more tokens to continue using premium models."
|
||||
if requested_llm_config_id == 0:
|
||||
try:
|
||||
llm_config_id = (
|
||||
await resolve_or_get_pinned_llm_config_id(
|
||||
session,
|
||||
thread_id=chat_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
selected_llm_config_id=0,
|
||||
force_repin_free=True,
|
||||
)
|
||||
).resolved_llm_config_id
|
||||
except ValueError as pin_error:
|
||||
yield _emit_stream_error(
|
||||
message=str(pin_error),
|
||||
error_kind="server_error",
|
||||
error_code="SERVER_ERROR",
|
||||
)
|
||||
yield streaming_service.format_done()
|
||||
return
|
||||
|
||||
llm, agent_config, llm_load_error = await _load_llm_bundle(llm_config_id)
|
||||
if llm_load_error:
|
||||
yield _emit_stream_error(
|
||||
message=llm_load_error,
|
||||
error_kind="server_error",
|
||||
error_code="SERVER_ERROR",
|
||||
)
|
||||
yield streaming_service.format_done()
|
||||
return
|
||||
_premium_request_id = None
|
||||
_premium_reserved = 0
|
||||
_log_chat_stream_error(
|
||||
flow=flow,
|
||||
error_kind="premium_quota_exhausted",
|
||||
error_code="PREMIUM_QUOTA_EXHAUSTED",
|
||||
severity="info",
|
||||
is_expected=True,
|
||||
request_id=request_id,
|
||||
thread_id=chat_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
message=(
|
||||
"Premium quota exhausted on pinned model; auto-fallback switched to a free model"
|
||||
),
|
||||
extra={
|
||||
"fallback_config_id": llm_config_id,
|
||||
"auto_fallback": True,
|
||||
},
|
||||
)
|
||||
else:
|
||||
yield _emit_stream_error(
|
||||
message=(
|
||||
"Buy more tokens to continue with this model, or switch to a free model"
|
||||
),
|
||||
error_kind="premium_quota_exhausted",
|
||||
error_code="PREMIUM_QUOTA_EXHAUSTED",
|
||||
severity="info",
|
||||
is_expected=True,
|
||||
extra={
|
||||
"resolved_config_id": llm_config_id,
|
||||
"auto_fallback": False,
|
||||
},
|
||||
)
|
||||
yield streaming_service.format_done()
|
||||
return
|
||||
# Auto mode: quota exhausted but we can still proceed
|
||||
# (the router may pick a free model). Reset reservation.
|
||||
_premium_request_id = None
|
||||
_premium_reserved = 0
|
||||
|
||||
if not llm:
|
||||
yield streaming_service.format_error("Failed to create LLM instance")
|
||||
yield _emit_stream_error(
|
||||
message="Failed to create LLM instance",
|
||||
error_kind="server_error",
|
||||
error_code="SERVER_ERROR",
|
||||
)
|
||||
yield streaming_service.format_done()
|
||||
return
|
||||
|
||||
|
|
@ -2499,28 +2734,20 @@ async def stream_new_chat(
|
|||
)
|
||||
|
||||
# Finalize premium quota with actual tokens.
|
||||
# For Auto mode, only count tokens from calls that used premium models.
|
||||
if _premium_request_id and user_id:
|
||||
try:
|
||||
from app.services.token_quota_service import TokenQuotaService
|
||||
|
||||
if agent_config and agent_config.is_auto_mode:
|
||||
from app.services.llm_router_service import LLMRouterService
|
||||
|
||||
actual_premium_tokens = LLMRouterService.compute_premium_tokens(
|
||||
accumulator.calls
|
||||
)
|
||||
else:
|
||||
actual_premium_tokens = accumulator.grand_total
|
||||
|
||||
async with shielded_async_session() as quota_session:
|
||||
await TokenQuotaService.premium_finalize(
|
||||
db_session=quota_session,
|
||||
user_id=UUID(user_id),
|
||||
request_id=_premium_request_id,
|
||||
actual_tokens=actual_premium_tokens,
|
||||
actual_tokens=accumulator.grand_total,
|
||||
reserved_tokens=_premium_reserved,
|
||||
)
|
||||
_premium_request_id = None
|
||||
_premium_reserved = 0
|
||||
except Exception:
|
||||
logging.getLogger(__name__).warning(
|
||||
"Failed to finalize premium quota for user %s",
|
||||
|
|
@ -2586,12 +2813,25 @@ async def stream_new_chat(
|
|||
# Handle any errors
|
||||
import traceback
|
||||
|
||||
(
|
||||
error_kind,
|
||||
error_code,
|
||||
severity,
|
||||
is_expected,
|
||||
user_message,
|
||||
) = _classify_stream_exception(e, flow_label="chat")
|
||||
error_message = f"Error during chat: {e!s}"
|
||||
print(f"[stream_new_chat] {error_message}")
|
||||
print(f"[stream_new_chat] Exception type: {type(e).__name__}")
|
||||
print(f"[stream_new_chat] Traceback:\n{traceback.format_exc()}")
|
||||
|
||||
yield streaming_service.format_error(error_message)
|
||||
yield _emit_stream_error(
|
||||
message=user_message,
|
||||
error_kind=error_kind,
|
||||
error_code=error_code,
|
||||
severity=severity,
|
||||
is_expected=is_expected,
|
||||
)
|
||||
yield streaming_service.format_finish_step()
|
||||
yield streaming_service.format_finish()
|
||||
yield streaming_service.format_done()
|
||||
|
|
@ -2706,36 +2946,83 @@ async def stream_resume_chat(
|
|||
|
||||
accumulator = start_turn()
|
||||
|
||||
_emit_stream_error = partial(
|
||||
_emit_stream_terminal_error,
|
||||
streaming_service=streaming_service,
|
||||
flow="resume",
|
||||
request_id=request_id,
|
||||
thread_id=chat_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
session = async_session_maker()
|
||||
try:
|
||||
if user_id:
|
||||
await set_ai_responding(session, chat_id, UUID(user_id))
|
||||
|
||||
agent_config: AgentConfig | None = None
|
||||
_t0 = time.perf_counter()
|
||||
if llm_config_id >= 0:
|
||||
agent_config = await load_agent_config(
|
||||
session=session,
|
||||
config_id=llm_config_id,
|
||||
search_space_id=search_space_id,
|
||||
requested_llm_config_id = llm_config_id
|
||||
|
||||
async def _load_llm_bundle(
|
||||
config_id: int,
|
||||
) -> tuple[Any, AgentConfig | None, str | None]:
|
||||
if config_id >= 0:
|
||||
loaded_agent_config = await load_agent_config(
|
||||
session=session,
|
||||
config_id=config_id,
|
||||
search_space_id=search_space_id,
|
||||
)
|
||||
if not loaded_agent_config:
|
||||
return (
|
||||
None,
|
||||
None,
|
||||
f"Failed to load NewLLMConfig with id {config_id}",
|
||||
)
|
||||
return (
|
||||
create_chat_litellm_from_agent_config(loaded_agent_config),
|
||||
loaded_agent_config,
|
||||
None,
|
||||
)
|
||||
|
||||
loaded_llm_config = load_global_llm_config_by_id(config_id)
|
||||
if not loaded_llm_config:
|
||||
return None, None, f"Failed to load LLM config with id {config_id}"
|
||||
return (
|
||||
create_chat_litellm_from_config(loaded_llm_config),
|
||||
AgentConfig.from_yaml_config(loaded_llm_config),
|
||||
None,
|
||||
)
|
||||
if not agent_config:
|
||||
yield streaming_service.format_error(
|
||||
f"Failed to load NewLLMConfig with id {llm_config_id}"
|
||||
|
||||
_t0 = time.perf_counter()
|
||||
try:
|
||||
llm_config_id = (
|
||||
await resolve_or_get_pinned_llm_config_id(
|
||||
session,
|
||||
thread_id=chat_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
selected_llm_config_id=llm_config_id,
|
||||
)
|
||||
yield streaming_service.format_done()
|
||||
return
|
||||
llm = create_chat_litellm_from_agent_config(agent_config)
|
||||
else:
|
||||
llm_config = load_global_llm_config_by_id(llm_config_id)
|
||||
if not llm_config:
|
||||
yield streaming_service.format_error(
|
||||
f"Failed to load LLM config with id {llm_config_id}"
|
||||
)
|
||||
yield streaming_service.format_done()
|
||||
return
|
||||
llm = create_chat_litellm_from_config(llm_config)
|
||||
agent_config = AgentConfig.from_yaml_config(llm_config)
|
||||
).resolved_llm_config_id
|
||||
except ValueError as pin_error:
|
||||
yield _emit_stream_error(
|
||||
message=str(pin_error),
|
||||
error_kind="server_error",
|
||||
error_code="SERVER_ERROR",
|
||||
)
|
||||
yield streaming_service.format_done()
|
||||
return
|
||||
|
||||
llm, agent_config, llm_load_error = await _load_llm_bundle(llm_config_id)
|
||||
if llm_load_error:
|
||||
yield _emit_stream_error(
|
||||
message=llm_load_error,
|
||||
error_kind="server_error",
|
||||
error_code="SERVER_ERROR",
|
||||
)
|
||||
yield streaming_service.format_done()
|
||||
return
|
||||
_perf_log.info(
|
||||
"[stream_resume] LLM config loaded in %.3fs", time.perf_counter() - _t0
|
||||
)
|
||||
|
|
@ -2746,7 +3033,7 @@ async def stream_resume_chat(
|
|||
_resume_needs_premium = (
|
||||
agent_config is not None
|
||||
and user_id
|
||||
and (agent_config.is_premium or agent_config.is_auto_mode)
|
||||
and agent_config.is_premium
|
||||
)
|
||||
if _resume_needs_premium:
|
||||
import uuid as _uuid
|
||||
|
|
@ -2769,17 +3056,79 @@ async def stream_resume_chat(
|
|||
)
|
||||
_resume_premium_reserved = reserve_amount
|
||||
if not quota_result.allowed:
|
||||
if agent_config.is_premium:
|
||||
yield streaming_service.format_error(
|
||||
"Premium token quota exceeded. Please purchase more tokens to continue using premium models."
|
||||
if requested_llm_config_id == 0:
|
||||
try:
|
||||
llm_config_id = (
|
||||
await resolve_or_get_pinned_llm_config_id(
|
||||
session,
|
||||
thread_id=chat_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
selected_llm_config_id=0,
|
||||
force_repin_free=True,
|
||||
)
|
||||
).resolved_llm_config_id
|
||||
except ValueError as pin_error:
|
||||
yield _emit_stream_error(
|
||||
message=str(pin_error),
|
||||
error_kind="server_error",
|
||||
error_code="SERVER_ERROR",
|
||||
)
|
||||
yield streaming_service.format_done()
|
||||
return
|
||||
|
||||
llm, agent_config, llm_load_error = await _load_llm_bundle(llm_config_id)
|
||||
if llm_load_error:
|
||||
yield _emit_stream_error(
|
||||
message=llm_load_error,
|
||||
error_kind="server_error",
|
||||
error_code="SERVER_ERROR",
|
||||
)
|
||||
yield streaming_service.format_done()
|
||||
return
|
||||
_resume_premium_request_id = None
|
||||
_resume_premium_reserved = 0
|
||||
_log_chat_stream_error(
|
||||
flow="resume",
|
||||
error_kind="premium_quota_exhausted",
|
||||
error_code="PREMIUM_QUOTA_EXHAUSTED",
|
||||
severity="info",
|
||||
is_expected=True,
|
||||
request_id=request_id,
|
||||
thread_id=chat_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
message=(
|
||||
"Premium quota exhausted on pinned model; auto-fallback switched to a free model"
|
||||
),
|
||||
extra={
|
||||
"fallback_config_id": llm_config_id,
|
||||
"auto_fallback": True,
|
||||
},
|
||||
)
|
||||
else:
|
||||
yield _emit_stream_error(
|
||||
message=(
|
||||
"Buy more tokens to continue with this model, or switch to a free model"
|
||||
),
|
||||
error_kind="premium_quota_exhausted",
|
||||
error_code="PREMIUM_QUOTA_EXHAUSTED",
|
||||
severity="info",
|
||||
is_expected=True,
|
||||
extra={
|
||||
"resolved_config_id": llm_config_id,
|
||||
"auto_fallback": False,
|
||||
},
|
||||
)
|
||||
yield streaming_service.format_done()
|
||||
return
|
||||
_resume_premium_request_id = None
|
||||
_resume_premium_reserved = 0
|
||||
|
||||
if not llm:
|
||||
yield streaming_service.format_error("Failed to create LLM instance")
|
||||
yield _emit_stream_error(
|
||||
message="Failed to create LLM instance",
|
||||
error_kind="server_error",
|
||||
error_code="SERVER_ERROR",
|
||||
)
|
||||
yield streaming_service.format_done()
|
||||
return
|
||||
|
||||
|
|
@ -2920,23 +3269,16 @@ async def stream_resume_chat(
|
|||
try:
|
||||
from app.services.token_quota_service import TokenQuotaService
|
||||
|
||||
if agent_config and agent_config.is_auto_mode:
|
||||
from app.services.llm_router_service import LLMRouterService
|
||||
|
||||
actual_premium_tokens = LLMRouterService.compute_premium_tokens(
|
||||
accumulator.calls
|
||||
)
|
||||
else:
|
||||
actual_premium_tokens = accumulator.grand_total
|
||||
|
||||
async with shielded_async_session() as quota_session:
|
||||
await TokenQuotaService.premium_finalize(
|
||||
db_session=quota_session,
|
||||
user_id=UUID(user_id),
|
||||
request_id=_resume_premium_request_id,
|
||||
actual_tokens=actual_premium_tokens,
|
||||
actual_tokens=accumulator.grand_total,
|
||||
reserved_tokens=_resume_premium_reserved,
|
||||
)
|
||||
_resume_premium_request_id = None
|
||||
_resume_premium_reserved = 0
|
||||
except Exception:
|
||||
logging.getLogger(__name__).warning(
|
||||
"Failed to finalize premium quota for user %s (resume)",
|
||||
|
|
@ -2970,10 +3312,23 @@ async def stream_resume_chat(
|
|||
except Exception as e:
|
||||
import traceback
|
||||
|
||||
(
|
||||
error_kind,
|
||||
error_code,
|
||||
severity,
|
||||
is_expected,
|
||||
user_message,
|
||||
) = _classify_stream_exception(e, flow_label="resume")
|
||||
error_message = f"Error during resume: {e!s}"
|
||||
print(f"[stream_resume_chat] {error_message}")
|
||||
print(f"[stream_resume_chat] Traceback:\n{traceback.format_exc()}")
|
||||
yield streaming_service.format_error(error_message)
|
||||
yield _emit_stream_error(
|
||||
message=user_message,
|
||||
error_kind=error_kind,
|
||||
error_code=error_code,
|
||||
severity=severity,
|
||||
is_expected=is_expected,
|
||||
)
|
||||
yield streaming_service.format_finish_step()
|
||||
yield streaming_service.format_finish()
|
||||
yield streaming_service.format_done()
|
||||
|
|
|
|||
|
|
@ -0,0 +1,329 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.auto_model_pin_service import (
|
||||
AUTO_FASTEST_MODE,
|
||||
resolve_or_get_pinned_llm_config_id,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
@dataclass
|
||||
class _FakeQuotaResult:
|
||||
allowed: bool
|
||||
|
||||
|
||||
class _FakeExecResult:
|
||||
def __init__(self, thread):
|
||||
self._thread = thread
|
||||
|
||||
def unique(self):
|
||||
return self
|
||||
|
||||
def scalar_one_or_none(self):
|
||||
return self._thread
|
||||
|
||||
|
||||
class _FakeSession:
|
||||
def __init__(self, thread):
|
||||
self.thread = thread
|
||||
self.commit_count = 0
|
||||
|
||||
async def execute(self, _stmt):
|
||||
return _FakeExecResult(self.thread)
|
||||
|
||||
async def commit(self):
|
||||
self.commit_count += 1
|
||||
|
||||
|
||||
def _thread(
|
||||
*,
|
||||
search_space_id: int = 10,
|
||||
pinned_llm_config_id: int | None = None,
|
||||
pinned_auto_mode: str | None = None,
|
||||
):
|
||||
return SimpleNamespace(
|
||||
id=1,
|
||||
search_space_id=search_space_id,
|
||||
pinned_llm_config_id=pinned_llm_config_id,
|
||||
pinned_auto_mode=pinned_auto_mode,
|
||||
pinned_at=None,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auto_first_turn_pins_one_model(monkeypatch):
|
||||
from app.config import config
|
||||
|
||||
session = _FakeSession(_thread())
|
||||
monkeypatch.setattr(
|
||||
config,
|
||||
"GLOBAL_LLM_CONFIGS",
|
||||
[
|
||||
{"id": -2, "provider": "OPENAI", "model_name": "gpt-free", "api_key": "k1"},
|
||||
{"id": -1, "provider": "OPENAI", "model_name": "gpt-prem", "api_key": "k2", "billing_tier": "premium"},
|
||||
],
|
||||
)
|
||||
|
||||
async def _allowed(*_args, **_kwargs):
|
||||
return _FakeQuotaResult(allowed=True)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage",
|
||||
_allowed,
|
||||
)
|
||||
|
||||
result = await resolve_or_get_pinned_llm_config_id(
|
||||
session,
|
||||
thread_id=1,
|
||||
search_space_id=10,
|
||||
user_id="00000000-0000-0000-0000-000000000001",
|
||||
selected_llm_config_id=0,
|
||||
)
|
||||
assert result.resolved_llm_config_id in {-1, -2}
|
||||
assert session.thread.pinned_llm_config_id == result.resolved_llm_config_id
|
||||
assert session.thread.pinned_auto_mode == AUTO_FASTEST_MODE
|
||||
assert session.thread.pinned_at is not None
|
||||
assert session.commit_count == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_next_turn_reuses_existing_pin(monkeypatch):
|
||||
from app.config import config
|
||||
|
||||
session = _FakeSession(
|
||||
_thread(pinned_llm_config_id=-1, pinned_auto_mode=AUTO_FASTEST_MODE)
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
config,
|
||||
"GLOBAL_LLM_CONFIGS",
|
||||
[
|
||||
{"id": -1, "provider": "OPENAI", "model_name": "gpt-prem", "api_key": "k2", "billing_tier": "premium"},
|
||||
],
|
||||
)
|
||||
|
||||
async def _must_not_call(*_args, **_kwargs):
|
||||
raise AssertionError("premium_get_usage should not be called for valid pin reuse")
|
||||
|
||||
monkeypatch.setattr(
|
||||
"app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage",
|
||||
_must_not_call,
|
||||
)
|
||||
|
||||
result = await resolve_or_get_pinned_llm_config_id(
|
||||
session,
|
||||
thread_id=1,
|
||||
search_space_id=10,
|
||||
user_id="00000000-0000-0000-0000-000000000001",
|
||||
selected_llm_config_id=0,
|
||||
)
|
||||
assert result.resolved_llm_config_id == -1
|
||||
assert result.from_existing_pin is True
|
||||
assert session.commit_count == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_premium_eligible_auto_can_pin_premium(monkeypatch):
|
||||
from app.config import config
|
||||
|
||||
session = _FakeSession(_thread())
|
||||
monkeypatch.setattr(
|
||||
config,
|
||||
"GLOBAL_LLM_CONFIGS",
|
||||
[
|
||||
{"id": -1, "provider": "OPENAI", "model_name": "gpt-prem", "api_key": "k2", "billing_tier": "premium"},
|
||||
],
|
||||
)
|
||||
|
||||
async def _allowed(*_args, **_kwargs):
|
||||
return _FakeQuotaResult(allowed=True)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage",
|
||||
_allowed,
|
||||
)
|
||||
|
||||
result = await resolve_or_get_pinned_llm_config_id(
|
||||
session,
|
||||
thread_id=1,
|
||||
search_space_id=10,
|
||||
user_id="00000000-0000-0000-0000-000000000001",
|
||||
selected_llm_config_id=0,
|
||||
)
|
||||
assert result.resolved_llm_config_id == -1
|
||||
assert result.resolved_tier == "premium"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_premium_ineligible_auto_pins_free_only(monkeypatch):
|
||||
from app.config import config
|
||||
|
||||
session = _FakeSession(_thread())
|
||||
monkeypatch.setattr(
|
||||
config,
|
||||
"GLOBAL_LLM_CONFIGS",
|
||||
[
|
||||
{"id": -2, "provider": "OPENAI", "model_name": "gpt-free", "api_key": "k1", "billing_tier": "free"},
|
||||
{"id": -1, "provider": "OPENAI", "model_name": "gpt-prem", "api_key": "k2", "billing_tier": "premium"},
|
||||
],
|
||||
)
|
||||
|
||||
async def _blocked(*_args, **_kwargs):
|
||||
return _FakeQuotaResult(allowed=False)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage",
|
||||
_blocked,
|
||||
)
|
||||
|
||||
result = await resolve_or_get_pinned_llm_config_id(
|
||||
session,
|
||||
thread_id=1,
|
||||
search_space_id=10,
|
||||
user_id="00000000-0000-0000-0000-000000000001",
|
||||
selected_llm_config_id=0,
|
||||
)
|
||||
assert result.resolved_llm_config_id == -2
|
||||
assert result.resolved_tier == "free"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pinned_premium_stays_premium_after_quota_exhaustion(monkeypatch):
|
||||
from app.config import config
|
||||
|
||||
session = _FakeSession(
|
||||
_thread(pinned_llm_config_id=-1, pinned_auto_mode=AUTO_FASTEST_MODE)
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
config,
|
||||
"GLOBAL_LLM_CONFIGS",
|
||||
[
|
||||
{"id": -2, "provider": "OPENAI", "model_name": "gpt-free", "api_key": "k1", "billing_tier": "free"},
|
||||
{"id": -1, "provider": "OPENAI", "model_name": "gpt-prem", "api_key": "k2", "billing_tier": "premium"},
|
||||
],
|
||||
)
|
||||
|
||||
async def _blocked(*_args, **_kwargs):
|
||||
return _FakeQuotaResult(allowed=False)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage",
|
||||
_blocked,
|
||||
)
|
||||
|
||||
result = await resolve_or_get_pinned_llm_config_id(
|
||||
session,
|
||||
thread_id=1,
|
||||
search_space_id=10,
|
||||
user_id="00000000-0000-0000-0000-000000000001",
|
||||
selected_llm_config_id=0,
|
||||
)
|
||||
assert result.resolved_llm_config_id == -1
|
||||
assert result.from_existing_pin is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_force_repin_free_switches_auto_premium_pin_to_free(monkeypatch):
|
||||
from app.config import config
|
||||
|
||||
session = _FakeSession(
|
||||
_thread(pinned_llm_config_id=-1, pinned_auto_mode=AUTO_FASTEST_MODE)
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
config,
|
||||
"GLOBAL_LLM_CONFIGS",
|
||||
[
|
||||
{"id": -2, "provider": "OPENAI", "model_name": "gpt-free", "api_key": "k1", "billing_tier": "free"},
|
||||
{"id": -1, "provider": "OPENAI", "model_name": "gpt-prem", "api_key": "k2", "billing_tier": "premium"},
|
||||
],
|
||||
)
|
||||
|
||||
async def _blocked(*_args, **_kwargs):
|
||||
return _FakeQuotaResult(allowed=False)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage",
|
||||
_blocked,
|
||||
)
|
||||
|
||||
result = await resolve_or_get_pinned_llm_config_id(
|
||||
session,
|
||||
thread_id=1,
|
||||
search_space_id=10,
|
||||
user_id="00000000-0000-0000-0000-000000000001",
|
||||
selected_llm_config_id=0,
|
||||
force_repin_free=True,
|
||||
)
|
||||
assert result.resolved_llm_config_id == -2
|
||||
assert result.resolved_tier == "free"
|
||||
assert result.from_existing_pin is False
|
||||
assert session.thread.pinned_llm_config_id == -2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_explicit_user_model_change_clears_pin(monkeypatch):
|
||||
from app.config import config
|
||||
|
||||
session = _FakeSession(
|
||||
_thread(pinned_llm_config_id=-2, pinned_auto_mode=AUTO_FASTEST_MODE)
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
config,
|
||||
"GLOBAL_LLM_CONFIGS",
|
||||
[
|
||||
{"id": -2, "provider": "OPENAI", "model_name": "gpt-free", "api_key": "k1"},
|
||||
],
|
||||
)
|
||||
|
||||
result = await resolve_or_get_pinned_llm_config_id(
|
||||
session,
|
||||
thread_id=1,
|
||||
search_space_id=10,
|
||||
user_id="00000000-0000-0000-0000-000000000001",
|
||||
selected_llm_config_id=7,
|
||||
)
|
||||
assert result.resolved_llm_config_id == 7
|
||||
assert session.thread.pinned_llm_config_id is None
|
||||
assert session.thread.pinned_auto_mode is None
|
||||
assert session.thread.pinned_at is None
|
||||
assert session.commit_count == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_pinned_config_repairs_with_new_pin(monkeypatch):
|
||||
from app.config import config
|
||||
|
||||
session = _FakeSession(
|
||||
_thread(pinned_llm_config_id=-999, pinned_auto_mode=AUTO_FASTEST_MODE)
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
config,
|
||||
"GLOBAL_LLM_CONFIGS",
|
||||
[
|
||||
{"id": -2, "provider": "OPENAI", "model_name": "gpt-free", "api_key": "k1"},
|
||||
],
|
||||
)
|
||||
|
||||
async def _allowed(*_args, **_kwargs):
|
||||
return _FakeQuotaResult(allowed=True)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage",
|
||||
_allowed,
|
||||
)
|
||||
|
||||
result = await resolve_or_get_pinned_llm_config_id(
|
||||
session,
|
||||
thread_id=1,
|
||||
search_space_id=10,
|
||||
user_id="00000000-0000-0000-0000-000000000001",
|
||||
selected_llm_config_id=0,
|
||||
)
|
||||
assert result.resolved_llm_config_id == -2
|
||||
assert session.thread.pinned_llm_config_id == -2
|
||||
assert session.commit_count == 1
|
||||
|
|
@ -1,9 +1,19 @@
|
|||
import inspect
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
import app.tasks.chat.stream_new_chat as stream_new_chat_module
|
||||
from app.agents.new_chat.errors import BusyError
|
||||
from app.tasks.chat.stream_new_chat import (
|
||||
StreamResult,
|
||||
_classify_stream_exception,
|
||||
_contract_enforcement_active,
|
||||
_evaluate_file_contract_outcome,
|
||||
_log_chat_stream_error,
|
||||
_tool_output_has_error,
|
||||
)
|
||||
|
||||
|
|
@ -45,3 +55,217 @@ def test_contract_enforcement_local_only():
|
|||
|
||||
result.filesystem_mode = "cloud"
|
||||
assert not _contract_enforcement_active(result)
|
||||
|
||||
|
||||
def _extract_chat_stream_payload(record_message: str) -> dict:
|
||||
prefix = "[chat_stream_error] "
|
||||
assert record_message.startswith(prefix)
|
||||
return json.loads(record_message[len(prefix) :])
|
||||
|
||||
|
||||
def test_unified_chat_stream_error_log_schema(caplog):
|
||||
with caplog.at_level(logging.INFO, logger="app.tasks.chat.stream_new_chat"):
|
||||
_log_chat_stream_error(
|
||||
flow="new",
|
||||
error_kind="server_error",
|
||||
error_code="SERVER_ERROR",
|
||||
severity="warn",
|
||||
is_expected=False,
|
||||
request_id="req-123",
|
||||
thread_id=101,
|
||||
search_space_id=202,
|
||||
user_id="user-1",
|
||||
message="Error during chat: boom",
|
||||
)
|
||||
|
||||
record = next(r for r in caplog.records if "[chat_stream_error]" in r.message)
|
||||
payload = _extract_chat_stream_payload(record.message)
|
||||
|
||||
required_keys = {
|
||||
"event",
|
||||
"flow",
|
||||
"error_kind",
|
||||
"error_code",
|
||||
"severity",
|
||||
"is_expected",
|
||||
"request_id",
|
||||
"thread_id",
|
||||
"search_space_id",
|
||||
"user_id",
|
||||
"message",
|
||||
}
|
||||
assert required_keys.issubset(payload.keys())
|
||||
assert payload["event"] == "chat_stream_error"
|
||||
assert payload["flow"] == "new"
|
||||
assert payload["error_code"] == "SERVER_ERROR"
|
||||
|
||||
|
||||
def test_premium_quota_uses_unified_chat_stream_log_shape(caplog):
|
||||
with caplog.at_level(logging.INFO, logger="app.tasks.chat.stream_new_chat"):
|
||||
_log_chat_stream_error(
|
||||
flow="resume",
|
||||
error_kind="premium_quota_exhausted",
|
||||
error_code="PREMIUM_QUOTA_EXHAUSTED",
|
||||
severity="info",
|
||||
is_expected=True,
|
||||
request_id="req-premium",
|
||||
thread_id=303,
|
||||
search_space_id=404,
|
||||
user_id="user-2",
|
||||
message="Buy more tokens to continue with this model, or switch to a free model",
|
||||
extra={"auto_fallback": False},
|
||||
)
|
||||
|
||||
record = next(r for r in caplog.records if "[chat_stream_error]" in r.message)
|
||||
payload = _extract_chat_stream_payload(record.message)
|
||||
assert payload["event"] == "chat_stream_error"
|
||||
assert payload["error_kind"] == "premium_quota_exhausted"
|
||||
assert payload["error_code"] == "PREMIUM_QUOTA_EXHAUSTED"
|
||||
assert payload["flow"] == "resume"
|
||||
assert payload["is_expected"] is True
|
||||
assert payload["auto_fallback"] is False
|
||||
|
||||
|
||||
def test_stream_error_emission_keeps_machine_error_codes():
|
||||
source = inspect.getsource(stream_new_chat_module)
|
||||
format_error_calls = re.findall(r"format_error\(", source)
|
||||
emitted_error_codes = set(re.findall(r'error_code="([A-Z_]+)"', source))
|
||||
|
||||
# All stream paths should route through one shared terminal error emitter.
|
||||
assert len(format_error_calls) == 1
|
||||
assert {
|
||||
"PREMIUM_QUOTA_EXHAUSTED",
|
||||
"SERVER_ERROR",
|
||||
}.issubset(emitted_error_codes)
|
||||
assert 'flow: Literal["new", "regenerate"] = "new"' in source
|
||||
assert "_emit_stream_terminal_error" in source
|
||||
assert "flow=flow" in source
|
||||
assert 'flow="resume"' in source
|
||||
|
||||
|
||||
def test_stream_exception_classifies_rate_limited():
|
||||
exc = Exception(
|
||||
'{"error":{"type":"rate_limit_error","message":"Rate limited. Please try again later."}}'
|
||||
)
|
||||
kind, code, severity, is_expected, user_message = _classify_stream_exception(
|
||||
exc, flow_label="chat"
|
||||
)
|
||||
assert kind == "rate_limited"
|
||||
assert code == "RATE_LIMITED"
|
||||
assert severity == "warn"
|
||||
assert is_expected is True
|
||||
assert "temporarily rate-limited" in user_message
|
||||
|
||||
|
||||
def test_stream_exception_classifies_thread_busy():
|
||||
exc = BusyError(request_id="thread-123")
|
||||
kind, code, severity, is_expected, user_message = _classify_stream_exception(
|
||||
exc, flow_label="chat"
|
||||
)
|
||||
assert kind == "thread_busy"
|
||||
assert code == "THREAD_BUSY"
|
||||
assert severity == "warn"
|
||||
assert is_expected is True
|
||||
assert "still finishing for this thread" in user_message
|
||||
|
||||
|
||||
def test_stream_exception_classifies_thread_busy_from_message():
|
||||
exc = Exception("Thread is busy with another request")
|
||||
kind, code, severity, is_expected, user_message = _classify_stream_exception(
|
||||
exc, flow_label="chat"
|
||||
)
|
||||
assert kind == "thread_busy"
|
||||
assert code == "THREAD_BUSY"
|
||||
assert severity == "warn"
|
||||
assert is_expected is True
|
||||
assert "still finishing for this thread" in user_message
|
||||
|
||||
|
||||
def test_premium_classification_is_error_code_driven():
|
||||
classifier_path = Path(__file__).resolve().parents[3] / "surfsense_web/lib/chat/chat-error-classifier.ts"
|
||||
source = classifier_path.read_text(encoding="utf-8")
|
||||
|
||||
assert "PREMIUM_KEYWORDS" not in source
|
||||
assert "RATE_LIMIT_KEYWORDS" not in source
|
||||
assert "normalized.includes(" not in source
|
||||
assert 'if (errorCode === "PREMIUM_QUOTA_EXHAUSTED") {' in source
|
||||
|
||||
|
||||
def test_stream_terminal_error_handler_has_pre_accept_soft_rollback_hook():
|
||||
page_path = (
|
||||
Path(__file__).resolve().parents[3]
|
||||
/ "surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx"
|
||||
)
|
||||
source = page_path.read_text(encoding="utf-8")
|
||||
|
||||
assert "onPreAcceptFailure?: () => Promise<void>;" in source
|
||||
assert "if (!accepted) {" in source
|
||||
assert "await onPreAcceptFailure?.();" in source
|
||||
assert "await onAcceptedStreamError?.();" in source
|
||||
assert "setMessages((prev) => prev.filter((m) => m.id !== userMsgId));" in source
|
||||
assert "setMessageDocumentsMap((prev) => {" in source
|
||||
|
||||
|
||||
def test_toast_only_pre_accept_policy_has_no_inline_failed_marker():
|
||||
user_message_path = (
|
||||
Path(__file__).resolve().parents[3] / "surfsense_web/components/assistant-ui/user-message.tsx"
|
||||
)
|
||||
source = user_message_path.read_text(encoding="utf-8")
|
||||
|
||||
assert "Not sent. Edit and retry." not in source
|
||||
assert "failed_pre_accept" not in source
|
||||
|
||||
|
||||
def test_network_send_failures_use_unified_retry_toast_message():
|
||||
classifier_path = Path(__file__).resolve().parents[3] / "surfsense_web/lib/chat/chat-error-classifier.ts"
|
||||
classifier_source = classifier_path.read_text(encoding="utf-8")
|
||||
page_path = (
|
||||
Path(__file__).resolve().parents[3]
|
||||
/ "surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx"
|
||||
)
|
||||
page_source = page_path.read_text(encoding="utf-8")
|
||||
|
||||
assert '"send_failed_pre_accept"' in classifier_source
|
||||
assert 'errorCode === "SEND_FAILED_PRE_ACCEPT"' in classifier_source
|
||||
assert "if (withCode.code) return withCode.code;" in classifier_source
|
||||
assert 'userMessage: "Message not sent. Please retry."' in classifier_source
|
||||
assert 'userMessage: "Connection issue. Please try again."' in classifier_source
|
||||
assert "tagPreAcceptSendFailure(error)" in page_source
|
||||
assert "const passthroughCodes = new Set([" in page_source
|
||||
assert '"PREMIUM_QUOTA_EXHAUSTED"' in page_source
|
||||
assert '"THREAD_BUSY"' in page_source
|
||||
assert '"AUTH_EXPIRED"' in page_source
|
||||
assert '"UNAUTHORIZED"' in page_source
|
||||
assert '"RATE_LIMITED"' in page_source
|
||||
assert '"NETWORK_ERROR"' in page_source
|
||||
assert '"STREAM_PARSE_ERROR"' in page_source
|
||||
assert '"TOOL_EXECUTION_ERROR"' in page_source
|
||||
assert '"PERSIST_MESSAGE_FAILED"' in page_source
|
||||
assert '"SERVER_ERROR"' in page_source
|
||||
assert "passthroughCodes.has(existingCode)" in page_source
|
||||
assert 'errorCode: "SEND_FAILED_PRE_ACCEPT"' in page_source
|
||||
assert 'errorCode: "NETWORK_ERROR"' not in page_source
|
||||
assert "Failed to start chat. Please try again." not in page_source
|
||||
|
||||
|
||||
def test_pre_post_accept_abort_contract_exists_for_new_resume_regenerate_flows():
|
||||
page_path = (
|
||||
Path(__file__).resolve().parents[3]
|
||||
/ "surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx"
|
||||
)
|
||||
source = page_path.read_text(encoding="utf-8")
|
||||
|
||||
# Each flow tracks accepted boundary and passes it into shared terminal handling.
|
||||
assert "let newAccepted = false;" in source
|
||||
assert "let resumeAccepted = false;" in source
|
||||
assert "let regenerateAccepted = false;" in source
|
||||
assert "accepted: newAccepted," in source
|
||||
assert "accepted: resumeAccepted," in source
|
||||
assert "accepted: regenerateAccepted," in source
|
||||
|
||||
# Pre-accept abort in resume/regenerate exits without persistence.
|
||||
assert "if (!resumeAccepted) return;" in source
|
||||
assert "if (!regenerateAccepted) return;" in source
|
||||
|
||||
# New flow persists only when accepted and not already persisted.
|
||||
assert "if (newAccepted && !userPersisted) {" in source
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
45
surfsense_web/atoms/chat/premium-alert.atom.ts
Normal file
45
surfsense_web/atoms/chat/premium-alert.atom.ts
Normal file
|
|
@ -0,0 +1,45 @@
|
|||
import { atom } from "jotai";
|
||||
|
||||
export type PremiumAlertState = {
|
||||
message: string;
|
||||
};
|
||||
|
||||
export const premiumAlertByThreadAtom = atom<Record<number, PremiumAlertState>>({});
|
||||
|
||||
export const setPremiumAlertForThreadAtom = atom(
|
||||
null,
|
||||
(
|
||||
get,
|
||||
set,
|
||||
payload: {
|
||||
threadId: number;
|
||||
message: string;
|
||||
userId?: string | null;
|
||||
}
|
||||
) => {
|
||||
const storageKey = `surfsense-premium-alert-seen-v1:${payload.userId ?? "anonymous"}`;
|
||||
|
||||
if (typeof window !== "undefined") {
|
||||
const hasSeen = localStorage.getItem(storageKey) === "true";
|
||||
if (hasSeen) return;
|
||||
}
|
||||
|
||||
const current = get(premiumAlertByThreadAtom);
|
||||
set(premiumAlertByThreadAtom, {
|
||||
...current,
|
||||
[payload.threadId]: { message: payload.message },
|
||||
});
|
||||
|
||||
if (typeof window !== "undefined") {
|
||||
localStorage.setItem(storageKey, "true");
|
||||
}
|
||||
}
|
||||
);
|
||||
|
||||
export const clearPremiumAlertForThreadAtom = atom(null, (get, set, threadId: number) => {
|
||||
const current = get(premiumAlertByThreadAtom);
|
||||
if (!(threadId in current)) return;
|
||||
const next = { ...current };
|
||||
delete next[threadId];
|
||||
set(premiumAlertByThreadAtom, next);
|
||||
});
|
||||
|
|
@ -37,10 +37,13 @@ import {
|
|||
toggleToolAtom,
|
||||
} from "@/atoms/agent-tools/agent-tools.atoms";
|
||||
import { chatSessionStateAtom } from "@/atoms/chat/chat-session-state.atom";
|
||||
import {
|
||||
mentionedDocumentsAtom,
|
||||
} from "@/atoms/chat/mentioned-documents.atom";
|
||||
import { currentThreadAtom } from "@/atoms/chat/current-thread.atom";
|
||||
import { mentionedDocumentsAtom } from "@/atoms/chat/mentioned-documents.atom";
|
||||
import { pendingUserImageDataUrlsAtom } from "@/atoms/chat/pending-user-images.atom";
|
||||
import {
|
||||
clearPremiumAlertForThreadAtom,
|
||||
premiumAlertByThreadAtom,
|
||||
} from "@/atoms/chat/premium-alert.atom";
|
||||
import { connectorDialogOpenAtom } from "@/atoms/connector-dialog/connector-dialog.atoms";
|
||||
import { connectorsAtom } from "@/atoms/connectors/connector-query.atoms";
|
||||
import { membersAtom } from "@/atoms/members/members-query.atoms";
|
||||
|
|
@ -135,6 +138,9 @@ const ThreadContent: FC = () => {
|
|||
style={{ paddingBottom: "max(1rem, env(safe-area-inset-bottom))" }}
|
||||
>
|
||||
<ThreadScrollToBottom />
|
||||
<AuiIf condition={({ thread }) => !thread.isEmpty}>
|
||||
<PremiumQuotaPinnedAlert />
|
||||
</AuiIf>
|
||||
<AuiIf condition={({ thread }) => !thread.isEmpty}>
|
||||
<Composer />
|
||||
</AuiIf>
|
||||
|
|
@ -144,6 +150,37 @@ const ThreadContent: FC = () => {
|
|||
);
|
||||
};
|
||||
|
||||
const PremiumQuotaPinnedAlert: FC = () => {
|
||||
const currentThreadState = useAtomValue(currentThreadAtom);
|
||||
const alertsByThread = useAtomValue(premiumAlertByThreadAtom);
|
||||
const clearPremiumAlertForThread = useSetAtom(clearPremiumAlertForThreadAtom);
|
||||
|
||||
const currentThreadId = currentThreadState?.id;
|
||||
if (!currentThreadId) return null;
|
||||
|
||||
const alert = alertsByThread[currentThreadId];
|
||||
if (!alert) return null;
|
||||
|
||||
return (
|
||||
<div className="mx-0 overflow-hidden rounded-2xl border-input bg-muted px-4 py-4 text-foreground select-none">
|
||||
<div className="flex items-center gap-2">
|
||||
<AlertCircle className="size-4 shrink-0 text-muted-foreground" />
|
||||
<div className="min-w-0 flex-1">
|
||||
<p className="text-sm">{alert.message}</p>
|
||||
</div>
|
||||
<button
|
||||
type="button"
|
||||
className="inline-flex size-6 items-center justify-center text-muted-foreground transition-colors hover:text-foreground"
|
||||
aria-label="Dismiss premium quota alert"
|
||||
onClick={() => clearPremiumAlertForThread(currentThreadId)}
|
||||
>
|
||||
<X className="size-4" />
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
const ThreadScrollToBottom: FC = () => {
|
||||
return (
|
||||
<ThreadPrimitive.ScrollToBottom asChild>
|
||||
|
|
|
|||
|
|
@ -104,7 +104,13 @@ export function AnonymousChat({ model }: AnonymousChatProps) {
|
|||
setMessages((prev) => prev.filter((m) => m.id !== assistantId));
|
||||
return;
|
||||
}
|
||||
throw new Error(`Stream error: ${response.status}`);
|
||||
const body = await response.text().catch(() => "");
|
||||
const errorCode = response.status === 409 ? "THREAD_BUSY" : "SERVER_ERROR";
|
||||
const message =
|
||||
errorCode === "THREAD_BUSY"
|
||||
? "A previous response is still stopping. Please try again in a moment."
|
||||
: `Stream error: ${response.status}`;
|
||||
throw Object.assign(new Error(body || message), { errorCode });
|
||||
}
|
||||
|
||||
for await (const event of readSSEStream(response)) {
|
||||
|
|
@ -115,10 +121,12 @@ export function AnonymousChat({ model }: AnonymousChatProps) {
|
|||
prev.map((m) => (m.id === assistantId ? { ...m, content: m.content + event.delta } : m))
|
||||
);
|
||||
} else if (event.type === "error") {
|
||||
const message =
|
||||
event.errorCode === "THREAD_BUSY"
|
||||
? "A previous response is still stopping. Please try again in a moment."
|
||||
: event.errorText;
|
||||
setMessages((prev) =>
|
||||
prev.map((m) =>
|
||||
m.id === assistantId ? { ...m, content: m.content || event.errorText } : m
|
||||
)
|
||||
prev.map((m) => (m.id === assistantId ? { ...m, content: m.content || message } : m))
|
||||
);
|
||||
} else if ("type" in event && event.type === "data-token-usage") {
|
||||
// After streaming completes, refresh quota
|
||||
|
|
|
|||
|
|
@ -55,6 +55,48 @@ function parseCaptchaError(status: number, body: string): string | null {
|
|||
return null;
|
||||
}
|
||||
|
||||
function normalizeFreeChatErrorMessage(error: unknown): string {
|
||||
if (!(error instanceof Error)) return "An unexpected error occurred";
|
||||
const code = (error as Error & { errorCode?: string }).errorCode;
|
||||
if (code === "THREAD_BUSY") {
|
||||
return "A previous response is still stopping. Please try again in a moment.";
|
||||
}
|
||||
return error.message || "An unexpected error occurred";
|
||||
}
|
||||
|
||||
function toFreeChatHttpError(status: number, body: string): Error & { errorCode?: string } {
|
||||
let errorCode: string | undefined;
|
||||
let message = body || `Server error: ${status}`;
|
||||
try {
|
||||
const parsed = JSON.parse(body) as Record<string, unknown>;
|
||||
const detail =
|
||||
typeof parsed.detail === "object" && parsed.detail !== null
|
||||
? (parsed.detail as Record<string, unknown>)
|
||||
: null;
|
||||
errorCode =
|
||||
(typeof detail?.error_code === "string" ? detail.error_code : undefined) ??
|
||||
(typeof detail?.errorCode === "string" ? detail.errorCode : undefined) ??
|
||||
(typeof parsed.error_code === "string" ? parsed.error_code : undefined) ??
|
||||
(typeof parsed.errorCode === "string" ? parsed.errorCode : undefined);
|
||||
message =
|
||||
(typeof detail?.message === "string" ? detail.message : undefined) ??
|
||||
(typeof parsed.message === "string" ? parsed.message : undefined) ??
|
||||
(typeof parsed.detail === "string" ? parsed.detail : undefined) ??
|
||||
message;
|
||||
} catch {
|
||||
// non-json response
|
||||
}
|
||||
|
||||
if (!errorCode) {
|
||||
if (status === 409) errorCode = "THREAD_BUSY";
|
||||
else if (status === 429) errorCode = "RATE_LIMITED";
|
||||
else if (status === 401 || status === 403) errorCode = "AUTH_EXPIRED";
|
||||
else errorCode = "SERVER_ERROR";
|
||||
}
|
||||
|
||||
return Object.assign(new Error(message), { errorCode });
|
||||
}
|
||||
|
||||
export function FreeChatPage() {
|
||||
const anonMode = useAnonymousMode();
|
||||
const modelSlug = anonMode.isAnonymous ? anonMode.modelSlug : "";
|
||||
|
|
@ -124,7 +166,7 @@ export function FreeChatPage() {
|
|||
const body = await response.text().catch(() => "");
|
||||
const captchaCode = parseCaptchaError(response.status, body);
|
||||
if (captchaCode) return "captcha";
|
||||
throw new Error(body || `Server error: ${response.status}`);
|
||||
throw toFreeChatHttpError(response.status, body);
|
||||
}
|
||||
|
||||
const currentThinkingSteps = new Map<string, ThinkingStepData>();
|
||||
|
|
@ -244,7 +286,9 @@ export function FreeChatPage() {
|
|||
break;
|
||||
|
||||
case "error":
|
||||
throw new Error(parsed.errorText || "Server error");
|
||||
throw Object.assign(new Error(parsed.errorText || "Server error"), {
|
||||
errorCode: parsed.errorCode,
|
||||
});
|
||||
}
|
||||
}
|
||||
batcher.flush();
|
||||
|
|
@ -334,7 +378,7 @@ export function FreeChatPage() {
|
|||
} catch (error) {
|
||||
if (error instanceof Error && error.name === "AbortError") return;
|
||||
console.error("[FreeChatPage] Chat error:", error);
|
||||
const errorText = error instanceof Error ? error.message : "An unexpected error occurred";
|
||||
const errorText = normalizeFreeChatErrorMessage(error);
|
||||
setMessages((prev) =>
|
||||
prev.map((m) =>
|
||||
m.id === assistantMsgId
|
||||
|
|
@ -393,7 +437,7 @@ export function FreeChatPage() {
|
|||
} catch (error) {
|
||||
if (error instanceof Error && error.name === "AbortError") return;
|
||||
console.error("[FreeChatPage] Retry error:", error);
|
||||
const errorText = error instanceof Error ? error.message : "An unexpected error occurred";
|
||||
const errorText = normalizeFreeChatErrorMessage(error);
|
||||
setMessages((prev) =>
|
||||
prev.map((m) =>
|
||||
m.id === assistantMsgId
|
||||
|
|
|
|||
304
surfsense_web/lib/chat/chat-error-classifier.ts
Normal file
304
surfsense_web/lib/chat/chat-error-classifier.ts
Normal file
|
|
@ -0,0 +1,304 @@
|
|||
export type ChatFlow = "new" | "resume" | "regenerate";
|
||||
|
||||
export type ChatErrorKind =
|
||||
| "premium_quota_exhausted"
|
||||
| "thread_busy"
|
||||
| "send_failed_pre_accept"
|
||||
| "auth_expired"
|
||||
| "rate_limited"
|
||||
| "network_offline"
|
||||
| "stream_interrupted"
|
||||
| "stream_parse_error"
|
||||
| "tool_execution_error"
|
||||
| "persist_message_failed"
|
||||
| "server_error"
|
||||
| "unknown";
|
||||
|
||||
export type ChatErrorChannel = "pinned_inline" | "toast" | "silent";
|
||||
export type ChatTelemetryEvent = "chat_blocked" | "chat_error";
|
||||
export type ChatErrorSeverity = "info" | "warn" | "error";
|
||||
|
||||
export interface NormalizedChatError {
|
||||
kind: ChatErrorKind;
|
||||
channel: ChatErrorChannel;
|
||||
severity: ChatErrorSeverity;
|
||||
telemetryEvent: ChatTelemetryEvent;
|
||||
isExpected: boolean;
|
||||
userMessage: string;
|
||||
assistantMessage?: string;
|
||||
rawMessage?: string;
|
||||
errorCode?: string;
|
||||
details?: Record<string, unknown>;
|
||||
}
|
||||
|
||||
export interface RawChatErrorInput {
|
||||
error: unknown;
|
||||
flow: ChatFlow;
|
||||
context?: {
|
||||
searchSpaceId?: number;
|
||||
threadId?: number | null;
|
||||
};
|
||||
}
|
||||
|
||||
export const PREMIUM_QUOTA_ASSISTANT_MESSAGE =
|
||||
"I can’t continue with the current premium model because your premium tokens are exhausted. Switch to a free model or buy more tokens to continue.";
|
||||
|
||||
function getErrorMessage(error: unknown): string {
|
||||
if (error instanceof Error) return error.message;
|
||||
if (typeof error === "string") return error;
|
||||
try {
|
||||
return JSON.stringify(error);
|
||||
} catch {
|
||||
return "Unknown error";
|
||||
}
|
||||
}
|
||||
|
||||
function getErrorCode(error: unknown, parsedJson: Record<string, unknown> | null): string | undefined {
|
||||
if (error instanceof Error) {
|
||||
const withCode = error as Error & { errorCode?: string; code?: string };
|
||||
if (withCode.errorCode) return withCode.errorCode;
|
||||
if (withCode.code) return withCode.code;
|
||||
}
|
||||
|
||||
if (typeof error === "object" && error !== null) {
|
||||
const withCode = error as { errorCode?: unknown };
|
||||
if (typeof withCode.errorCode === "string" && withCode.errorCode) {
|
||||
return withCode.errorCode;
|
||||
}
|
||||
}
|
||||
|
||||
if (parsedJson) {
|
||||
const topLevelCode = parsedJson.errorCode;
|
||||
if (typeof topLevelCode === "string" && topLevelCode) {
|
||||
return topLevelCode;
|
||||
}
|
||||
}
|
||||
|
||||
return undefined;
|
||||
}
|
||||
|
||||
function parseEmbeddedJson(text: string): Record<string, unknown> | null {
|
||||
const candidates = [text];
|
||||
const firstBraceIdx = text.indexOf("{");
|
||||
if (firstBraceIdx >= 0) {
|
||||
candidates.push(text.slice(firstBraceIdx));
|
||||
}
|
||||
for (const candidate of candidates) {
|
||||
try {
|
||||
const parsed = JSON.parse(candidate);
|
||||
if (typeof parsed === "object" && parsed !== null) {
|
||||
return parsed as Record<string, unknown>;
|
||||
}
|
||||
} catch {
|
||||
// noop
|
||||
}
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
function inferProviderErrorType(parsedJson: Record<string, unknown> | null): string | undefined {
|
||||
if (!parsedJson) return undefined;
|
||||
const topLevelType = parsedJson.type;
|
||||
if (typeof topLevelType === "string" && topLevelType) return topLevelType;
|
||||
const nestedError = parsedJson.error;
|
||||
if (typeof nestedError === "object" && nestedError !== null) {
|
||||
const nestedType = (nestedError as Record<string, unknown>).type;
|
||||
if (typeof nestedType === "string" && nestedType) return nestedType;
|
||||
}
|
||||
return undefined;
|
||||
}
|
||||
|
||||
export function classifyChatError(input: RawChatErrorInput): NormalizedChatError {
|
||||
const { error } = input;
|
||||
const rawMessage = getErrorMessage(error);
|
||||
const parsedJson = parseEmbeddedJson(rawMessage);
|
||||
const errorCode = getErrorCode(error, parsedJson);
|
||||
const providerErrorType = inferProviderErrorType(parsedJson);
|
||||
const providerTypeNormalized = providerErrorType?.toLowerCase() ?? "";
|
||||
const errorName = error instanceof Error ? error.name : undefined;
|
||||
|
||||
if (errorName === "AbortError") {
|
||||
return {
|
||||
kind: "stream_interrupted",
|
||||
channel: "silent",
|
||||
severity: "info",
|
||||
telemetryEvent: "chat_error",
|
||||
isExpected: true,
|
||||
userMessage: "Request canceled.",
|
||||
rawMessage,
|
||||
errorCode,
|
||||
details: { flow: input.flow },
|
||||
};
|
||||
}
|
||||
|
||||
if (errorCode === "PREMIUM_QUOTA_EXHAUSTED") {
|
||||
return {
|
||||
kind: "premium_quota_exhausted",
|
||||
channel: "pinned_inline",
|
||||
severity: "info",
|
||||
telemetryEvent: "chat_blocked",
|
||||
isExpected: true,
|
||||
userMessage:
|
||||
"Buy more tokens to continue with this model, or switch to a free model.",
|
||||
assistantMessage: PREMIUM_QUOTA_ASSISTANT_MESSAGE,
|
||||
rawMessage,
|
||||
errorCode: errorCode ?? "PREMIUM_QUOTA_EXHAUSTED",
|
||||
details: { flow: input.flow },
|
||||
};
|
||||
}
|
||||
|
||||
if (
|
||||
errorCode === "THREAD_BUSY"
|
||||
) {
|
||||
return {
|
||||
kind: "thread_busy",
|
||||
channel: "toast",
|
||||
severity: "warn",
|
||||
telemetryEvent: "chat_blocked",
|
||||
isExpected: true,
|
||||
userMessage: "A previous response is still stopping. Please try again in a moment.",
|
||||
rawMessage,
|
||||
errorCode: errorCode ?? "THREAD_BUSY",
|
||||
details: { flow: input.flow },
|
||||
};
|
||||
}
|
||||
|
||||
if (errorCode === "SEND_FAILED_PRE_ACCEPT") {
|
||||
return {
|
||||
kind: "send_failed_pre_accept",
|
||||
channel: "toast",
|
||||
severity: "warn",
|
||||
telemetryEvent: "chat_blocked",
|
||||
isExpected: true,
|
||||
userMessage: "Message not sent. Please retry.",
|
||||
rawMessage,
|
||||
errorCode: errorCode ?? "SEND_FAILED_PRE_ACCEPT",
|
||||
details: { flow: input.flow },
|
||||
};
|
||||
}
|
||||
|
||||
if (
|
||||
errorCode === "AUTH_EXPIRED" ||
|
||||
errorCode === "UNAUTHORIZED"
|
||||
) {
|
||||
return {
|
||||
kind: "auth_expired",
|
||||
channel: "toast",
|
||||
severity: "warn",
|
||||
telemetryEvent: "chat_error",
|
||||
isExpected: true,
|
||||
userMessage: "Your session expired. Please sign in again.",
|
||||
rawMessage,
|
||||
errorCode: errorCode ?? "AUTH_EXPIRED",
|
||||
details: { flow: input.flow },
|
||||
};
|
||||
}
|
||||
|
||||
if (
|
||||
errorCode === "RATE_LIMITED" ||
|
||||
providerTypeNormalized === "rate_limit_error"
|
||||
) {
|
||||
return {
|
||||
kind: "rate_limited",
|
||||
channel: "toast",
|
||||
severity: "warn",
|
||||
telemetryEvent: "chat_blocked",
|
||||
isExpected: true,
|
||||
userMessage:
|
||||
"This model is temporarily rate-limited. Please try again in a few seconds or switch models.",
|
||||
rawMessage,
|
||||
errorCode: errorCode ?? "RATE_LIMITED",
|
||||
details: { flow: input.flow, providerErrorType },
|
||||
};
|
||||
}
|
||||
|
||||
if (errorCode === "NETWORK_ERROR") {
|
||||
return {
|
||||
kind: "network_offline",
|
||||
channel: "toast",
|
||||
severity: "warn",
|
||||
telemetryEvent: "chat_error",
|
||||
isExpected: true,
|
||||
userMessage: "Connection issue. Please try again.",
|
||||
rawMessage,
|
||||
errorCode: errorCode ?? "NETWORK_ERROR",
|
||||
details: { flow: input.flow },
|
||||
};
|
||||
}
|
||||
|
||||
if (
|
||||
errorCode === "STREAM_PARSE_ERROR"
|
||||
) {
|
||||
return {
|
||||
kind: "stream_parse_error",
|
||||
channel: "toast",
|
||||
severity: "error",
|
||||
telemetryEvent: "chat_error",
|
||||
isExpected: false,
|
||||
userMessage: "We hit a response formatting issue. Please try again.",
|
||||
rawMessage,
|
||||
errorCode: errorCode ?? "STREAM_PARSE_ERROR",
|
||||
details: { flow: input.flow },
|
||||
};
|
||||
}
|
||||
|
||||
if (
|
||||
errorCode === "TOOL_EXECUTION_ERROR"
|
||||
) {
|
||||
return {
|
||||
kind: "tool_execution_error",
|
||||
channel: "toast",
|
||||
severity: "error",
|
||||
telemetryEvent: "chat_error",
|
||||
isExpected: false,
|
||||
userMessage: "A tool failed while processing your request. Please try again.",
|
||||
rawMessage,
|
||||
errorCode: errorCode ?? "TOOL_EXECUTION_ERROR",
|
||||
details: { flow: input.flow },
|
||||
};
|
||||
}
|
||||
|
||||
if (
|
||||
errorCode === "PERSIST_MESSAGE_FAILED"
|
||||
) {
|
||||
return {
|
||||
kind: "persist_message_failed",
|
||||
channel: "toast",
|
||||
severity: "error",
|
||||
telemetryEvent: "chat_error",
|
||||
isExpected: false,
|
||||
userMessage: "Response generated, but saving failed. Please retry once.",
|
||||
rawMessage,
|
||||
errorCode: errorCode ?? "PERSIST_MESSAGE_FAILED",
|
||||
details: { flow: input.flow },
|
||||
};
|
||||
}
|
||||
|
||||
if (
|
||||
errorCode === "SERVER_ERROR"
|
||||
) {
|
||||
return {
|
||||
kind: "server_error",
|
||||
channel: "toast",
|
||||
severity: "error",
|
||||
telemetryEvent: "chat_error",
|
||||
isExpected: false,
|
||||
userMessage: "We couldn’t complete this response right now. Please try again.",
|
||||
rawMessage,
|
||||
errorCode: errorCode ?? "SERVER_ERROR",
|
||||
details: { flow: input.flow, providerErrorType },
|
||||
};
|
||||
}
|
||||
|
||||
return {
|
||||
kind: "unknown",
|
||||
channel: "toast",
|
||||
severity: "error",
|
||||
telemetryEvent: "chat_error",
|
||||
isExpected: false,
|
||||
userMessage: "We couldn’t complete this response right now. Please try again.",
|
||||
rawMessage,
|
||||
errorCode,
|
||||
details: { flow: input.flow, providerErrorType },
|
||||
};
|
||||
}
|
||||
|
|
@ -546,7 +546,7 @@ export type SSEEvent =
|
|||
}>;
|
||||
};
|
||||
}
|
||||
| { type: "error"; errorText: string };
|
||||
| { type: "error"; errorText: string; errorCode?: string };
|
||||
|
||||
/**
|
||||
* Async generator that reads an SSE stream and yields parsed JSON objects.
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
import posthog from "posthog-js";
|
||||
import { getConnectorTelemetryMeta } from "@/components/assistant-ui/connector-popup/constants/connector-constants";
|
||||
import type { ChatErrorKind, ChatFlow, ChatErrorSeverity } from "@/lib/chat/chat-error-classifier";
|
||||
|
||||
/**
|
||||
* PostHog Analytics Event Definitions
|
||||
|
|
@ -139,6 +140,55 @@ export function trackChatError(searchSpaceId: number, chatId: number, error?: st
|
|||
});
|
||||
}
|
||||
|
||||
export interface ChatFailureTelemetry {
|
||||
flow: ChatFlow;
|
||||
kind: ChatErrorKind;
|
||||
error_code?: string;
|
||||
severity: ChatErrorSeverity;
|
||||
is_expected: boolean;
|
||||
message?: string;
|
||||
}
|
||||
|
||||
export function trackChatBlocked(
|
||||
searchSpaceId: number,
|
||||
chatId: number | null,
|
||||
payload: ChatFailureTelemetry
|
||||
) {
|
||||
safeCapture(
|
||||
"chat_blocked",
|
||||
compact({
|
||||
search_space_id: searchSpaceId,
|
||||
chat_id: chatId ?? undefined,
|
||||
flow: payload.flow,
|
||||
kind: payload.kind,
|
||||
error_code: payload.error_code,
|
||||
severity: payload.severity,
|
||||
is_expected: payload.is_expected,
|
||||
message: payload.message,
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
export function trackChatErrorDetailed(
|
||||
searchSpaceId: number,
|
||||
chatId: number | null,
|
||||
payload: ChatFailureTelemetry
|
||||
) {
|
||||
safeCapture(
|
||||
"chat_error",
|
||||
compact({
|
||||
search_space_id: searchSpaceId,
|
||||
chat_id: chatId ?? undefined,
|
||||
flow: payload.flow,
|
||||
kind: payload.kind,
|
||||
error_code: payload.error_code,
|
||||
severity: payload.severity,
|
||||
is_expected: payload.is_expected,
|
||||
message: payload.message,
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Track a message sent from the unauthenticated "free" / anonymous chat
|
||||
* flow. This is intentionally a separate event from `chat_message_sent`
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue