Merge pull request #1332 from AnishSarkar22/feat/model-pinnning-mode
Some checks are pending
Build and Push Docker Images / tag_release (push) Waiting to run
Build and Push Docker Images / build (./surfsense_backend, ./surfsense_backend/Dockerfile, backend, surfsense-backend, ubuntu-24.04-arm, linux/arm64, arm64) (push) Blocked by required conditions
Build and Push Docker Images / build (./surfsense_backend, ./surfsense_backend/Dockerfile, backend, surfsense-backend, ubuntu-latest, linux/amd64, amd64) (push) Blocked by required conditions
Build and Push Docker Images / build (./surfsense_web, ./surfsense_web/Dockerfile, web, surfsense-web, ubuntu-24.04-arm, linux/arm64, arm64) (push) Blocked by required conditions
Build and Push Docker Images / build (./surfsense_web, ./surfsense_web/Dockerfile, web, surfsense-web, ubuntu-latest, linux/amd64, amd64) (push) Blocked by required conditions
Build and Push Docker Images / create_manifest (backend, surfsense-backend) (push) Blocked by required conditions
Build and Push Docker Images / create_manifest (web, surfsense-web) (push) Blocked by required conditions

feat: Auto-pin quality scoring, OpenRouter tier refactor and live usage sidebar
This commit is contained in:
Rohan Verma 2026-05-01 15:57:19 -07:00 committed by GitHub
commit 451a98936e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
35 changed files with 3975 additions and 319 deletions

View file

@ -4,10 +4,12 @@ Revision ID: 138
Revises: 137
Create Date: 2026-04-30
Add thread-level fields to persist Auto (Fastest) model pinning metadata:
- pinned_llm_config_id: concrete resolved config id used for this thread
- pinned_auto_mode: auto policy identifier (currently "auto_fastest")
- pinned_at: timestamp when the pin was created/refreshed
Add a single thread-level column to persist the Auto (Fastest) model pin:
- pinned_llm_config_id: concrete resolved global LLM config id used for this
thread. NULL means "no pin; Auto will resolve on next turn".
The column is unindexed: all reads are by new_chat_threads.id (primary key),
so a secondary index would be dead write amplification.
"""
from __future__ import annotations
@ -27,29 +29,14 @@ def upgrade() -> None:
"ALTER TABLE new_chat_threads "
"ADD COLUMN IF NOT EXISTS pinned_llm_config_id INTEGER"
)
op.execute(
"ALTER TABLE new_chat_threads "
"ADD COLUMN IF NOT EXISTS pinned_auto_mode VARCHAR(32)"
)
op.execute(
"ALTER TABLE new_chat_threads "
"ADD COLUMN IF NOT EXISTS pinned_at TIMESTAMP WITH TIME ZONE"
)
op.execute(
"CREATE INDEX IF NOT EXISTS ix_new_chat_threads_pinned_llm_config_id "
"ON new_chat_threads (pinned_llm_config_id)"
)
op.execute(
"CREATE INDEX IF NOT EXISTS ix_new_chat_threads_pinned_auto_mode "
"ON new_chat_threads (pinned_auto_mode)"
)
def downgrade() -> None:
# Drop any shape the thread row may be carrying. The extra columns and
# indexes only exist on dev DBs that ran an earlier draft of 138; IF EXISTS
# makes each statement a safe no-op on the lean shape.
op.execute("DROP INDEX IF EXISTS ix_new_chat_threads_pinned_auto_mode")
op.execute("DROP INDEX IF EXISTS ix_new_chat_threads_pinned_llm_config_id")
op.execute("ALTER TABLE new_chat_threads DROP COLUMN IF EXISTS pinned_at")
op.execute("ALTER TABLE new_chat_threads DROP COLUMN IF EXISTS pinned_auto_mode")
op.execute(

View file

@ -0,0 +1,160 @@
"""add user table to zero_publication with column list
Adds the "user" table to zero_publication with a column-list publication
so that only the 5 fields driving the live usage meters are replicated
through WAL -> zero-cache -> browser IndexedDB:
id, pages_limit, pages_used,
premium_tokens_limit, premium_tokens_used
Sensitive columns (hashed_password, email, oauth_account, display_name,
avatar_url, memory_md, refresh_tokens, last_login, etc.) are NOT
included in the publication, so they never enter WAL replication.
Also re-asserts REPLICA IDENTITY DEFAULT on "user" for idempotency
(it is already DEFAULT today since "user" was never in the
TABLES_WITH_FULL_IDENTITY list of migration 117).
IMPORTANT - before AND after running this migration:
1. Stop zero-cache (it holds replication locks that will deadlock DDL)
2. Run: alembic upgrade head
3. Delete / reset the zero-cache data volume
4. Restart zero-cache (it will do a fresh initial sync)
Revision ID: 139
Revises: 138
"""
from collections.abc import Sequence
import sqlalchemy as sa
from alembic import op
revision: str = "139"
down_revision: str | None = "138"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
PUBLICATION_NAME = "zero_publication"
# Document column list as left by migration 117. Must match exactly.
DOCUMENT_COLS = [
"id",
"title",
"document_type",
"search_space_id",
"folder_id",
"created_by_id",
"status",
"created_at",
"updated_at",
]
# Five fields needed by the live usage meters (sidebar Tokens/Pages,
# Buy Tokens content). Keep this list narrow on purpose: anything added
# here flows into WAL and IndexedDB for every connected browser.
USER_COLS = [
"id",
"pages_limit",
"pages_used",
"premium_tokens_limit",
"premium_tokens_used",
]
def _terminate_blocked_pids(conn, table: str) -> None:
"""Kill backends whose locks on *table* would block our AccessExclusiveLock."""
conn.execute(
sa.text(
"SELECT pg_terminate_backend(l.pid) "
"FROM pg_locks l "
"JOIN pg_class c ON c.oid = l.relation "
"WHERE c.relname = :tbl "
" AND l.pid != pg_backend_pid()"
),
{"tbl": table},
)
def _has_zero_version(conn, table: str) -> bool:
return (
conn.execute(
sa.text(
"SELECT 1 FROM information_schema.columns "
"WHERE table_name = :tbl AND column_name = '_0_version'"
),
{"tbl": table},
).fetchone()
is not None
)
def _build_publication_ddl(
documents_has_zero_ver: bool, user_has_zero_ver: bool
) -> str:
doc_cols = DOCUMENT_COLS + (['"_0_version"'] if documents_has_zero_ver else [])
user_cols = USER_COLS + (['"_0_version"'] if user_has_zero_ver else [])
doc_col_list = ", ".join(doc_cols)
user_col_list = ", ".join(user_cols)
return (
f"CREATE PUBLICATION {PUBLICATION_NAME} FOR TABLE "
f"notifications, "
f"documents ({doc_col_list}), "
f"folders, "
f"search_source_connectors, "
f"new_chat_messages, "
f"chat_comments, "
f"chat_session_state, "
f'"user" ({user_col_list})'
)
def _build_publication_ddl_without_user(documents_has_zero_ver: bool) -> str:
doc_cols = DOCUMENT_COLS + (['"_0_version"'] if documents_has_zero_ver else [])
doc_col_list = ", ".join(doc_cols)
return (
f"CREATE PUBLICATION {PUBLICATION_NAME} FOR TABLE "
f"notifications, "
f"documents ({doc_col_list}), "
f"folders, "
f"search_source_connectors, "
f"new_chat_messages, "
f"chat_comments, "
f"chat_session_state"
)
def upgrade() -> None:
conn = op.get_bind()
# asyncpg requires LOCK TABLE inside a transaction block. Alembic already
# opened one via context.begin_transaction(), but the driver still errors
# unless we use an explicit SAVEPOINT (nested transaction) for this block.
tx = conn.begin_nested() if conn.in_transaction() else conn.begin()
with tx:
conn.execute(sa.text("SET lock_timeout = '10s'"))
_terminate_blocked_pids(conn, "user")
conn.execute(sa.text('LOCK TABLE "user" IN ACCESS EXCLUSIVE MODE'))
# Idempotent: "user" was never in TABLES_WITH_FULL_IDENTITY of
# migration 117, so this is already DEFAULT. Re-assert anyway so
# the column-list publication stays valid (DEFAULT identity only
# requires the PK to be in the column list).
conn.execute(sa.text('ALTER TABLE "user" REPLICA IDENTITY DEFAULT'))
conn.execute(sa.text(f"DROP PUBLICATION IF EXISTS {PUBLICATION_NAME}"))
documents_has_zero_ver = _has_zero_version(conn, "documents")
user_has_zero_ver = _has_zero_version(conn, "user")
conn.execute(
sa.text(_build_publication_ddl(documents_has_zero_ver, user_has_zero_ver))
)
def downgrade() -> None:
conn = op.get_bind()
conn.execute(sa.text(f"DROP PUBLICATION IF EXISTS {PUBLICATION_NAME}"))
documents_has_zero_ver = _has_zero_version(conn, "documents")
conn.execute(sa.text(_build_publication_ddl_without_user(documents_has_zero_ver)))

View file

@ -61,6 +61,9 @@ class _ThreadLockManager:
self._cancel_events: dict[str, asyncio.Event] = {}
self._cancel_requested_at_ms: dict[str, int] = {}
self._cancel_attempt_count: dict[str, int] = {}
# Monotonic per-thread epoch used to prevent stale middleware
# teardown from releasing a newer turn's lock.
self._turn_epoch: dict[str, int] = {}
def lock_for(self, thread_id: str) -> asyncio.Lock:
lock = self._locks.get(thread_id)
@ -107,6 +110,14 @@ class _ThreadLockManager:
self._cancel_requested_at_ms.pop(thread_id, None)
self._cancel_attempt_count.pop(thread_id, None)
def bump_turn_epoch(self, thread_id: str) -> int:
epoch = self._turn_epoch.get(thread_id, 0) + 1
self._turn_epoch[thread_id] = epoch
return epoch
def current_turn_epoch(self, thread_id: str) -> int:
return self._turn_epoch.get(thread_id, 0)
def end_turn(self, thread_id: str) -> None:
"""Best-effort terminal cleanup for a thread turn.
@ -114,6 +125,10 @@ class _ThreadLockManager:
finally-blocks where middleware teardown might be skipped due to abort
or disconnect edge-cases.
"""
# Invalidate any in-flight middleware holder first. This guarantees a
# stale ``aafter_agent`` from an older attempt cannot unlock a newer
# retry that already acquired the lock for the same thread.
self.bump_turn_epoch(thread_id)
lock = self._locks.get(thread_id)
if lock is not None and lock.locked():
lock.release()
@ -178,10 +193,10 @@ class BusyMutexMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, Respo
super().__init__()
self._require_thread_id = require_thread_id
self.tools = []
# Per-call locks owned by this middleware. We track them as
# an instance attribute so ``aafter_agent`` knows which lock
# to release.
self._held_locks: dict[str, asyncio.Lock] = {}
# Per-call lock ownership tracked as (lock, epoch). ``aafter_agent``
# only releases when its epoch still matches the manager's current
# epoch for the thread, preventing stale unlock races.
self._held_locks: dict[str, tuple[asyncio.Lock, int]] = {}
@staticmethod
def _thread_id(runtime: Runtime[ContextT]) -> str | None:
@ -232,7 +247,8 @@ class BusyMutexMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, Respo
if lock.locked():
raise BusyError(request_id=thread_id)
await lock.acquire()
self._held_locks[thread_id] = lock
epoch = manager.bump_turn_epoch(thread_id)
self._held_locks[thread_id] = (lock, epoch)
# Reset the cancel event so this turn starts fresh
reset_cancel(thread_id)
return None
@ -246,8 +262,15 @@ class BusyMutexMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, Respo
thread_id = self._thread_id(runtime)
if thread_id is None:
return None
lock = self._held_locks.pop(thread_id, None)
if lock is not None and lock.locked():
held = self._held_locks.pop(thread_id, None)
if held is None:
return None
lock, held_epoch = held
if held_epoch != manager.current_turn_epoch(thread_id):
# Stale teardown from an older attempt (e.g. runtime-recovery path
# already advanced epoch). Do not touch current lock/cancel state.
return None
if lock.locked():
lock.release()
# Always clear cancel event between turns so a stale signal
# doesn't leak into the next request.

View file

@ -63,6 +63,27 @@ def load_global_llm_configs():
else:
seen_slugs[slug] = cfg.get("id", 0)
# Stamp Auto (Fastest) ranking metadata. YAML configs are always
# Tier A — operator-curated, locked first when premium-eligible.
# The OpenRouter refresh tick later re-stamps health for any cfg
# whose provider == "OPENROUTER" via _enrich_health.
try:
from app.services.quality_score import static_score_yaml
for cfg in configs:
cfg["auto_pin_tier"] = "A"
static_q = static_score_yaml(cfg)
cfg["quality_score_static"] = static_q
cfg["quality_score"] = static_q
cfg["quality_score_health"] = None
# YAML cfgs whose provider is OPENROUTER are also subject
# to health gating against their own /endpoints data — a
# hand-picked dead OR model is still dead. _enrich_health
# re-stamps health_gated for them on the next refresh tick.
cfg["health_gated"] = False
except Exception as e:
print(f"Warning: Failed to score global LLM configs: {e}")
return configs
except Exception as e:
print(f"Warning: Failed to load global LLM configs: {e}")
@ -194,6 +215,9 @@ def load_openrouter_integration_settings() -> dict | None:
"""
Load OpenRouter integration settings from the YAML config.
Emits startup warnings for deprecated keys (``billing_tier``,
``anonymous_enabled``) and seeds their replacements for back-compat.
Returns:
dict with settings if present and enabled, None otherwise
"""
@ -206,9 +230,31 @@ def load_openrouter_integration_settings() -> dict | None:
with open(global_config_file, encoding="utf-8") as f:
data = yaml.safe_load(f)
settings = data.get("openrouter_integration")
if settings and settings.get("enabled"):
return settings
return None
if not settings or not settings.get("enabled"):
return None
if "billing_tier" in settings:
print(
"Warning: openrouter_integration.billing_tier is deprecated; "
"tier is now derived per model from OpenRouter data "
"(':free' suffix or zero pricing). Remove this key."
)
if "anonymous_enabled" in settings:
print(
"Warning: openrouter_integration.anonymous_enabled is "
"deprecated; use anonymous_enabled_paid and/or "
"anonymous_enabled_free instead. Both new flags have been "
"seeded from the legacy value for back-compat."
)
settings.setdefault(
"anonymous_enabled_paid", settings["anonymous_enabled"]
)
settings.setdefault(
"anonymous_enabled_free", settings["anonymous_enabled"]
)
return settings
except Exception as e:
print(f"Warning: Failed to load OpenRouter integration settings: {e}")
return None
@ -217,9 +263,14 @@ def load_openrouter_integration_settings() -> dict | None:
def initialize_openrouter_integration():
"""
If enabled, fetch all OpenRouter models and append them to
config.GLOBAL_LLM_CONFIGS as dynamic premium entries.
Should be called BEFORE initialize_llm_router() so the router
correctly excludes premium models from Auto mode.
config.GLOBAL_LLM_CONFIGS as dynamic entries. Each model's ``billing_tier``
is derived per-model from OpenRouter's API signals (``:free`` suffix or
zero pricing), so free OpenRouter models correctly skip premium quota.
Should be called BEFORE initialize_llm_router(). Dynamic entries are
tagged ``router_pool_eligible=False`` so the LiteLLM Router pool (used
by title-gen / sub-agent flows) remains scoped to curated YAML configs,
while user-facing Auto-mode thread pinning still considers them.
"""
settings = load_openrouter_integration_settings()
if not settings:
@ -235,9 +286,13 @@ def initialize_openrouter_integration():
if new_configs:
config.GLOBAL_LLM_CONFIGS.extend(new_configs)
free_count = sum(1 for c in new_configs if c.get("billing_tier") == "free")
premium_count = sum(
1 for c in new_configs if c.get("billing_tier") == "premium"
)
print(
f"Info: OpenRouter integration added {len(new_configs)} models "
f"(billing_tier={settings.get('billing_tier', 'premium')})"
f"(free={free_count}, premium={premium_count})"
)
else:
print("Info: OpenRouter integration enabled but no models fetched")

View file

@ -245,31 +245,53 @@ global_llm_configs:
# =============================================================================
# When enabled, dynamically fetches ALL available models from the OpenRouter API
# and injects them as global configs. This gives premium users access to any model
# on OpenRouter (Claude, Gemini, Llama, Mistral, etc.) via their premium token quota.
# on OpenRouter (Claude, Gemini, Llama, Mistral, etc.) via their premium token quota,
# while free-tier OpenRouter models show up with a green Free badge and do NOT
# consume premium quota.
# Models are fetched at startup and refreshed periodically in the background.
# All calls go through LiteLLM with the openrouter/ prefix.
openrouter_integration:
enabled: false
api_key: "sk-or-your-openrouter-api-key"
# billing_tier: "premium" or "free". Controls whether users need premium tokens.
billing_tier: "premium"
# anonymous_enabled: set true to also show OpenRouter models to no-login users
anonymous_enabled: false
# Tier is derived PER MODEL from OpenRouter's own API signals:
# - id ends with ":free" -> billing_tier=free
# - pricing.prompt AND pricing.completion == "0" -> billing_tier=free
# - otherwise -> billing_tier=premium
# No global billing_tier knob is honored; any legacy value emits a startup warning.
# Anonymous access is split by tier so operators can expose only free
# models to no-login users without leaking paid inference.
anonymous_enabled_paid: false
anonymous_enabled_free: false
seo_enabled: false
# quota_reserve_tokens: tokens reserved per call for quota enforcement
quota_reserve_tokens: 4000
# id_offset: starting negative ID for dynamically generated configs.
# Must not overlap with your static global_llm_configs IDs above.
# id_offset: base negative ID for dynamically generated configs.
# Model IDs are derived deterministically via BLAKE2b so they survive
# catalogue churn. Must not overlap with your static global_llm_configs IDs.
id_offset: -10000
# refresh_interval_hours: how often to re-fetch models from OpenRouter (0 = startup only)
refresh_interval_hours: 24
# rpm/tpm: Applied uniformly to all OpenRouter models for LiteLLM Router load balancing.
# OpenRouter doesn't expose per-model rate limits via API; actual throttling is handled
# upstream by OpenRouter itself (your account limits are at https://openrouter.ai/settings/limits).
# These values only matter if you set billing_tier to "free" (adding them to Auto mode).
# For premium-only models they are cosmetic. Set conservatively or match your account tier.
# Rate limits for PAID OpenRouter models. These are used by LiteLLM Router
# for per-deployment accounting when OR premium models participate in the
# shared sub-agent "auto" pool. They do NOT cap OpenRouter itself — your
# real account limits live at https://openrouter.ai/settings/limits.
rpm: 200
tpm: 1000000
# Rate limits for FREE OpenRouter models. Informational only: free OR
# models are intentionally kept OUT of the LiteLLM Router pool, because
# OpenRouter enforces free-tier limits globally per account (~20 RPM +
# 50-1000 daily requests across every ":free" model combined) —
# per-deployment router accounting can't represent a shared bucket
# correctly. Free OR models stay fully available in the model selector
# and for user-facing Auto thread pinning.
free_rpm: 20
free_tpm: 100000
litellm_params:
max_tokens: 16384
system_instructions: ""

View file

@ -638,13 +638,12 @@ class NewChatThread(BaseModel, TimestampMixin):
default=False,
server_default="false",
)
# Auto model pinning metadata:
# - pinned_llm_config_id stores the concrete resolved model config id.
# - pinned_auto_mode indicates which auto policy produced the pin.
# This allows Auto (Fastest) to resolve once per thread and stay stable.
pinned_llm_config_id = Column(Integer, nullable=True, index=True)
pinned_auto_mode = Column(String(32), nullable=True, index=True)
pinned_at = Column(TIMESTAMP(timezone=True), nullable=True)
# Auto (Fastest) model pin for this thread: concrete resolved global LLM
# config id. NULL means no pin; Auto will resolve on the next turn.
# Single-writer invariant: only app.services.auto_model_pin_service sets
# or clears this column (plus bulk clears when a search space's
# agent_llm_id changes). Unindexed: all reads are by primary key.
pinned_llm_config_id = Column(Integer, nullable=True)
# Relationships
search_space = relationship("SearchSpace", back_populates="new_chat_threads")

View file

@ -745,6 +745,51 @@ async def search_document_titles(
) from e
@router.get("/documents/by-virtual-path", response_model=DocumentTitleRead)
async def get_document_by_virtual_path(
search_space_id: int,
virtual_path: str,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""Resolve a knowledge-base document id by exact virtual path."""
try:
await check_permission(
session,
user,
search_space_id,
Permission.DOCUMENTS_READ.value,
"You don't have permission to read documents in this search space",
)
result = await session.execute(
select(
Document.id,
Document.title,
Document.document_type,
).filter(
Document.search_space_id == search_space_id,
Document.document_metadata["virtual_path"].as_string() == virtual_path,
)
)
row = result.first()
if row is None:
raise HTTPException(status_code=404, detail="Document not found")
return DocumentTitleRead(
id=row.id,
title=row.title,
document_type=row.document_type,
)
except HTTPException:
raise
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"Failed to resolve document by virtual path: {e!s}",
) from e
@router.get("/documents/status", response_model=DocumentStatusBatchResponse)
async def get_documents_status(
search_space_id: int,

View file

@ -803,11 +803,7 @@ async def update_llm_preferences(
await session.execute(
update(NewChatThread)
.where(NewChatThread.search_space_id == search_space_id)
.values(
pinned_llm_config_id=None,
pinned_auto_mode=None,
pinned_at=None,
)
.values(pinned_llm_config_id=None)
)
logger.info(
"Cleared auto model pins for search_space_id=%s after agent_llm_id change (%s -> %s)",

View file

@ -2,16 +2,23 @@
Auto (Fastest) is represented by ``agent_llm_id == 0``. For chat threads we
resolve that virtual mode to one concrete global LLM config exactly once and
persist the chosen config id on ``new_chat_threads`` so subsequent turns are
stable.
persist the chosen config id on ``new_chat_threads.pinned_llm_config_id`` so
subsequent turns are stable.
Single-writer invariant: this module is the only writer of
``NewChatThread.pinned_llm_config_id`` (aside from the bulk clear in
``search_spaces_routes`` when a search space's ``agent_llm_id`` changes).
Therefore a non-NULL value unambiguously means "this thread has an
Auto-resolved pin"; no separate source/policy column is needed.
"""
from __future__ import annotations
import hashlib
import logging
import threading
import time
from dataclasses import dataclass
from datetime import UTC, datetime
from uuid import UUID
from sqlalchemy import select
@ -19,12 +26,28 @@ from sqlalchemy.ext.asyncio import AsyncSession
from app.config import config
from app.db import NewChatThread
from app.services.quality_score import _QUALITY_TOP_K
from app.services.token_quota_service import TokenQuotaService
logger = logging.getLogger(__name__)
AUTO_FASTEST_ID = 0
AUTO_FASTEST_MODE = "auto_fastest"
_RUNTIME_COOLDOWN_SECONDS = 600
_HEALTHY_TTL_SECONDS = 45
# In-memory runtime cooldown map for configs that recently hard-failed at
# provider runtime (e.g. OpenRouter 429 on a pinned free model). This keeps
# the same unhealthy config from being reselected immediately during repair.
_runtime_cooldown_until: dict[int, float] = {}
_runtime_cooldown_lock = threading.Lock()
# Short-TTL "recently healthy" cache for configs that just passed a runtime
# preflight ping. Lets back-to-back turns on the same model skip the probe
# without eroding correctness — entries auto-expire and are wiped any time
# the same config is cooled down or the OR catalogue is refreshed.
_healthy_until: dict[int, float] = {}
_healthy_lock = threading.Lock()
@dataclass
@ -43,9 +66,117 @@ def _is_usable_global_config(cfg: dict) -> bool:
)
def _prune_runtime_cooldowns(now_ts: float | None = None) -> None:
now = time.time() if now_ts is None else now_ts
stale = [cid for cid, until in _runtime_cooldown_until.items() if until <= now]
for cid in stale:
_runtime_cooldown_until.pop(cid, None)
def _is_runtime_cooled_down(config_id: int) -> bool:
with _runtime_cooldown_lock:
_prune_runtime_cooldowns()
return config_id in _runtime_cooldown_until
def mark_runtime_cooldown(
config_id: int,
*,
reason: str = "rate_limited",
cooldown_seconds: int = _RUNTIME_COOLDOWN_SECONDS,
) -> None:
"""Temporarily suppress a config from Auto selection.
Used by runtime error handlers (e.g. OpenRouter 429) so an already pinned
config that is currently unhealthy does not get immediately reused on the
same thread during repair.
"""
if cooldown_seconds <= 0:
cooldown_seconds = _RUNTIME_COOLDOWN_SECONDS
until = time.time() + int(cooldown_seconds)
with _runtime_cooldown_lock:
_runtime_cooldown_until[int(config_id)] = until
_prune_runtime_cooldowns()
# A cooled cfg can never be "recently healthy"; drop any stale credit so
# the next turn that resolves to it (after cooldown) re-runs preflight.
clear_healthy(int(config_id))
logger.info(
"auto_pin_runtime_cooled_down config_id=%s reason=%s cooldown_seconds=%s",
config_id,
reason,
cooldown_seconds,
)
def clear_runtime_cooldown(config_id: int | None = None) -> None:
"""Test/ops helper to clear runtime cooldown entries."""
with _runtime_cooldown_lock:
if config_id is None:
_runtime_cooldown_until.clear()
return
_runtime_cooldown_until.pop(int(config_id), None)
def _prune_healthy(now_ts: float | None = None) -> None:
now = time.time() if now_ts is None else now_ts
stale = [cid for cid, until in _healthy_until.items() if until <= now]
for cid in stale:
_healthy_until.pop(cid, None)
def is_recently_healthy(config_id: int) -> bool:
"""Return True if ``config_id`` passed preflight within the TTL window."""
with _healthy_lock:
_prune_healthy()
return int(config_id) in _healthy_until
def mark_healthy(
config_id: int,
*,
ttl_seconds: int = _HEALTHY_TTL_SECONDS,
) -> None:
"""Record that ``config_id`` just passed a preflight probe.
Subsequent calls within ``ttl_seconds`` can skip the preflight ping. The
healthy state is intentionally process-local it's a latency hint, not a
correctness primitive so multi-worker drift is acceptable.
"""
if ttl_seconds <= 0:
ttl_seconds = _HEALTHY_TTL_SECONDS
until = time.time() + int(ttl_seconds)
with _healthy_lock:
_healthy_until[int(config_id)] = until
_prune_healthy()
def clear_healthy(config_id: int | None = None) -> None:
"""Drop one (or all) healthy-cache entries.
Called from runtime cooldown and OR catalogue refresh so a freshly cooled
or replaced config never carries stale "healthy" credit.
"""
with _healthy_lock:
if config_id is None:
_healthy_until.clear()
return
_healthy_until.pop(int(config_id), None)
def _global_candidates() -> list[dict]:
"""Return Auto-eligible global cfgs.
Drops cfgs flagged ``health_gated`` (best non-null OpenRouter uptime
below ``_HEALTH_GATE_UPTIME_PCT``) so chronically broken providers
can't be picked as the thread's pin. Also excludes configs currently
in runtime cooldown (e.g. temporary 429 bursts).
"""
candidates = [
cfg for cfg in config.GLOBAL_LLM_CONFIGS if _is_usable_global_config(cfg)
cfg
for cfg in config.GLOBAL_LLM_CONFIGS
if _is_usable_global_config(cfg)
and not cfg.get("health_gated")
and not _is_runtime_cooled_down(int(cfg.get("id", 0)))
]
return sorted(candidates, key=lambda c: int(c.get("id", 0)))
@ -54,10 +185,26 @@ def _tier_of(cfg: dict) -> str:
return str(cfg.get("billing_tier", "free")).lower()
def _deterministic_pick(candidates: list[dict], thread_id: int) -> dict:
def _select_pin(eligible: list[dict], thread_id: int) -> tuple[dict, int]:
"""Pick a config with quality-first ranking + deterministic spread.
Tier policy is lock-first: prefer Tier A (operator-curated YAML)
cfgs and only fall through to Tier B/C (dynamic OpenRouter) if no
Tier A cfg is eligible after upstream filters. Within the locked
pool, sort by ``quality_score`` and pick from the top-K via
``SHA256(thread_id)`` so different new threads spread across the
best models without ever picking a low-ranked one.
Returns ``(chosen_cfg, top_k_size)``. ``top_k_size`` is exposed for
structured logging in the caller.
"""
tier_a = [c for c in eligible if c.get("auto_pin_tier") in (None, "A")]
pool = tier_a if tier_a else eligible
pool = sorted(pool, key=lambda c: -int(c.get("quality_score") or 0))
top_k = pool[:_QUALITY_TOP_K]
digest = hashlib.sha256(f"{AUTO_FASTEST_MODE}:{thread_id}".encode()).digest()
idx = int.from_bytes(digest[:8], "big") % len(candidates)
return candidates[idx]
idx = int.from_bytes(digest[:8], "big") % len(top_k)
return top_k[idx], len(top_k)
def _to_uuid(user_id: str | UUID | None) -> UUID | None:
@ -89,11 +236,12 @@ async def resolve_or_get_pinned_llm_config_id(
user_id: str | UUID | None,
selected_llm_config_id: int,
force_repin_free: bool = False,
exclude_config_ids: set[int] | None = None,
) -> AutoPinResolution:
"""Resolve Auto (Fastest) to one concrete config id and persist pin metadata.
"""Resolve Auto (Fastest) to one concrete config id and persist the pin.
For non-auto selections, this function clears existing auto pin metadata and
returns the selected id as-is.
For non-auto selections, this function clears any existing pin and returns
the selected id as-is.
"""
thread = (
(
@ -113,16 +261,10 @@ async def resolve_or_get_pinned_llm_config_id(
f"Thread {thread_id} does not belong to search space {search_space_id}"
)
# Explicit model selected: clear stale auto pin metadata.
# Explicit model selected: clear any stale pin.
if selected_llm_config_id != AUTO_FASTEST_ID:
if (
thread.pinned_llm_config_id is not None
or thread.pinned_auto_mode is not None
or thread.pinned_at is not None
):
if thread.pinned_llm_config_id is not None:
thread.pinned_llm_config_id = None
thread.pinned_auto_mode = None
thread.pinned_at = None
await session.commit()
return AutoPinResolution(
resolved_llm_config_id=selected_llm_config_id,
@ -130,17 +272,19 @@ async def resolve_or_get_pinned_llm_config_id(
from_existing_pin=False,
)
candidates = _global_candidates()
excluded_ids = {int(cid) for cid in (exclude_config_ids or set())}
candidates = [
c for c in _global_candidates() if int(c.get("id", 0)) not in excluded_ids
]
if not candidates:
raise ValueError("No usable global LLM configs are available for Auto mode")
candidate_by_id = {int(c["id"]): c for c in candidates}
# Reuse existing valid pin without re-checking current quota (no silent tier switch),
# unless the caller explicitly requests a forced repin to free.
# Reuse an existing valid pin without re-checking current quota (no silent
# tier switch), unless the caller explicitly requests a forced repin to free.
pinned_id = thread.pinned_llm_config_id
if (
not force_repin_free
and thread.pinned_auto_mode == AUTO_FASTEST_MODE
and pinned_id is not None
and int(pinned_id) in candidate_by_id
):
@ -152,6 +296,15 @@ async def resolve_or_get_pinned_llm_config_id(
pinned_id,
_tier_of(pinned_cfg),
)
logger.info(
"auto_pin_resolved thread_id=%s config_id=%s tier=%s "
"auto_pin_tier=%s score=%s top_k_size=0 from_existing_pin=True",
thread_id,
pinned_id,
_tier_of(pinned_cfg),
pinned_cfg.get("auto_pin_tier", "?"),
int(pinned_cfg.get("quality_score") or 0),
)
return AutoPinResolution(
resolved_llm_config_id=int(pinned_id),
resolved_tier=_tier_of(pinned_cfg),
@ -159,11 +312,10 @@ async def resolve_or_get_pinned_llm_config_id(
)
if pinned_id is not None:
logger.info(
"auto_pin_invalid thread_id=%s search_space_id=%s pinned_config_id=%s pinned_auto_mode=%s",
"auto_pin_invalid thread_id=%s search_space_id=%s pinned_config_id=%s",
thread_id,
search_space_id,
pinned_id,
thread.pinned_auto_mode,
)
premium_eligible = (
@ -179,13 +331,11 @@ async def resolve_or_get_pinned_llm_config_id(
"Auto mode could not find an eligible LLM config for this user and quota state"
)
selected_cfg = _deterministic_pick(eligible, thread_id)
selected_cfg, top_k_size = _select_pin(eligible, thread_id)
selected_id = int(selected_cfg["id"])
selected_tier = _tier_of(selected_cfg)
thread.pinned_llm_config_id = selected_id
thread.pinned_auto_mode = AUTO_FASTEST_MODE
thread.pinned_at = datetime.now(UTC)
await session.commit()
if force_repin_free:
@ -216,6 +366,18 @@ async def resolve_or_get_pinned_llm_config_id(
selected_tier,
premium_eligible,
)
logger.info(
"auto_pin_resolved thread_id=%s config_id=%s tier=%s "
"auto_pin_tier=%s score=%s top_k_size=%d from_existing_pin=False",
thread_id,
selected_id,
selected_tier,
selected_cfg.get("auto_pin_tier", "?"),
int(selected_cfg.get("quality_score") or 0),
top_k_size,
)
return AutoPinResolution(
resolved_llm_config_id=selected_id,
resolved_tier=selected_tier,

View file

@ -208,6 +208,12 @@ class LLMRouterService:
"""
Initialize the router with global LLM configurations.
Configs with ``router_pool_eligible=False`` are skipped so that
dynamic OpenRouter entries stay out of the shared router pool used
by title-gen / sub-agent ``model="auto"`` flows. Those dynamic
entries are still available for user-facing Auto-mode thread pinning
via ``auto_model_pin_service``.
Args:
global_configs: List of global LLM config dictionaries from YAML
router_settings: Optional router settings (routing_strategy, num_retries, etc.)
@ -221,6 +227,8 @@ class LLMRouterService:
model_list = []
premium_models: set[str] = set()
for config in global_configs:
if config.get("router_pool_eligible") is False:
continue
deployment = cls._config_to_deployment(config)
if deployment:
model_list.append(deployment)
@ -309,10 +317,45 @@ class LLMRouterService:
logger.error(f"Failed to initialize LLM Router: {e}")
instance._router = None
@classmethod
def rebuild(
cls,
global_configs: list[dict],
router_settings: dict | None = None,
) -> None:
"""Reset the router and re-run ``initialize`` with fresh configs.
``initialize`` short-circuits once it has run to avoid re-creating the
LiteLLM Router on every request; ``rebuild`` deliberately clears
``_initialized`` so a caller (e.g. background OpenRouter refresh)
can force the pool to be rebuilt after catalogue changes.
"""
instance = cls.get_instance()
instance._initialized = False
instance._router = None
instance._model_list = []
instance._premium_model_strings = set()
cls.initialize(global_configs, router_settings)
@classmethod
def is_premium_model(cls, model_string: str) -> bool:
"""Return True if *model_string* (as reported by LiteLLM) belongs to a
premium-tier deployment in the router pool."""
"""Return True if *model_string* belongs to a premium-tier deployment
in the LiteLLM router pool.
Scope: only covers configs with ``router_pool_eligible`` truthy. That
includes static YAML premium configs AND dynamic OpenRouter *premium*
entries (which opt in at generation time). Dynamic OpenRouter *free*
entries are deliberately kept out of the router pool OpenRouter
enforces free-tier limits globally per account, so per-deployment
router accounting can't represent them correctly — and therefore
return ``False`` here, which matches their ``billing_tier="free"``
(no premium quota).
For per-request premium checks on an arbitrary config (static or
dynamic, pool or non-pool), read ``agent_config.is_premium`` instead;
that reflects the per-config ``billing_tier`` directly and is what
user-facing Auto-mode thread pinning uses to bill correctly.
"""
instance = cls.get_instance()
return model_string in instance._premium_model_strings

View file

@ -11,20 +11,81 @@ this service only manages the catalogue, not the inference path.
"""
import asyncio
import hashlib
import logging
import threading
import time
from typing import Any
import httpx
from app.services.quality_score import (
_HEALTH_BLEND_WEIGHT,
_HEALTH_ENRICH_CONCURRENCY,
_HEALTH_ENRICH_TOP_N_FREE,
_HEALTH_ENRICH_TOP_N_PREMIUM,
_HEALTH_FAIL_RATIO_FALLBACK,
_HEALTH_FETCH_TIMEOUT_SEC,
aggregate_health,
static_score_or,
)
logger = logging.getLogger(__name__)
OPENROUTER_API_URL = "https://openrouter.ai/api/v1/models"
OPENROUTER_ENDPOINTS_URL_TEMPLATE = (
"https://openrouter.ai/api/v1/models/{model_id}/endpoints"
)
# Sentinel value stored on each generated config so we can distinguish
# dynamic OpenRouter entries from hand-written YAML entries during refresh.
_OPENROUTER_DYNAMIC_MARKER = "__openrouter_dynamic__"
# Width of the hash space used by ``_stable_config_id``. 9_000_000 provides
# enough headroom to avoid frequent collisions for OpenRouter's catalogue
# (~300 models) while keeping IDs comfortably within Postgres INTEGER range.
_STABLE_ID_HASH_WIDTH = 9_000_000
def _stable_config_id(model_id: str, offset: int, taken: set[int]) -> int:
"""Derive a deterministic negative config ID from ``model_id``.
The same ``model_id`` always hashes to the same base value so thread pins
survive catalogue churn (models appearing/disappearing/reordering between
refreshes). On collision we decrement until we find an unused slot; this
keeps the mapping stable for the first config that claimed a slot and
only shifts collisions, which is much less disruptive than the legacy
index-based scheme that reshuffled every ID when the catalogue changed.
"""
digest = hashlib.blake2b(model_id.encode("utf-8"), digest_size=6).digest()
base = offset - (int.from_bytes(digest, "big") % _STABLE_ID_HASH_WIDTH)
cid = base
while cid in taken:
cid -= 1
taken.add(cid)
return cid
def _openrouter_tier(model: dict) -> str:
"""Classify an OpenRouter model as ``"free"`` or ``"premium"``.
Per OpenRouter's API contract, a model is free if:
- Its id ends with ``:free`` (OpenRouter's own free-variant convention), or
- Both ``pricing.prompt`` and ``pricing.completion`` are zero strings.
Anything else (missing pricing, non-zero pricing) falls through to
``"premium"`` so we never under-charge users. This derivation runs off the
already-cached /api/v1/models payload, so it adds no network cost.
"""
if model.get("id", "").endswith(":free"):
return "free"
pricing = model.get("pricing") or {}
prompt = str(pricing.get("prompt", "")).strip()
completion = str(pricing.get("completion", "")).strip()
if prompt == "0" and completion == "0":
return "free"
return "premium"
def _is_text_output_model(model: dict) -> bool:
"""Return True if the model produces text output only (skip image/audio generators)."""
@ -56,6 +117,11 @@ _EXCLUDED_MODEL_IDS: set[str] = {
# Deep-research models reject standard params (temperature, etc.)
"openai/o3-deep-research",
"openai/o4-mini-deep-research",
# OpenRouter's own meta-router over free models. We already enumerate every
# concrete ``:free`` model into GLOBAL_LLM_CONFIGS and Auto-mode thread
# pinning handles churn via the repair path, so exposing an additional
# indirection layer would only duplicate the capability with an opaque slug.
"openrouter/free",
}
_EXCLUDED_MODEL_SUFFIXES: tuple[str, ...] = ("-deep-research",)
@ -113,20 +179,41 @@ def _generate_configs(
raw_models: list[dict],
settings: dict[str, Any],
) -> list[dict]:
"""
Convert raw OpenRouter model entries into global LLM config dicts.
"""Convert raw OpenRouter model entries into global LLM config dicts.
Models are sorted by ID for deterministic, stable ID assignment across
restarts and refreshes.
Tier (``billing_tier``) is derived per-model from OpenRouter's own API
signals via ``_openrouter_tier`` there is no longer a uniform YAML
override. Config IDs are derived via ``_stable_config_id`` so they
survive catalogue churn across refreshes.
Router-pool membership is tier-aware:
- Premium OR models join the LiteLLM router pool (``router_pool_eligible=True``)
so sub-agent ``model="auto"`` flows benefit from load balancing and
failover across the curated YAML configs and the OR premium passthrough.
- Free OR models stay excluded (``router_pool_eligible=False``). LiteLLM
Router tracks rate limits per deployment, but OpenRouter enforces a
single global free-tier quota (~20 RPM + 50-1000 daily requests
account-wide across every ``:free`` model), so rotating across many
free deployments would only burn the shared bucket faster. Free OR
models remain fully available for user-facing Auto-mode thread pinning
via ``auto_model_pin_service``.
OpenRouter's own ``openrouter/free`` meta-router is filtered out upstream
via ``_EXCLUDED_MODEL_IDS``; we don't expose a redundant auto-select layer
because our own Auto (Fastest) pin + 24 h refresh + repair logic already
cover the catalogue-churn case.
"""
id_offset: int = settings.get("id_offset", -10000)
api_key: str = settings.get("api_key", "")
billing_tier: str = settings.get("billing_tier", "premium")
anonymous_enabled: bool = settings.get("anonymous_enabled", False)
seo_enabled: bool = settings.get("seo_enabled", False)
quota_reserve_tokens: int = settings.get("quota_reserve_tokens", 4000)
rpm: int = settings.get("rpm", 200)
tpm: int = settings.get("tpm", 1000000)
tpm: int = settings.get("tpm", 1_000_000)
free_rpm: int = settings.get("free_rpm", 20)
free_tpm: int = settings.get("free_tpm", 100_000)
anon_paid: bool = settings.get("anonymous_enabled_paid", False)
anon_free: bool = settings.get("anonymous_enabled_free", False)
litellm_params: dict = settings.get("litellm_params") or {}
system_instructions: str = settings.get("system_instructions", "")
use_default: bool = settings.get("use_default_system_instructions", True)
@ -142,19 +229,24 @@ def _generate_configs(
and _is_allowed_model(m)
and "/" in m.get("id", "")
]
text_models.sort(key=lambda m: m["id"])
configs: list[dict] = []
for idx, model in enumerate(text_models):
taken: set[int] = set()
now_ts = int(time.time())
for model in text_models:
model_id: str = model["id"]
name: str = model.get("name", model_id)
tier = _openrouter_tier(model)
static_q = static_score_or(model, now_ts=now_ts)
cfg: dict[str, Any] = {
"id": id_offset - idx,
"id": _stable_config_id(model_id, id_offset, taken),
"name": name,
"description": f"{name} via OpenRouter",
"billing_tier": billing_tier,
"anonymous_enabled": anonymous_enabled,
"billing_tier": tier,
"anonymous_enabled": anon_free if tier == "free" else anon_paid,
"seo_enabled": seo_enabled,
"seo_slug": None,
"quota_reserve_tokens": quota_reserve_tokens,
@ -162,13 +254,28 @@ def _generate_configs(
"model_name": model_id,
"api_key": api_key,
"api_base": "",
"rpm": rpm,
"tpm": tpm,
"rpm": free_rpm if tier == "free" else rpm,
"tpm": free_tpm if tier == "free" else tpm,
"litellm_params": dict(litellm_params),
"system_instructions": system_instructions,
"use_default_system_instructions": use_default,
"citations_enabled": citations_enabled,
# Premium OR deployments join the LiteLLM router pool so sub-agent
# model="auto" flows can load-balance / fail over across them.
# Free OR deployments stay out: OpenRouter's free tier is a single
# account-wide quota, so per-deployment routing can't spread load
# there — it just drains the shared bucket faster.
"router_pool_eligible": tier == "premium",
_OPENROUTER_DYNAMIC_MARKER: True,
# Auto (Fastest) ranking metadata. ``quality_score`` is initialised
# to the static score and gets re-blended with health on the next
# ``_enrich_health`` pass (synchronous on refresh, deferred on cold
# start so startup latency is unchanged).
"auto_pin_tier": "B" if tier == "premium" else "C",
"quality_score_static": static_q,
"quality_score_health": None,
"quality_score": static_q,
"health_gated": False,
}
configs.append(cfg)
@ -187,6 +294,12 @@ class OpenRouterIntegrationService:
self._configs_by_id: dict[int, dict] = {}
self._initialized = False
self._refresh_task: asyncio.Task | None = None
# Last-good per-model health snapshot. Survives across refresh
# cycles so a transient OpenRouter /endpoints outage doesn't drop
# every cfg back to static-only scoring.
# Shape: {model_name: {"gated": bool, "score": float | None}}
self._health_cache: dict[str, dict[str, Any]] = {}
self._enrich_task: asyncio.Task | None = None
@classmethod
def get_instance(cls) -> "OpenRouterIntegrationService":
@ -220,12 +333,27 @@ class OpenRouterIntegrationService:
self._configs_by_id = {c["id"]: c for c in self._configs}
self._initialized = True
tier_counts = self._tier_counts(self._configs)
logger.info(
"OpenRouter integration: loaded %d models (IDs %d to %d)",
"OpenRouter integration: loaded %d models (free=%d, premium=%d)",
len(self._configs),
self._configs[0]["id"] if self._configs else 0,
self._configs[-1]["id"] if self._configs else 0,
tier_counts["free"],
tier_counts["premium"],
)
# Schedule the first health-enrichment pass as a deferred task so
# cold-start latency is unchanged. Only valid when an event loop is
# already running (e.g. FastAPI lifespan); Celery worker init is
# fully sync so we silently skip — its first refresh tick (or the
# next refresh from the web process) will populate health data.
try:
loop = asyncio.get_running_loop()
self._enrich_task = loop.create_task(
self._enrich_health_safely(self._configs)
)
except RuntimeError:
pass
return self._configs
# ------------------------------------------------------------------
@ -254,7 +382,225 @@ class OpenRouterIntegrationService:
self._configs = new_configs
self._configs_by_id = new_by_id
logger.info("OpenRouter refresh: updated to %d models", len(new_configs))
# Catalogue churn invalidates per-config "recently healthy" credit
# earned by the previous turn's preflight. Drop the whole table so
# the next turn re-probes against the freshly loaded configs.
try:
from app.services.auto_model_pin_service import clear_healthy
clear_healthy()
except Exception:
logger.debug(
"OpenRouter refresh: clear_healthy import skipped", exc_info=True
)
tier_counts = self._tier_counts(new_configs)
logger.info(
"OpenRouter refresh: updated to %d models (free=%d, premium=%d)",
len(new_configs),
tier_counts["free"],
tier_counts["premium"],
)
# Re-blend health scores against the freshly fetched catalogue. Also
# re-stamps health for any YAML-curated cfg with provider==OPENROUTER
# so a hand-picked dead OR model is gated like a dynamic one.
await self._enrich_health_safely(static_configs + new_configs, log_summary=True)
# Rebuild the LiteLLM router so freshly fetched configs flow through
# (dynamic OR premium entries now opt into the pool, free ones stay
# out; a refresh also needs to pick up any static-config edits and
# reset cached context-window profiles).
try:
from app.config import config as _app_config
from app.services.llm_router_service import (
LLMRouterService,
_router_instance_cache as _chat_router_cache,
)
LLMRouterService.rebuild(
_app_config.GLOBAL_LLM_CONFIGS,
getattr(_app_config, "ROUTER_SETTINGS", None),
)
_chat_router_cache.clear()
except Exception as exc:
logger.warning("OpenRouter refresh: router rebuild skipped (%s)", exc)
@staticmethod
def _tier_counts(configs: list[dict]) -> dict[str, int]:
counts = {"free": 0, "premium": 0}
for cfg in configs:
tier = str(cfg.get("billing_tier", "")).lower()
if tier in counts:
counts[tier] += 1
return counts
# ------------------------------------------------------------------
# Auto (Fastest) health enrichment
# ------------------------------------------------------------------
async def _enrich_health_safely(
self, configs: list[dict], *, log_summary: bool = True
) -> None:
"""Wrapper around ``_enrich_health`` that swallows all errors.
Health enrichment is best-effort: any failure must leave cfgs in
their static-only state and never break refresh / startup.
"""
try:
await self._enrich_health(configs, log_summary=log_summary)
except Exception:
logger.exception("OpenRouter health enrichment failed")
async def _enrich_health(
self, configs: list[dict], *, log_summary: bool = True
) -> None:
"""Fetch per-model ``/endpoints`` data for the top OR cfgs and blend
the resulting health score into ``cfg["quality_score"]``.
Bounded fan-out: top-N per tier by ``quality_score_static`` only,
with ``asyncio.Semaphore(_HEALTH_ENRICH_CONCURRENCY)`` guarding the
outbound HTTP. Misses fall back to a per-model last-good cache; if
the failure ratio crosses ``_HEALTH_FAIL_RATIO_FALLBACK`` we keep
the entire previous cycle's cache for this run.
"""
or_cfgs = [
c for c in configs if str(c.get("provider", "")).upper() == "OPENROUTER"
]
if not or_cfgs:
return
premium_pool = sorted(
[c for c in or_cfgs if str(c.get("billing_tier", "")).lower() == "premium"],
key=lambda c: -int(c.get("quality_score_static") or 0),
)[:_HEALTH_ENRICH_TOP_N_PREMIUM]
free_pool = sorted(
[c for c in or_cfgs if str(c.get("billing_tier", "")).lower() == "free"],
key=lambda c: -int(c.get("quality_score_static") or 0),
)[:_HEALTH_ENRICH_TOP_N_FREE]
# De-duplicate while preserving order: a cfg shouldn't fall in both
# tiers, but defensive code is cheap here.
seen_ids: set[int] = set()
selected: list[dict] = []
for cfg in premium_pool + free_pool:
cid = int(cfg.get("id", 0))
if cid in seen_ids:
continue
seen_ids.add(cid)
selected.append(cfg)
if not selected:
return
api_key = str(self._settings.get("api_key") or "")
semaphore = asyncio.Semaphore(_HEALTH_ENRICH_CONCURRENCY)
async with httpx.AsyncClient(timeout=_HEALTH_FETCH_TIMEOUT_SEC) as client:
results = await asyncio.gather(
*(
self._fetch_endpoints(client, semaphore, api_key, cfg)
for cfg in selected
)
)
fail_count = sum(1 for _, _, err in results if err is not None)
fail_ratio = fail_count / len(results) if results else 0.0
degraded = fail_ratio >= _HEALTH_FAIL_RATIO_FALLBACK
if degraded:
logger.warning(
"auto_pin_health_enrich_degraded fail_ratio=%.2f total=%d "
"using_last_good_cache=true",
fail_ratio,
len(results),
)
# Per-cfg health update.
for cfg, endpoints, err in results:
model_name = str(cfg.get("model_name", ""))
if not degraded and err is None and endpoints is not None:
gated, h_score = aggregate_health(endpoints)
cfg["health_gated"] = bool(gated)
cfg["quality_score_health"] = h_score
self._health_cache[model_name] = {
"gated": bool(gated),
"score": h_score,
}
else:
cached = self._health_cache.get(model_name)
if cached is not None:
cfg["health_gated"] = bool(cached.get("gated", False))
cfg["quality_score_health"] = cached.get("score")
# else: keep current values (initial defaults from
# _generate_configs / load_global_llm_configs).
# Blend health into the final score for every OR cfg, including
# those outside the enriched top-N (they fall through to static).
gated_count = 0
by_provider: dict[str, int] = {}
for cfg in or_cfgs:
static_q = int(cfg.get("quality_score_static") or 0)
h = cfg.get("quality_score_health")
if h is not None and not cfg.get("health_gated"):
blended = (
_HEALTH_BLEND_WEIGHT * float(h)
+ (1 - _HEALTH_BLEND_WEIGHT) * static_q
)
cfg["quality_score"] = round(blended)
else:
cfg["quality_score"] = static_q
if cfg.get("health_gated"):
gated_count += 1
model_id = str(cfg.get("model_name", ""))
provider_slug = (
model_id.split("/", 1)[0] if "/" in model_id else "unknown"
)
by_provider[provider_slug] = by_provider.get(provider_slug, 0) + 1
if log_summary:
logger.info(
"auto_pin_health_gated count=%d by_provider=%s fail_ratio=%.2f "
"total_enriched=%d",
gated_count,
dict(sorted(by_provider.items(), key=lambda kv: -kv[1])),
fail_ratio,
len(selected),
)
@staticmethod
async def _fetch_endpoints(
client: httpx.AsyncClient,
semaphore: asyncio.Semaphore,
api_key: str,
cfg: dict,
) -> tuple[dict, list[dict] | None, Exception | None]:
"""Fetch ``/api/v1/models/{id}/endpoints`` for one cfg.
Returns ``(cfg, endpoints, err)`` so the caller can keep batched
results aligned with their cfgs without raising.
"""
model_id = str(cfg.get("model_name", ""))
if not model_id:
return cfg, None, ValueError("missing model_name")
url = OPENROUTER_ENDPOINTS_URL_TEMPLATE.format(model_id=model_id)
headers = {"Authorization": f"Bearer {api_key}"} if api_key else {}
async with semaphore:
try:
resp = await client.get(url, headers=headers)
resp.raise_for_status()
data = resp.json()
except Exception as exc:
return cfg, None, exc
payload = data.get("data") if isinstance(data, dict) else None
if not isinstance(payload, dict):
return cfg, None, ValueError("malformed endpoints payload")
endpoints = payload.get("endpoints")
if not isinstance(endpoints, list):
return cfg, [], None
return cfg, endpoints, None
async def _refresh_loop(self, interval_hours: float) -> None:
interval_sec = interval_hours * 3600

View file

@ -0,0 +1,380 @@
"""Pure-function quality scoring for Auto (Fastest) model selection.
This module is import-free of any service / request-path dependencies. All
numbers are computed once during the OpenRouter refresh tick (or YAML load)
and cached on the cfg dict, so the chat hot path only does a precomputed
sort and a SHA256 pick.
Score components (0-100 scale, higher is better):
* ``static_score_or`` - derived from the bulk ``/api/v1/models`` payload
(provider prestige + ``created`` recency + pricing band + context window
+ capabilities + narrow tiny/legacy slug penalty).
* ``static_score_yaml`` - same shape for hand-curated YAML configs, plus
an operator-trust bonus (the operator deliberately picked this model).
* ``aggregate_health`` - run on per-model ``/api/v1/models/{id}/endpoints``
responses; returns ``(gated, score_or_none)``.
The blended ``quality_score`` (0.5 * static + 0.5 * health) is computed in
:mod:`app.services.openrouter_integration_service` because that's the only
caller that sees both halves.
"""
from __future__ import annotations
# ---------------------------------------------------------------------------
# Tunables (constants, not flags)
# ---------------------------------------------------------------------------
# Top-K size for deterministic spread inside the locked tier.
_QUALITY_TOP_K: int = 5
# Hard health gate: any cfg whose best non-null uptime is below this %
# is excluded from Auto-mode selection entirely.
_HEALTH_GATE_UPTIME_PCT: float = 90.0
# Health/static blend weight when a cfg has fresh /endpoints data.
_HEALTH_BLEND_WEIGHT: float = 0.5
# Static bonus applied to YAML cfgs because the operator hand-picked them.
_OPERATOR_TRUST_BONUS: int = 20
# /endpoints fan-out is bounded per refresh tick.
_HEALTH_ENRICH_TOP_N_PREMIUM: int = 50
_HEALTH_ENRICH_TOP_N_FREE: int = 30
_HEALTH_ENRICH_CONCURRENCY: int = 15
_HEALTH_FETCH_TIMEOUT_SEC: float = 5.0
# If at least this fraction of /endpoints fetches fail in a refresh cycle,
# fall back to the previous cycle's last-good cache instead of writing
# partial / stale health values.
_HEALTH_FAIL_RATIO_FALLBACK: float = 0.25
# Narrow tiny/legacy slug penalties only. We deliberately do NOT penalise
# ``-nano`` / ``-mini`` / ``-lite`` because modern frontier models ship with
# those naming patterns (``gpt-5-mini``, ``gemini-2.5-flash-lite`` etc.) and
# blanket-penalising them suppresses high-quality picks.
_TINY_LEGACY_PENALTY_PATTERNS: tuple[str, ...] = (
"-1b-",
"-1.2b-",
"-1.5b-",
"-2b-",
"-3b-",
"gemma-3n",
"lfm-",
"-base",
"-distill",
":nitro",
"-preview",
)
# ---------------------------------------------------------------------------
# Provider prestige tables
# ---------------------------------------------------------------------------
# OpenRouter-side provider slug (the prefix before ``/`` in the model id).
# Tiers are coarse: frontier labs > strong open / fast-moving labs >
# specialist labs > everything else.
PROVIDER_PRESTIGE_OR: dict[str, int] = {
# Frontier labs
"openai": 50,
"anthropic": 50,
"google": 50,
"x-ai": 50,
# Strong open / fast-moving labs
"deepseek": 38,
"qwen": 38,
"meta-llama": 38,
"mistralai": 38,
"cohere": 38,
"nvidia": 38,
"alibaba": 38,
# Specialist / regional / strong second-tier
"microsoft": 28,
"01-ai": 28,
"minimax": 28,
"moonshot": 28,
"z-ai": 28,
"nousresearch": 28,
"ai21": 28,
"perplexity": 28,
# Smaller / niche providers
"liquid": 18,
"cognitivecomputations": 18,
"venice": 18,
"inflection": 18,
}
# YAML provider field (the upstream API shape the operator selected).
PROVIDER_PRESTIGE_YAML: dict[str, int] = {
"AZURE_OPENAI": 50,
"OPENAI": 50,
"ANTHROPIC": 50,
"GOOGLE": 50,
"VERTEX_AI": 50,
"GEMINI": 50,
"XAI": 50,
"MISTRAL": 38,
"DEEPSEEK": 38,
"COHERE": 38,
"GROQ": 30,
"TOGETHER_AI": 28,
"FIREWORKS_AI": 28,
"PERPLEXITY": 28,
"MINIMAX": 28,
"BEDROCK": 28,
"OPENROUTER": 25,
"OLLAMA": 12,
"CUSTOM": 12,
}
# ---------------------------------------------------------------------------
# Pure scoring helpers
# ---------------------------------------------------------------------------
# Calibrated against the live /api/v1/models bulk dump. Frontier models
# released in the last ~6 months (GPT-5 family, Claude 4.x, Gemini 2.5,
# Grok 4) score in the 18-20 band; mid-2024 models in the 8-12 band;
# anything older trails off.
_RECENCY_BANDS_DAYS: tuple[tuple[int, int], ...] = (
(60, 20),
(180, 16),
(365, 12),
(540, 9),
(730, 6),
(1095, 3),
)
def created_recency_signal(created_ts: int | None, now_ts: int) -> int:
"""Return 0-20 based on how recently the model was published.
Uses the OpenRouter ``created`` Unix timestamp (or any equivalent for
YAML cfgs). Models without a usable timestamp get 0 (we don't penalise,
we just don't reward).
"""
if created_ts is None or created_ts <= 0 or now_ts <= 0:
return 0
age_days = max(0, (now_ts - int(created_ts)) // 86_400)
for cutoff, score in _RECENCY_BANDS_DAYS:
if age_days <= cutoff:
return score
return 0
def pricing_band(
prompt: str | float | int | None,
completion: str | float | int | None,
) -> int:
"""Return 0-15 based on combined prompt+completion cost per 1M tokens.
Higher-priced models tend to be the larger / more capable ones. A free
model returns 0 (we use other signals to rank free-vs-free instead).
Uncoercible inputs are treated as 0 rather than raising.
"""
def _to_float(value) -> float:
if value is None:
return 0.0
try:
return float(value)
except (TypeError, ValueError):
return 0.0
p = _to_float(prompt)
c = _to_float(completion)
total_per_million = (p + c) * 1_000_000
if total_per_million >= 20.0:
return 15
if total_per_million >= 5.0:
return 12
if total_per_million >= 1.0:
return 9
if total_per_million >= 0.3:
return 6
if total_per_million >= 0.05:
return 4
if total_per_million > 0.0:
return 2
return 0
def context_signal(ctx: int | None) -> int:
"""Return 0-10 based on the model's context window."""
if not ctx or ctx <= 0:
return 0
if ctx >= 1_000_000:
return 10
if ctx >= 400_000:
return 8
if ctx >= 200_000:
return 6
if ctx >= 128_000:
return 4
if ctx >= 100_000:
return 2
return 0
def capabilities_signal(supported_parameters: list[str] | None) -> int:
"""Return 0-5 for capabilities that matter for our agent flows."""
if not supported_parameters:
return 0
params = set(supported_parameters)
score = 0
if "tools" in params:
score += 2
if "structured_outputs" in params or "response_format" in params:
score += 2
if "reasoning" in params or "include_reasoning" in params:
score += 1
return min(score, 5)
def slug_penalty(model_id: str) -> int:
"""Return a non-positive number; matches the narrow tiny/legacy patterns."""
if not model_id:
return 0
needle = model_id.lower()
for pattern in _TINY_LEGACY_PENALTY_PATTERNS:
if pattern in needle:
return -10
return 0
def _provider_prestige_or(model_id: str) -> int:
if "/" not in model_id:
return 0
slug = model_id.split("/", 1)[0].lower()
return PROVIDER_PRESTIGE_OR.get(slug, 15)
def static_score_or(or_model: dict, *, now_ts: int) -> int:
"""Score a raw OpenRouter ``/api/v1/models`` entry on a 0-100 scale."""
model_id = str(or_model.get("id", ""))
pricing = or_model.get("pricing") or {}
score = (
_provider_prestige_or(model_id)
+ created_recency_signal(or_model.get("created"), now_ts)
+ pricing_band(pricing.get("prompt"), pricing.get("completion"))
+ context_signal(or_model.get("context_length"))
+ capabilities_signal(or_model.get("supported_parameters"))
+ slug_penalty(model_id)
)
return max(0, min(100, int(score)))
def static_score_yaml(cfg: dict) -> int:
"""Score a YAML-curated cfg on a 0-100 scale.
Includes ``_OPERATOR_TRUST_BONUS`` because the operator deliberately
listed this model. Pricing / context fall through to lazy ``litellm``
lookups; failures are silent (we just lose those sub-points).
"""
provider = str(cfg.get("provider", "")).upper()
base = PROVIDER_PRESTIGE_YAML.get(provider, 15)
model_name = cfg.get("model_name") or ""
litellm_params = cfg.get("litellm_params") or {}
lookup_name = (
litellm_params.get("base_model") or litellm_params.get("model") or model_name
)
ctx = 0
p_cost: float = 0.0
c_cost: float = 0.0
try:
from litellm import get_model_info # lazy: avoid cold-import cost
info = get_model_info(lookup_name) or {}
ctx = int(info.get("max_input_tokens") or info.get("max_tokens") or 0)
p_cost = float(info.get("input_cost_per_token") or 0.0)
c_cost = float(info.get("output_cost_per_token") or 0.0)
except Exception:
# Unknown to litellm — that's fine for prestige+operator-bonus weighting.
pass
score = (
base
+ _OPERATOR_TRUST_BONUS
+ pricing_band(p_cost, c_cost)
+ context_signal(ctx)
+ slug_penalty(str(model_name))
)
return max(0, min(100, int(score)))
# ---------------------------------------------------------------------------
# Health aggregation
# ---------------------------------------------------------------------------
def _coerce_pct(value) -> float | None:
try:
if value is None:
return None
f = float(value)
except (TypeError, ValueError):
return None
if f < 0:
return None
# OpenRouter reports uptime as a 0-1 fraction; some endpoints surface it
# as a 0-100 percentage. Normalise.
return f * 100.0 if f <= 1.0 else f
def _best_uptime(endpoints: list[dict]) -> tuple[float | None, str | None]:
"""Pick the best (highest) non-null uptime across all endpoints.
Window preference: ``uptime_last_30m`` > ``uptime_last_1d`` >
``uptime_last_5m``. Returns ``(uptime_pct, window_used)``.
"""
for window in ("uptime_last_30m", "uptime_last_1d", "uptime_last_5m"):
values = [_coerce_pct(ep.get(window)) for ep in endpoints]
values = [v for v in values if v is not None]
if values:
return max(values), window
return None, None
def aggregate_health(endpoints: list[dict]) -> tuple[bool, float | None]:
"""Aggregate a model's per-endpoint health into ``(gated, score_or_none)``.
Hard gate (returns ``(True, None)``):
* ``endpoints`` empty,
* no endpoint reports ``status == 0`` (OK), or
* best non-null uptime below ``_HEALTH_GATE_UPTIME_PCT``.
On a pass, returns a 0-100 health score blending uptime, status, and a
freshness-weighted recent uptime sample.
"""
if not endpoints:
return True, None
any_ok = any(int(ep.get("status", 1)) == 0 for ep in endpoints)
if not any_ok:
return True, None
best_uptime, _ = _best_uptime(endpoints)
if best_uptime is None or best_uptime < _HEALTH_GATE_UPTIME_PCT:
return True, None
# Freshness term: prefer 5m, fall through to 30m / 1d if 5m is missing.
freshness = None
for window in ("uptime_last_5m", "uptime_last_30m", "uptime_last_1d"):
values = [_coerce_pct(ep.get(window)) for ep in endpoints]
values = [v for v in values if v is not None]
if values:
freshness = max(values)
break
uptime_term = best_uptime
status_term = 100.0 if any_ok else 0.0
freshness_term = freshness if freshness is not None else best_uptime
score = 0.50 * uptime_term + 0.30 * status_term + 0.20 * freshness_term
return False, max(0.0, min(100.0, score))

View file

@ -64,7 +64,12 @@ from app.db import (
shielded_async_session,
)
from app.prompts import TITLE_GENERATION_PROMPT
from app.services.auto_model_pin_service import resolve_or_get_pinned_llm_config_id
from app.services.auto_model_pin_service import (
is_recently_healthy,
mark_healthy,
mark_runtime_cooldown,
resolve_or_get_pinned_llm_config_id,
)
from app.services.chat_session_state_service import (
clear_ai_responding,
set_ai_responding,
@ -299,20 +304,17 @@ def _tool_output_has_error(tool_output: Any) -> bool:
return False
def _extract_resolved_file_path(*, tool_name: str, tool_output: Any) -> str | None:
def _extract_resolved_file_path(
*, tool_name: str, tool_output: Any, tool_input: Any | None = None
) -> str | None:
if isinstance(tool_output, dict):
path_value = tool_output.get("path")
if isinstance(path_value, str) and path_value.strip():
return path_value.strip()
text = _tool_output_to_text(tool_output)
if tool_name == "write_file":
match = re.search(r"Updated file\s+(.+)$", text.strip())
if match:
return match.group(1).strip()
if tool_name == "edit_file":
match = re.search(r"in '([^']+)'", text)
if match:
return match.group(1).strip()
if tool_name in ("write_file", "edit_file") and isinstance(tool_input, dict):
file_path = tool_input.get("file_path")
if isinstance(file_path, str) and file_path.strip():
return file_path.strip()
return None
@ -414,6 +416,108 @@ def _parse_error_payload(message: str) -> dict[str, Any] | None:
return None
def _extract_provider_error_code(parsed: dict[str, Any] | None) -> int | None:
if not isinstance(parsed, dict):
return None
candidates: list[Any] = [parsed.get("code")]
nested = parsed.get("error")
if isinstance(nested, dict):
candidates.append(nested.get("code"))
for value in candidates:
try:
if value is None:
continue
return int(value)
except Exception:
continue
return None
def _is_provider_rate_limited(exc: BaseException) -> bool:
"""Best-effort detection for provider-side runtime throttling.
Covers LiteLLM/OpenRouter shapes like:
- class name contains ``RateLimit``
- nested payload ``{"error": {"code": 429}}``
- nested payload ``{"error": {"type": "rate_limit_error"}}``
"""
raw = str(exc)
lowered = raw.lower()
if "ratelimit" in type(exc).__name__.lower():
return True
parsed = _parse_error_payload(raw)
provider_code = _extract_provider_error_code(parsed)
if provider_code == 429:
return True
provider_error_type = ""
if parsed:
top_type = parsed.get("type")
if isinstance(top_type, str):
provider_error_type = top_type.lower()
nested = parsed.get("error")
if isinstance(nested, dict):
nested_type = nested.get("type")
if isinstance(nested_type, str):
provider_error_type = nested_type.lower()
if provider_error_type == "rate_limit_error":
return True
return (
"rate limited" in lowered
or "rate-limited" in lowered
or "temporarily rate-limited upstream" in lowered
)
_PREFLIGHT_TIMEOUT_SEC: float = 2.5
_PREFLIGHT_MAX_TOKENS: int = 1
async def _preflight_llm(llm: Any) -> None:
"""Issue a minimal completion to confirm the pinned model isn't 429'ing.
Used before agent build / planner / classifier / title-gen so a known-bad
free OpenRouter deployment is detected and repinned before it cascades
into multiple wasted internal calls. The probe is intentionally cheap:
one token, low timeout, tagged ``surfsense:internal`` so token tracking
and SSE pipelines treat it as overhead rather than user output.
Raises the original exception when the provider responds with a
rate-limit-shaped error so the caller can drive the cooldown/repin
branch via :func:`_is_provider_rate_limited`. Other transient failures
are swallowed the caller continues to the normal stream path and the
in-stream recovery loop remains the safety net.
"""
from litellm import acompletion
model = getattr(llm, "model", None)
if not model or model == "auto":
# Auto-mode router doesn't have a single deployment to ping; the
# router itself handles per-deployment rate-limit accounting.
return
try:
await acompletion(
model=model,
messages=[{"role": "user", "content": "ping"}],
api_key=getattr(llm, "api_key", None),
api_base=getattr(llm, "api_base", None),
max_tokens=_PREFLIGHT_MAX_TOKENS,
timeout=_PREFLIGHT_TIMEOUT_SEC,
stream=False,
metadata={"tags": ["surfsense:internal", "auto-pin-preflight"]},
)
except Exception as exc:
if _is_provider_rate_limited(exc):
raise
logging.getLogger(__name__).debug(
"auto_pin_preflight non_rate_limit_error model=%s err=%s",
model,
exc,
)
def _classify_stream_exception(
exc: Exception,
*,
@ -449,19 +553,7 @@ def _classify_stream_exception(
None,
)
parsed = _parse_error_payload(raw)
provider_error_type = ""
if parsed:
top_type = parsed.get("type")
if isinstance(top_type, str):
provider_error_type = top_type.lower()
nested = parsed.get("error")
if isinstance(nested, dict):
nested_type = nested.get("type")
if isinstance(nested_type, str):
provider_error_type = nested_type.lower()
if provider_error_type == "rate_limit_error":
if _is_provider_rate_limited(exc):
return (
"rate_limited",
"RATE_LIMITED",
@ -619,6 +711,7 @@ async def _stream_agent_events(
# fallback path only and never re-pops a chunk we already streamed.
pending_tool_call_chunks: list[dict[str, Any]] = []
lc_tool_call_id_by_run: dict[str, str] = {}
file_path_by_run: dict[str, str] = {}
# parity_v2 only: live tool-call argument streaming. ``index_to_meta``
# is keyed by the chunk's ``index`` field — LangChain
@ -797,6 +890,10 @@ async def _stream_agent_events(
tool_input = event.get("data", {}).get("input", {})
if tool_name in ("write_file", "edit_file"):
result.write_attempted = True
if isinstance(tool_input, dict):
file_path = tool_input.get("file_path")
if isinstance(file_path, str) and file_path.strip() and run_id:
file_path_by_run[run_id] = file_path.strip()
if current_text_id is not None:
yield streaming_service.format_text_end(current_text_id)
@ -1203,6 +1300,7 @@ async def _stream_agent_events(
run_id = event.get("run_id", "")
tool_name = event.get("name", "unknown_tool")
raw_output = event.get("data", {}).get("output", "")
staged_file_path = file_path_by_run.pop(run_id, None) if run_id else None
if tool_name == "update_memory":
called_update_memory = True
@ -1716,6 +1814,9 @@ async def _stream_agent_events(
resolved_path = _extract_resolved_file_path(
tool_name=tool_name,
tool_output=tool_output,
tool_input={"file_path": staged_file_path}
if staged_file_path
else None,
)
result_text = _tool_output_to_text(tool_output)
if _tool_output_has_error(tool_output):
@ -2326,6 +2427,91 @@ async def stream_new_chat(
yield streaming_service.format_done()
return
# Auto-mode preflight ping. Runs ONLY for thread-pinned auto cfgs
# (negative ids selected via ``resolve_or_get_pinned_llm_config_id``)
# whose health hasn't already been confirmed within the TTL window.
# Detecting a 429 here lets us repin BEFORE the planner/classifier/
# title-generation LLM calls fan out and each independently hit the
# same upstream rate limit.
if (
requested_llm_config_id == 0
and llm_config_id < 0
and not is_recently_healthy(llm_config_id)
):
_t_preflight = time.perf_counter()
try:
await _preflight_llm(llm)
mark_healthy(llm_config_id)
_perf_log.info(
"[stream_new_chat] auto_pin_preflight ok config_id=%s took=%.3fs",
llm_config_id,
time.perf_counter() - _t_preflight,
)
except Exception as preflight_exc:
if not _is_provider_rate_limited(preflight_exc):
raise
previous_config_id = llm_config_id
mark_runtime_cooldown(
previous_config_id, reason="preflight_rate_limited"
)
try:
llm_config_id = (
await resolve_or_get_pinned_llm_config_id(
session,
thread_id=chat_id,
search_space_id=search_space_id,
user_id=user_id,
selected_llm_config_id=0,
exclude_config_ids={previous_config_id},
)
).resolved_llm_config_id
except ValueError as pin_error:
yield _emit_stream_error(
message=str(pin_error),
error_kind="server_error",
error_code="SERVER_ERROR",
)
yield streaming_service.format_done()
return
llm, agent_config, llm_load_error = await _load_llm_bundle(
llm_config_id
)
if llm_load_error or not llm:
yield _emit_stream_error(
message=llm_load_error or "Failed to create LLM instance",
error_kind="server_error",
error_code="SERVER_ERROR",
)
yield streaming_service.format_done()
return
# Trust the freshly-resolved cfg for the remainder of this
# turn rather than recursing into another preflight; the
# in-stream 429 recovery loop is still in place as the
# safety net if even this fallback hits an upstream cap.
mark_healthy(llm_config_id)
_log_chat_stream_error(
flow=flow,
error_kind="rate_limited",
error_code="RATE_LIMITED",
severity="info",
is_expected=True,
request_id=request_id,
thread_id=chat_id,
search_space_id=search_space_id,
user_id=user_id,
message=(
"Auto-pinned model failed preflight; switched to another "
"eligible model and continuing."
),
extra={
"auto_runtime_recover": True,
"preflight": True,
"previous_config_id": previous_config_id,
"fallback_config_id": llm_config_id,
},
)
# Create connector service
_t0 = time.perf_counter()
connector_service = ConnectorService(session, search_space_id=search_space_id)
@ -2671,54 +2857,155 @@ async def stream_new_chat(
_t_stream_start = time.perf_counter()
_first_event_logged = False
async for sse in _stream_agent_events(
agent=agent,
config=config,
input_data=input_state,
streaming_service=streaming_service,
result=stream_result,
step_prefix="thinking",
initial_step_id=initial_step_id,
initial_step_title=initial_title,
initial_step_items=initial_items,
fallback_commit_search_space_id=search_space_id,
fallback_commit_created_by_id=user_id,
fallback_commit_filesystem_mode=(
filesystem_selection.mode
if filesystem_selection
else FilesystemMode.CLOUD
),
fallback_commit_thread_id=chat_id,
):
if not _first_event_logged:
_perf_log.info(
"[stream_new_chat] First agent event in %.3fs (time since stream start), "
"%.3fs (total since request start) (chat_id=%s)",
time.perf_counter() - _t_stream_start,
time.perf_counter() - _t_total,
chat_id,
)
_first_event_logged = True
yield sse
# Inject title update mid-stream as soon as the background task finishes
if title_task is not None and title_task.done() and not title_emitted:
generated_title, title_usage = title_task.result()
if title_usage:
accumulator.add(**title_usage)
if generated_title:
async with shielded_async_session() as title_session:
title_thread_result = await title_session.execute(
select(NewChatThread).filter(NewChatThread.id == chat_id)
runtime_rate_limit_recovered = False
while True:
try:
async for sse in _stream_agent_events(
agent=agent,
config=config,
input_data=input_state,
streaming_service=streaming_service,
result=stream_result,
step_prefix="thinking",
initial_step_id=initial_step_id,
initial_step_title=initial_title,
initial_step_items=initial_items,
fallback_commit_search_space_id=search_space_id,
fallback_commit_created_by_id=user_id,
fallback_commit_filesystem_mode=(
filesystem_selection.mode
if filesystem_selection
else FilesystemMode.CLOUD
),
fallback_commit_thread_id=chat_id,
):
if not _first_event_logged:
_perf_log.info(
"[stream_new_chat] First agent event in %.3fs (time since stream start), "
"%.3fs (total since request start) (chat_id=%s)",
time.perf_counter() - _t_stream_start,
time.perf_counter() - _t_total,
chat_id,
)
title_thread = title_thread_result.scalars().first()
if title_thread:
title_thread.title = generated_title
await title_session.commit()
yield streaming_service.format_thread_title_update(
chat_id, generated_title
_first_event_logged = True
yield sse
# Inject title update mid-stream as soon as the background
# task finishes.
if (
title_task is not None
and title_task.done()
and not title_emitted
):
generated_title, title_usage = title_task.result()
if title_usage:
accumulator.add(**title_usage)
if generated_title:
async with shielded_async_session() as title_session:
title_thread_result = await title_session.execute(
select(NewChatThread).filter(
NewChatThread.id == chat_id
)
)
title_thread = title_thread_result.scalars().first()
if title_thread:
title_thread.title = generated_title
await title_session.commit()
yield streaming_service.format_thread_title_update(
chat_id, generated_title
)
title_emitted = True
break
except Exception as stream_exc:
can_runtime_recover = (
not runtime_rate_limit_recovered
and requested_llm_config_id == 0
and llm_config_id < 0
and not _first_event_logged
and _is_provider_rate_limited(stream_exc)
)
if not can_runtime_recover:
raise
runtime_rate_limit_recovered = True
previous_config_id = llm_config_id
# The failed attempt may still hold the per-thread busy mutex
# (middleware teardown can lag behind raised provider errors).
# Force release before we retry within the same request.
end_turn(str(chat_id))
mark_runtime_cooldown(
previous_config_id,
reason="provider_rate_limited",
)
llm_config_id = (
await resolve_or_get_pinned_llm_config_id(
session,
thread_id=chat_id,
search_space_id=search_space_id,
user_id=user_id,
selected_llm_config_id=0,
exclude_config_ids={previous_config_id},
)
title_emitted = True
).resolved_llm_config_id
llm, agent_config, llm_load_error = await _load_llm_bundle(
llm_config_id
)
if llm_load_error:
raise stream_exc
# Title generation uses the initial llm object. After a runtime
# repin we keep the stream focused on response recovery and skip
# title generation for this turn.
if title_task is not None and not title_task.done():
title_task.cancel()
title_task = None
_t0 = time.perf_counter()
agent = await create_surfsense_deep_agent(
llm=llm,
search_space_id=search_space_id,
db_session=session,
connector_service=connector_service,
checkpointer=checkpointer,
user_id=user_id,
thread_id=chat_id,
agent_config=agent_config,
firecrawl_api_key=firecrawl_api_key,
thread_visibility=visibility,
disabled_tools=disabled_tools,
mentioned_document_ids=mentioned_document_ids,
filesystem_selection=filesystem_selection,
)
_perf_log.info(
"[stream_new_chat] Runtime rate-limit recovery repinned "
"config_id=%s -> %s and rebuilt agent in %.3fs",
previous_config_id,
llm_config_id,
time.perf_counter() - _t0,
)
_log_chat_stream_error(
flow=flow,
error_kind="rate_limited",
error_code="RATE_LIMITED",
severity="info",
is_expected=True,
request_id=request_id,
thread_id=chat_id,
search_space_id=search_space_id,
user_id=user_id,
message=(
"Auto-pinned model hit runtime rate limit; switched to "
"another eligible model and retried."
),
extra={
"auto_runtime_recover": True,
"previous_config_id": previous_config_id,
"fallback_config_id": llm_config_id,
},
)
continue
_perf_log.info(
"[stream_new_chat] Agent stream completed in %.3fs (chat_id=%s)",
@ -3187,6 +3474,84 @@ async def stream_resume_chat(
yield streaming_service.format_done()
return
# Auto-mode preflight ping (resume path). Mirrors ``stream_new_chat``:
# one cheap probe before the agent is rebuilt so a 429'd pin gets
# repinned without burning planner/classifier/title calls first.
if (
requested_llm_config_id == 0
and llm_config_id < 0
and not is_recently_healthy(llm_config_id)
):
_t_preflight = time.perf_counter()
try:
await _preflight_llm(llm)
mark_healthy(llm_config_id)
_perf_log.info(
"[stream_resume] auto_pin_preflight ok config_id=%s took=%.3fs",
llm_config_id,
time.perf_counter() - _t_preflight,
)
except Exception as preflight_exc:
if not _is_provider_rate_limited(preflight_exc):
raise
previous_config_id = llm_config_id
mark_runtime_cooldown(
previous_config_id, reason="preflight_rate_limited"
)
try:
llm_config_id = (
await resolve_or_get_pinned_llm_config_id(
session,
thread_id=chat_id,
search_space_id=search_space_id,
user_id=user_id,
selected_llm_config_id=0,
exclude_config_ids={previous_config_id},
)
).resolved_llm_config_id
except ValueError as pin_error:
yield _emit_stream_error(
message=str(pin_error),
error_kind="server_error",
error_code="SERVER_ERROR",
)
yield streaming_service.format_done()
return
llm, agent_config, llm_load_error = await _load_llm_bundle(
llm_config_id
)
if llm_load_error or not llm:
yield _emit_stream_error(
message=llm_load_error or "Failed to create LLM instance",
error_kind="server_error",
error_code="SERVER_ERROR",
)
yield streaming_service.format_done()
return
mark_healthy(llm_config_id)
_log_chat_stream_error(
flow="resume",
error_kind="rate_limited",
error_code="RATE_LIMITED",
severity="info",
is_expected=True,
request_id=request_id,
thread_id=chat_id,
search_space_id=search_space_id,
user_id=user_id,
message=(
"Auto-pinned model failed preflight; switched to another "
"eligible model and continuing."
),
extra={
"auto_runtime_recover": True,
"preflight": True,
"previous_config_id": previous_config_id,
"fallback_config_id": llm_config_id,
},
)
_t0 = time.perf_counter()
connector_service = ConnectorService(session, search_space_id=search_space_id)
@ -3265,31 +3630,114 @@ async def stream_resume_chat(
_t_stream_start = time.perf_counter()
_first_event_logged = False
async for sse in _stream_agent_events(
agent=agent,
config=config,
input_data=Command(resume={"decisions": decisions}),
streaming_service=streaming_service,
result=stream_result,
step_prefix="thinking-resume",
fallback_commit_search_space_id=search_space_id,
fallback_commit_created_by_id=user_id,
fallback_commit_filesystem_mode=(
filesystem_selection.mode
if filesystem_selection
else FilesystemMode.CLOUD
),
fallback_commit_thread_id=chat_id,
):
if not _first_event_logged:
_perf_log.info(
"[stream_resume] First agent event in %.3fs (stream), %.3fs (total) (chat_id=%s)",
time.perf_counter() - _t_stream_start,
time.perf_counter() - _t_total,
chat_id,
runtime_rate_limit_recovered = False
while True:
try:
async for sse in _stream_agent_events(
agent=agent,
config=config,
input_data=Command(resume={"decisions": decisions}),
streaming_service=streaming_service,
result=stream_result,
step_prefix="thinking-resume",
fallback_commit_search_space_id=search_space_id,
fallback_commit_created_by_id=user_id,
fallback_commit_filesystem_mode=(
filesystem_selection.mode
if filesystem_selection
else FilesystemMode.CLOUD
),
fallback_commit_thread_id=chat_id,
):
if not _first_event_logged:
_perf_log.info(
"[stream_resume] First agent event in %.3fs (stream), %.3fs (total) (chat_id=%s)",
time.perf_counter() - _t_stream_start,
time.perf_counter() - _t_total,
chat_id,
)
_first_event_logged = True
yield sse
break
except Exception as stream_exc:
can_runtime_recover = (
not runtime_rate_limit_recovered
and requested_llm_config_id == 0
and llm_config_id < 0
and not _first_event_logged
and _is_provider_rate_limited(stream_exc)
)
_first_event_logged = True
yield sse
if not can_runtime_recover:
raise
runtime_rate_limit_recovered = True
previous_config_id = llm_config_id
# Ensure the same-request recovery retry does not trip the
# BusyMutex lock retained by the failed attempt.
end_turn(str(chat_id))
mark_runtime_cooldown(
previous_config_id,
reason="provider_rate_limited",
)
llm_config_id = (
await resolve_or_get_pinned_llm_config_id(
session,
thread_id=chat_id,
search_space_id=search_space_id,
user_id=user_id,
selected_llm_config_id=0,
exclude_config_ids={previous_config_id},
)
).resolved_llm_config_id
llm, agent_config, llm_load_error = await _load_llm_bundle(
llm_config_id
)
if llm_load_error:
raise stream_exc
_t0 = time.perf_counter()
agent = await create_surfsense_deep_agent(
llm=llm,
search_space_id=search_space_id,
db_session=session,
connector_service=connector_service,
checkpointer=checkpointer,
user_id=user_id,
thread_id=chat_id,
agent_config=agent_config,
firecrawl_api_key=firecrawl_api_key,
thread_visibility=visibility,
filesystem_selection=filesystem_selection,
)
_perf_log.info(
"[stream_resume] Runtime rate-limit recovery repinned "
"config_id=%s -> %s and rebuilt agent in %.3fs",
previous_config_id,
llm_config_id,
time.perf_counter() - _t0,
)
_log_chat_stream_error(
flow="resume",
error_kind="rate_limited",
error_code="RATE_LIMITED",
severity="info",
is_expected=True,
request_id=request_id,
thread_id=chat_id,
search_space_id=search_space_id,
user_id=user_id,
message=(
"Auto-pinned model hit runtime rate limit; switched to "
"another eligible model and retried."
),
extra={
"auto_runtime_recover": True,
"previous_config_id": previous_config_id,
"fallback_config_id": llm_config_id,
},
)
continue
_perf_log.info(
"[stream_resume] Agent stream completed in %.3fs (chat_id=%s)",
time.perf_counter() - _t_stream_start,

View file

@ -118,3 +118,37 @@ async def test_end_turn_force_clears_lock_and_cancel_state() -> None:
assert not manager.lock_for(thread_id).locked()
assert not get_cancel_event(thread_id).is_set()
assert is_cancel_requested(thread_id) is False
@pytest.mark.asyncio
async def test_busy_mutex_stale_aafter_does_not_release_new_attempt_lock() -> None:
"""A stale aafter call from attempt A must not unlock attempt B.
Repro flow:
1) attempt A acquires thread lock
2) forced end_turn clears A so retry can proceed
3) attempt B acquires same thread lock
4) stale attempt-A aafter runs late
Expected: B lock remains held.
"""
thread_id = "stale-aafter-lock"
runtime = _Runtime(thread_id)
attempt_a = BusyMutexMiddleware()
attempt_b = BusyMutexMiddleware()
await attempt_a.abefore_agent({}, runtime)
lock = manager.lock_for(thread_id)
assert lock.locked()
end_turn(thread_id)
assert not lock.locked()
await attempt_b.abefore_agent({}, runtime)
assert lock.locked()
# Stale cleanup from attempt A must not release attempt B's lock.
await attempt_a.aafter_agent({}, runtime)
assert lock.locked()
await attempt_b.aafter_agent({}, runtime)

View file

@ -6,13 +6,26 @@ from types import SimpleNamespace
import pytest
from app.services.auto_model_pin_service import (
AUTO_FASTEST_MODE,
clear_healthy,
clear_runtime_cooldown,
is_recently_healthy,
mark_healthy,
mark_runtime_cooldown,
resolve_or_get_pinned_llm_config_id,
)
pytestmark = pytest.mark.unit
@pytest.fixture(autouse=True)
def _clear_runtime_cooldown_map():
clear_runtime_cooldown()
clear_healthy()
yield
clear_runtime_cooldown()
clear_healthy()
@dataclass
class _FakeQuotaResult:
allowed: bool
@ -45,14 +58,11 @@ def _thread(
*,
search_space_id: int = 10,
pinned_llm_config_id: int | None = None,
pinned_auto_mode: str | None = None,
):
return SimpleNamespace(
id=1,
search_space_id=search_space_id,
pinned_llm_config_id=pinned_llm_config_id,
pinned_auto_mode=pinned_auto_mode,
pinned_at=None,
)
@ -93,8 +103,6 @@ async def test_auto_first_turn_pins_one_model(monkeypatch):
)
assert result.resolved_llm_config_id in {-1, -2}
assert session.thread.pinned_llm_config_id == result.resolved_llm_config_id
assert session.thread.pinned_auto_mode == AUTO_FASTEST_MODE
assert session.thread.pinned_at is not None
assert session.commit_count == 1
@ -102,9 +110,7 @@ async def test_auto_first_turn_pins_one_model(monkeypatch):
async def test_next_turn_reuses_existing_pin(monkeypatch):
from app.config import config
session = _FakeSession(
_thread(pinned_llm_config_id=-1, pinned_auto_mode=AUTO_FASTEST_MODE)
)
session = _FakeSession(_thread(pinned_llm_config_id=-1))
monkeypatch.setattr(
config,
"GLOBAL_LLM_CONFIGS",
@ -228,9 +234,7 @@ async def test_premium_ineligible_auto_pins_free_only(monkeypatch):
async def test_pinned_premium_stays_premium_after_quota_exhaustion(monkeypatch):
from app.config import config
session = _FakeSession(
_thread(pinned_llm_config_id=-1, pinned_auto_mode=AUTO_FASTEST_MODE)
)
session = _FakeSession(_thread(pinned_llm_config_id=-1))
monkeypatch.setattr(
config,
"GLOBAL_LLM_CONFIGS",
@ -275,9 +279,7 @@ async def test_pinned_premium_stays_premium_after_quota_exhaustion(monkeypatch):
async def test_force_repin_free_switches_auto_premium_pin_to_free(monkeypatch):
from app.config import config
session = _FakeSession(
_thread(pinned_llm_config_id=-1, pinned_auto_mode=AUTO_FASTEST_MODE)
)
session = _FakeSession(_thread(pinned_llm_config_id=-1))
monkeypatch.setattr(
config,
"GLOBAL_LLM_CONFIGS",
@ -325,9 +327,7 @@ async def test_force_repin_free_switches_auto_premium_pin_to_free(monkeypatch):
async def test_explicit_user_model_change_clears_pin(monkeypatch):
from app.config import config
session = _FakeSession(
_thread(pinned_llm_config_id=-2, pinned_auto_mode=AUTO_FASTEST_MODE)
)
session = _FakeSession(_thread(pinned_llm_config_id=-2))
monkeypatch.setattr(
config,
"GLOBAL_LLM_CONFIGS",
@ -345,8 +345,6 @@ async def test_explicit_user_model_change_clears_pin(monkeypatch):
)
assert result.resolved_llm_config_id == 7
assert session.thread.pinned_llm_config_id is None
assert session.thread.pinned_auto_mode is None
assert session.thread.pinned_at is None
assert session.commit_count == 1
@ -354,9 +352,7 @@ async def test_explicit_user_model_change_clears_pin(monkeypatch):
async def test_invalid_pinned_config_repairs_with_new_pin(monkeypatch):
from app.config import config
session = _FakeSession(
_thread(pinned_llm_config_id=-999, pinned_auto_mode=AUTO_FASTEST_MODE)
)
session = _FakeSession(_thread(pinned_llm_config_id=-999))
monkeypatch.setattr(
config,
"GLOBAL_LLM_CONFIGS",
@ -383,3 +379,543 @@ async def test_invalid_pinned_config_repairs_with_new_pin(monkeypatch):
assert result.resolved_llm_config_id == -2
assert session.thread.pinned_llm_config_id == -2
assert session.commit_count == 1
# ---------------------------------------------------------------------------
# Quality-aware pin selection (Auto Fastest upgrade)
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_health_gated_config_is_excluded_from_selection(monkeypatch):
"""A cfg flagged ``health_gated`` must never be picked even if it has
the highest score among eligible cfgs."""
from app.config import config
session = _FakeSession(_thread())
monkeypatch.setattr(
config,
"GLOBAL_LLM_CONFIGS",
[
{
"id": -1,
"provider": "OPENROUTER",
"model_name": "venice/dead-model",
"api_key": "k1",
"billing_tier": "free",
"auto_pin_tier": "C",
"quality_score": 95,
"health_gated": True,
},
{
"id": -2,
"provider": "OPENROUTER",
"model_name": "google/gemini-flash",
"api_key": "k1",
"billing_tier": "free",
"auto_pin_tier": "C",
"quality_score": 60,
"health_gated": False,
},
],
)
async def _blocked(*_args, **_kwargs):
return _FakeQuotaResult(allowed=False)
monkeypatch.setattr(
"app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage",
_blocked,
)
result = await resolve_or_get_pinned_llm_config_id(
session,
thread_id=1,
search_space_id=10,
user_id="00000000-0000-0000-0000-000000000001",
selected_llm_config_id=0,
)
assert result.resolved_llm_config_id == -2
@pytest.mark.asyncio
async def test_tier_a_locks_first_premium_user_skips_or(monkeypatch):
"""Premium-eligible users with Tier A available should never spill to
Tier B even if a B cfg ranks higher by ``quality_score``."""
from app.config import config
session = _FakeSession(_thread())
monkeypatch.setattr(
config,
"GLOBAL_LLM_CONFIGS",
[
{
"id": -1,
"provider": "AZURE_OPENAI",
"model_name": "gpt-5",
"api_key": "k-yaml",
"billing_tier": "premium",
"auto_pin_tier": "A",
"quality_score": 70,
"health_gated": False,
},
{
"id": -2,
"provider": "OPENROUTER",
"model_name": "openai/gpt-5",
"api_key": "k-or",
"billing_tier": "premium",
"auto_pin_tier": "B",
"quality_score": 95,
"health_gated": False,
},
],
)
async def _allowed(*_args, **_kwargs):
return _FakeQuotaResult(allowed=True)
monkeypatch.setattr(
"app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage",
_allowed,
)
result = await resolve_or_get_pinned_llm_config_id(
session,
thread_id=1,
search_space_id=10,
user_id="00000000-0000-0000-0000-000000000001",
selected_llm_config_id=0,
)
assert result.resolved_llm_config_id == -1
assert result.resolved_tier == "premium"
@pytest.mark.asyncio
async def test_tier_a_falls_through_to_or_when_a_pool_empty_for_user(monkeypatch):
"""Free-only user with no Tier A free cfg should pick from Tier C."""
from app.config import config
session = _FakeSession(_thread())
monkeypatch.setattr(
config,
"GLOBAL_LLM_CONFIGS",
[
{
"id": -1,
"provider": "AZURE_OPENAI",
"model_name": "gpt-5",
"api_key": "k-yaml",
"billing_tier": "premium",
"auto_pin_tier": "A",
"quality_score": 100,
"health_gated": False,
},
{
"id": -2,
"provider": "OPENROUTER",
"model_name": "google/gemini-flash:free",
"api_key": "k-or",
"billing_tier": "free",
"auto_pin_tier": "C",
"quality_score": 60,
"health_gated": False,
},
],
)
async def _blocked(*_args, **_kwargs):
return _FakeQuotaResult(allowed=False)
monkeypatch.setattr(
"app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage",
_blocked,
)
result = await resolve_or_get_pinned_llm_config_id(
session,
thread_id=1,
search_space_id=10,
user_id="00000000-0000-0000-0000-000000000001",
selected_llm_config_id=0,
)
assert result.resolved_llm_config_id == -2
@pytest.mark.asyncio
async def test_top_k_picks_only_high_score_models(monkeypatch):
"""Different thread IDs should spread across top-K, never pick the
obvious low-quality cfg even when it sits in the candidate list."""
from app.config import config
high_score_cfgs = [
{
"id": -i,
"provider": "AZURE_OPENAI",
"model_name": f"gpt-x-{i}",
"api_key": "k",
"billing_tier": "premium",
"auto_pin_tier": "A",
"quality_score": 90,
"health_gated": False,
}
for i in range(1, 6) # 5 high-quality Tier A cfgs
]
low_score_trap = {
"id": -99,
"provider": "AZURE_OPENAI",
"model_name": "tiny-legacy",
"api_key": "k",
"billing_tier": "premium",
"auto_pin_tier": "A",
"quality_score": 10,
"health_gated": False,
}
monkeypatch.setattr(
config,
"GLOBAL_LLM_CONFIGS",
[*high_score_cfgs, low_score_trap],
)
async def _allowed(*_args, **_kwargs):
return _FakeQuotaResult(allowed=True)
monkeypatch.setattr(
"app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage",
_allowed,
)
high_score_ids = {c["id"] for c in high_score_cfgs}
seen = set()
for thread_id in range(1, 50):
session = _FakeSession(_thread())
result = await resolve_or_get_pinned_llm_config_id(
session,
thread_id=thread_id,
search_space_id=10,
user_id="00000000-0000-0000-0000-000000000001",
selected_llm_config_id=0,
)
seen.add(result.resolved_llm_config_id)
assert result.resolved_llm_config_id != -99, (
"low-score trap cfg should never be picked"
)
assert result.resolved_llm_config_id in high_score_ids
# Spread across at least a couple of top-K cfgs.
assert len(seen) > 1
@pytest.mark.asyncio
async def test_pin_reuse_survives_health_gating_for_existing_pin(monkeypatch):
"""An *already* pinned cfg that later flips to ``health_gated`` should
still not be reused gated cfgs are filtered out of the candidate
pool, which forces a repair to a healthy cfg.
This guards the no-silent-tier-switch invariant: we don't keep using
a known-broken model just because the thread happened to be pinned
to it before the gate fired."""
from app.config import config
session = _FakeSession(_thread(pinned_llm_config_id=-1))
monkeypatch.setattr(
config,
"GLOBAL_LLM_CONFIGS",
[
{
"id": -1,
"provider": "OPENROUTER",
"model_name": "venice/dead-model",
"api_key": "k",
"billing_tier": "premium",
"auto_pin_tier": "B",
"quality_score": 50,
"health_gated": True,
},
{
"id": -2,
"provider": "AZURE_OPENAI",
"model_name": "gpt-5",
"api_key": "k",
"billing_tier": "premium",
"auto_pin_tier": "A",
"quality_score": 90,
"health_gated": False,
},
],
)
async def _allowed(*_args, **_kwargs):
return _FakeQuotaResult(allowed=True)
monkeypatch.setattr(
"app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage",
_allowed,
)
result = await resolve_or_get_pinned_llm_config_id(
session,
thread_id=1,
search_space_id=10,
user_id="00000000-0000-0000-0000-000000000001",
selected_llm_config_id=0,
)
assert result.resolved_llm_config_id == -2
assert result.from_existing_pin is False
@pytest.mark.asyncio
async def test_pin_reuse_regression_existing_healthy_pin(monkeypatch):
"""Existing pin reuse must short-circuit the new tier/score logic."""
from app.config import config
session = _FakeSession(_thread(pinned_llm_config_id=-1))
monkeypatch.setattr(
config,
"GLOBAL_LLM_CONFIGS",
[
{
"id": -1,
"provider": "AZURE_OPENAI",
"model_name": "gpt-5",
"api_key": "k",
"billing_tier": "premium",
"auto_pin_tier": "A",
"quality_score": 50, # lower than -2
"health_gated": False,
},
{
"id": -2,
"provider": "AZURE_OPENAI",
"model_name": "gpt-5-pro",
"api_key": "k",
"billing_tier": "premium",
"auto_pin_tier": "A",
"quality_score": 99,
"health_gated": False,
},
],
)
async def _must_not_call(*_args, **_kwargs):
raise AssertionError("premium_get_usage should not run on pin reuse")
monkeypatch.setattr(
"app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage",
_must_not_call,
)
result = await resolve_or_get_pinned_llm_config_id(
session,
thread_id=1,
search_space_id=10,
user_id="00000000-0000-0000-0000-000000000001",
selected_llm_config_id=0,
)
assert result.resolved_llm_config_id == -1
assert result.from_existing_pin is True
assert session.commit_count == 0
@pytest.mark.asyncio
async def test_runtime_cooled_down_pin_is_not_reused(monkeypatch):
"""A runtime-cooled config should be excluded from candidate reuse.
This enables one-shot recovery from transient provider 429 bursts: we can
mark the pinned cfg as cooled down and force a repair to another eligible
cfg on the next resolution.
"""
from app.config import config
session = _FakeSession(_thread(pinned_llm_config_id=-1))
monkeypatch.setattr(
config,
"GLOBAL_LLM_CONFIGS",
[
{
"id": -1,
"provider": "OPENROUTER",
"model_name": "google/gemma-4-26b-a4b-it:free",
"api_key": "k",
"billing_tier": "free",
"auto_pin_tier": "C",
"quality_score": 90,
"health_gated": False,
},
{
"id": -2,
"provider": "OPENROUTER",
"model_name": "google/gemini-2.5-flash:free",
"api_key": "k",
"billing_tier": "free",
"auto_pin_tier": "C",
"quality_score": 80,
"health_gated": False,
},
],
)
async def _blocked(*_args, **_kwargs):
return _FakeQuotaResult(allowed=False)
monkeypatch.setattr(
"app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage",
_blocked,
)
mark_runtime_cooldown(-1, reason="provider_rate_limited", cooldown_seconds=600)
result = await resolve_or_get_pinned_llm_config_id(
session,
thread_id=1,
search_space_id=10,
user_id="00000000-0000-0000-0000-000000000001",
selected_llm_config_id=0,
)
assert result.resolved_llm_config_id == -2
assert result.from_existing_pin is False
@pytest.mark.asyncio
async def test_clearing_runtime_cooldown_restores_pin_reuse(monkeypatch):
from app.config import config
session = _FakeSession(_thread(pinned_llm_config_id=-1))
monkeypatch.setattr(
config,
"GLOBAL_LLM_CONFIGS",
[
{
"id": -1,
"provider": "OPENROUTER",
"model_name": "google/gemma-4-26b-a4b-it:free",
"api_key": "k",
"billing_tier": "free",
"auto_pin_tier": "C",
"quality_score": 90,
"health_gated": False,
},
],
)
async def _must_not_call(*_args, **_kwargs):
raise AssertionError("premium_get_usage should not run on healthy pin reuse")
monkeypatch.setattr(
"app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage",
_must_not_call,
)
mark_runtime_cooldown(-1, reason="provider_rate_limited", cooldown_seconds=600)
clear_runtime_cooldown(-1)
result = await resolve_or_get_pinned_llm_config_id(
session,
thread_id=1,
search_space_id=10,
user_id="00000000-0000-0000-0000-000000000001",
selected_llm_config_id=0,
)
assert result.resolved_llm_config_id == -1
assert result.from_existing_pin is True
@pytest.mark.asyncio
async def test_auto_pin_repin_excludes_previous_config_on_runtime_retry(monkeypatch):
"""Runtime retry should never repin the just-failed config."""
from app.config import config
session = _FakeSession(_thread(pinned_llm_config_id=-1))
monkeypatch.setattr(
config,
"GLOBAL_LLM_CONFIGS",
[
{
"id": -1,
"provider": "OPENROUTER",
"model_name": "google/gemma-4-26b-a4b-it:free",
"api_key": "k",
"billing_tier": "free",
"auto_pin_tier": "C",
"quality_score": 90,
"health_gated": False,
},
{
"id": -2,
"provider": "OPENROUTER",
"model_name": "google/gemini-2.5-flash:free",
"api_key": "k",
"billing_tier": "free",
"auto_pin_tier": "C",
"quality_score": 80,
"health_gated": False,
},
],
)
async def _blocked(*_args, **_kwargs):
return _FakeQuotaResult(allowed=False)
monkeypatch.setattr(
"app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage",
_blocked,
)
result = await resolve_or_get_pinned_llm_config_id(
session,
thread_id=1,
search_space_id=10,
user_id="00000000-0000-0000-0000-000000000001",
selected_llm_config_id=0,
exclude_config_ids={-1},
)
assert result.resolved_llm_config_id == -2
assert result.from_existing_pin is False
# ---------------------------------------------------------------------------
# Healthy-status cache (preflight TTL companion)
# ---------------------------------------------------------------------------
def test_mark_healthy_then_is_recently_healthy_true_within_ttl():
mark_healthy(-42, ttl_seconds=60)
assert is_recently_healthy(-42) is True
def test_healthy_expires_after_ttl(monkeypatch):
import app.services.auto_model_pin_service as svc
real_time = svc.time.time
base = real_time()
monkeypatch.setattr(svc.time, "time", lambda: base)
mark_healthy(-7, ttl_seconds=10)
assert is_recently_healthy(-7) is True
monkeypatch.setattr(svc.time, "time", lambda: base + 11)
assert is_recently_healthy(-7) is False
def test_mark_runtime_cooldown_invalidates_healthy_cache():
mark_healthy(-9, ttl_seconds=60)
assert is_recently_healthy(-9) is True
mark_runtime_cooldown(-9, reason="test", cooldown_seconds=60)
assert is_recently_healthy(-9) is False
def test_clear_healthy_removes_single_entry():
mark_healthy(-11, ttl_seconds=60)
mark_healthy(-12, ttl_seconds=60)
clear_healthy(-11)
assert is_recently_healthy(-11) is False
assert is_recently_healthy(-12) is True
def test_clear_healthy_no_args_drops_all_entries():
mark_healthy(-21, ttl_seconds=60)
mark_healthy(-22, ttl_seconds=60)
clear_healthy()
assert is_recently_healthy(-21) is False
assert is_recently_healthy(-22) is False

View file

@ -0,0 +1,226 @@
"""LLMRouterService pool-filter / rebuild tests.
These tests focus on the *config plumbing* (which configs enter the router
pool, rebuild resets state correctly). They stub out the underlying
``litellm.Router`` so we don't need real API keys or network access.
"""
from __future__ import annotations
from unittest.mock import patch
import pytest
from app.services.llm_router_service import LLMRouterService
pytestmark = pytest.mark.unit
def _fake_yaml_config(
*,
id: int,
model_name: str,
billing_tier: str = "free",
) -> dict:
return {
"id": id,
"name": f"yaml-{id}",
"provider": "OPENAI",
"model_name": model_name,
"api_key": "sk-test",
"api_base": "",
"billing_tier": billing_tier,
"rpm": 100,
"tpm": 100_000,
"litellm_params": {},
}
def _fake_openrouter_config(
*,
id: int,
model_name: str,
billing_tier: str,
router_pool_eligible: bool | None = None,
) -> dict:
"""Build a synthetic dynamic-OR config dict for router-pool tests.
Defaults mirror Strategy 3: premium OR enters the pool, free OR stays
out. Callers can override ``router_pool_eligible`` to simulate legacy
configs or to regression-test the filter mechanics directly.
"""
if router_pool_eligible is None:
router_pool_eligible = billing_tier == "premium"
return {
"id": id,
"name": f"or-{id}",
"provider": "OPENROUTER",
"model_name": model_name,
"api_key": "sk-or-test",
"api_base": "",
"billing_tier": billing_tier,
"rpm": 20 if billing_tier == "free" else 200,
"tpm": 100_000 if billing_tier == "free" else 1_000_000,
"litellm_params": {},
"router_pool_eligible": router_pool_eligible,
}
def _reset_router_singleton() -> None:
instance = LLMRouterService.get_instance()
instance._initialized = False
instance._router = None
instance._model_list = []
instance._premium_model_strings = set()
def test_router_pool_includes_or_premium_excludes_or_free():
"""Strategy 3: premium OR joins the pool, free OR stays out.
Dynamic OpenRouter premium entries opt into load balancing alongside
curated YAML configs. Dynamic OR free entries are intentionally kept
out because OpenRouter's free tier enforces a single account-global
quota bucket that per-deployment router accounting can't represent.
"""
_reset_router_singleton()
configs = [
_fake_yaml_config(id=-1, model_name="gpt-4o", billing_tier="premium"),
_fake_yaml_config(id=-2, model_name="gpt-4o-mini", billing_tier="free"),
_fake_openrouter_config(
id=-10_001, model_name="openai/gpt-4o", billing_tier="premium"
),
_fake_openrouter_config(
id=-10_002,
model_name="meta-llama/llama-3.3-70b:free",
billing_tier="free",
),
]
with (
patch("app.services.llm_router_service.Router") as mock_router,
patch(
"app.services.llm_router_service.LLMRouterService._build_context_fallback_groups"
) as mock_ctx_fb,
):
mock_ctx_fb.side_effect = lambda ml: (ml, None)
mock_router.return_value = object()
LLMRouterService.initialize(configs)
pool_models = {
dep["litellm_params"]["model"]
for dep in LLMRouterService.get_instance()._model_list
}
# YAML premium + YAML free + dynamic OR premium are all in the pool.
# Dynamic OR free is NOT (shared-bucket rate limits can't be load-balanced).
assert pool_models == {
"openai/gpt-4o",
"openai/gpt-4o-mini",
"openrouter/openai/gpt-4o",
}
prem = LLMRouterService.get_instance()._premium_model_strings
# YAML premium is fingerprinted under both its model_string and its
# ``base_model`` form (existing behavior we don't want to regress).
assert "openai/gpt-4o" in prem
# Dynamic OR premium is now fingerprinted as premium so pool-level
# calls through the router are billed against premium quota.
assert "openrouter/openai/gpt-4o" in prem
assert LLMRouterService.is_premium_model("openrouter/openai/gpt-4o") is True
# Dynamic OR free never enters the pool, so it's never counted as premium.
assert (
LLMRouterService.is_premium_model("openrouter/meta-llama/llama-3.3-70b:free")
is False
)
def test_router_pool_filter_mechanics_respect_override():
"""The ``router_pool_eligible`` filter itself works independently of tier.
Regression guard: if a future refactor ever sets the flag False on a
premium config (e.g. for maintenance), that config MUST be skipped by
``initialize`` even though its tier is premium.
"""
_reset_router_singleton()
configs = [
_fake_yaml_config(id=-1, model_name="gpt-4o", billing_tier="premium"),
_fake_openrouter_config(
id=-10_001,
model_name="openai/gpt-4o",
billing_tier="premium",
router_pool_eligible=False, # opt out despite being premium
),
]
with (
patch("app.services.llm_router_service.Router") as mock_router,
patch(
"app.services.llm_router_service.LLMRouterService._build_context_fallback_groups"
) as mock_ctx_fb,
):
mock_ctx_fb.side_effect = lambda ml: (ml, None)
mock_router.return_value = object()
LLMRouterService.initialize(configs)
pool_models = {
dep["litellm_params"]["model"]
for dep in LLMRouterService.get_instance()._model_list
}
assert pool_models == {"openai/gpt-4o"}
assert LLMRouterService.is_premium_model("openrouter/openai/gpt-4o") is False
def test_rebuild_refreshes_pool_after_configs_change():
_reset_router_singleton()
configs_v1 = [
_fake_yaml_config(id=-1, model_name="gpt-4o", billing_tier="premium"),
]
configs_v2 = [
*configs_v1,
_fake_yaml_config(id=-2, model_name="gpt-4o-mini", billing_tier="free"),
]
with (
patch("app.services.llm_router_service.Router") as mock_router,
patch(
"app.services.llm_router_service.LLMRouterService._build_context_fallback_groups"
) as mock_ctx_fb,
):
mock_ctx_fb.side_effect = lambda ml: (ml, None)
mock_router.return_value = object()
LLMRouterService.initialize(configs_v1)
assert len(LLMRouterService.get_instance()._model_list) == 1
# ``initialize`` should be a no-op here (already initialized).
LLMRouterService.initialize(configs_v2)
assert len(LLMRouterService.get_instance()._model_list) == 1
# ``rebuild`` must clear the guard and re-run with the new configs.
LLMRouterService.rebuild(configs_v2)
assert len(LLMRouterService.get_instance()._model_list) == 2
def test_auto_model_pin_candidates_include_dynamic_openrouter():
"""Dynamic OR configs must remain Auto-mode thread-pin candidates.
Guards against a future regression where someone adds the
``router_pool_eligible`` filter to ``auto_model_pin_service._global_candidates``.
"""
from app.config import config
from app.services.auto_model_pin_service import _global_candidates
or_premium = _fake_openrouter_config(
id=-10_001, model_name="openai/gpt-4o", billing_tier="premium"
)
or_free = _fake_openrouter_config(
id=-10_002,
model_name="meta-llama/llama-3.3-70b:free",
billing_tier="free",
)
original = config.GLOBAL_LLM_CONFIGS
try:
config.GLOBAL_LLM_CONFIGS = [or_premium, or_free]
candidate_ids = {c["id"] for c in _global_candidates()}
assert candidate_ids == {-10_001, -10_002}
finally:
config.GLOBAL_LLM_CONFIGS = original

View file

@ -0,0 +1,216 @@
"""Unit tests for the dynamic OpenRouter integration."""
from __future__ import annotations
import pytest
from app.services.openrouter_integration_service import (
_OPENROUTER_DYNAMIC_MARKER,
_generate_configs,
_openrouter_tier,
_stable_config_id,
)
pytestmark = pytest.mark.unit
def _minimal_openrouter_model(
*,
model_id: str,
pricing: dict | None = None,
name: str | None = None,
) -> dict:
"""Return a synthetic OpenRouter /api/v1/models entry.
The real API payload includes a lot of fields; we only populate what
``_generate_configs`` actually inspects (architecture, tool support,
context, pricing, id).
"""
return {
"id": model_id,
"name": name or model_id,
"architecture": {"output_modalities": ["text"]},
"supported_parameters": ["tools"],
"context_length": 200_000,
"pricing": pricing or {"prompt": "0.000003", "completion": "0.000015"},
}
# ---------------------------------------------------------------------------
# _openrouter_tier
# ---------------------------------------------------------------------------
def test_openrouter_tier_free_suffix():
assert _openrouter_tier({"id": "foo/bar:free"}) == "free"
def test_openrouter_tier_zero_pricing():
model = {
"id": "foo/bar",
"pricing": {"prompt": "0", "completion": "0"},
}
assert _openrouter_tier(model) == "free"
def test_openrouter_tier_paid():
model = {
"id": "foo/bar",
"pricing": {"prompt": "0.000003", "completion": "0.000015"},
}
assert _openrouter_tier(model) == "premium"
def test_openrouter_tier_missing_pricing_is_premium():
assert _openrouter_tier({"id": "foo/bar"}) == "premium"
assert _openrouter_tier({"id": "foo/bar", "pricing": {}}) == "premium"
# ---------------------------------------------------------------------------
# _stable_config_id
# ---------------------------------------------------------------------------
def test_stable_config_id_deterministic():
taken1: set[int] = set()
taken2: set[int] = set()
a = _stable_config_id("openai/gpt-4o", -10_000, taken1)
b = _stable_config_id("openai/gpt-4o", -10_000, taken2)
assert a == b
assert a < 0
def test_stable_config_id_collision_decrements():
"""When two model_ids hash to the same slot, the second should decrement."""
taken: set[int] = set()
a = _stable_config_id("openai/gpt-4o", -10_000, taken)
# Force a collision by pre-populating ``taken`` with a slot we know will be
# picked.
taken_forced = {a}
b = _stable_config_id("openai/gpt-4o", -10_000, taken_forced)
assert b != a
assert b == a - 1
assert b in taken_forced
def test_stable_config_id_different_models_different_ids():
taken: set[int] = set()
ids = {
_stable_config_id("openai/gpt-4o", -10_000, taken),
_stable_config_id("anthropic/claude-3.5-sonnet", -10_000, taken),
_stable_config_id("google/gemini-2.0-flash", -10_000, taken),
}
assert len(ids) == 3
def test_stable_config_id_survives_catalogue_churn():
"""Removing a model should not shift other models' IDs (the bug we fix)."""
taken1: set[int] = set()
id_a1 = _stable_config_id("openai/gpt-4o", -10_000, taken1)
_ = _stable_config_id("anthropic/claude-3-haiku", -10_000, taken1)
id_c1 = _stable_config_id("google/gemini-2.0-flash", -10_000, taken1)
taken2: set[int] = set()
id_a2 = _stable_config_id("openai/gpt-4o", -10_000, taken2)
id_c2 = _stable_config_id("google/gemini-2.0-flash", -10_000, taken2)
assert id_a1 == id_a2
assert id_c1 == id_c2
# ---------------------------------------------------------------------------
# _generate_configs
# ---------------------------------------------------------------------------
_SETTINGS_BASE: dict = {
"api_key": "sk-or-test",
"id_offset": -10_000,
"rpm": 200,
"tpm": 1_000_000,
"free_rpm": 20,
"free_tpm": 100_000,
"anonymous_enabled_paid": False,
"anonymous_enabled_free": True,
"quota_reserve_tokens": 4000,
}
def test_generate_configs_respects_tier():
"""Premium OR models opt into the router pool; free OR models stay out.
Strategy-3 split: premium participates in LiteLLM Router load balancing,
free stays excluded because OpenRouter enforces a shared global free-tier
bucket that per-deployment router accounting can't represent.
"""
raw = [
_minimal_openrouter_model(model_id="openai/gpt-4o"),
_minimal_openrouter_model(
model_id="meta-llama/llama-3.3-70b-instruct:free",
pricing={"prompt": "0", "completion": "0"},
),
]
cfgs = _generate_configs(raw, dict(_SETTINGS_BASE))
by_model = {c["model_name"]: c for c in cfgs}
paid = by_model["openai/gpt-4o"]
assert paid["billing_tier"] == "premium"
assert paid["rpm"] == 200
assert paid["tpm"] == 1_000_000
assert paid["anonymous_enabled"] is False
assert paid["router_pool_eligible"] is True
assert paid[_OPENROUTER_DYNAMIC_MARKER] is True
free = by_model["meta-llama/llama-3.3-70b-instruct:free"]
assert free["billing_tier"] == "free"
assert free["rpm"] == 20
assert free["tpm"] == 100_000
assert free["anonymous_enabled"] is True
assert free["router_pool_eligible"] is False
def test_generate_configs_excludes_upstream_openrouter_free_router():
"""OpenRouter's own ``openrouter/free`` meta-router must never become a card.
The upstream API returns this as a first-class zero-priced model, so
without an explicit blocklist entry it would slip through every other
filter (text output, tool calling, 200k context, non-Amazon) and land
in the selector as a duplicate of the concrete ``:free`` cards. The
exclusion in ``_EXCLUDED_MODEL_IDS`` prevents that.
"""
raw = [
_minimal_openrouter_model(model_id="openai/gpt-4o"),
_minimal_openrouter_model(
model_id="openrouter/free",
pricing={"prompt": "0", "completion": "0"},
),
]
cfgs = _generate_configs(raw, dict(_SETTINGS_BASE))
model_names = {c["model_name"] for c in cfgs}
assert "openrouter/free" not in model_names
assert "openai/gpt-4o" in model_names
def test_generate_configs_drops_non_text_and_non_tool_models():
raw = [
_minimal_openrouter_model(model_id="openai/gpt-4o"),
{ # image-output model
"id": "openai/dall-e",
"architecture": {"output_modalities": ["image"]},
"supported_parameters": ["tools"],
"context_length": 200_000,
"pricing": {"prompt": "0.01", "completion": "0.01"},
},
{ # text but no tool calling
"id": "openai/completion-only",
"architecture": {"output_modalities": ["text"]},
"supported_parameters": [],
"context_length": 200_000,
"pricing": {"prompt": "0.01", "completion": "0.01"},
},
]
cfgs = _generate_configs(raw, dict(_SETTINGS_BASE))
model_names = [c["model_name"] for c in cfgs]
assert "openai/gpt-4o" in model_names
assert "openai/dall-e" not in model_names
assert "openai/completion-only" not in model_names

View file

@ -0,0 +1,108 @@
"""Tests for deprecated-key warnings and back-compat in
``load_openrouter_integration_settings``.
"""
from __future__ import annotations
from pathlib import Path
import pytest
pytestmark = pytest.mark.unit
def _write_yaml(tmp_path: Path, body: str) -> Path:
cfg_dir = tmp_path / "app" / "config"
cfg_dir.mkdir(parents=True)
cfg_path = cfg_dir / "global_llm_config.yaml"
cfg_path.write_text(body, encoding="utf-8")
return cfg_path
def _patch_base_dir(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None:
from app import config as config_module
monkeypatch.setattr(config_module, "BASE_DIR", tmp_path)
def test_legacy_billing_tier_emits_warning(monkeypatch, tmp_path, capsys):
_write_yaml(
tmp_path,
"""
openrouter_integration:
enabled: true
api_key: "sk-or-test"
billing_tier: "premium"
""".lstrip(),
)
_patch_base_dir(monkeypatch, tmp_path)
from app.config import load_openrouter_integration_settings
settings = load_openrouter_integration_settings()
captured = capsys.readouterr().out
assert settings is not None
assert "billing_tier is deprecated" in captured
def test_legacy_anonymous_enabled_back_compat(monkeypatch, tmp_path, capsys):
_write_yaml(
tmp_path,
"""
openrouter_integration:
enabled: true
api_key: "sk-or-test"
anonymous_enabled: true
""".lstrip(),
)
_patch_base_dir(monkeypatch, tmp_path)
from app.config import load_openrouter_integration_settings
settings = load_openrouter_integration_settings()
captured = capsys.readouterr().out
assert settings is not None
assert settings["anonymous_enabled_paid"] is True
assert settings["anonymous_enabled_free"] is True
assert "anonymous_enabled is" in captured
assert "deprecated" in captured
def test_new_keys_take_priority_over_legacy_back_compat(monkeypatch, tmp_path, capsys):
"""If both legacy and new keys are present, new keys win (setdefault)."""
_write_yaml(
tmp_path,
"""
openrouter_integration:
enabled: true
api_key: "sk-or-test"
anonymous_enabled: true
anonymous_enabled_paid: false
anonymous_enabled_free: false
""".lstrip(),
)
_patch_base_dir(monkeypatch, tmp_path)
from app.config import load_openrouter_integration_settings
settings = load_openrouter_integration_settings()
capsys.readouterr()
assert settings is not None
assert settings["anonymous_enabled_paid"] is False
assert settings["anonymous_enabled_free"] is False
def test_disabled_integration_returns_none(monkeypatch, tmp_path):
_write_yaml(
tmp_path,
"""
openrouter_integration:
enabled: false
api_key: "sk-or-test"
""".lstrip(),
)
_patch_base_dir(monkeypatch, tmp_path)
from app.config import load_openrouter_integration_settings
assert load_openrouter_integration_settings() is None

View file

@ -0,0 +1,331 @@
"""Unit tests for the OpenRouter ``_enrich_health`` background task."""
from __future__ import annotations
from typing import Any
import pytest
from app.services.openrouter_integration_service import (
OpenRouterIntegrationService,
)
from app.services.quality_score import (
_HEALTH_FAIL_RATIO_FALLBACK,
)
pytestmark = pytest.mark.unit
def _or_cfg(
*,
cid: int,
model_name: str,
tier: str = "premium",
static_score: int = 50,
) -> dict:
return {
"id": cid,
"provider": "OPENROUTER",
"model_name": model_name,
"billing_tier": tier,
"auto_pin_tier": "B" if tier == "premium" else "C",
"quality_score_static": static_score,
"quality_score_health": None,
"quality_score": static_score,
"health_gated": False,
}
class _StubResponse:
def __init__(self, *, payload: dict, status_code: int = 200):
self._payload = payload
self.status_code = status_code
def raise_for_status(self) -> None:
if self.status_code >= 400:
raise RuntimeError(f"HTTP {self.status_code}")
def json(self) -> dict:
return self._payload
class _StubAsyncClient:
"""Minimal drop-in for ``httpx.AsyncClient`` used by ``_fetch_endpoints``."""
def __init__(self, responder):
self._responder = responder
self.requests: list[str] = []
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc, tb):
return False
async def get(self, url: str, headers: dict | None = None) -> _StubResponse:
self.requests.append(url)
return self._responder(url)
def _patch_async_client(monkeypatch, responder) -> _StubAsyncClient:
"""Replace ``httpx.AsyncClient`` for the duration of the test."""
client = _StubAsyncClient(responder)
monkeypatch.setattr(
"app.services.openrouter_integration_service.httpx.AsyncClient",
lambda *_args, **_kwargs: client,
)
return client
def _healthy_payload() -> dict:
return {
"data": {
"endpoints": [
{
"status": 0,
"uptime_last_30m": 0.99,
"uptime_last_1d": 0.995,
"uptime_last_5m": 0.99,
}
]
}
}
def _unhealthy_payload() -> dict:
return {
"data": {
"endpoints": [
{
"status": 0,
"uptime_last_30m": 0.55,
"uptime_last_1d": 0.62,
"uptime_last_5m": 0.50,
}
]
}
}
# ---------------------------------------------------------------------------
# Bounded fan-out + happy path
# ---------------------------------------------------------------------------
async def test_enrich_health_marks_healthy_and_gates_unhealthy(monkeypatch):
cfgs = [
_or_cfg(cid=-1, model_name="anthropic/claude-haiku", static_score=70),
_or_cfg(cid=-2, model_name="venice/dead-model", static_score=60),
]
def responder(url: str) -> _StubResponse:
if "anthropic" in url:
return _StubResponse(payload=_healthy_payload())
return _StubResponse(payload=_unhealthy_payload())
_patch_async_client(monkeypatch, responder)
service = OpenRouterIntegrationService()
service._settings = {"api_key": ""}
await service._enrich_health(cfgs)
healthy = next(c for c in cfgs if c["id"] == -1)
gated = next(c for c in cfgs if c["id"] == -2)
assert healthy["health_gated"] is False
assert healthy["quality_score_health"] is not None
assert healthy["quality_score"] >= healthy["quality_score_static"]
assert gated["health_gated"] is True
assert gated["quality_score"] == gated["quality_score_static"]
async def test_enrich_health_only_touches_or_provider(monkeypatch):
"""YAML cfgs that aren't OPENROUTER must be skipped entirely."""
yaml_cfg = {
"id": -1,
"provider": "AZURE_OPENAI",
"model_name": "gpt-5",
"billing_tier": "premium",
"auto_pin_tier": "A",
"quality_score_static": 80,
"quality_score": 80,
"health_gated": False,
}
or_cfg = _or_cfg(cid=-2, model_name="anthropic/claude-haiku")
requests: list[str] = []
def responder(url: str) -> _StubResponse:
requests.append(url)
return _StubResponse(payload=_healthy_payload())
_patch_async_client(monkeypatch, responder)
service = OpenRouterIntegrationService()
service._settings = {}
await service._enrich_health([yaml_cfg, or_cfg])
assert all("anthropic/claude-haiku" in r for r in requests)
# YAML cfg is untouched.
assert yaml_cfg["quality_score"] == 80
assert yaml_cfg["health_gated"] is False
# ---------------------------------------------------------------------------
# Failure ratio fallback
# ---------------------------------------------------------------------------
async def test_enrich_health_falls_back_to_last_good_when_failure_ratio_high(
monkeypatch,
):
"""If >= 25% of fetches fail, keep last-good cache instead of writing
partial data."""
cfgs = [
_or_cfg(cid=-1, model_name="anthropic/claude-haiku", static_score=70),
_or_cfg(cid=-2, model_name="openai/gpt-5", static_score=80),
_or_cfg(cid=-3, model_name="google/gemini-flash", static_score=65),
_or_cfg(cid=-4, model_name="venice/something", static_score=50),
]
service = OpenRouterIntegrationService()
service._settings = {}
# Pre-seed last-good cache with a known-healthy snapshot.
service._health_cache = {
"anthropic/claude-haiku": {"gated": False, "score": 95.0},
}
def all_fail(_url: str) -> _StubResponse:
return _StubResponse(payload={}, status_code=500)
_patch_async_client(monkeypatch, all_fail)
await service._enrich_health(cfgs)
# Above threshold ⇒ degraded; last-good cache wins for the cached cfg.
cached_hit = next(c for c in cfgs if c["model_name"] == "anthropic/claude-haiku")
assert cached_hit["quality_score_health"] == 95.0
assert cached_hit["health_gated"] is False
# Confirm the threshold constant we're testing against is real.
assert _HEALTH_FAIL_RATIO_FALLBACK <= 1.0
async def test_enrich_health_keeps_static_only_with_no_cache_and_failures(
monkeypatch,
):
"""If a fetch fails and there's no last-good cache, the cfg keeps its
static-only ``quality_score`` and is *not* gated by default."""
cfgs = [
_or_cfg(cid=-1, model_name="anthropic/claude-haiku", static_score=70),
]
def fail(_url: str) -> _StubResponse:
return _StubResponse(payload={}, status_code=500)
_patch_async_client(monkeypatch, fail)
service = OpenRouterIntegrationService()
service._settings = {}
await service._enrich_health(cfgs)
cfg = cfgs[0]
assert cfg["health_gated"] is False
assert cfg["quality_score"] == cfg["quality_score_static"]
assert cfg["quality_score_health"] is None
# ---------------------------------------------------------------------------
# Last-good cache: success populates, next failure reuses
# ---------------------------------------------------------------------------
async def test_enrich_health_populates_cache_on_success_then_reuses_on_failure(
monkeypatch,
):
cfg = _or_cfg(cid=-1, model_name="anthropic/claude-haiku", static_score=70)
service = OpenRouterIntegrationService()
service._settings = {}
def healthy(_url: str) -> _StubResponse:
return _StubResponse(payload=_healthy_payload())
_patch_async_client(monkeypatch, healthy)
await service._enrich_health([cfg])
assert "anthropic/claude-haiku" in service._health_cache
cached_score = service._health_cache["anthropic/claude-haiku"]["score"]
assert cached_score is not None
# Next cycle: enough other healthy cfgs so failure ratio stays below
# the 25% threshold even when this one fails individually.
other_cfgs = [
_or_cfg(cid=-2 - i, model_name=f"healthy/m-{i}", static_score=60)
for i in range(10)
]
cfg["quality_score_health"] = None
cfg["quality_score"] = cfg["quality_score_static"]
def mixed(url: str) -> _StubResponse:
if "anthropic" in url:
return _StubResponse(payload={}, status_code=500)
return _StubResponse(payload=_healthy_payload())
_patch_async_client(monkeypatch, mixed)
await service._enrich_health([cfg, *other_cfgs])
assert cfg["quality_score_health"] == cached_score
assert cfg["health_gated"] is False
# ---------------------------------------------------------------------------
# Bounded fan-out: respects top-N caps
# ---------------------------------------------------------------------------
async def test_enrich_health_bounds_premium_fanout(monkeypatch):
"""Top-N premium cap is honoured even when many cfgs are present."""
from app.services.quality_score import _HEALTH_ENRICH_TOP_N_PREMIUM
cfgs = [
_or_cfg(
cid=-i, model_name=f"openai/m-{i}", tier="premium", static_score=100 - i
)
for i in range(1, _HEALTH_ENRICH_TOP_N_PREMIUM + 20)
]
seen: list[str] = []
def responder(url: str) -> _StubResponse:
seen.append(url)
return _StubResponse(payload=_healthy_payload())
_patch_async_client(monkeypatch, responder)
service = OpenRouterIntegrationService()
service._settings = {}
await service._enrich_health(cfgs)
assert len(seen) == _HEALTH_ENRICH_TOP_N_PREMIUM
async def test_enrich_health_no_or_cfgs_is_noop(monkeypatch):
"""When the catalogue has no OR cfgs at all, no HTTP calls fire."""
yaml_cfg: dict[str, Any] = {
"id": -1,
"provider": "AZURE_OPENAI",
"model_name": "gpt-5",
"billing_tier": "premium",
}
requests: list[str] = []
def responder(url: str) -> _StubResponse:
requests.append(url)
return _StubResponse(payload=_healthy_payload())
_patch_async_client(monkeypatch, responder)
service = OpenRouterIntegrationService()
service._settings = {}
await service._enrich_health([yaml_cfg])
assert requests == []

View file

@ -0,0 +1,345 @@
"""Unit tests for the Auto (Fastest) quality scoring module."""
from __future__ import annotations
import time
import pytest
from app.services.quality_score import (
_HEALTH_GATE_UPTIME_PCT,
_OPERATOR_TRUST_BONUS,
aggregate_health,
capabilities_signal,
context_signal,
created_recency_signal,
pricing_band,
slug_penalty,
static_score_or,
static_score_yaml,
)
pytestmark = pytest.mark.unit
# ---------------------------------------------------------------------------
# created_recency_signal
# ---------------------------------------------------------------------------
def test_created_recency_signal_recent_model_scores_high():
now = 1_750_000_000 # ~mid-2025
one_month_ago = now - (30 * 86_400)
assert created_recency_signal(one_month_ago, now) == 20
def test_created_recency_signal_old_model_scores_zero():
now = 1_750_000_000
five_years_ago = now - (5 * 365 * 86_400)
assert created_recency_signal(five_years_ago, now) == 0
def test_created_recency_signal_missing_timestamp_is_neutral():
now = 1_750_000_000
assert created_recency_signal(None, now) == 0
assert created_recency_signal(0, now) == 0
def test_created_recency_signal_monotonic_decay():
now = 1_750_000_000
scores = [
created_recency_signal(now - days * 86_400, now)
for days in (30, 120, 300, 500, 700, 1000, 1500)
]
assert scores == sorted(scores, reverse=True)
# ---------------------------------------------------------------------------
# pricing_band
# ---------------------------------------------------------------------------
def test_pricing_band_free_returns_zero():
assert pricing_band("0", "0") == 0
assert pricing_band(0.0, 0.0) == 0
assert pricing_band(None, None) == 0
def test_pricing_band_handles_unparseable():
assert pricing_band("not-a-number", "0") == 0
assert pricing_band({}, []) == 0 # type: ignore[arg-type]
def test_pricing_band_premium_tiers_increase_with_price():
cheap = pricing_band("0.0000003", "0.0000005")
mid = pricing_band("0.000003", "0.000015")
flagship = pricing_band("0.00001", "0.00005")
assert 0 < cheap < mid < flagship
# ---------------------------------------------------------------------------
# context_signal
# ---------------------------------------------------------------------------
@pytest.mark.parametrize(
"ctx,expected",
[
(1_500_000, 10),
(1_000_000, 10),
(500_000, 8),
(200_000, 6),
(128_000, 4),
(100_000, 2),
(50_000, 0),
(0, 0),
(None, 0),
],
)
def test_context_signal_bands(ctx, expected):
assert context_signal(ctx) == expected
# ---------------------------------------------------------------------------
# capabilities_signal
# ---------------------------------------------------------------------------
def test_capabilities_signal_caps_at_five():
assert (
capabilities_signal(
["tools", "structured_outputs", "reasoning", "include_reasoning"]
)
<= 5
)
def test_capabilities_signal_tools_only():
assert capabilities_signal(["tools"]) == 2
def test_capabilities_signal_empty():
assert capabilities_signal(None) == 0
assert capabilities_signal([]) == 0
# ---------------------------------------------------------------------------
# slug_penalty
# ---------------------------------------------------------------------------
def test_slug_penalty_demotes_tiny_models():
assert slug_penalty("meta-llama/llama-3.2-1b-instruct") < 0
assert slug_penalty("liquid/lfm-7b") < 0
assert slug_penalty("google/gemma-3n-e4b-it") < 0
def test_slug_penalty_skips_capable_mini_nano_lite_models():
"""Critical Option C+ regression: don't penalise modern frontier
models named ``-nano`` / ``-mini`` / ``-lite`` (gpt-5-mini, etc.)."""
assert slug_penalty("openai/gpt-5-mini") == 0
assert slug_penalty("openai/gpt-5-nano") == 0
assert slug_penalty("google/gemini-2.5-flash-lite") == 0
assert slug_penalty("anthropic/claude-haiku-4.5") == 0
def test_slug_penalty_demotes_legacy_variants():
assert slug_penalty("openai/o1-preview") < 0
assert slug_penalty("foo/bar-base") < 0
assert slug_penalty("foo/bar-distill") < 0
def test_slug_penalty_empty_input():
assert slug_penalty("") == 0
# ---------------------------------------------------------------------------
# static_score_or
# ---------------------------------------------------------------------------
def _or_model(
*,
model_id: str,
created: int | None = None,
prompt: str = "0.000003",
completion: str = "0.000015",
context: int = 200_000,
params: list[str] | None = None,
) -> dict:
return {
"id": model_id,
"created": created,
"pricing": {"prompt": prompt, "completion": completion},
"context_length": context,
"supported_parameters": params if params is not None else ["tools"],
}
def test_static_score_or_frontier_premium_beats_free_tiny():
now = 1_750_000_000
frontier = _or_model(
model_id="openai/gpt-5",
created=now - (60 * 86_400),
prompt="0.000005",
completion="0.000020",
context=400_000,
params=["tools", "structured_outputs", "reasoning"],
)
tiny_free = _or_model(
model_id="meta-llama/llama-3.2-1b-instruct:free",
created=now - (5 * 365 * 86_400),
prompt="0",
completion="0",
context=128_000,
params=["tools"],
)
assert static_score_or(frontier, now_ts=now) > static_score_or(
tiny_free, now_ts=now
)
def test_static_score_or_score_is_clamped_0_to_100():
now = int(time.time())
score = static_score_or(_or_model(model_id="openai/gpt-4o"), now_ts=now)
assert 0 <= score <= 100
def test_static_score_or_unknown_provider_is_neutral_not_zero():
now = int(time.time())
score = static_score_or(
_or_model(model_id="some-new-lab/some-model"),
now_ts=now,
)
assert score > 0
def test_static_score_or_recent_release_beats_year_old_same_provider():
now = 1_750_000_000
fresh = _or_model(model_id="openai/gpt-5", created=now - (60 * 86_400))
old = _or_model(model_id="openai/gpt-4-turbo", created=now - (700 * 86_400))
assert static_score_or(fresh, now_ts=now) > static_score_or(old, now_ts=now)
# ---------------------------------------------------------------------------
# static_score_yaml
# ---------------------------------------------------------------------------
def test_static_score_yaml_includes_operator_bonus():
cfg = {
"provider": "AZURE_OPENAI",
"model_name": "gpt-5",
"litellm_params": {"base_model": "azure/gpt-5"},
}
score = static_score_yaml(cfg)
assert score >= _OPERATOR_TRUST_BONUS
def test_static_score_yaml_unknown_provider_still_carries_bonus():
cfg = {
"provider": "SOME_NEW_PROVIDER",
"model_name": "weird-model",
}
score = static_score_yaml(cfg)
assert score >= _OPERATOR_TRUST_BONUS
def test_static_score_yaml_clamped_0_to_100():
cfg = {
"provider": "AZURE_OPENAI",
"model_name": "gpt-5",
"litellm_params": {"base_model": "azure/gpt-5"},
}
assert 0 <= static_score_yaml(cfg) <= 100
# ---------------------------------------------------------------------------
# aggregate_health
# ---------------------------------------------------------------------------
def test_aggregate_health_gates_when_uptime_below_threshold():
"""Live data showed Venice-routed cfgs at 53-68%; this guards that the
90% gate excludes them."""
venice_endpoints = [
{
"status": 0,
"uptime_last_30m": 0.55,
"uptime_last_1d": 0.60,
"uptime_last_5m": 0.50,
},
{
"status": 0,
"uptime_last_30m": 0.65,
"uptime_last_1d": 0.68,
"uptime_last_5m": 0.62,
},
]
gated, score = aggregate_health(venice_endpoints)
assert gated is True
assert score is None
def test_aggregate_health_passes_for_healthy_provider():
healthy = [
{
"status": 0,
"uptime_last_30m": 0.99,
"uptime_last_1d": 0.995,
"uptime_last_5m": 0.99,
},
]
gated, score = aggregate_health(healthy)
assert gated is False
assert score is not None
assert score >= _HEALTH_GATE_UPTIME_PCT
def test_aggregate_health_picks_best_endpoint_across_multiple():
"""Multi-endpoint aggregation should reward the best non-null uptime."""
mixed = [
{"status": 0, "uptime_last_30m": 0.55},
{"status": 0, "uptime_last_30m": 0.97}, # this one passes the gate
]
gated, score = aggregate_health(mixed)
assert gated is False
assert score is not None
def test_aggregate_health_empty_endpoints_gated():
gated, score = aggregate_health([])
assert gated is True
assert score is None
def test_aggregate_health_no_status_zero_gated():
"""Even with high uptime, no OK status means the cfg is broken upstream."""
endpoints = [
{"status": 1, "uptime_last_30m": 0.99},
{"status": 2, "uptime_last_30m": 0.98},
]
gated, score = aggregate_health(endpoints)
assert gated is True
assert score is None
def test_aggregate_health_all_uptime_null_gated():
endpoints = [
{"status": 0, "uptime_last_30m": None, "uptime_last_1d": None},
]
gated, score = aggregate_health(endpoints)
assert gated is True
assert score is None
def test_aggregate_health_pct_normalisation():
"""OpenRouter returns 0-1 fractions; some endpoints surface 0-100%
percentages. Both should reach the same gate decision."""
fraction_form = [{"status": 0, "uptime_last_30m": 0.95}]
pct_form = [{"status": 0, "uptime_last_30m": 95.0}]
g1, s1 = aggregate_health(fraction_form)
g2, s2 = aggregate_health(pct_form)
assert g1 == g2 == False # noqa: E712
assert s1 is not None and s2 is not None
assert abs(s1 - s2) < 0.5

View file

@ -14,6 +14,7 @@ from app.tasks.chat.stream_new_chat import (
_classify_stream_exception,
_contract_enforcement_active,
_evaluate_file_contract_outcome,
_extract_resolved_file_path,
_log_chat_stream_error,
_tool_output_has_error,
)
@ -28,6 +29,39 @@ def test_tool_output_error_detection():
assert not _tool_output_has_error({"result": "Updated file /notes.md"})
def test_extract_resolved_file_path_prefers_structured_path():
assert (
_extract_resolved_file_path(
tool_name="write_file",
tool_output={"status": "completed", "path": "/docs/note.md"},
tool_input=None,
)
== "/docs/note.md"
)
def test_extract_resolved_file_path_falls_back_to_tool_input():
assert (
_extract_resolved_file_path(
tool_name="edit_file",
tool_output={"status": "completed", "result": "updated"},
tool_input={"file_path": "/docs/edited.md"},
)
== "/docs/edited.md"
)
def test_extract_resolved_file_path_does_not_parse_result_text():
assert (
_extract_resolved_file_path(
tool_name="write_file",
tool_output={"result": "Updated file /docs/from-text.md"},
tool_input=None,
)
is None
)
def test_file_write_contract_outcome_reasons():
result = StreamResult(intent_detected="file_write")
passed, reason = _evaluate_file_contract_outcome(result)
@ -159,6 +193,84 @@ def test_stream_exception_classifies_rate_limited():
assert extra is None
def test_stream_exception_classifies_openrouter_429_payload():
exc = Exception(
'OpenrouterException - {"error":{"message":"Provider returned error","code":429,'
'"metadata":{"raw":"foo is temporarily rate-limited upstream"}}}'
)
kind, code, severity, is_expected, user_message, extra = _classify_stream_exception(
exc, flow_label="chat"
)
assert kind == "rate_limited"
assert code == "RATE_LIMITED"
assert severity == "warn"
assert is_expected is True
assert "temporarily rate-limited" in user_message
assert extra is None
@pytest.mark.asyncio
async def test_preflight_swallows_non_rate_limit_errors_and_re_raises_429(monkeypatch):
"""``_preflight_llm`` is best-effort.
- On rate-limit shaped exceptions (provider 429) it MUST re-raise so the
caller can drive the cooldown/repin branch.
- On any other transient failure it MUST swallow the error so the normal
stream path continues without surfacing preflight noise to the user.
"""
from types import SimpleNamespace
from app.tasks.chat.stream_new_chat import _preflight_llm
class _RateLimitedError(Exception):
"""Class-name carries 'RateLimit' so _is_provider_rate_limited triggers."""
rate_calls: list[dict] = []
other_calls: list[dict] = []
async def _fake_acompletion_429(**kwargs):
rate_calls.append(kwargs)
raise _RateLimitedError("simulated 429")
async def _fake_acompletion_other(**kwargs):
other_calls.append(kwargs)
raise RuntimeError("some unrelated transient failure")
fake_llm = SimpleNamespace(
model="openrouter/google/gemma-4-31b-it:free",
api_key="test",
api_base=None,
)
import litellm # type: ignore[import-not-found]
monkeypatch.setattr(litellm, "acompletion", _fake_acompletion_429)
with pytest.raises(_RateLimitedError):
await _preflight_llm(fake_llm)
assert len(rate_calls) == 1
assert rate_calls[0]["max_tokens"] == 1
assert rate_calls[0]["stream"] is False
monkeypatch.setattr(litellm, "acompletion", _fake_acompletion_other)
# MUST NOT raise: non-rate-limit failures are swallowed.
await _preflight_llm(fake_llm)
assert len(other_calls) == 1
@pytest.mark.asyncio
async def test_preflight_skipped_for_auto_router_model():
"""Router-mode ``model='auto'`` has no single deployment to ping; the
LiteLLM router itself owns per-deployment rate-limit accounting, so the
preflight helper must short-circuit instead of issuing a probe."""
from types import SimpleNamespace
from app.tasks.chat.stream_new_chat import _preflight_llm
fake_llm = SimpleNamespace(model="auto", api_key="x", api_base=None)
# Should return without raising or making any LiteLLM call.
await _preflight_llm(fake_llm)
def test_stream_exception_classifies_thread_busy():
exc = BusyError(request_id="thread-123")
kind, code, severity, is_expected, user_message, extra = _classify_stream_exception(

View file

@ -1,11 +1,8 @@
"use client";
import { useQueryClient } from "@tanstack/react-query";
import { CheckCircle2 } from "lucide-react";
import Link from "next/link";
import { useParams } from "next/navigation";
import { useEffect } from "react";
import { USER_QUERY_KEY } from "@/atoms/user/user-query.atoms";
import { Button } from "@/components/ui/button";
import {
Card,
@ -18,14 +15,8 @@ import {
export default function PurchaseSuccessPage() {
const params = useParams();
const queryClient = useQueryClient();
const searchSpaceId = String(params.search_space_id ?? "");
useEffect(() => {
void queryClient.invalidateQueries({ queryKey: USER_QUERY_KEY });
void queryClient.invalidateQueries({ queryKey: ["token-status"] });
}, [queryClient]);
return (
<div className="flex min-h-[calc(100vh-64px)] items-center justify-center px-4 py-8">
<Card className="w-full max-w-lg">

View file

@ -8,7 +8,10 @@ const userQueryFn = () => userApiService.getMe();
export const currentUserAtom = atomWithQuery(() => {
return {
queryKey: USER_QUERY_KEY,
staleTime: 5 * 60 * 1000,
// Live-changing numeric fields (pages_*, premium_tokens_*) are now
// pushed via Zero (queries.user.me()), so /users/me only needs to
// fire once per session for the static profile fields.
staleTime: Infinity,
enabled: !!getBearerToken(),
queryFn: userQueryFn,
};

View file

@ -19,6 +19,7 @@ import remarkMath from "remark-math";
import { openEditorPanelAtom } from "@/atoms/editor/editor-panel.atom";
import { ImagePreview, ImageRoot, ImageZoom } from "@/components/assistant-ui/image";
import "katex/dist/katex.min.css";
import { toast } from "sonner";
import { processChildrenWithCitations } from "@/components/citations/citation-renderer";
import { Skeleton } from "@/components/ui/skeleton";
import {
@ -30,6 +31,7 @@ import {
TableRow,
} from "@/components/ui/table";
import { useElectronAPI } from "@/hooks/use-platform";
import { documentsApiService } from "@/lib/apis/documents-api.service";
import { type CitationUrlMap, preprocessCitationMarkdown } from "@/lib/citations/citation-parser";
import { cn } from "@/lib/utils";
@ -194,6 +196,85 @@ function isVirtualFilePathToken(value: string): boolean {
return segments.length >= 2;
}
function isStandaloneDocumentsPathText(node: ReactNode): string | null {
if (typeof node !== "string") return null;
const value = node.trim();
if (!value.startsWith("/documents/")) return null;
if (value.includes(" ")) return null;
const normalized = value.replace(/\/+$/, "");
const leaf = normalized.split("/").filter(Boolean).at(-1) ?? "";
if (!leaf || !leaf.includes(".")) return null;
return value;
}
function FilePathLink({ path, className }: { path: string; className?: string }) {
const openEditorPanel = useSetAtom(openEditorPanelAtom);
const params = useParams();
const electronAPI = useElectronAPI();
const searchSpaceIdParam = params?.search_space_id;
const parsedSearchSpaceId = Array.isArray(searchSpaceIdParam)
? Number(searchSpaceIdParam[0])
: Number(searchSpaceIdParam);
const resolvedSearchSpaceId = Number.isFinite(parsedSearchSpaceId)
? parsedSearchSpaceId
: undefined;
return (
<button
type="button"
className={cn(
"cursor-pointer font-mono text-[0.9em] font-medium text-primary underline underline-offset-4 transition-colors hover:text-primary/80",
className
)}
onClick={(event) => {
event.preventDefault();
event.stopPropagation();
void (async () => {
if (electronAPI) {
let resolvedLocalPath = path;
if (electronAPI.getAgentFilesystemMounts) {
try {
const mounts = (await electronAPI.getAgentFilesystemMounts(
resolvedSearchSpaceId
)) as AgentFilesystemMount[];
resolvedLocalPath = normalizeLocalVirtualPathForEditor(path, mounts);
} catch {
// Fall back to the raw path if mount lookup fails.
}
}
openEditorPanel({
kind: "local_file",
localFilePath: resolvedLocalPath,
title: resolvedLocalPath.split("/").pop() || resolvedLocalPath,
searchSpaceId: resolvedSearchSpaceId,
});
return;
}
if (!resolvedSearchSpaceId || !path.startsWith("/documents/")) return;
try {
const doc = await documentsApiService.getDocumentByVirtualPath({
search_space_id: resolvedSearchSpaceId,
virtual_path: path,
});
openEditorPanel({
kind: "document",
documentId: doc.id,
searchSpaceId: resolvedSearchSpaceId,
title: doc.title,
});
} catch {
toast.error("Document not found in knowledge base.");
}
})();
}}
title="Open in editor panel"
>
{path}
</button>
);
}
function MarkdownImage({ src, alt }: { src?: string; alt?: string }) {
if (!src) return null;
@ -311,9 +392,14 @@ const defaultComponents = memoizeMarkdownComponents({
},
p: function P({ className, children, ...props }) {
const urlMap = useCitationUrlMap();
const standalonePath = isStandaloneDocumentsPathText(children);
return (
<p className={cn("aui-md-p mt-5 mb-5 leading-7 first:mt-0 last:mb-0", className)} {...props}>
{processChildrenWithCitations(children, urlMap)}
{standalonePath ? (
<FilePathLink path={standalonePath} />
) : (
processChildrenWithCitations(children, urlMap)
)}
</p>
);
},
@ -400,8 +486,6 @@ const defaultComponents = memoizeMarkdownComponents({
code: function Code({ className, children, ...props }) {
const isCodeBlock = useIsMarkdownCodeBlock();
const { resolvedTheme } = useTheme();
const openEditorPanel = useSetAtom(openEditorPanelAtom);
const params = useParams();
const electronAPI = useElectronAPI();
const language = /language-(\w+)/.exec(className || "")?.[1] ?? "text";
const codeString = String(children).replace(/\n$/, "");
@ -418,53 +502,17 @@ const defaultComponents = memoizeMarkdownComponents({
const isLikelyFolder =
inlineValue.endsWith("/") || !leafSegment || !leafSegment.includes(".");
const isLocalPath =
!!electronAPI &&
isVirtualFilePathToken(inlineValue) &&
!inlineValue.startsWith("//") &&
!isLikelyFolder;
const displayLocalPath = inlineValue.replace(/^\/+/, "");
const searchSpaceIdParam = params?.search_space_id;
const parsedSearchSpaceId = Array.isArray(searchSpaceIdParam)
? Number(searchSpaceIdParam[0])
: Number(searchSpaceIdParam);
(isVirtualFilePathToken(inlineValue) &&
!inlineValue.startsWith("//") &&
!isLikelyFolder &&
!!electronAPI) ||
(isVirtualFilePathToken(inlineValue) &&
!inlineValue.startsWith("//") &&
!isLikelyFolder &&
!electronAPI &&
inlineValue.startsWith("/documents/"));
if (isLocalPath) {
return (
<button
type="button"
className={cn(
"cursor-pointer font-mono text-[0.9em] font-medium text-primary underline underline-offset-4 transition-colors hover:text-primary/80"
)}
onClick={(event) => {
event.preventDefault();
event.stopPropagation();
void (async () => {
let resolvedLocalPath = inlineValue;
const resolvedSearchSpaceId = Number.isFinite(parsedSearchSpaceId)
? parsedSearchSpaceId
: undefined;
if (electronAPI?.getAgentFilesystemMounts) {
try {
const mounts = (await electronAPI.getAgentFilesystemMounts(
resolvedSearchSpaceId
)) as AgentFilesystemMount[];
resolvedLocalPath = normalizeLocalVirtualPathForEditor(inlineValue, mounts);
} catch {
// Fall back to the raw inline path if mount lookup fails.
}
}
openEditorPanel({
kind: "local_file",
localFilePath: resolvedLocalPath,
title: resolvedLocalPath.split("/").pop() || resolvedLocalPath,
searchSpaceId: resolvedSearchSpaceId,
});
})();
}}
title="Open in editor panel"
>
{displayLocalPath}
</button>
);
return <FilePathLink path={inlineValue} className="text-[0.9em]" />;
}
return (
<code

View file

@ -681,14 +681,6 @@ export function LayoutDataProvider({ searchSpaceId, children }: LayoutDataProvid
}
}, [chatToRename, newChatTitle, queryClient, searchSpaceId, tSidebar]);
// Page usage
const pageUsage = user
? {
pagesUsed: user.pages_used,
pagesLimit: user.pages_limit,
}
: undefined;
// Detect if we're on the chat page (needs overflow-hidden for chat's own scroll)
const isChatPage = pathname?.includes("/new-chat") ?? false;
@ -723,7 +715,6 @@ export function LayoutDataProvider({ searchSpaceId, children }: LayoutDataProvid
onManageMembers={handleManageMembers}
onUserSettings={handleUserSettings}
onLogout={handleLogout}
pageUsage={pageUsage}
theme={theme}
setTheme={setTheme}
isChatPage={isChatPage}

View file

@ -132,7 +132,7 @@ function MainContentPanel({
const isDocumentTab = activeTab?.type === "document";
return (
<div className="relative flex flex-1 flex-col min-w-0">
<div className="relative isolate flex flex-1 flex-col min-w-0">
<TabBar
onTabSwitch={onTabSwitch}
onNewChat={onNewChat}

View file

@ -0,0 +1,15 @@
"use client";
import { useQuery } from "@rocicorp/zero/react";
import { useIsAnonymous } from "@/contexts/anonymous-mode";
import { queries } from "@/zero/queries";
import { PageUsageDisplay } from "./PageUsageDisplay";
export function AuthenticatedPageUsageDisplay() {
const isAnonymous = useIsAnonymous();
const [me] = useQuery(queries.user.me({}));
if (isAnonymous || !me) return null;
return <PageUsageDisplay pagesUsed={me.pagesUsed} pagesLimit={me.pagesLimit} />;
}

View file

@ -1,23 +1,18 @@
"use client";
import { useQuery } from "@tanstack/react-query";
import { useQuery } from "@rocicorp/zero/react";
import { Progress } from "@/components/ui/progress";
import { useIsAnonymous } from "@/contexts/anonymous-mode";
import { stripeApiService } from "@/lib/apis/stripe-api.service";
import { queries } from "@/zero/queries";
export function PremiumTokenUsageDisplay() {
const isAnonymous = useIsAnonymous();
const { data: tokenStatus } = useQuery({
queryKey: ["token-status"],
queryFn: () => stripeApiService.getTokenStatus(),
staleTime: 60_000,
enabled: !isAnonymous,
});
const [me] = useQuery(queries.user.me({}));
if (!tokenStatus) return null;
if (isAnonymous || !me) return null;
const usagePercentage = Math.min(
(tokenStatus.premium_tokens_used / Math.max(tokenStatus.premium_tokens_limit, 1)) * 100,
(me.premiumTokensUsed / Math.max(me.premiumTokensLimit, 1)) * 100,
100
);
@ -31,8 +26,7 @@ export function PremiumTokenUsageDisplay() {
<div className="space-y-1.5">
<div className="flex justify-between items-center text-xs">
<span className="text-muted-foreground">
{formatTokens(tokenStatus.premium_tokens_used)} /{" "}
{formatTokens(tokenStatus.premium_tokens_limit)} tokens
{formatTokens(me.premiumTokensUsed)} / {formatTokens(me.premiumTokensLimit)} tokens
</span>
<span className="font-medium">{usagePercentage.toFixed(0)}%</span>
</div>

View file

@ -12,9 +12,9 @@ import { useIsAnonymous } from "@/contexts/anonymous-mode";
import { cn } from "@/lib/utils";
import { SIDEBAR_MIN_WIDTH } from "../../hooks/useSidebarResize";
import type { ChatItem, NavItem, PageUsage, SearchSpace, User } from "../../types/layout.types";
import { AuthenticatedPageUsageDisplay } from "./AuthenticatedPageUsageDisplay";
import { ChatListItem } from "./ChatListItem";
import { NavSection } from "./NavSection";
import { PageUsageDisplay } from "./PageUsageDisplay";
import { PremiumTokenUsageDisplay } from "./PremiumTokenUsageDisplay";
import { SidebarButton } from "./SidebarButton";
import { SidebarCollapseButton } from "./SidebarCollapseButton";
@ -338,9 +338,7 @@ function SidebarUsageFooter({
return (
<div className="px-3 py-3 border-t space-y-3">
<PremiumTokenUsageDisplay />
{pageUsage && (
<PageUsageDisplay pagesUsed={pageUsage.pagesUsed} pagesLimit={pageUsage.pagesLimit} />
)}
<AuthenticatedPageUsageDisplay />
<div className="space-y-0.5">
<Link
href={`/dashboard/${searchSpaceId}/more-pages`}

View file

@ -1,5 +1,6 @@
"use client";
import { useQuery as useZeroQuery } from "@rocicorp/zero/react";
import { useMutation, useQuery } from "@tanstack/react-query";
import { Minus, Plus } from "lucide-react";
import { useParams } from "next/navigation";
@ -11,6 +12,7 @@ import { Spinner } from "@/components/ui/spinner";
import { stripeApiService } from "@/lib/apis/stripe-api.service";
import { AppError } from "@/lib/error";
import { cn } from "@/lib/utils";
import { queries } from "@/zero/queries";
const TOKEN_PACK_SIZE = 1_000_000;
const PRICE_PER_PACK_USD = 1;
@ -21,11 +23,15 @@ export function BuyTokensContent() {
const searchSpaceId = Number(params?.search_space_id);
const [quantity, setQuantity] = useState(1);
// Server config flag: stays on REST, not per-user.
const { data: tokenStatus } = useQuery({
queryKey: ["token-status"],
queryFn: () => stripeApiService.getTokenStatus(),
});
// Live per-user usage via Zero.
const [me] = useZeroQuery(queries.user.me({}));
const purchaseMutation = useMutation({
mutationFn: stripeApiService.createTokenCheckoutSession,
onSuccess: (response) => {
@ -54,12 +60,11 @@ export function BuyTokensContent() {
);
}
const usagePercentage = tokenStatus
? Math.min(
(tokenStatus.premium_tokens_used / Math.max(tokenStatus.premium_tokens_limit, 1)) * 100,
100
)
: 0;
const used = me?.premiumTokensUsed ?? 0;
const limit = me?.premiumTokensLimit ?? 0;
// Mirrors the backend formula in stripe_routes.py:608 (max(0, limit - used)).
const remaining = Math.max(0, limit - used);
const usagePercentage = me ? Math.min((used / Math.max(limit, 1)) * 100, 100) : 0;
return (
<div className="w-full space-y-5">
@ -68,18 +73,17 @@ export function BuyTokensContent() {
<p className="mt-1 text-sm text-muted-foreground">$1 per 1M tokens, pay as you go</p>
</div>
{tokenStatus && (
{me && (
<div className="rounded-lg border bg-muted/20 p-3 space-y-1.5">
<div className="flex justify-between items-center text-xs">
<span className="text-muted-foreground">
{tokenStatus.premium_tokens_used.toLocaleString()} /{" "}
{tokenStatus.premium_tokens_limit.toLocaleString()} premium tokens
{used.toLocaleString()} / {limit.toLocaleString()} premium tokens
</span>
<span className="font-medium">{usagePercentage.toFixed(0)}%</span>
</div>
<Progress value={usagePercentage} className="h-1.5" />
<p className="text-[11px] text-muted-foreground">
{tokenStatus.premium_tokens_remaining.toLocaleString()} tokens remaining
{remaining.toLocaleString()} tokens remaining
</p>
</div>
)}

View file

@ -5,6 +5,7 @@ import {
type DeleteDocumentRequest,
deleteDocumentRequest,
deleteDocumentResponse,
documentTitleRead,
type GetDocumentByChunkRequest,
type GetDocumentChunksRequest,
type GetDocumentRequest,
@ -269,6 +270,17 @@ class DocumentsApiService {
);
};
getDocumentByVirtualPath = async (request: { search_space_id: number; virtual_path: string }) => {
const params = new URLSearchParams({
search_space_id: String(request.search_space_id),
virtual_path: request.virtual_path,
});
return baseApiService.get(
`/api/v1/documents/by-virtual-path?${params.toString()}`,
documentTitleRead
);
};
/**
* Get document type counts
*/

View file

@ -3,6 +3,7 @@ import { chatSessionQueries, commentQueries, messageQueries } from "./chat";
import { connectorQueries, documentQueries } from "./documents";
import { folderQueries } from "./folders";
import { notificationQueries } from "./inbox";
import { userQueries } from "./user";
export const queries = defineQueries({
notifications: notificationQueries,
@ -12,4 +13,5 @@ export const queries = defineQueries({
messages: messageQueries,
comments: commentQueries,
chatSession: chatSessionQueries,
user: userQueries,
});

View file

@ -0,0 +1,11 @@
import { defineQuery } from "@rocicorp/zero";
import { z } from "zod";
import { zql } from "../schema/index";
export const userQueries = {
me: defineQuery(z.object({}), ({ ctx }) => {
const userId = ctx?.userId;
if (!userId) return zql.user.where("id", "__none__").one();
return zql.user.where("id", userId).one();
}),
};

View file

@ -3,6 +3,7 @@ import { chatCommentTable, chatSessionStateTable, newChatMessageTable } from "./
import { documentTable, searchSourceConnectorTable } from "./documents";
import { folderTable } from "./folders";
import { notificationTable } from "./inbox";
import { userTable } from "./user";
const chatCommentRelationships = relationships(chatCommentTable, ({ one }) => ({
message: one({
@ -34,6 +35,7 @@ export const schema = createSchema({
newChatMessageTable,
chatCommentTable,
chatSessionStateTable,
userTable,
],
relationships: [chatCommentRelationships, newChatMessageRelationships],
});

View file

@ -0,0 +1,11 @@
import { number, string, table } from "@rocicorp/zero";
export const userTable = table("user")
.columns({
id: string(),
pagesLimit: number().from("pages_limit"),
pagesUsed: number().from("pages_used"),
premiumTokensLimit: number().from("premium_tokens_limit"),
premiumTokensUsed: number().from("premium_tokens_used"),
})
.primaryKey("id");