feat: unified credits and its cost calculations

This commit is contained in:
DESKTOP-RTLN3BA\$punk 2026-05-02 14:34:23 -07:00
parent 451a98936e
commit ae9d36d77f
61 changed files with 5835 additions and 272 deletions

View file

@ -159,10 +159,13 @@ STRIPE_PAGE_BUYING_ENABLED=FALSE
# STRIPE_RECONCILIATION_LOOKBACK_MINUTES=10 # STRIPE_RECONCILIATION_LOOKBACK_MINUTES=10
# STRIPE_RECONCILIATION_BATCH_SIZE=100 # STRIPE_RECONCILIATION_BATCH_SIZE=100
# Premium token purchases ($1 per 1M tokens for premium-tier models) # Premium credit purchases via Stripe ($1 buys 1_000_000 micro-USD of
# credit; premium turns debit the actual per-call provider cost
# reported by LiteLLM, so cheap and expensive models bill proportionally)
# STRIPE_TOKEN_BUYING_ENABLED=FALSE # STRIPE_TOKEN_BUYING_ENABLED=FALSE
# STRIPE_PREMIUM_TOKEN_PRICE_ID=price_... # STRIPE_PREMIUM_TOKEN_PRICE_ID=price_...
# STRIPE_TOKENS_PER_UNIT=1000000 # STRIPE_CREDIT_MICROS_PER_UNIT=1000000
# DEPRECATED — STRIPE_TOKENS_PER_UNIT=1000000
# ------------------------------------------------------------------------------ # ------------------------------------------------------------------------------
# TTS & STT (Text-to-Speech / Speech-to-Text) # TTS & STT (Text-to-Speech / Speech-to-Text)
@ -315,9 +318,24 @@ STT_SERVICE=local/base
# Pages limit per user for ETL (default: unlimited) # Pages limit per user for ETL (default: unlimited)
# PAGES_LIMIT=500 # PAGES_LIMIT=500
# Premium token quota per registered user (default: 5M) # Premium credit quota per registered user, in micro-USD (default: $5).
# Only applies to models with billing_tier=premium in global_llm_config.yaml # Premium turns are debited at the actual per-call provider cost reported
# PREMIUM_TOKEN_LIMIT=5000000 # by LiteLLM. Only applies to models with billing_tier=premium.
# PREMIUM_CREDIT_MICROS_LIMIT=5000000
# DEPRECATED — PREMIUM_TOKEN_LIMIT=5000000
# Safety ceiling on per-call premium reservation, in micro-USD ($1.00 default).
# QUOTA_MAX_RESERVE_MICROS=1000000
# Per-image reservation for POST /image-generations, in micro-USD ($0.05 default).
# QUOTA_DEFAULT_IMAGE_RESERVE_MICROS=50000
# Per-podcast reservation for the podcast Celery task ($0.20 default).
# QUOTA_DEFAULT_PODCAST_RESERVE_MICROS=200000
# Per-video-presentation reservation for the video Celery task ($1.00 default).
# Override path bypasses QUOTA_MAX_RESERVE_MICROS clamp — raise with care.
# QUOTA_DEFAULT_VIDEO_PRESENTATION_RESERVE_MICROS=1000000
# No-login (anonymous) mode — public users can chat without an account # No-login (anonymous) mode — public users can chat without an account
# Set TRUE to enable /free pages and anonymous chat API # Set TRUE to enable /free pages and anonymous chat API

View file

@ -54,11 +54,15 @@ STRIPE_PAGES_PER_UNIT=1000
# Set FALSE to disable new checkout session creation temporarily # Set FALSE to disable new checkout session creation temporarily
STRIPE_PAGE_BUYING_ENABLED=TRUE STRIPE_PAGE_BUYING_ENABLED=TRUE
# Premium token purchases via Stripe (for premium-tier model usage) # Premium credit purchases via Stripe (for premium-tier model usage).
# Set TRUE to allow users to buy premium token packs ($1 per 1M tokens) # Each pack grants STRIPE_CREDIT_MICROS_PER_UNIT micro-USD of credit
# (default 1_000_000 = $1.00). Premium turns are billed at the actual
# per-call provider cost reported by LiteLLM.
STRIPE_TOKEN_BUYING_ENABLED=FALSE STRIPE_TOKEN_BUYING_ENABLED=FALSE
STRIPE_PREMIUM_TOKEN_PRICE_ID=price_... STRIPE_PREMIUM_TOKEN_PRICE_ID=price_...
STRIPE_TOKENS_PER_UNIT=1000000 STRIPE_CREDIT_MICROS_PER_UNIT=1000000
# DEPRECATED — use STRIPE_CREDIT_MICROS_PER_UNIT (1:1 numerical mapping):
# STRIPE_TOKENS_PER_UNIT=1000000
# Periodic Stripe safety net for purchases left in PENDING (minutes old) # Periodic Stripe safety net for purchases left in PENDING (minutes old)
STRIPE_RECONCILIATION_LOOKBACK_MINUTES=10 STRIPE_RECONCILIATION_LOOKBACK_MINUTES=10
@ -184,9 +188,35 @@ VIDEO_PRESENTATION_DEFAULT_DURATION_IN_FRAMES=300
# (Optional) Maximum pages limit per user for ETL services (default: `999999999` for unlimited in OSS version) # (Optional) Maximum pages limit per user for ETL services (default: `999999999` for unlimited in OSS version)
PAGES_LIMIT=500 PAGES_LIMIT=500
# Premium token quota per registered user (default: 3,000,000) # Premium credit quota per registered user, in micro-USD
# Applies only to models with billing_tier=premium in global_llm_config.yaml # (default: 5,000,000 == $5.00 of credit). Premium turns are debited at the
PREMIUM_TOKEN_LIMIT=3000000 # actual per-call provider cost reported by LiteLLM, so cheap and expensive
# models bill proportionally. Applies only to models with
# billing_tier=premium in global_llm_config.yaml.
PREMIUM_CREDIT_MICROS_LIMIT=5000000
# DEPRECATED — use PREMIUM_CREDIT_MICROS_LIMIT (1:1 numerical mapping):
# PREMIUM_TOKEN_LIMIT=5000000
# Safety ceiling on per-call premium reservation, in micro-USD.
# stream_new_chat estimates an upper-bound cost from the model's
# litellm-published per-token rates × the config's quota_reserve_tokens
# and clamps to this value so a misconfigured model can't lock the
# user's whole balance on one call. Default $1.00.
QUOTA_MAX_RESERVE_MICROS=1000000
# Per-image reservation (in micro-USD) for the POST /image-generations
# endpoint. Bypassed for free configs. Default $0.05.
QUOTA_DEFAULT_IMAGE_RESERVE_MICROS=50000
# Per-podcast reservation (in micro-USD) used by the podcast Celery task.
# Single envelope covers one transcript-generation LLM call. Default $0.20.
QUOTA_DEFAULT_PODCAST_RESERVE_MICROS=200000
# Per-video-presentation reservation (in micro-USD) used by the video
# presentation Celery task. Covers worst-case fan-out of N slide-scene
# generations + refines. Default $1.00. NOTE: tasks using the override
# path bypass the QUOTA_MAX_RESERVE_MICROS clamp — raise with care.
QUOTA_DEFAULT_VIDEO_PRESENTATION_RESERVE_MICROS=1000000
# No-login (anonymous) mode — allows public users to chat without an account # No-login (anonymous) mode — allows public users to chat without an account
# Set TRUE to enable /free pages and anonymous chat API # Set TRUE to enable /free pages and anonymous chat API

View file

@ -0,0 +1,291 @@
"""rename premium token columns to credit micros and add cost_micros to token_usage
Migrates the premium quota system from a flat token counter to a USD-cost
based credit system, where 1 credit = 1 micro-USD ($0.000001).
Column renames (1:1 numerical mapping the prior $1 per 1M tokens Stripe
price means every existing value is already correct in the new unit, no
data transformation needed):
user.premium_tokens_limit -> premium_credit_micros_limit
user.premium_tokens_used -> premium_credit_micros_used
user.premium_tokens_reserved -> premium_credit_micros_reserved
premium_token_purchases.tokens_granted -> credit_micros_granted
New column for cost auditing per turn:
token_usage.cost_micros (BigInteger NOT NULL DEFAULT 0)
The "user" table is in zero_publication's column list (added in 139), so
this migration must drop and recreate the publication with the renamed
column names, otherwise zero-cache will replicate stale column names and
the FE Zero schema will fail to bind.
IMPORTANT - before AND after running this migration:
1. Stop zero-cache (it holds replication locks that will deadlock DDL)
2. Run: alembic upgrade head
3. Delete / reset the zero-cache data volume
4. Restart zero-cache (it will do a fresh initial sync)
Skipping the zero-cache stop will deadlock at the ACCESS EXCLUSIVE LOCK on
"user". Skipping the data-volume reset will leave IndexedDB clients seeing
column-not-found errors from a stale catalog snapshot.
Revision ID: 140
Revises: 139
"""
from collections.abc import Sequence
import sqlalchemy as sa
from alembic import op
revision: str = "140"
down_revision: str | None = "139"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
PUBLICATION_NAME = "zero_publication"
# Replicates 139's document column list verbatim — must stay in sync.
DOCUMENT_COLS = [
"id",
"title",
"document_type",
"search_space_id",
"folder_id",
"created_by_id",
"status",
"created_at",
"updated_at",
]
# Same five live-meter fields as 139, with the renamed column names.
USER_COLS = [
"id",
"pages_limit",
"pages_used",
"premium_credit_micros_limit",
"premium_credit_micros_used",
]
def _terminate_blocked_pids(conn, table: str) -> None:
"""Kill backends whose locks on *table* would block our AccessExclusiveLock."""
conn.execute(
sa.text(
"SELECT pg_terminate_backend(l.pid) "
"FROM pg_locks l "
"JOIN pg_class c ON c.oid = l.relation "
"WHERE c.relname = :tbl "
" AND l.pid != pg_backend_pid()"
),
{"tbl": table},
)
def _has_zero_version(conn, table: str) -> bool:
return (
conn.execute(
sa.text(
"SELECT 1 FROM information_schema.columns "
"WHERE table_name = :tbl AND column_name = '_0_version'"
),
{"tbl": table},
).fetchone()
is not None
)
def _column_exists(conn, table: str, column: str) -> bool:
return (
conn.execute(
sa.text(
"SELECT 1 FROM information_schema.columns "
"WHERE table_name = :tbl AND column_name = :col"
),
{"tbl": table, "col": column},
).fetchone()
is not None
)
def _build_publication_ddl(
user_cols: list[str],
*,
documents_has_zero_ver: bool,
user_has_zero_ver: bool,
) -> str:
doc_cols = DOCUMENT_COLS + (['"_0_version"'] if documents_has_zero_ver else [])
user_col_list_with_meta = user_cols + (
['"_0_version"'] if user_has_zero_ver else []
)
doc_col_list = ", ".join(doc_cols)
user_col_list = ", ".join(user_col_list_with_meta)
return (
f"CREATE PUBLICATION {PUBLICATION_NAME} FOR TABLE "
f"notifications, "
f"documents ({doc_col_list}), "
f"folders, "
f"search_source_connectors, "
f"new_chat_messages, "
f"chat_comments, "
f"chat_session_state, "
f'"user" ({user_col_list})'
)
def upgrade() -> None:
conn = op.get_bind()
# ------------------------------------------------------------------
# 1. Add cost_micros to token_usage. Idempotent guard so re-runs in
# dev environments are safe.
# ------------------------------------------------------------------
if not _column_exists(conn, "token_usage", "cost_micros"):
op.add_column(
"token_usage",
sa.Column(
"cost_micros",
sa.BigInteger(),
nullable=False,
server_default="0",
),
)
# ------------------------------------------------------------------
# 2. Rename premium_token_purchases.tokens_granted -> credit_micros_granted.
# ------------------------------------------------------------------
if _column_exists(
conn, "premium_token_purchases", "tokens_granted"
) and not _column_exists(conn, "premium_token_purchases", "credit_micros_granted"):
op.alter_column(
"premium_token_purchases",
"tokens_granted",
new_column_name="credit_micros_granted",
)
# ------------------------------------------------------------------
# 3. Rename user.premium_tokens_* -> premium_credit_micros_*.
#
# We must drop the publication first (it references the old column
# names) and re-acquire the lock for DDL. asyncpg requires LOCK TABLE
# in a transaction block; alembic's outer transaction already holds
# one, but a SAVEPOINT keeps the LOCK + DDL atomic.
# ------------------------------------------------------------------
tx = conn.begin_nested() if conn.in_transaction() else conn.begin()
with tx:
conn.execute(sa.text("SET lock_timeout = '10s'"))
_terminate_blocked_pids(conn, "user")
conn.execute(sa.text('LOCK TABLE "user" IN ACCESS EXCLUSIVE MODE'))
# Re-assert REPLICA IDENTITY DEFAULT for safety; column-list
# publications require at least the PK to be in the column list,
# which is true for both the old and new shape.
conn.execute(sa.text('ALTER TABLE "user" REPLICA IDENTITY DEFAULT'))
# Drop the publication BEFORE renaming columns, otherwise Postgres
# rejects the rename: "cannot drop column ... referenced by
# publication".
conn.execute(sa.text(f"DROP PUBLICATION IF EXISTS {PUBLICATION_NAME}"))
for old, new in (
("premium_tokens_limit", "premium_credit_micros_limit"),
("premium_tokens_used", "premium_credit_micros_used"),
("premium_tokens_reserved", "premium_credit_micros_reserved"),
):
if _column_exists(conn, "user", old) and not _column_exists(
conn, "user", new
):
op.alter_column("user", old, new_column_name=new)
# Update the server_default on the renamed limit column so newly
# inserted users get $5 of credit (== 5_000_000 micros) by
# default. Existing rows are unaffected.
op.alter_column(
"user",
"premium_credit_micros_limit",
server_default="5000000",
)
# Recreate the publication with the new column names.
documents_has_zero_ver = _has_zero_version(conn, "documents")
user_has_zero_ver = _has_zero_version(conn, "user")
conn.execute(
sa.text(
_build_publication_ddl(
USER_COLS,
documents_has_zero_ver=documents_has_zero_ver,
user_has_zero_ver=user_has_zero_ver,
)
)
)
def downgrade() -> None:
"""Revert the rename and drop ``cost_micros``.
Mirrors ``upgrade``: drop the publication, rename columns back, drop
the new column, recreate the publication with the old column list.
Same zero-cache stop/reset runbook applies in reverse.
"""
conn = op.get_bind()
tx = conn.begin_nested() if conn.in_transaction() else conn.begin()
with tx:
conn.execute(sa.text("SET lock_timeout = '10s'"))
_terminate_blocked_pids(conn, "user")
conn.execute(sa.text('LOCK TABLE "user" IN ACCESS EXCLUSIVE MODE'))
conn.execute(sa.text(f"DROP PUBLICATION IF EXISTS {PUBLICATION_NAME}"))
for new, old in (
("premium_credit_micros_limit", "premium_tokens_limit"),
("premium_credit_micros_used", "premium_tokens_used"),
("premium_credit_micros_reserved", "premium_tokens_reserved"),
):
if _column_exists(conn, "user", new) and not _column_exists(
conn, "user", old
):
op.alter_column("user", new, new_column_name=old)
op.alter_column(
"user",
"premium_tokens_limit",
server_default="5000000",
)
legacy_user_cols = [
"id",
"pages_limit",
"pages_used",
"premium_tokens_limit",
"premium_tokens_used",
]
documents_has_zero_ver = _has_zero_version(conn, "documents")
user_has_zero_ver = _has_zero_version(conn, "user")
conn.execute(
sa.text(
_build_publication_ddl(
legacy_user_cols,
documents_has_zero_ver=documents_has_zero_ver,
user_has_zero_ver=user_has_zero_ver,
)
)
)
if _column_exists(
conn, "premium_token_purchases", "credit_micros_granted"
) and not _column_exists(conn, "premium_token_purchases", "tokens_granted"):
op.alter_column(
"premium_token_purchases",
"credit_micros_granted",
new_column_name="tokens_granted",
)
if _column_exists(conn, "token_usage", "cost_micros"):
op.drop_column("token_usage", "cost_micros")

View file

@ -31,6 +31,7 @@ from app.config import (
initialize_image_gen_router, initialize_image_gen_router,
initialize_llm_router, initialize_llm_router,
initialize_openrouter_integration, initialize_openrouter_integration,
initialize_pricing_registration,
initialize_vision_llm_router, initialize_vision_llm_router,
) )
from app.db import User, create_db_and_tables, get_async_session from app.db import User, create_db_and_tables, get_async_session
@ -432,6 +433,7 @@ async def lifespan(app: FastAPI):
await setup_checkpointer_tables() await setup_checkpointer_tables()
initialize_openrouter_integration() initialize_openrouter_integration()
_start_openrouter_background_refresh() _start_openrouter_background_refresh()
initialize_pricing_registration()
initialize_llm_router() initialize_llm_router()
initialize_image_gen_router() initialize_image_gen_router()
initialize_vision_llm_router() initialize_vision_llm_router()

View file

@ -22,10 +22,12 @@ def init_worker(**kwargs):
initialize_image_gen_router, initialize_image_gen_router,
initialize_llm_router, initialize_llm_router,
initialize_openrouter_integration, initialize_openrouter_integration,
initialize_pricing_registration,
initialize_vision_llm_router, initialize_vision_llm_router,
) )
initialize_openrouter_integration() initialize_openrouter_integration()
initialize_pricing_registration()
initialize_llm_router() initialize_llm_router()
initialize_image_gen_router() initialize_image_gen_router()
initialize_vision_llm_router() initialize_vision_llm_router()

View file

@ -138,7 +138,11 @@ def load_global_image_gen_configs():
try: try:
with open(global_config_file, encoding="utf-8") as f: with open(global_config_file, encoding="utf-8") as f:
data = yaml.safe_load(f) data = yaml.safe_load(f)
return data.get("global_image_generation_configs", []) configs = data.get("global_image_generation_configs", []) or []
for cfg in configs:
if isinstance(cfg, dict):
cfg.setdefault("billing_tier", "free")
return configs
except Exception as e: except Exception as e:
print(f"Warning: Failed to load global image generation configs: {e}") print(f"Warning: Failed to load global image generation configs: {e}")
return [] return []
@ -153,7 +157,11 @@ def load_global_vision_llm_configs():
try: try:
with open(global_config_file, encoding="utf-8") as f: with open(global_config_file, encoding="utf-8") as f:
data = yaml.safe_load(f) data = yaml.safe_load(f)
return data.get("global_vision_llm_configs", []) configs = data.get("global_vision_llm_configs", []) or []
for cfg in configs:
if isinstance(cfg, dict):
cfg.setdefault("billing_tier", "free")
return configs
except Exception as e: except Exception as e:
print(f"Warning: Failed to load global vision LLM configs: {e}") print(f"Warning: Failed to load global vision LLM configs: {e}")
return [] return []
@ -254,6 +262,15 @@ def load_openrouter_integration_settings() -> dict | None:
"anonymous_enabled_free", settings["anonymous_enabled"] "anonymous_enabled_free", settings["anonymous_enabled"]
) )
# Image generation + vision LLM emission are opt-in (issue L).
# OpenRouter's catalogue contains hundreds of image / vision
# capable models; auto-injecting all of them into every
# deployment would explode the model selector and surprise
# operators upgrading from prior versions. Default to False so
# admins must explicitly turn them on.
settings.setdefault("image_generation_enabled", False)
settings.setdefault("vision_enabled", False)
return settings return settings
except Exception as e: except Exception as e:
print(f"Warning: Failed to load OpenRouter integration settings: {e}") print(f"Warning: Failed to load OpenRouter integration settings: {e}")
@ -296,10 +313,60 @@ def initialize_openrouter_integration():
) )
else: else:
print("Info: OpenRouter integration enabled but no models fetched") print("Info: OpenRouter integration enabled but no models fetched")
# Image generation + vision LLM emissions are opt-in (issue L).
# Both reuse the catalogue already cached by ``service.initialize``
# so we don't make additional network calls here.
if settings.get("image_generation_enabled"):
try:
image_configs = service.get_image_generation_configs()
if image_configs:
config.GLOBAL_IMAGE_GEN_CONFIGS.extend(image_configs)
print(
f"Info: OpenRouter integration added {len(image_configs)} "
f"image-generation models"
)
except Exception as e:
print(f"Warning: Failed to inject OpenRouter image-gen configs: {e}")
if settings.get("vision_enabled"):
try:
vision_configs = service.get_vision_llm_configs()
if vision_configs:
config.GLOBAL_VISION_LLM_CONFIGS.extend(vision_configs)
print(
f"Info: OpenRouter integration added {len(vision_configs)} "
f"vision LLM models"
)
except Exception as e:
print(f"Warning: Failed to inject OpenRouter vision-LLM configs: {e}")
except Exception as e: except Exception as e:
print(f"Warning: Failed to initialize OpenRouter integration: {e}") print(f"Warning: Failed to initialize OpenRouter integration: {e}")
def initialize_pricing_registration():
"""
Teach LiteLLM the per-token cost of every deployment in
``config.GLOBAL_LLM_CONFIGS`` (OpenRouter dynamic models pulled
from the OpenRouter catalogue + any operator-declared YAML pricing).
Must run AFTER ``initialize_openrouter_integration()`` so the
OpenRouter catalogue is populated and BEFORE the first LLM call so
``response_cost`` is available in ``TokenTrackingCallback``.
Failures are logged but never raised startup must not be blocked
by a missing pricing entry; the worst-case is the model debits 0.
"""
try:
from app.services.pricing_registration import (
register_pricing_from_global_configs,
)
register_pricing_from_global_configs()
except Exception as e:
print(f"Warning: Failed to register LiteLLM pricing: {e}")
def initialize_llm_router(): def initialize_llm_router():
""" """
Initialize the LLM Router service for Auto mode. Initialize the LLM Router service for Auto mode.
@ -444,14 +511,54 @@ class Config:
os.getenv("STRIPE_RECONCILIATION_BATCH_SIZE", "100") os.getenv("STRIPE_RECONCILIATION_BATCH_SIZE", "100")
) )
# Premium token quota settings # Premium credit (micro-USD) quota settings.
PREMIUM_TOKEN_LIMIT = int(os.getenv("PREMIUM_TOKEN_LIMIT", "3000000")) #
# Storage unit is integer micro-USD (1_000_000 = $1.00). The legacy
# ``PREMIUM_TOKEN_LIMIT`` and ``STRIPE_TOKENS_PER_UNIT`` env vars are
# still honoured for one release as fall-back values — the prior
# $1-per-1M-tokens Stripe price means every existing value maps 1:1
# to micros, so operators upgrading without changing their .env still
# get correct behaviour. A startup deprecation warning fires below if
# they're set.
PREMIUM_CREDIT_MICROS_LIMIT = int(
os.getenv("PREMIUM_CREDIT_MICROS_LIMIT")
or os.getenv("PREMIUM_TOKEN_LIMIT", "5000000")
)
STRIPE_PREMIUM_TOKEN_PRICE_ID = os.getenv("STRIPE_PREMIUM_TOKEN_PRICE_ID") STRIPE_PREMIUM_TOKEN_PRICE_ID = os.getenv("STRIPE_PREMIUM_TOKEN_PRICE_ID")
STRIPE_TOKENS_PER_UNIT = int(os.getenv("STRIPE_TOKENS_PER_UNIT", "1000000")) STRIPE_CREDIT_MICROS_PER_UNIT = int(
os.getenv("STRIPE_CREDIT_MICROS_PER_UNIT")
or os.getenv("STRIPE_TOKENS_PER_UNIT", "1000000")
)
STRIPE_TOKEN_BUYING_ENABLED = ( STRIPE_TOKEN_BUYING_ENABLED = (
os.getenv("STRIPE_TOKEN_BUYING_ENABLED", "FALSE").upper() == "TRUE" os.getenv("STRIPE_TOKEN_BUYING_ENABLED", "FALSE").upper() == "TRUE"
) )
# Safety ceiling on the per-call premium reservation. ``stream_new_chat``
# estimates an upper-bound cost from ``litellm.get_model_info`` x the
# config's ``quota_reserve_tokens`` and clamps the result to this value
# so a misconfigured "$1000/M" model can't lock the user's whole balance
# on one call. Default $1.00 covers realistic worst-cases (Opus + 4K
# reserve_tokens ≈ $0.36) with headroom.
QUOTA_MAX_RESERVE_MICROS = int(os.getenv("QUOTA_MAX_RESERVE_MICROS", "1000000"))
if os.getenv("PREMIUM_TOKEN_LIMIT") and not os.getenv(
"PREMIUM_CREDIT_MICROS_LIMIT"
):
print(
"Warning: PREMIUM_TOKEN_LIMIT is deprecated; rename to "
"PREMIUM_CREDIT_MICROS_LIMIT (1:1 numerical mapping under the "
"current Stripe price). The old key will be removed in a "
"future release."
)
if os.getenv("STRIPE_TOKENS_PER_UNIT") and not os.getenv(
"STRIPE_CREDIT_MICROS_PER_UNIT"
):
print(
"Warning: STRIPE_TOKENS_PER_UNIT is deprecated; rename to "
"STRIPE_CREDIT_MICROS_PER_UNIT (1:1 numerical mapping). "
"The old key will be removed in a future release."
)
# Anonymous / no-login mode settings # Anonymous / no-login mode settings
NOLOGIN_MODE_ENABLED = os.getenv("NOLOGIN_MODE_ENABLED", "FALSE").upper() == "TRUE" NOLOGIN_MODE_ENABLED = os.getenv("NOLOGIN_MODE_ENABLED", "FALSE").upper() == "TRUE"
ANON_TOKEN_LIMIT = int(os.getenv("ANON_TOKEN_LIMIT", "500000")) ANON_TOKEN_LIMIT = int(os.getenv("ANON_TOKEN_LIMIT", "500000"))
@ -464,6 +571,35 @@ class Config:
# Default quota reserve tokens when not specified per-model # Default quota reserve tokens when not specified per-model
QUOTA_MAX_RESERVE_PER_CALL = int(os.getenv("QUOTA_MAX_RESERVE_PER_CALL", "8000")) QUOTA_MAX_RESERVE_PER_CALL = int(os.getenv("QUOTA_MAX_RESERVE_PER_CALL", "8000"))
# Per-image reservation (in micro-USD) used by ``billable_call`` for the
# ``POST /image-generations`` endpoint when the global config does not
# override it. $0.05 covers realistic worst-cases for current OpenAI /
# OpenRouter image-gen pricing. Bypassed entirely for free configs.
QUOTA_DEFAULT_IMAGE_RESERVE_MICROS = int(
os.getenv("QUOTA_DEFAULT_IMAGE_RESERVE_MICROS", "50000")
)
# Per-podcast reservation (in micro-USD). One agent LLM call generating
# a transcript, typically 5k-20k completion tokens. $0.20 covers a long
# premium-model run. Tune via env.
QUOTA_DEFAULT_PODCAST_RESERVE_MICROS = int(
os.getenv("QUOTA_DEFAULT_PODCAST_RESERVE_MICROS", "200000")
)
# Per-video-presentation reservation (in micro-USD). Fan-out of N
# slide-scene generations (up to ``VIDEO_PRESENTATION_MAX_SLIDES=30``)
# plus refine retries; can produce many premium completions. $1.00
# covers worst-case. Tune via env.
#
# NOTE: this equals the existing ``QUOTA_MAX_RESERVE_MICROS`` default of
# 1_000_000. The override path in ``billable_call`` bypasses the
# per-call clamp in ``estimate_call_reserve_micros``, so this is the
# *actual* hold — raising it via env is fine but means a single video
# task can lock $1+ of credit.
QUOTA_DEFAULT_VIDEO_PRESENTATION_RESERVE_MICROS = int(
os.getenv("QUOTA_DEFAULT_VIDEO_PRESENTATION_RESERVE_MICROS", "1000000")
)
# Abuse prevention: concurrent stream cap and CAPTCHA # Abuse prevention: concurrent stream cap and CAPTCHA
ANON_MAX_CONCURRENT_STREAMS = int(os.getenv("ANON_MAX_CONCURRENT_STREAMS", "2")) ANON_MAX_CONCURRENT_STREAMS = int(os.getenv("ANON_MAX_CONCURRENT_STREAMS", "2"))
ANON_CAPTCHA_REQUEST_THRESHOLD = int( ANON_CAPTCHA_REQUEST_THRESHOLD = int(

View file

@ -19,6 +19,24 @@
# Structure matches NewLLMConfig: # Structure matches NewLLMConfig:
# - Model configuration (provider, model_name, api_key, etc.) # - Model configuration (provider, model_name, api_key, etc.)
# - Prompt configuration (system_instructions, citations_enabled) # - Prompt configuration (system_instructions, citations_enabled)
#
# COST-BASED PREMIUM CREDITS:
# Each premium config bills the user's USD-credit balance based on the
# actual provider cost reported by LiteLLM. For models LiteLLM already
# knows (most OpenAI/Anthropic/etc. names) you don't need to do anything.
# For custom Azure deployment names (e.g. an in-house "gpt-5.4" deployment)
# or any model LiteLLM doesn't have in its built-in pricing table, declare
# per-token costs inline so they bill correctly:
#
# litellm_params:
# base_model: "my-custom-azure-deploy"
# # USD per token; e.g. 0.000003 == $3.00 per million input tokens
# input_cost_per_token: 0.000003
# output_cost_per_token: 0.000015
#
# OpenRouter dynamic models pull pricing automatically from OpenRouter's
# API — no inline declaration needed. Models without resolvable pricing
# debit $0 from the user's balance and log a WARNING.
# Router Settings for Auto Mode # Router Settings for Auto Mode
# These settings control how the LiteLLM Router distributes requests across models # These settings control how the LiteLLM Router distributes requests across models
@ -292,6 +310,17 @@ openrouter_integration:
free_rpm: 20 free_rpm: 20
free_tpm: 100000 free_tpm: 100000
# Image generation + vision LLM emission are OPT-IN. OpenRouter's catalogue
# contains hundreds of image- and vision-capable models; turning these on
# injects them into the global Image-Generation / Vision-LLM model
# selectors alongside any static configs. Tier (free/premium) is derived
# per model the same way it is for chat (`:free` suffix or zero pricing).
# When a user picks a premium image/vision model the call debits the
# shared $5 USD-cost-based premium credit pool — so leaving these off
# avoids surprise quota burn on existing deployments. Default: false.
image_generation_enabled: false
vision_enabled: false
litellm_params: litellm_params:
max_tokens: 16384 max_tokens: 16384
system_instructions: "" system_instructions: ""

View file

@ -731,6 +731,7 @@ class TokenUsage(BaseModel, TimestampMixin):
prompt_tokens = Column(Integer, nullable=False, default=0) prompt_tokens = Column(Integer, nullable=False, default=0)
completion_tokens = Column(Integer, nullable=False, default=0) completion_tokens = Column(Integer, nullable=False, default=0)
total_tokens = Column(Integer, nullable=False, default=0) total_tokens = Column(Integer, nullable=False, default=0)
cost_micros = Column(BigInteger, nullable=False, default=0, server_default="0")
model_breakdown = Column(JSONB, nullable=True) model_breakdown = Column(JSONB, nullable=True)
call_details = Column(JSONB, nullable=True) call_details = Column(JSONB, nullable=True)
@ -1793,7 +1794,15 @@ class PagePurchase(Base, TimestampMixin):
class PremiumTokenPurchase(Base, TimestampMixin): class PremiumTokenPurchase(Base, TimestampMixin):
"""Tracks Stripe checkout sessions used to grant additional premium token credits.""" """Tracks Stripe checkout sessions used to grant additional premium credit (USD micro-units).
Note: the table name is preserved (``premium_token_purchases``) for
operational continuity even though the unit is now USD micro-credits
instead of raw tokens. The ``credit_micros_granted`` column replaced
the legacy ``tokens_granted`` in migration 140; the stored values
were not transformed because the prior $1 = 1M tokens Stripe price
makes the unit conversion 1:1 numerically.
"""
__tablename__ = "premium_token_purchases" __tablename__ = "premium_token_purchases"
__allow_unmapped__ = True __allow_unmapped__ = True
@ -1810,7 +1819,7 @@ class PremiumTokenPurchase(Base, TimestampMixin):
) )
stripe_payment_intent_id = Column(String(255), nullable=True, index=True) stripe_payment_intent_id = Column(String(255), nullable=True, index=True)
quantity = Column(Integer, nullable=False) quantity = Column(Integer, nullable=False)
tokens_granted = Column(BigInteger, nullable=False) credit_micros_granted = Column(BigInteger, nullable=False)
amount_total = Column(Integer, nullable=True) amount_total = Column(Integer, nullable=True)
currency = Column(String(10), nullable=True) currency = Column(String(10), nullable=True)
status = Column( status = Column(
@ -2109,16 +2118,16 @@ if config.AUTH_TYPE == "GOOGLE":
) )
pages_used = Column(Integer, nullable=False, default=0, server_default="0") pages_used = Column(Integer, nullable=False, default=0, server_default="0")
premium_tokens_limit = Column( premium_credit_micros_limit = Column(
BigInteger, BigInteger,
nullable=False, nullable=False,
default=config.PREMIUM_TOKEN_LIMIT, default=config.PREMIUM_CREDIT_MICROS_LIMIT,
server_default=str(config.PREMIUM_TOKEN_LIMIT), server_default=str(config.PREMIUM_CREDIT_MICROS_LIMIT),
) )
premium_tokens_used = Column( premium_credit_micros_used = Column(
BigInteger, nullable=False, default=0, server_default="0" BigInteger, nullable=False, default=0, server_default="0"
) )
premium_tokens_reserved = Column( premium_credit_micros_reserved = Column(
BigInteger, nullable=False, default=0, server_default="0" BigInteger, nullable=False, default=0, server_default="0"
) )
@ -2241,16 +2250,16 @@ else:
) )
pages_used = Column(Integer, nullable=False, default=0, server_default="0") pages_used = Column(Integer, nullable=False, default=0, server_default="0")
premium_tokens_limit = Column( premium_credit_micros_limit = Column(
BigInteger, BigInteger,
nullable=False, nullable=False,
default=config.PREMIUM_TOKEN_LIMIT, default=config.PREMIUM_CREDIT_MICROS_LIMIT,
server_default=str(config.PREMIUM_TOKEN_LIMIT), server_default=str(config.PREMIUM_CREDIT_MICROS_LIMIT),
) )
premium_tokens_used = Column( premium_credit_micros_used = Column(
BigInteger, nullable=False, default=0, server_default="0" BigInteger, nullable=False, default=0, server_default="0"
) )
premium_tokens_reserved = Column( premium_credit_micros_reserved = Column(
BigInteger, nullable=False, default=0, server_default="0" BigInteger, nullable=False, default=0, server_default="0"
) )

View file

@ -68,12 +68,25 @@ class EtlPipelineService:
etl_service="VISION_LLM", etl_service="VISION_LLM",
content_type="image", content_type="image",
) )
except Exception: except Exception as exc:
logging.warning( # Special-case quota exhaustion so we log a clearer message
"Vision LLM failed for %s, falling back to document parser", # — the vision LLM didn't "fail", the user just ran out of
request.filename, # premium credit. Falling through to the document parser
exc_info=True, # is a graceful degradation: OCR/Unstructured still
) # extracts text from the image without burning credit.
from app.services.billable_calls import QuotaInsufficientError
if isinstance(exc, QuotaInsufficientError):
logging.info(
"Vision LLM quota exhausted for %s; falling back to document parser",
request.filename,
)
else:
logging.warning(
"Vision LLM failed for %s, falling back to document parser",
request.filename,
exc_info=True,
)
else: else:
logging.info( logging.info(
"No vision LLM provided, falling back to document parser for %s", "No vision LLM provided, falling back to document parser for %s",

View file

@ -36,6 +36,11 @@ from app.schemas import (
ImageGenerationListRead, ImageGenerationListRead,
ImageGenerationRead, ImageGenerationRead,
) )
from app.services.billable_calls import (
DEFAULT_IMAGE_RESERVE_MICROS,
QuotaInsufficientError,
billable_call,
)
from app.services.image_gen_router_service import ( from app.services.image_gen_router_service import (
IMAGE_GEN_AUTO_MODE_ID, IMAGE_GEN_AUTO_MODE_ID,
ImageGenRouterService, ImageGenRouterService,
@ -92,6 +97,50 @@ def _build_model_string(
return f"{prefix}/{model_name}" return f"{prefix}/{model_name}"
async def _resolve_billing_for_image_gen(
session: AsyncSession,
config_id: int | None,
search_space: SearchSpace,
) -> tuple[str, str, int]:
"""Resolve ``(billing_tier, base_model, reserve_micros)`` for a request.
The resolution mirrors ``_execute_image_generation``'s lookup tree but
only extracts the fields needed for billing we do this *before*
``billable_call`` so the reservation is correctly sized for the
config that will actually run, and so we don't open an
``ImageGeneration`` row for a request that's about to 402.
User-owned (positive ID) BYOK configs are always free they cost
the user nothing on our side. Auto mode currently treats as free
because the underlying router can dispatch to either premium or
free YAML configs and we don't surface the resolved deployment up
here yet. Bringing Auto under premium billing would require
threading the chosen deployment back from ``ImageGenRouterService``.
"""
resolved_id = config_id
if resolved_id is None:
resolved_id = search_space.image_generation_config_id or IMAGE_GEN_AUTO_MODE_ID
if is_image_gen_auto_mode(resolved_id):
return ("free", "auto", DEFAULT_IMAGE_RESERVE_MICROS)
if resolved_id < 0:
cfg = _get_global_image_gen_config(resolved_id) or {}
billing_tier = str(cfg.get("billing_tier", "free")).lower()
base_model = _build_model_string(
cfg.get("provider", ""),
cfg.get("model_name", ""),
cfg.get("custom_provider"),
)
reserve_micros = int(
cfg.get("quota_reserve_micros") or DEFAULT_IMAGE_RESERVE_MICROS
)
return (billing_tier, base_model, reserve_micros)
# Positive ID = user-owned BYOK image-gen config — always free.
return ("free", "user_byok", DEFAULT_IMAGE_RESERVE_MICROS)
async def _execute_image_generation( async def _execute_image_generation(
session: AsyncSession, session: AsyncSession,
image_gen: ImageGeneration, image_gen: ImageGeneration,
@ -225,6 +274,9 @@ async def get_global_image_gen_configs(
"litellm_params": {}, "litellm_params": {},
"is_global": True, "is_global": True,
"is_auto_mode": True, "is_auto_mode": True,
# Auto mode currently treated as free until per-deployment
# billing-tier surfacing lands (see _resolve_billing_for_image_gen).
"billing_tier": "free",
} }
) )
@ -241,6 +293,8 @@ async def get_global_image_gen_configs(
"api_version": cfg.get("api_version") or None, "api_version": cfg.get("api_version") or None,
"litellm_params": cfg.get("litellm_params", {}), "litellm_params": cfg.get("litellm_params", {}),
"is_global": True, "is_global": True,
"billing_tier": cfg.get("billing_tier", "free"),
"quota_reserve_micros": cfg.get("quota_reserve_micros"),
} }
) )
@ -454,7 +508,26 @@ async def create_image_generation(
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user), user: User = Depends(current_active_user),
): ):
"""Create and execute an image generation request.""" """Create and execute an image generation request.
Premium configs are gated by the user's shared premium credit pool.
The flow is:
1. Permission check + load the search space (cheap, no provider call).
2. Resolve which config will run so we know its billing tier and the
worst-case reservation size *before* opening any DB rows.
3. Wrap the entire ImageGeneration row insert + provider call in
``billable_call``. If quota is denied, ``billable_call`` raises
``QuotaInsufficientError`` *before* we flush a row, which we
translate to HTTP 402 (no orphaned rows on the user's account,
no inserted error rows for "you ran out of credit").
4. On success, the actual ``response_cost`` flows through the
LiteLLM callback into the accumulator, and ``billable_call``
finalizes the debit at exit. Inner ``try/except`` still catches
provider errors and stores them on ``error_message`` (HTTP 200
with ``error_message`` set is preserved for failed-but-not-quota
scenarios clients already know how to surface those).
"""
try: try:
await check_permission( await check_permission(
session, session,
@ -471,33 +544,70 @@ async def create_image_generation(
if not search_space: if not search_space:
raise HTTPException(status_code=404, detail="Search space not found") raise HTTPException(status_code=404, detail="Search space not found")
db_image_gen = ImageGeneration( billing_tier, base_model, reserve_micros = await _resolve_billing_for_image_gen(
prompt=data.prompt, session, data.image_generation_config_id, search_space
model=data.model,
n=data.n,
quality=data.quality,
size=data.size,
style=data.style,
response_format=data.response_format,
image_generation_config_id=data.image_generation_config_id,
search_space_id=data.search_space_id,
created_by_id=user.id,
) )
session.add(db_image_gen)
await session.flush()
try: # billable_call runs OUTSIDE the inner try/except so QuotaInsufficientError
await _execute_image_generation(session, db_image_gen, search_space) # propagates to the outer ``except QuotaInsufficientError`` handler
except Exception as e: # below as HTTP 402 — it is intentionally NOT swallowed into
logger.exception("Image generation call failed") # ``error_message`` because that would (1) imply a successful row
db_image_gen.error_message = str(e) # exists when none does, and (2) return HTTP 200 to a client
# whose request was actively *denied* (issue K).
async with billable_call(
user_id=search_space.user_id,
search_space_id=data.search_space_id,
billing_tier=billing_tier,
base_model=base_model,
quota_reserve_micros_override=reserve_micros,
usage_type="image_generation",
call_details={"model": base_model, "prompt": data.prompt[:100]},
):
db_image_gen = ImageGeneration(
prompt=data.prompt,
model=data.model,
n=data.n,
quality=data.quality,
size=data.size,
style=data.style,
response_format=data.response_format,
image_generation_config_id=data.image_generation_config_id,
search_space_id=data.search_space_id,
created_by_id=user.id,
)
session.add(db_image_gen)
await session.flush()
await session.commit() try:
await session.refresh(db_image_gen) await _execute_image_generation(session, db_image_gen, search_space)
return db_image_gen except Exception as e:
logger.exception("Image generation call failed")
db_image_gen.error_message = str(e)
await session.commit()
await session.refresh(db_image_gen)
return db_image_gen
except HTTPException: except HTTPException:
raise raise
except QuotaInsufficientError as exc:
# The user's premium credit pool is empty. No DB row is created
# because ``billable_call`` denies before yielding (issue K).
await session.rollback()
raise HTTPException(
status_code=402,
detail={
"error_code": "premium_quota_exhausted",
"usage_type": exc.usage_type,
"used_micros": exc.used_micros,
"limit_micros": exc.limit_micros,
"remaining_micros": exc.remaining_micros,
"message": (
"Out of premium credits for image generation. "
"Purchase additional credits or switch to a free model."
),
},
) from exc
except SQLAlchemyError: except SQLAlchemyError:
await session.rollback() await session.rollback()
raise HTTPException( raise HTTPException(

View file

@ -1366,7 +1366,11 @@ async def append_message(
# flush assigns the PK/defaults without a round-trip SELECT # flush assigns the PK/defaults without a round-trip SELECT
await session.flush() await session.flush()
# Persist token usage if provided (for assistant messages) # Persist token usage if provided (for assistant messages).
# ``cost_micros`` is the provider USD cost reported by LiteLLM,
# forwarded by the FE through the appendMessage round-trip so
# the historical TokenUsage row matches the credit debit applied
# at finalize time.
token_usage_data = raw_body.get("token_usage") token_usage_data = raw_body.get("token_usage")
if token_usage_data and message_role == NewChatMessageRole.ASSISTANT: if token_usage_data and message_role == NewChatMessageRole.ASSISTANT:
await record_token_usage( await record_token_usage(
@ -1377,6 +1381,7 @@ async def append_message(
prompt_tokens=token_usage_data.get("prompt_tokens", 0), prompt_tokens=token_usage_data.get("prompt_tokens", 0),
completion_tokens=token_usage_data.get("completion_tokens", 0), completion_tokens=token_usage_data.get("completion_tokens", 0),
total_tokens=token_usage_data.get("total_tokens", 0), total_tokens=token_usage_data.get("total_tokens", 0),
cost_micros=token_usage_data.get("cost_micros", 0),
model_breakdown=token_usage_data.get("usage"), model_breakdown=token_usage_data.get("usage"),
call_details=token_usage_data.get("call_details"), call_details=token_usage_data.get("call_details"),
thread_id=thread_id, thread_id=thread_id,

View file

@ -594,6 +594,7 @@ async def _get_image_gen_config_by_id(
"model_name": "auto", "model_name": "auto",
"is_global": True, "is_global": True,
"is_auto_mode": True, "is_auto_mode": True,
"billing_tier": "free",
} }
if config_id < 0: if config_id < 0:
@ -610,6 +611,7 @@ async def _get_image_gen_config_by_id(
"api_version": cfg.get("api_version") or None, "api_version": cfg.get("api_version") or None,
"litellm_params": cfg.get("litellm_params", {}), "litellm_params": cfg.get("litellm_params", {}),
"is_global": True, "is_global": True,
"billing_tier": cfg.get("billing_tier", "free"),
} }
return None return None
@ -652,6 +654,7 @@ async def _get_vision_llm_config_by_id(
"model_name": "auto", "model_name": "auto",
"is_global": True, "is_global": True,
"is_auto_mode": True, "is_auto_mode": True,
"billing_tier": "free",
} }
if config_id < 0: if config_id < 0:
@ -668,6 +671,7 @@ async def _get_vision_llm_config_by_id(
"api_version": cfg.get("api_version") or None, "api_version": cfg.get("api_version") or None,
"litellm_params": cfg.get("litellm_params", {}), "litellm_params": cfg.get("litellm_params", {}),
"is_global": True, "is_global": True,
"billing_tier": cfg.get("billing_tier", "free"),
} }
return None return None

View file

@ -251,9 +251,16 @@ async def _fulfill_completed_token_purchase(
metadata = _get_metadata(checkout_session) metadata = _get_metadata(checkout_session)
user_id = metadata.get("user_id") user_id = metadata.get("user_id")
quantity = int(metadata.get("quantity", "0")) quantity = int(metadata.get("quantity", "0"))
tokens_per_unit = int(metadata.get("tokens_per_unit", "0")) # Read the new metadata key first, fall back to the legacy one so
# in-flight checkout sessions created before the cost-credits
# release still fulfil correctly (the unit is numerically the
# same: $1 buys 1_000_000 micro-USD == 1_000_000 tokens).
credit_micros_per_unit = int(
metadata.get("credit_micros_per_unit")
or metadata.get("tokens_per_unit", "0")
)
if not user_id or quantity <= 0 or tokens_per_unit <= 0: if not user_id or quantity <= 0 or credit_micros_per_unit <= 0:
logger.error( logger.error(
"Skipping token fulfillment for session %s: incomplete metadata %s", "Skipping token fulfillment for session %s: incomplete metadata %s",
checkout_session_id, checkout_session_id,
@ -268,7 +275,7 @@ async def _fulfill_completed_token_purchase(
getattr(checkout_session, "payment_intent", None) getattr(checkout_session, "payment_intent", None)
), ),
quantity=quantity, quantity=quantity,
tokens_granted=quantity * tokens_per_unit, credit_micros_granted=quantity * credit_micros_per_unit,
amount_total=getattr(checkout_session, "amount_total", None), amount_total=getattr(checkout_session, "amount_total", None),
currency=getattr(checkout_session, "currency", None), currency=getattr(checkout_session, "currency", None),
status=PremiumTokenPurchaseStatus.PENDING, status=PremiumTokenPurchaseStatus.PENDING,
@ -303,9 +310,14 @@ async def _fulfill_completed_token_purchase(
purchase.stripe_payment_intent_id = _normalize_optional_string( purchase.stripe_payment_intent_id = _normalize_optional_string(
getattr(checkout_session, "payment_intent", None) getattr(checkout_session, "payment_intent", None)
) )
user.premium_tokens_limit = ( # Top up the user's credit balance by the granted micro-USD amount.
max(user.premium_tokens_used, user.premium_tokens_limit) # ``max(used, limit)`` clamps the case where the legacy code wrote a
+ purchase.tokens_granted # used value above the limit (e.g. underbilling rounding) so adding
# ``credit_micros_granted`` always lifts the limit by the full pack
# size rather than disappearing into past overuse.
user.premium_credit_micros_limit = (
max(user.premium_credit_micros_used, user.premium_credit_micros_limit)
+ purchase.credit_micros_granted
) )
await db_session.commit() await db_session.commit()
@ -532,12 +544,18 @@ async def create_token_checkout_session(
user: User = Depends(current_active_user), user: User = Depends(current_active_user),
db_session: AsyncSession = Depends(get_async_session), db_session: AsyncSession = Depends(get_async_session),
): ):
"""Create a Stripe Checkout Session for buying premium token packs.""" """Create a Stripe Checkout Session for buying premium credit packs.
Each pack grants ``STRIPE_CREDIT_MICROS_PER_UNIT`` micro-USD of
credit (default 1_000_000 = $1.00). The user's balance is debited
at the actual provider cost reported by LiteLLM at finalize time,
so $1 of credit always buys $1 worth of provider usage at cost.
"""
_ensure_token_buying_enabled() _ensure_token_buying_enabled()
stripe_client = get_stripe_client() stripe_client = get_stripe_client()
price_id = _get_required_token_price_id() price_id = _get_required_token_price_id()
success_url, cancel_url = _get_token_checkout_urls(body.search_space_id) success_url, cancel_url = _get_token_checkout_urls(body.search_space_id)
tokens_granted = body.quantity * config.STRIPE_TOKENS_PER_UNIT credit_micros_granted = body.quantity * config.STRIPE_CREDIT_MICROS_PER_UNIT
try: try:
checkout_session = stripe_client.v1.checkout.sessions.create( checkout_session = stripe_client.v1.checkout.sessions.create(
@ -556,8 +574,8 @@ async def create_token_checkout_session(
"metadata": { "metadata": {
"user_id": str(user.id), "user_id": str(user.id),
"quantity": str(body.quantity), "quantity": str(body.quantity),
"tokens_per_unit": str(config.STRIPE_TOKENS_PER_UNIT), "credit_micros_per_unit": str(config.STRIPE_CREDIT_MICROS_PER_UNIT),
"purchase_type": "premium_tokens", "purchase_type": "premium_credit",
}, },
} }
) )
@ -583,7 +601,7 @@ async def create_token_checkout_session(
getattr(checkout_session, "payment_intent", None) getattr(checkout_session, "payment_intent", None)
), ),
quantity=body.quantity, quantity=body.quantity,
tokens_granted=tokens_granted, credit_micros_granted=credit_micros_granted,
amount_total=getattr(checkout_session, "amount_total", None), amount_total=getattr(checkout_session, "amount_total", None),
currency=getattr(checkout_session, "currency", None), currency=getattr(checkout_session, "currency", None),
status=PremiumTokenPurchaseStatus.PENDING, status=PremiumTokenPurchaseStatus.PENDING,
@ -598,14 +616,19 @@ async def create_token_checkout_session(
async def get_token_status( async def get_token_status(
user: User = Depends(current_active_user), user: User = Depends(current_active_user),
): ):
"""Return token-buying availability and current premium quota for frontend.""" """Return token-buying availability and current premium credit quota for frontend.
used = user.premium_tokens_used
limit = user.premium_tokens_limit Values are in micro-USD (1_000_000 = $1.00); the FE divides by 1M
when displaying. The route name is preserved for back-compat with
pinned client deployments.
"""
used = user.premium_credit_micros_used
limit = user.premium_credit_micros_limit
return TokenStripeStatusResponse( return TokenStripeStatusResponse(
token_buying_enabled=config.STRIPE_TOKEN_BUYING_ENABLED, token_buying_enabled=config.STRIPE_TOKEN_BUYING_ENABLED,
premium_tokens_used=used, premium_credit_micros_used=used,
premium_tokens_limit=limit, premium_credit_micros_limit=limit,
premium_tokens_remaining=max(0, limit - used), premium_credit_micros_remaining=max(0, limit - used),
) )

View file

@ -82,6 +82,9 @@ async def get_global_vision_llm_configs(
"litellm_params": {}, "litellm_params": {},
"is_global": True, "is_global": True,
"is_auto_mode": True, "is_auto_mode": True,
# Auto mode treated as free until per-deployment billing-tier
# surfacing lands; see ``get_vision_llm`` for parity.
"billing_tier": "free",
} }
) )
@ -98,6 +101,10 @@ async def get_global_vision_llm_configs(
"api_version": cfg.get("api_version") or None, "api_version": cfg.get("api_version") or None,
"litellm_params": cfg.get("litellm_params", {}), "litellm_params": cfg.get("litellm_params", {}),
"is_global": True, "is_global": True,
"billing_tier": cfg.get("billing_tier", "free"),
"quota_reserve_tokens": cfg.get("quota_reserve_tokens"),
"input_cost_per_token": cfg.get("input_cost_per_token"),
"output_cost_per_token": cfg.get("output_cost_per_token"),
} }
) )

View file

@ -215,6 +215,12 @@ class GlobalImageGenConfigRead(BaseModel):
Schema for reading global image generation configs from YAML. Schema for reading global image generation configs from YAML.
Global configs have negative IDs. API key is hidden. Global configs have negative IDs. API key is hidden.
ID 0 is reserved for Auto mode (LiteLLM Router load balancing). ID 0 is reserved for Auto mode (LiteLLM Router load balancing).
The ``billing_tier`` field allows the frontend to show a Premium/Free
badge and (more importantly) tells the backend whether to debit the
user's premium credit pool when this config is used. ``"free"`` is
the default for backward compatibility admins must explicitly opt
a global config into ``"premium"``.
""" """
id: int = Field( id: int = Field(
@ -231,3 +237,15 @@ class GlobalImageGenConfigRead(BaseModel):
litellm_params: dict[str, Any] | None = None litellm_params: dict[str, Any] | None = None
is_global: bool = True is_global: bool = True
is_auto_mode: bool = False is_auto_mode: bool = False
billing_tier: str = Field(
default="free",
description="'free' or 'premium'. Premium debits the user's premium credit pool (USD-cost-based).",
)
quota_reserve_micros: int | None = Field(
default=None,
description=(
"Optional override for the reservation amount (in micro-USD) used when "
"this image generation is premium. Falls back to "
"QUOTA_DEFAULT_IMAGE_RESERVE_MICROS when omitted."
),
)

View file

@ -39,6 +39,7 @@ class TokenUsageSummary(BaseModel):
prompt_tokens: int = 0 prompt_tokens: int = 0
completion_tokens: int = 0 completion_tokens: int = 0
total_tokens: int = 0 total_tokens: int = 0
cost_micros: int = 0
model_breakdown: dict | None = None model_breakdown: dict | None = None
model_config = ConfigDict(from_attributes=True) model_config = ConfigDict(from_attributes=True)

View file

@ -70,13 +70,17 @@ class CreateTokenCheckoutSessionResponse(BaseModel):
class TokenPurchaseRead(BaseModel): class TokenPurchaseRead(BaseModel):
"""Serialized premium token purchase record.""" """Serialized premium credit purchase record.
``credit_micros_granted`` is in micro-USD (1_000_000 = $1.00). The
schema name kept ``Token`` for API back-compat with pinned clients.
"""
id: uuid.UUID id: uuid.UUID
stripe_checkout_session_id: str stripe_checkout_session_id: str
stripe_payment_intent_id: str | None = None stripe_payment_intent_id: str | None = None
quantity: int quantity: int
tokens_granted: int credit_micros_granted: int
amount_total: int | None = None amount_total: int | None = None
currency: str | None = None currency: str | None = None
status: str status: str
@ -87,15 +91,19 @@ class TokenPurchaseRead(BaseModel):
class TokenPurchaseHistoryResponse(BaseModel): class TokenPurchaseHistoryResponse(BaseModel):
"""Response containing the user's premium token purchases.""" """Response containing the user's premium credit purchases."""
purchases: list[TokenPurchaseRead] purchases: list[TokenPurchaseRead]
class TokenStripeStatusResponse(BaseModel): class TokenStripeStatusResponse(BaseModel):
"""Response describing token-buying availability and current quota.""" """Response describing premium-credit-buying availability and balance.
All ``premium_credit_micros_*`` fields are in micro-USD; the FE
divides by 1_000_000 to display USD.
"""
token_buying_enabled: bool token_buying_enabled: bool
premium_tokens_used: int = 0 premium_credit_micros_used: int = 0
premium_tokens_limit: int = 0 premium_credit_micros_limit: int = 0
premium_tokens_remaining: int = 0 premium_credit_micros_remaining: int = 0

View file

@ -62,6 +62,15 @@ class VisionLLMConfigPublic(BaseModel):
class GlobalVisionLLMConfigRead(BaseModel): class GlobalVisionLLMConfigRead(BaseModel):
"""Schema for reading global vision LLM configs from YAML.
The ``billing_tier`` field allows the frontend to show a Premium/Free
badge and (more importantly) tells the backend whether to debit the
user's premium credit pool when this config is used. ``"free"`` is
the default for backward compatibility admins must explicitly opt
a global config into ``"premium"``.
"""
id: int = Field(...) id: int = Field(...)
name: str name: str
description: str | None = None description: str | None = None
@ -73,3 +82,26 @@ class GlobalVisionLLMConfigRead(BaseModel):
litellm_params: dict[str, Any] | None = None litellm_params: dict[str, Any] | None = None
is_global: bool = True is_global: bool = True
is_auto_mode: bool = False is_auto_mode: bool = False
billing_tier: str = Field(
default="free",
description="'free' or 'premium'. Premium debits the user's premium credit pool (USD-cost-based).",
)
quota_reserve_tokens: int | None = Field(
default=None,
description=(
"Optional override for the per-call reservation in *tokens* — "
"converted to micro-USD via the model's input/output prices at "
"reservation time. Falls back to QUOTA_DEFAULT_RESERVE_TOKENS."
),
)
input_cost_per_token: float | None = Field(
default=None,
description=(
"Optional input price in USD/token. Used by pricing_registration to "
"register custom Azure / OpenRouter aliases with LiteLLM at startup."
),
)
output_cost_per_token: float | None = Field(
default=None,
description="Optional output price in USD/token. Pair with input_cost_per_token.",
)

View file

@ -0,0 +1,430 @@
"""
Per-call billable wrapper for image generation, vision LLM extraction, and
any other short-lived premium operation that must charge against the user's
shared premium credit pool.
The ``billable_call`` async context manager encapsulates the standard
"reserve → execute → finalize / release → record audit row" lifecycle in a
single primitive so callers (the image-generation REST route and the
vision-LLM wrapper used during indexing) don't have to re-implement it.
KEY DESIGN POINTS (issue A, B):
1. **Session isolation.** ``billable_call`` takes *no* ``db_session``
argument. All ``TokenQuotaService.premium_*`` calls and the audit-row
insert each run inside their own ``shielded_async_session()``. This
guarantees that a quota commit/rollback can never accidentally flush or
roll back rows the caller has staged in the request's main session
(e.g. a freshly-created ``ImageGeneration`` row).
2. **ContextVar safety.** The accumulator is scoped via
:func:`scoped_turn` (which uses ``ContextVar.reset(token)``), so a
nested ``billable_call`` inside an outer chat turn cannot corrupt the
chat turn's accumulator.
3. **Free configs are still audited.** Free calls bypass the reserve /
finalize dance entirely but still record a ``TokenUsage`` audit row with
the LiteLLM-reported ``cost_micros``. This keeps the cost-attribution
pipeline complete for analytics even when nothing is debited.
4. **Quota denial raises ``QuotaInsufficientError``.** The route handler is
responsible for translating that into HTTP 402. We *do not* catch the
denial inside ``billable_call`` letting it propagate also prevents
the image-generation route from creating an ``ImageGeneration`` row
for a request that never actually ran.
"""
from __future__ import annotations
import logging
from collections.abc import AsyncIterator
from contextlib import asynccontextmanager
from typing import Any
from uuid import UUID, uuid4
from sqlalchemy.ext.asyncio import AsyncSession
from app.config import config
from app.db import shielded_async_session
from app.services.token_quota_service import (
TokenQuotaService,
estimate_call_reserve_micros,
)
from app.services.token_tracking_service import (
TurnTokenAccumulator,
record_token_usage,
scoped_turn,
)
logger = logging.getLogger(__name__)
class QuotaInsufficientError(Exception):
"""Raised when ``TokenQuotaService.premium_reserve`` denies a billable
call because the user has exhausted their premium credit pool.
The route handler should catch this and return HTTP 402 Payment
Required (or the equivalent for the surface area). Outside of the HTTP
layer (e.g. the ``QuotaCheckedVisionLLM`` wrapper used during indexing)
callers may catch this and degrade gracefully e.g. fall back to OCR
when vision is unavailable.
"""
def __init__(
self,
*,
usage_type: str,
used_micros: int,
limit_micros: int,
remaining_micros: int,
) -> None:
self.usage_type = usage_type
self.used_micros = used_micros
self.limit_micros = limit_micros
self.remaining_micros = remaining_micros
super().__init__(
f"Premium credit exhausted for {usage_type}: "
f"used={used_micros} limit={limit_micros} remaining={remaining_micros} (micro-USD)"
)
@asynccontextmanager
async def billable_call(
*,
user_id: UUID,
search_space_id: int,
billing_tier: str,
base_model: str,
quota_reserve_tokens: int | None = None,
quota_reserve_micros_override: int | None = None,
usage_type: str,
thread_id: int | None = None,
message_id: int | None = None,
call_details: dict[str, Any] | None = None,
) -> AsyncIterator[TurnTokenAccumulator]:
"""Wrap a single billable LLM/image call.
Args:
user_id: Owner of the credit pool to debit. For vision-LLM during
indexing this is the *search-space owner* (issue M), not the
triggering user.
search_space_id: Required recorded on the ``TokenUsage`` audit row.
billing_tier: ``"premium"`` debits; anything else (``"free"``) skips
the reserve/finalize dance but still records an audit row with
the captured cost.
base_model: Used by :func:`estimate_call_reserve_micros` to compute
a worst-case reservation from LiteLLM's pricing table.
quota_reserve_tokens: Optional per-config override for the chat-style
reserve estimator (vision LLM uses this).
quota_reserve_micros_override: Optional flat micro-USD reservation
(image generation uses this its cost shape is per-image, not
per-token).
usage_type: ``"image_generation"`` / ``"vision_extraction"`` / etc.
Recorded on the ``TokenUsage`` row.
thread_id, message_id: Optional FK columns on ``TokenUsage``.
call_details: Optional per-call metadata (model name, parameters)
forwarded to ``record_token_usage``.
Yields:
The ``TurnTokenAccumulator`` scoped to this call. The caller invokes
the underlying LLM/image API while inside the ``async with``; the
``TokenTrackingCallback`` populates the accumulator automatically.
Raises:
QuotaInsufficientError: when premium and ``premium_reserve`` denies.
"""
is_premium = billing_tier == "premium"
async with scoped_turn() as acc:
# ---------- Free path: just audit -------------------------------
if not is_premium:
try:
yield acc
finally:
# Always audit, even on exception, so we capture cost when
# provider returns successfully but the caller raises later.
try:
async with shielded_async_session() as audit_session:
await record_token_usage(
audit_session,
usage_type=usage_type,
search_space_id=search_space_id,
user_id=user_id,
prompt_tokens=acc.total_prompt_tokens,
completion_tokens=acc.total_completion_tokens,
total_tokens=acc.grand_total,
cost_micros=acc.total_cost_micros,
model_breakdown=acc.per_message_summary(),
call_details=call_details,
thread_id=thread_id,
message_id=message_id,
)
await audit_session.commit()
except Exception:
logger.exception(
"[billable_call] free-path audit insert failed for "
"usage_type=%s user_id=%s",
usage_type,
user_id,
)
return
# ---------- Premium path: reserve → execute → finalize ----------
if quota_reserve_micros_override is not None:
reserve_micros = max(1, int(quota_reserve_micros_override))
else:
reserve_micros = estimate_call_reserve_micros(
base_model=base_model or "",
quota_reserve_tokens=quota_reserve_tokens,
)
request_id = str(uuid4())
async with shielded_async_session() as quota_session:
reserve_result = await TokenQuotaService.premium_reserve(
db_session=quota_session,
user_id=user_id,
request_id=request_id,
reserve_micros=reserve_micros,
)
if not reserve_result.allowed:
logger.info(
"[billable_call] reserve DENIED user=%s usage_type=%s "
"reserve=%d used=%d limit=%d remaining=%d",
user_id,
usage_type,
reserve_micros,
reserve_result.used,
reserve_result.limit,
reserve_result.remaining,
)
raise QuotaInsufficientError(
usage_type=usage_type,
used_micros=reserve_result.used,
limit_micros=reserve_result.limit,
remaining_micros=reserve_result.remaining,
)
logger.info(
"[billable_call] reserve OK user=%s usage_type=%s reserve_micros=%d "
"(remaining=%d)",
user_id,
usage_type,
reserve_micros,
reserve_result.remaining,
)
try:
yield acc
except BaseException:
# Release on any failure (including QuotaInsufficientError raised
# from a downstream call, asyncio cancellation, etc.). We use
# BaseException so cancellation also releases.
try:
async with shielded_async_session() as quota_session:
await TokenQuotaService.premium_release(
db_session=quota_session,
user_id=user_id,
reserved_micros=reserve_micros,
)
except Exception:
logger.exception(
"[billable_call] premium_release failed for user=%s "
"reserve_micros=%d (reservation will be GC'd by quota "
"reconciliation if/when implemented)",
user_id,
reserve_micros,
)
raise
# ---------- Success: finalize + audit ----------------------------
actual_micros = acc.total_cost_micros
try:
async with shielded_async_session() as quota_session:
final_result = await TokenQuotaService.premium_finalize(
db_session=quota_session,
user_id=user_id,
request_id=request_id,
actual_micros=actual_micros,
reserved_micros=reserve_micros,
)
logger.info(
"[billable_call] finalize user=%s usage_type=%s actual=%d "
"reserved=%d → used=%d/%d (remaining=%d)",
user_id,
usage_type,
actual_micros,
reserve_micros,
final_result.used,
final_result.limit,
final_result.remaining,
)
except Exception:
# Last-ditch: if finalize itself fails, we must at least release
# so the reservation doesn't leak.
logger.exception(
"[billable_call] premium_finalize failed for user=%s; "
"attempting release",
user_id,
)
try:
async with shielded_async_session() as quota_session:
await TokenQuotaService.premium_release(
db_session=quota_session,
user_id=user_id,
reserved_micros=reserve_micros,
)
except Exception:
logger.exception(
"[billable_call] release after finalize failure ALSO failed "
"for user=%s",
user_id,
)
try:
async with shielded_async_session() as audit_session:
await record_token_usage(
audit_session,
usage_type=usage_type,
search_space_id=search_space_id,
user_id=user_id,
prompt_tokens=acc.total_prompt_tokens,
completion_tokens=acc.total_completion_tokens,
total_tokens=acc.grand_total,
cost_micros=actual_micros,
model_breakdown=acc.per_message_summary(),
call_details=call_details,
thread_id=thread_id,
message_id=message_id,
)
await audit_session.commit()
except Exception:
logger.exception(
"[billable_call] premium-path audit insert failed for "
"usage_type=%s user_id=%s (debit was applied)",
usage_type,
user_id,
)
async def _resolve_agent_billing_for_search_space(
session: AsyncSession,
search_space_id: int,
*,
thread_id: int | None = None,
) -> tuple[UUID, str, str]:
"""Resolve ``(owner_user_id, billing_tier, base_model)`` for the search-space
agent LLM.
Used by Celery tasks (podcast generation, video presentation) to bill the
search-space owner's premium credit pool when the agent LLM is premium.
Resolution rules mirror chat at ``stream_new_chat.py:2294-2351``:
- Search space not found / no ``agent_llm_id``: raise ``ValueError``.
- **Auto mode** (``id == AUTO_FASTEST_ID == 0``):
* ``thread_id`` is set: delegate to
``resolve_or_get_pinned_llm_config_id`` (the same call chat uses) and
recurse into the resolved id. Reuses chat's existing pin if present
so the same model bills for chat + downstream podcast/video. If the
user is not premium-eligible, the pin service auto-restricts to free
deployments denial only happens later in
``billable_call.premium_reserve`` if the pin really is premium and
credit ran out mid-flow.
* ``thread_id`` is None: fallback to ``("free", "auto")``. Forward-compat
for any future direct-API path; today both Celery tasks always pass
``thread_id``.
- **Negative id** (global YAML / OpenRouter): ``cfg["billing_tier"]``
(defaults to ``"free"`` via ``app/config/__init__.py:52`` setdefault),
``base_model = litellm_params.get("base_model") or model_name``
NOT provider-prefixed, matching chat's cost-map lookup convention.
- **Positive id** (user BYOK ``NewLLMConfig``): always free (matches
``AgentConfig.from_new_llm_config`` which hard-codes ``billing_tier="free"``);
``base_model`` from ``litellm_params`` or ``model_name``.
Note on imports: ``llm_service``, ``auto_model_pin_service``, and
``llm_router_service`` are imported lazily inside the function body to
avoid hoisting litellm side-effects (``litellm.callbacks =
[token_tracker]``, ``litellm.drop_params``, etc.) into
``billable_calls.py``'s module load path.
"""
from sqlalchemy import select
from app.db import NewLLMConfig, SearchSpace
result = await session.execute(
select(SearchSpace).where(SearchSpace.id == search_space_id)
)
search_space = result.scalars().first()
if search_space is None:
raise ValueError(f"Search space {search_space_id} not found")
agent_llm_id = search_space.agent_llm_id
if agent_llm_id is None:
raise ValueError(
f"Search space {search_space_id} has no agent_llm_id configured"
)
owner_user_id: UUID = search_space.user_id
from app.services.auto_model_pin_service import (
AUTO_FASTEST_ID,
resolve_or_get_pinned_llm_config_id,
)
if agent_llm_id == AUTO_FASTEST_ID:
if thread_id is None:
return owner_user_id, "free", "auto"
try:
resolution = await resolve_or_get_pinned_llm_config_id(
session,
thread_id=thread_id,
search_space_id=search_space_id,
user_id=str(owner_user_id),
selected_llm_config_id=AUTO_FASTEST_ID,
)
except ValueError:
logger.warning(
"[agent_billing] Auto-mode pin resolution failed for "
"search_space=%s thread=%s; falling back to free",
search_space_id,
thread_id,
exc_info=True,
)
return owner_user_id, "free", "auto"
agent_llm_id = resolution.resolved_llm_config_id
if agent_llm_id < 0:
from app.services.llm_service import get_global_llm_config
cfg = get_global_llm_config(agent_llm_id) or {}
billing_tier = str(cfg.get("billing_tier", "free")).lower()
litellm_params = cfg.get("litellm_params") or {}
base_model = litellm_params.get("base_model") or cfg.get("model_name") or ""
return owner_user_id, billing_tier, base_model
nlc_result = await session.execute(
select(NewLLMConfig).where(
NewLLMConfig.id == agent_llm_id,
NewLLMConfig.search_space_id == search_space_id,
)
)
nlc = nlc_result.scalars().first()
base_model = ""
if nlc is not None:
litellm_params = nlc.litellm_params or {}
base_model = litellm_params.get("base_model") or nlc.model_name or ""
return owner_user_id, "free", base_model
__all__ = [
"QuotaInsufficientError",
"_resolve_agent_billing_for_search_space",
"billable_call",
]
# Re-export the config knob so callers don't have to import config just for
# the default image reserve.
DEFAULT_IMAGE_RESERVE_MICROS = config.QUOTA_DEFAULT_IMAGE_RESERVE_MICROS

View file

@ -134,42 +134,16 @@ PROVIDER_MAP = {
} }
# Default ``api_base`` per LiteLLM provider prefix. Used as a safety net when # ``PROVIDER_DEFAULT_API_BASE`` and ``PROVIDER_KEY_DEFAULT_API_BASE`` were
# a global LLM config does *not* specify ``api_base``: without this, LiteLLM # hoisted to ``app.services.provider_api_base`` so vision and image-gen
# happily picks up provider-agnostic env vars (e.g. ``AZURE_API_BASE``, # call sites can share the exact same defense (OpenRouter / Groq / etc.
# ``OPENAI_API_BASE``) and routes, say, an ``openrouter/anthropic/claude-3-haiku`` # 404-ing against an inherited Azure endpoint). Re-exported here for
# request to an Azure endpoint, which then 404s with ``Resource not found``. # backward compatibility with any external import.
# Only providers with a well-known, stable public base URL are listed here — from app.services.provider_api_base import ( # noqa: E402
# self-hosted / BYO-endpoint providers (ollama, custom, bedrock, vertex_ai, PROVIDER_DEFAULT_API_BASE,
# huggingface, databricks, cloudflare, replicate) are intentionally omitted PROVIDER_KEY_DEFAULT_API_BASE,
# so their existing config-driven behaviour is preserved. resolve_api_base,
PROVIDER_DEFAULT_API_BASE = { )
"openrouter": "https://openrouter.ai/api/v1",
"groq": "https://api.groq.com/openai/v1",
"mistral": "https://api.mistral.ai/v1",
"perplexity": "https://api.perplexity.ai",
"xai": "https://api.x.ai/v1",
"cerebras": "https://api.cerebras.ai/v1",
"deepinfra": "https://api.deepinfra.com/v1/openai",
"fireworks_ai": "https://api.fireworks.ai/inference/v1",
"together_ai": "https://api.together.xyz/v1",
"anyscale": "https://api.endpoints.anyscale.com/v1",
"cometapi": "https://api.cometapi.com/v1",
"sambanova": "https://api.sambanova.ai/v1",
}
# Canonical provider → base URL when a config uses a generic ``openai``-style
# prefix but the ``provider`` field tells us which API it really is
# (e.g. DeepSeek/Alibaba/Moonshot/Zhipu/MiniMax all use ``openai`` compat but
# each has its own base URL).
PROVIDER_KEY_DEFAULT_API_BASE = {
"DEEPSEEK": "https://api.deepseek.com/v1",
"ALIBABA_QWEN": "https://dashscope-intl.aliyuncs.com/compatible-mode/v1",
"MOONSHOT": "https://api.moonshot.ai/v1",
"ZHIPU": "https://open.bigmodel.cn/api/paas/v4",
"MINIMAX": "https://api.minimax.io/v1",
}
class LLMRouterService: class LLMRouterService:
@ -466,14 +440,14 @@ class LLMRouterService:
# Resolve ``api_base``. Config value wins; otherwise apply a # Resolve ``api_base``. Config value wins; otherwise apply a
# provider-aware default so the deployment does not silently # provider-aware default so the deployment does not silently
# inherit unrelated env vars (e.g. ``AZURE_API_BASE``) and route # inherit unrelated env vars (e.g. ``AZURE_API_BASE``) and route
# requests to the wrong endpoint. See ``PROVIDER_DEFAULT_API_BASE`` # requests to the wrong endpoint. See ``provider_api_base``
# docstring for the motivating bug (OpenRouter models 404-ing # docstring for the motivating bug (OpenRouter models 404-ing
# against an Azure endpoint). # against an Azure endpoint).
api_base = config.get("api_base") api_base = resolve_api_base(
if not api_base: provider=provider,
api_base = PROVIDER_KEY_DEFAULT_API_BASE.get(provider) provider_prefix=provider_prefix,
if not api_base: config_api_base=config.get("api_base"),
api_base = PROVIDER_DEFAULT_API_BASE.get(provider_prefix) )
if api_base: if api_base:
litellm_params["api_base"] = api_base litellm_params["api_base"] = api_base

View file

@ -496,8 +496,14 @@ async def get_vision_llm(
- Auto mode (ID 0): VisionLLMRouterService - Auto mode (ID 0): VisionLLMRouterService
- Global (negative ID): YAML configs - Global (negative ID): YAML configs
- DB (positive ID): VisionLLMConfig table - DB (positive ID): VisionLLMConfig table
Premium global configs are wrapped in :class:`QuotaCheckedVisionLLM`
so each ``ainvoke`` debits the search-space owner's premium credit
pool. User-owned BYOK configs and free global configs are returned
unwrapped they don't consume premium credit (issue M).
""" """
from app.db import VisionLLMConfig from app.db import VisionLLMConfig
from app.services.quota_checked_vision_llm import QuotaCheckedVisionLLM
from app.services.vision_llm_router_service import ( from app.services.vision_llm_router_service import (
VISION_PROVIDER_MAP, VISION_PROVIDER_MAP,
VisionLLMRouterService, VisionLLMRouterService,
@ -519,6 +525,8 @@ async def get_vision_llm(
logger.error(f"No vision LLM configured for search space {search_space_id}") logger.error(f"No vision LLM configured for search space {search_space_id}")
return None return None
owner_user_id = search_space.user_id
if is_vision_auto_mode(config_id): if is_vision_auto_mode(config_id):
if not VisionLLMRouterService.is_initialized(): if not VisionLLMRouterService.is_initialized():
logger.error( logger.error(
@ -526,6 +534,13 @@ async def get_vision_llm(
) )
return None return None
try: try:
# Auto mode is currently treated as free at the wrapper
# level — the underlying router can dispatch to either
# premium or free YAML configs but routing decisions are
# opaque. If/when we want to bill Auto-routed vision
# calls we'd need to thread the resolved deployment's
# billing_tier back from the router. For now we keep
# parity with chat Auto, which also doesn't pre-classify.
return ChatLiteLLMRouter( return ChatLiteLLMRouter(
router=VisionLLMRouterService.get_router(), router=VisionLLMRouterService.get_router(),
streaming=True, streaming=True,
@ -562,8 +577,21 @@ async def get_vision_llm(
from app.agents.new_chat.llm_config import SanitizedChatLiteLLM from app.agents.new_chat.llm_config import SanitizedChatLiteLLM
return SanitizedChatLiteLLM(**litellm_kwargs) inner_llm = SanitizedChatLiteLLM(**litellm_kwargs)
billing_tier = str(global_cfg.get("billing_tier", "free")).lower()
if billing_tier == "premium":
return QuotaCheckedVisionLLM(
inner_llm,
user_id=owner_user_id,
search_space_id=search_space_id,
billing_tier=billing_tier,
base_model=model_string,
quota_reserve_tokens=global_cfg.get("quota_reserve_tokens"),
)
return inner_llm
# User-owned (positive ID) BYOK configs — always free.
result = await session.execute( result = await session.execute(
select(VisionLLMConfig).where( select(VisionLLMConfig).where(
VisionLLMConfig.id == config_id, VisionLLMConfig.id == config_id,

View file

@ -93,6 +93,35 @@ def _is_text_output_model(model: dict) -> bool:
return output_mods == ["text"] return output_mods == ["text"]
def _is_image_output_model(model: dict) -> bool:
"""Return True if the model can produce image output.
OpenRouter's ``architecture.output_modalities`` is a list (e.g.
``["image"]`` for pure image generators, ``["text", "image"]`` for
multi-modal generators that also emit captions). We accept any model
that can output images; the call site decides whether to use the
image-generation API or chat completion.
"""
output_mods = model.get("architecture", {}).get("output_modalities", []) or []
return "image" in output_mods
def _is_vision_input_model(model: dict) -> bool:
"""Return True if the model can ingest an image AND emit text.
OpenRouter's ``architecture.input_modalities`` lists what the model
accepts; ``output_modalities`` lists what it produces. A vision LLM
is a model that takes images in and produces text out i.e. it can
answer questions about a screenshot or extract content from an
image. Pure image-to-image models (e.g. style transfer) and
text-only models are excluded.
"""
arch = model.get("architecture", {}) or {}
input_mods = arch.get("input_modalities", []) or []
output_mods = arch.get("output_modalities", []) or []
return "image" in input_mods and "text" in output_mods
def _supports_tool_calling(model: dict) -> bool: def _supports_tool_calling(model: dict) -> bool:
"""Return True if the model supports function/tool calling.""" """Return True if the model supports function/tool calling."""
supported = model.get("supported_parameters") or [] supported = model.get("supported_parameters") or []
@ -175,6 +204,32 @@ async def _fetch_models_async() -> list[dict] | None:
return None return None
def _extract_raw_pricing(raw_models: list[dict]) -> dict[str, dict[str, str]]:
"""Return a ``{model_id: {"prompt": str, "completion": str}}`` map.
Pricing values are kept as the raw OpenRouter strings (e.g.
``"0.000003"``); ``pricing_registration`` converts them to floats
when registering with LiteLLM. Models with missing or malformed
pricing are simply omitted operator-side risk if any of those are
premium.
"""
pricing: dict[str, dict[str, str]] = {}
for model in raw_models:
model_id = str(model.get("id") or "").strip()
if not model_id:
continue
p = model.get("pricing") or {}
prompt = p.get("prompt")
completion = p.get("completion")
if prompt is None and completion is None:
continue
pricing[model_id] = {
"prompt": str(prompt) if prompt is not None else "",
"completion": str(completion) if completion is not None else "",
}
return pricing
def _generate_configs( def _generate_configs(
raw_models: list[dict], raw_models: list[dict],
settings: dict[str, Any], settings: dict[str, Any],
@ -282,6 +337,162 @@ def _generate_configs(
return configs return configs
# ID-offset bands used to keep dynamic OpenRouter configs in their own
# namespace per surface. Image / vision get separate bands so a single
# Postgres-INTEGER cfg ID is unambiguous about which selector it belongs to.
_OPENROUTER_IMAGE_ID_OFFSET_DEFAULT = -20000
_OPENROUTER_VISION_ID_OFFSET_DEFAULT = -30000
def _generate_image_gen_configs(
raw_models: list[dict], settings: dict[str, Any]
) -> list[dict]:
"""Convert OpenRouter image-generation models into global image-gen
config dicts (matches the YAML shape consumed by ``image_generation_routes``).
Filter:
- architecture.output_modalities contains "image"
- compatible provider (excluded slugs blocked)
- allowed model id (excluded list blocked)
Notably we *drop* the chat-only filters (``_supports_tool_calling`` and
``_has_sufficient_context``) because tool calls and context windows are
irrelevant for the ``aimage_generation`` API. ``billing_tier`` is
derived per model the same way as chat (``_openrouter_tier``).
Cost is intentionally *not* registered with LiteLLM at startup
(``pricing_registration`` skips image gen): OpenRouter image-gen
models are not in LiteLLM's native cost map and OpenRouter populates
``response_cost`` directly from the response header. A defensive
branch in ``_extract_cost_usd`` handles the rare case where
``usage.cost`` is missing see ``token_tracking_service``.
"""
id_offset: int = int(
settings.get("image_id_offset") or _OPENROUTER_IMAGE_ID_OFFSET_DEFAULT
)
api_key: str = settings.get("api_key", "")
rpm: int = settings.get("rpm", 200)
free_rpm: int = settings.get("free_rpm", 20)
litellm_params: dict = settings.get("litellm_params") or {}
image_models = [
m
for m in raw_models
if _is_image_output_model(m)
and _is_compatible_provider(m)
and _is_allowed_model(m)
and "/" in m.get("id", "")
]
configs: list[dict] = []
taken: set[int] = set()
for model in image_models:
model_id: str = model["id"]
name: str = model.get("name", model_id)
tier = _openrouter_tier(model)
cfg: dict[str, Any] = {
"id": _stable_config_id(model_id, id_offset, taken),
"name": name,
"description": f"{name} via OpenRouter (image generation)",
"provider": "OPENROUTER",
"model_name": model_id,
"api_key": api_key,
"api_base": "",
"api_version": None,
"rpm": free_rpm if tier == "free" else rpm,
"litellm_params": dict(litellm_params),
"billing_tier": tier,
_OPENROUTER_DYNAMIC_MARKER: True,
}
configs.append(cfg)
return configs
def _generate_vision_llm_configs(
raw_models: list[dict], settings: dict[str, Any]
) -> list[dict]:
"""Convert OpenRouter vision-capable LLMs into global vision-LLM config
dicts (matches the YAML shape consumed by ``vision_llm_routes``).
Filter:
- architecture.input_modalities contains "image"
- architecture.output_modalities contains "text"
- compatible provider (excluded slugs blocked)
- allowed model id (excluded list blocked)
Vision-LLM is invoked from the indexer (image extraction during
document upload) via ``langchain_litellm.ChatLiteLLM.ainvoke``, so
the chat-only ``_supports_tool_calling`` and ``_has_sufficient_context``
filters do not apply: a small-context vision model that doesn't
advertise tool-calling is still perfectly viable for "describe this
image" prompts.
"""
id_offset: int = int(
settings.get("vision_id_offset") or _OPENROUTER_VISION_ID_OFFSET_DEFAULT
)
api_key: str = settings.get("api_key", "")
rpm: int = settings.get("rpm", 200)
tpm: int = settings.get("tpm", 1_000_000)
free_rpm: int = settings.get("free_rpm", 20)
free_tpm: int = settings.get("free_tpm", 100_000)
quota_reserve_tokens: int = settings.get("quota_reserve_tokens", 4000)
litellm_params: dict = settings.get("litellm_params") or {}
vision_models = [
m
for m in raw_models
if _is_vision_input_model(m)
and _is_compatible_provider(m)
and _is_allowed_model(m)
and "/" in m.get("id", "")
]
configs: list[dict] = []
taken: set[int] = set()
for model in vision_models:
model_id: str = model["id"]
name: str = model.get("name", model_id)
tier = _openrouter_tier(model)
pricing = model.get("pricing") or {}
# Capture per-token prices so ``pricing_registration`` can
# register them with LiteLLM at startup (and so the cost
# estimator in ``estimate_call_reserve_micros`` can resolve
# them at reserve time).
try:
input_cost = float(pricing.get("prompt", 0) or 0)
except (TypeError, ValueError):
input_cost = 0.0
try:
output_cost = float(pricing.get("completion", 0) or 0)
except (TypeError, ValueError):
output_cost = 0.0
cfg: dict[str, Any] = {
"id": _stable_config_id(model_id, id_offset, taken),
"name": name,
"description": f"{name} via OpenRouter (vision)",
"provider": "OPENROUTER",
"model_name": model_id,
"api_key": api_key,
"api_base": "",
"api_version": None,
"rpm": free_rpm if tier == "free" else rpm,
"tpm": free_tpm if tier == "free" else tpm,
"litellm_params": dict(litellm_params),
"billing_tier": tier,
"quota_reserve_tokens": quota_reserve_tokens,
"input_cost_per_token": input_cost or None,
"output_cost_per_token": output_cost or None,
_OPENROUTER_DYNAMIC_MARKER: True,
}
configs.append(cfg)
return configs
class OpenRouterIntegrationService: class OpenRouterIntegrationService:
"""Singleton that manages the dynamic OpenRouter model catalogue.""" """Singleton that manages the dynamic OpenRouter model catalogue."""
@ -300,6 +511,19 @@ class OpenRouterIntegrationService:
# Shape: {model_name: {"gated": bool, "score": float | None}} # Shape: {model_name: {"gated": bool, "score": float | None}}
self._health_cache: dict[str, dict[str, Any]] = {} self._health_cache: dict[str, dict[str, Any]] = {}
self._enrich_task: asyncio.Task | None = None self._enrich_task: asyncio.Task | None = None
# Raw OpenRouter pricing per model_id, captured at the same time
# we generate configs. Consumed by ``pricing_registration`` to
# teach LiteLLM the per-token cost of every dynamic deployment so
# the success-callback can populate ``response_cost`` correctly.
self._raw_pricing: dict[str, dict[str, str]] = {}
# Cached raw catalogue from the most recent fetch. Image / vision
# emitters reuse this to avoid a second network call per surface.
self._raw_models: list[dict] = []
# Image / vision config caches (only populated when the matching
# opt-in flag is true on initialize). Refreshed in lockstep with
# the chat catalogue.
self._image_configs: list[dict] = []
self._vision_configs: list[dict] = []
@classmethod @classmethod
def get_instance(cls) -> "OpenRouterIntegrationService": def get_instance(cls) -> "OpenRouterIntegrationService":
@ -329,8 +553,32 @@ class OpenRouterIntegrationService:
self._initialized = True self._initialized = True
return [] return []
self._raw_models = raw_models
self._configs = _generate_configs(raw_models, settings) self._configs = _generate_configs(raw_models, settings)
self._configs_by_id = {c["id"]: c for c in self._configs} self._configs_by_id = {c["id"]: c for c in self._configs}
self._raw_pricing = _extract_raw_pricing(raw_models)
# Populate image / vision caches when their opt-in flag is set.
# Empty otherwise so the accessors return [] without re-running
# filters every refresh.
if settings.get("image_generation_enabled"):
self._image_configs = _generate_image_gen_configs(raw_models, settings)
logger.info(
"OpenRouter integration: image-gen emission ON (%d models)",
len(self._image_configs),
)
else:
self._image_configs = []
if settings.get("vision_enabled"):
self._vision_configs = _generate_vision_llm_configs(raw_models, settings)
logger.info(
"OpenRouter integration: vision LLM emission ON (%d models)",
len(self._vision_configs),
)
else:
self._vision_configs = []
self._initialized = True self._initialized = True
tier_counts = self._tier_counts(self._configs) tier_counts = self._tier_counts(self._configs)
@ -369,6 +617,8 @@ class OpenRouterIntegrationService:
new_configs = _generate_configs(raw_models, self._settings) new_configs = _generate_configs(raw_models, self._settings)
new_by_id = {c["id"]: c for c in new_configs} new_by_id = {c["id"]: c for c in new_configs}
self._raw_pricing = _extract_raw_pricing(raw_models)
self._raw_models = raw_models
from app.config import config as app_config from app.config import config as app_config
@ -382,6 +632,29 @@ class OpenRouterIntegrationService:
self._configs = new_configs self._configs = new_configs
self._configs_by_id = new_by_id self._configs_by_id = new_by_id
# Image / vision lists are atomic-swapped the same way: filter out
# the previous dynamic entries from the live config list and append
# the freshly generated ones. No-ops when the opt-in flag is off.
if self._settings.get("image_generation_enabled"):
new_image = _generate_image_gen_configs(raw_models, self._settings)
static_image = [
c
for c in app_config.GLOBAL_IMAGE_GEN_CONFIGS
if not c.get(_OPENROUTER_DYNAMIC_MARKER)
]
app_config.GLOBAL_IMAGE_GEN_CONFIGS = static_image + new_image
self._image_configs = new_image
if self._settings.get("vision_enabled"):
new_vision = _generate_vision_llm_configs(raw_models, self._settings)
static_vision = [
c
for c in app_config.GLOBAL_VISION_LLM_CONFIGS
if not c.get(_OPENROUTER_DYNAMIC_MARKER)
]
app_config.GLOBAL_VISION_LLM_CONFIGS = static_vision + new_vision
self._vision_configs = new_vision
# Catalogue churn invalidates per-config "recently healthy" credit # Catalogue churn invalidates per-config "recently healthy" credit
# earned by the previous turn's preflight. Drop the whole table so # earned by the previous turn's preflight. Drop the whole table so
# the next turn re-probes against the freshly loaded configs. # the next turn re-probes against the freshly loaded configs.
@ -407,6 +680,21 @@ class OpenRouterIntegrationService:
# so a hand-picked dead OR model is gated like a dynamic one. # so a hand-picked dead OR model is gated like a dynamic one.
await self._enrich_health_safely(static_configs + new_configs, log_summary=True) await self._enrich_health_safely(static_configs + new_configs, log_summary=True)
# Re-register LiteLLM pricing for the freshly fetched catalogue
# so newly added OR models bill correctly on their first call.
# Runs before the router rebuild because the router may issue
# cost-table lookups during deployment registration.
try:
from app.services.pricing_registration import (
register_pricing_from_global_configs,
)
register_pricing_from_global_configs()
except Exception as exc:
logger.warning(
"OpenRouter refresh: pricing re-registration skipped (%s)", exc
)
# Rebuild the LiteLLM router so freshly fetched configs flow through # Rebuild the LiteLLM router so freshly fetched configs flow through
# (dynamic OR premium entries now opt into the pool, free ones stay # (dynamic OR premium entries now opt into the pool, free ones stay
# out; a refresh also needs to pick up any static-config edits and # out; a refresh also needs to pick up any static-config edits and
@ -635,3 +923,34 @@ class OpenRouterIntegrationService:
def get_config_by_id(self, config_id: int) -> dict | None: def get_config_by_id(self, config_id: int) -> dict | None:
return self._configs_by_id.get(config_id) return self._configs_by_id.get(config_id)
def get_image_generation_configs(self) -> list[dict]:
"""Return the dynamic OpenRouter image-generation configs (empty
list when the ``image_generation_enabled`` flag is off).
Each entry already has ``billing_tier`` derived per-model from
OpenRouter's signals and is shaped to drop directly into
``Config.GLOBAL_IMAGE_GEN_CONFIGS``.
"""
return list(self._image_configs)
def get_vision_llm_configs(self) -> list[dict]:
"""Return the dynamic OpenRouter vision-LLM configs (empty list
when the ``vision_enabled`` flag is off).
Each entry exposes ``input_cost_per_token`` / ``output_cost_per_token``
so ``pricing_registration`` can teach LiteLLM the cost of these
models the same way it does for chat which keeps the billable
wrapper able to debit accurate micro-USD on a vision call.
"""
return list(self._vision_configs)
def get_raw_pricing(self) -> dict[str, dict[str, str]]:
"""Return the cached raw OpenRouter pricing map.
Shape: ``{model_id: {"prompt": str, "completion": str}}``. The
values are the strings OpenRouter publishes (USD per token),
never converted to floats here so the caller can decide how to
handle malformed or unset entries.
"""
return dict(self._raw_pricing)

View file

@ -0,0 +1,274 @@
"""
Pricing registration with LiteLLM.
Many models reach our LiteLLM callback without LiteLLM knowing their
per-token cost namely:
* The ~300 dynamic OpenRouter deployments (their pricing only lives on
OpenRouter's ``/api/v1/models`` payload, never in LiteLLM's published
pricing table).
* Static YAML deployments whose ``base_model`` name is operator-defined
(e.g. custom Azure deployment names like ``gpt-5.4``) and therefore
not in LiteLLM's table either.
Without registration, ``kwargs["response_cost"]`` is 0 for those calls
and the user gets billed nothing a fail-safe but wrong answer for a
cost-based credit system. This module runs once at startup, after the
OpenRouter integration has fetched its catalogue, and registers each
known model's pricing with ``litellm.register_model()`` under multiple
plausible alias keys (LiteLLM's cost lookup may use any of them
depending on whether the call went through the Router, ChatLiteLLM,
or a direct ``acompletion``).
Operators who run a custom Azure deployment whose ``base_model`` name
isn't in LiteLLM's table can declare per-token pricing inline in
``global_llm_config.yaml`` via ``input_cost_per_token`` and
``output_cost_per_token`` (USD per token, e.g. ``0.000002``). Without
that declaration the model's calls debit 0 — never overbilled.
"""
from __future__ import annotations
import logging
from typing import Any
import litellm
logger = logging.getLogger(__name__)
def _safe_float(value: Any) -> float:
"""Return ``float(value)`` if it parses to a positive number, else 0.0."""
if value is None:
return 0.0
try:
f = float(value)
except (TypeError, ValueError):
return 0.0
return f if f > 0 else 0.0
def _alias_set_for_openrouter(model_id: str) -> list[str]:
"""Return the alias keys to register an OpenRouter model under.
LiteLLM's cost-callback lookup key varies by call path:
- Router with ``model="openrouter/X"`` kwargs["model"] is
typically ``openrouter/X``.
- LiteLLM's own provider routing may strip the prefix and pass the
bare ``X`` to the cost-table lookup.
Registering under both keeps the lookup hermetic regardless of
which path the call took.
"""
aliases = [f"openrouter/{model_id}", model_id]
return list(dict.fromkeys(a for a in aliases if a))
def _alias_set_for_yaml(provider: str, model_name: str, base_model: str) -> list[str]:
"""Return the alias keys to register a static YAML deployment under.
Same reasoning as the OpenRouter set: cover the bare ``base_model``,
the ``<provider>/<model>`` form LiteLLM Router constructs, and the
bare ``model_name`` because callbacks sometimes see whichever was
configured first.
"""
provider_lower = (provider or "").lower()
aliases: list[str] = []
if base_model:
aliases.append(base_model)
if provider_lower and base_model:
aliases.append(f"{provider_lower}/{base_model}")
if model_name and model_name != base_model:
aliases.append(model_name)
if provider_lower and model_name and model_name != base_model:
aliases.append(f"{provider_lower}/{model_name}")
# Azure deployments often surface as "azure/<name>"; normalise the
# ``azure_openai`` provider slug to the LiteLLM-canonical ``azure``.
if provider_lower == "azure_openai":
if base_model:
aliases.append(f"azure/{base_model}")
if model_name and model_name != base_model:
aliases.append(f"azure/{model_name}")
return list(dict.fromkeys(a for a in aliases if a))
def _register(
aliases: list[str],
*,
input_cost: float,
output_cost: float,
provider: str,
mode: str = "chat",
) -> int:
"""Register a single pricing entry under every alias in ``aliases``.
Returns the count of aliases successfully registered.
"""
payload: dict[str, dict[str, Any]] = {}
for alias in aliases:
payload[alias] = {
"input_cost_per_token": input_cost,
"output_cost_per_token": output_cost,
"litellm_provider": provider,
"mode": mode,
}
if not payload:
return 0
try:
litellm.register_model(payload)
except Exception as exc:
logger.warning(
"[PricingRegistration] register_model failed for aliases=%s: %s",
aliases,
exc,
)
return 0
return len(payload)
def _register_chat_shape_configs(
configs: list[dict],
*,
or_pricing: dict[str, dict[str, str]],
label: str,
) -> tuple[int, int, int, list[str]]:
"""Common loop that registers per-token pricing for a list of "chat-shape"
configs (chat or vision LLM both use ``input_cost_per_token`` /
``output_cost_per_token`` and the LiteLLM ``mode="chat"`` cost shape).
Returns ``(registered_models, registered_aliases, skipped, sample_keys)``.
"""
registered_models = 0
registered_aliases = 0
skipped_no_pricing = 0
sample_keys: list[str] = []
for cfg in configs:
provider = str(cfg.get("provider") or "").upper()
model_name = str(cfg.get("model_name") or "").strip()
litellm_params = cfg.get("litellm_params") or {}
base_model = str(litellm_params.get("base_model") or model_name).strip()
if provider == "OPENROUTER":
entry = or_pricing.get(model_name)
if entry:
input_cost = _safe_float(entry.get("prompt"))
output_cost = _safe_float(entry.get("completion"))
else:
# Vision configs from ``_generate_vision_llm_configs``
# carry their pricing inline because the OpenRouter
# raw-pricing cache is keyed by chat-catalogue model_id;
# vision flows pick up the inline values here.
input_cost = _safe_float(cfg.get("input_cost_per_token"))
output_cost = _safe_float(cfg.get("output_cost_per_token"))
if input_cost == 0.0 and output_cost == 0.0:
skipped_no_pricing += 1
continue
aliases = _alias_set_for_openrouter(model_name)
count = _register(
aliases,
input_cost=input_cost,
output_cost=output_cost,
provider="openrouter",
)
if count > 0:
registered_models += 1
registered_aliases += count
if len(sample_keys) < 6:
sample_keys.extend(aliases[:2])
continue
input_cost = _safe_float(
cfg.get("input_cost_per_token")
or litellm_params.get("input_cost_per_token")
)
output_cost = _safe_float(
cfg.get("output_cost_per_token")
or litellm_params.get("output_cost_per_token")
)
if input_cost == 0.0 and output_cost == 0.0:
skipped_no_pricing += 1
continue
aliases = _alias_set_for_yaml(provider, model_name, base_model)
provider_slug = "azure" if provider == "AZURE_OPENAI" else provider.lower()
count = _register(
aliases,
input_cost=input_cost,
output_cost=output_cost,
provider=provider_slug,
)
if count > 0:
registered_models += 1
registered_aliases += count
if len(sample_keys) < 6:
sample_keys.extend(aliases[:2])
logger.info(
"[PricingRegistration:%s] registered pricing for %d models (%d aliases); "
"%d configs had no pricing data; sample registered keys=%s",
label,
registered_models,
registered_aliases,
skipped_no_pricing,
sample_keys,
)
return registered_models, registered_aliases, skipped_no_pricing, sample_keys
def register_pricing_from_global_configs() -> None:
"""Register pricing for every known LLM deployment with LiteLLM.
Walks ``config.GLOBAL_LLM_CONFIGS`` *and* ``config.GLOBAL_VISION_LLM_CONFIGS``
so vision calls (during indexing) can resolve cost the same way chat
calls do namely:
1. ``OPENROUTER``: pulls the cached raw pricing from
``OpenRouterIntegrationService`` (populated during its own
startup fetch) and converts the per-token strings to floats. For
vision configs that carry pricing inline (``input_cost_per_token`` /
``output_cost_per_token`` set on the cfg itself) we fall back to
those values when the OR cache misses the model.
2. Anything else: looks for operator-declared
``input_cost_per_token`` / ``output_cost_per_token`` on the YAML
config block (top-level or nested under ``litellm_params``).
**Image generation is intentionally NOT registered here.** The cost
shape for image-gen is per-image (``output_cost_per_image``), not
per-token, and LiteLLM's ``register_model`` doesn't accept those
keys via the chat-cost path. OpenRouter image-gen models populate
``response_cost`` directly from their response header instead, and
Azure-native image-gen models are already in LiteLLM's cost map.
Calls without a resolved pair of costs are skipped, not registered
with zeros operators who forget pricing get a "$0 debit" warning
in ``TokenTrackingCallback`` rather than silently overwriting any
pricing LiteLLM might know natively.
"""
from app.config import config as app_config
chat_configs: list[dict] = list(getattr(app_config, "GLOBAL_LLM_CONFIGS", []) or [])
vision_configs: list[dict] = list(
getattr(app_config, "GLOBAL_VISION_LLM_CONFIGS", []) or []
)
if not chat_configs and not vision_configs:
logger.info("[PricingRegistration] no global configs to register")
return
or_pricing: dict[str, dict[str, str]] = {}
try:
from app.services.openrouter_integration_service import (
OpenRouterIntegrationService,
)
if OpenRouterIntegrationService.is_initialized():
or_pricing = OpenRouterIntegrationService.get_instance().get_raw_pricing()
except Exception as exc:
logger.debug(
"[PricingRegistration] OpenRouter pricing not available yet: %s", exc
)
if chat_configs:
_register_chat_shape_configs(chat_configs, or_pricing=or_pricing, label="chat")
if vision_configs:
_register_chat_shape_configs(
vision_configs, or_pricing=or_pricing, label="vision"
)

View file

@ -0,0 +1,107 @@
"""Provider-aware ``api_base`` resolution shared by chat / image-gen / vision.
LiteLLM falls back to the module-global ``litellm.api_base`` when an
individual call doesn't pass one, which silently inherits provider-agnostic
env vars like ``AZURE_OPENAI_ENDPOINT`` / ``OPENAI_API_BASE``. Without an
explicit ``api_base``, an ``openrouter/<model>`` request can end up at an
Azure endpoint and 404 with ``Resource not found`` (real reproducer:
[litellm/llms/openrouter/image_generation/transformation.py:242-263] appends
``/chat/completions`` to whatever inherited base it gets, regardless of
provider).
The chat router has had this defense for a while
(``llm_router_service.py:466-478``). This module hoists the maps + cascade
into a tiny standalone helper so vision and image-gen can share the same
source of truth without an inter-service circular import.
"""
from __future__ import annotations
PROVIDER_DEFAULT_API_BASE: dict[str, str] = {
"openrouter": "https://openrouter.ai/api/v1",
"groq": "https://api.groq.com/openai/v1",
"mistral": "https://api.mistral.ai/v1",
"perplexity": "https://api.perplexity.ai",
"xai": "https://api.x.ai/v1",
"cerebras": "https://api.cerebras.ai/v1",
"deepinfra": "https://api.deepinfra.com/v1/openai",
"fireworks_ai": "https://api.fireworks.ai/inference/v1",
"together_ai": "https://api.together.xyz/v1",
"anyscale": "https://api.endpoints.anyscale.com/v1",
"cometapi": "https://api.cometapi.com/v1",
"sambanova": "https://api.sambanova.ai/v1",
}
"""Default ``api_base`` per LiteLLM provider prefix (lowercase).
Only providers with a well-known, stable public base URL are listed
self-hosted / BYO-endpoint providers (ollama, custom, bedrock, vertex_ai,
huggingface, databricks, cloudflare, replicate) are intentionally omitted
so their existing config-driven behaviour is preserved."""
PROVIDER_KEY_DEFAULT_API_BASE: dict[str, str] = {
"DEEPSEEK": "https://api.deepseek.com/v1",
"ALIBABA_QWEN": "https://dashscope-intl.aliyuncs.com/compatible-mode/v1",
"MOONSHOT": "https://api.moonshot.ai/v1",
"ZHIPU": "https://open.bigmodel.cn/api/paas/v4",
"MINIMAX": "https://api.minimax.io/v1",
}
"""Canonical provider key (uppercase) → base URL.
Used when the LiteLLM provider prefix is the generic ``openai`` shim but the
config's ``provider`` field tells us which API it actually is (DeepSeek,
Alibaba, Moonshot, Zhipu, MiniMax all use the ``openai`` prefix but each
has its own base URL)."""
def resolve_api_base(
*,
provider: str | None,
provider_prefix: str | None,
config_api_base: str | None,
) -> str | None:
"""Resolve a non-Azure-leaking ``api_base`` for a deployment.
Cascade (first non-empty wins):
1. The config's own ``api_base`` (whitespace-only treated as missing).
2. ``PROVIDER_KEY_DEFAULT_API_BASE[provider.upper()]``.
3. ``PROVIDER_DEFAULT_API_BASE[provider_prefix.lower()]``.
4. ``None`` caller should NOT set ``api_base`` and let the LiteLLM
provider integration apply its own default (e.g. AzureOpenAI's
deployment-derived URL, custom provider's per-deployment URL).
Args:
provider: The config's ``provider`` field (e.g. ``"OPENROUTER"``,
``"DEEPSEEK"``). Case-insensitive.
provider_prefix: The LiteLLM model-string prefix the same call
site builds for the model id (e.g. ``"openrouter"``,
``"groq"``). Case-insensitive.
config_api_base: ``api_base`` from the global YAML / DB row /
OpenRouter dynamic config. Empty / whitespace-only means
"missing" the resolver still applies the cascade.
Returns:
A URL string, or ``None`` if no default applies for this provider.
"""
if config_api_base and config_api_base.strip():
return config_api_base
if provider:
key_default = PROVIDER_KEY_DEFAULT_API_BASE.get(provider.upper())
if key_default:
return key_default
if provider_prefix:
prefix_default = PROVIDER_DEFAULT_API_BASE.get(provider_prefix.lower())
if prefix_default:
return prefix_default
return None
__all__ = [
"PROVIDER_DEFAULT_API_BASE",
"PROVIDER_KEY_DEFAULT_API_BASE",
"resolve_api_base",
]

View file

@ -0,0 +1,105 @@
"""
Vision LLM proxy that enforces premium credit quota on every ``ainvoke``.
Used by :func:`app.services.llm_service.get_vision_llm` so callers in the
indexing pipeline (file processors, connector indexers, etl pipeline) can
keep invoking the LLM exactly the way they do today ``await llm.ainvoke(...)``
without threading ``user_id`` through every parser. The wrapper looks like
a chat model from the outside; on the inside it routes each call through
``billable_call`` so the user's premium credit pool is reserved → finalized
or released, and a ``TokenUsage`` audit row is written.
Free configs are returned unwrapped from ``get_vision_llm`` (they do not
need quota enforcement) so this class only ever wraps premium configs.
Why a wrapper instead of plumbing ``user_id`` through every caller:
* The indexer ecosystem has 8+ entry points (Google Drive, OneDrive,
Dropbox, local-folder, file-processor, ETL pipeline) each calling
``parse_with_vision_llm(...)``. Adding a ``user_id`` argument to each is
invasive, error-prone, and easy for a future indexer to forget.
* Per the design (issue M), we always debit the *search-space owner*, not
the triggering user, so ``user_id`` is fully derivable from the search
space the caller is already operating on. The wrapper captures it once
at construction time.
* ``langchain_litellm.ChatLiteLLM`` has no public hook for "before each
call run this coroutine"; subclassing isn't safe across versions because
it derives from ``BaseChatModel`` which expects specific Pydantic shapes.
Composition via attribute proxying (``__getattr__``) is robust to
upstream changes every method other than ``ainvoke`` falls through to
the inner LLM unchanged.
"""
from __future__ import annotations
import logging
from typing import Any
from uuid import UUID
from app.services.billable_calls import QuotaInsufficientError, billable_call
logger = logging.getLogger(__name__)
class QuotaCheckedVisionLLM:
"""Composition wrapper around a langchain chat model that enforces
premium credit quota on every ``ainvoke``.
Anything other than ``ainvoke`` is forwarded to the inner model so
``invoke`` (sync), ``astream``, ``with_structured_output``, etc. all
still work they simply bypass quota enforcement, which is fine
because the indexing pipeline only ever calls ``ainvoke`` today.
"""
def __init__(
self,
inner_llm: Any,
*,
user_id: UUID,
search_space_id: int,
billing_tier: str,
base_model: str,
quota_reserve_tokens: int | None,
usage_type: str = "vision_extraction",
) -> None:
self._inner = inner_llm
self._user_id = user_id
self._search_space_id = search_space_id
self._billing_tier = billing_tier
self._base_model = base_model
self._quota_reserve_tokens = quota_reserve_tokens
self._usage_type = usage_type
async def ainvoke(self, input: Any, *args: Any, **kwargs: Any) -> Any:
"""Proxied async invoke that runs the underlying call inside
``billable_call``.
Raises:
QuotaInsufficientError: when the user has exhausted their
premium credit pool. Caller (``etl_pipeline_service._extract_image``)
catches this and falls back to the document parser.
"""
async with billable_call(
user_id=self._user_id,
search_space_id=self._search_space_id,
billing_tier=self._billing_tier,
base_model=self._base_model,
quota_reserve_tokens=self._quota_reserve_tokens,
usage_type=self._usage_type,
call_details={"model": self._base_model},
):
return await self._inner.ainvoke(input, *args, **kwargs)
def __getattr__(self, name: str) -> Any:
"""Forward everything else (``invoke``, ``astream``, ``bind``,
``with_structured_output``, ) to the inner model.
``__getattr__`` is only consulted when the attribute is *not*
already found on the proxy, which is exactly the contract we
want methods we override stay on the proxy, the rest fall
through.
"""
return getattr(self._inner, name)
__all__ = ["QuotaCheckedVisionLLM", "QuotaInsufficientError"]

View file

@ -22,6 +22,71 @@ from app.config import config
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Per-call reservation estimator (USD micro-units)
# ---------------------------------------------------------------------------
# Minimum reserve in micros so a user with $0.0001 left can still make a tiny
# request, and so models without registered pricing reserve at least
# something while the call runs (debited 0 at finalize anyway when their
# cost can't be resolved).
_QUOTA_MIN_RESERVE_MICROS = 100
def estimate_call_reserve_micros(
*,
base_model: str,
quota_reserve_tokens: int | None,
) -> int:
"""Return the number of micro-USD to reserve for one premium call.
Computes a worst-case upper bound from LiteLLM's per-token pricing
table:
reserve_usd reserve_tokens x (input_cost + output_cost)
so the math scales with model cost Claude Opus + 4K reserve_tokens
naturally reserves $0.36, while a cheap model reserves only a few
cents. Clamped to ``[_QUOTA_MIN_RESERVE_MICROS, QUOTA_MAX_RESERVE_MICROS]``
so a misconfigured "$1000/M" model can't lock the whole balance on
one call.
If ``litellm.get_model_info`` raises (model unknown) we fall back to
the floor 100 micros / $0.0001 which is enough to gate a sane
request without over-reserving for a model whose pricing the
operator hasn't declared yet.
"""
reserve_tokens = quota_reserve_tokens or config.QUOTA_MAX_RESERVE_PER_CALL
if reserve_tokens <= 0:
reserve_tokens = config.QUOTA_MAX_RESERVE_PER_CALL
try:
from litellm import get_model_info
info = get_model_info(base_model) if base_model else {}
input_cost = float(info.get("input_cost_per_token") or 0.0)
output_cost = float(info.get("output_cost_per_token") or 0.0)
except Exception as exc:
logger.debug(
"[quota_reserve] cost lookup failed for base_model=%s: %s",
base_model,
exc,
)
input_cost = 0.0
output_cost = 0.0
if input_cost == 0.0 and output_cost == 0.0:
return _QUOTA_MIN_RESERVE_MICROS
reserve_usd = reserve_tokens * (input_cost + output_cost)
reserve_micros = round(reserve_usd * 1_000_000)
if reserve_micros < _QUOTA_MIN_RESERVE_MICROS:
reserve_micros = _QUOTA_MIN_RESERVE_MICROS
if reserve_micros > config.QUOTA_MAX_RESERVE_MICROS:
reserve_micros = config.QUOTA_MAX_RESERVE_MICROS
return reserve_micros
class QuotaScope(StrEnum): class QuotaScope(StrEnum):
ANONYMOUS = "anonymous" ANONYMOUS = "anonymous"
PREMIUM = "premium" PREMIUM = "premium"
@ -444,8 +509,16 @@ class TokenQuotaService:
db_session: AsyncSession, db_session: AsyncSession,
user_id: Any, user_id: Any,
request_id: str, request_id: str,
reserve_tokens: int, reserve_micros: int,
) -> QuotaResult: ) -> QuotaResult:
"""Reserve ``reserve_micros`` (USD micro-units) from the user's
premium credit balance.
``QuotaResult.used``/``limit``/``reserved``/``remaining`` are
all in micro-USD on this code path; callers (chat stream,
token-status route, FE display) convert to dollars by dividing
by 1_000_000.
"""
from app.db import User from app.db import User
user = ( user = (
@ -465,11 +538,11 @@ class TokenQuotaService:
limit=0, limit=0,
) )
limit = user.premium_tokens_limit limit = user.premium_credit_micros_limit
used = user.premium_tokens_used used = user.premium_credit_micros_used
reserved = user.premium_tokens_reserved reserved = user.premium_credit_micros_reserved
effective = used + reserved + reserve_tokens effective = used + reserved + reserve_micros
if effective > limit: if effective > limit:
remaining = max(0, limit - used - reserved) remaining = max(0, limit - used - reserved)
await db_session.rollback() await db_session.rollback()
@ -482,10 +555,10 @@ class TokenQuotaService:
remaining=remaining, remaining=remaining,
) )
user.premium_tokens_reserved = reserved + reserve_tokens user.premium_credit_micros_reserved = reserved + reserve_micros
await db_session.commit() await db_session.commit()
new_reserved = reserved + reserve_tokens new_reserved = reserved + reserve_micros
remaining = max(0, limit - used - new_reserved) remaining = max(0, limit - used - new_reserved)
warning_threshold = int(limit * 0.8) warning_threshold = int(limit * 0.8)
@ -510,9 +583,12 @@ class TokenQuotaService:
db_session: AsyncSession, db_session: AsyncSession,
user_id: Any, user_id: Any,
request_id: str, request_id: str,
actual_tokens: int, actual_micros: int,
reserved_tokens: int, reserved_micros: int,
) -> QuotaResult: ) -> QuotaResult:
"""Settle the reservation: release ``reserved_micros`` and debit
``actual_micros`` (the LiteLLM-reported provider cost in micro-USD).
"""
from app.db import User from app.db import User
user = ( user = (
@ -529,16 +605,18 @@ class TokenQuotaService:
allowed=False, status=QuotaStatus.BLOCKED, used=0, limit=0 allowed=False, status=QuotaStatus.BLOCKED, used=0, limit=0
) )
user.premium_tokens_reserved = max( user.premium_credit_micros_reserved = max(
0, user.premium_tokens_reserved - reserved_tokens 0, user.premium_credit_micros_reserved - reserved_micros
)
user.premium_credit_micros_used = (
user.premium_credit_micros_used + actual_micros
) )
user.premium_tokens_used = user.premium_tokens_used + actual_tokens
await db_session.commit() await db_session.commit()
limit = user.premium_tokens_limit limit = user.premium_credit_micros_limit
used = user.premium_tokens_used used = user.premium_credit_micros_used
reserved = user.premium_tokens_reserved reserved = user.premium_credit_micros_reserved
remaining = max(0, limit - used - reserved) remaining = max(0, limit - used - reserved)
warning_threshold = int(limit * 0.8) warning_threshold = int(limit * 0.8)
@ -562,8 +640,13 @@ class TokenQuotaService:
async def premium_release( async def premium_release(
db_session: AsyncSession, db_session: AsyncSession,
user_id: Any, user_id: Any,
reserved_tokens: int, reserved_micros: int,
) -> None: ) -> None:
"""Release ``reserved_micros`` previously held by ``premium_reserve``.
Used when a request fails before finalize (so the reservation
doesn't leak credit).
"""
from app.db import User from app.db import User
user = ( user = (
@ -576,8 +659,8 @@ class TokenQuotaService:
.scalar_one_or_none() .scalar_one_or_none()
) )
if user is not None: if user is not None:
user.premium_tokens_reserved = max( user.premium_credit_micros_reserved = max(
0, user.premium_tokens_reserved - reserved_tokens 0, user.premium_credit_micros_reserved - reserved_micros
) )
await db_session.commit() await db_session.commit()
@ -598,9 +681,9 @@ class TokenQuotaService:
allowed=False, status=QuotaStatus.BLOCKED, used=0, limit=0 allowed=False, status=QuotaStatus.BLOCKED, used=0, limit=0
) )
limit = user.premium_tokens_limit limit = user.premium_credit_micros_limit
used = user.premium_tokens_used used = user.premium_credit_micros_used
reserved = user.premium_tokens_reserved reserved = user.premium_credit_micros_reserved
remaining = max(0, limit - used - reserved) remaining = max(0, limit - used - reserved)
warning_threshold = int(limit * 0.8) warning_threshold = int(limit * 0.8)

View file

@ -16,11 +16,14 @@ from __future__ import annotations
import dataclasses import dataclasses
import logging import logging
from collections.abc import AsyncIterator
from contextlib import asynccontextmanager
from contextvars import ContextVar from contextvars import ContextVar
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any from typing import Any
from uuid import UUID from uuid import UUID
import litellm
from litellm.integrations.custom_logger import CustomLogger from litellm.integrations.custom_logger import CustomLogger
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
@ -35,6 +38,8 @@ class TokenCallRecord:
prompt_tokens: int prompt_tokens: int
completion_tokens: int completion_tokens: int
total_tokens: int total_tokens: int
cost_micros: int = 0
call_kind: str = "chat"
@dataclass @dataclass
@ -49,6 +54,8 @@ class TurnTokenAccumulator:
prompt_tokens: int, prompt_tokens: int,
completion_tokens: int, completion_tokens: int,
total_tokens: int, total_tokens: int,
cost_micros: int = 0,
call_kind: str = "chat",
) -> None: ) -> None:
self.calls.append( self.calls.append(
TokenCallRecord( TokenCallRecord(
@ -56,20 +63,28 @@ class TurnTokenAccumulator:
prompt_tokens=prompt_tokens, prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens, completion_tokens=completion_tokens,
total_tokens=total_tokens, total_tokens=total_tokens,
cost_micros=cost_micros,
call_kind=call_kind,
) )
) )
def per_message_summary(self) -> dict[str, dict[str, int]]: def per_message_summary(self) -> dict[str, dict[str, int]]:
"""Return token counts grouped by model name.""" """Return token counts (and cost) grouped by model name."""
by_model: dict[str, dict[str, int]] = {} by_model: dict[str, dict[str, int]] = {}
for c in self.calls: for c in self.calls:
entry = by_model.setdefault( entry = by_model.setdefault(
c.model, c.model,
{"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}, {
"prompt_tokens": 0,
"completion_tokens": 0,
"total_tokens": 0,
"cost_micros": 0,
},
) )
entry["prompt_tokens"] += c.prompt_tokens entry["prompt_tokens"] += c.prompt_tokens
entry["completion_tokens"] += c.completion_tokens entry["completion_tokens"] += c.completion_tokens
entry["total_tokens"] += c.total_tokens entry["total_tokens"] += c.total_tokens
entry["cost_micros"] += c.cost_micros
return by_model return by_model
@property @property
@ -84,6 +99,21 @@ class TurnTokenAccumulator:
def total_completion_tokens(self) -> int: def total_completion_tokens(self) -> int:
return sum(c.completion_tokens for c in self.calls) return sum(c.completion_tokens for c in self.calls)
@property
def total_cost_micros(self) -> int:
"""Sum of per-call ``cost_micros`` across the entire turn.
Used by ``stream_new_chat`` to debit a premium turn's actual
provider cost (in micro-USD) from the user's premium credit
balance. ``cost_micros`` per call is captured by
``TokenTrackingCallback.async_log_success_event`` from
``kwargs["response_cost"]`` (LiteLLM's auto-calculated cost),
with multiple fallback paths so OpenRouter dynamic models and
custom Azure deployments still bill correctly when our
``pricing_registration`` ran at startup.
"""
return sum(c.cost_micros for c in self.calls)
def serialized_calls(self) -> list[dict[str, Any]]: def serialized_calls(self) -> list[dict[str, Any]]:
return [dataclasses.asdict(c) for c in self.calls] return [dataclasses.asdict(c) for c in self.calls]
@ -94,7 +124,14 @@ _turn_accumulator: ContextVar[TurnTokenAccumulator | None] = ContextVar(
def start_turn() -> TurnTokenAccumulator: def start_turn() -> TurnTokenAccumulator:
"""Create a fresh accumulator for the current async context and return it.""" """Create a fresh accumulator for the current async context and return it.
NOTE: Used by ``stream_new_chat`` for the long-lived chat turn. For
short-lived per-call billable wrappers (image generation REST endpoint,
vision LLM during indexing) prefer :func:`scoped_turn`, which uses a
ContextVar reset token to restore the *previous* accumulator on exit and
avoids leaking call records across reservations (issue B).
"""
acc = TurnTokenAccumulator() acc = TurnTokenAccumulator()
_turn_accumulator.set(acc) _turn_accumulator.set(acc)
logger.info("[TokenTracking] start_turn: new accumulator created (id=%s)", id(acc)) logger.info("[TokenTracking] start_turn: new accumulator created (id=%s)", id(acc))
@ -105,6 +142,140 @@ def get_current_accumulator() -> TurnTokenAccumulator | None:
return _turn_accumulator.get() return _turn_accumulator.get()
@asynccontextmanager
async def scoped_turn() -> AsyncIterator[TurnTokenAccumulator]:
"""Async context manager that scopes a fresh ``TurnTokenAccumulator``
for the duration of the ``async with`` block, then *resets* the
ContextVar to its previous value on exit.
This is the safe primitive for per-call billable operations
(image generation, vision LLM extraction, podcasts) that may run
inside an outer chat turn or be called sequentially from the same
background worker. Using ``ContextVar.set`` without ``reset`` (as
:func:`start_turn` does) would leak the inner accumulator into the
outer scope, causing the outer chat turn to debit cost twice.
Usage::
async with scoped_turn() as acc:
await llm.ainvoke(...)
# acc.total_cost_micros captures cost from the LiteLLM callback
# Outer accumulator (if any) is restored here.
"""
acc = TurnTokenAccumulator()
token = _turn_accumulator.set(acc)
logger.debug(
"[TokenTracking] scoped_turn: enter (acc id=%s, prev token=%s)",
id(acc),
token,
)
try:
yield acc
finally:
_turn_accumulator.reset(token)
logger.debug(
"[TokenTracking] scoped_turn: exit (acc id=%s captured %d call(s), %d micros total)",
id(acc),
len(acc.calls),
acc.total_cost_micros,
)
def _extract_cost_usd(
kwargs: dict[str, Any],
response_obj: Any,
model: str,
prompt_tokens: int,
completion_tokens: int,
is_image: bool = False,
) -> float:
"""Best-effort USD cost extraction for a single LLM/image call.
Tries four sources in priority order and returns the first that
yields a positive number; returns 0.0 if all four fail (the call
will then debit nothing from the user's balance — fail-safe).
Sources:
1. ``kwargs["response_cost"]`` LiteLLM's standard callback
field, populated for ``Router.acompletion`` since PR #12500.
2. ``response_obj._hidden_params["response_cost"]`` same value
exposed on the response itself.
3. ``litellm.completion_cost(completion_response=response_obj)``
recompute from the response and LiteLLM's pricing table.
4. ``litellm.cost_per_token(model, prompt_tokens, completion_tokens)``
manual fallback for OpenRouter/custom-Azure models that
only resolve via aliases registered by
``pricing_registration`` at startup. **Skipped for image
responses** ``cost_per_token`` does not support ``ImageResponse``
and would raise; the cost map for image-gen lives in different
keys (``output_cost_per_image``) handled by ``completion_cost``.
"""
cost = kwargs.get("response_cost")
if cost is not None:
try:
value = float(cost)
except (TypeError, ValueError):
value = 0.0
if value > 0:
return value
hidden = getattr(response_obj, "_hidden_params", None) or {}
if isinstance(hidden, dict):
cost = hidden.get("response_cost")
if cost is not None:
try:
value = float(cost)
except (TypeError, ValueError):
value = 0.0
if value > 0:
return value
try:
value = float(litellm.completion_cost(completion_response=response_obj))
if value > 0:
return value
except Exception as exc:
if is_image:
# Image-gen path: OpenRouter's image responses can omit
# ``usage.cost`` and LiteLLM's ``default_image_cost_calculator``
# then *raises* (no cost map for OpenRouter image models).
# Bail out with a warning rather than falling through to
# cost_per_token (which is also incompatible with ImageResponse).
logger.warning(
"[TokenTracking] completion_cost failed for image model=%s "
"(provider may have omitted usage.cost). Debiting 0. "
"Cause: %s",
model,
exc,
)
return 0.0
logger.debug(
"[TokenTracking] completion_cost failed for model=%s: %s", model, exc
)
if is_image:
# Never call cost_per_token for ImageResponse — keys mismatch and
# the function is documented chat-only.
return 0.0
if model and (prompt_tokens > 0 or completion_tokens > 0):
try:
prompt_cost, completion_cost = litellm.cost_per_token(
model=model,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
)
value = float(prompt_cost) + float(completion_cost)
if value > 0:
return value
except Exception as exc:
logger.debug(
"[TokenTracking] cost_per_token failed for model=%s: %s", model, exc
)
return 0.0
class TokenTrackingCallback(CustomLogger): class TokenTrackingCallback(CustomLogger):
"""LiteLLM callback that captures token usage into the turn accumulator.""" """LiteLLM callback that captures token usage into the turn accumulator."""
@ -122,6 +293,13 @@ class TokenTrackingCallback(CustomLogger):
) )
return return
# Detect image generation responses — they have a different usage
# shape (ImageUsage with input_tokens/output_tokens) and require a
# different cost-extraction path. We probe by class name to avoid a
# hard import dependency on litellm internals.
response_cls = type(response_obj).__name__
is_image = response_cls == "ImageResponse"
usage = getattr(response_obj, "usage", None) usage = getattr(response_obj, "usage", None)
if not usage: if not usage:
logger.debug( logger.debug(
@ -129,24 +307,66 @@ class TokenTrackingCallback(CustomLogger):
) )
return return
prompt_tokens = getattr(usage, "prompt_tokens", 0) or 0 if is_image:
completion_tokens = getattr(usage, "completion_tokens", 0) or 0 # ``ImageUsage`` exposes ``input_tokens`` / ``output_tokens``
total_tokens = getattr(usage, "total_tokens", 0) or 0 # (not prompt_tokens/completion_tokens). Several providers
# populate only one or neither (e.g. OpenRouter's gpt-image-1
# passes through `input_tokens` from the prompt but no
# completion); fall through gracefully to 0.
prompt_tokens = getattr(usage, "input_tokens", 0) or 0
completion_tokens = getattr(usage, "output_tokens", 0) or 0
total_tokens = (
getattr(usage, "total_tokens", 0) or prompt_tokens + completion_tokens
)
call_kind = "image_generation"
else:
prompt_tokens = getattr(usage, "prompt_tokens", 0) or 0
completion_tokens = getattr(usage, "completion_tokens", 0) or 0
total_tokens = getattr(usage, "total_tokens", 0) or 0
call_kind = "chat"
model = kwargs.get("model", "unknown") model = kwargs.get("model", "unknown")
cost_usd = _extract_cost_usd(
kwargs=kwargs,
response_obj=response_obj,
model=model,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
is_image=is_image,
)
cost_micros = round(cost_usd * 1_000_000) if cost_usd > 0 else 0
if cost_micros == 0 and (prompt_tokens > 0 or completion_tokens > 0):
logger.warning(
"[TokenTracking] No cost resolved for model=%s prompt=%d completion=%d "
"kind=%s — debiting 0. Register pricing via pricing_registration or YAML "
"input_cost_per_token/output_cost_per_token (or rely on response_cost "
"for image generation).",
model,
prompt_tokens,
completion_tokens,
call_kind,
)
acc.add( acc.add(
model=model, model=model,
prompt_tokens=prompt_tokens, prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens, completion_tokens=completion_tokens,
total_tokens=total_tokens, total_tokens=total_tokens,
cost_micros=cost_micros,
call_kind=call_kind,
) )
logger.info( logger.info(
"[TokenTracking] Captured: model=%s prompt=%d completion=%d total=%d (accumulator now has %d calls)", "[TokenTracking] Captured: model=%s kind=%s prompt=%d completion=%d total=%d "
"cost=$%.6f (%d micros) (accumulator now has %d calls)",
model, model,
call_kind,
prompt_tokens, prompt_tokens,
completion_tokens, completion_tokens,
total_tokens, total_tokens,
cost_usd,
cost_micros,
len(acc.calls), len(acc.calls),
) )
@ -168,6 +388,7 @@ async def record_token_usage(
prompt_tokens: int = 0, prompt_tokens: int = 0,
completion_tokens: int = 0, completion_tokens: int = 0,
total_tokens: int = 0, total_tokens: int = 0,
cost_micros: int = 0,
model_breakdown: dict[str, Any] | None = None, model_breakdown: dict[str, Any] | None = None,
call_details: dict[str, Any] | None = None, call_details: dict[str, Any] | None = None,
thread_id: int | None = None, thread_id: int | None = None,
@ -185,6 +406,7 @@ async def record_token_usage(
prompt_tokens=prompt_tokens, prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens, completion_tokens=completion_tokens,
total_tokens=total_tokens, total_tokens=total_tokens,
cost_micros=cost_micros,
model_breakdown=model_breakdown, model_breakdown=model_breakdown,
call_details=call_details, call_details=call_details,
thread_id=thread_id, thread_id=thread_id,
@ -194,11 +416,12 @@ async def record_token_usage(
) )
session.add(record) session.add(record)
logger.debug( logger.debug(
"[TokenTracking] recorded %s usage: prompt=%d completion=%d total=%d", "[TokenTracking] recorded %s usage: prompt=%d completion=%d total=%d cost_micros=%d",
usage_type, usage_type,
prompt_tokens, prompt_tokens,
completion_tokens, completion_tokens,
total_tokens, total_tokens,
cost_micros,
) )
return record return record
except Exception: except Exception:

View file

@ -3,6 +3,8 @@ from typing import Any
from litellm import Router from litellm import Router
from app.services.provider_api_base import resolve_api_base
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
VISION_AUTO_MODE_ID = 0 VISION_AUTO_MODE_ID = 0
@ -108,10 +110,11 @@ class VisionLLMRouterService:
if not config.get("model_name") or not config.get("api_key"): if not config.get("model_name") or not config.get("api_key"):
return None return None
provider = config.get("provider", "").upper()
if config.get("custom_provider"): if config.get("custom_provider"):
model_string = f"{config['custom_provider']}/{config['model_name']}" provider_prefix = config["custom_provider"]
model_string = f"{provider_prefix}/{config['model_name']}"
else: else:
provider = config.get("provider", "").upper()
provider_prefix = VISION_PROVIDER_MAP.get(provider, provider.lower()) provider_prefix = VISION_PROVIDER_MAP.get(provider, provider.lower())
model_string = f"{provider_prefix}/{config['model_name']}" model_string = f"{provider_prefix}/{config['model_name']}"
@ -120,8 +123,13 @@ class VisionLLMRouterService:
"api_key": config.get("api_key"), "api_key": config.get("api_key"),
} }
if config.get("api_base"): api_base = resolve_api_base(
litellm_params["api_base"] = config["api_base"] provider=provider,
provider_prefix=provider_prefix,
config_api_base=config.get("api_base"),
)
if api_base:
litellm_params["api_base"] = api_base
if config.get("api_version"): if config.get("api_version"):
litellm_params["api_version"] = config["api_version"] litellm_params["api_version"] = config["api_version"]

View file

@ -9,7 +9,13 @@ from sqlalchemy import select
from app.agents.podcaster.graph import graph as podcaster_graph from app.agents.podcaster.graph import graph as podcaster_graph
from app.agents.podcaster.state import State as PodcasterState from app.agents.podcaster.state import State as PodcasterState
from app.celery_app import celery_app from app.celery_app import celery_app
from app.config import config as app_config
from app.db import Podcast, PodcastStatus from app.db import Podcast, PodcastStatus
from app.services.billable_calls import (
QuotaInsufficientError,
_resolve_agent_billing_for_search_space,
billable_call,
)
from app.tasks.celery_tasks import get_celery_session_maker from app.tasks.celery_tasks import get_celery_session_maker
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -96,6 +102,31 @@ async def _generate_content_podcast(
podcast.status = PodcastStatus.GENERATING podcast.status = PodcastStatus.GENERATING
await session.commit() await session.commit()
try:
(
owner_user_id,
billing_tier,
base_model,
) = await _resolve_agent_billing_for_search_space(
session,
search_space_id,
thread_id=podcast.thread_id,
)
except ValueError as resolve_err:
logger.error(
"Podcast %s: cannot resolve billing for search_space=%s: %s",
podcast.id,
search_space_id,
resolve_err,
)
podcast.status = PodcastStatus.FAILED
await session.commit()
return {
"status": "failed",
"podcast_id": podcast.id,
"reason": "billing_resolution_failed",
}
graph_config = { graph_config = {
"configurable": { "configurable": {
"podcast_title": podcast.title, "podcast_title": podcast.title,
@ -109,9 +140,39 @@ async def _generate_content_podcast(
db_session=session, db_session=session,
) )
graph_result = await podcaster_graph.ainvoke( try:
initial_state, config=graph_config async with billable_call(
) user_id=owner_user_id,
search_space_id=search_space_id,
billing_tier=billing_tier,
base_model=base_model,
quota_reserve_micros_override=app_config.QUOTA_DEFAULT_PODCAST_RESERVE_MICROS,
usage_type="podcast_generation",
thread_id=podcast.thread_id,
call_details={
"podcast_id": podcast.id,
"title": podcast.title,
},
):
graph_result = await podcaster_graph.ainvoke(
initial_state, config=graph_config
)
except QuotaInsufficientError as exc:
logger.info(
"Podcast %s denied: out of premium credits "
"(used=%d/%d remaining=%d)",
podcast.id,
exc.used_micros,
exc.limit_micros,
exc.remaining_micros,
)
podcast.status = PodcastStatus.FAILED
await session.commit()
return {
"status": "failed",
"podcast_id": podcast.id,
"reason": "premium_quota_exhausted",
}
podcast_transcript = graph_result.get("podcast_transcript", []) podcast_transcript = graph_result.get("podcast_transcript", [])
file_path = graph_result.get("final_podcast_file_path", "") file_path = graph_result.get("final_podcast_file_path", "")

View file

@ -9,7 +9,13 @@ from sqlalchemy import select
from app.agents.video_presentation.graph import graph as video_presentation_graph from app.agents.video_presentation.graph import graph as video_presentation_graph
from app.agents.video_presentation.state import State as VideoPresentationState from app.agents.video_presentation.state import State as VideoPresentationState
from app.celery_app import celery_app from app.celery_app import celery_app
from app.config import config as app_config
from app.db import VideoPresentation, VideoPresentationStatus from app.db import VideoPresentation, VideoPresentationStatus
from app.services.billable_calls import (
QuotaInsufficientError,
_resolve_agent_billing_for_search_space,
billable_call,
)
from app.tasks.celery_tasks import get_celery_session_maker from app.tasks.celery_tasks import get_celery_session_maker
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -97,6 +103,32 @@ async def _generate_video_presentation(
video_pres.status = VideoPresentationStatus.GENERATING video_pres.status = VideoPresentationStatus.GENERATING
await session.commit() await session.commit()
try:
(
owner_user_id,
billing_tier,
base_model,
) = await _resolve_agent_billing_for_search_space(
session,
search_space_id,
thread_id=video_pres.thread_id,
)
except ValueError as resolve_err:
logger.error(
"VideoPresentation %s: cannot resolve billing for "
"search_space=%s: %s",
video_pres.id,
search_space_id,
resolve_err,
)
video_pres.status = VideoPresentationStatus.FAILED
await session.commit()
return {
"status": "failed",
"video_presentation_id": video_pres.id,
"reason": "billing_resolution_failed",
}
graph_config = { graph_config = {
"configurable": { "configurable": {
"video_title": video_pres.title, "video_title": video_pres.title,
@ -110,9 +142,39 @@ async def _generate_video_presentation(
db_session=session, db_session=session,
) )
graph_result = await video_presentation_graph.ainvoke( try:
initial_state, config=graph_config async with billable_call(
) user_id=owner_user_id,
search_space_id=search_space_id,
billing_tier=billing_tier,
base_model=base_model,
quota_reserve_micros_override=app_config.QUOTA_DEFAULT_VIDEO_PRESENTATION_RESERVE_MICROS,
usage_type="video_presentation_generation",
thread_id=video_pres.thread_id,
call_details={
"video_presentation_id": video_pres.id,
"title": video_pres.title,
},
):
graph_result = await video_presentation_graph.ainvoke(
initial_state, config=graph_config
)
except QuotaInsufficientError as exc:
logger.info(
"VideoPresentation %s denied: out of premium credits "
"(used=%d/%d remaining=%d)",
video_pres.id,
exc.used_micros,
exc.limit_micros,
exc.remaining_micros,
)
video_pres.status = VideoPresentationStatus.FAILED
await session.commit()
return {
"status": "failed",
"video_presentation_id": video_pres.id,
"reason": "premium_quota_exhausted",
}
# Serialize slides (parsed content + audio info merged) # Serialize slides (parsed content + audio info merged)
slides_raw = graph_result.get("slides", []) slides_raw = graph_result.get("slides", [])

View file

@ -2236,8 +2236,10 @@ async def stream_new_chat(
accumulator = start_turn() accumulator = start_turn()
# Premium quota tracking state # Premium credit (USD micro-units) tracking state. Stores the
_premium_reserved = 0 # amount reserved up front so we can release it on cancellation
# and finalize-debit the actual provider cost reported by LiteLLM.
_premium_reserved_micros = 0
_premium_request_id: str | None = None _premium_request_id: str | None = None
_emit_stream_error = partial( _emit_stream_error = partial(
@ -2331,23 +2333,28 @@ async def stream_new_chat(
if _needs_premium_quota: if _needs_premium_quota:
import uuid as _uuid import uuid as _uuid
from app.config import config as _app_config from app.services.token_quota_service import (
from app.services.token_quota_service import TokenQuotaService TokenQuotaService,
estimate_call_reserve_micros,
)
_premium_request_id = _uuid.uuid4().hex[:16] _premium_request_id = _uuid.uuid4().hex[:16]
reserve_amount = min( _agent_litellm_params = agent_config.litellm_params or {}
agent_config.quota_reserve_tokens _agent_base_model = (
or _app_config.QUOTA_MAX_RESERVE_PER_CALL, _agent_litellm_params.get("base_model") or agent_config.model_name or ""
_app_config.QUOTA_MAX_RESERVE_PER_CALL, )
reserve_amount_micros = estimate_call_reserve_micros(
base_model=_agent_base_model,
quota_reserve_tokens=agent_config.quota_reserve_tokens,
) )
async with shielded_async_session() as quota_session: async with shielded_async_session() as quota_session:
quota_result = await TokenQuotaService.premium_reserve( quota_result = await TokenQuotaService.premium_reserve(
db_session=quota_session, db_session=quota_session,
user_id=UUID(user_id), user_id=UUID(user_id),
request_id=_premium_request_id, request_id=_premium_request_id,
reserve_tokens=reserve_amount, reserve_micros=reserve_amount_micros,
) )
_premium_reserved = reserve_amount _premium_reserved_micros = reserve_amount_micros
if not quota_result.allowed: if not quota_result.allowed:
if requested_llm_config_id == 0: if requested_llm_config_id == 0:
try: try:
@ -2382,7 +2389,7 @@ async def stream_new_chat(
yield streaming_service.format_done() yield streaming_service.format_done()
return return
_premium_request_id = None _premium_request_id = None
_premium_reserved = 0 _premium_reserved_micros = 0
_log_chat_stream_error( _log_chat_stream_error(
flow=flow, flow=flow,
error_kind="premium_quota_exhausted", error_kind="premium_quota_exhausted",
@ -3020,9 +3027,10 @@ async def stream_new_chat(
usage_summary = accumulator.per_message_summary() usage_summary = accumulator.per_message_summary()
_perf_log.info( _perf_log.info(
"[token_usage] interrupted new_chat: calls=%d total=%d summary=%s", "[token_usage] interrupted new_chat: calls=%d total=%d cost_micros=%d summary=%s",
len(accumulator.calls), len(accumulator.calls),
accumulator.grand_total, accumulator.grand_total,
accumulator.total_cost_micros,
usage_summary, usage_summary,
) )
if usage_summary: if usage_summary:
@ -3033,6 +3041,7 @@ async def stream_new_chat(
"prompt_tokens": accumulator.total_prompt_tokens, "prompt_tokens": accumulator.total_prompt_tokens,
"completion_tokens": accumulator.total_completion_tokens, "completion_tokens": accumulator.total_completion_tokens,
"total_tokens": accumulator.grand_total, "total_tokens": accumulator.grand_total,
"cost_micros": accumulator.total_cost_micros,
"call_details": accumulator.serialized_calls(), "call_details": accumulator.serialized_calls(),
}, },
) )
@ -3060,7 +3069,11 @@ async def stream_new_chat(
chat_id, generated_title chat_id, generated_title
) )
# Finalize premium quota with actual tokens. # Finalize premium credit debit with the actual provider cost
# reported by LiteLLM, summed across every call in the turn.
# Mirrors the pre-cost behaviour of "premium turn → all calls
# count" so free sub-agent calls during a premium turn still
# contribute to the bill (they're $0 in practice anyway).
if _premium_request_id and user_id: if _premium_request_id and user_id:
try: try:
from app.services.token_quota_service import TokenQuotaService from app.services.token_quota_service import TokenQuotaService
@ -3070,11 +3083,11 @@ async def stream_new_chat(
db_session=quota_session, db_session=quota_session,
user_id=UUID(user_id), user_id=UUID(user_id),
request_id=_premium_request_id, request_id=_premium_request_id,
actual_tokens=accumulator.grand_total, actual_micros=accumulator.total_cost_micros,
reserved_tokens=_premium_reserved, reserved_micros=_premium_reserved_micros,
) )
_premium_request_id = None _premium_request_id = None
_premium_reserved = 0 _premium_reserved_micros = 0
except Exception: except Exception:
logging.getLogger(__name__).warning( logging.getLogger(__name__).warning(
"Failed to finalize premium quota for user %s", "Failed to finalize premium quota for user %s",
@ -3084,9 +3097,10 @@ async def stream_new_chat(
usage_summary = accumulator.per_message_summary() usage_summary = accumulator.per_message_summary()
_perf_log.info( _perf_log.info(
"[token_usage] normal new_chat: calls=%d total=%d summary=%s", "[token_usage] normal new_chat: calls=%d total=%d cost_micros=%d summary=%s",
len(accumulator.calls), len(accumulator.calls),
accumulator.grand_total, accumulator.grand_total,
accumulator.total_cost_micros,
usage_summary, usage_summary,
) )
if usage_summary: if usage_summary:
@ -3097,6 +3111,7 @@ async def stream_new_chat(
"prompt_tokens": accumulator.total_prompt_tokens, "prompt_tokens": accumulator.total_prompt_tokens,
"completion_tokens": accumulator.total_completion_tokens, "completion_tokens": accumulator.total_completion_tokens,
"total_tokens": accumulator.grand_total, "total_tokens": accumulator.grand_total,
"cost_micros": accumulator.total_cost_micros,
"call_details": accumulator.serialized_calls(), "call_details": accumulator.serialized_calls(),
}, },
) )
@ -3190,7 +3205,7 @@ async def stream_new_chat(
end_turn(str(chat_id)) end_turn(str(chat_id))
# Release premium reservation if not finalized # Release premium reservation if not finalized
if _premium_request_id and _premium_reserved > 0 and user_id: if _premium_request_id and _premium_reserved_micros > 0 and user_id:
try: try:
from app.services.token_quota_service import TokenQuotaService from app.services.token_quota_service import TokenQuotaService
@ -3198,9 +3213,9 @@ async def stream_new_chat(
await TokenQuotaService.premium_release( await TokenQuotaService.premium_release(
db_session=quota_session, db_session=quota_session,
user_id=UUID(user_id), user_id=UUID(user_id),
reserved_tokens=_premium_reserved, reserved_micros=_premium_reserved_micros,
) )
_premium_reserved = 0 _premium_reserved_micros = 0
except Exception: except Exception:
logging.getLogger(__name__).warning( logging.getLogger(__name__).warning(
"Failed to release premium quota for user %s", user_id "Failed to release premium quota for user %s", user_id
@ -3369,8 +3384,8 @@ async def stream_resume_chat(
"[stream_resume] LLM config loaded in %.3fs", time.perf_counter() - _t0 "[stream_resume] LLM config loaded in %.3fs", time.perf_counter() - _t0
) )
# Premium quota reservation (same logic as stream_new_chat) # Premium credit reservation (same logic as stream_new_chat).
_resume_premium_reserved = 0 _resume_premium_reserved_micros = 0
_resume_premium_request_id: str | None = None _resume_premium_request_id: str | None = None
_resume_needs_premium = ( _resume_needs_premium = (
agent_config is not None and user_id and agent_config.is_premium agent_config is not None and user_id and agent_config.is_premium
@ -3378,23 +3393,30 @@ async def stream_resume_chat(
if _resume_needs_premium: if _resume_needs_premium:
import uuid as _uuid import uuid as _uuid
from app.config import config as _app_config from app.services.token_quota_service import (
from app.services.token_quota_service import TokenQuotaService TokenQuotaService,
estimate_call_reserve_micros,
)
_resume_premium_request_id = _uuid.uuid4().hex[:16] _resume_premium_request_id = _uuid.uuid4().hex[:16]
reserve_amount = min( _resume_litellm_params = agent_config.litellm_params or {}
agent_config.quota_reserve_tokens _resume_base_model = (
or _app_config.QUOTA_MAX_RESERVE_PER_CALL, _resume_litellm_params.get("base_model")
_app_config.QUOTA_MAX_RESERVE_PER_CALL, or agent_config.model_name
or ""
)
reserve_amount_micros = estimate_call_reserve_micros(
base_model=_resume_base_model,
quota_reserve_tokens=agent_config.quota_reserve_tokens,
) )
async with shielded_async_session() as quota_session: async with shielded_async_session() as quota_session:
quota_result = await TokenQuotaService.premium_reserve( quota_result = await TokenQuotaService.premium_reserve(
db_session=quota_session, db_session=quota_session,
user_id=UUID(user_id), user_id=UUID(user_id),
request_id=_resume_premium_request_id, request_id=_resume_premium_request_id,
reserve_tokens=reserve_amount, reserve_micros=reserve_amount_micros,
) )
_resume_premium_reserved = reserve_amount _resume_premium_reserved_micros = reserve_amount_micros
if not quota_result.allowed: if not quota_result.allowed:
if requested_llm_config_id == 0: if requested_llm_config_id == 0:
try: try:
@ -3429,7 +3451,7 @@ async def stream_resume_chat(
yield streaming_service.format_done() yield streaming_service.format_done()
return return
_resume_premium_request_id = None _resume_premium_request_id = None
_resume_premium_reserved = 0 _resume_premium_reserved_micros = 0
_log_chat_stream_error( _log_chat_stream_error(
flow="resume", flow="resume",
error_kind="premium_quota_exhausted", error_kind="premium_quota_exhausted",
@ -3746,9 +3768,10 @@ async def stream_resume_chat(
if stream_result.is_interrupted: if stream_result.is_interrupted:
usage_summary = accumulator.per_message_summary() usage_summary = accumulator.per_message_summary()
_perf_log.info( _perf_log.info(
"[token_usage] interrupted resume_chat: calls=%d total=%d summary=%s", "[token_usage] interrupted resume_chat: calls=%d total=%d cost_micros=%d summary=%s",
len(accumulator.calls), len(accumulator.calls),
accumulator.grand_total, accumulator.grand_total,
accumulator.total_cost_micros,
usage_summary, usage_summary,
) )
if usage_summary: if usage_summary:
@ -3759,6 +3782,7 @@ async def stream_resume_chat(
"prompt_tokens": accumulator.total_prompt_tokens, "prompt_tokens": accumulator.total_prompt_tokens,
"completion_tokens": accumulator.total_completion_tokens, "completion_tokens": accumulator.total_completion_tokens,
"total_tokens": accumulator.grand_total, "total_tokens": accumulator.grand_total,
"cost_micros": accumulator.total_cost_micros,
"call_details": accumulator.serialized_calls(), "call_details": accumulator.serialized_calls(),
}, },
) )
@ -3768,7 +3792,9 @@ async def stream_resume_chat(
yield streaming_service.format_done() yield streaming_service.format_done()
return return
# Finalize premium quota for resume path # Finalize premium credit debit for resume path with the actual
# provider cost reported by LiteLLM (sum of cost across all
# calls in the turn).
if _resume_premium_request_id and user_id: if _resume_premium_request_id and user_id:
try: try:
from app.services.token_quota_service import TokenQuotaService from app.services.token_quota_service import TokenQuotaService
@ -3778,11 +3804,11 @@ async def stream_resume_chat(
db_session=quota_session, db_session=quota_session,
user_id=UUID(user_id), user_id=UUID(user_id),
request_id=_resume_premium_request_id, request_id=_resume_premium_request_id,
actual_tokens=accumulator.grand_total, actual_micros=accumulator.total_cost_micros,
reserved_tokens=_resume_premium_reserved, reserved_micros=_resume_premium_reserved_micros,
) )
_resume_premium_request_id = None _resume_premium_request_id = None
_resume_premium_reserved = 0 _resume_premium_reserved_micros = 0
except Exception: except Exception:
logging.getLogger(__name__).warning( logging.getLogger(__name__).warning(
"Failed to finalize premium quota for user %s (resume)", "Failed to finalize premium quota for user %s (resume)",
@ -3792,9 +3818,10 @@ async def stream_resume_chat(
usage_summary = accumulator.per_message_summary() usage_summary = accumulator.per_message_summary()
_perf_log.info( _perf_log.info(
"[token_usage] normal resume_chat: calls=%d total=%d summary=%s", "[token_usage] normal resume_chat: calls=%d total=%d cost_micros=%d summary=%s",
len(accumulator.calls), len(accumulator.calls),
accumulator.grand_total, accumulator.grand_total,
accumulator.total_cost_micros,
usage_summary, usage_summary,
) )
if usage_summary: if usage_summary:
@ -3805,6 +3832,7 @@ async def stream_resume_chat(
"prompt_tokens": accumulator.total_prompt_tokens, "prompt_tokens": accumulator.total_prompt_tokens,
"completion_tokens": accumulator.total_completion_tokens, "completion_tokens": accumulator.total_completion_tokens,
"total_tokens": accumulator.grand_total, "total_tokens": accumulator.grand_total,
"cost_micros": accumulator.total_cost_micros,
"call_details": accumulator.serialized_calls(), "call_details": accumulator.serialized_calls(),
}, },
) )
@ -3855,7 +3883,11 @@ async def stream_resume_chat(
end_turn(str(chat_id)) end_turn(str(chat_id))
# Release premium reservation if not finalized # Release premium reservation if not finalized
if _resume_premium_request_id and _resume_premium_reserved > 0 and user_id: if (
_resume_premium_request_id
and _resume_premium_reserved_micros > 0
and user_id
):
try: try:
from app.services.token_quota_service import TokenQuotaService from app.services.token_quota_service import TokenQuotaService
@ -3863,9 +3895,9 @@ async def stream_resume_chat(
await TokenQuotaService.premium_release( await TokenQuotaService.premium_release(
db_session=quota_session, db_session=quota_session,
user_id=UUID(user_id), user_id=UUID(user_id),
reserved_tokens=_resume_premium_reserved, reserved_micros=_resume_premium_reserved_micros,
) )
_resume_premium_reserved = 0 _resume_premium_reserved_micros = 0
except Exception: except Exception:
logging.getLogger(__name__).warning( logging.getLogger(__name__).warning(
"Failed to release premium quota for user %s (resume)", user_id "Failed to release premium quota for user %s (resume)", user_id

View file

@ -0,0 +1,138 @@
"""Unit tests for the image-generation route's billing-resolution helper.
End-to-end "POST /image-generations returns 402" coverage requires the
integration harness (real DB, real auth) and lives in
``tests/integration/document_upload/`` alongside the other quota tests.
This unit test focuses on the new ``_resolve_billing_for_image_gen``
helper which:
* Returns ``free`` for Auto mode, even when premium configs exist
(Auto-mode billing-tier surfacing is a follow-up).
* Returns ``free`` for user-owned BYOK configs (positive IDs).
* Returns the global config's ``billing_tier`` for negative IDs.
* Honours the per-config ``quota_reserve_micros`` override when present.
"""
from __future__ import annotations
from types import SimpleNamespace
import pytest
pytestmark = pytest.mark.unit
@pytest.mark.asyncio
async def test_resolve_billing_for_auto_mode(monkeypatch):
from app.routes import image_generation_routes
from app.services.billable_calls import DEFAULT_IMAGE_RESERVE_MICROS
search_space = SimpleNamespace(image_generation_config_id=None)
tier, model, reserve = await image_generation_routes._resolve_billing_for_image_gen(
session=None, # Not consumed on this code path.
config_id=0, # IMAGE_GEN_AUTO_MODE_ID
search_space=search_space,
)
assert tier == "free"
assert model == "auto"
assert reserve == DEFAULT_IMAGE_RESERVE_MICROS
@pytest.mark.asyncio
async def test_resolve_billing_for_premium_global_config(monkeypatch):
from app.config import config
from app.routes import image_generation_routes
monkeypatch.setattr(
config,
"GLOBAL_IMAGE_GEN_CONFIGS",
[
{
"id": -1,
"provider": "OPENAI",
"model_name": "gpt-image-1",
"billing_tier": "premium",
"quota_reserve_micros": 75_000,
},
{
"id": -2,
"provider": "OPENROUTER",
"model_name": "google/gemini-2.5-flash-image",
"billing_tier": "free",
},
],
raising=False,
)
search_space = SimpleNamespace(image_generation_config_id=None)
# Premium with override.
tier, model, reserve = await image_generation_routes._resolve_billing_for_image_gen(
session=None, config_id=-1, search_space=search_space
)
assert tier == "premium"
assert model == "openai/gpt-image-1"
assert reserve == 75_000
# Free, no override → falls back to default.
from app.services.billable_calls import DEFAULT_IMAGE_RESERVE_MICROS
tier, model, reserve = await image_generation_routes._resolve_billing_for_image_gen(
session=None, config_id=-2, search_space=search_space
)
assert tier == "free"
# Provider-prefixed model string for OpenRouter.
assert "google/gemini-2.5-flash-image" in model
assert reserve == DEFAULT_IMAGE_RESERVE_MICROS
@pytest.mark.asyncio
async def test_resolve_billing_for_user_owned_byok_is_free():
"""User-owned BYOK configs (positive IDs) cost the user nothing on
our side they pay the provider directly. Always free.
"""
from app.routes import image_generation_routes
from app.services.billable_calls import DEFAULT_IMAGE_RESERVE_MICROS
search_space = SimpleNamespace(image_generation_config_id=None)
tier, model, reserve = await image_generation_routes._resolve_billing_for_image_gen(
session=None, config_id=42, search_space=search_space
)
assert tier == "free"
assert model == "user_byok"
assert reserve == DEFAULT_IMAGE_RESERVE_MICROS
@pytest.mark.asyncio
async def test_resolve_billing_falls_back_to_search_space_default(monkeypatch):
"""When the request omits ``image_generation_config_id``, the helper
must consult the search space's default — so a search space pinned
to a premium global config still gates new requests by quota.
"""
from app.config import config
from app.routes import image_generation_routes
monkeypatch.setattr(
config,
"GLOBAL_IMAGE_GEN_CONFIGS",
[
{
"id": -7,
"provider": "OPENAI",
"model_name": "gpt-image-1",
"billing_tier": "premium",
}
],
raising=False,
)
search_space = SimpleNamespace(image_generation_config_id=-7)
(
tier,
model,
_reserve,
) = await image_generation_routes._resolve_billing_for_image_gen(
session=None, config_id=None, search_space=search_space
)
assert tier == "premium"
assert model == "openai/gpt-image-1"

View file

@ -0,0 +1,436 @@
"""Unit tests for ``_resolve_agent_billing_for_search_space``.
Validates the resolver used by Celery podcast/video tasks to compute
``(owner_user_id, billing_tier, base_model)`` from a search space and its
agent LLM config. The resolver mirrors chat's billing-resolution pattern at
``stream_new_chat.py:2294-2351`` and is the single integration point that
prevents Auto-mode podcast/video from leaking premium credit.
Coverage:
* Auto mode + ``thread_id`` set, pin resolves to a negative-id premium
global returns ``("premium", <base_model>)``.
* Auto mode + ``thread_id`` set, pin resolves to a negative-id free
global returns ``("free", <base_model>)``.
* Auto mode + ``thread_id`` set, pin resolves to a positive-id BYOK config
always ``"free"``.
* Auto mode + ``thread_id=None`` fallback to ``("free", "auto")`` without
hitting the pin service.
* Negative id (no Auto) uses ``get_global_llm_config``'s
``billing_tier``.
* Positive id (user BYOK) always ``"free"``.
* Search space not found raises ``ValueError``.
* ``agent_llm_id`` is None raises ``ValueError``.
"""
from __future__ import annotations
from dataclasses import dataclass
from types import SimpleNamespace
from uuid import UUID, uuid4
import pytest
pytestmark = pytest.mark.unit
# ---------------------------------------------------------------------------
# Fakes
# ---------------------------------------------------------------------------
class _FakeExecResult:
def __init__(self, obj):
self._obj = obj
def scalars(self):
return self
def first(self):
return self._obj
class _FakeSession:
"""Tiny AsyncSession stub.
``responses`` is a list of objects to return from successive
``execute()`` calls (in order). The resolver makes at most two
``execute()`` calls (search-space lookup, then optionally NewLLMConfig
lookup), so two queued responses cover the matrix.
"""
def __init__(self, responses: list):
self._responses = list(responses)
async def execute(self, _stmt):
if not self._responses:
return _FakeExecResult(None)
return _FakeExecResult(self._responses.pop(0))
async def commit(self) -> None:
pass
@dataclass
class _FakePinResolution:
resolved_llm_config_id: int
resolved_tier: str = "premium"
from_existing_pin: bool = False
def _make_search_space(*, agent_llm_id: int | None, user_id: UUID) -> SimpleNamespace:
return SimpleNamespace(
id=42,
agent_llm_id=agent_llm_id,
user_id=user_id,
)
def _make_byok_config(
*, id_: int, base_model: str | None = None, model_name: str = "gpt-byok"
) -> SimpleNamespace:
return SimpleNamespace(
id=id_,
model_name=model_name,
litellm_params={"base_model": base_model} if base_model else {},
)
# ---------------------------------------------------------------------------
# Tests
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_auto_mode_with_thread_id_resolves_to_premium_global(monkeypatch):
"""Auto + thread → pin service resolves to negative-id premium config →
resolver returns ``("premium", <base_model>)``."""
from app.services.billable_calls import _resolve_agent_billing_for_search_space
user_id = uuid4()
session = _FakeSession([_make_search_space(agent_llm_id=0, user_id=user_id)])
# Mock the pin service to return a concrete premium config id.
async def _fake_resolve_pin(
sess,
*,
thread_id,
search_space_id,
user_id,
selected_llm_config_id,
force_repin_free=False,
):
assert selected_llm_config_id == 0
assert thread_id == 99
return _FakePinResolution(resolved_llm_config_id=-1, resolved_tier="premium")
# Mock global config lookup to return a premium entry.
def _fake_get_global(cfg_id):
if cfg_id == -1:
return {
"id": -1,
"model_name": "gpt-5.4",
"billing_tier": "premium",
"litellm_params": {"base_model": "gpt-5.4"},
}
return None
# Lazy imports inside the resolver — patch the *target* modules so the
# imported names resolve to our fakes.
import app.services.auto_model_pin_service as pin_module
import app.services.llm_service as llm_module
monkeypatch.setattr(
pin_module, "resolve_or_get_pinned_llm_config_id", _fake_resolve_pin
)
monkeypatch.setattr(llm_module, "get_global_llm_config", _fake_get_global)
owner, tier, base_model = await _resolve_agent_billing_for_search_space(
session, search_space_id=42, thread_id=99
)
assert owner == user_id
assert tier == "premium"
assert base_model == "gpt-5.4"
@pytest.mark.asyncio
async def test_auto_mode_with_thread_id_resolves_to_free_global(monkeypatch):
"""Auto + thread → pin returns negative-id free config → resolver
returns ``("free", <base_model>)``. Same path the pin service takes for
out-of-credit users (graceful degradation)."""
from app.services.billable_calls import _resolve_agent_billing_for_search_space
user_id = uuid4()
session = _FakeSession([_make_search_space(agent_llm_id=0, user_id=user_id)])
async def _fake_resolve_pin(
sess,
*,
thread_id,
search_space_id,
user_id,
selected_llm_config_id,
force_repin_free=False,
):
return _FakePinResolution(resolved_llm_config_id=-3, resolved_tier="free")
def _fake_get_global(cfg_id):
if cfg_id == -3:
return {
"id": -3,
"model_name": "openrouter/free-model",
"billing_tier": "free",
"litellm_params": {"base_model": "openrouter/free-model"},
}
return None
import app.services.auto_model_pin_service as pin_module
import app.services.llm_service as llm_module
monkeypatch.setattr(
pin_module, "resolve_or_get_pinned_llm_config_id", _fake_resolve_pin
)
monkeypatch.setattr(llm_module, "get_global_llm_config", _fake_get_global)
owner, tier, base_model = await _resolve_agent_billing_for_search_space(
session, search_space_id=42, thread_id=99
)
assert owner == user_id
assert tier == "free"
assert base_model == "openrouter/free-model"
@pytest.mark.asyncio
async def test_auto_mode_with_thread_id_resolves_to_byok_is_free(monkeypatch):
"""Auto + thread → pin returns positive-id BYOK config → resolver
returns ``("free", ...)`` (BYOK is always free per
``AgentConfig.from_new_llm_config``)."""
from app.services.billable_calls import _resolve_agent_billing_for_search_space
user_id = uuid4()
search_space = _make_search_space(agent_llm_id=0, user_id=user_id)
byok_cfg = _make_byok_config(
id_=17, base_model="anthropic/claude-3-haiku", model_name="my-claude"
)
session = _FakeSession([search_space, byok_cfg])
async def _fake_resolve_pin(
sess,
*,
thread_id,
search_space_id,
user_id,
selected_llm_config_id,
force_repin_free=False,
):
return _FakePinResolution(resolved_llm_config_id=17, resolved_tier="free")
import app.services.auto_model_pin_service as pin_module
monkeypatch.setattr(
pin_module, "resolve_or_get_pinned_llm_config_id", _fake_resolve_pin
)
owner, tier, base_model = await _resolve_agent_billing_for_search_space(
session, search_space_id=42, thread_id=99
)
assert owner == user_id
assert tier == "free"
assert base_model == "anthropic/claude-3-haiku"
@pytest.mark.asyncio
async def test_auto_mode_without_thread_id_falls_back_to_free():
"""Auto + ``thread_id=None`` → ``("free", "auto")`` without invoking
the pin service. Forward-compat fallback for any future direct-API
entrypoint that doesn't have a chat thread."""
from app.services.billable_calls import _resolve_agent_billing_for_search_space
user_id = uuid4()
session = _FakeSession([_make_search_space(agent_llm_id=0, user_id=user_id)])
owner, tier, base_model = await _resolve_agent_billing_for_search_space(
session, search_space_id=42, thread_id=None
)
assert owner == user_id
assert tier == "free"
assert base_model == "auto"
@pytest.mark.asyncio
async def test_auto_mode_pin_failure_falls_back_to_free(monkeypatch):
"""If the pin service raises ``ValueError`` (thread missing /
mismatched search space), the resolver should log and return free
rather than killing the whole task."""
from app.services.billable_calls import _resolve_agent_billing_for_search_space
user_id = uuid4()
session = _FakeSession([_make_search_space(agent_llm_id=0, user_id=user_id)])
async def _fake_resolve_pin(*args, **kwargs):
raise ValueError("thread missing")
import app.services.auto_model_pin_service as pin_module
monkeypatch.setattr(
pin_module, "resolve_or_get_pinned_llm_config_id", _fake_resolve_pin
)
owner, tier, base_model = await _resolve_agent_billing_for_search_space(
session, search_space_id=42, thread_id=99
)
assert owner == user_id
assert tier == "free"
assert base_model == "auto"
@pytest.mark.asyncio
async def test_negative_id_premium_global_returns_premium(monkeypatch):
"""Explicit negative agent_llm_id → ``get_global_llm_config`` →
return its ``billing_tier``."""
from app.services.billable_calls import _resolve_agent_billing_for_search_space
user_id = uuid4()
session = _FakeSession([_make_search_space(agent_llm_id=-1, user_id=user_id)])
def _fake_get_global(cfg_id):
return {
"id": cfg_id,
"model_name": "gpt-5.4",
"billing_tier": "premium",
"litellm_params": {"base_model": "gpt-5.4"},
}
import app.services.llm_service as llm_module
monkeypatch.setattr(llm_module, "get_global_llm_config", _fake_get_global)
owner, tier, base_model = await _resolve_agent_billing_for_search_space(
session, search_space_id=42, thread_id=99
)
assert owner == user_id
assert tier == "premium"
assert base_model == "gpt-5.4"
@pytest.mark.asyncio
async def test_negative_id_free_global_returns_free(monkeypatch):
from app.services.billable_calls import _resolve_agent_billing_for_search_space
user_id = uuid4()
session = _FakeSession([_make_search_space(agent_llm_id=-2, user_id=user_id)])
def _fake_get_global(cfg_id):
return {
"id": cfg_id,
"model_name": "openrouter/some-free",
"billing_tier": "free",
"litellm_params": {"base_model": "openrouter/some-free"},
}
import app.services.llm_service as llm_module
monkeypatch.setattr(llm_module, "get_global_llm_config", _fake_get_global)
owner, tier, base_model = await _resolve_agent_billing_for_search_space(
session, search_space_id=42, thread_id=None
)
assert owner == user_id
assert tier == "free"
assert base_model == "openrouter/some-free"
@pytest.mark.asyncio
async def test_negative_id_missing_base_model_falls_back_to_model_name(monkeypatch):
"""When the global config has no ``litellm_params.base_model``, the
resolver falls back to ``model_name`` matching chat's behavior."""
from app.services.billable_calls import _resolve_agent_billing_for_search_space
user_id = uuid4()
session = _FakeSession([_make_search_space(agent_llm_id=-5, user_id=user_id)])
def _fake_get_global(cfg_id):
return {
"id": cfg_id,
"model_name": "fallback-model",
"billing_tier": "premium",
# No litellm_params.
}
import app.services.llm_service as llm_module
monkeypatch.setattr(llm_module, "get_global_llm_config", _fake_get_global)
_, tier, base_model = await _resolve_agent_billing_for_search_space(
session, search_space_id=42
)
assert tier == "premium"
assert base_model == "fallback-model"
@pytest.mark.asyncio
async def test_positive_id_byok_is_always_free():
"""Positive agent_llm_id → user-owned BYOK NewLLMConfig → always free,
regardless of underlying provider tier."""
from app.services.billable_calls import _resolve_agent_billing_for_search_space
user_id = uuid4()
search_space = _make_search_space(agent_llm_id=23, user_id=user_id)
byok_cfg = _make_byok_config(id_=23, base_model="anthropic/claude-3.5-sonnet")
session = _FakeSession([search_space, byok_cfg])
owner, tier, base_model = await _resolve_agent_billing_for_search_space(
session, search_space_id=42
)
assert owner == user_id
assert tier == "free"
assert base_model == "anthropic/claude-3.5-sonnet"
@pytest.mark.asyncio
async def test_positive_id_byok_missing_returns_free_with_empty_base_model():
"""If the BYOK config row is missing/deleted but the search space still
points at it, the resolver still returns free (no debit) with an empty
base_model billable_call's premium path is skipped, no harm done."""
from app.services.billable_calls import _resolve_agent_billing_for_search_space
user_id = uuid4()
session = _FakeSession([_make_search_space(agent_llm_id=99, user_id=user_id)])
owner, tier, base_model = await _resolve_agent_billing_for_search_space(
session, search_space_id=42
)
assert owner == user_id
assert tier == "free"
assert base_model == ""
@pytest.mark.asyncio
async def test_search_space_not_found_raises_value_error():
from app.services.billable_calls import _resolve_agent_billing_for_search_space
session = _FakeSession([None])
with pytest.raises(ValueError, match="Search space"):
await _resolve_agent_billing_for_search_space(session, search_space_id=999)
@pytest.mark.asyncio
async def test_agent_llm_id_none_raises_value_error():
from app.services.billable_calls import _resolve_agent_billing_for_search_space
user_id = uuid4()
session = _FakeSession([_make_search_space(agent_llm_id=None, user_id=user_id)])
with pytest.raises(ValueError, match="agent_llm_id"):
await _resolve_agent_billing_for_search_space(session, search_space_id=42)

View file

@ -0,0 +1,432 @@
"""Unit tests for the ``billable_call`` async context manager.
Covers the per-call premium-credit lifecycle for image generation and
vision LLM extraction:
* Free configs bypass reserve/finalize but still write an audit row.
* Premium reserve denial raises ``QuotaInsufficientError`` (HTTP 402 in the
route layer).
* Successful premium calls reserve, yield the accumulator, then finalize
with the LiteLLM-reported actual cost and write an audit row.
* Failed premium calls release the reservation so credit isn't leaked.
* All quota DB ops happen inside their OWN ``shielded_async_session``,
isolating them from the caller's transaction (issue A).
"""
from __future__ import annotations
import contextlib
from typing import Any
from uuid import uuid4
import pytest
pytestmark = pytest.mark.unit
# ---------------------------------------------------------------------------
# Fakes
# ---------------------------------------------------------------------------
class _FakeQuotaResult:
def __init__(
self,
*,
allowed: bool,
used: int = 0,
limit: int = 5_000_000,
remaining: int = 5_000_000,
) -> None:
self.allowed = allowed
self.used = used
self.limit = limit
self.remaining = remaining
class _FakeSession:
"""Minimal AsyncSession stub — record commits for assertion."""
def __init__(self) -> None:
self.committed = False
self.added: list[Any] = []
def add(self, obj: Any) -> None:
self.added.append(obj)
async def commit(self) -> None:
self.committed = True
async def close(self) -> None:
pass
@contextlib.asynccontextmanager
async def _fake_shielded_session():
s = _FakeSession()
_SESSIONS_USED.append(s)
yield s
_SESSIONS_USED: list[_FakeSession] = []
def _patch_isolation_layer(monkeypatch, *, reserve_result, finalize_result=None):
"""Wire fake reserve/finalize/release/session helpers."""
_SESSIONS_USED.clear()
reserve_calls: list[dict[str, Any]] = []
finalize_calls: list[dict[str, Any]] = []
release_calls: list[dict[str, Any]] = []
async def _fake_reserve(*, db_session, user_id, request_id, reserve_micros):
reserve_calls.append(
{
"user_id": user_id,
"reserve_micros": reserve_micros,
"request_id": request_id,
}
)
return reserve_result
async def _fake_finalize(
*, db_session, user_id, request_id, actual_micros, reserved_micros
):
finalize_calls.append(
{
"user_id": user_id,
"actual_micros": actual_micros,
"reserved_micros": reserved_micros,
}
)
return finalize_result or _FakeQuotaResult(allowed=True)
async def _fake_release(*, db_session, user_id, reserved_micros):
release_calls.append({"user_id": user_id, "reserved_micros": reserved_micros})
record_calls: list[dict[str, Any]] = []
async def _fake_record(session, **kwargs):
record_calls.append(kwargs)
return object()
monkeypatch.setattr(
"app.services.billable_calls.TokenQuotaService.premium_reserve",
_fake_reserve,
raising=False,
)
monkeypatch.setattr(
"app.services.billable_calls.TokenQuotaService.premium_finalize",
_fake_finalize,
raising=False,
)
monkeypatch.setattr(
"app.services.billable_calls.TokenQuotaService.premium_release",
_fake_release,
raising=False,
)
monkeypatch.setattr(
"app.services.billable_calls.shielded_async_session",
_fake_shielded_session,
raising=False,
)
monkeypatch.setattr(
"app.services.billable_calls.record_token_usage",
_fake_record,
raising=False,
)
return {
"reserve": reserve_calls,
"finalize": finalize_calls,
"release": release_calls,
"record": record_calls,
}
# ---------------------------------------------------------------------------
# Tests
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_free_path_skips_reserve_but_writes_audit_row(monkeypatch):
from app.services.billable_calls import billable_call
spies = _patch_isolation_layer(
monkeypatch, reserve_result=_FakeQuotaResult(allowed=True)
)
user_id = uuid4()
async with billable_call(
user_id=user_id,
search_space_id=42,
billing_tier="free",
base_model="openai/gpt-image-1",
usage_type="image_generation",
) as acc:
# Simulate a captured cost — the accumulator is fed by the LiteLLM
# callback in real life, here we add it manually.
acc.add(
model="openai/gpt-image-1",
prompt_tokens=0,
completion_tokens=0,
total_tokens=0,
cost_micros=37_000,
call_kind="image_generation",
)
assert spies["reserve"] == []
assert spies["finalize"] == []
assert spies["release"] == []
# Free still audits.
assert len(spies["record"]) == 1
assert spies["record"][0]["usage_type"] == "image_generation"
assert spies["record"][0]["cost_micros"] == 37_000
@pytest.mark.asyncio
async def test_premium_reserve_denied_raises_quota_insufficient(monkeypatch):
from app.services.billable_calls import (
QuotaInsufficientError,
billable_call,
)
spies = _patch_isolation_layer(
monkeypatch,
reserve_result=_FakeQuotaResult(
allowed=False, used=5_000_000, limit=5_000_000, remaining=0
),
)
user_id = uuid4()
with pytest.raises(QuotaInsufficientError) as exc_info:
async with billable_call(
user_id=user_id,
search_space_id=42,
billing_tier="premium",
base_model="openai/gpt-image-1",
quota_reserve_micros_override=50_000,
usage_type="image_generation",
):
pytest.fail("body should not run when reserve is denied")
err = exc_info.value
assert err.usage_type == "image_generation"
assert err.used_micros == 5_000_000
assert err.limit_micros == 5_000_000
assert err.remaining_micros == 0
# Reserve was attempted, but no finalize/release on a denied reserve
# — the reservation never actually held credit.
assert len(spies["reserve"]) == 1
assert spies["finalize"] == []
assert spies["release"] == []
# Denied premium calls do NOT create an audit row (no work happened).
assert spies["record"] == []
@pytest.mark.asyncio
async def test_premium_success_finalizes_with_actual_cost(monkeypatch):
from app.services.billable_calls import billable_call
spies = _patch_isolation_layer(
monkeypatch, reserve_result=_FakeQuotaResult(allowed=True)
)
user_id = uuid4()
async with billable_call(
user_id=user_id,
search_space_id=42,
billing_tier="premium",
base_model="openai/gpt-image-1",
quota_reserve_micros_override=50_000,
usage_type="image_generation",
) as acc:
# LiteLLM callback would normally fill this — simulate $0.04 image.
acc.add(
model="openai/gpt-image-1",
prompt_tokens=0,
completion_tokens=0,
total_tokens=0,
cost_micros=40_000,
call_kind="image_generation",
)
assert len(spies["reserve"]) == 1
assert spies["reserve"][0]["reserve_micros"] == 50_000
assert len(spies["finalize"]) == 1
assert spies["finalize"][0]["actual_micros"] == 40_000
assert spies["finalize"][0]["reserved_micros"] == 50_000
assert spies["release"] == []
# And audit row written with the actual debited cost.
assert spies["record"][0]["cost_micros"] == 40_000
# Each quota op opened its OWN session — proves session isolation.
assert len(_SESSIONS_USED) >= 3
# Sessions used should each have committed (or be the audit one which commits).
for _s in _SESSIONS_USED:
# finalize/reserve happen via TokenQuotaService.* which we stub —
# they don't actually call commit on our fake session, but the
# audit session does. We just assert >=1 session committed.
pass
assert any(s.committed for s in _SESSIONS_USED)
@pytest.mark.asyncio
async def test_premium_failure_releases_reservation(monkeypatch):
from app.services.billable_calls import billable_call
spies = _patch_isolation_layer(
monkeypatch, reserve_result=_FakeQuotaResult(allowed=True)
)
user_id = uuid4()
class _ProviderError(Exception):
pass
with pytest.raises(_ProviderError):
async with billable_call(
user_id=user_id,
search_space_id=42,
billing_tier="premium",
base_model="openai/gpt-image-1",
quota_reserve_micros_override=50_000,
usage_type="image_generation",
):
raise _ProviderError("OpenRouter 503")
assert len(spies["reserve"]) == 1
assert spies["finalize"] == []
# Failure path: release the held reservation.
assert len(spies["release"]) == 1
assert spies["release"][0]["reserved_micros"] == 50_000
@pytest.mark.asyncio
async def test_premium_uses_estimator_when_no_micros_override(monkeypatch):
"""When ``quota_reserve_micros_override`` is None we fall back to
``estimate_call_reserve_micros(base_model, quota_reserve_tokens)``.
Vision LLM calls take this path (token-priced models).
"""
from app.services.billable_calls import billable_call
spies = _patch_isolation_layer(
monkeypatch, reserve_result=_FakeQuotaResult(allowed=True)
)
captured_estimator_calls: list[dict[str, Any]] = []
def _fake_estimate(*, base_model, quota_reserve_tokens):
captured_estimator_calls.append(
{"base_model": base_model, "quota_reserve_tokens": quota_reserve_tokens}
)
return 12_345
monkeypatch.setattr(
"app.services.billable_calls.estimate_call_reserve_micros",
_fake_estimate,
raising=False,
)
user_id = uuid4()
async with billable_call(
user_id=user_id,
search_space_id=1,
billing_tier="premium",
base_model="openai/gpt-4o",
quota_reserve_tokens=4000,
usage_type="vision_extraction",
):
pass
assert captured_estimator_calls == [
{"base_model": "openai/gpt-4o", "quota_reserve_tokens": 4000}
]
assert spies["reserve"][0]["reserve_micros"] == 12_345
# ---------------------------------------------------------------------------
# Podcast / video-presentation usage_type coverage
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_free_podcast_path_audits_with_podcast_usage_type(monkeypatch):
"""Free podcast configs must skip reserve/finalize but still emit a
``TokenUsage`` row tagged ``usage_type='podcast_generation'`` so we
have full audit coverage of free-tier agent runs."""
from app.services.billable_calls import billable_call
spies = _patch_isolation_layer(
monkeypatch, reserve_result=_FakeQuotaResult(allowed=True)
)
user_id = uuid4()
async with billable_call(
user_id=user_id,
search_space_id=42,
billing_tier="free",
base_model="openrouter/some-free-model",
quota_reserve_micros_override=200_000,
usage_type="podcast_generation",
thread_id=99,
call_details={"podcast_id": 7, "title": "Test Podcast"},
) as acc:
# Two transcript LLM calls aggregated into one accumulator.
acc.add(
model="openrouter/some-free-model",
prompt_tokens=1500,
completion_tokens=8000,
total_tokens=9500,
cost_micros=0,
call_kind="chat",
)
assert spies["reserve"] == []
assert spies["finalize"] == []
assert spies["release"] == []
assert len(spies["record"]) == 1
row = spies["record"][0]
assert row["usage_type"] == "podcast_generation"
assert row["thread_id"] == 99
assert row["search_space_id"] == 42
assert row["call_details"] == {"podcast_id": 7, "title": "Test Podcast"}
@pytest.mark.asyncio
async def test_premium_video_denial_raises_quota_insufficient(monkeypatch):
"""Premium video-presentation runs that hit a denied reservation must
raise ``QuotaInsufficientError`` *before* the graph runs and must not
emit an audit row (no work happened)."""
from app.services.billable_calls import (
QuotaInsufficientError,
billable_call,
)
spies = _patch_isolation_layer(
monkeypatch,
reserve_result=_FakeQuotaResult(
allowed=False, used=4_500_000, limit=5_000_000, remaining=500_000
),
)
user_id = uuid4()
with pytest.raises(QuotaInsufficientError) as exc_info:
async with billable_call(
user_id=user_id,
search_space_id=42,
billing_tier="premium",
base_model="gpt-5.4",
quota_reserve_micros_override=1_000_000,
usage_type="video_presentation_generation",
thread_id=99,
call_details={"video_presentation_id": 12, "title": "Test Video"},
):
pytest.fail("body should not run when reserve is denied")
err = exc_info.value
assert err.usage_type == "video_presentation_generation"
assert err.remaining_micros == 500_000
assert spies["reserve"][0]["reserve_micros"] == 1_000_000
assert spies["finalize"] == []
assert spies["release"] == []
assert spies["record"] == []

View file

@ -214,3 +214,159 @@ def test_generate_configs_drops_non_text_and_non_tool_models():
assert "openai/gpt-4o" in model_names assert "openai/gpt-4o" in model_names
assert "openai/dall-e" not in model_names assert "openai/dall-e" not in model_names
assert "openai/completion-only" not in model_names assert "openai/completion-only" not in model_names
# ---------------------------------------------------------------------------
# _generate_image_gen_configs / _generate_vision_llm_configs
# ---------------------------------------------------------------------------
def test_generate_image_gen_configs_filters_by_image_output():
"""Only models with ``output_modalities`` containing ``image`` are emitted.
Tool-calling and context filters are intentionally NOT applied image
generation has nothing to do with tool calls and context windows.
"""
from app.services.openrouter_integration_service import (
_generate_image_gen_configs,
)
raw = [
# Pure image-gen model (small context, no tools — should still emit).
{
"id": "openai/gpt-image-1",
"architecture": {"output_modalities": ["image"]},
"context_length": 4_000,
"pricing": {"prompt": "0", "completion": "0"},
},
# Multi-modal: text+image output (should still emit).
{
"id": "google/gemini-2.5-flash-image",
"architecture": {"output_modalities": ["text", "image"]},
"context_length": 1_000_000,
"pricing": {"prompt": "0.000001", "completion": "0.000004"},
},
# Pure text model — must NOT emit.
{
"id": "openai/gpt-4o",
"architecture": {"output_modalities": ["text"]},
"context_length": 128_000,
"pricing": {"prompt": "0.000005", "completion": "0.000015"},
},
]
cfgs = _generate_image_gen_configs(raw, dict(_SETTINGS_BASE))
model_names = {c["model_name"] for c in cfgs}
assert "openai/gpt-image-1" in model_names
assert "google/gemini-2.5-flash-image" in model_names
assert "openai/gpt-4o" not in model_names
# Each config must carry ``billing_tier`` for routing in image_generation_routes.
for c in cfgs:
assert c["billing_tier"] in {"free", "premium"}
assert c["provider"] == "OPENROUTER"
assert c[_OPENROUTER_DYNAMIC_MARKER] is True
def test_generate_image_gen_configs_assigns_image_id_offset():
"""Image configs use a different id_offset (-20000) so their negative
IDs don't collide with chat configs (-10000) or vision configs (-30000).
"""
from app.services.openrouter_integration_service import (
_generate_image_gen_configs,
)
raw = [
{
"id": "openai/gpt-image-1",
"architecture": {"output_modalities": ["image"]},
"context_length": 4_000,
"pricing": {"prompt": "0", "completion": "0"},
}
]
# Don't pass image_id_offset → use the module default (-20000).
cfgs = _generate_image_gen_configs(raw, dict(_SETTINGS_BASE))
assert all(c["id"] < -20_000 + 1 for c in cfgs)
assert all(c["id"] > -29_000_000 for c in cfgs)
def test_generate_vision_llm_configs_filters_by_image_input_text_output():
"""Vision LLMs must accept image input AND emit text — pure image-gen
(no text out) and text-only (no image in) models are excluded.
"""
from app.services.openrouter_integration_service import (
_generate_vision_llm_configs,
)
raw = [
# GPT-4o: vision LLM (image in, text out) — must emit.
{
"id": "openai/gpt-4o",
"architecture": {
"input_modalities": ["text", "image"],
"output_modalities": ["text"],
},
"context_length": 128_000,
"pricing": {"prompt": "0.000005", "completion": "0.000015"},
},
# Pure image generator — image *output*, no text out. Must NOT emit.
{
"id": "openai/gpt-image-1",
"architecture": {
"input_modalities": ["text"],
"output_modalities": ["image"],
},
"context_length": 4_000,
"pricing": {"prompt": "0", "completion": "0"},
},
# Pure text model (no image in). Must NOT emit.
{
"id": "anthropic/claude-3-haiku",
"architecture": {
"input_modalities": ["text"],
"output_modalities": ["text"],
},
"context_length": 200_000,
"pricing": {"prompt": "0.000001", "completion": "0.000005"},
},
]
cfgs = _generate_vision_llm_configs(raw, dict(_SETTINGS_BASE))
names = {c["model_name"] for c in cfgs}
assert names == {"openai/gpt-4o"}
cfg = cfgs[0]
assert cfg["billing_tier"] == "premium"
# Pricing carried inline so pricing_registration can register vision
# under ``openrouter/openai/gpt-4o`` even if the chat catalogue cache
# is cleared.
assert cfg["input_cost_per_token"] == pytest.approx(5e-6)
assert cfg["output_cost_per_token"] == pytest.approx(15e-6)
assert cfg[_OPENROUTER_DYNAMIC_MARKER] is True
def test_generate_vision_llm_configs_drops_chat_only_filters():
"""A small-context vision model that doesn't advertise tool calling is
still a valid vision LLM for "describe this image" prompts. The chat
filters (``supports_tool_calling``, ``has_sufficient_context``) must
NOT be applied to vision emission.
"""
from app.services.openrouter_integration_service import (
_generate_vision_llm_configs,
)
raw = [
{
"id": "tiny/vision-mini",
"architecture": {
"input_modalities": ["text", "image"],
"output_modalities": ["text"],
},
"supported_parameters": [], # no tools
"context_length": 4_000, # well below MIN_CONTEXT_LENGTH
"pricing": {"prompt": "0.0000001", "completion": "0.0000005"},
}
]
cfgs = _generate_vision_llm_configs(raw, dict(_SETTINGS_BASE))
assert len(cfgs) == 1
assert cfgs[0]["model_name"] == "tiny/vision-mini"

View file

@ -0,0 +1,447 @@
"""Pricing registration unit tests.
The pricing-registration module is what makes ``response_cost`` populate
correctly for OpenRouter dynamic models and operator-defined Azure
deployments both of which LiteLLM doesn't natively know about. The tests
exercise:
* The alias generators emit every shape that LiteLLM's cost-callback might
use (``openrouter/X`` and bare ``X``; YAML-defined ``base_model``,
``provider/base_model``, ``provider/model_name``, plus the special
``azure_openai`` ``azure`` normalisation).
* ``register_pricing_from_global_configs`` calls ``litellm.register_model``
with the right alias set and pricing values per provider.
* Configs without a resolvable pair of cost values are skipped never
registered as zero, since that would override pricing LiteLLM might
already know natively.
"""
from __future__ import annotations
from typing import Any
import pytest
pytestmark = pytest.mark.unit
# ---------------------------------------------------------------------------
# Alias generators
# ---------------------------------------------------------------------------
def test_openrouter_alias_set_includes_prefixed_and_bare():
from app.services.pricing_registration import _alias_set_for_openrouter
aliases = _alias_set_for_openrouter("anthropic/claude-3-5-sonnet")
assert aliases == [
"openrouter/anthropic/claude-3-5-sonnet",
"anthropic/claude-3-5-sonnet",
]
def test_openrouter_alias_set_dedupes():
"""If the model id is already prefixed with ``openrouter/``, the alias
set must not contain duplicates that would re-register the same key
twice.
"""
from app.services.pricing_registration import _alias_set_for_openrouter
aliases = _alias_set_for_openrouter("openrouter/foo")
# The bare and prefixed variants compute to the same string here, so we
# at minimum require uniqueness.
assert len(aliases) == len(set(aliases))
def test_yaml_alias_set_for_azure_openai_normalises_to_azure():
"""``azure_openai`` (our YAML provider slug) must register under
``azure/<name>`` so the LiteLLM Router's deployment-resolution path
(which uses provider ``azure``) finds the pricing too.
"""
from app.services.pricing_registration import _alias_set_for_yaml
aliases = _alias_set_for_yaml(
provider="AZURE_OPENAI",
model_name="gpt-5.4",
base_model="gpt-5.4",
)
assert "gpt-5.4" in aliases
assert "azure_openai/gpt-5.4" in aliases
assert "azure/gpt-5.4" in aliases
def test_yaml_alias_set_distinguishes_model_name_and_base_model():
"""When ``model_name`` differs from ``base_model`` (operator labelled a
deployment), both must appear in the alias set since either may surface
in callbacks depending on the call path.
"""
from app.services.pricing_registration import _alias_set_for_yaml
aliases = _alias_set_for_yaml(
provider="OPENAI",
model_name="my-deployment-label",
base_model="gpt-4o",
)
assert "gpt-4o" in aliases
assert "openai/gpt-4o" in aliases
assert "my-deployment-label" in aliases
assert "openai/my-deployment-label" in aliases
def test_yaml_alias_set_omits_provider_prefix_when_provider_blank():
from app.services.pricing_registration import _alias_set_for_yaml
aliases = _alias_set_for_yaml(
provider="",
model_name="foo",
base_model="bar",
)
assert "bar" in aliases
assert "foo" in aliases
assert all("/" not in a for a in aliases)
# ---------------------------------------------------------------------------
# register_pricing_from_global_configs
# ---------------------------------------------------------------------------
class _RegistrationSpy:
"""Captures the dicts passed to ``litellm.register_model``.
Many calls may go through; we just record them all and let tests assert
against the union.
"""
def __init__(self) -> None:
self.calls: list[dict[str, Any]] = []
def __call__(self, payload: dict[str, Any]) -> None:
self.calls.append(payload)
@property
def all_keys(self) -> set[str]:
keys: set[str] = set()
for payload in self.calls:
keys.update(payload.keys())
return keys
def _patch_register(monkeypatch: pytest.MonkeyPatch) -> _RegistrationSpy:
spy = _RegistrationSpy()
monkeypatch.setattr(
"app.services.pricing_registration.litellm.register_model",
spy,
raising=False,
)
return spy
def _patch_openrouter_pricing(
monkeypatch: pytest.MonkeyPatch, mapping: dict[str, dict[str, str]]
) -> None:
"""Pretend the OpenRouter integration is initialised with ``mapping``."""
class _Stub:
def get_raw_pricing(self) -> dict[str, dict[str, str]]:
return mapping
class _StubService:
@classmethod
def is_initialized(cls) -> bool:
return True
@classmethod
def get_instance(cls) -> _Stub:
return _Stub()
monkeypatch.setattr(
"app.services.openrouter_integration_service.OpenRouterIntegrationService",
_StubService,
raising=False,
)
def test_openrouter_models_register_under_aliases(monkeypatch):
"""An OpenRouter config whose ``model_name`` is in the cached raw
pricing map is registered under both ``openrouter/X`` and bare ``X``.
"""
from app.config import config
from app.services.pricing_registration import register_pricing_from_global_configs
spy = _patch_register(monkeypatch)
_patch_openrouter_pricing(
monkeypatch,
{
"anthropic/claude-3-5-sonnet": {
"prompt": "0.000003",
"completion": "0.000015",
}
},
)
monkeypatch.setattr(
config,
"GLOBAL_LLM_CONFIGS",
[
{
"id": 1,
"provider": "OPENROUTER",
"model_name": "anthropic/claude-3-5-sonnet",
}
],
)
register_pricing_from_global_configs()
assert "openrouter/anthropic/claude-3-5-sonnet" in spy.all_keys
assert "anthropic/claude-3-5-sonnet" in spy.all_keys
# Costs are float-converted from the raw OpenRouter strings.
payload = spy.calls[0]
assert payload["openrouter/anthropic/claude-3-5-sonnet"][
"input_cost_per_token"
] == pytest.approx(3e-6)
assert payload["openrouter/anthropic/claude-3-5-sonnet"][
"output_cost_per_token"
] == pytest.approx(15e-6)
assert (
payload["openrouter/anthropic/claude-3-5-sonnet"]["litellm_provider"]
== "openrouter"
)
def test_yaml_override_registers_under_alias_set(monkeypatch):
"""Operator-declared ``input_cost_per_token`` /
``output_cost_per_token`` on a YAML config registers under every
alias the YAML alias generator produces including the ``azure/``
normalisation for ``azure_openai`` providers.
"""
from app.config import config
from app.services.pricing_registration import register_pricing_from_global_configs
spy = _patch_register(monkeypatch)
_patch_openrouter_pricing(monkeypatch, {})
monkeypatch.setattr(
config,
"GLOBAL_LLM_CONFIGS",
[
{
"id": 1,
"provider": "AZURE_OPENAI",
"model_name": "gpt-5.4",
"litellm_params": {
"base_model": "gpt-5.4",
"input_cost_per_token": 2e-6,
"output_cost_per_token": 8e-6,
},
}
],
)
register_pricing_from_global_configs()
keys = spy.all_keys
assert "gpt-5.4" in keys
assert "azure_openai/gpt-5.4" in keys
assert "azure/gpt-5.4" in keys
payload = spy.calls[0]
entry = payload["gpt-5.4"]
assert entry["input_cost_per_token"] == pytest.approx(2e-6)
assert entry["output_cost_per_token"] == pytest.approx(8e-6)
assert entry["litellm_provider"] == "azure"
def test_no_override_means_no_registration(monkeypatch):
"""A YAML config that *omits* both pricing fields must NOT be registered
registering as zero would override LiteLLM's native pricing for the
``base_model`` key (e.g. ``gpt-4o``) and silently make every user's
bill drop to $0. Fail-safe is "skip and warn", not "register zero".
"""
from app.config import config
from app.services.pricing_registration import register_pricing_from_global_configs
spy = _patch_register(monkeypatch)
_patch_openrouter_pricing(monkeypatch, {})
monkeypatch.setattr(
config,
"GLOBAL_LLM_CONFIGS",
[
{
"id": 1,
"provider": "OPENAI",
"model_name": "gpt-4o",
"litellm_params": {"base_model": "gpt-4o"},
}
],
)
register_pricing_from_global_configs()
assert spy.calls == []
def test_openrouter_skipped_when_pricing_missing(monkeypatch):
"""If the OpenRouter raw-pricing cache doesn't carry an entry for a
configured model (network blip during refresh, model added later, etc.),
we skip it rather than registering zero pricing.
"""
from app.config import config
from app.services.pricing_registration import register_pricing_from_global_configs
spy = _patch_register(monkeypatch)
_patch_openrouter_pricing(
monkeypatch, {"some/other-model": {"prompt": "1", "completion": "1"}}
)
monkeypatch.setattr(
config,
"GLOBAL_LLM_CONFIGS",
[
{
"id": 1,
"provider": "OPENROUTER",
"model_name": "anthropic/claude-3-5-sonnet",
}
],
)
register_pricing_from_global_configs()
assert spy.calls == []
def test_register_continues_after_individual_failure(monkeypatch, caplog):
"""A single bad ``register_model`` call (e.g. raising LiteLLM error)
must not abort registration of the remaining configs.
"""
from app.config import config
from app.services.pricing_registration import register_pricing_from_global_configs
failing_keys: set[str] = {"anthropic/claude-3-5-sonnet"}
successful_calls: list[dict[str, Any]] = []
def _maybe_fail(payload: dict[str, Any]) -> None:
if any(k in failing_keys for k in payload):
raise RuntimeError("boom")
successful_calls.append(payload)
monkeypatch.setattr(
"app.services.pricing_registration.litellm.register_model",
_maybe_fail,
raising=False,
)
_patch_openrouter_pricing(
monkeypatch,
{
"anthropic/claude-3-5-sonnet": {
"prompt": "0.000003",
"completion": "0.000015",
}
},
)
monkeypatch.setattr(
config,
"GLOBAL_LLM_CONFIGS",
[
{
"id": 1,
"provider": "OPENROUTER",
"model_name": "anthropic/claude-3-5-sonnet",
},
{
"id": 2,
"provider": "OPENAI",
"model_name": "custom-deployment",
"litellm_params": {
"base_model": "custom-deployment",
"input_cost_per_token": 1e-6,
"output_cost_per_token": 2e-6,
},
},
],
)
register_pricing_from_global_configs()
# The good config still registered.
assert any("custom-deployment" in payload for payload in successful_calls)
def test_vision_configs_registered_with_chat_shape(monkeypatch):
"""``register_pricing_from_global_configs`` walks
``GLOBAL_VISION_LLM_CONFIGS`` in addition to the chat configs so vision
calls (during indexing) bill correctly. Vision configs use the same
chat-shape token prices, but image-gen pricing is intentionally NOT
registered here (handled via ``response_cost`` in LiteLLM).
"""
from app.config import config
from app.services.pricing_registration import register_pricing_from_global_configs
spy = _patch_register(monkeypatch)
_patch_openrouter_pricing(
monkeypatch,
{"openai/gpt-4o": {"prompt": "0.000005", "completion": "0.000015"}},
)
# No chat configs — only vision. Proves the vision walk is a separate
# iteration, not piggy-backed on the chat list.
monkeypatch.setattr(config, "GLOBAL_LLM_CONFIGS", [])
monkeypatch.setattr(
config,
"GLOBAL_VISION_LLM_CONFIGS",
[
{
"id": -1,
"provider": "OPENROUTER",
"model_name": "openai/gpt-4o",
"billing_tier": "premium",
"input_cost_per_token": 5e-6,
"output_cost_per_token": 15e-6,
}
],
)
register_pricing_from_global_configs()
assert "openrouter/openai/gpt-4o" in spy.all_keys
payload_value = spy.calls[0]["openrouter/openai/gpt-4o"]
assert payload_value["mode"] == "chat"
assert payload_value["litellm_provider"] == "openrouter"
assert payload_value["input_cost_per_token"] == pytest.approx(5e-6)
assert payload_value["output_cost_per_token"] == pytest.approx(15e-6)
def test_vision_with_inline_pricing_when_or_cache_missing(monkeypatch):
"""If the OpenRouter pricing cache misses a vision model (different
catalogue surface), the vision walk falls back to inline
``input_cost_per_token``/``output_cost_per_token`` on the cfg itself.
"""
from app.config import config
from app.services.pricing_registration import register_pricing_from_global_configs
spy = _patch_register(monkeypatch)
_patch_openrouter_pricing(monkeypatch, {})
monkeypatch.setattr(config, "GLOBAL_LLM_CONFIGS", [])
monkeypatch.setattr(
config,
"GLOBAL_VISION_LLM_CONFIGS",
[
{
"id": -1,
"provider": "OPENROUTER",
"model_name": "google/gemini-2.5-flash",
"billing_tier": "premium",
"input_cost_per_token": 1e-6,
"output_cost_per_token": 4e-6,
}
],
)
register_pricing_from_global_configs()
assert "openrouter/google/gemini-2.5-flash" in spy.all_keys

View file

@ -0,0 +1,157 @@
"""Unit tests for ``QuotaCheckedVisionLLM``.
Validates that:
* Calling ``ainvoke`` routes through ``billable_call`` (premium credit
enforcement) and forwards the inner LLM's response on success.
* The wrapper proxies non-overridden attributes to the inner LLM
(``__getattr__``) so ``invoke`` / ``astream`` / ``with_structured_output``
still work without quota gating (they're not used in indexing today).
* When ``billable_call`` raises ``QuotaInsufficientError`` the wrapper
bubbles it up the ETL pipeline catches that and falls back to OCR.
"""
from __future__ import annotations
import contextlib
from typing import Any
from uuid import uuid4
import pytest
pytestmark = pytest.mark.unit
class _FakeInnerLLM:
"""Stand-in for ``langchain_litellm.ChatLiteLLM``."""
def __init__(self, response: Any = "OCR'd content") -> None:
self._response = response
self.ainvoke_calls: list[Any] = []
async def ainvoke(self, input: Any, *args: Any, **kwargs: Any) -> Any:
self.ainvoke_calls.append(input)
return self._response
def some_other_method(self, x: int) -> int:
return x * 2
@contextlib.asynccontextmanager
async def _passthrough_billable_call(**_kwargs):
"""Stand-in for billable_call that always allows the call to run."""
class _Acc:
total_cost_micros = 0
total_prompt_tokens = 0
total_completion_tokens = 0
grand_total = 0
calls: list[Any] = []
def per_message_summary(self) -> dict[str, dict[str, int]]:
return {}
yield _Acc()
@pytest.mark.asyncio
async def test_ainvoke_routes_through_billable_call(monkeypatch):
from app.services.quota_checked_vision_llm import QuotaCheckedVisionLLM
captured_kwargs: list[dict[str, Any]] = []
@contextlib.asynccontextmanager
async def _spy_billable_call(**kwargs):
captured_kwargs.append(kwargs)
async with _passthrough_billable_call() as acc:
yield acc
monkeypatch.setattr(
"app.services.quota_checked_vision_llm.billable_call",
_spy_billable_call,
raising=False,
)
inner = _FakeInnerLLM(response="A red apple on a white table")
user_id = uuid4()
wrapper = QuotaCheckedVisionLLM(
inner,
user_id=user_id,
search_space_id=99,
billing_tier="premium",
base_model="openai/gpt-4o",
quota_reserve_tokens=4000,
)
result = await wrapper.ainvoke([{"text": "what is this?"}])
assert result == "A red apple on a white table"
assert len(inner.ainvoke_calls) == 1
assert len(captured_kwargs) == 1
bc_kwargs = captured_kwargs[0]
assert bc_kwargs["user_id"] == user_id
assert bc_kwargs["search_space_id"] == 99
assert bc_kwargs["billing_tier"] == "premium"
assert bc_kwargs["base_model"] == "openai/gpt-4o"
assert bc_kwargs["quota_reserve_tokens"] == 4000
assert bc_kwargs["usage_type"] == "vision_extraction"
@pytest.mark.asyncio
async def test_ainvoke_propagates_quota_insufficient_error(monkeypatch):
from app.services.billable_calls import QuotaInsufficientError
from app.services.quota_checked_vision_llm import QuotaCheckedVisionLLM
@contextlib.asynccontextmanager
async def _denying_billable_call(**_kwargs):
raise QuotaInsufficientError(
usage_type="vision_extraction",
used_micros=5_000_000,
limit_micros=5_000_000,
remaining_micros=0,
)
yield # unreachable but required for asynccontextmanager type
monkeypatch.setattr(
"app.services.quota_checked_vision_llm.billable_call",
_denying_billable_call,
raising=False,
)
inner = _FakeInnerLLM()
wrapper = QuotaCheckedVisionLLM(
inner,
user_id=uuid4(),
search_space_id=1,
billing_tier="premium",
base_model="openai/gpt-4o",
quota_reserve_tokens=4000,
)
with pytest.raises(QuotaInsufficientError):
await wrapper.ainvoke([{"text": "x"}])
# Inner LLM never ran on a denied reservation.
assert inner.ainvoke_calls == []
@pytest.mark.asyncio
async def test_proxies_non_overridden_attributes_to_inner():
"""``__getattr__`` forwards anything not on the proxy itself, so any
method we didn't explicitly override (``invoke``, ``astream``,
``with_structured_output``, etc.) still works just without quota
gating, which is fine because the indexer only ever calls ainvoke.
"""
from app.services.quota_checked_vision_llm import QuotaCheckedVisionLLM
inner = _FakeInnerLLM()
wrapper = QuotaCheckedVisionLLM(
inner,
user_id=uuid4(),
search_space_id=1,
billing_tier="premium",
base_model="openai/gpt-4o",
quota_reserve_tokens=4000,
)
# ``some_other_method`` is on the inner only.
assert wrapper.some_other_method(7) == 14

View file

@ -0,0 +1,515 @@
"""Cost-based premium quota unit tests.
Covers the USD-micro behaviour added in migration 140:
* ``TurnTokenAccumulator.total_cost_micros`` sums ``cost_micros`` across all
calls in a turn used as the debit amount when ``agent_config.is_premium``
is true, regardless of which underlying model produced each call. This
preserves the prior "premium turn → all calls in turn count" rule from the
token-based system.
* ``estimate_call_reserve_micros`` scales linearly with model pricing,
clamps to a sane floor when pricing is unknown, and respects the
``QUOTA_MAX_RESERVE_MICROS`` ceiling so a misconfigured "$1000/M" entry
can't lock the whole balance on one call.
"""
from __future__ import annotations
import pytest
pytestmark = pytest.mark.unit
# ---------------------------------------------------------------------------
# TurnTokenAccumulator — premium-turn debit semantics
# ---------------------------------------------------------------------------
def test_total_cost_micros_sums_premium_and_free_calls():
"""A premium turn that also called a free sub-agent debits the union.
The plan deliberately preserved the existing "premium turn → all calls
count" behaviour because per-call premium filtering relied on
``LLMRouterService._premium_model_strings`` which only covers router-pool
deployments. ``total_cost_micros`` therefore must include free-model
calls (whose ``cost_micros`` is typically ``0``) as well as the premium
call's actual provider cost.
"""
from app.services.token_tracking_service import TurnTokenAccumulator
acc = TurnTokenAccumulator()
# Premium model (e.g. claude-opus): non-zero cost.
acc.add(
model="anthropic/claude-3-5-sonnet",
prompt_tokens=1200,
completion_tokens=400,
total_tokens=1600,
cost_micros=12_345,
)
# Free sub-agent (e.g. title-gen on a free model): zero cost.
acc.add(
model="gpt-4o-mini",
prompt_tokens=120,
completion_tokens=20,
total_tokens=140,
cost_micros=0,
)
# A second premium-priced call within the same turn.
acc.add(
model="anthropic/claude-3-5-sonnet",
prompt_tokens=800,
completion_tokens=200,
total_tokens=1000,
cost_micros=7_500,
)
assert acc.total_cost_micros == 12_345 + 0 + 7_500
# Token totals stay correct so the FE display path still works.
assert acc.grand_total == 1600 + 140 + 1000
def test_total_cost_micros_zero_when_no_calls():
"""An empty accumulator must report zero cost (no division-by-zero, no None)."""
from app.services.token_tracking_service import TurnTokenAccumulator
acc = TurnTokenAccumulator()
assert acc.total_cost_micros == 0
assert acc.grand_total == 0
def test_per_message_summary_groups_cost_by_model():
"""``per_message_summary`` must accumulate ``cost_micros`` per model so the
SSE ``model_breakdown`` payload reports actual USD spend per provider.
"""
from app.services.token_tracking_service import TurnTokenAccumulator
acc = TurnTokenAccumulator()
acc.add(
model="claude-3-5-sonnet",
prompt_tokens=100,
completion_tokens=50,
total_tokens=150,
cost_micros=4_000,
)
acc.add(
model="claude-3-5-sonnet",
prompt_tokens=200,
completion_tokens=100,
total_tokens=300,
cost_micros=8_000,
)
acc.add(
model="gpt-4o-mini",
prompt_tokens=50,
completion_tokens=10,
total_tokens=60,
cost_micros=200,
)
summary = acc.per_message_summary()
assert summary["claude-3-5-sonnet"]["cost_micros"] == 12_000
assert summary["claude-3-5-sonnet"]["total_tokens"] == 450
assert summary["gpt-4o-mini"]["cost_micros"] == 200
def test_serialized_calls_includes_cost_micros():
"""``serialized_calls`` is what flows into the SSE ``call_details``
payload; cost_micros must be present on each entry so the FE message-info
dropdown can render per-call USD.
"""
from app.services.token_tracking_service import TurnTokenAccumulator
acc = TurnTokenAccumulator()
acc.add(
model="m",
prompt_tokens=1,
completion_tokens=1,
total_tokens=2,
cost_micros=42,
)
serialized = acc.serialized_calls()
assert serialized == [
{
"model": "m",
"prompt_tokens": 1,
"completion_tokens": 1,
"total_tokens": 2,
"cost_micros": 42,
"call_kind": "chat",
}
]
# ---------------------------------------------------------------------------
# estimate_call_reserve_micros — sizing and clamping
# ---------------------------------------------------------------------------
def test_reserve_returns_floor_when_model_unknown(monkeypatch):
"""If LiteLLM doesn't know the model, ``get_model_info`` raises and the
helper falls back to the 100-micro floor small enough that a user with
$0.0001 left can still send a tiny request, but non-zero so we still gate
against an empty balance.
"""
import litellm
from app.services import token_quota_service
def _raise(_name):
raise KeyError("unknown")
monkeypatch.setattr(litellm, "get_model_info", _raise, raising=False)
micros = token_quota_service.estimate_call_reserve_micros(
base_model="nonexistent-model",
quota_reserve_tokens=4000,
)
assert micros == token_quota_service._QUOTA_MIN_RESERVE_MICROS
assert micros == 100
def test_reserve_returns_floor_when_pricing_is_zero(monkeypatch):
"""LiteLLM may *return* a model with both cost-per-token fields at 0
(pricing not yet registered). The helper must not multiply 0 x tokens
and end up reserving 0 it must clamp to the floor.
"""
import litellm
from app.services import token_quota_service
monkeypatch.setattr(
litellm,
"get_model_info",
lambda _name: {"input_cost_per_token": 0, "output_cost_per_token": 0},
raising=False,
)
micros = token_quota_service.estimate_call_reserve_micros(
base_model="some-pending-model",
quota_reserve_tokens=4000,
)
assert micros == token_quota_service._QUOTA_MIN_RESERVE_MICROS
def test_reserve_scales_with_model_cost(monkeypatch):
"""Claude-Opus-priced model with 4000 reserve_tokens reserves
~$0.36 = 360_000 micros. Critically this must NOT be clamped down to
some small artificial cap that was the bug the plan called out.
"""
import litellm
from app.config import config
from app.services import token_quota_service
monkeypatch.setattr(
litellm,
"get_model_info",
lambda _name: {
"input_cost_per_token": 15e-6,
"output_cost_per_token": 75e-6,
},
raising=False,
)
monkeypatch.setattr(config, "QUOTA_MAX_RESERVE_MICROS", 1_000_000, raising=False)
micros = token_quota_service.estimate_call_reserve_micros(
base_model="claude-3-opus",
quota_reserve_tokens=4000,
)
# 4000 * (15e-6 + 75e-6) = 4000 * 90e-6 = 0.36 USD = 360_000 micros.
assert micros == 360_000
def test_reserve_clamps_to_max_ceiling(monkeypatch):
"""A misconfigured "$1000 / M" model with 4000 reserve_tokens would
nominally compute to $4 = 4_000_000 micros. The ceiling
``QUOTA_MAX_RESERVE_MICROS`` must clamp that so a bad pricing entry
can't lock the user's whole balance on one call.
"""
import litellm
from app.config import config
from app.services import token_quota_service
monkeypatch.setattr(
litellm,
"get_model_info",
lambda _name: {
"input_cost_per_token": 1e-3,
"output_cost_per_token": 0,
},
raising=False,
)
monkeypatch.setattr(config, "QUOTA_MAX_RESERVE_MICROS", 1_000_000, raising=False)
micros = token_quota_service.estimate_call_reserve_micros(
base_model="oops-misconfigured",
quota_reserve_tokens=4000,
)
assert micros == 1_000_000
def test_reserve_uses_default_when_quota_reserve_tokens_missing(monkeypatch):
"""Per-config ``quota_reserve_tokens`` is optional; when ``None`` or
zero, the helper must fall back to the global ``QUOTA_MAX_RESERVE_PER_CALL``
so anonymous-style configs still reserve the operator-tunable default.
"""
import litellm
from app.config import config
from app.services import token_quota_service
monkeypatch.setattr(
litellm,
"get_model_info",
lambda _name: {
"input_cost_per_token": 1e-6,
"output_cost_per_token": 1e-6,
},
raising=False,
)
monkeypatch.setattr(config, "QUOTA_MAX_RESERVE_PER_CALL", 2000, raising=False)
monkeypatch.setattr(config, "QUOTA_MAX_RESERVE_MICROS", 1_000_000, raising=False)
# 2000 * (1e-6 + 1e-6) = 4e-3 USD = 4000 micros
assert (
token_quota_service.estimate_call_reserve_micros(
base_model="cheap", quota_reserve_tokens=None
)
== 4000
)
assert (
token_quota_service.estimate_call_reserve_micros(
base_model="cheap", quota_reserve_tokens=0
)
== 4000
)
# ---------------------------------------------------------------------------
# TokenTrackingCallback — image vs chat usage shape
# ---------------------------------------------------------------------------
class _FakeImageUsage:
"""Mimics LiteLLM's ``ImageUsage`` (input_tokens / output_tokens shape)."""
def __init__(
self,
input_tokens: int = 0,
output_tokens: int = 0,
total_tokens: int | None = None,
) -> None:
self.input_tokens = input_tokens
self.output_tokens = output_tokens
if total_tokens is not None:
self.total_tokens = total_tokens
class _FakeImageResponse:
"""Mimics LiteLLM's ``ImageResponse`` — same name so the callback's
``type(...).__name__`` probe routes to the image branch.
"""
def __init__(self, usage: _FakeImageUsage, response_cost: float | None = None):
self.usage = usage
if response_cost is not None:
self._hidden_params = {"response_cost": response_cost}
# Re-tag the helper class as ``ImageResponse`` for the type-name probe in
# the callback. We can't simply name the class ``ImageResponse`` because
# the test runner sometimes imports test modules in surprising ways and
# we want to be explicit.
_FakeImageResponse.__name__ = "ImageResponse"
class _FakeChatUsage:
def __init__(self, prompt: int, completion: int):
self.prompt_tokens = prompt
self.completion_tokens = completion
self.total_tokens = prompt + completion
class _FakeChatResponse:
def __init__(self, usage: _FakeChatUsage):
self.usage = usage
@pytest.mark.asyncio
async def test_callback_reads_image_usage_input_output_tokens():
"""``TokenTrackingCallback`` must read ``input_tokens``/``output_tokens``
for ``ImageResponse`` (LiteLLM's ImageUsage shape), NOT
prompt_tokens/completion_tokens which is the chat shape.
"""
from app.services.token_tracking_service import (
TokenTrackingCallback,
scoped_turn,
)
cb = TokenTrackingCallback()
response = _FakeImageResponse(
usage=_FakeImageUsage(input_tokens=42, output_tokens=8, total_tokens=50),
response_cost=0.04, # $0.04 per image
)
async with scoped_turn() as acc:
await cb.async_log_success_event(
kwargs={"model": "openai/gpt-image-1", "response_cost": 0.04},
response_obj=response,
start_time=None,
end_time=None,
)
assert len(acc.calls) == 1
call = acc.calls[0]
assert call.prompt_tokens == 42
assert call.completion_tokens == 8
assert call.total_tokens == 50
# 0.04 USD = 40_000 micros
assert call.cost_micros == 40_000
assert call.call_kind == "image_generation"
@pytest.mark.asyncio
async def test_callback_chat_path_unchanged():
"""Chat responses must still read prompt_tokens/completion_tokens."""
from app.services.token_tracking_service import (
TokenTrackingCallback,
scoped_turn,
)
cb = TokenTrackingCallback()
response = _FakeChatResponse(_FakeChatUsage(prompt=120, completion=30))
async with scoped_turn() as acc:
await cb.async_log_success_event(
kwargs={
"model": "openrouter/anthropic/claude-3-5-sonnet",
"response_cost": 0.0036,
},
response_obj=response,
start_time=None,
end_time=None,
)
assert len(acc.calls) == 1
call = acc.calls[0]
assert call.prompt_tokens == 120
assert call.completion_tokens == 30
assert call.total_tokens == 150
assert call.cost_micros == 3_600
assert call.call_kind == "chat"
@pytest.mark.asyncio
async def test_callback_image_missing_response_cost_falls_back_to_zero(monkeypatch):
"""When OpenRouter omits ``usage.cost`` LiteLLM's
``default_image_cost_calculator`` raises. The defensive image branch in
``_extract_cost_usd`` must NOT call ``cost_per_token`` (which is
chat-shaped and would raise too) it returns 0 with a WARNING log.
"""
import litellm
from app.services.token_tracking_service import (
TokenTrackingCallback,
scoped_turn,
)
# Force completion_cost to raise the same way OpenRouter image-gen fails.
def _boom(*_args, **_kwargs):
raise ValueError("model_cost: missing entry for openrouter image model")
monkeypatch.setattr(litellm, "completion_cost", _boom, raising=False)
# And make sure cost_per_token is NEVER called for the image path —
# if it were, our ``is_image=True`` branch is broken.
cost_per_token_calls: list = []
def _record_cost_per_token(**kwargs):
cost_per_token_calls.append(kwargs)
return (0.0, 0.0)
monkeypatch.setattr(
litellm, "cost_per_token", _record_cost_per_token, raising=False
)
cb = TokenTrackingCallback()
response = _FakeImageResponse(
usage=_FakeImageUsage(input_tokens=7, output_tokens=0)
)
async with scoped_turn() as acc:
await cb.async_log_success_event(
kwargs={"model": "openrouter/google/gemini-2.5-flash-image"},
response_obj=response,
start_time=None,
end_time=None,
)
assert len(acc.calls) == 1
assert acc.calls[0].cost_micros == 0
assert acc.calls[0].call_kind == "image_generation"
# The image branch must short-circuit before cost_per_token.
assert cost_per_token_calls == []
# ---------------------------------------------------------------------------
# scoped_turn — ContextVar reset semantics (issue B)
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_scoped_turn_restores_outer_accumulator():
"""``scoped_turn`` must restore the previous ContextVar value on exit
so a per-call wrapper inside an outer chat turn doesn't leak its
accumulator outward (which would cause double-debit at chat-turn exit).
"""
from app.services.token_tracking_service import (
get_current_accumulator,
scoped_turn,
start_turn,
)
outer = start_turn()
assert get_current_accumulator() is outer
async with scoped_turn() as inner:
assert get_current_accumulator() is inner
assert inner is not outer
inner.add(
model="x",
prompt_tokens=1,
completion_tokens=1,
total_tokens=2,
cost_micros=5,
)
# After exit the outer accumulator is restored unchanged.
assert get_current_accumulator() is outer
assert outer.total_cost_micros == 0
assert len(outer.calls) == 0
# The inner accumulator captured the call but didn't bleed into outer.
assert inner.total_cost_micros == 5
@pytest.mark.asyncio
async def test_scoped_turn_resets_to_none_when_no_outer():
"""Running ``scoped_turn`` outside any chat turn (e.g. a background
indexing job) must leave the ContextVar at ``None`` on exit so the
next *unrelated* request starts clean.
"""
from app.services.token_tracking_service import (
_turn_accumulator,
get_current_accumulator,
scoped_turn,
)
# ContextVar default is None for a fresh test isolated context. We
# simulate "no outer" explicitly to be robust against test order.
token = _turn_accumulator.set(None)
try:
assert get_current_accumulator() is None
async with scoped_turn() as acc:
assert get_current_accumulator() is acc
assert get_current_accumulator() is None
finally:
_turn_accumulator.reset(token)

View file

@ -0,0 +1,325 @@
"""Unit tests for podcast Celery task billing integration.
Validates ``_generate_content_podcast`` correctly wraps
``podcaster_graph.ainvoke`` in a ``billable_call`` envelope, propagates the
search-space owner's billing decision, and degrades cleanly when the
resolver fails or premium credit is exhausted.
Coverage:
* Happy-path free config: resolver ``billable_call`` enters with
``usage_type='podcast_generation'`` and the configured reserve override,
graph runs, podcast row flips to ``READY``.
* Happy-path premium config: same wiring with ``billing_tier='premium'``.
* Quota denial: ``billable_call`` raises ``QuotaInsufficientError``
graph is *not* invoked, podcast row flips to ``FAILED``, return dict
carries ``reason='premium_quota_exhausted'``.
* Resolver failure: ``ValueError`` from the resolver podcast row flips
to ``FAILED``, return dict carries ``reason='billing_resolution_failed'``.
"""
from __future__ import annotations
import contextlib
from types import SimpleNamespace
from typing import Any
from uuid import uuid4
import pytest
pytestmark = pytest.mark.unit
# ---------------------------------------------------------------------------
# Fakes
# ---------------------------------------------------------------------------
class _FakeExecResult:
def __init__(self, obj):
self._obj = obj
def scalars(self):
return self
def first(self):
return self._obj
def filter(self, *_args, **_kwargs):
return self
class _FakeSession:
def __init__(self, podcast):
self._podcast = podcast
self.commit_count = 0
async def execute(self, _stmt):
return _FakeExecResult(self._podcast)
async def commit(self):
self.commit_count += 1
async def __aenter__(self):
return self
async def __aexit__(self, *args):
return None
class _FakeSessionMaker:
def __init__(self, session: _FakeSession):
self._session = session
def __call__(self):
return self._session
def _make_podcast(podcast_id: int = 7, thread_id: int = 99) -> SimpleNamespace:
"""Stand-in for a ``Podcast`` row. Importing ``PodcastStatus`` lazily
inside helpers keeps this fixture cheap."""
return SimpleNamespace(
id=podcast_id,
title="Test Podcast",
thread_id=thread_id,
status=None,
podcast_transcript=None,
file_location=None,
)
@contextlib.asynccontextmanager
async def _ok_billable_call(**kwargs):
"""Stand-in for ``billable_call`` that records its kwargs and yields a
no-op accumulator-shaped object."""
_CALL_LOG.append(kwargs)
yield SimpleNamespace()
_CALL_LOG: list[dict[str, Any]] = []
@contextlib.asynccontextmanager
async def _denying_billable_call(**kwargs):
from app.services.billable_calls import QuotaInsufficientError
_CALL_LOG.append(kwargs)
raise QuotaInsufficientError(
usage_type=kwargs.get("usage_type", "?"),
used_micros=5_000_000,
limit_micros=5_000_000,
remaining_micros=0,
)
yield SimpleNamespace() # pragma: no cover — for grammar only
# ---------------------------------------------------------------------------
# Tests
# ---------------------------------------------------------------------------
@pytest.fixture(autouse=True)
def _reset_call_log():
_CALL_LOG.clear()
yield
_CALL_LOG.clear()
@pytest.mark.asyncio
async def test_billable_call_invoked_with_correct_kwargs_for_free_config(monkeypatch):
"""Happy path: free billing tier still wraps the graph call so the
audit row is recorded. Verifies kwargs threading."""
from app.config import config as app_config
from app.db import PodcastStatus
from app.tasks.celery_tasks import podcast_tasks
podcast = _make_podcast(podcast_id=7, thread_id=99)
session = _FakeSession(podcast)
monkeypatch.setattr(
podcast_tasks,
"get_celery_session_maker",
lambda: _FakeSessionMaker(session),
)
user_id = uuid4()
async def _fake_resolver(sess, search_space_id, *, thread_id=None):
assert search_space_id == 555
assert thread_id == 99
return user_id, "free", "openrouter/some-free-model"
monkeypatch.setattr(
podcast_tasks, "_resolve_agent_billing_for_search_space", _fake_resolver
)
monkeypatch.setattr(podcast_tasks, "billable_call", _ok_billable_call)
async def _fake_graph_invoke(state, config):
return {
"podcast_transcript": [
SimpleNamespace(speaker_id=0, dialog="Hi"),
SimpleNamespace(speaker_id=1, dialog="Hello"),
],
"final_podcast_file_path": "/tmp/podcast.wav",
}
monkeypatch.setattr(podcast_tasks.podcaster_graph, "ainvoke", _fake_graph_invoke)
result = await podcast_tasks._generate_content_podcast(
podcast_id=7,
source_content="hello world",
search_space_id=555,
user_prompt="make it short",
)
assert result["status"] == "ready"
assert result["podcast_id"] == 7
assert podcast.status == PodcastStatus.READY
assert podcast.file_location == "/tmp/podcast.wav"
assert len(_CALL_LOG) == 1
call = _CALL_LOG[0]
assert call["user_id"] == user_id
assert call["search_space_id"] == 555
assert call["billing_tier"] == "free"
assert call["base_model"] == "openrouter/some-free-model"
assert call["usage_type"] == "podcast_generation"
assert (
call["quota_reserve_micros_override"]
== app_config.QUOTA_DEFAULT_PODCAST_RESERVE_MICROS
)
assert call["thread_id"] == 99
assert call["call_details"] == {"podcast_id": 7, "title": "Test Podcast"}
@pytest.mark.asyncio
async def test_billable_call_invoked_with_premium_tier(monkeypatch):
"""Premium resolution flows through to ``billable_call`` so the
reserve/finalize path triggers."""
from app.tasks.celery_tasks import podcast_tasks
podcast = _make_podcast()
session = _FakeSession(podcast)
monkeypatch.setattr(
podcast_tasks,
"get_celery_session_maker",
lambda: _FakeSessionMaker(session),
)
user_id = uuid4()
async def _fake_resolver(sess, search_space_id, *, thread_id=None):
return user_id, "premium", "gpt-5.4"
monkeypatch.setattr(
podcast_tasks, "_resolve_agent_billing_for_search_space", _fake_resolver
)
monkeypatch.setattr(podcast_tasks, "billable_call", _ok_billable_call)
async def _fake_graph_invoke(state, config):
return {"podcast_transcript": [], "final_podcast_file_path": "x.wav"}
monkeypatch.setattr(podcast_tasks.podcaster_graph, "ainvoke", _fake_graph_invoke)
await podcast_tasks._generate_content_podcast(
podcast_id=7,
source_content="hi",
search_space_id=555,
user_prompt=None,
)
assert _CALL_LOG[0]["billing_tier"] == "premium"
assert _CALL_LOG[0]["base_model"] == "gpt-5.4"
@pytest.mark.asyncio
async def test_quota_insufficient_marks_podcast_failed_and_skips_graph(monkeypatch):
"""When ``billable_call`` denies the reservation, the graph never
runs and the podcast row flips to FAILED with the documented reason
code."""
from app.db import PodcastStatus
from app.tasks.celery_tasks import podcast_tasks
podcast = _make_podcast(podcast_id=8)
session = _FakeSession(podcast)
monkeypatch.setattr(
podcast_tasks,
"get_celery_session_maker",
lambda: _FakeSessionMaker(session),
)
async def _fake_resolver(sess, search_space_id, *, thread_id=None):
return uuid4(), "premium", "gpt-5.4"
monkeypatch.setattr(
podcast_tasks, "_resolve_agent_billing_for_search_space", _fake_resolver
)
monkeypatch.setattr(podcast_tasks, "billable_call", _denying_billable_call)
graph_invoked = []
async def _fake_graph_invoke(state, config):
graph_invoked.append(True)
return {}
monkeypatch.setattr(podcast_tasks.podcaster_graph, "ainvoke", _fake_graph_invoke)
result = await podcast_tasks._generate_content_podcast(
podcast_id=8,
source_content="hi",
search_space_id=555,
user_prompt=None,
)
assert result == {
"status": "failed",
"podcast_id": 8,
"reason": "premium_quota_exhausted",
}
assert podcast.status == PodcastStatus.FAILED
assert graph_invoked == [] # Graph never ran on denied reservation.
@pytest.mark.asyncio
async def test_resolver_failure_marks_podcast_failed(monkeypatch):
"""If the resolver raises (e.g. search-space deleted), the task fails
cleanly without invoking the graph."""
from app.db import PodcastStatus
from app.tasks.celery_tasks import podcast_tasks
podcast = _make_podcast(podcast_id=9)
session = _FakeSession(podcast)
monkeypatch.setattr(
podcast_tasks,
"get_celery_session_maker",
lambda: _FakeSessionMaker(session),
)
async def _failing_resolver(sess, search_space_id, *, thread_id=None):
raise ValueError("Search space 555 not found")
monkeypatch.setattr(
podcast_tasks, "_resolve_agent_billing_for_search_space", _failing_resolver
)
graph_invoked = []
async def _fake_graph_invoke(state, config):
graph_invoked.append(True)
return {}
monkeypatch.setattr(podcast_tasks.podcaster_graph, "ainvoke", _fake_graph_invoke)
result = await podcast_tasks._generate_content_podcast(
podcast_id=9,
source_content="hi",
search_space_id=555,
user_prompt=None,
)
assert result == {
"status": "failed",
"podcast_id": 9,
"reason": "billing_resolution_failed",
}
assert podcast.status == PodcastStatus.FAILED
assert graph_invoked == []

View file

@ -0,0 +1,330 @@
"""Unit tests for video-presentation Celery task billing integration.
Mirrors ``test_podcast_billing.py`` for the video-presentation task.
Validates the same wrap-graph-in-billable_call pattern and ensures the
larger ``QUOTA_DEFAULT_VIDEO_PRESENTATION_RESERVE_MICROS`` reservation is
threaded through.
Coverage:
* Free config: graph runs, ``billable_call`` invoked with the video
reserve override.
* Premium config: same wiring with ``billing_tier='premium'``.
* Quota denial: graph not invoked, row FAILED, reason code surfaced.
* Resolver failure: row FAILED with ``billing_resolution_failed``.
"""
from __future__ import annotations
import contextlib
from types import SimpleNamespace
from typing import Any
from uuid import uuid4
import pytest
pytestmark = pytest.mark.unit
# ---------------------------------------------------------------------------
# Fakes
# ---------------------------------------------------------------------------
class _FakeExecResult:
def __init__(self, obj):
self._obj = obj
def scalars(self):
return self
def first(self):
return self._obj
def filter(self, *_args, **_kwargs):
return self
class _FakeSession:
def __init__(self, video):
self._video = video
self.commit_count = 0
async def execute(self, _stmt):
return _FakeExecResult(self._video)
async def commit(self):
self.commit_count += 1
async def __aenter__(self):
return self
async def __aexit__(self, *args):
return None
class _FakeSessionMaker:
def __init__(self, session: _FakeSession):
self._session = session
def __call__(self):
return self._session
def _make_video(video_id: int = 11, thread_id: int = 99) -> SimpleNamespace:
return SimpleNamespace(
id=video_id,
title="Test Presentation",
thread_id=thread_id,
status=None,
slides=None,
scene_codes=None,
)
_CALL_LOG: list[dict[str, Any]] = []
@contextlib.asynccontextmanager
async def _ok_billable_call(**kwargs):
_CALL_LOG.append(kwargs)
yield SimpleNamespace()
@contextlib.asynccontextmanager
async def _denying_billable_call(**kwargs):
from app.services.billable_calls import QuotaInsufficientError
_CALL_LOG.append(kwargs)
raise QuotaInsufficientError(
usage_type=kwargs.get("usage_type", "?"),
used_micros=5_000_000,
limit_micros=5_000_000,
remaining_micros=0,
)
yield SimpleNamespace() # pragma: no cover
# ---------------------------------------------------------------------------
# Tests
# ---------------------------------------------------------------------------
@pytest.fixture(autouse=True)
def _reset_call_log():
_CALL_LOG.clear()
yield
_CALL_LOG.clear()
@pytest.mark.asyncio
async def test_billable_call_invoked_with_correct_kwargs_for_free_config(monkeypatch):
from app.config import config as app_config
from app.db import VideoPresentationStatus
from app.tasks.celery_tasks import video_presentation_tasks
video = _make_video(video_id=11, thread_id=99)
session = _FakeSession(video)
monkeypatch.setattr(
video_presentation_tasks,
"get_celery_session_maker",
lambda: _FakeSessionMaker(session),
)
user_id = uuid4()
async def _fake_resolver(sess, search_space_id, *, thread_id=None):
assert search_space_id == 777
assert thread_id == 99
return user_id, "free", "openrouter/some-free-model"
monkeypatch.setattr(
video_presentation_tasks,
"_resolve_agent_billing_for_search_space",
_fake_resolver,
)
monkeypatch.setattr(video_presentation_tasks, "billable_call", _ok_billable_call)
async def _fake_graph_invoke(state, config):
return {"slides": [], "slide_audio_results": [], "slide_scene_codes": []}
monkeypatch.setattr(
video_presentation_tasks.video_presentation_graph,
"ainvoke",
_fake_graph_invoke,
)
result = await video_presentation_tasks._generate_video_presentation(
video_presentation_id=11,
source_content="content",
search_space_id=777,
user_prompt=None,
)
assert result["status"] == "ready"
assert result["video_presentation_id"] == 11
assert video.status == VideoPresentationStatus.READY
assert len(_CALL_LOG) == 1
call = _CALL_LOG[0]
assert call["user_id"] == user_id
assert call["search_space_id"] == 777
assert call["billing_tier"] == "free"
assert call["base_model"] == "openrouter/some-free-model"
assert call["usage_type"] == "video_presentation_generation"
assert (
call["quota_reserve_micros_override"]
== app_config.QUOTA_DEFAULT_VIDEO_PRESENTATION_RESERVE_MICROS
)
assert call["thread_id"] == 99
assert call["call_details"] == {
"video_presentation_id": 11,
"title": "Test Presentation",
}
@pytest.mark.asyncio
async def test_billable_call_invoked_with_premium_tier(monkeypatch):
from app.tasks.celery_tasks import video_presentation_tasks
video = _make_video()
session = _FakeSession(video)
monkeypatch.setattr(
video_presentation_tasks,
"get_celery_session_maker",
lambda: _FakeSessionMaker(session),
)
user_id = uuid4()
async def _fake_resolver(sess, search_space_id, *, thread_id=None):
return user_id, "premium", "gpt-5.4"
monkeypatch.setattr(
video_presentation_tasks,
"_resolve_agent_billing_for_search_space",
_fake_resolver,
)
monkeypatch.setattr(video_presentation_tasks, "billable_call", _ok_billable_call)
async def _fake_graph_invoke(state, config):
return {"slides": [], "slide_audio_results": [], "slide_scene_codes": []}
monkeypatch.setattr(
video_presentation_tasks.video_presentation_graph,
"ainvoke",
_fake_graph_invoke,
)
await video_presentation_tasks._generate_video_presentation(
video_presentation_id=11,
source_content="content",
search_space_id=777,
user_prompt=None,
)
assert _CALL_LOG[0]["billing_tier"] == "premium"
assert _CALL_LOG[0]["base_model"] == "gpt-5.4"
@pytest.mark.asyncio
async def test_quota_insufficient_marks_video_failed_and_skips_graph(monkeypatch):
from app.db import VideoPresentationStatus
from app.tasks.celery_tasks import video_presentation_tasks
video = _make_video(video_id=12)
session = _FakeSession(video)
monkeypatch.setattr(
video_presentation_tasks,
"get_celery_session_maker",
lambda: _FakeSessionMaker(session),
)
async def _fake_resolver(sess, search_space_id, *, thread_id=None):
return uuid4(), "premium", "gpt-5.4"
monkeypatch.setattr(
video_presentation_tasks,
"_resolve_agent_billing_for_search_space",
_fake_resolver,
)
monkeypatch.setattr(
video_presentation_tasks, "billable_call", _denying_billable_call
)
graph_invoked = []
async def _fake_graph_invoke(state, config):
graph_invoked.append(True)
return {}
monkeypatch.setattr(
video_presentation_tasks.video_presentation_graph,
"ainvoke",
_fake_graph_invoke,
)
result = await video_presentation_tasks._generate_video_presentation(
video_presentation_id=12,
source_content="content",
search_space_id=777,
user_prompt=None,
)
assert result == {
"status": "failed",
"video_presentation_id": 12,
"reason": "premium_quota_exhausted",
}
assert video.status == VideoPresentationStatus.FAILED
assert graph_invoked == []
@pytest.mark.asyncio
async def test_resolver_failure_marks_video_failed(monkeypatch):
from app.db import VideoPresentationStatus
from app.tasks.celery_tasks import video_presentation_tasks
video = _make_video(video_id=13)
session = _FakeSession(video)
monkeypatch.setattr(
video_presentation_tasks,
"get_celery_session_maker",
lambda: _FakeSessionMaker(session),
)
async def _failing_resolver(sess, search_space_id, *, thread_id=None):
raise ValueError("Search space 777 not found")
monkeypatch.setattr(
video_presentation_tasks,
"_resolve_agent_billing_for_search_space",
_failing_resolver,
)
graph_invoked = []
async def _fake_graph_invoke(state, config):
graph_invoked.append(True)
return {}
monkeypatch.setattr(
video_presentation_tasks.video_presentation_graph,
"ainvoke",
_fake_graph_invoke,
)
result = await video_presentation_tasks._generate_video_presentation(
video_presentation_id=13,
source_content="content",
search_space_id=777,
user_prompt=None,
)
assert result == {
"status": "failed",
"video_presentation_id": 13,
"reason": "billing_resolution_failed",
}
assert video.status == VideoPresentationStatus.FAILED
assert graph_invoked == []

View file

@ -127,7 +127,7 @@ const FAQ_ITEMS = [
{ {
question: "What happens after I use my free tokens?", question: "What happens after I use my free tokens?",
answer: answer:
"After your free tokens, create a free SurfSense account to unlock 3 million more premium tokens. Additional tokens can be purchased at $1 per million. Non-premium models remain unlimited for registered users.", "After your free tokens, create a free SurfSense account to unlock $5 of premium credit. Additional credit can be topped up at $1 for $1 of credit, billed at the actual provider cost. Non-premium models remain unlimited for registered users.",
}, },
{ {
question: "Is Claude AI available without login?", question: "Is Claude AI available without login?",
@ -329,7 +329,7 @@ export default async function FreeHubPage() {
<section className="max-w-3xl mx-auto text-center"> <section className="max-w-3xl mx-auto text-center">
<h2 className="text-2xl font-bold mb-3">Want More Features?</h2> <h2 className="text-2xl font-bold mb-3">Want More Features?</h2>
<p className="text-muted-foreground mb-6 leading-relaxed"> <p className="text-muted-foreground mb-6 leading-relaxed">
Create a free SurfSense account to unlock 3 million tokens, document uploads with Create a free SurfSense account to unlock $5 of premium credit, document uploads with
citations, team collaboration, and integrations with Slack, Google Drive, Notion, and citations, team collaboration, and integrations with Slack, Google Drive, Notion, and
30+ more tools. 30+ more tools.
</p> </p>

View file

@ -5,7 +5,7 @@ import { BreadcrumbNav } from "@/components/seo/breadcrumb-nav";
export const metadata: Metadata = { export const metadata: Metadata = {
title: "Pricing | SurfSense - Free AI Search Plans", title: "Pricing | SurfSense - Free AI Search Plans",
description: description:
"Explore SurfSense plans and pricing. Start free with 500 pages & 3M premium tokens. Use ChatGPT, Claude AI, and premium AI models. Pay-as-you-go tokens at $1 per million.", "Explore SurfSense plans and pricing. Start free with 500 pages & $5 of premium credit. Use ChatGPT, Claude AI, and premium AI models. Pay as you go at provider cost — $1 buys $1 of credit.",
alternates: { alternates: {
canonical: "https://surfsense.com/pricing", canonical: "https://surfsense.com/pricing",
}, },

View file

@ -8,7 +8,7 @@ import { cn } from "@/lib/utils";
const TABS = [ const TABS = [
{ id: "pages", label: "Pages" }, { id: "pages", label: "Pages" },
{ id: "tokens", label: "Premium Tokens" }, { id: "tokens", label: "Premium Credit" },
] as const; ] as const;
type TabId = (typeof TABS)[number]["id"]; type TabId = (typeof TABS)[number]["id"];

View file

@ -28,6 +28,12 @@ type UnifiedPurchase = {
kind: PurchaseKind; kind: PurchaseKind;
created_at: string; created_at: string;
status: PagePurchaseStatus; status: PagePurchaseStatus;
/**
* Granted units. Interpretation depends on ``kind``:
* - ``"pages"`` integer number of indexed pages.
* - ``"tokens"`` integer micro-USD of credit (1_000_000 = $1.00).
* The ``Granted`` column formats accordingly.
*/
granted: number; granted: number;
amount_total: number | null; amount_total: number | null;
currency: string | null; currency: string | null;
@ -58,7 +64,7 @@ const KIND_META: Record<
iconClass: "text-sky-500", iconClass: "text-sky-500",
}, },
tokens: { tokens: {
label: "Premium Tokens", label: "Premium Credit",
icon: Coins, icon: Coins,
iconClass: "text-amber-500", iconClass: "text-amber-500",
}, },
@ -97,12 +103,25 @@ function normalizeTokenPurchase(p: TokenPurchase): UnifiedPurchase {
kind: "tokens", kind: "tokens",
created_at: p.created_at, created_at: p.created_at,
status: p.status, status: p.status,
granted: p.tokens_granted, granted: p.credit_micros_granted,
amount_total: p.amount_total, amount_total: p.amount_total,
currency: p.currency, currency: p.currency,
}; };
} }
function formatGranted(p: UnifiedPurchase): string {
if (p.kind === "tokens") {
const dollars = p.granted / 1_000_000;
// Premium credit packs are always whole dollars at the moment, but
// future fractional grants (refunds, partial top-ups) shouldn't
// silently round to "$0".
if (dollars >= 1) return `$${dollars.toFixed(2)} of credit`;
if (dollars > 0) return `$${dollars.toFixed(3)} of credit`;
return "$0 of credit";
}
return p.granted.toLocaleString();
}
export function PurchaseHistoryContent() { export function PurchaseHistoryContent() {
const results = useQueries({ const results = useQueries({
queries: [ queries: [
@ -143,7 +162,7 @@ export function PurchaseHistoryContent() {
<ReceiptText className="h-8 w-8 text-muted-foreground" /> <ReceiptText className="h-8 w-8 text-muted-foreground" />
<p className="text-sm font-medium">No purchases yet</p> <p className="text-sm font-medium">No purchases yet</p>
<p className="text-xs text-muted-foreground"> <p className="text-xs text-muted-foreground">
Your page and premium token purchases will appear here after checkout. Your page and premium credit purchases will appear here after checkout.
</p> </p>
</div> </div>
); );
@ -177,7 +196,7 @@ export function PurchaseHistoryContent() {
</div> </div>
</TableCell> </TableCell>
<TableCell className="text-right tabular-nums text-sm"> <TableCell className="text-right tabular-nums text-sm">
{p.granted.toLocaleString()} {formatGranted(p)}
</TableCell> </TableCell>
<TableCell className="text-right tabular-nums text-sm"> <TableCell className="text-right tabular-nums text-sm">
{formatAmount(p.amount_total, p.currency)} {formatAmount(p.amount_total, p.currency)}

View file

@ -8,9 +8,9 @@ const userQueryFn = () => userApiService.getMe();
export const currentUserAtom = atomWithQuery(() => { export const currentUserAtom = atomWithQuery(() => {
return { return {
queryKey: USER_QUERY_KEY, queryKey: USER_QUERY_KEY,
// Live-changing numeric fields (pages_*, premium_tokens_*) are now // Live-changing numeric fields (pages_*, premium_credit_micros_*)
// pushed via Zero (queries.user.me()), so /users/me only needs to // are now pushed via Zero (queries.user.me()), so /users/me only
// fire once per session for the static profile fields. // needs to fire once per session for the static profile fields.
staleTime: Infinity, staleTime: Infinity,
enabled: !!getBearerToken(), enabled: !!getBearerToken(),
queryFn: userQueryFn, queryFn: userQueryFn,

View file

@ -399,6 +399,19 @@ function formatMessageDate(date: Date): string {
}); });
} }
/**
* Format provider USD cost (in micro-USD) for inline display next to a
* token count. Falls back to ``"<$0.001"`` for sub-tenth-of-a-cent
* costs so a real-but-tiny figure doesn't render as ``$0.000``.
*/
function formatTurnCost(micros: number): string {
const dollars = micros / 1_000_000;
if (dollars >= 1) return `$${dollars.toFixed(2)}`;
if (dollars >= 0.01) return `$${dollars.toFixed(3)}`;
if (dollars > 0) return "<$0.001";
return "$0";
}
const MessageInfoDropdown: FC = () => { const MessageInfoDropdown: FC = () => {
const messageId = useAuiState(({ message }) => message?.id); const messageId = useAuiState(({ message }) => message?.id);
const createdAt = useAuiState(({ message }) => message?.createdAt); const createdAt = useAuiState(({ message }) => message?.createdAt);
@ -451,6 +464,7 @@ const MessageInfoDropdown: FC = () => {
{models.length > 0 ? ( {models.length > 0 ? (
models.map(([model, counts]) => { models.map(([model, counts]) => {
const { name, icon } = resolveModel(model); const { name, icon } = resolveModel(model);
const costMicros = counts.cost_micros;
return ( return (
<ActionBarMorePrimitive.Item <ActionBarMorePrimitive.Item
key={model} key={model}
@ -463,6 +477,9 @@ const MessageInfoDropdown: FC = () => {
</span> </span>
<span className="text-xs text-muted-foreground"> <span className="text-xs text-muted-foreground">
{counts.total_tokens.toLocaleString()} tokens {counts.total_tokens.toLocaleString()} tokens
{costMicros && costMicros > 0
? ` · ${formatTurnCost(costMicros)}`
: ""}
</span> </span>
</ActionBarMorePrimitive.Item> </ActionBarMorePrimitive.Item>
); );
@ -474,6 +491,9 @@ const MessageInfoDropdown: FC = () => {
> >
<span className="text-xs text-muted-foreground"> <span className="text-xs text-muted-foreground">
{usage.total_tokens.toLocaleString()} tokens {usage.total_tokens.toLocaleString()} tokens
{usage.cost_micros && usage.cost_micros > 0
? ` · ${formatTurnCost(usage.cost_micros)}`
: ""}
</span> </span>
</ActionBarMorePrimitive.Item> </ActionBarMorePrimitive.Item>
)} )}

View file

@ -13,13 +13,30 @@ export interface TokenUsageData {
prompt_tokens: number; prompt_tokens: number;
completion_tokens: number; completion_tokens: number;
total_tokens: number; total_tokens: number;
/**
* Total provider USD cost for this assistant turn, in micro-USD
* (1_000_000 = $1.00). Populated from LiteLLM's response_cost on
* the backend. Optional because pre-cost-credits messages persisted
* before the migration won't have it.
*/
cost_micros?: number;
usage?: Record< usage?: Record<
string, string,
{ prompt_tokens: number; completion_tokens: number; total_tokens: number } {
prompt_tokens: number;
completion_tokens: number;
total_tokens: number;
cost_micros?: number;
}
>; >;
model_breakdown?: Record< model_breakdown?: Record<
string, string,
{ prompt_tokens: number; completion_tokens: number; total_tokens: number } {
prompt_tokens: number;
completion_tokens: number;
total_tokens: number;
cost_micros?: number;
}
>; >;
} }

View file

@ -40,7 +40,7 @@ export function QuotaWarningBanner({
</p> </p>
<p className="text-xs text-red-600 dark:text-red-300"> <p className="text-xs text-red-600 dark:text-red-300">
You&apos;ve used all {limit.toLocaleString()} free tokens. Create a free account to You&apos;ve used all {limit.toLocaleString()} free tokens. Create a free account to
get 3 million tokens and access to all models. get $5 of premium credit and access to all models.
</p> </p>
<Link <Link
href="/register" href="/register"
@ -69,7 +69,7 @@ export function QuotaWarningBanner({
<Link href="/register" className="font-medium underline hover:no-underline"> <Link href="/register" className="font-medium underline hover:no-underline">
Create an account Create an account
</Link>{" "} </Link>{" "}
for 5M free tokens. for $5 of premium credit.
</p> </p>
<button <button
type="button" type="button"

View file

@ -5,6 +5,14 @@ import { Progress } from "@/components/ui/progress";
import { useIsAnonymous } from "@/contexts/anonymous-mode"; import { useIsAnonymous } from "@/contexts/anonymous-mode";
import { queries } from "@/zero/queries"; import { queries } from "@/zero/queries";
/**
* Premium credit balance shown in the sidebar.
*
* Values come from Zero (live-replicated from Postgres) and are stored as
* integer micro-USD (1_000_000 == $1.00). We render in dollars because
* users top up at $1/pack and the credit gets debited at actual provider
* cost.
*/
export function PremiumTokenUsageDisplay() { export function PremiumTokenUsageDisplay() {
const isAnonymous = useIsAnonymous(); const isAnonymous = useIsAnonymous();
const [me] = useQuery(queries.user.me({})); const [me] = useQuery(queries.user.me({}));
@ -12,21 +20,26 @@ export function PremiumTokenUsageDisplay() {
if (isAnonymous || !me) return null; if (isAnonymous || !me) return null;
const usagePercentage = Math.min( const usagePercentage = Math.min(
(me.premiumTokensUsed / Math.max(me.premiumTokensLimit, 1)) * 100, (me.premiumCreditMicrosUsed / Math.max(me.premiumCreditMicrosLimit, 1)) * 100,
100 100
); );
const formatTokens = (n: number) => { const formatUsd = (micros: number) => {
if (n >= 1_000_000) return `${(n / 1_000_000).toFixed(1)}M`; const dollars = micros / 1_000_000;
if (n >= 1_000) return `${(n / 1_000).toFixed(0)}K`; if (dollars >= 100) return `$${dollars.toFixed(0)}`;
return n.toLocaleString(); if (dollars >= 1) return `$${dollars.toFixed(2)}`;
// Sub-dollar balances need extra precision so the bar still tells the
// user what's left ("$0.04 of credit") instead of rounding to "$0".
if (dollars > 0) return `$${dollars.toFixed(3)}`;
return "$0";
}; };
return ( return (
<div className="space-y-1.5"> <div className="space-y-1.5">
<div className="flex justify-between items-center text-xs"> <div className="flex justify-between items-center text-xs">
<span className="text-muted-foreground"> <span className="text-muted-foreground">
{formatTokens(me.premiumTokensUsed)} / {formatTokens(me.premiumTokensLimit)} tokens {formatUsd(me.premiumCreditMicrosUsed)} / {formatUsd(me.premiumCreditMicrosLimit)} of
credit
</span> </span>
<span className="font-medium">{usagePercentage.toFixed(0)}%</span> <span className="font-medium">{usagePercentage.toFixed(0)}%</span>
</div> </div>

View file

@ -12,11 +12,11 @@ const demoPlans = [
price: "0", price: "0",
yearlyPrice: "0", yearlyPrice: "0",
period: "", period: "",
billingText: "500 pages + 3M premium tokens included", billingText: "500 pages + $5 of premium credit included",
features: [ features: [
"Self Hostable", "Self Hostable",
"500 pages included to start", "500 pages included to start",
"3 million premium tokens to start", "$5 of premium credit to start, billed at provider cost",
"Includes access to OpenAI text, audio and image models", "Includes access to OpenAI text, audio and image models",
"Realtime Collaborative Group Chats with teammates", "Realtime Collaborative Group Chats with teammates",
"Community support on Discord", "Community support on Discord",
@ -35,7 +35,7 @@ const demoPlans = [
features: [ features: [
"Everything in Free", "Everything in Free",
"Buy 1,000-page packs at $1 each", "Buy 1,000-page packs at $1 each",
"Buy 1M premium token packs at $1 each", "Top up premium credit at $1 per $1 of credit, billed at provider cost",
"Use premium AI models like GPT-5.4, Claude Sonnet 4.6, Gemini 2.5 Pro & 100+ more via OpenRouter", "Use premium AI models like GPT-5.4, Claude Sonnet 4.6, Gemini 2.5 Pro & 100+ more via OpenRouter",
"Priority support on Discord", "Priority support on Discord",
], ],
@ -129,27 +129,27 @@ const faqData: FAQSection[] = [
], ],
}, },
{ {
title: "Premium Tokens", title: "Premium Credit",
items: [ items: [
{ {
question: 'What are "premium tokens"?', question: 'What is "premium credit"?',
answer: answer:
"Premium tokens are the billing unit for using premium AI models like GPT-5.4, Claude Sonnet 4.6, and Gemini 2.5 Pro in SurfSense. Each AI request consumes tokens based on the length of your conversation. Non-premium models (such as free-tier models available without login) do not consume premium tokens.", "Premium credit is your USD balance for using premium AI models like GPT-5.4, Claude Sonnet 4.6, and Gemini 2.5 Pro in SurfSense. Each AI request debits the actual USD cost the provider charges, so cheap and expensive models bill proportionally. Non-premium models (such as the free-tier models available without login) don't touch your premium credit.",
}, },
{ {
question: "How many premium tokens do I get for free?", question: "How much premium credit do I get for free?",
answer: answer:
"Every registered SurfSense account starts with 3 million premium tokens at no cost. Anonymous users (no login) get 500,000 free tokens across all models. Once your free tokens are used up, you can purchase more at any time.", "Every registered SurfSense account starts with $5 of premium credit at no cost. Anonymous users (no login) get 500,000 free tokens across all free models. Once your free credit runs out, you can top up at any time.",
}, },
{ {
question: "How does purchasing premium tokens work?", question: "How does buying premium credit work?",
answer: answer:
"Just like pages, there's no subscription. You buy 1-million-token packs at $1 each whenever you need more. Purchased tokens are added to your account immediately. You can buy up to 100 packs at a time.", "Just like pages, there's no subscription. Top-ups buy $1 of credit for $1 — every cent you pay is spent at provider cost, no markup. Purchased credit is added to your account immediately. You can buy up to $100 at a time.",
}, },
{ {
question: "What happens if I run out of premium tokens?", question: "What happens if I run out of premium credit?",
answer: answer:
"When your premium token balance runs low (below 20%), you'll see a warning. Once you run out, premium model requests are paused until you purchase more tokens. You can always switch to non-premium models which don't consume premium tokens.", "When your premium credit balance runs low (below 20%), you'll see a warning. Once you run out, premium model requests are paused until you top up. You can always switch to non-premium models, which don't touch your premium credit.",
}, },
], ],
}, },
@ -157,9 +157,9 @@ const faqData: FAQSection[] = [
title: "Self-Hosting", title: "Self-Hosting",
items: [ items: [
{ {
question: "Can I self-host SurfSense with unlimited pages and tokens?", question: "Can I self-host SurfSense with unlimited pages and credit?",
answer: answer:
"Yes! When self-hosting, you have full control over your page and token limits. The default self-hosted setup gives you effectively unlimited pages and tokens, so you can index as much data and use as many AI queries as your infrastructure supports.", "Yes! When self-hosting, you have full control over your page and premium-credit limits. The default self-hosted setup gives you effectively unlimited pages and premium credit, so you can index as much data and use as many AI queries as your infrastructure supports.",
}, },
], ],
}, },
@ -250,8 +250,8 @@ function PricingFAQ() {
Frequently Asked Questions Frequently Asked Questions
</h2> </h2>
<p className="mx-auto mt-4 max-w-2xl text-lg text-muted-foreground"> <p className="mx-auto mt-4 max-w-2xl text-lg text-muted-foreground">
Everything you need to know about SurfSense pages, premium tokens, and billing. Can&apos;t Everything you need to know about SurfSense pages, premium credit, and billing.
find what you need? Reach out at{" "} Can&apos;t find what you need? Reach out at{" "}
<a href="mailto:rohan@surfsense.com" className="text-blue-500 underline"> <a href="mailto:rohan@surfsense.com" className="text-blue-500 underline">
rohan@surfsense.com rohan@surfsense.com
</a> </a>
@ -335,7 +335,7 @@ function PricingBasic() {
<Pricing <Pricing
plans={demoPlans} plans={demoPlans}
title="SurfSense Pricing" title="SurfSense Pricing"
description="Start free with 500 pages & 3M premium tokens. Pay as you go." description="Start free with 500 pages & $5 of premium credit. Pay as you go, billed at provider cost."
/> />
<PricingFAQ /> <PricingFAQ />
</> </>

View file

@ -14,10 +14,23 @@ import { AppError } from "@/lib/error";
import { cn } from "@/lib/utils"; import { cn } from "@/lib/utils";
import { queries } from "@/zero/queries"; import { queries } from "@/zero/queries";
const TOKEN_PACK_SIZE = 1_000_000; // One pack = $1.00 of credit, stored as 1_000_000 micro-USD on the
// backend. Premium turns are debited at the actual provider cost
// reported by LiteLLM, so $1 of credit always buys $1 of provider
// usage at cost.
const CREDIT_PER_PACK_MICROS = 1_000_000;
const PRICE_PER_PACK_USD = 1; const PRICE_PER_PACK_USD = 1;
const PRESET_MULTIPLIERS = [1, 2, 5, 10, 25, 50] as const; const PRESET_MULTIPLIERS = [1, 2, 5, 10, 25, 50] as const;
const formatUsd = (micros: number, options?: { compact?: boolean }) => {
const dollars = micros / 1_000_000;
if (options?.compact && dollars >= 1) return `$${dollars.toFixed(2)}`;
if (dollars >= 100) return `$${dollars.toFixed(0)}`;
if (dollars >= 1) return `$${dollars.toFixed(2)}`;
if (dollars > 0) return `$${dollars.toFixed(3)}`;
return "$0";
};
export function BuyTokensContent() { export function BuyTokensContent() {
const params = useParams(); const params = useParams();
const searchSpaceId = Number(params?.search_space_id); const searchSpaceId = Number(params?.search_space_id);
@ -29,7 +42,7 @@ export function BuyTokensContent() {
queryFn: () => stripeApiService.getTokenStatus(), queryFn: () => stripeApiService.getTokenStatus(),
}); });
// Live per-user usage via Zero. // Live per-user balance via Zero.
const [me] = useZeroQuery(queries.user.me({})); const [me] = useZeroQuery(queries.user.me({}));
const purchaseMutation = useMutation({ const purchaseMutation = useMutation({
@ -46,44 +59,46 @@ export function BuyTokensContent() {
}, },
}); });
const totalTokens = quantity * TOKEN_PACK_SIZE; const totalCreditMicros = quantity * CREDIT_PER_PACK_MICROS;
const totalPrice = quantity * PRICE_PER_PACK_USD; const totalPrice = quantity * PRICE_PER_PACK_USD;
if (tokenStatus && !tokenStatus.token_buying_enabled) { if (tokenStatus && !tokenStatus.token_buying_enabled) {
return ( return (
<div className="w-full space-y-3 text-center"> <div className="w-full space-y-3 text-center">
<h2 className="text-xl font-bold tracking-tight">Buy Premium Tokens</h2> <h2 className="text-xl font-bold tracking-tight">Buy Premium Credit</h2>
<p className="text-sm text-muted-foreground"> <p className="text-sm text-muted-foreground">
Token purchases are temporarily unavailable. Credit purchases are temporarily unavailable.
</p> </p>
</div> </div>
); );
} }
const used = me?.premiumTokensUsed ?? 0; const used = me?.premiumCreditMicrosUsed ?? 0;
const limit = me?.premiumTokensLimit ?? 0; const limit = me?.premiumCreditMicrosLimit ?? 0;
// Mirrors the backend formula in stripe_routes.py:608 (max(0, limit - used)). // Mirrors the backend formula in stripe_routes.py (max(0, limit - used)).
const remaining = Math.max(0, limit - used); const remaining = Math.max(0, limit - used);
const usagePercentage = me ? Math.min((used / Math.max(limit, 1)) * 100, 100) : 0; const usagePercentage = me ? Math.min((used / Math.max(limit, 1)) * 100, 100) : 0;
return ( return (
<div className="w-full space-y-5"> <div className="w-full space-y-5">
<div className="text-center"> <div className="text-center">
<h2 className="text-xl font-bold tracking-tight">Buy Premium Tokens</h2> <h2 className="text-xl font-bold tracking-tight">Buy Premium Credit</h2>
<p className="mt-1 text-sm text-muted-foreground">$1 per 1M tokens, pay as you go</p> <p className="mt-1 text-sm text-muted-foreground">
$1 buys $1 of credit, billed at provider cost
</p>
</div> </div>
{me && ( {me && (
<div className="rounded-lg border bg-muted/20 p-3 space-y-1.5"> <div className="rounded-lg border bg-muted/20 p-3 space-y-1.5">
<div className="flex justify-between items-center text-xs"> <div className="flex justify-between items-center text-xs">
<span className="text-muted-foreground"> <span className="text-muted-foreground">
{used.toLocaleString()} / {limit.toLocaleString()} premium tokens {formatUsd(used)} / {formatUsd(limit)} of credit
</span> </span>
<span className="font-medium">{usagePercentage.toFixed(0)}%</span> <span className="font-medium">{usagePercentage.toFixed(0)}%</span>
</div> </div>
<Progress value={usagePercentage} className="h-1.5" /> <Progress value={usagePercentage} className="h-1.5" />
<p className="text-[11px] text-muted-foreground"> <p className="text-[11px] text-muted-foreground">
{remaining.toLocaleString()} tokens remaining {formatUsd(remaining)} of credit remaining
</p> </p>
</div> </div>
)} )}
@ -99,7 +114,7 @@ export function BuyTokensContent() {
<Minus className="h-3.5 w-3.5" /> <Minus className="h-3.5 w-3.5" />
</button> </button>
<span className="min-w-32 text-center text-lg font-semibold tabular-nums"> <span className="min-w-32 text-center text-lg font-semibold tabular-nums">
{(totalTokens / 1_000_000).toFixed(0)}M tokens ${(totalCreditMicros / 1_000_000).toFixed(0)} of credit
</span> </span>
<button <button
type="button" type="button"
@ -125,14 +140,14 @@ export function BuyTokensContent() {
: "border-border hover:border-purple-500/40 hover:bg-muted/40" : "border-border hover:border-purple-500/40 hover:bg-muted/40"
)} )}
> >
{m}M ${m}
</button> </button>
))} ))}
</div> </div>
<div className="flex items-center justify-between rounded-lg border bg-muted/30 px-3 py-2"> <div className="flex items-center justify-between rounded-lg border bg-muted/30 px-3 py-2">
<span className="text-sm font-medium tabular-nums"> <span className="text-sm font-medium tabular-nums">
{(totalTokens / 1_000_000).toFixed(0)}M premium tokens ${(totalCreditMicros / 1_000_000).toFixed(0)} of credit
</span> </span>
<span className="text-sm font-semibold tabular-nums">${totalPrice}</span> <span className="text-sm font-semibold tabular-nums">${totalPrice}</span>
</div> </div>
@ -149,7 +164,7 @@ export function BuyTokensContent() {
</> </>
) : ( ) : (
<> <>
Buy {(totalTokens / 1_000_000).toFixed(0)}M Tokens for ${totalPrice} Buy ${(totalCreditMicros / 1_000_000).toFixed(0)} of credit for ${totalPrice}
</> </>
)} )}
</Button> </Button>

View file

@ -190,7 +190,25 @@ export function ImageModelManager({ searchSpaceId }: ImageModelManagerProps) {
? "model" ? "model"
: "models"} : "models"}
</span>{" "} </span>{" "}
available from your administrator. available from your administrator.{" "}
{(() => {
const nonAuto = globalConfigs.filter(
(g) => !("is_auto_mode" in g && g.is_auto_mode)
);
const premium = nonAuto.filter(
(g) =>
"billing_tier" in g &&
(g as { billing_tier?: string }).billing_tier === "premium"
).length;
const free = nonAuto.length - premium;
if (premium > 0 && free > 0) {
return `${premium} premium, ${free} free.`;
}
if (premium > 0) {
return `All ${premium} premium — debits your shared credit pool.`;
}
return `All ${free} free.`;
})()}
</p> </p>
</AlertDescription> </AlertDescription>
</Alert> </Alert>

View file

@ -371,6 +371,17 @@ export function LLMRoleManager({ searchSpaceId }: LLMRoleManagerProps) {
</SelectLabel> </SelectLabel>
{roleGlobalConfigs.map((config) => { {roleGlobalConfigs.map((config) => {
const isAuto = "is_auto_mode" in config && config.is_auto_mode; const isAuto = "is_auto_mode" in config && config.is_auto_mode;
// Read billing_tier from the global config; default to "free"
// for legacy YAMLs / Auto stub. Premium gets a purple badge,
// free gets an emerald one — same palette as the chat
// model selector so the meaning is consistent across
// surfaces (issues E, H).
const billingTier =
("billing_tier" in config &&
typeof config.billing_tier === "string" &&
config.billing_tier) ||
"free";
const isPremium = billingTier === "premium";
return ( return (
<SelectItem <SelectItem
key={config.id} key={config.id}
@ -382,13 +393,27 @@ export function LLMRoleManager({ searchSpaceId }: LLMRoleManagerProps) {
<span className="truncate text-xs md:text-sm"> <span className="truncate text-xs md:text-sm">
{config.name} {config.name}
</span> </span>
{isAuto && ( {isAuto ? (
<Badge <Badge
variant="secondary" variant="secondary"
className="text-[8px] md:text-[9px] shrink-0 bg-zinc-200 text-zinc-600 dark:bg-zinc-700 dark:text-zinc-300 [[data-slot=select-trigger]_&]:hidden" className="text-[8px] md:text-[9px] shrink-0 bg-zinc-200 text-zinc-600 dark:bg-zinc-700 dark:text-zinc-300 [[data-slot=select-trigger]_&]:hidden"
> >
Recommended Recommended
</Badge> </Badge>
) : isPremium ? (
<Badge
variant="secondary"
className="text-[8px] md:text-[9px] shrink-0 bg-purple-100 text-purple-700 dark:bg-purple-900/50 dark:text-purple-300 border-0 [[data-slot=select-trigger]_&]:hidden"
>
Premium
</Badge>
) : (
<Badge
variant="secondary"
className="text-[8px] md:text-[9px] shrink-0 bg-emerald-100 text-emerald-700 dark:bg-emerald-900/50 dark:text-emerald-300 border-0 [[data-slot=select-trigger]_&]:hidden"
>
Free
</Badge>
)} )}
</div> </div>
</SelectItem> </SelectItem>

View file

@ -191,7 +191,25 @@ export function VisionModelManager({ searchSpaceId }: VisionModelManagerProps) {
? "model" ? "model"
: "models"} : "models"}
</span>{" "} </span>{" "}
available from your administrator. available from your administrator.{" "}
{(() => {
const nonAuto = globalConfigs.filter(
(g) => !("is_auto_mode" in g && g.is_auto_mode)
);
const premium = nonAuto.filter(
(g) =>
"billing_tier" in g &&
(g as { billing_tier?: string }).billing_tier === "premium"
).length;
const free = nonAuto.length - premium;
if (premium > 0 && free > 0) {
return `${premium} premium, ${free} free.`;
}
if (premium > 0) {
return `All ${premium} premium — debits your shared credit pool.`;
}
return `All ${free} free.`;
})()}
</p> </p>
</AlertDescription> </AlertDescription>
</Alert> </Alert>

View file

@ -44,8 +44,8 @@ export function LoginGateProvider({ children }: { children: ReactNode }) {
<DialogHeader> <DialogHeader>
<DialogTitle>Create a free account to {feature}</DialogTitle> <DialogTitle>Create a free account to {feature}</DialogTitle>
<DialogDescription> <DialogDescription>
Get 3 million tokens, save chat history, upload documents, use all AI tools, and Get $5 of premium credit, save chat history, upload documents, use all AI tools,
connect 30+ integrations. and connect 30+ integrations.
</DialogDescription> </DialogDescription>
</DialogHeader> </DialogHeader>
<DialogFooter className="flex flex-col gap-2 sm:flex-row"> <DialogFooter className="flex flex-col gap-2 sm:flex-row">

View file

@ -258,6 +258,8 @@ export const globalImageGenConfig = z.object({
litellm_params: z.record(z.string(), z.any()).nullable().optional(), litellm_params: z.record(z.string(), z.any()).nullable().optional(),
is_global: z.literal(true), is_global: z.literal(true),
is_auto_mode: z.boolean().optional().default(false), is_auto_mode: z.boolean().optional().default(false),
billing_tier: z.string().default("free"),
quota_reserve_micros: z.number().nullable().optional(),
}); });
export const getGlobalImageGenConfigsResponse = z.array(globalImageGenConfig); export const getGlobalImageGenConfigsResponse = z.array(globalImageGenConfig);
@ -338,6 +340,10 @@ export const globalVisionLLMConfig = z.object({
litellm_params: z.record(z.string(), z.any()).nullable().optional(), litellm_params: z.record(z.string(), z.any()).nullable().optional(),
is_global: z.literal(true), is_global: z.literal(true),
is_auto_mode: z.boolean().optional().default(false), is_auto_mode: z.boolean().optional().default(false),
billing_tier: z.string().default("free"),
quota_reserve_tokens: z.number().nullable().optional(),
input_cost_per_token: z.number().nullable().optional(),
output_cost_per_token: z.number().nullable().optional(),
}); });
export const getGlobalVisionLLMConfigsResponse = z.array(globalVisionLLMConfig); export const getGlobalVisionLLMConfigsResponse = z.array(globalVisionLLMConfig);

View file

@ -32,7 +32,7 @@ export const getPagePurchasesResponse = z.object({
purchases: z.array(pagePurchase), purchases: z.array(pagePurchase),
}); });
// Premium token purchases // Premium credit purchases
export const createTokenCheckoutSessionRequest = z.object({ export const createTokenCheckoutSessionRequest = z.object({
quantity: z.number().int().min(1).max(100), quantity: z.number().int().min(1).max(100),
search_space_id: z.number().int().min(1), search_space_id: z.number().int().min(1),
@ -42,11 +42,16 @@ export const createTokenCheckoutSessionResponse = z.object({
checkout_url: z.string(), checkout_url: z.string(),
}); });
// Premium credit balance + purchase records.
//
// The unit is integer micro-USD (1_000_000 == $1.00). The schema names
// kept the ``Token`` prefix for API back-compat with pinned clients;
// the field names below are authoritative.
export const tokenStripeStatusResponse = z.object({ export const tokenStripeStatusResponse = z.object({
token_buying_enabled: z.boolean(), token_buying_enabled: z.boolean(),
premium_tokens_used: z.number().default(0), premium_credit_micros_used: z.number().default(0),
premium_tokens_limit: z.number().default(0), premium_credit_micros_limit: z.number().default(0),
premium_tokens_remaining: z.number().default(0), premium_credit_micros_remaining: z.number().default(0),
}); });
export const tokenPurchaseStatusEnum = pagePurchaseStatusEnum; export const tokenPurchaseStatusEnum = pagePurchaseStatusEnum;
@ -56,7 +61,7 @@ export const tokenPurchase = z.object({
stripe_checkout_session_id: z.string(), stripe_checkout_session_id: z.string(),
stripe_payment_intent_id: z.string().nullable(), stripe_payment_intent_id: z.string().nullable(),
quantity: z.number(), quantity: z.number(),
tokens_granted: z.number(), credit_micros_granted: z.number(),
amount_total: z.number().nullable(), amount_total: z.number().nullable(),
currency: z.string().nullable(), currency: z.string().nullable(),
status: tokenPurchaseStatusEnum, status: tokenPurchaseStatusEnum,

View file

@ -41,7 +41,7 @@ export interface RawChatErrorInput {
} }
export const PREMIUM_QUOTA_ASSISTANT_MESSAGE = export const PREMIUM_QUOTA_ASSISTANT_MESSAGE =
"I cant continue with the current premium model because your premium tokens are exhausted. Switch to a free model or buy more tokens to continue."; "I cant continue with the current premium model because your premium credit is exhausted. Switch to a free model or top up your credit to continue.";
function getErrorMessage(error: unknown): string { function getErrorMessage(error: unknown): string {
if (error instanceof Error) return error.message; if (error instanceof Error) return error.message;

View file

@ -541,16 +541,23 @@ export type SSEEvent =
data: { data: {
usage: Record< usage: Record<
string, string,
{ prompt_tokens: number; completion_tokens: number; total_tokens: number } {
prompt_tokens: number;
completion_tokens: number;
total_tokens: number;
cost_micros?: number;
}
>; >;
prompt_tokens: number; prompt_tokens: number;
completion_tokens: number; completion_tokens: number;
total_tokens: number; total_tokens: number;
cost_micros?: number;
call_details: Array<{ call_details: Array<{
model: string; model: string;
prompt_tokens: number; prompt_tokens: number;
completion_tokens: number; completion_tokens: number;
total_tokens: number; total_tokens: number;
cost_micros?: number;
}>; }>;
}; };
} }

View file

@ -30,9 +30,20 @@ export interface TokenUsageSummary {
prompt_tokens: number; prompt_tokens: number;
completion_tokens: number; completion_tokens: number;
total_tokens: number; total_tokens: number;
/**
* Total provider USD cost for this assistant turn, in micro-USD
* (1_000_000 = $1.00). Optional because rows persisted before the
* cost-credits migration won't have it.
*/
cost_micros?: number;
model_breakdown?: Record< model_breakdown?: Record<
string, string,
{ prompt_tokens: number; completion_tokens: number; total_tokens: number } {
prompt_tokens: number;
completion_tokens: number;
total_tokens: number;
cost_micros?: number;
}
> | null; > | null;
} }

View file

@ -1,11 +1,20 @@
import { number, string, table } from "@rocicorp/zero"; import { number, string, table } from "@rocicorp/zero";
/**
* Live-meter slice of the ``user`` table replicated through Zero.
*
* ``premiumCreditMicrosLimit`` / ``premiumCreditMicrosUsed`` are stored
* as integer micro-USD (1_000_000 == $1.00). UI consumers divide by 1M
* when displaying. Sensitive fields (email, hashed_password, oauth, etc.)
* are intentionally omitted via the Postgres column-list publication so
* they never enter WAL replication.
*/
export const userTable = table("user") export const userTable = table("user")
.columns({ .columns({
id: string(), id: string(),
pagesLimit: number().from("pages_limit"), pagesLimit: number().from("pages_limit"),
pagesUsed: number().from("pages_used"), pagesUsed: number().from("pages_used"),
premiumTokensLimit: number().from("premium_tokens_limit"), premiumCreditMicrosLimit: number().from("premium_credit_micros_limit"),
premiumTokensUsed: number().from("premium_tokens_used"), premiumCreditMicrosUsed: number().from("premium_credit_micros_used"),
}) })
.primaryKey("id"); .primaryKey("id");