diff --git a/surfsense_backend/alembic/versions/138_add_thread_auto_model_pinning_fields.py b/surfsense_backend/alembic/versions/138_add_thread_auto_model_pinning_fields.py index 3972b84b9..fba621a0c 100644 --- a/surfsense_backend/alembic/versions/138_add_thread_auto_model_pinning_fields.py +++ b/surfsense_backend/alembic/versions/138_add_thread_auto_model_pinning_fields.py @@ -4,10 +4,12 @@ Revision ID: 138 Revises: 137 Create Date: 2026-04-30 -Add thread-level fields to persist Auto (Fastest) model pinning metadata: -- pinned_llm_config_id: concrete resolved config id used for this thread -- pinned_auto_mode: auto policy identifier (currently "auto_fastest") -- pinned_at: timestamp when the pin was created/refreshed +Add a single thread-level column to persist the Auto (Fastest) model pin: +- pinned_llm_config_id: concrete resolved global LLM config id used for this + thread. NULL means "no pin; Auto will resolve on next turn". + +The column is unindexed: all reads are by new_chat_threads.id (primary key), +so a secondary index would be dead write amplification. """ from __future__ import annotations @@ -27,29 +29,14 @@ def upgrade() -> None: "ALTER TABLE new_chat_threads " "ADD COLUMN IF NOT EXISTS pinned_llm_config_id INTEGER" ) - op.execute( - "ALTER TABLE new_chat_threads " - "ADD COLUMN IF NOT EXISTS pinned_auto_mode VARCHAR(32)" - ) - op.execute( - "ALTER TABLE new_chat_threads " - "ADD COLUMN IF NOT EXISTS pinned_at TIMESTAMP WITH TIME ZONE" - ) - - op.execute( - "CREATE INDEX IF NOT EXISTS ix_new_chat_threads_pinned_llm_config_id " - "ON new_chat_threads (pinned_llm_config_id)" - ) - op.execute( - "CREATE INDEX IF NOT EXISTS ix_new_chat_threads_pinned_auto_mode " - "ON new_chat_threads (pinned_auto_mode)" - ) def downgrade() -> None: + # Drop any shape the thread row may be carrying. The extra columns and + # indexes only exist on dev DBs that ran an earlier draft of 138; IF EXISTS + # makes each statement a safe no-op on the lean shape. op.execute("DROP INDEX IF EXISTS ix_new_chat_threads_pinned_auto_mode") op.execute("DROP INDEX IF EXISTS ix_new_chat_threads_pinned_llm_config_id") - op.execute("ALTER TABLE new_chat_threads DROP COLUMN IF EXISTS pinned_at") op.execute("ALTER TABLE new_chat_threads DROP COLUMN IF EXISTS pinned_auto_mode") op.execute( diff --git a/surfsense_backend/alembic/versions/139_add_user_to_zero_publication.py b/surfsense_backend/alembic/versions/139_add_user_to_zero_publication.py new file mode 100644 index 000000000..83c96a429 --- /dev/null +++ b/surfsense_backend/alembic/versions/139_add_user_to_zero_publication.py @@ -0,0 +1,160 @@ +"""add user table to zero_publication with column list + +Adds the "user" table to zero_publication with a column-list publication +so that only the 5 fields driving the live usage meters are replicated +through WAL -> zero-cache -> browser IndexedDB: + + id, pages_limit, pages_used, + premium_tokens_limit, premium_tokens_used + +Sensitive columns (hashed_password, email, oauth_account, display_name, +avatar_url, memory_md, refresh_tokens, last_login, etc.) are NOT +included in the publication, so they never enter WAL replication. + +Also re-asserts REPLICA IDENTITY DEFAULT on "user" for idempotency +(it is already DEFAULT today since "user" was never in the +TABLES_WITH_FULL_IDENTITY list of migration 117). + +IMPORTANT - before AND after running this migration: + 1. Stop zero-cache (it holds replication locks that will deadlock DDL) + 2. Run: alembic upgrade head + 3. Delete / reset the zero-cache data volume + 4. Restart zero-cache (it will do a fresh initial sync) + +Revision ID: 139 +Revises: 138 +""" + +from collections.abc import Sequence + +import sqlalchemy as sa + +from alembic import op + +revision: str = "139" +down_revision: str | None = "138" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + +PUBLICATION_NAME = "zero_publication" + +# Document column list as left by migration 117. Must match exactly. +DOCUMENT_COLS = [ + "id", + "title", + "document_type", + "search_space_id", + "folder_id", + "created_by_id", + "status", + "created_at", + "updated_at", +] + +# Five fields needed by the live usage meters (sidebar Tokens/Pages, +# Buy Tokens content). Keep this list narrow on purpose: anything added +# here flows into WAL and IndexedDB for every connected browser. +USER_COLS = [ + "id", + "pages_limit", + "pages_used", + "premium_tokens_limit", + "premium_tokens_used", +] + + +def _terminate_blocked_pids(conn, table: str) -> None: + """Kill backends whose locks on *table* would block our AccessExclusiveLock.""" + conn.execute( + sa.text( + "SELECT pg_terminate_backend(l.pid) " + "FROM pg_locks l " + "JOIN pg_class c ON c.oid = l.relation " + "WHERE c.relname = :tbl " + " AND l.pid != pg_backend_pid()" + ), + {"tbl": table}, + ) + + +def _has_zero_version(conn, table: str) -> bool: + return ( + conn.execute( + sa.text( + "SELECT 1 FROM information_schema.columns " + "WHERE table_name = :tbl AND column_name = '_0_version'" + ), + {"tbl": table}, + ).fetchone() + is not None + ) + + +def _build_publication_ddl( + documents_has_zero_ver: bool, user_has_zero_ver: bool +) -> str: + doc_cols = DOCUMENT_COLS + (['"_0_version"'] if documents_has_zero_ver else []) + user_cols = USER_COLS + (['"_0_version"'] if user_has_zero_ver else []) + doc_col_list = ", ".join(doc_cols) + user_col_list = ", ".join(user_cols) + return ( + f"CREATE PUBLICATION {PUBLICATION_NAME} FOR TABLE " + f"notifications, " + f"documents ({doc_col_list}), " + f"folders, " + f"search_source_connectors, " + f"new_chat_messages, " + f"chat_comments, " + f"chat_session_state, " + f'"user" ({user_col_list})' + ) + + +def _build_publication_ddl_without_user(documents_has_zero_ver: bool) -> str: + doc_cols = DOCUMENT_COLS + (['"_0_version"'] if documents_has_zero_ver else []) + doc_col_list = ", ".join(doc_cols) + return ( + f"CREATE PUBLICATION {PUBLICATION_NAME} FOR TABLE " + f"notifications, " + f"documents ({doc_col_list}), " + f"folders, " + f"search_source_connectors, " + f"new_chat_messages, " + f"chat_comments, " + f"chat_session_state" + ) + + +def upgrade() -> None: + conn = op.get_bind() + # asyncpg requires LOCK TABLE inside a transaction block. Alembic already + # opened one via context.begin_transaction(), but the driver still errors + # unless we use an explicit SAVEPOINT (nested transaction) for this block. + tx = conn.begin_nested() if conn.in_transaction() else conn.begin() + with tx: + conn.execute(sa.text("SET lock_timeout = '10s'")) + + _terminate_blocked_pids(conn, "user") + conn.execute(sa.text('LOCK TABLE "user" IN ACCESS EXCLUSIVE MODE')) + + # Idempotent: "user" was never in TABLES_WITH_FULL_IDENTITY of + # migration 117, so this is already DEFAULT. Re-assert anyway so + # the column-list publication stays valid (DEFAULT identity only + # requires the PK to be in the column list). + conn.execute(sa.text('ALTER TABLE "user" REPLICA IDENTITY DEFAULT')) + + conn.execute(sa.text(f"DROP PUBLICATION IF EXISTS {PUBLICATION_NAME}")) + + documents_has_zero_ver = _has_zero_version(conn, "documents") + user_has_zero_ver = _has_zero_version(conn, "user") + + conn.execute( + sa.text(_build_publication_ddl(documents_has_zero_ver, user_has_zero_ver)) + ) + + +def downgrade() -> None: + conn = op.get_bind() + conn.execute(sa.text(f"DROP PUBLICATION IF EXISTS {PUBLICATION_NAME}")) + documents_has_zero_ver = _has_zero_version(conn, "documents") + conn.execute(sa.text(_build_publication_ddl_without_user(documents_has_zero_ver))) diff --git a/surfsense_backend/app/agents/new_chat/middleware/busy_mutex.py b/surfsense_backend/app/agents/new_chat/middleware/busy_mutex.py index d61a56533..06a27bc96 100644 --- a/surfsense_backend/app/agents/new_chat/middleware/busy_mutex.py +++ b/surfsense_backend/app/agents/new_chat/middleware/busy_mutex.py @@ -61,6 +61,9 @@ class _ThreadLockManager: self._cancel_events: dict[str, asyncio.Event] = {} self._cancel_requested_at_ms: dict[str, int] = {} self._cancel_attempt_count: dict[str, int] = {} + # Monotonic per-thread epoch used to prevent stale middleware + # teardown from releasing a newer turn's lock. + self._turn_epoch: dict[str, int] = {} def lock_for(self, thread_id: str) -> asyncio.Lock: lock = self._locks.get(thread_id) @@ -107,6 +110,14 @@ class _ThreadLockManager: self._cancel_requested_at_ms.pop(thread_id, None) self._cancel_attempt_count.pop(thread_id, None) + def bump_turn_epoch(self, thread_id: str) -> int: + epoch = self._turn_epoch.get(thread_id, 0) + 1 + self._turn_epoch[thread_id] = epoch + return epoch + + def current_turn_epoch(self, thread_id: str) -> int: + return self._turn_epoch.get(thread_id, 0) + def end_turn(self, thread_id: str) -> None: """Best-effort terminal cleanup for a thread turn. @@ -114,6 +125,10 @@ class _ThreadLockManager: finally-blocks where middleware teardown might be skipped due to abort or disconnect edge-cases. """ + # Invalidate any in-flight middleware holder first. This guarantees a + # stale ``aafter_agent`` from an older attempt cannot unlock a newer + # retry that already acquired the lock for the same thread. + self.bump_turn_epoch(thread_id) lock = self._locks.get(thread_id) if lock is not None and lock.locked(): lock.release() @@ -178,10 +193,10 @@ class BusyMutexMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, Respo super().__init__() self._require_thread_id = require_thread_id self.tools = [] - # Per-call locks owned by this middleware. We track them as - # an instance attribute so ``aafter_agent`` knows which lock - # to release. - self._held_locks: dict[str, asyncio.Lock] = {} + # Per-call lock ownership tracked as (lock, epoch). ``aafter_agent`` + # only releases when its epoch still matches the manager's current + # epoch for the thread, preventing stale unlock races. + self._held_locks: dict[str, tuple[asyncio.Lock, int]] = {} @staticmethod def _thread_id(runtime: Runtime[ContextT]) -> str | None: @@ -232,7 +247,8 @@ class BusyMutexMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, Respo if lock.locked(): raise BusyError(request_id=thread_id) await lock.acquire() - self._held_locks[thread_id] = lock + epoch = manager.bump_turn_epoch(thread_id) + self._held_locks[thread_id] = (lock, epoch) # Reset the cancel event so this turn starts fresh reset_cancel(thread_id) return None @@ -246,8 +262,15 @@ class BusyMutexMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, Respo thread_id = self._thread_id(runtime) if thread_id is None: return None - lock = self._held_locks.pop(thread_id, None) - if lock is not None and lock.locked(): + held = self._held_locks.pop(thread_id, None) + if held is None: + return None + lock, held_epoch = held + if held_epoch != manager.current_turn_epoch(thread_id): + # Stale teardown from an older attempt (e.g. runtime-recovery path + # already advanced epoch). Do not touch current lock/cancel state. + return None + if lock.locked(): lock.release() # Always clear cancel event between turns so a stale signal # doesn't leak into the next request. diff --git a/surfsense_backend/app/config/__init__.py b/surfsense_backend/app/config/__init__.py index bd97d2bb1..675b05d2c 100644 --- a/surfsense_backend/app/config/__init__.py +++ b/surfsense_backend/app/config/__init__.py @@ -63,6 +63,27 @@ def load_global_llm_configs(): else: seen_slugs[slug] = cfg.get("id", 0) + # Stamp Auto (Fastest) ranking metadata. YAML configs are always + # Tier A — operator-curated, locked first when premium-eligible. + # The OpenRouter refresh tick later re-stamps health for any cfg + # whose provider == "OPENROUTER" via _enrich_health. + try: + from app.services.quality_score import static_score_yaml + + for cfg in configs: + cfg["auto_pin_tier"] = "A" + static_q = static_score_yaml(cfg) + cfg["quality_score_static"] = static_q + cfg["quality_score"] = static_q + cfg["quality_score_health"] = None + # YAML cfgs whose provider is OPENROUTER are also subject + # to health gating against their own /endpoints data — a + # hand-picked dead OR model is still dead. _enrich_health + # re-stamps health_gated for them on the next refresh tick. + cfg["health_gated"] = False + except Exception as e: + print(f"Warning: Failed to score global LLM configs: {e}") + return configs except Exception as e: print(f"Warning: Failed to load global LLM configs: {e}") @@ -194,6 +215,9 @@ def load_openrouter_integration_settings() -> dict | None: """ Load OpenRouter integration settings from the YAML config. + Emits startup warnings for deprecated keys (``billing_tier``, + ``anonymous_enabled``) and seeds their replacements for back-compat. + Returns: dict with settings if present and enabled, None otherwise """ @@ -206,9 +230,31 @@ def load_openrouter_integration_settings() -> dict | None: with open(global_config_file, encoding="utf-8") as f: data = yaml.safe_load(f) settings = data.get("openrouter_integration") - if settings and settings.get("enabled"): - return settings - return None + if not settings or not settings.get("enabled"): + return None + + if "billing_tier" in settings: + print( + "Warning: openrouter_integration.billing_tier is deprecated; " + "tier is now derived per model from OpenRouter data " + "(':free' suffix or zero pricing). Remove this key." + ) + + if "anonymous_enabled" in settings: + print( + "Warning: openrouter_integration.anonymous_enabled is " + "deprecated; use anonymous_enabled_paid and/or " + "anonymous_enabled_free instead. Both new flags have been " + "seeded from the legacy value for back-compat." + ) + settings.setdefault( + "anonymous_enabled_paid", settings["anonymous_enabled"] + ) + settings.setdefault( + "anonymous_enabled_free", settings["anonymous_enabled"] + ) + + return settings except Exception as e: print(f"Warning: Failed to load OpenRouter integration settings: {e}") return None @@ -217,9 +263,14 @@ def load_openrouter_integration_settings() -> dict | None: def initialize_openrouter_integration(): """ If enabled, fetch all OpenRouter models and append them to - config.GLOBAL_LLM_CONFIGS as dynamic premium entries. - Should be called BEFORE initialize_llm_router() so the router - correctly excludes premium models from Auto mode. + config.GLOBAL_LLM_CONFIGS as dynamic entries. Each model's ``billing_tier`` + is derived per-model from OpenRouter's API signals (``:free`` suffix or + zero pricing), so free OpenRouter models correctly skip premium quota. + + Should be called BEFORE initialize_llm_router(). Dynamic entries are + tagged ``router_pool_eligible=False`` so the LiteLLM Router pool (used + by title-gen / sub-agent flows) remains scoped to curated YAML configs, + while user-facing Auto-mode thread pinning still considers them. """ settings = load_openrouter_integration_settings() if not settings: @@ -235,9 +286,13 @@ def initialize_openrouter_integration(): if new_configs: config.GLOBAL_LLM_CONFIGS.extend(new_configs) + free_count = sum(1 for c in new_configs if c.get("billing_tier") == "free") + premium_count = sum( + 1 for c in new_configs if c.get("billing_tier") == "premium" + ) print( f"Info: OpenRouter integration added {len(new_configs)} models " - f"(billing_tier={settings.get('billing_tier', 'premium')})" + f"(free={free_count}, premium={premium_count})" ) else: print("Info: OpenRouter integration enabled but no models fetched") diff --git a/surfsense_backend/app/config/global_llm_config.example.yaml b/surfsense_backend/app/config/global_llm_config.example.yaml index 9aca0f022..79cbe1e51 100644 --- a/surfsense_backend/app/config/global_llm_config.example.yaml +++ b/surfsense_backend/app/config/global_llm_config.example.yaml @@ -245,31 +245,53 @@ global_llm_configs: # ============================================================================= # When enabled, dynamically fetches ALL available models from the OpenRouter API # and injects them as global configs. This gives premium users access to any model -# on OpenRouter (Claude, Gemini, Llama, Mistral, etc.) via their premium token quota. +# on OpenRouter (Claude, Gemini, Llama, Mistral, etc.) via their premium token quota, +# while free-tier OpenRouter models show up with a green Free badge and do NOT +# consume premium quota. # Models are fetched at startup and refreshed periodically in the background. # All calls go through LiteLLM with the openrouter/ prefix. openrouter_integration: enabled: false api_key: "sk-or-your-openrouter-api-key" - # billing_tier: "premium" or "free". Controls whether users need premium tokens. - billing_tier: "premium" - # anonymous_enabled: set true to also show OpenRouter models to no-login users - anonymous_enabled: false + + # Tier is derived PER MODEL from OpenRouter's own API signals: + # - id ends with ":free" -> billing_tier=free + # - pricing.prompt AND pricing.completion == "0" -> billing_tier=free + # - otherwise -> billing_tier=premium + # No global billing_tier knob is honored; any legacy value emits a startup warning. + + # Anonymous access is split by tier so operators can expose only free + # models to no-login users without leaking paid inference. + anonymous_enabled_paid: false + anonymous_enabled_free: false + seo_enabled: false # quota_reserve_tokens: tokens reserved per call for quota enforcement quota_reserve_tokens: 4000 - # id_offset: starting negative ID for dynamically generated configs. - # Must not overlap with your static global_llm_configs IDs above. + # id_offset: base negative ID for dynamically generated configs. + # Model IDs are derived deterministically via BLAKE2b so they survive + # catalogue churn. Must not overlap with your static global_llm_configs IDs. id_offset: -10000 # refresh_interval_hours: how often to re-fetch models from OpenRouter (0 = startup only) refresh_interval_hours: 24 - # rpm/tpm: Applied uniformly to all OpenRouter models for LiteLLM Router load balancing. - # OpenRouter doesn't expose per-model rate limits via API; actual throttling is handled - # upstream by OpenRouter itself (your account limits are at https://openrouter.ai/settings/limits). - # These values only matter if you set billing_tier to "free" (adding them to Auto mode). - # For premium-only models they are cosmetic. Set conservatively or match your account tier. + + # Rate limits for PAID OpenRouter models. These are used by LiteLLM Router + # for per-deployment accounting when OR premium models participate in the + # shared sub-agent "auto" pool. They do NOT cap OpenRouter itself — your + # real account limits live at https://openrouter.ai/settings/limits. rpm: 200 tpm: 1000000 + + # Rate limits for FREE OpenRouter models. Informational only: free OR + # models are intentionally kept OUT of the LiteLLM Router pool, because + # OpenRouter enforces free-tier limits globally per account (~20 RPM + + # 50-1000 daily requests across every ":free" model combined) — + # per-deployment router accounting can't represent a shared bucket + # correctly. Free OR models stay fully available in the model selector + # and for user-facing Auto thread pinning. + free_rpm: 20 + free_tpm: 100000 + litellm_params: max_tokens: 16384 system_instructions: "" diff --git a/surfsense_backend/app/db.py b/surfsense_backend/app/db.py index ca3334f8b..2fe478d9b 100644 --- a/surfsense_backend/app/db.py +++ b/surfsense_backend/app/db.py @@ -638,13 +638,12 @@ class NewChatThread(BaseModel, TimestampMixin): default=False, server_default="false", ) - # Auto model pinning metadata: - # - pinned_llm_config_id stores the concrete resolved model config id. - # - pinned_auto_mode indicates which auto policy produced the pin. - # This allows Auto (Fastest) to resolve once per thread and stay stable. - pinned_llm_config_id = Column(Integer, nullable=True, index=True) - pinned_auto_mode = Column(String(32), nullable=True, index=True) - pinned_at = Column(TIMESTAMP(timezone=True), nullable=True) + # Auto (Fastest) model pin for this thread: concrete resolved global LLM + # config id. NULL means no pin; Auto will resolve on the next turn. + # Single-writer invariant: only app.services.auto_model_pin_service sets + # or clears this column (plus bulk clears when a search space's + # agent_llm_id changes). Unindexed: all reads are by primary key. + pinned_llm_config_id = Column(Integer, nullable=True) # Relationships search_space = relationship("SearchSpace", back_populates="new_chat_threads") diff --git a/surfsense_backend/app/routes/documents_routes.py b/surfsense_backend/app/routes/documents_routes.py index f558481cf..f1ca3b6bf 100644 --- a/surfsense_backend/app/routes/documents_routes.py +++ b/surfsense_backend/app/routes/documents_routes.py @@ -745,6 +745,51 @@ async def search_document_titles( ) from e +@router.get("/documents/by-virtual-path", response_model=DocumentTitleRead) +async def get_document_by_virtual_path( + search_space_id: int, + virtual_path: str, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + """Resolve a knowledge-base document id by exact virtual path.""" + try: + await check_permission( + session, + user, + search_space_id, + Permission.DOCUMENTS_READ.value, + "You don't have permission to read documents in this search space", + ) + + result = await session.execute( + select( + Document.id, + Document.title, + Document.document_type, + ).filter( + Document.search_space_id == search_space_id, + Document.document_metadata["virtual_path"].as_string() == virtual_path, + ) + ) + row = result.first() + if row is None: + raise HTTPException(status_code=404, detail="Document not found") + + return DocumentTitleRead( + id=row.id, + title=row.title, + document_type=row.document_type, + ) + except HTTPException: + raise + except Exception as e: + raise HTTPException( + status_code=500, + detail=f"Failed to resolve document by virtual path: {e!s}", + ) from e + + @router.get("/documents/status", response_model=DocumentStatusBatchResponse) async def get_documents_status( search_space_id: int, diff --git a/surfsense_backend/app/routes/search_spaces_routes.py b/surfsense_backend/app/routes/search_spaces_routes.py index 7944e7d66..72715ea5b 100644 --- a/surfsense_backend/app/routes/search_spaces_routes.py +++ b/surfsense_backend/app/routes/search_spaces_routes.py @@ -803,11 +803,7 @@ async def update_llm_preferences( await session.execute( update(NewChatThread) .where(NewChatThread.search_space_id == search_space_id) - .values( - pinned_llm_config_id=None, - pinned_auto_mode=None, - pinned_at=None, - ) + .values(pinned_llm_config_id=None) ) logger.info( "Cleared auto model pins for search_space_id=%s after agent_llm_id change (%s -> %s)", diff --git a/surfsense_backend/app/services/auto_model_pin_service.py b/surfsense_backend/app/services/auto_model_pin_service.py index 6b69c91ea..3a2c681b7 100644 --- a/surfsense_backend/app/services/auto_model_pin_service.py +++ b/surfsense_backend/app/services/auto_model_pin_service.py @@ -2,16 +2,23 @@ Auto (Fastest) is represented by ``agent_llm_id == 0``. For chat threads we resolve that virtual mode to one concrete global LLM config exactly once and -persist the chosen config id on ``new_chat_threads`` so subsequent turns are -stable. +persist the chosen config id on ``new_chat_threads.pinned_llm_config_id`` so +subsequent turns are stable. + +Single-writer invariant: this module is the only writer of +``NewChatThread.pinned_llm_config_id`` (aside from the bulk clear in +``search_spaces_routes`` when a search space's ``agent_llm_id`` changes). +Therefore a non-NULL value unambiguously means "this thread has an +Auto-resolved pin"; no separate source/policy column is needed. """ from __future__ import annotations import hashlib import logging +import threading +import time from dataclasses import dataclass -from datetime import UTC, datetime from uuid import UUID from sqlalchemy import select @@ -19,12 +26,28 @@ from sqlalchemy.ext.asyncio import AsyncSession from app.config import config from app.db import NewChatThread +from app.services.quality_score import _QUALITY_TOP_K from app.services.token_quota_service import TokenQuotaService logger = logging.getLogger(__name__) AUTO_FASTEST_ID = 0 AUTO_FASTEST_MODE = "auto_fastest" +_RUNTIME_COOLDOWN_SECONDS = 600 +_HEALTHY_TTL_SECONDS = 45 + +# In-memory runtime cooldown map for configs that recently hard-failed at +# provider runtime (e.g. OpenRouter 429 on a pinned free model). This keeps +# the same unhealthy config from being reselected immediately during repair. +_runtime_cooldown_until: dict[int, float] = {} +_runtime_cooldown_lock = threading.Lock() + +# Short-TTL "recently healthy" cache for configs that just passed a runtime +# preflight ping. Lets back-to-back turns on the same model skip the probe +# without eroding correctness — entries auto-expire and are wiped any time +# the same config is cooled down or the OR catalogue is refreshed. +_healthy_until: dict[int, float] = {} +_healthy_lock = threading.Lock() @dataclass @@ -43,9 +66,117 @@ def _is_usable_global_config(cfg: dict) -> bool: ) +def _prune_runtime_cooldowns(now_ts: float | None = None) -> None: + now = time.time() if now_ts is None else now_ts + stale = [cid for cid, until in _runtime_cooldown_until.items() if until <= now] + for cid in stale: + _runtime_cooldown_until.pop(cid, None) + + +def _is_runtime_cooled_down(config_id: int) -> bool: + with _runtime_cooldown_lock: + _prune_runtime_cooldowns() + return config_id in _runtime_cooldown_until + + +def mark_runtime_cooldown( + config_id: int, + *, + reason: str = "rate_limited", + cooldown_seconds: int = _RUNTIME_COOLDOWN_SECONDS, +) -> None: + """Temporarily suppress a config from Auto selection. + + Used by runtime error handlers (e.g. OpenRouter 429) so an already pinned + config that is currently unhealthy does not get immediately reused on the + same thread during repair. + """ + if cooldown_seconds <= 0: + cooldown_seconds = _RUNTIME_COOLDOWN_SECONDS + until = time.time() + int(cooldown_seconds) + with _runtime_cooldown_lock: + _runtime_cooldown_until[int(config_id)] = until + _prune_runtime_cooldowns() + # A cooled cfg can never be "recently healthy"; drop any stale credit so + # the next turn that resolves to it (after cooldown) re-runs preflight. + clear_healthy(int(config_id)) + logger.info( + "auto_pin_runtime_cooled_down config_id=%s reason=%s cooldown_seconds=%s", + config_id, + reason, + cooldown_seconds, + ) + + +def clear_runtime_cooldown(config_id: int | None = None) -> None: + """Test/ops helper to clear runtime cooldown entries.""" + with _runtime_cooldown_lock: + if config_id is None: + _runtime_cooldown_until.clear() + return + _runtime_cooldown_until.pop(int(config_id), None) + + +def _prune_healthy(now_ts: float | None = None) -> None: + now = time.time() if now_ts is None else now_ts + stale = [cid for cid, until in _healthy_until.items() if until <= now] + for cid in stale: + _healthy_until.pop(cid, None) + + +def is_recently_healthy(config_id: int) -> bool: + """Return True if ``config_id`` passed preflight within the TTL window.""" + with _healthy_lock: + _prune_healthy() + return int(config_id) in _healthy_until + + +def mark_healthy( + config_id: int, + *, + ttl_seconds: int = _HEALTHY_TTL_SECONDS, +) -> None: + """Record that ``config_id`` just passed a preflight probe. + + Subsequent calls within ``ttl_seconds`` can skip the preflight ping. The + healthy state is intentionally process-local — it's a latency hint, not a + correctness primitive — so multi-worker drift is acceptable. + """ + if ttl_seconds <= 0: + ttl_seconds = _HEALTHY_TTL_SECONDS + until = time.time() + int(ttl_seconds) + with _healthy_lock: + _healthy_until[int(config_id)] = until + _prune_healthy() + + +def clear_healthy(config_id: int | None = None) -> None: + """Drop one (or all) healthy-cache entries. + + Called from runtime cooldown and OR catalogue refresh so a freshly cooled + or replaced config never carries stale "healthy" credit. + """ + with _healthy_lock: + if config_id is None: + _healthy_until.clear() + return + _healthy_until.pop(int(config_id), None) + + def _global_candidates() -> list[dict]: + """Return Auto-eligible global cfgs. + + Drops cfgs flagged ``health_gated`` (best non-null OpenRouter uptime + below ``_HEALTH_GATE_UPTIME_PCT``) so chronically broken providers + can't be picked as the thread's pin. Also excludes configs currently + in runtime cooldown (e.g. temporary 429 bursts). + """ candidates = [ - cfg for cfg in config.GLOBAL_LLM_CONFIGS if _is_usable_global_config(cfg) + cfg + for cfg in config.GLOBAL_LLM_CONFIGS + if _is_usable_global_config(cfg) + and not cfg.get("health_gated") + and not _is_runtime_cooled_down(int(cfg.get("id", 0))) ] return sorted(candidates, key=lambda c: int(c.get("id", 0))) @@ -54,10 +185,26 @@ def _tier_of(cfg: dict) -> str: return str(cfg.get("billing_tier", "free")).lower() -def _deterministic_pick(candidates: list[dict], thread_id: int) -> dict: +def _select_pin(eligible: list[dict], thread_id: int) -> tuple[dict, int]: + """Pick a config with quality-first ranking + deterministic spread. + + Tier policy is lock-first: prefer Tier A (operator-curated YAML) + cfgs and only fall through to Tier B/C (dynamic OpenRouter) if no + Tier A cfg is eligible after upstream filters. Within the locked + pool, sort by ``quality_score`` and pick from the top-K via + ``SHA256(thread_id)`` so different new threads spread across the + best models without ever picking a low-ranked one. + + Returns ``(chosen_cfg, top_k_size)``. ``top_k_size`` is exposed for + structured logging in the caller. + """ + tier_a = [c for c in eligible if c.get("auto_pin_tier") in (None, "A")] + pool = tier_a if tier_a else eligible + pool = sorted(pool, key=lambda c: -int(c.get("quality_score") or 0)) + top_k = pool[:_QUALITY_TOP_K] digest = hashlib.sha256(f"{AUTO_FASTEST_MODE}:{thread_id}".encode()).digest() - idx = int.from_bytes(digest[:8], "big") % len(candidates) - return candidates[idx] + idx = int.from_bytes(digest[:8], "big") % len(top_k) + return top_k[idx], len(top_k) def _to_uuid(user_id: str | UUID | None) -> UUID | None: @@ -89,11 +236,12 @@ async def resolve_or_get_pinned_llm_config_id( user_id: str | UUID | None, selected_llm_config_id: int, force_repin_free: bool = False, + exclude_config_ids: set[int] | None = None, ) -> AutoPinResolution: - """Resolve Auto (Fastest) to one concrete config id and persist pin metadata. + """Resolve Auto (Fastest) to one concrete config id and persist the pin. - For non-auto selections, this function clears existing auto pin metadata and - returns the selected id as-is. + For non-auto selections, this function clears any existing pin and returns + the selected id as-is. """ thread = ( ( @@ -113,16 +261,10 @@ async def resolve_or_get_pinned_llm_config_id( f"Thread {thread_id} does not belong to search space {search_space_id}" ) - # Explicit model selected: clear stale auto pin metadata. + # Explicit model selected: clear any stale pin. if selected_llm_config_id != AUTO_FASTEST_ID: - if ( - thread.pinned_llm_config_id is not None - or thread.pinned_auto_mode is not None - or thread.pinned_at is not None - ): + if thread.pinned_llm_config_id is not None: thread.pinned_llm_config_id = None - thread.pinned_auto_mode = None - thread.pinned_at = None await session.commit() return AutoPinResolution( resolved_llm_config_id=selected_llm_config_id, @@ -130,17 +272,19 @@ async def resolve_or_get_pinned_llm_config_id( from_existing_pin=False, ) - candidates = _global_candidates() + excluded_ids = {int(cid) for cid in (exclude_config_ids or set())} + candidates = [ + c for c in _global_candidates() if int(c.get("id", 0)) not in excluded_ids + ] if not candidates: raise ValueError("No usable global LLM configs are available for Auto mode") candidate_by_id = {int(c["id"]): c for c in candidates} - # Reuse existing valid pin without re-checking current quota (no silent tier switch), - # unless the caller explicitly requests a forced repin to free. + # Reuse an existing valid pin without re-checking current quota (no silent + # tier switch), unless the caller explicitly requests a forced repin to free. pinned_id = thread.pinned_llm_config_id if ( not force_repin_free - and thread.pinned_auto_mode == AUTO_FASTEST_MODE and pinned_id is not None and int(pinned_id) in candidate_by_id ): @@ -152,6 +296,15 @@ async def resolve_or_get_pinned_llm_config_id( pinned_id, _tier_of(pinned_cfg), ) + logger.info( + "auto_pin_resolved thread_id=%s config_id=%s tier=%s " + "auto_pin_tier=%s score=%s top_k_size=0 from_existing_pin=True", + thread_id, + pinned_id, + _tier_of(pinned_cfg), + pinned_cfg.get("auto_pin_tier", "?"), + int(pinned_cfg.get("quality_score") or 0), + ) return AutoPinResolution( resolved_llm_config_id=int(pinned_id), resolved_tier=_tier_of(pinned_cfg), @@ -159,11 +312,10 @@ async def resolve_or_get_pinned_llm_config_id( ) if pinned_id is not None: logger.info( - "auto_pin_invalid thread_id=%s search_space_id=%s pinned_config_id=%s pinned_auto_mode=%s", + "auto_pin_invalid thread_id=%s search_space_id=%s pinned_config_id=%s", thread_id, search_space_id, pinned_id, - thread.pinned_auto_mode, ) premium_eligible = ( @@ -179,13 +331,11 @@ async def resolve_or_get_pinned_llm_config_id( "Auto mode could not find an eligible LLM config for this user and quota state" ) - selected_cfg = _deterministic_pick(eligible, thread_id) + selected_cfg, top_k_size = _select_pin(eligible, thread_id) selected_id = int(selected_cfg["id"]) selected_tier = _tier_of(selected_cfg) thread.pinned_llm_config_id = selected_id - thread.pinned_auto_mode = AUTO_FASTEST_MODE - thread.pinned_at = datetime.now(UTC) await session.commit() if force_repin_free: @@ -216,6 +366,18 @@ async def resolve_or_get_pinned_llm_config_id( selected_tier, premium_eligible, ) + + logger.info( + "auto_pin_resolved thread_id=%s config_id=%s tier=%s " + "auto_pin_tier=%s score=%s top_k_size=%d from_existing_pin=False", + thread_id, + selected_id, + selected_tier, + selected_cfg.get("auto_pin_tier", "?"), + int(selected_cfg.get("quality_score") or 0), + top_k_size, + ) + return AutoPinResolution( resolved_llm_config_id=selected_id, resolved_tier=selected_tier, diff --git a/surfsense_backend/app/services/llm_router_service.py b/surfsense_backend/app/services/llm_router_service.py index fbd42b458..8a7b2919a 100644 --- a/surfsense_backend/app/services/llm_router_service.py +++ b/surfsense_backend/app/services/llm_router_service.py @@ -208,6 +208,12 @@ class LLMRouterService: """ Initialize the router with global LLM configurations. + Configs with ``router_pool_eligible=False`` are skipped so that + dynamic OpenRouter entries stay out of the shared router pool used + by title-gen / sub-agent ``model="auto"`` flows. Those dynamic + entries are still available for user-facing Auto-mode thread pinning + via ``auto_model_pin_service``. + Args: global_configs: List of global LLM config dictionaries from YAML router_settings: Optional router settings (routing_strategy, num_retries, etc.) @@ -221,6 +227,8 @@ class LLMRouterService: model_list = [] premium_models: set[str] = set() for config in global_configs: + if config.get("router_pool_eligible") is False: + continue deployment = cls._config_to_deployment(config) if deployment: model_list.append(deployment) @@ -309,10 +317,45 @@ class LLMRouterService: logger.error(f"Failed to initialize LLM Router: {e}") instance._router = None + @classmethod + def rebuild( + cls, + global_configs: list[dict], + router_settings: dict | None = None, + ) -> None: + """Reset the router and re-run ``initialize`` with fresh configs. + + ``initialize`` short-circuits once it has run to avoid re-creating the + LiteLLM Router on every request; ``rebuild`` deliberately clears + ``_initialized`` so a caller (e.g. background OpenRouter refresh) + can force the pool to be rebuilt after catalogue changes. + """ + instance = cls.get_instance() + instance._initialized = False + instance._router = None + instance._model_list = [] + instance._premium_model_strings = set() + cls.initialize(global_configs, router_settings) + @classmethod def is_premium_model(cls, model_string: str) -> bool: - """Return True if *model_string* (as reported by LiteLLM) belongs to a - premium-tier deployment in the router pool.""" + """Return True if *model_string* belongs to a premium-tier deployment + in the LiteLLM router pool. + + Scope: only covers configs with ``router_pool_eligible`` truthy. That + includes static YAML premium configs AND dynamic OpenRouter *premium* + entries (which opt in at generation time). Dynamic OpenRouter *free* + entries are deliberately kept out of the router pool — OpenRouter + enforces free-tier limits globally per account, so per-deployment + router accounting can't represent them correctly — and therefore + return ``False`` here, which matches their ``billing_tier="free"`` + (no premium quota). + + For per-request premium checks on an arbitrary config (static or + dynamic, pool or non-pool), read ``agent_config.is_premium`` instead; + that reflects the per-config ``billing_tier`` directly and is what + user-facing Auto-mode thread pinning uses to bill correctly. + """ instance = cls.get_instance() return model_string in instance._premium_model_strings diff --git a/surfsense_backend/app/services/openrouter_integration_service.py b/surfsense_backend/app/services/openrouter_integration_service.py index 1245f73aa..7e856d015 100644 --- a/surfsense_backend/app/services/openrouter_integration_service.py +++ b/surfsense_backend/app/services/openrouter_integration_service.py @@ -11,20 +11,81 @@ this service only manages the catalogue, not the inference path. """ import asyncio +import hashlib import logging import threading +import time from typing import Any import httpx +from app.services.quality_score import ( + _HEALTH_BLEND_WEIGHT, + _HEALTH_ENRICH_CONCURRENCY, + _HEALTH_ENRICH_TOP_N_FREE, + _HEALTH_ENRICH_TOP_N_PREMIUM, + _HEALTH_FAIL_RATIO_FALLBACK, + _HEALTH_FETCH_TIMEOUT_SEC, + aggregate_health, + static_score_or, +) + logger = logging.getLogger(__name__) OPENROUTER_API_URL = "https://openrouter.ai/api/v1/models" +OPENROUTER_ENDPOINTS_URL_TEMPLATE = ( + "https://openrouter.ai/api/v1/models/{model_id}/endpoints" +) # Sentinel value stored on each generated config so we can distinguish # dynamic OpenRouter entries from hand-written YAML entries during refresh. _OPENROUTER_DYNAMIC_MARKER = "__openrouter_dynamic__" +# Width of the hash space used by ``_stable_config_id``. 9_000_000 provides +# enough headroom to avoid frequent collisions for OpenRouter's catalogue +# (~300 models) while keeping IDs comfortably within Postgres INTEGER range. +_STABLE_ID_HASH_WIDTH = 9_000_000 + + +def _stable_config_id(model_id: str, offset: int, taken: set[int]) -> int: + """Derive a deterministic negative config ID from ``model_id``. + + The same ``model_id`` always hashes to the same base value so thread pins + survive catalogue churn (models appearing/disappearing/reordering between + refreshes). On collision we decrement until we find an unused slot; this + keeps the mapping stable for the first config that claimed a slot and + only shifts collisions, which is much less disruptive than the legacy + index-based scheme that reshuffled every ID when the catalogue changed. + """ + digest = hashlib.blake2b(model_id.encode("utf-8"), digest_size=6).digest() + base = offset - (int.from_bytes(digest, "big") % _STABLE_ID_HASH_WIDTH) + cid = base + while cid in taken: + cid -= 1 + taken.add(cid) + return cid + + +def _openrouter_tier(model: dict) -> str: + """Classify an OpenRouter model as ``"free"`` or ``"premium"``. + + Per OpenRouter's API contract, a model is free if: + - Its id ends with ``:free`` (OpenRouter's own free-variant convention), or + - Both ``pricing.prompt`` and ``pricing.completion`` are zero strings. + + Anything else (missing pricing, non-zero pricing) falls through to + ``"premium"`` so we never under-charge users. This derivation runs off the + already-cached /api/v1/models payload, so it adds no network cost. + """ + if model.get("id", "").endswith(":free"): + return "free" + pricing = model.get("pricing") or {} + prompt = str(pricing.get("prompt", "")).strip() + completion = str(pricing.get("completion", "")).strip() + if prompt == "0" and completion == "0": + return "free" + return "premium" + def _is_text_output_model(model: dict) -> bool: """Return True if the model produces text output only (skip image/audio generators).""" @@ -56,6 +117,11 @@ _EXCLUDED_MODEL_IDS: set[str] = { # Deep-research models reject standard params (temperature, etc.) "openai/o3-deep-research", "openai/o4-mini-deep-research", + # OpenRouter's own meta-router over free models. We already enumerate every + # concrete ``:free`` model into GLOBAL_LLM_CONFIGS and Auto-mode thread + # pinning handles churn via the repair path, so exposing an additional + # indirection layer would only duplicate the capability with an opaque slug. + "openrouter/free", } _EXCLUDED_MODEL_SUFFIXES: tuple[str, ...] = ("-deep-research",) @@ -113,20 +179,41 @@ def _generate_configs( raw_models: list[dict], settings: dict[str, Any], ) -> list[dict]: - """ - Convert raw OpenRouter model entries into global LLM config dicts. + """Convert raw OpenRouter model entries into global LLM config dicts. - Models are sorted by ID for deterministic, stable ID assignment across - restarts and refreshes. + Tier (``billing_tier``) is derived per-model from OpenRouter's own API + signals via ``_openrouter_tier`` — there is no longer a uniform YAML + override. Config IDs are derived via ``_stable_config_id`` so they + survive catalogue churn across refreshes. + + Router-pool membership is tier-aware: + + - Premium OR models join the LiteLLM router pool (``router_pool_eligible=True``) + so sub-agent ``model="auto"`` flows benefit from load balancing and + failover across the curated YAML configs and the OR premium passthrough. + - Free OR models stay excluded (``router_pool_eligible=False``). LiteLLM + Router tracks rate limits per deployment, but OpenRouter enforces a + single global free-tier quota (~20 RPM + 50-1000 daily requests + account-wide across every ``:free`` model), so rotating across many + free deployments would only burn the shared bucket faster. Free OR + models remain fully available for user-facing Auto-mode thread pinning + via ``auto_model_pin_service``. + + OpenRouter's own ``openrouter/free`` meta-router is filtered out upstream + via ``_EXCLUDED_MODEL_IDS``; we don't expose a redundant auto-select layer + because our own Auto (Fastest) pin + 24 h refresh + repair logic already + cover the catalogue-churn case. """ id_offset: int = settings.get("id_offset", -10000) api_key: str = settings.get("api_key", "") - billing_tier: str = settings.get("billing_tier", "premium") - anonymous_enabled: bool = settings.get("anonymous_enabled", False) seo_enabled: bool = settings.get("seo_enabled", False) quota_reserve_tokens: int = settings.get("quota_reserve_tokens", 4000) rpm: int = settings.get("rpm", 200) - tpm: int = settings.get("tpm", 1000000) + tpm: int = settings.get("tpm", 1_000_000) + free_rpm: int = settings.get("free_rpm", 20) + free_tpm: int = settings.get("free_tpm", 100_000) + anon_paid: bool = settings.get("anonymous_enabled_paid", False) + anon_free: bool = settings.get("anonymous_enabled_free", False) litellm_params: dict = settings.get("litellm_params") or {} system_instructions: str = settings.get("system_instructions", "") use_default: bool = settings.get("use_default_system_instructions", True) @@ -142,19 +229,24 @@ def _generate_configs( and _is_allowed_model(m) and "/" in m.get("id", "") ] - text_models.sort(key=lambda m: m["id"]) configs: list[dict] = [] - for idx, model in enumerate(text_models): + taken: set[int] = set() + now_ts = int(time.time()) + + for model in text_models: model_id: str = model["id"] name: str = model.get("name", model_id) + tier = _openrouter_tier(model) + + static_q = static_score_or(model, now_ts=now_ts) cfg: dict[str, Any] = { - "id": id_offset - idx, + "id": _stable_config_id(model_id, id_offset, taken), "name": name, "description": f"{name} via OpenRouter", - "billing_tier": billing_tier, - "anonymous_enabled": anonymous_enabled, + "billing_tier": tier, + "anonymous_enabled": anon_free if tier == "free" else anon_paid, "seo_enabled": seo_enabled, "seo_slug": None, "quota_reserve_tokens": quota_reserve_tokens, @@ -162,13 +254,28 @@ def _generate_configs( "model_name": model_id, "api_key": api_key, "api_base": "", - "rpm": rpm, - "tpm": tpm, + "rpm": free_rpm if tier == "free" else rpm, + "tpm": free_tpm if tier == "free" else tpm, "litellm_params": dict(litellm_params), "system_instructions": system_instructions, "use_default_system_instructions": use_default, "citations_enabled": citations_enabled, + # Premium OR deployments join the LiteLLM router pool so sub-agent + # model="auto" flows can load-balance / fail over across them. + # Free OR deployments stay out: OpenRouter's free tier is a single + # account-wide quota, so per-deployment routing can't spread load + # there — it just drains the shared bucket faster. + "router_pool_eligible": tier == "premium", _OPENROUTER_DYNAMIC_MARKER: True, + # Auto (Fastest) ranking metadata. ``quality_score`` is initialised + # to the static score and gets re-blended with health on the next + # ``_enrich_health`` pass (synchronous on refresh, deferred on cold + # start so startup latency is unchanged). + "auto_pin_tier": "B" if tier == "premium" else "C", + "quality_score_static": static_q, + "quality_score_health": None, + "quality_score": static_q, + "health_gated": False, } configs.append(cfg) @@ -187,6 +294,12 @@ class OpenRouterIntegrationService: self._configs_by_id: dict[int, dict] = {} self._initialized = False self._refresh_task: asyncio.Task | None = None + # Last-good per-model health snapshot. Survives across refresh + # cycles so a transient OpenRouter /endpoints outage doesn't drop + # every cfg back to static-only scoring. + # Shape: {model_name: {"gated": bool, "score": float | None}} + self._health_cache: dict[str, dict[str, Any]] = {} + self._enrich_task: asyncio.Task | None = None @classmethod def get_instance(cls) -> "OpenRouterIntegrationService": @@ -220,12 +333,27 @@ class OpenRouterIntegrationService: self._configs_by_id = {c["id"]: c for c in self._configs} self._initialized = True + tier_counts = self._tier_counts(self._configs) logger.info( - "OpenRouter integration: loaded %d models (IDs %d to %d)", + "OpenRouter integration: loaded %d models (free=%d, premium=%d)", len(self._configs), - self._configs[0]["id"] if self._configs else 0, - self._configs[-1]["id"] if self._configs else 0, + tier_counts["free"], + tier_counts["premium"], ) + + # Schedule the first health-enrichment pass as a deferred task so + # cold-start latency is unchanged. Only valid when an event loop is + # already running (e.g. FastAPI lifespan); Celery worker init is + # fully sync so we silently skip — its first refresh tick (or the + # next refresh from the web process) will populate health data. + try: + loop = asyncio.get_running_loop() + self._enrich_task = loop.create_task( + self._enrich_health_safely(self._configs) + ) + except RuntimeError: + pass + return self._configs # ------------------------------------------------------------------ @@ -254,7 +382,225 @@ class OpenRouterIntegrationService: self._configs = new_configs self._configs_by_id = new_by_id - logger.info("OpenRouter refresh: updated to %d models", len(new_configs)) + # Catalogue churn invalidates per-config "recently healthy" credit + # earned by the previous turn's preflight. Drop the whole table so + # the next turn re-probes against the freshly loaded configs. + try: + from app.services.auto_model_pin_service import clear_healthy + + clear_healthy() + except Exception: + logger.debug( + "OpenRouter refresh: clear_healthy import skipped", exc_info=True + ) + + tier_counts = self._tier_counts(new_configs) + logger.info( + "OpenRouter refresh: updated to %d models (free=%d, premium=%d)", + len(new_configs), + tier_counts["free"], + tier_counts["premium"], + ) + + # Re-blend health scores against the freshly fetched catalogue. Also + # re-stamps health for any YAML-curated cfg with provider==OPENROUTER + # so a hand-picked dead OR model is gated like a dynamic one. + await self._enrich_health_safely(static_configs + new_configs, log_summary=True) + + # Rebuild the LiteLLM router so freshly fetched configs flow through + # (dynamic OR premium entries now opt into the pool, free ones stay + # out; a refresh also needs to pick up any static-config edits and + # reset cached context-window profiles). + try: + from app.config import config as _app_config + from app.services.llm_router_service import ( + LLMRouterService, + _router_instance_cache as _chat_router_cache, + ) + + LLMRouterService.rebuild( + _app_config.GLOBAL_LLM_CONFIGS, + getattr(_app_config, "ROUTER_SETTINGS", None), + ) + _chat_router_cache.clear() + except Exception as exc: + logger.warning("OpenRouter refresh: router rebuild skipped (%s)", exc) + + @staticmethod + def _tier_counts(configs: list[dict]) -> dict[str, int]: + counts = {"free": 0, "premium": 0} + for cfg in configs: + tier = str(cfg.get("billing_tier", "")).lower() + if tier in counts: + counts[tier] += 1 + return counts + + # ------------------------------------------------------------------ + # Auto (Fastest) health enrichment + # ------------------------------------------------------------------ + + async def _enrich_health_safely( + self, configs: list[dict], *, log_summary: bool = True + ) -> None: + """Wrapper around ``_enrich_health`` that swallows all errors. + + Health enrichment is best-effort: any failure must leave cfgs in + their static-only state and never break refresh / startup. + """ + try: + await self._enrich_health(configs, log_summary=log_summary) + except Exception: + logger.exception("OpenRouter health enrichment failed") + + async def _enrich_health( + self, configs: list[dict], *, log_summary: bool = True + ) -> None: + """Fetch per-model ``/endpoints`` data for the top OR cfgs and blend + the resulting health score into ``cfg["quality_score"]``. + + Bounded fan-out: top-N per tier by ``quality_score_static`` only, + with ``asyncio.Semaphore(_HEALTH_ENRICH_CONCURRENCY)`` guarding the + outbound HTTP. Misses fall back to a per-model last-good cache; if + the failure ratio crosses ``_HEALTH_FAIL_RATIO_FALLBACK`` we keep + the entire previous cycle's cache for this run. + """ + or_cfgs = [ + c for c in configs if str(c.get("provider", "")).upper() == "OPENROUTER" + ] + if not or_cfgs: + return + + premium_pool = sorted( + [c for c in or_cfgs if str(c.get("billing_tier", "")).lower() == "premium"], + key=lambda c: -int(c.get("quality_score_static") or 0), + )[:_HEALTH_ENRICH_TOP_N_PREMIUM] + free_pool = sorted( + [c for c in or_cfgs if str(c.get("billing_tier", "")).lower() == "free"], + key=lambda c: -int(c.get("quality_score_static") or 0), + )[:_HEALTH_ENRICH_TOP_N_FREE] + # De-duplicate while preserving order: a cfg shouldn't fall in both + # tiers, but defensive code is cheap here. + seen_ids: set[int] = set() + selected: list[dict] = [] + for cfg in premium_pool + free_pool: + cid = int(cfg.get("id", 0)) + if cid in seen_ids: + continue + seen_ids.add(cid) + selected.append(cfg) + + if not selected: + return + + api_key = str(self._settings.get("api_key") or "") + semaphore = asyncio.Semaphore(_HEALTH_ENRICH_CONCURRENCY) + + async with httpx.AsyncClient(timeout=_HEALTH_FETCH_TIMEOUT_SEC) as client: + results = await asyncio.gather( + *( + self._fetch_endpoints(client, semaphore, api_key, cfg) + for cfg in selected + ) + ) + + fail_count = sum(1 for _, _, err in results if err is not None) + fail_ratio = fail_count / len(results) if results else 0.0 + degraded = fail_ratio >= _HEALTH_FAIL_RATIO_FALLBACK + if degraded: + logger.warning( + "auto_pin_health_enrich_degraded fail_ratio=%.2f total=%d " + "using_last_good_cache=true", + fail_ratio, + len(results), + ) + + # Per-cfg health update. + for cfg, endpoints, err in results: + model_name = str(cfg.get("model_name", "")) + if not degraded and err is None and endpoints is not None: + gated, h_score = aggregate_health(endpoints) + cfg["health_gated"] = bool(gated) + cfg["quality_score_health"] = h_score + self._health_cache[model_name] = { + "gated": bool(gated), + "score": h_score, + } + else: + cached = self._health_cache.get(model_name) + if cached is not None: + cfg["health_gated"] = bool(cached.get("gated", False)) + cfg["quality_score_health"] = cached.get("score") + # else: keep current values (initial defaults from + # _generate_configs / load_global_llm_configs). + + # Blend health into the final score for every OR cfg, including + # those outside the enriched top-N (they fall through to static). + gated_count = 0 + by_provider: dict[str, int] = {} + for cfg in or_cfgs: + static_q = int(cfg.get("quality_score_static") or 0) + h = cfg.get("quality_score_health") + if h is not None and not cfg.get("health_gated"): + blended = ( + _HEALTH_BLEND_WEIGHT * float(h) + + (1 - _HEALTH_BLEND_WEIGHT) * static_q + ) + cfg["quality_score"] = round(blended) + else: + cfg["quality_score"] = static_q + + if cfg.get("health_gated"): + gated_count += 1 + model_id = str(cfg.get("model_name", "")) + provider_slug = ( + model_id.split("/", 1)[0] if "/" in model_id else "unknown" + ) + by_provider[provider_slug] = by_provider.get(provider_slug, 0) + 1 + + if log_summary: + logger.info( + "auto_pin_health_gated count=%d by_provider=%s fail_ratio=%.2f " + "total_enriched=%d", + gated_count, + dict(sorted(by_provider.items(), key=lambda kv: -kv[1])), + fail_ratio, + len(selected), + ) + + @staticmethod + async def _fetch_endpoints( + client: httpx.AsyncClient, + semaphore: asyncio.Semaphore, + api_key: str, + cfg: dict, + ) -> tuple[dict, list[dict] | None, Exception | None]: + """Fetch ``/api/v1/models/{id}/endpoints`` for one cfg. + + Returns ``(cfg, endpoints, err)`` so the caller can keep batched + results aligned with their cfgs without raising. + """ + model_id = str(cfg.get("model_name", "")) + if not model_id: + return cfg, None, ValueError("missing model_name") + + url = OPENROUTER_ENDPOINTS_URL_TEMPLATE.format(model_id=model_id) + headers = {"Authorization": f"Bearer {api_key}"} if api_key else {} + + async with semaphore: + try: + resp = await client.get(url, headers=headers) + resp.raise_for_status() + data = resp.json() + except Exception as exc: + return cfg, None, exc + + payload = data.get("data") if isinstance(data, dict) else None + if not isinstance(payload, dict): + return cfg, None, ValueError("malformed endpoints payload") + endpoints = payload.get("endpoints") + if not isinstance(endpoints, list): + return cfg, [], None + return cfg, endpoints, None async def _refresh_loop(self, interval_hours: float) -> None: interval_sec = interval_hours * 3600 diff --git a/surfsense_backend/app/services/quality_score.py b/surfsense_backend/app/services/quality_score.py new file mode 100644 index 000000000..2fb37de21 --- /dev/null +++ b/surfsense_backend/app/services/quality_score.py @@ -0,0 +1,380 @@ +"""Pure-function quality scoring for Auto (Fastest) model selection. + +This module is import-free of any service / request-path dependencies. All +numbers are computed once during the OpenRouter refresh tick (or YAML load) +and cached on the cfg dict, so the chat hot path only does a precomputed +sort and a SHA256 pick. + +Score components (0-100 scale, higher is better): + +* ``static_score_or`` - derived from the bulk ``/api/v1/models`` payload + (provider prestige + ``created`` recency + pricing band + context window + + capabilities + narrow tiny/legacy slug penalty). +* ``static_score_yaml`` - same shape for hand-curated YAML configs, plus + an operator-trust bonus (the operator deliberately picked this model). +* ``aggregate_health`` - run on per-model ``/api/v1/models/{id}/endpoints`` + responses; returns ``(gated, score_or_none)``. + +The blended ``quality_score`` (0.5 * static + 0.5 * health) is computed in +:mod:`app.services.openrouter_integration_service` because that's the only +caller that sees both halves. +""" + +from __future__ import annotations + +# --------------------------------------------------------------------------- +# Tunables (constants, not flags) +# --------------------------------------------------------------------------- + +# Top-K size for deterministic spread inside the locked tier. +_QUALITY_TOP_K: int = 5 + +# Hard health gate: any cfg whose best non-null uptime is below this % +# is excluded from Auto-mode selection entirely. +_HEALTH_GATE_UPTIME_PCT: float = 90.0 + +# Health/static blend weight when a cfg has fresh /endpoints data. +_HEALTH_BLEND_WEIGHT: float = 0.5 + +# Static bonus applied to YAML cfgs because the operator hand-picked them. +_OPERATOR_TRUST_BONUS: int = 20 + +# /endpoints fan-out is bounded per refresh tick. +_HEALTH_ENRICH_TOP_N_PREMIUM: int = 50 +_HEALTH_ENRICH_TOP_N_FREE: int = 30 +_HEALTH_ENRICH_CONCURRENCY: int = 15 +_HEALTH_FETCH_TIMEOUT_SEC: float = 5.0 + +# If at least this fraction of /endpoints fetches fail in a refresh cycle, +# fall back to the previous cycle's last-good cache instead of writing +# partial / stale health values. +_HEALTH_FAIL_RATIO_FALLBACK: float = 0.25 + +# Narrow tiny/legacy slug penalties only. We deliberately do NOT penalise +# ``-nano`` / ``-mini`` / ``-lite`` because modern frontier models ship with +# those naming patterns (``gpt-5-mini``, ``gemini-2.5-flash-lite`` etc.) and +# blanket-penalising them suppresses high-quality picks. +_TINY_LEGACY_PENALTY_PATTERNS: tuple[str, ...] = ( + "-1b-", + "-1.2b-", + "-1.5b-", + "-2b-", + "-3b-", + "gemma-3n", + "lfm-", + "-base", + "-distill", + ":nitro", + "-preview", +) + + +# --------------------------------------------------------------------------- +# Provider prestige tables +# --------------------------------------------------------------------------- + +# OpenRouter-side provider slug (the prefix before ``/`` in the model id). +# Tiers are coarse: frontier labs > strong open / fast-moving labs > +# specialist labs > everything else. +PROVIDER_PRESTIGE_OR: dict[str, int] = { + # Frontier labs + "openai": 50, + "anthropic": 50, + "google": 50, + "x-ai": 50, + # Strong open / fast-moving labs + "deepseek": 38, + "qwen": 38, + "meta-llama": 38, + "mistralai": 38, + "cohere": 38, + "nvidia": 38, + "alibaba": 38, + # Specialist / regional / strong second-tier + "microsoft": 28, + "01-ai": 28, + "minimax": 28, + "moonshot": 28, + "z-ai": 28, + "nousresearch": 28, + "ai21": 28, + "perplexity": 28, + # Smaller / niche providers + "liquid": 18, + "cognitivecomputations": 18, + "venice": 18, + "inflection": 18, +} + +# YAML provider field (the upstream API shape the operator selected). +PROVIDER_PRESTIGE_YAML: dict[str, int] = { + "AZURE_OPENAI": 50, + "OPENAI": 50, + "ANTHROPIC": 50, + "GOOGLE": 50, + "VERTEX_AI": 50, + "GEMINI": 50, + "XAI": 50, + "MISTRAL": 38, + "DEEPSEEK": 38, + "COHERE": 38, + "GROQ": 30, + "TOGETHER_AI": 28, + "FIREWORKS_AI": 28, + "PERPLEXITY": 28, + "MINIMAX": 28, + "BEDROCK": 28, + "OPENROUTER": 25, + "OLLAMA": 12, + "CUSTOM": 12, +} + + +# --------------------------------------------------------------------------- +# Pure scoring helpers +# --------------------------------------------------------------------------- + +# Calibrated against the live /api/v1/models bulk dump. Frontier models +# released in the last ~6 months (GPT-5 family, Claude 4.x, Gemini 2.5, +# Grok 4) score in the 18-20 band; mid-2024 models in the 8-12 band; +# anything older trails off. +_RECENCY_BANDS_DAYS: tuple[tuple[int, int], ...] = ( + (60, 20), + (180, 16), + (365, 12), + (540, 9), + (730, 6), + (1095, 3), +) + + +def created_recency_signal(created_ts: int | None, now_ts: int) -> int: + """Return 0-20 based on how recently the model was published. + + Uses the OpenRouter ``created`` Unix timestamp (or any equivalent for + YAML cfgs). Models without a usable timestamp get 0 (we don't penalise, + we just don't reward). + """ + if created_ts is None or created_ts <= 0 or now_ts <= 0: + return 0 + age_days = max(0, (now_ts - int(created_ts)) // 86_400) + for cutoff, score in _RECENCY_BANDS_DAYS: + if age_days <= cutoff: + return score + return 0 + + +def pricing_band( + prompt: str | float | int | None, + completion: str | float | int | None, +) -> int: + """Return 0-15 based on combined prompt+completion cost per 1M tokens. + + Higher-priced models tend to be the larger / more capable ones. A free + model returns 0 (we use other signals to rank free-vs-free instead). + Uncoercible inputs are treated as 0 rather than raising. + """ + + def _to_float(value) -> float: + if value is None: + return 0.0 + try: + return float(value) + except (TypeError, ValueError): + return 0.0 + + p = _to_float(prompt) + c = _to_float(completion) + total_per_million = (p + c) * 1_000_000 + + if total_per_million >= 20.0: + return 15 + if total_per_million >= 5.0: + return 12 + if total_per_million >= 1.0: + return 9 + if total_per_million >= 0.3: + return 6 + if total_per_million >= 0.05: + return 4 + if total_per_million > 0.0: + return 2 + return 0 + + +def context_signal(ctx: int | None) -> int: + """Return 0-10 based on the model's context window.""" + if not ctx or ctx <= 0: + return 0 + if ctx >= 1_000_000: + return 10 + if ctx >= 400_000: + return 8 + if ctx >= 200_000: + return 6 + if ctx >= 128_000: + return 4 + if ctx >= 100_000: + return 2 + return 0 + + +def capabilities_signal(supported_parameters: list[str] | None) -> int: + """Return 0-5 for capabilities that matter for our agent flows.""" + if not supported_parameters: + return 0 + params = set(supported_parameters) + score = 0 + if "tools" in params: + score += 2 + if "structured_outputs" in params or "response_format" in params: + score += 2 + if "reasoning" in params or "include_reasoning" in params: + score += 1 + return min(score, 5) + + +def slug_penalty(model_id: str) -> int: + """Return a non-positive number; matches the narrow tiny/legacy patterns.""" + if not model_id: + return 0 + needle = model_id.lower() + for pattern in _TINY_LEGACY_PENALTY_PATTERNS: + if pattern in needle: + return -10 + return 0 + + +def _provider_prestige_or(model_id: str) -> int: + if "/" not in model_id: + return 0 + slug = model_id.split("/", 1)[0].lower() + return PROVIDER_PRESTIGE_OR.get(slug, 15) + + +def static_score_or(or_model: dict, *, now_ts: int) -> int: + """Score a raw OpenRouter ``/api/v1/models`` entry on a 0-100 scale.""" + model_id = str(or_model.get("id", "")) + pricing = or_model.get("pricing") or {} + + score = ( + _provider_prestige_or(model_id) + + created_recency_signal(or_model.get("created"), now_ts) + + pricing_band(pricing.get("prompt"), pricing.get("completion")) + + context_signal(or_model.get("context_length")) + + capabilities_signal(or_model.get("supported_parameters")) + + slug_penalty(model_id) + ) + return max(0, min(100, int(score))) + + +def static_score_yaml(cfg: dict) -> int: + """Score a YAML-curated cfg on a 0-100 scale. + + Includes ``_OPERATOR_TRUST_BONUS`` because the operator deliberately + listed this model. Pricing / context fall through to lazy ``litellm`` + lookups; failures are silent (we just lose those sub-points). + """ + provider = str(cfg.get("provider", "")).upper() + base = PROVIDER_PRESTIGE_YAML.get(provider, 15) + + model_name = cfg.get("model_name") or "" + litellm_params = cfg.get("litellm_params") or {} + lookup_name = ( + litellm_params.get("base_model") or litellm_params.get("model") or model_name + ) + + ctx = 0 + p_cost: float = 0.0 + c_cost: float = 0.0 + try: + from litellm import get_model_info # lazy: avoid cold-import cost + + info = get_model_info(lookup_name) or {} + ctx = int(info.get("max_input_tokens") or info.get("max_tokens") or 0) + p_cost = float(info.get("input_cost_per_token") or 0.0) + c_cost = float(info.get("output_cost_per_token") or 0.0) + except Exception: + # Unknown to litellm — that's fine for prestige+operator-bonus weighting. + pass + + score = ( + base + + _OPERATOR_TRUST_BONUS + + pricing_band(p_cost, c_cost) + + context_signal(ctx) + + slug_penalty(str(model_name)) + ) + return max(0, min(100, int(score))) + + +# --------------------------------------------------------------------------- +# Health aggregation +# --------------------------------------------------------------------------- + + +def _coerce_pct(value) -> float | None: + try: + if value is None: + return None + f = float(value) + except (TypeError, ValueError): + return None + if f < 0: + return None + # OpenRouter reports uptime as a 0-1 fraction; some endpoints surface it + # as a 0-100 percentage. Normalise. + return f * 100.0 if f <= 1.0 else f + + +def _best_uptime(endpoints: list[dict]) -> tuple[float | None, str | None]: + """Pick the best (highest) non-null uptime across all endpoints. + + Window preference: ``uptime_last_30m`` > ``uptime_last_1d`` > + ``uptime_last_5m``. Returns ``(uptime_pct, window_used)``. + """ + for window in ("uptime_last_30m", "uptime_last_1d", "uptime_last_5m"): + values = [_coerce_pct(ep.get(window)) for ep in endpoints] + values = [v for v in values if v is not None] + if values: + return max(values), window + return None, None + + +def aggregate_health(endpoints: list[dict]) -> tuple[bool, float | None]: + """Aggregate a model's per-endpoint health into ``(gated, score_or_none)``. + + Hard gate (returns ``(True, None)``): + * ``endpoints`` empty, + * no endpoint reports ``status == 0`` (OK), or + * best non-null uptime below ``_HEALTH_GATE_UPTIME_PCT``. + + On a pass, returns a 0-100 health score blending uptime, status, and a + freshness-weighted recent uptime sample. + """ + if not endpoints: + return True, None + + any_ok = any(int(ep.get("status", 1)) == 0 for ep in endpoints) + if not any_ok: + return True, None + + best_uptime, _ = _best_uptime(endpoints) + if best_uptime is None or best_uptime < _HEALTH_GATE_UPTIME_PCT: + return True, None + + # Freshness term: prefer 5m, fall through to 30m / 1d if 5m is missing. + freshness = None + for window in ("uptime_last_5m", "uptime_last_30m", "uptime_last_1d"): + values = [_coerce_pct(ep.get(window)) for ep in endpoints] + values = [v for v in values if v is not None] + if values: + freshness = max(values) + break + + uptime_term = best_uptime + status_term = 100.0 if any_ok else 0.0 + freshness_term = freshness if freshness is not None else best_uptime + + score = 0.50 * uptime_term + 0.30 * status_term + 0.20 * freshness_term + return False, max(0.0, min(100.0, score)) diff --git a/surfsense_backend/app/tasks/chat/stream_new_chat.py b/surfsense_backend/app/tasks/chat/stream_new_chat.py index 5abcb63eb..dbfe9a67b 100644 --- a/surfsense_backend/app/tasks/chat/stream_new_chat.py +++ b/surfsense_backend/app/tasks/chat/stream_new_chat.py @@ -64,7 +64,12 @@ from app.db import ( shielded_async_session, ) from app.prompts import TITLE_GENERATION_PROMPT -from app.services.auto_model_pin_service import resolve_or_get_pinned_llm_config_id +from app.services.auto_model_pin_service import ( + is_recently_healthy, + mark_healthy, + mark_runtime_cooldown, + resolve_or_get_pinned_llm_config_id, +) from app.services.chat_session_state_service import ( clear_ai_responding, set_ai_responding, @@ -299,20 +304,17 @@ def _tool_output_has_error(tool_output: Any) -> bool: return False -def _extract_resolved_file_path(*, tool_name: str, tool_output: Any) -> str | None: +def _extract_resolved_file_path( + *, tool_name: str, tool_output: Any, tool_input: Any | None = None +) -> str | None: if isinstance(tool_output, dict): path_value = tool_output.get("path") if isinstance(path_value, str) and path_value.strip(): return path_value.strip() - text = _tool_output_to_text(tool_output) - if tool_name == "write_file": - match = re.search(r"Updated file\s+(.+)$", text.strip()) - if match: - return match.group(1).strip() - if tool_name == "edit_file": - match = re.search(r"in '([^']+)'", text) - if match: - return match.group(1).strip() + if tool_name in ("write_file", "edit_file") and isinstance(tool_input, dict): + file_path = tool_input.get("file_path") + if isinstance(file_path, str) and file_path.strip(): + return file_path.strip() return None @@ -414,6 +416,108 @@ def _parse_error_payload(message: str) -> dict[str, Any] | None: return None +def _extract_provider_error_code(parsed: dict[str, Any] | None) -> int | None: + if not isinstance(parsed, dict): + return None + candidates: list[Any] = [parsed.get("code")] + nested = parsed.get("error") + if isinstance(nested, dict): + candidates.append(nested.get("code")) + for value in candidates: + try: + if value is None: + continue + return int(value) + except Exception: + continue + return None + + +def _is_provider_rate_limited(exc: BaseException) -> bool: + """Best-effort detection for provider-side runtime throttling. + + Covers LiteLLM/OpenRouter shapes like: + - class name contains ``RateLimit`` + - nested payload ``{"error": {"code": 429}}`` + - nested payload ``{"error": {"type": "rate_limit_error"}}`` + """ + raw = str(exc) + lowered = raw.lower() + if "ratelimit" in type(exc).__name__.lower(): + return True + parsed = _parse_error_payload(raw) + provider_code = _extract_provider_error_code(parsed) + if provider_code == 429: + return True + + provider_error_type = "" + if parsed: + top_type = parsed.get("type") + if isinstance(top_type, str): + provider_error_type = top_type.lower() + nested = parsed.get("error") + if isinstance(nested, dict): + nested_type = nested.get("type") + if isinstance(nested_type, str): + provider_error_type = nested_type.lower() + if provider_error_type == "rate_limit_error": + return True + + return ( + "rate limited" in lowered + or "rate-limited" in lowered + or "temporarily rate-limited upstream" in lowered + ) + + +_PREFLIGHT_TIMEOUT_SEC: float = 2.5 +_PREFLIGHT_MAX_TOKENS: int = 1 + + +async def _preflight_llm(llm: Any) -> None: + """Issue a minimal completion to confirm the pinned model isn't 429'ing. + + Used before agent build / planner / classifier / title-gen so a known-bad + free OpenRouter deployment is detected and repinned before it cascades + into multiple wasted internal calls. The probe is intentionally cheap: + one token, low timeout, tagged ``surfsense:internal`` so token tracking + and SSE pipelines treat it as overhead rather than user output. + + Raises the original exception when the provider responds with a + rate-limit-shaped error so the caller can drive the cooldown/repin + branch via :func:`_is_provider_rate_limited`. Other transient failures + are swallowed — the caller continues to the normal stream path and the + in-stream recovery loop remains the safety net. + """ + from litellm import acompletion + + model = getattr(llm, "model", None) + if not model or model == "auto": + # Auto-mode router doesn't have a single deployment to ping; the + # router itself handles per-deployment rate-limit accounting. + return + + try: + await acompletion( + model=model, + messages=[{"role": "user", "content": "ping"}], + api_key=getattr(llm, "api_key", None), + api_base=getattr(llm, "api_base", None), + max_tokens=_PREFLIGHT_MAX_TOKENS, + timeout=_PREFLIGHT_TIMEOUT_SEC, + stream=False, + metadata={"tags": ["surfsense:internal", "auto-pin-preflight"]}, + ) + except Exception as exc: + if _is_provider_rate_limited(exc): + raise + logging.getLogger(__name__).debug( + "auto_pin_preflight non_rate_limit_error model=%s err=%s", + model, + exc, + ) + + def _classify_stream_exception( exc: Exception, *, @@ -449,19 +553,7 @@ def _classify_stream_exception( None, ) - parsed = _parse_error_payload(raw) - provider_error_type = "" - if parsed: - top_type = parsed.get("type") - if isinstance(top_type, str): - provider_error_type = top_type.lower() - nested = parsed.get("error") - if isinstance(nested, dict): - nested_type = nested.get("type") - if isinstance(nested_type, str): - provider_error_type = nested_type.lower() - - if provider_error_type == "rate_limit_error": + if _is_provider_rate_limited(exc): return ( "rate_limited", "RATE_LIMITED", @@ -619,6 +711,7 @@ async def _stream_agent_events( # fallback path only and never re-pops a chunk we already streamed. pending_tool_call_chunks: list[dict[str, Any]] = [] lc_tool_call_id_by_run: dict[str, str] = {} + file_path_by_run: dict[str, str] = {} # parity_v2 only: live tool-call argument streaming. ``index_to_meta`` # is keyed by the chunk's ``index`` field — LangChain @@ -797,6 +890,10 @@ async def _stream_agent_events( tool_input = event.get("data", {}).get("input", {}) if tool_name in ("write_file", "edit_file"): result.write_attempted = True + if isinstance(tool_input, dict): + file_path = tool_input.get("file_path") + if isinstance(file_path, str) and file_path.strip() and run_id: + file_path_by_run[run_id] = file_path.strip() if current_text_id is not None: yield streaming_service.format_text_end(current_text_id) @@ -1203,6 +1300,7 @@ async def _stream_agent_events( run_id = event.get("run_id", "") tool_name = event.get("name", "unknown_tool") raw_output = event.get("data", {}).get("output", "") + staged_file_path = file_path_by_run.pop(run_id, None) if run_id else None if tool_name == "update_memory": called_update_memory = True @@ -1716,6 +1814,9 @@ async def _stream_agent_events( resolved_path = _extract_resolved_file_path( tool_name=tool_name, tool_output=tool_output, + tool_input={"file_path": staged_file_path} + if staged_file_path + else None, ) result_text = _tool_output_to_text(tool_output) if _tool_output_has_error(tool_output): @@ -2326,6 +2427,91 @@ async def stream_new_chat( yield streaming_service.format_done() return + # Auto-mode preflight ping. Runs ONLY for thread-pinned auto cfgs + # (negative ids selected via ``resolve_or_get_pinned_llm_config_id``) + # whose health hasn't already been confirmed within the TTL window. + # Detecting a 429 here lets us repin BEFORE the planner/classifier/ + # title-generation LLM calls fan out and each independently hit the + # same upstream rate limit. + if ( + requested_llm_config_id == 0 + and llm_config_id < 0 + and not is_recently_healthy(llm_config_id) + ): + _t_preflight = time.perf_counter() + try: + await _preflight_llm(llm) + mark_healthy(llm_config_id) + _perf_log.info( + "[stream_new_chat] auto_pin_preflight ok config_id=%s took=%.3fs", + llm_config_id, + time.perf_counter() - _t_preflight, + ) + except Exception as preflight_exc: + if not _is_provider_rate_limited(preflight_exc): + raise + previous_config_id = llm_config_id + mark_runtime_cooldown( + previous_config_id, reason="preflight_rate_limited" + ) + try: + llm_config_id = ( + await resolve_or_get_pinned_llm_config_id( + session, + thread_id=chat_id, + search_space_id=search_space_id, + user_id=user_id, + selected_llm_config_id=0, + exclude_config_ids={previous_config_id}, + ) + ).resolved_llm_config_id + except ValueError as pin_error: + yield _emit_stream_error( + message=str(pin_error), + error_kind="server_error", + error_code="SERVER_ERROR", + ) + yield streaming_service.format_done() + return + + llm, agent_config, llm_load_error = await _load_llm_bundle( + llm_config_id + ) + if llm_load_error or not llm: + yield _emit_stream_error( + message=llm_load_error or "Failed to create LLM instance", + error_kind="server_error", + error_code="SERVER_ERROR", + ) + yield streaming_service.format_done() + return + # Trust the freshly-resolved cfg for the remainder of this + # turn rather than recursing into another preflight; the + # in-stream 429 recovery loop is still in place as the + # safety net if even this fallback hits an upstream cap. + mark_healthy(llm_config_id) + _log_chat_stream_error( + flow=flow, + error_kind="rate_limited", + error_code="RATE_LIMITED", + severity="info", + is_expected=True, + request_id=request_id, + thread_id=chat_id, + search_space_id=search_space_id, + user_id=user_id, + message=( + "Auto-pinned model failed preflight; switched to another " + "eligible model and continuing." + ), + extra={ + "auto_runtime_recover": True, + "preflight": True, + "previous_config_id": previous_config_id, + "fallback_config_id": llm_config_id, + }, + ) + # Create connector service _t0 = time.perf_counter() connector_service = ConnectorService(session, search_space_id=search_space_id) @@ -2671,54 +2857,155 @@ async def stream_new_chat( _t_stream_start = time.perf_counter() _first_event_logged = False - async for sse in _stream_agent_events( - agent=agent, - config=config, - input_data=input_state, - streaming_service=streaming_service, - result=stream_result, - step_prefix="thinking", - initial_step_id=initial_step_id, - initial_step_title=initial_title, - initial_step_items=initial_items, - fallback_commit_search_space_id=search_space_id, - fallback_commit_created_by_id=user_id, - fallback_commit_filesystem_mode=( - filesystem_selection.mode - if filesystem_selection - else FilesystemMode.CLOUD - ), - fallback_commit_thread_id=chat_id, - ): - if not _first_event_logged: - _perf_log.info( - "[stream_new_chat] First agent event in %.3fs (time since stream start), " - "%.3fs (total since request start) (chat_id=%s)", - time.perf_counter() - _t_stream_start, - time.perf_counter() - _t_total, - chat_id, - ) - _first_event_logged = True - yield sse - - # Inject title update mid-stream as soon as the background task finishes - if title_task is not None and title_task.done() and not title_emitted: - generated_title, title_usage = title_task.result() - if title_usage: - accumulator.add(**title_usage) - if generated_title: - async with shielded_async_session() as title_session: - title_thread_result = await title_session.execute( - select(NewChatThread).filter(NewChatThread.id == chat_id) + runtime_rate_limit_recovered = False + while True: + try: + async for sse in _stream_agent_events( + agent=agent, + config=config, + input_data=input_state, + streaming_service=streaming_service, + result=stream_result, + step_prefix="thinking", + initial_step_id=initial_step_id, + initial_step_title=initial_title, + initial_step_items=initial_items, + fallback_commit_search_space_id=search_space_id, + fallback_commit_created_by_id=user_id, + fallback_commit_filesystem_mode=( + filesystem_selection.mode + if filesystem_selection + else FilesystemMode.CLOUD + ), + fallback_commit_thread_id=chat_id, + ): + if not _first_event_logged: + _perf_log.info( + "[stream_new_chat] First agent event in %.3fs (time since stream start), " + "%.3fs (total since request start) (chat_id=%s)", + time.perf_counter() - _t_stream_start, + time.perf_counter() - _t_total, + chat_id, ) - title_thread = title_thread_result.scalars().first() - if title_thread: - title_thread.title = generated_title - await title_session.commit() - yield streaming_service.format_thread_title_update( - chat_id, generated_title + _first_event_logged = True + yield sse + + # Inject title update mid-stream as soon as the background + # task finishes. + if ( + title_task is not None + and title_task.done() + and not title_emitted + ): + generated_title, title_usage = title_task.result() + if title_usage: + accumulator.add(**title_usage) + if generated_title: + async with shielded_async_session() as title_session: + title_thread_result = await title_session.execute( + select(NewChatThread).filter( + NewChatThread.id == chat_id + ) + ) + title_thread = title_thread_result.scalars().first() + if title_thread: + title_thread.title = generated_title + await title_session.commit() + yield streaming_service.format_thread_title_update( + chat_id, generated_title + ) + title_emitted = True + break + except Exception as stream_exc: + can_runtime_recover = ( + not runtime_rate_limit_recovered + and requested_llm_config_id == 0 + and llm_config_id < 0 + and not _first_event_logged + and _is_provider_rate_limited(stream_exc) + ) + if not can_runtime_recover: + raise + + runtime_rate_limit_recovered = True + previous_config_id = llm_config_id + # The failed attempt may still hold the per-thread busy mutex + # (middleware teardown can lag behind raised provider errors). + # Force release before we retry within the same request. + end_turn(str(chat_id)) + mark_runtime_cooldown( + previous_config_id, + reason="provider_rate_limited", + ) + + llm_config_id = ( + await resolve_or_get_pinned_llm_config_id( + session, + thread_id=chat_id, + search_space_id=search_space_id, + user_id=user_id, + selected_llm_config_id=0, + exclude_config_ids={previous_config_id}, ) - title_emitted = True + ).resolved_llm_config_id + + llm, agent_config, llm_load_error = await _load_llm_bundle( + llm_config_id + ) + if llm_load_error: + raise stream_exc + + # Title generation uses the initial llm object. After a runtime + # repin we keep the stream focused on response recovery and skip + # title generation for this turn. + if title_task is not None and not title_task.done(): + title_task.cancel() + title_task = None + + _t0 = time.perf_counter() + agent = await create_surfsense_deep_agent( + llm=llm, + search_space_id=search_space_id, + db_session=session, + connector_service=connector_service, + checkpointer=checkpointer, + user_id=user_id, + thread_id=chat_id, + agent_config=agent_config, + firecrawl_api_key=firecrawl_api_key, + thread_visibility=visibility, + disabled_tools=disabled_tools, + mentioned_document_ids=mentioned_document_ids, + filesystem_selection=filesystem_selection, + ) + _perf_log.info( + "[stream_new_chat] Runtime rate-limit recovery repinned " + "config_id=%s -> %s and rebuilt agent in %.3fs", + previous_config_id, + llm_config_id, + time.perf_counter() - _t0, + ) + _log_chat_stream_error( + flow=flow, + error_kind="rate_limited", + error_code="RATE_LIMITED", + severity="info", + is_expected=True, + request_id=request_id, + thread_id=chat_id, + search_space_id=search_space_id, + user_id=user_id, + message=( + "Auto-pinned model hit runtime rate limit; switched to " + "another eligible model and retried." + ), + extra={ + "auto_runtime_recover": True, + "previous_config_id": previous_config_id, + "fallback_config_id": llm_config_id, + }, + ) + continue _perf_log.info( "[stream_new_chat] Agent stream completed in %.3fs (chat_id=%s)", @@ -3187,6 +3474,84 @@ async def stream_resume_chat( yield streaming_service.format_done() return + # Auto-mode preflight ping (resume path). Mirrors ``stream_new_chat``: + # one cheap probe before the agent is rebuilt so a 429'd pin gets + # repinned without burning planner/classifier/title calls first. + if ( + requested_llm_config_id == 0 + and llm_config_id < 0 + and not is_recently_healthy(llm_config_id) + ): + _t_preflight = time.perf_counter() + try: + await _preflight_llm(llm) + mark_healthy(llm_config_id) + _perf_log.info( + "[stream_resume] auto_pin_preflight ok config_id=%s took=%.3fs", + llm_config_id, + time.perf_counter() - _t_preflight, + ) + except Exception as preflight_exc: + if not _is_provider_rate_limited(preflight_exc): + raise + previous_config_id = llm_config_id + mark_runtime_cooldown( + previous_config_id, reason="preflight_rate_limited" + ) + try: + llm_config_id = ( + await resolve_or_get_pinned_llm_config_id( + session, + thread_id=chat_id, + search_space_id=search_space_id, + user_id=user_id, + selected_llm_config_id=0, + exclude_config_ids={previous_config_id}, + ) + ).resolved_llm_config_id + except ValueError as pin_error: + yield _emit_stream_error( + message=str(pin_error), + error_kind="server_error", + error_code="SERVER_ERROR", + ) + yield streaming_service.format_done() + return + + llm, agent_config, llm_load_error = await _load_llm_bundle( + llm_config_id + ) + if llm_load_error or not llm: + yield _emit_stream_error( + message=llm_load_error or "Failed to create LLM instance", + error_kind="server_error", + error_code="SERVER_ERROR", + ) + yield streaming_service.format_done() + return + mark_healthy(llm_config_id) + _log_chat_stream_error( + flow="resume", + error_kind="rate_limited", + error_code="RATE_LIMITED", + severity="info", + is_expected=True, + request_id=request_id, + thread_id=chat_id, + search_space_id=search_space_id, + user_id=user_id, + message=( + "Auto-pinned model failed preflight; switched to another " + "eligible model and continuing." + ), + extra={ + "auto_runtime_recover": True, + "preflight": True, + "previous_config_id": previous_config_id, + "fallback_config_id": llm_config_id, + }, + ) + _t0 = time.perf_counter() connector_service = ConnectorService(session, search_space_id=search_space_id) @@ -3265,31 +3630,114 @@ async def stream_resume_chat( _t_stream_start = time.perf_counter() _first_event_logged = False - async for sse in _stream_agent_events( - agent=agent, - config=config, - input_data=Command(resume={"decisions": decisions}), - streaming_service=streaming_service, - result=stream_result, - step_prefix="thinking-resume", - fallback_commit_search_space_id=search_space_id, - fallback_commit_created_by_id=user_id, - fallback_commit_filesystem_mode=( - filesystem_selection.mode - if filesystem_selection - else FilesystemMode.CLOUD - ), - fallback_commit_thread_id=chat_id, - ): - if not _first_event_logged: - _perf_log.info( - "[stream_resume] First agent event in %.3fs (stream), %.3fs (total) (chat_id=%s)", - time.perf_counter() - _t_stream_start, - time.perf_counter() - _t_total, - chat_id, + runtime_rate_limit_recovered = False + while True: + try: + async for sse in _stream_agent_events( + agent=agent, + config=config, + input_data=Command(resume={"decisions": decisions}), + streaming_service=streaming_service, + result=stream_result, + step_prefix="thinking-resume", + fallback_commit_search_space_id=search_space_id, + fallback_commit_created_by_id=user_id, + fallback_commit_filesystem_mode=( + filesystem_selection.mode + if filesystem_selection + else FilesystemMode.CLOUD + ), + fallback_commit_thread_id=chat_id, + ): + if not _first_event_logged: + _perf_log.info( + "[stream_resume] First agent event in %.3fs (stream), %.3fs (total) (chat_id=%s)", + time.perf_counter() - _t_stream_start, + time.perf_counter() - _t_total, + chat_id, + ) + _first_event_logged = True + yield sse + break + except Exception as stream_exc: + can_runtime_recover = ( + not runtime_rate_limit_recovered + and requested_llm_config_id == 0 + and llm_config_id < 0 + and not _first_event_logged + and _is_provider_rate_limited(stream_exc) ) - _first_event_logged = True - yield sse + if not can_runtime_recover: + raise + + runtime_rate_limit_recovered = True + previous_config_id = llm_config_id + # Ensure the same-request recovery retry does not trip the + # BusyMutex lock retained by the failed attempt. + end_turn(str(chat_id)) + mark_runtime_cooldown( + previous_config_id, + reason="provider_rate_limited", + ) + llm_config_id = ( + await resolve_or_get_pinned_llm_config_id( + session, + thread_id=chat_id, + search_space_id=search_space_id, + user_id=user_id, + selected_llm_config_id=0, + exclude_config_ids={previous_config_id}, + ) + ).resolved_llm_config_id + + llm, agent_config, llm_load_error = await _load_llm_bundle( + llm_config_id + ) + if llm_load_error: + raise stream_exc + + _t0 = time.perf_counter() + agent = await create_surfsense_deep_agent( + llm=llm, + search_space_id=search_space_id, + db_session=session, + connector_service=connector_service, + checkpointer=checkpointer, + user_id=user_id, + thread_id=chat_id, + agent_config=agent_config, + firecrawl_api_key=firecrawl_api_key, + thread_visibility=visibility, + filesystem_selection=filesystem_selection, + ) + _perf_log.info( + "[stream_resume] Runtime rate-limit recovery repinned " + "config_id=%s -> %s and rebuilt agent in %.3fs", + previous_config_id, + llm_config_id, + time.perf_counter() - _t0, + ) + _log_chat_stream_error( + flow="resume", + error_kind="rate_limited", + error_code="RATE_LIMITED", + severity="info", + is_expected=True, + request_id=request_id, + thread_id=chat_id, + search_space_id=search_space_id, + user_id=user_id, + message=( + "Auto-pinned model hit runtime rate limit; switched to " + "another eligible model and retried." + ), + extra={ + "auto_runtime_recover": True, + "previous_config_id": previous_config_id, + "fallback_config_id": llm_config_id, + }, + ) + continue _perf_log.info( "[stream_resume] Agent stream completed in %.3fs (chat_id=%s)", time.perf_counter() - _t_stream_start, diff --git a/surfsense_backend/tests/unit/agents/new_chat/test_busy_mutex.py b/surfsense_backend/tests/unit/agents/new_chat/test_busy_mutex.py index c923dc499..f0161f605 100644 --- a/surfsense_backend/tests/unit/agents/new_chat/test_busy_mutex.py +++ b/surfsense_backend/tests/unit/agents/new_chat/test_busy_mutex.py @@ -118,3 +118,37 @@ async def test_end_turn_force_clears_lock_and_cancel_state() -> None: assert not manager.lock_for(thread_id).locked() assert not get_cancel_event(thread_id).is_set() assert is_cancel_requested(thread_id) is False + + +@pytest.mark.asyncio +async def test_busy_mutex_stale_aafter_does_not_release_new_attempt_lock() -> None: + """A stale aafter call from attempt A must not unlock attempt B. + + Repro flow: + 1) attempt A acquires thread lock + 2) forced end_turn clears A so retry can proceed + 3) attempt B acquires same thread lock + 4) stale attempt-A aafter runs late + + Expected: B lock remains held. + """ + thread_id = "stale-aafter-lock" + runtime = _Runtime(thread_id) + attempt_a = BusyMutexMiddleware() + attempt_b = BusyMutexMiddleware() + + await attempt_a.abefore_agent({}, runtime) + lock = manager.lock_for(thread_id) + assert lock.locked() + + end_turn(thread_id) + assert not lock.locked() + + await attempt_b.abefore_agent({}, runtime) + assert lock.locked() + + # Stale cleanup from attempt A must not release attempt B's lock. + await attempt_a.aafter_agent({}, runtime) + assert lock.locked() + + await attempt_b.aafter_agent({}, runtime) diff --git a/surfsense_backend/tests/unit/services/test_auto_model_pin_service.py b/surfsense_backend/tests/unit/services/test_auto_model_pin_service.py index 0a2342e05..49b3621c7 100644 --- a/surfsense_backend/tests/unit/services/test_auto_model_pin_service.py +++ b/surfsense_backend/tests/unit/services/test_auto_model_pin_service.py @@ -6,13 +6,26 @@ from types import SimpleNamespace import pytest from app.services.auto_model_pin_service import ( - AUTO_FASTEST_MODE, + clear_healthy, + clear_runtime_cooldown, + is_recently_healthy, + mark_healthy, + mark_runtime_cooldown, resolve_or_get_pinned_llm_config_id, ) pytestmark = pytest.mark.unit +@pytest.fixture(autouse=True) +def _clear_runtime_cooldown_map(): + clear_runtime_cooldown() + clear_healthy() + yield + clear_runtime_cooldown() + clear_healthy() + + @dataclass class _FakeQuotaResult: allowed: bool @@ -45,14 +58,11 @@ def _thread( *, search_space_id: int = 10, pinned_llm_config_id: int | None = None, - pinned_auto_mode: str | None = None, ): return SimpleNamespace( id=1, search_space_id=search_space_id, pinned_llm_config_id=pinned_llm_config_id, - pinned_auto_mode=pinned_auto_mode, - pinned_at=None, ) @@ -93,8 +103,6 @@ async def test_auto_first_turn_pins_one_model(monkeypatch): ) assert result.resolved_llm_config_id in {-1, -2} assert session.thread.pinned_llm_config_id == result.resolved_llm_config_id - assert session.thread.pinned_auto_mode == AUTO_FASTEST_MODE - assert session.thread.pinned_at is not None assert session.commit_count == 1 @@ -102,9 +110,7 @@ async def test_auto_first_turn_pins_one_model(monkeypatch): async def test_next_turn_reuses_existing_pin(monkeypatch): from app.config import config - session = _FakeSession( - _thread(pinned_llm_config_id=-1, pinned_auto_mode=AUTO_FASTEST_MODE) - ) + session = _FakeSession(_thread(pinned_llm_config_id=-1)) monkeypatch.setattr( config, "GLOBAL_LLM_CONFIGS", @@ -228,9 +234,7 @@ async def test_premium_ineligible_auto_pins_free_only(monkeypatch): async def test_pinned_premium_stays_premium_after_quota_exhaustion(monkeypatch): from app.config import config - session = _FakeSession( - _thread(pinned_llm_config_id=-1, pinned_auto_mode=AUTO_FASTEST_MODE) - ) + session = _FakeSession(_thread(pinned_llm_config_id=-1)) monkeypatch.setattr( config, "GLOBAL_LLM_CONFIGS", @@ -275,9 +279,7 @@ async def test_pinned_premium_stays_premium_after_quota_exhaustion(monkeypatch): async def test_force_repin_free_switches_auto_premium_pin_to_free(monkeypatch): from app.config import config - session = _FakeSession( - _thread(pinned_llm_config_id=-1, pinned_auto_mode=AUTO_FASTEST_MODE) - ) + session = _FakeSession(_thread(pinned_llm_config_id=-1)) monkeypatch.setattr( config, "GLOBAL_LLM_CONFIGS", @@ -325,9 +327,7 @@ async def test_force_repin_free_switches_auto_premium_pin_to_free(monkeypatch): async def test_explicit_user_model_change_clears_pin(monkeypatch): from app.config import config - session = _FakeSession( - _thread(pinned_llm_config_id=-2, pinned_auto_mode=AUTO_FASTEST_MODE) - ) + session = _FakeSession(_thread(pinned_llm_config_id=-2)) monkeypatch.setattr( config, "GLOBAL_LLM_CONFIGS", @@ -345,8 +345,6 @@ async def test_explicit_user_model_change_clears_pin(monkeypatch): ) assert result.resolved_llm_config_id == 7 assert session.thread.pinned_llm_config_id is None - assert session.thread.pinned_auto_mode is None - assert session.thread.pinned_at is None assert session.commit_count == 1 @@ -354,9 +352,7 @@ async def test_explicit_user_model_change_clears_pin(monkeypatch): async def test_invalid_pinned_config_repairs_with_new_pin(monkeypatch): from app.config import config - session = _FakeSession( - _thread(pinned_llm_config_id=-999, pinned_auto_mode=AUTO_FASTEST_MODE) - ) + session = _FakeSession(_thread(pinned_llm_config_id=-999)) monkeypatch.setattr( config, "GLOBAL_LLM_CONFIGS", @@ -383,3 +379,543 @@ async def test_invalid_pinned_config_repairs_with_new_pin(monkeypatch): assert result.resolved_llm_config_id == -2 assert session.thread.pinned_llm_config_id == -2 assert session.commit_count == 1 + + +# --------------------------------------------------------------------------- +# Quality-aware pin selection (Auto Fastest upgrade) +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_health_gated_config_is_excluded_from_selection(monkeypatch): + """A cfg flagged ``health_gated`` must never be picked even if it has + the highest score among eligible cfgs.""" + from app.config import config + + session = _FakeSession(_thread()) + monkeypatch.setattr( + config, + "GLOBAL_LLM_CONFIGS", + [ + { + "id": -1, + "provider": "OPENROUTER", + "model_name": "venice/dead-model", + "api_key": "k1", + "billing_tier": "free", + "auto_pin_tier": "C", + "quality_score": 95, + "health_gated": True, + }, + { + "id": -2, + "provider": "OPENROUTER", + "model_name": "google/gemini-flash", + "api_key": "k1", + "billing_tier": "free", + "auto_pin_tier": "C", + "quality_score": 60, + "health_gated": False, + }, + ], + ) + + async def _blocked(*_args, **_kwargs): + return _FakeQuotaResult(allowed=False) + + monkeypatch.setattr( + "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage", + _blocked, + ) + + result = await resolve_or_get_pinned_llm_config_id( + session, + thread_id=1, + search_space_id=10, + user_id="00000000-0000-0000-0000-000000000001", + selected_llm_config_id=0, + ) + assert result.resolved_llm_config_id == -2 + + +@pytest.mark.asyncio +async def test_tier_a_locks_first_premium_user_skips_or(monkeypatch): + """Premium-eligible users with Tier A available should never spill to + Tier B even if a B cfg ranks higher by ``quality_score``.""" + from app.config import config + + session = _FakeSession(_thread()) + monkeypatch.setattr( + config, + "GLOBAL_LLM_CONFIGS", + [ + { + "id": -1, + "provider": "AZURE_OPENAI", + "model_name": "gpt-5", + "api_key": "k-yaml", + "billing_tier": "premium", + "auto_pin_tier": "A", + "quality_score": 70, + "health_gated": False, + }, + { + "id": -2, + "provider": "OPENROUTER", + "model_name": "openai/gpt-5", + "api_key": "k-or", + "billing_tier": "premium", + "auto_pin_tier": "B", + "quality_score": 95, + "health_gated": False, + }, + ], + ) + + async def _allowed(*_args, **_kwargs): + return _FakeQuotaResult(allowed=True) + + monkeypatch.setattr( + "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage", + _allowed, + ) + + result = await resolve_or_get_pinned_llm_config_id( + session, + thread_id=1, + search_space_id=10, + user_id="00000000-0000-0000-0000-000000000001", + selected_llm_config_id=0, + ) + assert result.resolved_llm_config_id == -1 + assert result.resolved_tier == "premium" + + +@pytest.mark.asyncio +async def test_tier_a_falls_through_to_or_when_a_pool_empty_for_user(monkeypatch): + """Free-only user with no Tier A free cfg should pick from Tier C.""" + from app.config import config + + session = _FakeSession(_thread()) + monkeypatch.setattr( + config, + "GLOBAL_LLM_CONFIGS", + [ + { + "id": -1, + "provider": "AZURE_OPENAI", + "model_name": "gpt-5", + "api_key": "k-yaml", + "billing_tier": "premium", + "auto_pin_tier": "A", + "quality_score": 100, + "health_gated": False, + }, + { + "id": -2, + "provider": "OPENROUTER", + "model_name": "google/gemini-flash:free", + "api_key": "k-or", + "billing_tier": "free", + "auto_pin_tier": "C", + "quality_score": 60, + "health_gated": False, + }, + ], + ) + + async def _blocked(*_args, **_kwargs): + return _FakeQuotaResult(allowed=False) + + monkeypatch.setattr( + "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage", + _blocked, + ) + + result = await resolve_or_get_pinned_llm_config_id( + session, + thread_id=1, + search_space_id=10, + user_id="00000000-0000-0000-0000-000000000001", + selected_llm_config_id=0, + ) + assert result.resolved_llm_config_id == -2 + + +@pytest.mark.asyncio +async def test_top_k_picks_only_high_score_models(monkeypatch): + """Different thread IDs should spread across top-K, never pick the + obvious low-quality cfg even when it sits in the candidate list.""" + from app.config import config + + high_score_cfgs = [ + { + "id": -i, + "provider": "AZURE_OPENAI", + "model_name": f"gpt-x-{i}", + "api_key": "k", + "billing_tier": "premium", + "auto_pin_tier": "A", + "quality_score": 90, + "health_gated": False, + } + for i in range(1, 6) # 5 high-quality Tier A cfgs + ] + low_score_trap = { + "id": -99, + "provider": "AZURE_OPENAI", + "model_name": "tiny-legacy", + "api_key": "k", + "billing_tier": "premium", + "auto_pin_tier": "A", + "quality_score": 10, + "health_gated": False, + } + monkeypatch.setattr( + config, + "GLOBAL_LLM_CONFIGS", + [*high_score_cfgs, low_score_trap], + ) + + async def _allowed(*_args, **_kwargs): + return _FakeQuotaResult(allowed=True) + + monkeypatch.setattr( + "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage", + _allowed, + ) + + high_score_ids = {c["id"] for c in high_score_cfgs} + seen = set() + for thread_id in range(1, 50): + session = _FakeSession(_thread()) + result = await resolve_or_get_pinned_llm_config_id( + session, + thread_id=thread_id, + search_space_id=10, + user_id="00000000-0000-0000-0000-000000000001", + selected_llm_config_id=0, + ) + seen.add(result.resolved_llm_config_id) + assert result.resolved_llm_config_id != -99, ( + "low-score trap cfg should never be picked" + ) + assert result.resolved_llm_config_id in high_score_ids + + # Spread across at least a couple of top-K cfgs. + assert len(seen) > 1 + + +@pytest.mark.asyncio +async def test_pin_reuse_survives_health_gating_for_existing_pin(monkeypatch): + """An *already* pinned cfg that later flips to ``health_gated`` should + still not be reused — gated cfgs are filtered out of the candidate + pool, which forces a repair to a healthy cfg. + + This guards the no-silent-tier-switch invariant: we don't keep using + a known-broken model just because the thread happened to be pinned + to it before the gate fired.""" + from app.config import config + + session = _FakeSession(_thread(pinned_llm_config_id=-1)) + monkeypatch.setattr( + config, + "GLOBAL_LLM_CONFIGS", + [ + { + "id": -1, + "provider": "OPENROUTER", + "model_name": "venice/dead-model", + "api_key": "k", + "billing_tier": "premium", + "auto_pin_tier": "B", + "quality_score": 50, + "health_gated": True, + }, + { + "id": -2, + "provider": "AZURE_OPENAI", + "model_name": "gpt-5", + "api_key": "k", + "billing_tier": "premium", + "auto_pin_tier": "A", + "quality_score": 90, + "health_gated": False, + }, + ], + ) + + async def _allowed(*_args, **_kwargs): + return _FakeQuotaResult(allowed=True) + + monkeypatch.setattr( + "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage", + _allowed, + ) + + result = await resolve_or_get_pinned_llm_config_id( + session, + thread_id=1, + search_space_id=10, + user_id="00000000-0000-0000-0000-000000000001", + selected_llm_config_id=0, + ) + assert result.resolved_llm_config_id == -2 + assert result.from_existing_pin is False + + +@pytest.mark.asyncio +async def test_pin_reuse_regression_existing_healthy_pin(monkeypatch): + """Existing pin reuse must short-circuit the new tier/score logic.""" + from app.config import config + + session = _FakeSession(_thread(pinned_llm_config_id=-1)) + monkeypatch.setattr( + config, + "GLOBAL_LLM_CONFIGS", + [ + { + "id": -1, + "provider": "AZURE_OPENAI", + "model_name": "gpt-5", + "api_key": "k", + "billing_tier": "premium", + "auto_pin_tier": "A", + "quality_score": 50, # lower than -2 + "health_gated": False, + }, + { + "id": -2, + "provider": "AZURE_OPENAI", + "model_name": "gpt-5-pro", + "api_key": "k", + "billing_tier": "premium", + "auto_pin_tier": "A", + "quality_score": 99, + "health_gated": False, + }, + ], + ) + + async def _must_not_call(*_args, **_kwargs): + raise AssertionError("premium_get_usage should not run on pin reuse") + + monkeypatch.setattr( + "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage", + _must_not_call, + ) + + result = await resolve_or_get_pinned_llm_config_id( + session, + thread_id=1, + search_space_id=10, + user_id="00000000-0000-0000-0000-000000000001", + selected_llm_config_id=0, + ) + assert result.resolved_llm_config_id == -1 + assert result.from_existing_pin is True + assert session.commit_count == 0 + + +@pytest.mark.asyncio +async def test_runtime_cooled_down_pin_is_not_reused(monkeypatch): + """A runtime-cooled config should be excluded from candidate reuse. + + This enables one-shot recovery from transient provider 429 bursts: we can + mark the pinned cfg as cooled down and force a repair to another eligible + cfg on the next resolution. + """ + from app.config import config + + session = _FakeSession(_thread(pinned_llm_config_id=-1)) + monkeypatch.setattr( + config, + "GLOBAL_LLM_CONFIGS", + [ + { + "id": -1, + "provider": "OPENROUTER", + "model_name": "google/gemma-4-26b-a4b-it:free", + "api_key": "k", + "billing_tier": "free", + "auto_pin_tier": "C", + "quality_score": 90, + "health_gated": False, + }, + { + "id": -2, + "provider": "OPENROUTER", + "model_name": "google/gemini-2.5-flash:free", + "api_key": "k", + "billing_tier": "free", + "auto_pin_tier": "C", + "quality_score": 80, + "health_gated": False, + }, + ], + ) + + async def _blocked(*_args, **_kwargs): + return _FakeQuotaResult(allowed=False) + + monkeypatch.setattr( + "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage", + _blocked, + ) + + mark_runtime_cooldown(-1, reason="provider_rate_limited", cooldown_seconds=600) + + result = await resolve_or_get_pinned_llm_config_id( + session, + thread_id=1, + search_space_id=10, + user_id="00000000-0000-0000-0000-000000000001", + selected_llm_config_id=0, + ) + assert result.resolved_llm_config_id == -2 + assert result.from_existing_pin is False + + +@pytest.mark.asyncio +async def test_clearing_runtime_cooldown_restores_pin_reuse(monkeypatch): + from app.config import config + + session = _FakeSession(_thread(pinned_llm_config_id=-1)) + monkeypatch.setattr( + config, + "GLOBAL_LLM_CONFIGS", + [ + { + "id": -1, + "provider": "OPENROUTER", + "model_name": "google/gemma-4-26b-a4b-it:free", + "api_key": "k", + "billing_tier": "free", + "auto_pin_tier": "C", + "quality_score": 90, + "health_gated": False, + }, + ], + ) + + async def _must_not_call(*_args, **_kwargs): + raise AssertionError("premium_get_usage should not run on healthy pin reuse") + + monkeypatch.setattr( + "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage", + _must_not_call, + ) + + mark_runtime_cooldown(-1, reason="provider_rate_limited", cooldown_seconds=600) + clear_runtime_cooldown(-1) + + result = await resolve_or_get_pinned_llm_config_id( + session, + thread_id=1, + search_space_id=10, + user_id="00000000-0000-0000-0000-000000000001", + selected_llm_config_id=0, + ) + assert result.resolved_llm_config_id == -1 + assert result.from_existing_pin is True + + +@pytest.mark.asyncio +async def test_auto_pin_repin_excludes_previous_config_on_runtime_retry(monkeypatch): + """Runtime retry should never repin the just-failed config.""" + from app.config import config + + session = _FakeSession(_thread(pinned_llm_config_id=-1)) + monkeypatch.setattr( + config, + "GLOBAL_LLM_CONFIGS", + [ + { + "id": -1, + "provider": "OPENROUTER", + "model_name": "google/gemma-4-26b-a4b-it:free", + "api_key": "k", + "billing_tier": "free", + "auto_pin_tier": "C", + "quality_score": 90, + "health_gated": False, + }, + { + "id": -2, + "provider": "OPENROUTER", + "model_name": "google/gemini-2.5-flash:free", + "api_key": "k", + "billing_tier": "free", + "auto_pin_tier": "C", + "quality_score": 80, + "health_gated": False, + }, + ], + ) + + async def _blocked(*_args, **_kwargs): + return _FakeQuotaResult(allowed=False) + + monkeypatch.setattr( + "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage", + _blocked, + ) + + result = await resolve_or_get_pinned_llm_config_id( + session, + thread_id=1, + search_space_id=10, + user_id="00000000-0000-0000-0000-000000000001", + selected_llm_config_id=0, + exclude_config_ids={-1}, + ) + assert result.resolved_llm_config_id == -2 + assert result.from_existing_pin is False + + +# --------------------------------------------------------------------------- +# Healthy-status cache (preflight TTL companion) +# --------------------------------------------------------------------------- + + +def test_mark_healthy_then_is_recently_healthy_true_within_ttl(): + mark_healthy(-42, ttl_seconds=60) + assert is_recently_healthy(-42) is True + + +def test_healthy_expires_after_ttl(monkeypatch): + import app.services.auto_model_pin_service as svc + + real_time = svc.time.time + base = real_time() + + monkeypatch.setattr(svc.time, "time", lambda: base) + mark_healthy(-7, ttl_seconds=10) + assert is_recently_healthy(-7) is True + + monkeypatch.setattr(svc.time, "time", lambda: base + 11) + assert is_recently_healthy(-7) is False + + +def test_mark_runtime_cooldown_invalidates_healthy_cache(): + mark_healthy(-9, ttl_seconds=60) + assert is_recently_healthy(-9) is True + + mark_runtime_cooldown(-9, reason="test", cooldown_seconds=60) + assert is_recently_healthy(-9) is False + + +def test_clear_healthy_removes_single_entry(): + mark_healthy(-11, ttl_seconds=60) + mark_healthy(-12, ttl_seconds=60) + clear_healthy(-11) + assert is_recently_healthy(-11) is False + assert is_recently_healthy(-12) is True + + +def test_clear_healthy_no_args_drops_all_entries(): + mark_healthy(-21, ttl_seconds=60) + mark_healthy(-22, ttl_seconds=60) + clear_healthy() + assert is_recently_healthy(-21) is False + assert is_recently_healthy(-22) is False diff --git a/surfsense_backend/tests/unit/services/test_llm_router_pool_filter.py b/surfsense_backend/tests/unit/services/test_llm_router_pool_filter.py new file mode 100644 index 000000000..c309ff881 --- /dev/null +++ b/surfsense_backend/tests/unit/services/test_llm_router_pool_filter.py @@ -0,0 +1,226 @@ +"""LLMRouterService pool-filter / rebuild tests. + +These tests focus on the *config plumbing* (which configs enter the router +pool, rebuild resets state correctly). They stub out the underlying +``litellm.Router`` so we don't need real API keys or network access. +""" + +from __future__ import annotations + +from unittest.mock import patch + +import pytest + +from app.services.llm_router_service import LLMRouterService + +pytestmark = pytest.mark.unit + + +def _fake_yaml_config( + *, + id: int, + model_name: str, + billing_tier: str = "free", +) -> dict: + return { + "id": id, + "name": f"yaml-{id}", + "provider": "OPENAI", + "model_name": model_name, + "api_key": "sk-test", + "api_base": "", + "billing_tier": billing_tier, + "rpm": 100, + "tpm": 100_000, + "litellm_params": {}, + } + + +def _fake_openrouter_config( + *, + id: int, + model_name: str, + billing_tier: str, + router_pool_eligible: bool | None = None, +) -> dict: + """Build a synthetic dynamic-OR config dict for router-pool tests. + + Defaults mirror Strategy 3: premium OR enters the pool, free OR stays + out. Callers can override ``router_pool_eligible`` to simulate legacy + configs or to regression-test the filter mechanics directly. + """ + if router_pool_eligible is None: + router_pool_eligible = billing_tier == "premium" + return { + "id": id, + "name": f"or-{id}", + "provider": "OPENROUTER", + "model_name": model_name, + "api_key": "sk-or-test", + "api_base": "", + "billing_tier": billing_tier, + "rpm": 20 if billing_tier == "free" else 200, + "tpm": 100_000 if billing_tier == "free" else 1_000_000, + "litellm_params": {}, + "router_pool_eligible": router_pool_eligible, + } + + +def _reset_router_singleton() -> None: + instance = LLMRouterService.get_instance() + instance._initialized = False + instance._router = None + instance._model_list = [] + instance._premium_model_strings = set() + + +def test_router_pool_includes_or_premium_excludes_or_free(): + """Strategy 3: premium OR joins the pool, free OR stays out. + + Dynamic OpenRouter premium entries opt into load balancing alongside + curated YAML configs. Dynamic OR free entries are intentionally kept + out because OpenRouter's free tier enforces a single account-global + quota bucket that per-deployment router accounting can't represent. + """ + _reset_router_singleton() + configs = [ + _fake_yaml_config(id=-1, model_name="gpt-4o", billing_tier="premium"), + _fake_yaml_config(id=-2, model_name="gpt-4o-mini", billing_tier="free"), + _fake_openrouter_config( + id=-10_001, model_name="openai/gpt-4o", billing_tier="premium" + ), + _fake_openrouter_config( + id=-10_002, + model_name="meta-llama/llama-3.3-70b:free", + billing_tier="free", + ), + ] + + with ( + patch("app.services.llm_router_service.Router") as mock_router, + patch( + "app.services.llm_router_service.LLMRouterService._build_context_fallback_groups" + ) as mock_ctx_fb, + ): + mock_ctx_fb.side_effect = lambda ml: (ml, None) + mock_router.return_value = object() + LLMRouterService.initialize(configs) + + pool_models = { + dep["litellm_params"]["model"] + for dep in LLMRouterService.get_instance()._model_list + } + # YAML premium + YAML free + dynamic OR premium are all in the pool. + # Dynamic OR free is NOT (shared-bucket rate limits can't be load-balanced). + assert pool_models == { + "openai/gpt-4o", + "openai/gpt-4o-mini", + "openrouter/openai/gpt-4o", + } + + prem = LLMRouterService.get_instance()._premium_model_strings + # YAML premium is fingerprinted under both its model_string and its + # ``base_model`` form (existing behavior we don't want to regress). + assert "openai/gpt-4o" in prem + # Dynamic OR premium is now fingerprinted as premium so pool-level + # calls through the router are billed against premium quota. + assert "openrouter/openai/gpt-4o" in prem + assert LLMRouterService.is_premium_model("openrouter/openai/gpt-4o") is True + # Dynamic OR free never enters the pool, so it's never counted as premium. + assert ( + LLMRouterService.is_premium_model("openrouter/meta-llama/llama-3.3-70b:free") + is False + ) + + +def test_router_pool_filter_mechanics_respect_override(): + """The ``router_pool_eligible`` filter itself works independently of tier. + + Regression guard: if a future refactor ever sets the flag False on a + premium config (e.g. for maintenance), that config MUST be skipped by + ``initialize`` even though its tier is premium. + """ + _reset_router_singleton() + configs = [ + _fake_yaml_config(id=-1, model_name="gpt-4o", billing_tier="premium"), + _fake_openrouter_config( + id=-10_001, + model_name="openai/gpt-4o", + billing_tier="premium", + router_pool_eligible=False, # opt out despite being premium + ), + ] + + with ( + patch("app.services.llm_router_service.Router") as mock_router, + patch( + "app.services.llm_router_service.LLMRouterService._build_context_fallback_groups" + ) as mock_ctx_fb, + ): + mock_ctx_fb.side_effect = lambda ml: (ml, None) + mock_router.return_value = object() + LLMRouterService.initialize(configs) + + pool_models = { + dep["litellm_params"]["model"] + for dep in LLMRouterService.get_instance()._model_list + } + assert pool_models == {"openai/gpt-4o"} + assert LLMRouterService.is_premium_model("openrouter/openai/gpt-4o") is False + + +def test_rebuild_refreshes_pool_after_configs_change(): + _reset_router_singleton() + configs_v1 = [ + _fake_yaml_config(id=-1, model_name="gpt-4o", billing_tier="premium"), + ] + configs_v2 = [ + *configs_v1, + _fake_yaml_config(id=-2, model_name="gpt-4o-mini", billing_tier="free"), + ] + + with ( + patch("app.services.llm_router_service.Router") as mock_router, + patch( + "app.services.llm_router_service.LLMRouterService._build_context_fallback_groups" + ) as mock_ctx_fb, + ): + mock_ctx_fb.side_effect = lambda ml: (ml, None) + mock_router.return_value = object() + + LLMRouterService.initialize(configs_v1) + assert len(LLMRouterService.get_instance()._model_list) == 1 + + # ``initialize`` should be a no-op here (already initialized). + LLMRouterService.initialize(configs_v2) + assert len(LLMRouterService.get_instance()._model_list) == 1 + + # ``rebuild`` must clear the guard and re-run with the new configs. + LLMRouterService.rebuild(configs_v2) + assert len(LLMRouterService.get_instance()._model_list) == 2 + + +def test_auto_model_pin_candidates_include_dynamic_openrouter(): + """Dynamic OR configs must remain Auto-mode thread-pin candidates. + + Guards against a future regression where someone adds the + ``router_pool_eligible`` filter to ``auto_model_pin_service._global_candidates``. + """ + from app.config import config + from app.services.auto_model_pin_service import _global_candidates + + or_premium = _fake_openrouter_config( + id=-10_001, model_name="openai/gpt-4o", billing_tier="premium" + ) + or_free = _fake_openrouter_config( + id=-10_002, + model_name="meta-llama/llama-3.3-70b:free", + billing_tier="free", + ) + original = config.GLOBAL_LLM_CONFIGS + try: + config.GLOBAL_LLM_CONFIGS = [or_premium, or_free] + candidate_ids = {c["id"] for c in _global_candidates()} + assert candidate_ids == {-10_001, -10_002} + finally: + config.GLOBAL_LLM_CONFIGS = original diff --git a/surfsense_backend/tests/unit/services/test_openrouter_integration_service.py b/surfsense_backend/tests/unit/services/test_openrouter_integration_service.py new file mode 100644 index 000000000..085740032 --- /dev/null +++ b/surfsense_backend/tests/unit/services/test_openrouter_integration_service.py @@ -0,0 +1,216 @@ +"""Unit tests for the dynamic OpenRouter integration.""" + +from __future__ import annotations + +import pytest + +from app.services.openrouter_integration_service import ( + _OPENROUTER_DYNAMIC_MARKER, + _generate_configs, + _openrouter_tier, + _stable_config_id, +) + +pytestmark = pytest.mark.unit + + +def _minimal_openrouter_model( + *, + model_id: str, + pricing: dict | None = None, + name: str | None = None, +) -> dict: + """Return a synthetic OpenRouter /api/v1/models entry. + + The real API payload includes a lot of fields; we only populate what + ``_generate_configs`` actually inspects (architecture, tool support, + context, pricing, id). + """ + return { + "id": model_id, + "name": name or model_id, + "architecture": {"output_modalities": ["text"]}, + "supported_parameters": ["tools"], + "context_length": 200_000, + "pricing": pricing or {"prompt": "0.000003", "completion": "0.000015"}, + } + + +# --------------------------------------------------------------------------- +# _openrouter_tier +# --------------------------------------------------------------------------- + + +def test_openrouter_tier_free_suffix(): + assert _openrouter_tier({"id": "foo/bar:free"}) == "free" + + +def test_openrouter_tier_zero_pricing(): + model = { + "id": "foo/bar", + "pricing": {"prompt": "0", "completion": "0"}, + } + assert _openrouter_tier(model) == "free" + + +def test_openrouter_tier_paid(): + model = { + "id": "foo/bar", + "pricing": {"prompt": "0.000003", "completion": "0.000015"}, + } + assert _openrouter_tier(model) == "premium" + + +def test_openrouter_tier_missing_pricing_is_premium(): + assert _openrouter_tier({"id": "foo/bar"}) == "premium" + assert _openrouter_tier({"id": "foo/bar", "pricing": {}}) == "premium" + + +# --------------------------------------------------------------------------- +# _stable_config_id +# --------------------------------------------------------------------------- + + +def test_stable_config_id_deterministic(): + taken1: set[int] = set() + taken2: set[int] = set() + a = _stable_config_id("openai/gpt-4o", -10_000, taken1) + b = _stable_config_id("openai/gpt-4o", -10_000, taken2) + assert a == b + assert a < 0 + + +def test_stable_config_id_collision_decrements(): + """When two model_ids hash to the same slot, the second should decrement.""" + taken: set[int] = set() + a = _stable_config_id("openai/gpt-4o", -10_000, taken) + # Force a collision by pre-populating ``taken`` with a slot we know will be + # picked. + taken_forced = {a} + b = _stable_config_id("openai/gpt-4o", -10_000, taken_forced) + assert b != a + assert b == a - 1 + assert b in taken_forced + + +def test_stable_config_id_different_models_different_ids(): + taken: set[int] = set() + ids = { + _stable_config_id("openai/gpt-4o", -10_000, taken), + _stable_config_id("anthropic/claude-3.5-sonnet", -10_000, taken), + _stable_config_id("google/gemini-2.0-flash", -10_000, taken), + } + assert len(ids) == 3 + + +def test_stable_config_id_survives_catalogue_churn(): + """Removing a model should not shift other models' IDs (the bug we fix).""" + taken1: set[int] = set() + id_a1 = _stable_config_id("openai/gpt-4o", -10_000, taken1) + _ = _stable_config_id("anthropic/claude-3-haiku", -10_000, taken1) + id_c1 = _stable_config_id("google/gemini-2.0-flash", -10_000, taken1) + + taken2: set[int] = set() + id_a2 = _stable_config_id("openai/gpt-4o", -10_000, taken2) + id_c2 = _stable_config_id("google/gemini-2.0-flash", -10_000, taken2) + + assert id_a1 == id_a2 + assert id_c1 == id_c2 + + +# --------------------------------------------------------------------------- +# _generate_configs +# --------------------------------------------------------------------------- + + +_SETTINGS_BASE: dict = { + "api_key": "sk-or-test", + "id_offset": -10_000, + "rpm": 200, + "tpm": 1_000_000, + "free_rpm": 20, + "free_tpm": 100_000, + "anonymous_enabled_paid": False, + "anonymous_enabled_free": True, + "quota_reserve_tokens": 4000, +} + + +def test_generate_configs_respects_tier(): + """Premium OR models opt into the router pool; free OR models stay out. + + Strategy-3 split: premium participates in LiteLLM Router load balancing, + free stays excluded because OpenRouter enforces a shared global free-tier + bucket that per-deployment router accounting can't represent. + """ + raw = [ + _minimal_openrouter_model(model_id="openai/gpt-4o"), + _minimal_openrouter_model( + model_id="meta-llama/llama-3.3-70b-instruct:free", + pricing={"prompt": "0", "completion": "0"}, + ), + ] + cfgs = _generate_configs(raw, dict(_SETTINGS_BASE)) + by_model = {c["model_name"]: c for c in cfgs} + + paid = by_model["openai/gpt-4o"] + assert paid["billing_tier"] == "premium" + assert paid["rpm"] == 200 + assert paid["tpm"] == 1_000_000 + assert paid["anonymous_enabled"] is False + assert paid["router_pool_eligible"] is True + assert paid[_OPENROUTER_DYNAMIC_MARKER] is True + + free = by_model["meta-llama/llama-3.3-70b-instruct:free"] + assert free["billing_tier"] == "free" + assert free["rpm"] == 20 + assert free["tpm"] == 100_000 + assert free["anonymous_enabled"] is True + assert free["router_pool_eligible"] is False + + +def test_generate_configs_excludes_upstream_openrouter_free_router(): + """OpenRouter's own ``openrouter/free`` meta-router must never become a card. + + The upstream API returns this as a first-class zero-priced model, so + without an explicit blocklist entry it would slip through every other + filter (text output, tool calling, 200k context, non-Amazon) and land + in the selector as a duplicate of the concrete ``:free`` cards. The + exclusion in ``_EXCLUDED_MODEL_IDS`` prevents that. + """ + raw = [ + _minimal_openrouter_model(model_id="openai/gpt-4o"), + _minimal_openrouter_model( + model_id="openrouter/free", + pricing={"prompt": "0", "completion": "0"}, + ), + ] + cfgs = _generate_configs(raw, dict(_SETTINGS_BASE)) + model_names = {c["model_name"] for c in cfgs} + assert "openrouter/free" not in model_names + assert "openai/gpt-4o" in model_names + + +def test_generate_configs_drops_non_text_and_non_tool_models(): + raw = [ + _minimal_openrouter_model(model_id="openai/gpt-4o"), + { # image-output model + "id": "openai/dall-e", + "architecture": {"output_modalities": ["image"]}, + "supported_parameters": ["tools"], + "context_length": 200_000, + "pricing": {"prompt": "0.01", "completion": "0.01"}, + }, + { # text but no tool calling + "id": "openai/completion-only", + "architecture": {"output_modalities": ["text"]}, + "supported_parameters": [], + "context_length": 200_000, + "pricing": {"prompt": "0.01", "completion": "0.01"}, + }, + ] + cfgs = _generate_configs(raw, dict(_SETTINGS_BASE)) + model_names = [c["model_name"] for c in cfgs] + assert "openai/gpt-4o" in model_names + assert "openai/dall-e" not in model_names + assert "openai/completion-only" not in model_names diff --git a/surfsense_backend/tests/unit/services/test_openrouter_legacy_config.py b/surfsense_backend/tests/unit/services/test_openrouter_legacy_config.py new file mode 100644 index 000000000..4eb1f2295 --- /dev/null +++ b/surfsense_backend/tests/unit/services/test_openrouter_legacy_config.py @@ -0,0 +1,108 @@ +"""Tests for deprecated-key warnings and back-compat in +``load_openrouter_integration_settings``. +""" + +from __future__ import annotations + +from pathlib import Path + +import pytest + +pytestmark = pytest.mark.unit + + +def _write_yaml(tmp_path: Path, body: str) -> Path: + cfg_dir = tmp_path / "app" / "config" + cfg_dir.mkdir(parents=True) + cfg_path = cfg_dir / "global_llm_config.yaml" + cfg_path.write_text(body, encoding="utf-8") + return cfg_path + + +def _patch_base_dir(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None: + from app import config as config_module + + monkeypatch.setattr(config_module, "BASE_DIR", tmp_path) + + +def test_legacy_billing_tier_emits_warning(monkeypatch, tmp_path, capsys): + _write_yaml( + tmp_path, + """ +openrouter_integration: + enabled: true + api_key: "sk-or-test" + billing_tier: "premium" +""".lstrip(), + ) + _patch_base_dir(monkeypatch, tmp_path) + + from app.config import load_openrouter_integration_settings + + settings = load_openrouter_integration_settings() + captured = capsys.readouterr().out + assert settings is not None + assert "billing_tier is deprecated" in captured + + +def test_legacy_anonymous_enabled_back_compat(monkeypatch, tmp_path, capsys): + _write_yaml( + tmp_path, + """ +openrouter_integration: + enabled: true + api_key: "sk-or-test" + anonymous_enabled: true +""".lstrip(), + ) + _patch_base_dir(monkeypatch, tmp_path) + + from app.config import load_openrouter_integration_settings + + settings = load_openrouter_integration_settings() + captured = capsys.readouterr().out + assert settings is not None + assert settings["anonymous_enabled_paid"] is True + assert settings["anonymous_enabled_free"] is True + assert "anonymous_enabled is" in captured + assert "deprecated" in captured + + +def test_new_keys_take_priority_over_legacy_back_compat(monkeypatch, tmp_path, capsys): + """If both legacy and new keys are present, new keys win (setdefault).""" + _write_yaml( + tmp_path, + """ +openrouter_integration: + enabled: true + api_key: "sk-or-test" + anonymous_enabled: true + anonymous_enabled_paid: false + anonymous_enabled_free: false +""".lstrip(), + ) + _patch_base_dir(monkeypatch, tmp_path) + + from app.config import load_openrouter_integration_settings + + settings = load_openrouter_integration_settings() + capsys.readouterr() + assert settings is not None + assert settings["anonymous_enabled_paid"] is False + assert settings["anonymous_enabled_free"] is False + + +def test_disabled_integration_returns_none(monkeypatch, tmp_path): + _write_yaml( + tmp_path, + """ +openrouter_integration: + enabled: false + api_key: "sk-or-test" +""".lstrip(), + ) + _patch_base_dir(monkeypatch, tmp_path) + + from app.config import load_openrouter_integration_settings + + assert load_openrouter_integration_settings() is None diff --git a/surfsense_backend/tests/unit/services/test_or_health_enrichment.py b/surfsense_backend/tests/unit/services/test_or_health_enrichment.py new file mode 100644 index 000000000..1c74aa928 --- /dev/null +++ b/surfsense_backend/tests/unit/services/test_or_health_enrichment.py @@ -0,0 +1,331 @@ +"""Unit tests for the OpenRouter ``_enrich_health`` background task.""" + +from __future__ import annotations + +from typing import Any + +import pytest + +from app.services.openrouter_integration_service import ( + OpenRouterIntegrationService, +) +from app.services.quality_score import ( + _HEALTH_FAIL_RATIO_FALLBACK, +) + +pytestmark = pytest.mark.unit + + +def _or_cfg( + *, + cid: int, + model_name: str, + tier: str = "premium", + static_score: int = 50, +) -> dict: + return { + "id": cid, + "provider": "OPENROUTER", + "model_name": model_name, + "billing_tier": tier, + "auto_pin_tier": "B" if tier == "premium" else "C", + "quality_score_static": static_score, + "quality_score_health": None, + "quality_score": static_score, + "health_gated": False, + } + + +class _StubResponse: + def __init__(self, *, payload: dict, status_code: int = 200): + self._payload = payload + self.status_code = status_code + + def raise_for_status(self) -> None: + if self.status_code >= 400: + raise RuntimeError(f"HTTP {self.status_code}") + + def json(self) -> dict: + return self._payload + + +class _StubAsyncClient: + """Minimal drop-in for ``httpx.AsyncClient`` used by ``_fetch_endpoints``.""" + + def __init__(self, responder): + self._responder = responder + self.requests: list[str] = [] + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + return False + + async def get(self, url: str, headers: dict | None = None) -> _StubResponse: + self.requests.append(url) + return self._responder(url) + + +def _patch_async_client(monkeypatch, responder) -> _StubAsyncClient: + """Replace ``httpx.AsyncClient`` for the duration of the test.""" + client = _StubAsyncClient(responder) + monkeypatch.setattr( + "app.services.openrouter_integration_service.httpx.AsyncClient", + lambda *_args, **_kwargs: client, + ) + return client + + +def _healthy_payload() -> dict: + return { + "data": { + "endpoints": [ + { + "status": 0, + "uptime_last_30m": 0.99, + "uptime_last_1d": 0.995, + "uptime_last_5m": 0.99, + } + ] + } + } + + +def _unhealthy_payload() -> dict: + return { + "data": { + "endpoints": [ + { + "status": 0, + "uptime_last_30m": 0.55, + "uptime_last_1d": 0.62, + "uptime_last_5m": 0.50, + } + ] + } + } + + +# --------------------------------------------------------------------------- +# Bounded fan-out + happy path +# --------------------------------------------------------------------------- + + +async def test_enrich_health_marks_healthy_and_gates_unhealthy(monkeypatch): + cfgs = [ + _or_cfg(cid=-1, model_name="anthropic/claude-haiku", static_score=70), + _or_cfg(cid=-2, model_name="venice/dead-model", static_score=60), + ] + + def responder(url: str) -> _StubResponse: + if "anthropic" in url: + return _StubResponse(payload=_healthy_payload()) + return _StubResponse(payload=_unhealthy_payload()) + + _patch_async_client(monkeypatch, responder) + + service = OpenRouterIntegrationService() + service._settings = {"api_key": ""} + await service._enrich_health(cfgs) + + healthy = next(c for c in cfgs if c["id"] == -1) + gated = next(c for c in cfgs if c["id"] == -2) + + assert healthy["health_gated"] is False + assert healthy["quality_score_health"] is not None + assert healthy["quality_score"] >= healthy["quality_score_static"] + + assert gated["health_gated"] is True + assert gated["quality_score"] == gated["quality_score_static"] + + +async def test_enrich_health_only_touches_or_provider(monkeypatch): + """YAML cfgs that aren't OPENROUTER must be skipped entirely.""" + yaml_cfg = { + "id": -1, + "provider": "AZURE_OPENAI", + "model_name": "gpt-5", + "billing_tier": "premium", + "auto_pin_tier": "A", + "quality_score_static": 80, + "quality_score": 80, + "health_gated": False, + } + or_cfg = _or_cfg(cid=-2, model_name="anthropic/claude-haiku") + + requests: list[str] = [] + + def responder(url: str) -> _StubResponse: + requests.append(url) + return _StubResponse(payload=_healthy_payload()) + + _patch_async_client(monkeypatch, responder) + + service = OpenRouterIntegrationService() + service._settings = {} + await service._enrich_health([yaml_cfg, or_cfg]) + + assert all("anthropic/claude-haiku" in r for r in requests) + # YAML cfg is untouched. + assert yaml_cfg["quality_score"] == 80 + assert yaml_cfg["health_gated"] is False + + +# --------------------------------------------------------------------------- +# Failure ratio fallback +# --------------------------------------------------------------------------- + + +async def test_enrich_health_falls_back_to_last_good_when_failure_ratio_high( + monkeypatch, +): + """If >= 25% of fetches fail, keep last-good cache instead of writing + partial data.""" + cfgs = [ + _or_cfg(cid=-1, model_name="anthropic/claude-haiku", static_score=70), + _or_cfg(cid=-2, model_name="openai/gpt-5", static_score=80), + _or_cfg(cid=-3, model_name="google/gemini-flash", static_score=65), + _or_cfg(cid=-4, model_name="venice/something", static_score=50), + ] + + service = OpenRouterIntegrationService() + service._settings = {} + # Pre-seed last-good cache with a known-healthy snapshot. + service._health_cache = { + "anthropic/claude-haiku": {"gated": False, "score": 95.0}, + } + + def all_fail(_url: str) -> _StubResponse: + return _StubResponse(payload={}, status_code=500) + + _patch_async_client(monkeypatch, all_fail) + await service._enrich_health(cfgs) + + # Above threshold ⇒ degraded; last-good cache wins for the cached cfg. + cached_hit = next(c for c in cfgs if c["model_name"] == "anthropic/claude-haiku") + assert cached_hit["quality_score_health"] == 95.0 + assert cached_hit["health_gated"] is False + # Confirm the threshold constant we're testing against is real. + assert _HEALTH_FAIL_RATIO_FALLBACK <= 1.0 + + +async def test_enrich_health_keeps_static_only_with_no_cache_and_failures( + monkeypatch, +): + """If a fetch fails and there's no last-good cache, the cfg keeps its + static-only ``quality_score`` and is *not* gated by default.""" + cfgs = [ + _or_cfg(cid=-1, model_name="anthropic/claude-haiku", static_score=70), + ] + + def fail(_url: str) -> _StubResponse: + return _StubResponse(payload={}, status_code=500) + + _patch_async_client(monkeypatch, fail) + + service = OpenRouterIntegrationService() + service._settings = {} + await service._enrich_health(cfgs) + + cfg = cfgs[0] + assert cfg["health_gated"] is False + assert cfg["quality_score"] == cfg["quality_score_static"] + assert cfg["quality_score_health"] is None + + +# --------------------------------------------------------------------------- +# Last-good cache: success populates, next failure reuses +# --------------------------------------------------------------------------- + + +async def test_enrich_health_populates_cache_on_success_then_reuses_on_failure( + monkeypatch, +): + cfg = _or_cfg(cid=-1, model_name="anthropic/claude-haiku", static_score=70) + + service = OpenRouterIntegrationService() + service._settings = {} + + def healthy(_url: str) -> _StubResponse: + return _StubResponse(payload=_healthy_payload()) + + _patch_async_client(monkeypatch, healthy) + await service._enrich_health([cfg]) + + assert "anthropic/claude-haiku" in service._health_cache + cached_score = service._health_cache["anthropic/claude-haiku"]["score"] + assert cached_score is not None + + # Next cycle: enough other healthy cfgs so failure ratio stays below + # the 25% threshold even when this one fails individually. + other_cfgs = [ + _or_cfg(cid=-2 - i, model_name=f"healthy/m-{i}", static_score=60) + for i in range(10) + ] + cfg["quality_score_health"] = None + cfg["quality_score"] = cfg["quality_score_static"] + + def mixed(url: str) -> _StubResponse: + if "anthropic" in url: + return _StubResponse(payload={}, status_code=500) + return _StubResponse(payload=_healthy_payload()) + + _patch_async_client(monkeypatch, mixed) + await service._enrich_health([cfg, *other_cfgs]) + + assert cfg["quality_score_health"] == cached_score + assert cfg["health_gated"] is False + + +# --------------------------------------------------------------------------- +# Bounded fan-out: respects top-N caps +# --------------------------------------------------------------------------- + + +async def test_enrich_health_bounds_premium_fanout(monkeypatch): + """Top-N premium cap is honoured even when many cfgs are present.""" + from app.services.quality_score import _HEALTH_ENRICH_TOP_N_PREMIUM + + cfgs = [ + _or_cfg( + cid=-i, model_name=f"openai/m-{i}", tier="premium", static_score=100 - i + ) + for i in range(1, _HEALTH_ENRICH_TOP_N_PREMIUM + 20) + ] + + seen: list[str] = [] + + def responder(url: str) -> _StubResponse: + seen.append(url) + return _StubResponse(payload=_healthy_payload()) + + _patch_async_client(monkeypatch, responder) + + service = OpenRouterIntegrationService() + service._settings = {} + await service._enrich_health(cfgs) + + assert len(seen) == _HEALTH_ENRICH_TOP_N_PREMIUM + + +async def test_enrich_health_no_or_cfgs_is_noop(monkeypatch): + """When the catalogue has no OR cfgs at all, no HTTP calls fire.""" + yaml_cfg: dict[str, Any] = { + "id": -1, + "provider": "AZURE_OPENAI", + "model_name": "gpt-5", + "billing_tier": "premium", + } + requests: list[str] = [] + + def responder(url: str) -> _StubResponse: + requests.append(url) + return _StubResponse(payload=_healthy_payload()) + + _patch_async_client(monkeypatch, responder) + + service = OpenRouterIntegrationService() + service._settings = {} + await service._enrich_health([yaml_cfg]) + assert requests == [] diff --git a/surfsense_backend/tests/unit/services/test_quality_score.py b/surfsense_backend/tests/unit/services/test_quality_score.py new file mode 100644 index 000000000..6fbc8fd62 --- /dev/null +++ b/surfsense_backend/tests/unit/services/test_quality_score.py @@ -0,0 +1,345 @@ +"""Unit tests for the Auto (Fastest) quality scoring module.""" + +from __future__ import annotations + +import time + +import pytest + +from app.services.quality_score import ( + _HEALTH_GATE_UPTIME_PCT, + _OPERATOR_TRUST_BONUS, + aggregate_health, + capabilities_signal, + context_signal, + created_recency_signal, + pricing_band, + slug_penalty, + static_score_or, + static_score_yaml, +) + +pytestmark = pytest.mark.unit + + +# --------------------------------------------------------------------------- +# created_recency_signal +# --------------------------------------------------------------------------- + + +def test_created_recency_signal_recent_model_scores_high(): + now = 1_750_000_000 # ~mid-2025 + one_month_ago = now - (30 * 86_400) + assert created_recency_signal(one_month_ago, now) == 20 + + +def test_created_recency_signal_old_model_scores_zero(): + now = 1_750_000_000 + five_years_ago = now - (5 * 365 * 86_400) + assert created_recency_signal(five_years_ago, now) == 0 + + +def test_created_recency_signal_missing_timestamp_is_neutral(): + now = 1_750_000_000 + assert created_recency_signal(None, now) == 0 + assert created_recency_signal(0, now) == 0 + + +def test_created_recency_signal_monotonic_decay(): + now = 1_750_000_000 + scores = [ + created_recency_signal(now - days * 86_400, now) + for days in (30, 120, 300, 500, 700, 1000, 1500) + ] + assert scores == sorted(scores, reverse=True) + + +# --------------------------------------------------------------------------- +# pricing_band +# --------------------------------------------------------------------------- + + +def test_pricing_band_free_returns_zero(): + assert pricing_band("0", "0") == 0 + assert pricing_band(0.0, 0.0) == 0 + assert pricing_band(None, None) == 0 + + +def test_pricing_band_handles_unparseable(): + assert pricing_band("not-a-number", "0") == 0 + assert pricing_band({}, []) == 0 # type: ignore[arg-type] + + +def test_pricing_band_premium_tiers_increase_with_price(): + cheap = pricing_band("0.0000003", "0.0000005") + mid = pricing_band("0.000003", "0.000015") + flagship = pricing_band("0.00001", "0.00005") + assert 0 < cheap < mid < flagship + + +# --------------------------------------------------------------------------- +# context_signal +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "ctx,expected", + [ + (1_500_000, 10), + (1_000_000, 10), + (500_000, 8), + (200_000, 6), + (128_000, 4), + (100_000, 2), + (50_000, 0), + (0, 0), + (None, 0), + ], +) +def test_context_signal_bands(ctx, expected): + assert context_signal(ctx) == expected + + +# --------------------------------------------------------------------------- +# capabilities_signal +# --------------------------------------------------------------------------- + + +def test_capabilities_signal_caps_at_five(): + assert ( + capabilities_signal( + ["tools", "structured_outputs", "reasoning", "include_reasoning"] + ) + <= 5 + ) + + +def test_capabilities_signal_tools_only(): + assert capabilities_signal(["tools"]) == 2 + + +def test_capabilities_signal_empty(): + assert capabilities_signal(None) == 0 + assert capabilities_signal([]) == 0 + + +# --------------------------------------------------------------------------- +# slug_penalty +# --------------------------------------------------------------------------- + + +def test_slug_penalty_demotes_tiny_models(): + assert slug_penalty("meta-llama/llama-3.2-1b-instruct") < 0 + assert slug_penalty("liquid/lfm-7b") < 0 + assert slug_penalty("google/gemma-3n-e4b-it") < 0 + + +def test_slug_penalty_skips_capable_mini_nano_lite_models(): + """Critical Option C+ regression: don't penalise modern frontier + models named ``-nano`` / ``-mini`` / ``-lite`` (gpt-5-mini, etc.).""" + assert slug_penalty("openai/gpt-5-mini") == 0 + assert slug_penalty("openai/gpt-5-nano") == 0 + assert slug_penalty("google/gemini-2.5-flash-lite") == 0 + assert slug_penalty("anthropic/claude-haiku-4.5") == 0 + + +def test_slug_penalty_demotes_legacy_variants(): + assert slug_penalty("openai/o1-preview") < 0 + assert slug_penalty("foo/bar-base") < 0 + assert slug_penalty("foo/bar-distill") < 0 + + +def test_slug_penalty_empty_input(): + assert slug_penalty("") == 0 + + +# --------------------------------------------------------------------------- +# static_score_or +# --------------------------------------------------------------------------- + + +def _or_model( + *, + model_id: str, + created: int | None = None, + prompt: str = "0.000003", + completion: str = "0.000015", + context: int = 200_000, + params: list[str] | None = None, +) -> dict: + return { + "id": model_id, + "created": created, + "pricing": {"prompt": prompt, "completion": completion}, + "context_length": context, + "supported_parameters": params if params is not None else ["tools"], + } + + +def test_static_score_or_frontier_premium_beats_free_tiny(): + now = 1_750_000_000 + frontier = _or_model( + model_id="openai/gpt-5", + created=now - (60 * 86_400), + prompt="0.000005", + completion="0.000020", + context=400_000, + params=["tools", "structured_outputs", "reasoning"], + ) + tiny_free = _or_model( + model_id="meta-llama/llama-3.2-1b-instruct:free", + created=now - (5 * 365 * 86_400), + prompt="0", + completion="0", + context=128_000, + params=["tools"], + ) + assert static_score_or(frontier, now_ts=now) > static_score_or( + tiny_free, now_ts=now + ) + + +def test_static_score_or_score_is_clamped_0_to_100(): + now = int(time.time()) + score = static_score_or(_or_model(model_id="openai/gpt-4o"), now_ts=now) + assert 0 <= score <= 100 + + +def test_static_score_or_unknown_provider_is_neutral_not_zero(): + now = int(time.time()) + score = static_score_or( + _or_model(model_id="some-new-lab/some-model"), + now_ts=now, + ) + assert score > 0 + + +def test_static_score_or_recent_release_beats_year_old_same_provider(): + now = 1_750_000_000 + fresh = _or_model(model_id="openai/gpt-5", created=now - (60 * 86_400)) + old = _or_model(model_id="openai/gpt-4-turbo", created=now - (700 * 86_400)) + assert static_score_or(fresh, now_ts=now) > static_score_or(old, now_ts=now) + + +# --------------------------------------------------------------------------- +# static_score_yaml +# --------------------------------------------------------------------------- + + +def test_static_score_yaml_includes_operator_bonus(): + cfg = { + "provider": "AZURE_OPENAI", + "model_name": "gpt-5", + "litellm_params": {"base_model": "azure/gpt-5"}, + } + score = static_score_yaml(cfg) + assert score >= _OPERATOR_TRUST_BONUS + + +def test_static_score_yaml_unknown_provider_still_carries_bonus(): + cfg = { + "provider": "SOME_NEW_PROVIDER", + "model_name": "weird-model", + } + score = static_score_yaml(cfg) + assert score >= _OPERATOR_TRUST_BONUS + + +def test_static_score_yaml_clamped_0_to_100(): + cfg = { + "provider": "AZURE_OPENAI", + "model_name": "gpt-5", + "litellm_params": {"base_model": "azure/gpt-5"}, + } + assert 0 <= static_score_yaml(cfg) <= 100 + + +# --------------------------------------------------------------------------- +# aggregate_health +# --------------------------------------------------------------------------- + + +def test_aggregate_health_gates_when_uptime_below_threshold(): + """Live data showed Venice-routed cfgs at 53-68%; this guards that the + 90% gate excludes them.""" + venice_endpoints = [ + { + "status": 0, + "uptime_last_30m": 0.55, + "uptime_last_1d": 0.60, + "uptime_last_5m": 0.50, + }, + { + "status": 0, + "uptime_last_30m": 0.65, + "uptime_last_1d": 0.68, + "uptime_last_5m": 0.62, + }, + ] + gated, score = aggregate_health(venice_endpoints) + assert gated is True + assert score is None + + +def test_aggregate_health_passes_for_healthy_provider(): + healthy = [ + { + "status": 0, + "uptime_last_30m": 0.99, + "uptime_last_1d": 0.995, + "uptime_last_5m": 0.99, + }, + ] + gated, score = aggregate_health(healthy) + assert gated is False + assert score is not None + assert score >= _HEALTH_GATE_UPTIME_PCT + + +def test_aggregate_health_picks_best_endpoint_across_multiple(): + """Multi-endpoint aggregation should reward the best non-null uptime.""" + mixed = [ + {"status": 0, "uptime_last_30m": 0.55}, + {"status": 0, "uptime_last_30m": 0.97}, # this one passes the gate + ] + gated, score = aggregate_health(mixed) + assert gated is False + assert score is not None + + +def test_aggregate_health_empty_endpoints_gated(): + gated, score = aggregate_health([]) + assert gated is True + assert score is None + + +def test_aggregate_health_no_status_zero_gated(): + """Even with high uptime, no OK status means the cfg is broken upstream.""" + endpoints = [ + {"status": 1, "uptime_last_30m": 0.99}, + {"status": 2, "uptime_last_30m": 0.98}, + ] + gated, score = aggregate_health(endpoints) + assert gated is True + assert score is None + + +def test_aggregate_health_all_uptime_null_gated(): + endpoints = [ + {"status": 0, "uptime_last_30m": None, "uptime_last_1d": None}, + ] + gated, score = aggregate_health(endpoints) + assert gated is True + assert score is None + + +def test_aggregate_health_pct_normalisation(): + """OpenRouter returns 0-1 fractions; some endpoints surface 0-100% + percentages. Both should reach the same gate decision.""" + fraction_form = [{"status": 0, "uptime_last_30m": 0.95}] + pct_form = [{"status": 0, "uptime_last_30m": 95.0}] + g1, s1 = aggregate_health(fraction_form) + g2, s2 = aggregate_health(pct_form) + assert g1 == g2 == False # noqa: E712 + assert s1 is not None and s2 is not None + assert abs(s1 - s2) < 0.5 diff --git a/surfsense_backend/tests/unit/test_stream_new_chat_contract.py b/surfsense_backend/tests/unit/test_stream_new_chat_contract.py index 5935d73ae..cc8157464 100644 --- a/surfsense_backend/tests/unit/test_stream_new_chat_contract.py +++ b/surfsense_backend/tests/unit/test_stream_new_chat_contract.py @@ -14,6 +14,7 @@ from app.tasks.chat.stream_new_chat import ( _classify_stream_exception, _contract_enforcement_active, _evaluate_file_contract_outcome, + _extract_resolved_file_path, _log_chat_stream_error, _tool_output_has_error, ) @@ -28,6 +29,39 @@ def test_tool_output_error_detection(): assert not _tool_output_has_error({"result": "Updated file /notes.md"}) +def test_extract_resolved_file_path_prefers_structured_path(): + assert ( + _extract_resolved_file_path( + tool_name="write_file", + tool_output={"status": "completed", "path": "/docs/note.md"}, + tool_input=None, + ) + == "/docs/note.md" + ) + + +def test_extract_resolved_file_path_falls_back_to_tool_input(): + assert ( + _extract_resolved_file_path( + tool_name="edit_file", + tool_output={"status": "completed", "result": "updated"}, + tool_input={"file_path": "/docs/edited.md"}, + ) + == "/docs/edited.md" + ) + + +def test_extract_resolved_file_path_does_not_parse_result_text(): + assert ( + _extract_resolved_file_path( + tool_name="write_file", + tool_output={"result": "Updated file /docs/from-text.md"}, + tool_input=None, + ) + is None + ) + + def test_file_write_contract_outcome_reasons(): result = StreamResult(intent_detected="file_write") passed, reason = _evaluate_file_contract_outcome(result) @@ -159,6 +193,84 @@ def test_stream_exception_classifies_rate_limited(): assert extra is None +def test_stream_exception_classifies_openrouter_429_payload(): + exc = Exception( + 'OpenrouterException - {"error":{"message":"Provider returned error","code":429,' + '"metadata":{"raw":"foo is temporarily rate-limited upstream"}}}' + ) + kind, code, severity, is_expected, user_message, extra = _classify_stream_exception( + exc, flow_label="chat" + ) + assert kind == "rate_limited" + assert code == "RATE_LIMITED" + assert severity == "warn" + assert is_expected is True + assert "temporarily rate-limited" in user_message + assert extra is None + + +@pytest.mark.asyncio +async def test_preflight_swallows_non_rate_limit_errors_and_re_raises_429(monkeypatch): + """``_preflight_llm`` is best-effort. + + - On rate-limit shaped exceptions (provider 429) it MUST re-raise so the + caller can drive the cooldown/repin branch. + - On any other transient failure it MUST swallow the error so the normal + stream path continues without surfacing preflight noise to the user. + """ + from types import SimpleNamespace + + from app.tasks.chat.stream_new_chat import _preflight_llm + + class _RateLimitedError(Exception): + """Class-name carries 'RateLimit' so _is_provider_rate_limited triggers.""" + + rate_calls: list[dict] = [] + other_calls: list[dict] = [] + + async def _fake_acompletion_429(**kwargs): + rate_calls.append(kwargs) + raise _RateLimitedError("simulated 429") + + async def _fake_acompletion_other(**kwargs): + other_calls.append(kwargs) + raise RuntimeError("some unrelated transient failure") + + fake_llm = SimpleNamespace( + model="openrouter/google/gemma-4-31b-it:free", + api_key="test", + api_base=None, + ) + + import litellm # type: ignore[import-not-found] + + monkeypatch.setattr(litellm, "acompletion", _fake_acompletion_429) + with pytest.raises(_RateLimitedError): + await _preflight_llm(fake_llm) + assert len(rate_calls) == 1 + assert rate_calls[0]["max_tokens"] == 1 + assert rate_calls[0]["stream"] is False + + monkeypatch.setattr(litellm, "acompletion", _fake_acompletion_other) + # MUST NOT raise: non-rate-limit failures are swallowed. + await _preflight_llm(fake_llm) + assert len(other_calls) == 1 + + +@pytest.mark.asyncio +async def test_preflight_skipped_for_auto_router_model(): + """Router-mode ``model='auto'`` has no single deployment to ping; the + LiteLLM router itself owns per-deployment rate-limit accounting, so the + preflight helper must short-circuit instead of issuing a probe.""" + from types import SimpleNamespace + + from app.tasks.chat.stream_new_chat import _preflight_llm + + fake_llm = SimpleNamespace(model="auto", api_key="x", api_base=None) + # Should return without raising or making any LiteLLM call. + await _preflight_llm(fake_llm) + + def test_stream_exception_classifies_thread_busy(): exc = BusyError(request_id="thread-123") kind, code, severity, is_expected, user_message, extra = _classify_stream_exception( diff --git a/surfsense_web/app/dashboard/[search_space_id]/purchase-success/page.tsx b/surfsense_web/app/dashboard/[search_space_id]/purchase-success/page.tsx index 67d9edab0..85bc4aaa6 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/purchase-success/page.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/purchase-success/page.tsx @@ -1,11 +1,8 @@ "use client"; -import { useQueryClient } from "@tanstack/react-query"; import { CheckCircle2 } from "lucide-react"; import Link from "next/link"; import { useParams } from "next/navigation"; -import { useEffect } from "react"; -import { USER_QUERY_KEY } from "@/atoms/user/user-query.atoms"; import { Button } from "@/components/ui/button"; import { Card, @@ -18,14 +15,8 @@ import { export default function PurchaseSuccessPage() { const params = useParams(); - const queryClient = useQueryClient(); const searchSpaceId = String(params.search_space_id ?? ""); - useEffect(() => { - void queryClient.invalidateQueries({ queryKey: USER_QUERY_KEY }); - void queryClient.invalidateQueries({ queryKey: ["token-status"] }); - }, [queryClient]); - return (
- {processChildrenWithCitations(children, urlMap)}
+ {standalonePath ? (
+
+
;
+}
diff --git a/surfsense_web/components/layout/ui/sidebar/PremiumTokenUsageDisplay.tsx b/surfsense_web/components/layout/ui/sidebar/PremiumTokenUsageDisplay.tsx
index a4d760dba..a3f028858 100644
--- a/surfsense_web/components/layout/ui/sidebar/PremiumTokenUsageDisplay.tsx
+++ b/surfsense_web/components/layout/ui/sidebar/PremiumTokenUsageDisplay.tsx
@@ -1,23 +1,18 @@
"use client";
-import { useQuery } from "@tanstack/react-query";
+import { useQuery } from "@rocicorp/zero/react";
import { Progress } from "@/components/ui/progress";
import { useIsAnonymous } from "@/contexts/anonymous-mode";
-import { stripeApiService } from "@/lib/apis/stripe-api.service";
+import { queries } from "@/zero/queries";
export function PremiumTokenUsageDisplay() {
const isAnonymous = useIsAnonymous();
- const { data: tokenStatus } = useQuery({
- queryKey: ["token-status"],
- queryFn: () => stripeApiService.getTokenStatus(),
- staleTime: 60_000,
- enabled: !isAnonymous,
- });
+ const [me] = useQuery(queries.user.me({}));
- if (!tokenStatus) return null;
+ if (isAnonymous || !me) return null;
const usagePercentage = Math.min(
- (tokenStatus.premium_tokens_used / Math.max(tokenStatus.premium_tokens_limit, 1)) * 100,
+ (me.premiumTokensUsed / Math.max(me.premiumTokensLimit, 1)) * 100,
100
);
@@ -31,8 +26,7 @@ export function PremiumTokenUsageDisplay() {
- {formatTokens(tokenStatus.premium_tokens_used)} /{" "}
- {formatTokens(tokenStatus.premium_tokens_limit)} tokens
+ {formatTokens(me.premiumTokensUsed)} / {formatTokens(me.premiumTokensLimit)} tokens
{usagePercentage.toFixed(0)}%
diff --git a/surfsense_web/components/layout/ui/sidebar/Sidebar.tsx b/surfsense_web/components/layout/ui/sidebar/Sidebar.tsx
index adad52792..d5038ea05 100644
--- a/surfsense_web/components/layout/ui/sidebar/Sidebar.tsx
+++ b/surfsense_web/components/layout/ui/sidebar/Sidebar.tsx
@@ -12,9 +12,9 @@ import { useIsAnonymous } from "@/contexts/anonymous-mode";
import { cn } from "@/lib/utils";
import { SIDEBAR_MIN_WIDTH } from "../../hooks/useSidebarResize";
import type { ChatItem, NavItem, PageUsage, SearchSpace, User } from "../../types/layout.types";
+import { AuthenticatedPageUsageDisplay } from "./AuthenticatedPageUsageDisplay";
import { ChatListItem } from "./ChatListItem";
import { NavSection } from "./NavSection";
-import { PageUsageDisplay } from "./PageUsageDisplay";
import { PremiumTokenUsageDisplay } from "./PremiumTokenUsageDisplay";
import { SidebarButton } from "./SidebarButton";
import { SidebarCollapseButton } from "./SidebarCollapseButton";
@@ -338,9 +338,7 @@ function SidebarUsageFooter({
return (
- {pageUsage && (
-
- )}
+
stripeApiService.getTokenStatus(),
});
+ // Live per-user usage via Zero.
+ const [me] = useZeroQuery(queries.user.me({}));
+
const purchaseMutation = useMutation({
mutationFn: stripeApiService.createTokenCheckoutSession,
onSuccess: (response) => {
@@ -54,12 +60,11 @@ export function BuyTokensContent() {
);
}
- const usagePercentage = tokenStatus
- ? Math.min(
- (tokenStatus.premium_tokens_used / Math.max(tokenStatus.premium_tokens_limit, 1)) * 100,
- 100
- )
- : 0;
+ const used = me?.premiumTokensUsed ?? 0;
+ const limit = me?.premiumTokensLimit ?? 0;
+ // Mirrors the backend formula in stripe_routes.py:608 (max(0, limit - used)).
+ const remaining = Math.max(0, limit - used);
+ const usagePercentage = me ? Math.min((used / Math.max(limit, 1)) * 100, 100) : 0;
return (
@@ -68,18 +73,17 @@ export function BuyTokensContent() {
$1 per 1M tokens, pay as you go
- {tokenStatus && (
+ {me && (
- {tokenStatus.premium_tokens_used.toLocaleString()} /{" "}
- {tokenStatus.premium_tokens_limit.toLocaleString()} premium tokens
+ {used.toLocaleString()} / {limit.toLocaleString()} premium tokens
{usagePercentage.toFixed(0)}%
- {tokenStatus.premium_tokens_remaining.toLocaleString()} tokens remaining
+ {remaining.toLocaleString()} tokens remaining
)}
diff --git a/surfsense_web/lib/apis/documents-api.service.ts b/surfsense_web/lib/apis/documents-api.service.ts
index 0cd81c0b7..630c88d16 100644
--- a/surfsense_web/lib/apis/documents-api.service.ts
+++ b/surfsense_web/lib/apis/documents-api.service.ts
@@ -5,6 +5,7 @@ import {
type DeleteDocumentRequest,
deleteDocumentRequest,
deleteDocumentResponse,
+ documentTitleRead,
type GetDocumentByChunkRequest,
type GetDocumentChunksRequest,
type GetDocumentRequest,
@@ -269,6 +270,17 @@ class DocumentsApiService {
);
};
+ getDocumentByVirtualPath = async (request: { search_space_id: number; virtual_path: string }) => {
+ const params = new URLSearchParams({
+ search_space_id: String(request.search_space_id),
+ virtual_path: request.virtual_path,
+ });
+ return baseApiService.get(
+ `/api/v1/documents/by-virtual-path?${params.toString()}`,
+ documentTitleRead
+ );
+ };
+
/**
* Get document type counts
*/
diff --git a/surfsense_web/zero/queries/index.ts b/surfsense_web/zero/queries/index.ts
index bc332114e..fbf1bd76e 100644
--- a/surfsense_web/zero/queries/index.ts
+++ b/surfsense_web/zero/queries/index.ts
@@ -3,6 +3,7 @@ import { chatSessionQueries, commentQueries, messageQueries } from "./chat";
import { connectorQueries, documentQueries } from "./documents";
import { folderQueries } from "./folders";
import { notificationQueries } from "./inbox";
+import { userQueries } from "./user";
export const queries = defineQueries({
notifications: notificationQueries,
@@ -12,4 +13,5 @@ export const queries = defineQueries({
messages: messageQueries,
comments: commentQueries,
chatSession: chatSessionQueries,
+ user: userQueries,
});
diff --git a/surfsense_web/zero/queries/user.ts b/surfsense_web/zero/queries/user.ts
new file mode 100644
index 000000000..30e71a482
--- /dev/null
+++ b/surfsense_web/zero/queries/user.ts
@@ -0,0 +1,11 @@
+import { defineQuery } from "@rocicorp/zero";
+import { z } from "zod";
+import { zql } from "../schema/index";
+
+export const userQueries = {
+ me: defineQuery(z.object({}), ({ ctx }) => {
+ const userId = ctx?.userId;
+ if (!userId) return zql.user.where("id", "__none__").one();
+ return zql.user.where("id", userId).one();
+ }),
+};
diff --git a/surfsense_web/zero/schema/index.ts b/surfsense_web/zero/schema/index.ts
index bba561580..3cca0f24a 100644
--- a/surfsense_web/zero/schema/index.ts
+++ b/surfsense_web/zero/schema/index.ts
@@ -3,6 +3,7 @@ import { chatCommentTable, chatSessionStateTable, newChatMessageTable } from "./
import { documentTable, searchSourceConnectorTable } from "./documents";
import { folderTable } from "./folders";
import { notificationTable } from "./inbox";
+import { userTable } from "./user";
const chatCommentRelationships = relationships(chatCommentTable, ({ one }) => ({
message: one({
@@ -34,6 +35,7 @@ export const schema = createSchema({
newChatMessageTable,
chatCommentTable,
chatSessionStateTable,
+ userTable,
],
relationships: [chatCommentRelationships, newChatMessageRelationships],
});
diff --git a/surfsense_web/zero/schema/user.ts b/surfsense_web/zero/schema/user.ts
new file mode 100644
index 000000000..0e6234db5
--- /dev/null
+++ b/surfsense_web/zero/schema/user.ts
@@ -0,0 +1,11 @@
+import { number, string, table } from "@rocicorp/zero";
+
+export const userTable = table("user")
+ .columns({
+ id: string(),
+ pagesLimit: number().from("pages_limit"),
+ pagesUsed: number().from("pages_used"),
+ premiumTokensLimit: number().from("premium_tokens_limit"),
+ premiumTokensUsed: number().from("premium_tokens_used"),
+ })
+ .primaryKey("id");