mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-06-16 21:05:20 +02:00
Merge pull request #1332 from AnishSarkar22/feat/model-pinnning-mode
Some checks are pending
Build and Push Docker Images / tag_release (push) Waiting to run
Build and Push Docker Images / build (./surfsense_backend, ./surfsense_backend/Dockerfile, backend, surfsense-backend, ubuntu-24.04-arm, linux/arm64, arm64) (push) Blocked by required conditions
Build and Push Docker Images / build (./surfsense_backend, ./surfsense_backend/Dockerfile, backend, surfsense-backend, ubuntu-latest, linux/amd64, amd64) (push) Blocked by required conditions
Build and Push Docker Images / build (./surfsense_web, ./surfsense_web/Dockerfile, web, surfsense-web, ubuntu-24.04-arm, linux/arm64, arm64) (push) Blocked by required conditions
Build and Push Docker Images / build (./surfsense_web, ./surfsense_web/Dockerfile, web, surfsense-web, ubuntu-latest, linux/amd64, amd64) (push) Blocked by required conditions
Build and Push Docker Images / create_manifest (backend, surfsense-backend) (push) Blocked by required conditions
Build and Push Docker Images / create_manifest (web, surfsense-web) (push) Blocked by required conditions
Some checks are pending
Build and Push Docker Images / tag_release (push) Waiting to run
Build and Push Docker Images / build (./surfsense_backend, ./surfsense_backend/Dockerfile, backend, surfsense-backend, ubuntu-24.04-arm, linux/arm64, arm64) (push) Blocked by required conditions
Build and Push Docker Images / build (./surfsense_backend, ./surfsense_backend/Dockerfile, backend, surfsense-backend, ubuntu-latest, linux/amd64, amd64) (push) Blocked by required conditions
Build and Push Docker Images / build (./surfsense_web, ./surfsense_web/Dockerfile, web, surfsense-web, ubuntu-24.04-arm, linux/arm64, arm64) (push) Blocked by required conditions
Build and Push Docker Images / build (./surfsense_web, ./surfsense_web/Dockerfile, web, surfsense-web, ubuntu-latest, linux/amd64, amd64) (push) Blocked by required conditions
Build and Push Docker Images / create_manifest (backend, surfsense-backend) (push) Blocked by required conditions
Build and Push Docker Images / create_manifest (web, surfsense-web) (push) Blocked by required conditions
feat: Auto-pin quality scoring, OpenRouter tier refactor and live usage sidebar
This commit is contained in:
commit
451a98936e
35 changed files with 3975 additions and 319 deletions
|
|
@ -4,10 +4,12 @@ Revision ID: 138
|
|||
Revises: 137
|
||||
Create Date: 2026-04-30
|
||||
|
||||
Add thread-level fields to persist Auto (Fastest) model pinning metadata:
|
||||
- pinned_llm_config_id: concrete resolved config id used for this thread
|
||||
- pinned_auto_mode: auto policy identifier (currently "auto_fastest")
|
||||
- pinned_at: timestamp when the pin was created/refreshed
|
||||
Add a single thread-level column to persist the Auto (Fastest) model pin:
|
||||
- pinned_llm_config_id: concrete resolved global LLM config id used for this
|
||||
thread. NULL means "no pin; Auto will resolve on next turn".
|
||||
|
||||
The column is unindexed: all reads are by new_chat_threads.id (primary key),
|
||||
so a secondary index would be dead write amplification.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
|
@ -27,29 +29,14 @@ def upgrade() -> None:
|
|||
"ALTER TABLE new_chat_threads "
|
||||
"ADD COLUMN IF NOT EXISTS pinned_llm_config_id INTEGER"
|
||||
)
|
||||
op.execute(
|
||||
"ALTER TABLE new_chat_threads "
|
||||
"ADD COLUMN IF NOT EXISTS pinned_auto_mode VARCHAR(32)"
|
||||
)
|
||||
op.execute(
|
||||
"ALTER TABLE new_chat_threads "
|
||||
"ADD COLUMN IF NOT EXISTS pinned_at TIMESTAMP WITH TIME ZONE"
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"CREATE INDEX IF NOT EXISTS ix_new_chat_threads_pinned_llm_config_id "
|
||||
"ON new_chat_threads (pinned_llm_config_id)"
|
||||
)
|
||||
op.execute(
|
||||
"CREATE INDEX IF NOT EXISTS ix_new_chat_threads_pinned_auto_mode "
|
||||
"ON new_chat_threads (pinned_auto_mode)"
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Drop any shape the thread row may be carrying. The extra columns and
|
||||
# indexes only exist on dev DBs that ran an earlier draft of 138; IF EXISTS
|
||||
# makes each statement a safe no-op on the lean shape.
|
||||
op.execute("DROP INDEX IF EXISTS ix_new_chat_threads_pinned_auto_mode")
|
||||
op.execute("DROP INDEX IF EXISTS ix_new_chat_threads_pinned_llm_config_id")
|
||||
|
||||
op.execute("ALTER TABLE new_chat_threads DROP COLUMN IF EXISTS pinned_at")
|
||||
op.execute("ALTER TABLE new_chat_threads DROP COLUMN IF EXISTS pinned_auto_mode")
|
||||
op.execute(
|
||||
|
|
|
|||
|
|
@ -0,0 +1,160 @@
|
|||
"""add user table to zero_publication with column list
|
||||
|
||||
Adds the "user" table to zero_publication with a column-list publication
|
||||
so that only the 5 fields driving the live usage meters are replicated
|
||||
through WAL -> zero-cache -> browser IndexedDB:
|
||||
|
||||
id, pages_limit, pages_used,
|
||||
premium_tokens_limit, premium_tokens_used
|
||||
|
||||
Sensitive columns (hashed_password, email, oauth_account, display_name,
|
||||
avatar_url, memory_md, refresh_tokens, last_login, etc.) are NOT
|
||||
included in the publication, so they never enter WAL replication.
|
||||
|
||||
Also re-asserts REPLICA IDENTITY DEFAULT on "user" for idempotency
|
||||
(it is already DEFAULT today since "user" was never in the
|
||||
TABLES_WITH_FULL_IDENTITY list of migration 117).
|
||||
|
||||
IMPORTANT - before AND after running this migration:
|
||||
1. Stop zero-cache (it holds replication locks that will deadlock DDL)
|
||||
2. Run: alembic upgrade head
|
||||
3. Delete / reset the zero-cache data volume
|
||||
4. Restart zero-cache (it will do a fresh initial sync)
|
||||
|
||||
Revision ID: 139
|
||||
Revises: 138
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
revision: str = "139"
|
||||
down_revision: str | None = "138"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
PUBLICATION_NAME = "zero_publication"
|
||||
|
||||
# Document column list as left by migration 117. Must match exactly.
|
||||
DOCUMENT_COLS = [
|
||||
"id",
|
||||
"title",
|
||||
"document_type",
|
||||
"search_space_id",
|
||||
"folder_id",
|
||||
"created_by_id",
|
||||
"status",
|
||||
"created_at",
|
||||
"updated_at",
|
||||
]
|
||||
|
||||
# Five fields needed by the live usage meters (sidebar Tokens/Pages,
|
||||
# Buy Tokens content). Keep this list narrow on purpose: anything added
|
||||
# here flows into WAL and IndexedDB for every connected browser.
|
||||
USER_COLS = [
|
||||
"id",
|
||||
"pages_limit",
|
||||
"pages_used",
|
||||
"premium_tokens_limit",
|
||||
"premium_tokens_used",
|
||||
]
|
||||
|
||||
|
||||
def _terminate_blocked_pids(conn, table: str) -> None:
|
||||
"""Kill backends whose locks on *table* would block our AccessExclusiveLock."""
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"SELECT pg_terminate_backend(l.pid) "
|
||||
"FROM pg_locks l "
|
||||
"JOIN pg_class c ON c.oid = l.relation "
|
||||
"WHERE c.relname = :tbl "
|
||||
" AND l.pid != pg_backend_pid()"
|
||||
),
|
||||
{"tbl": table},
|
||||
)
|
||||
|
||||
|
||||
def _has_zero_version(conn, table: str) -> bool:
|
||||
return (
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"SELECT 1 FROM information_schema.columns "
|
||||
"WHERE table_name = :tbl AND column_name = '_0_version'"
|
||||
),
|
||||
{"tbl": table},
|
||||
).fetchone()
|
||||
is not None
|
||||
)
|
||||
|
||||
|
||||
def _build_publication_ddl(
|
||||
documents_has_zero_ver: bool, user_has_zero_ver: bool
|
||||
) -> str:
|
||||
doc_cols = DOCUMENT_COLS + (['"_0_version"'] if documents_has_zero_ver else [])
|
||||
user_cols = USER_COLS + (['"_0_version"'] if user_has_zero_ver else [])
|
||||
doc_col_list = ", ".join(doc_cols)
|
||||
user_col_list = ", ".join(user_cols)
|
||||
return (
|
||||
f"CREATE PUBLICATION {PUBLICATION_NAME} FOR TABLE "
|
||||
f"notifications, "
|
||||
f"documents ({doc_col_list}), "
|
||||
f"folders, "
|
||||
f"search_source_connectors, "
|
||||
f"new_chat_messages, "
|
||||
f"chat_comments, "
|
||||
f"chat_session_state, "
|
||||
f'"user" ({user_col_list})'
|
||||
)
|
||||
|
||||
|
||||
def _build_publication_ddl_without_user(documents_has_zero_ver: bool) -> str:
|
||||
doc_cols = DOCUMENT_COLS + (['"_0_version"'] if documents_has_zero_ver else [])
|
||||
doc_col_list = ", ".join(doc_cols)
|
||||
return (
|
||||
f"CREATE PUBLICATION {PUBLICATION_NAME} FOR TABLE "
|
||||
f"notifications, "
|
||||
f"documents ({doc_col_list}), "
|
||||
f"folders, "
|
||||
f"search_source_connectors, "
|
||||
f"new_chat_messages, "
|
||||
f"chat_comments, "
|
||||
f"chat_session_state"
|
||||
)
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
# asyncpg requires LOCK TABLE inside a transaction block. Alembic already
|
||||
# opened one via context.begin_transaction(), but the driver still errors
|
||||
# unless we use an explicit SAVEPOINT (nested transaction) for this block.
|
||||
tx = conn.begin_nested() if conn.in_transaction() else conn.begin()
|
||||
with tx:
|
||||
conn.execute(sa.text("SET lock_timeout = '10s'"))
|
||||
|
||||
_terminate_blocked_pids(conn, "user")
|
||||
conn.execute(sa.text('LOCK TABLE "user" IN ACCESS EXCLUSIVE MODE'))
|
||||
|
||||
# Idempotent: "user" was never in TABLES_WITH_FULL_IDENTITY of
|
||||
# migration 117, so this is already DEFAULT. Re-assert anyway so
|
||||
# the column-list publication stays valid (DEFAULT identity only
|
||||
# requires the PK to be in the column list).
|
||||
conn.execute(sa.text('ALTER TABLE "user" REPLICA IDENTITY DEFAULT'))
|
||||
|
||||
conn.execute(sa.text(f"DROP PUBLICATION IF EXISTS {PUBLICATION_NAME}"))
|
||||
|
||||
documents_has_zero_ver = _has_zero_version(conn, "documents")
|
||||
user_has_zero_ver = _has_zero_version(conn, "user")
|
||||
|
||||
conn.execute(
|
||||
sa.text(_build_publication_ddl(documents_has_zero_ver, user_has_zero_ver))
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
conn.execute(sa.text(f"DROP PUBLICATION IF EXISTS {PUBLICATION_NAME}"))
|
||||
documents_has_zero_ver = _has_zero_version(conn, "documents")
|
||||
conn.execute(sa.text(_build_publication_ddl_without_user(documents_has_zero_ver)))
|
||||
|
|
@ -61,6 +61,9 @@ class _ThreadLockManager:
|
|||
self._cancel_events: dict[str, asyncio.Event] = {}
|
||||
self._cancel_requested_at_ms: dict[str, int] = {}
|
||||
self._cancel_attempt_count: dict[str, int] = {}
|
||||
# Monotonic per-thread epoch used to prevent stale middleware
|
||||
# teardown from releasing a newer turn's lock.
|
||||
self._turn_epoch: dict[str, int] = {}
|
||||
|
||||
def lock_for(self, thread_id: str) -> asyncio.Lock:
|
||||
lock = self._locks.get(thread_id)
|
||||
|
|
@ -107,6 +110,14 @@ class _ThreadLockManager:
|
|||
self._cancel_requested_at_ms.pop(thread_id, None)
|
||||
self._cancel_attempt_count.pop(thread_id, None)
|
||||
|
||||
def bump_turn_epoch(self, thread_id: str) -> int:
|
||||
epoch = self._turn_epoch.get(thread_id, 0) + 1
|
||||
self._turn_epoch[thread_id] = epoch
|
||||
return epoch
|
||||
|
||||
def current_turn_epoch(self, thread_id: str) -> int:
|
||||
return self._turn_epoch.get(thread_id, 0)
|
||||
|
||||
def end_turn(self, thread_id: str) -> None:
|
||||
"""Best-effort terminal cleanup for a thread turn.
|
||||
|
||||
|
|
@ -114,6 +125,10 @@ class _ThreadLockManager:
|
|||
finally-blocks where middleware teardown might be skipped due to abort
|
||||
or disconnect edge-cases.
|
||||
"""
|
||||
# Invalidate any in-flight middleware holder first. This guarantees a
|
||||
# stale ``aafter_agent`` from an older attempt cannot unlock a newer
|
||||
# retry that already acquired the lock for the same thread.
|
||||
self.bump_turn_epoch(thread_id)
|
||||
lock = self._locks.get(thread_id)
|
||||
if lock is not None and lock.locked():
|
||||
lock.release()
|
||||
|
|
@ -178,10 +193,10 @@ class BusyMutexMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, Respo
|
|||
super().__init__()
|
||||
self._require_thread_id = require_thread_id
|
||||
self.tools = []
|
||||
# Per-call locks owned by this middleware. We track them as
|
||||
# an instance attribute so ``aafter_agent`` knows which lock
|
||||
# to release.
|
||||
self._held_locks: dict[str, asyncio.Lock] = {}
|
||||
# Per-call lock ownership tracked as (lock, epoch). ``aafter_agent``
|
||||
# only releases when its epoch still matches the manager's current
|
||||
# epoch for the thread, preventing stale unlock races.
|
||||
self._held_locks: dict[str, tuple[asyncio.Lock, int]] = {}
|
||||
|
||||
@staticmethod
|
||||
def _thread_id(runtime: Runtime[ContextT]) -> str | None:
|
||||
|
|
@ -232,7 +247,8 @@ class BusyMutexMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, Respo
|
|||
if lock.locked():
|
||||
raise BusyError(request_id=thread_id)
|
||||
await lock.acquire()
|
||||
self._held_locks[thread_id] = lock
|
||||
epoch = manager.bump_turn_epoch(thread_id)
|
||||
self._held_locks[thread_id] = (lock, epoch)
|
||||
# Reset the cancel event so this turn starts fresh
|
||||
reset_cancel(thread_id)
|
||||
return None
|
||||
|
|
@ -246,8 +262,15 @@ class BusyMutexMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, Respo
|
|||
thread_id = self._thread_id(runtime)
|
||||
if thread_id is None:
|
||||
return None
|
||||
lock = self._held_locks.pop(thread_id, None)
|
||||
if lock is not None and lock.locked():
|
||||
held = self._held_locks.pop(thread_id, None)
|
||||
if held is None:
|
||||
return None
|
||||
lock, held_epoch = held
|
||||
if held_epoch != manager.current_turn_epoch(thread_id):
|
||||
# Stale teardown from an older attempt (e.g. runtime-recovery path
|
||||
# already advanced epoch). Do not touch current lock/cancel state.
|
||||
return None
|
||||
if lock.locked():
|
||||
lock.release()
|
||||
# Always clear cancel event between turns so a stale signal
|
||||
# doesn't leak into the next request.
|
||||
|
|
|
|||
|
|
@ -63,6 +63,27 @@ def load_global_llm_configs():
|
|||
else:
|
||||
seen_slugs[slug] = cfg.get("id", 0)
|
||||
|
||||
# Stamp Auto (Fastest) ranking metadata. YAML configs are always
|
||||
# Tier A — operator-curated, locked first when premium-eligible.
|
||||
# The OpenRouter refresh tick later re-stamps health for any cfg
|
||||
# whose provider == "OPENROUTER" via _enrich_health.
|
||||
try:
|
||||
from app.services.quality_score import static_score_yaml
|
||||
|
||||
for cfg in configs:
|
||||
cfg["auto_pin_tier"] = "A"
|
||||
static_q = static_score_yaml(cfg)
|
||||
cfg["quality_score_static"] = static_q
|
||||
cfg["quality_score"] = static_q
|
||||
cfg["quality_score_health"] = None
|
||||
# YAML cfgs whose provider is OPENROUTER are also subject
|
||||
# to health gating against their own /endpoints data — a
|
||||
# hand-picked dead OR model is still dead. _enrich_health
|
||||
# re-stamps health_gated for them on the next refresh tick.
|
||||
cfg["health_gated"] = False
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to score global LLM configs: {e}")
|
||||
|
||||
return configs
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to load global LLM configs: {e}")
|
||||
|
|
@ -194,6 +215,9 @@ def load_openrouter_integration_settings() -> dict | None:
|
|||
"""
|
||||
Load OpenRouter integration settings from the YAML config.
|
||||
|
||||
Emits startup warnings for deprecated keys (``billing_tier``,
|
||||
``anonymous_enabled``) and seeds their replacements for back-compat.
|
||||
|
||||
Returns:
|
||||
dict with settings if present and enabled, None otherwise
|
||||
"""
|
||||
|
|
@ -206,9 +230,31 @@ def load_openrouter_integration_settings() -> dict | None:
|
|||
with open(global_config_file, encoding="utf-8") as f:
|
||||
data = yaml.safe_load(f)
|
||||
settings = data.get("openrouter_integration")
|
||||
if settings and settings.get("enabled"):
|
||||
return settings
|
||||
return None
|
||||
if not settings or not settings.get("enabled"):
|
||||
return None
|
||||
|
||||
if "billing_tier" in settings:
|
||||
print(
|
||||
"Warning: openrouter_integration.billing_tier is deprecated; "
|
||||
"tier is now derived per model from OpenRouter data "
|
||||
"(':free' suffix or zero pricing). Remove this key."
|
||||
)
|
||||
|
||||
if "anonymous_enabled" in settings:
|
||||
print(
|
||||
"Warning: openrouter_integration.anonymous_enabled is "
|
||||
"deprecated; use anonymous_enabled_paid and/or "
|
||||
"anonymous_enabled_free instead. Both new flags have been "
|
||||
"seeded from the legacy value for back-compat."
|
||||
)
|
||||
settings.setdefault(
|
||||
"anonymous_enabled_paid", settings["anonymous_enabled"]
|
||||
)
|
||||
settings.setdefault(
|
||||
"anonymous_enabled_free", settings["anonymous_enabled"]
|
||||
)
|
||||
|
||||
return settings
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to load OpenRouter integration settings: {e}")
|
||||
return None
|
||||
|
|
@ -217,9 +263,14 @@ def load_openrouter_integration_settings() -> dict | None:
|
|||
def initialize_openrouter_integration():
|
||||
"""
|
||||
If enabled, fetch all OpenRouter models and append them to
|
||||
config.GLOBAL_LLM_CONFIGS as dynamic premium entries.
|
||||
Should be called BEFORE initialize_llm_router() so the router
|
||||
correctly excludes premium models from Auto mode.
|
||||
config.GLOBAL_LLM_CONFIGS as dynamic entries. Each model's ``billing_tier``
|
||||
is derived per-model from OpenRouter's API signals (``:free`` suffix or
|
||||
zero pricing), so free OpenRouter models correctly skip premium quota.
|
||||
|
||||
Should be called BEFORE initialize_llm_router(). Dynamic entries are
|
||||
tagged ``router_pool_eligible=False`` so the LiteLLM Router pool (used
|
||||
by title-gen / sub-agent flows) remains scoped to curated YAML configs,
|
||||
while user-facing Auto-mode thread pinning still considers them.
|
||||
"""
|
||||
settings = load_openrouter_integration_settings()
|
||||
if not settings:
|
||||
|
|
@ -235,9 +286,13 @@ def initialize_openrouter_integration():
|
|||
|
||||
if new_configs:
|
||||
config.GLOBAL_LLM_CONFIGS.extend(new_configs)
|
||||
free_count = sum(1 for c in new_configs if c.get("billing_tier") == "free")
|
||||
premium_count = sum(
|
||||
1 for c in new_configs if c.get("billing_tier") == "premium"
|
||||
)
|
||||
print(
|
||||
f"Info: OpenRouter integration added {len(new_configs)} models "
|
||||
f"(billing_tier={settings.get('billing_tier', 'premium')})"
|
||||
f"(free={free_count}, premium={premium_count})"
|
||||
)
|
||||
else:
|
||||
print("Info: OpenRouter integration enabled but no models fetched")
|
||||
|
|
|
|||
|
|
@ -245,31 +245,53 @@ global_llm_configs:
|
|||
# =============================================================================
|
||||
# When enabled, dynamically fetches ALL available models from the OpenRouter API
|
||||
# and injects them as global configs. This gives premium users access to any model
|
||||
# on OpenRouter (Claude, Gemini, Llama, Mistral, etc.) via their premium token quota.
|
||||
# on OpenRouter (Claude, Gemini, Llama, Mistral, etc.) via their premium token quota,
|
||||
# while free-tier OpenRouter models show up with a green Free badge and do NOT
|
||||
# consume premium quota.
|
||||
# Models are fetched at startup and refreshed periodically in the background.
|
||||
# All calls go through LiteLLM with the openrouter/ prefix.
|
||||
openrouter_integration:
|
||||
enabled: false
|
||||
api_key: "sk-or-your-openrouter-api-key"
|
||||
# billing_tier: "premium" or "free". Controls whether users need premium tokens.
|
||||
billing_tier: "premium"
|
||||
# anonymous_enabled: set true to also show OpenRouter models to no-login users
|
||||
anonymous_enabled: false
|
||||
|
||||
# Tier is derived PER MODEL from OpenRouter's own API signals:
|
||||
# - id ends with ":free" -> billing_tier=free
|
||||
# - pricing.prompt AND pricing.completion == "0" -> billing_tier=free
|
||||
# - otherwise -> billing_tier=premium
|
||||
# No global billing_tier knob is honored; any legacy value emits a startup warning.
|
||||
|
||||
# Anonymous access is split by tier so operators can expose only free
|
||||
# models to no-login users without leaking paid inference.
|
||||
anonymous_enabled_paid: false
|
||||
anonymous_enabled_free: false
|
||||
|
||||
seo_enabled: false
|
||||
# quota_reserve_tokens: tokens reserved per call for quota enforcement
|
||||
quota_reserve_tokens: 4000
|
||||
# id_offset: starting negative ID for dynamically generated configs.
|
||||
# Must not overlap with your static global_llm_configs IDs above.
|
||||
# id_offset: base negative ID for dynamically generated configs.
|
||||
# Model IDs are derived deterministically via BLAKE2b so they survive
|
||||
# catalogue churn. Must not overlap with your static global_llm_configs IDs.
|
||||
id_offset: -10000
|
||||
# refresh_interval_hours: how often to re-fetch models from OpenRouter (0 = startup only)
|
||||
refresh_interval_hours: 24
|
||||
# rpm/tpm: Applied uniformly to all OpenRouter models for LiteLLM Router load balancing.
|
||||
# OpenRouter doesn't expose per-model rate limits via API; actual throttling is handled
|
||||
# upstream by OpenRouter itself (your account limits are at https://openrouter.ai/settings/limits).
|
||||
# These values only matter if you set billing_tier to "free" (adding them to Auto mode).
|
||||
# For premium-only models they are cosmetic. Set conservatively or match your account tier.
|
||||
|
||||
# Rate limits for PAID OpenRouter models. These are used by LiteLLM Router
|
||||
# for per-deployment accounting when OR premium models participate in the
|
||||
# shared sub-agent "auto" pool. They do NOT cap OpenRouter itself — your
|
||||
# real account limits live at https://openrouter.ai/settings/limits.
|
||||
rpm: 200
|
||||
tpm: 1000000
|
||||
|
||||
# Rate limits for FREE OpenRouter models. Informational only: free OR
|
||||
# models are intentionally kept OUT of the LiteLLM Router pool, because
|
||||
# OpenRouter enforces free-tier limits globally per account (~20 RPM +
|
||||
# 50-1000 daily requests across every ":free" model combined) —
|
||||
# per-deployment router accounting can't represent a shared bucket
|
||||
# correctly. Free OR models stay fully available in the model selector
|
||||
# and for user-facing Auto thread pinning.
|
||||
free_rpm: 20
|
||||
free_tpm: 100000
|
||||
|
||||
litellm_params:
|
||||
max_tokens: 16384
|
||||
system_instructions: ""
|
||||
|
|
|
|||
|
|
@ -638,13 +638,12 @@ class NewChatThread(BaseModel, TimestampMixin):
|
|||
default=False,
|
||||
server_default="false",
|
||||
)
|
||||
# Auto model pinning metadata:
|
||||
# - pinned_llm_config_id stores the concrete resolved model config id.
|
||||
# - pinned_auto_mode indicates which auto policy produced the pin.
|
||||
# This allows Auto (Fastest) to resolve once per thread and stay stable.
|
||||
pinned_llm_config_id = Column(Integer, nullable=True, index=True)
|
||||
pinned_auto_mode = Column(String(32), nullable=True, index=True)
|
||||
pinned_at = Column(TIMESTAMP(timezone=True), nullable=True)
|
||||
# Auto (Fastest) model pin for this thread: concrete resolved global LLM
|
||||
# config id. NULL means no pin; Auto will resolve on the next turn.
|
||||
# Single-writer invariant: only app.services.auto_model_pin_service sets
|
||||
# or clears this column (plus bulk clears when a search space's
|
||||
# agent_llm_id changes). Unindexed: all reads are by primary key.
|
||||
pinned_llm_config_id = Column(Integer, nullable=True)
|
||||
|
||||
# Relationships
|
||||
search_space = relationship("SearchSpace", back_populates="new_chat_threads")
|
||||
|
|
|
|||
|
|
@ -745,6 +745,51 @@ async def search_document_titles(
|
|||
) from e
|
||||
|
||||
|
||||
@router.get("/documents/by-virtual-path", response_model=DocumentTitleRead)
|
||||
async def get_document_by_virtual_path(
|
||||
search_space_id: int,
|
||||
virtual_path: str,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""Resolve a knowledge-base document id by exact virtual path."""
|
||||
try:
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
search_space_id,
|
||||
Permission.DOCUMENTS_READ.value,
|
||||
"You don't have permission to read documents in this search space",
|
||||
)
|
||||
|
||||
result = await session.execute(
|
||||
select(
|
||||
Document.id,
|
||||
Document.title,
|
||||
Document.document_type,
|
||||
).filter(
|
||||
Document.search_space_id == search_space_id,
|
||||
Document.document_metadata["virtual_path"].as_string() == virtual_path,
|
||||
)
|
||||
)
|
||||
row = result.first()
|
||||
if row is None:
|
||||
raise HTTPException(status_code=404, detail="Document not found")
|
||||
|
||||
return DocumentTitleRead(
|
||||
id=row.id,
|
||||
title=row.title,
|
||||
document_type=row.document_type,
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to resolve document by virtual path: {e!s}",
|
||||
) from e
|
||||
|
||||
|
||||
@router.get("/documents/status", response_model=DocumentStatusBatchResponse)
|
||||
async def get_documents_status(
|
||||
search_space_id: int,
|
||||
|
|
|
|||
|
|
@ -803,11 +803,7 @@ async def update_llm_preferences(
|
|||
await session.execute(
|
||||
update(NewChatThread)
|
||||
.where(NewChatThread.search_space_id == search_space_id)
|
||||
.values(
|
||||
pinned_llm_config_id=None,
|
||||
pinned_auto_mode=None,
|
||||
pinned_at=None,
|
||||
)
|
||||
.values(pinned_llm_config_id=None)
|
||||
)
|
||||
logger.info(
|
||||
"Cleared auto model pins for search_space_id=%s after agent_llm_id change (%s -> %s)",
|
||||
|
|
|
|||
|
|
@ -2,16 +2,23 @@
|
|||
|
||||
Auto (Fastest) is represented by ``agent_llm_id == 0``. For chat threads we
|
||||
resolve that virtual mode to one concrete global LLM config exactly once and
|
||||
persist the chosen config id on ``new_chat_threads`` so subsequent turns are
|
||||
stable.
|
||||
persist the chosen config id on ``new_chat_threads.pinned_llm_config_id`` so
|
||||
subsequent turns are stable.
|
||||
|
||||
Single-writer invariant: this module is the only writer of
|
||||
``NewChatThread.pinned_llm_config_id`` (aside from the bulk clear in
|
||||
``search_spaces_routes`` when a search space's ``agent_llm_id`` changes).
|
||||
Therefore a non-NULL value unambiguously means "this thread has an
|
||||
Auto-resolved pin"; no separate source/policy column is needed.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import select
|
||||
|
|
@ -19,12 +26,28 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|||
|
||||
from app.config import config
|
||||
from app.db import NewChatThread
|
||||
from app.services.quality_score import _QUALITY_TOP_K
|
||||
from app.services.token_quota_service import TokenQuotaService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
AUTO_FASTEST_ID = 0
|
||||
AUTO_FASTEST_MODE = "auto_fastest"
|
||||
_RUNTIME_COOLDOWN_SECONDS = 600
|
||||
_HEALTHY_TTL_SECONDS = 45
|
||||
|
||||
# In-memory runtime cooldown map for configs that recently hard-failed at
|
||||
# provider runtime (e.g. OpenRouter 429 on a pinned free model). This keeps
|
||||
# the same unhealthy config from being reselected immediately during repair.
|
||||
_runtime_cooldown_until: dict[int, float] = {}
|
||||
_runtime_cooldown_lock = threading.Lock()
|
||||
|
||||
# Short-TTL "recently healthy" cache for configs that just passed a runtime
|
||||
# preflight ping. Lets back-to-back turns on the same model skip the probe
|
||||
# without eroding correctness — entries auto-expire and are wiped any time
|
||||
# the same config is cooled down or the OR catalogue is refreshed.
|
||||
_healthy_until: dict[int, float] = {}
|
||||
_healthy_lock = threading.Lock()
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -43,9 +66,117 @@ def _is_usable_global_config(cfg: dict) -> bool:
|
|||
)
|
||||
|
||||
|
||||
def _prune_runtime_cooldowns(now_ts: float | None = None) -> None:
|
||||
now = time.time() if now_ts is None else now_ts
|
||||
stale = [cid for cid, until in _runtime_cooldown_until.items() if until <= now]
|
||||
for cid in stale:
|
||||
_runtime_cooldown_until.pop(cid, None)
|
||||
|
||||
|
||||
def _is_runtime_cooled_down(config_id: int) -> bool:
|
||||
with _runtime_cooldown_lock:
|
||||
_prune_runtime_cooldowns()
|
||||
return config_id in _runtime_cooldown_until
|
||||
|
||||
|
||||
def mark_runtime_cooldown(
|
||||
config_id: int,
|
||||
*,
|
||||
reason: str = "rate_limited",
|
||||
cooldown_seconds: int = _RUNTIME_COOLDOWN_SECONDS,
|
||||
) -> None:
|
||||
"""Temporarily suppress a config from Auto selection.
|
||||
|
||||
Used by runtime error handlers (e.g. OpenRouter 429) so an already pinned
|
||||
config that is currently unhealthy does not get immediately reused on the
|
||||
same thread during repair.
|
||||
"""
|
||||
if cooldown_seconds <= 0:
|
||||
cooldown_seconds = _RUNTIME_COOLDOWN_SECONDS
|
||||
until = time.time() + int(cooldown_seconds)
|
||||
with _runtime_cooldown_lock:
|
||||
_runtime_cooldown_until[int(config_id)] = until
|
||||
_prune_runtime_cooldowns()
|
||||
# A cooled cfg can never be "recently healthy"; drop any stale credit so
|
||||
# the next turn that resolves to it (after cooldown) re-runs preflight.
|
||||
clear_healthy(int(config_id))
|
||||
logger.info(
|
||||
"auto_pin_runtime_cooled_down config_id=%s reason=%s cooldown_seconds=%s",
|
||||
config_id,
|
||||
reason,
|
||||
cooldown_seconds,
|
||||
)
|
||||
|
||||
|
||||
def clear_runtime_cooldown(config_id: int | None = None) -> None:
|
||||
"""Test/ops helper to clear runtime cooldown entries."""
|
||||
with _runtime_cooldown_lock:
|
||||
if config_id is None:
|
||||
_runtime_cooldown_until.clear()
|
||||
return
|
||||
_runtime_cooldown_until.pop(int(config_id), None)
|
||||
|
||||
|
||||
def _prune_healthy(now_ts: float | None = None) -> None:
|
||||
now = time.time() if now_ts is None else now_ts
|
||||
stale = [cid for cid, until in _healthy_until.items() if until <= now]
|
||||
for cid in stale:
|
||||
_healthy_until.pop(cid, None)
|
||||
|
||||
|
||||
def is_recently_healthy(config_id: int) -> bool:
|
||||
"""Return True if ``config_id`` passed preflight within the TTL window."""
|
||||
with _healthy_lock:
|
||||
_prune_healthy()
|
||||
return int(config_id) in _healthy_until
|
||||
|
||||
|
||||
def mark_healthy(
|
||||
config_id: int,
|
||||
*,
|
||||
ttl_seconds: int = _HEALTHY_TTL_SECONDS,
|
||||
) -> None:
|
||||
"""Record that ``config_id`` just passed a preflight probe.
|
||||
|
||||
Subsequent calls within ``ttl_seconds`` can skip the preflight ping. The
|
||||
healthy state is intentionally process-local — it's a latency hint, not a
|
||||
correctness primitive — so multi-worker drift is acceptable.
|
||||
"""
|
||||
if ttl_seconds <= 0:
|
||||
ttl_seconds = _HEALTHY_TTL_SECONDS
|
||||
until = time.time() + int(ttl_seconds)
|
||||
with _healthy_lock:
|
||||
_healthy_until[int(config_id)] = until
|
||||
_prune_healthy()
|
||||
|
||||
|
||||
def clear_healthy(config_id: int | None = None) -> None:
|
||||
"""Drop one (or all) healthy-cache entries.
|
||||
|
||||
Called from runtime cooldown and OR catalogue refresh so a freshly cooled
|
||||
or replaced config never carries stale "healthy" credit.
|
||||
"""
|
||||
with _healthy_lock:
|
||||
if config_id is None:
|
||||
_healthy_until.clear()
|
||||
return
|
||||
_healthy_until.pop(int(config_id), None)
|
||||
|
||||
|
||||
def _global_candidates() -> list[dict]:
|
||||
"""Return Auto-eligible global cfgs.
|
||||
|
||||
Drops cfgs flagged ``health_gated`` (best non-null OpenRouter uptime
|
||||
below ``_HEALTH_GATE_UPTIME_PCT``) so chronically broken providers
|
||||
can't be picked as the thread's pin. Also excludes configs currently
|
||||
in runtime cooldown (e.g. temporary 429 bursts).
|
||||
"""
|
||||
candidates = [
|
||||
cfg for cfg in config.GLOBAL_LLM_CONFIGS if _is_usable_global_config(cfg)
|
||||
cfg
|
||||
for cfg in config.GLOBAL_LLM_CONFIGS
|
||||
if _is_usable_global_config(cfg)
|
||||
and not cfg.get("health_gated")
|
||||
and not _is_runtime_cooled_down(int(cfg.get("id", 0)))
|
||||
]
|
||||
return sorted(candidates, key=lambda c: int(c.get("id", 0)))
|
||||
|
||||
|
|
@ -54,10 +185,26 @@ def _tier_of(cfg: dict) -> str:
|
|||
return str(cfg.get("billing_tier", "free")).lower()
|
||||
|
||||
|
||||
def _deterministic_pick(candidates: list[dict], thread_id: int) -> dict:
|
||||
def _select_pin(eligible: list[dict], thread_id: int) -> tuple[dict, int]:
|
||||
"""Pick a config with quality-first ranking + deterministic spread.
|
||||
|
||||
Tier policy is lock-first: prefer Tier A (operator-curated YAML)
|
||||
cfgs and only fall through to Tier B/C (dynamic OpenRouter) if no
|
||||
Tier A cfg is eligible after upstream filters. Within the locked
|
||||
pool, sort by ``quality_score`` and pick from the top-K via
|
||||
``SHA256(thread_id)`` so different new threads spread across the
|
||||
best models without ever picking a low-ranked one.
|
||||
|
||||
Returns ``(chosen_cfg, top_k_size)``. ``top_k_size`` is exposed for
|
||||
structured logging in the caller.
|
||||
"""
|
||||
tier_a = [c for c in eligible if c.get("auto_pin_tier") in (None, "A")]
|
||||
pool = tier_a if tier_a else eligible
|
||||
pool = sorted(pool, key=lambda c: -int(c.get("quality_score") or 0))
|
||||
top_k = pool[:_QUALITY_TOP_K]
|
||||
digest = hashlib.sha256(f"{AUTO_FASTEST_MODE}:{thread_id}".encode()).digest()
|
||||
idx = int.from_bytes(digest[:8], "big") % len(candidates)
|
||||
return candidates[idx]
|
||||
idx = int.from_bytes(digest[:8], "big") % len(top_k)
|
||||
return top_k[idx], len(top_k)
|
||||
|
||||
|
||||
def _to_uuid(user_id: str | UUID | None) -> UUID | None:
|
||||
|
|
@ -89,11 +236,12 @@ async def resolve_or_get_pinned_llm_config_id(
|
|||
user_id: str | UUID | None,
|
||||
selected_llm_config_id: int,
|
||||
force_repin_free: bool = False,
|
||||
exclude_config_ids: set[int] | None = None,
|
||||
) -> AutoPinResolution:
|
||||
"""Resolve Auto (Fastest) to one concrete config id and persist pin metadata.
|
||||
"""Resolve Auto (Fastest) to one concrete config id and persist the pin.
|
||||
|
||||
For non-auto selections, this function clears existing auto pin metadata and
|
||||
returns the selected id as-is.
|
||||
For non-auto selections, this function clears any existing pin and returns
|
||||
the selected id as-is.
|
||||
"""
|
||||
thread = (
|
||||
(
|
||||
|
|
@ -113,16 +261,10 @@ async def resolve_or_get_pinned_llm_config_id(
|
|||
f"Thread {thread_id} does not belong to search space {search_space_id}"
|
||||
)
|
||||
|
||||
# Explicit model selected: clear stale auto pin metadata.
|
||||
# Explicit model selected: clear any stale pin.
|
||||
if selected_llm_config_id != AUTO_FASTEST_ID:
|
||||
if (
|
||||
thread.pinned_llm_config_id is not None
|
||||
or thread.pinned_auto_mode is not None
|
||||
or thread.pinned_at is not None
|
||||
):
|
||||
if thread.pinned_llm_config_id is not None:
|
||||
thread.pinned_llm_config_id = None
|
||||
thread.pinned_auto_mode = None
|
||||
thread.pinned_at = None
|
||||
await session.commit()
|
||||
return AutoPinResolution(
|
||||
resolved_llm_config_id=selected_llm_config_id,
|
||||
|
|
@ -130,17 +272,19 @@ async def resolve_or_get_pinned_llm_config_id(
|
|||
from_existing_pin=False,
|
||||
)
|
||||
|
||||
candidates = _global_candidates()
|
||||
excluded_ids = {int(cid) for cid in (exclude_config_ids or set())}
|
||||
candidates = [
|
||||
c for c in _global_candidates() if int(c.get("id", 0)) not in excluded_ids
|
||||
]
|
||||
if not candidates:
|
||||
raise ValueError("No usable global LLM configs are available for Auto mode")
|
||||
candidate_by_id = {int(c["id"]): c for c in candidates}
|
||||
|
||||
# Reuse existing valid pin without re-checking current quota (no silent tier switch),
|
||||
# unless the caller explicitly requests a forced repin to free.
|
||||
# Reuse an existing valid pin without re-checking current quota (no silent
|
||||
# tier switch), unless the caller explicitly requests a forced repin to free.
|
||||
pinned_id = thread.pinned_llm_config_id
|
||||
if (
|
||||
not force_repin_free
|
||||
and thread.pinned_auto_mode == AUTO_FASTEST_MODE
|
||||
and pinned_id is not None
|
||||
and int(pinned_id) in candidate_by_id
|
||||
):
|
||||
|
|
@ -152,6 +296,15 @@ async def resolve_or_get_pinned_llm_config_id(
|
|||
pinned_id,
|
||||
_tier_of(pinned_cfg),
|
||||
)
|
||||
logger.info(
|
||||
"auto_pin_resolved thread_id=%s config_id=%s tier=%s "
|
||||
"auto_pin_tier=%s score=%s top_k_size=0 from_existing_pin=True",
|
||||
thread_id,
|
||||
pinned_id,
|
||||
_tier_of(pinned_cfg),
|
||||
pinned_cfg.get("auto_pin_tier", "?"),
|
||||
int(pinned_cfg.get("quality_score") or 0),
|
||||
)
|
||||
return AutoPinResolution(
|
||||
resolved_llm_config_id=int(pinned_id),
|
||||
resolved_tier=_tier_of(pinned_cfg),
|
||||
|
|
@ -159,11 +312,10 @@ async def resolve_or_get_pinned_llm_config_id(
|
|||
)
|
||||
if pinned_id is not None:
|
||||
logger.info(
|
||||
"auto_pin_invalid thread_id=%s search_space_id=%s pinned_config_id=%s pinned_auto_mode=%s",
|
||||
"auto_pin_invalid thread_id=%s search_space_id=%s pinned_config_id=%s",
|
||||
thread_id,
|
||||
search_space_id,
|
||||
pinned_id,
|
||||
thread.pinned_auto_mode,
|
||||
)
|
||||
|
||||
premium_eligible = (
|
||||
|
|
@ -179,13 +331,11 @@ async def resolve_or_get_pinned_llm_config_id(
|
|||
"Auto mode could not find an eligible LLM config for this user and quota state"
|
||||
)
|
||||
|
||||
selected_cfg = _deterministic_pick(eligible, thread_id)
|
||||
selected_cfg, top_k_size = _select_pin(eligible, thread_id)
|
||||
selected_id = int(selected_cfg["id"])
|
||||
selected_tier = _tier_of(selected_cfg)
|
||||
|
||||
thread.pinned_llm_config_id = selected_id
|
||||
thread.pinned_auto_mode = AUTO_FASTEST_MODE
|
||||
thread.pinned_at = datetime.now(UTC)
|
||||
await session.commit()
|
||||
|
||||
if force_repin_free:
|
||||
|
|
@ -216,6 +366,18 @@ async def resolve_or_get_pinned_llm_config_id(
|
|||
selected_tier,
|
||||
premium_eligible,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"auto_pin_resolved thread_id=%s config_id=%s tier=%s "
|
||||
"auto_pin_tier=%s score=%s top_k_size=%d from_existing_pin=False",
|
||||
thread_id,
|
||||
selected_id,
|
||||
selected_tier,
|
||||
selected_cfg.get("auto_pin_tier", "?"),
|
||||
int(selected_cfg.get("quality_score") or 0),
|
||||
top_k_size,
|
||||
)
|
||||
|
||||
return AutoPinResolution(
|
||||
resolved_llm_config_id=selected_id,
|
||||
resolved_tier=selected_tier,
|
||||
|
|
|
|||
|
|
@ -208,6 +208,12 @@ class LLMRouterService:
|
|||
"""
|
||||
Initialize the router with global LLM configurations.
|
||||
|
||||
Configs with ``router_pool_eligible=False`` are skipped so that
|
||||
dynamic OpenRouter entries stay out of the shared router pool used
|
||||
by title-gen / sub-agent ``model="auto"`` flows. Those dynamic
|
||||
entries are still available for user-facing Auto-mode thread pinning
|
||||
via ``auto_model_pin_service``.
|
||||
|
||||
Args:
|
||||
global_configs: List of global LLM config dictionaries from YAML
|
||||
router_settings: Optional router settings (routing_strategy, num_retries, etc.)
|
||||
|
|
@ -221,6 +227,8 @@ class LLMRouterService:
|
|||
model_list = []
|
||||
premium_models: set[str] = set()
|
||||
for config in global_configs:
|
||||
if config.get("router_pool_eligible") is False:
|
||||
continue
|
||||
deployment = cls._config_to_deployment(config)
|
||||
if deployment:
|
||||
model_list.append(deployment)
|
||||
|
|
@ -309,10 +317,45 @@ class LLMRouterService:
|
|||
logger.error(f"Failed to initialize LLM Router: {e}")
|
||||
instance._router = None
|
||||
|
||||
@classmethod
|
||||
def rebuild(
|
||||
cls,
|
||||
global_configs: list[dict],
|
||||
router_settings: dict | None = None,
|
||||
) -> None:
|
||||
"""Reset the router and re-run ``initialize`` with fresh configs.
|
||||
|
||||
``initialize`` short-circuits once it has run to avoid re-creating the
|
||||
LiteLLM Router on every request; ``rebuild`` deliberately clears
|
||||
``_initialized`` so a caller (e.g. background OpenRouter refresh)
|
||||
can force the pool to be rebuilt after catalogue changes.
|
||||
"""
|
||||
instance = cls.get_instance()
|
||||
instance._initialized = False
|
||||
instance._router = None
|
||||
instance._model_list = []
|
||||
instance._premium_model_strings = set()
|
||||
cls.initialize(global_configs, router_settings)
|
||||
|
||||
@classmethod
|
||||
def is_premium_model(cls, model_string: str) -> bool:
|
||||
"""Return True if *model_string* (as reported by LiteLLM) belongs to a
|
||||
premium-tier deployment in the router pool."""
|
||||
"""Return True if *model_string* belongs to a premium-tier deployment
|
||||
in the LiteLLM router pool.
|
||||
|
||||
Scope: only covers configs with ``router_pool_eligible`` truthy. That
|
||||
includes static YAML premium configs AND dynamic OpenRouter *premium*
|
||||
entries (which opt in at generation time). Dynamic OpenRouter *free*
|
||||
entries are deliberately kept out of the router pool — OpenRouter
|
||||
enforces free-tier limits globally per account, so per-deployment
|
||||
router accounting can't represent them correctly — and therefore
|
||||
return ``False`` here, which matches their ``billing_tier="free"``
|
||||
(no premium quota).
|
||||
|
||||
For per-request premium checks on an arbitrary config (static or
|
||||
dynamic, pool or non-pool), read ``agent_config.is_premium`` instead;
|
||||
that reflects the per-config ``billing_tier`` directly and is what
|
||||
user-facing Auto-mode thread pinning uses to bill correctly.
|
||||
"""
|
||||
instance = cls.get_instance()
|
||||
return model_string in instance._premium_model_strings
|
||||
|
||||
|
|
|
|||
|
|
@ -11,20 +11,81 @@ this service only manages the catalogue, not the inference path.
|
|||
"""
|
||||
|
||||
import asyncio
|
||||
import hashlib
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
|
||||
from app.services.quality_score import (
|
||||
_HEALTH_BLEND_WEIGHT,
|
||||
_HEALTH_ENRICH_CONCURRENCY,
|
||||
_HEALTH_ENRICH_TOP_N_FREE,
|
||||
_HEALTH_ENRICH_TOP_N_PREMIUM,
|
||||
_HEALTH_FAIL_RATIO_FALLBACK,
|
||||
_HEALTH_FETCH_TIMEOUT_SEC,
|
||||
aggregate_health,
|
||||
static_score_or,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
OPENROUTER_API_URL = "https://openrouter.ai/api/v1/models"
|
||||
OPENROUTER_ENDPOINTS_URL_TEMPLATE = (
|
||||
"https://openrouter.ai/api/v1/models/{model_id}/endpoints"
|
||||
)
|
||||
|
||||
# Sentinel value stored on each generated config so we can distinguish
|
||||
# dynamic OpenRouter entries from hand-written YAML entries during refresh.
|
||||
_OPENROUTER_DYNAMIC_MARKER = "__openrouter_dynamic__"
|
||||
|
||||
# Width of the hash space used by ``_stable_config_id``. 9_000_000 provides
|
||||
# enough headroom to avoid frequent collisions for OpenRouter's catalogue
|
||||
# (~300 models) while keeping IDs comfortably within Postgres INTEGER range.
|
||||
_STABLE_ID_HASH_WIDTH = 9_000_000
|
||||
|
||||
|
||||
def _stable_config_id(model_id: str, offset: int, taken: set[int]) -> int:
|
||||
"""Derive a deterministic negative config ID from ``model_id``.
|
||||
|
||||
The same ``model_id`` always hashes to the same base value so thread pins
|
||||
survive catalogue churn (models appearing/disappearing/reordering between
|
||||
refreshes). On collision we decrement until we find an unused slot; this
|
||||
keeps the mapping stable for the first config that claimed a slot and
|
||||
only shifts collisions, which is much less disruptive than the legacy
|
||||
index-based scheme that reshuffled every ID when the catalogue changed.
|
||||
"""
|
||||
digest = hashlib.blake2b(model_id.encode("utf-8"), digest_size=6).digest()
|
||||
base = offset - (int.from_bytes(digest, "big") % _STABLE_ID_HASH_WIDTH)
|
||||
cid = base
|
||||
while cid in taken:
|
||||
cid -= 1
|
||||
taken.add(cid)
|
||||
return cid
|
||||
|
||||
|
||||
def _openrouter_tier(model: dict) -> str:
|
||||
"""Classify an OpenRouter model as ``"free"`` or ``"premium"``.
|
||||
|
||||
Per OpenRouter's API contract, a model is free if:
|
||||
- Its id ends with ``:free`` (OpenRouter's own free-variant convention), or
|
||||
- Both ``pricing.prompt`` and ``pricing.completion`` are zero strings.
|
||||
|
||||
Anything else (missing pricing, non-zero pricing) falls through to
|
||||
``"premium"`` so we never under-charge users. This derivation runs off the
|
||||
already-cached /api/v1/models payload, so it adds no network cost.
|
||||
"""
|
||||
if model.get("id", "").endswith(":free"):
|
||||
return "free"
|
||||
pricing = model.get("pricing") or {}
|
||||
prompt = str(pricing.get("prompt", "")).strip()
|
||||
completion = str(pricing.get("completion", "")).strip()
|
||||
if prompt == "0" and completion == "0":
|
||||
return "free"
|
||||
return "premium"
|
||||
|
||||
|
||||
def _is_text_output_model(model: dict) -> bool:
|
||||
"""Return True if the model produces text output only (skip image/audio generators)."""
|
||||
|
|
@ -56,6 +117,11 @@ _EXCLUDED_MODEL_IDS: set[str] = {
|
|||
# Deep-research models reject standard params (temperature, etc.)
|
||||
"openai/o3-deep-research",
|
||||
"openai/o4-mini-deep-research",
|
||||
# OpenRouter's own meta-router over free models. We already enumerate every
|
||||
# concrete ``:free`` model into GLOBAL_LLM_CONFIGS and Auto-mode thread
|
||||
# pinning handles churn via the repair path, so exposing an additional
|
||||
# indirection layer would only duplicate the capability with an opaque slug.
|
||||
"openrouter/free",
|
||||
}
|
||||
|
||||
_EXCLUDED_MODEL_SUFFIXES: tuple[str, ...] = ("-deep-research",)
|
||||
|
|
@ -113,20 +179,41 @@ def _generate_configs(
|
|||
raw_models: list[dict],
|
||||
settings: dict[str, Any],
|
||||
) -> list[dict]:
|
||||
"""
|
||||
Convert raw OpenRouter model entries into global LLM config dicts.
|
||||
"""Convert raw OpenRouter model entries into global LLM config dicts.
|
||||
|
||||
Models are sorted by ID for deterministic, stable ID assignment across
|
||||
restarts and refreshes.
|
||||
Tier (``billing_tier``) is derived per-model from OpenRouter's own API
|
||||
signals via ``_openrouter_tier`` — there is no longer a uniform YAML
|
||||
override. Config IDs are derived via ``_stable_config_id`` so they
|
||||
survive catalogue churn across refreshes.
|
||||
|
||||
Router-pool membership is tier-aware:
|
||||
|
||||
- Premium OR models join the LiteLLM router pool (``router_pool_eligible=True``)
|
||||
so sub-agent ``model="auto"`` flows benefit from load balancing and
|
||||
failover across the curated YAML configs and the OR premium passthrough.
|
||||
- Free OR models stay excluded (``router_pool_eligible=False``). LiteLLM
|
||||
Router tracks rate limits per deployment, but OpenRouter enforces a
|
||||
single global free-tier quota (~20 RPM + 50-1000 daily requests
|
||||
account-wide across every ``:free`` model), so rotating across many
|
||||
free deployments would only burn the shared bucket faster. Free OR
|
||||
models remain fully available for user-facing Auto-mode thread pinning
|
||||
via ``auto_model_pin_service``.
|
||||
|
||||
OpenRouter's own ``openrouter/free`` meta-router is filtered out upstream
|
||||
via ``_EXCLUDED_MODEL_IDS``; we don't expose a redundant auto-select layer
|
||||
because our own Auto (Fastest) pin + 24 h refresh + repair logic already
|
||||
cover the catalogue-churn case.
|
||||
"""
|
||||
id_offset: int = settings.get("id_offset", -10000)
|
||||
api_key: str = settings.get("api_key", "")
|
||||
billing_tier: str = settings.get("billing_tier", "premium")
|
||||
anonymous_enabled: bool = settings.get("anonymous_enabled", False)
|
||||
seo_enabled: bool = settings.get("seo_enabled", False)
|
||||
quota_reserve_tokens: int = settings.get("quota_reserve_tokens", 4000)
|
||||
rpm: int = settings.get("rpm", 200)
|
||||
tpm: int = settings.get("tpm", 1000000)
|
||||
tpm: int = settings.get("tpm", 1_000_000)
|
||||
free_rpm: int = settings.get("free_rpm", 20)
|
||||
free_tpm: int = settings.get("free_tpm", 100_000)
|
||||
anon_paid: bool = settings.get("anonymous_enabled_paid", False)
|
||||
anon_free: bool = settings.get("anonymous_enabled_free", False)
|
||||
litellm_params: dict = settings.get("litellm_params") or {}
|
||||
system_instructions: str = settings.get("system_instructions", "")
|
||||
use_default: bool = settings.get("use_default_system_instructions", True)
|
||||
|
|
@ -142,19 +229,24 @@ def _generate_configs(
|
|||
and _is_allowed_model(m)
|
||||
and "/" in m.get("id", "")
|
||||
]
|
||||
text_models.sort(key=lambda m: m["id"])
|
||||
|
||||
configs: list[dict] = []
|
||||
for idx, model in enumerate(text_models):
|
||||
taken: set[int] = set()
|
||||
now_ts = int(time.time())
|
||||
|
||||
for model in text_models:
|
||||
model_id: str = model["id"]
|
||||
name: str = model.get("name", model_id)
|
||||
tier = _openrouter_tier(model)
|
||||
|
||||
static_q = static_score_or(model, now_ts=now_ts)
|
||||
|
||||
cfg: dict[str, Any] = {
|
||||
"id": id_offset - idx,
|
||||
"id": _stable_config_id(model_id, id_offset, taken),
|
||||
"name": name,
|
||||
"description": f"{name} via OpenRouter",
|
||||
"billing_tier": billing_tier,
|
||||
"anonymous_enabled": anonymous_enabled,
|
||||
"billing_tier": tier,
|
||||
"anonymous_enabled": anon_free if tier == "free" else anon_paid,
|
||||
"seo_enabled": seo_enabled,
|
||||
"seo_slug": None,
|
||||
"quota_reserve_tokens": quota_reserve_tokens,
|
||||
|
|
@ -162,13 +254,28 @@ def _generate_configs(
|
|||
"model_name": model_id,
|
||||
"api_key": api_key,
|
||||
"api_base": "",
|
||||
"rpm": rpm,
|
||||
"tpm": tpm,
|
||||
"rpm": free_rpm if tier == "free" else rpm,
|
||||
"tpm": free_tpm if tier == "free" else tpm,
|
||||
"litellm_params": dict(litellm_params),
|
||||
"system_instructions": system_instructions,
|
||||
"use_default_system_instructions": use_default,
|
||||
"citations_enabled": citations_enabled,
|
||||
# Premium OR deployments join the LiteLLM router pool so sub-agent
|
||||
# model="auto" flows can load-balance / fail over across them.
|
||||
# Free OR deployments stay out: OpenRouter's free tier is a single
|
||||
# account-wide quota, so per-deployment routing can't spread load
|
||||
# there — it just drains the shared bucket faster.
|
||||
"router_pool_eligible": tier == "premium",
|
||||
_OPENROUTER_DYNAMIC_MARKER: True,
|
||||
# Auto (Fastest) ranking metadata. ``quality_score`` is initialised
|
||||
# to the static score and gets re-blended with health on the next
|
||||
# ``_enrich_health`` pass (synchronous on refresh, deferred on cold
|
||||
# start so startup latency is unchanged).
|
||||
"auto_pin_tier": "B" if tier == "premium" else "C",
|
||||
"quality_score_static": static_q,
|
||||
"quality_score_health": None,
|
||||
"quality_score": static_q,
|
||||
"health_gated": False,
|
||||
}
|
||||
configs.append(cfg)
|
||||
|
||||
|
|
@ -187,6 +294,12 @@ class OpenRouterIntegrationService:
|
|||
self._configs_by_id: dict[int, dict] = {}
|
||||
self._initialized = False
|
||||
self._refresh_task: asyncio.Task | None = None
|
||||
# Last-good per-model health snapshot. Survives across refresh
|
||||
# cycles so a transient OpenRouter /endpoints outage doesn't drop
|
||||
# every cfg back to static-only scoring.
|
||||
# Shape: {model_name: {"gated": bool, "score": float | None}}
|
||||
self._health_cache: dict[str, dict[str, Any]] = {}
|
||||
self._enrich_task: asyncio.Task | None = None
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls) -> "OpenRouterIntegrationService":
|
||||
|
|
@ -220,12 +333,27 @@ class OpenRouterIntegrationService:
|
|||
self._configs_by_id = {c["id"]: c for c in self._configs}
|
||||
self._initialized = True
|
||||
|
||||
tier_counts = self._tier_counts(self._configs)
|
||||
logger.info(
|
||||
"OpenRouter integration: loaded %d models (IDs %d to %d)",
|
||||
"OpenRouter integration: loaded %d models (free=%d, premium=%d)",
|
||||
len(self._configs),
|
||||
self._configs[0]["id"] if self._configs else 0,
|
||||
self._configs[-1]["id"] if self._configs else 0,
|
||||
tier_counts["free"],
|
||||
tier_counts["premium"],
|
||||
)
|
||||
|
||||
# Schedule the first health-enrichment pass as a deferred task so
|
||||
# cold-start latency is unchanged. Only valid when an event loop is
|
||||
# already running (e.g. FastAPI lifespan); Celery worker init is
|
||||
# fully sync so we silently skip — its first refresh tick (or the
|
||||
# next refresh from the web process) will populate health data.
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
self._enrich_task = loop.create_task(
|
||||
self._enrich_health_safely(self._configs)
|
||||
)
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
||||
return self._configs
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
|
|
@ -254,7 +382,225 @@ class OpenRouterIntegrationService:
|
|||
self._configs = new_configs
|
||||
self._configs_by_id = new_by_id
|
||||
|
||||
logger.info("OpenRouter refresh: updated to %d models", len(new_configs))
|
||||
# Catalogue churn invalidates per-config "recently healthy" credit
|
||||
# earned by the previous turn's preflight. Drop the whole table so
|
||||
# the next turn re-probes against the freshly loaded configs.
|
||||
try:
|
||||
from app.services.auto_model_pin_service import clear_healthy
|
||||
|
||||
clear_healthy()
|
||||
except Exception:
|
||||
logger.debug(
|
||||
"OpenRouter refresh: clear_healthy import skipped", exc_info=True
|
||||
)
|
||||
|
||||
tier_counts = self._tier_counts(new_configs)
|
||||
logger.info(
|
||||
"OpenRouter refresh: updated to %d models (free=%d, premium=%d)",
|
||||
len(new_configs),
|
||||
tier_counts["free"],
|
||||
tier_counts["premium"],
|
||||
)
|
||||
|
||||
# Re-blend health scores against the freshly fetched catalogue. Also
|
||||
# re-stamps health for any YAML-curated cfg with provider==OPENROUTER
|
||||
# so a hand-picked dead OR model is gated like a dynamic one.
|
||||
await self._enrich_health_safely(static_configs + new_configs, log_summary=True)
|
||||
|
||||
# Rebuild the LiteLLM router so freshly fetched configs flow through
|
||||
# (dynamic OR premium entries now opt into the pool, free ones stay
|
||||
# out; a refresh also needs to pick up any static-config edits and
|
||||
# reset cached context-window profiles).
|
||||
try:
|
||||
from app.config import config as _app_config
|
||||
from app.services.llm_router_service import (
|
||||
LLMRouterService,
|
||||
_router_instance_cache as _chat_router_cache,
|
||||
)
|
||||
|
||||
LLMRouterService.rebuild(
|
||||
_app_config.GLOBAL_LLM_CONFIGS,
|
||||
getattr(_app_config, "ROUTER_SETTINGS", None),
|
||||
)
|
||||
_chat_router_cache.clear()
|
||||
except Exception as exc:
|
||||
logger.warning("OpenRouter refresh: router rebuild skipped (%s)", exc)
|
||||
|
||||
@staticmethod
|
||||
def _tier_counts(configs: list[dict]) -> dict[str, int]:
|
||||
counts = {"free": 0, "premium": 0}
|
||||
for cfg in configs:
|
||||
tier = str(cfg.get("billing_tier", "")).lower()
|
||||
if tier in counts:
|
||||
counts[tier] += 1
|
||||
return counts
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Auto (Fastest) health enrichment
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def _enrich_health_safely(
|
||||
self, configs: list[dict], *, log_summary: bool = True
|
||||
) -> None:
|
||||
"""Wrapper around ``_enrich_health`` that swallows all errors.
|
||||
|
||||
Health enrichment is best-effort: any failure must leave cfgs in
|
||||
their static-only state and never break refresh / startup.
|
||||
"""
|
||||
try:
|
||||
await self._enrich_health(configs, log_summary=log_summary)
|
||||
except Exception:
|
||||
logger.exception("OpenRouter health enrichment failed")
|
||||
|
||||
async def _enrich_health(
|
||||
self, configs: list[dict], *, log_summary: bool = True
|
||||
) -> None:
|
||||
"""Fetch per-model ``/endpoints`` data for the top OR cfgs and blend
|
||||
the resulting health score into ``cfg["quality_score"]``.
|
||||
|
||||
Bounded fan-out: top-N per tier by ``quality_score_static`` only,
|
||||
with ``asyncio.Semaphore(_HEALTH_ENRICH_CONCURRENCY)`` guarding the
|
||||
outbound HTTP. Misses fall back to a per-model last-good cache; if
|
||||
the failure ratio crosses ``_HEALTH_FAIL_RATIO_FALLBACK`` we keep
|
||||
the entire previous cycle's cache for this run.
|
||||
"""
|
||||
or_cfgs = [
|
||||
c for c in configs if str(c.get("provider", "")).upper() == "OPENROUTER"
|
||||
]
|
||||
if not or_cfgs:
|
||||
return
|
||||
|
||||
premium_pool = sorted(
|
||||
[c for c in or_cfgs if str(c.get("billing_tier", "")).lower() == "premium"],
|
||||
key=lambda c: -int(c.get("quality_score_static") or 0),
|
||||
)[:_HEALTH_ENRICH_TOP_N_PREMIUM]
|
||||
free_pool = sorted(
|
||||
[c for c in or_cfgs if str(c.get("billing_tier", "")).lower() == "free"],
|
||||
key=lambda c: -int(c.get("quality_score_static") or 0),
|
||||
)[:_HEALTH_ENRICH_TOP_N_FREE]
|
||||
# De-duplicate while preserving order: a cfg shouldn't fall in both
|
||||
# tiers, but defensive code is cheap here.
|
||||
seen_ids: set[int] = set()
|
||||
selected: list[dict] = []
|
||||
for cfg in premium_pool + free_pool:
|
||||
cid = int(cfg.get("id", 0))
|
||||
if cid in seen_ids:
|
||||
continue
|
||||
seen_ids.add(cid)
|
||||
selected.append(cfg)
|
||||
|
||||
if not selected:
|
||||
return
|
||||
|
||||
api_key = str(self._settings.get("api_key") or "")
|
||||
semaphore = asyncio.Semaphore(_HEALTH_ENRICH_CONCURRENCY)
|
||||
|
||||
async with httpx.AsyncClient(timeout=_HEALTH_FETCH_TIMEOUT_SEC) as client:
|
||||
results = await asyncio.gather(
|
||||
*(
|
||||
self._fetch_endpoints(client, semaphore, api_key, cfg)
|
||||
for cfg in selected
|
||||
)
|
||||
)
|
||||
|
||||
fail_count = sum(1 for _, _, err in results if err is not None)
|
||||
fail_ratio = fail_count / len(results) if results else 0.0
|
||||
degraded = fail_ratio >= _HEALTH_FAIL_RATIO_FALLBACK
|
||||
if degraded:
|
||||
logger.warning(
|
||||
"auto_pin_health_enrich_degraded fail_ratio=%.2f total=%d "
|
||||
"using_last_good_cache=true",
|
||||
fail_ratio,
|
||||
len(results),
|
||||
)
|
||||
|
||||
# Per-cfg health update.
|
||||
for cfg, endpoints, err in results:
|
||||
model_name = str(cfg.get("model_name", ""))
|
||||
if not degraded and err is None and endpoints is not None:
|
||||
gated, h_score = aggregate_health(endpoints)
|
||||
cfg["health_gated"] = bool(gated)
|
||||
cfg["quality_score_health"] = h_score
|
||||
self._health_cache[model_name] = {
|
||||
"gated": bool(gated),
|
||||
"score": h_score,
|
||||
}
|
||||
else:
|
||||
cached = self._health_cache.get(model_name)
|
||||
if cached is not None:
|
||||
cfg["health_gated"] = bool(cached.get("gated", False))
|
||||
cfg["quality_score_health"] = cached.get("score")
|
||||
# else: keep current values (initial defaults from
|
||||
# _generate_configs / load_global_llm_configs).
|
||||
|
||||
# Blend health into the final score for every OR cfg, including
|
||||
# those outside the enriched top-N (they fall through to static).
|
||||
gated_count = 0
|
||||
by_provider: dict[str, int] = {}
|
||||
for cfg in or_cfgs:
|
||||
static_q = int(cfg.get("quality_score_static") or 0)
|
||||
h = cfg.get("quality_score_health")
|
||||
if h is not None and not cfg.get("health_gated"):
|
||||
blended = (
|
||||
_HEALTH_BLEND_WEIGHT * float(h)
|
||||
+ (1 - _HEALTH_BLEND_WEIGHT) * static_q
|
||||
)
|
||||
cfg["quality_score"] = round(blended)
|
||||
else:
|
||||
cfg["quality_score"] = static_q
|
||||
|
||||
if cfg.get("health_gated"):
|
||||
gated_count += 1
|
||||
model_id = str(cfg.get("model_name", ""))
|
||||
provider_slug = (
|
||||
model_id.split("/", 1)[0] if "/" in model_id else "unknown"
|
||||
)
|
||||
by_provider[provider_slug] = by_provider.get(provider_slug, 0) + 1
|
||||
|
||||
if log_summary:
|
||||
logger.info(
|
||||
"auto_pin_health_gated count=%d by_provider=%s fail_ratio=%.2f "
|
||||
"total_enriched=%d",
|
||||
gated_count,
|
||||
dict(sorted(by_provider.items(), key=lambda kv: -kv[1])),
|
||||
fail_ratio,
|
||||
len(selected),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def _fetch_endpoints(
|
||||
client: httpx.AsyncClient,
|
||||
semaphore: asyncio.Semaphore,
|
||||
api_key: str,
|
||||
cfg: dict,
|
||||
) -> tuple[dict, list[dict] | None, Exception | None]:
|
||||
"""Fetch ``/api/v1/models/{id}/endpoints`` for one cfg.
|
||||
|
||||
Returns ``(cfg, endpoints, err)`` so the caller can keep batched
|
||||
results aligned with their cfgs without raising.
|
||||
"""
|
||||
model_id = str(cfg.get("model_name", ""))
|
||||
if not model_id:
|
||||
return cfg, None, ValueError("missing model_name")
|
||||
|
||||
url = OPENROUTER_ENDPOINTS_URL_TEMPLATE.format(model_id=model_id)
|
||||
headers = {"Authorization": f"Bearer {api_key}"} if api_key else {}
|
||||
|
||||
async with semaphore:
|
||||
try:
|
||||
resp = await client.get(url, headers=headers)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
except Exception as exc:
|
||||
return cfg, None, exc
|
||||
|
||||
payload = data.get("data") if isinstance(data, dict) else None
|
||||
if not isinstance(payload, dict):
|
||||
return cfg, None, ValueError("malformed endpoints payload")
|
||||
endpoints = payload.get("endpoints")
|
||||
if not isinstance(endpoints, list):
|
||||
return cfg, [], None
|
||||
return cfg, endpoints, None
|
||||
|
||||
async def _refresh_loop(self, interval_hours: float) -> None:
|
||||
interval_sec = interval_hours * 3600
|
||||
|
|
|
|||
380
surfsense_backend/app/services/quality_score.py
Normal file
380
surfsense_backend/app/services/quality_score.py
Normal file
|
|
@ -0,0 +1,380 @@
|
|||
"""Pure-function quality scoring for Auto (Fastest) model selection.
|
||||
|
||||
This module is import-free of any service / request-path dependencies. All
|
||||
numbers are computed once during the OpenRouter refresh tick (or YAML load)
|
||||
and cached on the cfg dict, so the chat hot path only does a precomputed
|
||||
sort and a SHA256 pick.
|
||||
|
||||
Score components (0-100 scale, higher is better):
|
||||
|
||||
* ``static_score_or`` - derived from the bulk ``/api/v1/models`` payload
|
||||
(provider prestige + ``created`` recency + pricing band + context window
|
||||
+ capabilities + narrow tiny/legacy slug penalty).
|
||||
* ``static_score_yaml`` - same shape for hand-curated YAML configs, plus
|
||||
an operator-trust bonus (the operator deliberately picked this model).
|
||||
* ``aggregate_health`` - run on per-model ``/api/v1/models/{id}/endpoints``
|
||||
responses; returns ``(gated, score_or_none)``.
|
||||
|
||||
The blended ``quality_score`` (0.5 * static + 0.5 * health) is computed in
|
||||
:mod:`app.services.openrouter_integration_service` because that's the only
|
||||
caller that sees both halves.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tunables (constants, not flags)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Top-K size for deterministic spread inside the locked tier.
|
||||
_QUALITY_TOP_K: int = 5
|
||||
|
||||
# Hard health gate: any cfg whose best non-null uptime is below this %
|
||||
# is excluded from Auto-mode selection entirely.
|
||||
_HEALTH_GATE_UPTIME_PCT: float = 90.0
|
||||
|
||||
# Health/static blend weight when a cfg has fresh /endpoints data.
|
||||
_HEALTH_BLEND_WEIGHT: float = 0.5
|
||||
|
||||
# Static bonus applied to YAML cfgs because the operator hand-picked them.
|
||||
_OPERATOR_TRUST_BONUS: int = 20
|
||||
|
||||
# /endpoints fan-out is bounded per refresh tick.
|
||||
_HEALTH_ENRICH_TOP_N_PREMIUM: int = 50
|
||||
_HEALTH_ENRICH_TOP_N_FREE: int = 30
|
||||
_HEALTH_ENRICH_CONCURRENCY: int = 15
|
||||
_HEALTH_FETCH_TIMEOUT_SEC: float = 5.0
|
||||
|
||||
# If at least this fraction of /endpoints fetches fail in a refresh cycle,
|
||||
# fall back to the previous cycle's last-good cache instead of writing
|
||||
# partial / stale health values.
|
||||
_HEALTH_FAIL_RATIO_FALLBACK: float = 0.25
|
||||
|
||||
# Narrow tiny/legacy slug penalties only. We deliberately do NOT penalise
|
||||
# ``-nano`` / ``-mini`` / ``-lite`` because modern frontier models ship with
|
||||
# those naming patterns (``gpt-5-mini``, ``gemini-2.5-flash-lite`` etc.) and
|
||||
# blanket-penalising them suppresses high-quality picks.
|
||||
_TINY_LEGACY_PENALTY_PATTERNS: tuple[str, ...] = (
|
||||
"-1b-",
|
||||
"-1.2b-",
|
||||
"-1.5b-",
|
||||
"-2b-",
|
||||
"-3b-",
|
||||
"gemma-3n",
|
||||
"lfm-",
|
||||
"-base",
|
||||
"-distill",
|
||||
":nitro",
|
||||
"-preview",
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Provider prestige tables
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# OpenRouter-side provider slug (the prefix before ``/`` in the model id).
|
||||
# Tiers are coarse: frontier labs > strong open / fast-moving labs >
|
||||
# specialist labs > everything else.
|
||||
PROVIDER_PRESTIGE_OR: dict[str, int] = {
|
||||
# Frontier labs
|
||||
"openai": 50,
|
||||
"anthropic": 50,
|
||||
"google": 50,
|
||||
"x-ai": 50,
|
||||
# Strong open / fast-moving labs
|
||||
"deepseek": 38,
|
||||
"qwen": 38,
|
||||
"meta-llama": 38,
|
||||
"mistralai": 38,
|
||||
"cohere": 38,
|
||||
"nvidia": 38,
|
||||
"alibaba": 38,
|
||||
# Specialist / regional / strong second-tier
|
||||
"microsoft": 28,
|
||||
"01-ai": 28,
|
||||
"minimax": 28,
|
||||
"moonshot": 28,
|
||||
"z-ai": 28,
|
||||
"nousresearch": 28,
|
||||
"ai21": 28,
|
||||
"perplexity": 28,
|
||||
# Smaller / niche providers
|
||||
"liquid": 18,
|
||||
"cognitivecomputations": 18,
|
||||
"venice": 18,
|
||||
"inflection": 18,
|
||||
}
|
||||
|
||||
# YAML provider field (the upstream API shape the operator selected).
|
||||
PROVIDER_PRESTIGE_YAML: dict[str, int] = {
|
||||
"AZURE_OPENAI": 50,
|
||||
"OPENAI": 50,
|
||||
"ANTHROPIC": 50,
|
||||
"GOOGLE": 50,
|
||||
"VERTEX_AI": 50,
|
||||
"GEMINI": 50,
|
||||
"XAI": 50,
|
||||
"MISTRAL": 38,
|
||||
"DEEPSEEK": 38,
|
||||
"COHERE": 38,
|
||||
"GROQ": 30,
|
||||
"TOGETHER_AI": 28,
|
||||
"FIREWORKS_AI": 28,
|
||||
"PERPLEXITY": 28,
|
||||
"MINIMAX": 28,
|
||||
"BEDROCK": 28,
|
||||
"OPENROUTER": 25,
|
||||
"OLLAMA": 12,
|
||||
"CUSTOM": 12,
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Pure scoring helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Calibrated against the live /api/v1/models bulk dump. Frontier models
|
||||
# released in the last ~6 months (GPT-5 family, Claude 4.x, Gemini 2.5,
|
||||
# Grok 4) score in the 18-20 band; mid-2024 models in the 8-12 band;
|
||||
# anything older trails off.
|
||||
_RECENCY_BANDS_DAYS: tuple[tuple[int, int], ...] = (
|
||||
(60, 20),
|
||||
(180, 16),
|
||||
(365, 12),
|
||||
(540, 9),
|
||||
(730, 6),
|
||||
(1095, 3),
|
||||
)
|
||||
|
||||
|
||||
def created_recency_signal(created_ts: int | None, now_ts: int) -> int:
|
||||
"""Return 0-20 based on how recently the model was published.
|
||||
|
||||
Uses the OpenRouter ``created`` Unix timestamp (or any equivalent for
|
||||
YAML cfgs). Models without a usable timestamp get 0 (we don't penalise,
|
||||
we just don't reward).
|
||||
"""
|
||||
if created_ts is None or created_ts <= 0 or now_ts <= 0:
|
||||
return 0
|
||||
age_days = max(0, (now_ts - int(created_ts)) // 86_400)
|
||||
for cutoff, score in _RECENCY_BANDS_DAYS:
|
||||
if age_days <= cutoff:
|
||||
return score
|
||||
return 0
|
||||
|
||||
|
||||
def pricing_band(
|
||||
prompt: str | float | int | None,
|
||||
completion: str | float | int | None,
|
||||
) -> int:
|
||||
"""Return 0-15 based on combined prompt+completion cost per 1M tokens.
|
||||
|
||||
Higher-priced models tend to be the larger / more capable ones. A free
|
||||
model returns 0 (we use other signals to rank free-vs-free instead).
|
||||
Uncoercible inputs are treated as 0 rather than raising.
|
||||
"""
|
||||
|
||||
def _to_float(value) -> float:
|
||||
if value is None:
|
||||
return 0.0
|
||||
try:
|
||||
return float(value)
|
||||
except (TypeError, ValueError):
|
||||
return 0.0
|
||||
|
||||
p = _to_float(prompt)
|
||||
c = _to_float(completion)
|
||||
total_per_million = (p + c) * 1_000_000
|
||||
|
||||
if total_per_million >= 20.0:
|
||||
return 15
|
||||
if total_per_million >= 5.0:
|
||||
return 12
|
||||
if total_per_million >= 1.0:
|
||||
return 9
|
||||
if total_per_million >= 0.3:
|
||||
return 6
|
||||
if total_per_million >= 0.05:
|
||||
return 4
|
||||
if total_per_million > 0.0:
|
||||
return 2
|
||||
return 0
|
||||
|
||||
|
||||
def context_signal(ctx: int | None) -> int:
|
||||
"""Return 0-10 based on the model's context window."""
|
||||
if not ctx or ctx <= 0:
|
||||
return 0
|
||||
if ctx >= 1_000_000:
|
||||
return 10
|
||||
if ctx >= 400_000:
|
||||
return 8
|
||||
if ctx >= 200_000:
|
||||
return 6
|
||||
if ctx >= 128_000:
|
||||
return 4
|
||||
if ctx >= 100_000:
|
||||
return 2
|
||||
return 0
|
||||
|
||||
|
||||
def capabilities_signal(supported_parameters: list[str] | None) -> int:
|
||||
"""Return 0-5 for capabilities that matter for our agent flows."""
|
||||
if not supported_parameters:
|
||||
return 0
|
||||
params = set(supported_parameters)
|
||||
score = 0
|
||||
if "tools" in params:
|
||||
score += 2
|
||||
if "structured_outputs" in params or "response_format" in params:
|
||||
score += 2
|
||||
if "reasoning" in params or "include_reasoning" in params:
|
||||
score += 1
|
||||
return min(score, 5)
|
||||
|
||||
|
||||
def slug_penalty(model_id: str) -> int:
|
||||
"""Return a non-positive number; matches the narrow tiny/legacy patterns."""
|
||||
if not model_id:
|
||||
return 0
|
||||
needle = model_id.lower()
|
||||
for pattern in _TINY_LEGACY_PENALTY_PATTERNS:
|
||||
if pattern in needle:
|
||||
return -10
|
||||
return 0
|
||||
|
||||
|
||||
def _provider_prestige_or(model_id: str) -> int:
|
||||
if "/" not in model_id:
|
||||
return 0
|
||||
slug = model_id.split("/", 1)[0].lower()
|
||||
return PROVIDER_PRESTIGE_OR.get(slug, 15)
|
||||
|
||||
|
||||
def static_score_or(or_model: dict, *, now_ts: int) -> int:
|
||||
"""Score a raw OpenRouter ``/api/v1/models`` entry on a 0-100 scale."""
|
||||
model_id = str(or_model.get("id", ""))
|
||||
pricing = or_model.get("pricing") or {}
|
||||
|
||||
score = (
|
||||
_provider_prestige_or(model_id)
|
||||
+ created_recency_signal(or_model.get("created"), now_ts)
|
||||
+ pricing_band(pricing.get("prompt"), pricing.get("completion"))
|
||||
+ context_signal(or_model.get("context_length"))
|
||||
+ capabilities_signal(or_model.get("supported_parameters"))
|
||||
+ slug_penalty(model_id)
|
||||
)
|
||||
return max(0, min(100, int(score)))
|
||||
|
||||
|
||||
def static_score_yaml(cfg: dict) -> int:
|
||||
"""Score a YAML-curated cfg on a 0-100 scale.
|
||||
|
||||
Includes ``_OPERATOR_TRUST_BONUS`` because the operator deliberately
|
||||
listed this model. Pricing / context fall through to lazy ``litellm``
|
||||
lookups; failures are silent (we just lose those sub-points).
|
||||
"""
|
||||
provider = str(cfg.get("provider", "")).upper()
|
||||
base = PROVIDER_PRESTIGE_YAML.get(provider, 15)
|
||||
|
||||
model_name = cfg.get("model_name") or ""
|
||||
litellm_params = cfg.get("litellm_params") or {}
|
||||
lookup_name = (
|
||||
litellm_params.get("base_model") or litellm_params.get("model") or model_name
|
||||
)
|
||||
|
||||
ctx = 0
|
||||
p_cost: float = 0.0
|
||||
c_cost: float = 0.0
|
||||
try:
|
||||
from litellm import get_model_info # lazy: avoid cold-import cost
|
||||
|
||||
info = get_model_info(lookup_name) or {}
|
||||
ctx = int(info.get("max_input_tokens") or info.get("max_tokens") or 0)
|
||||
p_cost = float(info.get("input_cost_per_token") or 0.0)
|
||||
c_cost = float(info.get("output_cost_per_token") or 0.0)
|
||||
except Exception:
|
||||
# Unknown to litellm — that's fine for prestige+operator-bonus weighting.
|
||||
pass
|
||||
|
||||
score = (
|
||||
base
|
||||
+ _OPERATOR_TRUST_BONUS
|
||||
+ pricing_band(p_cost, c_cost)
|
||||
+ context_signal(ctx)
|
||||
+ slug_penalty(str(model_name))
|
||||
)
|
||||
return max(0, min(100, int(score)))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Health aggregation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _coerce_pct(value) -> float | None:
|
||||
try:
|
||||
if value is None:
|
||||
return None
|
||||
f = float(value)
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
if f < 0:
|
||||
return None
|
||||
# OpenRouter reports uptime as a 0-1 fraction; some endpoints surface it
|
||||
# as a 0-100 percentage. Normalise.
|
||||
return f * 100.0 if f <= 1.0 else f
|
||||
|
||||
|
||||
def _best_uptime(endpoints: list[dict]) -> tuple[float | None, str | None]:
|
||||
"""Pick the best (highest) non-null uptime across all endpoints.
|
||||
|
||||
Window preference: ``uptime_last_30m`` > ``uptime_last_1d`` >
|
||||
``uptime_last_5m``. Returns ``(uptime_pct, window_used)``.
|
||||
"""
|
||||
for window in ("uptime_last_30m", "uptime_last_1d", "uptime_last_5m"):
|
||||
values = [_coerce_pct(ep.get(window)) for ep in endpoints]
|
||||
values = [v for v in values if v is not None]
|
||||
if values:
|
||||
return max(values), window
|
||||
return None, None
|
||||
|
||||
|
||||
def aggregate_health(endpoints: list[dict]) -> tuple[bool, float | None]:
|
||||
"""Aggregate a model's per-endpoint health into ``(gated, score_or_none)``.
|
||||
|
||||
Hard gate (returns ``(True, None)``):
|
||||
* ``endpoints`` empty,
|
||||
* no endpoint reports ``status == 0`` (OK), or
|
||||
* best non-null uptime below ``_HEALTH_GATE_UPTIME_PCT``.
|
||||
|
||||
On a pass, returns a 0-100 health score blending uptime, status, and a
|
||||
freshness-weighted recent uptime sample.
|
||||
"""
|
||||
if not endpoints:
|
||||
return True, None
|
||||
|
||||
any_ok = any(int(ep.get("status", 1)) == 0 for ep in endpoints)
|
||||
if not any_ok:
|
||||
return True, None
|
||||
|
||||
best_uptime, _ = _best_uptime(endpoints)
|
||||
if best_uptime is None or best_uptime < _HEALTH_GATE_UPTIME_PCT:
|
||||
return True, None
|
||||
|
||||
# Freshness term: prefer 5m, fall through to 30m / 1d if 5m is missing.
|
||||
freshness = None
|
||||
for window in ("uptime_last_5m", "uptime_last_30m", "uptime_last_1d"):
|
||||
values = [_coerce_pct(ep.get(window)) for ep in endpoints]
|
||||
values = [v for v in values if v is not None]
|
||||
if values:
|
||||
freshness = max(values)
|
||||
break
|
||||
|
||||
uptime_term = best_uptime
|
||||
status_term = 100.0 if any_ok else 0.0
|
||||
freshness_term = freshness if freshness is not None else best_uptime
|
||||
|
||||
score = 0.50 * uptime_term + 0.30 * status_term + 0.20 * freshness_term
|
||||
return False, max(0.0, min(100.0, score))
|
||||
|
|
@ -64,7 +64,12 @@ from app.db import (
|
|||
shielded_async_session,
|
||||
)
|
||||
from app.prompts import TITLE_GENERATION_PROMPT
|
||||
from app.services.auto_model_pin_service import resolve_or_get_pinned_llm_config_id
|
||||
from app.services.auto_model_pin_service import (
|
||||
is_recently_healthy,
|
||||
mark_healthy,
|
||||
mark_runtime_cooldown,
|
||||
resolve_or_get_pinned_llm_config_id,
|
||||
)
|
||||
from app.services.chat_session_state_service import (
|
||||
clear_ai_responding,
|
||||
set_ai_responding,
|
||||
|
|
@ -299,20 +304,17 @@ def _tool_output_has_error(tool_output: Any) -> bool:
|
|||
return False
|
||||
|
||||
|
||||
def _extract_resolved_file_path(*, tool_name: str, tool_output: Any) -> str | None:
|
||||
def _extract_resolved_file_path(
|
||||
*, tool_name: str, tool_output: Any, tool_input: Any | None = None
|
||||
) -> str | None:
|
||||
if isinstance(tool_output, dict):
|
||||
path_value = tool_output.get("path")
|
||||
if isinstance(path_value, str) and path_value.strip():
|
||||
return path_value.strip()
|
||||
text = _tool_output_to_text(tool_output)
|
||||
if tool_name == "write_file":
|
||||
match = re.search(r"Updated file\s+(.+)$", text.strip())
|
||||
if match:
|
||||
return match.group(1).strip()
|
||||
if tool_name == "edit_file":
|
||||
match = re.search(r"in '([^']+)'", text)
|
||||
if match:
|
||||
return match.group(1).strip()
|
||||
if tool_name in ("write_file", "edit_file") and isinstance(tool_input, dict):
|
||||
file_path = tool_input.get("file_path")
|
||||
if isinstance(file_path, str) and file_path.strip():
|
||||
return file_path.strip()
|
||||
return None
|
||||
|
||||
|
||||
|
|
@ -414,6 +416,108 @@ def _parse_error_payload(message: str) -> dict[str, Any] | None:
|
|||
return None
|
||||
|
||||
|
||||
def _extract_provider_error_code(parsed: dict[str, Any] | None) -> int | None:
|
||||
if not isinstance(parsed, dict):
|
||||
return None
|
||||
candidates: list[Any] = [parsed.get("code")]
|
||||
nested = parsed.get("error")
|
||||
if isinstance(nested, dict):
|
||||
candidates.append(nested.get("code"))
|
||||
for value in candidates:
|
||||
try:
|
||||
if value is None:
|
||||
continue
|
||||
return int(value)
|
||||
except Exception:
|
||||
continue
|
||||
return None
|
||||
|
||||
|
||||
def _is_provider_rate_limited(exc: BaseException) -> bool:
|
||||
"""Best-effort detection for provider-side runtime throttling.
|
||||
|
||||
Covers LiteLLM/OpenRouter shapes like:
|
||||
- class name contains ``RateLimit``
|
||||
- nested payload ``{"error": {"code": 429}}``
|
||||
- nested payload ``{"error": {"type": "rate_limit_error"}}``
|
||||
"""
|
||||
raw = str(exc)
|
||||
lowered = raw.lower()
|
||||
if "ratelimit" in type(exc).__name__.lower():
|
||||
return True
|
||||
parsed = _parse_error_payload(raw)
|
||||
provider_code = _extract_provider_error_code(parsed)
|
||||
if provider_code == 429:
|
||||
return True
|
||||
|
||||
provider_error_type = ""
|
||||
if parsed:
|
||||
top_type = parsed.get("type")
|
||||
if isinstance(top_type, str):
|
||||
provider_error_type = top_type.lower()
|
||||
nested = parsed.get("error")
|
||||
if isinstance(nested, dict):
|
||||
nested_type = nested.get("type")
|
||||
if isinstance(nested_type, str):
|
||||
provider_error_type = nested_type.lower()
|
||||
if provider_error_type == "rate_limit_error":
|
||||
return True
|
||||
|
||||
return (
|
||||
"rate limited" in lowered
|
||||
or "rate-limited" in lowered
|
||||
or "temporarily rate-limited upstream" in lowered
|
||||
)
|
||||
|
||||
|
||||
_PREFLIGHT_TIMEOUT_SEC: float = 2.5
|
||||
_PREFLIGHT_MAX_TOKENS: int = 1
|
||||
|
||||
|
||||
async def _preflight_llm(llm: Any) -> None:
|
||||
"""Issue a minimal completion to confirm the pinned model isn't 429'ing.
|
||||
|
||||
Used before agent build / planner / classifier / title-gen so a known-bad
|
||||
free OpenRouter deployment is detected and repinned before it cascades
|
||||
into multiple wasted internal calls. The probe is intentionally cheap:
|
||||
one token, low timeout, tagged ``surfsense:internal`` so token tracking
|
||||
and SSE pipelines treat it as overhead rather than user output.
|
||||
|
||||
Raises the original exception when the provider responds with a
|
||||
rate-limit-shaped error so the caller can drive the cooldown/repin
|
||||
branch via :func:`_is_provider_rate_limited`. Other transient failures
|
||||
are swallowed — the caller continues to the normal stream path and the
|
||||
in-stream recovery loop remains the safety net.
|
||||
"""
|
||||
from litellm import acompletion
|
||||
|
||||
model = getattr(llm, "model", None)
|
||||
if not model or model == "auto":
|
||||
# Auto-mode router doesn't have a single deployment to ping; the
|
||||
# router itself handles per-deployment rate-limit accounting.
|
||||
return
|
||||
|
||||
try:
|
||||
await acompletion(
|
||||
model=model,
|
||||
messages=[{"role": "user", "content": "ping"}],
|
||||
api_key=getattr(llm, "api_key", None),
|
||||
api_base=getattr(llm, "api_base", None),
|
||||
max_tokens=_PREFLIGHT_MAX_TOKENS,
|
||||
timeout=_PREFLIGHT_TIMEOUT_SEC,
|
||||
stream=False,
|
||||
metadata={"tags": ["surfsense:internal", "auto-pin-preflight"]},
|
||||
)
|
||||
except Exception as exc:
|
||||
if _is_provider_rate_limited(exc):
|
||||
raise
|
||||
logging.getLogger(__name__).debug(
|
||||
"auto_pin_preflight non_rate_limit_error model=%s err=%s",
|
||||
model,
|
||||
exc,
|
||||
)
|
||||
|
||||
|
||||
def _classify_stream_exception(
|
||||
exc: Exception,
|
||||
*,
|
||||
|
|
@ -449,19 +553,7 @@ def _classify_stream_exception(
|
|||
None,
|
||||
)
|
||||
|
||||
parsed = _parse_error_payload(raw)
|
||||
provider_error_type = ""
|
||||
if parsed:
|
||||
top_type = parsed.get("type")
|
||||
if isinstance(top_type, str):
|
||||
provider_error_type = top_type.lower()
|
||||
nested = parsed.get("error")
|
||||
if isinstance(nested, dict):
|
||||
nested_type = nested.get("type")
|
||||
if isinstance(nested_type, str):
|
||||
provider_error_type = nested_type.lower()
|
||||
|
||||
if provider_error_type == "rate_limit_error":
|
||||
if _is_provider_rate_limited(exc):
|
||||
return (
|
||||
"rate_limited",
|
||||
"RATE_LIMITED",
|
||||
|
|
@ -619,6 +711,7 @@ async def _stream_agent_events(
|
|||
# fallback path only and never re-pops a chunk we already streamed.
|
||||
pending_tool_call_chunks: list[dict[str, Any]] = []
|
||||
lc_tool_call_id_by_run: dict[str, str] = {}
|
||||
file_path_by_run: dict[str, str] = {}
|
||||
|
||||
# parity_v2 only: live tool-call argument streaming. ``index_to_meta``
|
||||
# is keyed by the chunk's ``index`` field — LangChain
|
||||
|
|
@ -797,6 +890,10 @@ async def _stream_agent_events(
|
|||
tool_input = event.get("data", {}).get("input", {})
|
||||
if tool_name in ("write_file", "edit_file"):
|
||||
result.write_attempted = True
|
||||
if isinstance(tool_input, dict):
|
||||
file_path = tool_input.get("file_path")
|
||||
if isinstance(file_path, str) and file_path.strip() and run_id:
|
||||
file_path_by_run[run_id] = file_path.strip()
|
||||
|
||||
if current_text_id is not None:
|
||||
yield streaming_service.format_text_end(current_text_id)
|
||||
|
|
@ -1203,6 +1300,7 @@ async def _stream_agent_events(
|
|||
run_id = event.get("run_id", "")
|
||||
tool_name = event.get("name", "unknown_tool")
|
||||
raw_output = event.get("data", {}).get("output", "")
|
||||
staged_file_path = file_path_by_run.pop(run_id, None) if run_id else None
|
||||
|
||||
if tool_name == "update_memory":
|
||||
called_update_memory = True
|
||||
|
|
@ -1716,6 +1814,9 @@ async def _stream_agent_events(
|
|||
resolved_path = _extract_resolved_file_path(
|
||||
tool_name=tool_name,
|
||||
tool_output=tool_output,
|
||||
tool_input={"file_path": staged_file_path}
|
||||
if staged_file_path
|
||||
else None,
|
||||
)
|
||||
result_text = _tool_output_to_text(tool_output)
|
||||
if _tool_output_has_error(tool_output):
|
||||
|
|
@ -2326,6 +2427,91 @@ async def stream_new_chat(
|
|||
yield streaming_service.format_done()
|
||||
return
|
||||
|
||||
# Auto-mode preflight ping. Runs ONLY for thread-pinned auto cfgs
|
||||
# (negative ids selected via ``resolve_or_get_pinned_llm_config_id``)
|
||||
# whose health hasn't already been confirmed within the TTL window.
|
||||
# Detecting a 429 here lets us repin BEFORE the planner/classifier/
|
||||
# title-generation LLM calls fan out and each independently hit the
|
||||
# same upstream rate limit.
|
||||
if (
|
||||
requested_llm_config_id == 0
|
||||
and llm_config_id < 0
|
||||
and not is_recently_healthy(llm_config_id)
|
||||
):
|
||||
_t_preflight = time.perf_counter()
|
||||
try:
|
||||
await _preflight_llm(llm)
|
||||
mark_healthy(llm_config_id)
|
||||
_perf_log.info(
|
||||
"[stream_new_chat] auto_pin_preflight ok config_id=%s took=%.3fs",
|
||||
llm_config_id,
|
||||
time.perf_counter() - _t_preflight,
|
||||
)
|
||||
except Exception as preflight_exc:
|
||||
if not _is_provider_rate_limited(preflight_exc):
|
||||
raise
|
||||
previous_config_id = llm_config_id
|
||||
mark_runtime_cooldown(
|
||||
previous_config_id, reason="preflight_rate_limited"
|
||||
)
|
||||
try:
|
||||
llm_config_id = (
|
||||
await resolve_or_get_pinned_llm_config_id(
|
||||
session,
|
||||
thread_id=chat_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
selected_llm_config_id=0,
|
||||
exclude_config_ids={previous_config_id},
|
||||
)
|
||||
).resolved_llm_config_id
|
||||
except ValueError as pin_error:
|
||||
yield _emit_stream_error(
|
||||
message=str(pin_error),
|
||||
error_kind="server_error",
|
||||
error_code="SERVER_ERROR",
|
||||
)
|
||||
yield streaming_service.format_done()
|
||||
return
|
||||
|
||||
llm, agent_config, llm_load_error = await _load_llm_bundle(
|
||||
llm_config_id
|
||||
)
|
||||
if llm_load_error or not llm:
|
||||
yield _emit_stream_error(
|
||||
message=llm_load_error or "Failed to create LLM instance",
|
||||
error_kind="server_error",
|
||||
error_code="SERVER_ERROR",
|
||||
)
|
||||
yield streaming_service.format_done()
|
||||
return
|
||||
# Trust the freshly-resolved cfg for the remainder of this
|
||||
# turn rather than recursing into another preflight; the
|
||||
# in-stream 429 recovery loop is still in place as the
|
||||
# safety net if even this fallback hits an upstream cap.
|
||||
mark_healthy(llm_config_id)
|
||||
_log_chat_stream_error(
|
||||
flow=flow,
|
||||
error_kind="rate_limited",
|
||||
error_code="RATE_LIMITED",
|
||||
severity="info",
|
||||
is_expected=True,
|
||||
request_id=request_id,
|
||||
thread_id=chat_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
message=(
|
||||
"Auto-pinned model failed preflight; switched to another "
|
||||
"eligible model and continuing."
|
||||
),
|
||||
extra={
|
||||
"auto_runtime_recover": True,
|
||||
"preflight": True,
|
||||
"previous_config_id": previous_config_id,
|
||||
"fallback_config_id": llm_config_id,
|
||||
},
|
||||
)
|
||||
|
||||
# Create connector service
|
||||
_t0 = time.perf_counter()
|
||||
connector_service = ConnectorService(session, search_space_id=search_space_id)
|
||||
|
|
@ -2671,54 +2857,155 @@ async def stream_new_chat(
|
|||
|
||||
_t_stream_start = time.perf_counter()
|
||||
_first_event_logged = False
|
||||
async for sse in _stream_agent_events(
|
||||
agent=agent,
|
||||
config=config,
|
||||
input_data=input_state,
|
||||
streaming_service=streaming_service,
|
||||
result=stream_result,
|
||||
step_prefix="thinking",
|
||||
initial_step_id=initial_step_id,
|
||||
initial_step_title=initial_title,
|
||||
initial_step_items=initial_items,
|
||||
fallback_commit_search_space_id=search_space_id,
|
||||
fallback_commit_created_by_id=user_id,
|
||||
fallback_commit_filesystem_mode=(
|
||||
filesystem_selection.mode
|
||||
if filesystem_selection
|
||||
else FilesystemMode.CLOUD
|
||||
),
|
||||
fallback_commit_thread_id=chat_id,
|
||||
):
|
||||
if not _first_event_logged:
|
||||
_perf_log.info(
|
||||
"[stream_new_chat] First agent event in %.3fs (time since stream start), "
|
||||
"%.3fs (total since request start) (chat_id=%s)",
|
||||
time.perf_counter() - _t_stream_start,
|
||||
time.perf_counter() - _t_total,
|
||||
chat_id,
|
||||
)
|
||||
_first_event_logged = True
|
||||
yield sse
|
||||
|
||||
# Inject title update mid-stream as soon as the background task finishes
|
||||
if title_task is not None and title_task.done() and not title_emitted:
|
||||
generated_title, title_usage = title_task.result()
|
||||
if title_usage:
|
||||
accumulator.add(**title_usage)
|
||||
if generated_title:
|
||||
async with shielded_async_session() as title_session:
|
||||
title_thread_result = await title_session.execute(
|
||||
select(NewChatThread).filter(NewChatThread.id == chat_id)
|
||||
runtime_rate_limit_recovered = False
|
||||
while True:
|
||||
try:
|
||||
async for sse in _stream_agent_events(
|
||||
agent=agent,
|
||||
config=config,
|
||||
input_data=input_state,
|
||||
streaming_service=streaming_service,
|
||||
result=stream_result,
|
||||
step_prefix="thinking",
|
||||
initial_step_id=initial_step_id,
|
||||
initial_step_title=initial_title,
|
||||
initial_step_items=initial_items,
|
||||
fallback_commit_search_space_id=search_space_id,
|
||||
fallback_commit_created_by_id=user_id,
|
||||
fallback_commit_filesystem_mode=(
|
||||
filesystem_selection.mode
|
||||
if filesystem_selection
|
||||
else FilesystemMode.CLOUD
|
||||
),
|
||||
fallback_commit_thread_id=chat_id,
|
||||
):
|
||||
if not _first_event_logged:
|
||||
_perf_log.info(
|
||||
"[stream_new_chat] First agent event in %.3fs (time since stream start), "
|
||||
"%.3fs (total since request start) (chat_id=%s)",
|
||||
time.perf_counter() - _t_stream_start,
|
||||
time.perf_counter() - _t_total,
|
||||
chat_id,
|
||||
)
|
||||
title_thread = title_thread_result.scalars().first()
|
||||
if title_thread:
|
||||
title_thread.title = generated_title
|
||||
await title_session.commit()
|
||||
yield streaming_service.format_thread_title_update(
|
||||
chat_id, generated_title
|
||||
_first_event_logged = True
|
||||
yield sse
|
||||
|
||||
# Inject title update mid-stream as soon as the background
|
||||
# task finishes.
|
||||
if (
|
||||
title_task is not None
|
||||
and title_task.done()
|
||||
and not title_emitted
|
||||
):
|
||||
generated_title, title_usage = title_task.result()
|
||||
if title_usage:
|
||||
accumulator.add(**title_usage)
|
||||
if generated_title:
|
||||
async with shielded_async_session() as title_session:
|
||||
title_thread_result = await title_session.execute(
|
||||
select(NewChatThread).filter(
|
||||
NewChatThread.id == chat_id
|
||||
)
|
||||
)
|
||||
title_thread = title_thread_result.scalars().first()
|
||||
if title_thread:
|
||||
title_thread.title = generated_title
|
||||
await title_session.commit()
|
||||
yield streaming_service.format_thread_title_update(
|
||||
chat_id, generated_title
|
||||
)
|
||||
title_emitted = True
|
||||
break
|
||||
except Exception as stream_exc:
|
||||
can_runtime_recover = (
|
||||
not runtime_rate_limit_recovered
|
||||
and requested_llm_config_id == 0
|
||||
and llm_config_id < 0
|
||||
and not _first_event_logged
|
||||
and _is_provider_rate_limited(stream_exc)
|
||||
)
|
||||
if not can_runtime_recover:
|
||||
raise
|
||||
|
||||
runtime_rate_limit_recovered = True
|
||||
previous_config_id = llm_config_id
|
||||
# The failed attempt may still hold the per-thread busy mutex
|
||||
# (middleware teardown can lag behind raised provider errors).
|
||||
# Force release before we retry within the same request.
|
||||
end_turn(str(chat_id))
|
||||
mark_runtime_cooldown(
|
||||
previous_config_id,
|
||||
reason="provider_rate_limited",
|
||||
)
|
||||
|
||||
llm_config_id = (
|
||||
await resolve_or_get_pinned_llm_config_id(
|
||||
session,
|
||||
thread_id=chat_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
selected_llm_config_id=0,
|
||||
exclude_config_ids={previous_config_id},
|
||||
)
|
||||
title_emitted = True
|
||||
).resolved_llm_config_id
|
||||
|
||||
llm, agent_config, llm_load_error = await _load_llm_bundle(
|
||||
llm_config_id
|
||||
)
|
||||
if llm_load_error:
|
||||
raise stream_exc
|
||||
|
||||
# Title generation uses the initial llm object. After a runtime
|
||||
# repin we keep the stream focused on response recovery and skip
|
||||
# title generation for this turn.
|
||||
if title_task is not None and not title_task.done():
|
||||
title_task.cancel()
|
||||
title_task = None
|
||||
|
||||
_t0 = time.perf_counter()
|
||||
agent = await create_surfsense_deep_agent(
|
||||
llm=llm,
|
||||
search_space_id=search_space_id,
|
||||
db_session=session,
|
||||
connector_service=connector_service,
|
||||
checkpointer=checkpointer,
|
||||
user_id=user_id,
|
||||
thread_id=chat_id,
|
||||
agent_config=agent_config,
|
||||
firecrawl_api_key=firecrawl_api_key,
|
||||
thread_visibility=visibility,
|
||||
disabled_tools=disabled_tools,
|
||||
mentioned_document_ids=mentioned_document_ids,
|
||||
filesystem_selection=filesystem_selection,
|
||||
)
|
||||
_perf_log.info(
|
||||
"[stream_new_chat] Runtime rate-limit recovery repinned "
|
||||
"config_id=%s -> %s and rebuilt agent in %.3fs",
|
||||
previous_config_id,
|
||||
llm_config_id,
|
||||
time.perf_counter() - _t0,
|
||||
)
|
||||
_log_chat_stream_error(
|
||||
flow=flow,
|
||||
error_kind="rate_limited",
|
||||
error_code="RATE_LIMITED",
|
||||
severity="info",
|
||||
is_expected=True,
|
||||
request_id=request_id,
|
||||
thread_id=chat_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
message=(
|
||||
"Auto-pinned model hit runtime rate limit; switched to "
|
||||
"another eligible model and retried."
|
||||
),
|
||||
extra={
|
||||
"auto_runtime_recover": True,
|
||||
"previous_config_id": previous_config_id,
|
||||
"fallback_config_id": llm_config_id,
|
||||
},
|
||||
)
|
||||
continue
|
||||
|
||||
_perf_log.info(
|
||||
"[stream_new_chat] Agent stream completed in %.3fs (chat_id=%s)",
|
||||
|
|
@ -3187,6 +3474,84 @@ async def stream_resume_chat(
|
|||
yield streaming_service.format_done()
|
||||
return
|
||||
|
||||
# Auto-mode preflight ping (resume path). Mirrors ``stream_new_chat``:
|
||||
# one cheap probe before the agent is rebuilt so a 429'd pin gets
|
||||
# repinned without burning planner/classifier/title calls first.
|
||||
if (
|
||||
requested_llm_config_id == 0
|
||||
and llm_config_id < 0
|
||||
and not is_recently_healthy(llm_config_id)
|
||||
):
|
||||
_t_preflight = time.perf_counter()
|
||||
try:
|
||||
await _preflight_llm(llm)
|
||||
mark_healthy(llm_config_id)
|
||||
_perf_log.info(
|
||||
"[stream_resume] auto_pin_preflight ok config_id=%s took=%.3fs",
|
||||
llm_config_id,
|
||||
time.perf_counter() - _t_preflight,
|
||||
)
|
||||
except Exception as preflight_exc:
|
||||
if not _is_provider_rate_limited(preflight_exc):
|
||||
raise
|
||||
previous_config_id = llm_config_id
|
||||
mark_runtime_cooldown(
|
||||
previous_config_id, reason="preflight_rate_limited"
|
||||
)
|
||||
try:
|
||||
llm_config_id = (
|
||||
await resolve_or_get_pinned_llm_config_id(
|
||||
session,
|
||||
thread_id=chat_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
selected_llm_config_id=0,
|
||||
exclude_config_ids={previous_config_id},
|
||||
)
|
||||
).resolved_llm_config_id
|
||||
except ValueError as pin_error:
|
||||
yield _emit_stream_error(
|
||||
message=str(pin_error),
|
||||
error_kind="server_error",
|
||||
error_code="SERVER_ERROR",
|
||||
)
|
||||
yield streaming_service.format_done()
|
||||
return
|
||||
|
||||
llm, agent_config, llm_load_error = await _load_llm_bundle(
|
||||
llm_config_id
|
||||
)
|
||||
if llm_load_error or not llm:
|
||||
yield _emit_stream_error(
|
||||
message=llm_load_error or "Failed to create LLM instance",
|
||||
error_kind="server_error",
|
||||
error_code="SERVER_ERROR",
|
||||
)
|
||||
yield streaming_service.format_done()
|
||||
return
|
||||
mark_healthy(llm_config_id)
|
||||
_log_chat_stream_error(
|
||||
flow="resume",
|
||||
error_kind="rate_limited",
|
||||
error_code="RATE_LIMITED",
|
||||
severity="info",
|
||||
is_expected=True,
|
||||
request_id=request_id,
|
||||
thread_id=chat_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
message=(
|
||||
"Auto-pinned model failed preflight; switched to another "
|
||||
"eligible model and continuing."
|
||||
),
|
||||
extra={
|
||||
"auto_runtime_recover": True,
|
||||
"preflight": True,
|
||||
"previous_config_id": previous_config_id,
|
||||
"fallback_config_id": llm_config_id,
|
||||
},
|
||||
)
|
||||
|
||||
_t0 = time.perf_counter()
|
||||
connector_service = ConnectorService(session, search_space_id=search_space_id)
|
||||
|
||||
|
|
@ -3265,31 +3630,114 @@ async def stream_resume_chat(
|
|||
|
||||
_t_stream_start = time.perf_counter()
|
||||
_first_event_logged = False
|
||||
async for sse in _stream_agent_events(
|
||||
agent=agent,
|
||||
config=config,
|
||||
input_data=Command(resume={"decisions": decisions}),
|
||||
streaming_service=streaming_service,
|
||||
result=stream_result,
|
||||
step_prefix="thinking-resume",
|
||||
fallback_commit_search_space_id=search_space_id,
|
||||
fallback_commit_created_by_id=user_id,
|
||||
fallback_commit_filesystem_mode=(
|
||||
filesystem_selection.mode
|
||||
if filesystem_selection
|
||||
else FilesystemMode.CLOUD
|
||||
),
|
||||
fallback_commit_thread_id=chat_id,
|
||||
):
|
||||
if not _first_event_logged:
|
||||
_perf_log.info(
|
||||
"[stream_resume] First agent event in %.3fs (stream), %.3fs (total) (chat_id=%s)",
|
||||
time.perf_counter() - _t_stream_start,
|
||||
time.perf_counter() - _t_total,
|
||||
chat_id,
|
||||
runtime_rate_limit_recovered = False
|
||||
while True:
|
||||
try:
|
||||
async for sse in _stream_agent_events(
|
||||
agent=agent,
|
||||
config=config,
|
||||
input_data=Command(resume={"decisions": decisions}),
|
||||
streaming_service=streaming_service,
|
||||
result=stream_result,
|
||||
step_prefix="thinking-resume",
|
||||
fallback_commit_search_space_id=search_space_id,
|
||||
fallback_commit_created_by_id=user_id,
|
||||
fallback_commit_filesystem_mode=(
|
||||
filesystem_selection.mode
|
||||
if filesystem_selection
|
||||
else FilesystemMode.CLOUD
|
||||
),
|
||||
fallback_commit_thread_id=chat_id,
|
||||
):
|
||||
if not _first_event_logged:
|
||||
_perf_log.info(
|
||||
"[stream_resume] First agent event in %.3fs (stream), %.3fs (total) (chat_id=%s)",
|
||||
time.perf_counter() - _t_stream_start,
|
||||
time.perf_counter() - _t_total,
|
||||
chat_id,
|
||||
)
|
||||
_first_event_logged = True
|
||||
yield sse
|
||||
break
|
||||
except Exception as stream_exc:
|
||||
can_runtime_recover = (
|
||||
not runtime_rate_limit_recovered
|
||||
and requested_llm_config_id == 0
|
||||
and llm_config_id < 0
|
||||
and not _first_event_logged
|
||||
and _is_provider_rate_limited(stream_exc)
|
||||
)
|
||||
_first_event_logged = True
|
||||
yield sse
|
||||
if not can_runtime_recover:
|
||||
raise
|
||||
|
||||
runtime_rate_limit_recovered = True
|
||||
previous_config_id = llm_config_id
|
||||
# Ensure the same-request recovery retry does not trip the
|
||||
# BusyMutex lock retained by the failed attempt.
|
||||
end_turn(str(chat_id))
|
||||
mark_runtime_cooldown(
|
||||
previous_config_id,
|
||||
reason="provider_rate_limited",
|
||||
)
|
||||
llm_config_id = (
|
||||
await resolve_or_get_pinned_llm_config_id(
|
||||
session,
|
||||
thread_id=chat_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
selected_llm_config_id=0,
|
||||
exclude_config_ids={previous_config_id},
|
||||
)
|
||||
).resolved_llm_config_id
|
||||
|
||||
llm, agent_config, llm_load_error = await _load_llm_bundle(
|
||||
llm_config_id
|
||||
)
|
||||
if llm_load_error:
|
||||
raise stream_exc
|
||||
|
||||
_t0 = time.perf_counter()
|
||||
agent = await create_surfsense_deep_agent(
|
||||
llm=llm,
|
||||
search_space_id=search_space_id,
|
||||
db_session=session,
|
||||
connector_service=connector_service,
|
||||
checkpointer=checkpointer,
|
||||
user_id=user_id,
|
||||
thread_id=chat_id,
|
||||
agent_config=agent_config,
|
||||
firecrawl_api_key=firecrawl_api_key,
|
||||
thread_visibility=visibility,
|
||||
filesystem_selection=filesystem_selection,
|
||||
)
|
||||
_perf_log.info(
|
||||
"[stream_resume] Runtime rate-limit recovery repinned "
|
||||
"config_id=%s -> %s and rebuilt agent in %.3fs",
|
||||
previous_config_id,
|
||||
llm_config_id,
|
||||
time.perf_counter() - _t0,
|
||||
)
|
||||
_log_chat_stream_error(
|
||||
flow="resume",
|
||||
error_kind="rate_limited",
|
||||
error_code="RATE_LIMITED",
|
||||
severity="info",
|
||||
is_expected=True,
|
||||
request_id=request_id,
|
||||
thread_id=chat_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
message=(
|
||||
"Auto-pinned model hit runtime rate limit; switched to "
|
||||
"another eligible model and retried."
|
||||
),
|
||||
extra={
|
||||
"auto_runtime_recover": True,
|
||||
"previous_config_id": previous_config_id,
|
||||
"fallback_config_id": llm_config_id,
|
||||
},
|
||||
)
|
||||
continue
|
||||
_perf_log.info(
|
||||
"[stream_resume] Agent stream completed in %.3fs (chat_id=%s)",
|
||||
time.perf_counter() - _t_stream_start,
|
||||
|
|
|
|||
|
|
@ -118,3 +118,37 @@ async def test_end_turn_force_clears_lock_and_cancel_state() -> None:
|
|||
assert not manager.lock_for(thread_id).locked()
|
||||
assert not get_cancel_event(thread_id).is_set()
|
||||
assert is_cancel_requested(thread_id) is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_busy_mutex_stale_aafter_does_not_release_new_attempt_lock() -> None:
|
||||
"""A stale aafter call from attempt A must not unlock attempt B.
|
||||
|
||||
Repro flow:
|
||||
1) attempt A acquires thread lock
|
||||
2) forced end_turn clears A so retry can proceed
|
||||
3) attempt B acquires same thread lock
|
||||
4) stale attempt-A aafter runs late
|
||||
|
||||
Expected: B lock remains held.
|
||||
"""
|
||||
thread_id = "stale-aafter-lock"
|
||||
runtime = _Runtime(thread_id)
|
||||
attempt_a = BusyMutexMiddleware()
|
||||
attempt_b = BusyMutexMiddleware()
|
||||
|
||||
await attempt_a.abefore_agent({}, runtime)
|
||||
lock = manager.lock_for(thread_id)
|
||||
assert lock.locked()
|
||||
|
||||
end_turn(thread_id)
|
||||
assert not lock.locked()
|
||||
|
||||
await attempt_b.abefore_agent({}, runtime)
|
||||
assert lock.locked()
|
||||
|
||||
# Stale cleanup from attempt A must not release attempt B's lock.
|
||||
await attempt_a.aafter_agent({}, runtime)
|
||||
assert lock.locked()
|
||||
|
||||
await attempt_b.aafter_agent({}, runtime)
|
||||
|
|
|
|||
|
|
@ -6,13 +6,26 @@ from types import SimpleNamespace
|
|||
import pytest
|
||||
|
||||
from app.services.auto_model_pin_service import (
|
||||
AUTO_FASTEST_MODE,
|
||||
clear_healthy,
|
||||
clear_runtime_cooldown,
|
||||
is_recently_healthy,
|
||||
mark_healthy,
|
||||
mark_runtime_cooldown,
|
||||
resolve_or_get_pinned_llm_config_id,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _clear_runtime_cooldown_map():
|
||||
clear_runtime_cooldown()
|
||||
clear_healthy()
|
||||
yield
|
||||
clear_runtime_cooldown()
|
||||
clear_healthy()
|
||||
|
||||
|
||||
@dataclass
|
||||
class _FakeQuotaResult:
|
||||
allowed: bool
|
||||
|
|
@ -45,14 +58,11 @@ def _thread(
|
|||
*,
|
||||
search_space_id: int = 10,
|
||||
pinned_llm_config_id: int | None = None,
|
||||
pinned_auto_mode: str | None = None,
|
||||
):
|
||||
return SimpleNamespace(
|
||||
id=1,
|
||||
search_space_id=search_space_id,
|
||||
pinned_llm_config_id=pinned_llm_config_id,
|
||||
pinned_auto_mode=pinned_auto_mode,
|
||||
pinned_at=None,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -93,8 +103,6 @@ async def test_auto_first_turn_pins_one_model(monkeypatch):
|
|||
)
|
||||
assert result.resolved_llm_config_id in {-1, -2}
|
||||
assert session.thread.pinned_llm_config_id == result.resolved_llm_config_id
|
||||
assert session.thread.pinned_auto_mode == AUTO_FASTEST_MODE
|
||||
assert session.thread.pinned_at is not None
|
||||
assert session.commit_count == 1
|
||||
|
||||
|
||||
|
|
@ -102,9 +110,7 @@ async def test_auto_first_turn_pins_one_model(monkeypatch):
|
|||
async def test_next_turn_reuses_existing_pin(monkeypatch):
|
||||
from app.config import config
|
||||
|
||||
session = _FakeSession(
|
||||
_thread(pinned_llm_config_id=-1, pinned_auto_mode=AUTO_FASTEST_MODE)
|
||||
)
|
||||
session = _FakeSession(_thread(pinned_llm_config_id=-1))
|
||||
monkeypatch.setattr(
|
||||
config,
|
||||
"GLOBAL_LLM_CONFIGS",
|
||||
|
|
@ -228,9 +234,7 @@ async def test_premium_ineligible_auto_pins_free_only(monkeypatch):
|
|||
async def test_pinned_premium_stays_premium_after_quota_exhaustion(monkeypatch):
|
||||
from app.config import config
|
||||
|
||||
session = _FakeSession(
|
||||
_thread(pinned_llm_config_id=-1, pinned_auto_mode=AUTO_FASTEST_MODE)
|
||||
)
|
||||
session = _FakeSession(_thread(pinned_llm_config_id=-1))
|
||||
monkeypatch.setattr(
|
||||
config,
|
||||
"GLOBAL_LLM_CONFIGS",
|
||||
|
|
@ -275,9 +279,7 @@ async def test_pinned_premium_stays_premium_after_quota_exhaustion(monkeypatch):
|
|||
async def test_force_repin_free_switches_auto_premium_pin_to_free(monkeypatch):
|
||||
from app.config import config
|
||||
|
||||
session = _FakeSession(
|
||||
_thread(pinned_llm_config_id=-1, pinned_auto_mode=AUTO_FASTEST_MODE)
|
||||
)
|
||||
session = _FakeSession(_thread(pinned_llm_config_id=-1))
|
||||
monkeypatch.setattr(
|
||||
config,
|
||||
"GLOBAL_LLM_CONFIGS",
|
||||
|
|
@ -325,9 +327,7 @@ async def test_force_repin_free_switches_auto_premium_pin_to_free(monkeypatch):
|
|||
async def test_explicit_user_model_change_clears_pin(monkeypatch):
|
||||
from app.config import config
|
||||
|
||||
session = _FakeSession(
|
||||
_thread(pinned_llm_config_id=-2, pinned_auto_mode=AUTO_FASTEST_MODE)
|
||||
)
|
||||
session = _FakeSession(_thread(pinned_llm_config_id=-2))
|
||||
monkeypatch.setattr(
|
||||
config,
|
||||
"GLOBAL_LLM_CONFIGS",
|
||||
|
|
@ -345,8 +345,6 @@ async def test_explicit_user_model_change_clears_pin(monkeypatch):
|
|||
)
|
||||
assert result.resolved_llm_config_id == 7
|
||||
assert session.thread.pinned_llm_config_id is None
|
||||
assert session.thread.pinned_auto_mode is None
|
||||
assert session.thread.pinned_at is None
|
||||
assert session.commit_count == 1
|
||||
|
||||
|
||||
|
|
@ -354,9 +352,7 @@ async def test_explicit_user_model_change_clears_pin(monkeypatch):
|
|||
async def test_invalid_pinned_config_repairs_with_new_pin(monkeypatch):
|
||||
from app.config import config
|
||||
|
||||
session = _FakeSession(
|
||||
_thread(pinned_llm_config_id=-999, pinned_auto_mode=AUTO_FASTEST_MODE)
|
||||
)
|
||||
session = _FakeSession(_thread(pinned_llm_config_id=-999))
|
||||
monkeypatch.setattr(
|
||||
config,
|
||||
"GLOBAL_LLM_CONFIGS",
|
||||
|
|
@ -383,3 +379,543 @@ async def test_invalid_pinned_config_repairs_with_new_pin(monkeypatch):
|
|||
assert result.resolved_llm_config_id == -2
|
||||
assert session.thread.pinned_llm_config_id == -2
|
||||
assert session.commit_count == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Quality-aware pin selection (Auto Fastest upgrade)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_gated_config_is_excluded_from_selection(monkeypatch):
|
||||
"""A cfg flagged ``health_gated`` must never be picked even if it has
|
||||
the highest score among eligible cfgs."""
|
||||
from app.config import config
|
||||
|
||||
session = _FakeSession(_thread())
|
||||
monkeypatch.setattr(
|
||||
config,
|
||||
"GLOBAL_LLM_CONFIGS",
|
||||
[
|
||||
{
|
||||
"id": -1,
|
||||
"provider": "OPENROUTER",
|
||||
"model_name": "venice/dead-model",
|
||||
"api_key": "k1",
|
||||
"billing_tier": "free",
|
||||
"auto_pin_tier": "C",
|
||||
"quality_score": 95,
|
||||
"health_gated": True,
|
||||
},
|
||||
{
|
||||
"id": -2,
|
||||
"provider": "OPENROUTER",
|
||||
"model_name": "google/gemini-flash",
|
||||
"api_key": "k1",
|
||||
"billing_tier": "free",
|
||||
"auto_pin_tier": "C",
|
||||
"quality_score": 60,
|
||||
"health_gated": False,
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
async def _blocked(*_args, **_kwargs):
|
||||
return _FakeQuotaResult(allowed=False)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage",
|
||||
_blocked,
|
||||
)
|
||||
|
||||
result = await resolve_or_get_pinned_llm_config_id(
|
||||
session,
|
||||
thread_id=1,
|
||||
search_space_id=10,
|
||||
user_id="00000000-0000-0000-0000-000000000001",
|
||||
selected_llm_config_id=0,
|
||||
)
|
||||
assert result.resolved_llm_config_id == -2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tier_a_locks_first_premium_user_skips_or(monkeypatch):
|
||||
"""Premium-eligible users with Tier A available should never spill to
|
||||
Tier B even if a B cfg ranks higher by ``quality_score``."""
|
||||
from app.config import config
|
||||
|
||||
session = _FakeSession(_thread())
|
||||
monkeypatch.setattr(
|
||||
config,
|
||||
"GLOBAL_LLM_CONFIGS",
|
||||
[
|
||||
{
|
||||
"id": -1,
|
||||
"provider": "AZURE_OPENAI",
|
||||
"model_name": "gpt-5",
|
||||
"api_key": "k-yaml",
|
||||
"billing_tier": "premium",
|
||||
"auto_pin_tier": "A",
|
||||
"quality_score": 70,
|
||||
"health_gated": False,
|
||||
},
|
||||
{
|
||||
"id": -2,
|
||||
"provider": "OPENROUTER",
|
||||
"model_name": "openai/gpt-5",
|
||||
"api_key": "k-or",
|
||||
"billing_tier": "premium",
|
||||
"auto_pin_tier": "B",
|
||||
"quality_score": 95,
|
||||
"health_gated": False,
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
async def _allowed(*_args, **_kwargs):
|
||||
return _FakeQuotaResult(allowed=True)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage",
|
||||
_allowed,
|
||||
)
|
||||
|
||||
result = await resolve_or_get_pinned_llm_config_id(
|
||||
session,
|
||||
thread_id=1,
|
||||
search_space_id=10,
|
||||
user_id="00000000-0000-0000-0000-000000000001",
|
||||
selected_llm_config_id=0,
|
||||
)
|
||||
assert result.resolved_llm_config_id == -1
|
||||
assert result.resolved_tier == "premium"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tier_a_falls_through_to_or_when_a_pool_empty_for_user(monkeypatch):
|
||||
"""Free-only user with no Tier A free cfg should pick from Tier C."""
|
||||
from app.config import config
|
||||
|
||||
session = _FakeSession(_thread())
|
||||
monkeypatch.setattr(
|
||||
config,
|
||||
"GLOBAL_LLM_CONFIGS",
|
||||
[
|
||||
{
|
||||
"id": -1,
|
||||
"provider": "AZURE_OPENAI",
|
||||
"model_name": "gpt-5",
|
||||
"api_key": "k-yaml",
|
||||
"billing_tier": "premium",
|
||||
"auto_pin_tier": "A",
|
||||
"quality_score": 100,
|
||||
"health_gated": False,
|
||||
},
|
||||
{
|
||||
"id": -2,
|
||||
"provider": "OPENROUTER",
|
||||
"model_name": "google/gemini-flash:free",
|
||||
"api_key": "k-or",
|
||||
"billing_tier": "free",
|
||||
"auto_pin_tier": "C",
|
||||
"quality_score": 60,
|
||||
"health_gated": False,
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
async def _blocked(*_args, **_kwargs):
|
||||
return _FakeQuotaResult(allowed=False)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage",
|
||||
_blocked,
|
||||
)
|
||||
|
||||
result = await resolve_or_get_pinned_llm_config_id(
|
||||
session,
|
||||
thread_id=1,
|
||||
search_space_id=10,
|
||||
user_id="00000000-0000-0000-0000-000000000001",
|
||||
selected_llm_config_id=0,
|
||||
)
|
||||
assert result.resolved_llm_config_id == -2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_top_k_picks_only_high_score_models(monkeypatch):
|
||||
"""Different thread IDs should spread across top-K, never pick the
|
||||
obvious low-quality cfg even when it sits in the candidate list."""
|
||||
from app.config import config
|
||||
|
||||
high_score_cfgs = [
|
||||
{
|
||||
"id": -i,
|
||||
"provider": "AZURE_OPENAI",
|
||||
"model_name": f"gpt-x-{i}",
|
||||
"api_key": "k",
|
||||
"billing_tier": "premium",
|
||||
"auto_pin_tier": "A",
|
||||
"quality_score": 90,
|
||||
"health_gated": False,
|
||||
}
|
||||
for i in range(1, 6) # 5 high-quality Tier A cfgs
|
||||
]
|
||||
low_score_trap = {
|
||||
"id": -99,
|
||||
"provider": "AZURE_OPENAI",
|
||||
"model_name": "tiny-legacy",
|
||||
"api_key": "k",
|
||||
"billing_tier": "premium",
|
||||
"auto_pin_tier": "A",
|
||||
"quality_score": 10,
|
||||
"health_gated": False,
|
||||
}
|
||||
monkeypatch.setattr(
|
||||
config,
|
||||
"GLOBAL_LLM_CONFIGS",
|
||||
[*high_score_cfgs, low_score_trap],
|
||||
)
|
||||
|
||||
async def _allowed(*_args, **_kwargs):
|
||||
return _FakeQuotaResult(allowed=True)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage",
|
||||
_allowed,
|
||||
)
|
||||
|
||||
high_score_ids = {c["id"] for c in high_score_cfgs}
|
||||
seen = set()
|
||||
for thread_id in range(1, 50):
|
||||
session = _FakeSession(_thread())
|
||||
result = await resolve_or_get_pinned_llm_config_id(
|
||||
session,
|
||||
thread_id=thread_id,
|
||||
search_space_id=10,
|
||||
user_id="00000000-0000-0000-0000-000000000001",
|
||||
selected_llm_config_id=0,
|
||||
)
|
||||
seen.add(result.resolved_llm_config_id)
|
||||
assert result.resolved_llm_config_id != -99, (
|
||||
"low-score trap cfg should never be picked"
|
||||
)
|
||||
assert result.resolved_llm_config_id in high_score_ids
|
||||
|
||||
# Spread across at least a couple of top-K cfgs.
|
||||
assert len(seen) > 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pin_reuse_survives_health_gating_for_existing_pin(monkeypatch):
|
||||
"""An *already* pinned cfg that later flips to ``health_gated`` should
|
||||
still not be reused — gated cfgs are filtered out of the candidate
|
||||
pool, which forces a repair to a healthy cfg.
|
||||
|
||||
This guards the no-silent-tier-switch invariant: we don't keep using
|
||||
a known-broken model just because the thread happened to be pinned
|
||||
to it before the gate fired."""
|
||||
from app.config import config
|
||||
|
||||
session = _FakeSession(_thread(pinned_llm_config_id=-1))
|
||||
monkeypatch.setattr(
|
||||
config,
|
||||
"GLOBAL_LLM_CONFIGS",
|
||||
[
|
||||
{
|
||||
"id": -1,
|
||||
"provider": "OPENROUTER",
|
||||
"model_name": "venice/dead-model",
|
||||
"api_key": "k",
|
||||
"billing_tier": "premium",
|
||||
"auto_pin_tier": "B",
|
||||
"quality_score": 50,
|
||||
"health_gated": True,
|
||||
},
|
||||
{
|
||||
"id": -2,
|
||||
"provider": "AZURE_OPENAI",
|
||||
"model_name": "gpt-5",
|
||||
"api_key": "k",
|
||||
"billing_tier": "premium",
|
||||
"auto_pin_tier": "A",
|
||||
"quality_score": 90,
|
||||
"health_gated": False,
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
async def _allowed(*_args, **_kwargs):
|
||||
return _FakeQuotaResult(allowed=True)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage",
|
||||
_allowed,
|
||||
)
|
||||
|
||||
result = await resolve_or_get_pinned_llm_config_id(
|
||||
session,
|
||||
thread_id=1,
|
||||
search_space_id=10,
|
||||
user_id="00000000-0000-0000-0000-000000000001",
|
||||
selected_llm_config_id=0,
|
||||
)
|
||||
assert result.resolved_llm_config_id == -2
|
||||
assert result.from_existing_pin is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pin_reuse_regression_existing_healthy_pin(monkeypatch):
|
||||
"""Existing pin reuse must short-circuit the new tier/score logic."""
|
||||
from app.config import config
|
||||
|
||||
session = _FakeSession(_thread(pinned_llm_config_id=-1))
|
||||
monkeypatch.setattr(
|
||||
config,
|
||||
"GLOBAL_LLM_CONFIGS",
|
||||
[
|
||||
{
|
||||
"id": -1,
|
||||
"provider": "AZURE_OPENAI",
|
||||
"model_name": "gpt-5",
|
||||
"api_key": "k",
|
||||
"billing_tier": "premium",
|
||||
"auto_pin_tier": "A",
|
||||
"quality_score": 50, # lower than -2
|
||||
"health_gated": False,
|
||||
},
|
||||
{
|
||||
"id": -2,
|
||||
"provider": "AZURE_OPENAI",
|
||||
"model_name": "gpt-5-pro",
|
||||
"api_key": "k",
|
||||
"billing_tier": "premium",
|
||||
"auto_pin_tier": "A",
|
||||
"quality_score": 99,
|
||||
"health_gated": False,
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
async def _must_not_call(*_args, **_kwargs):
|
||||
raise AssertionError("premium_get_usage should not run on pin reuse")
|
||||
|
||||
monkeypatch.setattr(
|
||||
"app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage",
|
||||
_must_not_call,
|
||||
)
|
||||
|
||||
result = await resolve_or_get_pinned_llm_config_id(
|
||||
session,
|
||||
thread_id=1,
|
||||
search_space_id=10,
|
||||
user_id="00000000-0000-0000-0000-000000000001",
|
||||
selected_llm_config_id=0,
|
||||
)
|
||||
assert result.resolved_llm_config_id == -1
|
||||
assert result.from_existing_pin is True
|
||||
assert session.commit_count == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runtime_cooled_down_pin_is_not_reused(monkeypatch):
|
||||
"""A runtime-cooled config should be excluded from candidate reuse.
|
||||
|
||||
This enables one-shot recovery from transient provider 429 bursts: we can
|
||||
mark the pinned cfg as cooled down and force a repair to another eligible
|
||||
cfg on the next resolution.
|
||||
"""
|
||||
from app.config import config
|
||||
|
||||
session = _FakeSession(_thread(pinned_llm_config_id=-1))
|
||||
monkeypatch.setattr(
|
||||
config,
|
||||
"GLOBAL_LLM_CONFIGS",
|
||||
[
|
||||
{
|
||||
"id": -1,
|
||||
"provider": "OPENROUTER",
|
||||
"model_name": "google/gemma-4-26b-a4b-it:free",
|
||||
"api_key": "k",
|
||||
"billing_tier": "free",
|
||||
"auto_pin_tier": "C",
|
||||
"quality_score": 90,
|
||||
"health_gated": False,
|
||||
},
|
||||
{
|
||||
"id": -2,
|
||||
"provider": "OPENROUTER",
|
||||
"model_name": "google/gemini-2.5-flash:free",
|
||||
"api_key": "k",
|
||||
"billing_tier": "free",
|
||||
"auto_pin_tier": "C",
|
||||
"quality_score": 80,
|
||||
"health_gated": False,
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
async def _blocked(*_args, **_kwargs):
|
||||
return _FakeQuotaResult(allowed=False)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage",
|
||||
_blocked,
|
||||
)
|
||||
|
||||
mark_runtime_cooldown(-1, reason="provider_rate_limited", cooldown_seconds=600)
|
||||
|
||||
result = await resolve_or_get_pinned_llm_config_id(
|
||||
session,
|
||||
thread_id=1,
|
||||
search_space_id=10,
|
||||
user_id="00000000-0000-0000-0000-000000000001",
|
||||
selected_llm_config_id=0,
|
||||
)
|
||||
assert result.resolved_llm_config_id == -2
|
||||
assert result.from_existing_pin is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_clearing_runtime_cooldown_restores_pin_reuse(monkeypatch):
|
||||
from app.config import config
|
||||
|
||||
session = _FakeSession(_thread(pinned_llm_config_id=-1))
|
||||
monkeypatch.setattr(
|
||||
config,
|
||||
"GLOBAL_LLM_CONFIGS",
|
||||
[
|
||||
{
|
||||
"id": -1,
|
||||
"provider": "OPENROUTER",
|
||||
"model_name": "google/gemma-4-26b-a4b-it:free",
|
||||
"api_key": "k",
|
||||
"billing_tier": "free",
|
||||
"auto_pin_tier": "C",
|
||||
"quality_score": 90,
|
||||
"health_gated": False,
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
async def _must_not_call(*_args, **_kwargs):
|
||||
raise AssertionError("premium_get_usage should not run on healthy pin reuse")
|
||||
|
||||
monkeypatch.setattr(
|
||||
"app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage",
|
||||
_must_not_call,
|
||||
)
|
||||
|
||||
mark_runtime_cooldown(-1, reason="provider_rate_limited", cooldown_seconds=600)
|
||||
clear_runtime_cooldown(-1)
|
||||
|
||||
result = await resolve_or_get_pinned_llm_config_id(
|
||||
session,
|
||||
thread_id=1,
|
||||
search_space_id=10,
|
||||
user_id="00000000-0000-0000-0000-000000000001",
|
||||
selected_llm_config_id=0,
|
||||
)
|
||||
assert result.resolved_llm_config_id == -1
|
||||
assert result.from_existing_pin is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auto_pin_repin_excludes_previous_config_on_runtime_retry(monkeypatch):
|
||||
"""Runtime retry should never repin the just-failed config."""
|
||||
from app.config import config
|
||||
|
||||
session = _FakeSession(_thread(pinned_llm_config_id=-1))
|
||||
monkeypatch.setattr(
|
||||
config,
|
||||
"GLOBAL_LLM_CONFIGS",
|
||||
[
|
||||
{
|
||||
"id": -1,
|
||||
"provider": "OPENROUTER",
|
||||
"model_name": "google/gemma-4-26b-a4b-it:free",
|
||||
"api_key": "k",
|
||||
"billing_tier": "free",
|
||||
"auto_pin_tier": "C",
|
||||
"quality_score": 90,
|
||||
"health_gated": False,
|
||||
},
|
||||
{
|
||||
"id": -2,
|
||||
"provider": "OPENROUTER",
|
||||
"model_name": "google/gemini-2.5-flash:free",
|
||||
"api_key": "k",
|
||||
"billing_tier": "free",
|
||||
"auto_pin_tier": "C",
|
||||
"quality_score": 80,
|
||||
"health_gated": False,
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
async def _blocked(*_args, **_kwargs):
|
||||
return _FakeQuotaResult(allowed=False)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage",
|
||||
_blocked,
|
||||
)
|
||||
|
||||
result = await resolve_or_get_pinned_llm_config_id(
|
||||
session,
|
||||
thread_id=1,
|
||||
search_space_id=10,
|
||||
user_id="00000000-0000-0000-0000-000000000001",
|
||||
selected_llm_config_id=0,
|
||||
exclude_config_ids={-1},
|
||||
)
|
||||
assert result.resolved_llm_config_id == -2
|
||||
assert result.from_existing_pin is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Healthy-status cache (preflight TTL companion)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_mark_healthy_then_is_recently_healthy_true_within_ttl():
|
||||
mark_healthy(-42, ttl_seconds=60)
|
||||
assert is_recently_healthy(-42) is True
|
||||
|
||||
|
||||
def test_healthy_expires_after_ttl(monkeypatch):
|
||||
import app.services.auto_model_pin_service as svc
|
||||
|
||||
real_time = svc.time.time
|
||||
base = real_time()
|
||||
|
||||
monkeypatch.setattr(svc.time, "time", lambda: base)
|
||||
mark_healthy(-7, ttl_seconds=10)
|
||||
assert is_recently_healthy(-7) is True
|
||||
|
||||
monkeypatch.setattr(svc.time, "time", lambda: base + 11)
|
||||
assert is_recently_healthy(-7) is False
|
||||
|
||||
|
||||
def test_mark_runtime_cooldown_invalidates_healthy_cache():
|
||||
mark_healthy(-9, ttl_seconds=60)
|
||||
assert is_recently_healthy(-9) is True
|
||||
|
||||
mark_runtime_cooldown(-9, reason="test", cooldown_seconds=60)
|
||||
assert is_recently_healthy(-9) is False
|
||||
|
||||
|
||||
def test_clear_healthy_removes_single_entry():
|
||||
mark_healthy(-11, ttl_seconds=60)
|
||||
mark_healthy(-12, ttl_seconds=60)
|
||||
clear_healthy(-11)
|
||||
assert is_recently_healthy(-11) is False
|
||||
assert is_recently_healthy(-12) is True
|
||||
|
||||
|
||||
def test_clear_healthy_no_args_drops_all_entries():
|
||||
mark_healthy(-21, ttl_seconds=60)
|
||||
mark_healthy(-22, ttl_seconds=60)
|
||||
clear_healthy()
|
||||
assert is_recently_healthy(-21) is False
|
||||
assert is_recently_healthy(-22) is False
|
||||
|
|
|
|||
|
|
@ -0,0 +1,226 @@
|
|||
"""LLMRouterService pool-filter / rebuild tests.
|
||||
|
||||
These tests focus on the *config plumbing* (which configs enter the router
|
||||
pool, rebuild resets state correctly). They stub out the underlying
|
||||
``litellm.Router`` so we don't need real API keys or network access.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.llm_router_service import LLMRouterService
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
def _fake_yaml_config(
|
||||
*,
|
||||
id: int,
|
||||
model_name: str,
|
||||
billing_tier: str = "free",
|
||||
) -> dict:
|
||||
return {
|
||||
"id": id,
|
||||
"name": f"yaml-{id}",
|
||||
"provider": "OPENAI",
|
||||
"model_name": model_name,
|
||||
"api_key": "sk-test",
|
||||
"api_base": "",
|
||||
"billing_tier": billing_tier,
|
||||
"rpm": 100,
|
||||
"tpm": 100_000,
|
||||
"litellm_params": {},
|
||||
}
|
||||
|
||||
|
||||
def _fake_openrouter_config(
|
||||
*,
|
||||
id: int,
|
||||
model_name: str,
|
||||
billing_tier: str,
|
||||
router_pool_eligible: bool | None = None,
|
||||
) -> dict:
|
||||
"""Build a synthetic dynamic-OR config dict for router-pool tests.
|
||||
|
||||
Defaults mirror Strategy 3: premium OR enters the pool, free OR stays
|
||||
out. Callers can override ``router_pool_eligible`` to simulate legacy
|
||||
configs or to regression-test the filter mechanics directly.
|
||||
"""
|
||||
if router_pool_eligible is None:
|
||||
router_pool_eligible = billing_tier == "premium"
|
||||
return {
|
||||
"id": id,
|
||||
"name": f"or-{id}",
|
||||
"provider": "OPENROUTER",
|
||||
"model_name": model_name,
|
||||
"api_key": "sk-or-test",
|
||||
"api_base": "",
|
||||
"billing_tier": billing_tier,
|
||||
"rpm": 20 if billing_tier == "free" else 200,
|
||||
"tpm": 100_000 if billing_tier == "free" else 1_000_000,
|
||||
"litellm_params": {},
|
||||
"router_pool_eligible": router_pool_eligible,
|
||||
}
|
||||
|
||||
|
||||
def _reset_router_singleton() -> None:
|
||||
instance = LLMRouterService.get_instance()
|
||||
instance._initialized = False
|
||||
instance._router = None
|
||||
instance._model_list = []
|
||||
instance._premium_model_strings = set()
|
||||
|
||||
|
||||
def test_router_pool_includes_or_premium_excludes_or_free():
|
||||
"""Strategy 3: premium OR joins the pool, free OR stays out.
|
||||
|
||||
Dynamic OpenRouter premium entries opt into load balancing alongside
|
||||
curated YAML configs. Dynamic OR free entries are intentionally kept
|
||||
out because OpenRouter's free tier enforces a single account-global
|
||||
quota bucket that per-deployment router accounting can't represent.
|
||||
"""
|
||||
_reset_router_singleton()
|
||||
configs = [
|
||||
_fake_yaml_config(id=-1, model_name="gpt-4o", billing_tier="premium"),
|
||||
_fake_yaml_config(id=-2, model_name="gpt-4o-mini", billing_tier="free"),
|
||||
_fake_openrouter_config(
|
||||
id=-10_001, model_name="openai/gpt-4o", billing_tier="premium"
|
||||
),
|
||||
_fake_openrouter_config(
|
||||
id=-10_002,
|
||||
model_name="meta-llama/llama-3.3-70b:free",
|
||||
billing_tier="free",
|
||||
),
|
||||
]
|
||||
|
||||
with (
|
||||
patch("app.services.llm_router_service.Router") as mock_router,
|
||||
patch(
|
||||
"app.services.llm_router_service.LLMRouterService._build_context_fallback_groups"
|
||||
) as mock_ctx_fb,
|
||||
):
|
||||
mock_ctx_fb.side_effect = lambda ml: (ml, None)
|
||||
mock_router.return_value = object()
|
||||
LLMRouterService.initialize(configs)
|
||||
|
||||
pool_models = {
|
||||
dep["litellm_params"]["model"]
|
||||
for dep in LLMRouterService.get_instance()._model_list
|
||||
}
|
||||
# YAML premium + YAML free + dynamic OR premium are all in the pool.
|
||||
# Dynamic OR free is NOT (shared-bucket rate limits can't be load-balanced).
|
||||
assert pool_models == {
|
||||
"openai/gpt-4o",
|
||||
"openai/gpt-4o-mini",
|
||||
"openrouter/openai/gpt-4o",
|
||||
}
|
||||
|
||||
prem = LLMRouterService.get_instance()._premium_model_strings
|
||||
# YAML premium is fingerprinted under both its model_string and its
|
||||
# ``base_model`` form (existing behavior we don't want to regress).
|
||||
assert "openai/gpt-4o" in prem
|
||||
# Dynamic OR premium is now fingerprinted as premium so pool-level
|
||||
# calls through the router are billed against premium quota.
|
||||
assert "openrouter/openai/gpt-4o" in prem
|
||||
assert LLMRouterService.is_premium_model("openrouter/openai/gpt-4o") is True
|
||||
# Dynamic OR free never enters the pool, so it's never counted as premium.
|
||||
assert (
|
||||
LLMRouterService.is_premium_model("openrouter/meta-llama/llama-3.3-70b:free")
|
||||
is False
|
||||
)
|
||||
|
||||
|
||||
def test_router_pool_filter_mechanics_respect_override():
|
||||
"""The ``router_pool_eligible`` filter itself works independently of tier.
|
||||
|
||||
Regression guard: if a future refactor ever sets the flag False on a
|
||||
premium config (e.g. for maintenance), that config MUST be skipped by
|
||||
``initialize`` even though its tier is premium.
|
||||
"""
|
||||
_reset_router_singleton()
|
||||
configs = [
|
||||
_fake_yaml_config(id=-1, model_name="gpt-4o", billing_tier="premium"),
|
||||
_fake_openrouter_config(
|
||||
id=-10_001,
|
||||
model_name="openai/gpt-4o",
|
||||
billing_tier="premium",
|
||||
router_pool_eligible=False, # opt out despite being premium
|
||||
),
|
||||
]
|
||||
|
||||
with (
|
||||
patch("app.services.llm_router_service.Router") as mock_router,
|
||||
patch(
|
||||
"app.services.llm_router_service.LLMRouterService._build_context_fallback_groups"
|
||||
) as mock_ctx_fb,
|
||||
):
|
||||
mock_ctx_fb.side_effect = lambda ml: (ml, None)
|
||||
mock_router.return_value = object()
|
||||
LLMRouterService.initialize(configs)
|
||||
|
||||
pool_models = {
|
||||
dep["litellm_params"]["model"]
|
||||
for dep in LLMRouterService.get_instance()._model_list
|
||||
}
|
||||
assert pool_models == {"openai/gpt-4o"}
|
||||
assert LLMRouterService.is_premium_model("openrouter/openai/gpt-4o") is False
|
||||
|
||||
|
||||
def test_rebuild_refreshes_pool_after_configs_change():
|
||||
_reset_router_singleton()
|
||||
configs_v1 = [
|
||||
_fake_yaml_config(id=-1, model_name="gpt-4o", billing_tier="premium"),
|
||||
]
|
||||
configs_v2 = [
|
||||
*configs_v1,
|
||||
_fake_yaml_config(id=-2, model_name="gpt-4o-mini", billing_tier="free"),
|
||||
]
|
||||
|
||||
with (
|
||||
patch("app.services.llm_router_service.Router") as mock_router,
|
||||
patch(
|
||||
"app.services.llm_router_service.LLMRouterService._build_context_fallback_groups"
|
||||
) as mock_ctx_fb,
|
||||
):
|
||||
mock_ctx_fb.side_effect = lambda ml: (ml, None)
|
||||
mock_router.return_value = object()
|
||||
|
||||
LLMRouterService.initialize(configs_v1)
|
||||
assert len(LLMRouterService.get_instance()._model_list) == 1
|
||||
|
||||
# ``initialize`` should be a no-op here (already initialized).
|
||||
LLMRouterService.initialize(configs_v2)
|
||||
assert len(LLMRouterService.get_instance()._model_list) == 1
|
||||
|
||||
# ``rebuild`` must clear the guard and re-run with the new configs.
|
||||
LLMRouterService.rebuild(configs_v2)
|
||||
assert len(LLMRouterService.get_instance()._model_list) == 2
|
||||
|
||||
|
||||
def test_auto_model_pin_candidates_include_dynamic_openrouter():
|
||||
"""Dynamic OR configs must remain Auto-mode thread-pin candidates.
|
||||
|
||||
Guards against a future regression where someone adds the
|
||||
``router_pool_eligible`` filter to ``auto_model_pin_service._global_candidates``.
|
||||
"""
|
||||
from app.config import config
|
||||
from app.services.auto_model_pin_service import _global_candidates
|
||||
|
||||
or_premium = _fake_openrouter_config(
|
||||
id=-10_001, model_name="openai/gpt-4o", billing_tier="premium"
|
||||
)
|
||||
or_free = _fake_openrouter_config(
|
||||
id=-10_002,
|
||||
model_name="meta-llama/llama-3.3-70b:free",
|
||||
billing_tier="free",
|
||||
)
|
||||
original = config.GLOBAL_LLM_CONFIGS
|
||||
try:
|
||||
config.GLOBAL_LLM_CONFIGS = [or_premium, or_free]
|
||||
candidate_ids = {c["id"] for c in _global_candidates()}
|
||||
assert candidate_ids == {-10_001, -10_002}
|
||||
finally:
|
||||
config.GLOBAL_LLM_CONFIGS = original
|
||||
|
|
@ -0,0 +1,216 @@
|
|||
"""Unit tests for the dynamic OpenRouter integration."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.openrouter_integration_service import (
|
||||
_OPENROUTER_DYNAMIC_MARKER,
|
||||
_generate_configs,
|
||||
_openrouter_tier,
|
||||
_stable_config_id,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
def _minimal_openrouter_model(
|
||||
*,
|
||||
model_id: str,
|
||||
pricing: dict | None = None,
|
||||
name: str | None = None,
|
||||
) -> dict:
|
||||
"""Return a synthetic OpenRouter /api/v1/models entry.
|
||||
|
||||
The real API payload includes a lot of fields; we only populate what
|
||||
``_generate_configs`` actually inspects (architecture, tool support,
|
||||
context, pricing, id).
|
||||
"""
|
||||
return {
|
||||
"id": model_id,
|
||||
"name": name or model_id,
|
||||
"architecture": {"output_modalities": ["text"]},
|
||||
"supported_parameters": ["tools"],
|
||||
"context_length": 200_000,
|
||||
"pricing": pricing or {"prompt": "0.000003", "completion": "0.000015"},
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _openrouter_tier
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_openrouter_tier_free_suffix():
|
||||
assert _openrouter_tier({"id": "foo/bar:free"}) == "free"
|
||||
|
||||
|
||||
def test_openrouter_tier_zero_pricing():
|
||||
model = {
|
||||
"id": "foo/bar",
|
||||
"pricing": {"prompt": "0", "completion": "0"},
|
||||
}
|
||||
assert _openrouter_tier(model) == "free"
|
||||
|
||||
|
||||
def test_openrouter_tier_paid():
|
||||
model = {
|
||||
"id": "foo/bar",
|
||||
"pricing": {"prompt": "0.000003", "completion": "0.000015"},
|
||||
}
|
||||
assert _openrouter_tier(model) == "premium"
|
||||
|
||||
|
||||
def test_openrouter_tier_missing_pricing_is_premium():
|
||||
assert _openrouter_tier({"id": "foo/bar"}) == "premium"
|
||||
assert _openrouter_tier({"id": "foo/bar", "pricing": {}}) == "premium"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _stable_config_id
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_stable_config_id_deterministic():
|
||||
taken1: set[int] = set()
|
||||
taken2: set[int] = set()
|
||||
a = _stable_config_id("openai/gpt-4o", -10_000, taken1)
|
||||
b = _stable_config_id("openai/gpt-4o", -10_000, taken2)
|
||||
assert a == b
|
||||
assert a < 0
|
||||
|
||||
|
||||
def test_stable_config_id_collision_decrements():
|
||||
"""When two model_ids hash to the same slot, the second should decrement."""
|
||||
taken: set[int] = set()
|
||||
a = _stable_config_id("openai/gpt-4o", -10_000, taken)
|
||||
# Force a collision by pre-populating ``taken`` with a slot we know will be
|
||||
# picked.
|
||||
taken_forced = {a}
|
||||
b = _stable_config_id("openai/gpt-4o", -10_000, taken_forced)
|
||||
assert b != a
|
||||
assert b == a - 1
|
||||
assert b in taken_forced
|
||||
|
||||
|
||||
def test_stable_config_id_different_models_different_ids():
|
||||
taken: set[int] = set()
|
||||
ids = {
|
||||
_stable_config_id("openai/gpt-4o", -10_000, taken),
|
||||
_stable_config_id("anthropic/claude-3.5-sonnet", -10_000, taken),
|
||||
_stable_config_id("google/gemini-2.0-flash", -10_000, taken),
|
||||
}
|
||||
assert len(ids) == 3
|
||||
|
||||
|
||||
def test_stable_config_id_survives_catalogue_churn():
|
||||
"""Removing a model should not shift other models' IDs (the bug we fix)."""
|
||||
taken1: set[int] = set()
|
||||
id_a1 = _stable_config_id("openai/gpt-4o", -10_000, taken1)
|
||||
_ = _stable_config_id("anthropic/claude-3-haiku", -10_000, taken1)
|
||||
id_c1 = _stable_config_id("google/gemini-2.0-flash", -10_000, taken1)
|
||||
|
||||
taken2: set[int] = set()
|
||||
id_a2 = _stable_config_id("openai/gpt-4o", -10_000, taken2)
|
||||
id_c2 = _stable_config_id("google/gemini-2.0-flash", -10_000, taken2)
|
||||
|
||||
assert id_a1 == id_a2
|
||||
assert id_c1 == id_c2
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _generate_configs
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
_SETTINGS_BASE: dict = {
|
||||
"api_key": "sk-or-test",
|
||||
"id_offset": -10_000,
|
||||
"rpm": 200,
|
||||
"tpm": 1_000_000,
|
||||
"free_rpm": 20,
|
||||
"free_tpm": 100_000,
|
||||
"anonymous_enabled_paid": False,
|
||||
"anonymous_enabled_free": True,
|
||||
"quota_reserve_tokens": 4000,
|
||||
}
|
||||
|
||||
|
||||
def test_generate_configs_respects_tier():
|
||||
"""Premium OR models opt into the router pool; free OR models stay out.
|
||||
|
||||
Strategy-3 split: premium participates in LiteLLM Router load balancing,
|
||||
free stays excluded because OpenRouter enforces a shared global free-tier
|
||||
bucket that per-deployment router accounting can't represent.
|
||||
"""
|
||||
raw = [
|
||||
_minimal_openrouter_model(model_id="openai/gpt-4o"),
|
||||
_minimal_openrouter_model(
|
||||
model_id="meta-llama/llama-3.3-70b-instruct:free",
|
||||
pricing={"prompt": "0", "completion": "0"},
|
||||
),
|
||||
]
|
||||
cfgs = _generate_configs(raw, dict(_SETTINGS_BASE))
|
||||
by_model = {c["model_name"]: c for c in cfgs}
|
||||
|
||||
paid = by_model["openai/gpt-4o"]
|
||||
assert paid["billing_tier"] == "premium"
|
||||
assert paid["rpm"] == 200
|
||||
assert paid["tpm"] == 1_000_000
|
||||
assert paid["anonymous_enabled"] is False
|
||||
assert paid["router_pool_eligible"] is True
|
||||
assert paid[_OPENROUTER_DYNAMIC_MARKER] is True
|
||||
|
||||
free = by_model["meta-llama/llama-3.3-70b-instruct:free"]
|
||||
assert free["billing_tier"] == "free"
|
||||
assert free["rpm"] == 20
|
||||
assert free["tpm"] == 100_000
|
||||
assert free["anonymous_enabled"] is True
|
||||
assert free["router_pool_eligible"] is False
|
||||
|
||||
|
||||
def test_generate_configs_excludes_upstream_openrouter_free_router():
|
||||
"""OpenRouter's own ``openrouter/free`` meta-router must never become a card.
|
||||
|
||||
The upstream API returns this as a first-class zero-priced model, so
|
||||
without an explicit blocklist entry it would slip through every other
|
||||
filter (text output, tool calling, 200k context, non-Amazon) and land
|
||||
in the selector as a duplicate of the concrete ``:free`` cards. The
|
||||
exclusion in ``_EXCLUDED_MODEL_IDS`` prevents that.
|
||||
"""
|
||||
raw = [
|
||||
_minimal_openrouter_model(model_id="openai/gpt-4o"),
|
||||
_minimal_openrouter_model(
|
||||
model_id="openrouter/free",
|
||||
pricing={"prompt": "0", "completion": "0"},
|
||||
),
|
||||
]
|
||||
cfgs = _generate_configs(raw, dict(_SETTINGS_BASE))
|
||||
model_names = {c["model_name"] for c in cfgs}
|
||||
assert "openrouter/free" not in model_names
|
||||
assert "openai/gpt-4o" in model_names
|
||||
|
||||
|
||||
def test_generate_configs_drops_non_text_and_non_tool_models():
|
||||
raw = [
|
||||
_minimal_openrouter_model(model_id="openai/gpt-4o"),
|
||||
{ # image-output model
|
||||
"id": "openai/dall-e",
|
||||
"architecture": {"output_modalities": ["image"]},
|
||||
"supported_parameters": ["tools"],
|
||||
"context_length": 200_000,
|
||||
"pricing": {"prompt": "0.01", "completion": "0.01"},
|
||||
},
|
||||
{ # text but no tool calling
|
||||
"id": "openai/completion-only",
|
||||
"architecture": {"output_modalities": ["text"]},
|
||||
"supported_parameters": [],
|
||||
"context_length": 200_000,
|
||||
"pricing": {"prompt": "0.01", "completion": "0.01"},
|
||||
},
|
||||
]
|
||||
cfgs = _generate_configs(raw, dict(_SETTINGS_BASE))
|
||||
model_names = [c["model_name"] for c in cfgs]
|
||||
assert "openai/gpt-4o" in model_names
|
||||
assert "openai/dall-e" not in model_names
|
||||
assert "openai/completion-only" not in model_names
|
||||
|
|
@ -0,0 +1,108 @@
|
|||
"""Tests for deprecated-key warnings and back-compat in
|
||||
``load_openrouter_integration_settings``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
def _write_yaml(tmp_path: Path, body: str) -> Path:
|
||||
cfg_dir = tmp_path / "app" / "config"
|
||||
cfg_dir.mkdir(parents=True)
|
||||
cfg_path = cfg_dir / "global_llm_config.yaml"
|
||||
cfg_path.write_text(body, encoding="utf-8")
|
||||
return cfg_path
|
||||
|
||||
|
||||
def _patch_base_dir(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None:
|
||||
from app import config as config_module
|
||||
|
||||
monkeypatch.setattr(config_module, "BASE_DIR", tmp_path)
|
||||
|
||||
|
||||
def test_legacy_billing_tier_emits_warning(monkeypatch, tmp_path, capsys):
|
||||
_write_yaml(
|
||||
tmp_path,
|
||||
"""
|
||||
openrouter_integration:
|
||||
enabled: true
|
||||
api_key: "sk-or-test"
|
||||
billing_tier: "premium"
|
||||
""".lstrip(),
|
||||
)
|
||||
_patch_base_dir(monkeypatch, tmp_path)
|
||||
|
||||
from app.config import load_openrouter_integration_settings
|
||||
|
||||
settings = load_openrouter_integration_settings()
|
||||
captured = capsys.readouterr().out
|
||||
assert settings is not None
|
||||
assert "billing_tier is deprecated" in captured
|
||||
|
||||
|
||||
def test_legacy_anonymous_enabled_back_compat(monkeypatch, tmp_path, capsys):
|
||||
_write_yaml(
|
||||
tmp_path,
|
||||
"""
|
||||
openrouter_integration:
|
||||
enabled: true
|
||||
api_key: "sk-or-test"
|
||||
anonymous_enabled: true
|
||||
""".lstrip(),
|
||||
)
|
||||
_patch_base_dir(monkeypatch, tmp_path)
|
||||
|
||||
from app.config import load_openrouter_integration_settings
|
||||
|
||||
settings = load_openrouter_integration_settings()
|
||||
captured = capsys.readouterr().out
|
||||
assert settings is not None
|
||||
assert settings["anonymous_enabled_paid"] is True
|
||||
assert settings["anonymous_enabled_free"] is True
|
||||
assert "anonymous_enabled is" in captured
|
||||
assert "deprecated" in captured
|
||||
|
||||
|
||||
def test_new_keys_take_priority_over_legacy_back_compat(monkeypatch, tmp_path, capsys):
|
||||
"""If both legacy and new keys are present, new keys win (setdefault)."""
|
||||
_write_yaml(
|
||||
tmp_path,
|
||||
"""
|
||||
openrouter_integration:
|
||||
enabled: true
|
||||
api_key: "sk-or-test"
|
||||
anonymous_enabled: true
|
||||
anonymous_enabled_paid: false
|
||||
anonymous_enabled_free: false
|
||||
""".lstrip(),
|
||||
)
|
||||
_patch_base_dir(monkeypatch, tmp_path)
|
||||
|
||||
from app.config import load_openrouter_integration_settings
|
||||
|
||||
settings = load_openrouter_integration_settings()
|
||||
capsys.readouterr()
|
||||
assert settings is not None
|
||||
assert settings["anonymous_enabled_paid"] is False
|
||||
assert settings["anonymous_enabled_free"] is False
|
||||
|
||||
|
||||
def test_disabled_integration_returns_none(monkeypatch, tmp_path):
|
||||
_write_yaml(
|
||||
tmp_path,
|
||||
"""
|
||||
openrouter_integration:
|
||||
enabled: false
|
||||
api_key: "sk-or-test"
|
||||
""".lstrip(),
|
||||
)
|
||||
_patch_base_dir(monkeypatch, tmp_path)
|
||||
|
||||
from app.config import load_openrouter_integration_settings
|
||||
|
||||
assert load_openrouter_integration_settings() is None
|
||||
|
|
@ -0,0 +1,331 @@
|
|||
"""Unit tests for the OpenRouter ``_enrich_health`` background task."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.openrouter_integration_service import (
|
||||
OpenRouterIntegrationService,
|
||||
)
|
||||
from app.services.quality_score import (
|
||||
_HEALTH_FAIL_RATIO_FALLBACK,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
def _or_cfg(
|
||||
*,
|
||||
cid: int,
|
||||
model_name: str,
|
||||
tier: str = "premium",
|
||||
static_score: int = 50,
|
||||
) -> dict:
|
||||
return {
|
||||
"id": cid,
|
||||
"provider": "OPENROUTER",
|
||||
"model_name": model_name,
|
||||
"billing_tier": tier,
|
||||
"auto_pin_tier": "B" if tier == "premium" else "C",
|
||||
"quality_score_static": static_score,
|
||||
"quality_score_health": None,
|
||||
"quality_score": static_score,
|
||||
"health_gated": False,
|
||||
}
|
||||
|
||||
|
||||
class _StubResponse:
|
||||
def __init__(self, *, payload: dict, status_code: int = 200):
|
||||
self._payload = payload
|
||||
self.status_code = status_code
|
||||
|
||||
def raise_for_status(self) -> None:
|
||||
if self.status_code >= 400:
|
||||
raise RuntimeError(f"HTTP {self.status_code}")
|
||||
|
||||
def json(self) -> dict:
|
||||
return self._payload
|
||||
|
||||
|
||||
class _StubAsyncClient:
|
||||
"""Minimal drop-in for ``httpx.AsyncClient`` used by ``_fetch_endpoints``."""
|
||||
|
||||
def __init__(self, responder):
|
||||
self._responder = responder
|
||||
self.requests: list[str] = []
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
async def get(self, url: str, headers: dict | None = None) -> _StubResponse:
|
||||
self.requests.append(url)
|
||||
return self._responder(url)
|
||||
|
||||
|
||||
def _patch_async_client(monkeypatch, responder) -> _StubAsyncClient:
|
||||
"""Replace ``httpx.AsyncClient`` for the duration of the test."""
|
||||
client = _StubAsyncClient(responder)
|
||||
monkeypatch.setattr(
|
||||
"app.services.openrouter_integration_service.httpx.AsyncClient",
|
||||
lambda *_args, **_kwargs: client,
|
||||
)
|
||||
return client
|
||||
|
||||
|
||||
def _healthy_payload() -> dict:
|
||||
return {
|
||||
"data": {
|
||||
"endpoints": [
|
||||
{
|
||||
"status": 0,
|
||||
"uptime_last_30m": 0.99,
|
||||
"uptime_last_1d": 0.995,
|
||||
"uptime_last_5m": 0.99,
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def _unhealthy_payload() -> dict:
|
||||
return {
|
||||
"data": {
|
||||
"endpoints": [
|
||||
{
|
||||
"status": 0,
|
||||
"uptime_last_30m": 0.55,
|
||||
"uptime_last_1d": 0.62,
|
||||
"uptime_last_5m": 0.50,
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Bounded fan-out + happy path
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def test_enrich_health_marks_healthy_and_gates_unhealthy(monkeypatch):
|
||||
cfgs = [
|
||||
_or_cfg(cid=-1, model_name="anthropic/claude-haiku", static_score=70),
|
||||
_or_cfg(cid=-2, model_name="venice/dead-model", static_score=60),
|
||||
]
|
||||
|
||||
def responder(url: str) -> _StubResponse:
|
||||
if "anthropic" in url:
|
||||
return _StubResponse(payload=_healthy_payload())
|
||||
return _StubResponse(payload=_unhealthy_payload())
|
||||
|
||||
_patch_async_client(monkeypatch, responder)
|
||||
|
||||
service = OpenRouterIntegrationService()
|
||||
service._settings = {"api_key": ""}
|
||||
await service._enrich_health(cfgs)
|
||||
|
||||
healthy = next(c for c in cfgs if c["id"] == -1)
|
||||
gated = next(c for c in cfgs if c["id"] == -2)
|
||||
|
||||
assert healthy["health_gated"] is False
|
||||
assert healthy["quality_score_health"] is not None
|
||||
assert healthy["quality_score"] >= healthy["quality_score_static"]
|
||||
|
||||
assert gated["health_gated"] is True
|
||||
assert gated["quality_score"] == gated["quality_score_static"]
|
||||
|
||||
|
||||
async def test_enrich_health_only_touches_or_provider(monkeypatch):
|
||||
"""YAML cfgs that aren't OPENROUTER must be skipped entirely."""
|
||||
yaml_cfg = {
|
||||
"id": -1,
|
||||
"provider": "AZURE_OPENAI",
|
||||
"model_name": "gpt-5",
|
||||
"billing_tier": "premium",
|
||||
"auto_pin_tier": "A",
|
||||
"quality_score_static": 80,
|
||||
"quality_score": 80,
|
||||
"health_gated": False,
|
||||
}
|
||||
or_cfg = _or_cfg(cid=-2, model_name="anthropic/claude-haiku")
|
||||
|
||||
requests: list[str] = []
|
||||
|
||||
def responder(url: str) -> _StubResponse:
|
||||
requests.append(url)
|
||||
return _StubResponse(payload=_healthy_payload())
|
||||
|
||||
_patch_async_client(monkeypatch, responder)
|
||||
|
||||
service = OpenRouterIntegrationService()
|
||||
service._settings = {}
|
||||
await service._enrich_health([yaml_cfg, or_cfg])
|
||||
|
||||
assert all("anthropic/claude-haiku" in r for r in requests)
|
||||
# YAML cfg is untouched.
|
||||
assert yaml_cfg["quality_score"] == 80
|
||||
assert yaml_cfg["health_gated"] is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Failure ratio fallback
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def test_enrich_health_falls_back_to_last_good_when_failure_ratio_high(
|
||||
monkeypatch,
|
||||
):
|
||||
"""If >= 25% of fetches fail, keep last-good cache instead of writing
|
||||
partial data."""
|
||||
cfgs = [
|
||||
_or_cfg(cid=-1, model_name="anthropic/claude-haiku", static_score=70),
|
||||
_or_cfg(cid=-2, model_name="openai/gpt-5", static_score=80),
|
||||
_or_cfg(cid=-3, model_name="google/gemini-flash", static_score=65),
|
||||
_or_cfg(cid=-4, model_name="venice/something", static_score=50),
|
||||
]
|
||||
|
||||
service = OpenRouterIntegrationService()
|
||||
service._settings = {}
|
||||
# Pre-seed last-good cache with a known-healthy snapshot.
|
||||
service._health_cache = {
|
||||
"anthropic/claude-haiku": {"gated": False, "score": 95.0},
|
||||
}
|
||||
|
||||
def all_fail(_url: str) -> _StubResponse:
|
||||
return _StubResponse(payload={}, status_code=500)
|
||||
|
||||
_patch_async_client(monkeypatch, all_fail)
|
||||
await service._enrich_health(cfgs)
|
||||
|
||||
# Above threshold ⇒ degraded; last-good cache wins for the cached cfg.
|
||||
cached_hit = next(c for c in cfgs if c["model_name"] == "anthropic/claude-haiku")
|
||||
assert cached_hit["quality_score_health"] == 95.0
|
||||
assert cached_hit["health_gated"] is False
|
||||
# Confirm the threshold constant we're testing against is real.
|
||||
assert _HEALTH_FAIL_RATIO_FALLBACK <= 1.0
|
||||
|
||||
|
||||
async def test_enrich_health_keeps_static_only_with_no_cache_and_failures(
|
||||
monkeypatch,
|
||||
):
|
||||
"""If a fetch fails and there's no last-good cache, the cfg keeps its
|
||||
static-only ``quality_score`` and is *not* gated by default."""
|
||||
cfgs = [
|
||||
_or_cfg(cid=-1, model_name="anthropic/claude-haiku", static_score=70),
|
||||
]
|
||||
|
||||
def fail(_url: str) -> _StubResponse:
|
||||
return _StubResponse(payload={}, status_code=500)
|
||||
|
||||
_patch_async_client(monkeypatch, fail)
|
||||
|
||||
service = OpenRouterIntegrationService()
|
||||
service._settings = {}
|
||||
await service._enrich_health(cfgs)
|
||||
|
||||
cfg = cfgs[0]
|
||||
assert cfg["health_gated"] is False
|
||||
assert cfg["quality_score"] == cfg["quality_score_static"]
|
||||
assert cfg["quality_score_health"] is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Last-good cache: success populates, next failure reuses
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def test_enrich_health_populates_cache_on_success_then_reuses_on_failure(
|
||||
monkeypatch,
|
||||
):
|
||||
cfg = _or_cfg(cid=-1, model_name="anthropic/claude-haiku", static_score=70)
|
||||
|
||||
service = OpenRouterIntegrationService()
|
||||
service._settings = {}
|
||||
|
||||
def healthy(_url: str) -> _StubResponse:
|
||||
return _StubResponse(payload=_healthy_payload())
|
||||
|
||||
_patch_async_client(monkeypatch, healthy)
|
||||
await service._enrich_health([cfg])
|
||||
|
||||
assert "anthropic/claude-haiku" in service._health_cache
|
||||
cached_score = service._health_cache["anthropic/claude-haiku"]["score"]
|
||||
assert cached_score is not None
|
||||
|
||||
# Next cycle: enough other healthy cfgs so failure ratio stays below
|
||||
# the 25% threshold even when this one fails individually.
|
||||
other_cfgs = [
|
||||
_or_cfg(cid=-2 - i, model_name=f"healthy/m-{i}", static_score=60)
|
||||
for i in range(10)
|
||||
]
|
||||
cfg["quality_score_health"] = None
|
||||
cfg["quality_score"] = cfg["quality_score_static"]
|
||||
|
||||
def mixed(url: str) -> _StubResponse:
|
||||
if "anthropic" in url:
|
||||
return _StubResponse(payload={}, status_code=500)
|
||||
return _StubResponse(payload=_healthy_payload())
|
||||
|
||||
_patch_async_client(monkeypatch, mixed)
|
||||
await service._enrich_health([cfg, *other_cfgs])
|
||||
|
||||
assert cfg["quality_score_health"] == cached_score
|
||||
assert cfg["health_gated"] is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Bounded fan-out: respects top-N caps
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def test_enrich_health_bounds_premium_fanout(monkeypatch):
|
||||
"""Top-N premium cap is honoured even when many cfgs are present."""
|
||||
from app.services.quality_score import _HEALTH_ENRICH_TOP_N_PREMIUM
|
||||
|
||||
cfgs = [
|
||||
_or_cfg(
|
||||
cid=-i, model_name=f"openai/m-{i}", tier="premium", static_score=100 - i
|
||||
)
|
||||
for i in range(1, _HEALTH_ENRICH_TOP_N_PREMIUM + 20)
|
||||
]
|
||||
|
||||
seen: list[str] = []
|
||||
|
||||
def responder(url: str) -> _StubResponse:
|
||||
seen.append(url)
|
||||
return _StubResponse(payload=_healthy_payload())
|
||||
|
||||
_patch_async_client(monkeypatch, responder)
|
||||
|
||||
service = OpenRouterIntegrationService()
|
||||
service._settings = {}
|
||||
await service._enrich_health(cfgs)
|
||||
|
||||
assert len(seen) == _HEALTH_ENRICH_TOP_N_PREMIUM
|
||||
|
||||
|
||||
async def test_enrich_health_no_or_cfgs_is_noop(monkeypatch):
|
||||
"""When the catalogue has no OR cfgs at all, no HTTP calls fire."""
|
||||
yaml_cfg: dict[str, Any] = {
|
||||
"id": -1,
|
||||
"provider": "AZURE_OPENAI",
|
||||
"model_name": "gpt-5",
|
||||
"billing_tier": "premium",
|
||||
}
|
||||
requests: list[str] = []
|
||||
|
||||
def responder(url: str) -> _StubResponse:
|
||||
requests.append(url)
|
||||
return _StubResponse(payload=_healthy_payload())
|
||||
|
||||
_patch_async_client(monkeypatch, responder)
|
||||
|
||||
service = OpenRouterIntegrationService()
|
||||
service._settings = {}
|
||||
await service._enrich_health([yaml_cfg])
|
||||
assert requests == []
|
||||
345
surfsense_backend/tests/unit/services/test_quality_score.py
Normal file
345
surfsense_backend/tests/unit/services/test_quality_score.py
Normal file
|
|
@ -0,0 +1,345 @@
|
|||
"""Unit tests for the Auto (Fastest) quality scoring module."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.quality_score import (
|
||||
_HEALTH_GATE_UPTIME_PCT,
|
||||
_OPERATOR_TRUST_BONUS,
|
||||
aggregate_health,
|
||||
capabilities_signal,
|
||||
context_signal,
|
||||
created_recency_signal,
|
||||
pricing_band,
|
||||
slug_penalty,
|
||||
static_score_or,
|
||||
static_score_yaml,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# created_recency_signal
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_created_recency_signal_recent_model_scores_high():
|
||||
now = 1_750_000_000 # ~mid-2025
|
||||
one_month_ago = now - (30 * 86_400)
|
||||
assert created_recency_signal(one_month_ago, now) == 20
|
||||
|
||||
|
||||
def test_created_recency_signal_old_model_scores_zero():
|
||||
now = 1_750_000_000
|
||||
five_years_ago = now - (5 * 365 * 86_400)
|
||||
assert created_recency_signal(five_years_ago, now) == 0
|
||||
|
||||
|
||||
def test_created_recency_signal_missing_timestamp_is_neutral():
|
||||
now = 1_750_000_000
|
||||
assert created_recency_signal(None, now) == 0
|
||||
assert created_recency_signal(0, now) == 0
|
||||
|
||||
|
||||
def test_created_recency_signal_monotonic_decay():
|
||||
now = 1_750_000_000
|
||||
scores = [
|
||||
created_recency_signal(now - days * 86_400, now)
|
||||
for days in (30, 120, 300, 500, 700, 1000, 1500)
|
||||
]
|
||||
assert scores == sorted(scores, reverse=True)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# pricing_band
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_pricing_band_free_returns_zero():
|
||||
assert pricing_band("0", "0") == 0
|
||||
assert pricing_band(0.0, 0.0) == 0
|
||||
assert pricing_band(None, None) == 0
|
||||
|
||||
|
||||
def test_pricing_band_handles_unparseable():
|
||||
assert pricing_band("not-a-number", "0") == 0
|
||||
assert pricing_band({}, []) == 0 # type: ignore[arg-type]
|
||||
|
||||
|
||||
def test_pricing_band_premium_tiers_increase_with_price():
|
||||
cheap = pricing_band("0.0000003", "0.0000005")
|
||||
mid = pricing_band("0.000003", "0.000015")
|
||||
flagship = pricing_band("0.00001", "0.00005")
|
||||
assert 0 < cheap < mid < flagship
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# context_signal
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"ctx,expected",
|
||||
[
|
||||
(1_500_000, 10),
|
||||
(1_000_000, 10),
|
||||
(500_000, 8),
|
||||
(200_000, 6),
|
||||
(128_000, 4),
|
||||
(100_000, 2),
|
||||
(50_000, 0),
|
||||
(0, 0),
|
||||
(None, 0),
|
||||
],
|
||||
)
|
||||
def test_context_signal_bands(ctx, expected):
|
||||
assert context_signal(ctx) == expected
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# capabilities_signal
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_capabilities_signal_caps_at_five():
|
||||
assert (
|
||||
capabilities_signal(
|
||||
["tools", "structured_outputs", "reasoning", "include_reasoning"]
|
||||
)
|
||||
<= 5
|
||||
)
|
||||
|
||||
|
||||
def test_capabilities_signal_tools_only():
|
||||
assert capabilities_signal(["tools"]) == 2
|
||||
|
||||
|
||||
def test_capabilities_signal_empty():
|
||||
assert capabilities_signal(None) == 0
|
||||
assert capabilities_signal([]) == 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# slug_penalty
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_slug_penalty_demotes_tiny_models():
|
||||
assert slug_penalty("meta-llama/llama-3.2-1b-instruct") < 0
|
||||
assert slug_penalty("liquid/lfm-7b") < 0
|
||||
assert slug_penalty("google/gemma-3n-e4b-it") < 0
|
||||
|
||||
|
||||
def test_slug_penalty_skips_capable_mini_nano_lite_models():
|
||||
"""Critical Option C+ regression: don't penalise modern frontier
|
||||
models named ``-nano`` / ``-mini`` / ``-lite`` (gpt-5-mini, etc.)."""
|
||||
assert slug_penalty("openai/gpt-5-mini") == 0
|
||||
assert slug_penalty("openai/gpt-5-nano") == 0
|
||||
assert slug_penalty("google/gemini-2.5-flash-lite") == 0
|
||||
assert slug_penalty("anthropic/claude-haiku-4.5") == 0
|
||||
|
||||
|
||||
def test_slug_penalty_demotes_legacy_variants():
|
||||
assert slug_penalty("openai/o1-preview") < 0
|
||||
assert slug_penalty("foo/bar-base") < 0
|
||||
assert slug_penalty("foo/bar-distill") < 0
|
||||
|
||||
|
||||
def test_slug_penalty_empty_input():
|
||||
assert slug_penalty("") == 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# static_score_or
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _or_model(
|
||||
*,
|
||||
model_id: str,
|
||||
created: int | None = None,
|
||||
prompt: str = "0.000003",
|
||||
completion: str = "0.000015",
|
||||
context: int = 200_000,
|
||||
params: list[str] | None = None,
|
||||
) -> dict:
|
||||
return {
|
||||
"id": model_id,
|
||||
"created": created,
|
||||
"pricing": {"prompt": prompt, "completion": completion},
|
||||
"context_length": context,
|
||||
"supported_parameters": params if params is not None else ["tools"],
|
||||
}
|
||||
|
||||
|
||||
def test_static_score_or_frontier_premium_beats_free_tiny():
|
||||
now = 1_750_000_000
|
||||
frontier = _or_model(
|
||||
model_id="openai/gpt-5",
|
||||
created=now - (60 * 86_400),
|
||||
prompt="0.000005",
|
||||
completion="0.000020",
|
||||
context=400_000,
|
||||
params=["tools", "structured_outputs", "reasoning"],
|
||||
)
|
||||
tiny_free = _or_model(
|
||||
model_id="meta-llama/llama-3.2-1b-instruct:free",
|
||||
created=now - (5 * 365 * 86_400),
|
||||
prompt="0",
|
||||
completion="0",
|
||||
context=128_000,
|
||||
params=["tools"],
|
||||
)
|
||||
assert static_score_or(frontier, now_ts=now) > static_score_or(
|
||||
tiny_free, now_ts=now
|
||||
)
|
||||
|
||||
|
||||
def test_static_score_or_score_is_clamped_0_to_100():
|
||||
now = int(time.time())
|
||||
score = static_score_or(_or_model(model_id="openai/gpt-4o"), now_ts=now)
|
||||
assert 0 <= score <= 100
|
||||
|
||||
|
||||
def test_static_score_or_unknown_provider_is_neutral_not_zero():
|
||||
now = int(time.time())
|
||||
score = static_score_or(
|
||||
_or_model(model_id="some-new-lab/some-model"),
|
||||
now_ts=now,
|
||||
)
|
||||
assert score > 0
|
||||
|
||||
|
||||
def test_static_score_or_recent_release_beats_year_old_same_provider():
|
||||
now = 1_750_000_000
|
||||
fresh = _or_model(model_id="openai/gpt-5", created=now - (60 * 86_400))
|
||||
old = _or_model(model_id="openai/gpt-4-turbo", created=now - (700 * 86_400))
|
||||
assert static_score_or(fresh, now_ts=now) > static_score_or(old, now_ts=now)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# static_score_yaml
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_static_score_yaml_includes_operator_bonus():
|
||||
cfg = {
|
||||
"provider": "AZURE_OPENAI",
|
||||
"model_name": "gpt-5",
|
||||
"litellm_params": {"base_model": "azure/gpt-5"},
|
||||
}
|
||||
score = static_score_yaml(cfg)
|
||||
assert score >= _OPERATOR_TRUST_BONUS
|
||||
|
||||
|
||||
def test_static_score_yaml_unknown_provider_still_carries_bonus():
|
||||
cfg = {
|
||||
"provider": "SOME_NEW_PROVIDER",
|
||||
"model_name": "weird-model",
|
||||
}
|
||||
score = static_score_yaml(cfg)
|
||||
assert score >= _OPERATOR_TRUST_BONUS
|
||||
|
||||
|
||||
def test_static_score_yaml_clamped_0_to_100():
|
||||
cfg = {
|
||||
"provider": "AZURE_OPENAI",
|
||||
"model_name": "gpt-5",
|
||||
"litellm_params": {"base_model": "azure/gpt-5"},
|
||||
}
|
||||
assert 0 <= static_score_yaml(cfg) <= 100
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# aggregate_health
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_aggregate_health_gates_when_uptime_below_threshold():
|
||||
"""Live data showed Venice-routed cfgs at 53-68%; this guards that the
|
||||
90% gate excludes them."""
|
||||
venice_endpoints = [
|
||||
{
|
||||
"status": 0,
|
||||
"uptime_last_30m": 0.55,
|
||||
"uptime_last_1d": 0.60,
|
||||
"uptime_last_5m": 0.50,
|
||||
},
|
||||
{
|
||||
"status": 0,
|
||||
"uptime_last_30m": 0.65,
|
||||
"uptime_last_1d": 0.68,
|
||||
"uptime_last_5m": 0.62,
|
||||
},
|
||||
]
|
||||
gated, score = aggregate_health(venice_endpoints)
|
||||
assert gated is True
|
||||
assert score is None
|
||||
|
||||
|
||||
def test_aggregate_health_passes_for_healthy_provider():
|
||||
healthy = [
|
||||
{
|
||||
"status": 0,
|
||||
"uptime_last_30m": 0.99,
|
||||
"uptime_last_1d": 0.995,
|
||||
"uptime_last_5m": 0.99,
|
||||
},
|
||||
]
|
||||
gated, score = aggregate_health(healthy)
|
||||
assert gated is False
|
||||
assert score is not None
|
||||
assert score >= _HEALTH_GATE_UPTIME_PCT
|
||||
|
||||
|
||||
def test_aggregate_health_picks_best_endpoint_across_multiple():
|
||||
"""Multi-endpoint aggregation should reward the best non-null uptime."""
|
||||
mixed = [
|
||||
{"status": 0, "uptime_last_30m": 0.55},
|
||||
{"status": 0, "uptime_last_30m": 0.97}, # this one passes the gate
|
||||
]
|
||||
gated, score = aggregate_health(mixed)
|
||||
assert gated is False
|
||||
assert score is not None
|
||||
|
||||
|
||||
def test_aggregate_health_empty_endpoints_gated():
|
||||
gated, score = aggregate_health([])
|
||||
assert gated is True
|
||||
assert score is None
|
||||
|
||||
|
||||
def test_aggregate_health_no_status_zero_gated():
|
||||
"""Even with high uptime, no OK status means the cfg is broken upstream."""
|
||||
endpoints = [
|
||||
{"status": 1, "uptime_last_30m": 0.99},
|
||||
{"status": 2, "uptime_last_30m": 0.98},
|
||||
]
|
||||
gated, score = aggregate_health(endpoints)
|
||||
assert gated is True
|
||||
assert score is None
|
||||
|
||||
|
||||
def test_aggregate_health_all_uptime_null_gated():
|
||||
endpoints = [
|
||||
{"status": 0, "uptime_last_30m": None, "uptime_last_1d": None},
|
||||
]
|
||||
gated, score = aggregate_health(endpoints)
|
||||
assert gated is True
|
||||
assert score is None
|
||||
|
||||
|
||||
def test_aggregate_health_pct_normalisation():
|
||||
"""OpenRouter returns 0-1 fractions; some endpoints surface 0-100%
|
||||
percentages. Both should reach the same gate decision."""
|
||||
fraction_form = [{"status": 0, "uptime_last_30m": 0.95}]
|
||||
pct_form = [{"status": 0, "uptime_last_30m": 95.0}]
|
||||
g1, s1 = aggregate_health(fraction_form)
|
||||
g2, s2 = aggregate_health(pct_form)
|
||||
assert g1 == g2 == False # noqa: E712
|
||||
assert s1 is not None and s2 is not None
|
||||
assert abs(s1 - s2) < 0.5
|
||||
|
|
@ -14,6 +14,7 @@ from app.tasks.chat.stream_new_chat import (
|
|||
_classify_stream_exception,
|
||||
_contract_enforcement_active,
|
||||
_evaluate_file_contract_outcome,
|
||||
_extract_resolved_file_path,
|
||||
_log_chat_stream_error,
|
||||
_tool_output_has_error,
|
||||
)
|
||||
|
|
@ -28,6 +29,39 @@ def test_tool_output_error_detection():
|
|||
assert not _tool_output_has_error({"result": "Updated file /notes.md"})
|
||||
|
||||
|
||||
def test_extract_resolved_file_path_prefers_structured_path():
|
||||
assert (
|
||||
_extract_resolved_file_path(
|
||||
tool_name="write_file",
|
||||
tool_output={"status": "completed", "path": "/docs/note.md"},
|
||||
tool_input=None,
|
||||
)
|
||||
== "/docs/note.md"
|
||||
)
|
||||
|
||||
|
||||
def test_extract_resolved_file_path_falls_back_to_tool_input():
|
||||
assert (
|
||||
_extract_resolved_file_path(
|
||||
tool_name="edit_file",
|
||||
tool_output={"status": "completed", "result": "updated"},
|
||||
tool_input={"file_path": "/docs/edited.md"},
|
||||
)
|
||||
== "/docs/edited.md"
|
||||
)
|
||||
|
||||
|
||||
def test_extract_resolved_file_path_does_not_parse_result_text():
|
||||
assert (
|
||||
_extract_resolved_file_path(
|
||||
tool_name="write_file",
|
||||
tool_output={"result": "Updated file /docs/from-text.md"},
|
||||
tool_input=None,
|
||||
)
|
||||
is None
|
||||
)
|
||||
|
||||
|
||||
def test_file_write_contract_outcome_reasons():
|
||||
result = StreamResult(intent_detected="file_write")
|
||||
passed, reason = _evaluate_file_contract_outcome(result)
|
||||
|
|
@ -159,6 +193,84 @@ def test_stream_exception_classifies_rate_limited():
|
|||
assert extra is None
|
||||
|
||||
|
||||
def test_stream_exception_classifies_openrouter_429_payload():
|
||||
exc = Exception(
|
||||
'OpenrouterException - {"error":{"message":"Provider returned error","code":429,'
|
||||
'"metadata":{"raw":"foo is temporarily rate-limited upstream"}}}'
|
||||
)
|
||||
kind, code, severity, is_expected, user_message, extra = _classify_stream_exception(
|
||||
exc, flow_label="chat"
|
||||
)
|
||||
assert kind == "rate_limited"
|
||||
assert code == "RATE_LIMITED"
|
||||
assert severity == "warn"
|
||||
assert is_expected is True
|
||||
assert "temporarily rate-limited" in user_message
|
||||
assert extra is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_preflight_swallows_non_rate_limit_errors_and_re_raises_429(monkeypatch):
|
||||
"""``_preflight_llm`` is best-effort.
|
||||
|
||||
- On rate-limit shaped exceptions (provider 429) it MUST re-raise so the
|
||||
caller can drive the cooldown/repin branch.
|
||||
- On any other transient failure it MUST swallow the error so the normal
|
||||
stream path continues without surfacing preflight noise to the user.
|
||||
"""
|
||||
from types import SimpleNamespace
|
||||
|
||||
from app.tasks.chat.stream_new_chat import _preflight_llm
|
||||
|
||||
class _RateLimitedError(Exception):
|
||||
"""Class-name carries 'RateLimit' so _is_provider_rate_limited triggers."""
|
||||
|
||||
rate_calls: list[dict] = []
|
||||
other_calls: list[dict] = []
|
||||
|
||||
async def _fake_acompletion_429(**kwargs):
|
||||
rate_calls.append(kwargs)
|
||||
raise _RateLimitedError("simulated 429")
|
||||
|
||||
async def _fake_acompletion_other(**kwargs):
|
||||
other_calls.append(kwargs)
|
||||
raise RuntimeError("some unrelated transient failure")
|
||||
|
||||
fake_llm = SimpleNamespace(
|
||||
model="openrouter/google/gemma-4-31b-it:free",
|
||||
api_key="test",
|
||||
api_base=None,
|
||||
)
|
||||
|
||||
import litellm # type: ignore[import-not-found]
|
||||
|
||||
monkeypatch.setattr(litellm, "acompletion", _fake_acompletion_429)
|
||||
with pytest.raises(_RateLimitedError):
|
||||
await _preflight_llm(fake_llm)
|
||||
assert len(rate_calls) == 1
|
||||
assert rate_calls[0]["max_tokens"] == 1
|
||||
assert rate_calls[0]["stream"] is False
|
||||
|
||||
monkeypatch.setattr(litellm, "acompletion", _fake_acompletion_other)
|
||||
# MUST NOT raise: non-rate-limit failures are swallowed.
|
||||
await _preflight_llm(fake_llm)
|
||||
assert len(other_calls) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_preflight_skipped_for_auto_router_model():
|
||||
"""Router-mode ``model='auto'`` has no single deployment to ping; the
|
||||
LiteLLM router itself owns per-deployment rate-limit accounting, so the
|
||||
preflight helper must short-circuit instead of issuing a probe."""
|
||||
from types import SimpleNamespace
|
||||
|
||||
from app.tasks.chat.stream_new_chat import _preflight_llm
|
||||
|
||||
fake_llm = SimpleNamespace(model="auto", api_key="x", api_base=None)
|
||||
# Should return without raising or making any LiteLLM call.
|
||||
await _preflight_llm(fake_llm)
|
||||
|
||||
|
||||
def test_stream_exception_classifies_thread_busy():
|
||||
exc = BusyError(request_id="thread-123")
|
||||
kind, code, severity, is_expected, user_message, extra = _classify_stream_exception(
|
||||
|
|
|
|||
|
|
@ -1,11 +1,8 @@
|
|||
"use client";
|
||||
|
||||
import { useQueryClient } from "@tanstack/react-query";
|
||||
import { CheckCircle2 } from "lucide-react";
|
||||
import Link from "next/link";
|
||||
import { useParams } from "next/navigation";
|
||||
import { useEffect } from "react";
|
||||
import { USER_QUERY_KEY } from "@/atoms/user/user-query.atoms";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import {
|
||||
Card,
|
||||
|
|
@ -18,14 +15,8 @@ import {
|
|||
|
||||
export default function PurchaseSuccessPage() {
|
||||
const params = useParams();
|
||||
const queryClient = useQueryClient();
|
||||
const searchSpaceId = String(params.search_space_id ?? "");
|
||||
|
||||
useEffect(() => {
|
||||
void queryClient.invalidateQueries({ queryKey: USER_QUERY_KEY });
|
||||
void queryClient.invalidateQueries({ queryKey: ["token-status"] });
|
||||
}, [queryClient]);
|
||||
|
||||
return (
|
||||
<div className="flex min-h-[calc(100vh-64px)] items-center justify-center px-4 py-8">
|
||||
<Card className="w-full max-w-lg">
|
||||
|
|
|
|||
|
|
@ -8,7 +8,10 @@ const userQueryFn = () => userApiService.getMe();
|
|||
export const currentUserAtom = atomWithQuery(() => {
|
||||
return {
|
||||
queryKey: USER_QUERY_KEY,
|
||||
staleTime: 5 * 60 * 1000,
|
||||
// Live-changing numeric fields (pages_*, premium_tokens_*) are now
|
||||
// pushed via Zero (queries.user.me()), so /users/me only needs to
|
||||
// fire once per session for the static profile fields.
|
||||
staleTime: Infinity,
|
||||
enabled: !!getBearerToken(),
|
||||
queryFn: userQueryFn,
|
||||
};
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ import remarkMath from "remark-math";
|
|||
import { openEditorPanelAtom } from "@/atoms/editor/editor-panel.atom";
|
||||
import { ImagePreview, ImageRoot, ImageZoom } from "@/components/assistant-ui/image";
|
||||
import "katex/dist/katex.min.css";
|
||||
import { toast } from "sonner";
|
||||
import { processChildrenWithCitations } from "@/components/citations/citation-renderer";
|
||||
import { Skeleton } from "@/components/ui/skeleton";
|
||||
import {
|
||||
|
|
@ -30,6 +31,7 @@ import {
|
|||
TableRow,
|
||||
} from "@/components/ui/table";
|
||||
import { useElectronAPI } from "@/hooks/use-platform";
|
||||
import { documentsApiService } from "@/lib/apis/documents-api.service";
|
||||
import { type CitationUrlMap, preprocessCitationMarkdown } from "@/lib/citations/citation-parser";
|
||||
import { cn } from "@/lib/utils";
|
||||
|
||||
|
|
@ -194,6 +196,85 @@ function isVirtualFilePathToken(value: string): boolean {
|
|||
return segments.length >= 2;
|
||||
}
|
||||
|
||||
function isStandaloneDocumentsPathText(node: ReactNode): string | null {
|
||||
if (typeof node !== "string") return null;
|
||||
const value = node.trim();
|
||||
if (!value.startsWith("/documents/")) return null;
|
||||
if (value.includes(" ")) return null;
|
||||
const normalized = value.replace(/\/+$/, "");
|
||||
const leaf = normalized.split("/").filter(Boolean).at(-1) ?? "";
|
||||
if (!leaf || !leaf.includes(".")) return null;
|
||||
return value;
|
||||
}
|
||||
|
||||
function FilePathLink({ path, className }: { path: string; className?: string }) {
|
||||
const openEditorPanel = useSetAtom(openEditorPanelAtom);
|
||||
const params = useParams();
|
||||
const electronAPI = useElectronAPI();
|
||||
const searchSpaceIdParam = params?.search_space_id;
|
||||
const parsedSearchSpaceId = Array.isArray(searchSpaceIdParam)
|
||||
? Number(searchSpaceIdParam[0])
|
||||
: Number(searchSpaceIdParam);
|
||||
const resolvedSearchSpaceId = Number.isFinite(parsedSearchSpaceId)
|
||||
? parsedSearchSpaceId
|
||||
: undefined;
|
||||
|
||||
return (
|
||||
<button
|
||||
type="button"
|
||||
className={cn(
|
||||
"cursor-pointer font-mono text-[0.9em] font-medium text-primary underline underline-offset-4 transition-colors hover:text-primary/80",
|
||||
className
|
||||
)}
|
||||
onClick={(event) => {
|
||||
event.preventDefault();
|
||||
event.stopPropagation();
|
||||
void (async () => {
|
||||
if (electronAPI) {
|
||||
let resolvedLocalPath = path;
|
||||
if (electronAPI.getAgentFilesystemMounts) {
|
||||
try {
|
||||
const mounts = (await electronAPI.getAgentFilesystemMounts(
|
||||
resolvedSearchSpaceId
|
||||
)) as AgentFilesystemMount[];
|
||||
resolvedLocalPath = normalizeLocalVirtualPathForEditor(path, mounts);
|
||||
} catch {
|
||||
// Fall back to the raw path if mount lookup fails.
|
||||
}
|
||||
}
|
||||
openEditorPanel({
|
||||
kind: "local_file",
|
||||
localFilePath: resolvedLocalPath,
|
||||
title: resolvedLocalPath.split("/").pop() || resolvedLocalPath,
|
||||
searchSpaceId: resolvedSearchSpaceId,
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
if (!resolvedSearchSpaceId || !path.startsWith("/documents/")) return;
|
||||
try {
|
||||
const doc = await documentsApiService.getDocumentByVirtualPath({
|
||||
search_space_id: resolvedSearchSpaceId,
|
||||
virtual_path: path,
|
||||
});
|
||||
openEditorPanel({
|
||||
kind: "document",
|
||||
documentId: doc.id,
|
||||
searchSpaceId: resolvedSearchSpaceId,
|
||||
title: doc.title,
|
||||
});
|
||||
} catch {
|
||||
toast.error("Document not found in knowledge base.");
|
||||
}
|
||||
})();
|
||||
}}
|
||||
title="Open in editor panel"
|
||||
>
|
||||
{path}
|
||||
</button>
|
||||
);
|
||||
}
|
||||
|
||||
function MarkdownImage({ src, alt }: { src?: string; alt?: string }) {
|
||||
if (!src) return null;
|
||||
|
||||
|
|
@ -311,9 +392,14 @@ const defaultComponents = memoizeMarkdownComponents({
|
|||
},
|
||||
p: function P({ className, children, ...props }) {
|
||||
const urlMap = useCitationUrlMap();
|
||||
const standalonePath = isStandaloneDocumentsPathText(children);
|
||||
return (
|
||||
<p className={cn("aui-md-p mt-5 mb-5 leading-7 first:mt-0 last:mb-0", className)} {...props}>
|
||||
{processChildrenWithCitations(children, urlMap)}
|
||||
{standalonePath ? (
|
||||
<FilePathLink path={standalonePath} />
|
||||
) : (
|
||||
processChildrenWithCitations(children, urlMap)
|
||||
)}
|
||||
</p>
|
||||
);
|
||||
},
|
||||
|
|
@ -400,8 +486,6 @@ const defaultComponents = memoizeMarkdownComponents({
|
|||
code: function Code({ className, children, ...props }) {
|
||||
const isCodeBlock = useIsMarkdownCodeBlock();
|
||||
const { resolvedTheme } = useTheme();
|
||||
const openEditorPanel = useSetAtom(openEditorPanelAtom);
|
||||
const params = useParams();
|
||||
const electronAPI = useElectronAPI();
|
||||
const language = /language-(\w+)/.exec(className || "")?.[1] ?? "text";
|
||||
const codeString = String(children).replace(/\n$/, "");
|
||||
|
|
@ -418,53 +502,17 @@ const defaultComponents = memoizeMarkdownComponents({
|
|||
const isLikelyFolder =
|
||||
inlineValue.endsWith("/") || !leafSegment || !leafSegment.includes(".");
|
||||
const isLocalPath =
|
||||
!!electronAPI &&
|
||||
isVirtualFilePathToken(inlineValue) &&
|
||||
!inlineValue.startsWith("//") &&
|
||||
!isLikelyFolder;
|
||||
const displayLocalPath = inlineValue.replace(/^\/+/, "");
|
||||
const searchSpaceIdParam = params?.search_space_id;
|
||||
const parsedSearchSpaceId = Array.isArray(searchSpaceIdParam)
|
||||
? Number(searchSpaceIdParam[0])
|
||||
: Number(searchSpaceIdParam);
|
||||
(isVirtualFilePathToken(inlineValue) &&
|
||||
!inlineValue.startsWith("//") &&
|
||||
!isLikelyFolder &&
|
||||
!!electronAPI) ||
|
||||
(isVirtualFilePathToken(inlineValue) &&
|
||||
!inlineValue.startsWith("//") &&
|
||||
!isLikelyFolder &&
|
||||
!electronAPI &&
|
||||
inlineValue.startsWith("/documents/"));
|
||||
if (isLocalPath) {
|
||||
return (
|
||||
<button
|
||||
type="button"
|
||||
className={cn(
|
||||
"cursor-pointer font-mono text-[0.9em] font-medium text-primary underline underline-offset-4 transition-colors hover:text-primary/80"
|
||||
)}
|
||||
onClick={(event) => {
|
||||
event.preventDefault();
|
||||
event.stopPropagation();
|
||||
void (async () => {
|
||||
let resolvedLocalPath = inlineValue;
|
||||
const resolvedSearchSpaceId = Number.isFinite(parsedSearchSpaceId)
|
||||
? parsedSearchSpaceId
|
||||
: undefined;
|
||||
if (electronAPI?.getAgentFilesystemMounts) {
|
||||
try {
|
||||
const mounts = (await electronAPI.getAgentFilesystemMounts(
|
||||
resolvedSearchSpaceId
|
||||
)) as AgentFilesystemMount[];
|
||||
resolvedLocalPath = normalizeLocalVirtualPathForEditor(inlineValue, mounts);
|
||||
} catch {
|
||||
// Fall back to the raw inline path if mount lookup fails.
|
||||
}
|
||||
}
|
||||
openEditorPanel({
|
||||
kind: "local_file",
|
||||
localFilePath: resolvedLocalPath,
|
||||
title: resolvedLocalPath.split("/").pop() || resolvedLocalPath,
|
||||
searchSpaceId: resolvedSearchSpaceId,
|
||||
});
|
||||
})();
|
||||
}}
|
||||
title="Open in editor panel"
|
||||
>
|
||||
{displayLocalPath}
|
||||
</button>
|
||||
);
|
||||
return <FilePathLink path={inlineValue} className="text-[0.9em]" />;
|
||||
}
|
||||
return (
|
||||
<code
|
||||
|
|
|
|||
|
|
@ -681,14 +681,6 @@ export function LayoutDataProvider({ searchSpaceId, children }: LayoutDataProvid
|
|||
}
|
||||
}, [chatToRename, newChatTitle, queryClient, searchSpaceId, tSidebar]);
|
||||
|
||||
// Page usage
|
||||
const pageUsage = user
|
||||
? {
|
||||
pagesUsed: user.pages_used,
|
||||
pagesLimit: user.pages_limit,
|
||||
}
|
||||
: undefined;
|
||||
|
||||
// Detect if we're on the chat page (needs overflow-hidden for chat's own scroll)
|
||||
const isChatPage = pathname?.includes("/new-chat") ?? false;
|
||||
|
||||
|
|
@ -723,7 +715,6 @@ export function LayoutDataProvider({ searchSpaceId, children }: LayoutDataProvid
|
|||
onManageMembers={handleManageMembers}
|
||||
onUserSettings={handleUserSettings}
|
||||
onLogout={handleLogout}
|
||||
pageUsage={pageUsage}
|
||||
theme={theme}
|
||||
setTheme={setTheme}
|
||||
isChatPage={isChatPage}
|
||||
|
|
|
|||
|
|
@ -132,7 +132,7 @@ function MainContentPanel({
|
|||
const isDocumentTab = activeTab?.type === "document";
|
||||
|
||||
return (
|
||||
<div className="relative flex flex-1 flex-col min-w-0">
|
||||
<div className="relative isolate flex flex-1 flex-col min-w-0">
|
||||
<TabBar
|
||||
onTabSwitch={onTabSwitch}
|
||||
onNewChat={onNewChat}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,15 @@
|
|||
"use client";
|
||||
|
||||
import { useQuery } from "@rocicorp/zero/react";
|
||||
import { useIsAnonymous } from "@/contexts/anonymous-mode";
|
||||
import { queries } from "@/zero/queries";
|
||||
import { PageUsageDisplay } from "./PageUsageDisplay";
|
||||
|
||||
export function AuthenticatedPageUsageDisplay() {
|
||||
const isAnonymous = useIsAnonymous();
|
||||
const [me] = useQuery(queries.user.me({}));
|
||||
|
||||
if (isAnonymous || !me) return null;
|
||||
|
||||
return <PageUsageDisplay pagesUsed={me.pagesUsed} pagesLimit={me.pagesLimit} />;
|
||||
}
|
||||
|
|
@ -1,23 +1,18 @@
|
|||
"use client";
|
||||
|
||||
import { useQuery } from "@tanstack/react-query";
|
||||
import { useQuery } from "@rocicorp/zero/react";
|
||||
import { Progress } from "@/components/ui/progress";
|
||||
import { useIsAnonymous } from "@/contexts/anonymous-mode";
|
||||
import { stripeApiService } from "@/lib/apis/stripe-api.service";
|
||||
import { queries } from "@/zero/queries";
|
||||
|
||||
export function PremiumTokenUsageDisplay() {
|
||||
const isAnonymous = useIsAnonymous();
|
||||
const { data: tokenStatus } = useQuery({
|
||||
queryKey: ["token-status"],
|
||||
queryFn: () => stripeApiService.getTokenStatus(),
|
||||
staleTime: 60_000,
|
||||
enabled: !isAnonymous,
|
||||
});
|
||||
const [me] = useQuery(queries.user.me({}));
|
||||
|
||||
if (!tokenStatus) return null;
|
||||
if (isAnonymous || !me) return null;
|
||||
|
||||
const usagePercentage = Math.min(
|
||||
(tokenStatus.premium_tokens_used / Math.max(tokenStatus.premium_tokens_limit, 1)) * 100,
|
||||
(me.premiumTokensUsed / Math.max(me.premiumTokensLimit, 1)) * 100,
|
||||
100
|
||||
);
|
||||
|
||||
|
|
@ -31,8 +26,7 @@ export function PremiumTokenUsageDisplay() {
|
|||
<div className="space-y-1.5">
|
||||
<div className="flex justify-between items-center text-xs">
|
||||
<span className="text-muted-foreground">
|
||||
{formatTokens(tokenStatus.premium_tokens_used)} /{" "}
|
||||
{formatTokens(tokenStatus.premium_tokens_limit)} tokens
|
||||
{formatTokens(me.premiumTokensUsed)} / {formatTokens(me.premiumTokensLimit)} tokens
|
||||
</span>
|
||||
<span className="font-medium">{usagePercentage.toFixed(0)}%</span>
|
||||
</div>
|
||||
|
|
|
|||
|
|
@ -12,9 +12,9 @@ import { useIsAnonymous } from "@/contexts/anonymous-mode";
|
|||
import { cn } from "@/lib/utils";
|
||||
import { SIDEBAR_MIN_WIDTH } from "../../hooks/useSidebarResize";
|
||||
import type { ChatItem, NavItem, PageUsage, SearchSpace, User } from "../../types/layout.types";
|
||||
import { AuthenticatedPageUsageDisplay } from "./AuthenticatedPageUsageDisplay";
|
||||
import { ChatListItem } from "./ChatListItem";
|
||||
import { NavSection } from "./NavSection";
|
||||
import { PageUsageDisplay } from "./PageUsageDisplay";
|
||||
import { PremiumTokenUsageDisplay } from "./PremiumTokenUsageDisplay";
|
||||
import { SidebarButton } from "./SidebarButton";
|
||||
import { SidebarCollapseButton } from "./SidebarCollapseButton";
|
||||
|
|
@ -338,9 +338,7 @@ function SidebarUsageFooter({
|
|||
return (
|
||||
<div className="px-3 py-3 border-t space-y-3">
|
||||
<PremiumTokenUsageDisplay />
|
||||
{pageUsage && (
|
||||
<PageUsageDisplay pagesUsed={pageUsage.pagesUsed} pagesLimit={pageUsage.pagesLimit} />
|
||||
)}
|
||||
<AuthenticatedPageUsageDisplay />
|
||||
<div className="space-y-0.5">
|
||||
<Link
|
||||
href={`/dashboard/${searchSpaceId}/more-pages`}
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
"use client";
|
||||
|
||||
import { useQuery as useZeroQuery } from "@rocicorp/zero/react";
|
||||
import { useMutation, useQuery } from "@tanstack/react-query";
|
||||
import { Minus, Plus } from "lucide-react";
|
||||
import { useParams } from "next/navigation";
|
||||
|
|
@ -11,6 +12,7 @@ import { Spinner } from "@/components/ui/spinner";
|
|||
import { stripeApiService } from "@/lib/apis/stripe-api.service";
|
||||
import { AppError } from "@/lib/error";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { queries } from "@/zero/queries";
|
||||
|
||||
const TOKEN_PACK_SIZE = 1_000_000;
|
||||
const PRICE_PER_PACK_USD = 1;
|
||||
|
|
@ -21,11 +23,15 @@ export function BuyTokensContent() {
|
|||
const searchSpaceId = Number(params?.search_space_id);
|
||||
const [quantity, setQuantity] = useState(1);
|
||||
|
||||
// Server config flag: stays on REST, not per-user.
|
||||
const { data: tokenStatus } = useQuery({
|
||||
queryKey: ["token-status"],
|
||||
queryFn: () => stripeApiService.getTokenStatus(),
|
||||
});
|
||||
|
||||
// Live per-user usage via Zero.
|
||||
const [me] = useZeroQuery(queries.user.me({}));
|
||||
|
||||
const purchaseMutation = useMutation({
|
||||
mutationFn: stripeApiService.createTokenCheckoutSession,
|
||||
onSuccess: (response) => {
|
||||
|
|
@ -54,12 +60,11 @@ export function BuyTokensContent() {
|
|||
);
|
||||
}
|
||||
|
||||
const usagePercentage = tokenStatus
|
||||
? Math.min(
|
||||
(tokenStatus.premium_tokens_used / Math.max(tokenStatus.premium_tokens_limit, 1)) * 100,
|
||||
100
|
||||
)
|
||||
: 0;
|
||||
const used = me?.premiumTokensUsed ?? 0;
|
||||
const limit = me?.premiumTokensLimit ?? 0;
|
||||
// Mirrors the backend formula in stripe_routes.py:608 (max(0, limit - used)).
|
||||
const remaining = Math.max(0, limit - used);
|
||||
const usagePercentage = me ? Math.min((used / Math.max(limit, 1)) * 100, 100) : 0;
|
||||
|
||||
return (
|
||||
<div className="w-full space-y-5">
|
||||
|
|
@ -68,18 +73,17 @@ export function BuyTokensContent() {
|
|||
<p className="mt-1 text-sm text-muted-foreground">$1 per 1M tokens, pay as you go</p>
|
||||
</div>
|
||||
|
||||
{tokenStatus && (
|
||||
{me && (
|
||||
<div className="rounded-lg border bg-muted/20 p-3 space-y-1.5">
|
||||
<div className="flex justify-between items-center text-xs">
|
||||
<span className="text-muted-foreground">
|
||||
{tokenStatus.premium_tokens_used.toLocaleString()} /{" "}
|
||||
{tokenStatus.premium_tokens_limit.toLocaleString()} premium tokens
|
||||
{used.toLocaleString()} / {limit.toLocaleString()} premium tokens
|
||||
</span>
|
||||
<span className="font-medium">{usagePercentage.toFixed(0)}%</span>
|
||||
</div>
|
||||
<Progress value={usagePercentage} className="h-1.5" />
|
||||
<p className="text-[11px] text-muted-foreground">
|
||||
{tokenStatus.premium_tokens_remaining.toLocaleString()} tokens remaining
|
||||
{remaining.toLocaleString()} tokens remaining
|
||||
</p>
|
||||
</div>
|
||||
)}
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ import {
|
|||
type DeleteDocumentRequest,
|
||||
deleteDocumentRequest,
|
||||
deleteDocumentResponse,
|
||||
documentTitleRead,
|
||||
type GetDocumentByChunkRequest,
|
||||
type GetDocumentChunksRequest,
|
||||
type GetDocumentRequest,
|
||||
|
|
@ -269,6 +270,17 @@ class DocumentsApiService {
|
|||
);
|
||||
};
|
||||
|
||||
getDocumentByVirtualPath = async (request: { search_space_id: number; virtual_path: string }) => {
|
||||
const params = new URLSearchParams({
|
||||
search_space_id: String(request.search_space_id),
|
||||
virtual_path: request.virtual_path,
|
||||
});
|
||||
return baseApiService.get(
|
||||
`/api/v1/documents/by-virtual-path?${params.toString()}`,
|
||||
documentTitleRead
|
||||
);
|
||||
};
|
||||
|
||||
/**
|
||||
* Get document type counts
|
||||
*/
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ import { chatSessionQueries, commentQueries, messageQueries } from "./chat";
|
|||
import { connectorQueries, documentQueries } from "./documents";
|
||||
import { folderQueries } from "./folders";
|
||||
import { notificationQueries } from "./inbox";
|
||||
import { userQueries } from "./user";
|
||||
|
||||
export const queries = defineQueries({
|
||||
notifications: notificationQueries,
|
||||
|
|
@ -12,4 +13,5 @@ export const queries = defineQueries({
|
|||
messages: messageQueries,
|
||||
comments: commentQueries,
|
||||
chatSession: chatSessionQueries,
|
||||
user: userQueries,
|
||||
});
|
||||
|
|
|
|||
11
surfsense_web/zero/queries/user.ts
Normal file
11
surfsense_web/zero/queries/user.ts
Normal file
|
|
@ -0,0 +1,11 @@
|
|||
import { defineQuery } from "@rocicorp/zero";
|
||||
import { z } from "zod";
|
||||
import { zql } from "../schema/index";
|
||||
|
||||
export const userQueries = {
|
||||
me: defineQuery(z.object({}), ({ ctx }) => {
|
||||
const userId = ctx?.userId;
|
||||
if (!userId) return zql.user.where("id", "__none__").one();
|
||||
return zql.user.where("id", userId).one();
|
||||
}),
|
||||
};
|
||||
|
|
@ -3,6 +3,7 @@ import { chatCommentTable, chatSessionStateTable, newChatMessageTable } from "./
|
|||
import { documentTable, searchSourceConnectorTable } from "./documents";
|
||||
import { folderTable } from "./folders";
|
||||
import { notificationTable } from "./inbox";
|
||||
import { userTable } from "./user";
|
||||
|
||||
const chatCommentRelationships = relationships(chatCommentTable, ({ one }) => ({
|
||||
message: one({
|
||||
|
|
@ -34,6 +35,7 @@ export const schema = createSchema({
|
|||
newChatMessageTable,
|
||||
chatCommentTable,
|
||||
chatSessionStateTable,
|
||||
userTable,
|
||||
],
|
||||
relationships: [chatCommentRelationships, newChatMessageRelationships],
|
||||
});
|
||||
|
|
|
|||
11
surfsense_web/zero/schema/user.ts
Normal file
11
surfsense_web/zero/schema/user.ts
Normal file
|
|
@ -0,0 +1,11 @@
|
|||
import { number, string, table } from "@rocicorp/zero";
|
||||
|
||||
export const userTable = table("user")
|
||||
.columns({
|
||||
id: string(),
|
||||
pagesLimit: number().from("pages_limit"),
|
||||
pagesUsed: number().from("pages_used"),
|
||||
premiumTokensLimit: number().from("premium_tokens_limit"),
|
||||
premiumTokensUsed: number().from("premium_tokens_used"),
|
||||
})
|
||||
.primaryKey("id");
|
||||
Loading…
Add table
Add a link
Reference in a new issue