diff --git a/docker/.env.example b/docker/.env.example
index 95de0cf85..c2e87a619 100644
--- a/docker/.env.example
+++ b/docker/.env.example
@@ -159,10 +159,13 @@ STRIPE_PAGE_BUYING_ENABLED=FALSE
# STRIPE_RECONCILIATION_LOOKBACK_MINUTES=10
# STRIPE_RECONCILIATION_BATCH_SIZE=100
-# Premium token purchases ($1 per 1M tokens for premium-tier models)
+# Premium credit purchases via Stripe ($1 buys 1_000_000 micro-USD of
+# credit; premium turns debit the actual per-call provider cost
+# reported by LiteLLM, so cheap and expensive models bill proportionally)
# STRIPE_TOKEN_BUYING_ENABLED=FALSE
# STRIPE_PREMIUM_TOKEN_PRICE_ID=price_...
-# STRIPE_TOKENS_PER_UNIT=1000000
+# STRIPE_CREDIT_MICROS_PER_UNIT=1000000
+# DEPRECATED — STRIPE_TOKENS_PER_UNIT=1000000
# ------------------------------------------------------------------------------
# TTS & STT (Text-to-Speech / Speech-to-Text)
@@ -315,9 +318,24 @@ STT_SERVICE=local/base
# Pages limit per user for ETL (default: unlimited)
# PAGES_LIMIT=500
-# Premium token quota per registered user (default: 5M)
-# Only applies to models with billing_tier=premium in global_llm_config.yaml
-# PREMIUM_TOKEN_LIMIT=5000000
+# Premium credit quota per registered user, in micro-USD (default: $5).
+# Premium turns are debited at the actual per-call provider cost reported
+# by LiteLLM. Only applies to models with billing_tier=premium.
+# PREMIUM_CREDIT_MICROS_LIMIT=5000000
+# DEPRECATED — PREMIUM_TOKEN_LIMIT=5000000
+
+# Safety ceiling on per-call premium reservation, in micro-USD ($1.00 default).
+# QUOTA_MAX_RESERVE_MICROS=1000000
+
+# Per-image reservation for POST /image-generations, in micro-USD ($0.05 default).
+# QUOTA_DEFAULT_IMAGE_RESERVE_MICROS=50000
+
+# Per-podcast reservation for the podcast Celery task ($0.20 default).
+# QUOTA_DEFAULT_PODCAST_RESERVE_MICROS=200000
+
+# Per-video-presentation reservation for the video Celery task ($1.00 default).
+# Override path bypasses QUOTA_MAX_RESERVE_MICROS clamp — raise with care.
+# QUOTA_DEFAULT_VIDEO_PRESENTATION_RESERVE_MICROS=1000000
# No-login (anonymous) mode — public users can chat without an account
# Set TRUE to enable /free pages and anonymous chat API
diff --git a/surfsense_backend/.env.example b/surfsense_backend/.env.example
index a793f33d1..1b1478ae6 100644
--- a/surfsense_backend/.env.example
+++ b/surfsense_backend/.env.example
@@ -54,11 +54,15 @@ STRIPE_PAGES_PER_UNIT=1000
# Set FALSE to disable new checkout session creation temporarily
STRIPE_PAGE_BUYING_ENABLED=TRUE
-# Premium token purchases via Stripe (for premium-tier model usage)
-# Set TRUE to allow users to buy premium token packs ($1 per 1M tokens)
+# Premium credit purchases via Stripe (for premium-tier model usage).
+# Each pack grants STRIPE_CREDIT_MICROS_PER_UNIT micro-USD of credit
+# (default 1_000_000 = $1.00). Premium turns are billed at the actual
+# per-call provider cost reported by LiteLLM.
STRIPE_TOKEN_BUYING_ENABLED=FALSE
STRIPE_PREMIUM_TOKEN_PRICE_ID=price_...
-STRIPE_TOKENS_PER_UNIT=1000000
+STRIPE_CREDIT_MICROS_PER_UNIT=1000000
+# DEPRECATED — use STRIPE_CREDIT_MICROS_PER_UNIT (1:1 numerical mapping):
+# STRIPE_TOKENS_PER_UNIT=1000000
# Periodic Stripe safety net for purchases left in PENDING (minutes old)
STRIPE_RECONCILIATION_LOOKBACK_MINUTES=10
@@ -184,9 +188,35 @@ VIDEO_PRESENTATION_DEFAULT_DURATION_IN_FRAMES=300
# (Optional) Maximum pages limit per user for ETL services (default: `999999999` for unlimited in OSS version)
PAGES_LIMIT=500
-# Premium token quota per registered user (default: 3,000,000)
-# Applies only to models with billing_tier=premium in global_llm_config.yaml
-PREMIUM_TOKEN_LIMIT=3000000
+# Premium credit quota per registered user, in micro-USD
+# (default: 5,000,000 == $5.00 of credit). Premium turns are debited at the
+# actual per-call provider cost reported by LiteLLM, so cheap and expensive
+# models bill proportionally. Applies only to models with
+# billing_tier=premium in global_llm_config.yaml.
+PREMIUM_CREDIT_MICROS_LIMIT=5000000
+# DEPRECATED — use PREMIUM_CREDIT_MICROS_LIMIT (1:1 numerical mapping):
+# PREMIUM_TOKEN_LIMIT=5000000
+
+# Safety ceiling on per-call premium reservation, in micro-USD.
+# stream_new_chat estimates an upper-bound cost from the model's
+# litellm-published per-token rates × the config's quota_reserve_tokens
+# and clamps to this value so a misconfigured model can't lock the
+# user's whole balance on one call. Default $1.00.
+QUOTA_MAX_RESERVE_MICROS=1000000
+
+# Per-image reservation (in micro-USD) for the POST /image-generations
+# endpoint. Bypassed for free configs. Default $0.05.
+QUOTA_DEFAULT_IMAGE_RESERVE_MICROS=50000
+
+# Per-podcast reservation (in micro-USD) used by the podcast Celery task.
+# Single envelope covers one transcript-generation LLM call. Default $0.20.
+QUOTA_DEFAULT_PODCAST_RESERVE_MICROS=200000
+
+# Per-video-presentation reservation (in micro-USD) used by the video
+# presentation Celery task. Covers worst-case fan-out of N slide-scene
+# generations + refines. Default $1.00. NOTE: tasks using the override
+# path bypass the QUOTA_MAX_RESERVE_MICROS clamp — raise with care.
+QUOTA_DEFAULT_VIDEO_PRESENTATION_RESERVE_MICROS=1000000
# No-login (anonymous) mode — allows public users to chat without an account
# Set TRUE to enable /free pages and anonymous chat API
diff --git a/surfsense_backend/alembic/versions/140_premium_tokens_to_credit_micros.py b/surfsense_backend/alembic/versions/140_premium_tokens_to_credit_micros.py
new file mode 100644
index 000000000..64aa699e8
--- /dev/null
+++ b/surfsense_backend/alembic/versions/140_premium_tokens_to_credit_micros.py
@@ -0,0 +1,291 @@
+"""rename premium token columns to credit micros and add cost_micros to token_usage
+
+Migrates the premium quota system from a flat token counter to a USD-cost
+based credit system, where 1 credit = 1 micro-USD ($0.000001).
+
+Column renames (1:1 numerical mapping — the prior $1 per 1M tokens Stripe
+price means every existing value is already correct in the new unit, no
+data transformation needed):
+
+ user.premium_tokens_limit -> premium_credit_micros_limit
+ user.premium_tokens_used -> premium_credit_micros_used
+ user.premium_tokens_reserved -> premium_credit_micros_reserved
+
+ premium_token_purchases.tokens_granted -> credit_micros_granted
+
+New column for cost auditing per turn:
+
+ token_usage.cost_micros (BigInteger NOT NULL DEFAULT 0)
+
+The "user" table is in zero_publication's column list (added in 139), so
+this migration must drop and recreate the publication with the renamed
+column names, otherwise zero-cache will replicate stale column names and
+the FE Zero schema will fail to bind.
+
+IMPORTANT - before AND after running this migration:
+ 1. Stop zero-cache (it holds replication locks that will deadlock DDL)
+ 2. Run: alembic upgrade head
+ 3. Delete / reset the zero-cache data volume
+ 4. Restart zero-cache (it will do a fresh initial sync)
+
+Skipping the zero-cache stop will deadlock at the ACCESS EXCLUSIVE LOCK on
+"user". Skipping the data-volume reset will leave IndexedDB clients seeing
+column-not-found errors from a stale catalog snapshot.
+
+Revision ID: 140
+Revises: 139
+"""
+
+from collections.abc import Sequence
+
+import sqlalchemy as sa
+
+from alembic import op
+
+revision: str = "140"
+down_revision: str | None = "139"
+branch_labels: str | Sequence[str] | None = None
+depends_on: str | Sequence[str] | None = None
+
+PUBLICATION_NAME = "zero_publication"
+
+# Replicates 139's document column list verbatim — must stay in sync.
+DOCUMENT_COLS = [
+ "id",
+ "title",
+ "document_type",
+ "search_space_id",
+ "folder_id",
+ "created_by_id",
+ "status",
+ "created_at",
+ "updated_at",
+]
+
+# Same five live-meter fields as 139, with the renamed column names.
+USER_COLS = [
+ "id",
+ "pages_limit",
+ "pages_used",
+ "premium_credit_micros_limit",
+ "premium_credit_micros_used",
+]
+
+
+def _terminate_blocked_pids(conn, table: str) -> None:
+ """Kill backends whose locks on *table* would block our AccessExclusiveLock."""
+ conn.execute(
+ sa.text(
+ "SELECT pg_terminate_backend(l.pid) "
+ "FROM pg_locks l "
+ "JOIN pg_class c ON c.oid = l.relation "
+ "WHERE c.relname = :tbl "
+ " AND l.pid != pg_backend_pid()"
+ ),
+ {"tbl": table},
+ )
+
+
+def _has_zero_version(conn, table: str) -> bool:
+ return (
+ conn.execute(
+ sa.text(
+ "SELECT 1 FROM information_schema.columns "
+ "WHERE table_name = :tbl AND column_name = '_0_version'"
+ ),
+ {"tbl": table},
+ ).fetchone()
+ is not None
+ )
+
+
+def _column_exists(conn, table: str, column: str) -> bool:
+ return (
+ conn.execute(
+ sa.text(
+ "SELECT 1 FROM information_schema.columns "
+ "WHERE table_name = :tbl AND column_name = :col"
+ ),
+ {"tbl": table, "col": column},
+ ).fetchone()
+ is not None
+ )
+
+
+def _build_publication_ddl(
+ user_cols: list[str],
+ *,
+ documents_has_zero_ver: bool,
+ user_has_zero_ver: bool,
+) -> str:
+ doc_cols = DOCUMENT_COLS + (['"_0_version"'] if documents_has_zero_ver else [])
+ user_col_list_with_meta = user_cols + (
+ ['"_0_version"'] if user_has_zero_ver else []
+ )
+ doc_col_list = ", ".join(doc_cols)
+ user_col_list = ", ".join(user_col_list_with_meta)
+ return (
+ f"CREATE PUBLICATION {PUBLICATION_NAME} FOR TABLE "
+ f"notifications, "
+ f"documents ({doc_col_list}), "
+ f"folders, "
+ f"search_source_connectors, "
+ f"new_chat_messages, "
+ f"chat_comments, "
+ f"chat_session_state, "
+ f'"user" ({user_col_list})'
+ )
+
+
+def upgrade() -> None:
+ conn = op.get_bind()
+
+ # ------------------------------------------------------------------
+ # 1. Add cost_micros to token_usage. Idempotent guard so re-runs in
+ # dev environments are safe.
+ # ------------------------------------------------------------------
+ if not _column_exists(conn, "token_usage", "cost_micros"):
+ op.add_column(
+ "token_usage",
+ sa.Column(
+ "cost_micros",
+ sa.BigInteger(),
+ nullable=False,
+ server_default="0",
+ ),
+ )
+
+ # ------------------------------------------------------------------
+ # 2. Rename premium_token_purchases.tokens_granted -> credit_micros_granted.
+ # ------------------------------------------------------------------
+ if _column_exists(
+ conn, "premium_token_purchases", "tokens_granted"
+ ) and not _column_exists(conn, "premium_token_purchases", "credit_micros_granted"):
+ op.alter_column(
+ "premium_token_purchases",
+ "tokens_granted",
+ new_column_name="credit_micros_granted",
+ )
+
+ # ------------------------------------------------------------------
+ # 3. Rename user.premium_tokens_* -> premium_credit_micros_*.
+ #
+ # We must drop the publication first (it references the old column
+ # names) and re-acquire the lock for DDL. asyncpg requires LOCK TABLE
+ # in a transaction block; alembic's outer transaction already holds
+ # one, but a SAVEPOINT keeps the LOCK + DDL atomic.
+ # ------------------------------------------------------------------
+ tx = conn.begin_nested() if conn.in_transaction() else conn.begin()
+ with tx:
+ conn.execute(sa.text("SET lock_timeout = '10s'"))
+
+ _terminate_blocked_pids(conn, "user")
+ conn.execute(sa.text('LOCK TABLE "user" IN ACCESS EXCLUSIVE MODE'))
+
+ # Re-assert REPLICA IDENTITY DEFAULT for safety; column-list
+ # publications require at least the PK to be in the column list,
+ # which is true for both the old and new shape.
+ conn.execute(sa.text('ALTER TABLE "user" REPLICA IDENTITY DEFAULT'))
+
+ # Drop the publication BEFORE renaming columns, otherwise Postgres
+ # rejects the rename: "cannot drop column ... referenced by
+ # publication".
+ conn.execute(sa.text(f"DROP PUBLICATION IF EXISTS {PUBLICATION_NAME}"))
+
+ for old, new in (
+ ("premium_tokens_limit", "premium_credit_micros_limit"),
+ ("premium_tokens_used", "premium_credit_micros_used"),
+ ("premium_tokens_reserved", "premium_credit_micros_reserved"),
+ ):
+ if _column_exists(conn, "user", old) and not _column_exists(
+ conn, "user", new
+ ):
+ op.alter_column("user", old, new_column_name=new)
+
+ # Update the server_default on the renamed limit column so newly
+ # inserted users get $5 of credit (== 5_000_000 micros) by
+ # default. Existing rows are unaffected.
+ op.alter_column(
+ "user",
+ "premium_credit_micros_limit",
+ server_default="5000000",
+ )
+
+ # Recreate the publication with the new column names.
+ documents_has_zero_ver = _has_zero_version(conn, "documents")
+ user_has_zero_ver = _has_zero_version(conn, "user")
+ conn.execute(
+ sa.text(
+ _build_publication_ddl(
+ USER_COLS,
+ documents_has_zero_ver=documents_has_zero_ver,
+ user_has_zero_ver=user_has_zero_ver,
+ )
+ )
+ )
+
+
+def downgrade() -> None:
+ """Revert the rename and drop ``cost_micros``.
+
+ Mirrors ``upgrade``: drop the publication, rename columns back, drop
+ the new column, recreate the publication with the old column list.
+ Same zero-cache stop/reset runbook applies in reverse.
+ """
+ conn = op.get_bind()
+
+ tx = conn.begin_nested() if conn.in_transaction() else conn.begin()
+ with tx:
+ conn.execute(sa.text("SET lock_timeout = '10s'"))
+
+ _terminate_blocked_pids(conn, "user")
+ conn.execute(sa.text('LOCK TABLE "user" IN ACCESS EXCLUSIVE MODE'))
+
+ conn.execute(sa.text(f"DROP PUBLICATION IF EXISTS {PUBLICATION_NAME}"))
+
+ for new, old in (
+ ("premium_credit_micros_limit", "premium_tokens_limit"),
+ ("premium_credit_micros_used", "premium_tokens_used"),
+ ("premium_credit_micros_reserved", "premium_tokens_reserved"),
+ ):
+ if _column_exists(conn, "user", new) and not _column_exists(
+ conn, "user", old
+ ):
+ op.alter_column("user", new, new_column_name=old)
+
+ op.alter_column(
+ "user",
+ "premium_tokens_limit",
+ server_default="5000000",
+ )
+
+ legacy_user_cols = [
+ "id",
+ "pages_limit",
+ "pages_used",
+ "premium_tokens_limit",
+ "premium_tokens_used",
+ ]
+ documents_has_zero_ver = _has_zero_version(conn, "documents")
+ user_has_zero_ver = _has_zero_version(conn, "user")
+ conn.execute(
+ sa.text(
+ _build_publication_ddl(
+ legacy_user_cols,
+ documents_has_zero_ver=documents_has_zero_ver,
+ user_has_zero_ver=user_has_zero_ver,
+ )
+ )
+ )
+
+ if _column_exists(
+ conn, "premium_token_purchases", "credit_micros_granted"
+ ) and not _column_exists(conn, "premium_token_purchases", "tokens_granted"):
+ op.alter_column(
+ "premium_token_purchases",
+ "credit_micros_granted",
+ new_column_name="tokens_granted",
+ )
+
+ if _column_exists(conn, "token_usage", "cost_micros"):
+ op.drop_column("token_usage", "cost_micros")
diff --git a/surfsense_backend/app/app.py b/surfsense_backend/app/app.py
index 016c2de42..14d7f4d23 100644
--- a/surfsense_backend/app/app.py
+++ b/surfsense_backend/app/app.py
@@ -31,6 +31,7 @@ from app.config import (
initialize_image_gen_router,
initialize_llm_router,
initialize_openrouter_integration,
+ initialize_pricing_registration,
initialize_vision_llm_router,
)
from app.db import User, create_db_and_tables, get_async_session
@@ -432,6 +433,7 @@ async def lifespan(app: FastAPI):
await setup_checkpointer_tables()
initialize_openrouter_integration()
_start_openrouter_background_refresh()
+ initialize_pricing_registration()
initialize_llm_router()
initialize_image_gen_router()
initialize_vision_llm_router()
diff --git a/surfsense_backend/app/celery_app.py b/surfsense_backend/app/celery_app.py
index 58a8b0f39..74710d5e1 100644
--- a/surfsense_backend/app/celery_app.py
+++ b/surfsense_backend/app/celery_app.py
@@ -22,10 +22,12 @@ def init_worker(**kwargs):
initialize_image_gen_router,
initialize_llm_router,
initialize_openrouter_integration,
+ initialize_pricing_registration,
initialize_vision_llm_router,
)
initialize_openrouter_integration()
+ initialize_pricing_registration()
initialize_llm_router()
initialize_image_gen_router()
initialize_vision_llm_router()
diff --git a/surfsense_backend/app/config/__init__.py b/surfsense_backend/app/config/__init__.py
index 675b05d2c..2aeeafb34 100644
--- a/surfsense_backend/app/config/__init__.py
+++ b/surfsense_backend/app/config/__init__.py
@@ -138,7 +138,11 @@ def load_global_image_gen_configs():
try:
with open(global_config_file, encoding="utf-8") as f:
data = yaml.safe_load(f)
- return data.get("global_image_generation_configs", [])
+ configs = data.get("global_image_generation_configs", []) or []
+ for cfg in configs:
+ if isinstance(cfg, dict):
+ cfg.setdefault("billing_tier", "free")
+ return configs
except Exception as e:
print(f"Warning: Failed to load global image generation configs: {e}")
return []
@@ -153,7 +157,11 @@ def load_global_vision_llm_configs():
try:
with open(global_config_file, encoding="utf-8") as f:
data = yaml.safe_load(f)
- return data.get("global_vision_llm_configs", [])
+ configs = data.get("global_vision_llm_configs", []) or []
+ for cfg in configs:
+ if isinstance(cfg, dict):
+ cfg.setdefault("billing_tier", "free")
+ return configs
except Exception as e:
print(f"Warning: Failed to load global vision LLM configs: {e}")
return []
@@ -254,6 +262,15 @@ def load_openrouter_integration_settings() -> dict | None:
"anonymous_enabled_free", settings["anonymous_enabled"]
)
+ # Image generation + vision LLM emission are opt-in (issue L).
+ # OpenRouter's catalogue contains hundreds of image / vision
+ # capable models; auto-injecting all of them into every
+ # deployment would explode the model selector and surprise
+ # operators upgrading from prior versions. Default to False so
+ # admins must explicitly turn them on.
+ settings.setdefault("image_generation_enabled", False)
+ settings.setdefault("vision_enabled", False)
+
return settings
except Exception as e:
print(f"Warning: Failed to load OpenRouter integration settings: {e}")
@@ -296,10 +313,60 @@ def initialize_openrouter_integration():
)
else:
print("Info: OpenRouter integration enabled but no models fetched")
+
+ # Image generation + vision LLM emissions are opt-in (issue L).
+ # Both reuse the catalogue already cached by ``service.initialize``
+ # so we don't make additional network calls here.
+ if settings.get("image_generation_enabled"):
+ try:
+ image_configs = service.get_image_generation_configs()
+ if image_configs:
+ config.GLOBAL_IMAGE_GEN_CONFIGS.extend(image_configs)
+ print(
+ f"Info: OpenRouter integration added {len(image_configs)} "
+ f"image-generation models"
+ )
+ except Exception as e:
+ print(f"Warning: Failed to inject OpenRouter image-gen configs: {e}")
+
+ if settings.get("vision_enabled"):
+ try:
+ vision_configs = service.get_vision_llm_configs()
+ if vision_configs:
+ config.GLOBAL_VISION_LLM_CONFIGS.extend(vision_configs)
+ print(
+ f"Info: OpenRouter integration added {len(vision_configs)} "
+ f"vision LLM models"
+ )
+ except Exception as e:
+ print(f"Warning: Failed to inject OpenRouter vision-LLM configs: {e}")
except Exception as e:
print(f"Warning: Failed to initialize OpenRouter integration: {e}")
+def initialize_pricing_registration():
+ """
+ Teach LiteLLM the per-token cost of every deployment in
+ ``config.GLOBAL_LLM_CONFIGS`` (OpenRouter dynamic models pulled
+ from the OpenRouter catalogue + any operator-declared YAML pricing).
+
+ Must run AFTER ``initialize_openrouter_integration()`` so the
+ OpenRouter catalogue is populated and BEFORE the first LLM call so
+ ``response_cost`` is available in ``TokenTrackingCallback``.
+
+ Failures are logged but never raised — startup must not be blocked
+ by a missing pricing entry; the worst-case is the model debits 0.
+ """
+ try:
+ from app.services.pricing_registration import (
+ register_pricing_from_global_configs,
+ )
+
+ register_pricing_from_global_configs()
+ except Exception as e:
+ print(f"Warning: Failed to register LiteLLM pricing: {e}")
+
+
def initialize_llm_router():
"""
Initialize the LLM Router service for Auto mode.
@@ -444,14 +511,54 @@ class Config:
os.getenv("STRIPE_RECONCILIATION_BATCH_SIZE", "100")
)
- # Premium token quota settings
- PREMIUM_TOKEN_LIMIT = int(os.getenv("PREMIUM_TOKEN_LIMIT", "3000000"))
+ # Premium credit (micro-USD) quota settings.
+ #
+ # Storage unit is integer micro-USD (1_000_000 = $1.00). The legacy
+ # ``PREMIUM_TOKEN_LIMIT`` and ``STRIPE_TOKENS_PER_UNIT`` env vars are
+ # still honoured for one release as fall-back values — the prior
+ # $1-per-1M-tokens Stripe price means every existing value maps 1:1
+ # to micros, so operators upgrading without changing their .env still
+ # get correct behaviour. A startup deprecation warning fires below if
+ # they're set.
+ PREMIUM_CREDIT_MICROS_LIMIT = int(
+ os.getenv("PREMIUM_CREDIT_MICROS_LIMIT")
+ or os.getenv("PREMIUM_TOKEN_LIMIT", "5000000")
+ )
STRIPE_PREMIUM_TOKEN_PRICE_ID = os.getenv("STRIPE_PREMIUM_TOKEN_PRICE_ID")
- STRIPE_TOKENS_PER_UNIT = int(os.getenv("STRIPE_TOKENS_PER_UNIT", "1000000"))
+ STRIPE_CREDIT_MICROS_PER_UNIT = int(
+ os.getenv("STRIPE_CREDIT_MICROS_PER_UNIT")
+ or os.getenv("STRIPE_TOKENS_PER_UNIT", "1000000")
+ )
STRIPE_TOKEN_BUYING_ENABLED = (
os.getenv("STRIPE_TOKEN_BUYING_ENABLED", "FALSE").upper() == "TRUE"
)
+ # Safety ceiling on the per-call premium reservation. ``stream_new_chat``
+ # estimates an upper-bound cost from ``litellm.get_model_info`` x the
+ # config's ``quota_reserve_tokens`` and clamps the result to this value
+ # so a misconfigured "$1000/M" model can't lock the user's whole balance
+ # on one call. Default $1.00 covers realistic worst-cases (Opus + 4K
+ # reserve_tokens ≈ $0.36) with headroom.
+ QUOTA_MAX_RESERVE_MICROS = int(os.getenv("QUOTA_MAX_RESERVE_MICROS", "1000000"))
+
+ if os.getenv("PREMIUM_TOKEN_LIMIT") and not os.getenv(
+ "PREMIUM_CREDIT_MICROS_LIMIT"
+ ):
+ print(
+ "Warning: PREMIUM_TOKEN_LIMIT is deprecated; rename to "
+ "PREMIUM_CREDIT_MICROS_LIMIT (1:1 numerical mapping under the "
+ "current Stripe price). The old key will be removed in a "
+ "future release."
+ )
+ if os.getenv("STRIPE_TOKENS_PER_UNIT") and not os.getenv(
+ "STRIPE_CREDIT_MICROS_PER_UNIT"
+ ):
+ print(
+ "Warning: STRIPE_TOKENS_PER_UNIT is deprecated; rename to "
+ "STRIPE_CREDIT_MICROS_PER_UNIT (1:1 numerical mapping). "
+ "The old key will be removed in a future release."
+ )
+
# Anonymous / no-login mode settings
NOLOGIN_MODE_ENABLED = os.getenv("NOLOGIN_MODE_ENABLED", "FALSE").upper() == "TRUE"
ANON_TOKEN_LIMIT = int(os.getenv("ANON_TOKEN_LIMIT", "500000"))
@@ -464,6 +571,35 @@ class Config:
# Default quota reserve tokens when not specified per-model
QUOTA_MAX_RESERVE_PER_CALL = int(os.getenv("QUOTA_MAX_RESERVE_PER_CALL", "8000"))
+ # Per-image reservation (in micro-USD) used by ``billable_call`` for the
+ # ``POST /image-generations`` endpoint when the global config does not
+ # override it. $0.05 covers realistic worst-cases for current OpenAI /
+ # OpenRouter image-gen pricing. Bypassed entirely for free configs.
+ QUOTA_DEFAULT_IMAGE_RESERVE_MICROS = int(
+ os.getenv("QUOTA_DEFAULT_IMAGE_RESERVE_MICROS", "50000")
+ )
+
+ # Per-podcast reservation (in micro-USD). One agent LLM call generating
+ # a transcript, typically 5k-20k completion tokens. $0.20 covers a long
+ # premium-model run. Tune via env.
+ QUOTA_DEFAULT_PODCAST_RESERVE_MICROS = int(
+ os.getenv("QUOTA_DEFAULT_PODCAST_RESERVE_MICROS", "200000")
+ )
+
+ # Per-video-presentation reservation (in micro-USD). Fan-out of N
+ # slide-scene generations (up to ``VIDEO_PRESENTATION_MAX_SLIDES=30``)
+ # plus refine retries; can produce many premium completions. $1.00
+ # covers worst-case. Tune via env.
+ #
+ # NOTE: this equals the existing ``QUOTA_MAX_RESERVE_MICROS`` default of
+ # 1_000_000. The override path in ``billable_call`` bypasses the
+ # per-call clamp in ``estimate_call_reserve_micros``, so this is the
+ # *actual* hold — raising it via env is fine but means a single video
+ # task can lock $1+ of credit.
+ QUOTA_DEFAULT_VIDEO_PRESENTATION_RESERVE_MICROS = int(
+ os.getenv("QUOTA_DEFAULT_VIDEO_PRESENTATION_RESERVE_MICROS", "1000000")
+ )
+
# Abuse prevention: concurrent stream cap and CAPTCHA
ANON_MAX_CONCURRENT_STREAMS = int(os.getenv("ANON_MAX_CONCURRENT_STREAMS", "2"))
ANON_CAPTCHA_REQUEST_THRESHOLD = int(
diff --git a/surfsense_backend/app/config/global_llm_config.example.yaml b/surfsense_backend/app/config/global_llm_config.example.yaml
index 79cbe1e51..d92640c8d 100644
--- a/surfsense_backend/app/config/global_llm_config.example.yaml
+++ b/surfsense_backend/app/config/global_llm_config.example.yaml
@@ -19,6 +19,24 @@
# Structure matches NewLLMConfig:
# - Model configuration (provider, model_name, api_key, etc.)
# - Prompt configuration (system_instructions, citations_enabled)
+#
+# COST-BASED PREMIUM CREDITS:
+# Each premium config bills the user's USD-credit balance based on the
+# actual provider cost reported by LiteLLM. For models LiteLLM already
+# knows (most OpenAI/Anthropic/etc. names) you don't need to do anything.
+# For custom Azure deployment names (e.g. an in-house "gpt-5.4" deployment)
+# or any model LiteLLM doesn't have in its built-in pricing table, declare
+# per-token costs inline so they bill correctly:
+#
+# litellm_params:
+# base_model: "my-custom-azure-deploy"
+# # USD per token; e.g. 0.000003 == $3.00 per million input tokens
+# input_cost_per_token: 0.000003
+# output_cost_per_token: 0.000015
+#
+# OpenRouter dynamic models pull pricing automatically from OpenRouter's
+# API — no inline declaration needed. Models without resolvable pricing
+# debit $0 from the user's balance and log a WARNING.
# Router Settings for Auto Mode
# These settings control how the LiteLLM Router distributes requests across models
@@ -292,6 +310,17 @@ openrouter_integration:
free_rpm: 20
free_tpm: 100000
+ # Image generation + vision LLM emission are OPT-IN. OpenRouter's catalogue
+ # contains hundreds of image- and vision-capable models; turning these on
+ # injects them into the global Image-Generation / Vision-LLM model
+ # selectors alongside any static configs. Tier (free/premium) is derived
+ # per model the same way it is for chat (`:free` suffix or zero pricing).
+ # When a user picks a premium image/vision model the call debits the
+ # shared $5 USD-cost-based premium credit pool — so leaving these off
+ # avoids surprise quota burn on existing deployments. Default: false.
+ image_generation_enabled: false
+ vision_enabled: false
+
litellm_params:
max_tokens: 16384
system_instructions: ""
diff --git a/surfsense_backend/app/db.py b/surfsense_backend/app/db.py
index 2fe478d9b..aef959ec9 100644
--- a/surfsense_backend/app/db.py
+++ b/surfsense_backend/app/db.py
@@ -731,6 +731,7 @@ class TokenUsage(BaseModel, TimestampMixin):
prompt_tokens = Column(Integer, nullable=False, default=0)
completion_tokens = Column(Integer, nullable=False, default=0)
total_tokens = Column(Integer, nullable=False, default=0)
+ cost_micros = Column(BigInteger, nullable=False, default=0, server_default="0")
model_breakdown = Column(JSONB, nullable=True)
call_details = Column(JSONB, nullable=True)
@@ -1793,7 +1794,15 @@ class PagePurchase(Base, TimestampMixin):
class PremiumTokenPurchase(Base, TimestampMixin):
- """Tracks Stripe checkout sessions used to grant additional premium token credits."""
+ """Tracks Stripe checkout sessions used to grant additional premium credit (USD micro-units).
+
+ Note: the table name is preserved (``premium_token_purchases``) for
+ operational continuity even though the unit is now USD micro-credits
+ instead of raw tokens. The ``credit_micros_granted`` column replaced
+ the legacy ``tokens_granted`` in migration 140; the stored values
+ were not transformed because the prior $1 = 1M tokens Stripe price
+ makes the unit conversion 1:1 numerically.
+ """
__tablename__ = "premium_token_purchases"
__allow_unmapped__ = True
@@ -1810,7 +1819,7 @@ class PremiumTokenPurchase(Base, TimestampMixin):
)
stripe_payment_intent_id = Column(String(255), nullable=True, index=True)
quantity = Column(Integer, nullable=False)
- tokens_granted = Column(BigInteger, nullable=False)
+ credit_micros_granted = Column(BigInteger, nullable=False)
amount_total = Column(Integer, nullable=True)
currency = Column(String(10), nullable=True)
status = Column(
@@ -2109,16 +2118,16 @@ if config.AUTH_TYPE == "GOOGLE":
)
pages_used = Column(Integer, nullable=False, default=0, server_default="0")
- premium_tokens_limit = Column(
+ premium_credit_micros_limit = Column(
BigInteger,
nullable=False,
- default=config.PREMIUM_TOKEN_LIMIT,
- server_default=str(config.PREMIUM_TOKEN_LIMIT),
+ default=config.PREMIUM_CREDIT_MICROS_LIMIT,
+ server_default=str(config.PREMIUM_CREDIT_MICROS_LIMIT),
)
- premium_tokens_used = Column(
+ premium_credit_micros_used = Column(
BigInteger, nullable=False, default=0, server_default="0"
)
- premium_tokens_reserved = Column(
+ premium_credit_micros_reserved = Column(
BigInteger, nullable=False, default=0, server_default="0"
)
@@ -2241,16 +2250,16 @@ else:
)
pages_used = Column(Integer, nullable=False, default=0, server_default="0")
- premium_tokens_limit = Column(
+ premium_credit_micros_limit = Column(
BigInteger,
nullable=False,
- default=config.PREMIUM_TOKEN_LIMIT,
- server_default=str(config.PREMIUM_TOKEN_LIMIT),
+ default=config.PREMIUM_CREDIT_MICROS_LIMIT,
+ server_default=str(config.PREMIUM_CREDIT_MICROS_LIMIT),
)
- premium_tokens_used = Column(
+ premium_credit_micros_used = Column(
BigInteger, nullable=False, default=0, server_default="0"
)
- premium_tokens_reserved = Column(
+ premium_credit_micros_reserved = Column(
BigInteger, nullable=False, default=0, server_default="0"
)
diff --git a/surfsense_backend/app/etl_pipeline/etl_pipeline_service.py b/surfsense_backend/app/etl_pipeline/etl_pipeline_service.py
index 4bb38b7b0..d45bd780c 100644
--- a/surfsense_backend/app/etl_pipeline/etl_pipeline_service.py
+++ b/surfsense_backend/app/etl_pipeline/etl_pipeline_service.py
@@ -68,12 +68,25 @@ class EtlPipelineService:
etl_service="VISION_LLM",
content_type="image",
)
- except Exception:
- logging.warning(
- "Vision LLM failed for %s, falling back to document parser",
- request.filename,
- exc_info=True,
- )
+ except Exception as exc:
+ # Special-case quota exhaustion so we log a clearer message
+ # — the vision LLM didn't "fail", the user just ran out of
+ # premium credit. Falling through to the document parser
+ # is a graceful degradation: OCR/Unstructured still
+ # extracts text from the image without burning credit.
+ from app.services.billable_calls import QuotaInsufficientError
+
+ if isinstance(exc, QuotaInsufficientError):
+ logging.info(
+ "Vision LLM quota exhausted for %s; falling back to document parser",
+ request.filename,
+ )
+ else:
+ logging.warning(
+ "Vision LLM failed for %s, falling back to document parser",
+ request.filename,
+ exc_info=True,
+ )
else:
logging.info(
"No vision LLM provided, falling back to document parser for %s",
diff --git a/surfsense_backend/app/routes/image_generation_routes.py b/surfsense_backend/app/routes/image_generation_routes.py
index 97a3559b9..34ed80207 100644
--- a/surfsense_backend/app/routes/image_generation_routes.py
+++ b/surfsense_backend/app/routes/image_generation_routes.py
@@ -36,6 +36,11 @@ from app.schemas import (
ImageGenerationListRead,
ImageGenerationRead,
)
+from app.services.billable_calls import (
+ DEFAULT_IMAGE_RESERVE_MICROS,
+ QuotaInsufficientError,
+ billable_call,
+)
from app.services.image_gen_router_service import (
IMAGE_GEN_AUTO_MODE_ID,
ImageGenRouterService,
@@ -92,6 +97,50 @@ def _build_model_string(
return f"{prefix}/{model_name}"
+async def _resolve_billing_for_image_gen(
+ session: AsyncSession,
+ config_id: int | None,
+ search_space: SearchSpace,
+) -> tuple[str, str, int]:
+ """Resolve ``(billing_tier, base_model, reserve_micros)`` for a request.
+
+ The resolution mirrors ``_execute_image_generation``'s lookup tree but
+ only extracts the fields needed for billing — we do this *before*
+ ``billable_call`` so the reservation is correctly sized for the
+ config that will actually run, and so we don't open an
+ ``ImageGeneration`` row for a request that's about to 402.
+
+ User-owned (positive ID) BYOK configs are always free — they cost
+ the user nothing on our side. Auto mode currently treats as free
+ because the underlying router can dispatch to either premium or
+ free YAML configs and we don't surface the resolved deployment up
+ here yet. Bringing Auto under premium billing would require
+ threading the chosen deployment back from ``ImageGenRouterService``.
+ """
+ resolved_id = config_id
+ if resolved_id is None:
+ resolved_id = search_space.image_generation_config_id or IMAGE_GEN_AUTO_MODE_ID
+
+ if is_image_gen_auto_mode(resolved_id):
+ return ("free", "auto", DEFAULT_IMAGE_RESERVE_MICROS)
+
+ if resolved_id < 0:
+ cfg = _get_global_image_gen_config(resolved_id) or {}
+ billing_tier = str(cfg.get("billing_tier", "free")).lower()
+ base_model = _build_model_string(
+ cfg.get("provider", ""),
+ cfg.get("model_name", ""),
+ cfg.get("custom_provider"),
+ )
+ reserve_micros = int(
+ cfg.get("quota_reserve_micros") or DEFAULT_IMAGE_RESERVE_MICROS
+ )
+ return (billing_tier, base_model, reserve_micros)
+
+ # Positive ID = user-owned BYOK image-gen config — always free.
+ return ("free", "user_byok", DEFAULT_IMAGE_RESERVE_MICROS)
+
+
async def _execute_image_generation(
session: AsyncSession,
image_gen: ImageGeneration,
@@ -225,6 +274,9 @@ async def get_global_image_gen_configs(
"litellm_params": {},
"is_global": True,
"is_auto_mode": True,
+ # Auto mode currently treated as free until per-deployment
+ # billing-tier surfacing lands (see _resolve_billing_for_image_gen).
+ "billing_tier": "free",
}
)
@@ -241,6 +293,8 @@ async def get_global_image_gen_configs(
"api_version": cfg.get("api_version") or None,
"litellm_params": cfg.get("litellm_params", {}),
"is_global": True,
+ "billing_tier": cfg.get("billing_tier", "free"),
+ "quota_reserve_micros": cfg.get("quota_reserve_micros"),
}
)
@@ -454,7 +508,26 @@ async def create_image_generation(
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
- """Create and execute an image generation request."""
+ """Create and execute an image generation request.
+
+ Premium configs are gated by the user's shared premium credit pool.
+ The flow is:
+
+ 1. Permission check + load the search space (cheap, no provider call).
+ 2. Resolve which config will run so we know its billing tier and the
+ worst-case reservation size *before* opening any DB rows.
+ 3. Wrap the entire ImageGeneration row insert + provider call in
+ ``billable_call``. If quota is denied, ``billable_call`` raises
+ ``QuotaInsufficientError`` *before* we flush a row, which we
+ translate to HTTP 402 (no orphaned rows on the user's account,
+ no inserted error rows for "you ran out of credit").
+ 4. On success, the actual ``response_cost`` flows through the
+ LiteLLM callback into the accumulator, and ``billable_call``
+ finalizes the debit at exit. Inner ``try/except`` still catches
+ provider errors and stores them on ``error_message`` (HTTP 200
+ with ``error_message`` set is preserved for failed-but-not-quota
+ scenarios — clients already know how to surface those).
+ """
try:
await check_permission(
session,
@@ -471,33 +544,70 @@ async def create_image_generation(
if not search_space:
raise HTTPException(status_code=404, detail="Search space not found")
- db_image_gen = ImageGeneration(
- prompt=data.prompt,
- model=data.model,
- n=data.n,
- quality=data.quality,
- size=data.size,
- style=data.style,
- response_format=data.response_format,
- image_generation_config_id=data.image_generation_config_id,
- search_space_id=data.search_space_id,
- created_by_id=user.id,
+ billing_tier, base_model, reserve_micros = await _resolve_billing_for_image_gen(
+ session, data.image_generation_config_id, search_space
)
- session.add(db_image_gen)
- await session.flush()
- try:
- await _execute_image_generation(session, db_image_gen, search_space)
- except Exception as e:
- logger.exception("Image generation call failed")
- db_image_gen.error_message = str(e)
+ # billable_call runs OUTSIDE the inner try/except so QuotaInsufficientError
+ # propagates to the outer ``except QuotaInsufficientError`` handler
+ # below as HTTP 402 — it is intentionally NOT swallowed into
+ # ``error_message`` because that would (1) imply a successful row
+ # exists when none does, and (2) return HTTP 200 to a client
+ # whose request was actively *denied* (issue K).
+ async with billable_call(
+ user_id=search_space.user_id,
+ search_space_id=data.search_space_id,
+ billing_tier=billing_tier,
+ base_model=base_model,
+ quota_reserve_micros_override=reserve_micros,
+ usage_type="image_generation",
+ call_details={"model": base_model, "prompt": data.prompt[:100]},
+ ):
+ db_image_gen = ImageGeneration(
+ prompt=data.prompt,
+ model=data.model,
+ n=data.n,
+ quality=data.quality,
+ size=data.size,
+ style=data.style,
+ response_format=data.response_format,
+ image_generation_config_id=data.image_generation_config_id,
+ search_space_id=data.search_space_id,
+ created_by_id=user.id,
+ )
+ session.add(db_image_gen)
+ await session.flush()
- await session.commit()
- await session.refresh(db_image_gen)
- return db_image_gen
+ try:
+ await _execute_image_generation(session, db_image_gen, search_space)
+ except Exception as e:
+ logger.exception("Image generation call failed")
+ db_image_gen.error_message = str(e)
+
+ await session.commit()
+ await session.refresh(db_image_gen)
+ return db_image_gen
except HTTPException:
raise
+ except QuotaInsufficientError as exc:
+ # The user's premium credit pool is empty. No DB row is created
+ # because ``billable_call`` denies before yielding (issue K).
+ await session.rollback()
+ raise HTTPException(
+ status_code=402,
+ detail={
+ "error_code": "premium_quota_exhausted",
+ "usage_type": exc.usage_type,
+ "used_micros": exc.used_micros,
+ "limit_micros": exc.limit_micros,
+ "remaining_micros": exc.remaining_micros,
+ "message": (
+ "Out of premium credits for image generation. "
+ "Purchase additional credits or switch to a free model."
+ ),
+ },
+ ) from exc
except SQLAlchemyError:
await session.rollback()
raise HTTPException(
diff --git a/surfsense_backend/app/routes/new_chat_routes.py b/surfsense_backend/app/routes/new_chat_routes.py
index 28b197ca2..d3bd51129 100644
--- a/surfsense_backend/app/routes/new_chat_routes.py
+++ b/surfsense_backend/app/routes/new_chat_routes.py
@@ -1366,7 +1366,11 @@ async def append_message(
# flush assigns the PK/defaults without a round-trip SELECT
await session.flush()
- # Persist token usage if provided (for assistant messages)
+ # Persist token usage if provided (for assistant messages).
+ # ``cost_micros`` is the provider USD cost reported by LiteLLM,
+ # forwarded by the FE through the appendMessage round-trip so
+ # the historical TokenUsage row matches the credit debit applied
+ # at finalize time.
token_usage_data = raw_body.get("token_usage")
if token_usage_data and message_role == NewChatMessageRole.ASSISTANT:
await record_token_usage(
@@ -1377,6 +1381,7 @@ async def append_message(
prompt_tokens=token_usage_data.get("prompt_tokens", 0),
completion_tokens=token_usage_data.get("completion_tokens", 0),
total_tokens=token_usage_data.get("total_tokens", 0),
+ cost_micros=token_usage_data.get("cost_micros", 0),
model_breakdown=token_usage_data.get("usage"),
call_details=token_usage_data.get("call_details"),
thread_id=thread_id,
diff --git a/surfsense_backend/app/routes/search_spaces_routes.py b/surfsense_backend/app/routes/search_spaces_routes.py
index 72715ea5b..5ecfb1814 100644
--- a/surfsense_backend/app/routes/search_spaces_routes.py
+++ b/surfsense_backend/app/routes/search_spaces_routes.py
@@ -594,6 +594,7 @@ async def _get_image_gen_config_by_id(
"model_name": "auto",
"is_global": True,
"is_auto_mode": True,
+ "billing_tier": "free",
}
if config_id < 0:
@@ -610,6 +611,7 @@ async def _get_image_gen_config_by_id(
"api_version": cfg.get("api_version") or None,
"litellm_params": cfg.get("litellm_params", {}),
"is_global": True,
+ "billing_tier": cfg.get("billing_tier", "free"),
}
return None
@@ -652,6 +654,7 @@ async def _get_vision_llm_config_by_id(
"model_name": "auto",
"is_global": True,
"is_auto_mode": True,
+ "billing_tier": "free",
}
if config_id < 0:
@@ -668,6 +671,7 @@ async def _get_vision_llm_config_by_id(
"api_version": cfg.get("api_version") or None,
"litellm_params": cfg.get("litellm_params", {}),
"is_global": True,
+ "billing_tier": cfg.get("billing_tier", "free"),
}
return None
diff --git a/surfsense_backend/app/routes/stripe_routes.py b/surfsense_backend/app/routes/stripe_routes.py
index cfdd4b52a..aed74ec8d 100644
--- a/surfsense_backend/app/routes/stripe_routes.py
+++ b/surfsense_backend/app/routes/stripe_routes.py
@@ -251,9 +251,16 @@ async def _fulfill_completed_token_purchase(
metadata = _get_metadata(checkout_session)
user_id = metadata.get("user_id")
quantity = int(metadata.get("quantity", "0"))
- tokens_per_unit = int(metadata.get("tokens_per_unit", "0"))
+ # Read the new metadata key first, fall back to the legacy one so
+ # in-flight checkout sessions created before the cost-credits
+ # release still fulfil correctly (the unit is numerically the
+ # same: $1 buys 1_000_000 micro-USD == 1_000_000 tokens).
+ credit_micros_per_unit = int(
+ metadata.get("credit_micros_per_unit")
+ or metadata.get("tokens_per_unit", "0")
+ )
- if not user_id or quantity <= 0 or tokens_per_unit <= 0:
+ if not user_id or quantity <= 0 or credit_micros_per_unit <= 0:
logger.error(
"Skipping token fulfillment for session %s: incomplete metadata %s",
checkout_session_id,
@@ -268,7 +275,7 @@ async def _fulfill_completed_token_purchase(
getattr(checkout_session, "payment_intent", None)
),
quantity=quantity,
- tokens_granted=quantity * tokens_per_unit,
+ credit_micros_granted=quantity * credit_micros_per_unit,
amount_total=getattr(checkout_session, "amount_total", None),
currency=getattr(checkout_session, "currency", None),
status=PremiumTokenPurchaseStatus.PENDING,
@@ -303,9 +310,14 @@ async def _fulfill_completed_token_purchase(
purchase.stripe_payment_intent_id = _normalize_optional_string(
getattr(checkout_session, "payment_intent", None)
)
- user.premium_tokens_limit = (
- max(user.premium_tokens_used, user.premium_tokens_limit)
- + purchase.tokens_granted
+ # Top up the user's credit balance by the granted micro-USD amount.
+ # ``max(used, limit)`` clamps the case where the legacy code wrote a
+ # used value above the limit (e.g. underbilling rounding) so adding
+ # ``credit_micros_granted`` always lifts the limit by the full pack
+ # size rather than disappearing into past overuse.
+ user.premium_credit_micros_limit = (
+ max(user.premium_credit_micros_used, user.premium_credit_micros_limit)
+ + purchase.credit_micros_granted
)
await db_session.commit()
@@ -532,12 +544,18 @@ async def create_token_checkout_session(
user: User = Depends(current_active_user),
db_session: AsyncSession = Depends(get_async_session),
):
- """Create a Stripe Checkout Session for buying premium token packs."""
+ """Create a Stripe Checkout Session for buying premium credit packs.
+
+ Each pack grants ``STRIPE_CREDIT_MICROS_PER_UNIT`` micro-USD of
+ credit (default 1_000_000 = $1.00). The user's balance is debited
+ at the actual provider cost reported by LiteLLM at finalize time,
+ so $1 of credit always buys $1 worth of provider usage at cost.
+ """
_ensure_token_buying_enabled()
stripe_client = get_stripe_client()
price_id = _get_required_token_price_id()
success_url, cancel_url = _get_token_checkout_urls(body.search_space_id)
- tokens_granted = body.quantity * config.STRIPE_TOKENS_PER_UNIT
+ credit_micros_granted = body.quantity * config.STRIPE_CREDIT_MICROS_PER_UNIT
try:
checkout_session = stripe_client.v1.checkout.sessions.create(
@@ -556,8 +574,8 @@ async def create_token_checkout_session(
"metadata": {
"user_id": str(user.id),
"quantity": str(body.quantity),
- "tokens_per_unit": str(config.STRIPE_TOKENS_PER_UNIT),
- "purchase_type": "premium_tokens",
+ "credit_micros_per_unit": str(config.STRIPE_CREDIT_MICROS_PER_UNIT),
+ "purchase_type": "premium_credit",
},
}
)
@@ -583,7 +601,7 @@ async def create_token_checkout_session(
getattr(checkout_session, "payment_intent", None)
),
quantity=body.quantity,
- tokens_granted=tokens_granted,
+ credit_micros_granted=credit_micros_granted,
amount_total=getattr(checkout_session, "amount_total", None),
currency=getattr(checkout_session, "currency", None),
status=PremiumTokenPurchaseStatus.PENDING,
@@ -598,14 +616,19 @@ async def create_token_checkout_session(
async def get_token_status(
user: User = Depends(current_active_user),
):
- """Return token-buying availability and current premium quota for frontend."""
- used = user.premium_tokens_used
- limit = user.premium_tokens_limit
+ """Return token-buying availability and current premium credit quota for frontend.
+
+ Values are in micro-USD (1_000_000 = $1.00); the FE divides by 1M
+ when displaying. The route name is preserved for back-compat with
+ pinned client deployments.
+ """
+ used = user.premium_credit_micros_used
+ limit = user.premium_credit_micros_limit
return TokenStripeStatusResponse(
token_buying_enabled=config.STRIPE_TOKEN_BUYING_ENABLED,
- premium_tokens_used=used,
- premium_tokens_limit=limit,
- premium_tokens_remaining=max(0, limit - used),
+ premium_credit_micros_used=used,
+ premium_credit_micros_limit=limit,
+ premium_credit_micros_remaining=max(0, limit - used),
)
diff --git a/surfsense_backend/app/routes/vision_llm_routes.py b/surfsense_backend/app/routes/vision_llm_routes.py
index 315c7c9fe..4f7e9f725 100644
--- a/surfsense_backend/app/routes/vision_llm_routes.py
+++ b/surfsense_backend/app/routes/vision_llm_routes.py
@@ -82,6 +82,9 @@ async def get_global_vision_llm_configs(
"litellm_params": {},
"is_global": True,
"is_auto_mode": True,
+ # Auto mode treated as free until per-deployment billing-tier
+ # surfacing lands; see ``get_vision_llm`` for parity.
+ "billing_tier": "free",
}
)
@@ -98,6 +101,10 @@ async def get_global_vision_llm_configs(
"api_version": cfg.get("api_version") or None,
"litellm_params": cfg.get("litellm_params", {}),
"is_global": True,
+ "billing_tier": cfg.get("billing_tier", "free"),
+ "quota_reserve_tokens": cfg.get("quota_reserve_tokens"),
+ "input_cost_per_token": cfg.get("input_cost_per_token"),
+ "output_cost_per_token": cfg.get("output_cost_per_token"),
}
)
diff --git a/surfsense_backend/app/schemas/image_generation.py b/surfsense_backend/app/schemas/image_generation.py
index 69f534e20..facca7b86 100644
--- a/surfsense_backend/app/schemas/image_generation.py
+++ b/surfsense_backend/app/schemas/image_generation.py
@@ -215,6 +215,12 @@ class GlobalImageGenConfigRead(BaseModel):
Schema for reading global image generation configs from YAML.
Global configs have negative IDs. API key is hidden.
ID 0 is reserved for Auto mode (LiteLLM Router load balancing).
+
+ The ``billing_tier`` field allows the frontend to show a Premium/Free
+ badge and (more importantly) tells the backend whether to debit the
+ user's premium credit pool when this config is used. ``"free"`` is
+ the default for backward compatibility — admins must explicitly opt
+ a global config into ``"premium"``.
"""
id: int = Field(
@@ -231,3 +237,15 @@ class GlobalImageGenConfigRead(BaseModel):
litellm_params: dict[str, Any] | None = None
is_global: bool = True
is_auto_mode: bool = False
+ billing_tier: str = Field(
+ default="free",
+ description="'free' or 'premium'. Premium debits the user's premium credit pool (USD-cost-based).",
+ )
+ quota_reserve_micros: int | None = Field(
+ default=None,
+ description=(
+ "Optional override for the reservation amount (in micro-USD) used when "
+ "this image generation is premium. Falls back to "
+ "QUOTA_DEFAULT_IMAGE_RESERVE_MICROS when omitted."
+ ),
+ )
diff --git a/surfsense_backend/app/schemas/new_chat.py b/surfsense_backend/app/schemas/new_chat.py
index ec5eefc07..892ff9693 100644
--- a/surfsense_backend/app/schemas/new_chat.py
+++ b/surfsense_backend/app/schemas/new_chat.py
@@ -39,6 +39,7 @@ class TokenUsageSummary(BaseModel):
prompt_tokens: int = 0
completion_tokens: int = 0
total_tokens: int = 0
+ cost_micros: int = 0
model_breakdown: dict | None = None
model_config = ConfigDict(from_attributes=True)
diff --git a/surfsense_backend/app/schemas/stripe.py b/surfsense_backend/app/schemas/stripe.py
index 3edd3e9e4..57265ec8e 100644
--- a/surfsense_backend/app/schemas/stripe.py
+++ b/surfsense_backend/app/schemas/stripe.py
@@ -70,13 +70,17 @@ class CreateTokenCheckoutSessionResponse(BaseModel):
class TokenPurchaseRead(BaseModel):
- """Serialized premium token purchase record."""
+ """Serialized premium credit purchase record.
+
+ ``credit_micros_granted`` is in micro-USD (1_000_000 = $1.00). The
+ schema name kept ``Token`` for API back-compat with pinned clients.
+ """
id: uuid.UUID
stripe_checkout_session_id: str
stripe_payment_intent_id: str | None = None
quantity: int
- tokens_granted: int
+ credit_micros_granted: int
amount_total: int | None = None
currency: str | None = None
status: str
@@ -87,15 +91,19 @@ class TokenPurchaseRead(BaseModel):
class TokenPurchaseHistoryResponse(BaseModel):
- """Response containing the user's premium token purchases."""
+ """Response containing the user's premium credit purchases."""
purchases: list[TokenPurchaseRead]
class TokenStripeStatusResponse(BaseModel):
- """Response describing token-buying availability and current quota."""
+ """Response describing premium-credit-buying availability and balance.
+
+ All ``premium_credit_micros_*`` fields are in micro-USD; the FE
+ divides by 1_000_000 to display USD.
+ """
token_buying_enabled: bool
- premium_tokens_used: int = 0
- premium_tokens_limit: int = 0
- premium_tokens_remaining: int = 0
+ premium_credit_micros_used: int = 0
+ premium_credit_micros_limit: int = 0
+ premium_credit_micros_remaining: int = 0
diff --git a/surfsense_backend/app/schemas/vision_llm.py b/surfsense_backend/app/schemas/vision_llm.py
index ab2e609dc..e55333a9d 100644
--- a/surfsense_backend/app/schemas/vision_llm.py
+++ b/surfsense_backend/app/schemas/vision_llm.py
@@ -62,6 +62,15 @@ class VisionLLMConfigPublic(BaseModel):
class GlobalVisionLLMConfigRead(BaseModel):
+ """Schema for reading global vision LLM configs from YAML.
+
+ The ``billing_tier`` field allows the frontend to show a Premium/Free
+ badge and (more importantly) tells the backend whether to debit the
+ user's premium credit pool when this config is used. ``"free"`` is
+ the default for backward compatibility — admins must explicitly opt
+ a global config into ``"premium"``.
+ """
+
id: int = Field(...)
name: str
description: str | None = None
@@ -73,3 +82,26 @@ class GlobalVisionLLMConfigRead(BaseModel):
litellm_params: dict[str, Any] | None = None
is_global: bool = True
is_auto_mode: bool = False
+ billing_tier: str = Field(
+ default="free",
+ description="'free' or 'premium'. Premium debits the user's premium credit pool (USD-cost-based).",
+ )
+ quota_reserve_tokens: int | None = Field(
+ default=None,
+ description=(
+ "Optional override for the per-call reservation in *tokens* — "
+ "converted to micro-USD via the model's input/output prices at "
+ "reservation time. Falls back to QUOTA_DEFAULT_RESERVE_TOKENS."
+ ),
+ )
+ input_cost_per_token: float | None = Field(
+ default=None,
+ description=(
+ "Optional input price in USD/token. Used by pricing_registration to "
+ "register custom Azure / OpenRouter aliases with LiteLLM at startup."
+ ),
+ )
+ output_cost_per_token: float | None = Field(
+ default=None,
+ description="Optional output price in USD/token. Pair with input_cost_per_token.",
+ )
diff --git a/surfsense_backend/app/services/billable_calls.py b/surfsense_backend/app/services/billable_calls.py
new file mode 100644
index 000000000..f5ca9818e
--- /dev/null
+++ b/surfsense_backend/app/services/billable_calls.py
@@ -0,0 +1,430 @@
+"""
+Per-call billable wrapper for image generation, vision LLM extraction, and
+any other short-lived premium operation that must charge against the user's
+shared premium credit pool.
+
+The ``billable_call`` async context manager encapsulates the standard
+"reserve → execute → finalize / release → record audit row" lifecycle in a
+single primitive so callers (the image-generation REST route and the
+vision-LLM wrapper used during indexing) don't have to re-implement it.
+
+KEY DESIGN POINTS (issue A, B):
+
+1. **Session isolation.** ``billable_call`` takes *no* ``db_session``
+ argument. All ``TokenQuotaService.premium_*`` calls and the audit-row
+ insert each run inside their own ``shielded_async_session()``. This
+ guarantees that a quota commit/rollback can never accidentally flush or
+ roll back rows the caller has staged in the request's main session
+ (e.g. a freshly-created ``ImageGeneration`` row).
+
+2. **ContextVar safety.** The accumulator is scoped via
+ :func:`scoped_turn` (which uses ``ContextVar.reset(token)``), so a
+ nested ``billable_call`` inside an outer chat turn cannot corrupt the
+ chat turn's accumulator.
+
+3. **Free configs are still audited.** Free calls bypass the reserve /
+ finalize dance entirely but still record a ``TokenUsage`` audit row with
+ the LiteLLM-reported ``cost_micros``. This keeps the cost-attribution
+ pipeline complete for analytics even when nothing is debited.
+
+4. **Quota denial raises ``QuotaInsufficientError``.** The route handler is
+ responsible for translating that into HTTP 402. We *do not* catch the
+ denial inside ``billable_call`` — letting it propagate also prevents
+ the image-generation route from creating an ``ImageGeneration`` row
+ for a request that never actually ran.
+"""
+
+from __future__ import annotations
+
+import logging
+from collections.abc import AsyncIterator
+from contextlib import asynccontextmanager
+from typing import Any
+from uuid import UUID, uuid4
+
+from sqlalchemy.ext.asyncio import AsyncSession
+
+from app.config import config
+from app.db import shielded_async_session
+from app.services.token_quota_service import (
+ TokenQuotaService,
+ estimate_call_reserve_micros,
+)
+from app.services.token_tracking_service import (
+ TurnTokenAccumulator,
+ record_token_usage,
+ scoped_turn,
+)
+
+logger = logging.getLogger(__name__)
+
+
+class QuotaInsufficientError(Exception):
+ """Raised when ``TokenQuotaService.premium_reserve`` denies a billable
+ call because the user has exhausted their premium credit pool.
+
+ The route handler should catch this and return HTTP 402 Payment
+ Required (or the equivalent for the surface area). Outside of the HTTP
+ layer (e.g. the ``QuotaCheckedVisionLLM`` wrapper used during indexing)
+ callers may catch this and degrade gracefully — e.g. fall back to OCR
+ when vision is unavailable.
+ """
+
+ def __init__(
+ self,
+ *,
+ usage_type: str,
+ used_micros: int,
+ limit_micros: int,
+ remaining_micros: int,
+ ) -> None:
+ self.usage_type = usage_type
+ self.used_micros = used_micros
+ self.limit_micros = limit_micros
+ self.remaining_micros = remaining_micros
+ super().__init__(
+ f"Premium credit exhausted for {usage_type}: "
+ f"used={used_micros} limit={limit_micros} remaining={remaining_micros} (micro-USD)"
+ )
+
+
+@asynccontextmanager
+async def billable_call(
+ *,
+ user_id: UUID,
+ search_space_id: int,
+ billing_tier: str,
+ base_model: str,
+ quota_reserve_tokens: int | None = None,
+ quota_reserve_micros_override: int | None = None,
+ usage_type: str,
+ thread_id: int | None = None,
+ message_id: int | None = None,
+ call_details: dict[str, Any] | None = None,
+) -> AsyncIterator[TurnTokenAccumulator]:
+ """Wrap a single billable LLM/image call.
+
+ Args:
+ user_id: Owner of the credit pool to debit. For vision-LLM during
+ indexing this is the *search-space owner* (issue M), not the
+ triggering user.
+ search_space_id: Required — recorded on the ``TokenUsage`` audit row.
+ billing_tier: ``"premium"`` debits; anything else (``"free"``) skips
+ the reserve/finalize dance but still records an audit row with
+ the captured cost.
+ base_model: Used by :func:`estimate_call_reserve_micros` to compute
+ a worst-case reservation from LiteLLM's pricing table.
+ quota_reserve_tokens: Optional per-config override for the chat-style
+ reserve estimator (vision LLM uses this).
+ quota_reserve_micros_override: Optional flat micro-USD reservation
+ (image generation uses this — its cost shape is per-image, not
+ per-token).
+ usage_type: ``"image_generation"`` / ``"vision_extraction"`` / etc.
+ Recorded on the ``TokenUsage`` row.
+ thread_id, message_id: Optional FK columns on ``TokenUsage``.
+ call_details: Optional per-call metadata (model name, parameters)
+ forwarded to ``record_token_usage``.
+
+ Yields:
+ The ``TurnTokenAccumulator`` scoped to this call. The caller invokes
+ the underlying LLM/image API while inside the ``async with``; the
+ ``TokenTrackingCallback`` populates the accumulator automatically.
+
+ Raises:
+ QuotaInsufficientError: when premium and ``premium_reserve`` denies.
+ """
+ is_premium = billing_tier == "premium"
+
+ async with scoped_turn() as acc:
+ # ---------- Free path: just audit -------------------------------
+ if not is_premium:
+ try:
+ yield acc
+ finally:
+ # Always audit, even on exception, so we capture cost when
+ # provider returns successfully but the caller raises later.
+ try:
+ async with shielded_async_session() as audit_session:
+ await record_token_usage(
+ audit_session,
+ usage_type=usage_type,
+ search_space_id=search_space_id,
+ user_id=user_id,
+ prompt_tokens=acc.total_prompt_tokens,
+ completion_tokens=acc.total_completion_tokens,
+ total_tokens=acc.grand_total,
+ cost_micros=acc.total_cost_micros,
+ model_breakdown=acc.per_message_summary(),
+ call_details=call_details,
+ thread_id=thread_id,
+ message_id=message_id,
+ )
+ await audit_session.commit()
+ except Exception:
+ logger.exception(
+ "[billable_call] free-path audit insert failed for "
+ "usage_type=%s user_id=%s",
+ usage_type,
+ user_id,
+ )
+ return
+
+ # ---------- Premium path: reserve → execute → finalize ----------
+ if quota_reserve_micros_override is not None:
+ reserve_micros = max(1, int(quota_reserve_micros_override))
+ else:
+ reserve_micros = estimate_call_reserve_micros(
+ base_model=base_model or "",
+ quota_reserve_tokens=quota_reserve_tokens,
+ )
+
+ request_id = str(uuid4())
+
+ async with shielded_async_session() as quota_session:
+ reserve_result = await TokenQuotaService.premium_reserve(
+ db_session=quota_session,
+ user_id=user_id,
+ request_id=request_id,
+ reserve_micros=reserve_micros,
+ )
+
+ if not reserve_result.allowed:
+ logger.info(
+ "[billable_call] reserve DENIED user=%s usage_type=%s "
+ "reserve=%d used=%d limit=%d remaining=%d",
+ user_id,
+ usage_type,
+ reserve_micros,
+ reserve_result.used,
+ reserve_result.limit,
+ reserve_result.remaining,
+ )
+ raise QuotaInsufficientError(
+ usage_type=usage_type,
+ used_micros=reserve_result.used,
+ limit_micros=reserve_result.limit,
+ remaining_micros=reserve_result.remaining,
+ )
+
+ logger.info(
+ "[billable_call] reserve OK user=%s usage_type=%s reserve_micros=%d "
+ "(remaining=%d)",
+ user_id,
+ usage_type,
+ reserve_micros,
+ reserve_result.remaining,
+ )
+
+ try:
+ yield acc
+ except BaseException:
+ # Release on any failure (including QuotaInsufficientError raised
+ # from a downstream call, asyncio cancellation, etc.). We use
+ # BaseException so cancellation also releases.
+ try:
+ async with shielded_async_session() as quota_session:
+ await TokenQuotaService.premium_release(
+ db_session=quota_session,
+ user_id=user_id,
+ reserved_micros=reserve_micros,
+ )
+ except Exception:
+ logger.exception(
+ "[billable_call] premium_release failed for user=%s "
+ "reserve_micros=%d (reservation will be GC'd by quota "
+ "reconciliation if/when implemented)",
+ user_id,
+ reserve_micros,
+ )
+ raise
+
+ # ---------- Success: finalize + audit ----------------------------
+ actual_micros = acc.total_cost_micros
+ try:
+ async with shielded_async_session() as quota_session:
+ final_result = await TokenQuotaService.premium_finalize(
+ db_session=quota_session,
+ user_id=user_id,
+ request_id=request_id,
+ actual_micros=actual_micros,
+ reserved_micros=reserve_micros,
+ )
+ logger.info(
+ "[billable_call] finalize user=%s usage_type=%s actual=%d "
+ "reserved=%d → used=%d/%d (remaining=%d)",
+ user_id,
+ usage_type,
+ actual_micros,
+ reserve_micros,
+ final_result.used,
+ final_result.limit,
+ final_result.remaining,
+ )
+ except Exception:
+ # Last-ditch: if finalize itself fails, we must at least release
+ # so the reservation doesn't leak.
+ logger.exception(
+ "[billable_call] premium_finalize failed for user=%s; "
+ "attempting release",
+ user_id,
+ )
+ try:
+ async with shielded_async_session() as quota_session:
+ await TokenQuotaService.premium_release(
+ db_session=quota_session,
+ user_id=user_id,
+ reserved_micros=reserve_micros,
+ )
+ except Exception:
+ logger.exception(
+ "[billable_call] release after finalize failure ALSO failed "
+ "for user=%s",
+ user_id,
+ )
+
+ try:
+ async with shielded_async_session() as audit_session:
+ await record_token_usage(
+ audit_session,
+ usage_type=usage_type,
+ search_space_id=search_space_id,
+ user_id=user_id,
+ prompt_tokens=acc.total_prompt_tokens,
+ completion_tokens=acc.total_completion_tokens,
+ total_tokens=acc.grand_total,
+ cost_micros=actual_micros,
+ model_breakdown=acc.per_message_summary(),
+ call_details=call_details,
+ thread_id=thread_id,
+ message_id=message_id,
+ )
+ await audit_session.commit()
+ except Exception:
+ logger.exception(
+ "[billable_call] premium-path audit insert failed for "
+ "usage_type=%s user_id=%s (debit was applied)",
+ usage_type,
+ user_id,
+ )
+
+
+async def _resolve_agent_billing_for_search_space(
+ session: AsyncSession,
+ search_space_id: int,
+ *,
+ thread_id: int | None = None,
+) -> tuple[UUID, str, str]:
+ """Resolve ``(owner_user_id, billing_tier, base_model)`` for the search-space
+ agent LLM.
+
+ Used by Celery tasks (podcast generation, video presentation) to bill the
+ search-space owner's premium credit pool when the agent LLM is premium.
+
+ Resolution rules mirror chat at ``stream_new_chat.py:2294-2351``:
+
+ - Search space not found / no ``agent_llm_id``: raise ``ValueError``.
+ - **Auto mode** (``id == AUTO_FASTEST_ID == 0``):
+ * ``thread_id`` is set: delegate to
+ ``resolve_or_get_pinned_llm_config_id`` (the same call chat uses) and
+ recurse into the resolved id. Reuses chat's existing pin if present
+ so the same model bills for chat + downstream podcast/video. If the
+ user is not premium-eligible, the pin service auto-restricts to free
+ deployments — denial only happens later in
+ ``billable_call.premium_reserve`` if the pin really is premium and
+ credit ran out mid-flow.
+ * ``thread_id`` is None: fallback to ``("free", "auto")``. Forward-compat
+ for any future direct-API path; today both Celery tasks always pass
+ ``thread_id``.
+ - **Negative id** (global YAML / OpenRouter): ``cfg["billing_tier"]``
+ (defaults to ``"free"`` via ``app/config/__init__.py:52`` setdefault),
+ ``base_model = litellm_params.get("base_model") or model_name`` —
+ NOT provider-prefixed, matching chat's cost-map lookup convention.
+ - **Positive id** (user BYOK ``NewLLMConfig``): always free (matches
+ ``AgentConfig.from_new_llm_config`` which hard-codes ``billing_tier="free"``);
+ ``base_model`` from ``litellm_params`` or ``model_name``.
+
+ Note on imports: ``llm_service``, ``auto_model_pin_service``, and
+ ``llm_router_service`` are imported lazily inside the function body to
+ avoid hoisting litellm side-effects (``litellm.callbacks =
+ [token_tracker]``, ``litellm.drop_params``, etc.) into
+ ``billable_calls.py``'s module load path.
+ """
+ from sqlalchemy import select
+
+ from app.db import NewLLMConfig, SearchSpace
+
+ result = await session.execute(
+ select(SearchSpace).where(SearchSpace.id == search_space_id)
+ )
+ search_space = result.scalars().first()
+ if search_space is None:
+ raise ValueError(f"Search space {search_space_id} not found")
+
+ agent_llm_id = search_space.agent_llm_id
+ if agent_llm_id is None:
+ raise ValueError(
+ f"Search space {search_space_id} has no agent_llm_id configured"
+ )
+
+ owner_user_id: UUID = search_space.user_id
+
+ from app.services.auto_model_pin_service import (
+ AUTO_FASTEST_ID,
+ resolve_or_get_pinned_llm_config_id,
+ )
+
+ if agent_llm_id == AUTO_FASTEST_ID:
+ if thread_id is None:
+ return owner_user_id, "free", "auto"
+ try:
+ resolution = await resolve_or_get_pinned_llm_config_id(
+ session,
+ thread_id=thread_id,
+ search_space_id=search_space_id,
+ user_id=str(owner_user_id),
+ selected_llm_config_id=AUTO_FASTEST_ID,
+ )
+ except ValueError:
+ logger.warning(
+ "[agent_billing] Auto-mode pin resolution failed for "
+ "search_space=%s thread=%s; falling back to free",
+ search_space_id,
+ thread_id,
+ exc_info=True,
+ )
+ return owner_user_id, "free", "auto"
+ agent_llm_id = resolution.resolved_llm_config_id
+
+ if agent_llm_id < 0:
+ from app.services.llm_service import get_global_llm_config
+
+ cfg = get_global_llm_config(agent_llm_id) or {}
+ billing_tier = str(cfg.get("billing_tier", "free")).lower()
+ litellm_params = cfg.get("litellm_params") or {}
+ base_model = litellm_params.get("base_model") or cfg.get("model_name") or ""
+ return owner_user_id, billing_tier, base_model
+
+ nlc_result = await session.execute(
+ select(NewLLMConfig).where(
+ NewLLMConfig.id == agent_llm_id,
+ NewLLMConfig.search_space_id == search_space_id,
+ )
+ )
+ nlc = nlc_result.scalars().first()
+ base_model = ""
+ if nlc is not None:
+ litellm_params = nlc.litellm_params or {}
+ base_model = litellm_params.get("base_model") or nlc.model_name or ""
+ return owner_user_id, "free", base_model
+
+
+__all__ = [
+ "QuotaInsufficientError",
+ "_resolve_agent_billing_for_search_space",
+ "billable_call",
+]
+
+
+# Re-export the config knob so callers don't have to import config just for
+# the default image reserve.
+DEFAULT_IMAGE_RESERVE_MICROS = config.QUOTA_DEFAULT_IMAGE_RESERVE_MICROS
diff --git a/surfsense_backend/app/services/llm_router_service.py b/surfsense_backend/app/services/llm_router_service.py
index 8a7b2919a..1e9d235c8 100644
--- a/surfsense_backend/app/services/llm_router_service.py
+++ b/surfsense_backend/app/services/llm_router_service.py
@@ -134,42 +134,16 @@ PROVIDER_MAP = {
}
-# Default ``api_base`` per LiteLLM provider prefix. Used as a safety net when
-# a global LLM config does *not* specify ``api_base``: without this, LiteLLM
-# happily picks up provider-agnostic env vars (e.g. ``AZURE_API_BASE``,
-# ``OPENAI_API_BASE``) and routes, say, an ``openrouter/anthropic/claude-3-haiku``
-# request to an Azure endpoint, which then 404s with ``Resource not found``.
-# Only providers with a well-known, stable public base URL are listed here —
-# self-hosted / BYO-endpoint providers (ollama, custom, bedrock, vertex_ai,
-# huggingface, databricks, cloudflare, replicate) are intentionally omitted
-# so their existing config-driven behaviour is preserved.
-PROVIDER_DEFAULT_API_BASE = {
- "openrouter": "https://openrouter.ai/api/v1",
- "groq": "https://api.groq.com/openai/v1",
- "mistral": "https://api.mistral.ai/v1",
- "perplexity": "https://api.perplexity.ai",
- "xai": "https://api.x.ai/v1",
- "cerebras": "https://api.cerebras.ai/v1",
- "deepinfra": "https://api.deepinfra.com/v1/openai",
- "fireworks_ai": "https://api.fireworks.ai/inference/v1",
- "together_ai": "https://api.together.xyz/v1",
- "anyscale": "https://api.endpoints.anyscale.com/v1",
- "cometapi": "https://api.cometapi.com/v1",
- "sambanova": "https://api.sambanova.ai/v1",
-}
-
-
-# Canonical provider → base URL when a config uses a generic ``openai``-style
-# prefix but the ``provider`` field tells us which API it really is
-# (e.g. DeepSeek/Alibaba/Moonshot/Zhipu/MiniMax all use ``openai`` compat but
-# each has its own base URL).
-PROVIDER_KEY_DEFAULT_API_BASE = {
- "DEEPSEEK": "https://api.deepseek.com/v1",
- "ALIBABA_QWEN": "https://dashscope-intl.aliyuncs.com/compatible-mode/v1",
- "MOONSHOT": "https://api.moonshot.ai/v1",
- "ZHIPU": "https://open.bigmodel.cn/api/paas/v4",
- "MINIMAX": "https://api.minimax.io/v1",
-}
+# ``PROVIDER_DEFAULT_API_BASE`` and ``PROVIDER_KEY_DEFAULT_API_BASE`` were
+# hoisted to ``app.services.provider_api_base`` so vision and image-gen
+# call sites can share the exact same defense (OpenRouter / Groq / etc.
+# 404-ing against an inherited Azure endpoint). Re-exported here for
+# backward compatibility with any external import.
+from app.services.provider_api_base import ( # noqa: E402
+ PROVIDER_DEFAULT_API_BASE,
+ PROVIDER_KEY_DEFAULT_API_BASE,
+ resolve_api_base,
+)
class LLMRouterService:
@@ -466,14 +440,14 @@ class LLMRouterService:
# Resolve ``api_base``. Config value wins; otherwise apply a
# provider-aware default so the deployment does not silently
# inherit unrelated env vars (e.g. ``AZURE_API_BASE``) and route
- # requests to the wrong endpoint. See ``PROVIDER_DEFAULT_API_BASE``
+ # requests to the wrong endpoint. See ``provider_api_base``
# docstring for the motivating bug (OpenRouter models 404-ing
# against an Azure endpoint).
- api_base = config.get("api_base")
- if not api_base:
- api_base = PROVIDER_KEY_DEFAULT_API_BASE.get(provider)
- if not api_base:
- api_base = PROVIDER_DEFAULT_API_BASE.get(provider_prefix)
+ api_base = resolve_api_base(
+ provider=provider,
+ provider_prefix=provider_prefix,
+ config_api_base=config.get("api_base"),
+ )
if api_base:
litellm_params["api_base"] = api_base
diff --git a/surfsense_backend/app/services/llm_service.py b/surfsense_backend/app/services/llm_service.py
index 942a9b7af..72c10035d 100644
--- a/surfsense_backend/app/services/llm_service.py
+++ b/surfsense_backend/app/services/llm_service.py
@@ -496,8 +496,14 @@ async def get_vision_llm(
- Auto mode (ID 0): VisionLLMRouterService
- Global (negative ID): YAML configs
- DB (positive ID): VisionLLMConfig table
+
+ Premium global configs are wrapped in :class:`QuotaCheckedVisionLLM`
+ so each ``ainvoke`` debits the search-space owner's premium credit
+ pool. User-owned BYOK configs and free global configs are returned
+ unwrapped — they don't consume premium credit (issue M).
"""
from app.db import VisionLLMConfig
+ from app.services.quota_checked_vision_llm import QuotaCheckedVisionLLM
from app.services.vision_llm_router_service import (
VISION_PROVIDER_MAP,
VisionLLMRouterService,
@@ -519,6 +525,8 @@ async def get_vision_llm(
logger.error(f"No vision LLM configured for search space {search_space_id}")
return None
+ owner_user_id = search_space.user_id
+
if is_vision_auto_mode(config_id):
if not VisionLLMRouterService.is_initialized():
logger.error(
@@ -526,6 +534,13 @@ async def get_vision_llm(
)
return None
try:
+ # Auto mode is currently treated as free at the wrapper
+ # level — the underlying router can dispatch to either
+ # premium or free YAML configs but routing decisions are
+ # opaque. If/when we want to bill Auto-routed vision
+ # calls we'd need to thread the resolved deployment's
+ # billing_tier back from the router. For now we keep
+ # parity with chat Auto, which also doesn't pre-classify.
return ChatLiteLLMRouter(
router=VisionLLMRouterService.get_router(),
streaming=True,
@@ -562,8 +577,21 @@ async def get_vision_llm(
from app.agents.new_chat.llm_config import SanitizedChatLiteLLM
- return SanitizedChatLiteLLM(**litellm_kwargs)
+ inner_llm = SanitizedChatLiteLLM(**litellm_kwargs)
+ billing_tier = str(global_cfg.get("billing_tier", "free")).lower()
+ if billing_tier == "premium":
+ return QuotaCheckedVisionLLM(
+ inner_llm,
+ user_id=owner_user_id,
+ search_space_id=search_space_id,
+ billing_tier=billing_tier,
+ base_model=model_string,
+ quota_reserve_tokens=global_cfg.get("quota_reserve_tokens"),
+ )
+ return inner_llm
+
+ # User-owned (positive ID) BYOK configs — always free.
result = await session.execute(
select(VisionLLMConfig).where(
VisionLLMConfig.id == config_id,
diff --git a/surfsense_backend/app/services/openrouter_integration_service.py b/surfsense_backend/app/services/openrouter_integration_service.py
index 7e856d015..0d030f04f 100644
--- a/surfsense_backend/app/services/openrouter_integration_service.py
+++ b/surfsense_backend/app/services/openrouter_integration_service.py
@@ -93,6 +93,35 @@ def _is_text_output_model(model: dict) -> bool:
return output_mods == ["text"]
+def _is_image_output_model(model: dict) -> bool:
+ """Return True if the model can produce image output.
+
+ OpenRouter's ``architecture.output_modalities`` is a list (e.g.
+ ``["image"]`` for pure image generators, ``["text", "image"]`` for
+ multi-modal generators that also emit captions). We accept any model
+ that can output images; the call site decides whether to use the
+ image-generation API or chat completion.
+ """
+ output_mods = model.get("architecture", {}).get("output_modalities", []) or []
+ return "image" in output_mods
+
+
+def _is_vision_input_model(model: dict) -> bool:
+ """Return True if the model can ingest an image AND emit text.
+
+ OpenRouter's ``architecture.input_modalities`` lists what the model
+ accepts; ``output_modalities`` lists what it produces. A vision LLM
+ is a model that takes images in and produces text out — i.e. it can
+ answer questions about a screenshot or extract content from an
+ image. Pure image-to-image models (e.g. style transfer) and
+ text-only models are excluded.
+ """
+ arch = model.get("architecture", {}) or {}
+ input_mods = arch.get("input_modalities", []) or []
+ output_mods = arch.get("output_modalities", []) or []
+ return "image" in input_mods and "text" in output_mods
+
+
def _supports_tool_calling(model: dict) -> bool:
"""Return True if the model supports function/tool calling."""
supported = model.get("supported_parameters") or []
@@ -175,6 +204,32 @@ async def _fetch_models_async() -> list[dict] | None:
return None
+def _extract_raw_pricing(raw_models: list[dict]) -> dict[str, dict[str, str]]:
+ """Return a ``{model_id: {"prompt": str, "completion": str}}`` map.
+
+ Pricing values are kept as the raw OpenRouter strings (e.g.
+ ``"0.000003"``); ``pricing_registration`` converts them to floats
+ when registering with LiteLLM. Models with missing or malformed
+ pricing are simply omitted — operator-side risk if any of those are
+ premium.
+ """
+ pricing: dict[str, dict[str, str]] = {}
+ for model in raw_models:
+ model_id = str(model.get("id") or "").strip()
+ if not model_id:
+ continue
+ p = model.get("pricing") or {}
+ prompt = p.get("prompt")
+ completion = p.get("completion")
+ if prompt is None and completion is None:
+ continue
+ pricing[model_id] = {
+ "prompt": str(prompt) if prompt is not None else "",
+ "completion": str(completion) if completion is not None else "",
+ }
+ return pricing
+
+
def _generate_configs(
raw_models: list[dict],
settings: dict[str, Any],
@@ -282,6 +337,162 @@ def _generate_configs(
return configs
+# ID-offset bands used to keep dynamic OpenRouter configs in their own
+# namespace per surface. Image / vision get separate bands so a single
+# Postgres-INTEGER cfg ID is unambiguous about which selector it belongs to.
+_OPENROUTER_IMAGE_ID_OFFSET_DEFAULT = -20000
+_OPENROUTER_VISION_ID_OFFSET_DEFAULT = -30000
+
+
+def _generate_image_gen_configs(
+ raw_models: list[dict], settings: dict[str, Any]
+) -> list[dict]:
+ """Convert OpenRouter image-generation models into global image-gen
+ config dicts (matches the YAML shape consumed by ``image_generation_routes``).
+
+ Filter:
+ - architecture.output_modalities contains "image"
+ - compatible provider (excluded slugs blocked)
+ - allowed model id (excluded list blocked)
+
+ Notably we *drop* the chat-only filters (``_supports_tool_calling`` and
+ ``_has_sufficient_context``) because tool calls and context windows are
+ irrelevant for the ``aimage_generation`` API. ``billing_tier`` is
+ derived per model the same way as chat (``_openrouter_tier``).
+
+ Cost is intentionally *not* registered with LiteLLM at startup
+ (``pricing_registration`` skips image gen): OpenRouter image-gen
+ models are not in LiteLLM's native cost map and OpenRouter populates
+ ``response_cost`` directly from the response header. A defensive
+ branch in ``_extract_cost_usd`` handles the rare case where
+ ``usage.cost`` is missing — see ``token_tracking_service``.
+ """
+ id_offset: int = int(
+ settings.get("image_id_offset") or _OPENROUTER_IMAGE_ID_OFFSET_DEFAULT
+ )
+ api_key: str = settings.get("api_key", "")
+ rpm: int = settings.get("rpm", 200)
+ free_rpm: int = settings.get("free_rpm", 20)
+ litellm_params: dict = settings.get("litellm_params") or {}
+
+ image_models = [
+ m
+ for m in raw_models
+ if _is_image_output_model(m)
+ and _is_compatible_provider(m)
+ and _is_allowed_model(m)
+ and "/" in m.get("id", "")
+ ]
+
+ configs: list[dict] = []
+ taken: set[int] = set()
+ for model in image_models:
+ model_id: str = model["id"]
+ name: str = model.get("name", model_id)
+ tier = _openrouter_tier(model)
+
+ cfg: dict[str, Any] = {
+ "id": _stable_config_id(model_id, id_offset, taken),
+ "name": name,
+ "description": f"{name} via OpenRouter (image generation)",
+ "provider": "OPENROUTER",
+ "model_name": model_id,
+ "api_key": api_key,
+ "api_base": "",
+ "api_version": None,
+ "rpm": free_rpm if tier == "free" else rpm,
+ "litellm_params": dict(litellm_params),
+ "billing_tier": tier,
+ _OPENROUTER_DYNAMIC_MARKER: True,
+ }
+ configs.append(cfg)
+
+ return configs
+
+
+def _generate_vision_llm_configs(
+ raw_models: list[dict], settings: dict[str, Any]
+) -> list[dict]:
+ """Convert OpenRouter vision-capable LLMs into global vision-LLM config
+ dicts (matches the YAML shape consumed by ``vision_llm_routes``).
+
+ Filter:
+ - architecture.input_modalities contains "image"
+ - architecture.output_modalities contains "text"
+ - compatible provider (excluded slugs blocked)
+ - allowed model id (excluded list blocked)
+
+ Vision-LLM is invoked from the indexer (image extraction during
+ document upload) via ``langchain_litellm.ChatLiteLLM.ainvoke``, so
+ the chat-only ``_supports_tool_calling`` and ``_has_sufficient_context``
+ filters do not apply: a small-context vision model that doesn't
+ advertise tool-calling is still perfectly viable for "describe this
+ image" prompts.
+ """
+ id_offset: int = int(
+ settings.get("vision_id_offset") or _OPENROUTER_VISION_ID_OFFSET_DEFAULT
+ )
+ api_key: str = settings.get("api_key", "")
+ rpm: int = settings.get("rpm", 200)
+ tpm: int = settings.get("tpm", 1_000_000)
+ free_rpm: int = settings.get("free_rpm", 20)
+ free_tpm: int = settings.get("free_tpm", 100_000)
+ quota_reserve_tokens: int = settings.get("quota_reserve_tokens", 4000)
+ litellm_params: dict = settings.get("litellm_params") or {}
+
+ vision_models = [
+ m
+ for m in raw_models
+ if _is_vision_input_model(m)
+ and _is_compatible_provider(m)
+ and _is_allowed_model(m)
+ and "/" in m.get("id", "")
+ ]
+
+ configs: list[dict] = []
+ taken: set[int] = set()
+ for model in vision_models:
+ model_id: str = model["id"]
+ name: str = model.get("name", model_id)
+ tier = _openrouter_tier(model)
+ pricing = model.get("pricing") or {}
+
+ # Capture per-token prices so ``pricing_registration`` can
+ # register them with LiteLLM at startup (and so the cost
+ # estimator in ``estimate_call_reserve_micros`` can resolve
+ # them at reserve time).
+ try:
+ input_cost = float(pricing.get("prompt", 0) or 0)
+ except (TypeError, ValueError):
+ input_cost = 0.0
+ try:
+ output_cost = float(pricing.get("completion", 0) or 0)
+ except (TypeError, ValueError):
+ output_cost = 0.0
+
+ cfg: dict[str, Any] = {
+ "id": _stable_config_id(model_id, id_offset, taken),
+ "name": name,
+ "description": f"{name} via OpenRouter (vision)",
+ "provider": "OPENROUTER",
+ "model_name": model_id,
+ "api_key": api_key,
+ "api_base": "",
+ "api_version": None,
+ "rpm": free_rpm if tier == "free" else rpm,
+ "tpm": free_tpm if tier == "free" else tpm,
+ "litellm_params": dict(litellm_params),
+ "billing_tier": tier,
+ "quota_reserve_tokens": quota_reserve_tokens,
+ "input_cost_per_token": input_cost or None,
+ "output_cost_per_token": output_cost or None,
+ _OPENROUTER_DYNAMIC_MARKER: True,
+ }
+ configs.append(cfg)
+
+ return configs
+
+
class OpenRouterIntegrationService:
"""Singleton that manages the dynamic OpenRouter model catalogue."""
@@ -300,6 +511,19 @@ class OpenRouterIntegrationService:
# Shape: {model_name: {"gated": bool, "score": float | None}}
self._health_cache: dict[str, dict[str, Any]] = {}
self._enrich_task: asyncio.Task | None = None
+ # Raw OpenRouter pricing per model_id, captured at the same time
+ # we generate configs. Consumed by ``pricing_registration`` to
+ # teach LiteLLM the per-token cost of every dynamic deployment so
+ # the success-callback can populate ``response_cost`` correctly.
+ self._raw_pricing: dict[str, dict[str, str]] = {}
+ # Cached raw catalogue from the most recent fetch. Image / vision
+ # emitters reuse this to avoid a second network call per surface.
+ self._raw_models: list[dict] = []
+ # Image / vision config caches (only populated when the matching
+ # opt-in flag is true on initialize). Refreshed in lockstep with
+ # the chat catalogue.
+ self._image_configs: list[dict] = []
+ self._vision_configs: list[dict] = []
@classmethod
def get_instance(cls) -> "OpenRouterIntegrationService":
@@ -329,8 +553,32 @@ class OpenRouterIntegrationService:
self._initialized = True
return []
+ self._raw_models = raw_models
self._configs = _generate_configs(raw_models, settings)
self._configs_by_id = {c["id"]: c for c in self._configs}
+ self._raw_pricing = _extract_raw_pricing(raw_models)
+
+ # Populate image / vision caches when their opt-in flag is set.
+ # Empty otherwise so the accessors return [] without re-running
+ # filters every refresh.
+ if settings.get("image_generation_enabled"):
+ self._image_configs = _generate_image_gen_configs(raw_models, settings)
+ logger.info(
+ "OpenRouter integration: image-gen emission ON (%d models)",
+ len(self._image_configs),
+ )
+ else:
+ self._image_configs = []
+
+ if settings.get("vision_enabled"):
+ self._vision_configs = _generate_vision_llm_configs(raw_models, settings)
+ logger.info(
+ "OpenRouter integration: vision LLM emission ON (%d models)",
+ len(self._vision_configs),
+ )
+ else:
+ self._vision_configs = []
+
self._initialized = True
tier_counts = self._tier_counts(self._configs)
@@ -369,6 +617,8 @@ class OpenRouterIntegrationService:
new_configs = _generate_configs(raw_models, self._settings)
new_by_id = {c["id"]: c for c in new_configs}
+ self._raw_pricing = _extract_raw_pricing(raw_models)
+ self._raw_models = raw_models
from app.config import config as app_config
@@ -382,6 +632,29 @@ class OpenRouterIntegrationService:
self._configs = new_configs
self._configs_by_id = new_by_id
+ # Image / vision lists are atomic-swapped the same way: filter out
+ # the previous dynamic entries from the live config list and append
+ # the freshly generated ones. No-ops when the opt-in flag is off.
+ if self._settings.get("image_generation_enabled"):
+ new_image = _generate_image_gen_configs(raw_models, self._settings)
+ static_image = [
+ c
+ for c in app_config.GLOBAL_IMAGE_GEN_CONFIGS
+ if not c.get(_OPENROUTER_DYNAMIC_MARKER)
+ ]
+ app_config.GLOBAL_IMAGE_GEN_CONFIGS = static_image + new_image
+ self._image_configs = new_image
+
+ if self._settings.get("vision_enabled"):
+ new_vision = _generate_vision_llm_configs(raw_models, self._settings)
+ static_vision = [
+ c
+ for c in app_config.GLOBAL_VISION_LLM_CONFIGS
+ if not c.get(_OPENROUTER_DYNAMIC_MARKER)
+ ]
+ app_config.GLOBAL_VISION_LLM_CONFIGS = static_vision + new_vision
+ self._vision_configs = new_vision
+
# Catalogue churn invalidates per-config "recently healthy" credit
# earned by the previous turn's preflight. Drop the whole table so
# the next turn re-probes against the freshly loaded configs.
@@ -407,6 +680,21 @@ class OpenRouterIntegrationService:
# so a hand-picked dead OR model is gated like a dynamic one.
await self._enrich_health_safely(static_configs + new_configs, log_summary=True)
+ # Re-register LiteLLM pricing for the freshly fetched catalogue
+ # so newly added OR models bill correctly on their first call.
+ # Runs before the router rebuild because the router may issue
+ # cost-table lookups during deployment registration.
+ try:
+ from app.services.pricing_registration import (
+ register_pricing_from_global_configs,
+ )
+
+ register_pricing_from_global_configs()
+ except Exception as exc:
+ logger.warning(
+ "OpenRouter refresh: pricing re-registration skipped (%s)", exc
+ )
+
# Rebuild the LiteLLM router so freshly fetched configs flow through
# (dynamic OR premium entries now opt into the pool, free ones stay
# out; a refresh also needs to pick up any static-config edits and
@@ -635,3 +923,34 @@ class OpenRouterIntegrationService:
def get_config_by_id(self, config_id: int) -> dict | None:
return self._configs_by_id.get(config_id)
+
+ def get_image_generation_configs(self) -> list[dict]:
+ """Return the dynamic OpenRouter image-generation configs (empty
+ list when the ``image_generation_enabled`` flag is off).
+
+ Each entry already has ``billing_tier`` derived per-model from
+ OpenRouter's signals and is shaped to drop directly into
+ ``Config.GLOBAL_IMAGE_GEN_CONFIGS``.
+ """
+ return list(self._image_configs)
+
+ def get_vision_llm_configs(self) -> list[dict]:
+ """Return the dynamic OpenRouter vision-LLM configs (empty list
+ when the ``vision_enabled`` flag is off).
+
+ Each entry exposes ``input_cost_per_token`` / ``output_cost_per_token``
+ so ``pricing_registration`` can teach LiteLLM the cost of these
+ models the same way it does for chat — which keeps the billable
+ wrapper able to debit accurate micro-USD on a vision call.
+ """
+ return list(self._vision_configs)
+
+ def get_raw_pricing(self) -> dict[str, dict[str, str]]:
+ """Return the cached raw OpenRouter pricing map.
+
+ Shape: ``{model_id: {"prompt": str, "completion": str}}``. The
+ values are the strings OpenRouter publishes (USD per token),
+ never converted to floats here so the caller can decide how to
+ handle malformed or unset entries.
+ """
+ return dict(self._raw_pricing)
diff --git a/surfsense_backend/app/services/pricing_registration.py b/surfsense_backend/app/services/pricing_registration.py
new file mode 100644
index 000000000..de98e50c2
--- /dev/null
+++ b/surfsense_backend/app/services/pricing_registration.py
@@ -0,0 +1,274 @@
+"""
+Pricing registration with LiteLLM.
+
+Many models reach our LiteLLM callback without LiteLLM knowing their
+per-token cost — namely:
+
+* The ~300 dynamic OpenRouter deployments (their pricing only lives on
+ OpenRouter's ``/api/v1/models`` payload, never in LiteLLM's published
+ pricing table).
+* Static YAML deployments whose ``base_model`` name is operator-defined
+ (e.g. custom Azure deployment names like ``gpt-5.4``) and therefore
+ not in LiteLLM's table either.
+
+Without registration, ``kwargs["response_cost"]`` is 0 for those calls
+and the user gets billed nothing — a fail-safe but wrong answer for a
+cost-based credit system. This module runs once at startup, after the
+OpenRouter integration has fetched its catalogue, and registers each
+known model's pricing with ``litellm.register_model()`` under multiple
+plausible alias keys (LiteLLM's cost lookup may use any of them
+depending on whether the call went through the Router, ChatLiteLLM,
+or a direct ``acompletion``).
+
+Operators who run a custom Azure deployment whose ``base_model`` name
+isn't in LiteLLM's table can declare per-token pricing inline in
+``global_llm_config.yaml`` via ``input_cost_per_token`` and
+``output_cost_per_token`` (USD per token, e.g. ``0.000002``). Without
+that declaration the model's calls debit 0 — never overbilled.
+"""
+
+from __future__ import annotations
+
+import logging
+from typing import Any
+
+import litellm
+
+logger = logging.getLogger(__name__)
+
+
+def _safe_float(value: Any) -> float:
+ """Return ``float(value)`` if it parses to a positive number, else 0.0."""
+ if value is None:
+ return 0.0
+ try:
+ f = float(value)
+ except (TypeError, ValueError):
+ return 0.0
+ return f if f > 0 else 0.0
+
+
+def _alias_set_for_openrouter(model_id: str) -> list[str]:
+ """Return the alias keys to register an OpenRouter model under.
+
+ LiteLLM's cost-callback lookup key varies by call path:
+ - Router with ``model="openrouter/X"`` → kwargs["model"] is
+ typically ``openrouter/X``.
+ - LiteLLM's own provider routing may strip the prefix and pass the
+ bare ``X`` to the cost-table lookup.
+ Registering under both keeps the lookup hermetic regardless of
+ which path the call took.
+ """
+ aliases = [f"openrouter/{model_id}", model_id]
+ return list(dict.fromkeys(a for a in aliases if a))
+
+
+def _alias_set_for_yaml(provider: str, model_name: str, base_model: str) -> list[str]:
+ """Return the alias keys to register a static YAML deployment under.
+
+ Same reasoning as the OpenRouter set: cover the bare ``base_model``,
+ the ``/`` form LiteLLM Router constructs, and the
+ bare ``model_name`` because callbacks sometimes see whichever was
+ configured first.
+ """
+ provider_lower = (provider or "").lower()
+ aliases: list[str] = []
+ if base_model:
+ aliases.append(base_model)
+ if provider_lower and base_model:
+ aliases.append(f"{provider_lower}/{base_model}")
+ if model_name and model_name != base_model:
+ aliases.append(model_name)
+ if provider_lower and model_name and model_name != base_model:
+ aliases.append(f"{provider_lower}/{model_name}")
+ # Azure deployments often surface as "azure/"; normalise the
+ # ``azure_openai`` provider slug to the LiteLLM-canonical ``azure``.
+ if provider_lower == "azure_openai":
+ if base_model:
+ aliases.append(f"azure/{base_model}")
+ if model_name and model_name != base_model:
+ aliases.append(f"azure/{model_name}")
+ return list(dict.fromkeys(a for a in aliases if a))
+
+
+def _register(
+ aliases: list[str],
+ *,
+ input_cost: float,
+ output_cost: float,
+ provider: str,
+ mode: str = "chat",
+) -> int:
+ """Register a single pricing entry under every alias in ``aliases``.
+
+ Returns the count of aliases successfully registered.
+ """
+ payload: dict[str, dict[str, Any]] = {}
+ for alias in aliases:
+ payload[alias] = {
+ "input_cost_per_token": input_cost,
+ "output_cost_per_token": output_cost,
+ "litellm_provider": provider,
+ "mode": mode,
+ }
+ if not payload:
+ return 0
+ try:
+ litellm.register_model(payload)
+ except Exception as exc:
+ logger.warning(
+ "[PricingRegistration] register_model failed for aliases=%s: %s",
+ aliases,
+ exc,
+ )
+ return 0
+ return len(payload)
+
+
+def _register_chat_shape_configs(
+ configs: list[dict],
+ *,
+ or_pricing: dict[str, dict[str, str]],
+ label: str,
+) -> tuple[int, int, int, list[str]]:
+ """Common loop that registers per-token pricing for a list of "chat-shape"
+ configs (chat or vision LLM — both use ``input_cost_per_token`` /
+ ``output_cost_per_token`` and the LiteLLM ``mode="chat"`` cost shape).
+
+ Returns ``(registered_models, registered_aliases, skipped, sample_keys)``.
+ """
+ registered_models = 0
+ registered_aliases = 0
+ skipped_no_pricing = 0
+ sample_keys: list[str] = []
+
+ for cfg in configs:
+ provider = str(cfg.get("provider") or "").upper()
+ model_name = str(cfg.get("model_name") or "").strip()
+ litellm_params = cfg.get("litellm_params") or {}
+ base_model = str(litellm_params.get("base_model") or model_name).strip()
+
+ if provider == "OPENROUTER":
+ entry = or_pricing.get(model_name)
+ if entry:
+ input_cost = _safe_float(entry.get("prompt"))
+ output_cost = _safe_float(entry.get("completion"))
+ else:
+ # Vision configs from ``_generate_vision_llm_configs``
+ # carry their pricing inline because the OpenRouter
+ # raw-pricing cache is keyed by chat-catalogue model_id;
+ # vision flows pick up the inline values here.
+ input_cost = _safe_float(cfg.get("input_cost_per_token"))
+ output_cost = _safe_float(cfg.get("output_cost_per_token"))
+ if input_cost == 0.0 and output_cost == 0.0:
+ skipped_no_pricing += 1
+ continue
+ aliases = _alias_set_for_openrouter(model_name)
+ count = _register(
+ aliases,
+ input_cost=input_cost,
+ output_cost=output_cost,
+ provider="openrouter",
+ )
+ if count > 0:
+ registered_models += 1
+ registered_aliases += count
+ if len(sample_keys) < 6:
+ sample_keys.extend(aliases[:2])
+ continue
+
+ input_cost = _safe_float(
+ cfg.get("input_cost_per_token")
+ or litellm_params.get("input_cost_per_token")
+ )
+ output_cost = _safe_float(
+ cfg.get("output_cost_per_token")
+ or litellm_params.get("output_cost_per_token")
+ )
+ if input_cost == 0.0 and output_cost == 0.0:
+ skipped_no_pricing += 1
+ continue
+ aliases = _alias_set_for_yaml(provider, model_name, base_model)
+ provider_slug = "azure" if provider == "AZURE_OPENAI" else provider.lower()
+ count = _register(
+ aliases,
+ input_cost=input_cost,
+ output_cost=output_cost,
+ provider=provider_slug,
+ )
+ if count > 0:
+ registered_models += 1
+ registered_aliases += count
+ if len(sample_keys) < 6:
+ sample_keys.extend(aliases[:2])
+
+ logger.info(
+ "[PricingRegistration:%s] registered pricing for %d models (%d aliases); "
+ "%d configs had no pricing data; sample registered keys=%s",
+ label,
+ registered_models,
+ registered_aliases,
+ skipped_no_pricing,
+ sample_keys,
+ )
+ return registered_models, registered_aliases, skipped_no_pricing, sample_keys
+
+
+def register_pricing_from_global_configs() -> None:
+ """Register pricing for every known LLM deployment with LiteLLM.
+
+ Walks ``config.GLOBAL_LLM_CONFIGS`` *and* ``config.GLOBAL_VISION_LLM_CONFIGS``
+ so vision calls (during indexing) can resolve cost the same way chat
+ calls do — namely:
+
+ 1. ``OPENROUTER``: pulls the cached raw pricing from
+ ``OpenRouterIntegrationService`` (populated during its own
+ startup fetch) and converts the per-token strings to floats. For
+ vision configs that carry pricing inline (``input_cost_per_token`` /
+ ``output_cost_per_token`` set on the cfg itself) we fall back to
+ those values when the OR cache misses the model.
+ 2. Anything else: looks for operator-declared
+ ``input_cost_per_token`` / ``output_cost_per_token`` on the YAML
+ config block (top-level or nested under ``litellm_params``).
+
+ **Image generation is intentionally NOT registered here.** The cost
+ shape for image-gen is per-image (``output_cost_per_image``), not
+ per-token, and LiteLLM's ``register_model`` doesn't accept those
+ keys via the chat-cost path. OpenRouter image-gen models populate
+ ``response_cost`` directly from their response header instead, and
+ Azure-native image-gen models are already in LiteLLM's cost map.
+
+ Calls without a resolved pair of costs are skipped, not registered
+ with zeros — operators who forget pricing get a "$0 debit" warning
+ in ``TokenTrackingCallback`` rather than silently overwriting any
+ pricing LiteLLM might know natively.
+ """
+ from app.config import config as app_config
+
+ chat_configs: list[dict] = list(getattr(app_config, "GLOBAL_LLM_CONFIGS", []) or [])
+ vision_configs: list[dict] = list(
+ getattr(app_config, "GLOBAL_VISION_LLM_CONFIGS", []) or []
+ )
+ if not chat_configs and not vision_configs:
+ logger.info("[PricingRegistration] no global configs to register")
+ return
+
+ or_pricing: dict[str, dict[str, str]] = {}
+ try:
+ from app.services.openrouter_integration_service import (
+ OpenRouterIntegrationService,
+ )
+
+ if OpenRouterIntegrationService.is_initialized():
+ or_pricing = OpenRouterIntegrationService.get_instance().get_raw_pricing()
+ except Exception as exc:
+ logger.debug(
+ "[PricingRegistration] OpenRouter pricing not available yet: %s", exc
+ )
+
+ if chat_configs:
+ _register_chat_shape_configs(chat_configs, or_pricing=or_pricing, label="chat")
+ if vision_configs:
+ _register_chat_shape_configs(
+ vision_configs, or_pricing=or_pricing, label="vision"
+ )
diff --git a/surfsense_backend/app/services/provider_api_base.py b/surfsense_backend/app/services/provider_api_base.py
new file mode 100644
index 000000000..979d7d3a1
--- /dev/null
+++ b/surfsense_backend/app/services/provider_api_base.py
@@ -0,0 +1,107 @@
+"""Provider-aware ``api_base`` resolution shared by chat / image-gen / vision.
+
+LiteLLM falls back to the module-global ``litellm.api_base`` when an
+individual call doesn't pass one, which silently inherits provider-agnostic
+env vars like ``AZURE_OPENAI_ENDPOINT`` / ``OPENAI_API_BASE``. Without an
+explicit ``api_base``, an ``openrouter/`` request can end up at an
+Azure endpoint and 404 with ``Resource not found`` (real reproducer:
+[litellm/llms/openrouter/image_generation/transformation.py:242-263] appends
+``/chat/completions`` to whatever inherited base it gets, regardless of
+provider).
+
+The chat router has had this defense for a while
+(``llm_router_service.py:466-478``). This module hoists the maps + cascade
+into a tiny standalone helper so vision and image-gen can share the same
+source of truth without an inter-service circular import.
+"""
+
+from __future__ import annotations
+
+
+PROVIDER_DEFAULT_API_BASE: dict[str, str] = {
+ "openrouter": "https://openrouter.ai/api/v1",
+ "groq": "https://api.groq.com/openai/v1",
+ "mistral": "https://api.mistral.ai/v1",
+ "perplexity": "https://api.perplexity.ai",
+ "xai": "https://api.x.ai/v1",
+ "cerebras": "https://api.cerebras.ai/v1",
+ "deepinfra": "https://api.deepinfra.com/v1/openai",
+ "fireworks_ai": "https://api.fireworks.ai/inference/v1",
+ "together_ai": "https://api.together.xyz/v1",
+ "anyscale": "https://api.endpoints.anyscale.com/v1",
+ "cometapi": "https://api.cometapi.com/v1",
+ "sambanova": "https://api.sambanova.ai/v1",
+}
+"""Default ``api_base`` per LiteLLM provider prefix (lowercase).
+
+Only providers with a well-known, stable public base URL are listed —
+self-hosted / BYO-endpoint providers (ollama, custom, bedrock, vertex_ai,
+huggingface, databricks, cloudflare, replicate) are intentionally omitted
+so their existing config-driven behaviour is preserved."""
+
+
+PROVIDER_KEY_DEFAULT_API_BASE: dict[str, str] = {
+ "DEEPSEEK": "https://api.deepseek.com/v1",
+ "ALIBABA_QWEN": "https://dashscope-intl.aliyuncs.com/compatible-mode/v1",
+ "MOONSHOT": "https://api.moonshot.ai/v1",
+ "ZHIPU": "https://open.bigmodel.cn/api/paas/v4",
+ "MINIMAX": "https://api.minimax.io/v1",
+}
+"""Canonical provider key (uppercase) → base URL.
+
+Used when the LiteLLM provider prefix is the generic ``openai`` shim but the
+config's ``provider`` field tells us which API it actually is (DeepSeek,
+Alibaba, Moonshot, Zhipu, MiniMax all use the ``openai`` prefix but each
+has its own base URL)."""
+
+
+def resolve_api_base(
+ *,
+ provider: str | None,
+ provider_prefix: str | None,
+ config_api_base: str | None,
+) -> str | None:
+ """Resolve a non-Azure-leaking ``api_base`` for a deployment.
+
+ Cascade (first non-empty wins):
+ 1. The config's own ``api_base`` (whitespace-only treated as missing).
+ 2. ``PROVIDER_KEY_DEFAULT_API_BASE[provider.upper()]``.
+ 3. ``PROVIDER_DEFAULT_API_BASE[provider_prefix.lower()]``.
+ 4. ``None`` — caller should NOT set ``api_base`` and let the LiteLLM
+ provider integration apply its own default (e.g. AzureOpenAI's
+ deployment-derived URL, custom provider's per-deployment URL).
+
+ Args:
+ provider: The config's ``provider`` field (e.g. ``"OPENROUTER"``,
+ ``"DEEPSEEK"``). Case-insensitive.
+ provider_prefix: The LiteLLM model-string prefix the same call
+ site builds for the model id (e.g. ``"openrouter"``,
+ ``"groq"``). Case-insensitive.
+ config_api_base: ``api_base`` from the global YAML / DB row /
+ OpenRouter dynamic config. Empty / whitespace-only means
+ "missing" — the resolver still applies the cascade.
+
+ Returns:
+ A URL string, or ``None`` if no default applies for this provider.
+ """
+ if config_api_base and config_api_base.strip():
+ return config_api_base
+
+ if provider:
+ key_default = PROVIDER_KEY_DEFAULT_API_BASE.get(provider.upper())
+ if key_default:
+ return key_default
+
+ if provider_prefix:
+ prefix_default = PROVIDER_DEFAULT_API_BASE.get(provider_prefix.lower())
+ if prefix_default:
+ return prefix_default
+
+ return None
+
+
+__all__ = [
+ "PROVIDER_DEFAULT_API_BASE",
+ "PROVIDER_KEY_DEFAULT_API_BASE",
+ "resolve_api_base",
+]
diff --git a/surfsense_backend/app/services/quota_checked_vision_llm.py b/surfsense_backend/app/services/quota_checked_vision_llm.py
new file mode 100644
index 000000000..0040e5a5b
--- /dev/null
+++ b/surfsense_backend/app/services/quota_checked_vision_llm.py
@@ -0,0 +1,105 @@
+"""
+Vision LLM proxy that enforces premium credit quota on every ``ainvoke``.
+
+Used by :func:`app.services.llm_service.get_vision_llm` so callers in the
+indexing pipeline (file processors, connector indexers, etl pipeline) can
+keep invoking the LLM exactly the way they do today — ``await llm.ainvoke(...)``
+— without threading ``user_id`` through every parser. The wrapper looks like
+a chat model from the outside; on the inside it routes each call through
+``billable_call`` so the user's premium credit pool is reserved → finalized
+or released, and a ``TokenUsage`` audit row is written.
+
+Free configs are returned unwrapped from ``get_vision_llm`` (they do not
+need quota enforcement) so this class only ever wraps premium configs.
+
+Why a wrapper instead of plumbing ``user_id`` through every caller:
+
+* The indexer ecosystem has 8+ entry points (Google Drive, OneDrive,
+ Dropbox, local-folder, file-processor, ETL pipeline) each calling
+ ``parse_with_vision_llm(...)``. Adding a ``user_id`` argument to each is
+ invasive, error-prone, and easy for a future indexer to forget.
+* Per the design (issue M), we always debit the *search-space owner*, not
+ the triggering user, so ``user_id`` is fully derivable from the search
+ space the caller is already operating on. The wrapper captures it once
+ at construction time.
+* ``langchain_litellm.ChatLiteLLM`` has no public hook for "before each
+ call run this coroutine"; subclassing isn't safe across versions because
+ it derives from ``BaseChatModel`` which expects specific Pydantic shapes.
+ Composition via attribute proxying (``__getattr__``) is robust to
+ upstream changes — every method other than ``ainvoke`` falls through to
+ the inner LLM unchanged.
+"""
+
+from __future__ import annotations
+
+import logging
+from typing import Any
+from uuid import UUID
+
+from app.services.billable_calls import QuotaInsufficientError, billable_call
+
+logger = logging.getLogger(__name__)
+
+
+class QuotaCheckedVisionLLM:
+ """Composition wrapper around a langchain chat model that enforces
+ premium credit quota on every ``ainvoke``.
+
+ Anything other than ``ainvoke`` is forwarded to the inner model so
+ ``invoke`` (sync), ``astream``, ``with_structured_output``, etc. all
+ still work — they simply bypass quota enforcement, which is fine
+ because the indexing pipeline only ever calls ``ainvoke`` today.
+ """
+
+ def __init__(
+ self,
+ inner_llm: Any,
+ *,
+ user_id: UUID,
+ search_space_id: int,
+ billing_tier: str,
+ base_model: str,
+ quota_reserve_tokens: int | None,
+ usage_type: str = "vision_extraction",
+ ) -> None:
+ self._inner = inner_llm
+ self._user_id = user_id
+ self._search_space_id = search_space_id
+ self._billing_tier = billing_tier
+ self._base_model = base_model
+ self._quota_reserve_tokens = quota_reserve_tokens
+ self._usage_type = usage_type
+
+ async def ainvoke(self, input: Any, *args: Any, **kwargs: Any) -> Any:
+ """Proxied async invoke that runs the underlying call inside
+ ``billable_call``.
+
+ Raises:
+ QuotaInsufficientError: when the user has exhausted their
+ premium credit pool. Caller (``etl_pipeline_service._extract_image``)
+ catches this and falls back to the document parser.
+ """
+ async with billable_call(
+ user_id=self._user_id,
+ search_space_id=self._search_space_id,
+ billing_tier=self._billing_tier,
+ base_model=self._base_model,
+ quota_reserve_tokens=self._quota_reserve_tokens,
+ usage_type=self._usage_type,
+ call_details={"model": self._base_model},
+ ):
+ return await self._inner.ainvoke(input, *args, **kwargs)
+
+ def __getattr__(self, name: str) -> Any:
+ """Forward everything else (``invoke``, ``astream``, ``bind``,
+ ``with_structured_output``, …) to the inner model.
+
+ ``__getattr__`` is only consulted when the attribute is *not*
+ already found on the proxy, which is exactly the contract we
+ want — methods we override stay on the proxy, the rest fall
+ through.
+ """
+ return getattr(self._inner, name)
+
+
+__all__ = ["QuotaCheckedVisionLLM", "QuotaInsufficientError"]
diff --git a/surfsense_backend/app/services/token_quota_service.py b/surfsense_backend/app/services/token_quota_service.py
index a3ec7aed0..310c3eb5e 100644
--- a/surfsense_backend/app/services/token_quota_service.py
+++ b/surfsense_backend/app/services/token_quota_service.py
@@ -22,6 +22,71 @@ from app.config import config
logger = logging.getLogger(__name__)
+# ---------------------------------------------------------------------------
+# Per-call reservation estimator (USD micro-units)
+# ---------------------------------------------------------------------------
+
+# Minimum reserve in micros so a user with $0.0001 left can still make a tiny
+# request, and so models without registered pricing reserve at least
+# something while the call runs (debited 0 at finalize anyway when their
+# cost can't be resolved).
+_QUOTA_MIN_RESERVE_MICROS = 100
+
+
+def estimate_call_reserve_micros(
+ *,
+ base_model: str,
+ quota_reserve_tokens: int | None,
+) -> int:
+ """Return the number of micro-USD to reserve for one premium call.
+
+ Computes a worst-case upper bound from LiteLLM's per-token pricing
+ table:
+
+ reserve_usd ≈ reserve_tokens x (input_cost + output_cost)
+
+ so the math scales with model cost — Claude Opus + 4K reserve_tokens
+ naturally reserves ≈ $0.36, while a cheap model reserves only a few
+ cents. Clamped to ``[_QUOTA_MIN_RESERVE_MICROS, QUOTA_MAX_RESERVE_MICROS]``
+ so a misconfigured "$1000/M" model can't lock the whole balance on
+ one call.
+
+ If ``litellm.get_model_info`` raises (model unknown) we fall back to
+ the floor — 100 micros / $0.0001 — which is enough to gate a sane
+ request without over-reserving for a model whose pricing the
+ operator hasn't declared yet.
+ """
+ reserve_tokens = quota_reserve_tokens or config.QUOTA_MAX_RESERVE_PER_CALL
+ if reserve_tokens <= 0:
+ reserve_tokens = config.QUOTA_MAX_RESERVE_PER_CALL
+
+ try:
+ from litellm import get_model_info
+
+ info = get_model_info(base_model) if base_model else {}
+ input_cost = float(info.get("input_cost_per_token") or 0.0)
+ output_cost = float(info.get("output_cost_per_token") or 0.0)
+ except Exception as exc:
+ logger.debug(
+ "[quota_reserve] cost lookup failed for base_model=%s: %s",
+ base_model,
+ exc,
+ )
+ input_cost = 0.0
+ output_cost = 0.0
+
+ if input_cost == 0.0 and output_cost == 0.0:
+ return _QUOTA_MIN_RESERVE_MICROS
+
+ reserve_usd = reserve_tokens * (input_cost + output_cost)
+ reserve_micros = round(reserve_usd * 1_000_000)
+ if reserve_micros < _QUOTA_MIN_RESERVE_MICROS:
+ reserve_micros = _QUOTA_MIN_RESERVE_MICROS
+ if reserve_micros > config.QUOTA_MAX_RESERVE_MICROS:
+ reserve_micros = config.QUOTA_MAX_RESERVE_MICROS
+ return reserve_micros
+
+
class QuotaScope(StrEnum):
ANONYMOUS = "anonymous"
PREMIUM = "premium"
@@ -444,8 +509,16 @@ class TokenQuotaService:
db_session: AsyncSession,
user_id: Any,
request_id: str,
- reserve_tokens: int,
+ reserve_micros: int,
) -> QuotaResult:
+ """Reserve ``reserve_micros`` (USD micro-units) from the user's
+ premium credit balance.
+
+ ``QuotaResult.used``/``limit``/``reserved``/``remaining`` are
+ all in micro-USD on this code path; callers (chat stream,
+ token-status route, FE display) convert to dollars by dividing
+ by 1_000_000.
+ """
from app.db import User
user = (
@@ -465,11 +538,11 @@ class TokenQuotaService:
limit=0,
)
- limit = user.premium_tokens_limit
- used = user.premium_tokens_used
- reserved = user.premium_tokens_reserved
+ limit = user.premium_credit_micros_limit
+ used = user.premium_credit_micros_used
+ reserved = user.premium_credit_micros_reserved
- effective = used + reserved + reserve_tokens
+ effective = used + reserved + reserve_micros
if effective > limit:
remaining = max(0, limit - used - reserved)
await db_session.rollback()
@@ -482,10 +555,10 @@ class TokenQuotaService:
remaining=remaining,
)
- user.premium_tokens_reserved = reserved + reserve_tokens
+ user.premium_credit_micros_reserved = reserved + reserve_micros
await db_session.commit()
- new_reserved = reserved + reserve_tokens
+ new_reserved = reserved + reserve_micros
remaining = max(0, limit - used - new_reserved)
warning_threshold = int(limit * 0.8)
@@ -510,9 +583,12 @@ class TokenQuotaService:
db_session: AsyncSession,
user_id: Any,
request_id: str,
- actual_tokens: int,
- reserved_tokens: int,
+ actual_micros: int,
+ reserved_micros: int,
) -> QuotaResult:
+ """Settle the reservation: release ``reserved_micros`` and debit
+ ``actual_micros`` (the LiteLLM-reported provider cost in micro-USD).
+ """
from app.db import User
user = (
@@ -529,16 +605,18 @@ class TokenQuotaService:
allowed=False, status=QuotaStatus.BLOCKED, used=0, limit=0
)
- user.premium_tokens_reserved = max(
- 0, user.premium_tokens_reserved - reserved_tokens
+ user.premium_credit_micros_reserved = max(
+ 0, user.premium_credit_micros_reserved - reserved_micros
+ )
+ user.premium_credit_micros_used = (
+ user.premium_credit_micros_used + actual_micros
)
- user.premium_tokens_used = user.premium_tokens_used + actual_tokens
await db_session.commit()
- limit = user.premium_tokens_limit
- used = user.premium_tokens_used
- reserved = user.premium_tokens_reserved
+ limit = user.premium_credit_micros_limit
+ used = user.premium_credit_micros_used
+ reserved = user.premium_credit_micros_reserved
remaining = max(0, limit - used - reserved)
warning_threshold = int(limit * 0.8)
@@ -562,8 +640,13 @@ class TokenQuotaService:
async def premium_release(
db_session: AsyncSession,
user_id: Any,
- reserved_tokens: int,
+ reserved_micros: int,
) -> None:
+ """Release ``reserved_micros`` previously held by ``premium_reserve``.
+
+ Used when a request fails before finalize (so the reservation
+ doesn't leak credit).
+ """
from app.db import User
user = (
@@ -576,8 +659,8 @@ class TokenQuotaService:
.scalar_one_or_none()
)
if user is not None:
- user.premium_tokens_reserved = max(
- 0, user.premium_tokens_reserved - reserved_tokens
+ user.premium_credit_micros_reserved = max(
+ 0, user.premium_credit_micros_reserved - reserved_micros
)
await db_session.commit()
@@ -598,9 +681,9 @@ class TokenQuotaService:
allowed=False, status=QuotaStatus.BLOCKED, used=0, limit=0
)
- limit = user.premium_tokens_limit
- used = user.premium_tokens_used
- reserved = user.premium_tokens_reserved
+ limit = user.premium_credit_micros_limit
+ used = user.premium_credit_micros_used
+ reserved = user.premium_credit_micros_reserved
remaining = max(0, limit - used - reserved)
warning_threshold = int(limit * 0.8)
diff --git a/surfsense_backend/app/services/token_tracking_service.py b/surfsense_backend/app/services/token_tracking_service.py
index 9aa8c6e70..9406d9be4 100644
--- a/surfsense_backend/app/services/token_tracking_service.py
+++ b/surfsense_backend/app/services/token_tracking_service.py
@@ -16,11 +16,14 @@ from __future__ import annotations
import dataclasses
import logging
+from collections.abc import AsyncIterator
+from contextlib import asynccontextmanager
from contextvars import ContextVar
from dataclasses import dataclass, field
from typing import Any
from uuid import UUID
+import litellm
from litellm.integrations.custom_logger import CustomLogger
from sqlalchemy.ext.asyncio import AsyncSession
@@ -35,6 +38,8 @@ class TokenCallRecord:
prompt_tokens: int
completion_tokens: int
total_tokens: int
+ cost_micros: int = 0
+ call_kind: str = "chat"
@dataclass
@@ -49,6 +54,8 @@ class TurnTokenAccumulator:
prompt_tokens: int,
completion_tokens: int,
total_tokens: int,
+ cost_micros: int = 0,
+ call_kind: str = "chat",
) -> None:
self.calls.append(
TokenCallRecord(
@@ -56,20 +63,28 @@ class TurnTokenAccumulator:
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=total_tokens,
+ cost_micros=cost_micros,
+ call_kind=call_kind,
)
)
def per_message_summary(self) -> dict[str, dict[str, int]]:
- """Return token counts grouped by model name."""
+ """Return token counts (and cost) grouped by model name."""
by_model: dict[str, dict[str, int]] = {}
for c in self.calls:
entry = by_model.setdefault(
c.model,
- {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
+ {
+ "prompt_tokens": 0,
+ "completion_tokens": 0,
+ "total_tokens": 0,
+ "cost_micros": 0,
+ },
)
entry["prompt_tokens"] += c.prompt_tokens
entry["completion_tokens"] += c.completion_tokens
entry["total_tokens"] += c.total_tokens
+ entry["cost_micros"] += c.cost_micros
return by_model
@property
@@ -84,6 +99,21 @@ class TurnTokenAccumulator:
def total_completion_tokens(self) -> int:
return sum(c.completion_tokens for c in self.calls)
+ @property
+ def total_cost_micros(self) -> int:
+ """Sum of per-call ``cost_micros`` across the entire turn.
+
+ Used by ``stream_new_chat`` to debit a premium turn's actual
+ provider cost (in micro-USD) from the user's premium credit
+ balance. ``cost_micros`` per call is captured by
+ ``TokenTrackingCallback.async_log_success_event`` from
+ ``kwargs["response_cost"]`` (LiteLLM's auto-calculated cost),
+ with multiple fallback paths so OpenRouter dynamic models and
+ custom Azure deployments still bill correctly when our
+ ``pricing_registration`` ran at startup.
+ """
+ return sum(c.cost_micros for c in self.calls)
+
def serialized_calls(self) -> list[dict[str, Any]]:
return [dataclasses.asdict(c) for c in self.calls]
@@ -94,7 +124,14 @@ _turn_accumulator: ContextVar[TurnTokenAccumulator | None] = ContextVar(
def start_turn() -> TurnTokenAccumulator:
- """Create a fresh accumulator for the current async context and return it."""
+ """Create a fresh accumulator for the current async context and return it.
+
+ NOTE: Used by ``stream_new_chat`` for the long-lived chat turn. For
+ short-lived per-call billable wrappers (image generation REST endpoint,
+ vision LLM during indexing) prefer :func:`scoped_turn`, which uses a
+ ContextVar reset token to restore the *previous* accumulator on exit and
+ avoids leaking call records across reservations (issue B).
+ """
acc = TurnTokenAccumulator()
_turn_accumulator.set(acc)
logger.info("[TokenTracking] start_turn: new accumulator created (id=%s)", id(acc))
@@ -105,6 +142,140 @@ def get_current_accumulator() -> TurnTokenAccumulator | None:
return _turn_accumulator.get()
+@asynccontextmanager
+async def scoped_turn() -> AsyncIterator[TurnTokenAccumulator]:
+ """Async context manager that scopes a fresh ``TurnTokenAccumulator``
+ for the duration of the ``async with`` block, then *resets* the
+ ContextVar to its previous value on exit.
+
+ This is the safe primitive for per-call billable operations
+ (image generation, vision LLM extraction, podcasts) that may run
+ inside an outer chat turn or be called sequentially from the same
+ background worker. Using ``ContextVar.set`` without ``reset`` (as
+ :func:`start_turn` does) would leak the inner accumulator into the
+ outer scope, causing the outer chat turn to debit cost twice.
+
+ Usage::
+
+ async with scoped_turn() as acc:
+ await llm.ainvoke(...)
+ # acc.total_cost_micros captures cost from the LiteLLM callback
+ # Outer accumulator (if any) is restored here.
+ """
+ acc = TurnTokenAccumulator()
+ token = _turn_accumulator.set(acc)
+ logger.debug(
+ "[TokenTracking] scoped_turn: enter (acc id=%s, prev token=%s)",
+ id(acc),
+ token,
+ )
+ try:
+ yield acc
+ finally:
+ _turn_accumulator.reset(token)
+ logger.debug(
+ "[TokenTracking] scoped_turn: exit (acc id=%s captured %d call(s), %d micros total)",
+ id(acc),
+ len(acc.calls),
+ acc.total_cost_micros,
+ )
+
+
+def _extract_cost_usd(
+ kwargs: dict[str, Any],
+ response_obj: Any,
+ model: str,
+ prompt_tokens: int,
+ completion_tokens: int,
+ is_image: bool = False,
+) -> float:
+ """Best-effort USD cost extraction for a single LLM/image call.
+
+ Tries four sources in priority order and returns the first that
+ yields a positive number; returns 0.0 if all four fail (the call
+ will then debit nothing from the user's balance — fail-safe).
+
+ Sources:
+ 1. ``kwargs["response_cost"]`` — LiteLLM's standard callback
+ field, populated for ``Router.acompletion`` since PR #12500.
+ 2. ``response_obj._hidden_params["response_cost"]`` — same value
+ exposed on the response itself.
+ 3. ``litellm.completion_cost(completion_response=response_obj)``
+ — recompute from the response and LiteLLM's pricing table.
+ 4. ``litellm.cost_per_token(model, prompt_tokens, completion_tokens)``
+ — manual fallback for OpenRouter/custom-Azure models that
+ only resolve via aliases registered by
+ ``pricing_registration`` at startup. **Skipped for image
+ responses** — ``cost_per_token`` does not support ``ImageResponse``
+ and would raise; the cost map for image-gen lives in different
+ keys (``output_cost_per_image``) handled by ``completion_cost``.
+ """
+ cost = kwargs.get("response_cost")
+ if cost is not None:
+ try:
+ value = float(cost)
+ except (TypeError, ValueError):
+ value = 0.0
+ if value > 0:
+ return value
+
+ hidden = getattr(response_obj, "_hidden_params", None) or {}
+ if isinstance(hidden, dict):
+ cost = hidden.get("response_cost")
+ if cost is not None:
+ try:
+ value = float(cost)
+ except (TypeError, ValueError):
+ value = 0.0
+ if value > 0:
+ return value
+
+ try:
+ value = float(litellm.completion_cost(completion_response=response_obj))
+ if value > 0:
+ return value
+ except Exception as exc:
+ if is_image:
+ # Image-gen path: OpenRouter's image responses can omit
+ # ``usage.cost`` and LiteLLM's ``default_image_cost_calculator``
+ # then *raises* (no cost map for OpenRouter image models).
+ # Bail out with a warning rather than falling through to
+ # cost_per_token (which is also incompatible with ImageResponse).
+ logger.warning(
+ "[TokenTracking] completion_cost failed for image model=%s "
+ "(provider may have omitted usage.cost). Debiting 0. "
+ "Cause: %s",
+ model,
+ exc,
+ )
+ return 0.0
+ logger.debug(
+ "[TokenTracking] completion_cost failed for model=%s: %s", model, exc
+ )
+
+ if is_image:
+ # Never call cost_per_token for ImageResponse — keys mismatch and
+ # the function is documented chat-only.
+ return 0.0
+
+ if model and (prompt_tokens > 0 or completion_tokens > 0):
+ try:
+ prompt_cost, completion_cost = litellm.cost_per_token(
+ model=model,
+ prompt_tokens=prompt_tokens,
+ completion_tokens=completion_tokens,
+ )
+ value = float(prompt_cost) + float(completion_cost)
+ if value > 0:
+ return value
+ except Exception as exc:
+ logger.debug(
+ "[TokenTracking] cost_per_token failed for model=%s: %s", model, exc
+ )
+
+ return 0.0
+
+
class TokenTrackingCallback(CustomLogger):
"""LiteLLM callback that captures token usage into the turn accumulator."""
@@ -122,6 +293,13 @@ class TokenTrackingCallback(CustomLogger):
)
return
+ # Detect image generation responses — they have a different usage
+ # shape (ImageUsage with input_tokens/output_tokens) and require a
+ # different cost-extraction path. We probe by class name to avoid a
+ # hard import dependency on litellm internals.
+ response_cls = type(response_obj).__name__
+ is_image = response_cls == "ImageResponse"
+
usage = getattr(response_obj, "usage", None)
if not usage:
logger.debug(
@@ -129,24 +307,66 @@ class TokenTrackingCallback(CustomLogger):
)
return
- prompt_tokens = getattr(usage, "prompt_tokens", 0) or 0
- completion_tokens = getattr(usage, "completion_tokens", 0) or 0
- total_tokens = getattr(usage, "total_tokens", 0) or 0
+ if is_image:
+ # ``ImageUsage`` exposes ``input_tokens`` / ``output_tokens``
+ # (not prompt_tokens/completion_tokens). Several providers
+ # populate only one or neither (e.g. OpenRouter's gpt-image-1
+ # passes through `input_tokens` from the prompt but no
+ # completion); fall through gracefully to 0.
+ prompt_tokens = getattr(usage, "input_tokens", 0) or 0
+ completion_tokens = getattr(usage, "output_tokens", 0) or 0
+ total_tokens = (
+ getattr(usage, "total_tokens", 0) or prompt_tokens + completion_tokens
+ )
+ call_kind = "image_generation"
+ else:
+ prompt_tokens = getattr(usage, "prompt_tokens", 0) or 0
+ completion_tokens = getattr(usage, "completion_tokens", 0) or 0
+ total_tokens = getattr(usage, "total_tokens", 0) or 0
+ call_kind = "chat"
model = kwargs.get("model", "unknown")
+ cost_usd = _extract_cost_usd(
+ kwargs=kwargs,
+ response_obj=response_obj,
+ model=model,
+ prompt_tokens=prompt_tokens,
+ completion_tokens=completion_tokens,
+ is_image=is_image,
+ )
+ cost_micros = round(cost_usd * 1_000_000) if cost_usd > 0 else 0
+
+ if cost_micros == 0 and (prompt_tokens > 0 or completion_tokens > 0):
+ logger.warning(
+ "[TokenTracking] No cost resolved for model=%s prompt=%d completion=%d "
+ "kind=%s — debiting 0. Register pricing via pricing_registration or YAML "
+ "input_cost_per_token/output_cost_per_token (or rely on response_cost "
+ "for image generation).",
+ model,
+ prompt_tokens,
+ completion_tokens,
+ call_kind,
+ )
+
acc.add(
model=model,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=total_tokens,
+ cost_micros=cost_micros,
+ call_kind=call_kind,
)
logger.info(
- "[TokenTracking] Captured: model=%s prompt=%d completion=%d total=%d (accumulator now has %d calls)",
+ "[TokenTracking] Captured: model=%s kind=%s prompt=%d completion=%d total=%d "
+ "cost=$%.6f (%d micros) (accumulator now has %d calls)",
model,
+ call_kind,
prompt_tokens,
completion_tokens,
total_tokens,
+ cost_usd,
+ cost_micros,
len(acc.calls),
)
@@ -168,6 +388,7 @@ async def record_token_usage(
prompt_tokens: int = 0,
completion_tokens: int = 0,
total_tokens: int = 0,
+ cost_micros: int = 0,
model_breakdown: dict[str, Any] | None = None,
call_details: dict[str, Any] | None = None,
thread_id: int | None = None,
@@ -185,6 +406,7 @@ async def record_token_usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=total_tokens,
+ cost_micros=cost_micros,
model_breakdown=model_breakdown,
call_details=call_details,
thread_id=thread_id,
@@ -194,11 +416,12 @@ async def record_token_usage(
)
session.add(record)
logger.debug(
- "[TokenTracking] recorded %s usage: prompt=%d completion=%d total=%d",
+ "[TokenTracking] recorded %s usage: prompt=%d completion=%d total=%d cost_micros=%d",
usage_type,
prompt_tokens,
completion_tokens,
total_tokens,
+ cost_micros,
)
return record
except Exception:
diff --git a/surfsense_backend/app/services/vision_llm_router_service.py b/surfsense_backend/app/services/vision_llm_router_service.py
index 0d782ab2b..ed5de921c 100644
--- a/surfsense_backend/app/services/vision_llm_router_service.py
+++ b/surfsense_backend/app/services/vision_llm_router_service.py
@@ -3,6 +3,8 @@ from typing import Any
from litellm import Router
+from app.services.provider_api_base import resolve_api_base
+
logger = logging.getLogger(__name__)
VISION_AUTO_MODE_ID = 0
@@ -108,10 +110,11 @@ class VisionLLMRouterService:
if not config.get("model_name") or not config.get("api_key"):
return None
+ provider = config.get("provider", "").upper()
if config.get("custom_provider"):
- model_string = f"{config['custom_provider']}/{config['model_name']}"
+ provider_prefix = config["custom_provider"]
+ model_string = f"{provider_prefix}/{config['model_name']}"
else:
- provider = config.get("provider", "").upper()
provider_prefix = VISION_PROVIDER_MAP.get(provider, provider.lower())
model_string = f"{provider_prefix}/{config['model_name']}"
@@ -120,8 +123,13 @@ class VisionLLMRouterService:
"api_key": config.get("api_key"),
}
- if config.get("api_base"):
- litellm_params["api_base"] = config["api_base"]
+ api_base = resolve_api_base(
+ provider=provider,
+ provider_prefix=provider_prefix,
+ config_api_base=config.get("api_base"),
+ )
+ if api_base:
+ litellm_params["api_base"] = api_base
if config.get("api_version"):
litellm_params["api_version"] = config["api_version"]
diff --git a/surfsense_backend/app/tasks/celery_tasks/podcast_tasks.py b/surfsense_backend/app/tasks/celery_tasks/podcast_tasks.py
index 953011ecf..937877473 100644
--- a/surfsense_backend/app/tasks/celery_tasks/podcast_tasks.py
+++ b/surfsense_backend/app/tasks/celery_tasks/podcast_tasks.py
@@ -9,7 +9,13 @@ from sqlalchemy import select
from app.agents.podcaster.graph import graph as podcaster_graph
from app.agents.podcaster.state import State as PodcasterState
from app.celery_app import celery_app
+from app.config import config as app_config
from app.db import Podcast, PodcastStatus
+from app.services.billable_calls import (
+ QuotaInsufficientError,
+ _resolve_agent_billing_for_search_space,
+ billable_call,
+)
from app.tasks.celery_tasks import get_celery_session_maker
logger = logging.getLogger(__name__)
@@ -96,6 +102,31 @@ async def _generate_content_podcast(
podcast.status = PodcastStatus.GENERATING
await session.commit()
+ try:
+ (
+ owner_user_id,
+ billing_tier,
+ base_model,
+ ) = await _resolve_agent_billing_for_search_space(
+ session,
+ search_space_id,
+ thread_id=podcast.thread_id,
+ )
+ except ValueError as resolve_err:
+ logger.error(
+ "Podcast %s: cannot resolve billing for search_space=%s: %s",
+ podcast.id,
+ search_space_id,
+ resolve_err,
+ )
+ podcast.status = PodcastStatus.FAILED
+ await session.commit()
+ return {
+ "status": "failed",
+ "podcast_id": podcast.id,
+ "reason": "billing_resolution_failed",
+ }
+
graph_config = {
"configurable": {
"podcast_title": podcast.title,
@@ -109,9 +140,39 @@ async def _generate_content_podcast(
db_session=session,
)
- graph_result = await podcaster_graph.ainvoke(
- initial_state, config=graph_config
- )
+ try:
+ async with billable_call(
+ user_id=owner_user_id,
+ search_space_id=search_space_id,
+ billing_tier=billing_tier,
+ base_model=base_model,
+ quota_reserve_micros_override=app_config.QUOTA_DEFAULT_PODCAST_RESERVE_MICROS,
+ usage_type="podcast_generation",
+ thread_id=podcast.thread_id,
+ call_details={
+ "podcast_id": podcast.id,
+ "title": podcast.title,
+ },
+ ):
+ graph_result = await podcaster_graph.ainvoke(
+ initial_state, config=graph_config
+ )
+ except QuotaInsufficientError as exc:
+ logger.info(
+ "Podcast %s denied: out of premium credits "
+ "(used=%d/%d remaining=%d)",
+ podcast.id,
+ exc.used_micros,
+ exc.limit_micros,
+ exc.remaining_micros,
+ )
+ podcast.status = PodcastStatus.FAILED
+ await session.commit()
+ return {
+ "status": "failed",
+ "podcast_id": podcast.id,
+ "reason": "premium_quota_exhausted",
+ }
podcast_transcript = graph_result.get("podcast_transcript", [])
file_path = graph_result.get("final_podcast_file_path", "")
diff --git a/surfsense_backend/app/tasks/celery_tasks/video_presentation_tasks.py b/surfsense_backend/app/tasks/celery_tasks/video_presentation_tasks.py
index 7880b385f..4f0c427d9 100644
--- a/surfsense_backend/app/tasks/celery_tasks/video_presentation_tasks.py
+++ b/surfsense_backend/app/tasks/celery_tasks/video_presentation_tasks.py
@@ -9,7 +9,13 @@ from sqlalchemy import select
from app.agents.video_presentation.graph import graph as video_presentation_graph
from app.agents.video_presentation.state import State as VideoPresentationState
from app.celery_app import celery_app
+from app.config import config as app_config
from app.db import VideoPresentation, VideoPresentationStatus
+from app.services.billable_calls import (
+ QuotaInsufficientError,
+ _resolve_agent_billing_for_search_space,
+ billable_call,
+)
from app.tasks.celery_tasks import get_celery_session_maker
logger = logging.getLogger(__name__)
@@ -97,6 +103,32 @@ async def _generate_video_presentation(
video_pres.status = VideoPresentationStatus.GENERATING
await session.commit()
+ try:
+ (
+ owner_user_id,
+ billing_tier,
+ base_model,
+ ) = await _resolve_agent_billing_for_search_space(
+ session,
+ search_space_id,
+ thread_id=video_pres.thread_id,
+ )
+ except ValueError as resolve_err:
+ logger.error(
+ "VideoPresentation %s: cannot resolve billing for "
+ "search_space=%s: %s",
+ video_pres.id,
+ search_space_id,
+ resolve_err,
+ )
+ video_pres.status = VideoPresentationStatus.FAILED
+ await session.commit()
+ return {
+ "status": "failed",
+ "video_presentation_id": video_pres.id,
+ "reason": "billing_resolution_failed",
+ }
+
graph_config = {
"configurable": {
"video_title": video_pres.title,
@@ -110,9 +142,39 @@ async def _generate_video_presentation(
db_session=session,
)
- graph_result = await video_presentation_graph.ainvoke(
- initial_state, config=graph_config
- )
+ try:
+ async with billable_call(
+ user_id=owner_user_id,
+ search_space_id=search_space_id,
+ billing_tier=billing_tier,
+ base_model=base_model,
+ quota_reserve_micros_override=app_config.QUOTA_DEFAULT_VIDEO_PRESENTATION_RESERVE_MICROS,
+ usage_type="video_presentation_generation",
+ thread_id=video_pres.thread_id,
+ call_details={
+ "video_presentation_id": video_pres.id,
+ "title": video_pres.title,
+ },
+ ):
+ graph_result = await video_presentation_graph.ainvoke(
+ initial_state, config=graph_config
+ )
+ except QuotaInsufficientError as exc:
+ logger.info(
+ "VideoPresentation %s denied: out of premium credits "
+ "(used=%d/%d remaining=%d)",
+ video_pres.id,
+ exc.used_micros,
+ exc.limit_micros,
+ exc.remaining_micros,
+ )
+ video_pres.status = VideoPresentationStatus.FAILED
+ await session.commit()
+ return {
+ "status": "failed",
+ "video_presentation_id": video_pres.id,
+ "reason": "premium_quota_exhausted",
+ }
# Serialize slides (parsed content + audio info merged)
slides_raw = graph_result.get("slides", [])
diff --git a/surfsense_backend/app/tasks/chat/stream_new_chat.py b/surfsense_backend/app/tasks/chat/stream_new_chat.py
index dbfe9a67b..31c0d7d6d 100644
--- a/surfsense_backend/app/tasks/chat/stream_new_chat.py
+++ b/surfsense_backend/app/tasks/chat/stream_new_chat.py
@@ -2236,8 +2236,10 @@ async def stream_new_chat(
accumulator = start_turn()
- # Premium quota tracking state
- _premium_reserved = 0
+ # Premium credit (USD micro-units) tracking state. Stores the
+ # amount reserved up front so we can release it on cancellation
+ # and finalize-debit the actual provider cost reported by LiteLLM.
+ _premium_reserved_micros = 0
_premium_request_id: str | None = None
_emit_stream_error = partial(
@@ -2331,23 +2333,28 @@ async def stream_new_chat(
if _needs_premium_quota:
import uuid as _uuid
- from app.config import config as _app_config
- from app.services.token_quota_service import TokenQuotaService
+ from app.services.token_quota_service import (
+ TokenQuotaService,
+ estimate_call_reserve_micros,
+ )
_premium_request_id = _uuid.uuid4().hex[:16]
- reserve_amount = min(
- agent_config.quota_reserve_tokens
- or _app_config.QUOTA_MAX_RESERVE_PER_CALL,
- _app_config.QUOTA_MAX_RESERVE_PER_CALL,
+ _agent_litellm_params = agent_config.litellm_params or {}
+ _agent_base_model = (
+ _agent_litellm_params.get("base_model") or agent_config.model_name or ""
+ )
+ reserve_amount_micros = estimate_call_reserve_micros(
+ base_model=_agent_base_model,
+ quota_reserve_tokens=agent_config.quota_reserve_tokens,
)
async with shielded_async_session() as quota_session:
quota_result = await TokenQuotaService.premium_reserve(
db_session=quota_session,
user_id=UUID(user_id),
request_id=_premium_request_id,
- reserve_tokens=reserve_amount,
+ reserve_micros=reserve_amount_micros,
)
- _premium_reserved = reserve_amount
+ _premium_reserved_micros = reserve_amount_micros
if not quota_result.allowed:
if requested_llm_config_id == 0:
try:
@@ -2382,7 +2389,7 @@ async def stream_new_chat(
yield streaming_service.format_done()
return
_premium_request_id = None
- _premium_reserved = 0
+ _premium_reserved_micros = 0
_log_chat_stream_error(
flow=flow,
error_kind="premium_quota_exhausted",
@@ -3020,9 +3027,10 @@ async def stream_new_chat(
usage_summary = accumulator.per_message_summary()
_perf_log.info(
- "[token_usage] interrupted new_chat: calls=%d total=%d summary=%s",
+ "[token_usage] interrupted new_chat: calls=%d total=%d cost_micros=%d summary=%s",
len(accumulator.calls),
accumulator.grand_total,
+ accumulator.total_cost_micros,
usage_summary,
)
if usage_summary:
@@ -3033,6 +3041,7 @@ async def stream_new_chat(
"prompt_tokens": accumulator.total_prompt_tokens,
"completion_tokens": accumulator.total_completion_tokens,
"total_tokens": accumulator.grand_total,
+ "cost_micros": accumulator.total_cost_micros,
"call_details": accumulator.serialized_calls(),
},
)
@@ -3060,7 +3069,11 @@ async def stream_new_chat(
chat_id, generated_title
)
- # Finalize premium quota with actual tokens.
+ # Finalize premium credit debit with the actual provider cost
+ # reported by LiteLLM, summed across every call in the turn.
+ # Mirrors the pre-cost behaviour of "premium turn → all calls
+ # count" so free sub-agent calls during a premium turn still
+ # contribute to the bill (they're $0 in practice anyway).
if _premium_request_id and user_id:
try:
from app.services.token_quota_service import TokenQuotaService
@@ -3070,11 +3083,11 @@ async def stream_new_chat(
db_session=quota_session,
user_id=UUID(user_id),
request_id=_premium_request_id,
- actual_tokens=accumulator.grand_total,
- reserved_tokens=_premium_reserved,
+ actual_micros=accumulator.total_cost_micros,
+ reserved_micros=_premium_reserved_micros,
)
_premium_request_id = None
- _premium_reserved = 0
+ _premium_reserved_micros = 0
except Exception:
logging.getLogger(__name__).warning(
"Failed to finalize premium quota for user %s",
@@ -3084,9 +3097,10 @@ async def stream_new_chat(
usage_summary = accumulator.per_message_summary()
_perf_log.info(
- "[token_usage] normal new_chat: calls=%d total=%d summary=%s",
+ "[token_usage] normal new_chat: calls=%d total=%d cost_micros=%d summary=%s",
len(accumulator.calls),
accumulator.grand_total,
+ accumulator.total_cost_micros,
usage_summary,
)
if usage_summary:
@@ -3097,6 +3111,7 @@ async def stream_new_chat(
"prompt_tokens": accumulator.total_prompt_tokens,
"completion_tokens": accumulator.total_completion_tokens,
"total_tokens": accumulator.grand_total,
+ "cost_micros": accumulator.total_cost_micros,
"call_details": accumulator.serialized_calls(),
},
)
@@ -3190,7 +3205,7 @@ async def stream_new_chat(
end_turn(str(chat_id))
# Release premium reservation if not finalized
- if _premium_request_id and _premium_reserved > 0 and user_id:
+ if _premium_request_id and _premium_reserved_micros > 0 and user_id:
try:
from app.services.token_quota_service import TokenQuotaService
@@ -3198,9 +3213,9 @@ async def stream_new_chat(
await TokenQuotaService.premium_release(
db_session=quota_session,
user_id=UUID(user_id),
- reserved_tokens=_premium_reserved,
+ reserved_micros=_premium_reserved_micros,
)
- _premium_reserved = 0
+ _premium_reserved_micros = 0
except Exception:
logging.getLogger(__name__).warning(
"Failed to release premium quota for user %s", user_id
@@ -3369,8 +3384,8 @@ async def stream_resume_chat(
"[stream_resume] LLM config loaded in %.3fs", time.perf_counter() - _t0
)
- # Premium quota reservation (same logic as stream_new_chat)
- _resume_premium_reserved = 0
+ # Premium credit reservation (same logic as stream_new_chat).
+ _resume_premium_reserved_micros = 0
_resume_premium_request_id: str | None = None
_resume_needs_premium = (
agent_config is not None and user_id and agent_config.is_premium
@@ -3378,23 +3393,30 @@ async def stream_resume_chat(
if _resume_needs_premium:
import uuid as _uuid
- from app.config import config as _app_config
- from app.services.token_quota_service import TokenQuotaService
+ from app.services.token_quota_service import (
+ TokenQuotaService,
+ estimate_call_reserve_micros,
+ )
_resume_premium_request_id = _uuid.uuid4().hex[:16]
- reserve_amount = min(
- agent_config.quota_reserve_tokens
- or _app_config.QUOTA_MAX_RESERVE_PER_CALL,
- _app_config.QUOTA_MAX_RESERVE_PER_CALL,
+ _resume_litellm_params = agent_config.litellm_params or {}
+ _resume_base_model = (
+ _resume_litellm_params.get("base_model")
+ or agent_config.model_name
+ or ""
+ )
+ reserve_amount_micros = estimate_call_reserve_micros(
+ base_model=_resume_base_model,
+ quota_reserve_tokens=agent_config.quota_reserve_tokens,
)
async with shielded_async_session() as quota_session:
quota_result = await TokenQuotaService.premium_reserve(
db_session=quota_session,
user_id=UUID(user_id),
request_id=_resume_premium_request_id,
- reserve_tokens=reserve_amount,
+ reserve_micros=reserve_amount_micros,
)
- _resume_premium_reserved = reserve_amount
+ _resume_premium_reserved_micros = reserve_amount_micros
if not quota_result.allowed:
if requested_llm_config_id == 0:
try:
@@ -3429,7 +3451,7 @@ async def stream_resume_chat(
yield streaming_service.format_done()
return
_resume_premium_request_id = None
- _resume_premium_reserved = 0
+ _resume_premium_reserved_micros = 0
_log_chat_stream_error(
flow="resume",
error_kind="premium_quota_exhausted",
@@ -3746,9 +3768,10 @@ async def stream_resume_chat(
if stream_result.is_interrupted:
usage_summary = accumulator.per_message_summary()
_perf_log.info(
- "[token_usage] interrupted resume_chat: calls=%d total=%d summary=%s",
+ "[token_usage] interrupted resume_chat: calls=%d total=%d cost_micros=%d summary=%s",
len(accumulator.calls),
accumulator.grand_total,
+ accumulator.total_cost_micros,
usage_summary,
)
if usage_summary:
@@ -3759,6 +3782,7 @@ async def stream_resume_chat(
"prompt_tokens": accumulator.total_prompt_tokens,
"completion_tokens": accumulator.total_completion_tokens,
"total_tokens": accumulator.grand_total,
+ "cost_micros": accumulator.total_cost_micros,
"call_details": accumulator.serialized_calls(),
},
)
@@ -3768,7 +3792,9 @@ async def stream_resume_chat(
yield streaming_service.format_done()
return
- # Finalize premium quota for resume path
+ # Finalize premium credit debit for resume path with the actual
+ # provider cost reported by LiteLLM (sum of cost across all
+ # calls in the turn).
if _resume_premium_request_id and user_id:
try:
from app.services.token_quota_service import TokenQuotaService
@@ -3778,11 +3804,11 @@ async def stream_resume_chat(
db_session=quota_session,
user_id=UUID(user_id),
request_id=_resume_premium_request_id,
- actual_tokens=accumulator.grand_total,
- reserved_tokens=_resume_premium_reserved,
+ actual_micros=accumulator.total_cost_micros,
+ reserved_micros=_resume_premium_reserved_micros,
)
_resume_premium_request_id = None
- _resume_premium_reserved = 0
+ _resume_premium_reserved_micros = 0
except Exception:
logging.getLogger(__name__).warning(
"Failed to finalize premium quota for user %s (resume)",
@@ -3792,9 +3818,10 @@ async def stream_resume_chat(
usage_summary = accumulator.per_message_summary()
_perf_log.info(
- "[token_usage] normal resume_chat: calls=%d total=%d summary=%s",
+ "[token_usage] normal resume_chat: calls=%d total=%d cost_micros=%d summary=%s",
len(accumulator.calls),
accumulator.grand_total,
+ accumulator.total_cost_micros,
usage_summary,
)
if usage_summary:
@@ -3805,6 +3832,7 @@ async def stream_resume_chat(
"prompt_tokens": accumulator.total_prompt_tokens,
"completion_tokens": accumulator.total_completion_tokens,
"total_tokens": accumulator.grand_total,
+ "cost_micros": accumulator.total_cost_micros,
"call_details": accumulator.serialized_calls(),
},
)
@@ -3855,7 +3883,11 @@ async def stream_resume_chat(
end_turn(str(chat_id))
# Release premium reservation if not finalized
- if _resume_premium_request_id and _resume_premium_reserved > 0 and user_id:
+ if (
+ _resume_premium_request_id
+ and _resume_premium_reserved_micros > 0
+ and user_id
+ ):
try:
from app.services.token_quota_service import TokenQuotaService
@@ -3863,9 +3895,9 @@ async def stream_resume_chat(
await TokenQuotaService.premium_release(
db_session=quota_session,
user_id=UUID(user_id),
- reserved_tokens=_resume_premium_reserved,
+ reserved_micros=_resume_premium_reserved_micros,
)
- _resume_premium_reserved = 0
+ _resume_premium_reserved_micros = 0
except Exception:
logging.getLogger(__name__).warning(
"Failed to release premium quota for user %s (resume)", user_id
diff --git a/surfsense_backend/tests/unit/routes/test_image_gen_quota.py b/surfsense_backend/tests/unit/routes/test_image_gen_quota.py
new file mode 100644
index 000000000..636b7de31
--- /dev/null
+++ b/surfsense_backend/tests/unit/routes/test_image_gen_quota.py
@@ -0,0 +1,138 @@
+"""Unit tests for the image-generation route's billing-resolution helper.
+
+End-to-end "POST /image-generations returns 402" coverage requires the
+integration harness (real DB, real auth) and lives in
+``tests/integration/document_upload/`` alongside the other quota tests.
+This unit test focuses on the new ``_resolve_billing_for_image_gen``
+helper which:
+
+* Returns ``free`` for Auto mode, even when premium configs exist
+ (Auto-mode billing-tier surfacing is a follow-up).
+* Returns ``free`` for user-owned BYOK configs (positive IDs).
+* Returns the global config's ``billing_tier`` for negative IDs.
+* Honours the per-config ``quota_reserve_micros`` override when present.
+"""
+
+from __future__ import annotations
+
+from types import SimpleNamespace
+
+import pytest
+
+pytestmark = pytest.mark.unit
+
+
+@pytest.mark.asyncio
+async def test_resolve_billing_for_auto_mode(monkeypatch):
+ from app.routes import image_generation_routes
+ from app.services.billable_calls import DEFAULT_IMAGE_RESERVE_MICROS
+
+ search_space = SimpleNamespace(image_generation_config_id=None)
+ tier, model, reserve = await image_generation_routes._resolve_billing_for_image_gen(
+ session=None, # Not consumed on this code path.
+ config_id=0, # IMAGE_GEN_AUTO_MODE_ID
+ search_space=search_space,
+ )
+ assert tier == "free"
+ assert model == "auto"
+ assert reserve == DEFAULT_IMAGE_RESERVE_MICROS
+
+
+@pytest.mark.asyncio
+async def test_resolve_billing_for_premium_global_config(monkeypatch):
+ from app.config import config
+ from app.routes import image_generation_routes
+
+ monkeypatch.setattr(
+ config,
+ "GLOBAL_IMAGE_GEN_CONFIGS",
+ [
+ {
+ "id": -1,
+ "provider": "OPENAI",
+ "model_name": "gpt-image-1",
+ "billing_tier": "premium",
+ "quota_reserve_micros": 75_000,
+ },
+ {
+ "id": -2,
+ "provider": "OPENROUTER",
+ "model_name": "google/gemini-2.5-flash-image",
+ "billing_tier": "free",
+ },
+ ],
+ raising=False,
+ )
+
+ search_space = SimpleNamespace(image_generation_config_id=None)
+
+ # Premium with override.
+ tier, model, reserve = await image_generation_routes._resolve_billing_for_image_gen(
+ session=None, config_id=-1, search_space=search_space
+ )
+ assert tier == "premium"
+ assert model == "openai/gpt-image-1"
+ assert reserve == 75_000
+
+ # Free, no override → falls back to default.
+ from app.services.billable_calls import DEFAULT_IMAGE_RESERVE_MICROS
+
+ tier, model, reserve = await image_generation_routes._resolve_billing_for_image_gen(
+ session=None, config_id=-2, search_space=search_space
+ )
+ assert tier == "free"
+ # Provider-prefixed model string for OpenRouter.
+ assert "google/gemini-2.5-flash-image" in model
+ assert reserve == DEFAULT_IMAGE_RESERVE_MICROS
+
+
+@pytest.mark.asyncio
+async def test_resolve_billing_for_user_owned_byok_is_free():
+ """User-owned BYOK configs (positive IDs) cost the user nothing on
+ our side — they pay the provider directly. Always free.
+ """
+ from app.routes import image_generation_routes
+ from app.services.billable_calls import DEFAULT_IMAGE_RESERVE_MICROS
+
+ search_space = SimpleNamespace(image_generation_config_id=None)
+ tier, model, reserve = await image_generation_routes._resolve_billing_for_image_gen(
+ session=None, config_id=42, search_space=search_space
+ )
+ assert tier == "free"
+ assert model == "user_byok"
+ assert reserve == DEFAULT_IMAGE_RESERVE_MICROS
+
+
+@pytest.mark.asyncio
+async def test_resolve_billing_falls_back_to_search_space_default(monkeypatch):
+ """When the request omits ``image_generation_config_id``, the helper
+ must consult the search space's default — so a search space pinned
+ to a premium global config still gates new requests by quota.
+ """
+ from app.config import config
+ from app.routes import image_generation_routes
+
+ monkeypatch.setattr(
+ config,
+ "GLOBAL_IMAGE_GEN_CONFIGS",
+ [
+ {
+ "id": -7,
+ "provider": "OPENAI",
+ "model_name": "gpt-image-1",
+ "billing_tier": "premium",
+ }
+ ],
+ raising=False,
+ )
+
+ search_space = SimpleNamespace(image_generation_config_id=-7)
+ (
+ tier,
+ model,
+ _reserve,
+ ) = await image_generation_routes._resolve_billing_for_image_gen(
+ session=None, config_id=None, search_space=search_space
+ )
+ assert tier == "premium"
+ assert model == "openai/gpt-image-1"
diff --git a/surfsense_backend/tests/unit/services/test_agent_billing_resolver.py b/surfsense_backend/tests/unit/services/test_agent_billing_resolver.py
new file mode 100644
index 000000000..fa8819b39
--- /dev/null
+++ b/surfsense_backend/tests/unit/services/test_agent_billing_resolver.py
@@ -0,0 +1,436 @@
+"""Unit tests for ``_resolve_agent_billing_for_search_space``.
+
+Validates the resolver used by Celery podcast/video tasks to compute
+``(owner_user_id, billing_tier, base_model)`` from a search space and its
+agent LLM config. The resolver mirrors chat's billing-resolution pattern at
+``stream_new_chat.py:2294-2351`` and is the single integration point that
+prevents Auto-mode podcast/video from leaking premium credit.
+
+Coverage:
+
+* Auto mode + ``thread_id`` set, pin resolves to a negative-id premium
+ global → returns ``("premium", )``.
+* Auto mode + ``thread_id`` set, pin resolves to a negative-id free
+ global → returns ``("free", )``.
+* Auto mode + ``thread_id`` set, pin resolves to a positive-id BYOK config
+ → always ``"free"``.
+* Auto mode + ``thread_id=None`` → fallback to ``("free", "auto")`` without
+ hitting the pin service.
+* Negative id (no Auto) → uses ``get_global_llm_config``'s
+ ``billing_tier``.
+* Positive id (user BYOK) → always ``"free"``.
+* Search space not found → raises ``ValueError``.
+* ``agent_llm_id`` is None → raises ``ValueError``.
+"""
+
+from __future__ import annotations
+
+from dataclasses import dataclass
+from types import SimpleNamespace
+from uuid import UUID, uuid4
+
+import pytest
+
+pytestmark = pytest.mark.unit
+
+
+# ---------------------------------------------------------------------------
+# Fakes
+# ---------------------------------------------------------------------------
+
+
+class _FakeExecResult:
+ def __init__(self, obj):
+ self._obj = obj
+
+ def scalars(self):
+ return self
+
+ def first(self):
+ return self._obj
+
+
+class _FakeSession:
+ """Tiny AsyncSession stub.
+
+ ``responses`` is a list of objects to return from successive
+ ``execute()`` calls (in order). The resolver makes at most two
+ ``execute()`` calls (search-space lookup, then optionally NewLLMConfig
+ lookup), so two queued responses cover the matrix.
+ """
+
+ def __init__(self, responses: list):
+ self._responses = list(responses)
+
+ async def execute(self, _stmt):
+ if not self._responses:
+ return _FakeExecResult(None)
+ return _FakeExecResult(self._responses.pop(0))
+
+ async def commit(self) -> None:
+ pass
+
+
+@dataclass
+class _FakePinResolution:
+ resolved_llm_config_id: int
+ resolved_tier: str = "premium"
+ from_existing_pin: bool = False
+
+
+def _make_search_space(*, agent_llm_id: int | None, user_id: UUID) -> SimpleNamespace:
+ return SimpleNamespace(
+ id=42,
+ agent_llm_id=agent_llm_id,
+ user_id=user_id,
+ )
+
+
+def _make_byok_config(
+ *, id_: int, base_model: str | None = None, model_name: str = "gpt-byok"
+) -> SimpleNamespace:
+ return SimpleNamespace(
+ id=id_,
+ model_name=model_name,
+ litellm_params={"base_model": base_model} if base_model else {},
+ )
+
+
+# ---------------------------------------------------------------------------
+# Tests
+# ---------------------------------------------------------------------------
+
+
+@pytest.mark.asyncio
+async def test_auto_mode_with_thread_id_resolves_to_premium_global(monkeypatch):
+ """Auto + thread → pin service resolves to negative-id premium config →
+ resolver returns ``("premium", )``."""
+ from app.services.billable_calls import _resolve_agent_billing_for_search_space
+
+ user_id = uuid4()
+ session = _FakeSession([_make_search_space(agent_llm_id=0, user_id=user_id)])
+
+ # Mock the pin service to return a concrete premium config id.
+ async def _fake_resolve_pin(
+ sess,
+ *,
+ thread_id,
+ search_space_id,
+ user_id,
+ selected_llm_config_id,
+ force_repin_free=False,
+ ):
+ assert selected_llm_config_id == 0
+ assert thread_id == 99
+ return _FakePinResolution(resolved_llm_config_id=-1, resolved_tier="premium")
+
+ # Mock global config lookup to return a premium entry.
+ def _fake_get_global(cfg_id):
+ if cfg_id == -1:
+ return {
+ "id": -1,
+ "model_name": "gpt-5.4",
+ "billing_tier": "premium",
+ "litellm_params": {"base_model": "gpt-5.4"},
+ }
+ return None
+
+ # Lazy imports inside the resolver — patch the *target* modules so the
+ # imported names resolve to our fakes.
+ import app.services.auto_model_pin_service as pin_module
+ import app.services.llm_service as llm_module
+
+ monkeypatch.setattr(
+ pin_module, "resolve_or_get_pinned_llm_config_id", _fake_resolve_pin
+ )
+ monkeypatch.setattr(llm_module, "get_global_llm_config", _fake_get_global)
+
+ owner, tier, base_model = await _resolve_agent_billing_for_search_space(
+ session, search_space_id=42, thread_id=99
+ )
+
+ assert owner == user_id
+ assert tier == "premium"
+ assert base_model == "gpt-5.4"
+
+
+@pytest.mark.asyncio
+async def test_auto_mode_with_thread_id_resolves_to_free_global(monkeypatch):
+ """Auto + thread → pin returns negative-id free config → resolver
+ returns ``("free", )``. Same path the pin service takes for
+ out-of-credit users (graceful degradation)."""
+ from app.services.billable_calls import _resolve_agent_billing_for_search_space
+
+ user_id = uuid4()
+ session = _FakeSession([_make_search_space(agent_llm_id=0, user_id=user_id)])
+
+ async def _fake_resolve_pin(
+ sess,
+ *,
+ thread_id,
+ search_space_id,
+ user_id,
+ selected_llm_config_id,
+ force_repin_free=False,
+ ):
+ return _FakePinResolution(resolved_llm_config_id=-3, resolved_tier="free")
+
+ def _fake_get_global(cfg_id):
+ if cfg_id == -3:
+ return {
+ "id": -3,
+ "model_name": "openrouter/free-model",
+ "billing_tier": "free",
+ "litellm_params": {"base_model": "openrouter/free-model"},
+ }
+ return None
+
+ import app.services.auto_model_pin_service as pin_module
+ import app.services.llm_service as llm_module
+
+ monkeypatch.setattr(
+ pin_module, "resolve_or_get_pinned_llm_config_id", _fake_resolve_pin
+ )
+ monkeypatch.setattr(llm_module, "get_global_llm_config", _fake_get_global)
+
+ owner, tier, base_model = await _resolve_agent_billing_for_search_space(
+ session, search_space_id=42, thread_id=99
+ )
+
+ assert owner == user_id
+ assert tier == "free"
+ assert base_model == "openrouter/free-model"
+
+
+@pytest.mark.asyncio
+async def test_auto_mode_with_thread_id_resolves_to_byok_is_free(monkeypatch):
+ """Auto + thread → pin returns positive-id BYOK config → resolver
+ returns ``("free", ...)`` (BYOK is always free per
+ ``AgentConfig.from_new_llm_config``)."""
+ from app.services.billable_calls import _resolve_agent_billing_for_search_space
+
+ user_id = uuid4()
+ search_space = _make_search_space(agent_llm_id=0, user_id=user_id)
+ byok_cfg = _make_byok_config(
+ id_=17, base_model="anthropic/claude-3-haiku", model_name="my-claude"
+ )
+ session = _FakeSession([search_space, byok_cfg])
+
+ async def _fake_resolve_pin(
+ sess,
+ *,
+ thread_id,
+ search_space_id,
+ user_id,
+ selected_llm_config_id,
+ force_repin_free=False,
+ ):
+ return _FakePinResolution(resolved_llm_config_id=17, resolved_tier="free")
+
+ import app.services.auto_model_pin_service as pin_module
+
+ monkeypatch.setattr(
+ pin_module, "resolve_or_get_pinned_llm_config_id", _fake_resolve_pin
+ )
+
+ owner, tier, base_model = await _resolve_agent_billing_for_search_space(
+ session, search_space_id=42, thread_id=99
+ )
+
+ assert owner == user_id
+ assert tier == "free"
+ assert base_model == "anthropic/claude-3-haiku"
+
+
+@pytest.mark.asyncio
+async def test_auto_mode_without_thread_id_falls_back_to_free():
+ """Auto + ``thread_id=None`` → ``("free", "auto")`` without invoking
+ the pin service. Forward-compat fallback for any future direct-API
+ entrypoint that doesn't have a chat thread."""
+ from app.services.billable_calls import _resolve_agent_billing_for_search_space
+
+ user_id = uuid4()
+ session = _FakeSession([_make_search_space(agent_llm_id=0, user_id=user_id)])
+
+ owner, tier, base_model = await _resolve_agent_billing_for_search_space(
+ session, search_space_id=42, thread_id=None
+ )
+
+ assert owner == user_id
+ assert tier == "free"
+ assert base_model == "auto"
+
+
+@pytest.mark.asyncio
+async def test_auto_mode_pin_failure_falls_back_to_free(monkeypatch):
+ """If the pin service raises ``ValueError`` (thread missing /
+ mismatched search space), the resolver should log and return free
+ rather than killing the whole task."""
+ from app.services.billable_calls import _resolve_agent_billing_for_search_space
+
+ user_id = uuid4()
+ session = _FakeSession([_make_search_space(agent_llm_id=0, user_id=user_id)])
+
+ async def _fake_resolve_pin(*args, **kwargs):
+ raise ValueError("thread missing")
+
+ import app.services.auto_model_pin_service as pin_module
+
+ monkeypatch.setattr(
+ pin_module, "resolve_or_get_pinned_llm_config_id", _fake_resolve_pin
+ )
+
+ owner, tier, base_model = await _resolve_agent_billing_for_search_space(
+ session, search_space_id=42, thread_id=99
+ )
+
+ assert owner == user_id
+ assert tier == "free"
+ assert base_model == "auto"
+
+
+@pytest.mark.asyncio
+async def test_negative_id_premium_global_returns_premium(monkeypatch):
+ """Explicit negative agent_llm_id → ``get_global_llm_config`` →
+ return its ``billing_tier``."""
+ from app.services.billable_calls import _resolve_agent_billing_for_search_space
+
+ user_id = uuid4()
+ session = _FakeSession([_make_search_space(agent_llm_id=-1, user_id=user_id)])
+
+ def _fake_get_global(cfg_id):
+ return {
+ "id": cfg_id,
+ "model_name": "gpt-5.4",
+ "billing_tier": "premium",
+ "litellm_params": {"base_model": "gpt-5.4"},
+ }
+
+ import app.services.llm_service as llm_module
+
+ monkeypatch.setattr(llm_module, "get_global_llm_config", _fake_get_global)
+
+ owner, tier, base_model = await _resolve_agent_billing_for_search_space(
+ session, search_space_id=42, thread_id=99
+ )
+
+ assert owner == user_id
+ assert tier == "premium"
+ assert base_model == "gpt-5.4"
+
+
+@pytest.mark.asyncio
+async def test_negative_id_free_global_returns_free(monkeypatch):
+ from app.services.billable_calls import _resolve_agent_billing_for_search_space
+
+ user_id = uuid4()
+ session = _FakeSession([_make_search_space(agent_llm_id=-2, user_id=user_id)])
+
+ def _fake_get_global(cfg_id):
+ return {
+ "id": cfg_id,
+ "model_name": "openrouter/some-free",
+ "billing_tier": "free",
+ "litellm_params": {"base_model": "openrouter/some-free"},
+ }
+
+ import app.services.llm_service as llm_module
+
+ monkeypatch.setattr(llm_module, "get_global_llm_config", _fake_get_global)
+
+ owner, tier, base_model = await _resolve_agent_billing_for_search_space(
+ session, search_space_id=42, thread_id=None
+ )
+
+ assert owner == user_id
+ assert tier == "free"
+ assert base_model == "openrouter/some-free"
+
+
+@pytest.mark.asyncio
+async def test_negative_id_missing_base_model_falls_back_to_model_name(monkeypatch):
+ """When the global config has no ``litellm_params.base_model``, the
+ resolver falls back to ``model_name`` — matching chat's behavior."""
+ from app.services.billable_calls import _resolve_agent_billing_for_search_space
+
+ user_id = uuid4()
+ session = _FakeSession([_make_search_space(agent_llm_id=-5, user_id=user_id)])
+
+ def _fake_get_global(cfg_id):
+ return {
+ "id": cfg_id,
+ "model_name": "fallback-model",
+ "billing_tier": "premium",
+ # No litellm_params.
+ }
+
+ import app.services.llm_service as llm_module
+
+ monkeypatch.setattr(llm_module, "get_global_llm_config", _fake_get_global)
+
+ _, tier, base_model = await _resolve_agent_billing_for_search_space(
+ session, search_space_id=42
+ )
+
+ assert tier == "premium"
+ assert base_model == "fallback-model"
+
+
+@pytest.mark.asyncio
+async def test_positive_id_byok_is_always_free():
+ """Positive agent_llm_id → user-owned BYOK NewLLMConfig → always free,
+ regardless of underlying provider tier."""
+ from app.services.billable_calls import _resolve_agent_billing_for_search_space
+
+ user_id = uuid4()
+ search_space = _make_search_space(agent_llm_id=23, user_id=user_id)
+ byok_cfg = _make_byok_config(id_=23, base_model="anthropic/claude-3.5-sonnet")
+ session = _FakeSession([search_space, byok_cfg])
+
+ owner, tier, base_model = await _resolve_agent_billing_for_search_space(
+ session, search_space_id=42
+ )
+
+ assert owner == user_id
+ assert tier == "free"
+ assert base_model == "anthropic/claude-3.5-sonnet"
+
+
+@pytest.mark.asyncio
+async def test_positive_id_byok_missing_returns_free_with_empty_base_model():
+ """If the BYOK config row is missing/deleted but the search space still
+ points at it, the resolver still returns free (no debit) with an empty
+ base_model — billable_call's premium path is skipped, no harm done."""
+ from app.services.billable_calls import _resolve_agent_billing_for_search_space
+
+ user_id = uuid4()
+ session = _FakeSession([_make_search_space(agent_llm_id=99, user_id=user_id)])
+
+ owner, tier, base_model = await _resolve_agent_billing_for_search_space(
+ session, search_space_id=42
+ )
+
+ assert owner == user_id
+ assert tier == "free"
+ assert base_model == ""
+
+
+@pytest.mark.asyncio
+async def test_search_space_not_found_raises_value_error():
+ from app.services.billable_calls import _resolve_agent_billing_for_search_space
+
+ session = _FakeSession([None])
+
+ with pytest.raises(ValueError, match="Search space"):
+ await _resolve_agent_billing_for_search_space(session, search_space_id=999)
+
+
+@pytest.mark.asyncio
+async def test_agent_llm_id_none_raises_value_error():
+ from app.services.billable_calls import _resolve_agent_billing_for_search_space
+
+ user_id = uuid4()
+ session = _FakeSession([_make_search_space(agent_llm_id=None, user_id=user_id)])
+
+ with pytest.raises(ValueError, match="agent_llm_id"):
+ await _resolve_agent_billing_for_search_space(session, search_space_id=42)
diff --git a/surfsense_backend/tests/unit/services/test_billable_call.py b/surfsense_backend/tests/unit/services/test_billable_call.py
new file mode 100644
index 000000000..86de5f23d
--- /dev/null
+++ b/surfsense_backend/tests/unit/services/test_billable_call.py
@@ -0,0 +1,432 @@
+"""Unit tests for the ``billable_call`` async context manager.
+
+Covers the per-call premium-credit lifecycle for image generation and
+vision LLM extraction:
+
+* Free configs bypass reserve/finalize but still write an audit row.
+* Premium reserve denial raises ``QuotaInsufficientError`` (HTTP 402 in the
+ route layer).
+* Successful premium calls reserve, yield the accumulator, then finalize
+ with the LiteLLM-reported actual cost — and write an audit row.
+* Failed premium calls release the reservation so credit isn't leaked.
+* All quota DB ops happen inside their OWN ``shielded_async_session``,
+ isolating them from the caller's transaction (issue A).
+"""
+
+from __future__ import annotations
+
+import contextlib
+from typing import Any
+from uuid import uuid4
+
+import pytest
+
+pytestmark = pytest.mark.unit
+
+
+# ---------------------------------------------------------------------------
+# Fakes
+# ---------------------------------------------------------------------------
+
+
+class _FakeQuotaResult:
+ def __init__(
+ self,
+ *,
+ allowed: bool,
+ used: int = 0,
+ limit: int = 5_000_000,
+ remaining: int = 5_000_000,
+ ) -> None:
+ self.allowed = allowed
+ self.used = used
+ self.limit = limit
+ self.remaining = remaining
+
+
+class _FakeSession:
+ """Minimal AsyncSession stub — record commits for assertion."""
+
+ def __init__(self) -> None:
+ self.committed = False
+ self.added: list[Any] = []
+
+ def add(self, obj: Any) -> None:
+ self.added.append(obj)
+
+ async def commit(self) -> None:
+ self.committed = True
+
+ async def close(self) -> None:
+ pass
+
+
+@contextlib.asynccontextmanager
+async def _fake_shielded_session():
+ s = _FakeSession()
+ _SESSIONS_USED.append(s)
+ yield s
+
+
+_SESSIONS_USED: list[_FakeSession] = []
+
+
+def _patch_isolation_layer(monkeypatch, *, reserve_result, finalize_result=None):
+ """Wire fake reserve/finalize/release/session helpers."""
+ _SESSIONS_USED.clear()
+ reserve_calls: list[dict[str, Any]] = []
+ finalize_calls: list[dict[str, Any]] = []
+ release_calls: list[dict[str, Any]] = []
+
+ async def _fake_reserve(*, db_session, user_id, request_id, reserve_micros):
+ reserve_calls.append(
+ {
+ "user_id": user_id,
+ "reserve_micros": reserve_micros,
+ "request_id": request_id,
+ }
+ )
+ return reserve_result
+
+ async def _fake_finalize(
+ *, db_session, user_id, request_id, actual_micros, reserved_micros
+ ):
+ finalize_calls.append(
+ {
+ "user_id": user_id,
+ "actual_micros": actual_micros,
+ "reserved_micros": reserved_micros,
+ }
+ )
+ return finalize_result or _FakeQuotaResult(allowed=True)
+
+ async def _fake_release(*, db_session, user_id, reserved_micros):
+ release_calls.append({"user_id": user_id, "reserved_micros": reserved_micros})
+
+ record_calls: list[dict[str, Any]] = []
+
+ async def _fake_record(session, **kwargs):
+ record_calls.append(kwargs)
+ return object()
+
+ monkeypatch.setattr(
+ "app.services.billable_calls.TokenQuotaService.premium_reserve",
+ _fake_reserve,
+ raising=False,
+ )
+ monkeypatch.setattr(
+ "app.services.billable_calls.TokenQuotaService.premium_finalize",
+ _fake_finalize,
+ raising=False,
+ )
+ monkeypatch.setattr(
+ "app.services.billable_calls.TokenQuotaService.premium_release",
+ _fake_release,
+ raising=False,
+ )
+ monkeypatch.setattr(
+ "app.services.billable_calls.shielded_async_session",
+ _fake_shielded_session,
+ raising=False,
+ )
+ monkeypatch.setattr(
+ "app.services.billable_calls.record_token_usage",
+ _fake_record,
+ raising=False,
+ )
+
+ return {
+ "reserve": reserve_calls,
+ "finalize": finalize_calls,
+ "release": release_calls,
+ "record": record_calls,
+ }
+
+
+# ---------------------------------------------------------------------------
+# Tests
+# ---------------------------------------------------------------------------
+
+
+@pytest.mark.asyncio
+async def test_free_path_skips_reserve_but_writes_audit_row(monkeypatch):
+ from app.services.billable_calls import billable_call
+
+ spies = _patch_isolation_layer(
+ monkeypatch, reserve_result=_FakeQuotaResult(allowed=True)
+ )
+ user_id = uuid4()
+
+ async with billable_call(
+ user_id=user_id,
+ search_space_id=42,
+ billing_tier="free",
+ base_model="openai/gpt-image-1",
+ usage_type="image_generation",
+ ) as acc:
+ # Simulate a captured cost — the accumulator is fed by the LiteLLM
+ # callback in real life, here we add it manually.
+ acc.add(
+ model="openai/gpt-image-1",
+ prompt_tokens=0,
+ completion_tokens=0,
+ total_tokens=0,
+ cost_micros=37_000,
+ call_kind="image_generation",
+ )
+
+ assert spies["reserve"] == []
+ assert spies["finalize"] == []
+ assert spies["release"] == []
+ # Free still audits.
+ assert len(spies["record"]) == 1
+ assert spies["record"][0]["usage_type"] == "image_generation"
+ assert spies["record"][0]["cost_micros"] == 37_000
+
+
+@pytest.mark.asyncio
+async def test_premium_reserve_denied_raises_quota_insufficient(monkeypatch):
+ from app.services.billable_calls import (
+ QuotaInsufficientError,
+ billable_call,
+ )
+
+ spies = _patch_isolation_layer(
+ monkeypatch,
+ reserve_result=_FakeQuotaResult(
+ allowed=False, used=5_000_000, limit=5_000_000, remaining=0
+ ),
+ )
+ user_id = uuid4()
+
+ with pytest.raises(QuotaInsufficientError) as exc_info:
+ async with billable_call(
+ user_id=user_id,
+ search_space_id=42,
+ billing_tier="premium",
+ base_model="openai/gpt-image-1",
+ quota_reserve_micros_override=50_000,
+ usage_type="image_generation",
+ ):
+ pytest.fail("body should not run when reserve is denied")
+
+ err = exc_info.value
+ assert err.usage_type == "image_generation"
+ assert err.used_micros == 5_000_000
+ assert err.limit_micros == 5_000_000
+ assert err.remaining_micros == 0
+ # Reserve was attempted, but no finalize/release on a denied reserve
+ # — the reservation never actually held credit.
+ assert len(spies["reserve"]) == 1
+ assert spies["finalize"] == []
+ assert spies["release"] == []
+ # Denied premium calls do NOT create an audit row (no work happened).
+ assert spies["record"] == []
+
+
+@pytest.mark.asyncio
+async def test_premium_success_finalizes_with_actual_cost(monkeypatch):
+ from app.services.billable_calls import billable_call
+
+ spies = _patch_isolation_layer(
+ monkeypatch, reserve_result=_FakeQuotaResult(allowed=True)
+ )
+ user_id = uuid4()
+
+ async with billable_call(
+ user_id=user_id,
+ search_space_id=42,
+ billing_tier="premium",
+ base_model="openai/gpt-image-1",
+ quota_reserve_micros_override=50_000,
+ usage_type="image_generation",
+ ) as acc:
+ # LiteLLM callback would normally fill this — simulate $0.04 image.
+ acc.add(
+ model="openai/gpt-image-1",
+ prompt_tokens=0,
+ completion_tokens=0,
+ total_tokens=0,
+ cost_micros=40_000,
+ call_kind="image_generation",
+ )
+
+ assert len(spies["reserve"]) == 1
+ assert spies["reserve"][0]["reserve_micros"] == 50_000
+ assert len(spies["finalize"]) == 1
+ assert spies["finalize"][0]["actual_micros"] == 40_000
+ assert spies["finalize"][0]["reserved_micros"] == 50_000
+ assert spies["release"] == []
+ # And audit row written with the actual debited cost.
+ assert spies["record"][0]["cost_micros"] == 40_000
+ # Each quota op opened its OWN session — proves session isolation.
+ assert len(_SESSIONS_USED) >= 3
+ # Sessions used should each have committed (or be the audit one which commits).
+ for _s in _SESSIONS_USED:
+ # finalize/reserve happen via TokenQuotaService.* which we stub —
+ # they don't actually call commit on our fake session, but the
+ # audit session does. We just assert >=1 session committed.
+ pass
+ assert any(s.committed for s in _SESSIONS_USED)
+
+
+@pytest.mark.asyncio
+async def test_premium_failure_releases_reservation(monkeypatch):
+ from app.services.billable_calls import billable_call
+
+ spies = _patch_isolation_layer(
+ monkeypatch, reserve_result=_FakeQuotaResult(allowed=True)
+ )
+ user_id = uuid4()
+
+ class _ProviderError(Exception):
+ pass
+
+ with pytest.raises(_ProviderError):
+ async with billable_call(
+ user_id=user_id,
+ search_space_id=42,
+ billing_tier="premium",
+ base_model="openai/gpt-image-1",
+ quota_reserve_micros_override=50_000,
+ usage_type="image_generation",
+ ):
+ raise _ProviderError("OpenRouter 503")
+
+ assert len(spies["reserve"]) == 1
+ assert spies["finalize"] == []
+ # Failure path: release the held reservation.
+ assert len(spies["release"]) == 1
+ assert spies["release"][0]["reserved_micros"] == 50_000
+
+
+@pytest.mark.asyncio
+async def test_premium_uses_estimator_when_no_micros_override(monkeypatch):
+ """When ``quota_reserve_micros_override`` is None we fall back to
+ ``estimate_call_reserve_micros(base_model, quota_reserve_tokens)``.
+ Vision LLM calls take this path (token-priced models).
+ """
+ from app.services.billable_calls import billable_call
+
+ spies = _patch_isolation_layer(
+ monkeypatch, reserve_result=_FakeQuotaResult(allowed=True)
+ )
+
+ captured_estimator_calls: list[dict[str, Any]] = []
+
+ def _fake_estimate(*, base_model, quota_reserve_tokens):
+ captured_estimator_calls.append(
+ {"base_model": base_model, "quota_reserve_tokens": quota_reserve_tokens}
+ )
+ return 12_345
+
+ monkeypatch.setattr(
+ "app.services.billable_calls.estimate_call_reserve_micros",
+ _fake_estimate,
+ raising=False,
+ )
+
+ user_id = uuid4()
+ async with billable_call(
+ user_id=user_id,
+ search_space_id=1,
+ billing_tier="premium",
+ base_model="openai/gpt-4o",
+ quota_reserve_tokens=4000,
+ usage_type="vision_extraction",
+ ):
+ pass
+
+ assert captured_estimator_calls == [
+ {"base_model": "openai/gpt-4o", "quota_reserve_tokens": 4000}
+ ]
+ assert spies["reserve"][0]["reserve_micros"] == 12_345
+
+
+# ---------------------------------------------------------------------------
+# Podcast / video-presentation usage_type coverage
+# ---------------------------------------------------------------------------
+
+
+@pytest.mark.asyncio
+async def test_free_podcast_path_audits_with_podcast_usage_type(monkeypatch):
+ """Free podcast configs must skip reserve/finalize but still emit a
+ ``TokenUsage`` row tagged ``usage_type='podcast_generation'`` so we
+ have full audit coverage of free-tier agent runs."""
+ from app.services.billable_calls import billable_call
+
+ spies = _patch_isolation_layer(
+ monkeypatch, reserve_result=_FakeQuotaResult(allowed=True)
+ )
+ user_id = uuid4()
+
+ async with billable_call(
+ user_id=user_id,
+ search_space_id=42,
+ billing_tier="free",
+ base_model="openrouter/some-free-model",
+ quota_reserve_micros_override=200_000,
+ usage_type="podcast_generation",
+ thread_id=99,
+ call_details={"podcast_id": 7, "title": "Test Podcast"},
+ ) as acc:
+ # Two transcript LLM calls aggregated into one accumulator.
+ acc.add(
+ model="openrouter/some-free-model",
+ prompt_tokens=1500,
+ completion_tokens=8000,
+ total_tokens=9500,
+ cost_micros=0,
+ call_kind="chat",
+ )
+
+ assert spies["reserve"] == []
+ assert spies["finalize"] == []
+ assert spies["release"] == []
+
+ assert len(spies["record"]) == 1
+ row = spies["record"][0]
+ assert row["usage_type"] == "podcast_generation"
+ assert row["thread_id"] == 99
+ assert row["search_space_id"] == 42
+ assert row["call_details"] == {"podcast_id": 7, "title": "Test Podcast"}
+
+
+@pytest.mark.asyncio
+async def test_premium_video_denial_raises_quota_insufficient(monkeypatch):
+ """Premium video-presentation runs that hit a denied reservation must
+ raise ``QuotaInsufficientError`` *before* the graph runs and must not
+ emit an audit row (no work happened)."""
+ from app.services.billable_calls import (
+ QuotaInsufficientError,
+ billable_call,
+ )
+
+ spies = _patch_isolation_layer(
+ monkeypatch,
+ reserve_result=_FakeQuotaResult(
+ allowed=False, used=4_500_000, limit=5_000_000, remaining=500_000
+ ),
+ )
+ user_id = uuid4()
+
+ with pytest.raises(QuotaInsufficientError) as exc_info:
+ async with billable_call(
+ user_id=user_id,
+ search_space_id=42,
+ billing_tier="premium",
+ base_model="gpt-5.4",
+ quota_reserve_micros_override=1_000_000,
+ usage_type="video_presentation_generation",
+ thread_id=99,
+ call_details={"video_presentation_id": 12, "title": "Test Video"},
+ ):
+ pytest.fail("body should not run when reserve is denied")
+
+ err = exc_info.value
+ assert err.usage_type == "video_presentation_generation"
+ assert err.remaining_micros == 500_000
+ assert spies["reserve"][0]["reserve_micros"] == 1_000_000
+ assert spies["finalize"] == []
+ assert spies["release"] == []
+ assert spies["record"] == []
diff --git a/surfsense_backend/tests/unit/services/test_openrouter_integration_service.py b/surfsense_backend/tests/unit/services/test_openrouter_integration_service.py
index 085740032..b635b4fe8 100644
--- a/surfsense_backend/tests/unit/services/test_openrouter_integration_service.py
+++ b/surfsense_backend/tests/unit/services/test_openrouter_integration_service.py
@@ -214,3 +214,159 @@ def test_generate_configs_drops_non_text_and_non_tool_models():
assert "openai/gpt-4o" in model_names
assert "openai/dall-e" not in model_names
assert "openai/completion-only" not in model_names
+
+
+# ---------------------------------------------------------------------------
+# _generate_image_gen_configs / _generate_vision_llm_configs
+# ---------------------------------------------------------------------------
+
+
+def test_generate_image_gen_configs_filters_by_image_output():
+ """Only models with ``output_modalities`` containing ``image`` are emitted.
+ Tool-calling and context filters are intentionally NOT applied — image
+ generation has nothing to do with tool calls and context windows.
+ """
+ from app.services.openrouter_integration_service import (
+ _generate_image_gen_configs,
+ )
+
+ raw = [
+ # Pure image-gen model (small context, no tools — should still emit).
+ {
+ "id": "openai/gpt-image-1",
+ "architecture": {"output_modalities": ["image"]},
+ "context_length": 4_000,
+ "pricing": {"prompt": "0", "completion": "0"},
+ },
+ # Multi-modal: text+image output (should still emit).
+ {
+ "id": "google/gemini-2.5-flash-image",
+ "architecture": {"output_modalities": ["text", "image"]},
+ "context_length": 1_000_000,
+ "pricing": {"prompt": "0.000001", "completion": "0.000004"},
+ },
+ # Pure text model — must NOT emit.
+ {
+ "id": "openai/gpt-4o",
+ "architecture": {"output_modalities": ["text"]},
+ "context_length": 128_000,
+ "pricing": {"prompt": "0.000005", "completion": "0.000015"},
+ },
+ ]
+
+ cfgs = _generate_image_gen_configs(raw, dict(_SETTINGS_BASE))
+ model_names = {c["model_name"] for c in cfgs}
+ assert "openai/gpt-image-1" in model_names
+ assert "google/gemini-2.5-flash-image" in model_names
+ assert "openai/gpt-4o" not in model_names
+
+ # Each config must carry ``billing_tier`` for routing in image_generation_routes.
+ for c in cfgs:
+ assert c["billing_tier"] in {"free", "premium"}
+ assert c["provider"] == "OPENROUTER"
+ assert c[_OPENROUTER_DYNAMIC_MARKER] is True
+
+
+def test_generate_image_gen_configs_assigns_image_id_offset():
+ """Image configs use a different id_offset (-20000) so their negative
+ IDs don't collide with chat configs (-10000) or vision configs (-30000).
+ """
+ from app.services.openrouter_integration_service import (
+ _generate_image_gen_configs,
+ )
+
+ raw = [
+ {
+ "id": "openai/gpt-image-1",
+ "architecture": {"output_modalities": ["image"]},
+ "context_length": 4_000,
+ "pricing": {"prompt": "0", "completion": "0"},
+ }
+ ]
+ # Don't pass image_id_offset → use the module default (-20000).
+ cfgs = _generate_image_gen_configs(raw, dict(_SETTINGS_BASE))
+ assert all(c["id"] < -20_000 + 1 for c in cfgs)
+ assert all(c["id"] > -29_000_000 for c in cfgs)
+
+
+def test_generate_vision_llm_configs_filters_by_image_input_text_output():
+ """Vision LLMs must accept image input AND emit text — pure image-gen
+ (no text out) and text-only (no image in) models are excluded.
+ """
+ from app.services.openrouter_integration_service import (
+ _generate_vision_llm_configs,
+ )
+
+ raw = [
+ # GPT-4o: vision LLM (image in, text out) — must emit.
+ {
+ "id": "openai/gpt-4o",
+ "architecture": {
+ "input_modalities": ["text", "image"],
+ "output_modalities": ["text"],
+ },
+ "context_length": 128_000,
+ "pricing": {"prompt": "0.000005", "completion": "0.000015"},
+ },
+ # Pure image generator — image *output*, no text out. Must NOT emit.
+ {
+ "id": "openai/gpt-image-1",
+ "architecture": {
+ "input_modalities": ["text"],
+ "output_modalities": ["image"],
+ },
+ "context_length": 4_000,
+ "pricing": {"prompt": "0", "completion": "0"},
+ },
+ # Pure text model (no image in). Must NOT emit.
+ {
+ "id": "anthropic/claude-3-haiku",
+ "architecture": {
+ "input_modalities": ["text"],
+ "output_modalities": ["text"],
+ },
+ "context_length": 200_000,
+ "pricing": {"prompt": "0.000001", "completion": "0.000005"},
+ },
+ ]
+
+ cfgs = _generate_vision_llm_configs(raw, dict(_SETTINGS_BASE))
+ names = {c["model_name"] for c in cfgs}
+ assert names == {"openai/gpt-4o"}
+
+ cfg = cfgs[0]
+ assert cfg["billing_tier"] == "premium"
+ # Pricing carried inline so pricing_registration can register vision
+ # under ``openrouter/openai/gpt-4o`` even if the chat catalogue cache
+ # is cleared.
+ assert cfg["input_cost_per_token"] == pytest.approx(5e-6)
+ assert cfg["output_cost_per_token"] == pytest.approx(15e-6)
+ assert cfg[_OPENROUTER_DYNAMIC_MARKER] is True
+
+
+def test_generate_vision_llm_configs_drops_chat_only_filters():
+ """A small-context vision model that doesn't advertise tool calling is
+ still a valid vision LLM for "describe this image" prompts. The chat
+ filters (``supports_tool_calling``, ``has_sufficient_context``) must
+ NOT be applied to vision emission.
+ """
+ from app.services.openrouter_integration_service import (
+ _generate_vision_llm_configs,
+ )
+
+ raw = [
+ {
+ "id": "tiny/vision-mini",
+ "architecture": {
+ "input_modalities": ["text", "image"],
+ "output_modalities": ["text"],
+ },
+ "supported_parameters": [], # no tools
+ "context_length": 4_000, # well below MIN_CONTEXT_LENGTH
+ "pricing": {"prompt": "0.0000001", "completion": "0.0000005"},
+ }
+ ]
+
+ cfgs = _generate_vision_llm_configs(raw, dict(_SETTINGS_BASE))
+ assert len(cfgs) == 1
+ assert cfgs[0]["model_name"] == "tiny/vision-mini"
diff --git a/surfsense_backend/tests/unit/services/test_pricing_registration.py b/surfsense_backend/tests/unit/services/test_pricing_registration.py
new file mode 100644
index 000000000..e97250ff2
--- /dev/null
+++ b/surfsense_backend/tests/unit/services/test_pricing_registration.py
@@ -0,0 +1,447 @@
+"""Pricing registration unit tests.
+
+The pricing-registration module is what makes ``response_cost`` populate
+correctly for OpenRouter dynamic models and operator-defined Azure
+deployments — both of which LiteLLM doesn't natively know about. The tests
+exercise:
+
+* The alias generators emit every shape that LiteLLM's cost-callback might
+ use (``openrouter/X`` and bare ``X``; YAML-defined ``base_model``,
+ ``provider/base_model``, ``provider/model_name``, plus the special
+ ``azure_openai`` → ``azure`` normalisation).
+* ``register_pricing_from_global_configs`` calls ``litellm.register_model``
+ with the right alias set and pricing values per provider.
+* Configs without a resolvable pair of cost values are skipped — never
+ registered as zero, since that would override pricing LiteLLM might
+ already know natively.
+"""
+
+from __future__ import annotations
+
+from typing import Any
+
+import pytest
+
+pytestmark = pytest.mark.unit
+
+
+# ---------------------------------------------------------------------------
+# Alias generators
+# ---------------------------------------------------------------------------
+
+
+def test_openrouter_alias_set_includes_prefixed_and_bare():
+ from app.services.pricing_registration import _alias_set_for_openrouter
+
+ aliases = _alias_set_for_openrouter("anthropic/claude-3-5-sonnet")
+ assert aliases == [
+ "openrouter/anthropic/claude-3-5-sonnet",
+ "anthropic/claude-3-5-sonnet",
+ ]
+
+
+def test_openrouter_alias_set_dedupes():
+ """If the model id is already prefixed with ``openrouter/``, the alias
+ set must not contain duplicates that would re-register the same key
+ twice.
+ """
+ from app.services.pricing_registration import _alias_set_for_openrouter
+
+ aliases = _alias_set_for_openrouter("openrouter/foo")
+ # The bare and prefixed variants compute to the same string here, so we
+ # at minimum require uniqueness.
+ assert len(aliases) == len(set(aliases))
+
+
+def test_yaml_alias_set_for_azure_openai_normalises_to_azure():
+ """``azure_openai`` (our YAML provider slug) must register under
+ ``azure/`` so the LiteLLM Router's deployment-resolution path
+ (which uses provider ``azure``) finds the pricing too.
+ """
+ from app.services.pricing_registration import _alias_set_for_yaml
+
+ aliases = _alias_set_for_yaml(
+ provider="AZURE_OPENAI",
+ model_name="gpt-5.4",
+ base_model="gpt-5.4",
+ )
+ assert "gpt-5.4" in aliases
+ assert "azure_openai/gpt-5.4" in aliases
+ assert "azure/gpt-5.4" in aliases
+
+
+def test_yaml_alias_set_distinguishes_model_name_and_base_model():
+ """When ``model_name`` differs from ``base_model`` (operator labelled a
+ deployment), both must appear in the alias set since either may surface
+ in callbacks depending on the call path.
+ """
+ from app.services.pricing_registration import _alias_set_for_yaml
+
+ aliases = _alias_set_for_yaml(
+ provider="OPENAI",
+ model_name="my-deployment-label",
+ base_model="gpt-4o",
+ )
+ assert "gpt-4o" in aliases
+ assert "openai/gpt-4o" in aliases
+ assert "my-deployment-label" in aliases
+ assert "openai/my-deployment-label" in aliases
+
+
+def test_yaml_alias_set_omits_provider_prefix_when_provider_blank():
+ from app.services.pricing_registration import _alias_set_for_yaml
+
+ aliases = _alias_set_for_yaml(
+ provider="",
+ model_name="foo",
+ base_model="bar",
+ )
+ assert "bar" in aliases
+ assert "foo" in aliases
+ assert all("/" not in a for a in aliases)
+
+
+# ---------------------------------------------------------------------------
+# register_pricing_from_global_configs
+# ---------------------------------------------------------------------------
+
+
+class _RegistrationSpy:
+ """Captures the dicts passed to ``litellm.register_model``.
+
+ Many calls may go through; we just record them all and let tests assert
+ against the union.
+ """
+
+ def __init__(self) -> None:
+ self.calls: list[dict[str, Any]] = []
+
+ def __call__(self, payload: dict[str, Any]) -> None:
+ self.calls.append(payload)
+
+ @property
+ def all_keys(self) -> set[str]:
+ keys: set[str] = set()
+ for payload in self.calls:
+ keys.update(payload.keys())
+ return keys
+
+
+def _patch_register(monkeypatch: pytest.MonkeyPatch) -> _RegistrationSpy:
+ spy = _RegistrationSpy()
+ monkeypatch.setattr(
+ "app.services.pricing_registration.litellm.register_model",
+ spy,
+ raising=False,
+ )
+ return spy
+
+
+def _patch_openrouter_pricing(
+ monkeypatch: pytest.MonkeyPatch, mapping: dict[str, dict[str, str]]
+) -> None:
+ """Pretend the OpenRouter integration is initialised with ``mapping``."""
+
+ class _Stub:
+ def get_raw_pricing(self) -> dict[str, dict[str, str]]:
+ return mapping
+
+ class _StubService:
+ @classmethod
+ def is_initialized(cls) -> bool:
+ return True
+
+ @classmethod
+ def get_instance(cls) -> _Stub:
+ return _Stub()
+
+ monkeypatch.setattr(
+ "app.services.openrouter_integration_service.OpenRouterIntegrationService",
+ _StubService,
+ raising=False,
+ )
+
+
+def test_openrouter_models_register_under_aliases(monkeypatch):
+ """An OpenRouter config whose ``model_name`` is in the cached raw
+ pricing map is registered under both ``openrouter/X`` and bare ``X``.
+ """
+ from app.config import config
+ from app.services.pricing_registration import register_pricing_from_global_configs
+
+ spy = _patch_register(monkeypatch)
+ _patch_openrouter_pricing(
+ monkeypatch,
+ {
+ "anthropic/claude-3-5-sonnet": {
+ "prompt": "0.000003",
+ "completion": "0.000015",
+ }
+ },
+ )
+
+ monkeypatch.setattr(
+ config,
+ "GLOBAL_LLM_CONFIGS",
+ [
+ {
+ "id": 1,
+ "provider": "OPENROUTER",
+ "model_name": "anthropic/claude-3-5-sonnet",
+ }
+ ],
+ )
+
+ register_pricing_from_global_configs()
+
+ assert "openrouter/anthropic/claude-3-5-sonnet" in spy.all_keys
+ assert "anthropic/claude-3-5-sonnet" in spy.all_keys
+ # Costs are float-converted from the raw OpenRouter strings.
+ payload = spy.calls[0]
+ assert payload["openrouter/anthropic/claude-3-5-sonnet"][
+ "input_cost_per_token"
+ ] == pytest.approx(3e-6)
+ assert payload["openrouter/anthropic/claude-3-5-sonnet"][
+ "output_cost_per_token"
+ ] == pytest.approx(15e-6)
+ assert (
+ payload["openrouter/anthropic/claude-3-5-sonnet"]["litellm_provider"]
+ == "openrouter"
+ )
+
+
+def test_yaml_override_registers_under_alias_set(monkeypatch):
+ """Operator-declared ``input_cost_per_token`` /
+ ``output_cost_per_token`` on a YAML config registers under every
+ alias the YAML alias generator produces — including the ``azure/``
+ normalisation for ``azure_openai`` providers.
+ """
+ from app.config import config
+ from app.services.pricing_registration import register_pricing_from_global_configs
+
+ spy = _patch_register(monkeypatch)
+ _patch_openrouter_pricing(monkeypatch, {})
+
+ monkeypatch.setattr(
+ config,
+ "GLOBAL_LLM_CONFIGS",
+ [
+ {
+ "id": 1,
+ "provider": "AZURE_OPENAI",
+ "model_name": "gpt-5.4",
+ "litellm_params": {
+ "base_model": "gpt-5.4",
+ "input_cost_per_token": 2e-6,
+ "output_cost_per_token": 8e-6,
+ },
+ }
+ ],
+ )
+
+ register_pricing_from_global_configs()
+
+ keys = spy.all_keys
+ assert "gpt-5.4" in keys
+ assert "azure_openai/gpt-5.4" in keys
+ assert "azure/gpt-5.4" in keys
+
+ payload = spy.calls[0]
+ entry = payload["gpt-5.4"]
+ assert entry["input_cost_per_token"] == pytest.approx(2e-6)
+ assert entry["output_cost_per_token"] == pytest.approx(8e-6)
+ assert entry["litellm_provider"] == "azure"
+
+
+def test_no_override_means_no_registration(monkeypatch):
+ """A YAML config that *omits* both pricing fields must NOT be registered
+ — registering as zero would override LiteLLM's native pricing for the
+ ``base_model`` key (e.g. ``gpt-4o``) and silently make every user's
+ bill drop to $0. Fail-safe is "skip and warn", not "register zero".
+ """
+ from app.config import config
+ from app.services.pricing_registration import register_pricing_from_global_configs
+
+ spy = _patch_register(monkeypatch)
+ _patch_openrouter_pricing(monkeypatch, {})
+
+ monkeypatch.setattr(
+ config,
+ "GLOBAL_LLM_CONFIGS",
+ [
+ {
+ "id": 1,
+ "provider": "OPENAI",
+ "model_name": "gpt-4o",
+ "litellm_params": {"base_model": "gpt-4o"},
+ }
+ ],
+ )
+
+ register_pricing_from_global_configs()
+
+ assert spy.calls == []
+
+
+def test_openrouter_skipped_when_pricing_missing(monkeypatch):
+ """If the OpenRouter raw-pricing cache doesn't carry an entry for a
+ configured model (network blip during refresh, model added later, etc.),
+ we skip it rather than registering zero pricing.
+ """
+ from app.config import config
+ from app.services.pricing_registration import register_pricing_from_global_configs
+
+ spy = _patch_register(monkeypatch)
+ _patch_openrouter_pricing(
+ monkeypatch, {"some/other-model": {"prompt": "1", "completion": "1"}}
+ )
+
+ monkeypatch.setattr(
+ config,
+ "GLOBAL_LLM_CONFIGS",
+ [
+ {
+ "id": 1,
+ "provider": "OPENROUTER",
+ "model_name": "anthropic/claude-3-5-sonnet",
+ }
+ ],
+ )
+
+ register_pricing_from_global_configs()
+
+ assert spy.calls == []
+
+
+def test_register_continues_after_individual_failure(monkeypatch, caplog):
+ """A single bad ``register_model`` call (e.g. raising LiteLLM error)
+ must not abort registration of the remaining configs.
+ """
+ from app.config import config
+ from app.services.pricing_registration import register_pricing_from_global_configs
+
+ failing_keys: set[str] = {"anthropic/claude-3-5-sonnet"}
+ successful_calls: list[dict[str, Any]] = []
+
+ def _maybe_fail(payload: dict[str, Any]) -> None:
+ if any(k in failing_keys for k in payload):
+ raise RuntimeError("boom")
+ successful_calls.append(payload)
+
+ monkeypatch.setattr(
+ "app.services.pricing_registration.litellm.register_model",
+ _maybe_fail,
+ raising=False,
+ )
+ _patch_openrouter_pricing(
+ monkeypatch,
+ {
+ "anthropic/claude-3-5-sonnet": {
+ "prompt": "0.000003",
+ "completion": "0.000015",
+ }
+ },
+ )
+
+ monkeypatch.setattr(
+ config,
+ "GLOBAL_LLM_CONFIGS",
+ [
+ {
+ "id": 1,
+ "provider": "OPENROUTER",
+ "model_name": "anthropic/claude-3-5-sonnet",
+ },
+ {
+ "id": 2,
+ "provider": "OPENAI",
+ "model_name": "custom-deployment",
+ "litellm_params": {
+ "base_model": "custom-deployment",
+ "input_cost_per_token": 1e-6,
+ "output_cost_per_token": 2e-6,
+ },
+ },
+ ],
+ )
+
+ register_pricing_from_global_configs()
+
+ # The good config still registered.
+ assert any("custom-deployment" in payload for payload in successful_calls)
+
+
+def test_vision_configs_registered_with_chat_shape(monkeypatch):
+ """``register_pricing_from_global_configs`` walks
+ ``GLOBAL_VISION_LLM_CONFIGS`` in addition to the chat configs so vision
+ calls (during indexing) bill correctly. Vision configs use the same
+ chat-shape token prices, but image-gen pricing is intentionally NOT
+ registered here (handled via ``response_cost`` in LiteLLM).
+ """
+ from app.config import config
+ from app.services.pricing_registration import register_pricing_from_global_configs
+
+ spy = _patch_register(monkeypatch)
+ _patch_openrouter_pricing(
+ monkeypatch,
+ {"openai/gpt-4o": {"prompt": "0.000005", "completion": "0.000015"}},
+ )
+
+ # No chat configs — only vision. Proves the vision walk is a separate
+ # iteration, not piggy-backed on the chat list.
+ monkeypatch.setattr(config, "GLOBAL_LLM_CONFIGS", [])
+ monkeypatch.setattr(
+ config,
+ "GLOBAL_VISION_LLM_CONFIGS",
+ [
+ {
+ "id": -1,
+ "provider": "OPENROUTER",
+ "model_name": "openai/gpt-4o",
+ "billing_tier": "premium",
+ "input_cost_per_token": 5e-6,
+ "output_cost_per_token": 15e-6,
+ }
+ ],
+ )
+
+ register_pricing_from_global_configs()
+
+ assert "openrouter/openai/gpt-4o" in spy.all_keys
+ payload_value = spy.calls[0]["openrouter/openai/gpt-4o"]
+ assert payload_value["mode"] == "chat"
+ assert payload_value["litellm_provider"] == "openrouter"
+ assert payload_value["input_cost_per_token"] == pytest.approx(5e-6)
+ assert payload_value["output_cost_per_token"] == pytest.approx(15e-6)
+
+
+def test_vision_with_inline_pricing_when_or_cache_missing(monkeypatch):
+ """If the OpenRouter pricing cache misses a vision model (different
+ catalogue surface), the vision walk falls back to inline
+ ``input_cost_per_token``/``output_cost_per_token`` on the cfg itself.
+ """
+ from app.config import config
+ from app.services.pricing_registration import register_pricing_from_global_configs
+
+ spy = _patch_register(monkeypatch)
+ _patch_openrouter_pricing(monkeypatch, {})
+
+ monkeypatch.setattr(config, "GLOBAL_LLM_CONFIGS", [])
+ monkeypatch.setattr(
+ config,
+ "GLOBAL_VISION_LLM_CONFIGS",
+ [
+ {
+ "id": -1,
+ "provider": "OPENROUTER",
+ "model_name": "google/gemini-2.5-flash",
+ "billing_tier": "premium",
+ "input_cost_per_token": 1e-6,
+ "output_cost_per_token": 4e-6,
+ }
+ ],
+ )
+
+ register_pricing_from_global_configs()
+
+ assert "openrouter/google/gemini-2.5-flash" in spy.all_keys
diff --git a/surfsense_backend/tests/unit/services/test_quota_checked_vision_llm.py b/surfsense_backend/tests/unit/services/test_quota_checked_vision_llm.py
new file mode 100644
index 000000000..9e35b6f9c
--- /dev/null
+++ b/surfsense_backend/tests/unit/services/test_quota_checked_vision_llm.py
@@ -0,0 +1,157 @@
+"""Unit tests for ``QuotaCheckedVisionLLM``.
+
+Validates that:
+
+* Calling ``ainvoke`` routes through ``billable_call`` (premium credit
+ enforcement) and forwards the inner LLM's response on success.
+* The wrapper proxies non-overridden attributes to the inner LLM
+ (``__getattr__``) so ``invoke`` / ``astream`` / ``with_structured_output``
+ still work without quota gating (they're not used in indexing today).
+* When ``billable_call`` raises ``QuotaInsufficientError`` the wrapper
+ bubbles it up — the ETL pipeline catches that and falls back to OCR.
+"""
+
+from __future__ import annotations
+
+import contextlib
+from typing import Any
+from uuid import uuid4
+
+import pytest
+
+pytestmark = pytest.mark.unit
+
+
+class _FakeInnerLLM:
+ """Stand-in for ``langchain_litellm.ChatLiteLLM``."""
+
+ def __init__(self, response: Any = "OCR'd content") -> None:
+ self._response = response
+ self.ainvoke_calls: list[Any] = []
+
+ async def ainvoke(self, input: Any, *args: Any, **kwargs: Any) -> Any:
+ self.ainvoke_calls.append(input)
+ return self._response
+
+ def some_other_method(self, x: int) -> int:
+ return x * 2
+
+
+@contextlib.asynccontextmanager
+async def _passthrough_billable_call(**_kwargs):
+ """Stand-in for billable_call that always allows the call to run."""
+
+ class _Acc:
+ total_cost_micros = 0
+ total_prompt_tokens = 0
+ total_completion_tokens = 0
+ grand_total = 0
+ calls: list[Any] = []
+
+ def per_message_summary(self) -> dict[str, dict[str, int]]:
+ return {}
+
+ yield _Acc()
+
+
+@pytest.mark.asyncio
+async def test_ainvoke_routes_through_billable_call(monkeypatch):
+ from app.services.quota_checked_vision_llm import QuotaCheckedVisionLLM
+
+ captured_kwargs: list[dict[str, Any]] = []
+
+ @contextlib.asynccontextmanager
+ async def _spy_billable_call(**kwargs):
+ captured_kwargs.append(kwargs)
+ async with _passthrough_billable_call() as acc:
+ yield acc
+
+ monkeypatch.setattr(
+ "app.services.quota_checked_vision_llm.billable_call",
+ _spy_billable_call,
+ raising=False,
+ )
+
+ inner = _FakeInnerLLM(response="A red apple on a white table")
+ user_id = uuid4()
+ wrapper = QuotaCheckedVisionLLM(
+ inner,
+ user_id=user_id,
+ search_space_id=99,
+ billing_tier="premium",
+ base_model="openai/gpt-4o",
+ quota_reserve_tokens=4000,
+ )
+
+ result = await wrapper.ainvoke([{"text": "what is this?"}])
+ assert result == "A red apple on a white table"
+ assert len(inner.ainvoke_calls) == 1
+ assert len(captured_kwargs) == 1
+ bc_kwargs = captured_kwargs[0]
+ assert bc_kwargs["user_id"] == user_id
+ assert bc_kwargs["search_space_id"] == 99
+ assert bc_kwargs["billing_tier"] == "premium"
+ assert bc_kwargs["base_model"] == "openai/gpt-4o"
+ assert bc_kwargs["quota_reserve_tokens"] == 4000
+ assert bc_kwargs["usage_type"] == "vision_extraction"
+
+
+@pytest.mark.asyncio
+async def test_ainvoke_propagates_quota_insufficient_error(monkeypatch):
+ from app.services.billable_calls import QuotaInsufficientError
+ from app.services.quota_checked_vision_llm import QuotaCheckedVisionLLM
+
+ @contextlib.asynccontextmanager
+ async def _denying_billable_call(**_kwargs):
+ raise QuotaInsufficientError(
+ usage_type="vision_extraction",
+ used_micros=5_000_000,
+ limit_micros=5_000_000,
+ remaining_micros=0,
+ )
+ yield # unreachable but required for asynccontextmanager type
+
+ monkeypatch.setattr(
+ "app.services.quota_checked_vision_llm.billable_call",
+ _denying_billable_call,
+ raising=False,
+ )
+
+ inner = _FakeInnerLLM()
+ wrapper = QuotaCheckedVisionLLM(
+ inner,
+ user_id=uuid4(),
+ search_space_id=1,
+ billing_tier="premium",
+ base_model="openai/gpt-4o",
+ quota_reserve_tokens=4000,
+ )
+
+ with pytest.raises(QuotaInsufficientError):
+ await wrapper.ainvoke([{"text": "x"}])
+
+ # Inner LLM never ran on a denied reservation.
+ assert inner.ainvoke_calls == []
+
+
+@pytest.mark.asyncio
+async def test_proxies_non_overridden_attributes_to_inner():
+ """``__getattr__`` forwards anything not on the proxy itself, so any
+ method we didn't explicitly override (``invoke``, ``astream``,
+ ``with_structured_output``, etc.) still works — just without quota
+ gating, which is fine because the indexer only ever calls ainvoke.
+ """
+ from app.services.quota_checked_vision_llm import QuotaCheckedVisionLLM
+
+ inner = _FakeInnerLLM()
+ wrapper = QuotaCheckedVisionLLM(
+ inner,
+ user_id=uuid4(),
+ search_space_id=1,
+ billing_tier="premium",
+ base_model="openai/gpt-4o",
+ quota_reserve_tokens=4000,
+ )
+
+ # ``some_other_method`` is on the inner only.
+ assert wrapper.some_other_method(7) == 14
diff --git a/surfsense_backend/tests/unit/services/test_token_quota_service_cost.py b/surfsense_backend/tests/unit/services/test_token_quota_service_cost.py
new file mode 100644
index 000000000..63681828d
--- /dev/null
+++ b/surfsense_backend/tests/unit/services/test_token_quota_service_cost.py
@@ -0,0 +1,515 @@
+"""Cost-based premium quota unit tests.
+
+Covers the USD-micro behaviour added in migration 140:
+
+* ``TurnTokenAccumulator.total_cost_micros`` sums ``cost_micros`` across all
+ calls in a turn — used as the debit amount when ``agent_config.is_premium``
+ is true, regardless of which underlying model produced each call. This
+ preserves the prior "premium turn → all calls in turn count" rule from the
+ token-based system.
+* ``estimate_call_reserve_micros`` scales linearly with model pricing,
+ clamps to a sane floor when pricing is unknown, and respects the
+ ``QUOTA_MAX_RESERVE_MICROS`` ceiling so a misconfigured "$1000/M" entry
+ can't lock the whole balance on one call.
+"""
+
+from __future__ import annotations
+
+import pytest
+
+pytestmark = pytest.mark.unit
+
+
+# ---------------------------------------------------------------------------
+# TurnTokenAccumulator — premium-turn debit semantics
+# ---------------------------------------------------------------------------
+
+
+def test_total_cost_micros_sums_premium_and_free_calls():
+ """A premium turn that also called a free sub-agent debits the union.
+
+ The plan deliberately preserved the existing "premium turn → all calls
+ count" behaviour because per-call premium filtering relied on
+ ``LLMRouterService._premium_model_strings`` which only covers router-pool
+ deployments. ``total_cost_micros`` therefore must include free-model
+ calls (whose ``cost_micros`` is typically ``0``) as well as the premium
+ call's actual provider cost.
+ """
+ from app.services.token_tracking_service import TurnTokenAccumulator
+
+ acc = TurnTokenAccumulator()
+ # Premium model (e.g. claude-opus): non-zero cost.
+ acc.add(
+ model="anthropic/claude-3-5-sonnet",
+ prompt_tokens=1200,
+ completion_tokens=400,
+ total_tokens=1600,
+ cost_micros=12_345,
+ )
+ # Free sub-agent (e.g. title-gen on a free model): zero cost.
+ acc.add(
+ model="gpt-4o-mini",
+ prompt_tokens=120,
+ completion_tokens=20,
+ total_tokens=140,
+ cost_micros=0,
+ )
+ # A second premium-priced call within the same turn.
+ acc.add(
+ model="anthropic/claude-3-5-sonnet",
+ prompt_tokens=800,
+ completion_tokens=200,
+ total_tokens=1000,
+ cost_micros=7_500,
+ )
+
+ assert acc.total_cost_micros == 12_345 + 0 + 7_500
+ # Token totals stay correct so the FE display path still works.
+ assert acc.grand_total == 1600 + 140 + 1000
+
+
+def test_total_cost_micros_zero_when_no_calls():
+ """An empty accumulator must report zero cost (no division-by-zero, no None)."""
+ from app.services.token_tracking_service import TurnTokenAccumulator
+
+ acc = TurnTokenAccumulator()
+ assert acc.total_cost_micros == 0
+ assert acc.grand_total == 0
+
+
+def test_per_message_summary_groups_cost_by_model():
+ """``per_message_summary`` must accumulate ``cost_micros`` per model so the
+ SSE ``model_breakdown`` payload reports actual USD spend per provider.
+ """
+ from app.services.token_tracking_service import TurnTokenAccumulator
+
+ acc = TurnTokenAccumulator()
+ acc.add(
+ model="claude-3-5-sonnet",
+ prompt_tokens=100,
+ completion_tokens=50,
+ total_tokens=150,
+ cost_micros=4_000,
+ )
+ acc.add(
+ model="claude-3-5-sonnet",
+ prompt_tokens=200,
+ completion_tokens=100,
+ total_tokens=300,
+ cost_micros=8_000,
+ )
+ acc.add(
+ model="gpt-4o-mini",
+ prompt_tokens=50,
+ completion_tokens=10,
+ total_tokens=60,
+ cost_micros=200,
+ )
+
+ summary = acc.per_message_summary()
+ assert summary["claude-3-5-sonnet"]["cost_micros"] == 12_000
+ assert summary["claude-3-5-sonnet"]["total_tokens"] == 450
+ assert summary["gpt-4o-mini"]["cost_micros"] == 200
+
+
+def test_serialized_calls_includes_cost_micros():
+ """``serialized_calls`` is what flows into the SSE ``call_details``
+ payload; cost_micros must be present on each entry so the FE message-info
+ dropdown can render per-call USD.
+ """
+ from app.services.token_tracking_service import TurnTokenAccumulator
+
+ acc = TurnTokenAccumulator()
+ acc.add(
+ model="m",
+ prompt_tokens=1,
+ completion_tokens=1,
+ total_tokens=2,
+ cost_micros=42,
+ )
+ serialized = acc.serialized_calls()
+ assert serialized == [
+ {
+ "model": "m",
+ "prompt_tokens": 1,
+ "completion_tokens": 1,
+ "total_tokens": 2,
+ "cost_micros": 42,
+ "call_kind": "chat",
+ }
+ ]
+
+
+# ---------------------------------------------------------------------------
+# estimate_call_reserve_micros — sizing and clamping
+# ---------------------------------------------------------------------------
+
+
+def test_reserve_returns_floor_when_model_unknown(monkeypatch):
+ """If LiteLLM doesn't know the model, ``get_model_info`` raises and the
+ helper falls back to the 100-micro floor — small enough that a user with
+ $0.0001 left can still send a tiny request, but non-zero so we still gate
+ against an empty balance.
+ """
+ import litellm
+
+ from app.services import token_quota_service
+
+ def _raise(_name):
+ raise KeyError("unknown")
+
+ monkeypatch.setattr(litellm, "get_model_info", _raise, raising=False)
+
+ micros = token_quota_service.estimate_call_reserve_micros(
+ base_model="nonexistent-model",
+ quota_reserve_tokens=4000,
+ )
+ assert micros == token_quota_service._QUOTA_MIN_RESERVE_MICROS
+ assert micros == 100
+
+
+def test_reserve_returns_floor_when_pricing_is_zero(monkeypatch):
+ """LiteLLM may *return* a model with both cost-per-token fields at 0
+ (pricing not yet registered). The helper must not multiply 0 x tokens
+ and end up reserving 0 — it must clamp to the floor.
+ """
+ import litellm
+
+ from app.services import token_quota_service
+
+ monkeypatch.setattr(
+ litellm,
+ "get_model_info",
+ lambda _name: {"input_cost_per_token": 0, "output_cost_per_token": 0},
+ raising=False,
+ )
+
+ micros = token_quota_service.estimate_call_reserve_micros(
+ base_model="some-pending-model",
+ quota_reserve_tokens=4000,
+ )
+ assert micros == token_quota_service._QUOTA_MIN_RESERVE_MICROS
+
+
+def test_reserve_scales_with_model_cost(monkeypatch):
+ """Claude-Opus-priced model with 4000 reserve_tokens reserves
+ ~$0.36 = 360_000 micros. Critically this must NOT be clamped down to
+ some small artificial cap — that was the bug the plan called out.
+ """
+ import litellm
+
+ from app.config import config
+ from app.services import token_quota_service
+
+ monkeypatch.setattr(
+ litellm,
+ "get_model_info",
+ lambda _name: {
+ "input_cost_per_token": 15e-6,
+ "output_cost_per_token": 75e-6,
+ },
+ raising=False,
+ )
+ monkeypatch.setattr(config, "QUOTA_MAX_RESERVE_MICROS", 1_000_000, raising=False)
+
+ micros = token_quota_service.estimate_call_reserve_micros(
+ base_model="claude-3-opus",
+ quota_reserve_tokens=4000,
+ )
+ # 4000 * (15e-6 + 75e-6) = 4000 * 90e-6 = 0.36 USD = 360_000 micros.
+ assert micros == 360_000
+
+
+def test_reserve_clamps_to_max_ceiling(monkeypatch):
+ """A misconfigured "$1000 / M" model with 4000 reserve_tokens would
+ nominally compute to $4 = 4_000_000 micros. The ceiling
+ ``QUOTA_MAX_RESERVE_MICROS`` must clamp that so a bad pricing entry
+ can't lock the user's whole balance on one call.
+ """
+ import litellm
+
+ from app.config import config
+ from app.services import token_quota_service
+
+ monkeypatch.setattr(
+ litellm,
+ "get_model_info",
+ lambda _name: {
+ "input_cost_per_token": 1e-3,
+ "output_cost_per_token": 0,
+ },
+ raising=False,
+ )
+ monkeypatch.setattr(config, "QUOTA_MAX_RESERVE_MICROS", 1_000_000, raising=False)
+
+ micros = token_quota_service.estimate_call_reserve_micros(
+ base_model="oops-misconfigured",
+ quota_reserve_tokens=4000,
+ )
+ assert micros == 1_000_000
+
+
+def test_reserve_uses_default_when_quota_reserve_tokens_missing(monkeypatch):
+ """Per-config ``quota_reserve_tokens`` is optional; when ``None`` or
+ zero, the helper must fall back to the global ``QUOTA_MAX_RESERVE_PER_CALL``
+ so anonymous-style configs still reserve the operator-tunable default.
+ """
+ import litellm
+
+ from app.config import config
+ from app.services import token_quota_service
+
+ monkeypatch.setattr(
+ litellm,
+ "get_model_info",
+ lambda _name: {
+ "input_cost_per_token": 1e-6,
+ "output_cost_per_token": 1e-6,
+ },
+ raising=False,
+ )
+ monkeypatch.setattr(config, "QUOTA_MAX_RESERVE_PER_CALL", 2000, raising=False)
+ monkeypatch.setattr(config, "QUOTA_MAX_RESERVE_MICROS", 1_000_000, raising=False)
+
+ # 2000 * (1e-6 + 1e-6) = 4e-3 USD = 4000 micros
+ assert (
+ token_quota_service.estimate_call_reserve_micros(
+ base_model="cheap", quota_reserve_tokens=None
+ )
+ == 4000
+ )
+ assert (
+ token_quota_service.estimate_call_reserve_micros(
+ base_model="cheap", quota_reserve_tokens=0
+ )
+ == 4000
+ )
+
+
+# ---------------------------------------------------------------------------
+# TokenTrackingCallback — image vs chat usage shape
+# ---------------------------------------------------------------------------
+
+
+class _FakeImageUsage:
+ """Mimics LiteLLM's ``ImageUsage`` (input_tokens / output_tokens shape)."""
+
+ def __init__(
+ self,
+ input_tokens: int = 0,
+ output_tokens: int = 0,
+ total_tokens: int | None = None,
+ ) -> None:
+ self.input_tokens = input_tokens
+ self.output_tokens = output_tokens
+ if total_tokens is not None:
+ self.total_tokens = total_tokens
+
+
+class _FakeImageResponse:
+ """Mimics LiteLLM's ``ImageResponse`` — same name so the callback's
+ ``type(...).__name__`` probe routes to the image branch.
+ """
+
+ def __init__(self, usage: _FakeImageUsage, response_cost: float | None = None):
+ self.usage = usage
+ if response_cost is not None:
+ self._hidden_params = {"response_cost": response_cost}
+
+
+# Re-tag the helper class as ``ImageResponse`` for the type-name probe in
+# the callback. We can't simply name the class ``ImageResponse`` because
+# the test runner sometimes imports test modules in surprising ways and
+# we want to be explicit.
+_FakeImageResponse.__name__ = "ImageResponse"
+
+
+class _FakeChatUsage:
+ def __init__(self, prompt: int, completion: int):
+ self.prompt_tokens = prompt
+ self.completion_tokens = completion
+ self.total_tokens = prompt + completion
+
+
+class _FakeChatResponse:
+ def __init__(self, usage: _FakeChatUsage):
+ self.usage = usage
+
+
+@pytest.mark.asyncio
+async def test_callback_reads_image_usage_input_output_tokens():
+ """``TokenTrackingCallback`` must read ``input_tokens``/``output_tokens``
+ for ``ImageResponse`` (LiteLLM's ImageUsage shape), NOT
+ prompt_tokens/completion_tokens which is the chat shape.
+ """
+ from app.services.token_tracking_service import (
+ TokenTrackingCallback,
+ scoped_turn,
+ )
+
+ cb = TokenTrackingCallback()
+ response = _FakeImageResponse(
+ usage=_FakeImageUsage(input_tokens=42, output_tokens=8, total_tokens=50),
+ response_cost=0.04, # $0.04 per image
+ )
+
+ async with scoped_turn() as acc:
+ await cb.async_log_success_event(
+ kwargs={"model": "openai/gpt-image-1", "response_cost": 0.04},
+ response_obj=response,
+ start_time=None,
+ end_time=None,
+ )
+ assert len(acc.calls) == 1
+ call = acc.calls[0]
+ assert call.prompt_tokens == 42
+ assert call.completion_tokens == 8
+ assert call.total_tokens == 50
+ # 0.04 USD = 40_000 micros
+ assert call.cost_micros == 40_000
+ assert call.call_kind == "image_generation"
+
+
+@pytest.mark.asyncio
+async def test_callback_chat_path_unchanged():
+ """Chat responses must still read prompt_tokens/completion_tokens."""
+ from app.services.token_tracking_service import (
+ TokenTrackingCallback,
+ scoped_turn,
+ )
+
+ cb = TokenTrackingCallback()
+ response = _FakeChatResponse(_FakeChatUsage(prompt=120, completion=30))
+
+ async with scoped_turn() as acc:
+ await cb.async_log_success_event(
+ kwargs={
+ "model": "openrouter/anthropic/claude-3-5-sonnet",
+ "response_cost": 0.0036,
+ },
+ response_obj=response,
+ start_time=None,
+ end_time=None,
+ )
+ assert len(acc.calls) == 1
+ call = acc.calls[0]
+ assert call.prompt_tokens == 120
+ assert call.completion_tokens == 30
+ assert call.total_tokens == 150
+ assert call.cost_micros == 3_600
+ assert call.call_kind == "chat"
+
+
+@pytest.mark.asyncio
+async def test_callback_image_missing_response_cost_falls_back_to_zero(monkeypatch):
+ """When OpenRouter omits ``usage.cost`` LiteLLM's
+ ``default_image_cost_calculator`` raises. The defensive image branch in
+ ``_extract_cost_usd`` must NOT call ``cost_per_token`` (which is
+ chat-shaped and would raise too) — it returns 0 with a WARNING log.
+ """
+ import litellm
+
+ from app.services.token_tracking_service import (
+ TokenTrackingCallback,
+ scoped_turn,
+ )
+
+ # Force completion_cost to raise the same way OpenRouter image-gen fails.
+ def _boom(*_args, **_kwargs):
+ raise ValueError("model_cost: missing entry for openrouter image model")
+
+ monkeypatch.setattr(litellm, "completion_cost", _boom, raising=False)
+
+ # And make sure cost_per_token is NEVER called for the image path —
+ # if it were, our ``is_image=True`` branch is broken.
+ cost_per_token_calls: list = []
+
+ def _record_cost_per_token(**kwargs):
+ cost_per_token_calls.append(kwargs)
+ return (0.0, 0.0)
+
+ monkeypatch.setattr(
+ litellm, "cost_per_token", _record_cost_per_token, raising=False
+ )
+
+ cb = TokenTrackingCallback()
+ response = _FakeImageResponse(
+ usage=_FakeImageUsage(input_tokens=7, output_tokens=0)
+ )
+
+ async with scoped_turn() as acc:
+ await cb.async_log_success_event(
+ kwargs={"model": "openrouter/google/gemini-2.5-flash-image"},
+ response_obj=response,
+ start_time=None,
+ end_time=None,
+ )
+
+ assert len(acc.calls) == 1
+ assert acc.calls[0].cost_micros == 0
+ assert acc.calls[0].call_kind == "image_generation"
+ # The image branch must short-circuit before cost_per_token.
+ assert cost_per_token_calls == []
+
+
+# ---------------------------------------------------------------------------
+# scoped_turn — ContextVar reset semantics (issue B)
+# ---------------------------------------------------------------------------
+
+
+@pytest.mark.asyncio
+async def test_scoped_turn_restores_outer_accumulator():
+ """``scoped_turn`` must restore the previous ContextVar value on exit
+ so a per-call wrapper inside an outer chat turn doesn't leak its
+ accumulator outward (which would cause double-debit at chat-turn exit).
+ """
+ from app.services.token_tracking_service import (
+ get_current_accumulator,
+ scoped_turn,
+ start_turn,
+ )
+
+ outer = start_turn()
+ assert get_current_accumulator() is outer
+
+ async with scoped_turn() as inner:
+ assert get_current_accumulator() is inner
+ assert inner is not outer
+ inner.add(
+ model="x",
+ prompt_tokens=1,
+ completion_tokens=1,
+ total_tokens=2,
+ cost_micros=5,
+ )
+
+ # After exit the outer accumulator is restored unchanged.
+ assert get_current_accumulator() is outer
+ assert outer.total_cost_micros == 0
+ assert len(outer.calls) == 0
+ # The inner accumulator captured the call but didn't bleed into outer.
+ assert inner.total_cost_micros == 5
+
+
+@pytest.mark.asyncio
+async def test_scoped_turn_resets_to_none_when_no_outer():
+ """Running ``scoped_turn`` outside any chat turn (e.g. a background
+ indexing job) must leave the ContextVar at ``None`` on exit so the
+ next *unrelated* request starts clean.
+ """
+ from app.services.token_tracking_service import (
+ _turn_accumulator,
+ get_current_accumulator,
+ scoped_turn,
+ )
+
+ # ContextVar default is None for a fresh test isolated context. We
+ # simulate "no outer" explicitly to be robust against test order.
+ token = _turn_accumulator.set(None)
+ try:
+ assert get_current_accumulator() is None
+ async with scoped_turn() as acc:
+ assert get_current_accumulator() is acc
+ assert get_current_accumulator() is None
+ finally:
+ _turn_accumulator.reset(token)
diff --git a/surfsense_backend/tests/unit/tasks/test_podcast_billing.py b/surfsense_backend/tests/unit/tasks/test_podcast_billing.py
new file mode 100644
index 000000000..38d6ba2ca
--- /dev/null
+++ b/surfsense_backend/tests/unit/tasks/test_podcast_billing.py
@@ -0,0 +1,325 @@
+"""Unit tests for podcast Celery task billing integration.
+
+Validates ``_generate_content_podcast`` correctly wraps
+``podcaster_graph.ainvoke`` in a ``billable_call`` envelope, propagates the
+search-space owner's billing decision, and degrades cleanly when the
+resolver fails or premium credit is exhausted.
+
+Coverage:
+
+* Happy-path free config: resolver → ``billable_call`` enters with
+ ``usage_type='podcast_generation'`` and the configured reserve override,
+ graph runs, podcast row flips to ``READY``.
+* Happy-path premium config: same wiring with ``billing_tier='premium'``.
+* Quota denial: ``billable_call`` raises ``QuotaInsufficientError`` →
+ graph is *not* invoked, podcast row flips to ``FAILED``, return dict
+ carries ``reason='premium_quota_exhausted'``.
+* Resolver failure: ``ValueError`` from the resolver → podcast row flips
+ to ``FAILED``, return dict carries ``reason='billing_resolution_failed'``.
+"""
+
+from __future__ import annotations
+
+import contextlib
+from types import SimpleNamespace
+from typing import Any
+from uuid import uuid4
+
+import pytest
+
+pytestmark = pytest.mark.unit
+
+
+# ---------------------------------------------------------------------------
+# Fakes
+# ---------------------------------------------------------------------------
+
+
+class _FakeExecResult:
+ def __init__(self, obj):
+ self._obj = obj
+
+ def scalars(self):
+ return self
+
+ def first(self):
+ return self._obj
+
+ def filter(self, *_args, **_kwargs):
+ return self
+
+
+class _FakeSession:
+ def __init__(self, podcast):
+ self._podcast = podcast
+ self.commit_count = 0
+
+ async def execute(self, _stmt):
+ return _FakeExecResult(self._podcast)
+
+ async def commit(self):
+ self.commit_count += 1
+
+ async def __aenter__(self):
+ return self
+
+ async def __aexit__(self, *args):
+ return None
+
+
+class _FakeSessionMaker:
+ def __init__(self, session: _FakeSession):
+ self._session = session
+
+ def __call__(self):
+ return self._session
+
+
+def _make_podcast(podcast_id: int = 7, thread_id: int = 99) -> SimpleNamespace:
+ """Stand-in for a ``Podcast`` row. Importing ``PodcastStatus`` lazily
+ inside helpers keeps this fixture cheap."""
+ return SimpleNamespace(
+ id=podcast_id,
+ title="Test Podcast",
+ thread_id=thread_id,
+ status=None,
+ podcast_transcript=None,
+ file_location=None,
+ )
+
+
+@contextlib.asynccontextmanager
+async def _ok_billable_call(**kwargs):
+ """Stand-in for ``billable_call`` that records its kwargs and yields a
+ no-op accumulator-shaped object."""
+ _CALL_LOG.append(kwargs)
+ yield SimpleNamespace()
+
+
+_CALL_LOG: list[dict[str, Any]] = []
+
+
+@contextlib.asynccontextmanager
+async def _denying_billable_call(**kwargs):
+ from app.services.billable_calls import QuotaInsufficientError
+
+ _CALL_LOG.append(kwargs)
+ raise QuotaInsufficientError(
+ usage_type=kwargs.get("usage_type", "?"),
+ used_micros=5_000_000,
+ limit_micros=5_000_000,
+ remaining_micros=0,
+ )
+ yield SimpleNamespace() # pragma: no cover — for grammar only
+
+
+# ---------------------------------------------------------------------------
+# Tests
+# ---------------------------------------------------------------------------
+
+
+@pytest.fixture(autouse=True)
+def _reset_call_log():
+ _CALL_LOG.clear()
+ yield
+ _CALL_LOG.clear()
+
+
+@pytest.mark.asyncio
+async def test_billable_call_invoked_with_correct_kwargs_for_free_config(monkeypatch):
+ """Happy path: free billing tier still wraps the graph call so the
+ audit row is recorded. Verifies kwargs threading."""
+ from app.config import config as app_config
+ from app.db import PodcastStatus
+ from app.tasks.celery_tasks import podcast_tasks
+
+ podcast = _make_podcast(podcast_id=7, thread_id=99)
+ session = _FakeSession(podcast)
+ monkeypatch.setattr(
+ podcast_tasks,
+ "get_celery_session_maker",
+ lambda: _FakeSessionMaker(session),
+ )
+
+ user_id = uuid4()
+
+ async def _fake_resolver(sess, search_space_id, *, thread_id=None):
+ assert search_space_id == 555
+ assert thread_id == 99
+ return user_id, "free", "openrouter/some-free-model"
+
+ monkeypatch.setattr(
+ podcast_tasks, "_resolve_agent_billing_for_search_space", _fake_resolver
+ )
+ monkeypatch.setattr(podcast_tasks, "billable_call", _ok_billable_call)
+
+ async def _fake_graph_invoke(state, config):
+ return {
+ "podcast_transcript": [
+ SimpleNamespace(speaker_id=0, dialog="Hi"),
+ SimpleNamespace(speaker_id=1, dialog="Hello"),
+ ],
+ "final_podcast_file_path": "/tmp/podcast.wav",
+ }
+
+ monkeypatch.setattr(podcast_tasks.podcaster_graph, "ainvoke", _fake_graph_invoke)
+
+ result = await podcast_tasks._generate_content_podcast(
+ podcast_id=7,
+ source_content="hello world",
+ search_space_id=555,
+ user_prompt="make it short",
+ )
+
+ assert result["status"] == "ready"
+ assert result["podcast_id"] == 7
+ assert podcast.status == PodcastStatus.READY
+ assert podcast.file_location == "/tmp/podcast.wav"
+
+ assert len(_CALL_LOG) == 1
+ call = _CALL_LOG[0]
+ assert call["user_id"] == user_id
+ assert call["search_space_id"] == 555
+ assert call["billing_tier"] == "free"
+ assert call["base_model"] == "openrouter/some-free-model"
+ assert call["usage_type"] == "podcast_generation"
+ assert (
+ call["quota_reserve_micros_override"]
+ == app_config.QUOTA_DEFAULT_PODCAST_RESERVE_MICROS
+ )
+ assert call["thread_id"] == 99
+ assert call["call_details"] == {"podcast_id": 7, "title": "Test Podcast"}
+
+
+@pytest.mark.asyncio
+async def test_billable_call_invoked_with_premium_tier(monkeypatch):
+ """Premium resolution flows through to ``billable_call`` so the
+ reserve/finalize path triggers."""
+ from app.tasks.celery_tasks import podcast_tasks
+
+ podcast = _make_podcast()
+ session = _FakeSession(podcast)
+ monkeypatch.setattr(
+ podcast_tasks,
+ "get_celery_session_maker",
+ lambda: _FakeSessionMaker(session),
+ )
+
+ user_id = uuid4()
+
+ async def _fake_resolver(sess, search_space_id, *, thread_id=None):
+ return user_id, "premium", "gpt-5.4"
+
+ monkeypatch.setattr(
+ podcast_tasks, "_resolve_agent_billing_for_search_space", _fake_resolver
+ )
+ monkeypatch.setattr(podcast_tasks, "billable_call", _ok_billable_call)
+
+ async def _fake_graph_invoke(state, config):
+ return {"podcast_transcript": [], "final_podcast_file_path": "x.wav"}
+
+ monkeypatch.setattr(podcast_tasks.podcaster_graph, "ainvoke", _fake_graph_invoke)
+
+ await podcast_tasks._generate_content_podcast(
+ podcast_id=7,
+ source_content="hi",
+ search_space_id=555,
+ user_prompt=None,
+ )
+
+ assert _CALL_LOG[0]["billing_tier"] == "premium"
+ assert _CALL_LOG[0]["base_model"] == "gpt-5.4"
+
+
+@pytest.mark.asyncio
+async def test_quota_insufficient_marks_podcast_failed_and_skips_graph(monkeypatch):
+ """When ``billable_call`` denies the reservation, the graph never
+ runs and the podcast row flips to FAILED with the documented reason
+ code."""
+ from app.db import PodcastStatus
+ from app.tasks.celery_tasks import podcast_tasks
+
+ podcast = _make_podcast(podcast_id=8)
+ session = _FakeSession(podcast)
+ monkeypatch.setattr(
+ podcast_tasks,
+ "get_celery_session_maker",
+ lambda: _FakeSessionMaker(session),
+ )
+
+ async def _fake_resolver(sess, search_space_id, *, thread_id=None):
+ return uuid4(), "premium", "gpt-5.4"
+
+ monkeypatch.setattr(
+ podcast_tasks, "_resolve_agent_billing_for_search_space", _fake_resolver
+ )
+ monkeypatch.setattr(podcast_tasks, "billable_call", _denying_billable_call)
+
+ graph_invoked = []
+
+ async def _fake_graph_invoke(state, config):
+ graph_invoked.append(True)
+ return {}
+
+ monkeypatch.setattr(podcast_tasks.podcaster_graph, "ainvoke", _fake_graph_invoke)
+
+ result = await podcast_tasks._generate_content_podcast(
+ podcast_id=8,
+ source_content="hi",
+ search_space_id=555,
+ user_prompt=None,
+ )
+
+ assert result == {
+ "status": "failed",
+ "podcast_id": 8,
+ "reason": "premium_quota_exhausted",
+ }
+ assert podcast.status == PodcastStatus.FAILED
+ assert graph_invoked == [] # Graph never ran on denied reservation.
+
+
+@pytest.mark.asyncio
+async def test_resolver_failure_marks_podcast_failed(monkeypatch):
+ """If the resolver raises (e.g. search-space deleted), the task fails
+ cleanly without invoking the graph."""
+ from app.db import PodcastStatus
+ from app.tasks.celery_tasks import podcast_tasks
+
+ podcast = _make_podcast(podcast_id=9)
+ session = _FakeSession(podcast)
+ monkeypatch.setattr(
+ podcast_tasks,
+ "get_celery_session_maker",
+ lambda: _FakeSessionMaker(session),
+ )
+
+ async def _failing_resolver(sess, search_space_id, *, thread_id=None):
+ raise ValueError("Search space 555 not found")
+
+ monkeypatch.setattr(
+ podcast_tasks, "_resolve_agent_billing_for_search_space", _failing_resolver
+ )
+
+ graph_invoked = []
+
+ async def _fake_graph_invoke(state, config):
+ graph_invoked.append(True)
+ return {}
+
+ monkeypatch.setattr(podcast_tasks.podcaster_graph, "ainvoke", _fake_graph_invoke)
+
+ result = await podcast_tasks._generate_content_podcast(
+ podcast_id=9,
+ source_content="hi",
+ search_space_id=555,
+ user_prompt=None,
+ )
+
+ assert result == {
+ "status": "failed",
+ "podcast_id": 9,
+ "reason": "billing_resolution_failed",
+ }
+ assert podcast.status == PodcastStatus.FAILED
+ assert graph_invoked == []
diff --git a/surfsense_backend/tests/unit/tasks/test_video_presentation_billing.py b/surfsense_backend/tests/unit/tasks/test_video_presentation_billing.py
new file mode 100644
index 000000000..671f57ae4
--- /dev/null
+++ b/surfsense_backend/tests/unit/tasks/test_video_presentation_billing.py
@@ -0,0 +1,330 @@
+"""Unit tests for video-presentation Celery task billing integration.
+
+Mirrors ``test_podcast_billing.py`` for the video-presentation task.
+Validates the same wrap-graph-in-billable_call pattern and ensures the
+larger ``QUOTA_DEFAULT_VIDEO_PRESENTATION_RESERVE_MICROS`` reservation is
+threaded through.
+
+Coverage:
+
+* Free config: graph runs, ``billable_call`` invoked with the video
+ reserve override.
+* Premium config: same wiring with ``billing_tier='premium'``.
+* Quota denial: graph not invoked, row → FAILED, reason code surfaced.
+* Resolver failure: row → FAILED with ``billing_resolution_failed``.
+"""
+
+from __future__ import annotations
+
+import contextlib
+from types import SimpleNamespace
+from typing import Any
+from uuid import uuid4
+
+import pytest
+
+pytestmark = pytest.mark.unit
+
+
+# ---------------------------------------------------------------------------
+# Fakes
+# ---------------------------------------------------------------------------
+
+
+class _FakeExecResult:
+ def __init__(self, obj):
+ self._obj = obj
+
+ def scalars(self):
+ return self
+
+ def first(self):
+ return self._obj
+
+ def filter(self, *_args, **_kwargs):
+ return self
+
+
+class _FakeSession:
+ def __init__(self, video):
+ self._video = video
+ self.commit_count = 0
+
+ async def execute(self, _stmt):
+ return _FakeExecResult(self._video)
+
+ async def commit(self):
+ self.commit_count += 1
+
+ async def __aenter__(self):
+ return self
+
+ async def __aexit__(self, *args):
+ return None
+
+
+class _FakeSessionMaker:
+ def __init__(self, session: _FakeSession):
+ self._session = session
+
+ def __call__(self):
+ return self._session
+
+
+def _make_video(video_id: int = 11, thread_id: int = 99) -> SimpleNamespace:
+ return SimpleNamespace(
+ id=video_id,
+ title="Test Presentation",
+ thread_id=thread_id,
+ status=None,
+ slides=None,
+ scene_codes=None,
+ )
+
+
+_CALL_LOG: list[dict[str, Any]] = []
+
+
+@contextlib.asynccontextmanager
+async def _ok_billable_call(**kwargs):
+ _CALL_LOG.append(kwargs)
+ yield SimpleNamespace()
+
+
+@contextlib.asynccontextmanager
+async def _denying_billable_call(**kwargs):
+ from app.services.billable_calls import QuotaInsufficientError
+
+ _CALL_LOG.append(kwargs)
+ raise QuotaInsufficientError(
+ usage_type=kwargs.get("usage_type", "?"),
+ used_micros=5_000_000,
+ limit_micros=5_000_000,
+ remaining_micros=0,
+ )
+ yield SimpleNamespace() # pragma: no cover
+
+
+# ---------------------------------------------------------------------------
+# Tests
+# ---------------------------------------------------------------------------
+
+
+@pytest.fixture(autouse=True)
+def _reset_call_log():
+ _CALL_LOG.clear()
+ yield
+ _CALL_LOG.clear()
+
+
+@pytest.mark.asyncio
+async def test_billable_call_invoked_with_correct_kwargs_for_free_config(monkeypatch):
+ from app.config import config as app_config
+ from app.db import VideoPresentationStatus
+ from app.tasks.celery_tasks import video_presentation_tasks
+
+ video = _make_video(video_id=11, thread_id=99)
+ session = _FakeSession(video)
+ monkeypatch.setattr(
+ video_presentation_tasks,
+ "get_celery_session_maker",
+ lambda: _FakeSessionMaker(session),
+ )
+
+ user_id = uuid4()
+
+ async def _fake_resolver(sess, search_space_id, *, thread_id=None):
+ assert search_space_id == 777
+ assert thread_id == 99
+ return user_id, "free", "openrouter/some-free-model"
+
+ monkeypatch.setattr(
+ video_presentation_tasks,
+ "_resolve_agent_billing_for_search_space",
+ _fake_resolver,
+ )
+ monkeypatch.setattr(video_presentation_tasks, "billable_call", _ok_billable_call)
+
+ async def _fake_graph_invoke(state, config):
+ return {"slides": [], "slide_audio_results": [], "slide_scene_codes": []}
+
+ monkeypatch.setattr(
+ video_presentation_tasks.video_presentation_graph,
+ "ainvoke",
+ _fake_graph_invoke,
+ )
+
+ result = await video_presentation_tasks._generate_video_presentation(
+ video_presentation_id=11,
+ source_content="content",
+ search_space_id=777,
+ user_prompt=None,
+ )
+
+ assert result["status"] == "ready"
+ assert result["video_presentation_id"] == 11
+ assert video.status == VideoPresentationStatus.READY
+
+ assert len(_CALL_LOG) == 1
+ call = _CALL_LOG[0]
+ assert call["user_id"] == user_id
+ assert call["search_space_id"] == 777
+ assert call["billing_tier"] == "free"
+ assert call["base_model"] == "openrouter/some-free-model"
+ assert call["usage_type"] == "video_presentation_generation"
+ assert (
+ call["quota_reserve_micros_override"]
+ == app_config.QUOTA_DEFAULT_VIDEO_PRESENTATION_RESERVE_MICROS
+ )
+ assert call["thread_id"] == 99
+ assert call["call_details"] == {
+ "video_presentation_id": 11,
+ "title": "Test Presentation",
+ }
+
+
+@pytest.mark.asyncio
+async def test_billable_call_invoked_with_premium_tier(monkeypatch):
+ from app.tasks.celery_tasks import video_presentation_tasks
+
+ video = _make_video()
+ session = _FakeSession(video)
+ monkeypatch.setattr(
+ video_presentation_tasks,
+ "get_celery_session_maker",
+ lambda: _FakeSessionMaker(session),
+ )
+
+ user_id = uuid4()
+
+ async def _fake_resolver(sess, search_space_id, *, thread_id=None):
+ return user_id, "premium", "gpt-5.4"
+
+ monkeypatch.setattr(
+ video_presentation_tasks,
+ "_resolve_agent_billing_for_search_space",
+ _fake_resolver,
+ )
+ monkeypatch.setattr(video_presentation_tasks, "billable_call", _ok_billable_call)
+
+ async def _fake_graph_invoke(state, config):
+ return {"slides": [], "slide_audio_results": [], "slide_scene_codes": []}
+
+ monkeypatch.setattr(
+ video_presentation_tasks.video_presentation_graph,
+ "ainvoke",
+ _fake_graph_invoke,
+ )
+
+ await video_presentation_tasks._generate_video_presentation(
+ video_presentation_id=11,
+ source_content="content",
+ search_space_id=777,
+ user_prompt=None,
+ )
+
+ assert _CALL_LOG[0]["billing_tier"] == "premium"
+ assert _CALL_LOG[0]["base_model"] == "gpt-5.4"
+
+
+@pytest.mark.asyncio
+async def test_quota_insufficient_marks_video_failed_and_skips_graph(monkeypatch):
+ from app.db import VideoPresentationStatus
+ from app.tasks.celery_tasks import video_presentation_tasks
+
+ video = _make_video(video_id=12)
+ session = _FakeSession(video)
+ monkeypatch.setattr(
+ video_presentation_tasks,
+ "get_celery_session_maker",
+ lambda: _FakeSessionMaker(session),
+ )
+
+ async def _fake_resolver(sess, search_space_id, *, thread_id=None):
+ return uuid4(), "premium", "gpt-5.4"
+
+ monkeypatch.setattr(
+ video_presentation_tasks,
+ "_resolve_agent_billing_for_search_space",
+ _fake_resolver,
+ )
+ monkeypatch.setattr(
+ video_presentation_tasks, "billable_call", _denying_billable_call
+ )
+
+ graph_invoked = []
+
+ async def _fake_graph_invoke(state, config):
+ graph_invoked.append(True)
+ return {}
+
+ monkeypatch.setattr(
+ video_presentation_tasks.video_presentation_graph,
+ "ainvoke",
+ _fake_graph_invoke,
+ )
+
+ result = await video_presentation_tasks._generate_video_presentation(
+ video_presentation_id=12,
+ source_content="content",
+ search_space_id=777,
+ user_prompt=None,
+ )
+
+ assert result == {
+ "status": "failed",
+ "video_presentation_id": 12,
+ "reason": "premium_quota_exhausted",
+ }
+ assert video.status == VideoPresentationStatus.FAILED
+ assert graph_invoked == []
+
+
+@pytest.mark.asyncio
+async def test_resolver_failure_marks_video_failed(monkeypatch):
+ from app.db import VideoPresentationStatus
+ from app.tasks.celery_tasks import video_presentation_tasks
+
+ video = _make_video(video_id=13)
+ session = _FakeSession(video)
+ monkeypatch.setattr(
+ video_presentation_tasks,
+ "get_celery_session_maker",
+ lambda: _FakeSessionMaker(session),
+ )
+
+ async def _failing_resolver(sess, search_space_id, *, thread_id=None):
+ raise ValueError("Search space 777 not found")
+
+ monkeypatch.setattr(
+ video_presentation_tasks,
+ "_resolve_agent_billing_for_search_space",
+ _failing_resolver,
+ )
+
+ graph_invoked = []
+
+ async def _fake_graph_invoke(state, config):
+ graph_invoked.append(True)
+ return {}
+
+ monkeypatch.setattr(
+ video_presentation_tasks.video_presentation_graph,
+ "ainvoke",
+ _fake_graph_invoke,
+ )
+
+ result = await video_presentation_tasks._generate_video_presentation(
+ video_presentation_id=13,
+ source_content="content",
+ search_space_id=777,
+ user_prompt=None,
+ )
+
+ assert result == {
+ "status": "failed",
+ "video_presentation_id": 13,
+ "reason": "billing_resolution_failed",
+ }
+ assert video.status == VideoPresentationStatus.FAILED
+ assert graph_invoked == []
diff --git a/surfsense_web/app/(home)/free/page.tsx b/surfsense_web/app/(home)/free/page.tsx
index 8d9ed5cb1..3ddd5195f 100644
--- a/surfsense_web/app/(home)/free/page.tsx
+++ b/surfsense_web/app/(home)/free/page.tsx
@@ -127,7 +127,7 @@ const FAQ_ITEMS = [
{
question: "What happens after I use my free tokens?",
answer:
- "After your free tokens, create a free SurfSense account to unlock 3 million more premium tokens. Additional tokens can be purchased at $1 per million. Non-premium models remain unlimited for registered users.",
+ "After your free tokens, create a free SurfSense account to unlock $5 of premium credit. Additional credit can be topped up at $1 for $1 of credit, billed at the actual provider cost. Non-premium models remain unlimited for registered users.",
},
{
question: "Is Claude AI available without login?",
@@ -329,7 +329,7 @@ export default async function FreeHubPage() {
Want More Features?
- Create a free SurfSense account to unlock 3 million tokens, document uploads with
+ Create a free SurfSense account to unlock $5 of premium credit, document uploads with
citations, team collaboration, and integrations with Slack, Google Drive, Notion, and
30+ more tools.
diff --git a/surfsense_web/app/(home)/pricing/page.tsx b/surfsense_web/app/(home)/pricing/page.tsx
index 6ad9435bf..6f332be70 100644
--- a/surfsense_web/app/(home)/pricing/page.tsx
+++ b/surfsense_web/app/(home)/pricing/page.tsx
@@ -5,7 +5,7 @@ import { BreadcrumbNav } from "@/components/seo/breadcrumb-nav";
export const metadata: Metadata = {
title: "Pricing | SurfSense - Free AI Search Plans",
description:
- "Explore SurfSense plans and pricing. Start free with 500 pages & 3M premium tokens. Use ChatGPT, Claude AI, and premium AI models. Pay-as-you-go tokens at $1 per million.",
+ "Explore SurfSense plans and pricing. Start free with 500 pages & $5 of premium credit. Use ChatGPT, Claude AI, and premium AI models. Pay as you go at provider cost — $1 buys $1 of credit.",
alternates: {
canonical: "https://surfsense.com/pricing",
},
diff --git a/surfsense_web/app/dashboard/[search_space_id]/buy-more/page.tsx b/surfsense_web/app/dashboard/[search_space_id]/buy-more/page.tsx
index 3017160e1..0c5662712 100644
--- a/surfsense_web/app/dashboard/[search_space_id]/buy-more/page.tsx
+++ b/surfsense_web/app/dashboard/[search_space_id]/buy-more/page.tsx
@@ -8,7 +8,7 @@ import { cn } from "@/lib/utils";
const TABS = [
{ id: "pages", label: "Pages" },
- { id: "tokens", label: "Premium Tokens" },
+ { id: "tokens", label: "Premium Credit" },
] as const;
type TabId = (typeof TABS)[number]["id"];
diff --git a/surfsense_web/app/dashboard/[search_space_id]/user-settings/components/PurchaseHistoryContent.tsx b/surfsense_web/app/dashboard/[search_space_id]/user-settings/components/PurchaseHistoryContent.tsx
index 2b7422f80..cf73b5eba 100644
--- a/surfsense_web/app/dashboard/[search_space_id]/user-settings/components/PurchaseHistoryContent.tsx
+++ b/surfsense_web/app/dashboard/[search_space_id]/user-settings/components/PurchaseHistoryContent.tsx
@@ -28,6 +28,12 @@ type UnifiedPurchase = {
kind: PurchaseKind;
created_at: string;
status: PagePurchaseStatus;
+ /**
+ * Granted units. Interpretation depends on ``kind``:
+ * - ``"pages"`` — integer number of indexed pages.
+ * - ``"tokens"`` — integer micro-USD of credit (1_000_000 = $1.00).
+ * The ``Granted`` column formats accordingly.
+ */
granted: number;
amount_total: number | null;
currency: string | null;
@@ -58,7 +64,7 @@ const KIND_META: Record<
iconClass: "text-sky-500",
},
tokens: {
- label: "Premium Tokens",
+ label: "Premium Credit",
icon: Coins,
iconClass: "text-amber-500",
},
@@ -97,12 +103,25 @@ function normalizeTokenPurchase(p: TokenPurchase): UnifiedPurchase {
kind: "tokens",
created_at: p.created_at,
status: p.status,
- granted: p.tokens_granted,
+ granted: p.credit_micros_granted,
amount_total: p.amount_total,
currency: p.currency,
};
}
+function formatGranted(p: UnifiedPurchase): string {
+ if (p.kind === "tokens") {
+ const dollars = p.granted / 1_000_000;
+ // Premium credit packs are always whole dollars at the moment, but
+ // future fractional grants (refunds, partial top-ups) shouldn't
+ // silently round to "$0".
+ if (dollars >= 1) return `$${dollars.toFixed(2)} of credit`;
+ if (dollars > 0) return `$${dollars.toFixed(3)} of credit`;
+ return "$0 of credit";
+ }
+ return p.granted.toLocaleString();
+}
+
export function PurchaseHistoryContent() {
const results = useQueries({
queries: [
@@ -143,7 +162,7 @@ export function PurchaseHistoryContent() {
No purchases yet
- Your page and premium token purchases will appear here after checkout.
+ Your page and premium credit purchases will appear here after checkout.
);
@@ -177,7 +196,7 @@ export function PurchaseHistoryContent() {
- {p.granted.toLocaleString()}
+ {formatGranted(p)}
{formatAmount(p.amount_total, p.currency)}
diff --git a/surfsense_web/atoms/user/user-query.atoms.ts b/surfsense_web/atoms/user/user-query.atoms.ts
index a59811324..4b6717440 100644
--- a/surfsense_web/atoms/user/user-query.atoms.ts
+++ b/surfsense_web/atoms/user/user-query.atoms.ts
@@ -8,9 +8,9 @@ const userQueryFn = () => userApiService.getMe();
export const currentUserAtom = atomWithQuery(() => {
return {
queryKey: USER_QUERY_KEY,
- // Live-changing numeric fields (pages_*, premium_tokens_*) are now
- // pushed via Zero (queries.user.me()), so /users/me only needs to
- // fire once per session for the static profile fields.
+ // Live-changing numeric fields (pages_*, premium_credit_micros_*)
+ // are now pushed via Zero (queries.user.me()), so /users/me only
+ // needs to fire once per session for the static profile fields.
staleTime: Infinity,
enabled: !!getBearerToken(),
queryFn: userQueryFn,
diff --git a/surfsense_web/components/assistant-ui/assistant-message.tsx b/surfsense_web/components/assistant-ui/assistant-message.tsx
index 711bb2fe2..ffb0e4dc8 100644
--- a/surfsense_web/components/assistant-ui/assistant-message.tsx
+++ b/surfsense_web/components/assistant-ui/assistant-message.tsx
@@ -399,6 +399,19 @@ function formatMessageDate(date: Date): string {
});
}
+/**
+ * Format provider USD cost (in micro-USD) for inline display next to a
+ * token count. Falls back to ``"<$0.001"`` for sub-tenth-of-a-cent
+ * costs so a real-but-tiny figure doesn't render as ``$0.000``.
+ */
+function formatTurnCost(micros: number): string {
+ const dollars = micros / 1_000_000;
+ if (dollars >= 1) return `$${dollars.toFixed(2)}`;
+ if (dollars >= 0.01) return `$${dollars.toFixed(3)}`;
+ if (dollars > 0) return "<$0.001";
+ return "$0";
+}
+
const MessageInfoDropdown: FC = () => {
const messageId = useAuiState(({ message }) => message?.id);
const createdAt = useAuiState(({ message }) => message?.createdAt);
@@ -451,6 +464,7 @@ const MessageInfoDropdown: FC = () => {
{models.length > 0 ? (
models.map(([model, counts]) => {
const { name, icon } = resolveModel(model);
+ const costMicros = counts.cost_micros;
return (
{
{counts.total_tokens.toLocaleString()} tokens
+ {costMicros && costMicros > 0
+ ? ` · ${formatTurnCost(costMicros)}`
+ : ""}
);
@@ -474,6 +491,9 @@ const MessageInfoDropdown: FC = () => {
>
{usage.total_tokens.toLocaleString()} tokens
+ {usage.cost_micros && usage.cost_micros > 0
+ ? ` · ${formatTurnCost(usage.cost_micros)}`
+ : ""}
)}
diff --git a/surfsense_web/components/assistant-ui/token-usage-context.tsx b/surfsense_web/components/assistant-ui/token-usage-context.tsx
index b3f71ab21..dd80bcac3 100644
--- a/surfsense_web/components/assistant-ui/token-usage-context.tsx
+++ b/surfsense_web/components/assistant-ui/token-usage-context.tsx
@@ -13,13 +13,30 @@ export interface TokenUsageData {
prompt_tokens: number;
completion_tokens: number;
total_tokens: number;
+ /**
+ * Total provider USD cost for this assistant turn, in micro-USD
+ * (1_000_000 = $1.00). Populated from LiteLLM's response_cost on
+ * the backend. Optional because pre-cost-credits messages persisted
+ * before the migration won't have it.
+ */
+ cost_micros?: number;
usage?: Record<
string,
- { prompt_tokens: number; completion_tokens: number; total_tokens: number }
+ {
+ prompt_tokens: number;
+ completion_tokens: number;
+ total_tokens: number;
+ cost_micros?: number;
+ }
>;
model_breakdown?: Record<
string,
- { prompt_tokens: number; completion_tokens: number; total_tokens: number }
+ {
+ prompt_tokens: number;
+ completion_tokens: number;
+ total_tokens: number;
+ cost_micros?: number;
+ }
>;
}
diff --git a/surfsense_web/components/free-chat/quota-warning-banner.tsx b/surfsense_web/components/free-chat/quota-warning-banner.tsx
index 3bfedf1b3..e013a64a8 100644
--- a/surfsense_web/components/free-chat/quota-warning-banner.tsx
+++ b/surfsense_web/components/free-chat/quota-warning-banner.tsx
@@ -40,7 +40,7 @@ export function QuotaWarningBanner({
You've used all {limit.toLocaleString()} free tokens. Create a free account to
- get 3 million tokens and access to all models.
+ get $5 of premium credit and access to all models.
Create an account
{" "}
- for 5M free tokens.
+ for $5 of premium credit.