mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-05 13:52:40 +02:00
feat: unified credits and its cost calculations
This commit is contained in:
parent
451a98936e
commit
ae9d36d77f
61 changed files with 5835 additions and 272 deletions
|
|
@ -54,11 +54,15 @@ STRIPE_PAGES_PER_UNIT=1000
|
|||
# Set FALSE to disable new checkout session creation temporarily
|
||||
STRIPE_PAGE_BUYING_ENABLED=TRUE
|
||||
|
||||
# Premium token purchases via Stripe (for premium-tier model usage)
|
||||
# Set TRUE to allow users to buy premium token packs ($1 per 1M tokens)
|
||||
# Premium credit purchases via Stripe (for premium-tier model usage).
|
||||
# Each pack grants STRIPE_CREDIT_MICROS_PER_UNIT micro-USD of credit
|
||||
# (default 1_000_000 = $1.00). Premium turns are billed at the actual
|
||||
# per-call provider cost reported by LiteLLM.
|
||||
STRIPE_TOKEN_BUYING_ENABLED=FALSE
|
||||
STRIPE_PREMIUM_TOKEN_PRICE_ID=price_...
|
||||
STRIPE_TOKENS_PER_UNIT=1000000
|
||||
STRIPE_CREDIT_MICROS_PER_UNIT=1000000
|
||||
# DEPRECATED — use STRIPE_CREDIT_MICROS_PER_UNIT (1:1 numerical mapping):
|
||||
# STRIPE_TOKENS_PER_UNIT=1000000
|
||||
|
||||
# Periodic Stripe safety net for purchases left in PENDING (minutes old)
|
||||
STRIPE_RECONCILIATION_LOOKBACK_MINUTES=10
|
||||
|
|
@ -184,9 +188,35 @@ VIDEO_PRESENTATION_DEFAULT_DURATION_IN_FRAMES=300
|
|||
# (Optional) Maximum pages limit per user for ETL services (default: `999999999` for unlimited in OSS version)
|
||||
PAGES_LIMIT=500
|
||||
|
||||
# Premium token quota per registered user (default: 3,000,000)
|
||||
# Applies only to models with billing_tier=premium in global_llm_config.yaml
|
||||
PREMIUM_TOKEN_LIMIT=3000000
|
||||
# Premium credit quota per registered user, in micro-USD
|
||||
# (default: 5,000,000 == $5.00 of credit). Premium turns are debited at the
|
||||
# actual per-call provider cost reported by LiteLLM, so cheap and expensive
|
||||
# models bill proportionally. Applies only to models with
|
||||
# billing_tier=premium in global_llm_config.yaml.
|
||||
PREMIUM_CREDIT_MICROS_LIMIT=5000000
|
||||
# DEPRECATED — use PREMIUM_CREDIT_MICROS_LIMIT (1:1 numerical mapping):
|
||||
# PREMIUM_TOKEN_LIMIT=5000000
|
||||
|
||||
# Safety ceiling on per-call premium reservation, in micro-USD.
|
||||
# stream_new_chat estimates an upper-bound cost from the model's
|
||||
# litellm-published per-token rates × the config's quota_reserve_tokens
|
||||
# and clamps to this value so a misconfigured model can't lock the
|
||||
# user's whole balance on one call. Default $1.00.
|
||||
QUOTA_MAX_RESERVE_MICROS=1000000
|
||||
|
||||
# Per-image reservation (in micro-USD) for the POST /image-generations
|
||||
# endpoint. Bypassed for free configs. Default $0.05.
|
||||
QUOTA_DEFAULT_IMAGE_RESERVE_MICROS=50000
|
||||
|
||||
# Per-podcast reservation (in micro-USD) used by the podcast Celery task.
|
||||
# Single envelope covers one transcript-generation LLM call. Default $0.20.
|
||||
QUOTA_DEFAULT_PODCAST_RESERVE_MICROS=200000
|
||||
|
||||
# Per-video-presentation reservation (in micro-USD) used by the video
|
||||
# presentation Celery task. Covers worst-case fan-out of N slide-scene
|
||||
# generations + refines. Default $1.00. NOTE: tasks using the override
|
||||
# path bypass the QUOTA_MAX_RESERVE_MICROS clamp — raise with care.
|
||||
QUOTA_DEFAULT_VIDEO_PRESENTATION_RESERVE_MICROS=1000000
|
||||
|
||||
# No-login (anonymous) mode — allows public users to chat without an account
|
||||
# Set TRUE to enable /free pages and anonymous chat API
|
||||
|
|
|
|||
|
|
@ -0,0 +1,291 @@
|
|||
"""rename premium token columns to credit micros and add cost_micros to token_usage
|
||||
|
||||
Migrates the premium quota system from a flat token counter to a USD-cost
|
||||
based credit system, where 1 credit = 1 micro-USD ($0.000001).
|
||||
|
||||
Column renames (1:1 numerical mapping — the prior $1 per 1M tokens Stripe
|
||||
price means every existing value is already correct in the new unit, no
|
||||
data transformation needed):
|
||||
|
||||
user.premium_tokens_limit -> premium_credit_micros_limit
|
||||
user.premium_tokens_used -> premium_credit_micros_used
|
||||
user.premium_tokens_reserved -> premium_credit_micros_reserved
|
||||
|
||||
premium_token_purchases.tokens_granted -> credit_micros_granted
|
||||
|
||||
New column for cost auditing per turn:
|
||||
|
||||
token_usage.cost_micros (BigInteger NOT NULL DEFAULT 0)
|
||||
|
||||
The "user" table is in zero_publication's column list (added in 139), so
|
||||
this migration must drop and recreate the publication with the renamed
|
||||
column names, otherwise zero-cache will replicate stale column names and
|
||||
the FE Zero schema will fail to bind.
|
||||
|
||||
IMPORTANT - before AND after running this migration:
|
||||
1. Stop zero-cache (it holds replication locks that will deadlock DDL)
|
||||
2. Run: alembic upgrade head
|
||||
3. Delete / reset the zero-cache data volume
|
||||
4. Restart zero-cache (it will do a fresh initial sync)
|
||||
|
||||
Skipping the zero-cache stop will deadlock at the ACCESS EXCLUSIVE LOCK on
|
||||
"user". Skipping the data-volume reset will leave IndexedDB clients seeing
|
||||
column-not-found errors from a stale catalog snapshot.
|
||||
|
||||
Revision ID: 140
|
||||
Revises: 139
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
revision: str = "140"
|
||||
down_revision: str | None = "139"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
PUBLICATION_NAME = "zero_publication"
|
||||
|
||||
# Replicates 139's document column list verbatim — must stay in sync.
|
||||
DOCUMENT_COLS = [
|
||||
"id",
|
||||
"title",
|
||||
"document_type",
|
||||
"search_space_id",
|
||||
"folder_id",
|
||||
"created_by_id",
|
||||
"status",
|
||||
"created_at",
|
||||
"updated_at",
|
||||
]
|
||||
|
||||
# Same five live-meter fields as 139, with the renamed column names.
|
||||
USER_COLS = [
|
||||
"id",
|
||||
"pages_limit",
|
||||
"pages_used",
|
||||
"premium_credit_micros_limit",
|
||||
"premium_credit_micros_used",
|
||||
]
|
||||
|
||||
|
||||
def _terminate_blocked_pids(conn, table: str) -> None:
|
||||
"""Kill backends whose locks on *table* would block our AccessExclusiveLock."""
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"SELECT pg_terminate_backend(l.pid) "
|
||||
"FROM pg_locks l "
|
||||
"JOIN pg_class c ON c.oid = l.relation "
|
||||
"WHERE c.relname = :tbl "
|
||||
" AND l.pid != pg_backend_pid()"
|
||||
),
|
||||
{"tbl": table},
|
||||
)
|
||||
|
||||
|
||||
def _has_zero_version(conn, table: str) -> bool:
|
||||
return (
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"SELECT 1 FROM information_schema.columns "
|
||||
"WHERE table_name = :tbl AND column_name = '_0_version'"
|
||||
),
|
||||
{"tbl": table},
|
||||
).fetchone()
|
||||
is not None
|
||||
)
|
||||
|
||||
|
||||
def _column_exists(conn, table: str, column: str) -> bool:
|
||||
return (
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"SELECT 1 FROM information_schema.columns "
|
||||
"WHERE table_name = :tbl AND column_name = :col"
|
||||
),
|
||||
{"tbl": table, "col": column},
|
||||
).fetchone()
|
||||
is not None
|
||||
)
|
||||
|
||||
|
||||
def _build_publication_ddl(
|
||||
user_cols: list[str],
|
||||
*,
|
||||
documents_has_zero_ver: bool,
|
||||
user_has_zero_ver: bool,
|
||||
) -> str:
|
||||
doc_cols = DOCUMENT_COLS + (['"_0_version"'] if documents_has_zero_ver else [])
|
||||
user_col_list_with_meta = user_cols + (
|
||||
['"_0_version"'] if user_has_zero_ver else []
|
||||
)
|
||||
doc_col_list = ", ".join(doc_cols)
|
||||
user_col_list = ", ".join(user_col_list_with_meta)
|
||||
return (
|
||||
f"CREATE PUBLICATION {PUBLICATION_NAME} FOR TABLE "
|
||||
f"notifications, "
|
||||
f"documents ({doc_col_list}), "
|
||||
f"folders, "
|
||||
f"search_source_connectors, "
|
||||
f"new_chat_messages, "
|
||||
f"chat_comments, "
|
||||
f"chat_session_state, "
|
||||
f'"user" ({user_col_list})'
|
||||
)
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 1. Add cost_micros to token_usage. Idempotent guard so re-runs in
|
||||
# dev environments are safe.
|
||||
# ------------------------------------------------------------------
|
||||
if not _column_exists(conn, "token_usage", "cost_micros"):
|
||||
op.add_column(
|
||||
"token_usage",
|
||||
sa.Column(
|
||||
"cost_micros",
|
||||
sa.BigInteger(),
|
||||
nullable=False,
|
||||
server_default="0",
|
||||
),
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 2. Rename premium_token_purchases.tokens_granted -> credit_micros_granted.
|
||||
# ------------------------------------------------------------------
|
||||
if _column_exists(
|
||||
conn, "premium_token_purchases", "tokens_granted"
|
||||
) and not _column_exists(conn, "premium_token_purchases", "credit_micros_granted"):
|
||||
op.alter_column(
|
||||
"premium_token_purchases",
|
||||
"tokens_granted",
|
||||
new_column_name="credit_micros_granted",
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 3. Rename user.premium_tokens_* -> premium_credit_micros_*.
|
||||
#
|
||||
# We must drop the publication first (it references the old column
|
||||
# names) and re-acquire the lock for DDL. asyncpg requires LOCK TABLE
|
||||
# in a transaction block; alembic's outer transaction already holds
|
||||
# one, but a SAVEPOINT keeps the LOCK + DDL atomic.
|
||||
# ------------------------------------------------------------------
|
||||
tx = conn.begin_nested() if conn.in_transaction() else conn.begin()
|
||||
with tx:
|
||||
conn.execute(sa.text("SET lock_timeout = '10s'"))
|
||||
|
||||
_terminate_blocked_pids(conn, "user")
|
||||
conn.execute(sa.text('LOCK TABLE "user" IN ACCESS EXCLUSIVE MODE'))
|
||||
|
||||
# Re-assert REPLICA IDENTITY DEFAULT for safety; column-list
|
||||
# publications require at least the PK to be in the column list,
|
||||
# which is true for both the old and new shape.
|
||||
conn.execute(sa.text('ALTER TABLE "user" REPLICA IDENTITY DEFAULT'))
|
||||
|
||||
# Drop the publication BEFORE renaming columns, otherwise Postgres
|
||||
# rejects the rename: "cannot drop column ... referenced by
|
||||
# publication".
|
||||
conn.execute(sa.text(f"DROP PUBLICATION IF EXISTS {PUBLICATION_NAME}"))
|
||||
|
||||
for old, new in (
|
||||
("premium_tokens_limit", "premium_credit_micros_limit"),
|
||||
("premium_tokens_used", "premium_credit_micros_used"),
|
||||
("premium_tokens_reserved", "premium_credit_micros_reserved"),
|
||||
):
|
||||
if _column_exists(conn, "user", old) and not _column_exists(
|
||||
conn, "user", new
|
||||
):
|
||||
op.alter_column("user", old, new_column_name=new)
|
||||
|
||||
# Update the server_default on the renamed limit column so newly
|
||||
# inserted users get $5 of credit (== 5_000_000 micros) by
|
||||
# default. Existing rows are unaffected.
|
||||
op.alter_column(
|
||||
"user",
|
||||
"premium_credit_micros_limit",
|
||||
server_default="5000000",
|
||||
)
|
||||
|
||||
# Recreate the publication with the new column names.
|
||||
documents_has_zero_ver = _has_zero_version(conn, "documents")
|
||||
user_has_zero_ver = _has_zero_version(conn, "user")
|
||||
conn.execute(
|
||||
sa.text(
|
||||
_build_publication_ddl(
|
||||
USER_COLS,
|
||||
documents_has_zero_ver=documents_has_zero_ver,
|
||||
user_has_zero_ver=user_has_zero_ver,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Revert the rename and drop ``cost_micros``.
|
||||
|
||||
Mirrors ``upgrade``: drop the publication, rename columns back, drop
|
||||
the new column, recreate the publication with the old column list.
|
||||
Same zero-cache stop/reset runbook applies in reverse.
|
||||
"""
|
||||
conn = op.get_bind()
|
||||
|
||||
tx = conn.begin_nested() if conn.in_transaction() else conn.begin()
|
||||
with tx:
|
||||
conn.execute(sa.text("SET lock_timeout = '10s'"))
|
||||
|
||||
_terminate_blocked_pids(conn, "user")
|
||||
conn.execute(sa.text('LOCK TABLE "user" IN ACCESS EXCLUSIVE MODE'))
|
||||
|
||||
conn.execute(sa.text(f"DROP PUBLICATION IF EXISTS {PUBLICATION_NAME}"))
|
||||
|
||||
for new, old in (
|
||||
("premium_credit_micros_limit", "premium_tokens_limit"),
|
||||
("premium_credit_micros_used", "premium_tokens_used"),
|
||||
("premium_credit_micros_reserved", "premium_tokens_reserved"),
|
||||
):
|
||||
if _column_exists(conn, "user", new) and not _column_exists(
|
||||
conn, "user", old
|
||||
):
|
||||
op.alter_column("user", new, new_column_name=old)
|
||||
|
||||
op.alter_column(
|
||||
"user",
|
||||
"premium_tokens_limit",
|
||||
server_default="5000000",
|
||||
)
|
||||
|
||||
legacy_user_cols = [
|
||||
"id",
|
||||
"pages_limit",
|
||||
"pages_used",
|
||||
"premium_tokens_limit",
|
||||
"premium_tokens_used",
|
||||
]
|
||||
documents_has_zero_ver = _has_zero_version(conn, "documents")
|
||||
user_has_zero_ver = _has_zero_version(conn, "user")
|
||||
conn.execute(
|
||||
sa.text(
|
||||
_build_publication_ddl(
|
||||
legacy_user_cols,
|
||||
documents_has_zero_ver=documents_has_zero_ver,
|
||||
user_has_zero_ver=user_has_zero_ver,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
if _column_exists(
|
||||
conn, "premium_token_purchases", "credit_micros_granted"
|
||||
) and not _column_exists(conn, "premium_token_purchases", "tokens_granted"):
|
||||
op.alter_column(
|
||||
"premium_token_purchases",
|
||||
"credit_micros_granted",
|
||||
new_column_name="tokens_granted",
|
||||
)
|
||||
|
||||
if _column_exists(conn, "token_usage", "cost_micros"):
|
||||
op.drop_column("token_usage", "cost_micros")
|
||||
|
|
@ -31,6 +31,7 @@ from app.config import (
|
|||
initialize_image_gen_router,
|
||||
initialize_llm_router,
|
||||
initialize_openrouter_integration,
|
||||
initialize_pricing_registration,
|
||||
initialize_vision_llm_router,
|
||||
)
|
||||
from app.db import User, create_db_and_tables, get_async_session
|
||||
|
|
@ -432,6 +433,7 @@ async def lifespan(app: FastAPI):
|
|||
await setup_checkpointer_tables()
|
||||
initialize_openrouter_integration()
|
||||
_start_openrouter_background_refresh()
|
||||
initialize_pricing_registration()
|
||||
initialize_llm_router()
|
||||
initialize_image_gen_router()
|
||||
initialize_vision_llm_router()
|
||||
|
|
|
|||
|
|
@ -22,10 +22,12 @@ def init_worker(**kwargs):
|
|||
initialize_image_gen_router,
|
||||
initialize_llm_router,
|
||||
initialize_openrouter_integration,
|
||||
initialize_pricing_registration,
|
||||
initialize_vision_llm_router,
|
||||
)
|
||||
|
||||
initialize_openrouter_integration()
|
||||
initialize_pricing_registration()
|
||||
initialize_llm_router()
|
||||
initialize_image_gen_router()
|
||||
initialize_vision_llm_router()
|
||||
|
|
|
|||
|
|
@ -138,7 +138,11 @@ def load_global_image_gen_configs():
|
|||
try:
|
||||
with open(global_config_file, encoding="utf-8") as f:
|
||||
data = yaml.safe_load(f)
|
||||
return data.get("global_image_generation_configs", [])
|
||||
configs = data.get("global_image_generation_configs", []) or []
|
||||
for cfg in configs:
|
||||
if isinstance(cfg, dict):
|
||||
cfg.setdefault("billing_tier", "free")
|
||||
return configs
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to load global image generation configs: {e}")
|
||||
return []
|
||||
|
|
@ -153,7 +157,11 @@ def load_global_vision_llm_configs():
|
|||
try:
|
||||
with open(global_config_file, encoding="utf-8") as f:
|
||||
data = yaml.safe_load(f)
|
||||
return data.get("global_vision_llm_configs", [])
|
||||
configs = data.get("global_vision_llm_configs", []) or []
|
||||
for cfg in configs:
|
||||
if isinstance(cfg, dict):
|
||||
cfg.setdefault("billing_tier", "free")
|
||||
return configs
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to load global vision LLM configs: {e}")
|
||||
return []
|
||||
|
|
@ -254,6 +262,15 @@ def load_openrouter_integration_settings() -> dict | None:
|
|||
"anonymous_enabled_free", settings["anonymous_enabled"]
|
||||
)
|
||||
|
||||
# Image generation + vision LLM emission are opt-in (issue L).
|
||||
# OpenRouter's catalogue contains hundreds of image / vision
|
||||
# capable models; auto-injecting all of them into every
|
||||
# deployment would explode the model selector and surprise
|
||||
# operators upgrading from prior versions. Default to False so
|
||||
# admins must explicitly turn them on.
|
||||
settings.setdefault("image_generation_enabled", False)
|
||||
settings.setdefault("vision_enabled", False)
|
||||
|
||||
return settings
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to load OpenRouter integration settings: {e}")
|
||||
|
|
@ -296,10 +313,60 @@ def initialize_openrouter_integration():
|
|||
)
|
||||
else:
|
||||
print("Info: OpenRouter integration enabled but no models fetched")
|
||||
|
||||
# Image generation + vision LLM emissions are opt-in (issue L).
|
||||
# Both reuse the catalogue already cached by ``service.initialize``
|
||||
# so we don't make additional network calls here.
|
||||
if settings.get("image_generation_enabled"):
|
||||
try:
|
||||
image_configs = service.get_image_generation_configs()
|
||||
if image_configs:
|
||||
config.GLOBAL_IMAGE_GEN_CONFIGS.extend(image_configs)
|
||||
print(
|
||||
f"Info: OpenRouter integration added {len(image_configs)} "
|
||||
f"image-generation models"
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to inject OpenRouter image-gen configs: {e}")
|
||||
|
||||
if settings.get("vision_enabled"):
|
||||
try:
|
||||
vision_configs = service.get_vision_llm_configs()
|
||||
if vision_configs:
|
||||
config.GLOBAL_VISION_LLM_CONFIGS.extend(vision_configs)
|
||||
print(
|
||||
f"Info: OpenRouter integration added {len(vision_configs)} "
|
||||
f"vision LLM models"
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to inject OpenRouter vision-LLM configs: {e}")
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to initialize OpenRouter integration: {e}")
|
||||
|
||||
|
||||
def initialize_pricing_registration():
|
||||
"""
|
||||
Teach LiteLLM the per-token cost of every deployment in
|
||||
``config.GLOBAL_LLM_CONFIGS`` (OpenRouter dynamic models pulled
|
||||
from the OpenRouter catalogue + any operator-declared YAML pricing).
|
||||
|
||||
Must run AFTER ``initialize_openrouter_integration()`` so the
|
||||
OpenRouter catalogue is populated and BEFORE the first LLM call so
|
||||
``response_cost`` is available in ``TokenTrackingCallback``.
|
||||
|
||||
Failures are logged but never raised — startup must not be blocked
|
||||
by a missing pricing entry; the worst-case is the model debits 0.
|
||||
"""
|
||||
try:
|
||||
from app.services.pricing_registration import (
|
||||
register_pricing_from_global_configs,
|
||||
)
|
||||
|
||||
register_pricing_from_global_configs()
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to register LiteLLM pricing: {e}")
|
||||
|
||||
|
||||
def initialize_llm_router():
|
||||
"""
|
||||
Initialize the LLM Router service for Auto mode.
|
||||
|
|
@ -444,14 +511,54 @@ class Config:
|
|||
os.getenv("STRIPE_RECONCILIATION_BATCH_SIZE", "100")
|
||||
)
|
||||
|
||||
# Premium token quota settings
|
||||
PREMIUM_TOKEN_LIMIT = int(os.getenv("PREMIUM_TOKEN_LIMIT", "3000000"))
|
||||
# Premium credit (micro-USD) quota settings.
|
||||
#
|
||||
# Storage unit is integer micro-USD (1_000_000 = $1.00). The legacy
|
||||
# ``PREMIUM_TOKEN_LIMIT`` and ``STRIPE_TOKENS_PER_UNIT`` env vars are
|
||||
# still honoured for one release as fall-back values — the prior
|
||||
# $1-per-1M-tokens Stripe price means every existing value maps 1:1
|
||||
# to micros, so operators upgrading without changing their .env still
|
||||
# get correct behaviour. A startup deprecation warning fires below if
|
||||
# they're set.
|
||||
PREMIUM_CREDIT_MICROS_LIMIT = int(
|
||||
os.getenv("PREMIUM_CREDIT_MICROS_LIMIT")
|
||||
or os.getenv("PREMIUM_TOKEN_LIMIT", "5000000")
|
||||
)
|
||||
STRIPE_PREMIUM_TOKEN_PRICE_ID = os.getenv("STRIPE_PREMIUM_TOKEN_PRICE_ID")
|
||||
STRIPE_TOKENS_PER_UNIT = int(os.getenv("STRIPE_TOKENS_PER_UNIT", "1000000"))
|
||||
STRIPE_CREDIT_MICROS_PER_UNIT = int(
|
||||
os.getenv("STRIPE_CREDIT_MICROS_PER_UNIT")
|
||||
or os.getenv("STRIPE_TOKENS_PER_UNIT", "1000000")
|
||||
)
|
||||
STRIPE_TOKEN_BUYING_ENABLED = (
|
||||
os.getenv("STRIPE_TOKEN_BUYING_ENABLED", "FALSE").upper() == "TRUE"
|
||||
)
|
||||
|
||||
# Safety ceiling on the per-call premium reservation. ``stream_new_chat``
|
||||
# estimates an upper-bound cost from ``litellm.get_model_info`` x the
|
||||
# config's ``quota_reserve_tokens`` and clamps the result to this value
|
||||
# so a misconfigured "$1000/M" model can't lock the user's whole balance
|
||||
# on one call. Default $1.00 covers realistic worst-cases (Opus + 4K
|
||||
# reserve_tokens ≈ $0.36) with headroom.
|
||||
QUOTA_MAX_RESERVE_MICROS = int(os.getenv("QUOTA_MAX_RESERVE_MICROS", "1000000"))
|
||||
|
||||
if os.getenv("PREMIUM_TOKEN_LIMIT") and not os.getenv(
|
||||
"PREMIUM_CREDIT_MICROS_LIMIT"
|
||||
):
|
||||
print(
|
||||
"Warning: PREMIUM_TOKEN_LIMIT is deprecated; rename to "
|
||||
"PREMIUM_CREDIT_MICROS_LIMIT (1:1 numerical mapping under the "
|
||||
"current Stripe price). The old key will be removed in a "
|
||||
"future release."
|
||||
)
|
||||
if os.getenv("STRIPE_TOKENS_PER_UNIT") and not os.getenv(
|
||||
"STRIPE_CREDIT_MICROS_PER_UNIT"
|
||||
):
|
||||
print(
|
||||
"Warning: STRIPE_TOKENS_PER_UNIT is deprecated; rename to "
|
||||
"STRIPE_CREDIT_MICROS_PER_UNIT (1:1 numerical mapping). "
|
||||
"The old key will be removed in a future release."
|
||||
)
|
||||
|
||||
# Anonymous / no-login mode settings
|
||||
NOLOGIN_MODE_ENABLED = os.getenv("NOLOGIN_MODE_ENABLED", "FALSE").upper() == "TRUE"
|
||||
ANON_TOKEN_LIMIT = int(os.getenv("ANON_TOKEN_LIMIT", "500000"))
|
||||
|
|
@ -464,6 +571,35 @@ class Config:
|
|||
# Default quota reserve tokens when not specified per-model
|
||||
QUOTA_MAX_RESERVE_PER_CALL = int(os.getenv("QUOTA_MAX_RESERVE_PER_CALL", "8000"))
|
||||
|
||||
# Per-image reservation (in micro-USD) used by ``billable_call`` for the
|
||||
# ``POST /image-generations`` endpoint when the global config does not
|
||||
# override it. $0.05 covers realistic worst-cases for current OpenAI /
|
||||
# OpenRouter image-gen pricing. Bypassed entirely for free configs.
|
||||
QUOTA_DEFAULT_IMAGE_RESERVE_MICROS = int(
|
||||
os.getenv("QUOTA_DEFAULT_IMAGE_RESERVE_MICROS", "50000")
|
||||
)
|
||||
|
||||
# Per-podcast reservation (in micro-USD). One agent LLM call generating
|
||||
# a transcript, typically 5k-20k completion tokens. $0.20 covers a long
|
||||
# premium-model run. Tune via env.
|
||||
QUOTA_DEFAULT_PODCAST_RESERVE_MICROS = int(
|
||||
os.getenv("QUOTA_DEFAULT_PODCAST_RESERVE_MICROS", "200000")
|
||||
)
|
||||
|
||||
# Per-video-presentation reservation (in micro-USD). Fan-out of N
|
||||
# slide-scene generations (up to ``VIDEO_PRESENTATION_MAX_SLIDES=30``)
|
||||
# plus refine retries; can produce many premium completions. $1.00
|
||||
# covers worst-case. Tune via env.
|
||||
#
|
||||
# NOTE: this equals the existing ``QUOTA_MAX_RESERVE_MICROS`` default of
|
||||
# 1_000_000. The override path in ``billable_call`` bypasses the
|
||||
# per-call clamp in ``estimate_call_reserve_micros``, so this is the
|
||||
# *actual* hold — raising it via env is fine but means a single video
|
||||
# task can lock $1+ of credit.
|
||||
QUOTA_DEFAULT_VIDEO_PRESENTATION_RESERVE_MICROS = int(
|
||||
os.getenv("QUOTA_DEFAULT_VIDEO_PRESENTATION_RESERVE_MICROS", "1000000")
|
||||
)
|
||||
|
||||
# Abuse prevention: concurrent stream cap and CAPTCHA
|
||||
ANON_MAX_CONCURRENT_STREAMS = int(os.getenv("ANON_MAX_CONCURRENT_STREAMS", "2"))
|
||||
ANON_CAPTCHA_REQUEST_THRESHOLD = int(
|
||||
|
|
|
|||
|
|
@ -19,6 +19,24 @@
|
|||
# Structure matches NewLLMConfig:
|
||||
# - Model configuration (provider, model_name, api_key, etc.)
|
||||
# - Prompt configuration (system_instructions, citations_enabled)
|
||||
#
|
||||
# COST-BASED PREMIUM CREDITS:
|
||||
# Each premium config bills the user's USD-credit balance based on the
|
||||
# actual provider cost reported by LiteLLM. For models LiteLLM already
|
||||
# knows (most OpenAI/Anthropic/etc. names) you don't need to do anything.
|
||||
# For custom Azure deployment names (e.g. an in-house "gpt-5.4" deployment)
|
||||
# or any model LiteLLM doesn't have in its built-in pricing table, declare
|
||||
# per-token costs inline so they bill correctly:
|
||||
#
|
||||
# litellm_params:
|
||||
# base_model: "my-custom-azure-deploy"
|
||||
# # USD per token; e.g. 0.000003 == $3.00 per million input tokens
|
||||
# input_cost_per_token: 0.000003
|
||||
# output_cost_per_token: 0.000015
|
||||
#
|
||||
# OpenRouter dynamic models pull pricing automatically from OpenRouter's
|
||||
# API — no inline declaration needed. Models without resolvable pricing
|
||||
# debit $0 from the user's balance and log a WARNING.
|
||||
|
||||
# Router Settings for Auto Mode
|
||||
# These settings control how the LiteLLM Router distributes requests across models
|
||||
|
|
@ -292,6 +310,17 @@ openrouter_integration:
|
|||
free_rpm: 20
|
||||
free_tpm: 100000
|
||||
|
||||
# Image generation + vision LLM emission are OPT-IN. OpenRouter's catalogue
|
||||
# contains hundreds of image- and vision-capable models; turning these on
|
||||
# injects them into the global Image-Generation / Vision-LLM model
|
||||
# selectors alongside any static configs. Tier (free/premium) is derived
|
||||
# per model the same way it is for chat (`:free` suffix or zero pricing).
|
||||
# When a user picks a premium image/vision model the call debits the
|
||||
# shared $5 USD-cost-based premium credit pool — so leaving these off
|
||||
# avoids surprise quota burn on existing deployments. Default: false.
|
||||
image_generation_enabled: false
|
||||
vision_enabled: false
|
||||
|
||||
litellm_params:
|
||||
max_tokens: 16384
|
||||
system_instructions: ""
|
||||
|
|
|
|||
|
|
@ -731,6 +731,7 @@ class TokenUsage(BaseModel, TimestampMixin):
|
|||
prompt_tokens = Column(Integer, nullable=False, default=0)
|
||||
completion_tokens = Column(Integer, nullable=False, default=0)
|
||||
total_tokens = Column(Integer, nullable=False, default=0)
|
||||
cost_micros = Column(BigInteger, nullable=False, default=0, server_default="0")
|
||||
model_breakdown = Column(JSONB, nullable=True)
|
||||
call_details = Column(JSONB, nullable=True)
|
||||
|
||||
|
|
@ -1793,7 +1794,15 @@ class PagePurchase(Base, TimestampMixin):
|
|||
|
||||
|
||||
class PremiumTokenPurchase(Base, TimestampMixin):
|
||||
"""Tracks Stripe checkout sessions used to grant additional premium token credits."""
|
||||
"""Tracks Stripe checkout sessions used to grant additional premium credit (USD micro-units).
|
||||
|
||||
Note: the table name is preserved (``premium_token_purchases``) for
|
||||
operational continuity even though the unit is now USD micro-credits
|
||||
instead of raw tokens. The ``credit_micros_granted`` column replaced
|
||||
the legacy ``tokens_granted`` in migration 140; the stored values
|
||||
were not transformed because the prior $1 = 1M tokens Stripe price
|
||||
makes the unit conversion 1:1 numerically.
|
||||
"""
|
||||
|
||||
__tablename__ = "premium_token_purchases"
|
||||
__allow_unmapped__ = True
|
||||
|
|
@ -1810,7 +1819,7 @@ class PremiumTokenPurchase(Base, TimestampMixin):
|
|||
)
|
||||
stripe_payment_intent_id = Column(String(255), nullable=True, index=True)
|
||||
quantity = Column(Integer, nullable=False)
|
||||
tokens_granted = Column(BigInteger, nullable=False)
|
||||
credit_micros_granted = Column(BigInteger, nullable=False)
|
||||
amount_total = Column(Integer, nullable=True)
|
||||
currency = Column(String(10), nullable=True)
|
||||
status = Column(
|
||||
|
|
@ -2109,16 +2118,16 @@ if config.AUTH_TYPE == "GOOGLE":
|
|||
)
|
||||
pages_used = Column(Integer, nullable=False, default=0, server_default="0")
|
||||
|
||||
premium_tokens_limit = Column(
|
||||
premium_credit_micros_limit = Column(
|
||||
BigInteger,
|
||||
nullable=False,
|
||||
default=config.PREMIUM_TOKEN_LIMIT,
|
||||
server_default=str(config.PREMIUM_TOKEN_LIMIT),
|
||||
default=config.PREMIUM_CREDIT_MICROS_LIMIT,
|
||||
server_default=str(config.PREMIUM_CREDIT_MICROS_LIMIT),
|
||||
)
|
||||
premium_tokens_used = Column(
|
||||
premium_credit_micros_used = Column(
|
||||
BigInteger, nullable=False, default=0, server_default="0"
|
||||
)
|
||||
premium_tokens_reserved = Column(
|
||||
premium_credit_micros_reserved = Column(
|
||||
BigInteger, nullable=False, default=0, server_default="0"
|
||||
)
|
||||
|
||||
|
|
@ -2241,16 +2250,16 @@ else:
|
|||
)
|
||||
pages_used = Column(Integer, nullable=False, default=0, server_default="0")
|
||||
|
||||
premium_tokens_limit = Column(
|
||||
premium_credit_micros_limit = Column(
|
||||
BigInteger,
|
||||
nullable=False,
|
||||
default=config.PREMIUM_TOKEN_LIMIT,
|
||||
server_default=str(config.PREMIUM_TOKEN_LIMIT),
|
||||
default=config.PREMIUM_CREDIT_MICROS_LIMIT,
|
||||
server_default=str(config.PREMIUM_CREDIT_MICROS_LIMIT),
|
||||
)
|
||||
premium_tokens_used = Column(
|
||||
premium_credit_micros_used = Column(
|
||||
BigInteger, nullable=False, default=0, server_default="0"
|
||||
)
|
||||
premium_tokens_reserved = Column(
|
||||
premium_credit_micros_reserved = Column(
|
||||
BigInteger, nullable=False, default=0, server_default="0"
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -68,12 +68,25 @@ class EtlPipelineService:
|
|||
etl_service="VISION_LLM",
|
||||
content_type="image",
|
||||
)
|
||||
except Exception:
|
||||
logging.warning(
|
||||
"Vision LLM failed for %s, falling back to document parser",
|
||||
request.filename,
|
||||
exc_info=True,
|
||||
)
|
||||
except Exception as exc:
|
||||
# Special-case quota exhaustion so we log a clearer message
|
||||
# — the vision LLM didn't "fail", the user just ran out of
|
||||
# premium credit. Falling through to the document parser
|
||||
# is a graceful degradation: OCR/Unstructured still
|
||||
# extracts text from the image without burning credit.
|
||||
from app.services.billable_calls import QuotaInsufficientError
|
||||
|
||||
if isinstance(exc, QuotaInsufficientError):
|
||||
logging.info(
|
||||
"Vision LLM quota exhausted for %s; falling back to document parser",
|
||||
request.filename,
|
||||
)
|
||||
else:
|
||||
logging.warning(
|
||||
"Vision LLM failed for %s, falling back to document parser",
|
||||
request.filename,
|
||||
exc_info=True,
|
||||
)
|
||||
else:
|
||||
logging.info(
|
||||
"No vision LLM provided, falling back to document parser for %s",
|
||||
|
|
|
|||
|
|
@ -36,6 +36,11 @@ from app.schemas import (
|
|||
ImageGenerationListRead,
|
||||
ImageGenerationRead,
|
||||
)
|
||||
from app.services.billable_calls import (
|
||||
DEFAULT_IMAGE_RESERVE_MICROS,
|
||||
QuotaInsufficientError,
|
||||
billable_call,
|
||||
)
|
||||
from app.services.image_gen_router_service import (
|
||||
IMAGE_GEN_AUTO_MODE_ID,
|
||||
ImageGenRouterService,
|
||||
|
|
@ -92,6 +97,50 @@ def _build_model_string(
|
|||
return f"{prefix}/{model_name}"
|
||||
|
||||
|
||||
async def _resolve_billing_for_image_gen(
|
||||
session: AsyncSession,
|
||||
config_id: int | None,
|
||||
search_space: SearchSpace,
|
||||
) -> tuple[str, str, int]:
|
||||
"""Resolve ``(billing_tier, base_model, reserve_micros)`` for a request.
|
||||
|
||||
The resolution mirrors ``_execute_image_generation``'s lookup tree but
|
||||
only extracts the fields needed for billing — we do this *before*
|
||||
``billable_call`` so the reservation is correctly sized for the
|
||||
config that will actually run, and so we don't open an
|
||||
``ImageGeneration`` row for a request that's about to 402.
|
||||
|
||||
User-owned (positive ID) BYOK configs are always free — they cost
|
||||
the user nothing on our side. Auto mode currently treats as free
|
||||
because the underlying router can dispatch to either premium or
|
||||
free YAML configs and we don't surface the resolved deployment up
|
||||
here yet. Bringing Auto under premium billing would require
|
||||
threading the chosen deployment back from ``ImageGenRouterService``.
|
||||
"""
|
||||
resolved_id = config_id
|
||||
if resolved_id is None:
|
||||
resolved_id = search_space.image_generation_config_id or IMAGE_GEN_AUTO_MODE_ID
|
||||
|
||||
if is_image_gen_auto_mode(resolved_id):
|
||||
return ("free", "auto", DEFAULT_IMAGE_RESERVE_MICROS)
|
||||
|
||||
if resolved_id < 0:
|
||||
cfg = _get_global_image_gen_config(resolved_id) or {}
|
||||
billing_tier = str(cfg.get("billing_tier", "free")).lower()
|
||||
base_model = _build_model_string(
|
||||
cfg.get("provider", ""),
|
||||
cfg.get("model_name", ""),
|
||||
cfg.get("custom_provider"),
|
||||
)
|
||||
reserve_micros = int(
|
||||
cfg.get("quota_reserve_micros") or DEFAULT_IMAGE_RESERVE_MICROS
|
||||
)
|
||||
return (billing_tier, base_model, reserve_micros)
|
||||
|
||||
# Positive ID = user-owned BYOK image-gen config — always free.
|
||||
return ("free", "user_byok", DEFAULT_IMAGE_RESERVE_MICROS)
|
||||
|
||||
|
||||
async def _execute_image_generation(
|
||||
session: AsyncSession,
|
||||
image_gen: ImageGeneration,
|
||||
|
|
@ -225,6 +274,9 @@ async def get_global_image_gen_configs(
|
|||
"litellm_params": {},
|
||||
"is_global": True,
|
||||
"is_auto_mode": True,
|
||||
# Auto mode currently treated as free until per-deployment
|
||||
# billing-tier surfacing lands (see _resolve_billing_for_image_gen).
|
||||
"billing_tier": "free",
|
||||
}
|
||||
)
|
||||
|
||||
|
|
@ -241,6 +293,8 @@ async def get_global_image_gen_configs(
|
|||
"api_version": cfg.get("api_version") or None,
|
||||
"litellm_params": cfg.get("litellm_params", {}),
|
||||
"is_global": True,
|
||||
"billing_tier": cfg.get("billing_tier", "free"),
|
||||
"quota_reserve_micros": cfg.get("quota_reserve_micros"),
|
||||
}
|
||||
)
|
||||
|
||||
|
|
@ -454,7 +508,26 @@ async def create_image_generation(
|
|||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""Create and execute an image generation request."""
|
||||
"""Create and execute an image generation request.
|
||||
|
||||
Premium configs are gated by the user's shared premium credit pool.
|
||||
The flow is:
|
||||
|
||||
1. Permission check + load the search space (cheap, no provider call).
|
||||
2. Resolve which config will run so we know its billing tier and the
|
||||
worst-case reservation size *before* opening any DB rows.
|
||||
3. Wrap the entire ImageGeneration row insert + provider call in
|
||||
``billable_call``. If quota is denied, ``billable_call`` raises
|
||||
``QuotaInsufficientError`` *before* we flush a row, which we
|
||||
translate to HTTP 402 (no orphaned rows on the user's account,
|
||||
no inserted error rows for "you ran out of credit").
|
||||
4. On success, the actual ``response_cost`` flows through the
|
||||
LiteLLM callback into the accumulator, and ``billable_call``
|
||||
finalizes the debit at exit. Inner ``try/except`` still catches
|
||||
provider errors and stores them on ``error_message`` (HTTP 200
|
||||
with ``error_message`` set is preserved for failed-but-not-quota
|
||||
scenarios — clients already know how to surface those).
|
||||
"""
|
||||
try:
|
||||
await check_permission(
|
||||
session,
|
||||
|
|
@ -471,33 +544,70 @@ async def create_image_generation(
|
|||
if not search_space:
|
||||
raise HTTPException(status_code=404, detail="Search space not found")
|
||||
|
||||
db_image_gen = ImageGeneration(
|
||||
prompt=data.prompt,
|
||||
model=data.model,
|
||||
n=data.n,
|
||||
quality=data.quality,
|
||||
size=data.size,
|
||||
style=data.style,
|
||||
response_format=data.response_format,
|
||||
image_generation_config_id=data.image_generation_config_id,
|
||||
search_space_id=data.search_space_id,
|
||||
created_by_id=user.id,
|
||||
billing_tier, base_model, reserve_micros = await _resolve_billing_for_image_gen(
|
||||
session, data.image_generation_config_id, search_space
|
||||
)
|
||||
session.add(db_image_gen)
|
||||
await session.flush()
|
||||
|
||||
try:
|
||||
await _execute_image_generation(session, db_image_gen, search_space)
|
||||
except Exception as e:
|
||||
logger.exception("Image generation call failed")
|
||||
db_image_gen.error_message = str(e)
|
||||
# billable_call runs OUTSIDE the inner try/except so QuotaInsufficientError
|
||||
# propagates to the outer ``except QuotaInsufficientError`` handler
|
||||
# below as HTTP 402 — it is intentionally NOT swallowed into
|
||||
# ``error_message`` because that would (1) imply a successful row
|
||||
# exists when none does, and (2) return HTTP 200 to a client
|
||||
# whose request was actively *denied* (issue K).
|
||||
async with billable_call(
|
||||
user_id=search_space.user_id,
|
||||
search_space_id=data.search_space_id,
|
||||
billing_tier=billing_tier,
|
||||
base_model=base_model,
|
||||
quota_reserve_micros_override=reserve_micros,
|
||||
usage_type="image_generation",
|
||||
call_details={"model": base_model, "prompt": data.prompt[:100]},
|
||||
):
|
||||
db_image_gen = ImageGeneration(
|
||||
prompt=data.prompt,
|
||||
model=data.model,
|
||||
n=data.n,
|
||||
quality=data.quality,
|
||||
size=data.size,
|
||||
style=data.style,
|
||||
response_format=data.response_format,
|
||||
image_generation_config_id=data.image_generation_config_id,
|
||||
search_space_id=data.search_space_id,
|
||||
created_by_id=user.id,
|
||||
)
|
||||
session.add(db_image_gen)
|
||||
await session.flush()
|
||||
|
||||
await session.commit()
|
||||
await session.refresh(db_image_gen)
|
||||
return db_image_gen
|
||||
try:
|
||||
await _execute_image_generation(session, db_image_gen, search_space)
|
||||
except Exception as e:
|
||||
logger.exception("Image generation call failed")
|
||||
db_image_gen.error_message = str(e)
|
||||
|
||||
await session.commit()
|
||||
await session.refresh(db_image_gen)
|
||||
return db_image_gen
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except QuotaInsufficientError as exc:
|
||||
# The user's premium credit pool is empty. No DB row is created
|
||||
# because ``billable_call`` denies before yielding (issue K).
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=402,
|
||||
detail={
|
||||
"error_code": "premium_quota_exhausted",
|
||||
"usage_type": exc.usage_type,
|
||||
"used_micros": exc.used_micros,
|
||||
"limit_micros": exc.limit_micros,
|
||||
"remaining_micros": exc.remaining_micros,
|
||||
"message": (
|
||||
"Out of premium credits for image generation. "
|
||||
"Purchase additional credits or switch to a free model."
|
||||
),
|
||||
},
|
||||
) from exc
|
||||
except SQLAlchemyError:
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
|
|
|
|||
|
|
@ -1366,7 +1366,11 @@ async def append_message(
|
|||
# flush assigns the PK/defaults without a round-trip SELECT
|
||||
await session.flush()
|
||||
|
||||
# Persist token usage if provided (for assistant messages)
|
||||
# Persist token usage if provided (for assistant messages).
|
||||
# ``cost_micros`` is the provider USD cost reported by LiteLLM,
|
||||
# forwarded by the FE through the appendMessage round-trip so
|
||||
# the historical TokenUsage row matches the credit debit applied
|
||||
# at finalize time.
|
||||
token_usage_data = raw_body.get("token_usage")
|
||||
if token_usage_data and message_role == NewChatMessageRole.ASSISTANT:
|
||||
await record_token_usage(
|
||||
|
|
@ -1377,6 +1381,7 @@ async def append_message(
|
|||
prompt_tokens=token_usage_data.get("prompt_tokens", 0),
|
||||
completion_tokens=token_usage_data.get("completion_tokens", 0),
|
||||
total_tokens=token_usage_data.get("total_tokens", 0),
|
||||
cost_micros=token_usage_data.get("cost_micros", 0),
|
||||
model_breakdown=token_usage_data.get("usage"),
|
||||
call_details=token_usage_data.get("call_details"),
|
||||
thread_id=thread_id,
|
||||
|
|
|
|||
|
|
@ -594,6 +594,7 @@ async def _get_image_gen_config_by_id(
|
|||
"model_name": "auto",
|
||||
"is_global": True,
|
||||
"is_auto_mode": True,
|
||||
"billing_tier": "free",
|
||||
}
|
||||
|
||||
if config_id < 0:
|
||||
|
|
@ -610,6 +611,7 @@ async def _get_image_gen_config_by_id(
|
|||
"api_version": cfg.get("api_version") or None,
|
||||
"litellm_params": cfg.get("litellm_params", {}),
|
||||
"is_global": True,
|
||||
"billing_tier": cfg.get("billing_tier", "free"),
|
||||
}
|
||||
return None
|
||||
|
||||
|
|
@ -652,6 +654,7 @@ async def _get_vision_llm_config_by_id(
|
|||
"model_name": "auto",
|
||||
"is_global": True,
|
||||
"is_auto_mode": True,
|
||||
"billing_tier": "free",
|
||||
}
|
||||
|
||||
if config_id < 0:
|
||||
|
|
@ -668,6 +671,7 @@ async def _get_vision_llm_config_by_id(
|
|||
"api_version": cfg.get("api_version") or None,
|
||||
"litellm_params": cfg.get("litellm_params", {}),
|
||||
"is_global": True,
|
||||
"billing_tier": cfg.get("billing_tier", "free"),
|
||||
}
|
||||
return None
|
||||
|
||||
|
|
|
|||
|
|
@ -251,9 +251,16 @@ async def _fulfill_completed_token_purchase(
|
|||
metadata = _get_metadata(checkout_session)
|
||||
user_id = metadata.get("user_id")
|
||||
quantity = int(metadata.get("quantity", "0"))
|
||||
tokens_per_unit = int(metadata.get("tokens_per_unit", "0"))
|
||||
# Read the new metadata key first, fall back to the legacy one so
|
||||
# in-flight checkout sessions created before the cost-credits
|
||||
# release still fulfil correctly (the unit is numerically the
|
||||
# same: $1 buys 1_000_000 micro-USD == 1_000_000 tokens).
|
||||
credit_micros_per_unit = int(
|
||||
metadata.get("credit_micros_per_unit")
|
||||
or metadata.get("tokens_per_unit", "0")
|
||||
)
|
||||
|
||||
if not user_id or quantity <= 0 or tokens_per_unit <= 0:
|
||||
if not user_id or quantity <= 0 or credit_micros_per_unit <= 0:
|
||||
logger.error(
|
||||
"Skipping token fulfillment for session %s: incomplete metadata %s",
|
||||
checkout_session_id,
|
||||
|
|
@ -268,7 +275,7 @@ async def _fulfill_completed_token_purchase(
|
|||
getattr(checkout_session, "payment_intent", None)
|
||||
),
|
||||
quantity=quantity,
|
||||
tokens_granted=quantity * tokens_per_unit,
|
||||
credit_micros_granted=quantity * credit_micros_per_unit,
|
||||
amount_total=getattr(checkout_session, "amount_total", None),
|
||||
currency=getattr(checkout_session, "currency", None),
|
||||
status=PremiumTokenPurchaseStatus.PENDING,
|
||||
|
|
@ -303,9 +310,14 @@ async def _fulfill_completed_token_purchase(
|
|||
purchase.stripe_payment_intent_id = _normalize_optional_string(
|
||||
getattr(checkout_session, "payment_intent", None)
|
||||
)
|
||||
user.premium_tokens_limit = (
|
||||
max(user.premium_tokens_used, user.premium_tokens_limit)
|
||||
+ purchase.tokens_granted
|
||||
# Top up the user's credit balance by the granted micro-USD amount.
|
||||
# ``max(used, limit)`` clamps the case where the legacy code wrote a
|
||||
# used value above the limit (e.g. underbilling rounding) so adding
|
||||
# ``credit_micros_granted`` always lifts the limit by the full pack
|
||||
# size rather than disappearing into past overuse.
|
||||
user.premium_credit_micros_limit = (
|
||||
max(user.premium_credit_micros_used, user.premium_credit_micros_limit)
|
||||
+ purchase.credit_micros_granted
|
||||
)
|
||||
|
||||
await db_session.commit()
|
||||
|
|
@ -532,12 +544,18 @@ async def create_token_checkout_session(
|
|||
user: User = Depends(current_active_user),
|
||||
db_session: AsyncSession = Depends(get_async_session),
|
||||
):
|
||||
"""Create a Stripe Checkout Session for buying premium token packs."""
|
||||
"""Create a Stripe Checkout Session for buying premium credit packs.
|
||||
|
||||
Each pack grants ``STRIPE_CREDIT_MICROS_PER_UNIT`` micro-USD of
|
||||
credit (default 1_000_000 = $1.00). The user's balance is debited
|
||||
at the actual provider cost reported by LiteLLM at finalize time,
|
||||
so $1 of credit always buys $1 worth of provider usage at cost.
|
||||
"""
|
||||
_ensure_token_buying_enabled()
|
||||
stripe_client = get_stripe_client()
|
||||
price_id = _get_required_token_price_id()
|
||||
success_url, cancel_url = _get_token_checkout_urls(body.search_space_id)
|
||||
tokens_granted = body.quantity * config.STRIPE_TOKENS_PER_UNIT
|
||||
credit_micros_granted = body.quantity * config.STRIPE_CREDIT_MICROS_PER_UNIT
|
||||
|
||||
try:
|
||||
checkout_session = stripe_client.v1.checkout.sessions.create(
|
||||
|
|
@ -556,8 +574,8 @@ async def create_token_checkout_session(
|
|||
"metadata": {
|
||||
"user_id": str(user.id),
|
||||
"quantity": str(body.quantity),
|
||||
"tokens_per_unit": str(config.STRIPE_TOKENS_PER_UNIT),
|
||||
"purchase_type": "premium_tokens",
|
||||
"credit_micros_per_unit": str(config.STRIPE_CREDIT_MICROS_PER_UNIT),
|
||||
"purchase_type": "premium_credit",
|
||||
},
|
||||
}
|
||||
)
|
||||
|
|
@ -583,7 +601,7 @@ async def create_token_checkout_session(
|
|||
getattr(checkout_session, "payment_intent", None)
|
||||
),
|
||||
quantity=body.quantity,
|
||||
tokens_granted=tokens_granted,
|
||||
credit_micros_granted=credit_micros_granted,
|
||||
amount_total=getattr(checkout_session, "amount_total", None),
|
||||
currency=getattr(checkout_session, "currency", None),
|
||||
status=PremiumTokenPurchaseStatus.PENDING,
|
||||
|
|
@ -598,14 +616,19 @@ async def create_token_checkout_session(
|
|||
async def get_token_status(
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""Return token-buying availability and current premium quota for frontend."""
|
||||
used = user.premium_tokens_used
|
||||
limit = user.premium_tokens_limit
|
||||
"""Return token-buying availability and current premium credit quota for frontend.
|
||||
|
||||
Values are in micro-USD (1_000_000 = $1.00); the FE divides by 1M
|
||||
when displaying. The route name is preserved for back-compat with
|
||||
pinned client deployments.
|
||||
"""
|
||||
used = user.premium_credit_micros_used
|
||||
limit = user.premium_credit_micros_limit
|
||||
return TokenStripeStatusResponse(
|
||||
token_buying_enabled=config.STRIPE_TOKEN_BUYING_ENABLED,
|
||||
premium_tokens_used=used,
|
||||
premium_tokens_limit=limit,
|
||||
premium_tokens_remaining=max(0, limit - used),
|
||||
premium_credit_micros_used=used,
|
||||
premium_credit_micros_limit=limit,
|
||||
premium_credit_micros_remaining=max(0, limit - used),
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -82,6 +82,9 @@ async def get_global_vision_llm_configs(
|
|||
"litellm_params": {},
|
||||
"is_global": True,
|
||||
"is_auto_mode": True,
|
||||
# Auto mode treated as free until per-deployment billing-tier
|
||||
# surfacing lands; see ``get_vision_llm`` for parity.
|
||||
"billing_tier": "free",
|
||||
}
|
||||
)
|
||||
|
||||
|
|
@ -98,6 +101,10 @@ async def get_global_vision_llm_configs(
|
|||
"api_version": cfg.get("api_version") or None,
|
||||
"litellm_params": cfg.get("litellm_params", {}),
|
||||
"is_global": True,
|
||||
"billing_tier": cfg.get("billing_tier", "free"),
|
||||
"quota_reserve_tokens": cfg.get("quota_reserve_tokens"),
|
||||
"input_cost_per_token": cfg.get("input_cost_per_token"),
|
||||
"output_cost_per_token": cfg.get("output_cost_per_token"),
|
||||
}
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -215,6 +215,12 @@ class GlobalImageGenConfigRead(BaseModel):
|
|||
Schema for reading global image generation configs from YAML.
|
||||
Global configs have negative IDs. API key is hidden.
|
||||
ID 0 is reserved for Auto mode (LiteLLM Router load balancing).
|
||||
|
||||
The ``billing_tier`` field allows the frontend to show a Premium/Free
|
||||
badge and (more importantly) tells the backend whether to debit the
|
||||
user's premium credit pool when this config is used. ``"free"`` is
|
||||
the default for backward compatibility — admins must explicitly opt
|
||||
a global config into ``"premium"``.
|
||||
"""
|
||||
|
||||
id: int = Field(
|
||||
|
|
@ -231,3 +237,15 @@ class GlobalImageGenConfigRead(BaseModel):
|
|||
litellm_params: dict[str, Any] | None = None
|
||||
is_global: bool = True
|
||||
is_auto_mode: bool = False
|
||||
billing_tier: str = Field(
|
||||
default="free",
|
||||
description="'free' or 'premium'. Premium debits the user's premium credit pool (USD-cost-based).",
|
||||
)
|
||||
quota_reserve_micros: int | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Optional override for the reservation amount (in micro-USD) used when "
|
||||
"this image generation is premium. Falls back to "
|
||||
"QUOTA_DEFAULT_IMAGE_RESERVE_MICROS when omitted."
|
||||
),
|
||||
)
|
||||
|
|
|
|||
|
|
@ -39,6 +39,7 @@ class TokenUsageSummary(BaseModel):
|
|||
prompt_tokens: int = 0
|
||||
completion_tokens: int = 0
|
||||
total_tokens: int = 0
|
||||
cost_micros: int = 0
|
||||
model_breakdown: dict | None = None
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
|
|
|||
|
|
@ -70,13 +70,17 @@ class CreateTokenCheckoutSessionResponse(BaseModel):
|
|||
|
||||
|
||||
class TokenPurchaseRead(BaseModel):
|
||||
"""Serialized premium token purchase record."""
|
||||
"""Serialized premium credit purchase record.
|
||||
|
||||
``credit_micros_granted`` is in micro-USD (1_000_000 = $1.00). The
|
||||
schema name kept ``Token`` for API back-compat with pinned clients.
|
||||
"""
|
||||
|
||||
id: uuid.UUID
|
||||
stripe_checkout_session_id: str
|
||||
stripe_payment_intent_id: str | None = None
|
||||
quantity: int
|
||||
tokens_granted: int
|
||||
credit_micros_granted: int
|
||||
amount_total: int | None = None
|
||||
currency: str | None = None
|
||||
status: str
|
||||
|
|
@ -87,15 +91,19 @@ class TokenPurchaseRead(BaseModel):
|
|||
|
||||
|
||||
class TokenPurchaseHistoryResponse(BaseModel):
|
||||
"""Response containing the user's premium token purchases."""
|
||||
"""Response containing the user's premium credit purchases."""
|
||||
|
||||
purchases: list[TokenPurchaseRead]
|
||||
|
||||
|
||||
class TokenStripeStatusResponse(BaseModel):
|
||||
"""Response describing token-buying availability and current quota."""
|
||||
"""Response describing premium-credit-buying availability and balance.
|
||||
|
||||
All ``premium_credit_micros_*`` fields are in micro-USD; the FE
|
||||
divides by 1_000_000 to display USD.
|
||||
"""
|
||||
|
||||
token_buying_enabled: bool
|
||||
premium_tokens_used: int = 0
|
||||
premium_tokens_limit: int = 0
|
||||
premium_tokens_remaining: int = 0
|
||||
premium_credit_micros_used: int = 0
|
||||
premium_credit_micros_limit: int = 0
|
||||
premium_credit_micros_remaining: int = 0
|
||||
|
|
|
|||
|
|
@ -62,6 +62,15 @@ class VisionLLMConfigPublic(BaseModel):
|
|||
|
||||
|
||||
class GlobalVisionLLMConfigRead(BaseModel):
|
||||
"""Schema for reading global vision LLM configs from YAML.
|
||||
|
||||
The ``billing_tier`` field allows the frontend to show a Premium/Free
|
||||
badge and (more importantly) tells the backend whether to debit the
|
||||
user's premium credit pool when this config is used. ``"free"`` is
|
||||
the default for backward compatibility — admins must explicitly opt
|
||||
a global config into ``"premium"``.
|
||||
"""
|
||||
|
||||
id: int = Field(...)
|
||||
name: str
|
||||
description: str | None = None
|
||||
|
|
@ -73,3 +82,26 @@ class GlobalVisionLLMConfigRead(BaseModel):
|
|||
litellm_params: dict[str, Any] | None = None
|
||||
is_global: bool = True
|
||||
is_auto_mode: bool = False
|
||||
billing_tier: str = Field(
|
||||
default="free",
|
||||
description="'free' or 'premium'. Premium debits the user's premium credit pool (USD-cost-based).",
|
||||
)
|
||||
quota_reserve_tokens: int | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Optional override for the per-call reservation in *tokens* — "
|
||||
"converted to micro-USD via the model's input/output prices at "
|
||||
"reservation time. Falls back to QUOTA_DEFAULT_RESERVE_TOKENS."
|
||||
),
|
||||
)
|
||||
input_cost_per_token: float | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Optional input price in USD/token. Used by pricing_registration to "
|
||||
"register custom Azure / OpenRouter aliases with LiteLLM at startup."
|
||||
),
|
||||
)
|
||||
output_cost_per_token: float | None = Field(
|
||||
default=None,
|
||||
description="Optional output price in USD/token. Pair with input_cost_per_token.",
|
||||
)
|
||||
|
|
|
|||
430
surfsense_backend/app/services/billable_calls.py
Normal file
430
surfsense_backend/app/services/billable_calls.py
Normal file
|
|
@ -0,0 +1,430 @@
|
|||
"""
|
||||
Per-call billable wrapper for image generation, vision LLM extraction, and
|
||||
any other short-lived premium operation that must charge against the user's
|
||||
shared premium credit pool.
|
||||
|
||||
The ``billable_call`` async context manager encapsulates the standard
|
||||
"reserve → execute → finalize / release → record audit row" lifecycle in a
|
||||
single primitive so callers (the image-generation REST route and the
|
||||
vision-LLM wrapper used during indexing) don't have to re-implement it.
|
||||
|
||||
KEY DESIGN POINTS (issue A, B):
|
||||
|
||||
1. **Session isolation.** ``billable_call`` takes *no* ``db_session``
|
||||
argument. All ``TokenQuotaService.premium_*`` calls and the audit-row
|
||||
insert each run inside their own ``shielded_async_session()``. This
|
||||
guarantees that a quota commit/rollback can never accidentally flush or
|
||||
roll back rows the caller has staged in the request's main session
|
||||
(e.g. a freshly-created ``ImageGeneration`` row).
|
||||
|
||||
2. **ContextVar safety.** The accumulator is scoped via
|
||||
:func:`scoped_turn` (which uses ``ContextVar.reset(token)``), so a
|
||||
nested ``billable_call`` inside an outer chat turn cannot corrupt the
|
||||
chat turn's accumulator.
|
||||
|
||||
3. **Free configs are still audited.** Free calls bypass the reserve /
|
||||
finalize dance entirely but still record a ``TokenUsage`` audit row with
|
||||
the LiteLLM-reported ``cost_micros``. This keeps the cost-attribution
|
||||
pipeline complete for analytics even when nothing is debited.
|
||||
|
||||
4. **Quota denial raises ``QuotaInsufficientError``.** The route handler is
|
||||
responsible for translating that into HTTP 402. We *do not* catch the
|
||||
denial inside ``billable_call`` — letting it propagate also prevents
|
||||
the image-generation route from creating an ``ImageGeneration`` row
|
||||
for a request that never actually ran.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections.abc import AsyncIterator
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Any
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.config import config
|
||||
from app.db import shielded_async_session
|
||||
from app.services.token_quota_service import (
|
||||
TokenQuotaService,
|
||||
estimate_call_reserve_micros,
|
||||
)
|
||||
from app.services.token_tracking_service import (
|
||||
TurnTokenAccumulator,
|
||||
record_token_usage,
|
||||
scoped_turn,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class QuotaInsufficientError(Exception):
|
||||
"""Raised when ``TokenQuotaService.premium_reserve`` denies a billable
|
||||
call because the user has exhausted their premium credit pool.
|
||||
|
||||
The route handler should catch this and return HTTP 402 Payment
|
||||
Required (or the equivalent for the surface area). Outside of the HTTP
|
||||
layer (e.g. the ``QuotaCheckedVisionLLM`` wrapper used during indexing)
|
||||
callers may catch this and degrade gracefully — e.g. fall back to OCR
|
||||
when vision is unavailable.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
usage_type: str,
|
||||
used_micros: int,
|
||||
limit_micros: int,
|
||||
remaining_micros: int,
|
||||
) -> None:
|
||||
self.usage_type = usage_type
|
||||
self.used_micros = used_micros
|
||||
self.limit_micros = limit_micros
|
||||
self.remaining_micros = remaining_micros
|
||||
super().__init__(
|
||||
f"Premium credit exhausted for {usage_type}: "
|
||||
f"used={used_micros} limit={limit_micros} remaining={remaining_micros} (micro-USD)"
|
||||
)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def billable_call(
|
||||
*,
|
||||
user_id: UUID,
|
||||
search_space_id: int,
|
||||
billing_tier: str,
|
||||
base_model: str,
|
||||
quota_reserve_tokens: int | None = None,
|
||||
quota_reserve_micros_override: int | None = None,
|
||||
usage_type: str,
|
||||
thread_id: int | None = None,
|
||||
message_id: int | None = None,
|
||||
call_details: dict[str, Any] | None = None,
|
||||
) -> AsyncIterator[TurnTokenAccumulator]:
|
||||
"""Wrap a single billable LLM/image call.
|
||||
|
||||
Args:
|
||||
user_id: Owner of the credit pool to debit. For vision-LLM during
|
||||
indexing this is the *search-space owner* (issue M), not the
|
||||
triggering user.
|
||||
search_space_id: Required — recorded on the ``TokenUsage`` audit row.
|
||||
billing_tier: ``"premium"`` debits; anything else (``"free"``) skips
|
||||
the reserve/finalize dance but still records an audit row with
|
||||
the captured cost.
|
||||
base_model: Used by :func:`estimate_call_reserve_micros` to compute
|
||||
a worst-case reservation from LiteLLM's pricing table.
|
||||
quota_reserve_tokens: Optional per-config override for the chat-style
|
||||
reserve estimator (vision LLM uses this).
|
||||
quota_reserve_micros_override: Optional flat micro-USD reservation
|
||||
(image generation uses this — its cost shape is per-image, not
|
||||
per-token).
|
||||
usage_type: ``"image_generation"`` / ``"vision_extraction"`` / etc.
|
||||
Recorded on the ``TokenUsage`` row.
|
||||
thread_id, message_id: Optional FK columns on ``TokenUsage``.
|
||||
call_details: Optional per-call metadata (model name, parameters)
|
||||
forwarded to ``record_token_usage``.
|
||||
|
||||
Yields:
|
||||
The ``TurnTokenAccumulator`` scoped to this call. The caller invokes
|
||||
the underlying LLM/image API while inside the ``async with``; the
|
||||
``TokenTrackingCallback`` populates the accumulator automatically.
|
||||
|
||||
Raises:
|
||||
QuotaInsufficientError: when premium and ``premium_reserve`` denies.
|
||||
"""
|
||||
is_premium = billing_tier == "premium"
|
||||
|
||||
async with scoped_turn() as acc:
|
||||
# ---------- Free path: just audit -------------------------------
|
||||
if not is_premium:
|
||||
try:
|
||||
yield acc
|
||||
finally:
|
||||
# Always audit, even on exception, so we capture cost when
|
||||
# provider returns successfully but the caller raises later.
|
||||
try:
|
||||
async with shielded_async_session() as audit_session:
|
||||
await record_token_usage(
|
||||
audit_session,
|
||||
usage_type=usage_type,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
prompt_tokens=acc.total_prompt_tokens,
|
||||
completion_tokens=acc.total_completion_tokens,
|
||||
total_tokens=acc.grand_total,
|
||||
cost_micros=acc.total_cost_micros,
|
||||
model_breakdown=acc.per_message_summary(),
|
||||
call_details=call_details,
|
||||
thread_id=thread_id,
|
||||
message_id=message_id,
|
||||
)
|
||||
await audit_session.commit()
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"[billable_call] free-path audit insert failed for "
|
||||
"usage_type=%s user_id=%s",
|
||||
usage_type,
|
||||
user_id,
|
||||
)
|
||||
return
|
||||
|
||||
# ---------- Premium path: reserve → execute → finalize ----------
|
||||
if quota_reserve_micros_override is not None:
|
||||
reserve_micros = max(1, int(quota_reserve_micros_override))
|
||||
else:
|
||||
reserve_micros = estimate_call_reserve_micros(
|
||||
base_model=base_model or "",
|
||||
quota_reserve_tokens=quota_reserve_tokens,
|
||||
)
|
||||
|
||||
request_id = str(uuid4())
|
||||
|
||||
async with shielded_async_session() as quota_session:
|
||||
reserve_result = await TokenQuotaService.premium_reserve(
|
||||
db_session=quota_session,
|
||||
user_id=user_id,
|
||||
request_id=request_id,
|
||||
reserve_micros=reserve_micros,
|
||||
)
|
||||
|
||||
if not reserve_result.allowed:
|
||||
logger.info(
|
||||
"[billable_call] reserve DENIED user=%s usage_type=%s "
|
||||
"reserve=%d used=%d limit=%d remaining=%d",
|
||||
user_id,
|
||||
usage_type,
|
||||
reserve_micros,
|
||||
reserve_result.used,
|
||||
reserve_result.limit,
|
||||
reserve_result.remaining,
|
||||
)
|
||||
raise QuotaInsufficientError(
|
||||
usage_type=usage_type,
|
||||
used_micros=reserve_result.used,
|
||||
limit_micros=reserve_result.limit,
|
||||
remaining_micros=reserve_result.remaining,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"[billable_call] reserve OK user=%s usage_type=%s reserve_micros=%d "
|
||||
"(remaining=%d)",
|
||||
user_id,
|
||||
usage_type,
|
||||
reserve_micros,
|
||||
reserve_result.remaining,
|
||||
)
|
||||
|
||||
try:
|
||||
yield acc
|
||||
except BaseException:
|
||||
# Release on any failure (including QuotaInsufficientError raised
|
||||
# from a downstream call, asyncio cancellation, etc.). We use
|
||||
# BaseException so cancellation also releases.
|
||||
try:
|
||||
async with shielded_async_session() as quota_session:
|
||||
await TokenQuotaService.premium_release(
|
||||
db_session=quota_session,
|
||||
user_id=user_id,
|
||||
reserved_micros=reserve_micros,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"[billable_call] premium_release failed for user=%s "
|
||||
"reserve_micros=%d (reservation will be GC'd by quota "
|
||||
"reconciliation if/when implemented)",
|
||||
user_id,
|
||||
reserve_micros,
|
||||
)
|
||||
raise
|
||||
|
||||
# ---------- Success: finalize + audit ----------------------------
|
||||
actual_micros = acc.total_cost_micros
|
||||
try:
|
||||
async with shielded_async_session() as quota_session:
|
||||
final_result = await TokenQuotaService.premium_finalize(
|
||||
db_session=quota_session,
|
||||
user_id=user_id,
|
||||
request_id=request_id,
|
||||
actual_micros=actual_micros,
|
||||
reserved_micros=reserve_micros,
|
||||
)
|
||||
logger.info(
|
||||
"[billable_call] finalize user=%s usage_type=%s actual=%d "
|
||||
"reserved=%d → used=%d/%d (remaining=%d)",
|
||||
user_id,
|
||||
usage_type,
|
||||
actual_micros,
|
||||
reserve_micros,
|
||||
final_result.used,
|
||||
final_result.limit,
|
||||
final_result.remaining,
|
||||
)
|
||||
except Exception:
|
||||
# Last-ditch: if finalize itself fails, we must at least release
|
||||
# so the reservation doesn't leak.
|
||||
logger.exception(
|
||||
"[billable_call] premium_finalize failed for user=%s; "
|
||||
"attempting release",
|
||||
user_id,
|
||||
)
|
||||
try:
|
||||
async with shielded_async_session() as quota_session:
|
||||
await TokenQuotaService.premium_release(
|
||||
db_session=quota_session,
|
||||
user_id=user_id,
|
||||
reserved_micros=reserve_micros,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"[billable_call] release after finalize failure ALSO failed "
|
||||
"for user=%s",
|
||||
user_id,
|
||||
)
|
||||
|
||||
try:
|
||||
async with shielded_async_session() as audit_session:
|
||||
await record_token_usage(
|
||||
audit_session,
|
||||
usage_type=usage_type,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
prompt_tokens=acc.total_prompt_tokens,
|
||||
completion_tokens=acc.total_completion_tokens,
|
||||
total_tokens=acc.grand_total,
|
||||
cost_micros=actual_micros,
|
||||
model_breakdown=acc.per_message_summary(),
|
||||
call_details=call_details,
|
||||
thread_id=thread_id,
|
||||
message_id=message_id,
|
||||
)
|
||||
await audit_session.commit()
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"[billable_call] premium-path audit insert failed for "
|
||||
"usage_type=%s user_id=%s (debit was applied)",
|
||||
usage_type,
|
||||
user_id,
|
||||
)
|
||||
|
||||
|
||||
async def _resolve_agent_billing_for_search_space(
|
||||
session: AsyncSession,
|
||||
search_space_id: int,
|
||||
*,
|
||||
thread_id: int | None = None,
|
||||
) -> tuple[UUID, str, str]:
|
||||
"""Resolve ``(owner_user_id, billing_tier, base_model)`` for the search-space
|
||||
agent LLM.
|
||||
|
||||
Used by Celery tasks (podcast generation, video presentation) to bill the
|
||||
search-space owner's premium credit pool when the agent LLM is premium.
|
||||
|
||||
Resolution rules mirror chat at ``stream_new_chat.py:2294-2351``:
|
||||
|
||||
- Search space not found / no ``agent_llm_id``: raise ``ValueError``.
|
||||
- **Auto mode** (``id == AUTO_FASTEST_ID == 0``):
|
||||
* ``thread_id`` is set: delegate to
|
||||
``resolve_or_get_pinned_llm_config_id`` (the same call chat uses) and
|
||||
recurse into the resolved id. Reuses chat's existing pin if present
|
||||
so the same model bills for chat + downstream podcast/video. If the
|
||||
user is not premium-eligible, the pin service auto-restricts to free
|
||||
deployments — denial only happens later in
|
||||
``billable_call.premium_reserve`` if the pin really is premium and
|
||||
credit ran out mid-flow.
|
||||
* ``thread_id`` is None: fallback to ``("free", "auto")``. Forward-compat
|
||||
for any future direct-API path; today both Celery tasks always pass
|
||||
``thread_id``.
|
||||
- **Negative id** (global YAML / OpenRouter): ``cfg["billing_tier"]``
|
||||
(defaults to ``"free"`` via ``app/config/__init__.py:52`` setdefault),
|
||||
``base_model = litellm_params.get("base_model") or model_name`` —
|
||||
NOT provider-prefixed, matching chat's cost-map lookup convention.
|
||||
- **Positive id** (user BYOK ``NewLLMConfig``): always free (matches
|
||||
``AgentConfig.from_new_llm_config`` which hard-codes ``billing_tier="free"``);
|
||||
``base_model`` from ``litellm_params`` or ``model_name``.
|
||||
|
||||
Note on imports: ``llm_service``, ``auto_model_pin_service``, and
|
||||
``llm_router_service`` are imported lazily inside the function body to
|
||||
avoid hoisting litellm side-effects (``litellm.callbacks =
|
||||
[token_tracker]``, ``litellm.drop_params``, etc.) into
|
||||
``billable_calls.py``'s module load path.
|
||||
"""
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.db import NewLLMConfig, SearchSpace
|
||||
|
||||
result = await session.execute(
|
||||
select(SearchSpace).where(SearchSpace.id == search_space_id)
|
||||
)
|
||||
search_space = result.scalars().first()
|
||||
if search_space is None:
|
||||
raise ValueError(f"Search space {search_space_id} not found")
|
||||
|
||||
agent_llm_id = search_space.agent_llm_id
|
||||
if agent_llm_id is None:
|
||||
raise ValueError(
|
||||
f"Search space {search_space_id} has no agent_llm_id configured"
|
||||
)
|
||||
|
||||
owner_user_id: UUID = search_space.user_id
|
||||
|
||||
from app.services.auto_model_pin_service import (
|
||||
AUTO_FASTEST_ID,
|
||||
resolve_or_get_pinned_llm_config_id,
|
||||
)
|
||||
|
||||
if agent_llm_id == AUTO_FASTEST_ID:
|
||||
if thread_id is None:
|
||||
return owner_user_id, "free", "auto"
|
||||
try:
|
||||
resolution = await resolve_or_get_pinned_llm_config_id(
|
||||
session,
|
||||
thread_id=thread_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=str(owner_user_id),
|
||||
selected_llm_config_id=AUTO_FASTEST_ID,
|
||||
)
|
||||
except ValueError:
|
||||
logger.warning(
|
||||
"[agent_billing] Auto-mode pin resolution failed for "
|
||||
"search_space=%s thread=%s; falling back to free",
|
||||
search_space_id,
|
||||
thread_id,
|
||||
exc_info=True,
|
||||
)
|
||||
return owner_user_id, "free", "auto"
|
||||
agent_llm_id = resolution.resolved_llm_config_id
|
||||
|
||||
if agent_llm_id < 0:
|
||||
from app.services.llm_service import get_global_llm_config
|
||||
|
||||
cfg = get_global_llm_config(agent_llm_id) or {}
|
||||
billing_tier = str(cfg.get("billing_tier", "free")).lower()
|
||||
litellm_params = cfg.get("litellm_params") or {}
|
||||
base_model = litellm_params.get("base_model") or cfg.get("model_name") or ""
|
||||
return owner_user_id, billing_tier, base_model
|
||||
|
||||
nlc_result = await session.execute(
|
||||
select(NewLLMConfig).where(
|
||||
NewLLMConfig.id == agent_llm_id,
|
||||
NewLLMConfig.search_space_id == search_space_id,
|
||||
)
|
||||
)
|
||||
nlc = nlc_result.scalars().first()
|
||||
base_model = ""
|
||||
if nlc is not None:
|
||||
litellm_params = nlc.litellm_params or {}
|
||||
base_model = litellm_params.get("base_model") or nlc.model_name or ""
|
||||
return owner_user_id, "free", base_model
|
||||
|
||||
|
||||
__all__ = [
|
||||
"QuotaInsufficientError",
|
||||
"_resolve_agent_billing_for_search_space",
|
||||
"billable_call",
|
||||
]
|
||||
|
||||
|
||||
# Re-export the config knob so callers don't have to import config just for
|
||||
# the default image reserve.
|
||||
DEFAULT_IMAGE_RESERVE_MICROS = config.QUOTA_DEFAULT_IMAGE_RESERVE_MICROS
|
||||
|
|
@ -134,42 +134,16 @@ PROVIDER_MAP = {
|
|||
}
|
||||
|
||||
|
||||
# Default ``api_base`` per LiteLLM provider prefix. Used as a safety net when
|
||||
# a global LLM config does *not* specify ``api_base``: without this, LiteLLM
|
||||
# happily picks up provider-agnostic env vars (e.g. ``AZURE_API_BASE``,
|
||||
# ``OPENAI_API_BASE``) and routes, say, an ``openrouter/anthropic/claude-3-haiku``
|
||||
# request to an Azure endpoint, which then 404s with ``Resource not found``.
|
||||
# Only providers with a well-known, stable public base URL are listed here —
|
||||
# self-hosted / BYO-endpoint providers (ollama, custom, bedrock, vertex_ai,
|
||||
# huggingface, databricks, cloudflare, replicate) are intentionally omitted
|
||||
# so their existing config-driven behaviour is preserved.
|
||||
PROVIDER_DEFAULT_API_BASE = {
|
||||
"openrouter": "https://openrouter.ai/api/v1",
|
||||
"groq": "https://api.groq.com/openai/v1",
|
||||
"mistral": "https://api.mistral.ai/v1",
|
||||
"perplexity": "https://api.perplexity.ai",
|
||||
"xai": "https://api.x.ai/v1",
|
||||
"cerebras": "https://api.cerebras.ai/v1",
|
||||
"deepinfra": "https://api.deepinfra.com/v1/openai",
|
||||
"fireworks_ai": "https://api.fireworks.ai/inference/v1",
|
||||
"together_ai": "https://api.together.xyz/v1",
|
||||
"anyscale": "https://api.endpoints.anyscale.com/v1",
|
||||
"cometapi": "https://api.cometapi.com/v1",
|
||||
"sambanova": "https://api.sambanova.ai/v1",
|
||||
}
|
||||
|
||||
|
||||
# Canonical provider → base URL when a config uses a generic ``openai``-style
|
||||
# prefix but the ``provider`` field tells us which API it really is
|
||||
# (e.g. DeepSeek/Alibaba/Moonshot/Zhipu/MiniMax all use ``openai`` compat but
|
||||
# each has its own base URL).
|
||||
PROVIDER_KEY_DEFAULT_API_BASE = {
|
||||
"DEEPSEEK": "https://api.deepseek.com/v1",
|
||||
"ALIBABA_QWEN": "https://dashscope-intl.aliyuncs.com/compatible-mode/v1",
|
||||
"MOONSHOT": "https://api.moonshot.ai/v1",
|
||||
"ZHIPU": "https://open.bigmodel.cn/api/paas/v4",
|
||||
"MINIMAX": "https://api.minimax.io/v1",
|
||||
}
|
||||
# ``PROVIDER_DEFAULT_API_BASE`` and ``PROVIDER_KEY_DEFAULT_API_BASE`` were
|
||||
# hoisted to ``app.services.provider_api_base`` so vision and image-gen
|
||||
# call sites can share the exact same defense (OpenRouter / Groq / etc.
|
||||
# 404-ing against an inherited Azure endpoint). Re-exported here for
|
||||
# backward compatibility with any external import.
|
||||
from app.services.provider_api_base import ( # noqa: E402
|
||||
PROVIDER_DEFAULT_API_BASE,
|
||||
PROVIDER_KEY_DEFAULT_API_BASE,
|
||||
resolve_api_base,
|
||||
)
|
||||
|
||||
|
||||
class LLMRouterService:
|
||||
|
|
@ -466,14 +440,14 @@ class LLMRouterService:
|
|||
# Resolve ``api_base``. Config value wins; otherwise apply a
|
||||
# provider-aware default so the deployment does not silently
|
||||
# inherit unrelated env vars (e.g. ``AZURE_API_BASE``) and route
|
||||
# requests to the wrong endpoint. See ``PROVIDER_DEFAULT_API_BASE``
|
||||
# requests to the wrong endpoint. See ``provider_api_base``
|
||||
# docstring for the motivating bug (OpenRouter models 404-ing
|
||||
# against an Azure endpoint).
|
||||
api_base = config.get("api_base")
|
||||
if not api_base:
|
||||
api_base = PROVIDER_KEY_DEFAULT_API_BASE.get(provider)
|
||||
if not api_base:
|
||||
api_base = PROVIDER_DEFAULT_API_BASE.get(provider_prefix)
|
||||
api_base = resolve_api_base(
|
||||
provider=provider,
|
||||
provider_prefix=provider_prefix,
|
||||
config_api_base=config.get("api_base"),
|
||||
)
|
||||
if api_base:
|
||||
litellm_params["api_base"] = api_base
|
||||
|
||||
|
|
|
|||
|
|
@ -496,8 +496,14 @@ async def get_vision_llm(
|
|||
- Auto mode (ID 0): VisionLLMRouterService
|
||||
- Global (negative ID): YAML configs
|
||||
- DB (positive ID): VisionLLMConfig table
|
||||
|
||||
Premium global configs are wrapped in :class:`QuotaCheckedVisionLLM`
|
||||
so each ``ainvoke`` debits the search-space owner's premium credit
|
||||
pool. User-owned BYOK configs and free global configs are returned
|
||||
unwrapped — they don't consume premium credit (issue M).
|
||||
"""
|
||||
from app.db import VisionLLMConfig
|
||||
from app.services.quota_checked_vision_llm import QuotaCheckedVisionLLM
|
||||
from app.services.vision_llm_router_service import (
|
||||
VISION_PROVIDER_MAP,
|
||||
VisionLLMRouterService,
|
||||
|
|
@ -519,6 +525,8 @@ async def get_vision_llm(
|
|||
logger.error(f"No vision LLM configured for search space {search_space_id}")
|
||||
return None
|
||||
|
||||
owner_user_id = search_space.user_id
|
||||
|
||||
if is_vision_auto_mode(config_id):
|
||||
if not VisionLLMRouterService.is_initialized():
|
||||
logger.error(
|
||||
|
|
@ -526,6 +534,13 @@ async def get_vision_llm(
|
|||
)
|
||||
return None
|
||||
try:
|
||||
# Auto mode is currently treated as free at the wrapper
|
||||
# level — the underlying router can dispatch to either
|
||||
# premium or free YAML configs but routing decisions are
|
||||
# opaque. If/when we want to bill Auto-routed vision
|
||||
# calls we'd need to thread the resolved deployment's
|
||||
# billing_tier back from the router. For now we keep
|
||||
# parity with chat Auto, which also doesn't pre-classify.
|
||||
return ChatLiteLLMRouter(
|
||||
router=VisionLLMRouterService.get_router(),
|
||||
streaming=True,
|
||||
|
|
@ -562,8 +577,21 @@ async def get_vision_llm(
|
|||
|
||||
from app.agents.new_chat.llm_config import SanitizedChatLiteLLM
|
||||
|
||||
return SanitizedChatLiteLLM(**litellm_kwargs)
|
||||
inner_llm = SanitizedChatLiteLLM(**litellm_kwargs)
|
||||
|
||||
billing_tier = str(global_cfg.get("billing_tier", "free")).lower()
|
||||
if billing_tier == "premium":
|
||||
return QuotaCheckedVisionLLM(
|
||||
inner_llm,
|
||||
user_id=owner_user_id,
|
||||
search_space_id=search_space_id,
|
||||
billing_tier=billing_tier,
|
||||
base_model=model_string,
|
||||
quota_reserve_tokens=global_cfg.get("quota_reserve_tokens"),
|
||||
)
|
||||
return inner_llm
|
||||
|
||||
# User-owned (positive ID) BYOK configs — always free.
|
||||
result = await session.execute(
|
||||
select(VisionLLMConfig).where(
|
||||
VisionLLMConfig.id == config_id,
|
||||
|
|
|
|||
|
|
@ -93,6 +93,35 @@ def _is_text_output_model(model: dict) -> bool:
|
|||
return output_mods == ["text"]
|
||||
|
||||
|
||||
def _is_image_output_model(model: dict) -> bool:
|
||||
"""Return True if the model can produce image output.
|
||||
|
||||
OpenRouter's ``architecture.output_modalities`` is a list (e.g.
|
||||
``["image"]`` for pure image generators, ``["text", "image"]`` for
|
||||
multi-modal generators that also emit captions). We accept any model
|
||||
that can output images; the call site decides whether to use the
|
||||
image-generation API or chat completion.
|
||||
"""
|
||||
output_mods = model.get("architecture", {}).get("output_modalities", []) or []
|
||||
return "image" in output_mods
|
||||
|
||||
|
||||
def _is_vision_input_model(model: dict) -> bool:
|
||||
"""Return True if the model can ingest an image AND emit text.
|
||||
|
||||
OpenRouter's ``architecture.input_modalities`` lists what the model
|
||||
accepts; ``output_modalities`` lists what it produces. A vision LLM
|
||||
is a model that takes images in and produces text out — i.e. it can
|
||||
answer questions about a screenshot or extract content from an
|
||||
image. Pure image-to-image models (e.g. style transfer) and
|
||||
text-only models are excluded.
|
||||
"""
|
||||
arch = model.get("architecture", {}) or {}
|
||||
input_mods = arch.get("input_modalities", []) or []
|
||||
output_mods = arch.get("output_modalities", []) or []
|
||||
return "image" in input_mods and "text" in output_mods
|
||||
|
||||
|
||||
def _supports_tool_calling(model: dict) -> bool:
|
||||
"""Return True if the model supports function/tool calling."""
|
||||
supported = model.get("supported_parameters") or []
|
||||
|
|
@ -175,6 +204,32 @@ async def _fetch_models_async() -> list[dict] | None:
|
|||
return None
|
||||
|
||||
|
||||
def _extract_raw_pricing(raw_models: list[dict]) -> dict[str, dict[str, str]]:
|
||||
"""Return a ``{model_id: {"prompt": str, "completion": str}}`` map.
|
||||
|
||||
Pricing values are kept as the raw OpenRouter strings (e.g.
|
||||
``"0.000003"``); ``pricing_registration`` converts them to floats
|
||||
when registering with LiteLLM. Models with missing or malformed
|
||||
pricing are simply omitted — operator-side risk if any of those are
|
||||
premium.
|
||||
"""
|
||||
pricing: dict[str, dict[str, str]] = {}
|
||||
for model in raw_models:
|
||||
model_id = str(model.get("id") or "").strip()
|
||||
if not model_id:
|
||||
continue
|
||||
p = model.get("pricing") or {}
|
||||
prompt = p.get("prompt")
|
||||
completion = p.get("completion")
|
||||
if prompt is None and completion is None:
|
||||
continue
|
||||
pricing[model_id] = {
|
||||
"prompt": str(prompt) if prompt is not None else "",
|
||||
"completion": str(completion) if completion is not None else "",
|
||||
}
|
||||
return pricing
|
||||
|
||||
|
||||
def _generate_configs(
|
||||
raw_models: list[dict],
|
||||
settings: dict[str, Any],
|
||||
|
|
@ -282,6 +337,162 @@ def _generate_configs(
|
|||
return configs
|
||||
|
||||
|
||||
# ID-offset bands used to keep dynamic OpenRouter configs in their own
|
||||
# namespace per surface. Image / vision get separate bands so a single
|
||||
# Postgres-INTEGER cfg ID is unambiguous about which selector it belongs to.
|
||||
_OPENROUTER_IMAGE_ID_OFFSET_DEFAULT = -20000
|
||||
_OPENROUTER_VISION_ID_OFFSET_DEFAULT = -30000
|
||||
|
||||
|
||||
def _generate_image_gen_configs(
|
||||
raw_models: list[dict], settings: dict[str, Any]
|
||||
) -> list[dict]:
|
||||
"""Convert OpenRouter image-generation models into global image-gen
|
||||
config dicts (matches the YAML shape consumed by ``image_generation_routes``).
|
||||
|
||||
Filter:
|
||||
- architecture.output_modalities contains "image"
|
||||
- compatible provider (excluded slugs blocked)
|
||||
- allowed model id (excluded list blocked)
|
||||
|
||||
Notably we *drop* the chat-only filters (``_supports_tool_calling`` and
|
||||
``_has_sufficient_context``) because tool calls and context windows are
|
||||
irrelevant for the ``aimage_generation`` API. ``billing_tier`` is
|
||||
derived per model the same way as chat (``_openrouter_tier``).
|
||||
|
||||
Cost is intentionally *not* registered with LiteLLM at startup
|
||||
(``pricing_registration`` skips image gen): OpenRouter image-gen
|
||||
models are not in LiteLLM's native cost map and OpenRouter populates
|
||||
``response_cost`` directly from the response header. A defensive
|
||||
branch in ``_extract_cost_usd`` handles the rare case where
|
||||
``usage.cost`` is missing — see ``token_tracking_service``.
|
||||
"""
|
||||
id_offset: int = int(
|
||||
settings.get("image_id_offset") or _OPENROUTER_IMAGE_ID_OFFSET_DEFAULT
|
||||
)
|
||||
api_key: str = settings.get("api_key", "")
|
||||
rpm: int = settings.get("rpm", 200)
|
||||
free_rpm: int = settings.get("free_rpm", 20)
|
||||
litellm_params: dict = settings.get("litellm_params") or {}
|
||||
|
||||
image_models = [
|
||||
m
|
||||
for m in raw_models
|
||||
if _is_image_output_model(m)
|
||||
and _is_compatible_provider(m)
|
||||
and _is_allowed_model(m)
|
||||
and "/" in m.get("id", "")
|
||||
]
|
||||
|
||||
configs: list[dict] = []
|
||||
taken: set[int] = set()
|
||||
for model in image_models:
|
||||
model_id: str = model["id"]
|
||||
name: str = model.get("name", model_id)
|
||||
tier = _openrouter_tier(model)
|
||||
|
||||
cfg: dict[str, Any] = {
|
||||
"id": _stable_config_id(model_id, id_offset, taken),
|
||||
"name": name,
|
||||
"description": f"{name} via OpenRouter (image generation)",
|
||||
"provider": "OPENROUTER",
|
||||
"model_name": model_id,
|
||||
"api_key": api_key,
|
||||
"api_base": "",
|
||||
"api_version": None,
|
||||
"rpm": free_rpm if tier == "free" else rpm,
|
||||
"litellm_params": dict(litellm_params),
|
||||
"billing_tier": tier,
|
||||
_OPENROUTER_DYNAMIC_MARKER: True,
|
||||
}
|
||||
configs.append(cfg)
|
||||
|
||||
return configs
|
||||
|
||||
|
||||
def _generate_vision_llm_configs(
|
||||
raw_models: list[dict], settings: dict[str, Any]
|
||||
) -> list[dict]:
|
||||
"""Convert OpenRouter vision-capable LLMs into global vision-LLM config
|
||||
dicts (matches the YAML shape consumed by ``vision_llm_routes``).
|
||||
|
||||
Filter:
|
||||
- architecture.input_modalities contains "image"
|
||||
- architecture.output_modalities contains "text"
|
||||
- compatible provider (excluded slugs blocked)
|
||||
- allowed model id (excluded list blocked)
|
||||
|
||||
Vision-LLM is invoked from the indexer (image extraction during
|
||||
document upload) via ``langchain_litellm.ChatLiteLLM.ainvoke``, so
|
||||
the chat-only ``_supports_tool_calling`` and ``_has_sufficient_context``
|
||||
filters do not apply: a small-context vision model that doesn't
|
||||
advertise tool-calling is still perfectly viable for "describe this
|
||||
image" prompts.
|
||||
"""
|
||||
id_offset: int = int(
|
||||
settings.get("vision_id_offset") or _OPENROUTER_VISION_ID_OFFSET_DEFAULT
|
||||
)
|
||||
api_key: str = settings.get("api_key", "")
|
||||
rpm: int = settings.get("rpm", 200)
|
||||
tpm: int = settings.get("tpm", 1_000_000)
|
||||
free_rpm: int = settings.get("free_rpm", 20)
|
||||
free_tpm: int = settings.get("free_tpm", 100_000)
|
||||
quota_reserve_tokens: int = settings.get("quota_reserve_tokens", 4000)
|
||||
litellm_params: dict = settings.get("litellm_params") or {}
|
||||
|
||||
vision_models = [
|
||||
m
|
||||
for m in raw_models
|
||||
if _is_vision_input_model(m)
|
||||
and _is_compatible_provider(m)
|
||||
and _is_allowed_model(m)
|
||||
and "/" in m.get("id", "")
|
||||
]
|
||||
|
||||
configs: list[dict] = []
|
||||
taken: set[int] = set()
|
||||
for model in vision_models:
|
||||
model_id: str = model["id"]
|
||||
name: str = model.get("name", model_id)
|
||||
tier = _openrouter_tier(model)
|
||||
pricing = model.get("pricing") or {}
|
||||
|
||||
# Capture per-token prices so ``pricing_registration`` can
|
||||
# register them with LiteLLM at startup (and so the cost
|
||||
# estimator in ``estimate_call_reserve_micros`` can resolve
|
||||
# them at reserve time).
|
||||
try:
|
||||
input_cost = float(pricing.get("prompt", 0) or 0)
|
||||
except (TypeError, ValueError):
|
||||
input_cost = 0.0
|
||||
try:
|
||||
output_cost = float(pricing.get("completion", 0) or 0)
|
||||
except (TypeError, ValueError):
|
||||
output_cost = 0.0
|
||||
|
||||
cfg: dict[str, Any] = {
|
||||
"id": _stable_config_id(model_id, id_offset, taken),
|
||||
"name": name,
|
||||
"description": f"{name} via OpenRouter (vision)",
|
||||
"provider": "OPENROUTER",
|
||||
"model_name": model_id,
|
||||
"api_key": api_key,
|
||||
"api_base": "",
|
||||
"api_version": None,
|
||||
"rpm": free_rpm if tier == "free" else rpm,
|
||||
"tpm": free_tpm if tier == "free" else tpm,
|
||||
"litellm_params": dict(litellm_params),
|
||||
"billing_tier": tier,
|
||||
"quota_reserve_tokens": quota_reserve_tokens,
|
||||
"input_cost_per_token": input_cost or None,
|
||||
"output_cost_per_token": output_cost or None,
|
||||
_OPENROUTER_DYNAMIC_MARKER: True,
|
||||
}
|
||||
configs.append(cfg)
|
||||
|
||||
return configs
|
||||
|
||||
|
||||
class OpenRouterIntegrationService:
|
||||
"""Singleton that manages the dynamic OpenRouter model catalogue."""
|
||||
|
||||
|
|
@ -300,6 +511,19 @@ class OpenRouterIntegrationService:
|
|||
# Shape: {model_name: {"gated": bool, "score": float | None}}
|
||||
self._health_cache: dict[str, dict[str, Any]] = {}
|
||||
self._enrich_task: asyncio.Task | None = None
|
||||
# Raw OpenRouter pricing per model_id, captured at the same time
|
||||
# we generate configs. Consumed by ``pricing_registration`` to
|
||||
# teach LiteLLM the per-token cost of every dynamic deployment so
|
||||
# the success-callback can populate ``response_cost`` correctly.
|
||||
self._raw_pricing: dict[str, dict[str, str]] = {}
|
||||
# Cached raw catalogue from the most recent fetch. Image / vision
|
||||
# emitters reuse this to avoid a second network call per surface.
|
||||
self._raw_models: list[dict] = []
|
||||
# Image / vision config caches (only populated when the matching
|
||||
# opt-in flag is true on initialize). Refreshed in lockstep with
|
||||
# the chat catalogue.
|
||||
self._image_configs: list[dict] = []
|
||||
self._vision_configs: list[dict] = []
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls) -> "OpenRouterIntegrationService":
|
||||
|
|
@ -329,8 +553,32 @@ class OpenRouterIntegrationService:
|
|||
self._initialized = True
|
||||
return []
|
||||
|
||||
self._raw_models = raw_models
|
||||
self._configs = _generate_configs(raw_models, settings)
|
||||
self._configs_by_id = {c["id"]: c for c in self._configs}
|
||||
self._raw_pricing = _extract_raw_pricing(raw_models)
|
||||
|
||||
# Populate image / vision caches when their opt-in flag is set.
|
||||
# Empty otherwise so the accessors return [] without re-running
|
||||
# filters every refresh.
|
||||
if settings.get("image_generation_enabled"):
|
||||
self._image_configs = _generate_image_gen_configs(raw_models, settings)
|
||||
logger.info(
|
||||
"OpenRouter integration: image-gen emission ON (%d models)",
|
||||
len(self._image_configs),
|
||||
)
|
||||
else:
|
||||
self._image_configs = []
|
||||
|
||||
if settings.get("vision_enabled"):
|
||||
self._vision_configs = _generate_vision_llm_configs(raw_models, settings)
|
||||
logger.info(
|
||||
"OpenRouter integration: vision LLM emission ON (%d models)",
|
||||
len(self._vision_configs),
|
||||
)
|
||||
else:
|
||||
self._vision_configs = []
|
||||
|
||||
self._initialized = True
|
||||
|
||||
tier_counts = self._tier_counts(self._configs)
|
||||
|
|
@ -369,6 +617,8 @@ class OpenRouterIntegrationService:
|
|||
|
||||
new_configs = _generate_configs(raw_models, self._settings)
|
||||
new_by_id = {c["id"]: c for c in new_configs}
|
||||
self._raw_pricing = _extract_raw_pricing(raw_models)
|
||||
self._raw_models = raw_models
|
||||
|
||||
from app.config import config as app_config
|
||||
|
||||
|
|
@ -382,6 +632,29 @@ class OpenRouterIntegrationService:
|
|||
self._configs = new_configs
|
||||
self._configs_by_id = new_by_id
|
||||
|
||||
# Image / vision lists are atomic-swapped the same way: filter out
|
||||
# the previous dynamic entries from the live config list and append
|
||||
# the freshly generated ones. No-ops when the opt-in flag is off.
|
||||
if self._settings.get("image_generation_enabled"):
|
||||
new_image = _generate_image_gen_configs(raw_models, self._settings)
|
||||
static_image = [
|
||||
c
|
||||
for c in app_config.GLOBAL_IMAGE_GEN_CONFIGS
|
||||
if not c.get(_OPENROUTER_DYNAMIC_MARKER)
|
||||
]
|
||||
app_config.GLOBAL_IMAGE_GEN_CONFIGS = static_image + new_image
|
||||
self._image_configs = new_image
|
||||
|
||||
if self._settings.get("vision_enabled"):
|
||||
new_vision = _generate_vision_llm_configs(raw_models, self._settings)
|
||||
static_vision = [
|
||||
c
|
||||
for c in app_config.GLOBAL_VISION_LLM_CONFIGS
|
||||
if not c.get(_OPENROUTER_DYNAMIC_MARKER)
|
||||
]
|
||||
app_config.GLOBAL_VISION_LLM_CONFIGS = static_vision + new_vision
|
||||
self._vision_configs = new_vision
|
||||
|
||||
# Catalogue churn invalidates per-config "recently healthy" credit
|
||||
# earned by the previous turn's preflight. Drop the whole table so
|
||||
# the next turn re-probes against the freshly loaded configs.
|
||||
|
|
@ -407,6 +680,21 @@ class OpenRouterIntegrationService:
|
|||
# so a hand-picked dead OR model is gated like a dynamic one.
|
||||
await self._enrich_health_safely(static_configs + new_configs, log_summary=True)
|
||||
|
||||
# Re-register LiteLLM pricing for the freshly fetched catalogue
|
||||
# so newly added OR models bill correctly on their first call.
|
||||
# Runs before the router rebuild because the router may issue
|
||||
# cost-table lookups during deployment registration.
|
||||
try:
|
||||
from app.services.pricing_registration import (
|
||||
register_pricing_from_global_configs,
|
||||
)
|
||||
|
||||
register_pricing_from_global_configs()
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"OpenRouter refresh: pricing re-registration skipped (%s)", exc
|
||||
)
|
||||
|
||||
# Rebuild the LiteLLM router so freshly fetched configs flow through
|
||||
# (dynamic OR premium entries now opt into the pool, free ones stay
|
||||
# out; a refresh also needs to pick up any static-config edits and
|
||||
|
|
@ -635,3 +923,34 @@ class OpenRouterIntegrationService:
|
|||
|
||||
def get_config_by_id(self, config_id: int) -> dict | None:
|
||||
return self._configs_by_id.get(config_id)
|
||||
|
||||
def get_image_generation_configs(self) -> list[dict]:
|
||||
"""Return the dynamic OpenRouter image-generation configs (empty
|
||||
list when the ``image_generation_enabled`` flag is off).
|
||||
|
||||
Each entry already has ``billing_tier`` derived per-model from
|
||||
OpenRouter's signals and is shaped to drop directly into
|
||||
``Config.GLOBAL_IMAGE_GEN_CONFIGS``.
|
||||
"""
|
||||
return list(self._image_configs)
|
||||
|
||||
def get_vision_llm_configs(self) -> list[dict]:
|
||||
"""Return the dynamic OpenRouter vision-LLM configs (empty list
|
||||
when the ``vision_enabled`` flag is off).
|
||||
|
||||
Each entry exposes ``input_cost_per_token`` / ``output_cost_per_token``
|
||||
so ``pricing_registration`` can teach LiteLLM the cost of these
|
||||
models the same way it does for chat — which keeps the billable
|
||||
wrapper able to debit accurate micro-USD on a vision call.
|
||||
"""
|
||||
return list(self._vision_configs)
|
||||
|
||||
def get_raw_pricing(self) -> dict[str, dict[str, str]]:
|
||||
"""Return the cached raw OpenRouter pricing map.
|
||||
|
||||
Shape: ``{model_id: {"prompt": str, "completion": str}}``. The
|
||||
values are the strings OpenRouter publishes (USD per token),
|
||||
never converted to floats here so the caller can decide how to
|
||||
handle malformed or unset entries.
|
||||
"""
|
||||
return dict(self._raw_pricing)
|
||||
|
|
|
|||
274
surfsense_backend/app/services/pricing_registration.py
Normal file
274
surfsense_backend/app/services/pricing_registration.py
Normal file
|
|
@ -0,0 +1,274 @@
|
|||
"""
|
||||
Pricing registration with LiteLLM.
|
||||
|
||||
Many models reach our LiteLLM callback without LiteLLM knowing their
|
||||
per-token cost — namely:
|
||||
|
||||
* The ~300 dynamic OpenRouter deployments (their pricing only lives on
|
||||
OpenRouter's ``/api/v1/models`` payload, never in LiteLLM's published
|
||||
pricing table).
|
||||
* Static YAML deployments whose ``base_model`` name is operator-defined
|
||||
(e.g. custom Azure deployment names like ``gpt-5.4``) and therefore
|
||||
not in LiteLLM's table either.
|
||||
|
||||
Without registration, ``kwargs["response_cost"]`` is 0 for those calls
|
||||
and the user gets billed nothing — a fail-safe but wrong answer for a
|
||||
cost-based credit system. This module runs once at startup, after the
|
||||
OpenRouter integration has fetched its catalogue, and registers each
|
||||
known model's pricing with ``litellm.register_model()`` under multiple
|
||||
plausible alias keys (LiteLLM's cost lookup may use any of them
|
||||
depending on whether the call went through the Router, ChatLiteLLM,
|
||||
or a direct ``acompletion``).
|
||||
|
||||
Operators who run a custom Azure deployment whose ``base_model`` name
|
||||
isn't in LiteLLM's table can declare per-token pricing inline in
|
||||
``global_llm_config.yaml`` via ``input_cost_per_token`` and
|
||||
``output_cost_per_token`` (USD per token, e.g. ``0.000002``). Without
|
||||
that declaration the model's calls debit 0 — never overbilled.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import litellm
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _safe_float(value: Any) -> float:
|
||||
"""Return ``float(value)`` if it parses to a positive number, else 0.0."""
|
||||
if value is None:
|
||||
return 0.0
|
||||
try:
|
||||
f = float(value)
|
||||
except (TypeError, ValueError):
|
||||
return 0.0
|
||||
return f if f > 0 else 0.0
|
||||
|
||||
|
||||
def _alias_set_for_openrouter(model_id: str) -> list[str]:
|
||||
"""Return the alias keys to register an OpenRouter model under.
|
||||
|
||||
LiteLLM's cost-callback lookup key varies by call path:
|
||||
- Router with ``model="openrouter/X"`` → kwargs["model"] is
|
||||
typically ``openrouter/X``.
|
||||
- LiteLLM's own provider routing may strip the prefix and pass the
|
||||
bare ``X`` to the cost-table lookup.
|
||||
Registering under both keeps the lookup hermetic regardless of
|
||||
which path the call took.
|
||||
"""
|
||||
aliases = [f"openrouter/{model_id}", model_id]
|
||||
return list(dict.fromkeys(a for a in aliases if a))
|
||||
|
||||
|
||||
def _alias_set_for_yaml(provider: str, model_name: str, base_model: str) -> list[str]:
|
||||
"""Return the alias keys to register a static YAML deployment under.
|
||||
|
||||
Same reasoning as the OpenRouter set: cover the bare ``base_model``,
|
||||
the ``<provider>/<model>`` form LiteLLM Router constructs, and the
|
||||
bare ``model_name`` because callbacks sometimes see whichever was
|
||||
configured first.
|
||||
"""
|
||||
provider_lower = (provider or "").lower()
|
||||
aliases: list[str] = []
|
||||
if base_model:
|
||||
aliases.append(base_model)
|
||||
if provider_lower and base_model:
|
||||
aliases.append(f"{provider_lower}/{base_model}")
|
||||
if model_name and model_name != base_model:
|
||||
aliases.append(model_name)
|
||||
if provider_lower and model_name and model_name != base_model:
|
||||
aliases.append(f"{provider_lower}/{model_name}")
|
||||
# Azure deployments often surface as "azure/<name>"; normalise the
|
||||
# ``azure_openai`` provider slug to the LiteLLM-canonical ``azure``.
|
||||
if provider_lower == "azure_openai":
|
||||
if base_model:
|
||||
aliases.append(f"azure/{base_model}")
|
||||
if model_name and model_name != base_model:
|
||||
aliases.append(f"azure/{model_name}")
|
||||
return list(dict.fromkeys(a for a in aliases if a))
|
||||
|
||||
|
||||
def _register(
|
||||
aliases: list[str],
|
||||
*,
|
||||
input_cost: float,
|
||||
output_cost: float,
|
||||
provider: str,
|
||||
mode: str = "chat",
|
||||
) -> int:
|
||||
"""Register a single pricing entry under every alias in ``aliases``.
|
||||
|
||||
Returns the count of aliases successfully registered.
|
||||
"""
|
||||
payload: dict[str, dict[str, Any]] = {}
|
||||
for alias in aliases:
|
||||
payload[alias] = {
|
||||
"input_cost_per_token": input_cost,
|
||||
"output_cost_per_token": output_cost,
|
||||
"litellm_provider": provider,
|
||||
"mode": mode,
|
||||
}
|
||||
if not payload:
|
||||
return 0
|
||||
try:
|
||||
litellm.register_model(payload)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"[PricingRegistration] register_model failed for aliases=%s: %s",
|
||||
aliases,
|
||||
exc,
|
||||
)
|
||||
return 0
|
||||
return len(payload)
|
||||
|
||||
|
||||
def _register_chat_shape_configs(
|
||||
configs: list[dict],
|
||||
*,
|
||||
or_pricing: dict[str, dict[str, str]],
|
||||
label: str,
|
||||
) -> tuple[int, int, int, list[str]]:
|
||||
"""Common loop that registers per-token pricing for a list of "chat-shape"
|
||||
configs (chat or vision LLM — both use ``input_cost_per_token`` /
|
||||
``output_cost_per_token`` and the LiteLLM ``mode="chat"`` cost shape).
|
||||
|
||||
Returns ``(registered_models, registered_aliases, skipped, sample_keys)``.
|
||||
"""
|
||||
registered_models = 0
|
||||
registered_aliases = 0
|
||||
skipped_no_pricing = 0
|
||||
sample_keys: list[str] = []
|
||||
|
||||
for cfg in configs:
|
||||
provider = str(cfg.get("provider") or "").upper()
|
||||
model_name = str(cfg.get("model_name") or "").strip()
|
||||
litellm_params = cfg.get("litellm_params") or {}
|
||||
base_model = str(litellm_params.get("base_model") or model_name).strip()
|
||||
|
||||
if provider == "OPENROUTER":
|
||||
entry = or_pricing.get(model_name)
|
||||
if entry:
|
||||
input_cost = _safe_float(entry.get("prompt"))
|
||||
output_cost = _safe_float(entry.get("completion"))
|
||||
else:
|
||||
# Vision configs from ``_generate_vision_llm_configs``
|
||||
# carry their pricing inline because the OpenRouter
|
||||
# raw-pricing cache is keyed by chat-catalogue model_id;
|
||||
# vision flows pick up the inline values here.
|
||||
input_cost = _safe_float(cfg.get("input_cost_per_token"))
|
||||
output_cost = _safe_float(cfg.get("output_cost_per_token"))
|
||||
if input_cost == 0.0 and output_cost == 0.0:
|
||||
skipped_no_pricing += 1
|
||||
continue
|
||||
aliases = _alias_set_for_openrouter(model_name)
|
||||
count = _register(
|
||||
aliases,
|
||||
input_cost=input_cost,
|
||||
output_cost=output_cost,
|
||||
provider="openrouter",
|
||||
)
|
||||
if count > 0:
|
||||
registered_models += 1
|
||||
registered_aliases += count
|
||||
if len(sample_keys) < 6:
|
||||
sample_keys.extend(aliases[:2])
|
||||
continue
|
||||
|
||||
input_cost = _safe_float(
|
||||
cfg.get("input_cost_per_token")
|
||||
or litellm_params.get("input_cost_per_token")
|
||||
)
|
||||
output_cost = _safe_float(
|
||||
cfg.get("output_cost_per_token")
|
||||
or litellm_params.get("output_cost_per_token")
|
||||
)
|
||||
if input_cost == 0.0 and output_cost == 0.0:
|
||||
skipped_no_pricing += 1
|
||||
continue
|
||||
aliases = _alias_set_for_yaml(provider, model_name, base_model)
|
||||
provider_slug = "azure" if provider == "AZURE_OPENAI" else provider.lower()
|
||||
count = _register(
|
||||
aliases,
|
||||
input_cost=input_cost,
|
||||
output_cost=output_cost,
|
||||
provider=provider_slug,
|
||||
)
|
||||
if count > 0:
|
||||
registered_models += 1
|
||||
registered_aliases += count
|
||||
if len(sample_keys) < 6:
|
||||
sample_keys.extend(aliases[:2])
|
||||
|
||||
logger.info(
|
||||
"[PricingRegistration:%s] registered pricing for %d models (%d aliases); "
|
||||
"%d configs had no pricing data; sample registered keys=%s",
|
||||
label,
|
||||
registered_models,
|
||||
registered_aliases,
|
||||
skipped_no_pricing,
|
||||
sample_keys,
|
||||
)
|
||||
return registered_models, registered_aliases, skipped_no_pricing, sample_keys
|
||||
|
||||
|
||||
def register_pricing_from_global_configs() -> None:
|
||||
"""Register pricing for every known LLM deployment with LiteLLM.
|
||||
|
||||
Walks ``config.GLOBAL_LLM_CONFIGS`` *and* ``config.GLOBAL_VISION_LLM_CONFIGS``
|
||||
so vision calls (during indexing) can resolve cost the same way chat
|
||||
calls do — namely:
|
||||
|
||||
1. ``OPENROUTER``: pulls the cached raw pricing from
|
||||
``OpenRouterIntegrationService`` (populated during its own
|
||||
startup fetch) and converts the per-token strings to floats. For
|
||||
vision configs that carry pricing inline (``input_cost_per_token`` /
|
||||
``output_cost_per_token`` set on the cfg itself) we fall back to
|
||||
those values when the OR cache misses the model.
|
||||
2. Anything else: looks for operator-declared
|
||||
``input_cost_per_token`` / ``output_cost_per_token`` on the YAML
|
||||
config block (top-level or nested under ``litellm_params``).
|
||||
|
||||
**Image generation is intentionally NOT registered here.** The cost
|
||||
shape for image-gen is per-image (``output_cost_per_image``), not
|
||||
per-token, and LiteLLM's ``register_model`` doesn't accept those
|
||||
keys via the chat-cost path. OpenRouter image-gen models populate
|
||||
``response_cost`` directly from their response header instead, and
|
||||
Azure-native image-gen models are already in LiteLLM's cost map.
|
||||
|
||||
Calls without a resolved pair of costs are skipped, not registered
|
||||
with zeros — operators who forget pricing get a "$0 debit" warning
|
||||
in ``TokenTrackingCallback`` rather than silently overwriting any
|
||||
pricing LiteLLM might know natively.
|
||||
"""
|
||||
from app.config import config as app_config
|
||||
|
||||
chat_configs: list[dict] = list(getattr(app_config, "GLOBAL_LLM_CONFIGS", []) or [])
|
||||
vision_configs: list[dict] = list(
|
||||
getattr(app_config, "GLOBAL_VISION_LLM_CONFIGS", []) or []
|
||||
)
|
||||
if not chat_configs and not vision_configs:
|
||||
logger.info("[PricingRegistration] no global configs to register")
|
||||
return
|
||||
|
||||
or_pricing: dict[str, dict[str, str]] = {}
|
||||
try:
|
||||
from app.services.openrouter_integration_service import (
|
||||
OpenRouterIntegrationService,
|
||||
)
|
||||
|
||||
if OpenRouterIntegrationService.is_initialized():
|
||||
or_pricing = OpenRouterIntegrationService.get_instance().get_raw_pricing()
|
||||
except Exception as exc:
|
||||
logger.debug(
|
||||
"[PricingRegistration] OpenRouter pricing not available yet: %s", exc
|
||||
)
|
||||
|
||||
if chat_configs:
|
||||
_register_chat_shape_configs(chat_configs, or_pricing=or_pricing, label="chat")
|
||||
if vision_configs:
|
||||
_register_chat_shape_configs(
|
||||
vision_configs, or_pricing=or_pricing, label="vision"
|
||||
)
|
||||
107
surfsense_backend/app/services/provider_api_base.py
Normal file
107
surfsense_backend/app/services/provider_api_base.py
Normal file
|
|
@ -0,0 +1,107 @@
|
|||
"""Provider-aware ``api_base`` resolution shared by chat / image-gen / vision.
|
||||
|
||||
LiteLLM falls back to the module-global ``litellm.api_base`` when an
|
||||
individual call doesn't pass one, which silently inherits provider-agnostic
|
||||
env vars like ``AZURE_OPENAI_ENDPOINT`` / ``OPENAI_API_BASE``. Without an
|
||||
explicit ``api_base``, an ``openrouter/<model>`` request can end up at an
|
||||
Azure endpoint and 404 with ``Resource not found`` (real reproducer:
|
||||
[litellm/llms/openrouter/image_generation/transformation.py:242-263] appends
|
||||
``/chat/completions`` to whatever inherited base it gets, regardless of
|
||||
provider).
|
||||
|
||||
The chat router has had this defense for a while
|
||||
(``llm_router_service.py:466-478``). This module hoists the maps + cascade
|
||||
into a tiny standalone helper so vision and image-gen can share the same
|
||||
source of truth without an inter-service circular import.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
PROVIDER_DEFAULT_API_BASE: dict[str, str] = {
|
||||
"openrouter": "https://openrouter.ai/api/v1",
|
||||
"groq": "https://api.groq.com/openai/v1",
|
||||
"mistral": "https://api.mistral.ai/v1",
|
||||
"perplexity": "https://api.perplexity.ai",
|
||||
"xai": "https://api.x.ai/v1",
|
||||
"cerebras": "https://api.cerebras.ai/v1",
|
||||
"deepinfra": "https://api.deepinfra.com/v1/openai",
|
||||
"fireworks_ai": "https://api.fireworks.ai/inference/v1",
|
||||
"together_ai": "https://api.together.xyz/v1",
|
||||
"anyscale": "https://api.endpoints.anyscale.com/v1",
|
||||
"cometapi": "https://api.cometapi.com/v1",
|
||||
"sambanova": "https://api.sambanova.ai/v1",
|
||||
}
|
||||
"""Default ``api_base`` per LiteLLM provider prefix (lowercase).
|
||||
|
||||
Only providers with a well-known, stable public base URL are listed —
|
||||
self-hosted / BYO-endpoint providers (ollama, custom, bedrock, vertex_ai,
|
||||
huggingface, databricks, cloudflare, replicate) are intentionally omitted
|
||||
so their existing config-driven behaviour is preserved."""
|
||||
|
||||
|
||||
PROVIDER_KEY_DEFAULT_API_BASE: dict[str, str] = {
|
||||
"DEEPSEEK": "https://api.deepseek.com/v1",
|
||||
"ALIBABA_QWEN": "https://dashscope-intl.aliyuncs.com/compatible-mode/v1",
|
||||
"MOONSHOT": "https://api.moonshot.ai/v1",
|
||||
"ZHIPU": "https://open.bigmodel.cn/api/paas/v4",
|
||||
"MINIMAX": "https://api.minimax.io/v1",
|
||||
}
|
||||
"""Canonical provider key (uppercase) → base URL.
|
||||
|
||||
Used when the LiteLLM provider prefix is the generic ``openai`` shim but the
|
||||
config's ``provider`` field tells us which API it actually is (DeepSeek,
|
||||
Alibaba, Moonshot, Zhipu, MiniMax all use the ``openai`` prefix but each
|
||||
has its own base URL)."""
|
||||
|
||||
|
||||
def resolve_api_base(
|
||||
*,
|
||||
provider: str | None,
|
||||
provider_prefix: str | None,
|
||||
config_api_base: str | None,
|
||||
) -> str | None:
|
||||
"""Resolve a non-Azure-leaking ``api_base`` for a deployment.
|
||||
|
||||
Cascade (first non-empty wins):
|
||||
1. The config's own ``api_base`` (whitespace-only treated as missing).
|
||||
2. ``PROVIDER_KEY_DEFAULT_API_BASE[provider.upper()]``.
|
||||
3. ``PROVIDER_DEFAULT_API_BASE[provider_prefix.lower()]``.
|
||||
4. ``None`` — caller should NOT set ``api_base`` and let the LiteLLM
|
||||
provider integration apply its own default (e.g. AzureOpenAI's
|
||||
deployment-derived URL, custom provider's per-deployment URL).
|
||||
|
||||
Args:
|
||||
provider: The config's ``provider`` field (e.g. ``"OPENROUTER"``,
|
||||
``"DEEPSEEK"``). Case-insensitive.
|
||||
provider_prefix: The LiteLLM model-string prefix the same call
|
||||
site builds for the model id (e.g. ``"openrouter"``,
|
||||
``"groq"``). Case-insensitive.
|
||||
config_api_base: ``api_base`` from the global YAML / DB row /
|
||||
OpenRouter dynamic config. Empty / whitespace-only means
|
||||
"missing" — the resolver still applies the cascade.
|
||||
|
||||
Returns:
|
||||
A URL string, or ``None`` if no default applies for this provider.
|
||||
"""
|
||||
if config_api_base and config_api_base.strip():
|
||||
return config_api_base
|
||||
|
||||
if provider:
|
||||
key_default = PROVIDER_KEY_DEFAULT_API_BASE.get(provider.upper())
|
||||
if key_default:
|
||||
return key_default
|
||||
|
||||
if provider_prefix:
|
||||
prefix_default = PROVIDER_DEFAULT_API_BASE.get(provider_prefix.lower())
|
||||
if prefix_default:
|
||||
return prefix_default
|
||||
|
||||
return None
|
||||
|
||||
|
||||
__all__ = [
|
||||
"PROVIDER_DEFAULT_API_BASE",
|
||||
"PROVIDER_KEY_DEFAULT_API_BASE",
|
||||
"resolve_api_base",
|
||||
]
|
||||
105
surfsense_backend/app/services/quota_checked_vision_llm.py
Normal file
105
surfsense_backend/app/services/quota_checked_vision_llm.py
Normal file
|
|
@ -0,0 +1,105 @@
|
|||
"""
|
||||
Vision LLM proxy that enforces premium credit quota on every ``ainvoke``.
|
||||
|
||||
Used by :func:`app.services.llm_service.get_vision_llm` so callers in the
|
||||
indexing pipeline (file processors, connector indexers, etl pipeline) can
|
||||
keep invoking the LLM exactly the way they do today — ``await llm.ainvoke(...)``
|
||||
— without threading ``user_id`` through every parser. The wrapper looks like
|
||||
a chat model from the outside; on the inside it routes each call through
|
||||
``billable_call`` so the user's premium credit pool is reserved → finalized
|
||||
or released, and a ``TokenUsage`` audit row is written.
|
||||
|
||||
Free configs are returned unwrapped from ``get_vision_llm`` (they do not
|
||||
need quota enforcement) so this class only ever wraps premium configs.
|
||||
|
||||
Why a wrapper instead of plumbing ``user_id`` through every caller:
|
||||
|
||||
* The indexer ecosystem has 8+ entry points (Google Drive, OneDrive,
|
||||
Dropbox, local-folder, file-processor, ETL pipeline) each calling
|
||||
``parse_with_vision_llm(...)``. Adding a ``user_id`` argument to each is
|
||||
invasive, error-prone, and easy for a future indexer to forget.
|
||||
* Per the design (issue M), we always debit the *search-space owner*, not
|
||||
the triggering user, so ``user_id`` is fully derivable from the search
|
||||
space the caller is already operating on. The wrapper captures it once
|
||||
at construction time.
|
||||
* ``langchain_litellm.ChatLiteLLM`` has no public hook for "before each
|
||||
call run this coroutine"; subclassing isn't safe across versions because
|
||||
it derives from ``BaseChatModel`` which expects specific Pydantic shapes.
|
||||
Composition via attribute proxying (``__getattr__``) is robust to
|
||||
upstream changes — every method other than ``ainvoke`` falls through to
|
||||
the inner LLM unchanged.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from app.services.billable_calls import QuotaInsufficientError, billable_call
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class QuotaCheckedVisionLLM:
|
||||
"""Composition wrapper around a langchain chat model that enforces
|
||||
premium credit quota on every ``ainvoke``.
|
||||
|
||||
Anything other than ``ainvoke`` is forwarded to the inner model so
|
||||
``invoke`` (sync), ``astream``, ``with_structured_output``, etc. all
|
||||
still work — they simply bypass quota enforcement, which is fine
|
||||
because the indexing pipeline only ever calls ``ainvoke`` today.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
inner_llm: Any,
|
||||
*,
|
||||
user_id: UUID,
|
||||
search_space_id: int,
|
||||
billing_tier: str,
|
||||
base_model: str,
|
||||
quota_reserve_tokens: int | None,
|
||||
usage_type: str = "vision_extraction",
|
||||
) -> None:
|
||||
self._inner = inner_llm
|
||||
self._user_id = user_id
|
||||
self._search_space_id = search_space_id
|
||||
self._billing_tier = billing_tier
|
||||
self._base_model = base_model
|
||||
self._quota_reserve_tokens = quota_reserve_tokens
|
||||
self._usage_type = usage_type
|
||||
|
||||
async def ainvoke(self, input: Any, *args: Any, **kwargs: Any) -> Any:
|
||||
"""Proxied async invoke that runs the underlying call inside
|
||||
``billable_call``.
|
||||
|
||||
Raises:
|
||||
QuotaInsufficientError: when the user has exhausted their
|
||||
premium credit pool. Caller (``etl_pipeline_service._extract_image``)
|
||||
catches this and falls back to the document parser.
|
||||
"""
|
||||
async with billable_call(
|
||||
user_id=self._user_id,
|
||||
search_space_id=self._search_space_id,
|
||||
billing_tier=self._billing_tier,
|
||||
base_model=self._base_model,
|
||||
quota_reserve_tokens=self._quota_reserve_tokens,
|
||||
usage_type=self._usage_type,
|
||||
call_details={"model": self._base_model},
|
||||
):
|
||||
return await self._inner.ainvoke(input, *args, **kwargs)
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
"""Forward everything else (``invoke``, ``astream``, ``bind``,
|
||||
``with_structured_output``, …) to the inner model.
|
||||
|
||||
``__getattr__`` is only consulted when the attribute is *not*
|
||||
already found on the proxy, which is exactly the contract we
|
||||
want — methods we override stay on the proxy, the rest fall
|
||||
through.
|
||||
"""
|
||||
return getattr(self._inner, name)
|
||||
|
||||
|
||||
__all__ = ["QuotaCheckedVisionLLM", "QuotaInsufficientError"]
|
||||
|
|
@ -22,6 +22,71 @@ from app.config import config
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Per-call reservation estimator (USD micro-units)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Minimum reserve in micros so a user with $0.0001 left can still make a tiny
|
||||
# request, and so models without registered pricing reserve at least
|
||||
# something while the call runs (debited 0 at finalize anyway when their
|
||||
# cost can't be resolved).
|
||||
_QUOTA_MIN_RESERVE_MICROS = 100
|
||||
|
||||
|
||||
def estimate_call_reserve_micros(
|
||||
*,
|
||||
base_model: str,
|
||||
quota_reserve_tokens: int | None,
|
||||
) -> int:
|
||||
"""Return the number of micro-USD to reserve for one premium call.
|
||||
|
||||
Computes a worst-case upper bound from LiteLLM's per-token pricing
|
||||
table:
|
||||
|
||||
reserve_usd ≈ reserve_tokens x (input_cost + output_cost)
|
||||
|
||||
so the math scales with model cost — Claude Opus + 4K reserve_tokens
|
||||
naturally reserves ≈ $0.36, while a cheap model reserves only a few
|
||||
cents. Clamped to ``[_QUOTA_MIN_RESERVE_MICROS, QUOTA_MAX_RESERVE_MICROS]``
|
||||
so a misconfigured "$1000/M" model can't lock the whole balance on
|
||||
one call.
|
||||
|
||||
If ``litellm.get_model_info`` raises (model unknown) we fall back to
|
||||
the floor — 100 micros / $0.0001 — which is enough to gate a sane
|
||||
request without over-reserving for a model whose pricing the
|
||||
operator hasn't declared yet.
|
||||
"""
|
||||
reserve_tokens = quota_reserve_tokens or config.QUOTA_MAX_RESERVE_PER_CALL
|
||||
if reserve_tokens <= 0:
|
||||
reserve_tokens = config.QUOTA_MAX_RESERVE_PER_CALL
|
||||
|
||||
try:
|
||||
from litellm import get_model_info
|
||||
|
||||
info = get_model_info(base_model) if base_model else {}
|
||||
input_cost = float(info.get("input_cost_per_token") or 0.0)
|
||||
output_cost = float(info.get("output_cost_per_token") or 0.0)
|
||||
except Exception as exc:
|
||||
logger.debug(
|
||||
"[quota_reserve] cost lookup failed for base_model=%s: %s",
|
||||
base_model,
|
||||
exc,
|
||||
)
|
||||
input_cost = 0.0
|
||||
output_cost = 0.0
|
||||
|
||||
if input_cost == 0.0 and output_cost == 0.0:
|
||||
return _QUOTA_MIN_RESERVE_MICROS
|
||||
|
||||
reserve_usd = reserve_tokens * (input_cost + output_cost)
|
||||
reserve_micros = round(reserve_usd * 1_000_000)
|
||||
if reserve_micros < _QUOTA_MIN_RESERVE_MICROS:
|
||||
reserve_micros = _QUOTA_MIN_RESERVE_MICROS
|
||||
if reserve_micros > config.QUOTA_MAX_RESERVE_MICROS:
|
||||
reserve_micros = config.QUOTA_MAX_RESERVE_MICROS
|
||||
return reserve_micros
|
||||
|
||||
|
||||
class QuotaScope(StrEnum):
|
||||
ANONYMOUS = "anonymous"
|
||||
PREMIUM = "premium"
|
||||
|
|
@ -444,8 +509,16 @@ class TokenQuotaService:
|
|||
db_session: AsyncSession,
|
||||
user_id: Any,
|
||||
request_id: str,
|
||||
reserve_tokens: int,
|
||||
reserve_micros: int,
|
||||
) -> QuotaResult:
|
||||
"""Reserve ``reserve_micros`` (USD micro-units) from the user's
|
||||
premium credit balance.
|
||||
|
||||
``QuotaResult.used``/``limit``/``reserved``/``remaining`` are
|
||||
all in micro-USD on this code path; callers (chat stream,
|
||||
token-status route, FE display) convert to dollars by dividing
|
||||
by 1_000_000.
|
||||
"""
|
||||
from app.db import User
|
||||
|
||||
user = (
|
||||
|
|
@ -465,11 +538,11 @@ class TokenQuotaService:
|
|||
limit=0,
|
||||
)
|
||||
|
||||
limit = user.premium_tokens_limit
|
||||
used = user.premium_tokens_used
|
||||
reserved = user.premium_tokens_reserved
|
||||
limit = user.premium_credit_micros_limit
|
||||
used = user.premium_credit_micros_used
|
||||
reserved = user.premium_credit_micros_reserved
|
||||
|
||||
effective = used + reserved + reserve_tokens
|
||||
effective = used + reserved + reserve_micros
|
||||
if effective > limit:
|
||||
remaining = max(0, limit - used - reserved)
|
||||
await db_session.rollback()
|
||||
|
|
@ -482,10 +555,10 @@ class TokenQuotaService:
|
|||
remaining=remaining,
|
||||
)
|
||||
|
||||
user.premium_tokens_reserved = reserved + reserve_tokens
|
||||
user.premium_credit_micros_reserved = reserved + reserve_micros
|
||||
await db_session.commit()
|
||||
|
||||
new_reserved = reserved + reserve_tokens
|
||||
new_reserved = reserved + reserve_micros
|
||||
remaining = max(0, limit - used - new_reserved)
|
||||
warning_threshold = int(limit * 0.8)
|
||||
|
||||
|
|
@ -510,9 +583,12 @@ class TokenQuotaService:
|
|||
db_session: AsyncSession,
|
||||
user_id: Any,
|
||||
request_id: str,
|
||||
actual_tokens: int,
|
||||
reserved_tokens: int,
|
||||
actual_micros: int,
|
||||
reserved_micros: int,
|
||||
) -> QuotaResult:
|
||||
"""Settle the reservation: release ``reserved_micros`` and debit
|
||||
``actual_micros`` (the LiteLLM-reported provider cost in micro-USD).
|
||||
"""
|
||||
from app.db import User
|
||||
|
||||
user = (
|
||||
|
|
@ -529,16 +605,18 @@ class TokenQuotaService:
|
|||
allowed=False, status=QuotaStatus.BLOCKED, used=0, limit=0
|
||||
)
|
||||
|
||||
user.premium_tokens_reserved = max(
|
||||
0, user.premium_tokens_reserved - reserved_tokens
|
||||
user.premium_credit_micros_reserved = max(
|
||||
0, user.premium_credit_micros_reserved - reserved_micros
|
||||
)
|
||||
user.premium_credit_micros_used = (
|
||||
user.premium_credit_micros_used + actual_micros
|
||||
)
|
||||
user.premium_tokens_used = user.premium_tokens_used + actual_tokens
|
||||
|
||||
await db_session.commit()
|
||||
|
||||
limit = user.premium_tokens_limit
|
||||
used = user.premium_tokens_used
|
||||
reserved = user.premium_tokens_reserved
|
||||
limit = user.premium_credit_micros_limit
|
||||
used = user.premium_credit_micros_used
|
||||
reserved = user.premium_credit_micros_reserved
|
||||
remaining = max(0, limit - used - reserved)
|
||||
|
||||
warning_threshold = int(limit * 0.8)
|
||||
|
|
@ -562,8 +640,13 @@ class TokenQuotaService:
|
|||
async def premium_release(
|
||||
db_session: AsyncSession,
|
||||
user_id: Any,
|
||||
reserved_tokens: int,
|
||||
reserved_micros: int,
|
||||
) -> None:
|
||||
"""Release ``reserved_micros`` previously held by ``premium_reserve``.
|
||||
|
||||
Used when a request fails before finalize (so the reservation
|
||||
doesn't leak credit).
|
||||
"""
|
||||
from app.db import User
|
||||
|
||||
user = (
|
||||
|
|
@ -576,8 +659,8 @@ class TokenQuotaService:
|
|||
.scalar_one_or_none()
|
||||
)
|
||||
if user is not None:
|
||||
user.premium_tokens_reserved = max(
|
||||
0, user.premium_tokens_reserved - reserved_tokens
|
||||
user.premium_credit_micros_reserved = max(
|
||||
0, user.premium_credit_micros_reserved - reserved_micros
|
||||
)
|
||||
await db_session.commit()
|
||||
|
||||
|
|
@ -598,9 +681,9 @@ class TokenQuotaService:
|
|||
allowed=False, status=QuotaStatus.BLOCKED, used=0, limit=0
|
||||
)
|
||||
|
||||
limit = user.premium_tokens_limit
|
||||
used = user.premium_tokens_used
|
||||
reserved = user.premium_tokens_reserved
|
||||
limit = user.premium_credit_micros_limit
|
||||
used = user.premium_credit_micros_used
|
||||
reserved = user.premium_credit_micros_reserved
|
||||
remaining = max(0, limit - used - reserved)
|
||||
|
||||
warning_threshold = int(limit * 0.8)
|
||||
|
|
|
|||
|
|
@ -16,11 +16,14 @@ from __future__ import annotations
|
|||
|
||||
import dataclasses
|
||||
import logging
|
||||
from collections.abc import AsyncIterator
|
||||
from contextlib import asynccontextmanager
|
||||
from contextvars import ContextVar
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
import litellm
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
|
|
@ -35,6 +38,8 @@ class TokenCallRecord:
|
|||
prompt_tokens: int
|
||||
completion_tokens: int
|
||||
total_tokens: int
|
||||
cost_micros: int = 0
|
||||
call_kind: str = "chat"
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -49,6 +54,8 @@ class TurnTokenAccumulator:
|
|||
prompt_tokens: int,
|
||||
completion_tokens: int,
|
||||
total_tokens: int,
|
||||
cost_micros: int = 0,
|
||||
call_kind: str = "chat",
|
||||
) -> None:
|
||||
self.calls.append(
|
||||
TokenCallRecord(
|
||||
|
|
@ -56,20 +63,28 @@ class TurnTokenAccumulator:
|
|||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=total_tokens,
|
||||
cost_micros=cost_micros,
|
||||
call_kind=call_kind,
|
||||
)
|
||||
)
|
||||
|
||||
def per_message_summary(self) -> dict[str, dict[str, int]]:
|
||||
"""Return token counts grouped by model name."""
|
||||
"""Return token counts (and cost) grouped by model name."""
|
||||
by_model: dict[str, dict[str, int]] = {}
|
||||
for c in self.calls:
|
||||
entry = by_model.setdefault(
|
||||
c.model,
|
||||
{"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
|
||||
{
|
||||
"prompt_tokens": 0,
|
||||
"completion_tokens": 0,
|
||||
"total_tokens": 0,
|
||||
"cost_micros": 0,
|
||||
},
|
||||
)
|
||||
entry["prompt_tokens"] += c.prompt_tokens
|
||||
entry["completion_tokens"] += c.completion_tokens
|
||||
entry["total_tokens"] += c.total_tokens
|
||||
entry["cost_micros"] += c.cost_micros
|
||||
return by_model
|
||||
|
||||
@property
|
||||
|
|
@ -84,6 +99,21 @@ class TurnTokenAccumulator:
|
|||
def total_completion_tokens(self) -> int:
|
||||
return sum(c.completion_tokens for c in self.calls)
|
||||
|
||||
@property
|
||||
def total_cost_micros(self) -> int:
|
||||
"""Sum of per-call ``cost_micros`` across the entire turn.
|
||||
|
||||
Used by ``stream_new_chat`` to debit a premium turn's actual
|
||||
provider cost (in micro-USD) from the user's premium credit
|
||||
balance. ``cost_micros`` per call is captured by
|
||||
``TokenTrackingCallback.async_log_success_event`` from
|
||||
``kwargs["response_cost"]`` (LiteLLM's auto-calculated cost),
|
||||
with multiple fallback paths so OpenRouter dynamic models and
|
||||
custom Azure deployments still bill correctly when our
|
||||
``pricing_registration`` ran at startup.
|
||||
"""
|
||||
return sum(c.cost_micros for c in self.calls)
|
||||
|
||||
def serialized_calls(self) -> list[dict[str, Any]]:
|
||||
return [dataclasses.asdict(c) for c in self.calls]
|
||||
|
||||
|
|
@ -94,7 +124,14 @@ _turn_accumulator: ContextVar[TurnTokenAccumulator | None] = ContextVar(
|
|||
|
||||
|
||||
def start_turn() -> TurnTokenAccumulator:
|
||||
"""Create a fresh accumulator for the current async context and return it."""
|
||||
"""Create a fresh accumulator for the current async context and return it.
|
||||
|
||||
NOTE: Used by ``stream_new_chat`` for the long-lived chat turn. For
|
||||
short-lived per-call billable wrappers (image generation REST endpoint,
|
||||
vision LLM during indexing) prefer :func:`scoped_turn`, which uses a
|
||||
ContextVar reset token to restore the *previous* accumulator on exit and
|
||||
avoids leaking call records across reservations (issue B).
|
||||
"""
|
||||
acc = TurnTokenAccumulator()
|
||||
_turn_accumulator.set(acc)
|
||||
logger.info("[TokenTracking] start_turn: new accumulator created (id=%s)", id(acc))
|
||||
|
|
@ -105,6 +142,140 @@ def get_current_accumulator() -> TurnTokenAccumulator | None:
|
|||
return _turn_accumulator.get()
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def scoped_turn() -> AsyncIterator[TurnTokenAccumulator]:
|
||||
"""Async context manager that scopes a fresh ``TurnTokenAccumulator``
|
||||
for the duration of the ``async with`` block, then *resets* the
|
||||
ContextVar to its previous value on exit.
|
||||
|
||||
This is the safe primitive for per-call billable operations
|
||||
(image generation, vision LLM extraction, podcasts) that may run
|
||||
inside an outer chat turn or be called sequentially from the same
|
||||
background worker. Using ``ContextVar.set`` without ``reset`` (as
|
||||
:func:`start_turn` does) would leak the inner accumulator into the
|
||||
outer scope, causing the outer chat turn to debit cost twice.
|
||||
|
||||
Usage::
|
||||
|
||||
async with scoped_turn() as acc:
|
||||
await llm.ainvoke(...)
|
||||
# acc.total_cost_micros captures cost from the LiteLLM callback
|
||||
# Outer accumulator (if any) is restored here.
|
||||
"""
|
||||
acc = TurnTokenAccumulator()
|
||||
token = _turn_accumulator.set(acc)
|
||||
logger.debug(
|
||||
"[TokenTracking] scoped_turn: enter (acc id=%s, prev token=%s)",
|
||||
id(acc),
|
||||
token,
|
||||
)
|
||||
try:
|
||||
yield acc
|
||||
finally:
|
||||
_turn_accumulator.reset(token)
|
||||
logger.debug(
|
||||
"[TokenTracking] scoped_turn: exit (acc id=%s captured %d call(s), %d micros total)",
|
||||
id(acc),
|
||||
len(acc.calls),
|
||||
acc.total_cost_micros,
|
||||
)
|
||||
|
||||
|
||||
def _extract_cost_usd(
|
||||
kwargs: dict[str, Any],
|
||||
response_obj: Any,
|
||||
model: str,
|
||||
prompt_tokens: int,
|
||||
completion_tokens: int,
|
||||
is_image: bool = False,
|
||||
) -> float:
|
||||
"""Best-effort USD cost extraction for a single LLM/image call.
|
||||
|
||||
Tries four sources in priority order and returns the first that
|
||||
yields a positive number; returns 0.0 if all four fail (the call
|
||||
will then debit nothing from the user's balance — fail-safe).
|
||||
|
||||
Sources:
|
||||
1. ``kwargs["response_cost"]`` — LiteLLM's standard callback
|
||||
field, populated for ``Router.acompletion`` since PR #12500.
|
||||
2. ``response_obj._hidden_params["response_cost"]`` — same value
|
||||
exposed on the response itself.
|
||||
3. ``litellm.completion_cost(completion_response=response_obj)``
|
||||
— recompute from the response and LiteLLM's pricing table.
|
||||
4. ``litellm.cost_per_token(model, prompt_tokens, completion_tokens)``
|
||||
— manual fallback for OpenRouter/custom-Azure models that
|
||||
only resolve via aliases registered by
|
||||
``pricing_registration`` at startup. **Skipped for image
|
||||
responses** — ``cost_per_token`` does not support ``ImageResponse``
|
||||
and would raise; the cost map for image-gen lives in different
|
||||
keys (``output_cost_per_image``) handled by ``completion_cost``.
|
||||
"""
|
||||
cost = kwargs.get("response_cost")
|
||||
if cost is not None:
|
||||
try:
|
||||
value = float(cost)
|
||||
except (TypeError, ValueError):
|
||||
value = 0.0
|
||||
if value > 0:
|
||||
return value
|
||||
|
||||
hidden = getattr(response_obj, "_hidden_params", None) or {}
|
||||
if isinstance(hidden, dict):
|
||||
cost = hidden.get("response_cost")
|
||||
if cost is not None:
|
||||
try:
|
||||
value = float(cost)
|
||||
except (TypeError, ValueError):
|
||||
value = 0.0
|
||||
if value > 0:
|
||||
return value
|
||||
|
||||
try:
|
||||
value = float(litellm.completion_cost(completion_response=response_obj))
|
||||
if value > 0:
|
||||
return value
|
||||
except Exception as exc:
|
||||
if is_image:
|
||||
# Image-gen path: OpenRouter's image responses can omit
|
||||
# ``usage.cost`` and LiteLLM's ``default_image_cost_calculator``
|
||||
# then *raises* (no cost map for OpenRouter image models).
|
||||
# Bail out with a warning rather than falling through to
|
||||
# cost_per_token (which is also incompatible with ImageResponse).
|
||||
logger.warning(
|
||||
"[TokenTracking] completion_cost failed for image model=%s "
|
||||
"(provider may have omitted usage.cost). Debiting 0. "
|
||||
"Cause: %s",
|
||||
model,
|
||||
exc,
|
||||
)
|
||||
return 0.0
|
||||
logger.debug(
|
||||
"[TokenTracking] completion_cost failed for model=%s: %s", model, exc
|
||||
)
|
||||
|
||||
if is_image:
|
||||
# Never call cost_per_token for ImageResponse — keys mismatch and
|
||||
# the function is documented chat-only.
|
||||
return 0.0
|
||||
|
||||
if model and (prompt_tokens > 0 or completion_tokens > 0):
|
||||
try:
|
||||
prompt_cost, completion_cost = litellm.cost_per_token(
|
||||
model=model,
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
)
|
||||
value = float(prompt_cost) + float(completion_cost)
|
||||
if value > 0:
|
||||
return value
|
||||
except Exception as exc:
|
||||
logger.debug(
|
||||
"[TokenTracking] cost_per_token failed for model=%s: %s", model, exc
|
||||
)
|
||||
|
||||
return 0.0
|
||||
|
||||
|
||||
class TokenTrackingCallback(CustomLogger):
|
||||
"""LiteLLM callback that captures token usage into the turn accumulator."""
|
||||
|
||||
|
|
@ -122,6 +293,13 @@ class TokenTrackingCallback(CustomLogger):
|
|||
)
|
||||
return
|
||||
|
||||
# Detect image generation responses — they have a different usage
|
||||
# shape (ImageUsage with input_tokens/output_tokens) and require a
|
||||
# different cost-extraction path. We probe by class name to avoid a
|
||||
# hard import dependency on litellm internals.
|
||||
response_cls = type(response_obj).__name__
|
||||
is_image = response_cls == "ImageResponse"
|
||||
|
||||
usage = getattr(response_obj, "usage", None)
|
||||
if not usage:
|
||||
logger.debug(
|
||||
|
|
@ -129,24 +307,66 @@ class TokenTrackingCallback(CustomLogger):
|
|||
)
|
||||
return
|
||||
|
||||
prompt_tokens = getattr(usage, "prompt_tokens", 0) or 0
|
||||
completion_tokens = getattr(usage, "completion_tokens", 0) or 0
|
||||
total_tokens = getattr(usage, "total_tokens", 0) or 0
|
||||
if is_image:
|
||||
# ``ImageUsage`` exposes ``input_tokens`` / ``output_tokens``
|
||||
# (not prompt_tokens/completion_tokens). Several providers
|
||||
# populate only one or neither (e.g. OpenRouter's gpt-image-1
|
||||
# passes through `input_tokens` from the prompt but no
|
||||
# completion); fall through gracefully to 0.
|
||||
prompt_tokens = getattr(usage, "input_tokens", 0) or 0
|
||||
completion_tokens = getattr(usage, "output_tokens", 0) or 0
|
||||
total_tokens = (
|
||||
getattr(usage, "total_tokens", 0) or prompt_tokens + completion_tokens
|
||||
)
|
||||
call_kind = "image_generation"
|
||||
else:
|
||||
prompt_tokens = getattr(usage, "prompt_tokens", 0) or 0
|
||||
completion_tokens = getattr(usage, "completion_tokens", 0) or 0
|
||||
total_tokens = getattr(usage, "total_tokens", 0) or 0
|
||||
call_kind = "chat"
|
||||
|
||||
model = kwargs.get("model", "unknown")
|
||||
|
||||
cost_usd = _extract_cost_usd(
|
||||
kwargs=kwargs,
|
||||
response_obj=response_obj,
|
||||
model=model,
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
is_image=is_image,
|
||||
)
|
||||
cost_micros = round(cost_usd * 1_000_000) if cost_usd > 0 else 0
|
||||
|
||||
if cost_micros == 0 and (prompt_tokens > 0 or completion_tokens > 0):
|
||||
logger.warning(
|
||||
"[TokenTracking] No cost resolved for model=%s prompt=%d completion=%d "
|
||||
"kind=%s — debiting 0. Register pricing via pricing_registration or YAML "
|
||||
"input_cost_per_token/output_cost_per_token (or rely on response_cost "
|
||||
"for image generation).",
|
||||
model,
|
||||
prompt_tokens,
|
||||
completion_tokens,
|
||||
call_kind,
|
||||
)
|
||||
|
||||
acc.add(
|
||||
model=model,
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=total_tokens,
|
||||
cost_micros=cost_micros,
|
||||
call_kind=call_kind,
|
||||
)
|
||||
logger.info(
|
||||
"[TokenTracking] Captured: model=%s prompt=%d completion=%d total=%d (accumulator now has %d calls)",
|
||||
"[TokenTracking] Captured: model=%s kind=%s prompt=%d completion=%d total=%d "
|
||||
"cost=$%.6f (%d micros) (accumulator now has %d calls)",
|
||||
model,
|
||||
call_kind,
|
||||
prompt_tokens,
|
||||
completion_tokens,
|
||||
total_tokens,
|
||||
cost_usd,
|
||||
cost_micros,
|
||||
len(acc.calls),
|
||||
)
|
||||
|
||||
|
|
@ -168,6 +388,7 @@ async def record_token_usage(
|
|||
prompt_tokens: int = 0,
|
||||
completion_tokens: int = 0,
|
||||
total_tokens: int = 0,
|
||||
cost_micros: int = 0,
|
||||
model_breakdown: dict[str, Any] | None = None,
|
||||
call_details: dict[str, Any] | None = None,
|
||||
thread_id: int | None = None,
|
||||
|
|
@ -185,6 +406,7 @@ async def record_token_usage(
|
|||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=total_tokens,
|
||||
cost_micros=cost_micros,
|
||||
model_breakdown=model_breakdown,
|
||||
call_details=call_details,
|
||||
thread_id=thread_id,
|
||||
|
|
@ -194,11 +416,12 @@ async def record_token_usage(
|
|||
)
|
||||
session.add(record)
|
||||
logger.debug(
|
||||
"[TokenTracking] recorded %s usage: prompt=%d completion=%d total=%d",
|
||||
"[TokenTracking] recorded %s usage: prompt=%d completion=%d total=%d cost_micros=%d",
|
||||
usage_type,
|
||||
prompt_tokens,
|
||||
completion_tokens,
|
||||
total_tokens,
|
||||
cost_micros,
|
||||
)
|
||||
return record
|
||||
except Exception:
|
||||
|
|
|
|||
|
|
@ -3,6 +3,8 @@ from typing import Any
|
|||
|
||||
from litellm import Router
|
||||
|
||||
from app.services.provider_api_base import resolve_api_base
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
VISION_AUTO_MODE_ID = 0
|
||||
|
|
@ -108,10 +110,11 @@ class VisionLLMRouterService:
|
|||
if not config.get("model_name") or not config.get("api_key"):
|
||||
return None
|
||||
|
||||
provider = config.get("provider", "").upper()
|
||||
if config.get("custom_provider"):
|
||||
model_string = f"{config['custom_provider']}/{config['model_name']}"
|
||||
provider_prefix = config["custom_provider"]
|
||||
model_string = f"{provider_prefix}/{config['model_name']}"
|
||||
else:
|
||||
provider = config.get("provider", "").upper()
|
||||
provider_prefix = VISION_PROVIDER_MAP.get(provider, provider.lower())
|
||||
model_string = f"{provider_prefix}/{config['model_name']}"
|
||||
|
||||
|
|
@ -120,8 +123,13 @@ class VisionLLMRouterService:
|
|||
"api_key": config.get("api_key"),
|
||||
}
|
||||
|
||||
if config.get("api_base"):
|
||||
litellm_params["api_base"] = config["api_base"]
|
||||
api_base = resolve_api_base(
|
||||
provider=provider,
|
||||
provider_prefix=provider_prefix,
|
||||
config_api_base=config.get("api_base"),
|
||||
)
|
||||
if api_base:
|
||||
litellm_params["api_base"] = api_base
|
||||
|
||||
if config.get("api_version"):
|
||||
litellm_params["api_version"] = config["api_version"]
|
||||
|
|
|
|||
|
|
@ -9,7 +9,13 @@ from sqlalchemy import select
|
|||
from app.agents.podcaster.graph import graph as podcaster_graph
|
||||
from app.agents.podcaster.state import State as PodcasterState
|
||||
from app.celery_app import celery_app
|
||||
from app.config import config as app_config
|
||||
from app.db import Podcast, PodcastStatus
|
||||
from app.services.billable_calls import (
|
||||
QuotaInsufficientError,
|
||||
_resolve_agent_billing_for_search_space,
|
||||
billable_call,
|
||||
)
|
||||
from app.tasks.celery_tasks import get_celery_session_maker
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -96,6 +102,31 @@ async def _generate_content_podcast(
|
|||
podcast.status = PodcastStatus.GENERATING
|
||||
await session.commit()
|
||||
|
||||
try:
|
||||
(
|
||||
owner_user_id,
|
||||
billing_tier,
|
||||
base_model,
|
||||
) = await _resolve_agent_billing_for_search_space(
|
||||
session,
|
||||
search_space_id,
|
||||
thread_id=podcast.thread_id,
|
||||
)
|
||||
except ValueError as resolve_err:
|
||||
logger.error(
|
||||
"Podcast %s: cannot resolve billing for search_space=%s: %s",
|
||||
podcast.id,
|
||||
search_space_id,
|
||||
resolve_err,
|
||||
)
|
||||
podcast.status = PodcastStatus.FAILED
|
||||
await session.commit()
|
||||
return {
|
||||
"status": "failed",
|
||||
"podcast_id": podcast.id,
|
||||
"reason": "billing_resolution_failed",
|
||||
}
|
||||
|
||||
graph_config = {
|
||||
"configurable": {
|
||||
"podcast_title": podcast.title,
|
||||
|
|
@ -109,9 +140,39 @@ async def _generate_content_podcast(
|
|||
db_session=session,
|
||||
)
|
||||
|
||||
graph_result = await podcaster_graph.ainvoke(
|
||||
initial_state, config=graph_config
|
||||
)
|
||||
try:
|
||||
async with billable_call(
|
||||
user_id=owner_user_id,
|
||||
search_space_id=search_space_id,
|
||||
billing_tier=billing_tier,
|
||||
base_model=base_model,
|
||||
quota_reserve_micros_override=app_config.QUOTA_DEFAULT_PODCAST_RESERVE_MICROS,
|
||||
usage_type="podcast_generation",
|
||||
thread_id=podcast.thread_id,
|
||||
call_details={
|
||||
"podcast_id": podcast.id,
|
||||
"title": podcast.title,
|
||||
},
|
||||
):
|
||||
graph_result = await podcaster_graph.ainvoke(
|
||||
initial_state, config=graph_config
|
||||
)
|
||||
except QuotaInsufficientError as exc:
|
||||
logger.info(
|
||||
"Podcast %s denied: out of premium credits "
|
||||
"(used=%d/%d remaining=%d)",
|
||||
podcast.id,
|
||||
exc.used_micros,
|
||||
exc.limit_micros,
|
||||
exc.remaining_micros,
|
||||
)
|
||||
podcast.status = PodcastStatus.FAILED
|
||||
await session.commit()
|
||||
return {
|
||||
"status": "failed",
|
||||
"podcast_id": podcast.id,
|
||||
"reason": "premium_quota_exhausted",
|
||||
}
|
||||
|
||||
podcast_transcript = graph_result.get("podcast_transcript", [])
|
||||
file_path = graph_result.get("final_podcast_file_path", "")
|
||||
|
|
|
|||
|
|
@ -9,7 +9,13 @@ from sqlalchemy import select
|
|||
from app.agents.video_presentation.graph import graph as video_presentation_graph
|
||||
from app.agents.video_presentation.state import State as VideoPresentationState
|
||||
from app.celery_app import celery_app
|
||||
from app.config import config as app_config
|
||||
from app.db import VideoPresentation, VideoPresentationStatus
|
||||
from app.services.billable_calls import (
|
||||
QuotaInsufficientError,
|
||||
_resolve_agent_billing_for_search_space,
|
||||
billable_call,
|
||||
)
|
||||
from app.tasks.celery_tasks import get_celery_session_maker
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -97,6 +103,32 @@ async def _generate_video_presentation(
|
|||
video_pres.status = VideoPresentationStatus.GENERATING
|
||||
await session.commit()
|
||||
|
||||
try:
|
||||
(
|
||||
owner_user_id,
|
||||
billing_tier,
|
||||
base_model,
|
||||
) = await _resolve_agent_billing_for_search_space(
|
||||
session,
|
||||
search_space_id,
|
||||
thread_id=video_pres.thread_id,
|
||||
)
|
||||
except ValueError as resolve_err:
|
||||
logger.error(
|
||||
"VideoPresentation %s: cannot resolve billing for "
|
||||
"search_space=%s: %s",
|
||||
video_pres.id,
|
||||
search_space_id,
|
||||
resolve_err,
|
||||
)
|
||||
video_pres.status = VideoPresentationStatus.FAILED
|
||||
await session.commit()
|
||||
return {
|
||||
"status": "failed",
|
||||
"video_presentation_id": video_pres.id,
|
||||
"reason": "billing_resolution_failed",
|
||||
}
|
||||
|
||||
graph_config = {
|
||||
"configurable": {
|
||||
"video_title": video_pres.title,
|
||||
|
|
@ -110,9 +142,39 @@ async def _generate_video_presentation(
|
|||
db_session=session,
|
||||
)
|
||||
|
||||
graph_result = await video_presentation_graph.ainvoke(
|
||||
initial_state, config=graph_config
|
||||
)
|
||||
try:
|
||||
async with billable_call(
|
||||
user_id=owner_user_id,
|
||||
search_space_id=search_space_id,
|
||||
billing_tier=billing_tier,
|
||||
base_model=base_model,
|
||||
quota_reserve_micros_override=app_config.QUOTA_DEFAULT_VIDEO_PRESENTATION_RESERVE_MICROS,
|
||||
usage_type="video_presentation_generation",
|
||||
thread_id=video_pres.thread_id,
|
||||
call_details={
|
||||
"video_presentation_id": video_pres.id,
|
||||
"title": video_pres.title,
|
||||
},
|
||||
):
|
||||
graph_result = await video_presentation_graph.ainvoke(
|
||||
initial_state, config=graph_config
|
||||
)
|
||||
except QuotaInsufficientError as exc:
|
||||
logger.info(
|
||||
"VideoPresentation %s denied: out of premium credits "
|
||||
"(used=%d/%d remaining=%d)",
|
||||
video_pres.id,
|
||||
exc.used_micros,
|
||||
exc.limit_micros,
|
||||
exc.remaining_micros,
|
||||
)
|
||||
video_pres.status = VideoPresentationStatus.FAILED
|
||||
await session.commit()
|
||||
return {
|
||||
"status": "failed",
|
||||
"video_presentation_id": video_pres.id,
|
||||
"reason": "premium_quota_exhausted",
|
||||
}
|
||||
|
||||
# Serialize slides (parsed content + audio info merged)
|
||||
slides_raw = graph_result.get("slides", [])
|
||||
|
|
|
|||
|
|
@ -2236,8 +2236,10 @@ async def stream_new_chat(
|
|||
|
||||
accumulator = start_turn()
|
||||
|
||||
# Premium quota tracking state
|
||||
_premium_reserved = 0
|
||||
# Premium credit (USD micro-units) tracking state. Stores the
|
||||
# amount reserved up front so we can release it on cancellation
|
||||
# and finalize-debit the actual provider cost reported by LiteLLM.
|
||||
_premium_reserved_micros = 0
|
||||
_premium_request_id: str | None = None
|
||||
|
||||
_emit_stream_error = partial(
|
||||
|
|
@ -2331,23 +2333,28 @@ async def stream_new_chat(
|
|||
if _needs_premium_quota:
|
||||
import uuid as _uuid
|
||||
|
||||
from app.config import config as _app_config
|
||||
from app.services.token_quota_service import TokenQuotaService
|
||||
from app.services.token_quota_service import (
|
||||
TokenQuotaService,
|
||||
estimate_call_reserve_micros,
|
||||
)
|
||||
|
||||
_premium_request_id = _uuid.uuid4().hex[:16]
|
||||
reserve_amount = min(
|
||||
agent_config.quota_reserve_tokens
|
||||
or _app_config.QUOTA_MAX_RESERVE_PER_CALL,
|
||||
_app_config.QUOTA_MAX_RESERVE_PER_CALL,
|
||||
_agent_litellm_params = agent_config.litellm_params or {}
|
||||
_agent_base_model = (
|
||||
_agent_litellm_params.get("base_model") or agent_config.model_name or ""
|
||||
)
|
||||
reserve_amount_micros = estimate_call_reserve_micros(
|
||||
base_model=_agent_base_model,
|
||||
quota_reserve_tokens=agent_config.quota_reserve_tokens,
|
||||
)
|
||||
async with shielded_async_session() as quota_session:
|
||||
quota_result = await TokenQuotaService.premium_reserve(
|
||||
db_session=quota_session,
|
||||
user_id=UUID(user_id),
|
||||
request_id=_premium_request_id,
|
||||
reserve_tokens=reserve_amount,
|
||||
reserve_micros=reserve_amount_micros,
|
||||
)
|
||||
_premium_reserved = reserve_amount
|
||||
_premium_reserved_micros = reserve_amount_micros
|
||||
if not quota_result.allowed:
|
||||
if requested_llm_config_id == 0:
|
||||
try:
|
||||
|
|
@ -2382,7 +2389,7 @@ async def stream_new_chat(
|
|||
yield streaming_service.format_done()
|
||||
return
|
||||
_premium_request_id = None
|
||||
_premium_reserved = 0
|
||||
_premium_reserved_micros = 0
|
||||
_log_chat_stream_error(
|
||||
flow=flow,
|
||||
error_kind="premium_quota_exhausted",
|
||||
|
|
@ -3020,9 +3027,10 @@ async def stream_new_chat(
|
|||
|
||||
usage_summary = accumulator.per_message_summary()
|
||||
_perf_log.info(
|
||||
"[token_usage] interrupted new_chat: calls=%d total=%d summary=%s",
|
||||
"[token_usage] interrupted new_chat: calls=%d total=%d cost_micros=%d summary=%s",
|
||||
len(accumulator.calls),
|
||||
accumulator.grand_total,
|
||||
accumulator.total_cost_micros,
|
||||
usage_summary,
|
||||
)
|
||||
if usage_summary:
|
||||
|
|
@ -3033,6 +3041,7 @@ async def stream_new_chat(
|
|||
"prompt_tokens": accumulator.total_prompt_tokens,
|
||||
"completion_tokens": accumulator.total_completion_tokens,
|
||||
"total_tokens": accumulator.grand_total,
|
||||
"cost_micros": accumulator.total_cost_micros,
|
||||
"call_details": accumulator.serialized_calls(),
|
||||
},
|
||||
)
|
||||
|
|
@ -3060,7 +3069,11 @@ async def stream_new_chat(
|
|||
chat_id, generated_title
|
||||
)
|
||||
|
||||
# Finalize premium quota with actual tokens.
|
||||
# Finalize premium credit debit with the actual provider cost
|
||||
# reported by LiteLLM, summed across every call in the turn.
|
||||
# Mirrors the pre-cost behaviour of "premium turn → all calls
|
||||
# count" so free sub-agent calls during a premium turn still
|
||||
# contribute to the bill (they're $0 in practice anyway).
|
||||
if _premium_request_id and user_id:
|
||||
try:
|
||||
from app.services.token_quota_service import TokenQuotaService
|
||||
|
|
@ -3070,11 +3083,11 @@ async def stream_new_chat(
|
|||
db_session=quota_session,
|
||||
user_id=UUID(user_id),
|
||||
request_id=_premium_request_id,
|
||||
actual_tokens=accumulator.grand_total,
|
||||
reserved_tokens=_premium_reserved,
|
||||
actual_micros=accumulator.total_cost_micros,
|
||||
reserved_micros=_premium_reserved_micros,
|
||||
)
|
||||
_premium_request_id = None
|
||||
_premium_reserved = 0
|
||||
_premium_reserved_micros = 0
|
||||
except Exception:
|
||||
logging.getLogger(__name__).warning(
|
||||
"Failed to finalize premium quota for user %s",
|
||||
|
|
@ -3084,9 +3097,10 @@ async def stream_new_chat(
|
|||
|
||||
usage_summary = accumulator.per_message_summary()
|
||||
_perf_log.info(
|
||||
"[token_usage] normal new_chat: calls=%d total=%d summary=%s",
|
||||
"[token_usage] normal new_chat: calls=%d total=%d cost_micros=%d summary=%s",
|
||||
len(accumulator.calls),
|
||||
accumulator.grand_total,
|
||||
accumulator.total_cost_micros,
|
||||
usage_summary,
|
||||
)
|
||||
if usage_summary:
|
||||
|
|
@ -3097,6 +3111,7 @@ async def stream_new_chat(
|
|||
"prompt_tokens": accumulator.total_prompt_tokens,
|
||||
"completion_tokens": accumulator.total_completion_tokens,
|
||||
"total_tokens": accumulator.grand_total,
|
||||
"cost_micros": accumulator.total_cost_micros,
|
||||
"call_details": accumulator.serialized_calls(),
|
||||
},
|
||||
)
|
||||
|
|
@ -3190,7 +3205,7 @@ async def stream_new_chat(
|
|||
end_turn(str(chat_id))
|
||||
|
||||
# Release premium reservation if not finalized
|
||||
if _premium_request_id and _premium_reserved > 0 and user_id:
|
||||
if _premium_request_id and _premium_reserved_micros > 0 and user_id:
|
||||
try:
|
||||
from app.services.token_quota_service import TokenQuotaService
|
||||
|
||||
|
|
@ -3198,9 +3213,9 @@ async def stream_new_chat(
|
|||
await TokenQuotaService.premium_release(
|
||||
db_session=quota_session,
|
||||
user_id=UUID(user_id),
|
||||
reserved_tokens=_premium_reserved,
|
||||
reserved_micros=_premium_reserved_micros,
|
||||
)
|
||||
_premium_reserved = 0
|
||||
_premium_reserved_micros = 0
|
||||
except Exception:
|
||||
logging.getLogger(__name__).warning(
|
||||
"Failed to release premium quota for user %s", user_id
|
||||
|
|
@ -3369,8 +3384,8 @@ async def stream_resume_chat(
|
|||
"[stream_resume] LLM config loaded in %.3fs", time.perf_counter() - _t0
|
||||
)
|
||||
|
||||
# Premium quota reservation (same logic as stream_new_chat)
|
||||
_resume_premium_reserved = 0
|
||||
# Premium credit reservation (same logic as stream_new_chat).
|
||||
_resume_premium_reserved_micros = 0
|
||||
_resume_premium_request_id: str | None = None
|
||||
_resume_needs_premium = (
|
||||
agent_config is not None and user_id and agent_config.is_premium
|
||||
|
|
@ -3378,23 +3393,30 @@ async def stream_resume_chat(
|
|||
if _resume_needs_premium:
|
||||
import uuid as _uuid
|
||||
|
||||
from app.config import config as _app_config
|
||||
from app.services.token_quota_service import TokenQuotaService
|
||||
from app.services.token_quota_service import (
|
||||
TokenQuotaService,
|
||||
estimate_call_reserve_micros,
|
||||
)
|
||||
|
||||
_resume_premium_request_id = _uuid.uuid4().hex[:16]
|
||||
reserve_amount = min(
|
||||
agent_config.quota_reserve_tokens
|
||||
or _app_config.QUOTA_MAX_RESERVE_PER_CALL,
|
||||
_app_config.QUOTA_MAX_RESERVE_PER_CALL,
|
||||
_resume_litellm_params = agent_config.litellm_params or {}
|
||||
_resume_base_model = (
|
||||
_resume_litellm_params.get("base_model")
|
||||
or agent_config.model_name
|
||||
or ""
|
||||
)
|
||||
reserve_amount_micros = estimate_call_reserve_micros(
|
||||
base_model=_resume_base_model,
|
||||
quota_reserve_tokens=agent_config.quota_reserve_tokens,
|
||||
)
|
||||
async with shielded_async_session() as quota_session:
|
||||
quota_result = await TokenQuotaService.premium_reserve(
|
||||
db_session=quota_session,
|
||||
user_id=UUID(user_id),
|
||||
request_id=_resume_premium_request_id,
|
||||
reserve_tokens=reserve_amount,
|
||||
reserve_micros=reserve_amount_micros,
|
||||
)
|
||||
_resume_premium_reserved = reserve_amount
|
||||
_resume_premium_reserved_micros = reserve_amount_micros
|
||||
if not quota_result.allowed:
|
||||
if requested_llm_config_id == 0:
|
||||
try:
|
||||
|
|
@ -3429,7 +3451,7 @@ async def stream_resume_chat(
|
|||
yield streaming_service.format_done()
|
||||
return
|
||||
_resume_premium_request_id = None
|
||||
_resume_premium_reserved = 0
|
||||
_resume_premium_reserved_micros = 0
|
||||
_log_chat_stream_error(
|
||||
flow="resume",
|
||||
error_kind="premium_quota_exhausted",
|
||||
|
|
@ -3746,9 +3768,10 @@ async def stream_resume_chat(
|
|||
if stream_result.is_interrupted:
|
||||
usage_summary = accumulator.per_message_summary()
|
||||
_perf_log.info(
|
||||
"[token_usage] interrupted resume_chat: calls=%d total=%d summary=%s",
|
||||
"[token_usage] interrupted resume_chat: calls=%d total=%d cost_micros=%d summary=%s",
|
||||
len(accumulator.calls),
|
||||
accumulator.grand_total,
|
||||
accumulator.total_cost_micros,
|
||||
usage_summary,
|
||||
)
|
||||
if usage_summary:
|
||||
|
|
@ -3759,6 +3782,7 @@ async def stream_resume_chat(
|
|||
"prompt_tokens": accumulator.total_prompt_tokens,
|
||||
"completion_tokens": accumulator.total_completion_tokens,
|
||||
"total_tokens": accumulator.grand_total,
|
||||
"cost_micros": accumulator.total_cost_micros,
|
||||
"call_details": accumulator.serialized_calls(),
|
||||
},
|
||||
)
|
||||
|
|
@ -3768,7 +3792,9 @@ async def stream_resume_chat(
|
|||
yield streaming_service.format_done()
|
||||
return
|
||||
|
||||
# Finalize premium quota for resume path
|
||||
# Finalize premium credit debit for resume path with the actual
|
||||
# provider cost reported by LiteLLM (sum of cost across all
|
||||
# calls in the turn).
|
||||
if _resume_premium_request_id and user_id:
|
||||
try:
|
||||
from app.services.token_quota_service import TokenQuotaService
|
||||
|
|
@ -3778,11 +3804,11 @@ async def stream_resume_chat(
|
|||
db_session=quota_session,
|
||||
user_id=UUID(user_id),
|
||||
request_id=_resume_premium_request_id,
|
||||
actual_tokens=accumulator.grand_total,
|
||||
reserved_tokens=_resume_premium_reserved,
|
||||
actual_micros=accumulator.total_cost_micros,
|
||||
reserved_micros=_resume_premium_reserved_micros,
|
||||
)
|
||||
_resume_premium_request_id = None
|
||||
_resume_premium_reserved = 0
|
||||
_resume_premium_reserved_micros = 0
|
||||
except Exception:
|
||||
logging.getLogger(__name__).warning(
|
||||
"Failed to finalize premium quota for user %s (resume)",
|
||||
|
|
@ -3792,9 +3818,10 @@ async def stream_resume_chat(
|
|||
|
||||
usage_summary = accumulator.per_message_summary()
|
||||
_perf_log.info(
|
||||
"[token_usage] normal resume_chat: calls=%d total=%d summary=%s",
|
||||
"[token_usage] normal resume_chat: calls=%d total=%d cost_micros=%d summary=%s",
|
||||
len(accumulator.calls),
|
||||
accumulator.grand_total,
|
||||
accumulator.total_cost_micros,
|
||||
usage_summary,
|
||||
)
|
||||
if usage_summary:
|
||||
|
|
@ -3805,6 +3832,7 @@ async def stream_resume_chat(
|
|||
"prompt_tokens": accumulator.total_prompt_tokens,
|
||||
"completion_tokens": accumulator.total_completion_tokens,
|
||||
"total_tokens": accumulator.grand_total,
|
||||
"cost_micros": accumulator.total_cost_micros,
|
||||
"call_details": accumulator.serialized_calls(),
|
||||
},
|
||||
)
|
||||
|
|
@ -3855,7 +3883,11 @@ async def stream_resume_chat(
|
|||
end_turn(str(chat_id))
|
||||
|
||||
# Release premium reservation if not finalized
|
||||
if _resume_premium_request_id and _resume_premium_reserved > 0 and user_id:
|
||||
if (
|
||||
_resume_premium_request_id
|
||||
and _resume_premium_reserved_micros > 0
|
||||
and user_id
|
||||
):
|
||||
try:
|
||||
from app.services.token_quota_service import TokenQuotaService
|
||||
|
||||
|
|
@ -3863,9 +3895,9 @@ async def stream_resume_chat(
|
|||
await TokenQuotaService.premium_release(
|
||||
db_session=quota_session,
|
||||
user_id=UUID(user_id),
|
||||
reserved_tokens=_resume_premium_reserved,
|
||||
reserved_micros=_resume_premium_reserved_micros,
|
||||
)
|
||||
_resume_premium_reserved = 0
|
||||
_resume_premium_reserved_micros = 0
|
||||
except Exception:
|
||||
logging.getLogger(__name__).warning(
|
||||
"Failed to release premium quota for user %s (resume)", user_id
|
||||
|
|
|
|||
138
surfsense_backend/tests/unit/routes/test_image_gen_quota.py
Normal file
138
surfsense_backend/tests/unit/routes/test_image_gen_quota.py
Normal file
|
|
@ -0,0 +1,138 @@
|
|||
"""Unit tests for the image-generation route's billing-resolution helper.
|
||||
|
||||
End-to-end "POST /image-generations returns 402" coverage requires the
|
||||
integration harness (real DB, real auth) and lives in
|
||||
``tests/integration/document_upload/`` alongside the other quota tests.
|
||||
This unit test focuses on the new ``_resolve_billing_for_image_gen``
|
||||
helper which:
|
||||
|
||||
* Returns ``free`` for Auto mode, even when premium configs exist
|
||||
(Auto-mode billing-tier surfacing is a follow-up).
|
||||
* Returns ``free`` for user-owned BYOK configs (positive IDs).
|
||||
* Returns the global config's ``billing_tier`` for negative IDs.
|
||||
* Honours the per-config ``quota_reserve_micros`` override when present.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_billing_for_auto_mode(monkeypatch):
|
||||
from app.routes import image_generation_routes
|
||||
from app.services.billable_calls import DEFAULT_IMAGE_RESERVE_MICROS
|
||||
|
||||
search_space = SimpleNamespace(image_generation_config_id=None)
|
||||
tier, model, reserve = await image_generation_routes._resolve_billing_for_image_gen(
|
||||
session=None, # Not consumed on this code path.
|
||||
config_id=0, # IMAGE_GEN_AUTO_MODE_ID
|
||||
search_space=search_space,
|
||||
)
|
||||
assert tier == "free"
|
||||
assert model == "auto"
|
||||
assert reserve == DEFAULT_IMAGE_RESERVE_MICROS
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_billing_for_premium_global_config(monkeypatch):
|
||||
from app.config import config
|
||||
from app.routes import image_generation_routes
|
||||
|
||||
monkeypatch.setattr(
|
||||
config,
|
||||
"GLOBAL_IMAGE_GEN_CONFIGS",
|
||||
[
|
||||
{
|
||||
"id": -1,
|
||||
"provider": "OPENAI",
|
||||
"model_name": "gpt-image-1",
|
||||
"billing_tier": "premium",
|
||||
"quota_reserve_micros": 75_000,
|
||||
},
|
||||
{
|
||||
"id": -2,
|
||||
"provider": "OPENROUTER",
|
||||
"model_name": "google/gemini-2.5-flash-image",
|
||||
"billing_tier": "free",
|
||||
},
|
||||
],
|
||||
raising=False,
|
||||
)
|
||||
|
||||
search_space = SimpleNamespace(image_generation_config_id=None)
|
||||
|
||||
# Premium with override.
|
||||
tier, model, reserve = await image_generation_routes._resolve_billing_for_image_gen(
|
||||
session=None, config_id=-1, search_space=search_space
|
||||
)
|
||||
assert tier == "premium"
|
||||
assert model == "openai/gpt-image-1"
|
||||
assert reserve == 75_000
|
||||
|
||||
# Free, no override → falls back to default.
|
||||
from app.services.billable_calls import DEFAULT_IMAGE_RESERVE_MICROS
|
||||
|
||||
tier, model, reserve = await image_generation_routes._resolve_billing_for_image_gen(
|
||||
session=None, config_id=-2, search_space=search_space
|
||||
)
|
||||
assert tier == "free"
|
||||
# Provider-prefixed model string for OpenRouter.
|
||||
assert "google/gemini-2.5-flash-image" in model
|
||||
assert reserve == DEFAULT_IMAGE_RESERVE_MICROS
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_billing_for_user_owned_byok_is_free():
|
||||
"""User-owned BYOK configs (positive IDs) cost the user nothing on
|
||||
our side — they pay the provider directly. Always free.
|
||||
"""
|
||||
from app.routes import image_generation_routes
|
||||
from app.services.billable_calls import DEFAULT_IMAGE_RESERVE_MICROS
|
||||
|
||||
search_space = SimpleNamespace(image_generation_config_id=None)
|
||||
tier, model, reserve = await image_generation_routes._resolve_billing_for_image_gen(
|
||||
session=None, config_id=42, search_space=search_space
|
||||
)
|
||||
assert tier == "free"
|
||||
assert model == "user_byok"
|
||||
assert reserve == DEFAULT_IMAGE_RESERVE_MICROS
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_billing_falls_back_to_search_space_default(monkeypatch):
|
||||
"""When the request omits ``image_generation_config_id``, the helper
|
||||
must consult the search space's default — so a search space pinned
|
||||
to a premium global config still gates new requests by quota.
|
||||
"""
|
||||
from app.config import config
|
||||
from app.routes import image_generation_routes
|
||||
|
||||
monkeypatch.setattr(
|
||||
config,
|
||||
"GLOBAL_IMAGE_GEN_CONFIGS",
|
||||
[
|
||||
{
|
||||
"id": -7,
|
||||
"provider": "OPENAI",
|
||||
"model_name": "gpt-image-1",
|
||||
"billing_tier": "premium",
|
||||
}
|
||||
],
|
||||
raising=False,
|
||||
)
|
||||
|
||||
search_space = SimpleNamespace(image_generation_config_id=-7)
|
||||
(
|
||||
tier,
|
||||
model,
|
||||
_reserve,
|
||||
) = await image_generation_routes._resolve_billing_for_image_gen(
|
||||
session=None, config_id=None, search_space=search_space
|
||||
)
|
||||
assert tier == "premium"
|
||||
assert model == "openai/gpt-image-1"
|
||||
|
|
@ -0,0 +1,436 @@
|
|||
"""Unit tests for ``_resolve_agent_billing_for_search_space``.
|
||||
|
||||
Validates the resolver used by Celery podcast/video tasks to compute
|
||||
``(owner_user_id, billing_tier, base_model)`` from a search space and its
|
||||
agent LLM config. The resolver mirrors chat's billing-resolution pattern at
|
||||
``stream_new_chat.py:2294-2351`` and is the single integration point that
|
||||
prevents Auto-mode podcast/video from leaking premium credit.
|
||||
|
||||
Coverage:
|
||||
|
||||
* Auto mode + ``thread_id`` set, pin resolves to a negative-id premium
|
||||
global → returns ``("premium", <base_model>)``.
|
||||
* Auto mode + ``thread_id`` set, pin resolves to a negative-id free
|
||||
global → returns ``("free", <base_model>)``.
|
||||
* Auto mode + ``thread_id`` set, pin resolves to a positive-id BYOK config
|
||||
→ always ``"free"``.
|
||||
* Auto mode + ``thread_id=None`` → fallback to ``("free", "auto")`` without
|
||||
hitting the pin service.
|
||||
* Negative id (no Auto) → uses ``get_global_llm_config``'s
|
||||
``billing_tier``.
|
||||
* Positive id (user BYOK) → always ``"free"``.
|
||||
* Search space not found → raises ``ValueError``.
|
||||
* ``agent_llm_id`` is None → raises ``ValueError``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from types import SimpleNamespace
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fakes
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _FakeExecResult:
|
||||
def __init__(self, obj):
|
||||
self._obj = obj
|
||||
|
||||
def scalars(self):
|
||||
return self
|
||||
|
||||
def first(self):
|
||||
return self._obj
|
||||
|
||||
|
||||
class _FakeSession:
|
||||
"""Tiny AsyncSession stub.
|
||||
|
||||
``responses`` is a list of objects to return from successive
|
||||
``execute()`` calls (in order). The resolver makes at most two
|
||||
``execute()`` calls (search-space lookup, then optionally NewLLMConfig
|
||||
lookup), so two queued responses cover the matrix.
|
||||
"""
|
||||
|
||||
def __init__(self, responses: list):
|
||||
self._responses = list(responses)
|
||||
|
||||
async def execute(self, _stmt):
|
||||
if not self._responses:
|
||||
return _FakeExecResult(None)
|
||||
return _FakeExecResult(self._responses.pop(0))
|
||||
|
||||
async def commit(self) -> None:
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class _FakePinResolution:
|
||||
resolved_llm_config_id: int
|
||||
resolved_tier: str = "premium"
|
||||
from_existing_pin: bool = False
|
||||
|
||||
|
||||
def _make_search_space(*, agent_llm_id: int | None, user_id: UUID) -> SimpleNamespace:
|
||||
return SimpleNamespace(
|
||||
id=42,
|
||||
agent_llm_id=agent_llm_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
|
||||
def _make_byok_config(
|
||||
*, id_: int, base_model: str | None = None, model_name: str = "gpt-byok"
|
||||
) -> SimpleNamespace:
|
||||
return SimpleNamespace(
|
||||
id=id_,
|
||||
model_name=model_name,
|
||||
litellm_params={"base_model": base_model} if base_model else {},
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auto_mode_with_thread_id_resolves_to_premium_global(monkeypatch):
|
||||
"""Auto + thread → pin service resolves to negative-id premium config →
|
||||
resolver returns ``("premium", <base_model>)``."""
|
||||
from app.services.billable_calls import _resolve_agent_billing_for_search_space
|
||||
|
||||
user_id = uuid4()
|
||||
session = _FakeSession([_make_search_space(agent_llm_id=0, user_id=user_id)])
|
||||
|
||||
# Mock the pin service to return a concrete premium config id.
|
||||
async def _fake_resolve_pin(
|
||||
sess,
|
||||
*,
|
||||
thread_id,
|
||||
search_space_id,
|
||||
user_id,
|
||||
selected_llm_config_id,
|
||||
force_repin_free=False,
|
||||
):
|
||||
assert selected_llm_config_id == 0
|
||||
assert thread_id == 99
|
||||
return _FakePinResolution(resolved_llm_config_id=-1, resolved_tier="premium")
|
||||
|
||||
# Mock global config lookup to return a premium entry.
|
||||
def _fake_get_global(cfg_id):
|
||||
if cfg_id == -1:
|
||||
return {
|
||||
"id": -1,
|
||||
"model_name": "gpt-5.4",
|
||||
"billing_tier": "premium",
|
||||
"litellm_params": {"base_model": "gpt-5.4"},
|
||||
}
|
||||
return None
|
||||
|
||||
# Lazy imports inside the resolver — patch the *target* modules so the
|
||||
# imported names resolve to our fakes.
|
||||
import app.services.auto_model_pin_service as pin_module
|
||||
import app.services.llm_service as llm_module
|
||||
|
||||
monkeypatch.setattr(
|
||||
pin_module, "resolve_or_get_pinned_llm_config_id", _fake_resolve_pin
|
||||
)
|
||||
monkeypatch.setattr(llm_module, "get_global_llm_config", _fake_get_global)
|
||||
|
||||
owner, tier, base_model = await _resolve_agent_billing_for_search_space(
|
||||
session, search_space_id=42, thread_id=99
|
||||
)
|
||||
|
||||
assert owner == user_id
|
||||
assert tier == "premium"
|
||||
assert base_model == "gpt-5.4"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auto_mode_with_thread_id_resolves_to_free_global(monkeypatch):
|
||||
"""Auto + thread → pin returns negative-id free config → resolver
|
||||
returns ``("free", <base_model>)``. Same path the pin service takes for
|
||||
out-of-credit users (graceful degradation)."""
|
||||
from app.services.billable_calls import _resolve_agent_billing_for_search_space
|
||||
|
||||
user_id = uuid4()
|
||||
session = _FakeSession([_make_search_space(agent_llm_id=0, user_id=user_id)])
|
||||
|
||||
async def _fake_resolve_pin(
|
||||
sess,
|
||||
*,
|
||||
thread_id,
|
||||
search_space_id,
|
||||
user_id,
|
||||
selected_llm_config_id,
|
||||
force_repin_free=False,
|
||||
):
|
||||
return _FakePinResolution(resolved_llm_config_id=-3, resolved_tier="free")
|
||||
|
||||
def _fake_get_global(cfg_id):
|
||||
if cfg_id == -3:
|
||||
return {
|
||||
"id": -3,
|
||||
"model_name": "openrouter/free-model",
|
||||
"billing_tier": "free",
|
||||
"litellm_params": {"base_model": "openrouter/free-model"},
|
||||
}
|
||||
return None
|
||||
|
||||
import app.services.auto_model_pin_service as pin_module
|
||||
import app.services.llm_service as llm_module
|
||||
|
||||
monkeypatch.setattr(
|
||||
pin_module, "resolve_or_get_pinned_llm_config_id", _fake_resolve_pin
|
||||
)
|
||||
monkeypatch.setattr(llm_module, "get_global_llm_config", _fake_get_global)
|
||||
|
||||
owner, tier, base_model = await _resolve_agent_billing_for_search_space(
|
||||
session, search_space_id=42, thread_id=99
|
||||
)
|
||||
|
||||
assert owner == user_id
|
||||
assert tier == "free"
|
||||
assert base_model == "openrouter/free-model"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auto_mode_with_thread_id_resolves_to_byok_is_free(monkeypatch):
|
||||
"""Auto + thread → pin returns positive-id BYOK config → resolver
|
||||
returns ``("free", ...)`` (BYOK is always free per
|
||||
``AgentConfig.from_new_llm_config``)."""
|
||||
from app.services.billable_calls import _resolve_agent_billing_for_search_space
|
||||
|
||||
user_id = uuid4()
|
||||
search_space = _make_search_space(agent_llm_id=0, user_id=user_id)
|
||||
byok_cfg = _make_byok_config(
|
||||
id_=17, base_model="anthropic/claude-3-haiku", model_name="my-claude"
|
||||
)
|
||||
session = _FakeSession([search_space, byok_cfg])
|
||||
|
||||
async def _fake_resolve_pin(
|
||||
sess,
|
||||
*,
|
||||
thread_id,
|
||||
search_space_id,
|
||||
user_id,
|
||||
selected_llm_config_id,
|
||||
force_repin_free=False,
|
||||
):
|
||||
return _FakePinResolution(resolved_llm_config_id=17, resolved_tier="free")
|
||||
|
||||
import app.services.auto_model_pin_service as pin_module
|
||||
|
||||
monkeypatch.setattr(
|
||||
pin_module, "resolve_or_get_pinned_llm_config_id", _fake_resolve_pin
|
||||
)
|
||||
|
||||
owner, tier, base_model = await _resolve_agent_billing_for_search_space(
|
||||
session, search_space_id=42, thread_id=99
|
||||
)
|
||||
|
||||
assert owner == user_id
|
||||
assert tier == "free"
|
||||
assert base_model == "anthropic/claude-3-haiku"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auto_mode_without_thread_id_falls_back_to_free():
|
||||
"""Auto + ``thread_id=None`` → ``("free", "auto")`` without invoking
|
||||
the pin service. Forward-compat fallback for any future direct-API
|
||||
entrypoint that doesn't have a chat thread."""
|
||||
from app.services.billable_calls import _resolve_agent_billing_for_search_space
|
||||
|
||||
user_id = uuid4()
|
||||
session = _FakeSession([_make_search_space(agent_llm_id=0, user_id=user_id)])
|
||||
|
||||
owner, tier, base_model = await _resolve_agent_billing_for_search_space(
|
||||
session, search_space_id=42, thread_id=None
|
||||
)
|
||||
|
||||
assert owner == user_id
|
||||
assert tier == "free"
|
||||
assert base_model == "auto"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auto_mode_pin_failure_falls_back_to_free(monkeypatch):
|
||||
"""If the pin service raises ``ValueError`` (thread missing /
|
||||
mismatched search space), the resolver should log and return free
|
||||
rather than killing the whole task."""
|
||||
from app.services.billable_calls import _resolve_agent_billing_for_search_space
|
||||
|
||||
user_id = uuid4()
|
||||
session = _FakeSession([_make_search_space(agent_llm_id=0, user_id=user_id)])
|
||||
|
||||
async def _fake_resolve_pin(*args, **kwargs):
|
||||
raise ValueError("thread missing")
|
||||
|
||||
import app.services.auto_model_pin_service as pin_module
|
||||
|
||||
monkeypatch.setattr(
|
||||
pin_module, "resolve_or_get_pinned_llm_config_id", _fake_resolve_pin
|
||||
)
|
||||
|
||||
owner, tier, base_model = await _resolve_agent_billing_for_search_space(
|
||||
session, search_space_id=42, thread_id=99
|
||||
)
|
||||
|
||||
assert owner == user_id
|
||||
assert tier == "free"
|
||||
assert base_model == "auto"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_negative_id_premium_global_returns_premium(monkeypatch):
|
||||
"""Explicit negative agent_llm_id → ``get_global_llm_config`` →
|
||||
return its ``billing_tier``."""
|
||||
from app.services.billable_calls import _resolve_agent_billing_for_search_space
|
||||
|
||||
user_id = uuid4()
|
||||
session = _FakeSession([_make_search_space(agent_llm_id=-1, user_id=user_id)])
|
||||
|
||||
def _fake_get_global(cfg_id):
|
||||
return {
|
||||
"id": cfg_id,
|
||||
"model_name": "gpt-5.4",
|
||||
"billing_tier": "premium",
|
||||
"litellm_params": {"base_model": "gpt-5.4"},
|
||||
}
|
||||
|
||||
import app.services.llm_service as llm_module
|
||||
|
||||
monkeypatch.setattr(llm_module, "get_global_llm_config", _fake_get_global)
|
||||
|
||||
owner, tier, base_model = await _resolve_agent_billing_for_search_space(
|
||||
session, search_space_id=42, thread_id=99
|
||||
)
|
||||
|
||||
assert owner == user_id
|
||||
assert tier == "premium"
|
||||
assert base_model == "gpt-5.4"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_negative_id_free_global_returns_free(monkeypatch):
|
||||
from app.services.billable_calls import _resolve_agent_billing_for_search_space
|
||||
|
||||
user_id = uuid4()
|
||||
session = _FakeSession([_make_search_space(agent_llm_id=-2, user_id=user_id)])
|
||||
|
||||
def _fake_get_global(cfg_id):
|
||||
return {
|
||||
"id": cfg_id,
|
||||
"model_name": "openrouter/some-free",
|
||||
"billing_tier": "free",
|
||||
"litellm_params": {"base_model": "openrouter/some-free"},
|
||||
}
|
||||
|
||||
import app.services.llm_service as llm_module
|
||||
|
||||
monkeypatch.setattr(llm_module, "get_global_llm_config", _fake_get_global)
|
||||
|
||||
owner, tier, base_model = await _resolve_agent_billing_for_search_space(
|
||||
session, search_space_id=42, thread_id=None
|
||||
)
|
||||
|
||||
assert owner == user_id
|
||||
assert tier == "free"
|
||||
assert base_model == "openrouter/some-free"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_negative_id_missing_base_model_falls_back_to_model_name(monkeypatch):
|
||||
"""When the global config has no ``litellm_params.base_model``, the
|
||||
resolver falls back to ``model_name`` — matching chat's behavior."""
|
||||
from app.services.billable_calls import _resolve_agent_billing_for_search_space
|
||||
|
||||
user_id = uuid4()
|
||||
session = _FakeSession([_make_search_space(agent_llm_id=-5, user_id=user_id)])
|
||||
|
||||
def _fake_get_global(cfg_id):
|
||||
return {
|
||||
"id": cfg_id,
|
||||
"model_name": "fallback-model",
|
||||
"billing_tier": "premium",
|
||||
# No litellm_params.
|
||||
}
|
||||
|
||||
import app.services.llm_service as llm_module
|
||||
|
||||
monkeypatch.setattr(llm_module, "get_global_llm_config", _fake_get_global)
|
||||
|
||||
_, tier, base_model = await _resolve_agent_billing_for_search_space(
|
||||
session, search_space_id=42
|
||||
)
|
||||
|
||||
assert tier == "premium"
|
||||
assert base_model == "fallback-model"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_positive_id_byok_is_always_free():
|
||||
"""Positive agent_llm_id → user-owned BYOK NewLLMConfig → always free,
|
||||
regardless of underlying provider tier."""
|
||||
from app.services.billable_calls import _resolve_agent_billing_for_search_space
|
||||
|
||||
user_id = uuid4()
|
||||
search_space = _make_search_space(agent_llm_id=23, user_id=user_id)
|
||||
byok_cfg = _make_byok_config(id_=23, base_model="anthropic/claude-3.5-sonnet")
|
||||
session = _FakeSession([search_space, byok_cfg])
|
||||
|
||||
owner, tier, base_model = await _resolve_agent_billing_for_search_space(
|
||||
session, search_space_id=42
|
||||
)
|
||||
|
||||
assert owner == user_id
|
||||
assert tier == "free"
|
||||
assert base_model == "anthropic/claude-3.5-sonnet"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_positive_id_byok_missing_returns_free_with_empty_base_model():
|
||||
"""If the BYOK config row is missing/deleted but the search space still
|
||||
points at it, the resolver still returns free (no debit) with an empty
|
||||
base_model — billable_call's premium path is skipped, no harm done."""
|
||||
from app.services.billable_calls import _resolve_agent_billing_for_search_space
|
||||
|
||||
user_id = uuid4()
|
||||
session = _FakeSession([_make_search_space(agent_llm_id=99, user_id=user_id)])
|
||||
|
||||
owner, tier, base_model = await _resolve_agent_billing_for_search_space(
|
||||
session, search_space_id=42
|
||||
)
|
||||
|
||||
assert owner == user_id
|
||||
assert tier == "free"
|
||||
assert base_model == ""
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_space_not_found_raises_value_error():
|
||||
from app.services.billable_calls import _resolve_agent_billing_for_search_space
|
||||
|
||||
session = _FakeSession([None])
|
||||
|
||||
with pytest.raises(ValueError, match="Search space"):
|
||||
await _resolve_agent_billing_for_search_space(session, search_space_id=999)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_llm_id_none_raises_value_error():
|
||||
from app.services.billable_calls import _resolve_agent_billing_for_search_space
|
||||
|
||||
user_id = uuid4()
|
||||
session = _FakeSession([_make_search_space(agent_llm_id=None, user_id=user_id)])
|
||||
|
||||
with pytest.raises(ValueError, match="agent_llm_id"):
|
||||
await _resolve_agent_billing_for_search_space(session, search_space_id=42)
|
||||
432
surfsense_backend/tests/unit/services/test_billable_call.py
Normal file
432
surfsense_backend/tests/unit/services/test_billable_call.py
Normal file
|
|
@ -0,0 +1,432 @@
|
|||
"""Unit tests for the ``billable_call`` async context manager.
|
||||
|
||||
Covers the per-call premium-credit lifecycle for image generation and
|
||||
vision LLM extraction:
|
||||
|
||||
* Free configs bypass reserve/finalize but still write an audit row.
|
||||
* Premium reserve denial raises ``QuotaInsufficientError`` (HTTP 402 in the
|
||||
route layer).
|
||||
* Successful premium calls reserve, yield the accumulator, then finalize
|
||||
with the LiteLLM-reported actual cost — and write an audit row.
|
||||
* Failed premium calls release the reservation so credit isn't leaked.
|
||||
* All quota DB ops happen inside their OWN ``shielded_async_session``,
|
||||
isolating them from the caller's transaction (issue A).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fakes
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _FakeQuotaResult:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
allowed: bool,
|
||||
used: int = 0,
|
||||
limit: int = 5_000_000,
|
||||
remaining: int = 5_000_000,
|
||||
) -> None:
|
||||
self.allowed = allowed
|
||||
self.used = used
|
||||
self.limit = limit
|
||||
self.remaining = remaining
|
||||
|
||||
|
||||
class _FakeSession:
|
||||
"""Minimal AsyncSession stub — record commits for assertion."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.committed = False
|
||||
self.added: list[Any] = []
|
||||
|
||||
def add(self, obj: Any) -> None:
|
||||
self.added.append(obj)
|
||||
|
||||
async def commit(self) -> None:
|
||||
self.committed = True
|
||||
|
||||
async def close(self) -> None:
|
||||
pass
|
||||
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def _fake_shielded_session():
|
||||
s = _FakeSession()
|
||||
_SESSIONS_USED.append(s)
|
||||
yield s
|
||||
|
||||
|
||||
_SESSIONS_USED: list[_FakeSession] = []
|
||||
|
||||
|
||||
def _patch_isolation_layer(monkeypatch, *, reserve_result, finalize_result=None):
|
||||
"""Wire fake reserve/finalize/release/session helpers."""
|
||||
_SESSIONS_USED.clear()
|
||||
reserve_calls: list[dict[str, Any]] = []
|
||||
finalize_calls: list[dict[str, Any]] = []
|
||||
release_calls: list[dict[str, Any]] = []
|
||||
|
||||
async def _fake_reserve(*, db_session, user_id, request_id, reserve_micros):
|
||||
reserve_calls.append(
|
||||
{
|
||||
"user_id": user_id,
|
||||
"reserve_micros": reserve_micros,
|
||||
"request_id": request_id,
|
||||
}
|
||||
)
|
||||
return reserve_result
|
||||
|
||||
async def _fake_finalize(
|
||||
*, db_session, user_id, request_id, actual_micros, reserved_micros
|
||||
):
|
||||
finalize_calls.append(
|
||||
{
|
||||
"user_id": user_id,
|
||||
"actual_micros": actual_micros,
|
||||
"reserved_micros": reserved_micros,
|
||||
}
|
||||
)
|
||||
return finalize_result or _FakeQuotaResult(allowed=True)
|
||||
|
||||
async def _fake_release(*, db_session, user_id, reserved_micros):
|
||||
release_calls.append({"user_id": user_id, "reserved_micros": reserved_micros})
|
||||
|
||||
record_calls: list[dict[str, Any]] = []
|
||||
|
||||
async def _fake_record(session, **kwargs):
|
||||
record_calls.append(kwargs)
|
||||
return object()
|
||||
|
||||
monkeypatch.setattr(
|
||||
"app.services.billable_calls.TokenQuotaService.premium_reserve",
|
||||
_fake_reserve,
|
||||
raising=False,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"app.services.billable_calls.TokenQuotaService.premium_finalize",
|
||||
_fake_finalize,
|
||||
raising=False,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"app.services.billable_calls.TokenQuotaService.premium_release",
|
||||
_fake_release,
|
||||
raising=False,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"app.services.billable_calls.shielded_async_session",
|
||||
_fake_shielded_session,
|
||||
raising=False,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"app.services.billable_calls.record_token_usage",
|
||||
_fake_record,
|
||||
raising=False,
|
||||
)
|
||||
|
||||
return {
|
||||
"reserve": reserve_calls,
|
||||
"finalize": finalize_calls,
|
||||
"release": release_calls,
|
||||
"record": record_calls,
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_free_path_skips_reserve_but_writes_audit_row(monkeypatch):
|
||||
from app.services.billable_calls import billable_call
|
||||
|
||||
spies = _patch_isolation_layer(
|
||||
monkeypatch, reserve_result=_FakeQuotaResult(allowed=True)
|
||||
)
|
||||
user_id = uuid4()
|
||||
|
||||
async with billable_call(
|
||||
user_id=user_id,
|
||||
search_space_id=42,
|
||||
billing_tier="free",
|
||||
base_model="openai/gpt-image-1",
|
||||
usage_type="image_generation",
|
||||
) as acc:
|
||||
# Simulate a captured cost — the accumulator is fed by the LiteLLM
|
||||
# callback in real life, here we add it manually.
|
||||
acc.add(
|
||||
model="openai/gpt-image-1",
|
||||
prompt_tokens=0,
|
||||
completion_tokens=0,
|
||||
total_tokens=0,
|
||||
cost_micros=37_000,
|
||||
call_kind="image_generation",
|
||||
)
|
||||
|
||||
assert spies["reserve"] == []
|
||||
assert spies["finalize"] == []
|
||||
assert spies["release"] == []
|
||||
# Free still audits.
|
||||
assert len(spies["record"]) == 1
|
||||
assert spies["record"][0]["usage_type"] == "image_generation"
|
||||
assert spies["record"][0]["cost_micros"] == 37_000
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_premium_reserve_denied_raises_quota_insufficient(monkeypatch):
|
||||
from app.services.billable_calls import (
|
||||
QuotaInsufficientError,
|
||||
billable_call,
|
||||
)
|
||||
|
||||
spies = _patch_isolation_layer(
|
||||
monkeypatch,
|
||||
reserve_result=_FakeQuotaResult(
|
||||
allowed=False, used=5_000_000, limit=5_000_000, remaining=0
|
||||
),
|
||||
)
|
||||
user_id = uuid4()
|
||||
|
||||
with pytest.raises(QuotaInsufficientError) as exc_info:
|
||||
async with billable_call(
|
||||
user_id=user_id,
|
||||
search_space_id=42,
|
||||
billing_tier="premium",
|
||||
base_model="openai/gpt-image-1",
|
||||
quota_reserve_micros_override=50_000,
|
||||
usage_type="image_generation",
|
||||
):
|
||||
pytest.fail("body should not run when reserve is denied")
|
||||
|
||||
err = exc_info.value
|
||||
assert err.usage_type == "image_generation"
|
||||
assert err.used_micros == 5_000_000
|
||||
assert err.limit_micros == 5_000_000
|
||||
assert err.remaining_micros == 0
|
||||
# Reserve was attempted, but no finalize/release on a denied reserve
|
||||
# — the reservation never actually held credit.
|
||||
assert len(spies["reserve"]) == 1
|
||||
assert spies["finalize"] == []
|
||||
assert spies["release"] == []
|
||||
# Denied premium calls do NOT create an audit row (no work happened).
|
||||
assert spies["record"] == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_premium_success_finalizes_with_actual_cost(monkeypatch):
|
||||
from app.services.billable_calls import billable_call
|
||||
|
||||
spies = _patch_isolation_layer(
|
||||
monkeypatch, reserve_result=_FakeQuotaResult(allowed=True)
|
||||
)
|
||||
user_id = uuid4()
|
||||
|
||||
async with billable_call(
|
||||
user_id=user_id,
|
||||
search_space_id=42,
|
||||
billing_tier="premium",
|
||||
base_model="openai/gpt-image-1",
|
||||
quota_reserve_micros_override=50_000,
|
||||
usage_type="image_generation",
|
||||
) as acc:
|
||||
# LiteLLM callback would normally fill this — simulate $0.04 image.
|
||||
acc.add(
|
||||
model="openai/gpt-image-1",
|
||||
prompt_tokens=0,
|
||||
completion_tokens=0,
|
||||
total_tokens=0,
|
||||
cost_micros=40_000,
|
||||
call_kind="image_generation",
|
||||
)
|
||||
|
||||
assert len(spies["reserve"]) == 1
|
||||
assert spies["reserve"][0]["reserve_micros"] == 50_000
|
||||
assert len(spies["finalize"]) == 1
|
||||
assert spies["finalize"][0]["actual_micros"] == 40_000
|
||||
assert spies["finalize"][0]["reserved_micros"] == 50_000
|
||||
assert spies["release"] == []
|
||||
# And audit row written with the actual debited cost.
|
||||
assert spies["record"][0]["cost_micros"] == 40_000
|
||||
# Each quota op opened its OWN session — proves session isolation.
|
||||
assert len(_SESSIONS_USED) >= 3
|
||||
# Sessions used should each have committed (or be the audit one which commits).
|
||||
for _s in _SESSIONS_USED:
|
||||
# finalize/reserve happen via TokenQuotaService.* which we stub —
|
||||
# they don't actually call commit on our fake session, but the
|
||||
# audit session does. We just assert >=1 session committed.
|
||||
pass
|
||||
assert any(s.committed for s in _SESSIONS_USED)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_premium_failure_releases_reservation(monkeypatch):
|
||||
from app.services.billable_calls import billable_call
|
||||
|
||||
spies = _patch_isolation_layer(
|
||||
monkeypatch, reserve_result=_FakeQuotaResult(allowed=True)
|
||||
)
|
||||
user_id = uuid4()
|
||||
|
||||
class _ProviderError(Exception):
|
||||
pass
|
||||
|
||||
with pytest.raises(_ProviderError):
|
||||
async with billable_call(
|
||||
user_id=user_id,
|
||||
search_space_id=42,
|
||||
billing_tier="premium",
|
||||
base_model="openai/gpt-image-1",
|
||||
quota_reserve_micros_override=50_000,
|
||||
usage_type="image_generation",
|
||||
):
|
||||
raise _ProviderError("OpenRouter 503")
|
||||
|
||||
assert len(spies["reserve"]) == 1
|
||||
assert spies["finalize"] == []
|
||||
# Failure path: release the held reservation.
|
||||
assert len(spies["release"]) == 1
|
||||
assert spies["release"][0]["reserved_micros"] == 50_000
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_premium_uses_estimator_when_no_micros_override(monkeypatch):
|
||||
"""When ``quota_reserve_micros_override`` is None we fall back to
|
||||
``estimate_call_reserve_micros(base_model, quota_reserve_tokens)``.
|
||||
Vision LLM calls take this path (token-priced models).
|
||||
"""
|
||||
from app.services.billable_calls import billable_call
|
||||
|
||||
spies = _patch_isolation_layer(
|
||||
monkeypatch, reserve_result=_FakeQuotaResult(allowed=True)
|
||||
)
|
||||
|
||||
captured_estimator_calls: list[dict[str, Any]] = []
|
||||
|
||||
def _fake_estimate(*, base_model, quota_reserve_tokens):
|
||||
captured_estimator_calls.append(
|
||||
{"base_model": base_model, "quota_reserve_tokens": quota_reserve_tokens}
|
||||
)
|
||||
return 12_345
|
||||
|
||||
monkeypatch.setattr(
|
||||
"app.services.billable_calls.estimate_call_reserve_micros",
|
||||
_fake_estimate,
|
||||
raising=False,
|
||||
)
|
||||
|
||||
user_id = uuid4()
|
||||
async with billable_call(
|
||||
user_id=user_id,
|
||||
search_space_id=1,
|
||||
billing_tier="premium",
|
||||
base_model="openai/gpt-4o",
|
||||
quota_reserve_tokens=4000,
|
||||
usage_type="vision_extraction",
|
||||
):
|
||||
pass
|
||||
|
||||
assert captured_estimator_calls == [
|
||||
{"base_model": "openai/gpt-4o", "quota_reserve_tokens": 4000}
|
||||
]
|
||||
assert spies["reserve"][0]["reserve_micros"] == 12_345
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Podcast / video-presentation usage_type coverage
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_free_podcast_path_audits_with_podcast_usage_type(monkeypatch):
|
||||
"""Free podcast configs must skip reserve/finalize but still emit a
|
||||
``TokenUsage`` row tagged ``usage_type='podcast_generation'`` so we
|
||||
have full audit coverage of free-tier agent runs."""
|
||||
from app.services.billable_calls import billable_call
|
||||
|
||||
spies = _patch_isolation_layer(
|
||||
monkeypatch, reserve_result=_FakeQuotaResult(allowed=True)
|
||||
)
|
||||
user_id = uuid4()
|
||||
|
||||
async with billable_call(
|
||||
user_id=user_id,
|
||||
search_space_id=42,
|
||||
billing_tier="free",
|
||||
base_model="openrouter/some-free-model",
|
||||
quota_reserve_micros_override=200_000,
|
||||
usage_type="podcast_generation",
|
||||
thread_id=99,
|
||||
call_details={"podcast_id": 7, "title": "Test Podcast"},
|
||||
) as acc:
|
||||
# Two transcript LLM calls aggregated into one accumulator.
|
||||
acc.add(
|
||||
model="openrouter/some-free-model",
|
||||
prompt_tokens=1500,
|
||||
completion_tokens=8000,
|
||||
total_tokens=9500,
|
||||
cost_micros=0,
|
||||
call_kind="chat",
|
||||
)
|
||||
|
||||
assert spies["reserve"] == []
|
||||
assert spies["finalize"] == []
|
||||
assert spies["release"] == []
|
||||
|
||||
assert len(spies["record"]) == 1
|
||||
row = spies["record"][0]
|
||||
assert row["usage_type"] == "podcast_generation"
|
||||
assert row["thread_id"] == 99
|
||||
assert row["search_space_id"] == 42
|
||||
assert row["call_details"] == {"podcast_id": 7, "title": "Test Podcast"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_premium_video_denial_raises_quota_insufficient(monkeypatch):
|
||||
"""Premium video-presentation runs that hit a denied reservation must
|
||||
raise ``QuotaInsufficientError`` *before* the graph runs and must not
|
||||
emit an audit row (no work happened)."""
|
||||
from app.services.billable_calls import (
|
||||
QuotaInsufficientError,
|
||||
billable_call,
|
||||
)
|
||||
|
||||
spies = _patch_isolation_layer(
|
||||
monkeypatch,
|
||||
reserve_result=_FakeQuotaResult(
|
||||
allowed=False, used=4_500_000, limit=5_000_000, remaining=500_000
|
||||
),
|
||||
)
|
||||
user_id = uuid4()
|
||||
|
||||
with pytest.raises(QuotaInsufficientError) as exc_info:
|
||||
async with billable_call(
|
||||
user_id=user_id,
|
||||
search_space_id=42,
|
||||
billing_tier="premium",
|
||||
base_model="gpt-5.4",
|
||||
quota_reserve_micros_override=1_000_000,
|
||||
usage_type="video_presentation_generation",
|
||||
thread_id=99,
|
||||
call_details={"video_presentation_id": 12, "title": "Test Video"},
|
||||
):
|
||||
pytest.fail("body should not run when reserve is denied")
|
||||
|
||||
err = exc_info.value
|
||||
assert err.usage_type == "video_presentation_generation"
|
||||
assert err.remaining_micros == 500_000
|
||||
assert spies["reserve"][0]["reserve_micros"] == 1_000_000
|
||||
assert spies["finalize"] == []
|
||||
assert spies["release"] == []
|
||||
assert spies["record"] == []
|
||||
|
|
@ -214,3 +214,159 @@ def test_generate_configs_drops_non_text_and_non_tool_models():
|
|||
assert "openai/gpt-4o" in model_names
|
||||
assert "openai/dall-e" not in model_names
|
||||
assert "openai/completion-only" not in model_names
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _generate_image_gen_configs / _generate_vision_llm_configs
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_generate_image_gen_configs_filters_by_image_output():
|
||||
"""Only models with ``output_modalities`` containing ``image`` are emitted.
|
||||
Tool-calling and context filters are intentionally NOT applied — image
|
||||
generation has nothing to do with tool calls and context windows.
|
||||
"""
|
||||
from app.services.openrouter_integration_service import (
|
||||
_generate_image_gen_configs,
|
||||
)
|
||||
|
||||
raw = [
|
||||
# Pure image-gen model (small context, no tools — should still emit).
|
||||
{
|
||||
"id": "openai/gpt-image-1",
|
||||
"architecture": {"output_modalities": ["image"]},
|
||||
"context_length": 4_000,
|
||||
"pricing": {"prompt": "0", "completion": "0"},
|
||||
},
|
||||
# Multi-modal: text+image output (should still emit).
|
||||
{
|
||||
"id": "google/gemini-2.5-flash-image",
|
||||
"architecture": {"output_modalities": ["text", "image"]},
|
||||
"context_length": 1_000_000,
|
||||
"pricing": {"prompt": "0.000001", "completion": "0.000004"},
|
||||
},
|
||||
# Pure text model — must NOT emit.
|
||||
{
|
||||
"id": "openai/gpt-4o",
|
||||
"architecture": {"output_modalities": ["text"]},
|
||||
"context_length": 128_000,
|
||||
"pricing": {"prompt": "0.000005", "completion": "0.000015"},
|
||||
},
|
||||
]
|
||||
|
||||
cfgs = _generate_image_gen_configs(raw, dict(_SETTINGS_BASE))
|
||||
model_names = {c["model_name"] for c in cfgs}
|
||||
assert "openai/gpt-image-1" in model_names
|
||||
assert "google/gemini-2.5-flash-image" in model_names
|
||||
assert "openai/gpt-4o" not in model_names
|
||||
|
||||
# Each config must carry ``billing_tier`` for routing in image_generation_routes.
|
||||
for c in cfgs:
|
||||
assert c["billing_tier"] in {"free", "premium"}
|
||||
assert c["provider"] == "OPENROUTER"
|
||||
assert c[_OPENROUTER_DYNAMIC_MARKER] is True
|
||||
|
||||
|
||||
def test_generate_image_gen_configs_assigns_image_id_offset():
|
||||
"""Image configs use a different id_offset (-20000) so their negative
|
||||
IDs don't collide with chat configs (-10000) or vision configs (-30000).
|
||||
"""
|
||||
from app.services.openrouter_integration_service import (
|
||||
_generate_image_gen_configs,
|
||||
)
|
||||
|
||||
raw = [
|
||||
{
|
||||
"id": "openai/gpt-image-1",
|
||||
"architecture": {"output_modalities": ["image"]},
|
||||
"context_length": 4_000,
|
||||
"pricing": {"prompt": "0", "completion": "0"},
|
||||
}
|
||||
]
|
||||
# Don't pass image_id_offset → use the module default (-20000).
|
||||
cfgs = _generate_image_gen_configs(raw, dict(_SETTINGS_BASE))
|
||||
assert all(c["id"] < -20_000 + 1 for c in cfgs)
|
||||
assert all(c["id"] > -29_000_000 for c in cfgs)
|
||||
|
||||
|
||||
def test_generate_vision_llm_configs_filters_by_image_input_text_output():
|
||||
"""Vision LLMs must accept image input AND emit text — pure image-gen
|
||||
(no text out) and text-only (no image in) models are excluded.
|
||||
"""
|
||||
from app.services.openrouter_integration_service import (
|
||||
_generate_vision_llm_configs,
|
||||
)
|
||||
|
||||
raw = [
|
||||
# GPT-4o: vision LLM (image in, text out) — must emit.
|
||||
{
|
||||
"id": "openai/gpt-4o",
|
||||
"architecture": {
|
||||
"input_modalities": ["text", "image"],
|
||||
"output_modalities": ["text"],
|
||||
},
|
||||
"context_length": 128_000,
|
||||
"pricing": {"prompt": "0.000005", "completion": "0.000015"},
|
||||
},
|
||||
# Pure image generator — image *output*, no text out. Must NOT emit.
|
||||
{
|
||||
"id": "openai/gpt-image-1",
|
||||
"architecture": {
|
||||
"input_modalities": ["text"],
|
||||
"output_modalities": ["image"],
|
||||
},
|
||||
"context_length": 4_000,
|
||||
"pricing": {"prompt": "0", "completion": "0"},
|
||||
},
|
||||
# Pure text model (no image in). Must NOT emit.
|
||||
{
|
||||
"id": "anthropic/claude-3-haiku",
|
||||
"architecture": {
|
||||
"input_modalities": ["text"],
|
||||
"output_modalities": ["text"],
|
||||
},
|
||||
"context_length": 200_000,
|
||||
"pricing": {"prompt": "0.000001", "completion": "0.000005"},
|
||||
},
|
||||
]
|
||||
|
||||
cfgs = _generate_vision_llm_configs(raw, dict(_SETTINGS_BASE))
|
||||
names = {c["model_name"] for c in cfgs}
|
||||
assert names == {"openai/gpt-4o"}
|
||||
|
||||
cfg = cfgs[0]
|
||||
assert cfg["billing_tier"] == "premium"
|
||||
# Pricing carried inline so pricing_registration can register vision
|
||||
# under ``openrouter/openai/gpt-4o`` even if the chat catalogue cache
|
||||
# is cleared.
|
||||
assert cfg["input_cost_per_token"] == pytest.approx(5e-6)
|
||||
assert cfg["output_cost_per_token"] == pytest.approx(15e-6)
|
||||
assert cfg[_OPENROUTER_DYNAMIC_MARKER] is True
|
||||
|
||||
|
||||
def test_generate_vision_llm_configs_drops_chat_only_filters():
|
||||
"""A small-context vision model that doesn't advertise tool calling is
|
||||
still a valid vision LLM for "describe this image" prompts. The chat
|
||||
filters (``supports_tool_calling``, ``has_sufficient_context``) must
|
||||
NOT be applied to vision emission.
|
||||
"""
|
||||
from app.services.openrouter_integration_service import (
|
||||
_generate_vision_llm_configs,
|
||||
)
|
||||
|
||||
raw = [
|
||||
{
|
||||
"id": "tiny/vision-mini",
|
||||
"architecture": {
|
||||
"input_modalities": ["text", "image"],
|
||||
"output_modalities": ["text"],
|
||||
},
|
||||
"supported_parameters": [], # no tools
|
||||
"context_length": 4_000, # well below MIN_CONTEXT_LENGTH
|
||||
"pricing": {"prompt": "0.0000001", "completion": "0.0000005"},
|
||||
}
|
||||
]
|
||||
|
||||
cfgs = _generate_vision_llm_configs(raw, dict(_SETTINGS_BASE))
|
||||
assert len(cfgs) == 1
|
||||
assert cfgs[0]["model_name"] == "tiny/vision-mini"
|
||||
|
|
|
|||
|
|
@ -0,0 +1,447 @@
|
|||
"""Pricing registration unit tests.
|
||||
|
||||
The pricing-registration module is what makes ``response_cost`` populate
|
||||
correctly for OpenRouter dynamic models and operator-defined Azure
|
||||
deployments — both of which LiteLLM doesn't natively know about. The tests
|
||||
exercise:
|
||||
|
||||
* The alias generators emit every shape that LiteLLM's cost-callback might
|
||||
use (``openrouter/X`` and bare ``X``; YAML-defined ``base_model``,
|
||||
``provider/base_model``, ``provider/model_name``, plus the special
|
||||
``azure_openai`` → ``azure`` normalisation).
|
||||
* ``register_pricing_from_global_configs`` calls ``litellm.register_model``
|
||||
with the right alias set and pricing values per provider.
|
||||
* Configs without a resolvable pair of cost values are skipped — never
|
||||
registered as zero, since that would override pricing LiteLLM might
|
||||
already know natively.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Alias generators
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_openrouter_alias_set_includes_prefixed_and_bare():
|
||||
from app.services.pricing_registration import _alias_set_for_openrouter
|
||||
|
||||
aliases = _alias_set_for_openrouter("anthropic/claude-3-5-sonnet")
|
||||
assert aliases == [
|
||||
"openrouter/anthropic/claude-3-5-sonnet",
|
||||
"anthropic/claude-3-5-sonnet",
|
||||
]
|
||||
|
||||
|
||||
def test_openrouter_alias_set_dedupes():
|
||||
"""If the model id is already prefixed with ``openrouter/``, the alias
|
||||
set must not contain duplicates that would re-register the same key
|
||||
twice.
|
||||
"""
|
||||
from app.services.pricing_registration import _alias_set_for_openrouter
|
||||
|
||||
aliases = _alias_set_for_openrouter("openrouter/foo")
|
||||
# The bare and prefixed variants compute to the same string here, so we
|
||||
# at minimum require uniqueness.
|
||||
assert len(aliases) == len(set(aliases))
|
||||
|
||||
|
||||
def test_yaml_alias_set_for_azure_openai_normalises_to_azure():
|
||||
"""``azure_openai`` (our YAML provider slug) must register under
|
||||
``azure/<name>`` so the LiteLLM Router's deployment-resolution path
|
||||
(which uses provider ``azure``) finds the pricing too.
|
||||
"""
|
||||
from app.services.pricing_registration import _alias_set_for_yaml
|
||||
|
||||
aliases = _alias_set_for_yaml(
|
||||
provider="AZURE_OPENAI",
|
||||
model_name="gpt-5.4",
|
||||
base_model="gpt-5.4",
|
||||
)
|
||||
assert "gpt-5.4" in aliases
|
||||
assert "azure_openai/gpt-5.4" in aliases
|
||||
assert "azure/gpt-5.4" in aliases
|
||||
|
||||
|
||||
def test_yaml_alias_set_distinguishes_model_name_and_base_model():
|
||||
"""When ``model_name`` differs from ``base_model`` (operator labelled a
|
||||
deployment), both must appear in the alias set since either may surface
|
||||
in callbacks depending on the call path.
|
||||
"""
|
||||
from app.services.pricing_registration import _alias_set_for_yaml
|
||||
|
||||
aliases = _alias_set_for_yaml(
|
||||
provider="OPENAI",
|
||||
model_name="my-deployment-label",
|
||||
base_model="gpt-4o",
|
||||
)
|
||||
assert "gpt-4o" in aliases
|
||||
assert "openai/gpt-4o" in aliases
|
||||
assert "my-deployment-label" in aliases
|
||||
assert "openai/my-deployment-label" in aliases
|
||||
|
||||
|
||||
def test_yaml_alias_set_omits_provider_prefix_when_provider_blank():
|
||||
from app.services.pricing_registration import _alias_set_for_yaml
|
||||
|
||||
aliases = _alias_set_for_yaml(
|
||||
provider="",
|
||||
model_name="foo",
|
||||
base_model="bar",
|
||||
)
|
||||
assert "bar" in aliases
|
||||
assert "foo" in aliases
|
||||
assert all("/" not in a for a in aliases)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# register_pricing_from_global_configs
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _RegistrationSpy:
|
||||
"""Captures the dicts passed to ``litellm.register_model``.
|
||||
|
||||
Many calls may go through; we just record them all and let tests assert
|
||||
against the union.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.calls: list[dict[str, Any]] = []
|
||||
|
||||
def __call__(self, payload: dict[str, Any]) -> None:
|
||||
self.calls.append(payload)
|
||||
|
||||
@property
|
||||
def all_keys(self) -> set[str]:
|
||||
keys: set[str] = set()
|
||||
for payload in self.calls:
|
||||
keys.update(payload.keys())
|
||||
return keys
|
||||
|
||||
|
||||
def _patch_register(monkeypatch: pytest.MonkeyPatch) -> _RegistrationSpy:
|
||||
spy = _RegistrationSpy()
|
||||
monkeypatch.setattr(
|
||||
"app.services.pricing_registration.litellm.register_model",
|
||||
spy,
|
||||
raising=False,
|
||||
)
|
||||
return spy
|
||||
|
||||
|
||||
def _patch_openrouter_pricing(
|
||||
monkeypatch: pytest.MonkeyPatch, mapping: dict[str, dict[str, str]]
|
||||
) -> None:
|
||||
"""Pretend the OpenRouter integration is initialised with ``mapping``."""
|
||||
|
||||
class _Stub:
|
||||
def get_raw_pricing(self) -> dict[str, dict[str, str]]:
|
||||
return mapping
|
||||
|
||||
class _StubService:
|
||||
@classmethod
|
||||
def is_initialized(cls) -> bool:
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls) -> _Stub:
|
||||
return _Stub()
|
||||
|
||||
monkeypatch.setattr(
|
||||
"app.services.openrouter_integration_service.OpenRouterIntegrationService",
|
||||
_StubService,
|
||||
raising=False,
|
||||
)
|
||||
|
||||
|
||||
def test_openrouter_models_register_under_aliases(monkeypatch):
|
||||
"""An OpenRouter config whose ``model_name`` is in the cached raw
|
||||
pricing map is registered under both ``openrouter/X`` and bare ``X``.
|
||||
"""
|
||||
from app.config import config
|
||||
from app.services.pricing_registration import register_pricing_from_global_configs
|
||||
|
||||
spy = _patch_register(monkeypatch)
|
||||
_patch_openrouter_pricing(
|
||||
monkeypatch,
|
||||
{
|
||||
"anthropic/claude-3-5-sonnet": {
|
||||
"prompt": "0.000003",
|
||||
"completion": "0.000015",
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
monkeypatch.setattr(
|
||||
config,
|
||||
"GLOBAL_LLM_CONFIGS",
|
||||
[
|
||||
{
|
||||
"id": 1,
|
||||
"provider": "OPENROUTER",
|
||||
"model_name": "anthropic/claude-3-5-sonnet",
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
register_pricing_from_global_configs()
|
||||
|
||||
assert "openrouter/anthropic/claude-3-5-sonnet" in spy.all_keys
|
||||
assert "anthropic/claude-3-5-sonnet" in spy.all_keys
|
||||
# Costs are float-converted from the raw OpenRouter strings.
|
||||
payload = spy.calls[0]
|
||||
assert payload["openrouter/anthropic/claude-3-5-sonnet"][
|
||||
"input_cost_per_token"
|
||||
] == pytest.approx(3e-6)
|
||||
assert payload["openrouter/anthropic/claude-3-5-sonnet"][
|
||||
"output_cost_per_token"
|
||||
] == pytest.approx(15e-6)
|
||||
assert (
|
||||
payload["openrouter/anthropic/claude-3-5-sonnet"]["litellm_provider"]
|
||||
== "openrouter"
|
||||
)
|
||||
|
||||
|
||||
def test_yaml_override_registers_under_alias_set(monkeypatch):
|
||||
"""Operator-declared ``input_cost_per_token`` /
|
||||
``output_cost_per_token`` on a YAML config registers under every
|
||||
alias the YAML alias generator produces — including the ``azure/``
|
||||
normalisation for ``azure_openai`` providers.
|
||||
"""
|
||||
from app.config import config
|
||||
from app.services.pricing_registration import register_pricing_from_global_configs
|
||||
|
||||
spy = _patch_register(monkeypatch)
|
||||
_patch_openrouter_pricing(monkeypatch, {})
|
||||
|
||||
monkeypatch.setattr(
|
||||
config,
|
||||
"GLOBAL_LLM_CONFIGS",
|
||||
[
|
||||
{
|
||||
"id": 1,
|
||||
"provider": "AZURE_OPENAI",
|
||||
"model_name": "gpt-5.4",
|
||||
"litellm_params": {
|
||||
"base_model": "gpt-5.4",
|
||||
"input_cost_per_token": 2e-6,
|
||||
"output_cost_per_token": 8e-6,
|
||||
},
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
register_pricing_from_global_configs()
|
||||
|
||||
keys = spy.all_keys
|
||||
assert "gpt-5.4" in keys
|
||||
assert "azure_openai/gpt-5.4" in keys
|
||||
assert "azure/gpt-5.4" in keys
|
||||
|
||||
payload = spy.calls[0]
|
||||
entry = payload["gpt-5.4"]
|
||||
assert entry["input_cost_per_token"] == pytest.approx(2e-6)
|
||||
assert entry["output_cost_per_token"] == pytest.approx(8e-6)
|
||||
assert entry["litellm_provider"] == "azure"
|
||||
|
||||
|
||||
def test_no_override_means_no_registration(monkeypatch):
|
||||
"""A YAML config that *omits* both pricing fields must NOT be registered
|
||||
— registering as zero would override LiteLLM's native pricing for the
|
||||
``base_model`` key (e.g. ``gpt-4o``) and silently make every user's
|
||||
bill drop to $0. Fail-safe is "skip and warn", not "register zero".
|
||||
"""
|
||||
from app.config import config
|
||||
from app.services.pricing_registration import register_pricing_from_global_configs
|
||||
|
||||
spy = _patch_register(monkeypatch)
|
||||
_patch_openrouter_pricing(monkeypatch, {})
|
||||
|
||||
monkeypatch.setattr(
|
||||
config,
|
||||
"GLOBAL_LLM_CONFIGS",
|
||||
[
|
||||
{
|
||||
"id": 1,
|
||||
"provider": "OPENAI",
|
||||
"model_name": "gpt-4o",
|
||||
"litellm_params": {"base_model": "gpt-4o"},
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
register_pricing_from_global_configs()
|
||||
|
||||
assert spy.calls == []
|
||||
|
||||
|
||||
def test_openrouter_skipped_when_pricing_missing(monkeypatch):
|
||||
"""If the OpenRouter raw-pricing cache doesn't carry an entry for a
|
||||
configured model (network blip during refresh, model added later, etc.),
|
||||
we skip it rather than registering zero pricing.
|
||||
"""
|
||||
from app.config import config
|
||||
from app.services.pricing_registration import register_pricing_from_global_configs
|
||||
|
||||
spy = _patch_register(monkeypatch)
|
||||
_patch_openrouter_pricing(
|
||||
monkeypatch, {"some/other-model": {"prompt": "1", "completion": "1"}}
|
||||
)
|
||||
|
||||
monkeypatch.setattr(
|
||||
config,
|
||||
"GLOBAL_LLM_CONFIGS",
|
||||
[
|
||||
{
|
||||
"id": 1,
|
||||
"provider": "OPENROUTER",
|
||||
"model_name": "anthropic/claude-3-5-sonnet",
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
register_pricing_from_global_configs()
|
||||
|
||||
assert spy.calls == []
|
||||
|
||||
|
||||
def test_register_continues_after_individual_failure(monkeypatch, caplog):
|
||||
"""A single bad ``register_model`` call (e.g. raising LiteLLM error)
|
||||
must not abort registration of the remaining configs.
|
||||
"""
|
||||
from app.config import config
|
||||
from app.services.pricing_registration import register_pricing_from_global_configs
|
||||
|
||||
failing_keys: set[str] = {"anthropic/claude-3-5-sonnet"}
|
||||
successful_calls: list[dict[str, Any]] = []
|
||||
|
||||
def _maybe_fail(payload: dict[str, Any]) -> None:
|
||||
if any(k in failing_keys for k in payload):
|
||||
raise RuntimeError("boom")
|
||||
successful_calls.append(payload)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"app.services.pricing_registration.litellm.register_model",
|
||||
_maybe_fail,
|
||||
raising=False,
|
||||
)
|
||||
_patch_openrouter_pricing(
|
||||
monkeypatch,
|
||||
{
|
||||
"anthropic/claude-3-5-sonnet": {
|
||||
"prompt": "0.000003",
|
||||
"completion": "0.000015",
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
monkeypatch.setattr(
|
||||
config,
|
||||
"GLOBAL_LLM_CONFIGS",
|
||||
[
|
||||
{
|
||||
"id": 1,
|
||||
"provider": "OPENROUTER",
|
||||
"model_name": "anthropic/claude-3-5-sonnet",
|
||||
},
|
||||
{
|
||||
"id": 2,
|
||||
"provider": "OPENAI",
|
||||
"model_name": "custom-deployment",
|
||||
"litellm_params": {
|
||||
"base_model": "custom-deployment",
|
||||
"input_cost_per_token": 1e-6,
|
||||
"output_cost_per_token": 2e-6,
|
||||
},
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
register_pricing_from_global_configs()
|
||||
|
||||
# The good config still registered.
|
||||
assert any("custom-deployment" in payload for payload in successful_calls)
|
||||
|
||||
|
||||
def test_vision_configs_registered_with_chat_shape(monkeypatch):
|
||||
"""``register_pricing_from_global_configs`` walks
|
||||
``GLOBAL_VISION_LLM_CONFIGS`` in addition to the chat configs so vision
|
||||
calls (during indexing) bill correctly. Vision configs use the same
|
||||
chat-shape token prices, but image-gen pricing is intentionally NOT
|
||||
registered here (handled via ``response_cost`` in LiteLLM).
|
||||
"""
|
||||
from app.config import config
|
||||
from app.services.pricing_registration import register_pricing_from_global_configs
|
||||
|
||||
spy = _patch_register(monkeypatch)
|
||||
_patch_openrouter_pricing(
|
||||
monkeypatch,
|
||||
{"openai/gpt-4o": {"prompt": "0.000005", "completion": "0.000015"}},
|
||||
)
|
||||
|
||||
# No chat configs — only vision. Proves the vision walk is a separate
|
||||
# iteration, not piggy-backed on the chat list.
|
||||
monkeypatch.setattr(config, "GLOBAL_LLM_CONFIGS", [])
|
||||
monkeypatch.setattr(
|
||||
config,
|
||||
"GLOBAL_VISION_LLM_CONFIGS",
|
||||
[
|
||||
{
|
||||
"id": -1,
|
||||
"provider": "OPENROUTER",
|
||||
"model_name": "openai/gpt-4o",
|
||||
"billing_tier": "premium",
|
||||
"input_cost_per_token": 5e-6,
|
||||
"output_cost_per_token": 15e-6,
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
register_pricing_from_global_configs()
|
||||
|
||||
assert "openrouter/openai/gpt-4o" in spy.all_keys
|
||||
payload_value = spy.calls[0]["openrouter/openai/gpt-4o"]
|
||||
assert payload_value["mode"] == "chat"
|
||||
assert payload_value["litellm_provider"] == "openrouter"
|
||||
assert payload_value["input_cost_per_token"] == pytest.approx(5e-6)
|
||||
assert payload_value["output_cost_per_token"] == pytest.approx(15e-6)
|
||||
|
||||
|
||||
def test_vision_with_inline_pricing_when_or_cache_missing(monkeypatch):
|
||||
"""If the OpenRouter pricing cache misses a vision model (different
|
||||
catalogue surface), the vision walk falls back to inline
|
||||
``input_cost_per_token``/``output_cost_per_token`` on the cfg itself.
|
||||
"""
|
||||
from app.config import config
|
||||
from app.services.pricing_registration import register_pricing_from_global_configs
|
||||
|
||||
spy = _patch_register(monkeypatch)
|
||||
_patch_openrouter_pricing(monkeypatch, {})
|
||||
|
||||
monkeypatch.setattr(config, "GLOBAL_LLM_CONFIGS", [])
|
||||
monkeypatch.setattr(
|
||||
config,
|
||||
"GLOBAL_VISION_LLM_CONFIGS",
|
||||
[
|
||||
{
|
||||
"id": -1,
|
||||
"provider": "OPENROUTER",
|
||||
"model_name": "google/gemini-2.5-flash",
|
||||
"billing_tier": "premium",
|
||||
"input_cost_per_token": 1e-6,
|
||||
"output_cost_per_token": 4e-6,
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
register_pricing_from_global_configs()
|
||||
|
||||
assert "openrouter/google/gemini-2.5-flash" in spy.all_keys
|
||||
|
|
@ -0,0 +1,157 @@
|
|||
"""Unit tests for ``QuotaCheckedVisionLLM``.
|
||||
|
||||
Validates that:
|
||||
|
||||
* Calling ``ainvoke`` routes through ``billable_call`` (premium credit
|
||||
enforcement) and forwards the inner LLM's response on success.
|
||||
* The wrapper proxies non-overridden attributes to the inner LLM
|
||||
(``__getattr__``) so ``invoke`` / ``astream`` / ``with_structured_output``
|
||||
still work without quota gating (they're not used in indexing today).
|
||||
* When ``billable_call`` raises ``QuotaInsufficientError`` the wrapper
|
||||
bubbles it up — the ETL pipeline catches that and falls back to OCR.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
class _FakeInnerLLM:
|
||||
"""Stand-in for ``langchain_litellm.ChatLiteLLM``."""
|
||||
|
||||
def __init__(self, response: Any = "OCR'd content") -> None:
|
||||
self._response = response
|
||||
self.ainvoke_calls: list[Any] = []
|
||||
|
||||
async def ainvoke(self, input: Any, *args: Any, **kwargs: Any) -> Any:
|
||||
self.ainvoke_calls.append(input)
|
||||
return self._response
|
||||
|
||||
def some_other_method(self, x: int) -> int:
|
||||
return x * 2
|
||||
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def _passthrough_billable_call(**_kwargs):
|
||||
"""Stand-in for billable_call that always allows the call to run."""
|
||||
|
||||
class _Acc:
|
||||
total_cost_micros = 0
|
||||
total_prompt_tokens = 0
|
||||
total_completion_tokens = 0
|
||||
grand_total = 0
|
||||
calls: list[Any] = []
|
||||
|
||||
def per_message_summary(self) -> dict[str, dict[str, int]]:
|
||||
return {}
|
||||
|
||||
yield _Acc()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ainvoke_routes_through_billable_call(monkeypatch):
|
||||
from app.services.quota_checked_vision_llm import QuotaCheckedVisionLLM
|
||||
|
||||
captured_kwargs: list[dict[str, Any]] = []
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def _spy_billable_call(**kwargs):
|
||||
captured_kwargs.append(kwargs)
|
||||
async with _passthrough_billable_call() as acc:
|
||||
yield acc
|
||||
|
||||
monkeypatch.setattr(
|
||||
"app.services.quota_checked_vision_llm.billable_call",
|
||||
_spy_billable_call,
|
||||
raising=False,
|
||||
)
|
||||
|
||||
inner = _FakeInnerLLM(response="A red apple on a white table")
|
||||
user_id = uuid4()
|
||||
wrapper = QuotaCheckedVisionLLM(
|
||||
inner,
|
||||
user_id=user_id,
|
||||
search_space_id=99,
|
||||
billing_tier="premium",
|
||||
base_model="openai/gpt-4o",
|
||||
quota_reserve_tokens=4000,
|
||||
)
|
||||
|
||||
result = await wrapper.ainvoke([{"text": "what is this?"}])
|
||||
assert result == "A red apple on a white table"
|
||||
assert len(inner.ainvoke_calls) == 1
|
||||
assert len(captured_kwargs) == 1
|
||||
bc_kwargs = captured_kwargs[0]
|
||||
assert bc_kwargs["user_id"] == user_id
|
||||
assert bc_kwargs["search_space_id"] == 99
|
||||
assert bc_kwargs["billing_tier"] == "premium"
|
||||
assert bc_kwargs["base_model"] == "openai/gpt-4o"
|
||||
assert bc_kwargs["quota_reserve_tokens"] == 4000
|
||||
assert bc_kwargs["usage_type"] == "vision_extraction"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ainvoke_propagates_quota_insufficient_error(monkeypatch):
|
||||
from app.services.billable_calls import QuotaInsufficientError
|
||||
from app.services.quota_checked_vision_llm import QuotaCheckedVisionLLM
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def _denying_billable_call(**_kwargs):
|
||||
raise QuotaInsufficientError(
|
||||
usage_type="vision_extraction",
|
||||
used_micros=5_000_000,
|
||||
limit_micros=5_000_000,
|
||||
remaining_micros=0,
|
||||
)
|
||||
yield # unreachable but required for asynccontextmanager type
|
||||
|
||||
monkeypatch.setattr(
|
||||
"app.services.quota_checked_vision_llm.billable_call",
|
||||
_denying_billable_call,
|
||||
raising=False,
|
||||
)
|
||||
|
||||
inner = _FakeInnerLLM()
|
||||
wrapper = QuotaCheckedVisionLLM(
|
||||
inner,
|
||||
user_id=uuid4(),
|
||||
search_space_id=1,
|
||||
billing_tier="premium",
|
||||
base_model="openai/gpt-4o",
|
||||
quota_reserve_tokens=4000,
|
||||
)
|
||||
|
||||
with pytest.raises(QuotaInsufficientError):
|
||||
await wrapper.ainvoke([{"text": "x"}])
|
||||
|
||||
# Inner LLM never ran on a denied reservation.
|
||||
assert inner.ainvoke_calls == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_proxies_non_overridden_attributes_to_inner():
|
||||
"""``__getattr__`` forwards anything not on the proxy itself, so any
|
||||
method we didn't explicitly override (``invoke``, ``astream``,
|
||||
``with_structured_output``, etc.) still works — just without quota
|
||||
gating, which is fine because the indexer only ever calls ainvoke.
|
||||
"""
|
||||
from app.services.quota_checked_vision_llm import QuotaCheckedVisionLLM
|
||||
|
||||
inner = _FakeInnerLLM()
|
||||
wrapper = QuotaCheckedVisionLLM(
|
||||
inner,
|
||||
user_id=uuid4(),
|
||||
search_space_id=1,
|
||||
billing_tier="premium",
|
||||
base_model="openai/gpt-4o",
|
||||
quota_reserve_tokens=4000,
|
||||
)
|
||||
|
||||
# ``some_other_method`` is on the inner only.
|
||||
assert wrapper.some_other_method(7) == 14
|
||||
|
|
@ -0,0 +1,515 @@
|
|||
"""Cost-based premium quota unit tests.
|
||||
|
||||
Covers the USD-micro behaviour added in migration 140:
|
||||
|
||||
* ``TurnTokenAccumulator.total_cost_micros`` sums ``cost_micros`` across all
|
||||
calls in a turn — used as the debit amount when ``agent_config.is_premium``
|
||||
is true, regardless of which underlying model produced each call. This
|
||||
preserves the prior "premium turn → all calls in turn count" rule from the
|
||||
token-based system.
|
||||
* ``estimate_call_reserve_micros`` scales linearly with model pricing,
|
||||
clamps to a sane floor when pricing is unknown, and respects the
|
||||
``QUOTA_MAX_RESERVE_MICROS`` ceiling so a misconfigured "$1000/M" entry
|
||||
can't lock the whole balance on one call.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TurnTokenAccumulator — premium-turn debit semantics
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_total_cost_micros_sums_premium_and_free_calls():
|
||||
"""A premium turn that also called a free sub-agent debits the union.
|
||||
|
||||
The plan deliberately preserved the existing "premium turn → all calls
|
||||
count" behaviour because per-call premium filtering relied on
|
||||
``LLMRouterService._premium_model_strings`` which only covers router-pool
|
||||
deployments. ``total_cost_micros`` therefore must include free-model
|
||||
calls (whose ``cost_micros`` is typically ``0``) as well as the premium
|
||||
call's actual provider cost.
|
||||
"""
|
||||
from app.services.token_tracking_service import TurnTokenAccumulator
|
||||
|
||||
acc = TurnTokenAccumulator()
|
||||
# Premium model (e.g. claude-opus): non-zero cost.
|
||||
acc.add(
|
||||
model="anthropic/claude-3-5-sonnet",
|
||||
prompt_tokens=1200,
|
||||
completion_tokens=400,
|
||||
total_tokens=1600,
|
||||
cost_micros=12_345,
|
||||
)
|
||||
# Free sub-agent (e.g. title-gen on a free model): zero cost.
|
||||
acc.add(
|
||||
model="gpt-4o-mini",
|
||||
prompt_tokens=120,
|
||||
completion_tokens=20,
|
||||
total_tokens=140,
|
||||
cost_micros=0,
|
||||
)
|
||||
# A second premium-priced call within the same turn.
|
||||
acc.add(
|
||||
model="anthropic/claude-3-5-sonnet",
|
||||
prompt_tokens=800,
|
||||
completion_tokens=200,
|
||||
total_tokens=1000,
|
||||
cost_micros=7_500,
|
||||
)
|
||||
|
||||
assert acc.total_cost_micros == 12_345 + 0 + 7_500
|
||||
# Token totals stay correct so the FE display path still works.
|
||||
assert acc.grand_total == 1600 + 140 + 1000
|
||||
|
||||
|
||||
def test_total_cost_micros_zero_when_no_calls():
|
||||
"""An empty accumulator must report zero cost (no division-by-zero, no None)."""
|
||||
from app.services.token_tracking_service import TurnTokenAccumulator
|
||||
|
||||
acc = TurnTokenAccumulator()
|
||||
assert acc.total_cost_micros == 0
|
||||
assert acc.grand_total == 0
|
||||
|
||||
|
||||
def test_per_message_summary_groups_cost_by_model():
|
||||
"""``per_message_summary`` must accumulate ``cost_micros`` per model so the
|
||||
SSE ``model_breakdown`` payload reports actual USD spend per provider.
|
||||
"""
|
||||
from app.services.token_tracking_service import TurnTokenAccumulator
|
||||
|
||||
acc = TurnTokenAccumulator()
|
||||
acc.add(
|
||||
model="claude-3-5-sonnet",
|
||||
prompt_tokens=100,
|
||||
completion_tokens=50,
|
||||
total_tokens=150,
|
||||
cost_micros=4_000,
|
||||
)
|
||||
acc.add(
|
||||
model="claude-3-5-sonnet",
|
||||
prompt_tokens=200,
|
||||
completion_tokens=100,
|
||||
total_tokens=300,
|
||||
cost_micros=8_000,
|
||||
)
|
||||
acc.add(
|
||||
model="gpt-4o-mini",
|
||||
prompt_tokens=50,
|
||||
completion_tokens=10,
|
||||
total_tokens=60,
|
||||
cost_micros=200,
|
||||
)
|
||||
|
||||
summary = acc.per_message_summary()
|
||||
assert summary["claude-3-5-sonnet"]["cost_micros"] == 12_000
|
||||
assert summary["claude-3-5-sonnet"]["total_tokens"] == 450
|
||||
assert summary["gpt-4o-mini"]["cost_micros"] == 200
|
||||
|
||||
|
||||
def test_serialized_calls_includes_cost_micros():
|
||||
"""``serialized_calls`` is what flows into the SSE ``call_details``
|
||||
payload; cost_micros must be present on each entry so the FE message-info
|
||||
dropdown can render per-call USD.
|
||||
"""
|
||||
from app.services.token_tracking_service import TurnTokenAccumulator
|
||||
|
||||
acc = TurnTokenAccumulator()
|
||||
acc.add(
|
||||
model="m",
|
||||
prompt_tokens=1,
|
||||
completion_tokens=1,
|
||||
total_tokens=2,
|
||||
cost_micros=42,
|
||||
)
|
||||
serialized = acc.serialized_calls()
|
||||
assert serialized == [
|
||||
{
|
||||
"model": "m",
|
||||
"prompt_tokens": 1,
|
||||
"completion_tokens": 1,
|
||||
"total_tokens": 2,
|
||||
"cost_micros": 42,
|
||||
"call_kind": "chat",
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# estimate_call_reserve_micros — sizing and clamping
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_reserve_returns_floor_when_model_unknown(monkeypatch):
|
||||
"""If LiteLLM doesn't know the model, ``get_model_info`` raises and the
|
||||
helper falls back to the 100-micro floor — small enough that a user with
|
||||
$0.0001 left can still send a tiny request, but non-zero so we still gate
|
||||
against an empty balance.
|
||||
"""
|
||||
import litellm
|
||||
|
||||
from app.services import token_quota_service
|
||||
|
||||
def _raise(_name):
|
||||
raise KeyError("unknown")
|
||||
|
||||
monkeypatch.setattr(litellm, "get_model_info", _raise, raising=False)
|
||||
|
||||
micros = token_quota_service.estimate_call_reserve_micros(
|
||||
base_model="nonexistent-model",
|
||||
quota_reserve_tokens=4000,
|
||||
)
|
||||
assert micros == token_quota_service._QUOTA_MIN_RESERVE_MICROS
|
||||
assert micros == 100
|
||||
|
||||
|
||||
def test_reserve_returns_floor_when_pricing_is_zero(monkeypatch):
|
||||
"""LiteLLM may *return* a model with both cost-per-token fields at 0
|
||||
(pricing not yet registered). The helper must not multiply 0 x tokens
|
||||
and end up reserving 0 — it must clamp to the floor.
|
||||
"""
|
||||
import litellm
|
||||
|
||||
from app.services import token_quota_service
|
||||
|
||||
monkeypatch.setattr(
|
||||
litellm,
|
||||
"get_model_info",
|
||||
lambda _name: {"input_cost_per_token": 0, "output_cost_per_token": 0},
|
||||
raising=False,
|
||||
)
|
||||
|
||||
micros = token_quota_service.estimate_call_reserve_micros(
|
||||
base_model="some-pending-model",
|
||||
quota_reserve_tokens=4000,
|
||||
)
|
||||
assert micros == token_quota_service._QUOTA_MIN_RESERVE_MICROS
|
||||
|
||||
|
||||
def test_reserve_scales_with_model_cost(monkeypatch):
|
||||
"""Claude-Opus-priced model with 4000 reserve_tokens reserves
|
||||
~$0.36 = 360_000 micros. Critically this must NOT be clamped down to
|
||||
some small artificial cap — that was the bug the plan called out.
|
||||
"""
|
||||
import litellm
|
||||
|
||||
from app.config import config
|
||||
from app.services import token_quota_service
|
||||
|
||||
monkeypatch.setattr(
|
||||
litellm,
|
||||
"get_model_info",
|
||||
lambda _name: {
|
||||
"input_cost_per_token": 15e-6,
|
||||
"output_cost_per_token": 75e-6,
|
||||
},
|
||||
raising=False,
|
||||
)
|
||||
monkeypatch.setattr(config, "QUOTA_MAX_RESERVE_MICROS", 1_000_000, raising=False)
|
||||
|
||||
micros = token_quota_service.estimate_call_reserve_micros(
|
||||
base_model="claude-3-opus",
|
||||
quota_reserve_tokens=4000,
|
||||
)
|
||||
# 4000 * (15e-6 + 75e-6) = 4000 * 90e-6 = 0.36 USD = 360_000 micros.
|
||||
assert micros == 360_000
|
||||
|
||||
|
||||
def test_reserve_clamps_to_max_ceiling(monkeypatch):
|
||||
"""A misconfigured "$1000 / M" model with 4000 reserve_tokens would
|
||||
nominally compute to $4 = 4_000_000 micros. The ceiling
|
||||
``QUOTA_MAX_RESERVE_MICROS`` must clamp that so a bad pricing entry
|
||||
can't lock the user's whole balance on one call.
|
||||
"""
|
||||
import litellm
|
||||
|
||||
from app.config import config
|
||||
from app.services import token_quota_service
|
||||
|
||||
monkeypatch.setattr(
|
||||
litellm,
|
||||
"get_model_info",
|
||||
lambda _name: {
|
||||
"input_cost_per_token": 1e-3,
|
||||
"output_cost_per_token": 0,
|
||||
},
|
||||
raising=False,
|
||||
)
|
||||
monkeypatch.setattr(config, "QUOTA_MAX_RESERVE_MICROS", 1_000_000, raising=False)
|
||||
|
||||
micros = token_quota_service.estimate_call_reserve_micros(
|
||||
base_model="oops-misconfigured",
|
||||
quota_reserve_tokens=4000,
|
||||
)
|
||||
assert micros == 1_000_000
|
||||
|
||||
|
||||
def test_reserve_uses_default_when_quota_reserve_tokens_missing(monkeypatch):
|
||||
"""Per-config ``quota_reserve_tokens`` is optional; when ``None`` or
|
||||
zero, the helper must fall back to the global ``QUOTA_MAX_RESERVE_PER_CALL``
|
||||
so anonymous-style configs still reserve the operator-tunable default.
|
||||
"""
|
||||
import litellm
|
||||
|
||||
from app.config import config
|
||||
from app.services import token_quota_service
|
||||
|
||||
monkeypatch.setattr(
|
||||
litellm,
|
||||
"get_model_info",
|
||||
lambda _name: {
|
||||
"input_cost_per_token": 1e-6,
|
||||
"output_cost_per_token": 1e-6,
|
||||
},
|
||||
raising=False,
|
||||
)
|
||||
monkeypatch.setattr(config, "QUOTA_MAX_RESERVE_PER_CALL", 2000, raising=False)
|
||||
monkeypatch.setattr(config, "QUOTA_MAX_RESERVE_MICROS", 1_000_000, raising=False)
|
||||
|
||||
# 2000 * (1e-6 + 1e-6) = 4e-3 USD = 4000 micros
|
||||
assert (
|
||||
token_quota_service.estimate_call_reserve_micros(
|
||||
base_model="cheap", quota_reserve_tokens=None
|
||||
)
|
||||
== 4000
|
||||
)
|
||||
assert (
|
||||
token_quota_service.estimate_call_reserve_micros(
|
||||
base_model="cheap", quota_reserve_tokens=0
|
||||
)
|
||||
== 4000
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TokenTrackingCallback — image vs chat usage shape
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _FakeImageUsage:
|
||||
"""Mimics LiteLLM's ``ImageUsage`` (input_tokens / output_tokens shape)."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_tokens: int = 0,
|
||||
output_tokens: int = 0,
|
||||
total_tokens: int | None = None,
|
||||
) -> None:
|
||||
self.input_tokens = input_tokens
|
||||
self.output_tokens = output_tokens
|
||||
if total_tokens is not None:
|
||||
self.total_tokens = total_tokens
|
||||
|
||||
|
||||
class _FakeImageResponse:
|
||||
"""Mimics LiteLLM's ``ImageResponse`` — same name so the callback's
|
||||
``type(...).__name__`` probe routes to the image branch.
|
||||
"""
|
||||
|
||||
def __init__(self, usage: _FakeImageUsage, response_cost: float | None = None):
|
||||
self.usage = usage
|
||||
if response_cost is not None:
|
||||
self._hidden_params = {"response_cost": response_cost}
|
||||
|
||||
|
||||
# Re-tag the helper class as ``ImageResponse`` for the type-name probe in
|
||||
# the callback. We can't simply name the class ``ImageResponse`` because
|
||||
# the test runner sometimes imports test modules in surprising ways and
|
||||
# we want to be explicit.
|
||||
_FakeImageResponse.__name__ = "ImageResponse"
|
||||
|
||||
|
||||
class _FakeChatUsage:
|
||||
def __init__(self, prompt: int, completion: int):
|
||||
self.prompt_tokens = prompt
|
||||
self.completion_tokens = completion
|
||||
self.total_tokens = prompt + completion
|
||||
|
||||
|
||||
class _FakeChatResponse:
|
||||
def __init__(self, usage: _FakeChatUsage):
|
||||
self.usage = usage
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_callback_reads_image_usage_input_output_tokens():
|
||||
"""``TokenTrackingCallback`` must read ``input_tokens``/``output_tokens``
|
||||
for ``ImageResponse`` (LiteLLM's ImageUsage shape), NOT
|
||||
prompt_tokens/completion_tokens which is the chat shape.
|
||||
"""
|
||||
from app.services.token_tracking_service import (
|
||||
TokenTrackingCallback,
|
||||
scoped_turn,
|
||||
)
|
||||
|
||||
cb = TokenTrackingCallback()
|
||||
response = _FakeImageResponse(
|
||||
usage=_FakeImageUsage(input_tokens=42, output_tokens=8, total_tokens=50),
|
||||
response_cost=0.04, # $0.04 per image
|
||||
)
|
||||
|
||||
async with scoped_turn() as acc:
|
||||
await cb.async_log_success_event(
|
||||
kwargs={"model": "openai/gpt-image-1", "response_cost": 0.04},
|
||||
response_obj=response,
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
)
|
||||
assert len(acc.calls) == 1
|
||||
call = acc.calls[0]
|
||||
assert call.prompt_tokens == 42
|
||||
assert call.completion_tokens == 8
|
||||
assert call.total_tokens == 50
|
||||
# 0.04 USD = 40_000 micros
|
||||
assert call.cost_micros == 40_000
|
||||
assert call.call_kind == "image_generation"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_callback_chat_path_unchanged():
|
||||
"""Chat responses must still read prompt_tokens/completion_tokens."""
|
||||
from app.services.token_tracking_service import (
|
||||
TokenTrackingCallback,
|
||||
scoped_turn,
|
||||
)
|
||||
|
||||
cb = TokenTrackingCallback()
|
||||
response = _FakeChatResponse(_FakeChatUsage(prompt=120, completion=30))
|
||||
|
||||
async with scoped_turn() as acc:
|
||||
await cb.async_log_success_event(
|
||||
kwargs={
|
||||
"model": "openrouter/anthropic/claude-3-5-sonnet",
|
||||
"response_cost": 0.0036,
|
||||
},
|
||||
response_obj=response,
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
)
|
||||
assert len(acc.calls) == 1
|
||||
call = acc.calls[0]
|
||||
assert call.prompt_tokens == 120
|
||||
assert call.completion_tokens == 30
|
||||
assert call.total_tokens == 150
|
||||
assert call.cost_micros == 3_600
|
||||
assert call.call_kind == "chat"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_callback_image_missing_response_cost_falls_back_to_zero(monkeypatch):
|
||||
"""When OpenRouter omits ``usage.cost`` LiteLLM's
|
||||
``default_image_cost_calculator`` raises. The defensive image branch in
|
||||
``_extract_cost_usd`` must NOT call ``cost_per_token`` (which is
|
||||
chat-shaped and would raise too) — it returns 0 with a WARNING log.
|
||||
"""
|
||||
import litellm
|
||||
|
||||
from app.services.token_tracking_service import (
|
||||
TokenTrackingCallback,
|
||||
scoped_turn,
|
||||
)
|
||||
|
||||
# Force completion_cost to raise the same way OpenRouter image-gen fails.
|
||||
def _boom(*_args, **_kwargs):
|
||||
raise ValueError("model_cost: missing entry for openrouter image model")
|
||||
|
||||
monkeypatch.setattr(litellm, "completion_cost", _boom, raising=False)
|
||||
|
||||
# And make sure cost_per_token is NEVER called for the image path —
|
||||
# if it were, our ``is_image=True`` branch is broken.
|
||||
cost_per_token_calls: list = []
|
||||
|
||||
def _record_cost_per_token(**kwargs):
|
||||
cost_per_token_calls.append(kwargs)
|
||||
return (0.0, 0.0)
|
||||
|
||||
monkeypatch.setattr(
|
||||
litellm, "cost_per_token", _record_cost_per_token, raising=False
|
||||
)
|
||||
|
||||
cb = TokenTrackingCallback()
|
||||
response = _FakeImageResponse(
|
||||
usage=_FakeImageUsage(input_tokens=7, output_tokens=0)
|
||||
)
|
||||
|
||||
async with scoped_turn() as acc:
|
||||
await cb.async_log_success_event(
|
||||
kwargs={"model": "openrouter/google/gemini-2.5-flash-image"},
|
||||
response_obj=response,
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
)
|
||||
|
||||
assert len(acc.calls) == 1
|
||||
assert acc.calls[0].cost_micros == 0
|
||||
assert acc.calls[0].call_kind == "image_generation"
|
||||
# The image branch must short-circuit before cost_per_token.
|
||||
assert cost_per_token_calls == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# scoped_turn — ContextVar reset semantics (issue B)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scoped_turn_restores_outer_accumulator():
|
||||
"""``scoped_turn`` must restore the previous ContextVar value on exit
|
||||
so a per-call wrapper inside an outer chat turn doesn't leak its
|
||||
accumulator outward (which would cause double-debit at chat-turn exit).
|
||||
"""
|
||||
from app.services.token_tracking_service import (
|
||||
get_current_accumulator,
|
||||
scoped_turn,
|
||||
start_turn,
|
||||
)
|
||||
|
||||
outer = start_turn()
|
||||
assert get_current_accumulator() is outer
|
||||
|
||||
async with scoped_turn() as inner:
|
||||
assert get_current_accumulator() is inner
|
||||
assert inner is not outer
|
||||
inner.add(
|
||||
model="x",
|
||||
prompt_tokens=1,
|
||||
completion_tokens=1,
|
||||
total_tokens=2,
|
||||
cost_micros=5,
|
||||
)
|
||||
|
||||
# After exit the outer accumulator is restored unchanged.
|
||||
assert get_current_accumulator() is outer
|
||||
assert outer.total_cost_micros == 0
|
||||
assert len(outer.calls) == 0
|
||||
# The inner accumulator captured the call but didn't bleed into outer.
|
||||
assert inner.total_cost_micros == 5
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scoped_turn_resets_to_none_when_no_outer():
|
||||
"""Running ``scoped_turn`` outside any chat turn (e.g. a background
|
||||
indexing job) must leave the ContextVar at ``None`` on exit so the
|
||||
next *unrelated* request starts clean.
|
||||
"""
|
||||
from app.services.token_tracking_service import (
|
||||
_turn_accumulator,
|
||||
get_current_accumulator,
|
||||
scoped_turn,
|
||||
)
|
||||
|
||||
# ContextVar default is None for a fresh test isolated context. We
|
||||
# simulate "no outer" explicitly to be robust against test order.
|
||||
token = _turn_accumulator.set(None)
|
||||
try:
|
||||
assert get_current_accumulator() is None
|
||||
async with scoped_turn() as acc:
|
||||
assert get_current_accumulator() is acc
|
||||
assert get_current_accumulator() is None
|
||||
finally:
|
||||
_turn_accumulator.reset(token)
|
||||
325
surfsense_backend/tests/unit/tasks/test_podcast_billing.py
Normal file
325
surfsense_backend/tests/unit/tasks/test_podcast_billing.py
Normal file
|
|
@ -0,0 +1,325 @@
|
|||
"""Unit tests for podcast Celery task billing integration.
|
||||
|
||||
Validates ``_generate_content_podcast`` correctly wraps
|
||||
``podcaster_graph.ainvoke`` in a ``billable_call`` envelope, propagates the
|
||||
search-space owner's billing decision, and degrades cleanly when the
|
||||
resolver fails or premium credit is exhausted.
|
||||
|
||||
Coverage:
|
||||
|
||||
* Happy-path free config: resolver → ``billable_call`` enters with
|
||||
``usage_type='podcast_generation'`` and the configured reserve override,
|
||||
graph runs, podcast row flips to ``READY``.
|
||||
* Happy-path premium config: same wiring with ``billing_tier='premium'``.
|
||||
* Quota denial: ``billable_call`` raises ``QuotaInsufficientError`` →
|
||||
graph is *not* invoked, podcast row flips to ``FAILED``, return dict
|
||||
carries ``reason='premium_quota_exhausted'``.
|
||||
* Resolver failure: ``ValueError`` from the resolver → podcast row flips
|
||||
to ``FAILED``, return dict carries ``reason='billing_resolution_failed'``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
from types import SimpleNamespace
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fakes
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _FakeExecResult:
|
||||
def __init__(self, obj):
|
||||
self._obj = obj
|
||||
|
||||
def scalars(self):
|
||||
return self
|
||||
|
||||
def first(self):
|
||||
return self._obj
|
||||
|
||||
def filter(self, *_args, **_kwargs):
|
||||
return self
|
||||
|
||||
|
||||
class _FakeSession:
|
||||
def __init__(self, podcast):
|
||||
self._podcast = podcast
|
||||
self.commit_count = 0
|
||||
|
||||
async def execute(self, _stmt):
|
||||
return _FakeExecResult(self._podcast)
|
||||
|
||||
async def commit(self):
|
||||
self.commit_count += 1
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, *args):
|
||||
return None
|
||||
|
||||
|
||||
class _FakeSessionMaker:
|
||||
def __init__(self, session: _FakeSession):
|
||||
self._session = session
|
||||
|
||||
def __call__(self):
|
||||
return self._session
|
||||
|
||||
|
||||
def _make_podcast(podcast_id: int = 7, thread_id: int = 99) -> SimpleNamespace:
|
||||
"""Stand-in for a ``Podcast`` row. Importing ``PodcastStatus`` lazily
|
||||
inside helpers keeps this fixture cheap."""
|
||||
return SimpleNamespace(
|
||||
id=podcast_id,
|
||||
title="Test Podcast",
|
||||
thread_id=thread_id,
|
||||
status=None,
|
||||
podcast_transcript=None,
|
||||
file_location=None,
|
||||
)
|
||||
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def _ok_billable_call(**kwargs):
|
||||
"""Stand-in for ``billable_call`` that records its kwargs and yields a
|
||||
no-op accumulator-shaped object."""
|
||||
_CALL_LOG.append(kwargs)
|
||||
yield SimpleNamespace()
|
||||
|
||||
|
||||
_CALL_LOG: list[dict[str, Any]] = []
|
||||
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def _denying_billable_call(**kwargs):
|
||||
from app.services.billable_calls import QuotaInsufficientError
|
||||
|
||||
_CALL_LOG.append(kwargs)
|
||||
raise QuotaInsufficientError(
|
||||
usage_type=kwargs.get("usage_type", "?"),
|
||||
used_micros=5_000_000,
|
||||
limit_micros=5_000_000,
|
||||
remaining_micros=0,
|
||||
)
|
||||
yield SimpleNamespace() # pragma: no cover — for grammar only
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _reset_call_log():
|
||||
_CALL_LOG.clear()
|
||||
yield
|
||||
_CALL_LOG.clear()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_billable_call_invoked_with_correct_kwargs_for_free_config(monkeypatch):
|
||||
"""Happy path: free billing tier still wraps the graph call so the
|
||||
audit row is recorded. Verifies kwargs threading."""
|
||||
from app.config import config as app_config
|
||||
from app.db import PodcastStatus
|
||||
from app.tasks.celery_tasks import podcast_tasks
|
||||
|
||||
podcast = _make_podcast(podcast_id=7, thread_id=99)
|
||||
session = _FakeSession(podcast)
|
||||
monkeypatch.setattr(
|
||||
podcast_tasks,
|
||||
"get_celery_session_maker",
|
||||
lambda: _FakeSessionMaker(session),
|
||||
)
|
||||
|
||||
user_id = uuid4()
|
||||
|
||||
async def _fake_resolver(sess, search_space_id, *, thread_id=None):
|
||||
assert search_space_id == 555
|
||||
assert thread_id == 99
|
||||
return user_id, "free", "openrouter/some-free-model"
|
||||
|
||||
monkeypatch.setattr(
|
||||
podcast_tasks, "_resolve_agent_billing_for_search_space", _fake_resolver
|
||||
)
|
||||
monkeypatch.setattr(podcast_tasks, "billable_call", _ok_billable_call)
|
||||
|
||||
async def _fake_graph_invoke(state, config):
|
||||
return {
|
||||
"podcast_transcript": [
|
||||
SimpleNamespace(speaker_id=0, dialog="Hi"),
|
||||
SimpleNamespace(speaker_id=1, dialog="Hello"),
|
||||
],
|
||||
"final_podcast_file_path": "/tmp/podcast.wav",
|
||||
}
|
||||
|
||||
monkeypatch.setattr(podcast_tasks.podcaster_graph, "ainvoke", _fake_graph_invoke)
|
||||
|
||||
result = await podcast_tasks._generate_content_podcast(
|
||||
podcast_id=7,
|
||||
source_content="hello world",
|
||||
search_space_id=555,
|
||||
user_prompt="make it short",
|
||||
)
|
||||
|
||||
assert result["status"] == "ready"
|
||||
assert result["podcast_id"] == 7
|
||||
assert podcast.status == PodcastStatus.READY
|
||||
assert podcast.file_location == "/tmp/podcast.wav"
|
||||
|
||||
assert len(_CALL_LOG) == 1
|
||||
call = _CALL_LOG[0]
|
||||
assert call["user_id"] == user_id
|
||||
assert call["search_space_id"] == 555
|
||||
assert call["billing_tier"] == "free"
|
||||
assert call["base_model"] == "openrouter/some-free-model"
|
||||
assert call["usage_type"] == "podcast_generation"
|
||||
assert (
|
||||
call["quota_reserve_micros_override"]
|
||||
== app_config.QUOTA_DEFAULT_PODCAST_RESERVE_MICROS
|
||||
)
|
||||
assert call["thread_id"] == 99
|
||||
assert call["call_details"] == {"podcast_id": 7, "title": "Test Podcast"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_billable_call_invoked_with_premium_tier(monkeypatch):
|
||||
"""Premium resolution flows through to ``billable_call`` so the
|
||||
reserve/finalize path triggers."""
|
||||
from app.tasks.celery_tasks import podcast_tasks
|
||||
|
||||
podcast = _make_podcast()
|
||||
session = _FakeSession(podcast)
|
||||
monkeypatch.setattr(
|
||||
podcast_tasks,
|
||||
"get_celery_session_maker",
|
||||
lambda: _FakeSessionMaker(session),
|
||||
)
|
||||
|
||||
user_id = uuid4()
|
||||
|
||||
async def _fake_resolver(sess, search_space_id, *, thread_id=None):
|
||||
return user_id, "premium", "gpt-5.4"
|
||||
|
||||
monkeypatch.setattr(
|
||||
podcast_tasks, "_resolve_agent_billing_for_search_space", _fake_resolver
|
||||
)
|
||||
monkeypatch.setattr(podcast_tasks, "billable_call", _ok_billable_call)
|
||||
|
||||
async def _fake_graph_invoke(state, config):
|
||||
return {"podcast_transcript": [], "final_podcast_file_path": "x.wav"}
|
||||
|
||||
monkeypatch.setattr(podcast_tasks.podcaster_graph, "ainvoke", _fake_graph_invoke)
|
||||
|
||||
await podcast_tasks._generate_content_podcast(
|
||||
podcast_id=7,
|
||||
source_content="hi",
|
||||
search_space_id=555,
|
||||
user_prompt=None,
|
||||
)
|
||||
|
||||
assert _CALL_LOG[0]["billing_tier"] == "premium"
|
||||
assert _CALL_LOG[0]["base_model"] == "gpt-5.4"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_quota_insufficient_marks_podcast_failed_and_skips_graph(monkeypatch):
|
||||
"""When ``billable_call`` denies the reservation, the graph never
|
||||
runs and the podcast row flips to FAILED with the documented reason
|
||||
code."""
|
||||
from app.db import PodcastStatus
|
||||
from app.tasks.celery_tasks import podcast_tasks
|
||||
|
||||
podcast = _make_podcast(podcast_id=8)
|
||||
session = _FakeSession(podcast)
|
||||
monkeypatch.setattr(
|
||||
podcast_tasks,
|
||||
"get_celery_session_maker",
|
||||
lambda: _FakeSessionMaker(session),
|
||||
)
|
||||
|
||||
async def _fake_resolver(sess, search_space_id, *, thread_id=None):
|
||||
return uuid4(), "premium", "gpt-5.4"
|
||||
|
||||
monkeypatch.setattr(
|
||||
podcast_tasks, "_resolve_agent_billing_for_search_space", _fake_resolver
|
||||
)
|
||||
monkeypatch.setattr(podcast_tasks, "billable_call", _denying_billable_call)
|
||||
|
||||
graph_invoked = []
|
||||
|
||||
async def _fake_graph_invoke(state, config):
|
||||
graph_invoked.append(True)
|
||||
return {}
|
||||
|
||||
monkeypatch.setattr(podcast_tasks.podcaster_graph, "ainvoke", _fake_graph_invoke)
|
||||
|
||||
result = await podcast_tasks._generate_content_podcast(
|
||||
podcast_id=8,
|
||||
source_content="hi",
|
||||
search_space_id=555,
|
||||
user_prompt=None,
|
||||
)
|
||||
|
||||
assert result == {
|
||||
"status": "failed",
|
||||
"podcast_id": 8,
|
||||
"reason": "premium_quota_exhausted",
|
||||
}
|
||||
assert podcast.status == PodcastStatus.FAILED
|
||||
assert graph_invoked == [] # Graph never ran on denied reservation.
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolver_failure_marks_podcast_failed(monkeypatch):
|
||||
"""If the resolver raises (e.g. search-space deleted), the task fails
|
||||
cleanly without invoking the graph."""
|
||||
from app.db import PodcastStatus
|
||||
from app.tasks.celery_tasks import podcast_tasks
|
||||
|
||||
podcast = _make_podcast(podcast_id=9)
|
||||
session = _FakeSession(podcast)
|
||||
monkeypatch.setattr(
|
||||
podcast_tasks,
|
||||
"get_celery_session_maker",
|
||||
lambda: _FakeSessionMaker(session),
|
||||
)
|
||||
|
||||
async def _failing_resolver(sess, search_space_id, *, thread_id=None):
|
||||
raise ValueError("Search space 555 not found")
|
||||
|
||||
monkeypatch.setattr(
|
||||
podcast_tasks, "_resolve_agent_billing_for_search_space", _failing_resolver
|
||||
)
|
||||
|
||||
graph_invoked = []
|
||||
|
||||
async def _fake_graph_invoke(state, config):
|
||||
graph_invoked.append(True)
|
||||
return {}
|
||||
|
||||
monkeypatch.setattr(podcast_tasks.podcaster_graph, "ainvoke", _fake_graph_invoke)
|
||||
|
||||
result = await podcast_tasks._generate_content_podcast(
|
||||
podcast_id=9,
|
||||
source_content="hi",
|
||||
search_space_id=555,
|
||||
user_prompt=None,
|
||||
)
|
||||
|
||||
assert result == {
|
||||
"status": "failed",
|
||||
"podcast_id": 9,
|
||||
"reason": "billing_resolution_failed",
|
||||
}
|
||||
assert podcast.status == PodcastStatus.FAILED
|
||||
assert graph_invoked == []
|
||||
|
|
@ -0,0 +1,330 @@
|
|||
"""Unit tests for video-presentation Celery task billing integration.
|
||||
|
||||
Mirrors ``test_podcast_billing.py`` for the video-presentation task.
|
||||
Validates the same wrap-graph-in-billable_call pattern and ensures the
|
||||
larger ``QUOTA_DEFAULT_VIDEO_PRESENTATION_RESERVE_MICROS`` reservation is
|
||||
threaded through.
|
||||
|
||||
Coverage:
|
||||
|
||||
* Free config: graph runs, ``billable_call`` invoked with the video
|
||||
reserve override.
|
||||
* Premium config: same wiring with ``billing_tier='premium'``.
|
||||
* Quota denial: graph not invoked, row → FAILED, reason code surfaced.
|
||||
* Resolver failure: row → FAILED with ``billing_resolution_failed``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
from types import SimpleNamespace
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fakes
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _FakeExecResult:
|
||||
def __init__(self, obj):
|
||||
self._obj = obj
|
||||
|
||||
def scalars(self):
|
||||
return self
|
||||
|
||||
def first(self):
|
||||
return self._obj
|
||||
|
||||
def filter(self, *_args, **_kwargs):
|
||||
return self
|
||||
|
||||
|
||||
class _FakeSession:
|
||||
def __init__(self, video):
|
||||
self._video = video
|
||||
self.commit_count = 0
|
||||
|
||||
async def execute(self, _stmt):
|
||||
return _FakeExecResult(self._video)
|
||||
|
||||
async def commit(self):
|
||||
self.commit_count += 1
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, *args):
|
||||
return None
|
||||
|
||||
|
||||
class _FakeSessionMaker:
|
||||
def __init__(self, session: _FakeSession):
|
||||
self._session = session
|
||||
|
||||
def __call__(self):
|
||||
return self._session
|
||||
|
||||
|
||||
def _make_video(video_id: int = 11, thread_id: int = 99) -> SimpleNamespace:
|
||||
return SimpleNamespace(
|
||||
id=video_id,
|
||||
title="Test Presentation",
|
||||
thread_id=thread_id,
|
||||
status=None,
|
||||
slides=None,
|
||||
scene_codes=None,
|
||||
)
|
||||
|
||||
|
||||
_CALL_LOG: list[dict[str, Any]] = []
|
||||
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def _ok_billable_call(**kwargs):
|
||||
_CALL_LOG.append(kwargs)
|
||||
yield SimpleNamespace()
|
||||
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def _denying_billable_call(**kwargs):
|
||||
from app.services.billable_calls import QuotaInsufficientError
|
||||
|
||||
_CALL_LOG.append(kwargs)
|
||||
raise QuotaInsufficientError(
|
||||
usage_type=kwargs.get("usage_type", "?"),
|
||||
used_micros=5_000_000,
|
||||
limit_micros=5_000_000,
|
||||
remaining_micros=0,
|
||||
)
|
||||
yield SimpleNamespace() # pragma: no cover
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _reset_call_log():
|
||||
_CALL_LOG.clear()
|
||||
yield
|
||||
_CALL_LOG.clear()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_billable_call_invoked_with_correct_kwargs_for_free_config(monkeypatch):
|
||||
from app.config import config as app_config
|
||||
from app.db import VideoPresentationStatus
|
||||
from app.tasks.celery_tasks import video_presentation_tasks
|
||||
|
||||
video = _make_video(video_id=11, thread_id=99)
|
||||
session = _FakeSession(video)
|
||||
monkeypatch.setattr(
|
||||
video_presentation_tasks,
|
||||
"get_celery_session_maker",
|
||||
lambda: _FakeSessionMaker(session),
|
||||
)
|
||||
|
||||
user_id = uuid4()
|
||||
|
||||
async def _fake_resolver(sess, search_space_id, *, thread_id=None):
|
||||
assert search_space_id == 777
|
||||
assert thread_id == 99
|
||||
return user_id, "free", "openrouter/some-free-model"
|
||||
|
||||
monkeypatch.setattr(
|
||||
video_presentation_tasks,
|
||||
"_resolve_agent_billing_for_search_space",
|
||||
_fake_resolver,
|
||||
)
|
||||
monkeypatch.setattr(video_presentation_tasks, "billable_call", _ok_billable_call)
|
||||
|
||||
async def _fake_graph_invoke(state, config):
|
||||
return {"slides": [], "slide_audio_results": [], "slide_scene_codes": []}
|
||||
|
||||
monkeypatch.setattr(
|
||||
video_presentation_tasks.video_presentation_graph,
|
||||
"ainvoke",
|
||||
_fake_graph_invoke,
|
||||
)
|
||||
|
||||
result = await video_presentation_tasks._generate_video_presentation(
|
||||
video_presentation_id=11,
|
||||
source_content="content",
|
||||
search_space_id=777,
|
||||
user_prompt=None,
|
||||
)
|
||||
|
||||
assert result["status"] == "ready"
|
||||
assert result["video_presentation_id"] == 11
|
||||
assert video.status == VideoPresentationStatus.READY
|
||||
|
||||
assert len(_CALL_LOG) == 1
|
||||
call = _CALL_LOG[0]
|
||||
assert call["user_id"] == user_id
|
||||
assert call["search_space_id"] == 777
|
||||
assert call["billing_tier"] == "free"
|
||||
assert call["base_model"] == "openrouter/some-free-model"
|
||||
assert call["usage_type"] == "video_presentation_generation"
|
||||
assert (
|
||||
call["quota_reserve_micros_override"]
|
||||
== app_config.QUOTA_DEFAULT_VIDEO_PRESENTATION_RESERVE_MICROS
|
||||
)
|
||||
assert call["thread_id"] == 99
|
||||
assert call["call_details"] == {
|
||||
"video_presentation_id": 11,
|
||||
"title": "Test Presentation",
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_billable_call_invoked_with_premium_tier(monkeypatch):
|
||||
from app.tasks.celery_tasks import video_presentation_tasks
|
||||
|
||||
video = _make_video()
|
||||
session = _FakeSession(video)
|
||||
monkeypatch.setattr(
|
||||
video_presentation_tasks,
|
||||
"get_celery_session_maker",
|
||||
lambda: _FakeSessionMaker(session),
|
||||
)
|
||||
|
||||
user_id = uuid4()
|
||||
|
||||
async def _fake_resolver(sess, search_space_id, *, thread_id=None):
|
||||
return user_id, "premium", "gpt-5.4"
|
||||
|
||||
monkeypatch.setattr(
|
||||
video_presentation_tasks,
|
||||
"_resolve_agent_billing_for_search_space",
|
||||
_fake_resolver,
|
||||
)
|
||||
monkeypatch.setattr(video_presentation_tasks, "billable_call", _ok_billable_call)
|
||||
|
||||
async def _fake_graph_invoke(state, config):
|
||||
return {"slides": [], "slide_audio_results": [], "slide_scene_codes": []}
|
||||
|
||||
monkeypatch.setattr(
|
||||
video_presentation_tasks.video_presentation_graph,
|
||||
"ainvoke",
|
||||
_fake_graph_invoke,
|
||||
)
|
||||
|
||||
await video_presentation_tasks._generate_video_presentation(
|
||||
video_presentation_id=11,
|
||||
source_content="content",
|
||||
search_space_id=777,
|
||||
user_prompt=None,
|
||||
)
|
||||
|
||||
assert _CALL_LOG[0]["billing_tier"] == "premium"
|
||||
assert _CALL_LOG[0]["base_model"] == "gpt-5.4"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_quota_insufficient_marks_video_failed_and_skips_graph(monkeypatch):
|
||||
from app.db import VideoPresentationStatus
|
||||
from app.tasks.celery_tasks import video_presentation_tasks
|
||||
|
||||
video = _make_video(video_id=12)
|
||||
session = _FakeSession(video)
|
||||
monkeypatch.setattr(
|
||||
video_presentation_tasks,
|
||||
"get_celery_session_maker",
|
||||
lambda: _FakeSessionMaker(session),
|
||||
)
|
||||
|
||||
async def _fake_resolver(sess, search_space_id, *, thread_id=None):
|
||||
return uuid4(), "premium", "gpt-5.4"
|
||||
|
||||
monkeypatch.setattr(
|
||||
video_presentation_tasks,
|
||||
"_resolve_agent_billing_for_search_space",
|
||||
_fake_resolver,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
video_presentation_tasks, "billable_call", _denying_billable_call
|
||||
)
|
||||
|
||||
graph_invoked = []
|
||||
|
||||
async def _fake_graph_invoke(state, config):
|
||||
graph_invoked.append(True)
|
||||
return {}
|
||||
|
||||
monkeypatch.setattr(
|
||||
video_presentation_tasks.video_presentation_graph,
|
||||
"ainvoke",
|
||||
_fake_graph_invoke,
|
||||
)
|
||||
|
||||
result = await video_presentation_tasks._generate_video_presentation(
|
||||
video_presentation_id=12,
|
||||
source_content="content",
|
||||
search_space_id=777,
|
||||
user_prompt=None,
|
||||
)
|
||||
|
||||
assert result == {
|
||||
"status": "failed",
|
||||
"video_presentation_id": 12,
|
||||
"reason": "premium_quota_exhausted",
|
||||
}
|
||||
assert video.status == VideoPresentationStatus.FAILED
|
||||
assert graph_invoked == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolver_failure_marks_video_failed(monkeypatch):
|
||||
from app.db import VideoPresentationStatus
|
||||
from app.tasks.celery_tasks import video_presentation_tasks
|
||||
|
||||
video = _make_video(video_id=13)
|
||||
session = _FakeSession(video)
|
||||
monkeypatch.setattr(
|
||||
video_presentation_tasks,
|
||||
"get_celery_session_maker",
|
||||
lambda: _FakeSessionMaker(session),
|
||||
)
|
||||
|
||||
async def _failing_resolver(sess, search_space_id, *, thread_id=None):
|
||||
raise ValueError("Search space 777 not found")
|
||||
|
||||
monkeypatch.setattr(
|
||||
video_presentation_tasks,
|
||||
"_resolve_agent_billing_for_search_space",
|
||||
_failing_resolver,
|
||||
)
|
||||
|
||||
graph_invoked = []
|
||||
|
||||
async def _fake_graph_invoke(state, config):
|
||||
graph_invoked.append(True)
|
||||
return {}
|
||||
|
||||
monkeypatch.setattr(
|
||||
video_presentation_tasks.video_presentation_graph,
|
||||
"ainvoke",
|
||||
_fake_graph_invoke,
|
||||
)
|
||||
|
||||
result = await video_presentation_tasks._generate_video_presentation(
|
||||
video_presentation_id=13,
|
||||
source_content="content",
|
||||
search_space_id=777,
|
||||
user_prompt=None,
|
||||
)
|
||||
|
||||
assert result == {
|
||||
"status": "failed",
|
||||
"video_presentation_id": 13,
|
||||
"reason": "billing_resolution_failed",
|
||||
}
|
||||
assert video.status == VideoPresentationStatus.FAILED
|
||||
assert graph_invoked == []
|
||||
Loading…
Add table
Add a link
Reference in a new issue