mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-04 05:12:38 +02:00
feat: unified credits and its cost calculations
This commit is contained in:
parent
451a98936e
commit
ae9d36d77f
61 changed files with 5835 additions and 272 deletions
|
|
@ -159,10 +159,13 @@ STRIPE_PAGE_BUYING_ENABLED=FALSE
|
||||||
# STRIPE_RECONCILIATION_LOOKBACK_MINUTES=10
|
# STRIPE_RECONCILIATION_LOOKBACK_MINUTES=10
|
||||||
# STRIPE_RECONCILIATION_BATCH_SIZE=100
|
# STRIPE_RECONCILIATION_BATCH_SIZE=100
|
||||||
|
|
||||||
# Premium token purchases ($1 per 1M tokens for premium-tier models)
|
# Premium credit purchases via Stripe ($1 buys 1_000_000 micro-USD of
|
||||||
|
# credit; premium turns debit the actual per-call provider cost
|
||||||
|
# reported by LiteLLM, so cheap and expensive models bill proportionally)
|
||||||
# STRIPE_TOKEN_BUYING_ENABLED=FALSE
|
# STRIPE_TOKEN_BUYING_ENABLED=FALSE
|
||||||
# STRIPE_PREMIUM_TOKEN_PRICE_ID=price_...
|
# STRIPE_PREMIUM_TOKEN_PRICE_ID=price_...
|
||||||
# STRIPE_TOKENS_PER_UNIT=1000000
|
# STRIPE_CREDIT_MICROS_PER_UNIT=1000000
|
||||||
|
# DEPRECATED — STRIPE_TOKENS_PER_UNIT=1000000
|
||||||
|
|
||||||
# ------------------------------------------------------------------------------
|
# ------------------------------------------------------------------------------
|
||||||
# TTS & STT (Text-to-Speech / Speech-to-Text)
|
# TTS & STT (Text-to-Speech / Speech-to-Text)
|
||||||
|
|
@ -315,9 +318,24 @@ STT_SERVICE=local/base
|
||||||
# Pages limit per user for ETL (default: unlimited)
|
# Pages limit per user for ETL (default: unlimited)
|
||||||
# PAGES_LIMIT=500
|
# PAGES_LIMIT=500
|
||||||
|
|
||||||
# Premium token quota per registered user (default: 5M)
|
# Premium credit quota per registered user, in micro-USD (default: $5).
|
||||||
# Only applies to models with billing_tier=premium in global_llm_config.yaml
|
# Premium turns are debited at the actual per-call provider cost reported
|
||||||
# PREMIUM_TOKEN_LIMIT=5000000
|
# by LiteLLM. Only applies to models with billing_tier=premium.
|
||||||
|
# PREMIUM_CREDIT_MICROS_LIMIT=5000000
|
||||||
|
# DEPRECATED — PREMIUM_TOKEN_LIMIT=5000000
|
||||||
|
|
||||||
|
# Safety ceiling on per-call premium reservation, in micro-USD ($1.00 default).
|
||||||
|
# QUOTA_MAX_RESERVE_MICROS=1000000
|
||||||
|
|
||||||
|
# Per-image reservation for POST /image-generations, in micro-USD ($0.05 default).
|
||||||
|
# QUOTA_DEFAULT_IMAGE_RESERVE_MICROS=50000
|
||||||
|
|
||||||
|
# Per-podcast reservation for the podcast Celery task ($0.20 default).
|
||||||
|
# QUOTA_DEFAULT_PODCAST_RESERVE_MICROS=200000
|
||||||
|
|
||||||
|
# Per-video-presentation reservation for the video Celery task ($1.00 default).
|
||||||
|
# Override path bypasses QUOTA_MAX_RESERVE_MICROS clamp — raise with care.
|
||||||
|
# QUOTA_DEFAULT_VIDEO_PRESENTATION_RESERVE_MICROS=1000000
|
||||||
|
|
||||||
# No-login (anonymous) mode — public users can chat without an account
|
# No-login (anonymous) mode — public users can chat without an account
|
||||||
# Set TRUE to enable /free pages and anonymous chat API
|
# Set TRUE to enable /free pages and anonymous chat API
|
||||||
|
|
|
||||||
|
|
@ -54,11 +54,15 @@ STRIPE_PAGES_PER_UNIT=1000
|
||||||
# Set FALSE to disable new checkout session creation temporarily
|
# Set FALSE to disable new checkout session creation temporarily
|
||||||
STRIPE_PAGE_BUYING_ENABLED=TRUE
|
STRIPE_PAGE_BUYING_ENABLED=TRUE
|
||||||
|
|
||||||
# Premium token purchases via Stripe (for premium-tier model usage)
|
# Premium credit purchases via Stripe (for premium-tier model usage).
|
||||||
# Set TRUE to allow users to buy premium token packs ($1 per 1M tokens)
|
# Each pack grants STRIPE_CREDIT_MICROS_PER_UNIT micro-USD of credit
|
||||||
|
# (default 1_000_000 = $1.00). Premium turns are billed at the actual
|
||||||
|
# per-call provider cost reported by LiteLLM.
|
||||||
STRIPE_TOKEN_BUYING_ENABLED=FALSE
|
STRIPE_TOKEN_BUYING_ENABLED=FALSE
|
||||||
STRIPE_PREMIUM_TOKEN_PRICE_ID=price_...
|
STRIPE_PREMIUM_TOKEN_PRICE_ID=price_...
|
||||||
STRIPE_TOKENS_PER_UNIT=1000000
|
STRIPE_CREDIT_MICROS_PER_UNIT=1000000
|
||||||
|
# DEPRECATED — use STRIPE_CREDIT_MICROS_PER_UNIT (1:1 numerical mapping):
|
||||||
|
# STRIPE_TOKENS_PER_UNIT=1000000
|
||||||
|
|
||||||
# Periodic Stripe safety net for purchases left in PENDING (minutes old)
|
# Periodic Stripe safety net for purchases left in PENDING (minutes old)
|
||||||
STRIPE_RECONCILIATION_LOOKBACK_MINUTES=10
|
STRIPE_RECONCILIATION_LOOKBACK_MINUTES=10
|
||||||
|
|
@ -184,9 +188,35 @@ VIDEO_PRESENTATION_DEFAULT_DURATION_IN_FRAMES=300
|
||||||
# (Optional) Maximum pages limit per user for ETL services (default: `999999999` for unlimited in OSS version)
|
# (Optional) Maximum pages limit per user for ETL services (default: `999999999` for unlimited in OSS version)
|
||||||
PAGES_LIMIT=500
|
PAGES_LIMIT=500
|
||||||
|
|
||||||
# Premium token quota per registered user (default: 3,000,000)
|
# Premium credit quota per registered user, in micro-USD
|
||||||
# Applies only to models with billing_tier=premium in global_llm_config.yaml
|
# (default: 5,000,000 == $5.00 of credit). Premium turns are debited at the
|
||||||
PREMIUM_TOKEN_LIMIT=3000000
|
# actual per-call provider cost reported by LiteLLM, so cheap and expensive
|
||||||
|
# models bill proportionally. Applies only to models with
|
||||||
|
# billing_tier=premium in global_llm_config.yaml.
|
||||||
|
PREMIUM_CREDIT_MICROS_LIMIT=5000000
|
||||||
|
# DEPRECATED — use PREMIUM_CREDIT_MICROS_LIMIT (1:1 numerical mapping):
|
||||||
|
# PREMIUM_TOKEN_LIMIT=5000000
|
||||||
|
|
||||||
|
# Safety ceiling on per-call premium reservation, in micro-USD.
|
||||||
|
# stream_new_chat estimates an upper-bound cost from the model's
|
||||||
|
# litellm-published per-token rates × the config's quota_reserve_tokens
|
||||||
|
# and clamps to this value so a misconfigured model can't lock the
|
||||||
|
# user's whole balance on one call. Default $1.00.
|
||||||
|
QUOTA_MAX_RESERVE_MICROS=1000000
|
||||||
|
|
||||||
|
# Per-image reservation (in micro-USD) for the POST /image-generations
|
||||||
|
# endpoint. Bypassed for free configs. Default $0.05.
|
||||||
|
QUOTA_DEFAULT_IMAGE_RESERVE_MICROS=50000
|
||||||
|
|
||||||
|
# Per-podcast reservation (in micro-USD) used by the podcast Celery task.
|
||||||
|
# Single envelope covers one transcript-generation LLM call. Default $0.20.
|
||||||
|
QUOTA_DEFAULT_PODCAST_RESERVE_MICROS=200000
|
||||||
|
|
||||||
|
# Per-video-presentation reservation (in micro-USD) used by the video
|
||||||
|
# presentation Celery task. Covers worst-case fan-out of N slide-scene
|
||||||
|
# generations + refines. Default $1.00. NOTE: tasks using the override
|
||||||
|
# path bypass the QUOTA_MAX_RESERVE_MICROS clamp — raise with care.
|
||||||
|
QUOTA_DEFAULT_VIDEO_PRESENTATION_RESERVE_MICROS=1000000
|
||||||
|
|
||||||
# No-login (anonymous) mode — allows public users to chat without an account
|
# No-login (anonymous) mode — allows public users to chat without an account
|
||||||
# Set TRUE to enable /free pages and anonymous chat API
|
# Set TRUE to enable /free pages and anonymous chat API
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,291 @@
|
||||||
|
"""rename premium token columns to credit micros and add cost_micros to token_usage
|
||||||
|
|
||||||
|
Migrates the premium quota system from a flat token counter to a USD-cost
|
||||||
|
based credit system, where 1 credit = 1 micro-USD ($0.000001).
|
||||||
|
|
||||||
|
Column renames (1:1 numerical mapping — the prior $1 per 1M tokens Stripe
|
||||||
|
price means every existing value is already correct in the new unit, no
|
||||||
|
data transformation needed):
|
||||||
|
|
||||||
|
user.premium_tokens_limit -> premium_credit_micros_limit
|
||||||
|
user.premium_tokens_used -> premium_credit_micros_used
|
||||||
|
user.premium_tokens_reserved -> premium_credit_micros_reserved
|
||||||
|
|
||||||
|
premium_token_purchases.tokens_granted -> credit_micros_granted
|
||||||
|
|
||||||
|
New column for cost auditing per turn:
|
||||||
|
|
||||||
|
token_usage.cost_micros (BigInteger NOT NULL DEFAULT 0)
|
||||||
|
|
||||||
|
The "user" table is in zero_publication's column list (added in 139), so
|
||||||
|
this migration must drop and recreate the publication with the renamed
|
||||||
|
column names, otherwise zero-cache will replicate stale column names and
|
||||||
|
the FE Zero schema will fail to bind.
|
||||||
|
|
||||||
|
IMPORTANT - before AND after running this migration:
|
||||||
|
1. Stop zero-cache (it holds replication locks that will deadlock DDL)
|
||||||
|
2. Run: alembic upgrade head
|
||||||
|
3. Delete / reset the zero-cache data volume
|
||||||
|
4. Restart zero-cache (it will do a fresh initial sync)
|
||||||
|
|
||||||
|
Skipping the zero-cache stop will deadlock at the ACCESS EXCLUSIVE LOCK on
|
||||||
|
"user". Skipping the data-volume reset will leave IndexedDB clients seeing
|
||||||
|
column-not-found errors from a stale catalog snapshot.
|
||||||
|
|
||||||
|
Revision ID: 140
|
||||||
|
Revises: 139
|
||||||
|
"""
|
||||||
|
|
||||||
|
from collections.abc import Sequence
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
revision: str = "140"
|
||||||
|
down_revision: str | None = "139"
|
||||||
|
branch_labels: str | Sequence[str] | None = None
|
||||||
|
depends_on: str | Sequence[str] | None = None
|
||||||
|
|
||||||
|
PUBLICATION_NAME = "zero_publication"
|
||||||
|
|
||||||
|
# Replicates 139's document column list verbatim — must stay in sync.
|
||||||
|
DOCUMENT_COLS = [
|
||||||
|
"id",
|
||||||
|
"title",
|
||||||
|
"document_type",
|
||||||
|
"search_space_id",
|
||||||
|
"folder_id",
|
||||||
|
"created_by_id",
|
||||||
|
"status",
|
||||||
|
"created_at",
|
||||||
|
"updated_at",
|
||||||
|
]
|
||||||
|
|
||||||
|
# Same five live-meter fields as 139, with the renamed column names.
|
||||||
|
USER_COLS = [
|
||||||
|
"id",
|
||||||
|
"pages_limit",
|
||||||
|
"pages_used",
|
||||||
|
"premium_credit_micros_limit",
|
||||||
|
"premium_credit_micros_used",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _terminate_blocked_pids(conn, table: str) -> None:
|
||||||
|
"""Kill backends whose locks on *table* would block our AccessExclusiveLock."""
|
||||||
|
conn.execute(
|
||||||
|
sa.text(
|
||||||
|
"SELECT pg_terminate_backend(l.pid) "
|
||||||
|
"FROM pg_locks l "
|
||||||
|
"JOIN pg_class c ON c.oid = l.relation "
|
||||||
|
"WHERE c.relname = :tbl "
|
||||||
|
" AND l.pid != pg_backend_pid()"
|
||||||
|
),
|
||||||
|
{"tbl": table},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _has_zero_version(conn, table: str) -> bool:
|
||||||
|
return (
|
||||||
|
conn.execute(
|
||||||
|
sa.text(
|
||||||
|
"SELECT 1 FROM information_schema.columns "
|
||||||
|
"WHERE table_name = :tbl AND column_name = '_0_version'"
|
||||||
|
),
|
||||||
|
{"tbl": table},
|
||||||
|
).fetchone()
|
||||||
|
is not None
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _column_exists(conn, table: str, column: str) -> bool:
|
||||||
|
return (
|
||||||
|
conn.execute(
|
||||||
|
sa.text(
|
||||||
|
"SELECT 1 FROM information_schema.columns "
|
||||||
|
"WHERE table_name = :tbl AND column_name = :col"
|
||||||
|
),
|
||||||
|
{"tbl": table, "col": column},
|
||||||
|
).fetchone()
|
||||||
|
is not None
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _build_publication_ddl(
|
||||||
|
user_cols: list[str],
|
||||||
|
*,
|
||||||
|
documents_has_zero_ver: bool,
|
||||||
|
user_has_zero_ver: bool,
|
||||||
|
) -> str:
|
||||||
|
doc_cols = DOCUMENT_COLS + (['"_0_version"'] if documents_has_zero_ver else [])
|
||||||
|
user_col_list_with_meta = user_cols + (
|
||||||
|
['"_0_version"'] if user_has_zero_ver else []
|
||||||
|
)
|
||||||
|
doc_col_list = ", ".join(doc_cols)
|
||||||
|
user_col_list = ", ".join(user_col_list_with_meta)
|
||||||
|
return (
|
||||||
|
f"CREATE PUBLICATION {PUBLICATION_NAME} FOR TABLE "
|
||||||
|
f"notifications, "
|
||||||
|
f"documents ({doc_col_list}), "
|
||||||
|
f"folders, "
|
||||||
|
f"search_source_connectors, "
|
||||||
|
f"new_chat_messages, "
|
||||||
|
f"chat_comments, "
|
||||||
|
f"chat_session_state, "
|
||||||
|
f'"user" ({user_col_list})'
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
conn = op.get_bind()
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# 1. Add cost_micros to token_usage. Idempotent guard so re-runs in
|
||||||
|
# dev environments are safe.
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
if not _column_exists(conn, "token_usage", "cost_micros"):
|
||||||
|
op.add_column(
|
||||||
|
"token_usage",
|
||||||
|
sa.Column(
|
||||||
|
"cost_micros",
|
||||||
|
sa.BigInteger(),
|
||||||
|
nullable=False,
|
||||||
|
server_default="0",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# 2. Rename premium_token_purchases.tokens_granted -> credit_micros_granted.
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
if _column_exists(
|
||||||
|
conn, "premium_token_purchases", "tokens_granted"
|
||||||
|
) and not _column_exists(conn, "premium_token_purchases", "credit_micros_granted"):
|
||||||
|
op.alter_column(
|
||||||
|
"premium_token_purchases",
|
||||||
|
"tokens_granted",
|
||||||
|
new_column_name="credit_micros_granted",
|
||||||
|
)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# 3. Rename user.premium_tokens_* -> premium_credit_micros_*.
|
||||||
|
#
|
||||||
|
# We must drop the publication first (it references the old column
|
||||||
|
# names) and re-acquire the lock for DDL. asyncpg requires LOCK TABLE
|
||||||
|
# in a transaction block; alembic's outer transaction already holds
|
||||||
|
# one, but a SAVEPOINT keeps the LOCK + DDL atomic.
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
tx = conn.begin_nested() if conn.in_transaction() else conn.begin()
|
||||||
|
with tx:
|
||||||
|
conn.execute(sa.text("SET lock_timeout = '10s'"))
|
||||||
|
|
||||||
|
_terminate_blocked_pids(conn, "user")
|
||||||
|
conn.execute(sa.text('LOCK TABLE "user" IN ACCESS EXCLUSIVE MODE'))
|
||||||
|
|
||||||
|
# Re-assert REPLICA IDENTITY DEFAULT for safety; column-list
|
||||||
|
# publications require at least the PK to be in the column list,
|
||||||
|
# which is true for both the old and new shape.
|
||||||
|
conn.execute(sa.text('ALTER TABLE "user" REPLICA IDENTITY DEFAULT'))
|
||||||
|
|
||||||
|
# Drop the publication BEFORE renaming columns, otherwise Postgres
|
||||||
|
# rejects the rename: "cannot drop column ... referenced by
|
||||||
|
# publication".
|
||||||
|
conn.execute(sa.text(f"DROP PUBLICATION IF EXISTS {PUBLICATION_NAME}"))
|
||||||
|
|
||||||
|
for old, new in (
|
||||||
|
("premium_tokens_limit", "premium_credit_micros_limit"),
|
||||||
|
("premium_tokens_used", "premium_credit_micros_used"),
|
||||||
|
("premium_tokens_reserved", "premium_credit_micros_reserved"),
|
||||||
|
):
|
||||||
|
if _column_exists(conn, "user", old) and not _column_exists(
|
||||||
|
conn, "user", new
|
||||||
|
):
|
||||||
|
op.alter_column("user", old, new_column_name=new)
|
||||||
|
|
||||||
|
# Update the server_default on the renamed limit column so newly
|
||||||
|
# inserted users get $5 of credit (== 5_000_000 micros) by
|
||||||
|
# default. Existing rows are unaffected.
|
||||||
|
op.alter_column(
|
||||||
|
"user",
|
||||||
|
"premium_credit_micros_limit",
|
||||||
|
server_default="5000000",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Recreate the publication with the new column names.
|
||||||
|
documents_has_zero_ver = _has_zero_version(conn, "documents")
|
||||||
|
user_has_zero_ver = _has_zero_version(conn, "user")
|
||||||
|
conn.execute(
|
||||||
|
sa.text(
|
||||||
|
_build_publication_ddl(
|
||||||
|
USER_COLS,
|
||||||
|
documents_has_zero_ver=documents_has_zero_ver,
|
||||||
|
user_has_zero_ver=user_has_zero_ver,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
"""Revert the rename and drop ``cost_micros``.
|
||||||
|
|
||||||
|
Mirrors ``upgrade``: drop the publication, rename columns back, drop
|
||||||
|
the new column, recreate the publication with the old column list.
|
||||||
|
Same zero-cache stop/reset runbook applies in reverse.
|
||||||
|
"""
|
||||||
|
conn = op.get_bind()
|
||||||
|
|
||||||
|
tx = conn.begin_nested() if conn.in_transaction() else conn.begin()
|
||||||
|
with tx:
|
||||||
|
conn.execute(sa.text("SET lock_timeout = '10s'"))
|
||||||
|
|
||||||
|
_terminate_blocked_pids(conn, "user")
|
||||||
|
conn.execute(sa.text('LOCK TABLE "user" IN ACCESS EXCLUSIVE MODE'))
|
||||||
|
|
||||||
|
conn.execute(sa.text(f"DROP PUBLICATION IF EXISTS {PUBLICATION_NAME}"))
|
||||||
|
|
||||||
|
for new, old in (
|
||||||
|
("premium_credit_micros_limit", "premium_tokens_limit"),
|
||||||
|
("premium_credit_micros_used", "premium_tokens_used"),
|
||||||
|
("premium_credit_micros_reserved", "premium_tokens_reserved"),
|
||||||
|
):
|
||||||
|
if _column_exists(conn, "user", new) and not _column_exists(
|
||||||
|
conn, "user", old
|
||||||
|
):
|
||||||
|
op.alter_column("user", new, new_column_name=old)
|
||||||
|
|
||||||
|
op.alter_column(
|
||||||
|
"user",
|
||||||
|
"premium_tokens_limit",
|
||||||
|
server_default="5000000",
|
||||||
|
)
|
||||||
|
|
||||||
|
legacy_user_cols = [
|
||||||
|
"id",
|
||||||
|
"pages_limit",
|
||||||
|
"pages_used",
|
||||||
|
"premium_tokens_limit",
|
||||||
|
"premium_tokens_used",
|
||||||
|
]
|
||||||
|
documents_has_zero_ver = _has_zero_version(conn, "documents")
|
||||||
|
user_has_zero_ver = _has_zero_version(conn, "user")
|
||||||
|
conn.execute(
|
||||||
|
sa.text(
|
||||||
|
_build_publication_ddl(
|
||||||
|
legacy_user_cols,
|
||||||
|
documents_has_zero_ver=documents_has_zero_ver,
|
||||||
|
user_has_zero_ver=user_has_zero_ver,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if _column_exists(
|
||||||
|
conn, "premium_token_purchases", "credit_micros_granted"
|
||||||
|
) and not _column_exists(conn, "premium_token_purchases", "tokens_granted"):
|
||||||
|
op.alter_column(
|
||||||
|
"premium_token_purchases",
|
||||||
|
"credit_micros_granted",
|
||||||
|
new_column_name="tokens_granted",
|
||||||
|
)
|
||||||
|
|
||||||
|
if _column_exists(conn, "token_usage", "cost_micros"):
|
||||||
|
op.drop_column("token_usage", "cost_micros")
|
||||||
|
|
@ -31,6 +31,7 @@ from app.config import (
|
||||||
initialize_image_gen_router,
|
initialize_image_gen_router,
|
||||||
initialize_llm_router,
|
initialize_llm_router,
|
||||||
initialize_openrouter_integration,
|
initialize_openrouter_integration,
|
||||||
|
initialize_pricing_registration,
|
||||||
initialize_vision_llm_router,
|
initialize_vision_llm_router,
|
||||||
)
|
)
|
||||||
from app.db import User, create_db_and_tables, get_async_session
|
from app.db import User, create_db_and_tables, get_async_session
|
||||||
|
|
@ -432,6 +433,7 @@ async def lifespan(app: FastAPI):
|
||||||
await setup_checkpointer_tables()
|
await setup_checkpointer_tables()
|
||||||
initialize_openrouter_integration()
|
initialize_openrouter_integration()
|
||||||
_start_openrouter_background_refresh()
|
_start_openrouter_background_refresh()
|
||||||
|
initialize_pricing_registration()
|
||||||
initialize_llm_router()
|
initialize_llm_router()
|
||||||
initialize_image_gen_router()
|
initialize_image_gen_router()
|
||||||
initialize_vision_llm_router()
|
initialize_vision_llm_router()
|
||||||
|
|
|
||||||
|
|
@ -22,10 +22,12 @@ def init_worker(**kwargs):
|
||||||
initialize_image_gen_router,
|
initialize_image_gen_router,
|
||||||
initialize_llm_router,
|
initialize_llm_router,
|
||||||
initialize_openrouter_integration,
|
initialize_openrouter_integration,
|
||||||
|
initialize_pricing_registration,
|
||||||
initialize_vision_llm_router,
|
initialize_vision_llm_router,
|
||||||
)
|
)
|
||||||
|
|
||||||
initialize_openrouter_integration()
|
initialize_openrouter_integration()
|
||||||
|
initialize_pricing_registration()
|
||||||
initialize_llm_router()
|
initialize_llm_router()
|
||||||
initialize_image_gen_router()
|
initialize_image_gen_router()
|
||||||
initialize_vision_llm_router()
|
initialize_vision_llm_router()
|
||||||
|
|
|
||||||
|
|
@ -138,7 +138,11 @@ def load_global_image_gen_configs():
|
||||||
try:
|
try:
|
||||||
with open(global_config_file, encoding="utf-8") as f:
|
with open(global_config_file, encoding="utf-8") as f:
|
||||||
data = yaml.safe_load(f)
|
data = yaml.safe_load(f)
|
||||||
return data.get("global_image_generation_configs", [])
|
configs = data.get("global_image_generation_configs", []) or []
|
||||||
|
for cfg in configs:
|
||||||
|
if isinstance(cfg, dict):
|
||||||
|
cfg.setdefault("billing_tier", "free")
|
||||||
|
return configs
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Warning: Failed to load global image generation configs: {e}")
|
print(f"Warning: Failed to load global image generation configs: {e}")
|
||||||
return []
|
return []
|
||||||
|
|
@ -153,7 +157,11 @@ def load_global_vision_llm_configs():
|
||||||
try:
|
try:
|
||||||
with open(global_config_file, encoding="utf-8") as f:
|
with open(global_config_file, encoding="utf-8") as f:
|
||||||
data = yaml.safe_load(f)
|
data = yaml.safe_load(f)
|
||||||
return data.get("global_vision_llm_configs", [])
|
configs = data.get("global_vision_llm_configs", []) or []
|
||||||
|
for cfg in configs:
|
||||||
|
if isinstance(cfg, dict):
|
||||||
|
cfg.setdefault("billing_tier", "free")
|
||||||
|
return configs
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Warning: Failed to load global vision LLM configs: {e}")
|
print(f"Warning: Failed to load global vision LLM configs: {e}")
|
||||||
return []
|
return []
|
||||||
|
|
@ -254,6 +262,15 @@ def load_openrouter_integration_settings() -> dict | None:
|
||||||
"anonymous_enabled_free", settings["anonymous_enabled"]
|
"anonymous_enabled_free", settings["anonymous_enabled"]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Image generation + vision LLM emission are opt-in (issue L).
|
||||||
|
# OpenRouter's catalogue contains hundreds of image / vision
|
||||||
|
# capable models; auto-injecting all of them into every
|
||||||
|
# deployment would explode the model selector and surprise
|
||||||
|
# operators upgrading from prior versions. Default to False so
|
||||||
|
# admins must explicitly turn them on.
|
||||||
|
settings.setdefault("image_generation_enabled", False)
|
||||||
|
settings.setdefault("vision_enabled", False)
|
||||||
|
|
||||||
return settings
|
return settings
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Warning: Failed to load OpenRouter integration settings: {e}")
|
print(f"Warning: Failed to load OpenRouter integration settings: {e}")
|
||||||
|
|
@ -296,10 +313,60 @@ def initialize_openrouter_integration():
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
print("Info: OpenRouter integration enabled but no models fetched")
|
print("Info: OpenRouter integration enabled but no models fetched")
|
||||||
|
|
||||||
|
# Image generation + vision LLM emissions are opt-in (issue L).
|
||||||
|
# Both reuse the catalogue already cached by ``service.initialize``
|
||||||
|
# so we don't make additional network calls here.
|
||||||
|
if settings.get("image_generation_enabled"):
|
||||||
|
try:
|
||||||
|
image_configs = service.get_image_generation_configs()
|
||||||
|
if image_configs:
|
||||||
|
config.GLOBAL_IMAGE_GEN_CONFIGS.extend(image_configs)
|
||||||
|
print(
|
||||||
|
f"Info: OpenRouter integration added {len(image_configs)} "
|
||||||
|
f"image-generation models"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Warning: Failed to inject OpenRouter image-gen configs: {e}")
|
||||||
|
|
||||||
|
if settings.get("vision_enabled"):
|
||||||
|
try:
|
||||||
|
vision_configs = service.get_vision_llm_configs()
|
||||||
|
if vision_configs:
|
||||||
|
config.GLOBAL_VISION_LLM_CONFIGS.extend(vision_configs)
|
||||||
|
print(
|
||||||
|
f"Info: OpenRouter integration added {len(vision_configs)} "
|
||||||
|
f"vision LLM models"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Warning: Failed to inject OpenRouter vision-LLM configs: {e}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Warning: Failed to initialize OpenRouter integration: {e}")
|
print(f"Warning: Failed to initialize OpenRouter integration: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
def initialize_pricing_registration():
|
||||||
|
"""
|
||||||
|
Teach LiteLLM the per-token cost of every deployment in
|
||||||
|
``config.GLOBAL_LLM_CONFIGS`` (OpenRouter dynamic models pulled
|
||||||
|
from the OpenRouter catalogue + any operator-declared YAML pricing).
|
||||||
|
|
||||||
|
Must run AFTER ``initialize_openrouter_integration()`` so the
|
||||||
|
OpenRouter catalogue is populated and BEFORE the first LLM call so
|
||||||
|
``response_cost`` is available in ``TokenTrackingCallback``.
|
||||||
|
|
||||||
|
Failures are logged but never raised — startup must not be blocked
|
||||||
|
by a missing pricing entry; the worst-case is the model debits 0.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from app.services.pricing_registration import (
|
||||||
|
register_pricing_from_global_configs,
|
||||||
|
)
|
||||||
|
|
||||||
|
register_pricing_from_global_configs()
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Warning: Failed to register LiteLLM pricing: {e}")
|
||||||
|
|
||||||
|
|
||||||
def initialize_llm_router():
|
def initialize_llm_router():
|
||||||
"""
|
"""
|
||||||
Initialize the LLM Router service for Auto mode.
|
Initialize the LLM Router service for Auto mode.
|
||||||
|
|
@ -444,14 +511,54 @@ class Config:
|
||||||
os.getenv("STRIPE_RECONCILIATION_BATCH_SIZE", "100")
|
os.getenv("STRIPE_RECONCILIATION_BATCH_SIZE", "100")
|
||||||
)
|
)
|
||||||
|
|
||||||
# Premium token quota settings
|
# Premium credit (micro-USD) quota settings.
|
||||||
PREMIUM_TOKEN_LIMIT = int(os.getenv("PREMIUM_TOKEN_LIMIT", "3000000"))
|
#
|
||||||
|
# Storage unit is integer micro-USD (1_000_000 = $1.00). The legacy
|
||||||
|
# ``PREMIUM_TOKEN_LIMIT`` and ``STRIPE_TOKENS_PER_UNIT`` env vars are
|
||||||
|
# still honoured for one release as fall-back values — the prior
|
||||||
|
# $1-per-1M-tokens Stripe price means every existing value maps 1:1
|
||||||
|
# to micros, so operators upgrading without changing their .env still
|
||||||
|
# get correct behaviour. A startup deprecation warning fires below if
|
||||||
|
# they're set.
|
||||||
|
PREMIUM_CREDIT_MICROS_LIMIT = int(
|
||||||
|
os.getenv("PREMIUM_CREDIT_MICROS_LIMIT")
|
||||||
|
or os.getenv("PREMIUM_TOKEN_LIMIT", "5000000")
|
||||||
|
)
|
||||||
STRIPE_PREMIUM_TOKEN_PRICE_ID = os.getenv("STRIPE_PREMIUM_TOKEN_PRICE_ID")
|
STRIPE_PREMIUM_TOKEN_PRICE_ID = os.getenv("STRIPE_PREMIUM_TOKEN_PRICE_ID")
|
||||||
STRIPE_TOKENS_PER_UNIT = int(os.getenv("STRIPE_TOKENS_PER_UNIT", "1000000"))
|
STRIPE_CREDIT_MICROS_PER_UNIT = int(
|
||||||
|
os.getenv("STRIPE_CREDIT_MICROS_PER_UNIT")
|
||||||
|
or os.getenv("STRIPE_TOKENS_PER_UNIT", "1000000")
|
||||||
|
)
|
||||||
STRIPE_TOKEN_BUYING_ENABLED = (
|
STRIPE_TOKEN_BUYING_ENABLED = (
|
||||||
os.getenv("STRIPE_TOKEN_BUYING_ENABLED", "FALSE").upper() == "TRUE"
|
os.getenv("STRIPE_TOKEN_BUYING_ENABLED", "FALSE").upper() == "TRUE"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Safety ceiling on the per-call premium reservation. ``stream_new_chat``
|
||||||
|
# estimates an upper-bound cost from ``litellm.get_model_info`` x the
|
||||||
|
# config's ``quota_reserve_tokens`` and clamps the result to this value
|
||||||
|
# so a misconfigured "$1000/M" model can't lock the user's whole balance
|
||||||
|
# on one call. Default $1.00 covers realistic worst-cases (Opus + 4K
|
||||||
|
# reserve_tokens ≈ $0.36) with headroom.
|
||||||
|
QUOTA_MAX_RESERVE_MICROS = int(os.getenv("QUOTA_MAX_RESERVE_MICROS", "1000000"))
|
||||||
|
|
||||||
|
if os.getenv("PREMIUM_TOKEN_LIMIT") and not os.getenv(
|
||||||
|
"PREMIUM_CREDIT_MICROS_LIMIT"
|
||||||
|
):
|
||||||
|
print(
|
||||||
|
"Warning: PREMIUM_TOKEN_LIMIT is deprecated; rename to "
|
||||||
|
"PREMIUM_CREDIT_MICROS_LIMIT (1:1 numerical mapping under the "
|
||||||
|
"current Stripe price). The old key will be removed in a "
|
||||||
|
"future release."
|
||||||
|
)
|
||||||
|
if os.getenv("STRIPE_TOKENS_PER_UNIT") and not os.getenv(
|
||||||
|
"STRIPE_CREDIT_MICROS_PER_UNIT"
|
||||||
|
):
|
||||||
|
print(
|
||||||
|
"Warning: STRIPE_TOKENS_PER_UNIT is deprecated; rename to "
|
||||||
|
"STRIPE_CREDIT_MICROS_PER_UNIT (1:1 numerical mapping). "
|
||||||
|
"The old key will be removed in a future release."
|
||||||
|
)
|
||||||
|
|
||||||
# Anonymous / no-login mode settings
|
# Anonymous / no-login mode settings
|
||||||
NOLOGIN_MODE_ENABLED = os.getenv("NOLOGIN_MODE_ENABLED", "FALSE").upper() == "TRUE"
|
NOLOGIN_MODE_ENABLED = os.getenv("NOLOGIN_MODE_ENABLED", "FALSE").upper() == "TRUE"
|
||||||
ANON_TOKEN_LIMIT = int(os.getenv("ANON_TOKEN_LIMIT", "500000"))
|
ANON_TOKEN_LIMIT = int(os.getenv("ANON_TOKEN_LIMIT", "500000"))
|
||||||
|
|
@ -464,6 +571,35 @@ class Config:
|
||||||
# Default quota reserve tokens when not specified per-model
|
# Default quota reserve tokens when not specified per-model
|
||||||
QUOTA_MAX_RESERVE_PER_CALL = int(os.getenv("QUOTA_MAX_RESERVE_PER_CALL", "8000"))
|
QUOTA_MAX_RESERVE_PER_CALL = int(os.getenv("QUOTA_MAX_RESERVE_PER_CALL", "8000"))
|
||||||
|
|
||||||
|
# Per-image reservation (in micro-USD) used by ``billable_call`` for the
|
||||||
|
# ``POST /image-generations`` endpoint when the global config does not
|
||||||
|
# override it. $0.05 covers realistic worst-cases for current OpenAI /
|
||||||
|
# OpenRouter image-gen pricing. Bypassed entirely for free configs.
|
||||||
|
QUOTA_DEFAULT_IMAGE_RESERVE_MICROS = int(
|
||||||
|
os.getenv("QUOTA_DEFAULT_IMAGE_RESERVE_MICROS", "50000")
|
||||||
|
)
|
||||||
|
|
||||||
|
# Per-podcast reservation (in micro-USD). One agent LLM call generating
|
||||||
|
# a transcript, typically 5k-20k completion tokens. $0.20 covers a long
|
||||||
|
# premium-model run. Tune via env.
|
||||||
|
QUOTA_DEFAULT_PODCAST_RESERVE_MICROS = int(
|
||||||
|
os.getenv("QUOTA_DEFAULT_PODCAST_RESERVE_MICROS", "200000")
|
||||||
|
)
|
||||||
|
|
||||||
|
# Per-video-presentation reservation (in micro-USD). Fan-out of N
|
||||||
|
# slide-scene generations (up to ``VIDEO_PRESENTATION_MAX_SLIDES=30``)
|
||||||
|
# plus refine retries; can produce many premium completions. $1.00
|
||||||
|
# covers worst-case. Tune via env.
|
||||||
|
#
|
||||||
|
# NOTE: this equals the existing ``QUOTA_MAX_RESERVE_MICROS`` default of
|
||||||
|
# 1_000_000. The override path in ``billable_call`` bypasses the
|
||||||
|
# per-call clamp in ``estimate_call_reserve_micros``, so this is the
|
||||||
|
# *actual* hold — raising it via env is fine but means a single video
|
||||||
|
# task can lock $1+ of credit.
|
||||||
|
QUOTA_DEFAULT_VIDEO_PRESENTATION_RESERVE_MICROS = int(
|
||||||
|
os.getenv("QUOTA_DEFAULT_VIDEO_PRESENTATION_RESERVE_MICROS", "1000000")
|
||||||
|
)
|
||||||
|
|
||||||
# Abuse prevention: concurrent stream cap and CAPTCHA
|
# Abuse prevention: concurrent stream cap and CAPTCHA
|
||||||
ANON_MAX_CONCURRENT_STREAMS = int(os.getenv("ANON_MAX_CONCURRENT_STREAMS", "2"))
|
ANON_MAX_CONCURRENT_STREAMS = int(os.getenv("ANON_MAX_CONCURRENT_STREAMS", "2"))
|
||||||
ANON_CAPTCHA_REQUEST_THRESHOLD = int(
|
ANON_CAPTCHA_REQUEST_THRESHOLD = int(
|
||||||
|
|
|
||||||
|
|
@ -19,6 +19,24 @@
|
||||||
# Structure matches NewLLMConfig:
|
# Structure matches NewLLMConfig:
|
||||||
# - Model configuration (provider, model_name, api_key, etc.)
|
# - Model configuration (provider, model_name, api_key, etc.)
|
||||||
# - Prompt configuration (system_instructions, citations_enabled)
|
# - Prompt configuration (system_instructions, citations_enabled)
|
||||||
|
#
|
||||||
|
# COST-BASED PREMIUM CREDITS:
|
||||||
|
# Each premium config bills the user's USD-credit balance based on the
|
||||||
|
# actual provider cost reported by LiteLLM. For models LiteLLM already
|
||||||
|
# knows (most OpenAI/Anthropic/etc. names) you don't need to do anything.
|
||||||
|
# For custom Azure deployment names (e.g. an in-house "gpt-5.4" deployment)
|
||||||
|
# or any model LiteLLM doesn't have in its built-in pricing table, declare
|
||||||
|
# per-token costs inline so they bill correctly:
|
||||||
|
#
|
||||||
|
# litellm_params:
|
||||||
|
# base_model: "my-custom-azure-deploy"
|
||||||
|
# # USD per token; e.g. 0.000003 == $3.00 per million input tokens
|
||||||
|
# input_cost_per_token: 0.000003
|
||||||
|
# output_cost_per_token: 0.000015
|
||||||
|
#
|
||||||
|
# OpenRouter dynamic models pull pricing automatically from OpenRouter's
|
||||||
|
# API — no inline declaration needed. Models without resolvable pricing
|
||||||
|
# debit $0 from the user's balance and log a WARNING.
|
||||||
|
|
||||||
# Router Settings for Auto Mode
|
# Router Settings for Auto Mode
|
||||||
# These settings control how the LiteLLM Router distributes requests across models
|
# These settings control how the LiteLLM Router distributes requests across models
|
||||||
|
|
@ -292,6 +310,17 @@ openrouter_integration:
|
||||||
free_rpm: 20
|
free_rpm: 20
|
||||||
free_tpm: 100000
|
free_tpm: 100000
|
||||||
|
|
||||||
|
# Image generation + vision LLM emission are OPT-IN. OpenRouter's catalogue
|
||||||
|
# contains hundreds of image- and vision-capable models; turning these on
|
||||||
|
# injects them into the global Image-Generation / Vision-LLM model
|
||||||
|
# selectors alongside any static configs. Tier (free/premium) is derived
|
||||||
|
# per model the same way it is for chat (`:free` suffix or zero pricing).
|
||||||
|
# When a user picks a premium image/vision model the call debits the
|
||||||
|
# shared $5 USD-cost-based premium credit pool — so leaving these off
|
||||||
|
# avoids surprise quota burn on existing deployments. Default: false.
|
||||||
|
image_generation_enabled: false
|
||||||
|
vision_enabled: false
|
||||||
|
|
||||||
litellm_params:
|
litellm_params:
|
||||||
max_tokens: 16384
|
max_tokens: 16384
|
||||||
system_instructions: ""
|
system_instructions: ""
|
||||||
|
|
|
||||||
|
|
@ -731,6 +731,7 @@ class TokenUsage(BaseModel, TimestampMixin):
|
||||||
prompt_tokens = Column(Integer, nullable=False, default=0)
|
prompt_tokens = Column(Integer, nullable=False, default=0)
|
||||||
completion_tokens = Column(Integer, nullable=False, default=0)
|
completion_tokens = Column(Integer, nullable=False, default=0)
|
||||||
total_tokens = Column(Integer, nullable=False, default=0)
|
total_tokens = Column(Integer, nullable=False, default=0)
|
||||||
|
cost_micros = Column(BigInteger, nullable=False, default=0, server_default="0")
|
||||||
model_breakdown = Column(JSONB, nullable=True)
|
model_breakdown = Column(JSONB, nullable=True)
|
||||||
call_details = Column(JSONB, nullable=True)
|
call_details = Column(JSONB, nullable=True)
|
||||||
|
|
||||||
|
|
@ -1793,7 +1794,15 @@ class PagePurchase(Base, TimestampMixin):
|
||||||
|
|
||||||
|
|
||||||
class PremiumTokenPurchase(Base, TimestampMixin):
|
class PremiumTokenPurchase(Base, TimestampMixin):
|
||||||
"""Tracks Stripe checkout sessions used to grant additional premium token credits."""
|
"""Tracks Stripe checkout sessions used to grant additional premium credit (USD micro-units).
|
||||||
|
|
||||||
|
Note: the table name is preserved (``premium_token_purchases``) for
|
||||||
|
operational continuity even though the unit is now USD micro-credits
|
||||||
|
instead of raw tokens. The ``credit_micros_granted`` column replaced
|
||||||
|
the legacy ``tokens_granted`` in migration 140; the stored values
|
||||||
|
were not transformed because the prior $1 = 1M tokens Stripe price
|
||||||
|
makes the unit conversion 1:1 numerically.
|
||||||
|
"""
|
||||||
|
|
||||||
__tablename__ = "premium_token_purchases"
|
__tablename__ = "premium_token_purchases"
|
||||||
__allow_unmapped__ = True
|
__allow_unmapped__ = True
|
||||||
|
|
@ -1810,7 +1819,7 @@ class PremiumTokenPurchase(Base, TimestampMixin):
|
||||||
)
|
)
|
||||||
stripe_payment_intent_id = Column(String(255), nullable=True, index=True)
|
stripe_payment_intent_id = Column(String(255), nullable=True, index=True)
|
||||||
quantity = Column(Integer, nullable=False)
|
quantity = Column(Integer, nullable=False)
|
||||||
tokens_granted = Column(BigInteger, nullable=False)
|
credit_micros_granted = Column(BigInteger, nullable=False)
|
||||||
amount_total = Column(Integer, nullable=True)
|
amount_total = Column(Integer, nullable=True)
|
||||||
currency = Column(String(10), nullable=True)
|
currency = Column(String(10), nullable=True)
|
||||||
status = Column(
|
status = Column(
|
||||||
|
|
@ -2109,16 +2118,16 @@ if config.AUTH_TYPE == "GOOGLE":
|
||||||
)
|
)
|
||||||
pages_used = Column(Integer, nullable=False, default=0, server_default="0")
|
pages_used = Column(Integer, nullable=False, default=0, server_default="0")
|
||||||
|
|
||||||
premium_tokens_limit = Column(
|
premium_credit_micros_limit = Column(
|
||||||
BigInteger,
|
BigInteger,
|
||||||
nullable=False,
|
nullable=False,
|
||||||
default=config.PREMIUM_TOKEN_LIMIT,
|
default=config.PREMIUM_CREDIT_MICROS_LIMIT,
|
||||||
server_default=str(config.PREMIUM_TOKEN_LIMIT),
|
server_default=str(config.PREMIUM_CREDIT_MICROS_LIMIT),
|
||||||
)
|
)
|
||||||
premium_tokens_used = Column(
|
premium_credit_micros_used = Column(
|
||||||
BigInteger, nullable=False, default=0, server_default="0"
|
BigInteger, nullable=False, default=0, server_default="0"
|
||||||
)
|
)
|
||||||
premium_tokens_reserved = Column(
|
premium_credit_micros_reserved = Column(
|
||||||
BigInteger, nullable=False, default=0, server_default="0"
|
BigInteger, nullable=False, default=0, server_default="0"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -2241,16 +2250,16 @@ else:
|
||||||
)
|
)
|
||||||
pages_used = Column(Integer, nullable=False, default=0, server_default="0")
|
pages_used = Column(Integer, nullable=False, default=0, server_default="0")
|
||||||
|
|
||||||
premium_tokens_limit = Column(
|
premium_credit_micros_limit = Column(
|
||||||
BigInteger,
|
BigInteger,
|
||||||
nullable=False,
|
nullable=False,
|
||||||
default=config.PREMIUM_TOKEN_LIMIT,
|
default=config.PREMIUM_CREDIT_MICROS_LIMIT,
|
||||||
server_default=str(config.PREMIUM_TOKEN_LIMIT),
|
server_default=str(config.PREMIUM_CREDIT_MICROS_LIMIT),
|
||||||
)
|
)
|
||||||
premium_tokens_used = Column(
|
premium_credit_micros_used = Column(
|
||||||
BigInteger, nullable=False, default=0, server_default="0"
|
BigInteger, nullable=False, default=0, server_default="0"
|
||||||
)
|
)
|
||||||
premium_tokens_reserved = Column(
|
premium_credit_micros_reserved = Column(
|
||||||
BigInteger, nullable=False, default=0, server_default="0"
|
BigInteger, nullable=False, default=0, server_default="0"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -68,12 +68,25 @@ class EtlPipelineService:
|
||||||
etl_service="VISION_LLM",
|
etl_service="VISION_LLM",
|
||||||
content_type="image",
|
content_type="image",
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception as exc:
|
||||||
logging.warning(
|
# Special-case quota exhaustion so we log a clearer message
|
||||||
"Vision LLM failed for %s, falling back to document parser",
|
# — the vision LLM didn't "fail", the user just ran out of
|
||||||
request.filename,
|
# premium credit. Falling through to the document parser
|
||||||
exc_info=True,
|
# is a graceful degradation: OCR/Unstructured still
|
||||||
)
|
# extracts text from the image without burning credit.
|
||||||
|
from app.services.billable_calls import QuotaInsufficientError
|
||||||
|
|
||||||
|
if isinstance(exc, QuotaInsufficientError):
|
||||||
|
logging.info(
|
||||||
|
"Vision LLM quota exhausted for %s; falling back to document parser",
|
||||||
|
request.filename,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logging.warning(
|
||||||
|
"Vision LLM failed for %s, falling back to document parser",
|
||||||
|
request.filename,
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
logging.info(
|
logging.info(
|
||||||
"No vision LLM provided, falling back to document parser for %s",
|
"No vision LLM provided, falling back to document parser for %s",
|
||||||
|
|
|
||||||
|
|
@ -36,6 +36,11 @@ from app.schemas import (
|
||||||
ImageGenerationListRead,
|
ImageGenerationListRead,
|
||||||
ImageGenerationRead,
|
ImageGenerationRead,
|
||||||
)
|
)
|
||||||
|
from app.services.billable_calls import (
|
||||||
|
DEFAULT_IMAGE_RESERVE_MICROS,
|
||||||
|
QuotaInsufficientError,
|
||||||
|
billable_call,
|
||||||
|
)
|
||||||
from app.services.image_gen_router_service import (
|
from app.services.image_gen_router_service import (
|
||||||
IMAGE_GEN_AUTO_MODE_ID,
|
IMAGE_GEN_AUTO_MODE_ID,
|
||||||
ImageGenRouterService,
|
ImageGenRouterService,
|
||||||
|
|
@ -92,6 +97,50 @@ def _build_model_string(
|
||||||
return f"{prefix}/{model_name}"
|
return f"{prefix}/{model_name}"
|
||||||
|
|
||||||
|
|
||||||
|
async def _resolve_billing_for_image_gen(
|
||||||
|
session: AsyncSession,
|
||||||
|
config_id: int | None,
|
||||||
|
search_space: SearchSpace,
|
||||||
|
) -> tuple[str, str, int]:
|
||||||
|
"""Resolve ``(billing_tier, base_model, reserve_micros)`` for a request.
|
||||||
|
|
||||||
|
The resolution mirrors ``_execute_image_generation``'s lookup tree but
|
||||||
|
only extracts the fields needed for billing — we do this *before*
|
||||||
|
``billable_call`` so the reservation is correctly sized for the
|
||||||
|
config that will actually run, and so we don't open an
|
||||||
|
``ImageGeneration`` row for a request that's about to 402.
|
||||||
|
|
||||||
|
User-owned (positive ID) BYOK configs are always free — they cost
|
||||||
|
the user nothing on our side. Auto mode currently treats as free
|
||||||
|
because the underlying router can dispatch to either premium or
|
||||||
|
free YAML configs and we don't surface the resolved deployment up
|
||||||
|
here yet. Bringing Auto under premium billing would require
|
||||||
|
threading the chosen deployment back from ``ImageGenRouterService``.
|
||||||
|
"""
|
||||||
|
resolved_id = config_id
|
||||||
|
if resolved_id is None:
|
||||||
|
resolved_id = search_space.image_generation_config_id or IMAGE_GEN_AUTO_MODE_ID
|
||||||
|
|
||||||
|
if is_image_gen_auto_mode(resolved_id):
|
||||||
|
return ("free", "auto", DEFAULT_IMAGE_RESERVE_MICROS)
|
||||||
|
|
||||||
|
if resolved_id < 0:
|
||||||
|
cfg = _get_global_image_gen_config(resolved_id) or {}
|
||||||
|
billing_tier = str(cfg.get("billing_tier", "free")).lower()
|
||||||
|
base_model = _build_model_string(
|
||||||
|
cfg.get("provider", ""),
|
||||||
|
cfg.get("model_name", ""),
|
||||||
|
cfg.get("custom_provider"),
|
||||||
|
)
|
||||||
|
reserve_micros = int(
|
||||||
|
cfg.get("quota_reserve_micros") or DEFAULT_IMAGE_RESERVE_MICROS
|
||||||
|
)
|
||||||
|
return (billing_tier, base_model, reserve_micros)
|
||||||
|
|
||||||
|
# Positive ID = user-owned BYOK image-gen config — always free.
|
||||||
|
return ("free", "user_byok", DEFAULT_IMAGE_RESERVE_MICROS)
|
||||||
|
|
||||||
|
|
||||||
async def _execute_image_generation(
|
async def _execute_image_generation(
|
||||||
session: AsyncSession,
|
session: AsyncSession,
|
||||||
image_gen: ImageGeneration,
|
image_gen: ImageGeneration,
|
||||||
|
|
@ -225,6 +274,9 @@ async def get_global_image_gen_configs(
|
||||||
"litellm_params": {},
|
"litellm_params": {},
|
||||||
"is_global": True,
|
"is_global": True,
|
||||||
"is_auto_mode": True,
|
"is_auto_mode": True,
|
||||||
|
# Auto mode currently treated as free until per-deployment
|
||||||
|
# billing-tier surfacing lands (see _resolve_billing_for_image_gen).
|
||||||
|
"billing_tier": "free",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -241,6 +293,8 @@ async def get_global_image_gen_configs(
|
||||||
"api_version": cfg.get("api_version") or None,
|
"api_version": cfg.get("api_version") or None,
|
||||||
"litellm_params": cfg.get("litellm_params", {}),
|
"litellm_params": cfg.get("litellm_params", {}),
|
||||||
"is_global": True,
|
"is_global": True,
|
||||||
|
"billing_tier": cfg.get("billing_tier", "free"),
|
||||||
|
"quota_reserve_micros": cfg.get("quota_reserve_micros"),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -454,7 +508,26 @@ async def create_image_generation(
|
||||||
session: AsyncSession = Depends(get_async_session),
|
session: AsyncSession = Depends(get_async_session),
|
||||||
user: User = Depends(current_active_user),
|
user: User = Depends(current_active_user),
|
||||||
):
|
):
|
||||||
"""Create and execute an image generation request."""
|
"""Create and execute an image generation request.
|
||||||
|
|
||||||
|
Premium configs are gated by the user's shared premium credit pool.
|
||||||
|
The flow is:
|
||||||
|
|
||||||
|
1. Permission check + load the search space (cheap, no provider call).
|
||||||
|
2. Resolve which config will run so we know its billing tier and the
|
||||||
|
worst-case reservation size *before* opening any DB rows.
|
||||||
|
3. Wrap the entire ImageGeneration row insert + provider call in
|
||||||
|
``billable_call``. If quota is denied, ``billable_call`` raises
|
||||||
|
``QuotaInsufficientError`` *before* we flush a row, which we
|
||||||
|
translate to HTTP 402 (no orphaned rows on the user's account,
|
||||||
|
no inserted error rows for "you ran out of credit").
|
||||||
|
4. On success, the actual ``response_cost`` flows through the
|
||||||
|
LiteLLM callback into the accumulator, and ``billable_call``
|
||||||
|
finalizes the debit at exit. Inner ``try/except`` still catches
|
||||||
|
provider errors and stores them on ``error_message`` (HTTP 200
|
||||||
|
with ``error_message`` set is preserved for failed-but-not-quota
|
||||||
|
scenarios — clients already know how to surface those).
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
await check_permission(
|
await check_permission(
|
||||||
session,
|
session,
|
||||||
|
|
@ -471,33 +544,70 @@ async def create_image_generation(
|
||||||
if not search_space:
|
if not search_space:
|
||||||
raise HTTPException(status_code=404, detail="Search space not found")
|
raise HTTPException(status_code=404, detail="Search space not found")
|
||||||
|
|
||||||
db_image_gen = ImageGeneration(
|
billing_tier, base_model, reserve_micros = await _resolve_billing_for_image_gen(
|
||||||
prompt=data.prompt,
|
session, data.image_generation_config_id, search_space
|
||||||
model=data.model,
|
|
||||||
n=data.n,
|
|
||||||
quality=data.quality,
|
|
||||||
size=data.size,
|
|
||||||
style=data.style,
|
|
||||||
response_format=data.response_format,
|
|
||||||
image_generation_config_id=data.image_generation_config_id,
|
|
||||||
search_space_id=data.search_space_id,
|
|
||||||
created_by_id=user.id,
|
|
||||||
)
|
)
|
||||||
session.add(db_image_gen)
|
|
||||||
await session.flush()
|
|
||||||
|
|
||||||
try:
|
# billable_call runs OUTSIDE the inner try/except so QuotaInsufficientError
|
||||||
await _execute_image_generation(session, db_image_gen, search_space)
|
# propagates to the outer ``except QuotaInsufficientError`` handler
|
||||||
except Exception as e:
|
# below as HTTP 402 — it is intentionally NOT swallowed into
|
||||||
logger.exception("Image generation call failed")
|
# ``error_message`` because that would (1) imply a successful row
|
||||||
db_image_gen.error_message = str(e)
|
# exists when none does, and (2) return HTTP 200 to a client
|
||||||
|
# whose request was actively *denied* (issue K).
|
||||||
|
async with billable_call(
|
||||||
|
user_id=search_space.user_id,
|
||||||
|
search_space_id=data.search_space_id,
|
||||||
|
billing_tier=billing_tier,
|
||||||
|
base_model=base_model,
|
||||||
|
quota_reserve_micros_override=reserve_micros,
|
||||||
|
usage_type="image_generation",
|
||||||
|
call_details={"model": base_model, "prompt": data.prompt[:100]},
|
||||||
|
):
|
||||||
|
db_image_gen = ImageGeneration(
|
||||||
|
prompt=data.prompt,
|
||||||
|
model=data.model,
|
||||||
|
n=data.n,
|
||||||
|
quality=data.quality,
|
||||||
|
size=data.size,
|
||||||
|
style=data.style,
|
||||||
|
response_format=data.response_format,
|
||||||
|
image_generation_config_id=data.image_generation_config_id,
|
||||||
|
search_space_id=data.search_space_id,
|
||||||
|
created_by_id=user.id,
|
||||||
|
)
|
||||||
|
session.add(db_image_gen)
|
||||||
|
await session.flush()
|
||||||
|
|
||||||
await session.commit()
|
try:
|
||||||
await session.refresh(db_image_gen)
|
await _execute_image_generation(session, db_image_gen, search_space)
|
||||||
return db_image_gen
|
except Exception as e:
|
||||||
|
logger.exception("Image generation call failed")
|
||||||
|
db_image_gen.error_message = str(e)
|
||||||
|
|
||||||
|
await session.commit()
|
||||||
|
await session.refresh(db_image_gen)
|
||||||
|
return db_image_gen
|
||||||
|
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
raise
|
raise
|
||||||
|
except QuotaInsufficientError as exc:
|
||||||
|
# The user's premium credit pool is empty. No DB row is created
|
||||||
|
# because ``billable_call`` denies before yielding (issue K).
|
||||||
|
await session.rollback()
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=402,
|
||||||
|
detail={
|
||||||
|
"error_code": "premium_quota_exhausted",
|
||||||
|
"usage_type": exc.usage_type,
|
||||||
|
"used_micros": exc.used_micros,
|
||||||
|
"limit_micros": exc.limit_micros,
|
||||||
|
"remaining_micros": exc.remaining_micros,
|
||||||
|
"message": (
|
||||||
|
"Out of premium credits for image generation. "
|
||||||
|
"Purchase additional credits or switch to a free model."
|
||||||
|
),
|
||||||
|
},
|
||||||
|
) from exc
|
||||||
except SQLAlchemyError:
|
except SQLAlchemyError:
|
||||||
await session.rollback()
|
await session.rollback()
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
|
|
|
||||||
|
|
@ -1366,7 +1366,11 @@ async def append_message(
|
||||||
# flush assigns the PK/defaults without a round-trip SELECT
|
# flush assigns the PK/defaults without a round-trip SELECT
|
||||||
await session.flush()
|
await session.flush()
|
||||||
|
|
||||||
# Persist token usage if provided (for assistant messages)
|
# Persist token usage if provided (for assistant messages).
|
||||||
|
# ``cost_micros`` is the provider USD cost reported by LiteLLM,
|
||||||
|
# forwarded by the FE through the appendMessage round-trip so
|
||||||
|
# the historical TokenUsage row matches the credit debit applied
|
||||||
|
# at finalize time.
|
||||||
token_usage_data = raw_body.get("token_usage")
|
token_usage_data = raw_body.get("token_usage")
|
||||||
if token_usage_data and message_role == NewChatMessageRole.ASSISTANT:
|
if token_usage_data and message_role == NewChatMessageRole.ASSISTANT:
|
||||||
await record_token_usage(
|
await record_token_usage(
|
||||||
|
|
@ -1377,6 +1381,7 @@ async def append_message(
|
||||||
prompt_tokens=token_usage_data.get("prompt_tokens", 0),
|
prompt_tokens=token_usage_data.get("prompt_tokens", 0),
|
||||||
completion_tokens=token_usage_data.get("completion_tokens", 0),
|
completion_tokens=token_usage_data.get("completion_tokens", 0),
|
||||||
total_tokens=token_usage_data.get("total_tokens", 0),
|
total_tokens=token_usage_data.get("total_tokens", 0),
|
||||||
|
cost_micros=token_usage_data.get("cost_micros", 0),
|
||||||
model_breakdown=token_usage_data.get("usage"),
|
model_breakdown=token_usage_data.get("usage"),
|
||||||
call_details=token_usage_data.get("call_details"),
|
call_details=token_usage_data.get("call_details"),
|
||||||
thread_id=thread_id,
|
thread_id=thread_id,
|
||||||
|
|
|
||||||
|
|
@ -594,6 +594,7 @@ async def _get_image_gen_config_by_id(
|
||||||
"model_name": "auto",
|
"model_name": "auto",
|
||||||
"is_global": True,
|
"is_global": True,
|
||||||
"is_auto_mode": True,
|
"is_auto_mode": True,
|
||||||
|
"billing_tier": "free",
|
||||||
}
|
}
|
||||||
|
|
||||||
if config_id < 0:
|
if config_id < 0:
|
||||||
|
|
@ -610,6 +611,7 @@ async def _get_image_gen_config_by_id(
|
||||||
"api_version": cfg.get("api_version") or None,
|
"api_version": cfg.get("api_version") or None,
|
||||||
"litellm_params": cfg.get("litellm_params", {}),
|
"litellm_params": cfg.get("litellm_params", {}),
|
||||||
"is_global": True,
|
"is_global": True,
|
||||||
|
"billing_tier": cfg.get("billing_tier", "free"),
|
||||||
}
|
}
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
@ -652,6 +654,7 @@ async def _get_vision_llm_config_by_id(
|
||||||
"model_name": "auto",
|
"model_name": "auto",
|
||||||
"is_global": True,
|
"is_global": True,
|
||||||
"is_auto_mode": True,
|
"is_auto_mode": True,
|
||||||
|
"billing_tier": "free",
|
||||||
}
|
}
|
||||||
|
|
||||||
if config_id < 0:
|
if config_id < 0:
|
||||||
|
|
@ -668,6 +671,7 @@ async def _get_vision_llm_config_by_id(
|
||||||
"api_version": cfg.get("api_version") or None,
|
"api_version": cfg.get("api_version") or None,
|
||||||
"litellm_params": cfg.get("litellm_params", {}),
|
"litellm_params": cfg.get("litellm_params", {}),
|
||||||
"is_global": True,
|
"is_global": True,
|
||||||
|
"billing_tier": cfg.get("billing_tier", "free"),
|
||||||
}
|
}
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -251,9 +251,16 @@ async def _fulfill_completed_token_purchase(
|
||||||
metadata = _get_metadata(checkout_session)
|
metadata = _get_metadata(checkout_session)
|
||||||
user_id = metadata.get("user_id")
|
user_id = metadata.get("user_id")
|
||||||
quantity = int(metadata.get("quantity", "0"))
|
quantity = int(metadata.get("quantity", "0"))
|
||||||
tokens_per_unit = int(metadata.get("tokens_per_unit", "0"))
|
# Read the new metadata key first, fall back to the legacy one so
|
||||||
|
# in-flight checkout sessions created before the cost-credits
|
||||||
|
# release still fulfil correctly (the unit is numerically the
|
||||||
|
# same: $1 buys 1_000_000 micro-USD == 1_000_000 tokens).
|
||||||
|
credit_micros_per_unit = int(
|
||||||
|
metadata.get("credit_micros_per_unit")
|
||||||
|
or metadata.get("tokens_per_unit", "0")
|
||||||
|
)
|
||||||
|
|
||||||
if not user_id or quantity <= 0 or tokens_per_unit <= 0:
|
if not user_id or quantity <= 0 or credit_micros_per_unit <= 0:
|
||||||
logger.error(
|
logger.error(
|
||||||
"Skipping token fulfillment for session %s: incomplete metadata %s",
|
"Skipping token fulfillment for session %s: incomplete metadata %s",
|
||||||
checkout_session_id,
|
checkout_session_id,
|
||||||
|
|
@ -268,7 +275,7 @@ async def _fulfill_completed_token_purchase(
|
||||||
getattr(checkout_session, "payment_intent", None)
|
getattr(checkout_session, "payment_intent", None)
|
||||||
),
|
),
|
||||||
quantity=quantity,
|
quantity=quantity,
|
||||||
tokens_granted=quantity * tokens_per_unit,
|
credit_micros_granted=quantity * credit_micros_per_unit,
|
||||||
amount_total=getattr(checkout_session, "amount_total", None),
|
amount_total=getattr(checkout_session, "amount_total", None),
|
||||||
currency=getattr(checkout_session, "currency", None),
|
currency=getattr(checkout_session, "currency", None),
|
||||||
status=PremiumTokenPurchaseStatus.PENDING,
|
status=PremiumTokenPurchaseStatus.PENDING,
|
||||||
|
|
@ -303,9 +310,14 @@ async def _fulfill_completed_token_purchase(
|
||||||
purchase.stripe_payment_intent_id = _normalize_optional_string(
|
purchase.stripe_payment_intent_id = _normalize_optional_string(
|
||||||
getattr(checkout_session, "payment_intent", None)
|
getattr(checkout_session, "payment_intent", None)
|
||||||
)
|
)
|
||||||
user.premium_tokens_limit = (
|
# Top up the user's credit balance by the granted micro-USD amount.
|
||||||
max(user.premium_tokens_used, user.premium_tokens_limit)
|
# ``max(used, limit)`` clamps the case where the legacy code wrote a
|
||||||
+ purchase.tokens_granted
|
# used value above the limit (e.g. underbilling rounding) so adding
|
||||||
|
# ``credit_micros_granted`` always lifts the limit by the full pack
|
||||||
|
# size rather than disappearing into past overuse.
|
||||||
|
user.premium_credit_micros_limit = (
|
||||||
|
max(user.premium_credit_micros_used, user.premium_credit_micros_limit)
|
||||||
|
+ purchase.credit_micros_granted
|
||||||
)
|
)
|
||||||
|
|
||||||
await db_session.commit()
|
await db_session.commit()
|
||||||
|
|
@ -532,12 +544,18 @@ async def create_token_checkout_session(
|
||||||
user: User = Depends(current_active_user),
|
user: User = Depends(current_active_user),
|
||||||
db_session: AsyncSession = Depends(get_async_session),
|
db_session: AsyncSession = Depends(get_async_session),
|
||||||
):
|
):
|
||||||
"""Create a Stripe Checkout Session for buying premium token packs."""
|
"""Create a Stripe Checkout Session for buying premium credit packs.
|
||||||
|
|
||||||
|
Each pack grants ``STRIPE_CREDIT_MICROS_PER_UNIT`` micro-USD of
|
||||||
|
credit (default 1_000_000 = $1.00). The user's balance is debited
|
||||||
|
at the actual provider cost reported by LiteLLM at finalize time,
|
||||||
|
so $1 of credit always buys $1 worth of provider usage at cost.
|
||||||
|
"""
|
||||||
_ensure_token_buying_enabled()
|
_ensure_token_buying_enabled()
|
||||||
stripe_client = get_stripe_client()
|
stripe_client = get_stripe_client()
|
||||||
price_id = _get_required_token_price_id()
|
price_id = _get_required_token_price_id()
|
||||||
success_url, cancel_url = _get_token_checkout_urls(body.search_space_id)
|
success_url, cancel_url = _get_token_checkout_urls(body.search_space_id)
|
||||||
tokens_granted = body.quantity * config.STRIPE_TOKENS_PER_UNIT
|
credit_micros_granted = body.quantity * config.STRIPE_CREDIT_MICROS_PER_UNIT
|
||||||
|
|
||||||
try:
|
try:
|
||||||
checkout_session = stripe_client.v1.checkout.sessions.create(
|
checkout_session = stripe_client.v1.checkout.sessions.create(
|
||||||
|
|
@ -556,8 +574,8 @@ async def create_token_checkout_session(
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"user_id": str(user.id),
|
"user_id": str(user.id),
|
||||||
"quantity": str(body.quantity),
|
"quantity": str(body.quantity),
|
||||||
"tokens_per_unit": str(config.STRIPE_TOKENS_PER_UNIT),
|
"credit_micros_per_unit": str(config.STRIPE_CREDIT_MICROS_PER_UNIT),
|
||||||
"purchase_type": "premium_tokens",
|
"purchase_type": "premium_credit",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
@ -583,7 +601,7 @@ async def create_token_checkout_session(
|
||||||
getattr(checkout_session, "payment_intent", None)
|
getattr(checkout_session, "payment_intent", None)
|
||||||
),
|
),
|
||||||
quantity=body.quantity,
|
quantity=body.quantity,
|
||||||
tokens_granted=tokens_granted,
|
credit_micros_granted=credit_micros_granted,
|
||||||
amount_total=getattr(checkout_session, "amount_total", None),
|
amount_total=getattr(checkout_session, "amount_total", None),
|
||||||
currency=getattr(checkout_session, "currency", None),
|
currency=getattr(checkout_session, "currency", None),
|
||||||
status=PremiumTokenPurchaseStatus.PENDING,
|
status=PremiumTokenPurchaseStatus.PENDING,
|
||||||
|
|
@ -598,14 +616,19 @@ async def create_token_checkout_session(
|
||||||
async def get_token_status(
|
async def get_token_status(
|
||||||
user: User = Depends(current_active_user),
|
user: User = Depends(current_active_user),
|
||||||
):
|
):
|
||||||
"""Return token-buying availability and current premium quota for frontend."""
|
"""Return token-buying availability and current premium credit quota for frontend.
|
||||||
used = user.premium_tokens_used
|
|
||||||
limit = user.premium_tokens_limit
|
Values are in micro-USD (1_000_000 = $1.00); the FE divides by 1M
|
||||||
|
when displaying. The route name is preserved for back-compat with
|
||||||
|
pinned client deployments.
|
||||||
|
"""
|
||||||
|
used = user.premium_credit_micros_used
|
||||||
|
limit = user.premium_credit_micros_limit
|
||||||
return TokenStripeStatusResponse(
|
return TokenStripeStatusResponse(
|
||||||
token_buying_enabled=config.STRIPE_TOKEN_BUYING_ENABLED,
|
token_buying_enabled=config.STRIPE_TOKEN_BUYING_ENABLED,
|
||||||
premium_tokens_used=used,
|
premium_credit_micros_used=used,
|
||||||
premium_tokens_limit=limit,
|
premium_credit_micros_limit=limit,
|
||||||
premium_tokens_remaining=max(0, limit - used),
|
premium_credit_micros_remaining=max(0, limit - used),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -82,6 +82,9 @@ async def get_global_vision_llm_configs(
|
||||||
"litellm_params": {},
|
"litellm_params": {},
|
||||||
"is_global": True,
|
"is_global": True,
|
||||||
"is_auto_mode": True,
|
"is_auto_mode": True,
|
||||||
|
# Auto mode treated as free until per-deployment billing-tier
|
||||||
|
# surfacing lands; see ``get_vision_llm`` for parity.
|
||||||
|
"billing_tier": "free",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -98,6 +101,10 @@ async def get_global_vision_llm_configs(
|
||||||
"api_version": cfg.get("api_version") or None,
|
"api_version": cfg.get("api_version") or None,
|
||||||
"litellm_params": cfg.get("litellm_params", {}),
|
"litellm_params": cfg.get("litellm_params", {}),
|
||||||
"is_global": True,
|
"is_global": True,
|
||||||
|
"billing_tier": cfg.get("billing_tier", "free"),
|
||||||
|
"quota_reserve_tokens": cfg.get("quota_reserve_tokens"),
|
||||||
|
"input_cost_per_token": cfg.get("input_cost_per_token"),
|
||||||
|
"output_cost_per_token": cfg.get("output_cost_per_token"),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -215,6 +215,12 @@ class GlobalImageGenConfigRead(BaseModel):
|
||||||
Schema for reading global image generation configs from YAML.
|
Schema for reading global image generation configs from YAML.
|
||||||
Global configs have negative IDs. API key is hidden.
|
Global configs have negative IDs. API key is hidden.
|
||||||
ID 0 is reserved for Auto mode (LiteLLM Router load balancing).
|
ID 0 is reserved for Auto mode (LiteLLM Router load balancing).
|
||||||
|
|
||||||
|
The ``billing_tier`` field allows the frontend to show a Premium/Free
|
||||||
|
badge and (more importantly) tells the backend whether to debit the
|
||||||
|
user's premium credit pool when this config is used. ``"free"`` is
|
||||||
|
the default for backward compatibility — admins must explicitly opt
|
||||||
|
a global config into ``"premium"``.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
id: int = Field(
|
id: int = Field(
|
||||||
|
|
@ -231,3 +237,15 @@ class GlobalImageGenConfigRead(BaseModel):
|
||||||
litellm_params: dict[str, Any] | None = None
|
litellm_params: dict[str, Any] | None = None
|
||||||
is_global: bool = True
|
is_global: bool = True
|
||||||
is_auto_mode: bool = False
|
is_auto_mode: bool = False
|
||||||
|
billing_tier: str = Field(
|
||||||
|
default="free",
|
||||||
|
description="'free' or 'premium'. Premium debits the user's premium credit pool (USD-cost-based).",
|
||||||
|
)
|
||||||
|
quota_reserve_micros: int | None = Field(
|
||||||
|
default=None,
|
||||||
|
description=(
|
||||||
|
"Optional override for the reservation amount (in micro-USD) used when "
|
||||||
|
"this image generation is premium. Falls back to "
|
||||||
|
"QUOTA_DEFAULT_IMAGE_RESERVE_MICROS when omitted."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
|
||||||
|
|
@ -39,6 +39,7 @@ class TokenUsageSummary(BaseModel):
|
||||||
prompt_tokens: int = 0
|
prompt_tokens: int = 0
|
||||||
completion_tokens: int = 0
|
completion_tokens: int = 0
|
||||||
total_tokens: int = 0
|
total_tokens: int = 0
|
||||||
|
cost_micros: int = 0
|
||||||
model_breakdown: dict | None = None
|
model_breakdown: dict | None = None
|
||||||
model_config = ConfigDict(from_attributes=True)
|
model_config = ConfigDict(from_attributes=True)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -70,13 +70,17 @@ class CreateTokenCheckoutSessionResponse(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
class TokenPurchaseRead(BaseModel):
|
class TokenPurchaseRead(BaseModel):
|
||||||
"""Serialized premium token purchase record."""
|
"""Serialized premium credit purchase record.
|
||||||
|
|
||||||
|
``credit_micros_granted`` is in micro-USD (1_000_000 = $1.00). The
|
||||||
|
schema name kept ``Token`` for API back-compat with pinned clients.
|
||||||
|
"""
|
||||||
|
|
||||||
id: uuid.UUID
|
id: uuid.UUID
|
||||||
stripe_checkout_session_id: str
|
stripe_checkout_session_id: str
|
||||||
stripe_payment_intent_id: str | None = None
|
stripe_payment_intent_id: str | None = None
|
||||||
quantity: int
|
quantity: int
|
||||||
tokens_granted: int
|
credit_micros_granted: int
|
||||||
amount_total: int | None = None
|
amount_total: int | None = None
|
||||||
currency: str | None = None
|
currency: str | None = None
|
||||||
status: str
|
status: str
|
||||||
|
|
@ -87,15 +91,19 @@ class TokenPurchaseRead(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
class TokenPurchaseHistoryResponse(BaseModel):
|
class TokenPurchaseHistoryResponse(BaseModel):
|
||||||
"""Response containing the user's premium token purchases."""
|
"""Response containing the user's premium credit purchases."""
|
||||||
|
|
||||||
purchases: list[TokenPurchaseRead]
|
purchases: list[TokenPurchaseRead]
|
||||||
|
|
||||||
|
|
||||||
class TokenStripeStatusResponse(BaseModel):
|
class TokenStripeStatusResponse(BaseModel):
|
||||||
"""Response describing token-buying availability and current quota."""
|
"""Response describing premium-credit-buying availability and balance.
|
||||||
|
|
||||||
|
All ``premium_credit_micros_*`` fields are in micro-USD; the FE
|
||||||
|
divides by 1_000_000 to display USD.
|
||||||
|
"""
|
||||||
|
|
||||||
token_buying_enabled: bool
|
token_buying_enabled: bool
|
||||||
premium_tokens_used: int = 0
|
premium_credit_micros_used: int = 0
|
||||||
premium_tokens_limit: int = 0
|
premium_credit_micros_limit: int = 0
|
||||||
premium_tokens_remaining: int = 0
|
premium_credit_micros_remaining: int = 0
|
||||||
|
|
|
||||||
|
|
@ -62,6 +62,15 @@ class VisionLLMConfigPublic(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
class GlobalVisionLLMConfigRead(BaseModel):
|
class GlobalVisionLLMConfigRead(BaseModel):
|
||||||
|
"""Schema for reading global vision LLM configs from YAML.
|
||||||
|
|
||||||
|
The ``billing_tier`` field allows the frontend to show a Premium/Free
|
||||||
|
badge and (more importantly) tells the backend whether to debit the
|
||||||
|
user's premium credit pool when this config is used. ``"free"`` is
|
||||||
|
the default for backward compatibility — admins must explicitly opt
|
||||||
|
a global config into ``"premium"``.
|
||||||
|
"""
|
||||||
|
|
||||||
id: int = Field(...)
|
id: int = Field(...)
|
||||||
name: str
|
name: str
|
||||||
description: str | None = None
|
description: str | None = None
|
||||||
|
|
@ -73,3 +82,26 @@ class GlobalVisionLLMConfigRead(BaseModel):
|
||||||
litellm_params: dict[str, Any] | None = None
|
litellm_params: dict[str, Any] | None = None
|
||||||
is_global: bool = True
|
is_global: bool = True
|
||||||
is_auto_mode: bool = False
|
is_auto_mode: bool = False
|
||||||
|
billing_tier: str = Field(
|
||||||
|
default="free",
|
||||||
|
description="'free' or 'premium'. Premium debits the user's premium credit pool (USD-cost-based).",
|
||||||
|
)
|
||||||
|
quota_reserve_tokens: int | None = Field(
|
||||||
|
default=None,
|
||||||
|
description=(
|
||||||
|
"Optional override for the per-call reservation in *tokens* — "
|
||||||
|
"converted to micro-USD via the model's input/output prices at "
|
||||||
|
"reservation time. Falls back to QUOTA_DEFAULT_RESERVE_TOKENS."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
input_cost_per_token: float | None = Field(
|
||||||
|
default=None,
|
||||||
|
description=(
|
||||||
|
"Optional input price in USD/token. Used by pricing_registration to "
|
||||||
|
"register custom Azure / OpenRouter aliases with LiteLLM at startup."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
output_cost_per_token: float | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="Optional output price in USD/token. Pair with input_cost_per_token.",
|
||||||
|
)
|
||||||
|
|
|
||||||
430
surfsense_backend/app/services/billable_calls.py
Normal file
430
surfsense_backend/app/services/billable_calls.py
Normal file
|
|
@ -0,0 +1,430 @@
|
||||||
|
"""
|
||||||
|
Per-call billable wrapper for image generation, vision LLM extraction, and
|
||||||
|
any other short-lived premium operation that must charge against the user's
|
||||||
|
shared premium credit pool.
|
||||||
|
|
||||||
|
The ``billable_call`` async context manager encapsulates the standard
|
||||||
|
"reserve → execute → finalize / release → record audit row" lifecycle in a
|
||||||
|
single primitive so callers (the image-generation REST route and the
|
||||||
|
vision-LLM wrapper used during indexing) don't have to re-implement it.
|
||||||
|
|
||||||
|
KEY DESIGN POINTS (issue A, B):
|
||||||
|
|
||||||
|
1. **Session isolation.** ``billable_call`` takes *no* ``db_session``
|
||||||
|
argument. All ``TokenQuotaService.premium_*`` calls and the audit-row
|
||||||
|
insert each run inside their own ``shielded_async_session()``. This
|
||||||
|
guarantees that a quota commit/rollback can never accidentally flush or
|
||||||
|
roll back rows the caller has staged in the request's main session
|
||||||
|
(e.g. a freshly-created ``ImageGeneration`` row).
|
||||||
|
|
||||||
|
2. **ContextVar safety.** The accumulator is scoped via
|
||||||
|
:func:`scoped_turn` (which uses ``ContextVar.reset(token)``), so a
|
||||||
|
nested ``billable_call`` inside an outer chat turn cannot corrupt the
|
||||||
|
chat turn's accumulator.
|
||||||
|
|
||||||
|
3. **Free configs are still audited.** Free calls bypass the reserve /
|
||||||
|
finalize dance entirely but still record a ``TokenUsage`` audit row with
|
||||||
|
the LiteLLM-reported ``cost_micros``. This keeps the cost-attribution
|
||||||
|
pipeline complete for analytics even when nothing is debited.
|
||||||
|
|
||||||
|
4. **Quota denial raises ``QuotaInsufficientError``.** The route handler is
|
||||||
|
responsible for translating that into HTTP 402. We *do not* catch the
|
||||||
|
denial inside ``billable_call`` — letting it propagate also prevents
|
||||||
|
the image-generation route from creating an ``ImageGeneration`` row
|
||||||
|
for a request that never actually ran.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from collections.abc import AsyncIterator
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
from typing import Any
|
||||||
|
from uuid import UUID, uuid4
|
||||||
|
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.config import config
|
||||||
|
from app.db import shielded_async_session
|
||||||
|
from app.services.token_quota_service import (
|
||||||
|
TokenQuotaService,
|
||||||
|
estimate_call_reserve_micros,
|
||||||
|
)
|
||||||
|
from app.services.token_tracking_service import (
|
||||||
|
TurnTokenAccumulator,
|
||||||
|
record_token_usage,
|
||||||
|
scoped_turn,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class QuotaInsufficientError(Exception):
|
||||||
|
"""Raised when ``TokenQuotaService.premium_reserve`` denies a billable
|
||||||
|
call because the user has exhausted their premium credit pool.
|
||||||
|
|
||||||
|
The route handler should catch this and return HTTP 402 Payment
|
||||||
|
Required (or the equivalent for the surface area). Outside of the HTTP
|
||||||
|
layer (e.g. the ``QuotaCheckedVisionLLM`` wrapper used during indexing)
|
||||||
|
callers may catch this and degrade gracefully — e.g. fall back to OCR
|
||||||
|
when vision is unavailable.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
usage_type: str,
|
||||||
|
used_micros: int,
|
||||||
|
limit_micros: int,
|
||||||
|
remaining_micros: int,
|
||||||
|
) -> None:
|
||||||
|
self.usage_type = usage_type
|
||||||
|
self.used_micros = used_micros
|
||||||
|
self.limit_micros = limit_micros
|
||||||
|
self.remaining_micros = remaining_micros
|
||||||
|
super().__init__(
|
||||||
|
f"Premium credit exhausted for {usage_type}: "
|
||||||
|
f"used={used_micros} limit={limit_micros} remaining={remaining_micros} (micro-USD)"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def billable_call(
|
||||||
|
*,
|
||||||
|
user_id: UUID,
|
||||||
|
search_space_id: int,
|
||||||
|
billing_tier: str,
|
||||||
|
base_model: str,
|
||||||
|
quota_reserve_tokens: int | None = None,
|
||||||
|
quota_reserve_micros_override: int | None = None,
|
||||||
|
usage_type: str,
|
||||||
|
thread_id: int | None = None,
|
||||||
|
message_id: int | None = None,
|
||||||
|
call_details: dict[str, Any] | None = None,
|
||||||
|
) -> AsyncIterator[TurnTokenAccumulator]:
|
||||||
|
"""Wrap a single billable LLM/image call.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: Owner of the credit pool to debit. For vision-LLM during
|
||||||
|
indexing this is the *search-space owner* (issue M), not the
|
||||||
|
triggering user.
|
||||||
|
search_space_id: Required — recorded on the ``TokenUsage`` audit row.
|
||||||
|
billing_tier: ``"premium"`` debits; anything else (``"free"``) skips
|
||||||
|
the reserve/finalize dance but still records an audit row with
|
||||||
|
the captured cost.
|
||||||
|
base_model: Used by :func:`estimate_call_reserve_micros` to compute
|
||||||
|
a worst-case reservation from LiteLLM's pricing table.
|
||||||
|
quota_reserve_tokens: Optional per-config override for the chat-style
|
||||||
|
reserve estimator (vision LLM uses this).
|
||||||
|
quota_reserve_micros_override: Optional flat micro-USD reservation
|
||||||
|
(image generation uses this — its cost shape is per-image, not
|
||||||
|
per-token).
|
||||||
|
usage_type: ``"image_generation"`` / ``"vision_extraction"`` / etc.
|
||||||
|
Recorded on the ``TokenUsage`` row.
|
||||||
|
thread_id, message_id: Optional FK columns on ``TokenUsage``.
|
||||||
|
call_details: Optional per-call metadata (model name, parameters)
|
||||||
|
forwarded to ``record_token_usage``.
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
The ``TurnTokenAccumulator`` scoped to this call. The caller invokes
|
||||||
|
the underlying LLM/image API while inside the ``async with``; the
|
||||||
|
``TokenTrackingCallback`` populates the accumulator automatically.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
QuotaInsufficientError: when premium and ``premium_reserve`` denies.
|
||||||
|
"""
|
||||||
|
is_premium = billing_tier == "premium"
|
||||||
|
|
||||||
|
async with scoped_turn() as acc:
|
||||||
|
# ---------- Free path: just audit -------------------------------
|
||||||
|
if not is_premium:
|
||||||
|
try:
|
||||||
|
yield acc
|
||||||
|
finally:
|
||||||
|
# Always audit, even on exception, so we capture cost when
|
||||||
|
# provider returns successfully but the caller raises later.
|
||||||
|
try:
|
||||||
|
async with shielded_async_session() as audit_session:
|
||||||
|
await record_token_usage(
|
||||||
|
audit_session,
|
||||||
|
usage_type=usage_type,
|
||||||
|
search_space_id=search_space_id,
|
||||||
|
user_id=user_id,
|
||||||
|
prompt_tokens=acc.total_prompt_tokens,
|
||||||
|
completion_tokens=acc.total_completion_tokens,
|
||||||
|
total_tokens=acc.grand_total,
|
||||||
|
cost_micros=acc.total_cost_micros,
|
||||||
|
model_breakdown=acc.per_message_summary(),
|
||||||
|
call_details=call_details,
|
||||||
|
thread_id=thread_id,
|
||||||
|
message_id=message_id,
|
||||||
|
)
|
||||||
|
await audit_session.commit()
|
||||||
|
except Exception:
|
||||||
|
logger.exception(
|
||||||
|
"[billable_call] free-path audit insert failed for "
|
||||||
|
"usage_type=%s user_id=%s",
|
||||||
|
usage_type,
|
||||||
|
user_id,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# ---------- Premium path: reserve → execute → finalize ----------
|
||||||
|
if quota_reserve_micros_override is not None:
|
||||||
|
reserve_micros = max(1, int(quota_reserve_micros_override))
|
||||||
|
else:
|
||||||
|
reserve_micros = estimate_call_reserve_micros(
|
||||||
|
base_model=base_model or "",
|
||||||
|
quota_reserve_tokens=quota_reserve_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
request_id = str(uuid4())
|
||||||
|
|
||||||
|
async with shielded_async_session() as quota_session:
|
||||||
|
reserve_result = await TokenQuotaService.premium_reserve(
|
||||||
|
db_session=quota_session,
|
||||||
|
user_id=user_id,
|
||||||
|
request_id=request_id,
|
||||||
|
reserve_micros=reserve_micros,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not reserve_result.allowed:
|
||||||
|
logger.info(
|
||||||
|
"[billable_call] reserve DENIED user=%s usage_type=%s "
|
||||||
|
"reserve=%d used=%d limit=%d remaining=%d",
|
||||||
|
user_id,
|
||||||
|
usage_type,
|
||||||
|
reserve_micros,
|
||||||
|
reserve_result.used,
|
||||||
|
reserve_result.limit,
|
||||||
|
reserve_result.remaining,
|
||||||
|
)
|
||||||
|
raise QuotaInsufficientError(
|
||||||
|
usage_type=usage_type,
|
||||||
|
used_micros=reserve_result.used,
|
||||||
|
limit_micros=reserve_result.limit,
|
||||||
|
remaining_micros=reserve_result.remaining,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"[billable_call] reserve OK user=%s usage_type=%s reserve_micros=%d "
|
||||||
|
"(remaining=%d)",
|
||||||
|
user_id,
|
||||||
|
usage_type,
|
||||||
|
reserve_micros,
|
||||||
|
reserve_result.remaining,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
yield acc
|
||||||
|
except BaseException:
|
||||||
|
# Release on any failure (including QuotaInsufficientError raised
|
||||||
|
# from a downstream call, asyncio cancellation, etc.). We use
|
||||||
|
# BaseException so cancellation also releases.
|
||||||
|
try:
|
||||||
|
async with shielded_async_session() as quota_session:
|
||||||
|
await TokenQuotaService.premium_release(
|
||||||
|
db_session=quota_session,
|
||||||
|
user_id=user_id,
|
||||||
|
reserved_micros=reserve_micros,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logger.exception(
|
||||||
|
"[billable_call] premium_release failed for user=%s "
|
||||||
|
"reserve_micros=%d (reservation will be GC'd by quota "
|
||||||
|
"reconciliation if/when implemented)",
|
||||||
|
user_id,
|
||||||
|
reserve_micros,
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
# ---------- Success: finalize + audit ----------------------------
|
||||||
|
actual_micros = acc.total_cost_micros
|
||||||
|
try:
|
||||||
|
async with shielded_async_session() as quota_session:
|
||||||
|
final_result = await TokenQuotaService.premium_finalize(
|
||||||
|
db_session=quota_session,
|
||||||
|
user_id=user_id,
|
||||||
|
request_id=request_id,
|
||||||
|
actual_micros=actual_micros,
|
||||||
|
reserved_micros=reserve_micros,
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
"[billable_call] finalize user=%s usage_type=%s actual=%d "
|
||||||
|
"reserved=%d → used=%d/%d (remaining=%d)",
|
||||||
|
user_id,
|
||||||
|
usage_type,
|
||||||
|
actual_micros,
|
||||||
|
reserve_micros,
|
||||||
|
final_result.used,
|
||||||
|
final_result.limit,
|
||||||
|
final_result.remaining,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
# Last-ditch: if finalize itself fails, we must at least release
|
||||||
|
# so the reservation doesn't leak.
|
||||||
|
logger.exception(
|
||||||
|
"[billable_call] premium_finalize failed for user=%s; "
|
||||||
|
"attempting release",
|
||||||
|
user_id,
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
async with shielded_async_session() as quota_session:
|
||||||
|
await TokenQuotaService.premium_release(
|
||||||
|
db_session=quota_session,
|
||||||
|
user_id=user_id,
|
||||||
|
reserved_micros=reserve_micros,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logger.exception(
|
||||||
|
"[billable_call] release after finalize failure ALSO failed "
|
||||||
|
"for user=%s",
|
||||||
|
user_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with shielded_async_session() as audit_session:
|
||||||
|
await record_token_usage(
|
||||||
|
audit_session,
|
||||||
|
usage_type=usage_type,
|
||||||
|
search_space_id=search_space_id,
|
||||||
|
user_id=user_id,
|
||||||
|
prompt_tokens=acc.total_prompt_tokens,
|
||||||
|
completion_tokens=acc.total_completion_tokens,
|
||||||
|
total_tokens=acc.grand_total,
|
||||||
|
cost_micros=actual_micros,
|
||||||
|
model_breakdown=acc.per_message_summary(),
|
||||||
|
call_details=call_details,
|
||||||
|
thread_id=thread_id,
|
||||||
|
message_id=message_id,
|
||||||
|
)
|
||||||
|
await audit_session.commit()
|
||||||
|
except Exception:
|
||||||
|
logger.exception(
|
||||||
|
"[billable_call] premium-path audit insert failed for "
|
||||||
|
"usage_type=%s user_id=%s (debit was applied)",
|
||||||
|
usage_type,
|
||||||
|
user_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def _resolve_agent_billing_for_search_space(
|
||||||
|
session: AsyncSession,
|
||||||
|
search_space_id: int,
|
||||||
|
*,
|
||||||
|
thread_id: int | None = None,
|
||||||
|
) -> tuple[UUID, str, str]:
|
||||||
|
"""Resolve ``(owner_user_id, billing_tier, base_model)`` for the search-space
|
||||||
|
agent LLM.
|
||||||
|
|
||||||
|
Used by Celery tasks (podcast generation, video presentation) to bill the
|
||||||
|
search-space owner's premium credit pool when the agent LLM is premium.
|
||||||
|
|
||||||
|
Resolution rules mirror chat at ``stream_new_chat.py:2294-2351``:
|
||||||
|
|
||||||
|
- Search space not found / no ``agent_llm_id``: raise ``ValueError``.
|
||||||
|
- **Auto mode** (``id == AUTO_FASTEST_ID == 0``):
|
||||||
|
* ``thread_id`` is set: delegate to
|
||||||
|
``resolve_or_get_pinned_llm_config_id`` (the same call chat uses) and
|
||||||
|
recurse into the resolved id. Reuses chat's existing pin if present
|
||||||
|
so the same model bills for chat + downstream podcast/video. If the
|
||||||
|
user is not premium-eligible, the pin service auto-restricts to free
|
||||||
|
deployments — denial only happens later in
|
||||||
|
``billable_call.premium_reserve`` if the pin really is premium and
|
||||||
|
credit ran out mid-flow.
|
||||||
|
* ``thread_id`` is None: fallback to ``("free", "auto")``. Forward-compat
|
||||||
|
for any future direct-API path; today both Celery tasks always pass
|
||||||
|
``thread_id``.
|
||||||
|
- **Negative id** (global YAML / OpenRouter): ``cfg["billing_tier"]``
|
||||||
|
(defaults to ``"free"`` via ``app/config/__init__.py:52`` setdefault),
|
||||||
|
``base_model = litellm_params.get("base_model") or model_name`` —
|
||||||
|
NOT provider-prefixed, matching chat's cost-map lookup convention.
|
||||||
|
- **Positive id** (user BYOK ``NewLLMConfig``): always free (matches
|
||||||
|
``AgentConfig.from_new_llm_config`` which hard-codes ``billing_tier="free"``);
|
||||||
|
``base_model`` from ``litellm_params`` or ``model_name``.
|
||||||
|
|
||||||
|
Note on imports: ``llm_service``, ``auto_model_pin_service``, and
|
||||||
|
``llm_router_service`` are imported lazily inside the function body to
|
||||||
|
avoid hoisting litellm side-effects (``litellm.callbacks =
|
||||||
|
[token_tracker]``, ``litellm.drop_params``, etc.) into
|
||||||
|
``billable_calls.py``'s module load path.
|
||||||
|
"""
|
||||||
|
from sqlalchemy import select
|
||||||
|
|
||||||
|
from app.db import NewLLMConfig, SearchSpace
|
||||||
|
|
||||||
|
result = await session.execute(
|
||||||
|
select(SearchSpace).where(SearchSpace.id == search_space_id)
|
||||||
|
)
|
||||||
|
search_space = result.scalars().first()
|
||||||
|
if search_space is None:
|
||||||
|
raise ValueError(f"Search space {search_space_id} not found")
|
||||||
|
|
||||||
|
agent_llm_id = search_space.agent_llm_id
|
||||||
|
if agent_llm_id is None:
|
||||||
|
raise ValueError(
|
||||||
|
f"Search space {search_space_id} has no agent_llm_id configured"
|
||||||
|
)
|
||||||
|
|
||||||
|
owner_user_id: UUID = search_space.user_id
|
||||||
|
|
||||||
|
from app.services.auto_model_pin_service import (
|
||||||
|
AUTO_FASTEST_ID,
|
||||||
|
resolve_or_get_pinned_llm_config_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
if agent_llm_id == AUTO_FASTEST_ID:
|
||||||
|
if thread_id is None:
|
||||||
|
return owner_user_id, "free", "auto"
|
||||||
|
try:
|
||||||
|
resolution = await resolve_or_get_pinned_llm_config_id(
|
||||||
|
session,
|
||||||
|
thread_id=thread_id,
|
||||||
|
search_space_id=search_space_id,
|
||||||
|
user_id=str(owner_user_id),
|
||||||
|
selected_llm_config_id=AUTO_FASTEST_ID,
|
||||||
|
)
|
||||||
|
except ValueError:
|
||||||
|
logger.warning(
|
||||||
|
"[agent_billing] Auto-mode pin resolution failed for "
|
||||||
|
"search_space=%s thread=%s; falling back to free",
|
||||||
|
search_space_id,
|
||||||
|
thread_id,
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
return owner_user_id, "free", "auto"
|
||||||
|
agent_llm_id = resolution.resolved_llm_config_id
|
||||||
|
|
||||||
|
if agent_llm_id < 0:
|
||||||
|
from app.services.llm_service import get_global_llm_config
|
||||||
|
|
||||||
|
cfg = get_global_llm_config(agent_llm_id) or {}
|
||||||
|
billing_tier = str(cfg.get("billing_tier", "free")).lower()
|
||||||
|
litellm_params = cfg.get("litellm_params") or {}
|
||||||
|
base_model = litellm_params.get("base_model") or cfg.get("model_name") or ""
|
||||||
|
return owner_user_id, billing_tier, base_model
|
||||||
|
|
||||||
|
nlc_result = await session.execute(
|
||||||
|
select(NewLLMConfig).where(
|
||||||
|
NewLLMConfig.id == agent_llm_id,
|
||||||
|
NewLLMConfig.search_space_id == search_space_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
nlc = nlc_result.scalars().first()
|
||||||
|
base_model = ""
|
||||||
|
if nlc is not None:
|
||||||
|
litellm_params = nlc.litellm_params or {}
|
||||||
|
base_model = litellm_params.get("base_model") or nlc.model_name or ""
|
||||||
|
return owner_user_id, "free", base_model
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"QuotaInsufficientError",
|
||||||
|
"_resolve_agent_billing_for_search_space",
|
||||||
|
"billable_call",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
# Re-export the config knob so callers don't have to import config just for
|
||||||
|
# the default image reserve.
|
||||||
|
DEFAULT_IMAGE_RESERVE_MICROS = config.QUOTA_DEFAULT_IMAGE_RESERVE_MICROS
|
||||||
|
|
@ -134,42 +134,16 @@ PROVIDER_MAP = {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
# Default ``api_base`` per LiteLLM provider prefix. Used as a safety net when
|
# ``PROVIDER_DEFAULT_API_BASE`` and ``PROVIDER_KEY_DEFAULT_API_BASE`` were
|
||||||
# a global LLM config does *not* specify ``api_base``: without this, LiteLLM
|
# hoisted to ``app.services.provider_api_base`` so vision and image-gen
|
||||||
# happily picks up provider-agnostic env vars (e.g. ``AZURE_API_BASE``,
|
# call sites can share the exact same defense (OpenRouter / Groq / etc.
|
||||||
# ``OPENAI_API_BASE``) and routes, say, an ``openrouter/anthropic/claude-3-haiku``
|
# 404-ing against an inherited Azure endpoint). Re-exported here for
|
||||||
# request to an Azure endpoint, which then 404s with ``Resource not found``.
|
# backward compatibility with any external import.
|
||||||
# Only providers with a well-known, stable public base URL are listed here —
|
from app.services.provider_api_base import ( # noqa: E402
|
||||||
# self-hosted / BYO-endpoint providers (ollama, custom, bedrock, vertex_ai,
|
PROVIDER_DEFAULT_API_BASE,
|
||||||
# huggingface, databricks, cloudflare, replicate) are intentionally omitted
|
PROVIDER_KEY_DEFAULT_API_BASE,
|
||||||
# so their existing config-driven behaviour is preserved.
|
resolve_api_base,
|
||||||
PROVIDER_DEFAULT_API_BASE = {
|
)
|
||||||
"openrouter": "https://openrouter.ai/api/v1",
|
|
||||||
"groq": "https://api.groq.com/openai/v1",
|
|
||||||
"mistral": "https://api.mistral.ai/v1",
|
|
||||||
"perplexity": "https://api.perplexity.ai",
|
|
||||||
"xai": "https://api.x.ai/v1",
|
|
||||||
"cerebras": "https://api.cerebras.ai/v1",
|
|
||||||
"deepinfra": "https://api.deepinfra.com/v1/openai",
|
|
||||||
"fireworks_ai": "https://api.fireworks.ai/inference/v1",
|
|
||||||
"together_ai": "https://api.together.xyz/v1",
|
|
||||||
"anyscale": "https://api.endpoints.anyscale.com/v1",
|
|
||||||
"cometapi": "https://api.cometapi.com/v1",
|
|
||||||
"sambanova": "https://api.sambanova.ai/v1",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
# Canonical provider → base URL when a config uses a generic ``openai``-style
|
|
||||||
# prefix but the ``provider`` field tells us which API it really is
|
|
||||||
# (e.g. DeepSeek/Alibaba/Moonshot/Zhipu/MiniMax all use ``openai`` compat but
|
|
||||||
# each has its own base URL).
|
|
||||||
PROVIDER_KEY_DEFAULT_API_BASE = {
|
|
||||||
"DEEPSEEK": "https://api.deepseek.com/v1",
|
|
||||||
"ALIBABA_QWEN": "https://dashscope-intl.aliyuncs.com/compatible-mode/v1",
|
|
||||||
"MOONSHOT": "https://api.moonshot.ai/v1",
|
|
||||||
"ZHIPU": "https://open.bigmodel.cn/api/paas/v4",
|
|
||||||
"MINIMAX": "https://api.minimax.io/v1",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class LLMRouterService:
|
class LLMRouterService:
|
||||||
|
|
@ -466,14 +440,14 @@ class LLMRouterService:
|
||||||
# Resolve ``api_base``. Config value wins; otherwise apply a
|
# Resolve ``api_base``. Config value wins; otherwise apply a
|
||||||
# provider-aware default so the deployment does not silently
|
# provider-aware default so the deployment does not silently
|
||||||
# inherit unrelated env vars (e.g. ``AZURE_API_BASE``) and route
|
# inherit unrelated env vars (e.g. ``AZURE_API_BASE``) and route
|
||||||
# requests to the wrong endpoint. See ``PROVIDER_DEFAULT_API_BASE``
|
# requests to the wrong endpoint. See ``provider_api_base``
|
||||||
# docstring for the motivating bug (OpenRouter models 404-ing
|
# docstring for the motivating bug (OpenRouter models 404-ing
|
||||||
# against an Azure endpoint).
|
# against an Azure endpoint).
|
||||||
api_base = config.get("api_base")
|
api_base = resolve_api_base(
|
||||||
if not api_base:
|
provider=provider,
|
||||||
api_base = PROVIDER_KEY_DEFAULT_API_BASE.get(provider)
|
provider_prefix=provider_prefix,
|
||||||
if not api_base:
|
config_api_base=config.get("api_base"),
|
||||||
api_base = PROVIDER_DEFAULT_API_BASE.get(provider_prefix)
|
)
|
||||||
if api_base:
|
if api_base:
|
||||||
litellm_params["api_base"] = api_base
|
litellm_params["api_base"] = api_base
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -496,8 +496,14 @@ async def get_vision_llm(
|
||||||
- Auto mode (ID 0): VisionLLMRouterService
|
- Auto mode (ID 0): VisionLLMRouterService
|
||||||
- Global (negative ID): YAML configs
|
- Global (negative ID): YAML configs
|
||||||
- DB (positive ID): VisionLLMConfig table
|
- DB (positive ID): VisionLLMConfig table
|
||||||
|
|
||||||
|
Premium global configs are wrapped in :class:`QuotaCheckedVisionLLM`
|
||||||
|
so each ``ainvoke`` debits the search-space owner's premium credit
|
||||||
|
pool. User-owned BYOK configs and free global configs are returned
|
||||||
|
unwrapped — they don't consume premium credit (issue M).
|
||||||
"""
|
"""
|
||||||
from app.db import VisionLLMConfig
|
from app.db import VisionLLMConfig
|
||||||
|
from app.services.quota_checked_vision_llm import QuotaCheckedVisionLLM
|
||||||
from app.services.vision_llm_router_service import (
|
from app.services.vision_llm_router_service import (
|
||||||
VISION_PROVIDER_MAP,
|
VISION_PROVIDER_MAP,
|
||||||
VisionLLMRouterService,
|
VisionLLMRouterService,
|
||||||
|
|
@ -519,6 +525,8 @@ async def get_vision_llm(
|
||||||
logger.error(f"No vision LLM configured for search space {search_space_id}")
|
logger.error(f"No vision LLM configured for search space {search_space_id}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
owner_user_id = search_space.user_id
|
||||||
|
|
||||||
if is_vision_auto_mode(config_id):
|
if is_vision_auto_mode(config_id):
|
||||||
if not VisionLLMRouterService.is_initialized():
|
if not VisionLLMRouterService.is_initialized():
|
||||||
logger.error(
|
logger.error(
|
||||||
|
|
@ -526,6 +534,13 @@ async def get_vision_llm(
|
||||||
)
|
)
|
||||||
return None
|
return None
|
||||||
try:
|
try:
|
||||||
|
# Auto mode is currently treated as free at the wrapper
|
||||||
|
# level — the underlying router can dispatch to either
|
||||||
|
# premium or free YAML configs but routing decisions are
|
||||||
|
# opaque. If/when we want to bill Auto-routed vision
|
||||||
|
# calls we'd need to thread the resolved deployment's
|
||||||
|
# billing_tier back from the router. For now we keep
|
||||||
|
# parity with chat Auto, which also doesn't pre-classify.
|
||||||
return ChatLiteLLMRouter(
|
return ChatLiteLLMRouter(
|
||||||
router=VisionLLMRouterService.get_router(),
|
router=VisionLLMRouterService.get_router(),
|
||||||
streaming=True,
|
streaming=True,
|
||||||
|
|
@ -562,8 +577,21 @@ async def get_vision_llm(
|
||||||
|
|
||||||
from app.agents.new_chat.llm_config import SanitizedChatLiteLLM
|
from app.agents.new_chat.llm_config import SanitizedChatLiteLLM
|
||||||
|
|
||||||
return SanitizedChatLiteLLM(**litellm_kwargs)
|
inner_llm = SanitizedChatLiteLLM(**litellm_kwargs)
|
||||||
|
|
||||||
|
billing_tier = str(global_cfg.get("billing_tier", "free")).lower()
|
||||||
|
if billing_tier == "premium":
|
||||||
|
return QuotaCheckedVisionLLM(
|
||||||
|
inner_llm,
|
||||||
|
user_id=owner_user_id,
|
||||||
|
search_space_id=search_space_id,
|
||||||
|
billing_tier=billing_tier,
|
||||||
|
base_model=model_string,
|
||||||
|
quota_reserve_tokens=global_cfg.get("quota_reserve_tokens"),
|
||||||
|
)
|
||||||
|
return inner_llm
|
||||||
|
|
||||||
|
# User-owned (positive ID) BYOK configs — always free.
|
||||||
result = await session.execute(
|
result = await session.execute(
|
||||||
select(VisionLLMConfig).where(
|
select(VisionLLMConfig).where(
|
||||||
VisionLLMConfig.id == config_id,
|
VisionLLMConfig.id == config_id,
|
||||||
|
|
|
||||||
|
|
@ -93,6 +93,35 @@ def _is_text_output_model(model: dict) -> bool:
|
||||||
return output_mods == ["text"]
|
return output_mods == ["text"]
|
||||||
|
|
||||||
|
|
||||||
|
def _is_image_output_model(model: dict) -> bool:
|
||||||
|
"""Return True if the model can produce image output.
|
||||||
|
|
||||||
|
OpenRouter's ``architecture.output_modalities`` is a list (e.g.
|
||||||
|
``["image"]`` for pure image generators, ``["text", "image"]`` for
|
||||||
|
multi-modal generators that also emit captions). We accept any model
|
||||||
|
that can output images; the call site decides whether to use the
|
||||||
|
image-generation API or chat completion.
|
||||||
|
"""
|
||||||
|
output_mods = model.get("architecture", {}).get("output_modalities", []) or []
|
||||||
|
return "image" in output_mods
|
||||||
|
|
||||||
|
|
||||||
|
def _is_vision_input_model(model: dict) -> bool:
|
||||||
|
"""Return True if the model can ingest an image AND emit text.
|
||||||
|
|
||||||
|
OpenRouter's ``architecture.input_modalities`` lists what the model
|
||||||
|
accepts; ``output_modalities`` lists what it produces. A vision LLM
|
||||||
|
is a model that takes images in and produces text out — i.e. it can
|
||||||
|
answer questions about a screenshot or extract content from an
|
||||||
|
image. Pure image-to-image models (e.g. style transfer) and
|
||||||
|
text-only models are excluded.
|
||||||
|
"""
|
||||||
|
arch = model.get("architecture", {}) or {}
|
||||||
|
input_mods = arch.get("input_modalities", []) or []
|
||||||
|
output_mods = arch.get("output_modalities", []) or []
|
||||||
|
return "image" in input_mods and "text" in output_mods
|
||||||
|
|
||||||
|
|
||||||
def _supports_tool_calling(model: dict) -> bool:
|
def _supports_tool_calling(model: dict) -> bool:
|
||||||
"""Return True if the model supports function/tool calling."""
|
"""Return True if the model supports function/tool calling."""
|
||||||
supported = model.get("supported_parameters") or []
|
supported = model.get("supported_parameters") or []
|
||||||
|
|
@ -175,6 +204,32 @@ async def _fetch_models_async() -> list[dict] | None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_raw_pricing(raw_models: list[dict]) -> dict[str, dict[str, str]]:
|
||||||
|
"""Return a ``{model_id: {"prompt": str, "completion": str}}`` map.
|
||||||
|
|
||||||
|
Pricing values are kept as the raw OpenRouter strings (e.g.
|
||||||
|
``"0.000003"``); ``pricing_registration`` converts them to floats
|
||||||
|
when registering with LiteLLM. Models with missing or malformed
|
||||||
|
pricing are simply omitted — operator-side risk if any of those are
|
||||||
|
premium.
|
||||||
|
"""
|
||||||
|
pricing: dict[str, dict[str, str]] = {}
|
||||||
|
for model in raw_models:
|
||||||
|
model_id = str(model.get("id") or "").strip()
|
||||||
|
if not model_id:
|
||||||
|
continue
|
||||||
|
p = model.get("pricing") or {}
|
||||||
|
prompt = p.get("prompt")
|
||||||
|
completion = p.get("completion")
|
||||||
|
if prompt is None and completion is None:
|
||||||
|
continue
|
||||||
|
pricing[model_id] = {
|
||||||
|
"prompt": str(prompt) if prompt is not None else "",
|
||||||
|
"completion": str(completion) if completion is not None else "",
|
||||||
|
}
|
||||||
|
return pricing
|
||||||
|
|
||||||
|
|
||||||
def _generate_configs(
|
def _generate_configs(
|
||||||
raw_models: list[dict],
|
raw_models: list[dict],
|
||||||
settings: dict[str, Any],
|
settings: dict[str, Any],
|
||||||
|
|
@ -282,6 +337,162 @@ def _generate_configs(
|
||||||
return configs
|
return configs
|
||||||
|
|
||||||
|
|
||||||
|
# ID-offset bands used to keep dynamic OpenRouter configs in their own
|
||||||
|
# namespace per surface. Image / vision get separate bands so a single
|
||||||
|
# Postgres-INTEGER cfg ID is unambiguous about which selector it belongs to.
|
||||||
|
_OPENROUTER_IMAGE_ID_OFFSET_DEFAULT = -20000
|
||||||
|
_OPENROUTER_VISION_ID_OFFSET_DEFAULT = -30000
|
||||||
|
|
||||||
|
|
||||||
|
def _generate_image_gen_configs(
|
||||||
|
raw_models: list[dict], settings: dict[str, Any]
|
||||||
|
) -> list[dict]:
|
||||||
|
"""Convert OpenRouter image-generation models into global image-gen
|
||||||
|
config dicts (matches the YAML shape consumed by ``image_generation_routes``).
|
||||||
|
|
||||||
|
Filter:
|
||||||
|
- architecture.output_modalities contains "image"
|
||||||
|
- compatible provider (excluded slugs blocked)
|
||||||
|
- allowed model id (excluded list blocked)
|
||||||
|
|
||||||
|
Notably we *drop* the chat-only filters (``_supports_tool_calling`` and
|
||||||
|
``_has_sufficient_context``) because tool calls and context windows are
|
||||||
|
irrelevant for the ``aimage_generation`` API. ``billing_tier`` is
|
||||||
|
derived per model the same way as chat (``_openrouter_tier``).
|
||||||
|
|
||||||
|
Cost is intentionally *not* registered with LiteLLM at startup
|
||||||
|
(``pricing_registration`` skips image gen): OpenRouter image-gen
|
||||||
|
models are not in LiteLLM's native cost map and OpenRouter populates
|
||||||
|
``response_cost`` directly from the response header. A defensive
|
||||||
|
branch in ``_extract_cost_usd`` handles the rare case where
|
||||||
|
``usage.cost`` is missing — see ``token_tracking_service``.
|
||||||
|
"""
|
||||||
|
id_offset: int = int(
|
||||||
|
settings.get("image_id_offset") or _OPENROUTER_IMAGE_ID_OFFSET_DEFAULT
|
||||||
|
)
|
||||||
|
api_key: str = settings.get("api_key", "")
|
||||||
|
rpm: int = settings.get("rpm", 200)
|
||||||
|
free_rpm: int = settings.get("free_rpm", 20)
|
||||||
|
litellm_params: dict = settings.get("litellm_params") or {}
|
||||||
|
|
||||||
|
image_models = [
|
||||||
|
m
|
||||||
|
for m in raw_models
|
||||||
|
if _is_image_output_model(m)
|
||||||
|
and _is_compatible_provider(m)
|
||||||
|
and _is_allowed_model(m)
|
||||||
|
and "/" in m.get("id", "")
|
||||||
|
]
|
||||||
|
|
||||||
|
configs: list[dict] = []
|
||||||
|
taken: set[int] = set()
|
||||||
|
for model in image_models:
|
||||||
|
model_id: str = model["id"]
|
||||||
|
name: str = model.get("name", model_id)
|
||||||
|
tier = _openrouter_tier(model)
|
||||||
|
|
||||||
|
cfg: dict[str, Any] = {
|
||||||
|
"id": _stable_config_id(model_id, id_offset, taken),
|
||||||
|
"name": name,
|
||||||
|
"description": f"{name} via OpenRouter (image generation)",
|
||||||
|
"provider": "OPENROUTER",
|
||||||
|
"model_name": model_id,
|
||||||
|
"api_key": api_key,
|
||||||
|
"api_base": "",
|
||||||
|
"api_version": None,
|
||||||
|
"rpm": free_rpm if tier == "free" else rpm,
|
||||||
|
"litellm_params": dict(litellm_params),
|
||||||
|
"billing_tier": tier,
|
||||||
|
_OPENROUTER_DYNAMIC_MARKER: True,
|
||||||
|
}
|
||||||
|
configs.append(cfg)
|
||||||
|
|
||||||
|
return configs
|
||||||
|
|
||||||
|
|
||||||
|
def _generate_vision_llm_configs(
|
||||||
|
raw_models: list[dict], settings: dict[str, Any]
|
||||||
|
) -> list[dict]:
|
||||||
|
"""Convert OpenRouter vision-capable LLMs into global vision-LLM config
|
||||||
|
dicts (matches the YAML shape consumed by ``vision_llm_routes``).
|
||||||
|
|
||||||
|
Filter:
|
||||||
|
- architecture.input_modalities contains "image"
|
||||||
|
- architecture.output_modalities contains "text"
|
||||||
|
- compatible provider (excluded slugs blocked)
|
||||||
|
- allowed model id (excluded list blocked)
|
||||||
|
|
||||||
|
Vision-LLM is invoked from the indexer (image extraction during
|
||||||
|
document upload) via ``langchain_litellm.ChatLiteLLM.ainvoke``, so
|
||||||
|
the chat-only ``_supports_tool_calling`` and ``_has_sufficient_context``
|
||||||
|
filters do not apply: a small-context vision model that doesn't
|
||||||
|
advertise tool-calling is still perfectly viable for "describe this
|
||||||
|
image" prompts.
|
||||||
|
"""
|
||||||
|
id_offset: int = int(
|
||||||
|
settings.get("vision_id_offset") or _OPENROUTER_VISION_ID_OFFSET_DEFAULT
|
||||||
|
)
|
||||||
|
api_key: str = settings.get("api_key", "")
|
||||||
|
rpm: int = settings.get("rpm", 200)
|
||||||
|
tpm: int = settings.get("tpm", 1_000_000)
|
||||||
|
free_rpm: int = settings.get("free_rpm", 20)
|
||||||
|
free_tpm: int = settings.get("free_tpm", 100_000)
|
||||||
|
quota_reserve_tokens: int = settings.get("quota_reserve_tokens", 4000)
|
||||||
|
litellm_params: dict = settings.get("litellm_params") or {}
|
||||||
|
|
||||||
|
vision_models = [
|
||||||
|
m
|
||||||
|
for m in raw_models
|
||||||
|
if _is_vision_input_model(m)
|
||||||
|
and _is_compatible_provider(m)
|
||||||
|
and _is_allowed_model(m)
|
||||||
|
and "/" in m.get("id", "")
|
||||||
|
]
|
||||||
|
|
||||||
|
configs: list[dict] = []
|
||||||
|
taken: set[int] = set()
|
||||||
|
for model in vision_models:
|
||||||
|
model_id: str = model["id"]
|
||||||
|
name: str = model.get("name", model_id)
|
||||||
|
tier = _openrouter_tier(model)
|
||||||
|
pricing = model.get("pricing") or {}
|
||||||
|
|
||||||
|
# Capture per-token prices so ``pricing_registration`` can
|
||||||
|
# register them with LiteLLM at startup (and so the cost
|
||||||
|
# estimator in ``estimate_call_reserve_micros`` can resolve
|
||||||
|
# them at reserve time).
|
||||||
|
try:
|
||||||
|
input_cost = float(pricing.get("prompt", 0) or 0)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
input_cost = 0.0
|
||||||
|
try:
|
||||||
|
output_cost = float(pricing.get("completion", 0) or 0)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
output_cost = 0.0
|
||||||
|
|
||||||
|
cfg: dict[str, Any] = {
|
||||||
|
"id": _stable_config_id(model_id, id_offset, taken),
|
||||||
|
"name": name,
|
||||||
|
"description": f"{name} via OpenRouter (vision)",
|
||||||
|
"provider": "OPENROUTER",
|
||||||
|
"model_name": model_id,
|
||||||
|
"api_key": api_key,
|
||||||
|
"api_base": "",
|
||||||
|
"api_version": None,
|
||||||
|
"rpm": free_rpm if tier == "free" else rpm,
|
||||||
|
"tpm": free_tpm if tier == "free" else tpm,
|
||||||
|
"litellm_params": dict(litellm_params),
|
||||||
|
"billing_tier": tier,
|
||||||
|
"quota_reserve_tokens": quota_reserve_tokens,
|
||||||
|
"input_cost_per_token": input_cost or None,
|
||||||
|
"output_cost_per_token": output_cost or None,
|
||||||
|
_OPENROUTER_DYNAMIC_MARKER: True,
|
||||||
|
}
|
||||||
|
configs.append(cfg)
|
||||||
|
|
||||||
|
return configs
|
||||||
|
|
||||||
|
|
||||||
class OpenRouterIntegrationService:
|
class OpenRouterIntegrationService:
|
||||||
"""Singleton that manages the dynamic OpenRouter model catalogue."""
|
"""Singleton that manages the dynamic OpenRouter model catalogue."""
|
||||||
|
|
||||||
|
|
@ -300,6 +511,19 @@ class OpenRouterIntegrationService:
|
||||||
# Shape: {model_name: {"gated": bool, "score": float | None}}
|
# Shape: {model_name: {"gated": bool, "score": float | None}}
|
||||||
self._health_cache: dict[str, dict[str, Any]] = {}
|
self._health_cache: dict[str, dict[str, Any]] = {}
|
||||||
self._enrich_task: asyncio.Task | None = None
|
self._enrich_task: asyncio.Task | None = None
|
||||||
|
# Raw OpenRouter pricing per model_id, captured at the same time
|
||||||
|
# we generate configs. Consumed by ``pricing_registration`` to
|
||||||
|
# teach LiteLLM the per-token cost of every dynamic deployment so
|
||||||
|
# the success-callback can populate ``response_cost`` correctly.
|
||||||
|
self._raw_pricing: dict[str, dict[str, str]] = {}
|
||||||
|
# Cached raw catalogue from the most recent fetch. Image / vision
|
||||||
|
# emitters reuse this to avoid a second network call per surface.
|
||||||
|
self._raw_models: list[dict] = []
|
||||||
|
# Image / vision config caches (only populated when the matching
|
||||||
|
# opt-in flag is true on initialize). Refreshed in lockstep with
|
||||||
|
# the chat catalogue.
|
||||||
|
self._image_configs: list[dict] = []
|
||||||
|
self._vision_configs: list[dict] = []
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_instance(cls) -> "OpenRouterIntegrationService":
|
def get_instance(cls) -> "OpenRouterIntegrationService":
|
||||||
|
|
@ -329,8 +553,32 @@ class OpenRouterIntegrationService:
|
||||||
self._initialized = True
|
self._initialized = True
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
self._raw_models = raw_models
|
||||||
self._configs = _generate_configs(raw_models, settings)
|
self._configs = _generate_configs(raw_models, settings)
|
||||||
self._configs_by_id = {c["id"]: c for c in self._configs}
|
self._configs_by_id = {c["id"]: c for c in self._configs}
|
||||||
|
self._raw_pricing = _extract_raw_pricing(raw_models)
|
||||||
|
|
||||||
|
# Populate image / vision caches when their opt-in flag is set.
|
||||||
|
# Empty otherwise so the accessors return [] without re-running
|
||||||
|
# filters every refresh.
|
||||||
|
if settings.get("image_generation_enabled"):
|
||||||
|
self._image_configs = _generate_image_gen_configs(raw_models, settings)
|
||||||
|
logger.info(
|
||||||
|
"OpenRouter integration: image-gen emission ON (%d models)",
|
||||||
|
len(self._image_configs),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self._image_configs = []
|
||||||
|
|
||||||
|
if settings.get("vision_enabled"):
|
||||||
|
self._vision_configs = _generate_vision_llm_configs(raw_models, settings)
|
||||||
|
logger.info(
|
||||||
|
"OpenRouter integration: vision LLM emission ON (%d models)",
|
||||||
|
len(self._vision_configs),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self._vision_configs = []
|
||||||
|
|
||||||
self._initialized = True
|
self._initialized = True
|
||||||
|
|
||||||
tier_counts = self._tier_counts(self._configs)
|
tier_counts = self._tier_counts(self._configs)
|
||||||
|
|
@ -369,6 +617,8 @@ class OpenRouterIntegrationService:
|
||||||
|
|
||||||
new_configs = _generate_configs(raw_models, self._settings)
|
new_configs = _generate_configs(raw_models, self._settings)
|
||||||
new_by_id = {c["id"]: c for c in new_configs}
|
new_by_id = {c["id"]: c for c in new_configs}
|
||||||
|
self._raw_pricing = _extract_raw_pricing(raw_models)
|
||||||
|
self._raw_models = raw_models
|
||||||
|
|
||||||
from app.config import config as app_config
|
from app.config import config as app_config
|
||||||
|
|
||||||
|
|
@ -382,6 +632,29 @@ class OpenRouterIntegrationService:
|
||||||
self._configs = new_configs
|
self._configs = new_configs
|
||||||
self._configs_by_id = new_by_id
|
self._configs_by_id = new_by_id
|
||||||
|
|
||||||
|
# Image / vision lists are atomic-swapped the same way: filter out
|
||||||
|
# the previous dynamic entries from the live config list and append
|
||||||
|
# the freshly generated ones. No-ops when the opt-in flag is off.
|
||||||
|
if self._settings.get("image_generation_enabled"):
|
||||||
|
new_image = _generate_image_gen_configs(raw_models, self._settings)
|
||||||
|
static_image = [
|
||||||
|
c
|
||||||
|
for c in app_config.GLOBAL_IMAGE_GEN_CONFIGS
|
||||||
|
if not c.get(_OPENROUTER_DYNAMIC_MARKER)
|
||||||
|
]
|
||||||
|
app_config.GLOBAL_IMAGE_GEN_CONFIGS = static_image + new_image
|
||||||
|
self._image_configs = new_image
|
||||||
|
|
||||||
|
if self._settings.get("vision_enabled"):
|
||||||
|
new_vision = _generate_vision_llm_configs(raw_models, self._settings)
|
||||||
|
static_vision = [
|
||||||
|
c
|
||||||
|
for c in app_config.GLOBAL_VISION_LLM_CONFIGS
|
||||||
|
if not c.get(_OPENROUTER_DYNAMIC_MARKER)
|
||||||
|
]
|
||||||
|
app_config.GLOBAL_VISION_LLM_CONFIGS = static_vision + new_vision
|
||||||
|
self._vision_configs = new_vision
|
||||||
|
|
||||||
# Catalogue churn invalidates per-config "recently healthy" credit
|
# Catalogue churn invalidates per-config "recently healthy" credit
|
||||||
# earned by the previous turn's preflight. Drop the whole table so
|
# earned by the previous turn's preflight. Drop the whole table so
|
||||||
# the next turn re-probes against the freshly loaded configs.
|
# the next turn re-probes against the freshly loaded configs.
|
||||||
|
|
@ -407,6 +680,21 @@ class OpenRouterIntegrationService:
|
||||||
# so a hand-picked dead OR model is gated like a dynamic one.
|
# so a hand-picked dead OR model is gated like a dynamic one.
|
||||||
await self._enrich_health_safely(static_configs + new_configs, log_summary=True)
|
await self._enrich_health_safely(static_configs + new_configs, log_summary=True)
|
||||||
|
|
||||||
|
# Re-register LiteLLM pricing for the freshly fetched catalogue
|
||||||
|
# so newly added OR models bill correctly on their first call.
|
||||||
|
# Runs before the router rebuild because the router may issue
|
||||||
|
# cost-table lookups during deployment registration.
|
||||||
|
try:
|
||||||
|
from app.services.pricing_registration import (
|
||||||
|
register_pricing_from_global_configs,
|
||||||
|
)
|
||||||
|
|
||||||
|
register_pricing_from_global_configs()
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning(
|
||||||
|
"OpenRouter refresh: pricing re-registration skipped (%s)", exc
|
||||||
|
)
|
||||||
|
|
||||||
# Rebuild the LiteLLM router so freshly fetched configs flow through
|
# Rebuild the LiteLLM router so freshly fetched configs flow through
|
||||||
# (dynamic OR premium entries now opt into the pool, free ones stay
|
# (dynamic OR premium entries now opt into the pool, free ones stay
|
||||||
# out; a refresh also needs to pick up any static-config edits and
|
# out; a refresh also needs to pick up any static-config edits and
|
||||||
|
|
@ -635,3 +923,34 @@ class OpenRouterIntegrationService:
|
||||||
|
|
||||||
def get_config_by_id(self, config_id: int) -> dict | None:
|
def get_config_by_id(self, config_id: int) -> dict | None:
|
||||||
return self._configs_by_id.get(config_id)
|
return self._configs_by_id.get(config_id)
|
||||||
|
|
||||||
|
def get_image_generation_configs(self) -> list[dict]:
|
||||||
|
"""Return the dynamic OpenRouter image-generation configs (empty
|
||||||
|
list when the ``image_generation_enabled`` flag is off).
|
||||||
|
|
||||||
|
Each entry already has ``billing_tier`` derived per-model from
|
||||||
|
OpenRouter's signals and is shaped to drop directly into
|
||||||
|
``Config.GLOBAL_IMAGE_GEN_CONFIGS``.
|
||||||
|
"""
|
||||||
|
return list(self._image_configs)
|
||||||
|
|
||||||
|
def get_vision_llm_configs(self) -> list[dict]:
|
||||||
|
"""Return the dynamic OpenRouter vision-LLM configs (empty list
|
||||||
|
when the ``vision_enabled`` flag is off).
|
||||||
|
|
||||||
|
Each entry exposes ``input_cost_per_token`` / ``output_cost_per_token``
|
||||||
|
so ``pricing_registration`` can teach LiteLLM the cost of these
|
||||||
|
models the same way it does for chat — which keeps the billable
|
||||||
|
wrapper able to debit accurate micro-USD on a vision call.
|
||||||
|
"""
|
||||||
|
return list(self._vision_configs)
|
||||||
|
|
||||||
|
def get_raw_pricing(self) -> dict[str, dict[str, str]]:
|
||||||
|
"""Return the cached raw OpenRouter pricing map.
|
||||||
|
|
||||||
|
Shape: ``{model_id: {"prompt": str, "completion": str}}``. The
|
||||||
|
values are the strings OpenRouter publishes (USD per token),
|
||||||
|
never converted to floats here so the caller can decide how to
|
||||||
|
handle malformed or unset entries.
|
||||||
|
"""
|
||||||
|
return dict(self._raw_pricing)
|
||||||
|
|
|
||||||
274
surfsense_backend/app/services/pricing_registration.py
Normal file
274
surfsense_backend/app/services/pricing_registration.py
Normal file
|
|
@ -0,0 +1,274 @@
|
||||||
|
"""
|
||||||
|
Pricing registration with LiteLLM.
|
||||||
|
|
||||||
|
Many models reach our LiteLLM callback without LiteLLM knowing their
|
||||||
|
per-token cost — namely:
|
||||||
|
|
||||||
|
* The ~300 dynamic OpenRouter deployments (their pricing only lives on
|
||||||
|
OpenRouter's ``/api/v1/models`` payload, never in LiteLLM's published
|
||||||
|
pricing table).
|
||||||
|
* Static YAML deployments whose ``base_model`` name is operator-defined
|
||||||
|
(e.g. custom Azure deployment names like ``gpt-5.4``) and therefore
|
||||||
|
not in LiteLLM's table either.
|
||||||
|
|
||||||
|
Without registration, ``kwargs["response_cost"]`` is 0 for those calls
|
||||||
|
and the user gets billed nothing — a fail-safe but wrong answer for a
|
||||||
|
cost-based credit system. This module runs once at startup, after the
|
||||||
|
OpenRouter integration has fetched its catalogue, and registers each
|
||||||
|
known model's pricing with ``litellm.register_model()`` under multiple
|
||||||
|
plausible alias keys (LiteLLM's cost lookup may use any of them
|
||||||
|
depending on whether the call went through the Router, ChatLiteLLM,
|
||||||
|
or a direct ``acompletion``).
|
||||||
|
|
||||||
|
Operators who run a custom Azure deployment whose ``base_model`` name
|
||||||
|
isn't in LiteLLM's table can declare per-token pricing inline in
|
||||||
|
``global_llm_config.yaml`` via ``input_cost_per_token`` and
|
||||||
|
``output_cost_per_token`` (USD per token, e.g. ``0.000002``). Without
|
||||||
|
that declaration the model's calls debit 0 — never overbilled.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _safe_float(value: Any) -> float:
|
||||||
|
"""Return ``float(value)`` if it parses to a positive number, else 0.0."""
|
||||||
|
if value is None:
|
||||||
|
return 0.0
|
||||||
|
try:
|
||||||
|
f = float(value)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
return 0.0
|
||||||
|
return f if f > 0 else 0.0
|
||||||
|
|
||||||
|
|
||||||
|
def _alias_set_for_openrouter(model_id: str) -> list[str]:
|
||||||
|
"""Return the alias keys to register an OpenRouter model under.
|
||||||
|
|
||||||
|
LiteLLM's cost-callback lookup key varies by call path:
|
||||||
|
- Router with ``model="openrouter/X"`` → kwargs["model"] is
|
||||||
|
typically ``openrouter/X``.
|
||||||
|
- LiteLLM's own provider routing may strip the prefix and pass the
|
||||||
|
bare ``X`` to the cost-table lookup.
|
||||||
|
Registering under both keeps the lookup hermetic regardless of
|
||||||
|
which path the call took.
|
||||||
|
"""
|
||||||
|
aliases = [f"openrouter/{model_id}", model_id]
|
||||||
|
return list(dict.fromkeys(a for a in aliases if a))
|
||||||
|
|
||||||
|
|
||||||
|
def _alias_set_for_yaml(provider: str, model_name: str, base_model: str) -> list[str]:
|
||||||
|
"""Return the alias keys to register a static YAML deployment under.
|
||||||
|
|
||||||
|
Same reasoning as the OpenRouter set: cover the bare ``base_model``,
|
||||||
|
the ``<provider>/<model>`` form LiteLLM Router constructs, and the
|
||||||
|
bare ``model_name`` because callbacks sometimes see whichever was
|
||||||
|
configured first.
|
||||||
|
"""
|
||||||
|
provider_lower = (provider or "").lower()
|
||||||
|
aliases: list[str] = []
|
||||||
|
if base_model:
|
||||||
|
aliases.append(base_model)
|
||||||
|
if provider_lower and base_model:
|
||||||
|
aliases.append(f"{provider_lower}/{base_model}")
|
||||||
|
if model_name and model_name != base_model:
|
||||||
|
aliases.append(model_name)
|
||||||
|
if provider_lower and model_name and model_name != base_model:
|
||||||
|
aliases.append(f"{provider_lower}/{model_name}")
|
||||||
|
# Azure deployments often surface as "azure/<name>"; normalise the
|
||||||
|
# ``azure_openai`` provider slug to the LiteLLM-canonical ``azure``.
|
||||||
|
if provider_lower == "azure_openai":
|
||||||
|
if base_model:
|
||||||
|
aliases.append(f"azure/{base_model}")
|
||||||
|
if model_name and model_name != base_model:
|
||||||
|
aliases.append(f"azure/{model_name}")
|
||||||
|
return list(dict.fromkeys(a for a in aliases if a))
|
||||||
|
|
||||||
|
|
||||||
|
def _register(
|
||||||
|
aliases: list[str],
|
||||||
|
*,
|
||||||
|
input_cost: float,
|
||||||
|
output_cost: float,
|
||||||
|
provider: str,
|
||||||
|
mode: str = "chat",
|
||||||
|
) -> int:
|
||||||
|
"""Register a single pricing entry under every alias in ``aliases``.
|
||||||
|
|
||||||
|
Returns the count of aliases successfully registered.
|
||||||
|
"""
|
||||||
|
payload: dict[str, dict[str, Any]] = {}
|
||||||
|
for alias in aliases:
|
||||||
|
payload[alias] = {
|
||||||
|
"input_cost_per_token": input_cost,
|
||||||
|
"output_cost_per_token": output_cost,
|
||||||
|
"litellm_provider": provider,
|
||||||
|
"mode": mode,
|
||||||
|
}
|
||||||
|
if not payload:
|
||||||
|
return 0
|
||||||
|
try:
|
||||||
|
litellm.register_model(payload)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning(
|
||||||
|
"[PricingRegistration] register_model failed for aliases=%s: %s",
|
||||||
|
aliases,
|
||||||
|
exc,
|
||||||
|
)
|
||||||
|
return 0
|
||||||
|
return len(payload)
|
||||||
|
|
||||||
|
|
||||||
|
def _register_chat_shape_configs(
|
||||||
|
configs: list[dict],
|
||||||
|
*,
|
||||||
|
or_pricing: dict[str, dict[str, str]],
|
||||||
|
label: str,
|
||||||
|
) -> tuple[int, int, int, list[str]]:
|
||||||
|
"""Common loop that registers per-token pricing for a list of "chat-shape"
|
||||||
|
configs (chat or vision LLM — both use ``input_cost_per_token`` /
|
||||||
|
``output_cost_per_token`` and the LiteLLM ``mode="chat"`` cost shape).
|
||||||
|
|
||||||
|
Returns ``(registered_models, registered_aliases, skipped, sample_keys)``.
|
||||||
|
"""
|
||||||
|
registered_models = 0
|
||||||
|
registered_aliases = 0
|
||||||
|
skipped_no_pricing = 0
|
||||||
|
sample_keys: list[str] = []
|
||||||
|
|
||||||
|
for cfg in configs:
|
||||||
|
provider = str(cfg.get("provider") or "").upper()
|
||||||
|
model_name = str(cfg.get("model_name") or "").strip()
|
||||||
|
litellm_params = cfg.get("litellm_params") or {}
|
||||||
|
base_model = str(litellm_params.get("base_model") or model_name).strip()
|
||||||
|
|
||||||
|
if provider == "OPENROUTER":
|
||||||
|
entry = or_pricing.get(model_name)
|
||||||
|
if entry:
|
||||||
|
input_cost = _safe_float(entry.get("prompt"))
|
||||||
|
output_cost = _safe_float(entry.get("completion"))
|
||||||
|
else:
|
||||||
|
# Vision configs from ``_generate_vision_llm_configs``
|
||||||
|
# carry their pricing inline because the OpenRouter
|
||||||
|
# raw-pricing cache is keyed by chat-catalogue model_id;
|
||||||
|
# vision flows pick up the inline values here.
|
||||||
|
input_cost = _safe_float(cfg.get("input_cost_per_token"))
|
||||||
|
output_cost = _safe_float(cfg.get("output_cost_per_token"))
|
||||||
|
if input_cost == 0.0 and output_cost == 0.0:
|
||||||
|
skipped_no_pricing += 1
|
||||||
|
continue
|
||||||
|
aliases = _alias_set_for_openrouter(model_name)
|
||||||
|
count = _register(
|
||||||
|
aliases,
|
||||||
|
input_cost=input_cost,
|
||||||
|
output_cost=output_cost,
|
||||||
|
provider="openrouter",
|
||||||
|
)
|
||||||
|
if count > 0:
|
||||||
|
registered_models += 1
|
||||||
|
registered_aliases += count
|
||||||
|
if len(sample_keys) < 6:
|
||||||
|
sample_keys.extend(aliases[:2])
|
||||||
|
continue
|
||||||
|
|
||||||
|
input_cost = _safe_float(
|
||||||
|
cfg.get("input_cost_per_token")
|
||||||
|
or litellm_params.get("input_cost_per_token")
|
||||||
|
)
|
||||||
|
output_cost = _safe_float(
|
||||||
|
cfg.get("output_cost_per_token")
|
||||||
|
or litellm_params.get("output_cost_per_token")
|
||||||
|
)
|
||||||
|
if input_cost == 0.0 and output_cost == 0.0:
|
||||||
|
skipped_no_pricing += 1
|
||||||
|
continue
|
||||||
|
aliases = _alias_set_for_yaml(provider, model_name, base_model)
|
||||||
|
provider_slug = "azure" if provider == "AZURE_OPENAI" else provider.lower()
|
||||||
|
count = _register(
|
||||||
|
aliases,
|
||||||
|
input_cost=input_cost,
|
||||||
|
output_cost=output_cost,
|
||||||
|
provider=provider_slug,
|
||||||
|
)
|
||||||
|
if count > 0:
|
||||||
|
registered_models += 1
|
||||||
|
registered_aliases += count
|
||||||
|
if len(sample_keys) < 6:
|
||||||
|
sample_keys.extend(aliases[:2])
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"[PricingRegistration:%s] registered pricing for %d models (%d aliases); "
|
||||||
|
"%d configs had no pricing data; sample registered keys=%s",
|
||||||
|
label,
|
||||||
|
registered_models,
|
||||||
|
registered_aliases,
|
||||||
|
skipped_no_pricing,
|
||||||
|
sample_keys,
|
||||||
|
)
|
||||||
|
return registered_models, registered_aliases, skipped_no_pricing, sample_keys
|
||||||
|
|
||||||
|
|
||||||
|
def register_pricing_from_global_configs() -> None:
|
||||||
|
"""Register pricing for every known LLM deployment with LiteLLM.
|
||||||
|
|
||||||
|
Walks ``config.GLOBAL_LLM_CONFIGS`` *and* ``config.GLOBAL_VISION_LLM_CONFIGS``
|
||||||
|
so vision calls (during indexing) can resolve cost the same way chat
|
||||||
|
calls do — namely:
|
||||||
|
|
||||||
|
1. ``OPENROUTER``: pulls the cached raw pricing from
|
||||||
|
``OpenRouterIntegrationService`` (populated during its own
|
||||||
|
startup fetch) and converts the per-token strings to floats. For
|
||||||
|
vision configs that carry pricing inline (``input_cost_per_token`` /
|
||||||
|
``output_cost_per_token`` set on the cfg itself) we fall back to
|
||||||
|
those values when the OR cache misses the model.
|
||||||
|
2. Anything else: looks for operator-declared
|
||||||
|
``input_cost_per_token`` / ``output_cost_per_token`` on the YAML
|
||||||
|
config block (top-level or nested under ``litellm_params``).
|
||||||
|
|
||||||
|
**Image generation is intentionally NOT registered here.** The cost
|
||||||
|
shape for image-gen is per-image (``output_cost_per_image``), not
|
||||||
|
per-token, and LiteLLM's ``register_model`` doesn't accept those
|
||||||
|
keys via the chat-cost path. OpenRouter image-gen models populate
|
||||||
|
``response_cost`` directly from their response header instead, and
|
||||||
|
Azure-native image-gen models are already in LiteLLM's cost map.
|
||||||
|
|
||||||
|
Calls without a resolved pair of costs are skipped, not registered
|
||||||
|
with zeros — operators who forget pricing get a "$0 debit" warning
|
||||||
|
in ``TokenTrackingCallback`` rather than silently overwriting any
|
||||||
|
pricing LiteLLM might know natively.
|
||||||
|
"""
|
||||||
|
from app.config import config as app_config
|
||||||
|
|
||||||
|
chat_configs: list[dict] = list(getattr(app_config, "GLOBAL_LLM_CONFIGS", []) or [])
|
||||||
|
vision_configs: list[dict] = list(
|
||||||
|
getattr(app_config, "GLOBAL_VISION_LLM_CONFIGS", []) or []
|
||||||
|
)
|
||||||
|
if not chat_configs and not vision_configs:
|
||||||
|
logger.info("[PricingRegistration] no global configs to register")
|
||||||
|
return
|
||||||
|
|
||||||
|
or_pricing: dict[str, dict[str, str]] = {}
|
||||||
|
try:
|
||||||
|
from app.services.openrouter_integration_service import (
|
||||||
|
OpenRouterIntegrationService,
|
||||||
|
)
|
||||||
|
|
||||||
|
if OpenRouterIntegrationService.is_initialized():
|
||||||
|
or_pricing = OpenRouterIntegrationService.get_instance().get_raw_pricing()
|
||||||
|
except Exception as exc:
|
||||||
|
logger.debug(
|
||||||
|
"[PricingRegistration] OpenRouter pricing not available yet: %s", exc
|
||||||
|
)
|
||||||
|
|
||||||
|
if chat_configs:
|
||||||
|
_register_chat_shape_configs(chat_configs, or_pricing=or_pricing, label="chat")
|
||||||
|
if vision_configs:
|
||||||
|
_register_chat_shape_configs(
|
||||||
|
vision_configs, or_pricing=or_pricing, label="vision"
|
||||||
|
)
|
||||||
107
surfsense_backend/app/services/provider_api_base.py
Normal file
107
surfsense_backend/app/services/provider_api_base.py
Normal file
|
|
@ -0,0 +1,107 @@
|
||||||
|
"""Provider-aware ``api_base`` resolution shared by chat / image-gen / vision.
|
||||||
|
|
||||||
|
LiteLLM falls back to the module-global ``litellm.api_base`` when an
|
||||||
|
individual call doesn't pass one, which silently inherits provider-agnostic
|
||||||
|
env vars like ``AZURE_OPENAI_ENDPOINT`` / ``OPENAI_API_BASE``. Without an
|
||||||
|
explicit ``api_base``, an ``openrouter/<model>`` request can end up at an
|
||||||
|
Azure endpoint and 404 with ``Resource not found`` (real reproducer:
|
||||||
|
[litellm/llms/openrouter/image_generation/transformation.py:242-263] appends
|
||||||
|
``/chat/completions`` to whatever inherited base it gets, regardless of
|
||||||
|
provider).
|
||||||
|
|
||||||
|
The chat router has had this defense for a while
|
||||||
|
(``llm_router_service.py:466-478``). This module hoists the maps + cascade
|
||||||
|
into a tiny standalone helper so vision and image-gen can share the same
|
||||||
|
source of truth without an inter-service circular import.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
|
||||||
|
PROVIDER_DEFAULT_API_BASE: dict[str, str] = {
|
||||||
|
"openrouter": "https://openrouter.ai/api/v1",
|
||||||
|
"groq": "https://api.groq.com/openai/v1",
|
||||||
|
"mistral": "https://api.mistral.ai/v1",
|
||||||
|
"perplexity": "https://api.perplexity.ai",
|
||||||
|
"xai": "https://api.x.ai/v1",
|
||||||
|
"cerebras": "https://api.cerebras.ai/v1",
|
||||||
|
"deepinfra": "https://api.deepinfra.com/v1/openai",
|
||||||
|
"fireworks_ai": "https://api.fireworks.ai/inference/v1",
|
||||||
|
"together_ai": "https://api.together.xyz/v1",
|
||||||
|
"anyscale": "https://api.endpoints.anyscale.com/v1",
|
||||||
|
"cometapi": "https://api.cometapi.com/v1",
|
||||||
|
"sambanova": "https://api.sambanova.ai/v1",
|
||||||
|
}
|
||||||
|
"""Default ``api_base`` per LiteLLM provider prefix (lowercase).
|
||||||
|
|
||||||
|
Only providers with a well-known, stable public base URL are listed —
|
||||||
|
self-hosted / BYO-endpoint providers (ollama, custom, bedrock, vertex_ai,
|
||||||
|
huggingface, databricks, cloudflare, replicate) are intentionally omitted
|
||||||
|
so their existing config-driven behaviour is preserved."""
|
||||||
|
|
||||||
|
|
||||||
|
PROVIDER_KEY_DEFAULT_API_BASE: dict[str, str] = {
|
||||||
|
"DEEPSEEK": "https://api.deepseek.com/v1",
|
||||||
|
"ALIBABA_QWEN": "https://dashscope-intl.aliyuncs.com/compatible-mode/v1",
|
||||||
|
"MOONSHOT": "https://api.moonshot.ai/v1",
|
||||||
|
"ZHIPU": "https://open.bigmodel.cn/api/paas/v4",
|
||||||
|
"MINIMAX": "https://api.minimax.io/v1",
|
||||||
|
}
|
||||||
|
"""Canonical provider key (uppercase) → base URL.
|
||||||
|
|
||||||
|
Used when the LiteLLM provider prefix is the generic ``openai`` shim but the
|
||||||
|
config's ``provider`` field tells us which API it actually is (DeepSeek,
|
||||||
|
Alibaba, Moonshot, Zhipu, MiniMax all use the ``openai`` prefix but each
|
||||||
|
has its own base URL)."""
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_api_base(
|
||||||
|
*,
|
||||||
|
provider: str | None,
|
||||||
|
provider_prefix: str | None,
|
||||||
|
config_api_base: str | None,
|
||||||
|
) -> str | None:
|
||||||
|
"""Resolve a non-Azure-leaking ``api_base`` for a deployment.
|
||||||
|
|
||||||
|
Cascade (first non-empty wins):
|
||||||
|
1. The config's own ``api_base`` (whitespace-only treated as missing).
|
||||||
|
2. ``PROVIDER_KEY_DEFAULT_API_BASE[provider.upper()]``.
|
||||||
|
3. ``PROVIDER_DEFAULT_API_BASE[provider_prefix.lower()]``.
|
||||||
|
4. ``None`` — caller should NOT set ``api_base`` and let the LiteLLM
|
||||||
|
provider integration apply its own default (e.g. AzureOpenAI's
|
||||||
|
deployment-derived URL, custom provider's per-deployment URL).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
provider: The config's ``provider`` field (e.g. ``"OPENROUTER"``,
|
||||||
|
``"DEEPSEEK"``). Case-insensitive.
|
||||||
|
provider_prefix: The LiteLLM model-string prefix the same call
|
||||||
|
site builds for the model id (e.g. ``"openrouter"``,
|
||||||
|
``"groq"``). Case-insensitive.
|
||||||
|
config_api_base: ``api_base`` from the global YAML / DB row /
|
||||||
|
OpenRouter dynamic config. Empty / whitespace-only means
|
||||||
|
"missing" — the resolver still applies the cascade.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A URL string, or ``None`` if no default applies for this provider.
|
||||||
|
"""
|
||||||
|
if config_api_base and config_api_base.strip():
|
||||||
|
return config_api_base
|
||||||
|
|
||||||
|
if provider:
|
||||||
|
key_default = PROVIDER_KEY_DEFAULT_API_BASE.get(provider.upper())
|
||||||
|
if key_default:
|
||||||
|
return key_default
|
||||||
|
|
||||||
|
if provider_prefix:
|
||||||
|
prefix_default = PROVIDER_DEFAULT_API_BASE.get(provider_prefix.lower())
|
||||||
|
if prefix_default:
|
||||||
|
return prefix_default
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"PROVIDER_DEFAULT_API_BASE",
|
||||||
|
"PROVIDER_KEY_DEFAULT_API_BASE",
|
||||||
|
"resolve_api_base",
|
||||||
|
]
|
||||||
105
surfsense_backend/app/services/quota_checked_vision_llm.py
Normal file
105
surfsense_backend/app/services/quota_checked_vision_llm.py
Normal file
|
|
@ -0,0 +1,105 @@
|
||||||
|
"""
|
||||||
|
Vision LLM proxy that enforces premium credit quota on every ``ainvoke``.
|
||||||
|
|
||||||
|
Used by :func:`app.services.llm_service.get_vision_llm` so callers in the
|
||||||
|
indexing pipeline (file processors, connector indexers, etl pipeline) can
|
||||||
|
keep invoking the LLM exactly the way they do today — ``await llm.ainvoke(...)``
|
||||||
|
— without threading ``user_id`` through every parser. The wrapper looks like
|
||||||
|
a chat model from the outside; on the inside it routes each call through
|
||||||
|
``billable_call`` so the user's premium credit pool is reserved → finalized
|
||||||
|
or released, and a ``TokenUsage`` audit row is written.
|
||||||
|
|
||||||
|
Free configs are returned unwrapped from ``get_vision_llm`` (they do not
|
||||||
|
need quota enforcement) so this class only ever wraps premium configs.
|
||||||
|
|
||||||
|
Why a wrapper instead of plumbing ``user_id`` through every caller:
|
||||||
|
|
||||||
|
* The indexer ecosystem has 8+ entry points (Google Drive, OneDrive,
|
||||||
|
Dropbox, local-folder, file-processor, ETL pipeline) each calling
|
||||||
|
``parse_with_vision_llm(...)``. Adding a ``user_id`` argument to each is
|
||||||
|
invasive, error-prone, and easy for a future indexer to forget.
|
||||||
|
* Per the design (issue M), we always debit the *search-space owner*, not
|
||||||
|
the triggering user, so ``user_id`` is fully derivable from the search
|
||||||
|
space the caller is already operating on. The wrapper captures it once
|
||||||
|
at construction time.
|
||||||
|
* ``langchain_litellm.ChatLiteLLM`` has no public hook for "before each
|
||||||
|
call run this coroutine"; subclassing isn't safe across versions because
|
||||||
|
it derives from ``BaseChatModel`` which expects specific Pydantic shapes.
|
||||||
|
Composition via attribute proxying (``__getattr__``) is robust to
|
||||||
|
upstream changes — every method other than ``ainvoke`` falls through to
|
||||||
|
the inner LLM unchanged.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from app.services.billable_calls import QuotaInsufficientError, billable_call
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class QuotaCheckedVisionLLM:
|
||||||
|
"""Composition wrapper around a langchain chat model that enforces
|
||||||
|
premium credit quota on every ``ainvoke``.
|
||||||
|
|
||||||
|
Anything other than ``ainvoke`` is forwarded to the inner model so
|
||||||
|
``invoke`` (sync), ``astream``, ``with_structured_output``, etc. all
|
||||||
|
still work — they simply bypass quota enforcement, which is fine
|
||||||
|
because the indexing pipeline only ever calls ``ainvoke`` today.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
inner_llm: Any,
|
||||||
|
*,
|
||||||
|
user_id: UUID,
|
||||||
|
search_space_id: int,
|
||||||
|
billing_tier: str,
|
||||||
|
base_model: str,
|
||||||
|
quota_reserve_tokens: int | None,
|
||||||
|
usage_type: str = "vision_extraction",
|
||||||
|
) -> None:
|
||||||
|
self._inner = inner_llm
|
||||||
|
self._user_id = user_id
|
||||||
|
self._search_space_id = search_space_id
|
||||||
|
self._billing_tier = billing_tier
|
||||||
|
self._base_model = base_model
|
||||||
|
self._quota_reserve_tokens = quota_reserve_tokens
|
||||||
|
self._usage_type = usage_type
|
||||||
|
|
||||||
|
async def ainvoke(self, input: Any, *args: Any, **kwargs: Any) -> Any:
|
||||||
|
"""Proxied async invoke that runs the underlying call inside
|
||||||
|
``billable_call``.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
QuotaInsufficientError: when the user has exhausted their
|
||||||
|
premium credit pool. Caller (``etl_pipeline_service._extract_image``)
|
||||||
|
catches this and falls back to the document parser.
|
||||||
|
"""
|
||||||
|
async with billable_call(
|
||||||
|
user_id=self._user_id,
|
||||||
|
search_space_id=self._search_space_id,
|
||||||
|
billing_tier=self._billing_tier,
|
||||||
|
base_model=self._base_model,
|
||||||
|
quota_reserve_tokens=self._quota_reserve_tokens,
|
||||||
|
usage_type=self._usage_type,
|
||||||
|
call_details={"model": self._base_model},
|
||||||
|
):
|
||||||
|
return await self._inner.ainvoke(input, *args, **kwargs)
|
||||||
|
|
||||||
|
def __getattr__(self, name: str) -> Any:
|
||||||
|
"""Forward everything else (``invoke``, ``astream``, ``bind``,
|
||||||
|
``with_structured_output``, …) to the inner model.
|
||||||
|
|
||||||
|
``__getattr__`` is only consulted when the attribute is *not*
|
||||||
|
already found on the proxy, which is exactly the contract we
|
||||||
|
want — methods we override stay on the proxy, the rest fall
|
||||||
|
through.
|
||||||
|
"""
|
||||||
|
return getattr(self._inner, name)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["QuotaCheckedVisionLLM", "QuotaInsufficientError"]
|
||||||
|
|
@ -22,6 +22,71 @@ from app.config import config
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Per-call reservation estimator (USD micro-units)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
# Minimum reserve in micros so a user with $0.0001 left can still make a tiny
|
||||||
|
# request, and so models without registered pricing reserve at least
|
||||||
|
# something while the call runs (debited 0 at finalize anyway when their
|
||||||
|
# cost can't be resolved).
|
||||||
|
_QUOTA_MIN_RESERVE_MICROS = 100
|
||||||
|
|
||||||
|
|
||||||
|
def estimate_call_reserve_micros(
|
||||||
|
*,
|
||||||
|
base_model: str,
|
||||||
|
quota_reserve_tokens: int | None,
|
||||||
|
) -> int:
|
||||||
|
"""Return the number of micro-USD to reserve for one premium call.
|
||||||
|
|
||||||
|
Computes a worst-case upper bound from LiteLLM's per-token pricing
|
||||||
|
table:
|
||||||
|
|
||||||
|
reserve_usd ≈ reserve_tokens x (input_cost + output_cost)
|
||||||
|
|
||||||
|
so the math scales with model cost — Claude Opus + 4K reserve_tokens
|
||||||
|
naturally reserves ≈ $0.36, while a cheap model reserves only a few
|
||||||
|
cents. Clamped to ``[_QUOTA_MIN_RESERVE_MICROS, QUOTA_MAX_RESERVE_MICROS]``
|
||||||
|
so a misconfigured "$1000/M" model can't lock the whole balance on
|
||||||
|
one call.
|
||||||
|
|
||||||
|
If ``litellm.get_model_info`` raises (model unknown) we fall back to
|
||||||
|
the floor — 100 micros / $0.0001 — which is enough to gate a sane
|
||||||
|
request without over-reserving for a model whose pricing the
|
||||||
|
operator hasn't declared yet.
|
||||||
|
"""
|
||||||
|
reserve_tokens = quota_reserve_tokens or config.QUOTA_MAX_RESERVE_PER_CALL
|
||||||
|
if reserve_tokens <= 0:
|
||||||
|
reserve_tokens = config.QUOTA_MAX_RESERVE_PER_CALL
|
||||||
|
|
||||||
|
try:
|
||||||
|
from litellm import get_model_info
|
||||||
|
|
||||||
|
info = get_model_info(base_model) if base_model else {}
|
||||||
|
input_cost = float(info.get("input_cost_per_token") or 0.0)
|
||||||
|
output_cost = float(info.get("output_cost_per_token") or 0.0)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.debug(
|
||||||
|
"[quota_reserve] cost lookup failed for base_model=%s: %s",
|
||||||
|
base_model,
|
||||||
|
exc,
|
||||||
|
)
|
||||||
|
input_cost = 0.0
|
||||||
|
output_cost = 0.0
|
||||||
|
|
||||||
|
if input_cost == 0.0 and output_cost == 0.0:
|
||||||
|
return _QUOTA_MIN_RESERVE_MICROS
|
||||||
|
|
||||||
|
reserve_usd = reserve_tokens * (input_cost + output_cost)
|
||||||
|
reserve_micros = round(reserve_usd * 1_000_000)
|
||||||
|
if reserve_micros < _QUOTA_MIN_RESERVE_MICROS:
|
||||||
|
reserve_micros = _QUOTA_MIN_RESERVE_MICROS
|
||||||
|
if reserve_micros > config.QUOTA_MAX_RESERVE_MICROS:
|
||||||
|
reserve_micros = config.QUOTA_MAX_RESERVE_MICROS
|
||||||
|
return reserve_micros
|
||||||
|
|
||||||
|
|
||||||
class QuotaScope(StrEnum):
|
class QuotaScope(StrEnum):
|
||||||
ANONYMOUS = "anonymous"
|
ANONYMOUS = "anonymous"
|
||||||
PREMIUM = "premium"
|
PREMIUM = "premium"
|
||||||
|
|
@ -444,8 +509,16 @@ class TokenQuotaService:
|
||||||
db_session: AsyncSession,
|
db_session: AsyncSession,
|
||||||
user_id: Any,
|
user_id: Any,
|
||||||
request_id: str,
|
request_id: str,
|
||||||
reserve_tokens: int,
|
reserve_micros: int,
|
||||||
) -> QuotaResult:
|
) -> QuotaResult:
|
||||||
|
"""Reserve ``reserve_micros`` (USD micro-units) from the user's
|
||||||
|
premium credit balance.
|
||||||
|
|
||||||
|
``QuotaResult.used``/``limit``/``reserved``/``remaining`` are
|
||||||
|
all in micro-USD on this code path; callers (chat stream,
|
||||||
|
token-status route, FE display) convert to dollars by dividing
|
||||||
|
by 1_000_000.
|
||||||
|
"""
|
||||||
from app.db import User
|
from app.db import User
|
||||||
|
|
||||||
user = (
|
user = (
|
||||||
|
|
@ -465,11 +538,11 @@ class TokenQuotaService:
|
||||||
limit=0,
|
limit=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
limit = user.premium_tokens_limit
|
limit = user.premium_credit_micros_limit
|
||||||
used = user.premium_tokens_used
|
used = user.premium_credit_micros_used
|
||||||
reserved = user.premium_tokens_reserved
|
reserved = user.premium_credit_micros_reserved
|
||||||
|
|
||||||
effective = used + reserved + reserve_tokens
|
effective = used + reserved + reserve_micros
|
||||||
if effective > limit:
|
if effective > limit:
|
||||||
remaining = max(0, limit - used - reserved)
|
remaining = max(0, limit - used - reserved)
|
||||||
await db_session.rollback()
|
await db_session.rollback()
|
||||||
|
|
@ -482,10 +555,10 @@ class TokenQuotaService:
|
||||||
remaining=remaining,
|
remaining=remaining,
|
||||||
)
|
)
|
||||||
|
|
||||||
user.premium_tokens_reserved = reserved + reserve_tokens
|
user.premium_credit_micros_reserved = reserved + reserve_micros
|
||||||
await db_session.commit()
|
await db_session.commit()
|
||||||
|
|
||||||
new_reserved = reserved + reserve_tokens
|
new_reserved = reserved + reserve_micros
|
||||||
remaining = max(0, limit - used - new_reserved)
|
remaining = max(0, limit - used - new_reserved)
|
||||||
warning_threshold = int(limit * 0.8)
|
warning_threshold = int(limit * 0.8)
|
||||||
|
|
||||||
|
|
@ -510,9 +583,12 @@ class TokenQuotaService:
|
||||||
db_session: AsyncSession,
|
db_session: AsyncSession,
|
||||||
user_id: Any,
|
user_id: Any,
|
||||||
request_id: str,
|
request_id: str,
|
||||||
actual_tokens: int,
|
actual_micros: int,
|
||||||
reserved_tokens: int,
|
reserved_micros: int,
|
||||||
) -> QuotaResult:
|
) -> QuotaResult:
|
||||||
|
"""Settle the reservation: release ``reserved_micros`` and debit
|
||||||
|
``actual_micros`` (the LiteLLM-reported provider cost in micro-USD).
|
||||||
|
"""
|
||||||
from app.db import User
|
from app.db import User
|
||||||
|
|
||||||
user = (
|
user = (
|
||||||
|
|
@ -529,16 +605,18 @@ class TokenQuotaService:
|
||||||
allowed=False, status=QuotaStatus.BLOCKED, used=0, limit=0
|
allowed=False, status=QuotaStatus.BLOCKED, used=0, limit=0
|
||||||
)
|
)
|
||||||
|
|
||||||
user.premium_tokens_reserved = max(
|
user.premium_credit_micros_reserved = max(
|
||||||
0, user.premium_tokens_reserved - reserved_tokens
|
0, user.premium_credit_micros_reserved - reserved_micros
|
||||||
|
)
|
||||||
|
user.premium_credit_micros_used = (
|
||||||
|
user.premium_credit_micros_used + actual_micros
|
||||||
)
|
)
|
||||||
user.premium_tokens_used = user.premium_tokens_used + actual_tokens
|
|
||||||
|
|
||||||
await db_session.commit()
|
await db_session.commit()
|
||||||
|
|
||||||
limit = user.premium_tokens_limit
|
limit = user.premium_credit_micros_limit
|
||||||
used = user.premium_tokens_used
|
used = user.premium_credit_micros_used
|
||||||
reserved = user.premium_tokens_reserved
|
reserved = user.premium_credit_micros_reserved
|
||||||
remaining = max(0, limit - used - reserved)
|
remaining = max(0, limit - used - reserved)
|
||||||
|
|
||||||
warning_threshold = int(limit * 0.8)
|
warning_threshold = int(limit * 0.8)
|
||||||
|
|
@ -562,8 +640,13 @@ class TokenQuotaService:
|
||||||
async def premium_release(
|
async def premium_release(
|
||||||
db_session: AsyncSession,
|
db_session: AsyncSession,
|
||||||
user_id: Any,
|
user_id: Any,
|
||||||
reserved_tokens: int,
|
reserved_micros: int,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
"""Release ``reserved_micros`` previously held by ``premium_reserve``.
|
||||||
|
|
||||||
|
Used when a request fails before finalize (so the reservation
|
||||||
|
doesn't leak credit).
|
||||||
|
"""
|
||||||
from app.db import User
|
from app.db import User
|
||||||
|
|
||||||
user = (
|
user = (
|
||||||
|
|
@ -576,8 +659,8 @@ class TokenQuotaService:
|
||||||
.scalar_one_or_none()
|
.scalar_one_or_none()
|
||||||
)
|
)
|
||||||
if user is not None:
|
if user is not None:
|
||||||
user.premium_tokens_reserved = max(
|
user.premium_credit_micros_reserved = max(
|
||||||
0, user.premium_tokens_reserved - reserved_tokens
|
0, user.premium_credit_micros_reserved - reserved_micros
|
||||||
)
|
)
|
||||||
await db_session.commit()
|
await db_session.commit()
|
||||||
|
|
||||||
|
|
@ -598,9 +681,9 @@ class TokenQuotaService:
|
||||||
allowed=False, status=QuotaStatus.BLOCKED, used=0, limit=0
|
allowed=False, status=QuotaStatus.BLOCKED, used=0, limit=0
|
||||||
)
|
)
|
||||||
|
|
||||||
limit = user.premium_tokens_limit
|
limit = user.premium_credit_micros_limit
|
||||||
used = user.premium_tokens_used
|
used = user.premium_credit_micros_used
|
||||||
reserved = user.premium_tokens_reserved
|
reserved = user.premium_credit_micros_reserved
|
||||||
remaining = max(0, limit - used - reserved)
|
remaining = max(0, limit - used - reserved)
|
||||||
|
|
||||||
warning_threshold = int(limit * 0.8)
|
warning_threshold = int(limit * 0.8)
|
||||||
|
|
|
||||||
|
|
@ -16,11 +16,14 @@ from __future__ import annotations
|
||||||
|
|
||||||
import dataclasses
|
import dataclasses
|
||||||
import logging
|
import logging
|
||||||
|
from collections.abc import AsyncIterator
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
from contextvars import ContextVar
|
from contextvars import ContextVar
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
|
import litellm
|
||||||
from litellm.integrations.custom_logger import CustomLogger
|
from litellm.integrations.custom_logger import CustomLogger
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
|
@ -35,6 +38,8 @@ class TokenCallRecord:
|
||||||
prompt_tokens: int
|
prompt_tokens: int
|
||||||
completion_tokens: int
|
completion_tokens: int
|
||||||
total_tokens: int
|
total_tokens: int
|
||||||
|
cost_micros: int = 0
|
||||||
|
call_kind: str = "chat"
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|
@ -49,6 +54,8 @@ class TurnTokenAccumulator:
|
||||||
prompt_tokens: int,
|
prompt_tokens: int,
|
||||||
completion_tokens: int,
|
completion_tokens: int,
|
||||||
total_tokens: int,
|
total_tokens: int,
|
||||||
|
cost_micros: int = 0,
|
||||||
|
call_kind: str = "chat",
|
||||||
) -> None:
|
) -> None:
|
||||||
self.calls.append(
|
self.calls.append(
|
||||||
TokenCallRecord(
|
TokenCallRecord(
|
||||||
|
|
@ -56,20 +63,28 @@ class TurnTokenAccumulator:
|
||||||
prompt_tokens=prompt_tokens,
|
prompt_tokens=prompt_tokens,
|
||||||
completion_tokens=completion_tokens,
|
completion_tokens=completion_tokens,
|
||||||
total_tokens=total_tokens,
|
total_tokens=total_tokens,
|
||||||
|
cost_micros=cost_micros,
|
||||||
|
call_kind=call_kind,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
def per_message_summary(self) -> dict[str, dict[str, int]]:
|
def per_message_summary(self) -> dict[str, dict[str, int]]:
|
||||||
"""Return token counts grouped by model name."""
|
"""Return token counts (and cost) grouped by model name."""
|
||||||
by_model: dict[str, dict[str, int]] = {}
|
by_model: dict[str, dict[str, int]] = {}
|
||||||
for c in self.calls:
|
for c in self.calls:
|
||||||
entry = by_model.setdefault(
|
entry = by_model.setdefault(
|
||||||
c.model,
|
c.model,
|
||||||
{"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
|
{
|
||||||
|
"prompt_tokens": 0,
|
||||||
|
"completion_tokens": 0,
|
||||||
|
"total_tokens": 0,
|
||||||
|
"cost_micros": 0,
|
||||||
|
},
|
||||||
)
|
)
|
||||||
entry["prompt_tokens"] += c.prompt_tokens
|
entry["prompt_tokens"] += c.prompt_tokens
|
||||||
entry["completion_tokens"] += c.completion_tokens
|
entry["completion_tokens"] += c.completion_tokens
|
||||||
entry["total_tokens"] += c.total_tokens
|
entry["total_tokens"] += c.total_tokens
|
||||||
|
entry["cost_micros"] += c.cost_micros
|
||||||
return by_model
|
return by_model
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|
@ -84,6 +99,21 @@ class TurnTokenAccumulator:
|
||||||
def total_completion_tokens(self) -> int:
|
def total_completion_tokens(self) -> int:
|
||||||
return sum(c.completion_tokens for c in self.calls)
|
return sum(c.completion_tokens for c in self.calls)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def total_cost_micros(self) -> int:
|
||||||
|
"""Sum of per-call ``cost_micros`` across the entire turn.
|
||||||
|
|
||||||
|
Used by ``stream_new_chat`` to debit a premium turn's actual
|
||||||
|
provider cost (in micro-USD) from the user's premium credit
|
||||||
|
balance. ``cost_micros`` per call is captured by
|
||||||
|
``TokenTrackingCallback.async_log_success_event`` from
|
||||||
|
``kwargs["response_cost"]`` (LiteLLM's auto-calculated cost),
|
||||||
|
with multiple fallback paths so OpenRouter dynamic models and
|
||||||
|
custom Azure deployments still bill correctly when our
|
||||||
|
``pricing_registration`` ran at startup.
|
||||||
|
"""
|
||||||
|
return sum(c.cost_micros for c in self.calls)
|
||||||
|
|
||||||
def serialized_calls(self) -> list[dict[str, Any]]:
|
def serialized_calls(self) -> list[dict[str, Any]]:
|
||||||
return [dataclasses.asdict(c) for c in self.calls]
|
return [dataclasses.asdict(c) for c in self.calls]
|
||||||
|
|
||||||
|
|
@ -94,7 +124,14 @@ _turn_accumulator: ContextVar[TurnTokenAccumulator | None] = ContextVar(
|
||||||
|
|
||||||
|
|
||||||
def start_turn() -> TurnTokenAccumulator:
|
def start_turn() -> TurnTokenAccumulator:
|
||||||
"""Create a fresh accumulator for the current async context and return it."""
|
"""Create a fresh accumulator for the current async context and return it.
|
||||||
|
|
||||||
|
NOTE: Used by ``stream_new_chat`` for the long-lived chat turn. For
|
||||||
|
short-lived per-call billable wrappers (image generation REST endpoint,
|
||||||
|
vision LLM during indexing) prefer :func:`scoped_turn`, which uses a
|
||||||
|
ContextVar reset token to restore the *previous* accumulator on exit and
|
||||||
|
avoids leaking call records across reservations (issue B).
|
||||||
|
"""
|
||||||
acc = TurnTokenAccumulator()
|
acc = TurnTokenAccumulator()
|
||||||
_turn_accumulator.set(acc)
|
_turn_accumulator.set(acc)
|
||||||
logger.info("[TokenTracking] start_turn: new accumulator created (id=%s)", id(acc))
|
logger.info("[TokenTracking] start_turn: new accumulator created (id=%s)", id(acc))
|
||||||
|
|
@ -105,6 +142,140 @@ def get_current_accumulator() -> TurnTokenAccumulator | None:
|
||||||
return _turn_accumulator.get()
|
return _turn_accumulator.get()
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def scoped_turn() -> AsyncIterator[TurnTokenAccumulator]:
|
||||||
|
"""Async context manager that scopes a fresh ``TurnTokenAccumulator``
|
||||||
|
for the duration of the ``async with`` block, then *resets* the
|
||||||
|
ContextVar to its previous value on exit.
|
||||||
|
|
||||||
|
This is the safe primitive for per-call billable operations
|
||||||
|
(image generation, vision LLM extraction, podcasts) that may run
|
||||||
|
inside an outer chat turn or be called sequentially from the same
|
||||||
|
background worker. Using ``ContextVar.set`` without ``reset`` (as
|
||||||
|
:func:`start_turn` does) would leak the inner accumulator into the
|
||||||
|
outer scope, causing the outer chat turn to debit cost twice.
|
||||||
|
|
||||||
|
Usage::
|
||||||
|
|
||||||
|
async with scoped_turn() as acc:
|
||||||
|
await llm.ainvoke(...)
|
||||||
|
# acc.total_cost_micros captures cost from the LiteLLM callback
|
||||||
|
# Outer accumulator (if any) is restored here.
|
||||||
|
"""
|
||||||
|
acc = TurnTokenAccumulator()
|
||||||
|
token = _turn_accumulator.set(acc)
|
||||||
|
logger.debug(
|
||||||
|
"[TokenTracking] scoped_turn: enter (acc id=%s, prev token=%s)",
|
||||||
|
id(acc),
|
||||||
|
token,
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
yield acc
|
||||||
|
finally:
|
||||||
|
_turn_accumulator.reset(token)
|
||||||
|
logger.debug(
|
||||||
|
"[TokenTracking] scoped_turn: exit (acc id=%s captured %d call(s), %d micros total)",
|
||||||
|
id(acc),
|
||||||
|
len(acc.calls),
|
||||||
|
acc.total_cost_micros,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_cost_usd(
|
||||||
|
kwargs: dict[str, Any],
|
||||||
|
response_obj: Any,
|
||||||
|
model: str,
|
||||||
|
prompt_tokens: int,
|
||||||
|
completion_tokens: int,
|
||||||
|
is_image: bool = False,
|
||||||
|
) -> float:
|
||||||
|
"""Best-effort USD cost extraction for a single LLM/image call.
|
||||||
|
|
||||||
|
Tries four sources in priority order and returns the first that
|
||||||
|
yields a positive number; returns 0.0 if all four fail (the call
|
||||||
|
will then debit nothing from the user's balance — fail-safe).
|
||||||
|
|
||||||
|
Sources:
|
||||||
|
1. ``kwargs["response_cost"]`` — LiteLLM's standard callback
|
||||||
|
field, populated for ``Router.acompletion`` since PR #12500.
|
||||||
|
2. ``response_obj._hidden_params["response_cost"]`` — same value
|
||||||
|
exposed on the response itself.
|
||||||
|
3. ``litellm.completion_cost(completion_response=response_obj)``
|
||||||
|
— recompute from the response and LiteLLM's pricing table.
|
||||||
|
4. ``litellm.cost_per_token(model, prompt_tokens, completion_tokens)``
|
||||||
|
— manual fallback for OpenRouter/custom-Azure models that
|
||||||
|
only resolve via aliases registered by
|
||||||
|
``pricing_registration`` at startup. **Skipped for image
|
||||||
|
responses** — ``cost_per_token`` does not support ``ImageResponse``
|
||||||
|
and would raise; the cost map for image-gen lives in different
|
||||||
|
keys (``output_cost_per_image``) handled by ``completion_cost``.
|
||||||
|
"""
|
||||||
|
cost = kwargs.get("response_cost")
|
||||||
|
if cost is not None:
|
||||||
|
try:
|
||||||
|
value = float(cost)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
value = 0.0
|
||||||
|
if value > 0:
|
||||||
|
return value
|
||||||
|
|
||||||
|
hidden = getattr(response_obj, "_hidden_params", None) or {}
|
||||||
|
if isinstance(hidden, dict):
|
||||||
|
cost = hidden.get("response_cost")
|
||||||
|
if cost is not None:
|
||||||
|
try:
|
||||||
|
value = float(cost)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
value = 0.0
|
||||||
|
if value > 0:
|
||||||
|
return value
|
||||||
|
|
||||||
|
try:
|
||||||
|
value = float(litellm.completion_cost(completion_response=response_obj))
|
||||||
|
if value > 0:
|
||||||
|
return value
|
||||||
|
except Exception as exc:
|
||||||
|
if is_image:
|
||||||
|
# Image-gen path: OpenRouter's image responses can omit
|
||||||
|
# ``usage.cost`` and LiteLLM's ``default_image_cost_calculator``
|
||||||
|
# then *raises* (no cost map for OpenRouter image models).
|
||||||
|
# Bail out with a warning rather than falling through to
|
||||||
|
# cost_per_token (which is also incompatible with ImageResponse).
|
||||||
|
logger.warning(
|
||||||
|
"[TokenTracking] completion_cost failed for image model=%s "
|
||||||
|
"(provider may have omitted usage.cost). Debiting 0. "
|
||||||
|
"Cause: %s",
|
||||||
|
model,
|
||||||
|
exc,
|
||||||
|
)
|
||||||
|
return 0.0
|
||||||
|
logger.debug(
|
||||||
|
"[TokenTracking] completion_cost failed for model=%s: %s", model, exc
|
||||||
|
)
|
||||||
|
|
||||||
|
if is_image:
|
||||||
|
# Never call cost_per_token for ImageResponse — keys mismatch and
|
||||||
|
# the function is documented chat-only.
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
if model and (prompt_tokens > 0 or completion_tokens > 0):
|
||||||
|
try:
|
||||||
|
prompt_cost, completion_cost = litellm.cost_per_token(
|
||||||
|
model=model,
|
||||||
|
prompt_tokens=prompt_tokens,
|
||||||
|
completion_tokens=completion_tokens,
|
||||||
|
)
|
||||||
|
value = float(prompt_cost) + float(completion_cost)
|
||||||
|
if value > 0:
|
||||||
|
return value
|
||||||
|
except Exception as exc:
|
||||||
|
logger.debug(
|
||||||
|
"[TokenTracking] cost_per_token failed for model=%s: %s", model, exc
|
||||||
|
)
|
||||||
|
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
|
||||||
class TokenTrackingCallback(CustomLogger):
|
class TokenTrackingCallback(CustomLogger):
|
||||||
"""LiteLLM callback that captures token usage into the turn accumulator."""
|
"""LiteLLM callback that captures token usage into the turn accumulator."""
|
||||||
|
|
||||||
|
|
@ -122,6 +293,13 @@ class TokenTrackingCallback(CustomLogger):
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# Detect image generation responses — they have a different usage
|
||||||
|
# shape (ImageUsage with input_tokens/output_tokens) and require a
|
||||||
|
# different cost-extraction path. We probe by class name to avoid a
|
||||||
|
# hard import dependency on litellm internals.
|
||||||
|
response_cls = type(response_obj).__name__
|
||||||
|
is_image = response_cls == "ImageResponse"
|
||||||
|
|
||||||
usage = getattr(response_obj, "usage", None)
|
usage = getattr(response_obj, "usage", None)
|
||||||
if not usage:
|
if not usage:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
|
|
@ -129,24 +307,66 @@ class TokenTrackingCallback(CustomLogger):
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
prompt_tokens = getattr(usage, "prompt_tokens", 0) or 0
|
if is_image:
|
||||||
completion_tokens = getattr(usage, "completion_tokens", 0) or 0
|
# ``ImageUsage`` exposes ``input_tokens`` / ``output_tokens``
|
||||||
total_tokens = getattr(usage, "total_tokens", 0) or 0
|
# (not prompt_tokens/completion_tokens). Several providers
|
||||||
|
# populate only one or neither (e.g. OpenRouter's gpt-image-1
|
||||||
|
# passes through `input_tokens` from the prompt but no
|
||||||
|
# completion); fall through gracefully to 0.
|
||||||
|
prompt_tokens = getattr(usage, "input_tokens", 0) or 0
|
||||||
|
completion_tokens = getattr(usage, "output_tokens", 0) or 0
|
||||||
|
total_tokens = (
|
||||||
|
getattr(usage, "total_tokens", 0) or prompt_tokens + completion_tokens
|
||||||
|
)
|
||||||
|
call_kind = "image_generation"
|
||||||
|
else:
|
||||||
|
prompt_tokens = getattr(usage, "prompt_tokens", 0) or 0
|
||||||
|
completion_tokens = getattr(usage, "completion_tokens", 0) or 0
|
||||||
|
total_tokens = getattr(usage, "total_tokens", 0) or 0
|
||||||
|
call_kind = "chat"
|
||||||
|
|
||||||
model = kwargs.get("model", "unknown")
|
model = kwargs.get("model", "unknown")
|
||||||
|
|
||||||
|
cost_usd = _extract_cost_usd(
|
||||||
|
kwargs=kwargs,
|
||||||
|
response_obj=response_obj,
|
||||||
|
model=model,
|
||||||
|
prompt_tokens=prompt_tokens,
|
||||||
|
completion_tokens=completion_tokens,
|
||||||
|
is_image=is_image,
|
||||||
|
)
|
||||||
|
cost_micros = round(cost_usd * 1_000_000) if cost_usd > 0 else 0
|
||||||
|
|
||||||
|
if cost_micros == 0 and (prompt_tokens > 0 or completion_tokens > 0):
|
||||||
|
logger.warning(
|
||||||
|
"[TokenTracking] No cost resolved for model=%s prompt=%d completion=%d "
|
||||||
|
"kind=%s — debiting 0. Register pricing via pricing_registration or YAML "
|
||||||
|
"input_cost_per_token/output_cost_per_token (or rely on response_cost "
|
||||||
|
"for image generation).",
|
||||||
|
model,
|
||||||
|
prompt_tokens,
|
||||||
|
completion_tokens,
|
||||||
|
call_kind,
|
||||||
|
)
|
||||||
|
|
||||||
acc.add(
|
acc.add(
|
||||||
model=model,
|
model=model,
|
||||||
prompt_tokens=prompt_tokens,
|
prompt_tokens=prompt_tokens,
|
||||||
completion_tokens=completion_tokens,
|
completion_tokens=completion_tokens,
|
||||||
total_tokens=total_tokens,
|
total_tokens=total_tokens,
|
||||||
|
cost_micros=cost_micros,
|
||||||
|
call_kind=call_kind,
|
||||||
)
|
)
|
||||||
logger.info(
|
logger.info(
|
||||||
"[TokenTracking] Captured: model=%s prompt=%d completion=%d total=%d (accumulator now has %d calls)",
|
"[TokenTracking] Captured: model=%s kind=%s prompt=%d completion=%d total=%d "
|
||||||
|
"cost=$%.6f (%d micros) (accumulator now has %d calls)",
|
||||||
model,
|
model,
|
||||||
|
call_kind,
|
||||||
prompt_tokens,
|
prompt_tokens,
|
||||||
completion_tokens,
|
completion_tokens,
|
||||||
total_tokens,
|
total_tokens,
|
||||||
|
cost_usd,
|
||||||
|
cost_micros,
|
||||||
len(acc.calls),
|
len(acc.calls),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -168,6 +388,7 @@ async def record_token_usage(
|
||||||
prompt_tokens: int = 0,
|
prompt_tokens: int = 0,
|
||||||
completion_tokens: int = 0,
|
completion_tokens: int = 0,
|
||||||
total_tokens: int = 0,
|
total_tokens: int = 0,
|
||||||
|
cost_micros: int = 0,
|
||||||
model_breakdown: dict[str, Any] | None = None,
|
model_breakdown: dict[str, Any] | None = None,
|
||||||
call_details: dict[str, Any] | None = None,
|
call_details: dict[str, Any] | None = None,
|
||||||
thread_id: int | None = None,
|
thread_id: int | None = None,
|
||||||
|
|
@ -185,6 +406,7 @@ async def record_token_usage(
|
||||||
prompt_tokens=prompt_tokens,
|
prompt_tokens=prompt_tokens,
|
||||||
completion_tokens=completion_tokens,
|
completion_tokens=completion_tokens,
|
||||||
total_tokens=total_tokens,
|
total_tokens=total_tokens,
|
||||||
|
cost_micros=cost_micros,
|
||||||
model_breakdown=model_breakdown,
|
model_breakdown=model_breakdown,
|
||||||
call_details=call_details,
|
call_details=call_details,
|
||||||
thread_id=thread_id,
|
thread_id=thread_id,
|
||||||
|
|
@ -194,11 +416,12 @@ async def record_token_usage(
|
||||||
)
|
)
|
||||||
session.add(record)
|
session.add(record)
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"[TokenTracking] recorded %s usage: prompt=%d completion=%d total=%d",
|
"[TokenTracking] recorded %s usage: prompt=%d completion=%d total=%d cost_micros=%d",
|
||||||
usage_type,
|
usage_type,
|
||||||
prompt_tokens,
|
prompt_tokens,
|
||||||
completion_tokens,
|
completion_tokens,
|
||||||
total_tokens,
|
total_tokens,
|
||||||
|
cost_micros,
|
||||||
)
|
)
|
||||||
return record
|
return record
|
||||||
except Exception:
|
except Exception:
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,8 @@ from typing import Any
|
||||||
|
|
||||||
from litellm import Router
|
from litellm import Router
|
||||||
|
|
||||||
|
from app.services.provider_api_base import resolve_api_base
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
VISION_AUTO_MODE_ID = 0
|
VISION_AUTO_MODE_ID = 0
|
||||||
|
|
@ -108,10 +110,11 @@ class VisionLLMRouterService:
|
||||||
if not config.get("model_name") or not config.get("api_key"):
|
if not config.get("model_name") or not config.get("api_key"):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
provider = config.get("provider", "").upper()
|
||||||
if config.get("custom_provider"):
|
if config.get("custom_provider"):
|
||||||
model_string = f"{config['custom_provider']}/{config['model_name']}"
|
provider_prefix = config["custom_provider"]
|
||||||
|
model_string = f"{provider_prefix}/{config['model_name']}"
|
||||||
else:
|
else:
|
||||||
provider = config.get("provider", "").upper()
|
|
||||||
provider_prefix = VISION_PROVIDER_MAP.get(provider, provider.lower())
|
provider_prefix = VISION_PROVIDER_MAP.get(provider, provider.lower())
|
||||||
model_string = f"{provider_prefix}/{config['model_name']}"
|
model_string = f"{provider_prefix}/{config['model_name']}"
|
||||||
|
|
||||||
|
|
@ -120,8 +123,13 @@ class VisionLLMRouterService:
|
||||||
"api_key": config.get("api_key"),
|
"api_key": config.get("api_key"),
|
||||||
}
|
}
|
||||||
|
|
||||||
if config.get("api_base"):
|
api_base = resolve_api_base(
|
||||||
litellm_params["api_base"] = config["api_base"]
|
provider=provider,
|
||||||
|
provider_prefix=provider_prefix,
|
||||||
|
config_api_base=config.get("api_base"),
|
||||||
|
)
|
||||||
|
if api_base:
|
||||||
|
litellm_params["api_base"] = api_base
|
||||||
|
|
||||||
if config.get("api_version"):
|
if config.get("api_version"):
|
||||||
litellm_params["api_version"] = config["api_version"]
|
litellm_params["api_version"] = config["api_version"]
|
||||||
|
|
|
||||||
|
|
@ -9,7 +9,13 @@ from sqlalchemy import select
|
||||||
from app.agents.podcaster.graph import graph as podcaster_graph
|
from app.agents.podcaster.graph import graph as podcaster_graph
|
||||||
from app.agents.podcaster.state import State as PodcasterState
|
from app.agents.podcaster.state import State as PodcasterState
|
||||||
from app.celery_app import celery_app
|
from app.celery_app import celery_app
|
||||||
|
from app.config import config as app_config
|
||||||
from app.db import Podcast, PodcastStatus
|
from app.db import Podcast, PodcastStatus
|
||||||
|
from app.services.billable_calls import (
|
||||||
|
QuotaInsufficientError,
|
||||||
|
_resolve_agent_billing_for_search_space,
|
||||||
|
billable_call,
|
||||||
|
)
|
||||||
from app.tasks.celery_tasks import get_celery_session_maker
|
from app.tasks.celery_tasks import get_celery_session_maker
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
@ -96,6 +102,31 @@ async def _generate_content_podcast(
|
||||||
podcast.status = PodcastStatus.GENERATING
|
podcast.status = PodcastStatus.GENERATING
|
||||||
await session.commit()
|
await session.commit()
|
||||||
|
|
||||||
|
try:
|
||||||
|
(
|
||||||
|
owner_user_id,
|
||||||
|
billing_tier,
|
||||||
|
base_model,
|
||||||
|
) = await _resolve_agent_billing_for_search_space(
|
||||||
|
session,
|
||||||
|
search_space_id,
|
||||||
|
thread_id=podcast.thread_id,
|
||||||
|
)
|
||||||
|
except ValueError as resolve_err:
|
||||||
|
logger.error(
|
||||||
|
"Podcast %s: cannot resolve billing for search_space=%s: %s",
|
||||||
|
podcast.id,
|
||||||
|
search_space_id,
|
||||||
|
resolve_err,
|
||||||
|
)
|
||||||
|
podcast.status = PodcastStatus.FAILED
|
||||||
|
await session.commit()
|
||||||
|
return {
|
||||||
|
"status": "failed",
|
||||||
|
"podcast_id": podcast.id,
|
||||||
|
"reason": "billing_resolution_failed",
|
||||||
|
}
|
||||||
|
|
||||||
graph_config = {
|
graph_config = {
|
||||||
"configurable": {
|
"configurable": {
|
||||||
"podcast_title": podcast.title,
|
"podcast_title": podcast.title,
|
||||||
|
|
@ -109,9 +140,39 @@ async def _generate_content_podcast(
|
||||||
db_session=session,
|
db_session=session,
|
||||||
)
|
)
|
||||||
|
|
||||||
graph_result = await podcaster_graph.ainvoke(
|
try:
|
||||||
initial_state, config=graph_config
|
async with billable_call(
|
||||||
)
|
user_id=owner_user_id,
|
||||||
|
search_space_id=search_space_id,
|
||||||
|
billing_tier=billing_tier,
|
||||||
|
base_model=base_model,
|
||||||
|
quota_reserve_micros_override=app_config.QUOTA_DEFAULT_PODCAST_RESERVE_MICROS,
|
||||||
|
usage_type="podcast_generation",
|
||||||
|
thread_id=podcast.thread_id,
|
||||||
|
call_details={
|
||||||
|
"podcast_id": podcast.id,
|
||||||
|
"title": podcast.title,
|
||||||
|
},
|
||||||
|
):
|
||||||
|
graph_result = await podcaster_graph.ainvoke(
|
||||||
|
initial_state, config=graph_config
|
||||||
|
)
|
||||||
|
except QuotaInsufficientError as exc:
|
||||||
|
logger.info(
|
||||||
|
"Podcast %s denied: out of premium credits "
|
||||||
|
"(used=%d/%d remaining=%d)",
|
||||||
|
podcast.id,
|
||||||
|
exc.used_micros,
|
||||||
|
exc.limit_micros,
|
||||||
|
exc.remaining_micros,
|
||||||
|
)
|
||||||
|
podcast.status = PodcastStatus.FAILED
|
||||||
|
await session.commit()
|
||||||
|
return {
|
||||||
|
"status": "failed",
|
||||||
|
"podcast_id": podcast.id,
|
||||||
|
"reason": "premium_quota_exhausted",
|
||||||
|
}
|
||||||
|
|
||||||
podcast_transcript = graph_result.get("podcast_transcript", [])
|
podcast_transcript = graph_result.get("podcast_transcript", [])
|
||||||
file_path = graph_result.get("final_podcast_file_path", "")
|
file_path = graph_result.get("final_podcast_file_path", "")
|
||||||
|
|
|
||||||
|
|
@ -9,7 +9,13 @@ from sqlalchemy import select
|
||||||
from app.agents.video_presentation.graph import graph as video_presentation_graph
|
from app.agents.video_presentation.graph import graph as video_presentation_graph
|
||||||
from app.agents.video_presentation.state import State as VideoPresentationState
|
from app.agents.video_presentation.state import State as VideoPresentationState
|
||||||
from app.celery_app import celery_app
|
from app.celery_app import celery_app
|
||||||
|
from app.config import config as app_config
|
||||||
from app.db import VideoPresentation, VideoPresentationStatus
|
from app.db import VideoPresentation, VideoPresentationStatus
|
||||||
|
from app.services.billable_calls import (
|
||||||
|
QuotaInsufficientError,
|
||||||
|
_resolve_agent_billing_for_search_space,
|
||||||
|
billable_call,
|
||||||
|
)
|
||||||
from app.tasks.celery_tasks import get_celery_session_maker
|
from app.tasks.celery_tasks import get_celery_session_maker
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
@ -97,6 +103,32 @@ async def _generate_video_presentation(
|
||||||
video_pres.status = VideoPresentationStatus.GENERATING
|
video_pres.status = VideoPresentationStatus.GENERATING
|
||||||
await session.commit()
|
await session.commit()
|
||||||
|
|
||||||
|
try:
|
||||||
|
(
|
||||||
|
owner_user_id,
|
||||||
|
billing_tier,
|
||||||
|
base_model,
|
||||||
|
) = await _resolve_agent_billing_for_search_space(
|
||||||
|
session,
|
||||||
|
search_space_id,
|
||||||
|
thread_id=video_pres.thread_id,
|
||||||
|
)
|
||||||
|
except ValueError as resolve_err:
|
||||||
|
logger.error(
|
||||||
|
"VideoPresentation %s: cannot resolve billing for "
|
||||||
|
"search_space=%s: %s",
|
||||||
|
video_pres.id,
|
||||||
|
search_space_id,
|
||||||
|
resolve_err,
|
||||||
|
)
|
||||||
|
video_pres.status = VideoPresentationStatus.FAILED
|
||||||
|
await session.commit()
|
||||||
|
return {
|
||||||
|
"status": "failed",
|
||||||
|
"video_presentation_id": video_pres.id,
|
||||||
|
"reason": "billing_resolution_failed",
|
||||||
|
}
|
||||||
|
|
||||||
graph_config = {
|
graph_config = {
|
||||||
"configurable": {
|
"configurable": {
|
||||||
"video_title": video_pres.title,
|
"video_title": video_pres.title,
|
||||||
|
|
@ -110,9 +142,39 @@ async def _generate_video_presentation(
|
||||||
db_session=session,
|
db_session=session,
|
||||||
)
|
)
|
||||||
|
|
||||||
graph_result = await video_presentation_graph.ainvoke(
|
try:
|
||||||
initial_state, config=graph_config
|
async with billable_call(
|
||||||
)
|
user_id=owner_user_id,
|
||||||
|
search_space_id=search_space_id,
|
||||||
|
billing_tier=billing_tier,
|
||||||
|
base_model=base_model,
|
||||||
|
quota_reserve_micros_override=app_config.QUOTA_DEFAULT_VIDEO_PRESENTATION_RESERVE_MICROS,
|
||||||
|
usage_type="video_presentation_generation",
|
||||||
|
thread_id=video_pres.thread_id,
|
||||||
|
call_details={
|
||||||
|
"video_presentation_id": video_pres.id,
|
||||||
|
"title": video_pres.title,
|
||||||
|
},
|
||||||
|
):
|
||||||
|
graph_result = await video_presentation_graph.ainvoke(
|
||||||
|
initial_state, config=graph_config
|
||||||
|
)
|
||||||
|
except QuotaInsufficientError as exc:
|
||||||
|
logger.info(
|
||||||
|
"VideoPresentation %s denied: out of premium credits "
|
||||||
|
"(used=%d/%d remaining=%d)",
|
||||||
|
video_pres.id,
|
||||||
|
exc.used_micros,
|
||||||
|
exc.limit_micros,
|
||||||
|
exc.remaining_micros,
|
||||||
|
)
|
||||||
|
video_pres.status = VideoPresentationStatus.FAILED
|
||||||
|
await session.commit()
|
||||||
|
return {
|
||||||
|
"status": "failed",
|
||||||
|
"video_presentation_id": video_pres.id,
|
||||||
|
"reason": "premium_quota_exhausted",
|
||||||
|
}
|
||||||
|
|
||||||
# Serialize slides (parsed content + audio info merged)
|
# Serialize slides (parsed content + audio info merged)
|
||||||
slides_raw = graph_result.get("slides", [])
|
slides_raw = graph_result.get("slides", [])
|
||||||
|
|
|
||||||
|
|
@ -2236,8 +2236,10 @@ async def stream_new_chat(
|
||||||
|
|
||||||
accumulator = start_turn()
|
accumulator = start_turn()
|
||||||
|
|
||||||
# Premium quota tracking state
|
# Premium credit (USD micro-units) tracking state. Stores the
|
||||||
_premium_reserved = 0
|
# amount reserved up front so we can release it on cancellation
|
||||||
|
# and finalize-debit the actual provider cost reported by LiteLLM.
|
||||||
|
_premium_reserved_micros = 0
|
||||||
_premium_request_id: str | None = None
|
_premium_request_id: str | None = None
|
||||||
|
|
||||||
_emit_stream_error = partial(
|
_emit_stream_error = partial(
|
||||||
|
|
@ -2331,23 +2333,28 @@ async def stream_new_chat(
|
||||||
if _needs_premium_quota:
|
if _needs_premium_quota:
|
||||||
import uuid as _uuid
|
import uuid as _uuid
|
||||||
|
|
||||||
from app.config import config as _app_config
|
from app.services.token_quota_service import (
|
||||||
from app.services.token_quota_service import TokenQuotaService
|
TokenQuotaService,
|
||||||
|
estimate_call_reserve_micros,
|
||||||
|
)
|
||||||
|
|
||||||
_premium_request_id = _uuid.uuid4().hex[:16]
|
_premium_request_id = _uuid.uuid4().hex[:16]
|
||||||
reserve_amount = min(
|
_agent_litellm_params = agent_config.litellm_params or {}
|
||||||
agent_config.quota_reserve_tokens
|
_agent_base_model = (
|
||||||
or _app_config.QUOTA_MAX_RESERVE_PER_CALL,
|
_agent_litellm_params.get("base_model") or agent_config.model_name or ""
|
||||||
_app_config.QUOTA_MAX_RESERVE_PER_CALL,
|
)
|
||||||
|
reserve_amount_micros = estimate_call_reserve_micros(
|
||||||
|
base_model=_agent_base_model,
|
||||||
|
quota_reserve_tokens=agent_config.quota_reserve_tokens,
|
||||||
)
|
)
|
||||||
async with shielded_async_session() as quota_session:
|
async with shielded_async_session() as quota_session:
|
||||||
quota_result = await TokenQuotaService.premium_reserve(
|
quota_result = await TokenQuotaService.premium_reserve(
|
||||||
db_session=quota_session,
|
db_session=quota_session,
|
||||||
user_id=UUID(user_id),
|
user_id=UUID(user_id),
|
||||||
request_id=_premium_request_id,
|
request_id=_premium_request_id,
|
||||||
reserve_tokens=reserve_amount,
|
reserve_micros=reserve_amount_micros,
|
||||||
)
|
)
|
||||||
_premium_reserved = reserve_amount
|
_premium_reserved_micros = reserve_amount_micros
|
||||||
if not quota_result.allowed:
|
if not quota_result.allowed:
|
||||||
if requested_llm_config_id == 0:
|
if requested_llm_config_id == 0:
|
||||||
try:
|
try:
|
||||||
|
|
@ -2382,7 +2389,7 @@ async def stream_new_chat(
|
||||||
yield streaming_service.format_done()
|
yield streaming_service.format_done()
|
||||||
return
|
return
|
||||||
_premium_request_id = None
|
_premium_request_id = None
|
||||||
_premium_reserved = 0
|
_premium_reserved_micros = 0
|
||||||
_log_chat_stream_error(
|
_log_chat_stream_error(
|
||||||
flow=flow,
|
flow=flow,
|
||||||
error_kind="premium_quota_exhausted",
|
error_kind="premium_quota_exhausted",
|
||||||
|
|
@ -3020,9 +3027,10 @@ async def stream_new_chat(
|
||||||
|
|
||||||
usage_summary = accumulator.per_message_summary()
|
usage_summary = accumulator.per_message_summary()
|
||||||
_perf_log.info(
|
_perf_log.info(
|
||||||
"[token_usage] interrupted new_chat: calls=%d total=%d summary=%s",
|
"[token_usage] interrupted new_chat: calls=%d total=%d cost_micros=%d summary=%s",
|
||||||
len(accumulator.calls),
|
len(accumulator.calls),
|
||||||
accumulator.grand_total,
|
accumulator.grand_total,
|
||||||
|
accumulator.total_cost_micros,
|
||||||
usage_summary,
|
usage_summary,
|
||||||
)
|
)
|
||||||
if usage_summary:
|
if usage_summary:
|
||||||
|
|
@ -3033,6 +3041,7 @@ async def stream_new_chat(
|
||||||
"prompt_tokens": accumulator.total_prompt_tokens,
|
"prompt_tokens": accumulator.total_prompt_tokens,
|
||||||
"completion_tokens": accumulator.total_completion_tokens,
|
"completion_tokens": accumulator.total_completion_tokens,
|
||||||
"total_tokens": accumulator.grand_total,
|
"total_tokens": accumulator.grand_total,
|
||||||
|
"cost_micros": accumulator.total_cost_micros,
|
||||||
"call_details": accumulator.serialized_calls(),
|
"call_details": accumulator.serialized_calls(),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
@ -3060,7 +3069,11 @@ async def stream_new_chat(
|
||||||
chat_id, generated_title
|
chat_id, generated_title
|
||||||
)
|
)
|
||||||
|
|
||||||
# Finalize premium quota with actual tokens.
|
# Finalize premium credit debit with the actual provider cost
|
||||||
|
# reported by LiteLLM, summed across every call in the turn.
|
||||||
|
# Mirrors the pre-cost behaviour of "premium turn → all calls
|
||||||
|
# count" so free sub-agent calls during a premium turn still
|
||||||
|
# contribute to the bill (they're $0 in practice anyway).
|
||||||
if _premium_request_id and user_id:
|
if _premium_request_id and user_id:
|
||||||
try:
|
try:
|
||||||
from app.services.token_quota_service import TokenQuotaService
|
from app.services.token_quota_service import TokenQuotaService
|
||||||
|
|
@ -3070,11 +3083,11 @@ async def stream_new_chat(
|
||||||
db_session=quota_session,
|
db_session=quota_session,
|
||||||
user_id=UUID(user_id),
|
user_id=UUID(user_id),
|
||||||
request_id=_premium_request_id,
|
request_id=_premium_request_id,
|
||||||
actual_tokens=accumulator.grand_total,
|
actual_micros=accumulator.total_cost_micros,
|
||||||
reserved_tokens=_premium_reserved,
|
reserved_micros=_premium_reserved_micros,
|
||||||
)
|
)
|
||||||
_premium_request_id = None
|
_premium_request_id = None
|
||||||
_premium_reserved = 0
|
_premium_reserved_micros = 0
|
||||||
except Exception:
|
except Exception:
|
||||||
logging.getLogger(__name__).warning(
|
logging.getLogger(__name__).warning(
|
||||||
"Failed to finalize premium quota for user %s",
|
"Failed to finalize premium quota for user %s",
|
||||||
|
|
@ -3084,9 +3097,10 @@ async def stream_new_chat(
|
||||||
|
|
||||||
usage_summary = accumulator.per_message_summary()
|
usage_summary = accumulator.per_message_summary()
|
||||||
_perf_log.info(
|
_perf_log.info(
|
||||||
"[token_usage] normal new_chat: calls=%d total=%d summary=%s",
|
"[token_usage] normal new_chat: calls=%d total=%d cost_micros=%d summary=%s",
|
||||||
len(accumulator.calls),
|
len(accumulator.calls),
|
||||||
accumulator.grand_total,
|
accumulator.grand_total,
|
||||||
|
accumulator.total_cost_micros,
|
||||||
usage_summary,
|
usage_summary,
|
||||||
)
|
)
|
||||||
if usage_summary:
|
if usage_summary:
|
||||||
|
|
@ -3097,6 +3111,7 @@ async def stream_new_chat(
|
||||||
"prompt_tokens": accumulator.total_prompt_tokens,
|
"prompt_tokens": accumulator.total_prompt_tokens,
|
||||||
"completion_tokens": accumulator.total_completion_tokens,
|
"completion_tokens": accumulator.total_completion_tokens,
|
||||||
"total_tokens": accumulator.grand_total,
|
"total_tokens": accumulator.grand_total,
|
||||||
|
"cost_micros": accumulator.total_cost_micros,
|
||||||
"call_details": accumulator.serialized_calls(),
|
"call_details": accumulator.serialized_calls(),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
@ -3190,7 +3205,7 @@ async def stream_new_chat(
|
||||||
end_turn(str(chat_id))
|
end_turn(str(chat_id))
|
||||||
|
|
||||||
# Release premium reservation if not finalized
|
# Release premium reservation if not finalized
|
||||||
if _premium_request_id and _premium_reserved > 0 and user_id:
|
if _premium_request_id and _premium_reserved_micros > 0 and user_id:
|
||||||
try:
|
try:
|
||||||
from app.services.token_quota_service import TokenQuotaService
|
from app.services.token_quota_service import TokenQuotaService
|
||||||
|
|
||||||
|
|
@ -3198,9 +3213,9 @@ async def stream_new_chat(
|
||||||
await TokenQuotaService.premium_release(
|
await TokenQuotaService.premium_release(
|
||||||
db_session=quota_session,
|
db_session=quota_session,
|
||||||
user_id=UUID(user_id),
|
user_id=UUID(user_id),
|
||||||
reserved_tokens=_premium_reserved,
|
reserved_micros=_premium_reserved_micros,
|
||||||
)
|
)
|
||||||
_premium_reserved = 0
|
_premium_reserved_micros = 0
|
||||||
except Exception:
|
except Exception:
|
||||||
logging.getLogger(__name__).warning(
|
logging.getLogger(__name__).warning(
|
||||||
"Failed to release premium quota for user %s", user_id
|
"Failed to release premium quota for user %s", user_id
|
||||||
|
|
@ -3369,8 +3384,8 @@ async def stream_resume_chat(
|
||||||
"[stream_resume] LLM config loaded in %.3fs", time.perf_counter() - _t0
|
"[stream_resume] LLM config loaded in %.3fs", time.perf_counter() - _t0
|
||||||
)
|
)
|
||||||
|
|
||||||
# Premium quota reservation (same logic as stream_new_chat)
|
# Premium credit reservation (same logic as stream_new_chat).
|
||||||
_resume_premium_reserved = 0
|
_resume_premium_reserved_micros = 0
|
||||||
_resume_premium_request_id: str | None = None
|
_resume_premium_request_id: str | None = None
|
||||||
_resume_needs_premium = (
|
_resume_needs_premium = (
|
||||||
agent_config is not None and user_id and agent_config.is_premium
|
agent_config is not None and user_id and agent_config.is_premium
|
||||||
|
|
@ -3378,23 +3393,30 @@ async def stream_resume_chat(
|
||||||
if _resume_needs_premium:
|
if _resume_needs_premium:
|
||||||
import uuid as _uuid
|
import uuid as _uuid
|
||||||
|
|
||||||
from app.config import config as _app_config
|
from app.services.token_quota_service import (
|
||||||
from app.services.token_quota_service import TokenQuotaService
|
TokenQuotaService,
|
||||||
|
estimate_call_reserve_micros,
|
||||||
|
)
|
||||||
|
|
||||||
_resume_premium_request_id = _uuid.uuid4().hex[:16]
|
_resume_premium_request_id = _uuid.uuid4().hex[:16]
|
||||||
reserve_amount = min(
|
_resume_litellm_params = agent_config.litellm_params or {}
|
||||||
agent_config.quota_reserve_tokens
|
_resume_base_model = (
|
||||||
or _app_config.QUOTA_MAX_RESERVE_PER_CALL,
|
_resume_litellm_params.get("base_model")
|
||||||
_app_config.QUOTA_MAX_RESERVE_PER_CALL,
|
or agent_config.model_name
|
||||||
|
or ""
|
||||||
|
)
|
||||||
|
reserve_amount_micros = estimate_call_reserve_micros(
|
||||||
|
base_model=_resume_base_model,
|
||||||
|
quota_reserve_tokens=agent_config.quota_reserve_tokens,
|
||||||
)
|
)
|
||||||
async with shielded_async_session() as quota_session:
|
async with shielded_async_session() as quota_session:
|
||||||
quota_result = await TokenQuotaService.premium_reserve(
|
quota_result = await TokenQuotaService.premium_reserve(
|
||||||
db_session=quota_session,
|
db_session=quota_session,
|
||||||
user_id=UUID(user_id),
|
user_id=UUID(user_id),
|
||||||
request_id=_resume_premium_request_id,
|
request_id=_resume_premium_request_id,
|
||||||
reserve_tokens=reserve_amount,
|
reserve_micros=reserve_amount_micros,
|
||||||
)
|
)
|
||||||
_resume_premium_reserved = reserve_amount
|
_resume_premium_reserved_micros = reserve_amount_micros
|
||||||
if not quota_result.allowed:
|
if not quota_result.allowed:
|
||||||
if requested_llm_config_id == 0:
|
if requested_llm_config_id == 0:
|
||||||
try:
|
try:
|
||||||
|
|
@ -3429,7 +3451,7 @@ async def stream_resume_chat(
|
||||||
yield streaming_service.format_done()
|
yield streaming_service.format_done()
|
||||||
return
|
return
|
||||||
_resume_premium_request_id = None
|
_resume_premium_request_id = None
|
||||||
_resume_premium_reserved = 0
|
_resume_premium_reserved_micros = 0
|
||||||
_log_chat_stream_error(
|
_log_chat_stream_error(
|
||||||
flow="resume",
|
flow="resume",
|
||||||
error_kind="premium_quota_exhausted",
|
error_kind="premium_quota_exhausted",
|
||||||
|
|
@ -3746,9 +3768,10 @@ async def stream_resume_chat(
|
||||||
if stream_result.is_interrupted:
|
if stream_result.is_interrupted:
|
||||||
usage_summary = accumulator.per_message_summary()
|
usage_summary = accumulator.per_message_summary()
|
||||||
_perf_log.info(
|
_perf_log.info(
|
||||||
"[token_usage] interrupted resume_chat: calls=%d total=%d summary=%s",
|
"[token_usage] interrupted resume_chat: calls=%d total=%d cost_micros=%d summary=%s",
|
||||||
len(accumulator.calls),
|
len(accumulator.calls),
|
||||||
accumulator.grand_total,
|
accumulator.grand_total,
|
||||||
|
accumulator.total_cost_micros,
|
||||||
usage_summary,
|
usage_summary,
|
||||||
)
|
)
|
||||||
if usage_summary:
|
if usage_summary:
|
||||||
|
|
@ -3759,6 +3782,7 @@ async def stream_resume_chat(
|
||||||
"prompt_tokens": accumulator.total_prompt_tokens,
|
"prompt_tokens": accumulator.total_prompt_tokens,
|
||||||
"completion_tokens": accumulator.total_completion_tokens,
|
"completion_tokens": accumulator.total_completion_tokens,
|
||||||
"total_tokens": accumulator.grand_total,
|
"total_tokens": accumulator.grand_total,
|
||||||
|
"cost_micros": accumulator.total_cost_micros,
|
||||||
"call_details": accumulator.serialized_calls(),
|
"call_details": accumulator.serialized_calls(),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
@ -3768,7 +3792,9 @@ async def stream_resume_chat(
|
||||||
yield streaming_service.format_done()
|
yield streaming_service.format_done()
|
||||||
return
|
return
|
||||||
|
|
||||||
# Finalize premium quota for resume path
|
# Finalize premium credit debit for resume path with the actual
|
||||||
|
# provider cost reported by LiteLLM (sum of cost across all
|
||||||
|
# calls in the turn).
|
||||||
if _resume_premium_request_id and user_id:
|
if _resume_premium_request_id and user_id:
|
||||||
try:
|
try:
|
||||||
from app.services.token_quota_service import TokenQuotaService
|
from app.services.token_quota_service import TokenQuotaService
|
||||||
|
|
@ -3778,11 +3804,11 @@ async def stream_resume_chat(
|
||||||
db_session=quota_session,
|
db_session=quota_session,
|
||||||
user_id=UUID(user_id),
|
user_id=UUID(user_id),
|
||||||
request_id=_resume_premium_request_id,
|
request_id=_resume_premium_request_id,
|
||||||
actual_tokens=accumulator.grand_total,
|
actual_micros=accumulator.total_cost_micros,
|
||||||
reserved_tokens=_resume_premium_reserved,
|
reserved_micros=_resume_premium_reserved_micros,
|
||||||
)
|
)
|
||||||
_resume_premium_request_id = None
|
_resume_premium_request_id = None
|
||||||
_resume_premium_reserved = 0
|
_resume_premium_reserved_micros = 0
|
||||||
except Exception:
|
except Exception:
|
||||||
logging.getLogger(__name__).warning(
|
logging.getLogger(__name__).warning(
|
||||||
"Failed to finalize premium quota for user %s (resume)",
|
"Failed to finalize premium quota for user %s (resume)",
|
||||||
|
|
@ -3792,9 +3818,10 @@ async def stream_resume_chat(
|
||||||
|
|
||||||
usage_summary = accumulator.per_message_summary()
|
usage_summary = accumulator.per_message_summary()
|
||||||
_perf_log.info(
|
_perf_log.info(
|
||||||
"[token_usage] normal resume_chat: calls=%d total=%d summary=%s",
|
"[token_usage] normal resume_chat: calls=%d total=%d cost_micros=%d summary=%s",
|
||||||
len(accumulator.calls),
|
len(accumulator.calls),
|
||||||
accumulator.grand_total,
|
accumulator.grand_total,
|
||||||
|
accumulator.total_cost_micros,
|
||||||
usage_summary,
|
usage_summary,
|
||||||
)
|
)
|
||||||
if usage_summary:
|
if usage_summary:
|
||||||
|
|
@ -3805,6 +3832,7 @@ async def stream_resume_chat(
|
||||||
"prompt_tokens": accumulator.total_prompt_tokens,
|
"prompt_tokens": accumulator.total_prompt_tokens,
|
||||||
"completion_tokens": accumulator.total_completion_tokens,
|
"completion_tokens": accumulator.total_completion_tokens,
|
||||||
"total_tokens": accumulator.grand_total,
|
"total_tokens": accumulator.grand_total,
|
||||||
|
"cost_micros": accumulator.total_cost_micros,
|
||||||
"call_details": accumulator.serialized_calls(),
|
"call_details": accumulator.serialized_calls(),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
@ -3855,7 +3883,11 @@ async def stream_resume_chat(
|
||||||
end_turn(str(chat_id))
|
end_turn(str(chat_id))
|
||||||
|
|
||||||
# Release premium reservation if not finalized
|
# Release premium reservation if not finalized
|
||||||
if _resume_premium_request_id and _resume_premium_reserved > 0 and user_id:
|
if (
|
||||||
|
_resume_premium_request_id
|
||||||
|
and _resume_premium_reserved_micros > 0
|
||||||
|
and user_id
|
||||||
|
):
|
||||||
try:
|
try:
|
||||||
from app.services.token_quota_service import TokenQuotaService
|
from app.services.token_quota_service import TokenQuotaService
|
||||||
|
|
||||||
|
|
@ -3863,9 +3895,9 @@ async def stream_resume_chat(
|
||||||
await TokenQuotaService.premium_release(
|
await TokenQuotaService.premium_release(
|
||||||
db_session=quota_session,
|
db_session=quota_session,
|
||||||
user_id=UUID(user_id),
|
user_id=UUID(user_id),
|
||||||
reserved_tokens=_resume_premium_reserved,
|
reserved_micros=_resume_premium_reserved_micros,
|
||||||
)
|
)
|
||||||
_resume_premium_reserved = 0
|
_resume_premium_reserved_micros = 0
|
||||||
except Exception:
|
except Exception:
|
||||||
logging.getLogger(__name__).warning(
|
logging.getLogger(__name__).warning(
|
||||||
"Failed to release premium quota for user %s (resume)", user_id
|
"Failed to release premium quota for user %s (resume)", user_id
|
||||||
|
|
|
||||||
138
surfsense_backend/tests/unit/routes/test_image_gen_quota.py
Normal file
138
surfsense_backend/tests/unit/routes/test_image_gen_quota.py
Normal file
|
|
@ -0,0 +1,138 @@
|
||||||
|
"""Unit tests for the image-generation route's billing-resolution helper.
|
||||||
|
|
||||||
|
End-to-end "POST /image-generations returns 402" coverage requires the
|
||||||
|
integration harness (real DB, real auth) and lives in
|
||||||
|
``tests/integration/document_upload/`` alongside the other quota tests.
|
||||||
|
This unit test focuses on the new ``_resolve_billing_for_image_gen``
|
||||||
|
helper which:
|
||||||
|
|
||||||
|
* Returns ``free`` for Auto mode, even when premium configs exist
|
||||||
|
(Auto-mode billing-tier surfacing is a follow-up).
|
||||||
|
* Returns ``free`` for user-owned BYOK configs (positive IDs).
|
||||||
|
* Returns the global config's ``billing_tier`` for negative IDs.
|
||||||
|
* Honours the per-config ``quota_reserve_micros`` override when present.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from types import SimpleNamespace
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.unit
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_resolve_billing_for_auto_mode(monkeypatch):
|
||||||
|
from app.routes import image_generation_routes
|
||||||
|
from app.services.billable_calls import DEFAULT_IMAGE_RESERVE_MICROS
|
||||||
|
|
||||||
|
search_space = SimpleNamespace(image_generation_config_id=None)
|
||||||
|
tier, model, reserve = await image_generation_routes._resolve_billing_for_image_gen(
|
||||||
|
session=None, # Not consumed on this code path.
|
||||||
|
config_id=0, # IMAGE_GEN_AUTO_MODE_ID
|
||||||
|
search_space=search_space,
|
||||||
|
)
|
||||||
|
assert tier == "free"
|
||||||
|
assert model == "auto"
|
||||||
|
assert reserve == DEFAULT_IMAGE_RESERVE_MICROS
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_resolve_billing_for_premium_global_config(monkeypatch):
|
||||||
|
from app.config import config
|
||||||
|
from app.routes import image_generation_routes
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
config,
|
||||||
|
"GLOBAL_IMAGE_GEN_CONFIGS",
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"id": -1,
|
||||||
|
"provider": "OPENAI",
|
||||||
|
"model_name": "gpt-image-1",
|
||||||
|
"billing_tier": "premium",
|
||||||
|
"quota_reserve_micros": 75_000,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": -2,
|
||||||
|
"provider": "OPENROUTER",
|
||||||
|
"model_name": "google/gemini-2.5-flash-image",
|
||||||
|
"billing_tier": "free",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
raising=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
search_space = SimpleNamespace(image_generation_config_id=None)
|
||||||
|
|
||||||
|
# Premium with override.
|
||||||
|
tier, model, reserve = await image_generation_routes._resolve_billing_for_image_gen(
|
||||||
|
session=None, config_id=-1, search_space=search_space
|
||||||
|
)
|
||||||
|
assert tier == "premium"
|
||||||
|
assert model == "openai/gpt-image-1"
|
||||||
|
assert reserve == 75_000
|
||||||
|
|
||||||
|
# Free, no override → falls back to default.
|
||||||
|
from app.services.billable_calls import DEFAULT_IMAGE_RESERVE_MICROS
|
||||||
|
|
||||||
|
tier, model, reserve = await image_generation_routes._resolve_billing_for_image_gen(
|
||||||
|
session=None, config_id=-2, search_space=search_space
|
||||||
|
)
|
||||||
|
assert tier == "free"
|
||||||
|
# Provider-prefixed model string for OpenRouter.
|
||||||
|
assert "google/gemini-2.5-flash-image" in model
|
||||||
|
assert reserve == DEFAULT_IMAGE_RESERVE_MICROS
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_resolve_billing_for_user_owned_byok_is_free():
|
||||||
|
"""User-owned BYOK configs (positive IDs) cost the user nothing on
|
||||||
|
our side — they pay the provider directly. Always free.
|
||||||
|
"""
|
||||||
|
from app.routes import image_generation_routes
|
||||||
|
from app.services.billable_calls import DEFAULT_IMAGE_RESERVE_MICROS
|
||||||
|
|
||||||
|
search_space = SimpleNamespace(image_generation_config_id=None)
|
||||||
|
tier, model, reserve = await image_generation_routes._resolve_billing_for_image_gen(
|
||||||
|
session=None, config_id=42, search_space=search_space
|
||||||
|
)
|
||||||
|
assert tier == "free"
|
||||||
|
assert model == "user_byok"
|
||||||
|
assert reserve == DEFAULT_IMAGE_RESERVE_MICROS
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_resolve_billing_falls_back_to_search_space_default(monkeypatch):
|
||||||
|
"""When the request omits ``image_generation_config_id``, the helper
|
||||||
|
must consult the search space's default — so a search space pinned
|
||||||
|
to a premium global config still gates new requests by quota.
|
||||||
|
"""
|
||||||
|
from app.config import config
|
||||||
|
from app.routes import image_generation_routes
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
config,
|
||||||
|
"GLOBAL_IMAGE_GEN_CONFIGS",
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"id": -7,
|
||||||
|
"provider": "OPENAI",
|
||||||
|
"model_name": "gpt-image-1",
|
||||||
|
"billing_tier": "premium",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
raising=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
search_space = SimpleNamespace(image_generation_config_id=-7)
|
||||||
|
(
|
||||||
|
tier,
|
||||||
|
model,
|
||||||
|
_reserve,
|
||||||
|
) = await image_generation_routes._resolve_billing_for_image_gen(
|
||||||
|
session=None, config_id=None, search_space=search_space
|
||||||
|
)
|
||||||
|
assert tier == "premium"
|
||||||
|
assert model == "openai/gpt-image-1"
|
||||||
|
|
@ -0,0 +1,436 @@
|
||||||
|
"""Unit tests for ``_resolve_agent_billing_for_search_space``.
|
||||||
|
|
||||||
|
Validates the resolver used by Celery podcast/video tasks to compute
|
||||||
|
``(owner_user_id, billing_tier, base_model)`` from a search space and its
|
||||||
|
agent LLM config. The resolver mirrors chat's billing-resolution pattern at
|
||||||
|
``stream_new_chat.py:2294-2351`` and is the single integration point that
|
||||||
|
prevents Auto-mode podcast/video from leaking premium credit.
|
||||||
|
|
||||||
|
Coverage:
|
||||||
|
|
||||||
|
* Auto mode + ``thread_id`` set, pin resolves to a negative-id premium
|
||||||
|
global → returns ``("premium", <base_model>)``.
|
||||||
|
* Auto mode + ``thread_id`` set, pin resolves to a negative-id free
|
||||||
|
global → returns ``("free", <base_model>)``.
|
||||||
|
* Auto mode + ``thread_id`` set, pin resolves to a positive-id BYOK config
|
||||||
|
→ always ``"free"``.
|
||||||
|
* Auto mode + ``thread_id=None`` → fallback to ``("free", "auto")`` without
|
||||||
|
hitting the pin service.
|
||||||
|
* Negative id (no Auto) → uses ``get_global_llm_config``'s
|
||||||
|
``billing_tier``.
|
||||||
|
* Positive id (user BYOK) → always ``"free"``.
|
||||||
|
* Search space not found → raises ``ValueError``.
|
||||||
|
* ``agent_llm_id`` is None → raises ``ValueError``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from types import SimpleNamespace
|
||||||
|
from uuid import UUID, uuid4
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.unit
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Fakes
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeExecResult:
|
||||||
|
def __init__(self, obj):
|
||||||
|
self._obj = obj
|
||||||
|
|
||||||
|
def scalars(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def first(self):
|
||||||
|
return self._obj
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeSession:
|
||||||
|
"""Tiny AsyncSession stub.
|
||||||
|
|
||||||
|
``responses`` is a list of objects to return from successive
|
||||||
|
``execute()`` calls (in order). The resolver makes at most two
|
||||||
|
``execute()`` calls (search-space lookup, then optionally NewLLMConfig
|
||||||
|
lookup), so two queued responses cover the matrix.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, responses: list):
|
||||||
|
self._responses = list(responses)
|
||||||
|
|
||||||
|
async def execute(self, _stmt):
|
||||||
|
if not self._responses:
|
||||||
|
return _FakeExecResult(None)
|
||||||
|
return _FakeExecResult(self._responses.pop(0))
|
||||||
|
|
||||||
|
async def commit(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class _FakePinResolution:
|
||||||
|
resolved_llm_config_id: int
|
||||||
|
resolved_tier: str = "premium"
|
||||||
|
from_existing_pin: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
def _make_search_space(*, agent_llm_id: int | None, user_id: UUID) -> SimpleNamespace:
|
||||||
|
return SimpleNamespace(
|
||||||
|
id=42,
|
||||||
|
agent_llm_id=agent_llm_id,
|
||||||
|
user_id=user_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_byok_config(
|
||||||
|
*, id_: int, base_model: str | None = None, model_name: str = "gpt-byok"
|
||||||
|
) -> SimpleNamespace:
|
||||||
|
return SimpleNamespace(
|
||||||
|
id=id_,
|
||||||
|
model_name=model_name,
|
||||||
|
litellm_params={"base_model": base_model} if base_model else {},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_auto_mode_with_thread_id_resolves_to_premium_global(monkeypatch):
|
||||||
|
"""Auto + thread → pin service resolves to negative-id premium config →
|
||||||
|
resolver returns ``("premium", <base_model>)``."""
|
||||||
|
from app.services.billable_calls import _resolve_agent_billing_for_search_space
|
||||||
|
|
||||||
|
user_id = uuid4()
|
||||||
|
session = _FakeSession([_make_search_space(agent_llm_id=0, user_id=user_id)])
|
||||||
|
|
||||||
|
# Mock the pin service to return a concrete premium config id.
|
||||||
|
async def _fake_resolve_pin(
|
||||||
|
sess,
|
||||||
|
*,
|
||||||
|
thread_id,
|
||||||
|
search_space_id,
|
||||||
|
user_id,
|
||||||
|
selected_llm_config_id,
|
||||||
|
force_repin_free=False,
|
||||||
|
):
|
||||||
|
assert selected_llm_config_id == 0
|
||||||
|
assert thread_id == 99
|
||||||
|
return _FakePinResolution(resolved_llm_config_id=-1, resolved_tier="premium")
|
||||||
|
|
||||||
|
# Mock global config lookup to return a premium entry.
|
||||||
|
def _fake_get_global(cfg_id):
|
||||||
|
if cfg_id == -1:
|
||||||
|
return {
|
||||||
|
"id": -1,
|
||||||
|
"model_name": "gpt-5.4",
|
||||||
|
"billing_tier": "premium",
|
||||||
|
"litellm_params": {"base_model": "gpt-5.4"},
|
||||||
|
}
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Lazy imports inside the resolver — patch the *target* modules so the
|
||||||
|
# imported names resolve to our fakes.
|
||||||
|
import app.services.auto_model_pin_service as pin_module
|
||||||
|
import app.services.llm_service as llm_module
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
pin_module, "resolve_or_get_pinned_llm_config_id", _fake_resolve_pin
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(llm_module, "get_global_llm_config", _fake_get_global)
|
||||||
|
|
||||||
|
owner, tier, base_model = await _resolve_agent_billing_for_search_space(
|
||||||
|
session, search_space_id=42, thread_id=99
|
||||||
|
)
|
||||||
|
|
||||||
|
assert owner == user_id
|
||||||
|
assert tier == "premium"
|
||||||
|
assert base_model == "gpt-5.4"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_auto_mode_with_thread_id_resolves_to_free_global(monkeypatch):
|
||||||
|
"""Auto + thread → pin returns negative-id free config → resolver
|
||||||
|
returns ``("free", <base_model>)``. Same path the pin service takes for
|
||||||
|
out-of-credit users (graceful degradation)."""
|
||||||
|
from app.services.billable_calls import _resolve_agent_billing_for_search_space
|
||||||
|
|
||||||
|
user_id = uuid4()
|
||||||
|
session = _FakeSession([_make_search_space(agent_llm_id=0, user_id=user_id)])
|
||||||
|
|
||||||
|
async def _fake_resolve_pin(
|
||||||
|
sess,
|
||||||
|
*,
|
||||||
|
thread_id,
|
||||||
|
search_space_id,
|
||||||
|
user_id,
|
||||||
|
selected_llm_config_id,
|
||||||
|
force_repin_free=False,
|
||||||
|
):
|
||||||
|
return _FakePinResolution(resolved_llm_config_id=-3, resolved_tier="free")
|
||||||
|
|
||||||
|
def _fake_get_global(cfg_id):
|
||||||
|
if cfg_id == -3:
|
||||||
|
return {
|
||||||
|
"id": -3,
|
||||||
|
"model_name": "openrouter/free-model",
|
||||||
|
"billing_tier": "free",
|
||||||
|
"litellm_params": {"base_model": "openrouter/free-model"},
|
||||||
|
}
|
||||||
|
return None
|
||||||
|
|
||||||
|
import app.services.auto_model_pin_service as pin_module
|
||||||
|
import app.services.llm_service as llm_module
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
pin_module, "resolve_or_get_pinned_llm_config_id", _fake_resolve_pin
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(llm_module, "get_global_llm_config", _fake_get_global)
|
||||||
|
|
||||||
|
owner, tier, base_model = await _resolve_agent_billing_for_search_space(
|
||||||
|
session, search_space_id=42, thread_id=99
|
||||||
|
)
|
||||||
|
|
||||||
|
assert owner == user_id
|
||||||
|
assert tier == "free"
|
||||||
|
assert base_model == "openrouter/free-model"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_auto_mode_with_thread_id_resolves_to_byok_is_free(monkeypatch):
|
||||||
|
"""Auto + thread → pin returns positive-id BYOK config → resolver
|
||||||
|
returns ``("free", ...)`` (BYOK is always free per
|
||||||
|
``AgentConfig.from_new_llm_config``)."""
|
||||||
|
from app.services.billable_calls import _resolve_agent_billing_for_search_space
|
||||||
|
|
||||||
|
user_id = uuid4()
|
||||||
|
search_space = _make_search_space(agent_llm_id=0, user_id=user_id)
|
||||||
|
byok_cfg = _make_byok_config(
|
||||||
|
id_=17, base_model="anthropic/claude-3-haiku", model_name="my-claude"
|
||||||
|
)
|
||||||
|
session = _FakeSession([search_space, byok_cfg])
|
||||||
|
|
||||||
|
async def _fake_resolve_pin(
|
||||||
|
sess,
|
||||||
|
*,
|
||||||
|
thread_id,
|
||||||
|
search_space_id,
|
||||||
|
user_id,
|
||||||
|
selected_llm_config_id,
|
||||||
|
force_repin_free=False,
|
||||||
|
):
|
||||||
|
return _FakePinResolution(resolved_llm_config_id=17, resolved_tier="free")
|
||||||
|
|
||||||
|
import app.services.auto_model_pin_service as pin_module
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
pin_module, "resolve_or_get_pinned_llm_config_id", _fake_resolve_pin
|
||||||
|
)
|
||||||
|
|
||||||
|
owner, tier, base_model = await _resolve_agent_billing_for_search_space(
|
||||||
|
session, search_space_id=42, thread_id=99
|
||||||
|
)
|
||||||
|
|
||||||
|
assert owner == user_id
|
||||||
|
assert tier == "free"
|
||||||
|
assert base_model == "anthropic/claude-3-haiku"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_auto_mode_without_thread_id_falls_back_to_free():
|
||||||
|
"""Auto + ``thread_id=None`` → ``("free", "auto")`` without invoking
|
||||||
|
the pin service. Forward-compat fallback for any future direct-API
|
||||||
|
entrypoint that doesn't have a chat thread."""
|
||||||
|
from app.services.billable_calls import _resolve_agent_billing_for_search_space
|
||||||
|
|
||||||
|
user_id = uuid4()
|
||||||
|
session = _FakeSession([_make_search_space(agent_llm_id=0, user_id=user_id)])
|
||||||
|
|
||||||
|
owner, tier, base_model = await _resolve_agent_billing_for_search_space(
|
||||||
|
session, search_space_id=42, thread_id=None
|
||||||
|
)
|
||||||
|
|
||||||
|
assert owner == user_id
|
||||||
|
assert tier == "free"
|
||||||
|
assert base_model == "auto"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_auto_mode_pin_failure_falls_back_to_free(monkeypatch):
|
||||||
|
"""If the pin service raises ``ValueError`` (thread missing /
|
||||||
|
mismatched search space), the resolver should log and return free
|
||||||
|
rather than killing the whole task."""
|
||||||
|
from app.services.billable_calls import _resolve_agent_billing_for_search_space
|
||||||
|
|
||||||
|
user_id = uuid4()
|
||||||
|
session = _FakeSession([_make_search_space(agent_llm_id=0, user_id=user_id)])
|
||||||
|
|
||||||
|
async def _fake_resolve_pin(*args, **kwargs):
|
||||||
|
raise ValueError("thread missing")
|
||||||
|
|
||||||
|
import app.services.auto_model_pin_service as pin_module
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
pin_module, "resolve_or_get_pinned_llm_config_id", _fake_resolve_pin
|
||||||
|
)
|
||||||
|
|
||||||
|
owner, tier, base_model = await _resolve_agent_billing_for_search_space(
|
||||||
|
session, search_space_id=42, thread_id=99
|
||||||
|
)
|
||||||
|
|
||||||
|
assert owner == user_id
|
||||||
|
assert tier == "free"
|
||||||
|
assert base_model == "auto"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_negative_id_premium_global_returns_premium(monkeypatch):
|
||||||
|
"""Explicit negative agent_llm_id → ``get_global_llm_config`` →
|
||||||
|
return its ``billing_tier``."""
|
||||||
|
from app.services.billable_calls import _resolve_agent_billing_for_search_space
|
||||||
|
|
||||||
|
user_id = uuid4()
|
||||||
|
session = _FakeSession([_make_search_space(agent_llm_id=-1, user_id=user_id)])
|
||||||
|
|
||||||
|
def _fake_get_global(cfg_id):
|
||||||
|
return {
|
||||||
|
"id": cfg_id,
|
||||||
|
"model_name": "gpt-5.4",
|
||||||
|
"billing_tier": "premium",
|
||||||
|
"litellm_params": {"base_model": "gpt-5.4"},
|
||||||
|
}
|
||||||
|
|
||||||
|
import app.services.llm_service as llm_module
|
||||||
|
|
||||||
|
monkeypatch.setattr(llm_module, "get_global_llm_config", _fake_get_global)
|
||||||
|
|
||||||
|
owner, tier, base_model = await _resolve_agent_billing_for_search_space(
|
||||||
|
session, search_space_id=42, thread_id=99
|
||||||
|
)
|
||||||
|
|
||||||
|
assert owner == user_id
|
||||||
|
assert tier == "premium"
|
||||||
|
assert base_model == "gpt-5.4"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_negative_id_free_global_returns_free(monkeypatch):
|
||||||
|
from app.services.billable_calls import _resolve_agent_billing_for_search_space
|
||||||
|
|
||||||
|
user_id = uuid4()
|
||||||
|
session = _FakeSession([_make_search_space(agent_llm_id=-2, user_id=user_id)])
|
||||||
|
|
||||||
|
def _fake_get_global(cfg_id):
|
||||||
|
return {
|
||||||
|
"id": cfg_id,
|
||||||
|
"model_name": "openrouter/some-free",
|
||||||
|
"billing_tier": "free",
|
||||||
|
"litellm_params": {"base_model": "openrouter/some-free"},
|
||||||
|
}
|
||||||
|
|
||||||
|
import app.services.llm_service as llm_module
|
||||||
|
|
||||||
|
monkeypatch.setattr(llm_module, "get_global_llm_config", _fake_get_global)
|
||||||
|
|
||||||
|
owner, tier, base_model = await _resolve_agent_billing_for_search_space(
|
||||||
|
session, search_space_id=42, thread_id=None
|
||||||
|
)
|
||||||
|
|
||||||
|
assert owner == user_id
|
||||||
|
assert tier == "free"
|
||||||
|
assert base_model == "openrouter/some-free"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_negative_id_missing_base_model_falls_back_to_model_name(monkeypatch):
|
||||||
|
"""When the global config has no ``litellm_params.base_model``, the
|
||||||
|
resolver falls back to ``model_name`` — matching chat's behavior."""
|
||||||
|
from app.services.billable_calls import _resolve_agent_billing_for_search_space
|
||||||
|
|
||||||
|
user_id = uuid4()
|
||||||
|
session = _FakeSession([_make_search_space(agent_llm_id=-5, user_id=user_id)])
|
||||||
|
|
||||||
|
def _fake_get_global(cfg_id):
|
||||||
|
return {
|
||||||
|
"id": cfg_id,
|
||||||
|
"model_name": "fallback-model",
|
||||||
|
"billing_tier": "premium",
|
||||||
|
# No litellm_params.
|
||||||
|
}
|
||||||
|
|
||||||
|
import app.services.llm_service as llm_module
|
||||||
|
|
||||||
|
monkeypatch.setattr(llm_module, "get_global_llm_config", _fake_get_global)
|
||||||
|
|
||||||
|
_, tier, base_model = await _resolve_agent_billing_for_search_space(
|
||||||
|
session, search_space_id=42
|
||||||
|
)
|
||||||
|
|
||||||
|
assert tier == "premium"
|
||||||
|
assert base_model == "fallback-model"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_positive_id_byok_is_always_free():
|
||||||
|
"""Positive agent_llm_id → user-owned BYOK NewLLMConfig → always free,
|
||||||
|
regardless of underlying provider tier."""
|
||||||
|
from app.services.billable_calls import _resolve_agent_billing_for_search_space
|
||||||
|
|
||||||
|
user_id = uuid4()
|
||||||
|
search_space = _make_search_space(agent_llm_id=23, user_id=user_id)
|
||||||
|
byok_cfg = _make_byok_config(id_=23, base_model="anthropic/claude-3.5-sonnet")
|
||||||
|
session = _FakeSession([search_space, byok_cfg])
|
||||||
|
|
||||||
|
owner, tier, base_model = await _resolve_agent_billing_for_search_space(
|
||||||
|
session, search_space_id=42
|
||||||
|
)
|
||||||
|
|
||||||
|
assert owner == user_id
|
||||||
|
assert tier == "free"
|
||||||
|
assert base_model == "anthropic/claude-3.5-sonnet"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_positive_id_byok_missing_returns_free_with_empty_base_model():
|
||||||
|
"""If the BYOK config row is missing/deleted but the search space still
|
||||||
|
points at it, the resolver still returns free (no debit) with an empty
|
||||||
|
base_model — billable_call's premium path is skipped, no harm done."""
|
||||||
|
from app.services.billable_calls import _resolve_agent_billing_for_search_space
|
||||||
|
|
||||||
|
user_id = uuid4()
|
||||||
|
session = _FakeSession([_make_search_space(agent_llm_id=99, user_id=user_id)])
|
||||||
|
|
||||||
|
owner, tier, base_model = await _resolve_agent_billing_for_search_space(
|
||||||
|
session, search_space_id=42
|
||||||
|
)
|
||||||
|
|
||||||
|
assert owner == user_id
|
||||||
|
assert tier == "free"
|
||||||
|
assert base_model == ""
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_search_space_not_found_raises_value_error():
|
||||||
|
from app.services.billable_calls import _resolve_agent_billing_for_search_space
|
||||||
|
|
||||||
|
session = _FakeSession([None])
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="Search space"):
|
||||||
|
await _resolve_agent_billing_for_search_space(session, search_space_id=999)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_agent_llm_id_none_raises_value_error():
|
||||||
|
from app.services.billable_calls import _resolve_agent_billing_for_search_space
|
||||||
|
|
||||||
|
user_id = uuid4()
|
||||||
|
session = _FakeSession([_make_search_space(agent_llm_id=None, user_id=user_id)])
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="agent_llm_id"):
|
||||||
|
await _resolve_agent_billing_for_search_space(session, search_space_id=42)
|
||||||
432
surfsense_backend/tests/unit/services/test_billable_call.py
Normal file
432
surfsense_backend/tests/unit/services/test_billable_call.py
Normal file
|
|
@ -0,0 +1,432 @@
|
||||||
|
"""Unit tests for the ``billable_call`` async context manager.
|
||||||
|
|
||||||
|
Covers the per-call premium-credit lifecycle for image generation and
|
||||||
|
vision LLM extraction:
|
||||||
|
|
||||||
|
* Free configs bypass reserve/finalize but still write an audit row.
|
||||||
|
* Premium reserve denial raises ``QuotaInsufficientError`` (HTTP 402 in the
|
||||||
|
route layer).
|
||||||
|
* Successful premium calls reserve, yield the accumulator, then finalize
|
||||||
|
with the LiteLLM-reported actual cost — and write an audit row.
|
||||||
|
* Failed premium calls release the reservation so credit isn't leaked.
|
||||||
|
* All quota DB ops happen inside their OWN ``shielded_async_session``,
|
||||||
|
isolating them from the caller's transaction (issue A).
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import contextlib
|
||||||
|
from typing import Any
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.unit
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Fakes
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeQuotaResult:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
allowed: bool,
|
||||||
|
used: int = 0,
|
||||||
|
limit: int = 5_000_000,
|
||||||
|
remaining: int = 5_000_000,
|
||||||
|
) -> None:
|
||||||
|
self.allowed = allowed
|
||||||
|
self.used = used
|
||||||
|
self.limit = limit
|
||||||
|
self.remaining = remaining
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeSession:
|
||||||
|
"""Minimal AsyncSession stub — record commits for assertion."""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.committed = False
|
||||||
|
self.added: list[Any] = []
|
||||||
|
|
||||||
|
def add(self, obj: Any) -> None:
|
||||||
|
self.added.append(obj)
|
||||||
|
|
||||||
|
async def commit(self) -> None:
|
||||||
|
self.committed = True
|
||||||
|
|
||||||
|
async def close(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@contextlib.asynccontextmanager
|
||||||
|
async def _fake_shielded_session():
|
||||||
|
s = _FakeSession()
|
||||||
|
_SESSIONS_USED.append(s)
|
||||||
|
yield s
|
||||||
|
|
||||||
|
|
||||||
|
_SESSIONS_USED: list[_FakeSession] = []
|
||||||
|
|
||||||
|
|
||||||
|
def _patch_isolation_layer(monkeypatch, *, reserve_result, finalize_result=None):
|
||||||
|
"""Wire fake reserve/finalize/release/session helpers."""
|
||||||
|
_SESSIONS_USED.clear()
|
||||||
|
reserve_calls: list[dict[str, Any]] = []
|
||||||
|
finalize_calls: list[dict[str, Any]] = []
|
||||||
|
release_calls: list[dict[str, Any]] = []
|
||||||
|
|
||||||
|
async def _fake_reserve(*, db_session, user_id, request_id, reserve_micros):
|
||||||
|
reserve_calls.append(
|
||||||
|
{
|
||||||
|
"user_id": user_id,
|
||||||
|
"reserve_micros": reserve_micros,
|
||||||
|
"request_id": request_id,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return reserve_result
|
||||||
|
|
||||||
|
async def _fake_finalize(
|
||||||
|
*, db_session, user_id, request_id, actual_micros, reserved_micros
|
||||||
|
):
|
||||||
|
finalize_calls.append(
|
||||||
|
{
|
||||||
|
"user_id": user_id,
|
||||||
|
"actual_micros": actual_micros,
|
||||||
|
"reserved_micros": reserved_micros,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return finalize_result or _FakeQuotaResult(allowed=True)
|
||||||
|
|
||||||
|
async def _fake_release(*, db_session, user_id, reserved_micros):
|
||||||
|
release_calls.append({"user_id": user_id, "reserved_micros": reserved_micros})
|
||||||
|
|
||||||
|
record_calls: list[dict[str, Any]] = []
|
||||||
|
|
||||||
|
async def _fake_record(session, **kwargs):
|
||||||
|
record_calls.append(kwargs)
|
||||||
|
return object()
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"app.services.billable_calls.TokenQuotaService.premium_reserve",
|
||||||
|
_fake_reserve,
|
||||||
|
raising=False,
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"app.services.billable_calls.TokenQuotaService.premium_finalize",
|
||||||
|
_fake_finalize,
|
||||||
|
raising=False,
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"app.services.billable_calls.TokenQuotaService.premium_release",
|
||||||
|
_fake_release,
|
||||||
|
raising=False,
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"app.services.billable_calls.shielded_async_session",
|
||||||
|
_fake_shielded_session,
|
||||||
|
raising=False,
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"app.services.billable_calls.record_token_usage",
|
||||||
|
_fake_record,
|
||||||
|
raising=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"reserve": reserve_calls,
|
||||||
|
"finalize": finalize_calls,
|
||||||
|
"release": release_calls,
|
||||||
|
"record": record_calls,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_free_path_skips_reserve_but_writes_audit_row(monkeypatch):
|
||||||
|
from app.services.billable_calls import billable_call
|
||||||
|
|
||||||
|
spies = _patch_isolation_layer(
|
||||||
|
monkeypatch, reserve_result=_FakeQuotaResult(allowed=True)
|
||||||
|
)
|
||||||
|
user_id = uuid4()
|
||||||
|
|
||||||
|
async with billable_call(
|
||||||
|
user_id=user_id,
|
||||||
|
search_space_id=42,
|
||||||
|
billing_tier="free",
|
||||||
|
base_model="openai/gpt-image-1",
|
||||||
|
usage_type="image_generation",
|
||||||
|
) as acc:
|
||||||
|
# Simulate a captured cost — the accumulator is fed by the LiteLLM
|
||||||
|
# callback in real life, here we add it manually.
|
||||||
|
acc.add(
|
||||||
|
model="openai/gpt-image-1",
|
||||||
|
prompt_tokens=0,
|
||||||
|
completion_tokens=0,
|
||||||
|
total_tokens=0,
|
||||||
|
cost_micros=37_000,
|
||||||
|
call_kind="image_generation",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert spies["reserve"] == []
|
||||||
|
assert spies["finalize"] == []
|
||||||
|
assert spies["release"] == []
|
||||||
|
# Free still audits.
|
||||||
|
assert len(spies["record"]) == 1
|
||||||
|
assert spies["record"][0]["usage_type"] == "image_generation"
|
||||||
|
assert spies["record"][0]["cost_micros"] == 37_000
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_premium_reserve_denied_raises_quota_insufficient(monkeypatch):
|
||||||
|
from app.services.billable_calls import (
|
||||||
|
QuotaInsufficientError,
|
||||||
|
billable_call,
|
||||||
|
)
|
||||||
|
|
||||||
|
spies = _patch_isolation_layer(
|
||||||
|
monkeypatch,
|
||||||
|
reserve_result=_FakeQuotaResult(
|
||||||
|
allowed=False, used=5_000_000, limit=5_000_000, remaining=0
|
||||||
|
),
|
||||||
|
)
|
||||||
|
user_id = uuid4()
|
||||||
|
|
||||||
|
with pytest.raises(QuotaInsufficientError) as exc_info:
|
||||||
|
async with billable_call(
|
||||||
|
user_id=user_id,
|
||||||
|
search_space_id=42,
|
||||||
|
billing_tier="premium",
|
||||||
|
base_model="openai/gpt-image-1",
|
||||||
|
quota_reserve_micros_override=50_000,
|
||||||
|
usage_type="image_generation",
|
||||||
|
):
|
||||||
|
pytest.fail("body should not run when reserve is denied")
|
||||||
|
|
||||||
|
err = exc_info.value
|
||||||
|
assert err.usage_type == "image_generation"
|
||||||
|
assert err.used_micros == 5_000_000
|
||||||
|
assert err.limit_micros == 5_000_000
|
||||||
|
assert err.remaining_micros == 0
|
||||||
|
# Reserve was attempted, but no finalize/release on a denied reserve
|
||||||
|
# — the reservation never actually held credit.
|
||||||
|
assert len(spies["reserve"]) == 1
|
||||||
|
assert spies["finalize"] == []
|
||||||
|
assert spies["release"] == []
|
||||||
|
# Denied premium calls do NOT create an audit row (no work happened).
|
||||||
|
assert spies["record"] == []
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_premium_success_finalizes_with_actual_cost(monkeypatch):
|
||||||
|
from app.services.billable_calls import billable_call
|
||||||
|
|
||||||
|
spies = _patch_isolation_layer(
|
||||||
|
monkeypatch, reserve_result=_FakeQuotaResult(allowed=True)
|
||||||
|
)
|
||||||
|
user_id = uuid4()
|
||||||
|
|
||||||
|
async with billable_call(
|
||||||
|
user_id=user_id,
|
||||||
|
search_space_id=42,
|
||||||
|
billing_tier="premium",
|
||||||
|
base_model="openai/gpt-image-1",
|
||||||
|
quota_reserve_micros_override=50_000,
|
||||||
|
usage_type="image_generation",
|
||||||
|
) as acc:
|
||||||
|
# LiteLLM callback would normally fill this — simulate $0.04 image.
|
||||||
|
acc.add(
|
||||||
|
model="openai/gpt-image-1",
|
||||||
|
prompt_tokens=0,
|
||||||
|
completion_tokens=0,
|
||||||
|
total_tokens=0,
|
||||||
|
cost_micros=40_000,
|
||||||
|
call_kind="image_generation",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(spies["reserve"]) == 1
|
||||||
|
assert spies["reserve"][0]["reserve_micros"] == 50_000
|
||||||
|
assert len(spies["finalize"]) == 1
|
||||||
|
assert spies["finalize"][0]["actual_micros"] == 40_000
|
||||||
|
assert spies["finalize"][0]["reserved_micros"] == 50_000
|
||||||
|
assert spies["release"] == []
|
||||||
|
# And audit row written with the actual debited cost.
|
||||||
|
assert spies["record"][0]["cost_micros"] == 40_000
|
||||||
|
# Each quota op opened its OWN session — proves session isolation.
|
||||||
|
assert len(_SESSIONS_USED) >= 3
|
||||||
|
# Sessions used should each have committed (or be the audit one which commits).
|
||||||
|
for _s in _SESSIONS_USED:
|
||||||
|
# finalize/reserve happen via TokenQuotaService.* which we stub —
|
||||||
|
# they don't actually call commit on our fake session, but the
|
||||||
|
# audit session does. We just assert >=1 session committed.
|
||||||
|
pass
|
||||||
|
assert any(s.committed for s in _SESSIONS_USED)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_premium_failure_releases_reservation(monkeypatch):
|
||||||
|
from app.services.billable_calls import billable_call
|
||||||
|
|
||||||
|
spies = _patch_isolation_layer(
|
||||||
|
monkeypatch, reserve_result=_FakeQuotaResult(allowed=True)
|
||||||
|
)
|
||||||
|
user_id = uuid4()
|
||||||
|
|
||||||
|
class _ProviderError(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
with pytest.raises(_ProviderError):
|
||||||
|
async with billable_call(
|
||||||
|
user_id=user_id,
|
||||||
|
search_space_id=42,
|
||||||
|
billing_tier="premium",
|
||||||
|
base_model="openai/gpt-image-1",
|
||||||
|
quota_reserve_micros_override=50_000,
|
||||||
|
usage_type="image_generation",
|
||||||
|
):
|
||||||
|
raise _ProviderError("OpenRouter 503")
|
||||||
|
|
||||||
|
assert len(spies["reserve"]) == 1
|
||||||
|
assert spies["finalize"] == []
|
||||||
|
# Failure path: release the held reservation.
|
||||||
|
assert len(spies["release"]) == 1
|
||||||
|
assert spies["release"][0]["reserved_micros"] == 50_000
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_premium_uses_estimator_when_no_micros_override(monkeypatch):
|
||||||
|
"""When ``quota_reserve_micros_override`` is None we fall back to
|
||||||
|
``estimate_call_reserve_micros(base_model, quota_reserve_tokens)``.
|
||||||
|
Vision LLM calls take this path (token-priced models).
|
||||||
|
"""
|
||||||
|
from app.services.billable_calls import billable_call
|
||||||
|
|
||||||
|
spies = _patch_isolation_layer(
|
||||||
|
monkeypatch, reserve_result=_FakeQuotaResult(allowed=True)
|
||||||
|
)
|
||||||
|
|
||||||
|
captured_estimator_calls: list[dict[str, Any]] = []
|
||||||
|
|
||||||
|
def _fake_estimate(*, base_model, quota_reserve_tokens):
|
||||||
|
captured_estimator_calls.append(
|
||||||
|
{"base_model": base_model, "quota_reserve_tokens": quota_reserve_tokens}
|
||||||
|
)
|
||||||
|
return 12_345
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"app.services.billable_calls.estimate_call_reserve_micros",
|
||||||
|
_fake_estimate,
|
||||||
|
raising=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
user_id = uuid4()
|
||||||
|
async with billable_call(
|
||||||
|
user_id=user_id,
|
||||||
|
search_space_id=1,
|
||||||
|
billing_tier="premium",
|
||||||
|
base_model="openai/gpt-4o",
|
||||||
|
quota_reserve_tokens=4000,
|
||||||
|
usage_type="vision_extraction",
|
||||||
|
):
|
||||||
|
pass
|
||||||
|
|
||||||
|
assert captured_estimator_calls == [
|
||||||
|
{"base_model": "openai/gpt-4o", "quota_reserve_tokens": 4000}
|
||||||
|
]
|
||||||
|
assert spies["reserve"][0]["reserve_micros"] == 12_345
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Podcast / video-presentation usage_type coverage
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_free_podcast_path_audits_with_podcast_usage_type(monkeypatch):
|
||||||
|
"""Free podcast configs must skip reserve/finalize but still emit a
|
||||||
|
``TokenUsage`` row tagged ``usage_type='podcast_generation'`` so we
|
||||||
|
have full audit coverage of free-tier agent runs."""
|
||||||
|
from app.services.billable_calls import billable_call
|
||||||
|
|
||||||
|
spies = _patch_isolation_layer(
|
||||||
|
monkeypatch, reserve_result=_FakeQuotaResult(allowed=True)
|
||||||
|
)
|
||||||
|
user_id = uuid4()
|
||||||
|
|
||||||
|
async with billable_call(
|
||||||
|
user_id=user_id,
|
||||||
|
search_space_id=42,
|
||||||
|
billing_tier="free",
|
||||||
|
base_model="openrouter/some-free-model",
|
||||||
|
quota_reserve_micros_override=200_000,
|
||||||
|
usage_type="podcast_generation",
|
||||||
|
thread_id=99,
|
||||||
|
call_details={"podcast_id": 7, "title": "Test Podcast"},
|
||||||
|
) as acc:
|
||||||
|
# Two transcript LLM calls aggregated into one accumulator.
|
||||||
|
acc.add(
|
||||||
|
model="openrouter/some-free-model",
|
||||||
|
prompt_tokens=1500,
|
||||||
|
completion_tokens=8000,
|
||||||
|
total_tokens=9500,
|
||||||
|
cost_micros=0,
|
||||||
|
call_kind="chat",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert spies["reserve"] == []
|
||||||
|
assert spies["finalize"] == []
|
||||||
|
assert spies["release"] == []
|
||||||
|
|
||||||
|
assert len(spies["record"]) == 1
|
||||||
|
row = spies["record"][0]
|
||||||
|
assert row["usage_type"] == "podcast_generation"
|
||||||
|
assert row["thread_id"] == 99
|
||||||
|
assert row["search_space_id"] == 42
|
||||||
|
assert row["call_details"] == {"podcast_id": 7, "title": "Test Podcast"}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_premium_video_denial_raises_quota_insufficient(monkeypatch):
|
||||||
|
"""Premium video-presentation runs that hit a denied reservation must
|
||||||
|
raise ``QuotaInsufficientError`` *before* the graph runs and must not
|
||||||
|
emit an audit row (no work happened)."""
|
||||||
|
from app.services.billable_calls import (
|
||||||
|
QuotaInsufficientError,
|
||||||
|
billable_call,
|
||||||
|
)
|
||||||
|
|
||||||
|
spies = _patch_isolation_layer(
|
||||||
|
monkeypatch,
|
||||||
|
reserve_result=_FakeQuotaResult(
|
||||||
|
allowed=False, used=4_500_000, limit=5_000_000, remaining=500_000
|
||||||
|
),
|
||||||
|
)
|
||||||
|
user_id = uuid4()
|
||||||
|
|
||||||
|
with pytest.raises(QuotaInsufficientError) as exc_info:
|
||||||
|
async with billable_call(
|
||||||
|
user_id=user_id,
|
||||||
|
search_space_id=42,
|
||||||
|
billing_tier="premium",
|
||||||
|
base_model="gpt-5.4",
|
||||||
|
quota_reserve_micros_override=1_000_000,
|
||||||
|
usage_type="video_presentation_generation",
|
||||||
|
thread_id=99,
|
||||||
|
call_details={"video_presentation_id": 12, "title": "Test Video"},
|
||||||
|
):
|
||||||
|
pytest.fail("body should not run when reserve is denied")
|
||||||
|
|
||||||
|
err = exc_info.value
|
||||||
|
assert err.usage_type == "video_presentation_generation"
|
||||||
|
assert err.remaining_micros == 500_000
|
||||||
|
assert spies["reserve"][0]["reserve_micros"] == 1_000_000
|
||||||
|
assert spies["finalize"] == []
|
||||||
|
assert spies["release"] == []
|
||||||
|
assert spies["record"] == []
|
||||||
|
|
@ -214,3 +214,159 @@ def test_generate_configs_drops_non_text_and_non_tool_models():
|
||||||
assert "openai/gpt-4o" in model_names
|
assert "openai/gpt-4o" in model_names
|
||||||
assert "openai/dall-e" not in model_names
|
assert "openai/dall-e" not in model_names
|
||||||
assert "openai/completion-only" not in model_names
|
assert "openai/completion-only" not in model_names
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# _generate_image_gen_configs / _generate_vision_llm_configs
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_generate_image_gen_configs_filters_by_image_output():
|
||||||
|
"""Only models with ``output_modalities`` containing ``image`` are emitted.
|
||||||
|
Tool-calling and context filters are intentionally NOT applied — image
|
||||||
|
generation has nothing to do with tool calls and context windows.
|
||||||
|
"""
|
||||||
|
from app.services.openrouter_integration_service import (
|
||||||
|
_generate_image_gen_configs,
|
||||||
|
)
|
||||||
|
|
||||||
|
raw = [
|
||||||
|
# Pure image-gen model (small context, no tools — should still emit).
|
||||||
|
{
|
||||||
|
"id": "openai/gpt-image-1",
|
||||||
|
"architecture": {"output_modalities": ["image"]},
|
||||||
|
"context_length": 4_000,
|
||||||
|
"pricing": {"prompt": "0", "completion": "0"},
|
||||||
|
},
|
||||||
|
# Multi-modal: text+image output (should still emit).
|
||||||
|
{
|
||||||
|
"id": "google/gemini-2.5-flash-image",
|
||||||
|
"architecture": {"output_modalities": ["text", "image"]},
|
||||||
|
"context_length": 1_000_000,
|
||||||
|
"pricing": {"prompt": "0.000001", "completion": "0.000004"},
|
||||||
|
},
|
||||||
|
# Pure text model — must NOT emit.
|
||||||
|
{
|
||||||
|
"id": "openai/gpt-4o",
|
||||||
|
"architecture": {"output_modalities": ["text"]},
|
||||||
|
"context_length": 128_000,
|
||||||
|
"pricing": {"prompt": "0.000005", "completion": "0.000015"},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
cfgs = _generate_image_gen_configs(raw, dict(_SETTINGS_BASE))
|
||||||
|
model_names = {c["model_name"] for c in cfgs}
|
||||||
|
assert "openai/gpt-image-1" in model_names
|
||||||
|
assert "google/gemini-2.5-flash-image" in model_names
|
||||||
|
assert "openai/gpt-4o" not in model_names
|
||||||
|
|
||||||
|
# Each config must carry ``billing_tier`` for routing in image_generation_routes.
|
||||||
|
for c in cfgs:
|
||||||
|
assert c["billing_tier"] in {"free", "premium"}
|
||||||
|
assert c["provider"] == "OPENROUTER"
|
||||||
|
assert c[_OPENROUTER_DYNAMIC_MARKER] is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_generate_image_gen_configs_assigns_image_id_offset():
|
||||||
|
"""Image configs use a different id_offset (-20000) so their negative
|
||||||
|
IDs don't collide with chat configs (-10000) or vision configs (-30000).
|
||||||
|
"""
|
||||||
|
from app.services.openrouter_integration_service import (
|
||||||
|
_generate_image_gen_configs,
|
||||||
|
)
|
||||||
|
|
||||||
|
raw = [
|
||||||
|
{
|
||||||
|
"id": "openai/gpt-image-1",
|
||||||
|
"architecture": {"output_modalities": ["image"]},
|
||||||
|
"context_length": 4_000,
|
||||||
|
"pricing": {"prompt": "0", "completion": "0"},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
# Don't pass image_id_offset → use the module default (-20000).
|
||||||
|
cfgs = _generate_image_gen_configs(raw, dict(_SETTINGS_BASE))
|
||||||
|
assert all(c["id"] < -20_000 + 1 for c in cfgs)
|
||||||
|
assert all(c["id"] > -29_000_000 for c in cfgs)
|
||||||
|
|
||||||
|
|
||||||
|
def test_generate_vision_llm_configs_filters_by_image_input_text_output():
|
||||||
|
"""Vision LLMs must accept image input AND emit text — pure image-gen
|
||||||
|
(no text out) and text-only (no image in) models are excluded.
|
||||||
|
"""
|
||||||
|
from app.services.openrouter_integration_service import (
|
||||||
|
_generate_vision_llm_configs,
|
||||||
|
)
|
||||||
|
|
||||||
|
raw = [
|
||||||
|
# GPT-4o: vision LLM (image in, text out) — must emit.
|
||||||
|
{
|
||||||
|
"id": "openai/gpt-4o",
|
||||||
|
"architecture": {
|
||||||
|
"input_modalities": ["text", "image"],
|
||||||
|
"output_modalities": ["text"],
|
||||||
|
},
|
||||||
|
"context_length": 128_000,
|
||||||
|
"pricing": {"prompt": "0.000005", "completion": "0.000015"},
|
||||||
|
},
|
||||||
|
# Pure image generator — image *output*, no text out. Must NOT emit.
|
||||||
|
{
|
||||||
|
"id": "openai/gpt-image-1",
|
||||||
|
"architecture": {
|
||||||
|
"input_modalities": ["text"],
|
||||||
|
"output_modalities": ["image"],
|
||||||
|
},
|
||||||
|
"context_length": 4_000,
|
||||||
|
"pricing": {"prompt": "0", "completion": "0"},
|
||||||
|
},
|
||||||
|
# Pure text model (no image in). Must NOT emit.
|
||||||
|
{
|
||||||
|
"id": "anthropic/claude-3-haiku",
|
||||||
|
"architecture": {
|
||||||
|
"input_modalities": ["text"],
|
||||||
|
"output_modalities": ["text"],
|
||||||
|
},
|
||||||
|
"context_length": 200_000,
|
||||||
|
"pricing": {"prompt": "0.000001", "completion": "0.000005"},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
cfgs = _generate_vision_llm_configs(raw, dict(_SETTINGS_BASE))
|
||||||
|
names = {c["model_name"] for c in cfgs}
|
||||||
|
assert names == {"openai/gpt-4o"}
|
||||||
|
|
||||||
|
cfg = cfgs[0]
|
||||||
|
assert cfg["billing_tier"] == "premium"
|
||||||
|
# Pricing carried inline so pricing_registration can register vision
|
||||||
|
# under ``openrouter/openai/gpt-4o`` even if the chat catalogue cache
|
||||||
|
# is cleared.
|
||||||
|
assert cfg["input_cost_per_token"] == pytest.approx(5e-6)
|
||||||
|
assert cfg["output_cost_per_token"] == pytest.approx(15e-6)
|
||||||
|
assert cfg[_OPENROUTER_DYNAMIC_MARKER] is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_generate_vision_llm_configs_drops_chat_only_filters():
|
||||||
|
"""A small-context vision model that doesn't advertise tool calling is
|
||||||
|
still a valid vision LLM for "describe this image" prompts. The chat
|
||||||
|
filters (``supports_tool_calling``, ``has_sufficient_context``) must
|
||||||
|
NOT be applied to vision emission.
|
||||||
|
"""
|
||||||
|
from app.services.openrouter_integration_service import (
|
||||||
|
_generate_vision_llm_configs,
|
||||||
|
)
|
||||||
|
|
||||||
|
raw = [
|
||||||
|
{
|
||||||
|
"id": "tiny/vision-mini",
|
||||||
|
"architecture": {
|
||||||
|
"input_modalities": ["text", "image"],
|
||||||
|
"output_modalities": ["text"],
|
||||||
|
},
|
||||||
|
"supported_parameters": [], # no tools
|
||||||
|
"context_length": 4_000, # well below MIN_CONTEXT_LENGTH
|
||||||
|
"pricing": {"prompt": "0.0000001", "completion": "0.0000005"},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
cfgs = _generate_vision_llm_configs(raw, dict(_SETTINGS_BASE))
|
||||||
|
assert len(cfgs) == 1
|
||||||
|
assert cfgs[0]["model_name"] == "tiny/vision-mini"
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,447 @@
|
||||||
|
"""Pricing registration unit tests.
|
||||||
|
|
||||||
|
The pricing-registration module is what makes ``response_cost`` populate
|
||||||
|
correctly for OpenRouter dynamic models and operator-defined Azure
|
||||||
|
deployments — both of which LiteLLM doesn't natively know about. The tests
|
||||||
|
exercise:
|
||||||
|
|
||||||
|
* The alias generators emit every shape that LiteLLM's cost-callback might
|
||||||
|
use (``openrouter/X`` and bare ``X``; YAML-defined ``base_model``,
|
||||||
|
``provider/base_model``, ``provider/model_name``, plus the special
|
||||||
|
``azure_openai`` → ``azure`` normalisation).
|
||||||
|
* ``register_pricing_from_global_configs`` calls ``litellm.register_model``
|
||||||
|
with the right alias set and pricing values per provider.
|
||||||
|
* Configs without a resolvable pair of cost values are skipped — never
|
||||||
|
registered as zero, since that would override pricing LiteLLM might
|
||||||
|
already know natively.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.unit
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Alias generators
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_openrouter_alias_set_includes_prefixed_and_bare():
|
||||||
|
from app.services.pricing_registration import _alias_set_for_openrouter
|
||||||
|
|
||||||
|
aliases = _alias_set_for_openrouter("anthropic/claude-3-5-sonnet")
|
||||||
|
assert aliases == [
|
||||||
|
"openrouter/anthropic/claude-3-5-sonnet",
|
||||||
|
"anthropic/claude-3-5-sonnet",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_openrouter_alias_set_dedupes():
|
||||||
|
"""If the model id is already prefixed with ``openrouter/``, the alias
|
||||||
|
set must not contain duplicates that would re-register the same key
|
||||||
|
twice.
|
||||||
|
"""
|
||||||
|
from app.services.pricing_registration import _alias_set_for_openrouter
|
||||||
|
|
||||||
|
aliases = _alias_set_for_openrouter("openrouter/foo")
|
||||||
|
# The bare and prefixed variants compute to the same string here, so we
|
||||||
|
# at minimum require uniqueness.
|
||||||
|
assert len(aliases) == len(set(aliases))
|
||||||
|
|
||||||
|
|
||||||
|
def test_yaml_alias_set_for_azure_openai_normalises_to_azure():
|
||||||
|
"""``azure_openai`` (our YAML provider slug) must register under
|
||||||
|
``azure/<name>`` so the LiteLLM Router's deployment-resolution path
|
||||||
|
(which uses provider ``azure``) finds the pricing too.
|
||||||
|
"""
|
||||||
|
from app.services.pricing_registration import _alias_set_for_yaml
|
||||||
|
|
||||||
|
aliases = _alias_set_for_yaml(
|
||||||
|
provider="AZURE_OPENAI",
|
||||||
|
model_name="gpt-5.4",
|
||||||
|
base_model="gpt-5.4",
|
||||||
|
)
|
||||||
|
assert "gpt-5.4" in aliases
|
||||||
|
assert "azure_openai/gpt-5.4" in aliases
|
||||||
|
assert "azure/gpt-5.4" in aliases
|
||||||
|
|
||||||
|
|
||||||
|
def test_yaml_alias_set_distinguishes_model_name_and_base_model():
|
||||||
|
"""When ``model_name`` differs from ``base_model`` (operator labelled a
|
||||||
|
deployment), both must appear in the alias set since either may surface
|
||||||
|
in callbacks depending on the call path.
|
||||||
|
"""
|
||||||
|
from app.services.pricing_registration import _alias_set_for_yaml
|
||||||
|
|
||||||
|
aliases = _alias_set_for_yaml(
|
||||||
|
provider="OPENAI",
|
||||||
|
model_name="my-deployment-label",
|
||||||
|
base_model="gpt-4o",
|
||||||
|
)
|
||||||
|
assert "gpt-4o" in aliases
|
||||||
|
assert "openai/gpt-4o" in aliases
|
||||||
|
assert "my-deployment-label" in aliases
|
||||||
|
assert "openai/my-deployment-label" in aliases
|
||||||
|
|
||||||
|
|
||||||
|
def test_yaml_alias_set_omits_provider_prefix_when_provider_blank():
|
||||||
|
from app.services.pricing_registration import _alias_set_for_yaml
|
||||||
|
|
||||||
|
aliases = _alias_set_for_yaml(
|
||||||
|
provider="",
|
||||||
|
model_name="foo",
|
||||||
|
base_model="bar",
|
||||||
|
)
|
||||||
|
assert "bar" in aliases
|
||||||
|
assert "foo" in aliases
|
||||||
|
assert all("/" not in a for a in aliases)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# register_pricing_from_global_configs
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class _RegistrationSpy:
|
||||||
|
"""Captures the dicts passed to ``litellm.register_model``.
|
||||||
|
|
||||||
|
Many calls may go through; we just record them all and let tests assert
|
||||||
|
against the union.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.calls: list[dict[str, Any]] = []
|
||||||
|
|
||||||
|
def __call__(self, payload: dict[str, Any]) -> None:
|
||||||
|
self.calls.append(payload)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def all_keys(self) -> set[str]:
|
||||||
|
keys: set[str] = set()
|
||||||
|
for payload in self.calls:
|
||||||
|
keys.update(payload.keys())
|
||||||
|
return keys
|
||||||
|
|
||||||
|
|
||||||
|
def _patch_register(monkeypatch: pytest.MonkeyPatch) -> _RegistrationSpy:
|
||||||
|
spy = _RegistrationSpy()
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"app.services.pricing_registration.litellm.register_model",
|
||||||
|
spy,
|
||||||
|
raising=False,
|
||||||
|
)
|
||||||
|
return spy
|
||||||
|
|
||||||
|
|
||||||
|
def _patch_openrouter_pricing(
|
||||||
|
monkeypatch: pytest.MonkeyPatch, mapping: dict[str, dict[str, str]]
|
||||||
|
) -> None:
|
||||||
|
"""Pretend the OpenRouter integration is initialised with ``mapping``."""
|
||||||
|
|
||||||
|
class _Stub:
|
||||||
|
def get_raw_pricing(self) -> dict[str, dict[str, str]]:
|
||||||
|
return mapping
|
||||||
|
|
||||||
|
class _StubService:
|
||||||
|
@classmethod
|
||||||
|
def is_initialized(cls) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_instance(cls) -> _Stub:
|
||||||
|
return _Stub()
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"app.services.openrouter_integration_service.OpenRouterIntegrationService",
|
||||||
|
_StubService,
|
||||||
|
raising=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_openrouter_models_register_under_aliases(monkeypatch):
|
||||||
|
"""An OpenRouter config whose ``model_name`` is in the cached raw
|
||||||
|
pricing map is registered under both ``openrouter/X`` and bare ``X``.
|
||||||
|
"""
|
||||||
|
from app.config import config
|
||||||
|
from app.services.pricing_registration import register_pricing_from_global_configs
|
||||||
|
|
||||||
|
spy = _patch_register(monkeypatch)
|
||||||
|
_patch_openrouter_pricing(
|
||||||
|
monkeypatch,
|
||||||
|
{
|
||||||
|
"anthropic/claude-3-5-sonnet": {
|
||||||
|
"prompt": "0.000003",
|
||||||
|
"completion": "0.000015",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
config,
|
||||||
|
"GLOBAL_LLM_CONFIGS",
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"id": 1,
|
||||||
|
"provider": "OPENROUTER",
|
||||||
|
"model_name": "anthropic/claude-3-5-sonnet",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
register_pricing_from_global_configs()
|
||||||
|
|
||||||
|
assert "openrouter/anthropic/claude-3-5-sonnet" in spy.all_keys
|
||||||
|
assert "anthropic/claude-3-5-sonnet" in spy.all_keys
|
||||||
|
# Costs are float-converted from the raw OpenRouter strings.
|
||||||
|
payload = spy.calls[0]
|
||||||
|
assert payload["openrouter/anthropic/claude-3-5-sonnet"][
|
||||||
|
"input_cost_per_token"
|
||||||
|
] == pytest.approx(3e-6)
|
||||||
|
assert payload["openrouter/anthropic/claude-3-5-sonnet"][
|
||||||
|
"output_cost_per_token"
|
||||||
|
] == pytest.approx(15e-6)
|
||||||
|
assert (
|
||||||
|
payload["openrouter/anthropic/claude-3-5-sonnet"]["litellm_provider"]
|
||||||
|
== "openrouter"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_yaml_override_registers_under_alias_set(monkeypatch):
|
||||||
|
"""Operator-declared ``input_cost_per_token`` /
|
||||||
|
``output_cost_per_token`` on a YAML config registers under every
|
||||||
|
alias the YAML alias generator produces — including the ``azure/``
|
||||||
|
normalisation for ``azure_openai`` providers.
|
||||||
|
"""
|
||||||
|
from app.config import config
|
||||||
|
from app.services.pricing_registration import register_pricing_from_global_configs
|
||||||
|
|
||||||
|
spy = _patch_register(monkeypatch)
|
||||||
|
_patch_openrouter_pricing(monkeypatch, {})
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
config,
|
||||||
|
"GLOBAL_LLM_CONFIGS",
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"id": 1,
|
||||||
|
"provider": "AZURE_OPENAI",
|
||||||
|
"model_name": "gpt-5.4",
|
||||||
|
"litellm_params": {
|
||||||
|
"base_model": "gpt-5.4",
|
||||||
|
"input_cost_per_token": 2e-6,
|
||||||
|
"output_cost_per_token": 8e-6,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
register_pricing_from_global_configs()
|
||||||
|
|
||||||
|
keys = spy.all_keys
|
||||||
|
assert "gpt-5.4" in keys
|
||||||
|
assert "azure_openai/gpt-5.4" in keys
|
||||||
|
assert "azure/gpt-5.4" in keys
|
||||||
|
|
||||||
|
payload = spy.calls[0]
|
||||||
|
entry = payload["gpt-5.4"]
|
||||||
|
assert entry["input_cost_per_token"] == pytest.approx(2e-6)
|
||||||
|
assert entry["output_cost_per_token"] == pytest.approx(8e-6)
|
||||||
|
assert entry["litellm_provider"] == "azure"
|
||||||
|
|
||||||
|
|
||||||
|
def test_no_override_means_no_registration(monkeypatch):
|
||||||
|
"""A YAML config that *omits* both pricing fields must NOT be registered
|
||||||
|
— registering as zero would override LiteLLM's native pricing for the
|
||||||
|
``base_model`` key (e.g. ``gpt-4o``) and silently make every user's
|
||||||
|
bill drop to $0. Fail-safe is "skip and warn", not "register zero".
|
||||||
|
"""
|
||||||
|
from app.config import config
|
||||||
|
from app.services.pricing_registration import register_pricing_from_global_configs
|
||||||
|
|
||||||
|
spy = _patch_register(monkeypatch)
|
||||||
|
_patch_openrouter_pricing(monkeypatch, {})
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
config,
|
||||||
|
"GLOBAL_LLM_CONFIGS",
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"id": 1,
|
||||||
|
"provider": "OPENAI",
|
||||||
|
"model_name": "gpt-4o",
|
||||||
|
"litellm_params": {"base_model": "gpt-4o"},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
register_pricing_from_global_configs()
|
||||||
|
|
||||||
|
assert spy.calls == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_openrouter_skipped_when_pricing_missing(monkeypatch):
|
||||||
|
"""If the OpenRouter raw-pricing cache doesn't carry an entry for a
|
||||||
|
configured model (network blip during refresh, model added later, etc.),
|
||||||
|
we skip it rather than registering zero pricing.
|
||||||
|
"""
|
||||||
|
from app.config import config
|
||||||
|
from app.services.pricing_registration import register_pricing_from_global_configs
|
||||||
|
|
||||||
|
spy = _patch_register(monkeypatch)
|
||||||
|
_patch_openrouter_pricing(
|
||||||
|
monkeypatch, {"some/other-model": {"prompt": "1", "completion": "1"}}
|
||||||
|
)
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
config,
|
||||||
|
"GLOBAL_LLM_CONFIGS",
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"id": 1,
|
||||||
|
"provider": "OPENROUTER",
|
||||||
|
"model_name": "anthropic/claude-3-5-sonnet",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
register_pricing_from_global_configs()
|
||||||
|
|
||||||
|
assert spy.calls == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_register_continues_after_individual_failure(monkeypatch, caplog):
|
||||||
|
"""A single bad ``register_model`` call (e.g. raising LiteLLM error)
|
||||||
|
must not abort registration of the remaining configs.
|
||||||
|
"""
|
||||||
|
from app.config import config
|
||||||
|
from app.services.pricing_registration import register_pricing_from_global_configs
|
||||||
|
|
||||||
|
failing_keys: set[str] = {"anthropic/claude-3-5-sonnet"}
|
||||||
|
successful_calls: list[dict[str, Any]] = []
|
||||||
|
|
||||||
|
def _maybe_fail(payload: dict[str, Any]) -> None:
|
||||||
|
if any(k in failing_keys for k in payload):
|
||||||
|
raise RuntimeError("boom")
|
||||||
|
successful_calls.append(payload)
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"app.services.pricing_registration.litellm.register_model",
|
||||||
|
_maybe_fail,
|
||||||
|
raising=False,
|
||||||
|
)
|
||||||
|
_patch_openrouter_pricing(
|
||||||
|
monkeypatch,
|
||||||
|
{
|
||||||
|
"anthropic/claude-3-5-sonnet": {
|
||||||
|
"prompt": "0.000003",
|
||||||
|
"completion": "0.000015",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
config,
|
||||||
|
"GLOBAL_LLM_CONFIGS",
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"id": 1,
|
||||||
|
"provider": "OPENROUTER",
|
||||||
|
"model_name": "anthropic/claude-3-5-sonnet",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2,
|
||||||
|
"provider": "OPENAI",
|
||||||
|
"model_name": "custom-deployment",
|
||||||
|
"litellm_params": {
|
||||||
|
"base_model": "custom-deployment",
|
||||||
|
"input_cost_per_token": 1e-6,
|
||||||
|
"output_cost_per_token": 2e-6,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
register_pricing_from_global_configs()
|
||||||
|
|
||||||
|
# The good config still registered.
|
||||||
|
assert any("custom-deployment" in payload for payload in successful_calls)
|
||||||
|
|
||||||
|
|
||||||
|
def test_vision_configs_registered_with_chat_shape(monkeypatch):
|
||||||
|
"""``register_pricing_from_global_configs`` walks
|
||||||
|
``GLOBAL_VISION_LLM_CONFIGS`` in addition to the chat configs so vision
|
||||||
|
calls (during indexing) bill correctly. Vision configs use the same
|
||||||
|
chat-shape token prices, but image-gen pricing is intentionally NOT
|
||||||
|
registered here (handled via ``response_cost`` in LiteLLM).
|
||||||
|
"""
|
||||||
|
from app.config import config
|
||||||
|
from app.services.pricing_registration import register_pricing_from_global_configs
|
||||||
|
|
||||||
|
spy = _patch_register(monkeypatch)
|
||||||
|
_patch_openrouter_pricing(
|
||||||
|
monkeypatch,
|
||||||
|
{"openai/gpt-4o": {"prompt": "0.000005", "completion": "0.000015"}},
|
||||||
|
)
|
||||||
|
|
||||||
|
# No chat configs — only vision. Proves the vision walk is a separate
|
||||||
|
# iteration, not piggy-backed on the chat list.
|
||||||
|
monkeypatch.setattr(config, "GLOBAL_LLM_CONFIGS", [])
|
||||||
|
monkeypatch.setattr(
|
||||||
|
config,
|
||||||
|
"GLOBAL_VISION_LLM_CONFIGS",
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"id": -1,
|
||||||
|
"provider": "OPENROUTER",
|
||||||
|
"model_name": "openai/gpt-4o",
|
||||||
|
"billing_tier": "premium",
|
||||||
|
"input_cost_per_token": 5e-6,
|
||||||
|
"output_cost_per_token": 15e-6,
|
||||||
|
}
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
register_pricing_from_global_configs()
|
||||||
|
|
||||||
|
assert "openrouter/openai/gpt-4o" in spy.all_keys
|
||||||
|
payload_value = spy.calls[0]["openrouter/openai/gpt-4o"]
|
||||||
|
assert payload_value["mode"] == "chat"
|
||||||
|
assert payload_value["litellm_provider"] == "openrouter"
|
||||||
|
assert payload_value["input_cost_per_token"] == pytest.approx(5e-6)
|
||||||
|
assert payload_value["output_cost_per_token"] == pytest.approx(15e-6)
|
||||||
|
|
||||||
|
|
||||||
|
def test_vision_with_inline_pricing_when_or_cache_missing(monkeypatch):
|
||||||
|
"""If the OpenRouter pricing cache misses a vision model (different
|
||||||
|
catalogue surface), the vision walk falls back to inline
|
||||||
|
``input_cost_per_token``/``output_cost_per_token`` on the cfg itself.
|
||||||
|
"""
|
||||||
|
from app.config import config
|
||||||
|
from app.services.pricing_registration import register_pricing_from_global_configs
|
||||||
|
|
||||||
|
spy = _patch_register(monkeypatch)
|
||||||
|
_patch_openrouter_pricing(monkeypatch, {})
|
||||||
|
|
||||||
|
monkeypatch.setattr(config, "GLOBAL_LLM_CONFIGS", [])
|
||||||
|
monkeypatch.setattr(
|
||||||
|
config,
|
||||||
|
"GLOBAL_VISION_LLM_CONFIGS",
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"id": -1,
|
||||||
|
"provider": "OPENROUTER",
|
||||||
|
"model_name": "google/gemini-2.5-flash",
|
||||||
|
"billing_tier": "premium",
|
||||||
|
"input_cost_per_token": 1e-6,
|
||||||
|
"output_cost_per_token": 4e-6,
|
||||||
|
}
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
register_pricing_from_global_configs()
|
||||||
|
|
||||||
|
assert "openrouter/google/gemini-2.5-flash" in spy.all_keys
|
||||||
|
|
@ -0,0 +1,157 @@
|
||||||
|
"""Unit tests for ``QuotaCheckedVisionLLM``.
|
||||||
|
|
||||||
|
Validates that:
|
||||||
|
|
||||||
|
* Calling ``ainvoke`` routes through ``billable_call`` (premium credit
|
||||||
|
enforcement) and forwards the inner LLM's response on success.
|
||||||
|
* The wrapper proxies non-overridden attributes to the inner LLM
|
||||||
|
(``__getattr__``) so ``invoke`` / ``astream`` / ``with_structured_output``
|
||||||
|
still work without quota gating (they're not used in indexing today).
|
||||||
|
* When ``billable_call`` raises ``QuotaInsufficientError`` the wrapper
|
||||||
|
bubbles it up — the ETL pipeline catches that and falls back to OCR.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import contextlib
|
||||||
|
from typing import Any
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.unit
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeInnerLLM:
|
||||||
|
"""Stand-in for ``langchain_litellm.ChatLiteLLM``."""
|
||||||
|
|
||||||
|
def __init__(self, response: Any = "OCR'd content") -> None:
|
||||||
|
self._response = response
|
||||||
|
self.ainvoke_calls: list[Any] = []
|
||||||
|
|
||||||
|
async def ainvoke(self, input: Any, *args: Any, **kwargs: Any) -> Any:
|
||||||
|
self.ainvoke_calls.append(input)
|
||||||
|
return self._response
|
||||||
|
|
||||||
|
def some_other_method(self, x: int) -> int:
|
||||||
|
return x * 2
|
||||||
|
|
||||||
|
|
||||||
|
@contextlib.asynccontextmanager
|
||||||
|
async def _passthrough_billable_call(**_kwargs):
|
||||||
|
"""Stand-in for billable_call that always allows the call to run."""
|
||||||
|
|
||||||
|
class _Acc:
|
||||||
|
total_cost_micros = 0
|
||||||
|
total_prompt_tokens = 0
|
||||||
|
total_completion_tokens = 0
|
||||||
|
grand_total = 0
|
||||||
|
calls: list[Any] = []
|
||||||
|
|
||||||
|
def per_message_summary(self) -> dict[str, dict[str, int]]:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
yield _Acc()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_ainvoke_routes_through_billable_call(monkeypatch):
|
||||||
|
from app.services.quota_checked_vision_llm import QuotaCheckedVisionLLM
|
||||||
|
|
||||||
|
captured_kwargs: list[dict[str, Any]] = []
|
||||||
|
|
||||||
|
@contextlib.asynccontextmanager
|
||||||
|
async def _spy_billable_call(**kwargs):
|
||||||
|
captured_kwargs.append(kwargs)
|
||||||
|
async with _passthrough_billable_call() as acc:
|
||||||
|
yield acc
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"app.services.quota_checked_vision_llm.billable_call",
|
||||||
|
_spy_billable_call,
|
||||||
|
raising=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
inner = _FakeInnerLLM(response="A red apple on a white table")
|
||||||
|
user_id = uuid4()
|
||||||
|
wrapper = QuotaCheckedVisionLLM(
|
||||||
|
inner,
|
||||||
|
user_id=user_id,
|
||||||
|
search_space_id=99,
|
||||||
|
billing_tier="premium",
|
||||||
|
base_model="openai/gpt-4o",
|
||||||
|
quota_reserve_tokens=4000,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await wrapper.ainvoke([{"text": "what is this?"}])
|
||||||
|
assert result == "A red apple on a white table"
|
||||||
|
assert len(inner.ainvoke_calls) == 1
|
||||||
|
assert len(captured_kwargs) == 1
|
||||||
|
bc_kwargs = captured_kwargs[0]
|
||||||
|
assert bc_kwargs["user_id"] == user_id
|
||||||
|
assert bc_kwargs["search_space_id"] == 99
|
||||||
|
assert bc_kwargs["billing_tier"] == "premium"
|
||||||
|
assert bc_kwargs["base_model"] == "openai/gpt-4o"
|
||||||
|
assert bc_kwargs["quota_reserve_tokens"] == 4000
|
||||||
|
assert bc_kwargs["usage_type"] == "vision_extraction"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_ainvoke_propagates_quota_insufficient_error(monkeypatch):
|
||||||
|
from app.services.billable_calls import QuotaInsufficientError
|
||||||
|
from app.services.quota_checked_vision_llm import QuotaCheckedVisionLLM
|
||||||
|
|
||||||
|
@contextlib.asynccontextmanager
|
||||||
|
async def _denying_billable_call(**_kwargs):
|
||||||
|
raise QuotaInsufficientError(
|
||||||
|
usage_type="vision_extraction",
|
||||||
|
used_micros=5_000_000,
|
||||||
|
limit_micros=5_000_000,
|
||||||
|
remaining_micros=0,
|
||||||
|
)
|
||||||
|
yield # unreachable but required for asynccontextmanager type
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"app.services.quota_checked_vision_llm.billable_call",
|
||||||
|
_denying_billable_call,
|
||||||
|
raising=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
inner = _FakeInnerLLM()
|
||||||
|
wrapper = QuotaCheckedVisionLLM(
|
||||||
|
inner,
|
||||||
|
user_id=uuid4(),
|
||||||
|
search_space_id=1,
|
||||||
|
billing_tier="premium",
|
||||||
|
base_model="openai/gpt-4o",
|
||||||
|
quota_reserve_tokens=4000,
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(QuotaInsufficientError):
|
||||||
|
await wrapper.ainvoke([{"text": "x"}])
|
||||||
|
|
||||||
|
# Inner LLM never ran on a denied reservation.
|
||||||
|
assert inner.ainvoke_calls == []
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_proxies_non_overridden_attributes_to_inner():
|
||||||
|
"""``__getattr__`` forwards anything not on the proxy itself, so any
|
||||||
|
method we didn't explicitly override (``invoke``, ``astream``,
|
||||||
|
``with_structured_output``, etc.) still works — just without quota
|
||||||
|
gating, which is fine because the indexer only ever calls ainvoke.
|
||||||
|
"""
|
||||||
|
from app.services.quota_checked_vision_llm import QuotaCheckedVisionLLM
|
||||||
|
|
||||||
|
inner = _FakeInnerLLM()
|
||||||
|
wrapper = QuotaCheckedVisionLLM(
|
||||||
|
inner,
|
||||||
|
user_id=uuid4(),
|
||||||
|
search_space_id=1,
|
||||||
|
billing_tier="premium",
|
||||||
|
base_model="openai/gpt-4o",
|
||||||
|
quota_reserve_tokens=4000,
|
||||||
|
)
|
||||||
|
|
||||||
|
# ``some_other_method`` is on the inner only.
|
||||||
|
assert wrapper.some_other_method(7) == 14
|
||||||
|
|
@ -0,0 +1,515 @@
|
||||||
|
"""Cost-based premium quota unit tests.
|
||||||
|
|
||||||
|
Covers the USD-micro behaviour added in migration 140:
|
||||||
|
|
||||||
|
* ``TurnTokenAccumulator.total_cost_micros`` sums ``cost_micros`` across all
|
||||||
|
calls in a turn — used as the debit amount when ``agent_config.is_premium``
|
||||||
|
is true, regardless of which underlying model produced each call. This
|
||||||
|
preserves the prior "premium turn → all calls in turn count" rule from the
|
||||||
|
token-based system.
|
||||||
|
* ``estimate_call_reserve_micros`` scales linearly with model pricing,
|
||||||
|
clamps to a sane floor when pricing is unknown, and respects the
|
||||||
|
``QUOTA_MAX_RESERVE_MICROS`` ceiling so a misconfigured "$1000/M" entry
|
||||||
|
can't lock the whole balance on one call.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.unit
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# TurnTokenAccumulator — premium-turn debit semantics
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_total_cost_micros_sums_premium_and_free_calls():
|
||||||
|
"""A premium turn that also called a free sub-agent debits the union.
|
||||||
|
|
||||||
|
The plan deliberately preserved the existing "premium turn → all calls
|
||||||
|
count" behaviour because per-call premium filtering relied on
|
||||||
|
``LLMRouterService._premium_model_strings`` which only covers router-pool
|
||||||
|
deployments. ``total_cost_micros`` therefore must include free-model
|
||||||
|
calls (whose ``cost_micros`` is typically ``0``) as well as the premium
|
||||||
|
call's actual provider cost.
|
||||||
|
"""
|
||||||
|
from app.services.token_tracking_service import TurnTokenAccumulator
|
||||||
|
|
||||||
|
acc = TurnTokenAccumulator()
|
||||||
|
# Premium model (e.g. claude-opus): non-zero cost.
|
||||||
|
acc.add(
|
||||||
|
model="anthropic/claude-3-5-sonnet",
|
||||||
|
prompt_tokens=1200,
|
||||||
|
completion_tokens=400,
|
||||||
|
total_tokens=1600,
|
||||||
|
cost_micros=12_345,
|
||||||
|
)
|
||||||
|
# Free sub-agent (e.g. title-gen on a free model): zero cost.
|
||||||
|
acc.add(
|
||||||
|
model="gpt-4o-mini",
|
||||||
|
prompt_tokens=120,
|
||||||
|
completion_tokens=20,
|
||||||
|
total_tokens=140,
|
||||||
|
cost_micros=0,
|
||||||
|
)
|
||||||
|
# A second premium-priced call within the same turn.
|
||||||
|
acc.add(
|
||||||
|
model="anthropic/claude-3-5-sonnet",
|
||||||
|
prompt_tokens=800,
|
||||||
|
completion_tokens=200,
|
||||||
|
total_tokens=1000,
|
||||||
|
cost_micros=7_500,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert acc.total_cost_micros == 12_345 + 0 + 7_500
|
||||||
|
# Token totals stay correct so the FE display path still works.
|
||||||
|
assert acc.grand_total == 1600 + 140 + 1000
|
||||||
|
|
||||||
|
|
||||||
|
def test_total_cost_micros_zero_when_no_calls():
|
||||||
|
"""An empty accumulator must report zero cost (no division-by-zero, no None)."""
|
||||||
|
from app.services.token_tracking_service import TurnTokenAccumulator
|
||||||
|
|
||||||
|
acc = TurnTokenAccumulator()
|
||||||
|
assert acc.total_cost_micros == 0
|
||||||
|
assert acc.grand_total == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_per_message_summary_groups_cost_by_model():
|
||||||
|
"""``per_message_summary`` must accumulate ``cost_micros`` per model so the
|
||||||
|
SSE ``model_breakdown`` payload reports actual USD spend per provider.
|
||||||
|
"""
|
||||||
|
from app.services.token_tracking_service import TurnTokenAccumulator
|
||||||
|
|
||||||
|
acc = TurnTokenAccumulator()
|
||||||
|
acc.add(
|
||||||
|
model="claude-3-5-sonnet",
|
||||||
|
prompt_tokens=100,
|
||||||
|
completion_tokens=50,
|
||||||
|
total_tokens=150,
|
||||||
|
cost_micros=4_000,
|
||||||
|
)
|
||||||
|
acc.add(
|
||||||
|
model="claude-3-5-sonnet",
|
||||||
|
prompt_tokens=200,
|
||||||
|
completion_tokens=100,
|
||||||
|
total_tokens=300,
|
||||||
|
cost_micros=8_000,
|
||||||
|
)
|
||||||
|
acc.add(
|
||||||
|
model="gpt-4o-mini",
|
||||||
|
prompt_tokens=50,
|
||||||
|
completion_tokens=10,
|
||||||
|
total_tokens=60,
|
||||||
|
cost_micros=200,
|
||||||
|
)
|
||||||
|
|
||||||
|
summary = acc.per_message_summary()
|
||||||
|
assert summary["claude-3-5-sonnet"]["cost_micros"] == 12_000
|
||||||
|
assert summary["claude-3-5-sonnet"]["total_tokens"] == 450
|
||||||
|
assert summary["gpt-4o-mini"]["cost_micros"] == 200
|
||||||
|
|
||||||
|
|
||||||
|
def test_serialized_calls_includes_cost_micros():
|
||||||
|
"""``serialized_calls`` is what flows into the SSE ``call_details``
|
||||||
|
payload; cost_micros must be present on each entry so the FE message-info
|
||||||
|
dropdown can render per-call USD.
|
||||||
|
"""
|
||||||
|
from app.services.token_tracking_service import TurnTokenAccumulator
|
||||||
|
|
||||||
|
acc = TurnTokenAccumulator()
|
||||||
|
acc.add(
|
||||||
|
model="m",
|
||||||
|
prompt_tokens=1,
|
||||||
|
completion_tokens=1,
|
||||||
|
total_tokens=2,
|
||||||
|
cost_micros=42,
|
||||||
|
)
|
||||||
|
serialized = acc.serialized_calls()
|
||||||
|
assert serialized == [
|
||||||
|
{
|
||||||
|
"model": "m",
|
||||||
|
"prompt_tokens": 1,
|
||||||
|
"completion_tokens": 1,
|
||||||
|
"total_tokens": 2,
|
||||||
|
"cost_micros": 42,
|
||||||
|
"call_kind": "chat",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# estimate_call_reserve_micros — sizing and clamping
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_reserve_returns_floor_when_model_unknown(monkeypatch):
|
||||||
|
"""If LiteLLM doesn't know the model, ``get_model_info`` raises and the
|
||||||
|
helper falls back to the 100-micro floor — small enough that a user with
|
||||||
|
$0.0001 left can still send a tiny request, but non-zero so we still gate
|
||||||
|
against an empty balance.
|
||||||
|
"""
|
||||||
|
import litellm
|
||||||
|
|
||||||
|
from app.services import token_quota_service
|
||||||
|
|
||||||
|
def _raise(_name):
|
||||||
|
raise KeyError("unknown")
|
||||||
|
|
||||||
|
monkeypatch.setattr(litellm, "get_model_info", _raise, raising=False)
|
||||||
|
|
||||||
|
micros = token_quota_service.estimate_call_reserve_micros(
|
||||||
|
base_model="nonexistent-model",
|
||||||
|
quota_reserve_tokens=4000,
|
||||||
|
)
|
||||||
|
assert micros == token_quota_service._QUOTA_MIN_RESERVE_MICROS
|
||||||
|
assert micros == 100
|
||||||
|
|
||||||
|
|
||||||
|
def test_reserve_returns_floor_when_pricing_is_zero(monkeypatch):
|
||||||
|
"""LiteLLM may *return* a model with both cost-per-token fields at 0
|
||||||
|
(pricing not yet registered). The helper must not multiply 0 x tokens
|
||||||
|
and end up reserving 0 — it must clamp to the floor.
|
||||||
|
"""
|
||||||
|
import litellm
|
||||||
|
|
||||||
|
from app.services import token_quota_service
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
litellm,
|
||||||
|
"get_model_info",
|
||||||
|
lambda _name: {"input_cost_per_token": 0, "output_cost_per_token": 0},
|
||||||
|
raising=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
micros = token_quota_service.estimate_call_reserve_micros(
|
||||||
|
base_model="some-pending-model",
|
||||||
|
quota_reserve_tokens=4000,
|
||||||
|
)
|
||||||
|
assert micros == token_quota_service._QUOTA_MIN_RESERVE_MICROS
|
||||||
|
|
||||||
|
|
||||||
|
def test_reserve_scales_with_model_cost(monkeypatch):
|
||||||
|
"""Claude-Opus-priced model with 4000 reserve_tokens reserves
|
||||||
|
~$0.36 = 360_000 micros. Critically this must NOT be clamped down to
|
||||||
|
some small artificial cap — that was the bug the plan called out.
|
||||||
|
"""
|
||||||
|
import litellm
|
||||||
|
|
||||||
|
from app.config import config
|
||||||
|
from app.services import token_quota_service
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
litellm,
|
||||||
|
"get_model_info",
|
||||||
|
lambda _name: {
|
||||||
|
"input_cost_per_token": 15e-6,
|
||||||
|
"output_cost_per_token": 75e-6,
|
||||||
|
},
|
||||||
|
raising=False,
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(config, "QUOTA_MAX_RESERVE_MICROS", 1_000_000, raising=False)
|
||||||
|
|
||||||
|
micros = token_quota_service.estimate_call_reserve_micros(
|
||||||
|
base_model="claude-3-opus",
|
||||||
|
quota_reserve_tokens=4000,
|
||||||
|
)
|
||||||
|
# 4000 * (15e-6 + 75e-6) = 4000 * 90e-6 = 0.36 USD = 360_000 micros.
|
||||||
|
assert micros == 360_000
|
||||||
|
|
||||||
|
|
||||||
|
def test_reserve_clamps_to_max_ceiling(monkeypatch):
|
||||||
|
"""A misconfigured "$1000 / M" model with 4000 reserve_tokens would
|
||||||
|
nominally compute to $4 = 4_000_000 micros. The ceiling
|
||||||
|
``QUOTA_MAX_RESERVE_MICROS`` must clamp that so a bad pricing entry
|
||||||
|
can't lock the user's whole balance on one call.
|
||||||
|
"""
|
||||||
|
import litellm
|
||||||
|
|
||||||
|
from app.config import config
|
||||||
|
from app.services import token_quota_service
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
litellm,
|
||||||
|
"get_model_info",
|
||||||
|
lambda _name: {
|
||||||
|
"input_cost_per_token": 1e-3,
|
||||||
|
"output_cost_per_token": 0,
|
||||||
|
},
|
||||||
|
raising=False,
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(config, "QUOTA_MAX_RESERVE_MICROS", 1_000_000, raising=False)
|
||||||
|
|
||||||
|
micros = token_quota_service.estimate_call_reserve_micros(
|
||||||
|
base_model="oops-misconfigured",
|
||||||
|
quota_reserve_tokens=4000,
|
||||||
|
)
|
||||||
|
assert micros == 1_000_000
|
||||||
|
|
||||||
|
|
||||||
|
def test_reserve_uses_default_when_quota_reserve_tokens_missing(monkeypatch):
|
||||||
|
"""Per-config ``quota_reserve_tokens`` is optional; when ``None`` or
|
||||||
|
zero, the helper must fall back to the global ``QUOTA_MAX_RESERVE_PER_CALL``
|
||||||
|
so anonymous-style configs still reserve the operator-tunable default.
|
||||||
|
"""
|
||||||
|
import litellm
|
||||||
|
|
||||||
|
from app.config import config
|
||||||
|
from app.services import token_quota_service
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
litellm,
|
||||||
|
"get_model_info",
|
||||||
|
lambda _name: {
|
||||||
|
"input_cost_per_token": 1e-6,
|
||||||
|
"output_cost_per_token": 1e-6,
|
||||||
|
},
|
||||||
|
raising=False,
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(config, "QUOTA_MAX_RESERVE_PER_CALL", 2000, raising=False)
|
||||||
|
monkeypatch.setattr(config, "QUOTA_MAX_RESERVE_MICROS", 1_000_000, raising=False)
|
||||||
|
|
||||||
|
# 2000 * (1e-6 + 1e-6) = 4e-3 USD = 4000 micros
|
||||||
|
assert (
|
||||||
|
token_quota_service.estimate_call_reserve_micros(
|
||||||
|
base_model="cheap", quota_reserve_tokens=None
|
||||||
|
)
|
||||||
|
== 4000
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
token_quota_service.estimate_call_reserve_micros(
|
||||||
|
base_model="cheap", quota_reserve_tokens=0
|
||||||
|
)
|
||||||
|
== 4000
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# TokenTrackingCallback — image vs chat usage shape
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeImageUsage:
|
||||||
|
"""Mimics LiteLLM's ``ImageUsage`` (input_tokens / output_tokens shape)."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
input_tokens: int = 0,
|
||||||
|
output_tokens: int = 0,
|
||||||
|
total_tokens: int | None = None,
|
||||||
|
) -> None:
|
||||||
|
self.input_tokens = input_tokens
|
||||||
|
self.output_tokens = output_tokens
|
||||||
|
if total_tokens is not None:
|
||||||
|
self.total_tokens = total_tokens
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeImageResponse:
|
||||||
|
"""Mimics LiteLLM's ``ImageResponse`` — same name so the callback's
|
||||||
|
``type(...).__name__`` probe routes to the image branch.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, usage: _FakeImageUsage, response_cost: float | None = None):
|
||||||
|
self.usage = usage
|
||||||
|
if response_cost is not None:
|
||||||
|
self._hidden_params = {"response_cost": response_cost}
|
||||||
|
|
||||||
|
|
||||||
|
# Re-tag the helper class as ``ImageResponse`` for the type-name probe in
|
||||||
|
# the callback. We can't simply name the class ``ImageResponse`` because
|
||||||
|
# the test runner sometimes imports test modules in surprising ways and
|
||||||
|
# we want to be explicit.
|
||||||
|
_FakeImageResponse.__name__ = "ImageResponse"
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeChatUsage:
|
||||||
|
def __init__(self, prompt: int, completion: int):
|
||||||
|
self.prompt_tokens = prompt
|
||||||
|
self.completion_tokens = completion
|
||||||
|
self.total_tokens = prompt + completion
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeChatResponse:
|
||||||
|
def __init__(self, usage: _FakeChatUsage):
|
||||||
|
self.usage = usage
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_callback_reads_image_usage_input_output_tokens():
|
||||||
|
"""``TokenTrackingCallback`` must read ``input_tokens``/``output_tokens``
|
||||||
|
for ``ImageResponse`` (LiteLLM's ImageUsage shape), NOT
|
||||||
|
prompt_tokens/completion_tokens which is the chat shape.
|
||||||
|
"""
|
||||||
|
from app.services.token_tracking_service import (
|
||||||
|
TokenTrackingCallback,
|
||||||
|
scoped_turn,
|
||||||
|
)
|
||||||
|
|
||||||
|
cb = TokenTrackingCallback()
|
||||||
|
response = _FakeImageResponse(
|
||||||
|
usage=_FakeImageUsage(input_tokens=42, output_tokens=8, total_tokens=50),
|
||||||
|
response_cost=0.04, # $0.04 per image
|
||||||
|
)
|
||||||
|
|
||||||
|
async with scoped_turn() as acc:
|
||||||
|
await cb.async_log_success_event(
|
||||||
|
kwargs={"model": "openai/gpt-image-1", "response_cost": 0.04},
|
||||||
|
response_obj=response,
|
||||||
|
start_time=None,
|
||||||
|
end_time=None,
|
||||||
|
)
|
||||||
|
assert len(acc.calls) == 1
|
||||||
|
call = acc.calls[0]
|
||||||
|
assert call.prompt_tokens == 42
|
||||||
|
assert call.completion_tokens == 8
|
||||||
|
assert call.total_tokens == 50
|
||||||
|
# 0.04 USD = 40_000 micros
|
||||||
|
assert call.cost_micros == 40_000
|
||||||
|
assert call.call_kind == "image_generation"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_callback_chat_path_unchanged():
|
||||||
|
"""Chat responses must still read prompt_tokens/completion_tokens."""
|
||||||
|
from app.services.token_tracking_service import (
|
||||||
|
TokenTrackingCallback,
|
||||||
|
scoped_turn,
|
||||||
|
)
|
||||||
|
|
||||||
|
cb = TokenTrackingCallback()
|
||||||
|
response = _FakeChatResponse(_FakeChatUsage(prompt=120, completion=30))
|
||||||
|
|
||||||
|
async with scoped_turn() as acc:
|
||||||
|
await cb.async_log_success_event(
|
||||||
|
kwargs={
|
||||||
|
"model": "openrouter/anthropic/claude-3-5-sonnet",
|
||||||
|
"response_cost": 0.0036,
|
||||||
|
},
|
||||||
|
response_obj=response,
|
||||||
|
start_time=None,
|
||||||
|
end_time=None,
|
||||||
|
)
|
||||||
|
assert len(acc.calls) == 1
|
||||||
|
call = acc.calls[0]
|
||||||
|
assert call.prompt_tokens == 120
|
||||||
|
assert call.completion_tokens == 30
|
||||||
|
assert call.total_tokens == 150
|
||||||
|
assert call.cost_micros == 3_600
|
||||||
|
assert call.call_kind == "chat"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_callback_image_missing_response_cost_falls_back_to_zero(monkeypatch):
|
||||||
|
"""When OpenRouter omits ``usage.cost`` LiteLLM's
|
||||||
|
``default_image_cost_calculator`` raises. The defensive image branch in
|
||||||
|
``_extract_cost_usd`` must NOT call ``cost_per_token`` (which is
|
||||||
|
chat-shaped and would raise too) — it returns 0 with a WARNING log.
|
||||||
|
"""
|
||||||
|
import litellm
|
||||||
|
|
||||||
|
from app.services.token_tracking_service import (
|
||||||
|
TokenTrackingCallback,
|
||||||
|
scoped_turn,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Force completion_cost to raise the same way OpenRouter image-gen fails.
|
||||||
|
def _boom(*_args, **_kwargs):
|
||||||
|
raise ValueError("model_cost: missing entry for openrouter image model")
|
||||||
|
|
||||||
|
monkeypatch.setattr(litellm, "completion_cost", _boom, raising=False)
|
||||||
|
|
||||||
|
# And make sure cost_per_token is NEVER called for the image path —
|
||||||
|
# if it were, our ``is_image=True`` branch is broken.
|
||||||
|
cost_per_token_calls: list = []
|
||||||
|
|
||||||
|
def _record_cost_per_token(**kwargs):
|
||||||
|
cost_per_token_calls.append(kwargs)
|
||||||
|
return (0.0, 0.0)
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
litellm, "cost_per_token", _record_cost_per_token, raising=False
|
||||||
|
)
|
||||||
|
|
||||||
|
cb = TokenTrackingCallback()
|
||||||
|
response = _FakeImageResponse(
|
||||||
|
usage=_FakeImageUsage(input_tokens=7, output_tokens=0)
|
||||||
|
)
|
||||||
|
|
||||||
|
async with scoped_turn() as acc:
|
||||||
|
await cb.async_log_success_event(
|
||||||
|
kwargs={"model": "openrouter/google/gemini-2.5-flash-image"},
|
||||||
|
response_obj=response,
|
||||||
|
start_time=None,
|
||||||
|
end_time=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(acc.calls) == 1
|
||||||
|
assert acc.calls[0].cost_micros == 0
|
||||||
|
assert acc.calls[0].call_kind == "image_generation"
|
||||||
|
# The image branch must short-circuit before cost_per_token.
|
||||||
|
assert cost_per_token_calls == []
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# scoped_turn — ContextVar reset semantics (issue B)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_scoped_turn_restores_outer_accumulator():
|
||||||
|
"""``scoped_turn`` must restore the previous ContextVar value on exit
|
||||||
|
so a per-call wrapper inside an outer chat turn doesn't leak its
|
||||||
|
accumulator outward (which would cause double-debit at chat-turn exit).
|
||||||
|
"""
|
||||||
|
from app.services.token_tracking_service import (
|
||||||
|
get_current_accumulator,
|
||||||
|
scoped_turn,
|
||||||
|
start_turn,
|
||||||
|
)
|
||||||
|
|
||||||
|
outer = start_turn()
|
||||||
|
assert get_current_accumulator() is outer
|
||||||
|
|
||||||
|
async with scoped_turn() as inner:
|
||||||
|
assert get_current_accumulator() is inner
|
||||||
|
assert inner is not outer
|
||||||
|
inner.add(
|
||||||
|
model="x",
|
||||||
|
prompt_tokens=1,
|
||||||
|
completion_tokens=1,
|
||||||
|
total_tokens=2,
|
||||||
|
cost_micros=5,
|
||||||
|
)
|
||||||
|
|
||||||
|
# After exit the outer accumulator is restored unchanged.
|
||||||
|
assert get_current_accumulator() is outer
|
||||||
|
assert outer.total_cost_micros == 0
|
||||||
|
assert len(outer.calls) == 0
|
||||||
|
# The inner accumulator captured the call but didn't bleed into outer.
|
||||||
|
assert inner.total_cost_micros == 5
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_scoped_turn_resets_to_none_when_no_outer():
|
||||||
|
"""Running ``scoped_turn`` outside any chat turn (e.g. a background
|
||||||
|
indexing job) must leave the ContextVar at ``None`` on exit so the
|
||||||
|
next *unrelated* request starts clean.
|
||||||
|
"""
|
||||||
|
from app.services.token_tracking_service import (
|
||||||
|
_turn_accumulator,
|
||||||
|
get_current_accumulator,
|
||||||
|
scoped_turn,
|
||||||
|
)
|
||||||
|
|
||||||
|
# ContextVar default is None for a fresh test isolated context. We
|
||||||
|
# simulate "no outer" explicitly to be robust against test order.
|
||||||
|
token = _turn_accumulator.set(None)
|
||||||
|
try:
|
||||||
|
assert get_current_accumulator() is None
|
||||||
|
async with scoped_turn() as acc:
|
||||||
|
assert get_current_accumulator() is acc
|
||||||
|
assert get_current_accumulator() is None
|
||||||
|
finally:
|
||||||
|
_turn_accumulator.reset(token)
|
||||||
325
surfsense_backend/tests/unit/tasks/test_podcast_billing.py
Normal file
325
surfsense_backend/tests/unit/tasks/test_podcast_billing.py
Normal file
|
|
@ -0,0 +1,325 @@
|
||||||
|
"""Unit tests for podcast Celery task billing integration.
|
||||||
|
|
||||||
|
Validates ``_generate_content_podcast`` correctly wraps
|
||||||
|
``podcaster_graph.ainvoke`` in a ``billable_call`` envelope, propagates the
|
||||||
|
search-space owner's billing decision, and degrades cleanly when the
|
||||||
|
resolver fails or premium credit is exhausted.
|
||||||
|
|
||||||
|
Coverage:
|
||||||
|
|
||||||
|
* Happy-path free config: resolver → ``billable_call`` enters with
|
||||||
|
``usage_type='podcast_generation'`` and the configured reserve override,
|
||||||
|
graph runs, podcast row flips to ``READY``.
|
||||||
|
* Happy-path premium config: same wiring with ``billing_tier='premium'``.
|
||||||
|
* Quota denial: ``billable_call`` raises ``QuotaInsufficientError`` →
|
||||||
|
graph is *not* invoked, podcast row flips to ``FAILED``, return dict
|
||||||
|
carries ``reason='premium_quota_exhausted'``.
|
||||||
|
* Resolver failure: ``ValueError`` from the resolver → podcast row flips
|
||||||
|
to ``FAILED``, return dict carries ``reason='billing_resolution_failed'``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import contextlib
|
||||||
|
from types import SimpleNamespace
|
||||||
|
from typing import Any
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.unit
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Fakes
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeExecResult:
|
||||||
|
def __init__(self, obj):
|
||||||
|
self._obj = obj
|
||||||
|
|
||||||
|
def scalars(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def first(self):
|
||||||
|
return self._obj
|
||||||
|
|
||||||
|
def filter(self, *_args, **_kwargs):
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeSession:
|
||||||
|
def __init__(self, podcast):
|
||||||
|
self._podcast = podcast
|
||||||
|
self.commit_count = 0
|
||||||
|
|
||||||
|
async def execute(self, _stmt):
|
||||||
|
return _FakeExecResult(self._podcast)
|
||||||
|
|
||||||
|
async def commit(self):
|
||||||
|
self.commit_count += 1
|
||||||
|
|
||||||
|
async def __aenter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __aexit__(self, *args):
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeSessionMaker:
|
||||||
|
def __init__(self, session: _FakeSession):
|
||||||
|
self._session = session
|
||||||
|
|
||||||
|
def __call__(self):
|
||||||
|
return self._session
|
||||||
|
|
||||||
|
|
||||||
|
def _make_podcast(podcast_id: int = 7, thread_id: int = 99) -> SimpleNamespace:
|
||||||
|
"""Stand-in for a ``Podcast`` row. Importing ``PodcastStatus`` lazily
|
||||||
|
inside helpers keeps this fixture cheap."""
|
||||||
|
return SimpleNamespace(
|
||||||
|
id=podcast_id,
|
||||||
|
title="Test Podcast",
|
||||||
|
thread_id=thread_id,
|
||||||
|
status=None,
|
||||||
|
podcast_transcript=None,
|
||||||
|
file_location=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@contextlib.asynccontextmanager
|
||||||
|
async def _ok_billable_call(**kwargs):
|
||||||
|
"""Stand-in for ``billable_call`` that records its kwargs and yields a
|
||||||
|
no-op accumulator-shaped object."""
|
||||||
|
_CALL_LOG.append(kwargs)
|
||||||
|
yield SimpleNamespace()
|
||||||
|
|
||||||
|
|
||||||
|
_CALL_LOG: list[dict[str, Any]] = []
|
||||||
|
|
||||||
|
|
||||||
|
@contextlib.asynccontextmanager
|
||||||
|
async def _denying_billable_call(**kwargs):
|
||||||
|
from app.services.billable_calls import QuotaInsufficientError
|
||||||
|
|
||||||
|
_CALL_LOG.append(kwargs)
|
||||||
|
raise QuotaInsufficientError(
|
||||||
|
usage_type=kwargs.get("usage_type", "?"),
|
||||||
|
used_micros=5_000_000,
|
||||||
|
limit_micros=5_000_000,
|
||||||
|
remaining_micros=0,
|
||||||
|
)
|
||||||
|
yield SimpleNamespace() # pragma: no cover — for grammar only
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def _reset_call_log():
|
||||||
|
_CALL_LOG.clear()
|
||||||
|
yield
|
||||||
|
_CALL_LOG.clear()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_billable_call_invoked_with_correct_kwargs_for_free_config(monkeypatch):
|
||||||
|
"""Happy path: free billing tier still wraps the graph call so the
|
||||||
|
audit row is recorded. Verifies kwargs threading."""
|
||||||
|
from app.config import config as app_config
|
||||||
|
from app.db import PodcastStatus
|
||||||
|
from app.tasks.celery_tasks import podcast_tasks
|
||||||
|
|
||||||
|
podcast = _make_podcast(podcast_id=7, thread_id=99)
|
||||||
|
session = _FakeSession(podcast)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
podcast_tasks,
|
||||||
|
"get_celery_session_maker",
|
||||||
|
lambda: _FakeSessionMaker(session),
|
||||||
|
)
|
||||||
|
|
||||||
|
user_id = uuid4()
|
||||||
|
|
||||||
|
async def _fake_resolver(sess, search_space_id, *, thread_id=None):
|
||||||
|
assert search_space_id == 555
|
||||||
|
assert thread_id == 99
|
||||||
|
return user_id, "free", "openrouter/some-free-model"
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
podcast_tasks, "_resolve_agent_billing_for_search_space", _fake_resolver
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(podcast_tasks, "billable_call", _ok_billable_call)
|
||||||
|
|
||||||
|
async def _fake_graph_invoke(state, config):
|
||||||
|
return {
|
||||||
|
"podcast_transcript": [
|
||||||
|
SimpleNamespace(speaker_id=0, dialog="Hi"),
|
||||||
|
SimpleNamespace(speaker_id=1, dialog="Hello"),
|
||||||
|
],
|
||||||
|
"final_podcast_file_path": "/tmp/podcast.wav",
|
||||||
|
}
|
||||||
|
|
||||||
|
monkeypatch.setattr(podcast_tasks.podcaster_graph, "ainvoke", _fake_graph_invoke)
|
||||||
|
|
||||||
|
result = await podcast_tasks._generate_content_podcast(
|
||||||
|
podcast_id=7,
|
||||||
|
source_content="hello world",
|
||||||
|
search_space_id=555,
|
||||||
|
user_prompt="make it short",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result["status"] == "ready"
|
||||||
|
assert result["podcast_id"] == 7
|
||||||
|
assert podcast.status == PodcastStatus.READY
|
||||||
|
assert podcast.file_location == "/tmp/podcast.wav"
|
||||||
|
|
||||||
|
assert len(_CALL_LOG) == 1
|
||||||
|
call = _CALL_LOG[0]
|
||||||
|
assert call["user_id"] == user_id
|
||||||
|
assert call["search_space_id"] == 555
|
||||||
|
assert call["billing_tier"] == "free"
|
||||||
|
assert call["base_model"] == "openrouter/some-free-model"
|
||||||
|
assert call["usage_type"] == "podcast_generation"
|
||||||
|
assert (
|
||||||
|
call["quota_reserve_micros_override"]
|
||||||
|
== app_config.QUOTA_DEFAULT_PODCAST_RESERVE_MICROS
|
||||||
|
)
|
||||||
|
assert call["thread_id"] == 99
|
||||||
|
assert call["call_details"] == {"podcast_id": 7, "title": "Test Podcast"}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_billable_call_invoked_with_premium_tier(monkeypatch):
|
||||||
|
"""Premium resolution flows through to ``billable_call`` so the
|
||||||
|
reserve/finalize path triggers."""
|
||||||
|
from app.tasks.celery_tasks import podcast_tasks
|
||||||
|
|
||||||
|
podcast = _make_podcast()
|
||||||
|
session = _FakeSession(podcast)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
podcast_tasks,
|
||||||
|
"get_celery_session_maker",
|
||||||
|
lambda: _FakeSessionMaker(session),
|
||||||
|
)
|
||||||
|
|
||||||
|
user_id = uuid4()
|
||||||
|
|
||||||
|
async def _fake_resolver(sess, search_space_id, *, thread_id=None):
|
||||||
|
return user_id, "premium", "gpt-5.4"
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
podcast_tasks, "_resolve_agent_billing_for_search_space", _fake_resolver
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(podcast_tasks, "billable_call", _ok_billable_call)
|
||||||
|
|
||||||
|
async def _fake_graph_invoke(state, config):
|
||||||
|
return {"podcast_transcript": [], "final_podcast_file_path": "x.wav"}
|
||||||
|
|
||||||
|
monkeypatch.setattr(podcast_tasks.podcaster_graph, "ainvoke", _fake_graph_invoke)
|
||||||
|
|
||||||
|
await podcast_tasks._generate_content_podcast(
|
||||||
|
podcast_id=7,
|
||||||
|
source_content="hi",
|
||||||
|
search_space_id=555,
|
||||||
|
user_prompt=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert _CALL_LOG[0]["billing_tier"] == "premium"
|
||||||
|
assert _CALL_LOG[0]["base_model"] == "gpt-5.4"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_quota_insufficient_marks_podcast_failed_and_skips_graph(monkeypatch):
|
||||||
|
"""When ``billable_call`` denies the reservation, the graph never
|
||||||
|
runs and the podcast row flips to FAILED with the documented reason
|
||||||
|
code."""
|
||||||
|
from app.db import PodcastStatus
|
||||||
|
from app.tasks.celery_tasks import podcast_tasks
|
||||||
|
|
||||||
|
podcast = _make_podcast(podcast_id=8)
|
||||||
|
session = _FakeSession(podcast)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
podcast_tasks,
|
||||||
|
"get_celery_session_maker",
|
||||||
|
lambda: _FakeSessionMaker(session),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _fake_resolver(sess, search_space_id, *, thread_id=None):
|
||||||
|
return uuid4(), "premium", "gpt-5.4"
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
podcast_tasks, "_resolve_agent_billing_for_search_space", _fake_resolver
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(podcast_tasks, "billable_call", _denying_billable_call)
|
||||||
|
|
||||||
|
graph_invoked = []
|
||||||
|
|
||||||
|
async def _fake_graph_invoke(state, config):
|
||||||
|
graph_invoked.append(True)
|
||||||
|
return {}
|
||||||
|
|
||||||
|
monkeypatch.setattr(podcast_tasks.podcaster_graph, "ainvoke", _fake_graph_invoke)
|
||||||
|
|
||||||
|
result = await podcast_tasks._generate_content_podcast(
|
||||||
|
podcast_id=8,
|
||||||
|
source_content="hi",
|
||||||
|
search_space_id=555,
|
||||||
|
user_prompt=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result == {
|
||||||
|
"status": "failed",
|
||||||
|
"podcast_id": 8,
|
||||||
|
"reason": "premium_quota_exhausted",
|
||||||
|
}
|
||||||
|
assert podcast.status == PodcastStatus.FAILED
|
||||||
|
assert graph_invoked == [] # Graph never ran on denied reservation.
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_resolver_failure_marks_podcast_failed(monkeypatch):
|
||||||
|
"""If the resolver raises (e.g. search-space deleted), the task fails
|
||||||
|
cleanly without invoking the graph."""
|
||||||
|
from app.db import PodcastStatus
|
||||||
|
from app.tasks.celery_tasks import podcast_tasks
|
||||||
|
|
||||||
|
podcast = _make_podcast(podcast_id=9)
|
||||||
|
session = _FakeSession(podcast)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
podcast_tasks,
|
||||||
|
"get_celery_session_maker",
|
||||||
|
lambda: _FakeSessionMaker(session),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _failing_resolver(sess, search_space_id, *, thread_id=None):
|
||||||
|
raise ValueError("Search space 555 not found")
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
podcast_tasks, "_resolve_agent_billing_for_search_space", _failing_resolver
|
||||||
|
)
|
||||||
|
|
||||||
|
graph_invoked = []
|
||||||
|
|
||||||
|
async def _fake_graph_invoke(state, config):
|
||||||
|
graph_invoked.append(True)
|
||||||
|
return {}
|
||||||
|
|
||||||
|
monkeypatch.setattr(podcast_tasks.podcaster_graph, "ainvoke", _fake_graph_invoke)
|
||||||
|
|
||||||
|
result = await podcast_tasks._generate_content_podcast(
|
||||||
|
podcast_id=9,
|
||||||
|
source_content="hi",
|
||||||
|
search_space_id=555,
|
||||||
|
user_prompt=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result == {
|
||||||
|
"status": "failed",
|
||||||
|
"podcast_id": 9,
|
||||||
|
"reason": "billing_resolution_failed",
|
||||||
|
}
|
||||||
|
assert podcast.status == PodcastStatus.FAILED
|
||||||
|
assert graph_invoked == []
|
||||||
|
|
@ -0,0 +1,330 @@
|
||||||
|
"""Unit tests for video-presentation Celery task billing integration.
|
||||||
|
|
||||||
|
Mirrors ``test_podcast_billing.py`` for the video-presentation task.
|
||||||
|
Validates the same wrap-graph-in-billable_call pattern and ensures the
|
||||||
|
larger ``QUOTA_DEFAULT_VIDEO_PRESENTATION_RESERVE_MICROS`` reservation is
|
||||||
|
threaded through.
|
||||||
|
|
||||||
|
Coverage:
|
||||||
|
|
||||||
|
* Free config: graph runs, ``billable_call`` invoked with the video
|
||||||
|
reserve override.
|
||||||
|
* Premium config: same wiring with ``billing_tier='premium'``.
|
||||||
|
* Quota denial: graph not invoked, row → FAILED, reason code surfaced.
|
||||||
|
* Resolver failure: row → FAILED with ``billing_resolution_failed``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import contextlib
|
||||||
|
from types import SimpleNamespace
|
||||||
|
from typing import Any
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.unit
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Fakes
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeExecResult:
|
||||||
|
def __init__(self, obj):
|
||||||
|
self._obj = obj
|
||||||
|
|
||||||
|
def scalars(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def first(self):
|
||||||
|
return self._obj
|
||||||
|
|
||||||
|
def filter(self, *_args, **_kwargs):
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeSession:
|
||||||
|
def __init__(self, video):
|
||||||
|
self._video = video
|
||||||
|
self.commit_count = 0
|
||||||
|
|
||||||
|
async def execute(self, _stmt):
|
||||||
|
return _FakeExecResult(self._video)
|
||||||
|
|
||||||
|
async def commit(self):
|
||||||
|
self.commit_count += 1
|
||||||
|
|
||||||
|
async def __aenter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __aexit__(self, *args):
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeSessionMaker:
|
||||||
|
def __init__(self, session: _FakeSession):
|
||||||
|
self._session = session
|
||||||
|
|
||||||
|
def __call__(self):
|
||||||
|
return self._session
|
||||||
|
|
||||||
|
|
||||||
|
def _make_video(video_id: int = 11, thread_id: int = 99) -> SimpleNamespace:
|
||||||
|
return SimpleNamespace(
|
||||||
|
id=video_id,
|
||||||
|
title="Test Presentation",
|
||||||
|
thread_id=thread_id,
|
||||||
|
status=None,
|
||||||
|
slides=None,
|
||||||
|
scene_codes=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
_CALL_LOG: list[dict[str, Any]] = []
|
||||||
|
|
||||||
|
|
||||||
|
@contextlib.asynccontextmanager
|
||||||
|
async def _ok_billable_call(**kwargs):
|
||||||
|
_CALL_LOG.append(kwargs)
|
||||||
|
yield SimpleNamespace()
|
||||||
|
|
||||||
|
|
||||||
|
@contextlib.asynccontextmanager
|
||||||
|
async def _denying_billable_call(**kwargs):
|
||||||
|
from app.services.billable_calls import QuotaInsufficientError
|
||||||
|
|
||||||
|
_CALL_LOG.append(kwargs)
|
||||||
|
raise QuotaInsufficientError(
|
||||||
|
usage_type=kwargs.get("usage_type", "?"),
|
||||||
|
used_micros=5_000_000,
|
||||||
|
limit_micros=5_000_000,
|
||||||
|
remaining_micros=0,
|
||||||
|
)
|
||||||
|
yield SimpleNamespace() # pragma: no cover
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def _reset_call_log():
|
||||||
|
_CALL_LOG.clear()
|
||||||
|
yield
|
||||||
|
_CALL_LOG.clear()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_billable_call_invoked_with_correct_kwargs_for_free_config(monkeypatch):
|
||||||
|
from app.config import config as app_config
|
||||||
|
from app.db import VideoPresentationStatus
|
||||||
|
from app.tasks.celery_tasks import video_presentation_tasks
|
||||||
|
|
||||||
|
video = _make_video(video_id=11, thread_id=99)
|
||||||
|
session = _FakeSession(video)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
video_presentation_tasks,
|
||||||
|
"get_celery_session_maker",
|
||||||
|
lambda: _FakeSessionMaker(session),
|
||||||
|
)
|
||||||
|
|
||||||
|
user_id = uuid4()
|
||||||
|
|
||||||
|
async def _fake_resolver(sess, search_space_id, *, thread_id=None):
|
||||||
|
assert search_space_id == 777
|
||||||
|
assert thread_id == 99
|
||||||
|
return user_id, "free", "openrouter/some-free-model"
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
video_presentation_tasks,
|
||||||
|
"_resolve_agent_billing_for_search_space",
|
||||||
|
_fake_resolver,
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(video_presentation_tasks, "billable_call", _ok_billable_call)
|
||||||
|
|
||||||
|
async def _fake_graph_invoke(state, config):
|
||||||
|
return {"slides": [], "slide_audio_results": [], "slide_scene_codes": []}
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
video_presentation_tasks.video_presentation_graph,
|
||||||
|
"ainvoke",
|
||||||
|
_fake_graph_invoke,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await video_presentation_tasks._generate_video_presentation(
|
||||||
|
video_presentation_id=11,
|
||||||
|
source_content="content",
|
||||||
|
search_space_id=777,
|
||||||
|
user_prompt=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result["status"] == "ready"
|
||||||
|
assert result["video_presentation_id"] == 11
|
||||||
|
assert video.status == VideoPresentationStatus.READY
|
||||||
|
|
||||||
|
assert len(_CALL_LOG) == 1
|
||||||
|
call = _CALL_LOG[0]
|
||||||
|
assert call["user_id"] == user_id
|
||||||
|
assert call["search_space_id"] == 777
|
||||||
|
assert call["billing_tier"] == "free"
|
||||||
|
assert call["base_model"] == "openrouter/some-free-model"
|
||||||
|
assert call["usage_type"] == "video_presentation_generation"
|
||||||
|
assert (
|
||||||
|
call["quota_reserve_micros_override"]
|
||||||
|
== app_config.QUOTA_DEFAULT_VIDEO_PRESENTATION_RESERVE_MICROS
|
||||||
|
)
|
||||||
|
assert call["thread_id"] == 99
|
||||||
|
assert call["call_details"] == {
|
||||||
|
"video_presentation_id": 11,
|
||||||
|
"title": "Test Presentation",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_billable_call_invoked_with_premium_tier(monkeypatch):
|
||||||
|
from app.tasks.celery_tasks import video_presentation_tasks
|
||||||
|
|
||||||
|
video = _make_video()
|
||||||
|
session = _FakeSession(video)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
video_presentation_tasks,
|
||||||
|
"get_celery_session_maker",
|
||||||
|
lambda: _FakeSessionMaker(session),
|
||||||
|
)
|
||||||
|
|
||||||
|
user_id = uuid4()
|
||||||
|
|
||||||
|
async def _fake_resolver(sess, search_space_id, *, thread_id=None):
|
||||||
|
return user_id, "premium", "gpt-5.4"
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
video_presentation_tasks,
|
||||||
|
"_resolve_agent_billing_for_search_space",
|
||||||
|
_fake_resolver,
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(video_presentation_tasks, "billable_call", _ok_billable_call)
|
||||||
|
|
||||||
|
async def _fake_graph_invoke(state, config):
|
||||||
|
return {"slides": [], "slide_audio_results": [], "slide_scene_codes": []}
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
video_presentation_tasks.video_presentation_graph,
|
||||||
|
"ainvoke",
|
||||||
|
_fake_graph_invoke,
|
||||||
|
)
|
||||||
|
|
||||||
|
await video_presentation_tasks._generate_video_presentation(
|
||||||
|
video_presentation_id=11,
|
||||||
|
source_content="content",
|
||||||
|
search_space_id=777,
|
||||||
|
user_prompt=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert _CALL_LOG[0]["billing_tier"] == "premium"
|
||||||
|
assert _CALL_LOG[0]["base_model"] == "gpt-5.4"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_quota_insufficient_marks_video_failed_and_skips_graph(monkeypatch):
|
||||||
|
from app.db import VideoPresentationStatus
|
||||||
|
from app.tasks.celery_tasks import video_presentation_tasks
|
||||||
|
|
||||||
|
video = _make_video(video_id=12)
|
||||||
|
session = _FakeSession(video)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
video_presentation_tasks,
|
||||||
|
"get_celery_session_maker",
|
||||||
|
lambda: _FakeSessionMaker(session),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _fake_resolver(sess, search_space_id, *, thread_id=None):
|
||||||
|
return uuid4(), "premium", "gpt-5.4"
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
video_presentation_tasks,
|
||||||
|
"_resolve_agent_billing_for_search_space",
|
||||||
|
_fake_resolver,
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
video_presentation_tasks, "billable_call", _denying_billable_call
|
||||||
|
)
|
||||||
|
|
||||||
|
graph_invoked = []
|
||||||
|
|
||||||
|
async def _fake_graph_invoke(state, config):
|
||||||
|
graph_invoked.append(True)
|
||||||
|
return {}
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
video_presentation_tasks.video_presentation_graph,
|
||||||
|
"ainvoke",
|
||||||
|
_fake_graph_invoke,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await video_presentation_tasks._generate_video_presentation(
|
||||||
|
video_presentation_id=12,
|
||||||
|
source_content="content",
|
||||||
|
search_space_id=777,
|
||||||
|
user_prompt=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result == {
|
||||||
|
"status": "failed",
|
||||||
|
"video_presentation_id": 12,
|
||||||
|
"reason": "premium_quota_exhausted",
|
||||||
|
}
|
||||||
|
assert video.status == VideoPresentationStatus.FAILED
|
||||||
|
assert graph_invoked == []
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_resolver_failure_marks_video_failed(monkeypatch):
|
||||||
|
from app.db import VideoPresentationStatus
|
||||||
|
from app.tasks.celery_tasks import video_presentation_tasks
|
||||||
|
|
||||||
|
video = _make_video(video_id=13)
|
||||||
|
session = _FakeSession(video)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
video_presentation_tasks,
|
||||||
|
"get_celery_session_maker",
|
||||||
|
lambda: _FakeSessionMaker(session),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _failing_resolver(sess, search_space_id, *, thread_id=None):
|
||||||
|
raise ValueError("Search space 777 not found")
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
video_presentation_tasks,
|
||||||
|
"_resolve_agent_billing_for_search_space",
|
||||||
|
_failing_resolver,
|
||||||
|
)
|
||||||
|
|
||||||
|
graph_invoked = []
|
||||||
|
|
||||||
|
async def _fake_graph_invoke(state, config):
|
||||||
|
graph_invoked.append(True)
|
||||||
|
return {}
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
video_presentation_tasks.video_presentation_graph,
|
||||||
|
"ainvoke",
|
||||||
|
_fake_graph_invoke,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await video_presentation_tasks._generate_video_presentation(
|
||||||
|
video_presentation_id=13,
|
||||||
|
source_content="content",
|
||||||
|
search_space_id=777,
|
||||||
|
user_prompt=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result == {
|
||||||
|
"status": "failed",
|
||||||
|
"video_presentation_id": 13,
|
||||||
|
"reason": "billing_resolution_failed",
|
||||||
|
}
|
||||||
|
assert video.status == VideoPresentationStatus.FAILED
|
||||||
|
assert graph_invoked == []
|
||||||
|
|
@ -127,7 +127,7 @@ const FAQ_ITEMS = [
|
||||||
{
|
{
|
||||||
question: "What happens after I use my free tokens?",
|
question: "What happens after I use my free tokens?",
|
||||||
answer:
|
answer:
|
||||||
"After your free tokens, create a free SurfSense account to unlock 3 million more premium tokens. Additional tokens can be purchased at $1 per million. Non-premium models remain unlimited for registered users.",
|
"After your free tokens, create a free SurfSense account to unlock $5 of premium credit. Additional credit can be topped up at $1 for $1 of credit, billed at the actual provider cost. Non-premium models remain unlimited for registered users.",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
question: "Is Claude AI available without login?",
|
question: "Is Claude AI available without login?",
|
||||||
|
|
@ -329,7 +329,7 @@ export default async function FreeHubPage() {
|
||||||
<section className="max-w-3xl mx-auto text-center">
|
<section className="max-w-3xl mx-auto text-center">
|
||||||
<h2 className="text-2xl font-bold mb-3">Want More Features?</h2>
|
<h2 className="text-2xl font-bold mb-3">Want More Features?</h2>
|
||||||
<p className="text-muted-foreground mb-6 leading-relaxed">
|
<p className="text-muted-foreground mb-6 leading-relaxed">
|
||||||
Create a free SurfSense account to unlock 3 million tokens, document uploads with
|
Create a free SurfSense account to unlock $5 of premium credit, document uploads with
|
||||||
citations, team collaboration, and integrations with Slack, Google Drive, Notion, and
|
citations, team collaboration, and integrations with Slack, Google Drive, Notion, and
|
||||||
30+ more tools.
|
30+ more tools.
|
||||||
</p>
|
</p>
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,7 @@ import { BreadcrumbNav } from "@/components/seo/breadcrumb-nav";
|
||||||
export const metadata: Metadata = {
|
export const metadata: Metadata = {
|
||||||
title: "Pricing | SurfSense - Free AI Search Plans",
|
title: "Pricing | SurfSense - Free AI Search Plans",
|
||||||
description:
|
description:
|
||||||
"Explore SurfSense plans and pricing. Start free with 500 pages & 3M premium tokens. Use ChatGPT, Claude AI, and premium AI models. Pay-as-you-go tokens at $1 per million.",
|
"Explore SurfSense plans and pricing. Start free with 500 pages & $5 of premium credit. Use ChatGPT, Claude AI, and premium AI models. Pay as you go at provider cost — $1 buys $1 of credit.",
|
||||||
alternates: {
|
alternates: {
|
||||||
canonical: "https://surfsense.com/pricing",
|
canonical: "https://surfsense.com/pricing",
|
||||||
},
|
},
|
||||||
|
|
|
||||||
|
|
@ -8,7 +8,7 @@ import { cn } from "@/lib/utils";
|
||||||
|
|
||||||
const TABS = [
|
const TABS = [
|
||||||
{ id: "pages", label: "Pages" },
|
{ id: "pages", label: "Pages" },
|
||||||
{ id: "tokens", label: "Premium Tokens" },
|
{ id: "tokens", label: "Premium Credit" },
|
||||||
] as const;
|
] as const;
|
||||||
|
|
||||||
type TabId = (typeof TABS)[number]["id"];
|
type TabId = (typeof TABS)[number]["id"];
|
||||||
|
|
|
||||||
|
|
@ -28,6 +28,12 @@ type UnifiedPurchase = {
|
||||||
kind: PurchaseKind;
|
kind: PurchaseKind;
|
||||||
created_at: string;
|
created_at: string;
|
||||||
status: PagePurchaseStatus;
|
status: PagePurchaseStatus;
|
||||||
|
/**
|
||||||
|
* Granted units. Interpretation depends on ``kind``:
|
||||||
|
* - ``"pages"`` — integer number of indexed pages.
|
||||||
|
* - ``"tokens"`` — integer micro-USD of credit (1_000_000 = $1.00).
|
||||||
|
* The ``Granted`` column formats accordingly.
|
||||||
|
*/
|
||||||
granted: number;
|
granted: number;
|
||||||
amount_total: number | null;
|
amount_total: number | null;
|
||||||
currency: string | null;
|
currency: string | null;
|
||||||
|
|
@ -58,7 +64,7 @@ const KIND_META: Record<
|
||||||
iconClass: "text-sky-500",
|
iconClass: "text-sky-500",
|
||||||
},
|
},
|
||||||
tokens: {
|
tokens: {
|
||||||
label: "Premium Tokens",
|
label: "Premium Credit",
|
||||||
icon: Coins,
|
icon: Coins,
|
||||||
iconClass: "text-amber-500",
|
iconClass: "text-amber-500",
|
||||||
},
|
},
|
||||||
|
|
@ -97,12 +103,25 @@ function normalizeTokenPurchase(p: TokenPurchase): UnifiedPurchase {
|
||||||
kind: "tokens",
|
kind: "tokens",
|
||||||
created_at: p.created_at,
|
created_at: p.created_at,
|
||||||
status: p.status,
|
status: p.status,
|
||||||
granted: p.tokens_granted,
|
granted: p.credit_micros_granted,
|
||||||
amount_total: p.amount_total,
|
amount_total: p.amount_total,
|
||||||
currency: p.currency,
|
currency: p.currency,
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function formatGranted(p: UnifiedPurchase): string {
|
||||||
|
if (p.kind === "tokens") {
|
||||||
|
const dollars = p.granted / 1_000_000;
|
||||||
|
// Premium credit packs are always whole dollars at the moment, but
|
||||||
|
// future fractional grants (refunds, partial top-ups) shouldn't
|
||||||
|
// silently round to "$0".
|
||||||
|
if (dollars >= 1) return `$${dollars.toFixed(2)} of credit`;
|
||||||
|
if (dollars > 0) return `$${dollars.toFixed(3)} of credit`;
|
||||||
|
return "$0 of credit";
|
||||||
|
}
|
||||||
|
return p.granted.toLocaleString();
|
||||||
|
}
|
||||||
|
|
||||||
export function PurchaseHistoryContent() {
|
export function PurchaseHistoryContent() {
|
||||||
const results = useQueries({
|
const results = useQueries({
|
||||||
queries: [
|
queries: [
|
||||||
|
|
@ -143,7 +162,7 @@ export function PurchaseHistoryContent() {
|
||||||
<ReceiptText className="h-8 w-8 text-muted-foreground" />
|
<ReceiptText className="h-8 w-8 text-muted-foreground" />
|
||||||
<p className="text-sm font-medium">No purchases yet</p>
|
<p className="text-sm font-medium">No purchases yet</p>
|
||||||
<p className="text-xs text-muted-foreground">
|
<p className="text-xs text-muted-foreground">
|
||||||
Your page and premium token purchases will appear here after checkout.
|
Your page and premium credit purchases will appear here after checkout.
|
||||||
</p>
|
</p>
|
||||||
</div>
|
</div>
|
||||||
);
|
);
|
||||||
|
|
@ -177,7 +196,7 @@ export function PurchaseHistoryContent() {
|
||||||
</div>
|
</div>
|
||||||
</TableCell>
|
</TableCell>
|
||||||
<TableCell className="text-right tabular-nums text-sm">
|
<TableCell className="text-right tabular-nums text-sm">
|
||||||
{p.granted.toLocaleString()}
|
{formatGranted(p)}
|
||||||
</TableCell>
|
</TableCell>
|
||||||
<TableCell className="text-right tabular-nums text-sm">
|
<TableCell className="text-right tabular-nums text-sm">
|
||||||
{formatAmount(p.amount_total, p.currency)}
|
{formatAmount(p.amount_total, p.currency)}
|
||||||
|
|
|
||||||
|
|
@ -8,9 +8,9 @@ const userQueryFn = () => userApiService.getMe();
|
||||||
export const currentUserAtom = atomWithQuery(() => {
|
export const currentUserAtom = atomWithQuery(() => {
|
||||||
return {
|
return {
|
||||||
queryKey: USER_QUERY_KEY,
|
queryKey: USER_QUERY_KEY,
|
||||||
// Live-changing numeric fields (pages_*, premium_tokens_*) are now
|
// Live-changing numeric fields (pages_*, premium_credit_micros_*)
|
||||||
// pushed via Zero (queries.user.me()), so /users/me only needs to
|
// are now pushed via Zero (queries.user.me()), so /users/me only
|
||||||
// fire once per session for the static profile fields.
|
// needs to fire once per session for the static profile fields.
|
||||||
staleTime: Infinity,
|
staleTime: Infinity,
|
||||||
enabled: !!getBearerToken(),
|
enabled: !!getBearerToken(),
|
||||||
queryFn: userQueryFn,
|
queryFn: userQueryFn,
|
||||||
|
|
|
||||||
|
|
@ -399,6 +399,19 @@ function formatMessageDate(date: Date): string {
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Format provider USD cost (in micro-USD) for inline display next to a
|
||||||
|
* token count. Falls back to ``"<$0.001"`` for sub-tenth-of-a-cent
|
||||||
|
* costs so a real-but-tiny figure doesn't render as ``$0.000``.
|
||||||
|
*/
|
||||||
|
function formatTurnCost(micros: number): string {
|
||||||
|
const dollars = micros / 1_000_000;
|
||||||
|
if (dollars >= 1) return `$${dollars.toFixed(2)}`;
|
||||||
|
if (dollars >= 0.01) return `$${dollars.toFixed(3)}`;
|
||||||
|
if (dollars > 0) return "<$0.001";
|
||||||
|
return "$0";
|
||||||
|
}
|
||||||
|
|
||||||
const MessageInfoDropdown: FC = () => {
|
const MessageInfoDropdown: FC = () => {
|
||||||
const messageId = useAuiState(({ message }) => message?.id);
|
const messageId = useAuiState(({ message }) => message?.id);
|
||||||
const createdAt = useAuiState(({ message }) => message?.createdAt);
|
const createdAt = useAuiState(({ message }) => message?.createdAt);
|
||||||
|
|
@ -451,6 +464,7 @@ const MessageInfoDropdown: FC = () => {
|
||||||
{models.length > 0 ? (
|
{models.length > 0 ? (
|
||||||
models.map(([model, counts]) => {
|
models.map(([model, counts]) => {
|
||||||
const { name, icon } = resolveModel(model);
|
const { name, icon } = resolveModel(model);
|
||||||
|
const costMicros = counts.cost_micros;
|
||||||
return (
|
return (
|
||||||
<ActionBarMorePrimitive.Item
|
<ActionBarMorePrimitive.Item
|
||||||
key={model}
|
key={model}
|
||||||
|
|
@ -463,6 +477,9 @@ const MessageInfoDropdown: FC = () => {
|
||||||
</span>
|
</span>
|
||||||
<span className="text-xs text-muted-foreground">
|
<span className="text-xs text-muted-foreground">
|
||||||
{counts.total_tokens.toLocaleString()} tokens
|
{counts.total_tokens.toLocaleString()} tokens
|
||||||
|
{costMicros && costMicros > 0
|
||||||
|
? ` · ${formatTurnCost(costMicros)}`
|
||||||
|
: ""}
|
||||||
</span>
|
</span>
|
||||||
</ActionBarMorePrimitive.Item>
|
</ActionBarMorePrimitive.Item>
|
||||||
);
|
);
|
||||||
|
|
@ -474,6 +491,9 @@ const MessageInfoDropdown: FC = () => {
|
||||||
>
|
>
|
||||||
<span className="text-xs text-muted-foreground">
|
<span className="text-xs text-muted-foreground">
|
||||||
{usage.total_tokens.toLocaleString()} tokens
|
{usage.total_tokens.toLocaleString()} tokens
|
||||||
|
{usage.cost_micros && usage.cost_micros > 0
|
||||||
|
? ` · ${formatTurnCost(usage.cost_micros)}`
|
||||||
|
: ""}
|
||||||
</span>
|
</span>
|
||||||
</ActionBarMorePrimitive.Item>
|
</ActionBarMorePrimitive.Item>
|
||||||
)}
|
)}
|
||||||
|
|
|
||||||
|
|
@ -13,13 +13,30 @@ export interface TokenUsageData {
|
||||||
prompt_tokens: number;
|
prompt_tokens: number;
|
||||||
completion_tokens: number;
|
completion_tokens: number;
|
||||||
total_tokens: number;
|
total_tokens: number;
|
||||||
|
/**
|
||||||
|
* Total provider USD cost for this assistant turn, in micro-USD
|
||||||
|
* (1_000_000 = $1.00). Populated from LiteLLM's response_cost on
|
||||||
|
* the backend. Optional because pre-cost-credits messages persisted
|
||||||
|
* before the migration won't have it.
|
||||||
|
*/
|
||||||
|
cost_micros?: number;
|
||||||
usage?: Record<
|
usage?: Record<
|
||||||
string,
|
string,
|
||||||
{ prompt_tokens: number; completion_tokens: number; total_tokens: number }
|
{
|
||||||
|
prompt_tokens: number;
|
||||||
|
completion_tokens: number;
|
||||||
|
total_tokens: number;
|
||||||
|
cost_micros?: number;
|
||||||
|
}
|
||||||
>;
|
>;
|
||||||
model_breakdown?: Record<
|
model_breakdown?: Record<
|
||||||
string,
|
string,
|
||||||
{ prompt_tokens: number; completion_tokens: number; total_tokens: number }
|
{
|
||||||
|
prompt_tokens: number;
|
||||||
|
completion_tokens: number;
|
||||||
|
total_tokens: number;
|
||||||
|
cost_micros?: number;
|
||||||
|
}
|
||||||
>;
|
>;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -40,7 +40,7 @@ export function QuotaWarningBanner({
|
||||||
</p>
|
</p>
|
||||||
<p className="text-xs text-red-600 dark:text-red-300">
|
<p className="text-xs text-red-600 dark:text-red-300">
|
||||||
You've used all {limit.toLocaleString()} free tokens. Create a free account to
|
You've used all {limit.toLocaleString()} free tokens. Create a free account to
|
||||||
get 3 million tokens and access to all models.
|
get $5 of premium credit and access to all models.
|
||||||
</p>
|
</p>
|
||||||
<Link
|
<Link
|
||||||
href="/register"
|
href="/register"
|
||||||
|
|
@ -69,7 +69,7 @@ export function QuotaWarningBanner({
|
||||||
<Link href="/register" className="font-medium underline hover:no-underline">
|
<Link href="/register" className="font-medium underline hover:no-underline">
|
||||||
Create an account
|
Create an account
|
||||||
</Link>{" "}
|
</Link>{" "}
|
||||||
for 5M free tokens.
|
for $5 of premium credit.
|
||||||
</p>
|
</p>
|
||||||
<button
|
<button
|
||||||
type="button"
|
type="button"
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,14 @@ import { Progress } from "@/components/ui/progress";
|
||||||
import { useIsAnonymous } from "@/contexts/anonymous-mode";
|
import { useIsAnonymous } from "@/contexts/anonymous-mode";
|
||||||
import { queries } from "@/zero/queries";
|
import { queries } from "@/zero/queries";
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Premium credit balance shown in the sidebar.
|
||||||
|
*
|
||||||
|
* Values come from Zero (live-replicated from Postgres) and are stored as
|
||||||
|
* integer micro-USD (1_000_000 == $1.00). We render in dollars because
|
||||||
|
* users top up at $1/pack and the credit gets debited at actual provider
|
||||||
|
* cost.
|
||||||
|
*/
|
||||||
export function PremiumTokenUsageDisplay() {
|
export function PremiumTokenUsageDisplay() {
|
||||||
const isAnonymous = useIsAnonymous();
|
const isAnonymous = useIsAnonymous();
|
||||||
const [me] = useQuery(queries.user.me({}));
|
const [me] = useQuery(queries.user.me({}));
|
||||||
|
|
@ -12,21 +20,26 @@ export function PremiumTokenUsageDisplay() {
|
||||||
if (isAnonymous || !me) return null;
|
if (isAnonymous || !me) return null;
|
||||||
|
|
||||||
const usagePercentage = Math.min(
|
const usagePercentage = Math.min(
|
||||||
(me.premiumTokensUsed / Math.max(me.premiumTokensLimit, 1)) * 100,
|
(me.premiumCreditMicrosUsed / Math.max(me.premiumCreditMicrosLimit, 1)) * 100,
|
||||||
100
|
100
|
||||||
);
|
);
|
||||||
|
|
||||||
const formatTokens = (n: number) => {
|
const formatUsd = (micros: number) => {
|
||||||
if (n >= 1_000_000) return `${(n / 1_000_000).toFixed(1)}M`;
|
const dollars = micros / 1_000_000;
|
||||||
if (n >= 1_000) return `${(n / 1_000).toFixed(0)}K`;
|
if (dollars >= 100) return `$${dollars.toFixed(0)}`;
|
||||||
return n.toLocaleString();
|
if (dollars >= 1) return `$${dollars.toFixed(2)}`;
|
||||||
|
// Sub-dollar balances need extra precision so the bar still tells the
|
||||||
|
// user what's left ("$0.04 of credit") instead of rounding to "$0".
|
||||||
|
if (dollars > 0) return `$${dollars.toFixed(3)}`;
|
||||||
|
return "$0";
|
||||||
};
|
};
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div className="space-y-1.5">
|
<div className="space-y-1.5">
|
||||||
<div className="flex justify-between items-center text-xs">
|
<div className="flex justify-between items-center text-xs">
|
||||||
<span className="text-muted-foreground">
|
<span className="text-muted-foreground">
|
||||||
{formatTokens(me.premiumTokensUsed)} / {formatTokens(me.premiumTokensLimit)} tokens
|
{formatUsd(me.premiumCreditMicrosUsed)} / {formatUsd(me.premiumCreditMicrosLimit)} of
|
||||||
|
credit
|
||||||
</span>
|
</span>
|
||||||
<span className="font-medium">{usagePercentage.toFixed(0)}%</span>
|
<span className="font-medium">{usagePercentage.toFixed(0)}%</span>
|
||||||
</div>
|
</div>
|
||||||
|
|
|
||||||
|
|
@ -12,11 +12,11 @@ const demoPlans = [
|
||||||
price: "0",
|
price: "0",
|
||||||
yearlyPrice: "0",
|
yearlyPrice: "0",
|
||||||
period: "",
|
period: "",
|
||||||
billingText: "500 pages + 3M premium tokens included",
|
billingText: "500 pages + $5 of premium credit included",
|
||||||
features: [
|
features: [
|
||||||
"Self Hostable",
|
"Self Hostable",
|
||||||
"500 pages included to start",
|
"500 pages included to start",
|
||||||
"3 million premium tokens to start",
|
"$5 of premium credit to start, billed at provider cost",
|
||||||
"Includes access to OpenAI text, audio and image models",
|
"Includes access to OpenAI text, audio and image models",
|
||||||
"Realtime Collaborative Group Chats with teammates",
|
"Realtime Collaborative Group Chats with teammates",
|
||||||
"Community support on Discord",
|
"Community support on Discord",
|
||||||
|
|
@ -35,7 +35,7 @@ const demoPlans = [
|
||||||
features: [
|
features: [
|
||||||
"Everything in Free",
|
"Everything in Free",
|
||||||
"Buy 1,000-page packs at $1 each",
|
"Buy 1,000-page packs at $1 each",
|
||||||
"Buy 1M premium token packs at $1 each",
|
"Top up premium credit at $1 per $1 of credit, billed at provider cost",
|
||||||
"Use premium AI models like GPT-5.4, Claude Sonnet 4.6, Gemini 2.5 Pro & 100+ more via OpenRouter",
|
"Use premium AI models like GPT-5.4, Claude Sonnet 4.6, Gemini 2.5 Pro & 100+ more via OpenRouter",
|
||||||
"Priority support on Discord",
|
"Priority support on Discord",
|
||||||
],
|
],
|
||||||
|
|
@ -129,27 +129,27 @@ const faqData: FAQSection[] = [
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
title: "Premium Tokens",
|
title: "Premium Credit",
|
||||||
items: [
|
items: [
|
||||||
{
|
{
|
||||||
question: 'What are "premium tokens"?',
|
question: 'What is "premium credit"?',
|
||||||
answer:
|
answer:
|
||||||
"Premium tokens are the billing unit for using premium AI models like GPT-5.4, Claude Sonnet 4.6, and Gemini 2.5 Pro in SurfSense. Each AI request consumes tokens based on the length of your conversation. Non-premium models (such as free-tier models available without login) do not consume premium tokens.",
|
"Premium credit is your USD balance for using premium AI models like GPT-5.4, Claude Sonnet 4.6, and Gemini 2.5 Pro in SurfSense. Each AI request debits the actual USD cost the provider charges, so cheap and expensive models bill proportionally. Non-premium models (such as the free-tier models available without login) don't touch your premium credit.",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
question: "How many premium tokens do I get for free?",
|
question: "How much premium credit do I get for free?",
|
||||||
answer:
|
answer:
|
||||||
"Every registered SurfSense account starts with 3 million premium tokens at no cost. Anonymous users (no login) get 500,000 free tokens across all models. Once your free tokens are used up, you can purchase more at any time.",
|
"Every registered SurfSense account starts with $5 of premium credit at no cost. Anonymous users (no login) get 500,000 free tokens across all free models. Once your free credit runs out, you can top up at any time.",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
question: "How does purchasing premium tokens work?",
|
question: "How does buying premium credit work?",
|
||||||
answer:
|
answer:
|
||||||
"Just like pages, there's no subscription. You buy 1-million-token packs at $1 each whenever you need more. Purchased tokens are added to your account immediately. You can buy up to 100 packs at a time.",
|
"Just like pages, there's no subscription. Top-ups buy $1 of credit for $1 — every cent you pay is spent at provider cost, no markup. Purchased credit is added to your account immediately. You can buy up to $100 at a time.",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
question: "What happens if I run out of premium tokens?",
|
question: "What happens if I run out of premium credit?",
|
||||||
answer:
|
answer:
|
||||||
"When your premium token balance runs low (below 20%), you'll see a warning. Once you run out, premium model requests are paused until you purchase more tokens. You can always switch to non-premium models which don't consume premium tokens.",
|
"When your premium credit balance runs low (below 20%), you'll see a warning. Once you run out, premium model requests are paused until you top up. You can always switch to non-premium models, which don't touch your premium credit.",
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
|
|
@ -157,9 +157,9 @@ const faqData: FAQSection[] = [
|
||||||
title: "Self-Hosting",
|
title: "Self-Hosting",
|
||||||
items: [
|
items: [
|
||||||
{
|
{
|
||||||
question: "Can I self-host SurfSense with unlimited pages and tokens?",
|
question: "Can I self-host SurfSense with unlimited pages and credit?",
|
||||||
answer:
|
answer:
|
||||||
"Yes! When self-hosting, you have full control over your page and token limits. The default self-hosted setup gives you effectively unlimited pages and tokens, so you can index as much data and use as many AI queries as your infrastructure supports.",
|
"Yes! When self-hosting, you have full control over your page and premium-credit limits. The default self-hosted setup gives you effectively unlimited pages and premium credit, so you can index as much data and use as many AI queries as your infrastructure supports.",
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
|
|
@ -250,8 +250,8 @@ function PricingFAQ() {
|
||||||
Frequently Asked Questions
|
Frequently Asked Questions
|
||||||
</h2>
|
</h2>
|
||||||
<p className="mx-auto mt-4 max-w-2xl text-lg text-muted-foreground">
|
<p className="mx-auto mt-4 max-w-2xl text-lg text-muted-foreground">
|
||||||
Everything you need to know about SurfSense pages, premium tokens, and billing. Can't
|
Everything you need to know about SurfSense pages, premium credit, and billing.
|
||||||
find what you need? Reach out at{" "}
|
Can't find what you need? Reach out at{" "}
|
||||||
<a href="mailto:rohan@surfsense.com" className="text-blue-500 underline">
|
<a href="mailto:rohan@surfsense.com" className="text-blue-500 underline">
|
||||||
rohan@surfsense.com
|
rohan@surfsense.com
|
||||||
</a>
|
</a>
|
||||||
|
|
@ -335,7 +335,7 @@ function PricingBasic() {
|
||||||
<Pricing
|
<Pricing
|
||||||
plans={demoPlans}
|
plans={demoPlans}
|
||||||
title="SurfSense Pricing"
|
title="SurfSense Pricing"
|
||||||
description="Start free with 500 pages & 3M premium tokens. Pay as you go."
|
description="Start free with 500 pages & $5 of premium credit. Pay as you go, billed at provider cost."
|
||||||
/>
|
/>
|
||||||
<PricingFAQ />
|
<PricingFAQ />
|
||||||
</>
|
</>
|
||||||
|
|
|
||||||
|
|
@ -14,10 +14,23 @@ import { AppError } from "@/lib/error";
|
||||||
import { cn } from "@/lib/utils";
|
import { cn } from "@/lib/utils";
|
||||||
import { queries } from "@/zero/queries";
|
import { queries } from "@/zero/queries";
|
||||||
|
|
||||||
const TOKEN_PACK_SIZE = 1_000_000;
|
// One pack = $1.00 of credit, stored as 1_000_000 micro-USD on the
|
||||||
|
// backend. Premium turns are debited at the actual provider cost
|
||||||
|
// reported by LiteLLM, so $1 of credit always buys $1 of provider
|
||||||
|
// usage at cost.
|
||||||
|
const CREDIT_PER_PACK_MICROS = 1_000_000;
|
||||||
const PRICE_PER_PACK_USD = 1;
|
const PRICE_PER_PACK_USD = 1;
|
||||||
const PRESET_MULTIPLIERS = [1, 2, 5, 10, 25, 50] as const;
|
const PRESET_MULTIPLIERS = [1, 2, 5, 10, 25, 50] as const;
|
||||||
|
|
||||||
|
const formatUsd = (micros: number, options?: { compact?: boolean }) => {
|
||||||
|
const dollars = micros / 1_000_000;
|
||||||
|
if (options?.compact && dollars >= 1) return `$${dollars.toFixed(2)}`;
|
||||||
|
if (dollars >= 100) return `$${dollars.toFixed(0)}`;
|
||||||
|
if (dollars >= 1) return `$${dollars.toFixed(2)}`;
|
||||||
|
if (dollars > 0) return `$${dollars.toFixed(3)}`;
|
||||||
|
return "$0";
|
||||||
|
};
|
||||||
|
|
||||||
export function BuyTokensContent() {
|
export function BuyTokensContent() {
|
||||||
const params = useParams();
|
const params = useParams();
|
||||||
const searchSpaceId = Number(params?.search_space_id);
|
const searchSpaceId = Number(params?.search_space_id);
|
||||||
|
|
@ -29,7 +42,7 @@ export function BuyTokensContent() {
|
||||||
queryFn: () => stripeApiService.getTokenStatus(),
|
queryFn: () => stripeApiService.getTokenStatus(),
|
||||||
});
|
});
|
||||||
|
|
||||||
// Live per-user usage via Zero.
|
// Live per-user balance via Zero.
|
||||||
const [me] = useZeroQuery(queries.user.me({}));
|
const [me] = useZeroQuery(queries.user.me({}));
|
||||||
|
|
||||||
const purchaseMutation = useMutation({
|
const purchaseMutation = useMutation({
|
||||||
|
|
@ -46,44 +59,46 @@ export function BuyTokensContent() {
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
const totalTokens = quantity * TOKEN_PACK_SIZE;
|
const totalCreditMicros = quantity * CREDIT_PER_PACK_MICROS;
|
||||||
const totalPrice = quantity * PRICE_PER_PACK_USD;
|
const totalPrice = quantity * PRICE_PER_PACK_USD;
|
||||||
|
|
||||||
if (tokenStatus && !tokenStatus.token_buying_enabled) {
|
if (tokenStatus && !tokenStatus.token_buying_enabled) {
|
||||||
return (
|
return (
|
||||||
<div className="w-full space-y-3 text-center">
|
<div className="w-full space-y-3 text-center">
|
||||||
<h2 className="text-xl font-bold tracking-tight">Buy Premium Tokens</h2>
|
<h2 className="text-xl font-bold tracking-tight">Buy Premium Credit</h2>
|
||||||
<p className="text-sm text-muted-foreground">
|
<p className="text-sm text-muted-foreground">
|
||||||
Token purchases are temporarily unavailable.
|
Credit purchases are temporarily unavailable.
|
||||||
</p>
|
</p>
|
||||||
</div>
|
</div>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
const used = me?.premiumTokensUsed ?? 0;
|
const used = me?.premiumCreditMicrosUsed ?? 0;
|
||||||
const limit = me?.premiumTokensLimit ?? 0;
|
const limit = me?.premiumCreditMicrosLimit ?? 0;
|
||||||
// Mirrors the backend formula in stripe_routes.py:608 (max(0, limit - used)).
|
// Mirrors the backend formula in stripe_routes.py (max(0, limit - used)).
|
||||||
const remaining = Math.max(0, limit - used);
|
const remaining = Math.max(0, limit - used);
|
||||||
const usagePercentage = me ? Math.min((used / Math.max(limit, 1)) * 100, 100) : 0;
|
const usagePercentage = me ? Math.min((used / Math.max(limit, 1)) * 100, 100) : 0;
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div className="w-full space-y-5">
|
<div className="w-full space-y-5">
|
||||||
<div className="text-center">
|
<div className="text-center">
|
||||||
<h2 className="text-xl font-bold tracking-tight">Buy Premium Tokens</h2>
|
<h2 className="text-xl font-bold tracking-tight">Buy Premium Credit</h2>
|
||||||
<p className="mt-1 text-sm text-muted-foreground">$1 per 1M tokens, pay as you go</p>
|
<p className="mt-1 text-sm text-muted-foreground">
|
||||||
|
$1 buys $1 of credit, billed at provider cost
|
||||||
|
</p>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
{me && (
|
{me && (
|
||||||
<div className="rounded-lg border bg-muted/20 p-3 space-y-1.5">
|
<div className="rounded-lg border bg-muted/20 p-3 space-y-1.5">
|
||||||
<div className="flex justify-between items-center text-xs">
|
<div className="flex justify-between items-center text-xs">
|
||||||
<span className="text-muted-foreground">
|
<span className="text-muted-foreground">
|
||||||
{used.toLocaleString()} / {limit.toLocaleString()} premium tokens
|
{formatUsd(used)} / {formatUsd(limit)} of credit
|
||||||
</span>
|
</span>
|
||||||
<span className="font-medium">{usagePercentage.toFixed(0)}%</span>
|
<span className="font-medium">{usagePercentage.toFixed(0)}%</span>
|
||||||
</div>
|
</div>
|
||||||
<Progress value={usagePercentage} className="h-1.5" />
|
<Progress value={usagePercentage} className="h-1.5" />
|
||||||
<p className="text-[11px] text-muted-foreground">
|
<p className="text-[11px] text-muted-foreground">
|
||||||
{remaining.toLocaleString()} tokens remaining
|
{formatUsd(remaining)} of credit remaining
|
||||||
</p>
|
</p>
|
||||||
</div>
|
</div>
|
||||||
)}
|
)}
|
||||||
|
|
@ -99,7 +114,7 @@ export function BuyTokensContent() {
|
||||||
<Minus className="h-3.5 w-3.5" />
|
<Minus className="h-3.5 w-3.5" />
|
||||||
</button>
|
</button>
|
||||||
<span className="min-w-32 text-center text-lg font-semibold tabular-nums">
|
<span className="min-w-32 text-center text-lg font-semibold tabular-nums">
|
||||||
{(totalTokens / 1_000_000).toFixed(0)}M tokens
|
${(totalCreditMicros / 1_000_000).toFixed(0)} of credit
|
||||||
</span>
|
</span>
|
||||||
<button
|
<button
|
||||||
type="button"
|
type="button"
|
||||||
|
|
@ -125,14 +140,14 @@ export function BuyTokensContent() {
|
||||||
: "border-border hover:border-purple-500/40 hover:bg-muted/40"
|
: "border-border hover:border-purple-500/40 hover:bg-muted/40"
|
||||||
)}
|
)}
|
||||||
>
|
>
|
||||||
{m}M
|
${m}
|
||||||
</button>
|
</button>
|
||||||
))}
|
))}
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<div className="flex items-center justify-between rounded-lg border bg-muted/30 px-3 py-2">
|
<div className="flex items-center justify-between rounded-lg border bg-muted/30 px-3 py-2">
|
||||||
<span className="text-sm font-medium tabular-nums">
|
<span className="text-sm font-medium tabular-nums">
|
||||||
{(totalTokens / 1_000_000).toFixed(0)}M premium tokens
|
${(totalCreditMicros / 1_000_000).toFixed(0)} of credit
|
||||||
</span>
|
</span>
|
||||||
<span className="text-sm font-semibold tabular-nums">${totalPrice}</span>
|
<span className="text-sm font-semibold tabular-nums">${totalPrice}</span>
|
||||||
</div>
|
</div>
|
||||||
|
|
@ -149,7 +164,7 @@ export function BuyTokensContent() {
|
||||||
</>
|
</>
|
||||||
) : (
|
) : (
|
||||||
<>
|
<>
|
||||||
Buy {(totalTokens / 1_000_000).toFixed(0)}M Tokens for ${totalPrice}
|
Buy ${(totalCreditMicros / 1_000_000).toFixed(0)} of credit for ${totalPrice}
|
||||||
</>
|
</>
|
||||||
)}
|
)}
|
||||||
</Button>
|
</Button>
|
||||||
|
|
|
||||||
|
|
@ -190,7 +190,25 @@ export function ImageModelManager({ searchSpaceId }: ImageModelManagerProps) {
|
||||||
? "model"
|
? "model"
|
||||||
: "models"}
|
: "models"}
|
||||||
</span>{" "}
|
</span>{" "}
|
||||||
available from your administrator.
|
available from your administrator.{" "}
|
||||||
|
{(() => {
|
||||||
|
const nonAuto = globalConfigs.filter(
|
||||||
|
(g) => !("is_auto_mode" in g && g.is_auto_mode)
|
||||||
|
);
|
||||||
|
const premium = nonAuto.filter(
|
||||||
|
(g) =>
|
||||||
|
"billing_tier" in g &&
|
||||||
|
(g as { billing_tier?: string }).billing_tier === "premium"
|
||||||
|
).length;
|
||||||
|
const free = nonAuto.length - premium;
|
||||||
|
if (premium > 0 && free > 0) {
|
||||||
|
return `${premium} premium, ${free} free.`;
|
||||||
|
}
|
||||||
|
if (premium > 0) {
|
||||||
|
return `All ${premium} premium — debits your shared credit pool.`;
|
||||||
|
}
|
||||||
|
return `All ${free} free.`;
|
||||||
|
})()}
|
||||||
</p>
|
</p>
|
||||||
</AlertDescription>
|
</AlertDescription>
|
||||||
</Alert>
|
</Alert>
|
||||||
|
|
|
||||||
|
|
@ -371,6 +371,17 @@ export function LLMRoleManager({ searchSpaceId }: LLMRoleManagerProps) {
|
||||||
</SelectLabel>
|
</SelectLabel>
|
||||||
{roleGlobalConfigs.map((config) => {
|
{roleGlobalConfigs.map((config) => {
|
||||||
const isAuto = "is_auto_mode" in config && config.is_auto_mode;
|
const isAuto = "is_auto_mode" in config && config.is_auto_mode;
|
||||||
|
// Read billing_tier from the global config; default to "free"
|
||||||
|
// for legacy YAMLs / Auto stub. Premium gets a purple badge,
|
||||||
|
// free gets an emerald one — same palette as the chat
|
||||||
|
// model selector so the meaning is consistent across
|
||||||
|
// surfaces (issues E, H).
|
||||||
|
const billingTier =
|
||||||
|
("billing_tier" in config &&
|
||||||
|
typeof config.billing_tier === "string" &&
|
||||||
|
config.billing_tier) ||
|
||||||
|
"free";
|
||||||
|
const isPremium = billingTier === "premium";
|
||||||
return (
|
return (
|
||||||
<SelectItem
|
<SelectItem
|
||||||
key={config.id}
|
key={config.id}
|
||||||
|
|
@ -382,13 +393,27 @@ export function LLMRoleManager({ searchSpaceId }: LLMRoleManagerProps) {
|
||||||
<span className="truncate text-xs md:text-sm">
|
<span className="truncate text-xs md:text-sm">
|
||||||
{config.name}
|
{config.name}
|
||||||
</span>
|
</span>
|
||||||
{isAuto && (
|
{isAuto ? (
|
||||||
<Badge
|
<Badge
|
||||||
variant="secondary"
|
variant="secondary"
|
||||||
className="text-[8px] md:text-[9px] shrink-0 bg-zinc-200 text-zinc-600 dark:bg-zinc-700 dark:text-zinc-300 [[data-slot=select-trigger]_&]:hidden"
|
className="text-[8px] md:text-[9px] shrink-0 bg-zinc-200 text-zinc-600 dark:bg-zinc-700 dark:text-zinc-300 [[data-slot=select-trigger]_&]:hidden"
|
||||||
>
|
>
|
||||||
Recommended
|
Recommended
|
||||||
</Badge>
|
</Badge>
|
||||||
|
) : isPremium ? (
|
||||||
|
<Badge
|
||||||
|
variant="secondary"
|
||||||
|
className="text-[8px] md:text-[9px] shrink-0 bg-purple-100 text-purple-700 dark:bg-purple-900/50 dark:text-purple-300 border-0 [[data-slot=select-trigger]_&]:hidden"
|
||||||
|
>
|
||||||
|
Premium
|
||||||
|
</Badge>
|
||||||
|
) : (
|
||||||
|
<Badge
|
||||||
|
variant="secondary"
|
||||||
|
className="text-[8px] md:text-[9px] shrink-0 bg-emerald-100 text-emerald-700 dark:bg-emerald-900/50 dark:text-emerald-300 border-0 [[data-slot=select-trigger]_&]:hidden"
|
||||||
|
>
|
||||||
|
Free
|
||||||
|
</Badge>
|
||||||
)}
|
)}
|
||||||
</div>
|
</div>
|
||||||
</SelectItem>
|
</SelectItem>
|
||||||
|
|
|
||||||
|
|
@ -191,7 +191,25 @@ export function VisionModelManager({ searchSpaceId }: VisionModelManagerProps) {
|
||||||
? "model"
|
? "model"
|
||||||
: "models"}
|
: "models"}
|
||||||
</span>{" "}
|
</span>{" "}
|
||||||
available from your administrator.
|
available from your administrator.{" "}
|
||||||
|
{(() => {
|
||||||
|
const nonAuto = globalConfigs.filter(
|
||||||
|
(g) => !("is_auto_mode" in g && g.is_auto_mode)
|
||||||
|
);
|
||||||
|
const premium = nonAuto.filter(
|
||||||
|
(g) =>
|
||||||
|
"billing_tier" in g &&
|
||||||
|
(g as { billing_tier?: string }).billing_tier === "premium"
|
||||||
|
).length;
|
||||||
|
const free = nonAuto.length - premium;
|
||||||
|
if (premium > 0 && free > 0) {
|
||||||
|
return `${premium} premium, ${free} free.`;
|
||||||
|
}
|
||||||
|
if (premium > 0) {
|
||||||
|
return `All ${premium} premium — debits your shared credit pool.`;
|
||||||
|
}
|
||||||
|
return `All ${free} free.`;
|
||||||
|
})()}
|
||||||
</p>
|
</p>
|
||||||
</AlertDescription>
|
</AlertDescription>
|
||||||
</Alert>
|
</Alert>
|
||||||
|
|
|
||||||
|
|
@ -44,8 +44,8 @@ export function LoginGateProvider({ children }: { children: ReactNode }) {
|
||||||
<DialogHeader>
|
<DialogHeader>
|
||||||
<DialogTitle>Create a free account to {feature}</DialogTitle>
|
<DialogTitle>Create a free account to {feature}</DialogTitle>
|
||||||
<DialogDescription>
|
<DialogDescription>
|
||||||
Get 3 million tokens, save chat history, upload documents, use all AI tools, and
|
Get $5 of premium credit, save chat history, upload documents, use all AI tools,
|
||||||
connect 30+ integrations.
|
and connect 30+ integrations.
|
||||||
</DialogDescription>
|
</DialogDescription>
|
||||||
</DialogHeader>
|
</DialogHeader>
|
||||||
<DialogFooter className="flex flex-col gap-2 sm:flex-row">
|
<DialogFooter className="flex flex-col gap-2 sm:flex-row">
|
||||||
|
|
|
||||||
|
|
@ -258,6 +258,8 @@ export const globalImageGenConfig = z.object({
|
||||||
litellm_params: z.record(z.string(), z.any()).nullable().optional(),
|
litellm_params: z.record(z.string(), z.any()).nullable().optional(),
|
||||||
is_global: z.literal(true),
|
is_global: z.literal(true),
|
||||||
is_auto_mode: z.boolean().optional().default(false),
|
is_auto_mode: z.boolean().optional().default(false),
|
||||||
|
billing_tier: z.string().default("free"),
|
||||||
|
quota_reserve_micros: z.number().nullable().optional(),
|
||||||
});
|
});
|
||||||
|
|
||||||
export const getGlobalImageGenConfigsResponse = z.array(globalImageGenConfig);
|
export const getGlobalImageGenConfigsResponse = z.array(globalImageGenConfig);
|
||||||
|
|
@ -338,6 +340,10 @@ export const globalVisionLLMConfig = z.object({
|
||||||
litellm_params: z.record(z.string(), z.any()).nullable().optional(),
|
litellm_params: z.record(z.string(), z.any()).nullable().optional(),
|
||||||
is_global: z.literal(true),
|
is_global: z.literal(true),
|
||||||
is_auto_mode: z.boolean().optional().default(false),
|
is_auto_mode: z.boolean().optional().default(false),
|
||||||
|
billing_tier: z.string().default("free"),
|
||||||
|
quota_reserve_tokens: z.number().nullable().optional(),
|
||||||
|
input_cost_per_token: z.number().nullable().optional(),
|
||||||
|
output_cost_per_token: z.number().nullable().optional(),
|
||||||
});
|
});
|
||||||
|
|
||||||
export const getGlobalVisionLLMConfigsResponse = z.array(globalVisionLLMConfig);
|
export const getGlobalVisionLLMConfigsResponse = z.array(globalVisionLLMConfig);
|
||||||
|
|
|
||||||
|
|
@ -32,7 +32,7 @@ export const getPagePurchasesResponse = z.object({
|
||||||
purchases: z.array(pagePurchase),
|
purchases: z.array(pagePurchase),
|
||||||
});
|
});
|
||||||
|
|
||||||
// Premium token purchases
|
// Premium credit purchases
|
||||||
export const createTokenCheckoutSessionRequest = z.object({
|
export const createTokenCheckoutSessionRequest = z.object({
|
||||||
quantity: z.number().int().min(1).max(100),
|
quantity: z.number().int().min(1).max(100),
|
||||||
search_space_id: z.number().int().min(1),
|
search_space_id: z.number().int().min(1),
|
||||||
|
|
@ -42,11 +42,16 @@ export const createTokenCheckoutSessionResponse = z.object({
|
||||||
checkout_url: z.string(),
|
checkout_url: z.string(),
|
||||||
});
|
});
|
||||||
|
|
||||||
|
// Premium credit balance + purchase records.
|
||||||
|
//
|
||||||
|
// The unit is integer micro-USD (1_000_000 == $1.00). The schema names
|
||||||
|
// kept the ``Token`` prefix for API back-compat with pinned clients;
|
||||||
|
// the field names below are authoritative.
|
||||||
export const tokenStripeStatusResponse = z.object({
|
export const tokenStripeStatusResponse = z.object({
|
||||||
token_buying_enabled: z.boolean(),
|
token_buying_enabled: z.boolean(),
|
||||||
premium_tokens_used: z.number().default(0),
|
premium_credit_micros_used: z.number().default(0),
|
||||||
premium_tokens_limit: z.number().default(0),
|
premium_credit_micros_limit: z.number().default(0),
|
||||||
premium_tokens_remaining: z.number().default(0),
|
premium_credit_micros_remaining: z.number().default(0),
|
||||||
});
|
});
|
||||||
|
|
||||||
export const tokenPurchaseStatusEnum = pagePurchaseStatusEnum;
|
export const tokenPurchaseStatusEnum = pagePurchaseStatusEnum;
|
||||||
|
|
@ -56,7 +61,7 @@ export const tokenPurchase = z.object({
|
||||||
stripe_checkout_session_id: z.string(),
|
stripe_checkout_session_id: z.string(),
|
||||||
stripe_payment_intent_id: z.string().nullable(),
|
stripe_payment_intent_id: z.string().nullable(),
|
||||||
quantity: z.number(),
|
quantity: z.number(),
|
||||||
tokens_granted: z.number(),
|
credit_micros_granted: z.number(),
|
||||||
amount_total: z.number().nullable(),
|
amount_total: z.number().nullable(),
|
||||||
currency: z.string().nullable(),
|
currency: z.string().nullable(),
|
||||||
status: tokenPurchaseStatusEnum,
|
status: tokenPurchaseStatusEnum,
|
||||||
|
|
|
||||||
|
|
@ -41,7 +41,7 @@ export interface RawChatErrorInput {
|
||||||
}
|
}
|
||||||
|
|
||||||
export const PREMIUM_QUOTA_ASSISTANT_MESSAGE =
|
export const PREMIUM_QUOTA_ASSISTANT_MESSAGE =
|
||||||
"I can’t continue with the current premium model because your premium tokens are exhausted. Switch to a free model or buy more tokens to continue.";
|
"I can’t continue with the current premium model because your premium credit is exhausted. Switch to a free model or top up your credit to continue.";
|
||||||
|
|
||||||
function getErrorMessage(error: unknown): string {
|
function getErrorMessage(error: unknown): string {
|
||||||
if (error instanceof Error) return error.message;
|
if (error instanceof Error) return error.message;
|
||||||
|
|
|
||||||
|
|
@ -541,16 +541,23 @@ export type SSEEvent =
|
||||||
data: {
|
data: {
|
||||||
usage: Record<
|
usage: Record<
|
||||||
string,
|
string,
|
||||||
{ prompt_tokens: number; completion_tokens: number; total_tokens: number }
|
{
|
||||||
|
prompt_tokens: number;
|
||||||
|
completion_tokens: number;
|
||||||
|
total_tokens: number;
|
||||||
|
cost_micros?: number;
|
||||||
|
}
|
||||||
>;
|
>;
|
||||||
prompt_tokens: number;
|
prompt_tokens: number;
|
||||||
completion_tokens: number;
|
completion_tokens: number;
|
||||||
total_tokens: number;
|
total_tokens: number;
|
||||||
|
cost_micros?: number;
|
||||||
call_details: Array<{
|
call_details: Array<{
|
||||||
model: string;
|
model: string;
|
||||||
prompt_tokens: number;
|
prompt_tokens: number;
|
||||||
completion_tokens: number;
|
completion_tokens: number;
|
||||||
total_tokens: number;
|
total_tokens: number;
|
||||||
|
cost_micros?: number;
|
||||||
}>;
|
}>;
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -30,9 +30,20 @@ export interface TokenUsageSummary {
|
||||||
prompt_tokens: number;
|
prompt_tokens: number;
|
||||||
completion_tokens: number;
|
completion_tokens: number;
|
||||||
total_tokens: number;
|
total_tokens: number;
|
||||||
|
/**
|
||||||
|
* Total provider USD cost for this assistant turn, in micro-USD
|
||||||
|
* (1_000_000 = $1.00). Optional because rows persisted before the
|
||||||
|
* cost-credits migration won't have it.
|
||||||
|
*/
|
||||||
|
cost_micros?: number;
|
||||||
model_breakdown?: Record<
|
model_breakdown?: Record<
|
||||||
string,
|
string,
|
||||||
{ prompt_tokens: number; completion_tokens: number; total_tokens: number }
|
{
|
||||||
|
prompt_tokens: number;
|
||||||
|
completion_tokens: number;
|
||||||
|
total_tokens: number;
|
||||||
|
cost_micros?: number;
|
||||||
|
}
|
||||||
> | null;
|
> | null;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,11 +1,20 @@
|
||||||
import { number, string, table } from "@rocicorp/zero";
|
import { number, string, table } from "@rocicorp/zero";
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Live-meter slice of the ``user`` table replicated through Zero.
|
||||||
|
*
|
||||||
|
* ``premiumCreditMicrosLimit`` / ``premiumCreditMicrosUsed`` are stored
|
||||||
|
* as integer micro-USD (1_000_000 == $1.00). UI consumers divide by 1M
|
||||||
|
* when displaying. Sensitive fields (email, hashed_password, oauth, etc.)
|
||||||
|
* are intentionally omitted via the Postgres column-list publication so
|
||||||
|
* they never enter WAL replication.
|
||||||
|
*/
|
||||||
export const userTable = table("user")
|
export const userTable = table("user")
|
||||||
.columns({
|
.columns({
|
||||||
id: string(),
|
id: string(),
|
||||||
pagesLimit: number().from("pages_limit"),
|
pagesLimit: number().from("pages_limit"),
|
||||||
pagesUsed: number().from("pages_used"),
|
pagesUsed: number().from("pages_used"),
|
||||||
premiumTokensLimit: number().from("premium_tokens_limit"),
|
premiumCreditMicrosLimit: number().from("premium_credit_micros_limit"),
|
||||||
premiumTokensUsed: number().from("premium_tokens_used"),
|
premiumCreditMicrosUsed: number().from("premium_credit_micros_used"),
|
||||||
})
|
})
|
||||||
.primaryKey("id");
|
.primaryKey("id");
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue