diff --git a/docker/.env.example b/docker/.env.example index 95de0cf85..c2e87a619 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -159,10 +159,13 @@ STRIPE_PAGE_BUYING_ENABLED=FALSE # STRIPE_RECONCILIATION_LOOKBACK_MINUTES=10 # STRIPE_RECONCILIATION_BATCH_SIZE=100 -# Premium token purchases ($1 per 1M tokens for premium-tier models) +# Premium credit purchases via Stripe ($1 buys 1_000_000 micro-USD of +# credit; premium turns debit the actual per-call provider cost +# reported by LiteLLM, so cheap and expensive models bill proportionally) # STRIPE_TOKEN_BUYING_ENABLED=FALSE # STRIPE_PREMIUM_TOKEN_PRICE_ID=price_... -# STRIPE_TOKENS_PER_UNIT=1000000 +# STRIPE_CREDIT_MICROS_PER_UNIT=1000000 +# DEPRECATED — STRIPE_TOKENS_PER_UNIT=1000000 # ------------------------------------------------------------------------------ # TTS & STT (Text-to-Speech / Speech-to-Text) @@ -315,9 +318,24 @@ STT_SERVICE=local/base # Pages limit per user for ETL (default: unlimited) # PAGES_LIMIT=500 -# Premium token quota per registered user (default: 5M) -# Only applies to models with billing_tier=premium in global_llm_config.yaml -# PREMIUM_TOKEN_LIMIT=5000000 +# Premium credit quota per registered user, in micro-USD (default: $5). +# Premium turns are debited at the actual per-call provider cost reported +# by LiteLLM. Only applies to models with billing_tier=premium. +# PREMIUM_CREDIT_MICROS_LIMIT=5000000 +# DEPRECATED — PREMIUM_TOKEN_LIMIT=5000000 + +# Safety ceiling on per-call premium reservation, in micro-USD ($1.00 default). +# QUOTA_MAX_RESERVE_MICROS=1000000 + +# Per-image reservation for POST /image-generations, in micro-USD ($0.05 default). +# QUOTA_DEFAULT_IMAGE_RESERVE_MICROS=50000 + +# Per-podcast reservation for the podcast Celery task ($0.20 default). +# QUOTA_DEFAULT_PODCAST_RESERVE_MICROS=200000 + +# Per-video-presentation reservation for the video Celery task ($1.00 default). +# Override path bypasses QUOTA_MAX_RESERVE_MICROS clamp — raise with care. +# QUOTA_DEFAULT_VIDEO_PRESENTATION_RESERVE_MICROS=1000000 # No-login (anonymous) mode — public users can chat without an account # Set TRUE to enable /free pages and anonymous chat API diff --git a/surfsense_backend/.env.example b/surfsense_backend/.env.example index a793f33d1..1b1478ae6 100644 --- a/surfsense_backend/.env.example +++ b/surfsense_backend/.env.example @@ -54,11 +54,15 @@ STRIPE_PAGES_PER_UNIT=1000 # Set FALSE to disable new checkout session creation temporarily STRIPE_PAGE_BUYING_ENABLED=TRUE -# Premium token purchases via Stripe (for premium-tier model usage) -# Set TRUE to allow users to buy premium token packs ($1 per 1M tokens) +# Premium credit purchases via Stripe (for premium-tier model usage). +# Each pack grants STRIPE_CREDIT_MICROS_PER_UNIT micro-USD of credit +# (default 1_000_000 = $1.00). Premium turns are billed at the actual +# per-call provider cost reported by LiteLLM. STRIPE_TOKEN_BUYING_ENABLED=FALSE STRIPE_PREMIUM_TOKEN_PRICE_ID=price_... -STRIPE_TOKENS_PER_UNIT=1000000 +STRIPE_CREDIT_MICROS_PER_UNIT=1000000 +# DEPRECATED — use STRIPE_CREDIT_MICROS_PER_UNIT (1:1 numerical mapping): +# STRIPE_TOKENS_PER_UNIT=1000000 # Periodic Stripe safety net for purchases left in PENDING (minutes old) STRIPE_RECONCILIATION_LOOKBACK_MINUTES=10 @@ -184,9 +188,35 @@ VIDEO_PRESENTATION_DEFAULT_DURATION_IN_FRAMES=300 # (Optional) Maximum pages limit per user for ETL services (default: `999999999` for unlimited in OSS version) PAGES_LIMIT=500 -# Premium token quota per registered user (default: 3,000,000) -# Applies only to models with billing_tier=premium in global_llm_config.yaml -PREMIUM_TOKEN_LIMIT=3000000 +# Premium credit quota per registered user, in micro-USD +# (default: 5,000,000 == $5.00 of credit). Premium turns are debited at the +# actual per-call provider cost reported by LiteLLM, so cheap and expensive +# models bill proportionally. Applies only to models with +# billing_tier=premium in global_llm_config.yaml. +PREMIUM_CREDIT_MICROS_LIMIT=5000000 +# DEPRECATED — use PREMIUM_CREDIT_MICROS_LIMIT (1:1 numerical mapping): +# PREMIUM_TOKEN_LIMIT=5000000 + +# Safety ceiling on per-call premium reservation, in micro-USD. +# stream_new_chat estimates an upper-bound cost from the model's +# litellm-published per-token rates × the config's quota_reserve_tokens +# and clamps to this value so a misconfigured model can't lock the +# user's whole balance on one call. Default $1.00. +QUOTA_MAX_RESERVE_MICROS=1000000 + +# Per-image reservation (in micro-USD) for the POST /image-generations +# endpoint. Bypassed for free configs. Default $0.05. +QUOTA_DEFAULT_IMAGE_RESERVE_MICROS=50000 + +# Per-podcast reservation (in micro-USD) used by the podcast Celery task. +# Single envelope covers one transcript-generation LLM call. Default $0.20. +QUOTA_DEFAULT_PODCAST_RESERVE_MICROS=200000 + +# Per-video-presentation reservation (in micro-USD) used by the video +# presentation Celery task. Covers worst-case fan-out of N slide-scene +# generations + refines. Default $1.00. NOTE: tasks using the override +# path bypass the QUOTA_MAX_RESERVE_MICROS clamp — raise with care. +QUOTA_DEFAULT_VIDEO_PRESENTATION_RESERVE_MICROS=1000000 # No-login (anonymous) mode — allows public users to chat without an account # Set TRUE to enable /free pages and anonymous chat API diff --git a/surfsense_backend/alembic/versions/140_premium_tokens_to_credit_micros.py b/surfsense_backend/alembic/versions/140_premium_tokens_to_credit_micros.py new file mode 100644 index 000000000..64aa699e8 --- /dev/null +++ b/surfsense_backend/alembic/versions/140_premium_tokens_to_credit_micros.py @@ -0,0 +1,291 @@ +"""rename premium token columns to credit micros and add cost_micros to token_usage + +Migrates the premium quota system from a flat token counter to a USD-cost +based credit system, where 1 credit = 1 micro-USD ($0.000001). + +Column renames (1:1 numerical mapping — the prior $1 per 1M tokens Stripe +price means every existing value is already correct in the new unit, no +data transformation needed): + + user.premium_tokens_limit -> premium_credit_micros_limit + user.premium_tokens_used -> premium_credit_micros_used + user.premium_tokens_reserved -> premium_credit_micros_reserved + + premium_token_purchases.tokens_granted -> credit_micros_granted + +New column for cost auditing per turn: + + token_usage.cost_micros (BigInteger NOT NULL DEFAULT 0) + +The "user" table is in zero_publication's column list (added in 139), so +this migration must drop and recreate the publication with the renamed +column names, otherwise zero-cache will replicate stale column names and +the FE Zero schema will fail to bind. + +IMPORTANT - before AND after running this migration: + 1. Stop zero-cache (it holds replication locks that will deadlock DDL) + 2. Run: alembic upgrade head + 3. Delete / reset the zero-cache data volume + 4. Restart zero-cache (it will do a fresh initial sync) + +Skipping the zero-cache stop will deadlock at the ACCESS EXCLUSIVE LOCK on +"user". Skipping the data-volume reset will leave IndexedDB clients seeing +column-not-found errors from a stale catalog snapshot. + +Revision ID: 140 +Revises: 139 +""" + +from collections.abc import Sequence + +import sqlalchemy as sa + +from alembic import op + +revision: str = "140" +down_revision: str | None = "139" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + +PUBLICATION_NAME = "zero_publication" + +# Replicates 139's document column list verbatim — must stay in sync. +DOCUMENT_COLS = [ + "id", + "title", + "document_type", + "search_space_id", + "folder_id", + "created_by_id", + "status", + "created_at", + "updated_at", +] + +# Same five live-meter fields as 139, with the renamed column names. +USER_COLS = [ + "id", + "pages_limit", + "pages_used", + "premium_credit_micros_limit", + "premium_credit_micros_used", +] + + +def _terminate_blocked_pids(conn, table: str) -> None: + """Kill backends whose locks on *table* would block our AccessExclusiveLock.""" + conn.execute( + sa.text( + "SELECT pg_terminate_backend(l.pid) " + "FROM pg_locks l " + "JOIN pg_class c ON c.oid = l.relation " + "WHERE c.relname = :tbl " + " AND l.pid != pg_backend_pid()" + ), + {"tbl": table}, + ) + + +def _has_zero_version(conn, table: str) -> bool: + return ( + conn.execute( + sa.text( + "SELECT 1 FROM information_schema.columns " + "WHERE table_name = :tbl AND column_name = '_0_version'" + ), + {"tbl": table}, + ).fetchone() + is not None + ) + + +def _column_exists(conn, table: str, column: str) -> bool: + return ( + conn.execute( + sa.text( + "SELECT 1 FROM information_schema.columns " + "WHERE table_name = :tbl AND column_name = :col" + ), + {"tbl": table, "col": column}, + ).fetchone() + is not None + ) + + +def _build_publication_ddl( + user_cols: list[str], + *, + documents_has_zero_ver: bool, + user_has_zero_ver: bool, +) -> str: + doc_cols = DOCUMENT_COLS + (['"_0_version"'] if documents_has_zero_ver else []) + user_col_list_with_meta = user_cols + ( + ['"_0_version"'] if user_has_zero_ver else [] + ) + doc_col_list = ", ".join(doc_cols) + user_col_list = ", ".join(user_col_list_with_meta) + return ( + f"CREATE PUBLICATION {PUBLICATION_NAME} FOR TABLE " + f"notifications, " + f"documents ({doc_col_list}), " + f"folders, " + f"search_source_connectors, " + f"new_chat_messages, " + f"chat_comments, " + f"chat_session_state, " + f'"user" ({user_col_list})' + ) + + +def upgrade() -> None: + conn = op.get_bind() + + # ------------------------------------------------------------------ + # 1. Add cost_micros to token_usage. Idempotent guard so re-runs in + # dev environments are safe. + # ------------------------------------------------------------------ + if not _column_exists(conn, "token_usage", "cost_micros"): + op.add_column( + "token_usage", + sa.Column( + "cost_micros", + sa.BigInteger(), + nullable=False, + server_default="0", + ), + ) + + # ------------------------------------------------------------------ + # 2. Rename premium_token_purchases.tokens_granted -> credit_micros_granted. + # ------------------------------------------------------------------ + if _column_exists( + conn, "premium_token_purchases", "tokens_granted" + ) and not _column_exists(conn, "premium_token_purchases", "credit_micros_granted"): + op.alter_column( + "premium_token_purchases", + "tokens_granted", + new_column_name="credit_micros_granted", + ) + + # ------------------------------------------------------------------ + # 3. Rename user.premium_tokens_* -> premium_credit_micros_*. + # + # We must drop the publication first (it references the old column + # names) and re-acquire the lock for DDL. asyncpg requires LOCK TABLE + # in a transaction block; alembic's outer transaction already holds + # one, but a SAVEPOINT keeps the LOCK + DDL atomic. + # ------------------------------------------------------------------ + tx = conn.begin_nested() if conn.in_transaction() else conn.begin() + with tx: + conn.execute(sa.text("SET lock_timeout = '10s'")) + + _terminate_blocked_pids(conn, "user") + conn.execute(sa.text('LOCK TABLE "user" IN ACCESS EXCLUSIVE MODE')) + + # Re-assert REPLICA IDENTITY DEFAULT for safety; column-list + # publications require at least the PK to be in the column list, + # which is true for both the old and new shape. + conn.execute(sa.text('ALTER TABLE "user" REPLICA IDENTITY DEFAULT')) + + # Drop the publication BEFORE renaming columns, otherwise Postgres + # rejects the rename: "cannot drop column ... referenced by + # publication". + conn.execute(sa.text(f"DROP PUBLICATION IF EXISTS {PUBLICATION_NAME}")) + + for old, new in ( + ("premium_tokens_limit", "premium_credit_micros_limit"), + ("premium_tokens_used", "premium_credit_micros_used"), + ("premium_tokens_reserved", "premium_credit_micros_reserved"), + ): + if _column_exists(conn, "user", old) and not _column_exists( + conn, "user", new + ): + op.alter_column("user", old, new_column_name=new) + + # Update the server_default on the renamed limit column so newly + # inserted users get $5 of credit (== 5_000_000 micros) by + # default. Existing rows are unaffected. + op.alter_column( + "user", + "premium_credit_micros_limit", + server_default="5000000", + ) + + # Recreate the publication with the new column names. + documents_has_zero_ver = _has_zero_version(conn, "documents") + user_has_zero_ver = _has_zero_version(conn, "user") + conn.execute( + sa.text( + _build_publication_ddl( + USER_COLS, + documents_has_zero_ver=documents_has_zero_ver, + user_has_zero_ver=user_has_zero_ver, + ) + ) + ) + + +def downgrade() -> None: + """Revert the rename and drop ``cost_micros``. + + Mirrors ``upgrade``: drop the publication, rename columns back, drop + the new column, recreate the publication with the old column list. + Same zero-cache stop/reset runbook applies in reverse. + """ + conn = op.get_bind() + + tx = conn.begin_nested() if conn.in_transaction() else conn.begin() + with tx: + conn.execute(sa.text("SET lock_timeout = '10s'")) + + _terminate_blocked_pids(conn, "user") + conn.execute(sa.text('LOCK TABLE "user" IN ACCESS EXCLUSIVE MODE')) + + conn.execute(sa.text(f"DROP PUBLICATION IF EXISTS {PUBLICATION_NAME}")) + + for new, old in ( + ("premium_credit_micros_limit", "premium_tokens_limit"), + ("premium_credit_micros_used", "premium_tokens_used"), + ("premium_credit_micros_reserved", "premium_tokens_reserved"), + ): + if _column_exists(conn, "user", new) and not _column_exists( + conn, "user", old + ): + op.alter_column("user", new, new_column_name=old) + + op.alter_column( + "user", + "premium_tokens_limit", + server_default="5000000", + ) + + legacy_user_cols = [ + "id", + "pages_limit", + "pages_used", + "premium_tokens_limit", + "premium_tokens_used", + ] + documents_has_zero_ver = _has_zero_version(conn, "documents") + user_has_zero_ver = _has_zero_version(conn, "user") + conn.execute( + sa.text( + _build_publication_ddl( + legacy_user_cols, + documents_has_zero_ver=documents_has_zero_ver, + user_has_zero_ver=user_has_zero_ver, + ) + ) + ) + + if _column_exists( + conn, "premium_token_purchases", "credit_micros_granted" + ) and not _column_exists(conn, "premium_token_purchases", "tokens_granted"): + op.alter_column( + "premium_token_purchases", + "credit_micros_granted", + new_column_name="tokens_granted", + ) + + if _column_exists(conn, "token_usage", "cost_micros"): + op.drop_column("token_usage", "cost_micros") diff --git a/surfsense_backend/app/app.py b/surfsense_backend/app/app.py index 016c2de42..14d7f4d23 100644 --- a/surfsense_backend/app/app.py +++ b/surfsense_backend/app/app.py @@ -31,6 +31,7 @@ from app.config import ( initialize_image_gen_router, initialize_llm_router, initialize_openrouter_integration, + initialize_pricing_registration, initialize_vision_llm_router, ) from app.db import User, create_db_and_tables, get_async_session @@ -432,6 +433,7 @@ async def lifespan(app: FastAPI): await setup_checkpointer_tables() initialize_openrouter_integration() _start_openrouter_background_refresh() + initialize_pricing_registration() initialize_llm_router() initialize_image_gen_router() initialize_vision_llm_router() diff --git a/surfsense_backend/app/celery_app.py b/surfsense_backend/app/celery_app.py index 58a8b0f39..74710d5e1 100644 --- a/surfsense_backend/app/celery_app.py +++ b/surfsense_backend/app/celery_app.py @@ -22,10 +22,12 @@ def init_worker(**kwargs): initialize_image_gen_router, initialize_llm_router, initialize_openrouter_integration, + initialize_pricing_registration, initialize_vision_llm_router, ) initialize_openrouter_integration() + initialize_pricing_registration() initialize_llm_router() initialize_image_gen_router() initialize_vision_llm_router() diff --git a/surfsense_backend/app/config/__init__.py b/surfsense_backend/app/config/__init__.py index 675b05d2c..2aeeafb34 100644 --- a/surfsense_backend/app/config/__init__.py +++ b/surfsense_backend/app/config/__init__.py @@ -138,7 +138,11 @@ def load_global_image_gen_configs(): try: with open(global_config_file, encoding="utf-8") as f: data = yaml.safe_load(f) - return data.get("global_image_generation_configs", []) + configs = data.get("global_image_generation_configs", []) or [] + for cfg in configs: + if isinstance(cfg, dict): + cfg.setdefault("billing_tier", "free") + return configs except Exception as e: print(f"Warning: Failed to load global image generation configs: {e}") return [] @@ -153,7 +157,11 @@ def load_global_vision_llm_configs(): try: with open(global_config_file, encoding="utf-8") as f: data = yaml.safe_load(f) - return data.get("global_vision_llm_configs", []) + configs = data.get("global_vision_llm_configs", []) or [] + for cfg in configs: + if isinstance(cfg, dict): + cfg.setdefault("billing_tier", "free") + return configs except Exception as e: print(f"Warning: Failed to load global vision LLM configs: {e}") return [] @@ -254,6 +262,15 @@ def load_openrouter_integration_settings() -> dict | None: "anonymous_enabled_free", settings["anonymous_enabled"] ) + # Image generation + vision LLM emission are opt-in (issue L). + # OpenRouter's catalogue contains hundreds of image / vision + # capable models; auto-injecting all of them into every + # deployment would explode the model selector and surprise + # operators upgrading from prior versions. Default to False so + # admins must explicitly turn them on. + settings.setdefault("image_generation_enabled", False) + settings.setdefault("vision_enabled", False) + return settings except Exception as e: print(f"Warning: Failed to load OpenRouter integration settings: {e}") @@ -296,10 +313,60 @@ def initialize_openrouter_integration(): ) else: print("Info: OpenRouter integration enabled but no models fetched") + + # Image generation + vision LLM emissions are opt-in (issue L). + # Both reuse the catalogue already cached by ``service.initialize`` + # so we don't make additional network calls here. + if settings.get("image_generation_enabled"): + try: + image_configs = service.get_image_generation_configs() + if image_configs: + config.GLOBAL_IMAGE_GEN_CONFIGS.extend(image_configs) + print( + f"Info: OpenRouter integration added {len(image_configs)} " + f"image-generation models" + ) + except Exception as e: + print(f"Warning: Failed to inject OpenRouter image-gen configs: {e}") + + if settings.get("vision_enabled"): + try: + vision_configs = service.get_vision_llm_configs() + if vision_configs: + config.GLOBAL_VISION_LLM_CONFIGS.extend(vision_configs) + print( + f"Info: OpenRouter integration added {len(vision_configs)} " + f"vision LLM models" + ) + except Exception as e: + print(f"Warning: Failed to inject OpenRouter vision-LLM configs: {e}") except Exception as e: print(f"Warning: Failed to initialize OpenRouter integration: {e}") +def initialize_pricing_registration(): + """ + Teach LiteLLM the per-token cost of every deployment in + ``config.GLOBAL_LLM_CONFIGS`` (OpenRouter dynamic models pulled + from the OpenRouter catalogue + any operator-declared YAML pricing). + + Must run AFTER ``initialize_openrouter_integration()`` so the + OpenRouter catalogue is populated and BEFORE the first LLM call so + ``response_cost`` is available in ``TokenTrackingCallback``. + + Failures are logged but never raised — startup must not be blocked + by a missing pricing entry; the worst-case is the model debits 0. + """ + try: + from app.services.pricing_registration import ( + register_pricing_from_global_configs, + ) + + register_pricing_from_global_configs() + except Exception as e: + print(f"Warning: Failed to register LiteLLM pricing: {e}") + + def initialize_llm_router(): """ Initialize the LLM Router service for Auto mode. @@ -444,14 +511,54 @@ class Config: os.getenv("STRIPE_RECONCILIATION_BATCH_SIZE", "100") ) - # Premium token quota settings - PREMIUM_TOKEN_LIMIT = int(os.getenv("PREMIUM_TOKEN_LIMIT", "3000000")) + # Premium credit (micro-USD) quota settings. + # + # Storage unit is integer micro-USD (1_000_000 = $1.00). The legacy + # ``PREMIUM_TOKEN_LIMIT`` and ``STRIPE_TOKENS_PER_UNIT`` env vars are + # still honoured for one release as fall-back values — the prior + # $1-per-1M-tokens Stripe price means every existing value maps 1:1 + # to micros, so operators upgrading without changing their .env still + # get correct behaviour. A startup deprecation warning fires below if + # they're set. + PREMIUM_CREDIT_MICROS_LIMIT = int( + os.getenv("PREMIUM_CREDIT_MICROS_LIMIT") + or os.getenv("PREMIUM_TOKEN_LIMIT", "5000000") + ) STRIPE_PREMIUM_TOKEN_PRICE_ID = os.getenv("STRIPE_PREMIUM_TOKEN_PRICE_ID") - STRIPE_TOKENS_PER_UNIT = int(os.getenv("STRIPE_TOKENS_PER_UNIT", "1000000")) + STRIPE_CREDIT_MICROS_PER_UNIT = int( + os.getenv("STRIPE_CREDIT_MICROS_PER_UNIT") + or os.getenv("STRIPE_TOKENS_PER_UNIT", "1000000") + ) STRIPE_TOKEN_BUYING_ENABLED = ( os.getenv("STRIPE_TOKEN_BUYING_ENABLED", "FALSE").upper() == "TRUE" ) + # Safety ceiling on the per-call premium reservation. ``stream_new_chat`` + # estimates an upper-bound cost from ``litellm.get_model_info`` x the + # config's ``quota_reserve_tokens`` and clamps the result to this value + # so a misconfigured "$1000/M" model can't lock the user's whole balance + # on one call. Default $1.00 covers realistic worst-cases (Opus + 4K + # reserve_tokens ≈ $0.36) with headroom. + QUOTA_MAX_RESERVE_MICROS = int(os.getenv("QUOTA_MAX_RESERVE_MICROS", "1000000")) + + if os.getenv("PREMIUM_TOKEN_LIMIT") and not os.getenv( + "PREMIUM_CREDIT_MICROS_LIMIT" + ): + print( + "Warning: PREMIUM_TOKEN_LIMIT is deprecated; rename to " + "PREMIUM_CREDIT_MICROS_LIMIT (1:1 numerical mapping under the " + "current Stripe price). The old key will be removed in a " + "future release." + ) + if os.getenv("STRIPE_TOKENS_PER_UNIT") and not os.getenv( + "STRIPE_CREDIT_MICROS_PER_UNIT" + ): + print( + "Warning: STRIPE_TOKENS_PER_UNIT is deprecated; rename to " + "STRIPE_CREDIT_MICROS_PER_UNIT (1:1 numerical mapping). " + "The old key will be removed in a future release." + ) + # Anonymous / no-login mode settings NOLOGIN_MODE_ENABLED = os.getenv("NOLOGIN_MODE_ENABLED", "FALSE").upper() == "TRUE" ANON_TOKEN_LIMIT = int(os.getenv("ANON_TOKEN_LIMIT", "500000")) @@ -464,6 +571,35 @@ class Config: # Default quota reserve tokens when not specified per-model QUOTA_MAX_RESERVE_PER_CALL = int(os.getenv("QUOTA_MAX_RESERVE_PER_CALL", "8000")) + # Per-image reservation (in micro-USD) used by ``billable_call`` for the + # ``POST /image-generations`` endpoint when the global config does not + # override it. $0.05 covers realistic worst-cases for current OpenAI / + # OpenRouter image-gen pricing. Bypassed entirely for free configs. + QUOTA_DEFAULT_IMAGE_RESERVE_MICROS = int( + os.getenv("QUOTA_DEFAULT_IMAGE_RESERVE_MICROS", "50000") + ) + + # Per-podcast reservation (in micro-USD). One agent LLM call generating + # a transcript, typically 5k-20k completion tokens. $0.20 covers a long + # premium-model run. Tune via env. + QUOTA_DEFAULT_PODCAST_RESERVE_MICROS = int( + os.getenv("QUOTA_DEFAULT_PODCAST_RESERVE_MICROS", "200000") + ) + + # Per-video-presentation reservation (in micro-USD). Fan-out of N + # slide-scene generations (up to ``VIDEO_PRESENTATION_MAX_SLIDES=30``) + # plus refine retries; can produce many premium completions. $1.00 + # covers worst-case. Tune via env. + # + # NOTE: this equals the existing ``QUOTA_MAX_RESERVE_MICROS`` default of + # 1_000_000. The override path in ``billable_call`` bypasses the + # per-call clamp in ``estimate_call_reserve_micros``, so this is the + # *actual* hold — raising it via env is fine but means a single video + # task can lock $1+ of credit. + QUOTA_DEFAULT_VIDEO_PRESENTATION_RESERVE_MICROS = int( + os.getenv("QUOTA_DEFAULT_VIDEO_PRESENTATION_RESERVE_MICROS", "1000000") + ) + # Abuse prevention: concurrent stream cap and CAPTCHA ANON_MAX_CONCURRENT_STREAMS = int(os.getenv("ANON_MAX_CONCURRENT_STREAMS", "2")) ANON_CAPTCHA_REQUEST_THRESHOLD = int( diff --git a/surfsense_backend/app/config/global_llm_config.example.yaml b/surfsense_backend/app/config/global_llm_config.example.yaml index 79cbe1e51..d92640c8d 100644 --- a/surfsense_backend/app/config/global_llm_config.example.yaml +++ b/surfsense_backend/app/config/global_llm_config.example.yaml @@ -19,6 +19,24 @@ # Structure matches NewLLMConfig: # - Model configuration (provider, model_name, api_key, etc.) # - Prompt configuration (system_instructions, citations_enabled) +# +# COST-BASED PREMIUM CREDITS: +# Each premium config bills the user's USD-credit balance based on the +# actual provider cost reported by LiteLLM. For models LiteLLM already +# knows (most OpenAI/Anthropic/etc. names) you don't need to do anything. +# For custom Azure deployment names (e.g. an in-house "gpt-5.4" deployment) +# or any model LiteLLM doesn't have in its built-in pricing table, declare +# per-token costs inline so they bill correctly: +# +# litellm_params: +# base_model: "my-custom-azure-deploy" +# # USD per token; e.g. 0.000003 == $3.00 per million input tokens +# input_cost_per_token: 0.000003 +# output_cost_per_token: 0.000015 +# +# OpenRouter dynamic models pull pricing automatically from OpenRouter's +# API — no inline declaration needed. Models without resolvable pricing +# debit $0 from the user's balance and log a WARNING. # Router Settings for Auto Mode # These settings control how the LiteLLM Router distributes requests across models @@ -292,6 +310,17 @@ openrouter_integration: free_rpm: 20 free_tpm: 100000 + # Image generation + vision LLM emission are OPT-IN. OpenRouter's catalogue + # contains hundreds of image- and vision-capable models; turning these on + # injects them into the global Image-Generation / Vision-LLM model + # selectors alongside any static configs. Tier (free/premium) is derived + # per model the same way it is for chat (`:free` suffix or zero pricing). + # When a user picks a premium image/vision model the call debits the + # shared $5 USD-cost-based premium credit pool — so leaving these off + # avoids surprise quota burn on existing deployments. Default: false. + image_generation_enabled: false + vision_enabled: false + litellm_params: max_tokens: 16384 system_instructions: "" diff --git a/surfsense_backend/app/db.py b/surfsense_backend/app/db.py index 2fe478d9b..aef959ec9 100644 --- a/surfsense_backend/app/db.py +++ b/surfsense_backend/app/db.py @@ -731,6 +731,7 @@ class TokenUsage(BaseModel, TimestampMixin): prompt_tokens = Column(Integer, nullable=False, default=0) completion_tokens = Column(Integer, nullable=False, default=0) total_tokens = Column(Integer, nullable=False, default=0) + cost_micros = Column(BigInteger, nullable=False, default=0, server_default="0") model_breakdown = Column(JSONB, nullable=True) call_details = Column(JSONB, nullable=True) @@ -1793,7 +1794,15 @@ class PagePurchase(Base, TimestampMixin): class PremiumTokenPurchase(Base, TimestampMixin): - """Tracks Stripe checkout sessions used to grant additional premium token credits.""" + """Tracks Stripe checkout sessions used to grant additional premium credit (USD micro-units). + + Note: the table name is preserved (``premium_token_purchases``) for + operational continuity even though the unit is now USD micro-credits + instead of raw tokens. The ``credit_micros_granted`` column replaced + the legacy ``tokens_granted`` in migration 140; the stored values + were not transformed because the prior $1 = 1M tokens Stripe price + makes the unit conversion 1:1 numerically. + """ __tablename__ = "premium_token_purchases" __allow_unmapped__ = True @@ -1810,7 +1819,7 @@ class PremiumTokenPurchase(Base, TimestampMixin): ) stripe_payment_intent_id = Column(String(255), nullable=True, index=True) quantity = Column(Integer, nullable=False) - tokens_granted = Column(BigInteger, nullable=False) + credit_micros_granted = Column(BigInteger, nullable=False) amount_total = Column(Integer, nullable=True) currency = Column(String(10), nullable=True) status = Column( @@ -2109,16 +2118,16 @@ if config.AUTH_TYPE == "GOOGLE": ) pages_used = Column(Integer, nullable=False, default=0, server_default="0") - premium_tokens_limit = Column( + premium_credit_micros_limit = Column( BigInteger, nullable=False, - default=config.PREMIUM_TOKEN_LIMIT, - server_default=str(config.PREMIUM_TOKEN_LIMIT), + default=config.PREMIUM_CREDIT_MICROS_LIMIT, + server_default=str(config.PREMIUM_CREDIT_MICROS_LIMIT), ) - premium_tokens_used = Column( + premium_credit_micros_used = Column( BigInteger, nullable=False, default=0, server_default="0" ) - premium_tokens_reserved = Column( + premium_credit_micros_reserved = Column( BigInteger, nullable=False, default=0, server_default="0" ) @@ -2241,16 +2250,16 @@ else: ) pages_used = Column(Integer, nullable=False, default=0, server_default="0") - premium_tokens_limit = Column( + premium_credit_micros_limit = Column( BigInteger, nullable=False, - default=config.PREMIUM_TOKEN_LIMIT, - server_default=str(config.PREMIUM_TOKEN_LIMIT), + default=config.PREMIUM_CREDIT_MICROS_LIMIT, + server_default=str(config.PREMIUM_CREDIT_MICROS_LIMIT), ) - premium_tokens_used = Column( + premium_credit_micros_used = Column( BigInteger, nullable=False, default=0, server_default="0" ) - premium_tokens_reserved = Column( + premium_credit_micros_reserved = Column( BigInteger, nullable=False, default=0, server_default="0" ) diff --git a/surfsense_backend/app/etl_pipeline/etl_pipeline_service.py b/surfsense_backend/app/etl_pipeline/etl_pipeline_service.py index 4bb38b7b0..d45bd780c 100644 --- a/surfsense_backend/app/etl_pipeline/etl_pipeline_service.py +++ b/surfsense_backend/app/etl_pipeline/etl_pipeline_service.py @@ -68,12 +68,25 @@ class EtlPipelineService: etl_service="VISION_LLM", content_type="image", ) - except Exception: - logging.warning( - "Vision LLM failed for %s, falling back to document parser", - request.filename, - exc_info=True, - ) + except Exception as exc: + # Special-case quota exhaustion so we log a clearer message + # — the vision LLM didn't "fail", the user just ran out of + # premium credit. Falling through to the document parser + # is a graceful degradation: OCR/Unstructured still + # extracts text from the image without burning credit. + from app.services.billable_calls import QuotaInsufficientError + + if isinstance(exc, QuotaInsufficientError): + logging.info( + "Vision LLM quota exhausted for %s; falling back to document parser", + request.filename, + ) + else: + logging.warning( + "Vision LLM failed for %s, falling back to document parser", + request.filename, + exc_info=True, + ) else: logging.info( "No vision LLM provided, falling back to document parser for %s", diff --git a/surfsense_backend/app/routes/image_generation_routes.py b/surfsense_backend/app/routes/image_generation_routes.py index 97a3559b9..34ed80207 100644 --- a/surfsense_backend/app/routes/image_generation_routes.py +++ b/surfsense_backend/app/routes/image_generation_routes.py @@ -36,6 +36,11 @@ from app.schemas import ( ImageGenerationListRead, ImageGenerationRead, ) +from app.services.billable_calls import ( + DEFAULT_IMAGE_RESERVE_MICROS, + QuotaInsufficientError, + billable_call, +) from app.services.image_gen_router_service import ( IMAGE_GEN_AUTO_MODE_ID, ImageGenRouterService, @@ -92,6 +97,50 @@ def _build_model_string( return f"{prefix}/{model_name}" +async def _resolve_billing_for_image_gen( + session: AsyncSession, + config_id: int | None, + search_space: SearchSpace, +) -> tuple[str, str, int]: + """Resolve ``(billing_tier, base_model, reserve_micros)`` for a request. + + The resolution mirrors ``_execute_image_generation``'s lookup tree but + only extracts the fields needed for billing — we do this *before* + ``billable_call`` so the reservation is correctly sized for the + config that will actually run, and so we don't open an + ``ImageGeneration`` row for a request that's about to 402. + + User-owned (positive ID) BYOK configs are always free — they cost + the user nothing on our side. Auto mode currently treats as free + because the underlying router can dispatch to either premium or + free YAML configs and we don't surface the resolved deployment up + here yet. Bringing Auto under premium billing would require + threading the chosen deployment back from ``ImageGenRouterService``. + """ + resolved_id = config_id + if resolved_id is None: + resolved_id = search_space.image_generation_config_id or IMAGE_GEN_AUTO_MODE_ID + + if is_image_gen_auto_mode(resolved_id): + return ("free", "auto", DEFAULT_IMAGE_RESERVE_MICROS) + + if resolved_id < 0: + cfg = _get_global_image_gen_config(resolved_id) or {} + billing_tier = str(cfg.get("billing_tier", "free")).lower() + base_model = _build_model_string( + cfg.get("provider", ""), + cfg.get("model_name", ""), + cfg.get("custom_provider"), + ) + reserve_micros = int( + cfg.get("quota_reserve_micros") or DEFAULT_IMAGE_RESERVE_MICROS + ) + return (billing_tier, base_model, reserve_micros) + + # Positive ID = user-owned BYOK image-gen config — always free. + return ("free", "user_byok", DEFAULT_IMAGE_RESERVE_MICROS) + + async def _execute_image_generation( session: AsyncSession, image_gen: ImageGeneration, @@ -225,6 +274,9 @@ async def get_global_image_gen_configs( "litellm_params": {}, "is_global": True, "is_auto_mode": True, + # Auto mode currently treated as free until per-deployment + # billing-tier surfacing lands (see _resolve_billing_for_image_gen). + "billing_tier": "free", } ) @@ -241,6 +293,8 @@ async def get_global_image_gen_configs( "api_version": cfg.get("api_version") or None, "litellm_params": cfg.get("litellm_params", {}), "is_global": True, + "billing_tier": cfg.get("billing_tier", "free"), + "quota_reserve_micros": cfg.get("quota_reserve_micros"), } ) @@ -454,7 +508,26 @@ async def create_image_generation( session: AsyncSession = Depends(get_async_session), user: User = Depends(current_active_user), ): - """Create and execute an image generation request.""" + """Create and execute an image generation request. + + Premium configs are gated by the user's shared premium credit pool. + The flow is: + + 1. Permission check + load the search space (cheap, no provider call). + 2. Resolve which config will run so we know its billing tier and the + worst-case reservation size *before* opening any DB rows. + 3. Wrap the entire ImageGeneration row insert + provider call in + ``billable_call``. If quota is denied, ``billable_call`` raises + ``QuotaInsufficientError`` *before* we flush a row, which we + translate to HTTP 402 (no orphaned rows on the user's account, + no inserted error rows for "you ran out of credit"). + 4. On success, the actual ``response_cost`` flows through the + LiteLLM callback into the accumulator, and ``billable_call`` + finalizes the debit at exit. Inner ``try/except`` still catches + provider errors and stores them on ``error_message`` (HTTP 200 + with ``error_message`` set is preserved for failed-but-not-quota + scenarios — clients already know how to surface those). + """ try: await check_permission( session, @@ -471,33 +544,70 @@ async def create_image_generation( if not search_space: raise HTTPException(status_code=404, detail="Search space not found") - db_image_gen = ImageGeneration( - prompt=data.prompt, - model=data.model, - n=data.n, - quality=data.quality, - size=data.size, - style=data.style, - response_format=data.response_format, - image_generation_config_id=data.image_generation_config_id, - search_space_id=data.search_space_id, - created_by_id=user.id, + billing_tier, base_model, reserve_micros = await _resolve_billing_for_image_gen( + session, data.image_generation_config_id, search_space ) - session.add(db_image_gen) - await session.flush() - try: - await _execute_image_generation(session, db_image_gen, search_space) - except Exception as e: - logger.exception("Image generation call failed") - db_image_gen.error_message = str(e) + # billable_call runs OUTSIDE the inner try/except so QuotaInsufficientError + # propagates to the outer ``except QuotaInsufficientError`` handler + # below as HTTP 402 — it is intentionally NOT swallowed into + # ``error_message`` because that would (1) imply a successful row + # exists when none does, and (2) return HTTP 200 to a client + # whose request was actively *denied* (issue K). + async with billable_call( + user_id=search_space.user_id, + search_space_id=data.search_space_id, + billing_tier=billing_tier, + base_model=base_model, + quota_reserve_micros_override=reserve_micros, + usage_type="image_generation", + call_details={"model": base_model, "prompt": data.prompt[:100]}, + ): + db_image_gen = ImageGeneration( + prompt=data.prompt, + model=data.model, + n=data.n, + quality=data.quality, + size=data.size, + style=data.style, + response_format=data.response_format, + image_generation_config_id=data.image_generation_config_id, + search_space_id=data.search_space_id, + created_by_id=user.id, + ) + session.add(db_image_gen) + await session.flush() - await session.commit() - await session.refresh(db_image_gen) - return db_image_gen + try: + await _execute_image_generation(session, db_image_gen, search_space) + except Exception as e: + logger.exception("Image generation call failed") + db_image_gen.error_message = str(e) + + await session.commit() + await session.refresh(db_image_gen) + return db_image_gen except HTTPException: raise + except QuotaInsufficientError as exc: + # The user's premium credit pool is empty. No DB row is created + # because ``billable_call`` denies before yielding (issue K). + await session.rollback() + raise HTTPException( + status_code=402, + detail={ + "error_code": "premium_quota_exhausted", + "usage_type": exc.usage_type, + "used_micros": exc.used_micros, + "limit_micros": exc.limit_micros, + "remaining_micros": exc.remaining_micros, + "message": ( + "Out of premium credits for image generation. " + "Purchase additional credits or switch to a free model." + ), + }, + ) from exc except SQLAlchemyError: await session.rollback() raise HTTPException( diff --git a/surfsense_backend/app/routes/new_chat_routes.py b/surfsense_backend/app/routes/new_chat_routes.py index 28b197ca2..d3bd51129 100644 --- a/surfsense_backend/app/routes/new_chat_routes.py +++ b/surfsense_backend/app/routes/new_chat_routes.py @@ -1366,7 +1366,11 @@ async def append_message( # flush assigns the PK/defaults without a round-trip SELECT await session.flush() - # Persist token usage if provided (for assistant messages) + # Persist token usage if provided (for assistant messages). + # ``cost_micros`` is the provider USD cost reported by LiteLLM, + # forwarded by the FE through the appendMessage round-trip so + # the historical TokenUsage row matches the credit debit applied + # at finalize time. token_usage_data = raw_body.get("token_usage") if token_usage_data and message_role == NewChatMessageRole.ASSISTANT: await record_token_usage( @@ -1377,6 +1381,7 @@ async def append_message( prompt_tokens=token_usage_data.get("prompt_tokens", 0), completion_tokens=token_usage_data.get("completion_tokens", 0), total_tokens=token_usage_data.get("total_tokens", 0), + cost_micros=token_usage_data.get("cost_micros", 0), model_breakdown=token_usage_data.get("usage"), call_details=token_usage_data.get("call_details"), thread_id=thread_id, diff --git a/surfsense_backend/app/routes/search_spaces_routes.py b/surfsense_backend/app/routes/search_spaces_routes.py index 72715ea5b..5ecfb1814 100644 --- a/surfsense_backend/app/routes/search_spaces_routes.py +++ b/surfsense_backend/app/routes/search_spaces_routes.py @@ -594,6 +594,7 @@ async def _get_image_gen_config_by_id( "model_name": "auto", "is_global": True, "is_auto_mode": True, + "billing_tier": "free", } if config_id < 0: @@ -610,6 +611,7 @@ async def _get_image_gen_config_by_id( "api_version": cfg.get("api_version") or None, "litellm_params": cfg.get("litellm_params", {}), "is_global": True, + "billing_tier": cfg.get("billing_tier", "free"), } return None @@ -652,6 +654,7 @@ async def _get_vision_llm_config_by_id( "model_name": "auto", "is_global": True, "is_auto_mode": True, + "billing_tier": "free", } if config_id < 0: @@ -668,6 +671,7 @@ async def _get_vision_llm_config_by_id( "api_version": cfg.get("api_version") or None, "litellm_params": cfg.get("litellm_params", {}), "is_global": True, + "billing_tier": cfg.get("billing_tier", "free"), } return None diff --git a/surfsense_backend/app/routes/stripe_routes.py b/surfsense_backend/app/routes/stripe_routes.py index cfdd4b52a..aed74ec8d 100644 --- a/surfsense_backend/app/routes/stripe_routes.py +++ b/surfsense_backend/app/routes/stripe_routes.py @@ -251,9 +251,16 @@ async def _fulfill_completed_token_purchase( metadata = _get_metadata(checkout_session) user_id = metadata.get("user_id") quantity = int(metadata.get("quantity", "0")) - tokens_per_unit = int(metadata.get("tokens_per_unit", "0")) + # Read the new metadata key first, fall back to the legacy one so + # in-flight checkout sessions created before the cost-credits + # release still fulfil correctly (the unit is numerically the + # same: $1 buys 1_000_000 micro-USD == 1_000_000 tokens). + credit_micros_per_unit = int( + metadata.get("credit_micros_per_unit") + or metadata.get("tokens_per_unit", "0") + ) - if not user_id or quantity <= 0 or tokens_per_unit <= 0: + if not user_id or quantity <= 0 or credit_micros_per_unit <= 0: logger.error( "Skipping token fulfillment for session %s: incomplete metadata %s", checkout_session_id, @@ -268,7 +275,7 @@ async def _fulfill_completed_token_purchase( getattr(checkout_session, "payment_intent", None) ), quantity=quantity, - tokens_granted=quantity * tokens_per_unit, + credit_micros_granted=quantity * credit_micros_per_unit, amount_total=getattr(checkout_session, "amount_total", None), currency=getattr(checkout_session, "currency", None), status=PremiumTokenPurchaseStatus.PENDING, @@ -303,9 +310,14 @@ async def _fulfill_completed_token_purchase( purchase.stripe_payment_intent_id = _normalize_optional_string( getattr(checkout_session, "payment_intent", None) ) - user.premium_tokens_limit = ( - max(user.premium_tokens_used, user.premium_tokens_limit) - + purchase.tokens_granted + # Top up the user's credit balance by the granted micro-USD amount. + # ``max(used, limit)`` clamps the case where the legacy code wrote a + # used value above the limit (e.g. underbilling rounding) so adding + # ``credit_micros_granted`` always lifts the limit by the full pack + # size rather than disappearing into past overuse. + user.premium_credit_micros_limit = ( + max(user.premium_credit_micros_used, user.premium_credit_micros_limit) + + purchase.credit_micros_granted ) await db_session.commit() @@ -532,12 +544,18 @@ async def create_token_checkout_session( user: User = Depends(current_active_user), db_session: AsyncSession = Depends(get_async_session), ): - """Create a Stripe Checkout Session for buying premium token packs.""" + """Create a Stripe Checkout Session for buying premium credit packs. + + Each pack grants ``STRIPE_CREDIT_MICROS_PER_UNIT`` micro-USD of + credit (default 1_000_000 = $1.00). The user's balance is debited + at the actual provider cost reported by LiteLLM at finalize time, + so $1 of credit always buys $1 worth of provider usage at cost. + """ _ensure_token_buying_enabled() stripe_client = get_stripe_client() price_id = _get_required_token_price_id() success_url, cancel_url = _get_token_checkout_urls(body.search_space_id) - tokens_granted = body.quantity * config.STRIPE_TOKENS_PER_UNIT + credit_micros_granted = body.quantity * config.STRIPE_CREDIT_MICROS_PER_UNIT try: checkout_session = stripe_client.v1.checkout.sessions.create( @@ -556,8 +574,8 @@ async def create_token_checkout_session( "metadata": { "user_id": str(user.id), "quantity": str(body.quantity), - "tokens_per_unit": str(config.STRIPE_TOKENS_PER_UNIT), - "purchase_type": "premium_tokens", + "credit_micros_per_unit": str(config.STRIPE_CREDIT_MICROS_PER_UNIT), + "purchase_type": "premium_credit", }, } ) @@ -583,7 +601,7 @@ async def create_token_checkout_session( getattr(checkout_session, "payment_intent", None) ), quantity=body.quantity, - tokens_granted=tokens_granted, + credit_micros_granted=credit_micros_granted, amount_total=getattr(checkout_session, "amount_total", None), currency=getattr(checkout_session, "currency", None), status=PremiumTokenPurchaseStatus.PENDING, @@ -598,14 +616,19 @@ async def create_token_checkout_session( async def get_token_status( user: User = Depends(current_active_user), ): - """Return token-buying availability and current premium quota for frontend.""" - used = user.premium_tokens_used - limit = user.premium_tokens_limit + """Return token-buying availability and current premium credit quota for frontend. + + Values are in micro-USD (1_000_000 = $1.00); the FE divides by 1M + when displaying. The route name is preserved for back-compat with + pinned client deployments. + """ + used = user.premium_credit_micros_used + limit = user.premium_credit_micros_limit return TokenStripeStatusResponse( token_buying_enabled=config.STRIPE_TOKEN_BUYING_ENABLED, - premium_tokens_used=used, - premium_tokens_limit=limit, - premium_tokens_remaining=max(0, limit - used), + premium_credit_micros_used=used, + premium_credit_micros_limit=limit, + premium_credit_micros_remaining=max(0, limit - used), ) diff --git a/surfsense_backend/app/routes/vision_llm_routes.py b/surfsense_backend/app/routes/vision_llm_routes.py index 315c7c9fe..4f7e9f725 100644 --- a/surfsense_backend/app/routes/vision_llm_routes.py +++ b/surfsense_backend/app/routes/vision_llm_routes.py @@ -82,6 +82,9 @@ async def get_global_vision_llm_configs( "litellm_params": {}, "is_global": True, "is_auto_mode": True, + # Auto mode treated as free until per-deployment billing-tier + # surfacing lands; see ``get_vision_llm`` for parity. + "billing_tier": "free", } ) @@ -98,6 +101,10 @@ async def get_global_vision_llm_configs( "api_version": cfg.get("api_version") or None, "litellm_params": cfg.get("litellm_params", {}), "is_global": True, + "billing_tier": cfg.get("billing_tier", "free"), + "quota_reserve_tokens": cfg.get("quota_reserve_tokens"), + "input_cost_per_token": cfg.get("input_cost_per_token"), + "output_cost_per_token": cfg.get("output_cost_per_token"), } ) diff --git a/surfsense_backend/app/schemas/image_generation.py b/surfsense_backend/app/schemas/image_generation.py index 69f534e20..facca7b86 100644 --- a/surfsense_backend/app/schemas/image_generation.py +++ b/surfsense_backend/app/schemas/image_generation.py @@ -215,6 +215,12 @@ class GlobalImageGenConfigRead(BaseModel): Schema for reading global image generation configs from YAML. Global configs have negative IDs. API key is hidden. ID 0 is reserved for Auto mode (LiteLLM Router load balancing). + + The ``billing_tier`` field allows the frontend to show a Premium/Free + badge and (more importantly) tells the backend whether to debit the + user's premium credit pool when this config is used. ``"free"`` is + the default for backward compatibility — admins must explicitly opt + a global config into ``"premium"``. """ id: int = Field( @@ -231,3 +237,15 @@ class GlobalImageGenConfigRead(BaseModel): litellm_params: dict[str, Any] | None = None is_global: bool = True is_auto_mode: bool = False + billing_tier: str = Field( + default="free", + description="'free' or 'premium'. Premium debits the user's premium credit pool (USD-cost-based).", + ) + quota_reserve_micros: int | None = Field( + default=None, + description=( + "Optional override for the reservation amount (in micro-USD) used when " + "this image generation is premium. Falls back to " + "QUOTA_DEFAULT_IMAGE_RESERVE_MICROS when omitted." + ), + ) diff --git a/surfsense_backend/app/schemas/new_chat.py b/surfsense_backend/app/schemas/new_chat.py index ec5eefc07..892ff9693 100644 --- a/surfsense_backend/app/schemas/new_chat.py +++ b/surfsense_backend/app/schemas/new_chat.py @@ -39,6 +39,7 @@ class TokenUsageSummary(BaseModel): prompt_tokens: int = 0 completion_tokens: int = 0 total_tokens: int = 0 + cost_micros: int = 0 model_breakdown: dict | None = None model_config = ConfigDict(from_attributes=True) diff --git a/surfsense_backend/app/schemas/stripe.py b/surfsense_backend/app/schemas/stripe.py index 3edd3e9e4..57265ec8e 100644 --- a/surfsense_backend/app/schemas/stripe.py +++ b/surfsense_backend/app/schemas/stripe.py @@ -70,13 +70,17 @@ class CreateTokenCheckoutSessionResponse(BaseModel): class TokenPurchaseRead(BaseModel): - """Serialized premium token purchase record.""" + """Serialized premium credit purchase record. + + ``credit_micros_granted`` is in micro-USD (1_000_000 = $1.00). The + schema name kept ``Token`` for API back-compat with pinned clients. + """ id: uuid.UUID stripe_checkout_session_id: str stripe_payment_intent_id: str | None = None quantity: int - tokens_granted: int + credit_micros_granted: int amount_total: int | None = None currency: str | None = None status: str @@ -87,15 +91,19 @@ class TokenPurchaseRead(BaseModel): class TokenPurchaseHistoryResponse(BaseModel): - """Response containing the user's premium token purchases.""" + """Response containing the user's premium credit purchases.""" purchases: list[TokenPurchaseRead] class TokenStripeStatusResponse(BaseModel): - """Response describing token-buying availability and current quota.""" + """Response describing premium-credit-buying availability and balance. + + All ``premium_credit_micros_*`` fields are in micro-USD; the FE + divides by 1_000_000 to display USD. + """ token_buying_enabled: bool - premium_tokens_used: int = 0 - premium_tokens_limit: int = 0 - premium_tokens_remaining: int = 0 + premium_credit_micros_used: int = 0 + premium_credit_micros_limit: int = 0 + premium_credit_micros_remaining: int = 0 diff --git a/surfsense_backend/app/schemas/vision_llm.py b/surfsense_backend/app/schemas/vision_llm.py index ab2e609dc..e55333a9d 100644 --- a/surfsense_backend/app/schemas/vision_llm.py +++ b/surfsense_backend/app/schemas/vision_llm.py @@ -62,6 +62,15 @@ class VisionLLMConfigPublic(BaseModel): class GlobalVisionLLMConfigRead(BaseModel): + """Schema for reading global vision LLM configs from YAML. + + The ``billing_tier`` field allows the frontend to show a Premium/Free + badge and (more importantly) tells the backend whether to debit the + user's premium credit pool when this config is used. ``"free"`` is + the default for backward compatibility — admins must explicitly opt + a global config into ``"premium"``. + """ + id: int = Field(...) name: str description: str | None = None @@ -73,3 +82,26 @@ class GlobalVisionLLMConfigRead(BaseModel): litellm_params: dict[str, Any] | None = None is_global: bool = True is_auto_mode: bool = False + billing_tier: str = Field( + default="free", + description="'free' or 'premium'. Premium debits the user's premium credit pool (USD-cost-based).", + ) + quota_reserve_tokens: int | None = Field( + default=None, + description=( + "Optional override for the per-call reservation in *tokens* — " + "converted to micro-USD via the model's input/output prices at " + "reservation time. Falls back to QUOTA_DEFAULT_RESERVE_TOKENS." + ), + ) + input_cost_per_token: float | None = Field( + default=None, + description=( + "Optional input price in USD/token. Used by pricing_registration to " + "register custom Azure / OpenRouter aliases with LiteLLM at startup." + ), + ) + output_cost_per_token: float | None = Field( + default=None, + description="Optional output price in USD/token. Pair with input_cost_per_token.", + ) diff --git a/surfsense_backend/app/services/billable_calls.py b/surfsense_backend/app/services/billable_calls.py new file mode 100644 index 000000000..f5ca9818e --- /dev/null +++ b/surfsense_backend/app/services/billable_calls.py @@ -0,0 +1,430 @@ +""" +Per-call billable wrapper for image generation, vision LLM extraction, and +any other short-lived premium operation that must charge against the user's +shared premium credit pool. + +The ``billable_call`` async context manager encapsulates the standard +"reserve → execute → finalize / release → record audit row" lifecycle in a +single primitive so callers (the image-generation REST route and the +vision-LLM wrapper used during indexing) don't have to re-implement it. + +KEY DESIGN POINTS (issue A, B): + +1. **Session isolation.** ``billable_call`` takes *no* ``db_session`` + argument. All ``TokenQuotaService.premium_*`` calls and the audit-row + insert each run inside their own ``shielded_async_session()``. This + guarantees that a quota commit/rollback can never accidentally flush or + roll back rows the caller has staged in the request's main session + (e.g. a freshly-created ``ImageGeneration`` row). + +2. **ContextVar safety.** The accumulator is scoped via + :func:`scoped_turn` (which uses ``ContextVar.reset(token)``), so a + nested ``billable_call`` inside an outer chat turn cannot corrupt the + chat turn's accumulator. + +3. **Free configs are still audited.** Free calls bypass the reserve / + finalize dance entirely but still record a ``TokenUsage`` audit row with + the LiteLLM-reported ``cost_micros``. This keeps the cost-attribution + pipeline complete for analytics even when nothing is debited. + +4. **Quota denial raises ``QuotaInsufficientError``.** The route handler is + responsible for translating that into HTTP 402. We *do not* catch the + denial inside ``billable_call`` — letting it propagate also prevents + the image-generation route from creating an ``ImageGeneration`` row + for a request that never actually ran. +""" + +from __future__ import annotations + +import logging +from collections.abc import AsyncIterator +from contextlib import asynccontextmanager +from typing import Any +from uuid import UUID, uuid4 + +from sqlalchemy.ext.asyncio import AsyncSession + +from app.config import config +from app.db import shielded_async_session +from app.services.token_quota_service import ( + TokenQuotaService, + estimate_call_reserve_micros, +) +from app.services.token_tracking_service import ( + TurnTokenAccumulator, + record_token_usage, + scoped_turn, +) + +logger = logging.getLogger(__name__) + + +class QuotaInsufficientError(Exception): + """Raised when ``TokenQuotaService.premium_reserve`` denies a billable + call because the user has exhausted their premium credit pool. + + The route handler should catch this and return HTTP 402 Payment + Required (or the equivalent for the surface area). Outside of the HTTP + layer (e.g. the ``QuotaCheckedVisionLLM`` wrapper used during indexing) + callers may catch this and degrade gracefully — e.g. fall back to OCR + when vision is unavailable. + """ + + def __init__( + self, + *, + usage_type: str, + used_micros: int, + limit_micros: int, + remaining_micros: int, + ) -> None: + self.usage_type = usage_type + self.used_micros = used_micros + self.limit_micros = limit_micros + self.remaining_micros = remaining_micros + super().__init__( + f"Premium credit exhausted for {usage_type}: " + f"used={used_micros} limit={limit_micros} remaining={remaining_micros} (micro-USD)" + ) + + +@asynccontextmanager +async def billable_call( + *, + user_id: UUID, + search_space_id: int, + billing_tier: str, + base_model: str, + quota_reserve_tokens: int | None = None, + quota_reserve_micros_override: int | None = None, + usage_type: str, + thread_id: int | None = None, + message_id: int | None = None, + call_details: dict[str, Any] | None = None, +) -> AsyncIterator[TurnTokenAccumulator]: + """Wrap a single billable LLM/image call. + + Args: + user_id: Owner of the credit pool to debit. For vision-LLM during + indexing this is the *search-space owner* (issue M), not the + triggering user. + search_space_id: Required — recorded on the ``TokenUsage`` audit row. + billing_tier: ``"premium"`` debits; anything else (``"free"``) skips + the reserve/finalize dance but still records an audit row with + the captured cost. + base_model: Used by :func:`estimate_call_reserve_micros` to compute + a worst-case reservation from LiteLLM's pricing table. + quota_reserve_tokens: Optional per-config override for the chat-style + reserve estimator (vision LLM uses this). + quota_reserve_micros_override: Optional flat micro-USD reservation + (image generation uses this — its cost shape is per-image, not + per-token). + usage_type: ``"image_generation"`` / ``"vision_extraction"`` / etc. + Recorded on the ``TokenUsage`` row. + thread_id, message_id: Optional FK columns on ``TokenUsage``. + call_details: Optional per-call metadata (model name, parameters) + forwarded to ``record_token_usage``. + + Yields: + The ``TurnTokenAccumulator`` scoped to this call. The caller invokes + the underlying LLM/image API while inside the ``async with``; the + ``TokenTrackingCallback`` populates the accumulator automatically. + + Raises: + QuotaInsufficientError: when premium and ``premium_reserve`` denies. + """ + is_premium = billing_tier == "premium" + + async with scoped_turn() as acc: + # ---------- Free path: just audit ------------------------------- + if not is_premium: + try: + yield acc + finally: + # Always audit, even on exception, so we capture cost when + # provider returns successfully but the caller raises later. + try: + async with shielded_async_session() as audit_session: + await record_token_usage( + audit_session, + usage_type=usage_type, + search_space_id=search_space_id, + user_id=user_id, + prompt_tokens=acc.total_prompt_tokens, + completion_tokens=acc.total_completion_tokens, + total_tokens=acc.grand_total, + cost_micros=acc.total_cost_micros, + model_breakdown=acc.per_message_summary(), + call_details=call_details, + thread_id=thread_id, + message_id=message_id, + ) + await audit_session.commit() + except Exception: + logger.exception( + "[billable_call] free-path audit insert failed for " + "usage_type=%s user_id=%s", + usage_type, + user_id, + ) + return + + # ---------- Premium path: reserve → execute → finalize ---------- + if quota_reserve_micros_override is not None: + reserve_micros = max(1, int(quota_reserve_micros_override)) + else: + reserve_micros = estimate_call_reserve_micros( + base_model=base_model or "", + quota_reserve_tokens=quota_reserve_tokens, + ) + + request_id = str(uuid4()) + + async with shielded_async_session() as quota_session: + reserve_result = await TokenQuotaService.premium_reserve( + db_session=quota_session, + user_id=user_id, + request_id=request_id, + reserve_micros=reserve_micros, + ) + + if not reserve_result.allowed: + logger.info( + "[billable_call] reserve DENIED user=%s usage_type=%s " + "reserve=%d used=%d limit=%d remaining=%d", + user_id, + usage_type, + reserve_micros, + reserve_result.used, + reserve_result.limit, + reserve_result.remaining, + ) + raise QuotaInsufficientError( + usage_type=usage_type, + used_micros=reserve_result.used, + limit_micros=reserve_result.limit, + remaining_micros=reserve_result.remaining, + ) + + logger.info( + "[billable_call] reserve OK user=%s usage_type=%s reserve_micros=%d " + "(remaining=%d)", + user_id, + usage_type, + reserve_micros, + reserve_result.remaining, + ) + + try: + yield acc + except BaseException: + # Release on any failure (including QuotaInsufficientError raised + # from a downstream call, asyncio cancellation, etc.). We use + # BaseException so cancellation also releases. + try: + async with shielded_async_session() as quota_session: + await TokenQuotaService.premium_release( + db_session=quota_session, + user_id=user_id, + reserved_micros=reserve_micros, + ) + except Exception: + logger.exception( + "[billable_call] premium_release failed for user=%s " + "reserve_micros=%d (reservation will be GC'd by quota " + "reconciliation if/when implemented)", + user_id, + reserve_micros, + ) + raise + + # ---------- Success: finalize + audit ---------------------------- + actual_micros = acc.total_cost_micros + try: + async with shielded_async_session() as quota_session: + final_result = await TokenQuotaService.premium_finalize( + db_session=quota_session, + user_id=user_id, + request_id=request_id, + actual_micros=actual_micros, + reserved_micros=reserve_micros, + ) + logger.info( + "[billable_call] finalize user=%s usage_type=%s actual=%d " + "reserved=%d → used=%d/%d (remaining=%d)", + user_id, + usage_type, + actual_micros, + reserve_micros, + final_result.used, + final_result.limit, + final_result.remaining, + ) + except Exception: + # Last-ditch: if finalize itself fails, we must at least release + # so the reservation doesn't leak. + logger.exception( + "[billable_call] premium_finalize failed for user=%s; " + "attempting release", + user_id, + ) + try: + async with shielded_async_session() as quota_session: + await TokenQuotaService.premium_release( + db_session=quota_session, + user_id=user_id, + reserved_micros=reserve_micros, + ) + except Exception: + logger.exception( + "[billable_call] release after finalize failure ALSO failed " + "for user=%s", + user_id, + ) + + try: + async with shielded_async_session() as audit_session: + await record_token_usage( + audit_session, + usage_type=usage_type, + search_space_id=search_space_id, + user_id=user_id, + prompt_tokens=acc.total_prompt_tokens, + completion_tokens=acc.total_completion_tokens, + total_tokens=acc.grand_total, + cost_micros=actual_micros, + model_breakdown=acc.per_message_summary(), + call_details=call_details, + thread_id=thread_id, + message_id=message_id, + ) + await audit_session.commit() + except Exception: + logger.exception( + "[billable_call] premium-path audit insert failed for " + "usage_type=%s user_id=%s (debit was applied)", + usage_type, + user_id, + ) + + +async def _resolve_agent_billing_for_search_space( + session: AsyncSession, + search_space_id: int, + *, + thread_id: int | None = None, +) -> tuple[UUID, str, str]: + """Resolve ``(owner_user_id, billing_tier, base_model)`` for the search-space + agent LLM. + + Used by Celery tasks (podcast generation, video presentation) to bill the + search-space owner's premium credit pool when the agent LLM is premium. + + Resolution rules mirror chat at ``stream_new_chat.py:2294-2351``: + + - Search space not found / no ``agent_llm_id``: raise ``ValueError``. + - **Auto mode** (``id == AUTO_FASTEST_ID == 0``): + * ``thread_id`` is set: delegate to + ``resolve_or_get_pinned_llm_config_id`` (the same call chat uses) and + recurse into the resolved id. Reuses chat's existing pin if present + so the same model bills for chat + downstream podcast/video. If the + user is not premium-eligible, the pin service auto-restricts to free + deployments — denial only happens later in + ``billable_call.premium_reserve`` if the pin really is premium and + credit ran out mid-flow. + * ``thread_id`` is None: fallback to ``("free", "auto")``. Forward-compat + for any future direct-API path; today both Celery tasks always pass + ``thread_id``. + - **Negative id** (global YAML / OpenRouter): ``cfg["billing_tier"]`` + (defaults to ``"free"`` via ``app/config/__init__.py:52`` setdefault), + ``base_model = litellm_params.get("base_model") or model_name`` — + NOT provider-prefixed, matching chat's cost-map lookup convention. + - **Positive id** (user BYOK ``NewLLMConfig``): always free (matches + ``AgentConfig.from_new_llm_config`` which hard-codes ``billing_tier="free"``); + ``base_model`` from ``litellm_params`` or ``model_name``. + + Note on imports: ``llm_service``, ``auto_model_pin_service``, and + ``llm_router_service`` are imported lazily inside the function body to + avoid hoisting litellm side-effects (``litellm.callbacks = + [token_tracker]``, ``litellm.drop_params``, etc.) into + ``billable_calls.py``'s module load path. + """ + from sqlalchemy import select + + from app.db import NewLLMConfig, SearchSpace + + result = await session.execute( + select(SearchSpace).where(SearchSpace.id == search_space_id) + ) + search_space = result.scalars().first() + if search_space is None: + raise ValueError(f"Search space {search_space_id} not found") + + agent_llm_id = search_space.agent_llm_id + if agent_llm_id is None: + raise ValueError( + f"Search space {search_space_id} has no agent_llm_id configured" + ) + + owner_user_id: UUID = search_space.user_id + + from app.services.auto_model_pin_service import ( + AUTO_FASTEST_ID, + resolve_or_get_pinned_llm_config_id, + ) + + if agent_llm_id == AUTO_FASTEST_ID: + if thread_id is None: + return owner_user_id, "free", "auto" + try: + resolution = await resolve_or_get_pinned_llm_config_id( + session, + thread_id=thread_id, + search_space_id=search_space_id, + user_id=str(owner_user_id), + selected_llm_config_id=AUTO_FASTEST_ID, + ) + except ValueError: + logger.warning( + "[agent_billing] Auto-mode pin resolution failed for " + "search_space=%s thread=%s; falling back to free", + search_space_id, + thread_id, + exc_info=True, + ) + return owner_user_id, "free", "auto" + agent_llm_id = resolution.resolved_llm_config_id + + if agent_llm_id < 0: + from app.services.llm_service import get_global_llm_config + + cfg = get_global_llm_config(agent_llm_id) or {} + billing_tier = str(cfg.get("billing_tier", "free")).lower() + litellm_params = cfg.get("litellm_params") or {} + base_model = litellm_params.get("base_model") or cfg.get("model_name") or "" + return owner_user_id, billing_tier, base_model + + nlc_result = await session.execute( + select(NewLLMConfig).where( + NewLLMConfig.id == agent_llm_id, + NewLLMConfig.search_space_id == search_space_id, + ) + ) + nlc = nlc_result.scalars().first() + base_model = "" + if nlc is not None: + litellm_params = nlc.litellm_params or {} + base_model = litellm_params.get("base_model") or nlc.model_name or "" + return owner_user_id, "free", base_model + + +__all__ = [ + "QuotaInsufficientError", + "_resolve_agent_billing_for_search_space", + "billable_call", +] + + +# Re-export the config knob so callers don't have to import config just for +# the default image reserve. +DEFAULT_IMAGE_RESERVE_MICROS = config.QUOTA_DEFAULT_IMAGE_RESERVE_MICROS diff --git a/surfsense_backend/app/services/llm_router_service.py b/surfsense_backend/app/services/llm_router_service.py index 8a7b2919a..1e9d235c8 100644 --- a/surfsense_backend/app/services/llm_router_service.py +++ b/surfsense_backend/app/services/llm_router_service.py @@ -134,42 +134,16 @@ PROVIDER_MAP = { } -# Default ``api_base`` per LiteLLM provider prefix. Used as a safety net when -# a global LLM config does *not* specify ``api_base``: without this, LiteLLM -# happily picks up provider-agnostic env vars (e.g. ``AZURE_API_BASE``, -# ``OPENAI_API_BASE``) and routes, say, an ``openrouter/anthropic/claude-3-haiku`` -# request to an Azure endpoint, which then 404s with ``Resource not found``. -# Only providers with a well-known, stable public base URL are listed here — -# self-hosted / BYO-endpoint providers (ollama, custom, bedrock, vertex_ai, -# huggingface, databricks, cloudflare, replicate) are intentionally omitted -# so their existing config-driven behaviour is preserved. -PROVIDER_DEFAULT_API_BASE = { - "openrouter": "https://openrouter.ai/api/v1", - "groq": "https://api.groq.com/openai/v1", - "mistral": "https://api.mistral.ai/v1", - "perplexity": "https://api.perplexity.ai", - "xai": "https://api.x.ai/v1", - "cerebras": "https://api.cerebras.ai/v1", - "deepinfra": "https://api.deepinfra.com/v1/openai", - "fireworks_ai": "https://api.fireworks.ai/inference/v1", - "together_ai": "https://api.together.xyz/v1", - "anyscale": "https://api.endpoints.anyscale.com/v1", - "cometapi": "https://api.cometapi.com/v1", - "sambanova": "https://api.sambanova.ai/v1", -} - - -# Canonical provider → base URL when a config uses a generic ``openai``-style -# prefix but the ``provider`` field tells us which API it really is -# (e.g. DeepSeek/Alibaba/Moonshot/Zhipu/MiniMax all use ``openai`` compat but -# each has its own base URL). -PROVIDER_KEY_DEFAULT_API_BASE = { - "DEEPSEEK": "https://api.deepseek.com/v1", - "ALIBABA_QWEN": "https://dashscope-intl.aliyuncs.com/compatible-mode/v1", - "MOONSHOT": "https://api.moonshot.ai/v1", - "ZHIPU": "https://open.bigmodel.cn/api/paas/v4", - "MINIMAX": "https://api.minimax.io/v1", -} +# ``PROVIDER_DEFAULT_API_BASE`` and ``PROVIDER_KEY_DEFAULT_API_BASE`` were +# hoisted to ``app.services.provider_api_base`` so vision and image-gen +# call sites can share the exact same defense (OpenRouter / Groq / etc. +# 404-ing against an inherited Azure endpoint). Re-exported here for +# backward compatibility with any external import. +from app.services.provider_api_base import ( # noqa: E402 + PROVIDER_DEFAULT_API_BASE, + PROVIDER_KEY_DEFAULT_API_BASE, + resolve_api_base, +) class LLMRouterService: @@ -466,14 +440,14 @@ class LLMRouterService: # Resolve ``api_base``. Config value wins; otherwise apply a # provider-aware default so the deployment does not silently # inherit unrelated env vars (e.g. ``AZURE_API_BASE``) and route - # requests to the wrong endpoint. See ``PROVIDER_DEFAULT_API_BASE`` + # requests to the wrong endpoint. See ``provider_api_base`` # docstring for the motivating bug (OpenRouter models 404-ing # against an Azure endpoint). - api_base = config.get("api_base") - if not api_base: - api_base = PROVIDER_KEY_DEFAULT_API_BASE.get(provider) - if not api_base: - api_base = PROVIDER_DEFAULT_API_BASE.get(provider_prefix) + api_base = resolve_api_base( + provider=provider, + provider_prefix=provider_prefix, + config_api_base=config.get("api_base"), + ) if api_base: litellm_params["api_base"] = api_base diff --git a/surfsense_backend/app/services/llm_service.py b/surfsense_backend/app/services/llm_service.py index 942a9b7af..72c10035d 100644 --- a/surfsense_backend/app/services/llm_service.py +++ b/surfsense_backend/app/services/llm_service.py @@ -496,8 +496,14 @@ async def get_vision_llm( - Auto mode (ID 0): VisionLLMRouterService - Global (negative ID): YAML configs - DB (positive ID): VisionLLMConfig table + + Premium global configs are wrapped in :class:`QuotaCheckedVisionLLM` + so each ``ainvoke`` debits the search-space owner's premium credit + pool. User-owned BYOK configs and free global configs are returned + unwrapped — they don't consume premium credit (issue M). """ from app.db import VisionLLMConfig + from app.services.quota_checked_vision_llm import QuotaCheckedVisionLLM from app.services.vision_llm_router_service import ( VISION_PROVIDER_MAP, VisionLLMRouterService, @@ -519,6 +525,8 @@ async def get_vision_llm( logger.error(f"No vision LLM configured for search space {search_space_id}") return None + owner_user_id = search_space.user_id + if is_vision_auto_mode(config_id): if not VisionLLMRouterService.is_initialized(): logger.error( @@ -526,6 +534,13 @@ async def get_vision_llm( ) return None try: + # Auto mode is currently treated as free at the wrapper + # level — the underlying router can dispatch to either + # premium or free YAML configs but routing decisions are + # opaque. If/when we want to bill Auto-routed vision + # calls we'd need to thread the resolved deployment's + # billing_tier back from the router. For now we keep + # parity with chat Auto, which also doesn't pre-classify. return ChatLiteLLMRouter( router=VisionLLMRouterService.get_router(), streaming=True, @@ -562,8 +577,21 @@ async def get_vision_llm( from app.agents.new_chat.llm_config import SanitizedChatLiteLLM - return SanitizedChatLiteLLM(**litellm_kwargs) + inner_llm = SanitizedChatLiteLLM(**litellm_kwargs) + billing_tier = str(global_cfg.get("billing_tier", "free")).lower() + if billing_tier == "premium": + return QuotaCheckedVisionLLM( + inner_llm, + user_id=owner_user_id, + search_space_id=search_space_id, + billing_tier=billing_tier, + base_model=model_string, + quota_reserve_tokens=global_cfg.get("quota_reserve_tokens"), + ) + return inner_llm + + # User-owned (positive ID) BYOK configs — always free. result = await session.execute( select(VisionLLMConfig).where( VisionLLMConfig.id == config_id, diff --git a/surfsense_backend/app/services/openrouter_integration_service.py b/surfsense_backend/app/services/openrouter_integration_service.py index 7e856d015..0d030f04f 100644 --- a/surfsense_backend/app/services/openrouter_integration_service.py +++ b/surfsense_backend/app/services/openrouter_integration_service.py @@ -93,6 +93,35 @@ def _is_text_output_model(model: dict) -> bool: return output_mods == ["text"] +def _is_image_output_model(model: dict) -> bool: + """Return True if the model can produce image output. + + OpenRouter's ``architecture.output_modalities`` is a list (e.g. + ``["image"]`` for pure image generators, ``["text", "image"]`` for + multi-modal generators that also emit captions). We accept any model + that can output images; the call site decides whether to use the + image-generation API or chat completion. + """ + output_mods = model.get("architecture", {}).get("output_modalities", []) or [] + return "image" in output_mods + + +def _is_vision_input_model(model: dict) -> bool: + """Return True if the model can ingest an image AND emit text. + + OpenRouter's ``architecture.input_modalities`` lists what the model + accepts; ``output_modalities`` lists what it produces. A vision LLM + is a model that takes images in and produces text out — i.e. it can + answer questions about a screenshot or extract content from an + image. Pure image-to-image models (e.g. style transfer) and + text-only models are excluded. + """ + arch = model.get("architecture", {}) or {} + input_mods = arch.get("input_modalities", []) or [] + output_mods = arch.get("output_modalities", []) or [] + return "image" in input_mods and "text" in output_mods + + def _supports_tool_calling(model: dict) -> bool: """Return True if the model supports function/tool calling.""" supported = model.get("supported_parameters") or [] @@ -175,6 +204,32 @@ async def _fetch_models_async() -> list[dict] | None: return None +def _extract_raw_pricing(raw_models: list[dict]) -> dict[str, dict[str, str]]: + """Return a ``{model_id: {"prompt": str, "completion": str}}`` map. + + Pricing values are kept as the raw OpenRouter strings (e.g. + ``"0.000003"``); ``pricing_registration`` converts them to floats + when registering with LiteLLM. Models with missing or malformed + pricing are simply omitted — operator-side risk if any of those are + premium. + """ + pricing: dict[str, dict[str, str]] = {} + for model in raw_models: + model_id = str(model.get("id") or "").strip() + if not model_id: + continue + p = model.get("pricing") or {} + prompt = p.get("prompt") + completion = p.get("completion") + if prompt is None and completion is None: + continue + pricing[model_id] = { + "prompt": str(prompt) if prompt is not None else "", + "completion": str(completion) if completion is not None else "", + } + return pricing + + def _generate_configs( raw_models: list[dict], settings: dict[str, Any], @@ -282,6 +337,162 @@ def _generate_configs( return configs +# ID-offset bands used to keep dynamic OpenRouter configs in their own +# namespace per surface. Image / vision get separate bands so a single +# Postgres-INTEGER cfg ID is unambiguous about which selector it belongs to. +_OPENROUTER_IMAGE_ID_OFFSET_DEFAULT = -20000 +_OPENROUTER_VISION_ID_OFFSET_DEFAULT = -30000 + + +def _generate_image_gen_configs( + raw_models: list[dict], settings: dict[str, Any] +) -> list[dict]: + """Convert OpenRouter image-generation models into global image-gen + config dicts (matches the YAML shape consumed by ``image_generation_routes``). + + Filter: + - architecture.output_modalities contains "image" + - compatible provider (excluded slugs blocked) + - allowed model id (excluded list blocked) + + Notably we *drop* the chat-only filters (``_supports_tool_calling`` and + ``_has_sufficient_context``) because tool calls and context windows are + irrelevant for the ``aimage_generation`` API. ``billing_tier`` is + derived per model the same way as chat (``_openrouter_tier``). + + Cost is intentionally *not* registered with LiteLLM at startup + (``pricing_registration`` skips image gen): OpenRouter image-gen + models are not in LiteLLM's native cost map and OpenRouter populates + ``response_cost`` directly from the response header. A defensive + branch in ``_extract_cost_usd`` handles the rare case where + ``usage.cost`` is missing — see ``token_tracking_service``. + """ + id_offset: int = int( + settings.get("image_id_offset") or _OPENROUTER_IMAGE_ID_OFFSET_DEFAULT + ) + api_key: str = settings.get("api_key", "") + rpm: int = settings.get("rpm", 200) + free_rpm: int = settings.get("free_rpm", 20) + litellm_params: dict = settings.get("litellm_params") or {} + + image_models = [ + m + for m in raw_models + if _is_image_output_model(m) + and _is_compatible_provider(m) + and _is_allowed_model(m) + and "/" in m.get("id", "") + ] + + configs: list[dict] = [] + taken: set[int] = set() + for model in image_models: + model_id: str = model["id"] + name: str = model.get("name", model_id) + tier = _openrouter_tier(model) + + cfg: dict[str, Any] = { + "id": _stable_config_id(model_id, id_offset, taken), + "name": name, + "description": f"{name} via OpenRouter (image generation)", + "provider": "OPENROUTER", + "model_name": model_id, + "api_key": api_key, + "api_base": "", + "api_version": None, + "rpm": free_rpm if tier == "free" else rpm, + "litellm_params": dict(litellm_params), + "billing_tier": tier, + _OPENROUTER_DYNAMIC_MARKER: True, + } + configs.append(cfg) + + return configs + + +def _generate_vision_llm_configs( + raw_models: list[dict], settings: dict[str, Any] +) -> list[dict]: + """Convert OpenRouter vision-capable LLMs into global vision-LLM config + dicts (matches the YAML shape consumed by ``vision_llm_routes``). + + Filter: + - architecture.input_modalities contains "image" + - architecture.output_modalities contains "text" + - compatible provider (excluded slugs blocked) + - allowed model id (excluded list blocked) + + Vision-LLM is invoked from the indexer (image extraction during + document upload) via ``langchain_litellm.ChatLiteLLM.ainvoke``, so + the chat-only ``_supports_tool_calling`` and ``_has_sufficient_context`` + filters do not apply: a small-context vision model that doesn't + advertise tool-calling is still perfectly viable for "describe this + image" prompts. + """ + id_offset: int = int( + settings.get("vision_id_offset") or _OPENROUTER_VISION_ID_OFFSET_DEFAULT + ) + api_key: str = settings.get("api_key", "") + rpm: int = settings.get("rpm", 200) + tpm: int = settings.get("tpm", 1_000_000) + free_rpm: int = settings.get("free_rpm", 20) + free_tpm: int = settings.get("free_tpm", 100_000) + quota_reserve_tokens: int = settings.get("quota_reserve_tokens", 4000) + litellm_params: dict = settings.get("litellm_params") or {} + + vision_models = [ + m + for m in raw_models + if _is_vision_input_model(m) + and _is_compatible_provider(m) + and _is_allowed_model(m) + and "/" in m.get("id", "") + ] + + configs: list[dict] = [] + taken: set[int] = set() + for model in vision_models: + model_id: str = model["id"] + name: str = model.get("name", model_id) + tier = _openrouter_tier(model) + pricing = model.get("pricing") or {} + + # Capture per-token prices so ``pricing_registration`` can + # register them with LiteLLM at startup (and so the cost + # estimator in ``estimate_call_reserve_micros`` can resolve + # them at reserve time). + try: + input_cost = float(pricing.get("prompt", 0) or 0) + except (TypeError, ValueError): + input_cost = 0.0 + try: + output_cost = float(pricing.get("completion", 0) or 0) + except (TypeError, ValueError): + output_cost = 0.0 + + cfg: dict[str, Any] = { + "id": _stable_config_id(model_id, id_offset, taken), + "name": name, + "description": f"{name} via OpenRouter (vision)", + "provider": "OPENROUTER", + "model_name": model_id, + "api_key": api_key, + "api_base": "", + "api_version": None, + "rpm": free_rpm if tier == "free" else rpm, + "tpm": free_tpm if tier == "free" else tpm, + "litellm_params": dict(litellm_params), + "billing_tier": tier, + "quota_reserve_tokens": quota_reserve_tokens, + "input_cost_per_token": input_cost or None, + "output_cost_per_token": output_cost or None, + _OPENROUTER_DYNAMIC_MARKER: True, + } + configs.append(cfg) + + return configs + + class OpenRouterIntegrationService: """Singleton that manages the dynamic OpenRouter model catalogue.""" @@ -300,6 +511,19 @@ class OpenRouterIntegrationService: # Shape: {model_name: {"gated": bool, "score": float | None}} self._health_cache: dict[str, dict[str, Any]] = {} self._enrich_task: asyncio.Task | None = None + # Raw OpenRouter pricing per model_id, captured at the same time + # we generate configs. Consumed by ``pricing_registration`` to + # teach LiteLLM the per-token cost of every dynamic deployment so + # the success-callback can populate ``response_cost`` correctly. + self._raw_pricing: dict[str, dict[str, str]] = {} + # Cached raw catalogue from the most recent fetch. Image / vision + # emitters reuse this to avoid a second network call per surface. + self._raw_models: list[dict] = [] + # Image / vision config caches (only populated when the matching + # opt-in flag is true on initialize). Refreshed in lockstep with + # the chat catalogue. + self._image_configs: list[dict] = [] + self._vision_configs: list[dict] = [] @classmethod def get_instance(cls) -> "OpenRouterIntegrationService": @@ -329,8 +553,32 @@ class OpenRouterIntegrationService: self._initialized = True return [] + self._raw_models = raw_models self._configs = _generate_configs(raw_models, settings) self._configs_by_id = {c["id"]: c for c in self._configs} + self._raw_pricing = _extract_raw_pricing(raw_models) + + # Populate image / vision caches when their opt-in flag is set. + # Empty otherwise so the accessors return [] without re-running + # filters every refresh. + if settings.get("image_generation_enabled"): + self._image_configs = _generate_image_gen_configs(raw_models, settings) + logger.info( + "OpenRouter integration: image-gen emission ON (%d models)", + len(self._image_configs), + ) + else: + self._image_configs = [] + + if settings.get("vision_enabled"): + self._vision_configs = _generate_vision_llm_configs(raw_models, settings) + logger.info( + "OpenRouter integration: vision LLM emission ON (%d models)", + len(self._vision_configs), + ) + else: + self._vision_configs = [] + self._initialized = True tier_counts = self._tier_counts(self._configs) @@ -369,6 +617,8 @@ class OpenRouterIntegrationService: new_configs = _generate_configs(raw_models, self._settings) new_by_id = {c["id"]: c for c in new_configs} + self._raw_pricing = _extract_raw_pricing(raw_models) + self._raw_models = raw_models from app.config import config as app_config @@ -382,6 +632,29 @@ class OpenRouterIntegrationService: self._configs = new_configs self._configs_by_id = new_by_id + # Image / vision lists are atomic-swapped the same way: filter out + # the previous dynamic entries from the live config list and append + # the freshly generated ones. No-ops when the opt-in flag is off. + if self._settings.get("image_generation_enabled"): + new_image = _generate_image_gen_configs(raw_models, self._settings) + static_image = [ + c + for c in app_config.GLOBAL_IMAGE_GEN_CONFIGS + if not c.get(_OPENROUTER_DYNAMIC_MARKER) + ] + app_config.GLOBAL_IMAGE_GEN_CONFIGS = static_image + new_image + self._image_configs = new_image + + if self._settings.get("vision_enabled"): + new_vision = _generate_vision_llm_configs(raw_models, self._settings) + static_vision = [ + c + for c in app_config.GLOBAL_VISION_LLM_CONFIGS + if not c.get(_OPENROUTER_DYNAMIC_MARKER) + ] + app_config.GLOBAL_VISION_LLM_CONFIGS = static_vision + new_vision + self._vision_configs = new_vision + # Catalogue churn invalidates per-config "recently healthy" credit # earned by the previous turn's preflight. Drop the whole table so # the next turn re-probes against the freshly loaded configs. @@ -407,6 +680,21 @@ class OpenRouterIntegrationService: # so a hand-picked dead OR model is gated like a dynamic one. await self._enrich_health_safely(static_configs + new_configs, log_summary=True) + # Re-register LiteLLM pricing for the freshly fetched catalogue + # so newly added OR models bill correctly on their first call. + # Runs before the router rebuild because the router may issue + # cost-table lookups during deployment registration. + try: + from app.services.pricing_registration import ( + register_pricing_from_global_configs, + ) + + register_pricing_from_global_configs() + except Exception as exc: + logger.warning( + "OpenRouter refresh: pricing re-registration skipped (%s)", exc + ) + # Rebuild the LiteLLM router so freshly fetched configs flow through # (dynamic OR premium entries now opt into the pool, free ones stay # out; a refresh also needs to pick up any static-config edits and @@ -635,3 +923,34 @@ class OpenRouterIntegrationService: def get_config_by_id(self, config_id: int) -> dict | None: return self._configs_by_id.get(config_id) + + def get_image_generation_configs(self) -> list[dict]: + """Return the dynamic OpenRouter image-generation configs (empty + list when the ``image_generation_enabled`` flag is off). + + Each entry already has ``billing_tier`` derived per-model from + OpenRouter's signals and is shaped to drop directly into + ``Config.GLOBAL_IMAGE_GEN_CONFIGS``. + """ + return list(self._image_configs) + + def get_vision_llm_configs(self) -> list[dict]: + """Return the dynamic OpenRouter vision-LLM configs (empty list + when the ``vision_enabled`` flag is off). + + Each entry exposes ``input_cost_per_token`` / ``output_cost_per_token`` + so ``pricing_registration`` can teach LiteLLM the cost of these + models the same way it does for chat — which keeps the billable + wrapper able to debit accurate micro-USD on a vision call. + """ + return list(self._vision_configs) + + def get_raw_pricing(self) -> dict[str, dict[str, str]]: + """Return the cached raw OpenRouter pricing map. + + Shape: ``{model_id: {"prompt": str, "completion": str}}``. The + values are the strings OpenRouter publishes (USD per token), + never converted to floats here so the caller can decide how to + handle malformed or unset entries. + """ + return dict(self._raw_pricing) diff --git a/surfsense_backend/app/services/pricing_registration.py b/surfsense_backend/app/services/pricing_registration.py new file mode 100644 index 000000000..de98e50c2 --- /dev/null +++ b/surfsense_backend/app/services/pricing_registration.py @@ -0,0 +1,274 @@ +""" +Pricing registration with LiteLLM. + +Many models reach our LiteLLM callback without LiteLLM knowing their +per-token cost — namely: + +* The ~300 dynamic OpenRouter deployments (their pricing only lives on + OpenRouter's ``/api/v1/models`` payload, never in LiteLLM's published + pricing table). +* Static YAML deployments whose ``base_model`` name is operator-defined + (e.g. custom Azure deployment names like ``gpt-5.4``) and therefore + not in LiteLLM's table either. + +Without registration, ``kwargs["response_cost"]`` is 0 for those calls +and the user gets billed nothing — a fail-safe but wrong answer for a +cost-based credit system. This module runs once at startup, after the +OpenRouter integration has fetched its catalogue, and registers each +known model's pricing with ``litellm.register_model()`` under multiple +plausible alias keys (LiteLLM's cost lookup may use any of them +depending on whether the call went through the Router, ChatLiteLLM, +or a direct ``acompletion``). + +Operators who run a custom Azure deployment whose ``base_model`` name +isn't in LiteLLM's table can declare per-token pricing inline in +``global_llm_config.yaml`` via ``input_cost_per_token`` and +``output_cost_per_token`` (USD per token, e.g. ``0.000002``). Without +that declaration the model's calls debit 0 — never overbilled. +""" + +from __future__ import annotations + +import logging +from typing import Any + +import litellm + +logger = logging.getLogger(__name__) + + +def _safe_float(value: Any) -> float: + """Return ``float(value)`` if it parses to a positive number, else 0.0.""" + if value is None: + return 0.0 + try: + f = float(value) + except (TypeError, ValueError): + return 0.0 + return f if f > 0 else 0.0 + + +def _alias_set_for_openrouter(model_id: str) -> list[str]: + """Return the alias keys to register an OpenRouter model under. + + LiteLLM's cost-callback lookup key varies by call path: + - Router with ``model="openrouter/X"`` → kwargs["model"] is + typically ``openrouter/X``. + - LiteLLM's own provider routing may strip the prefix and pass the + bare ``X`` to the cost-table lookup. + Registering under both keeps the lookup hermetic regardless of + which path the call took. + """ + aliases = [f"openrouter/{model_id}", model_id] + return list(dict.fromkeys(a for a in aliases if a)) + + +def _alias_set_for_yaml(provider: str, model_name: str, base_model: str) -> list[str]: + """Return the alias keys to register a static YAML deployment under. + + Same reasoning as the OpenRouter set: cover the bare ``base_model``, + the ``/`` form LiteLLM Router constructs, and the + bare ``model_name`` because callbacks sometimes see whichever was + configured first. + """ + provider_lower = (provider or "").lower() + aliases: list[str] = [] + if base_model: + aliases.append(base_model) + if provider_lower and base_model: + aliases.append(f"{provider_lower}/{base_model}") + if model_name and model_name != base_model: + aliases.append(model_name) + if provider_lower and model_name and model_name != base_model: + aliases.append(f"{provider_lower}/{model_name}") + # Azure deployments often surface as "azure/"; normalise the + # ``azure_openai`` provider slug to the LiteLLM-canonical ``azure``. + if provider_lower == "azure_openai": + if base_model: + aliases.append(f"azure/{base_model}") + if model_name and model_name != base_model: + aliases.append(f"azure/{model_name}") + return list(dict.fromkeys(a for a in aliases if a)) + + +def _register( + aliases: list[str], + *, + input_cost: float, + output_cost: float, + provider: str, + mode: str = "chat", +) -> int: + """Register a single pricing entry under every alias in ``aliases``. + + Returns the count of aliases successfully registered. + """ + payload: dict[str, dict[str, Any]] = {} + for alias in aliases: + payload[alias] = { + "input_cost_per_token": input_cost, + "output_cost_per_token": output_cost, + "litellm_provider": provider, + "mode": mode, + } + if not payload: + return 0 + try: + litellm.register_model(payload) + except Exception as exc: + logger.warning( + "[PricingRegistration] register_model failed for aliases=%s: %s", + aliases, + exc, + ) + return 0 + return len(payload) + + +def _register_chat_shape_configs( + configs: list[dict], + *, + or_pricing: dict[str, dict[str, str]], + label: str, +) -> tuple[int, int, int, list[str]]: + """Common loop that registers per-token pricing for a list of "chat-shape" + configs (chat or vision LLM — both use ``input_cost_per_token`` / + ``output_cost_per_token`` and the LiteLLM ``mode="chat"`` cost shape). + + Returns ``(registered_models, registered_aliases, skipped, sample_keys)``. + """ + registered_models = 0 + registered_aliases = 0 + skipped_no_pricing = 0 + sample_keys: list[str] = [] + + for cfg in configs: + provider = str(cfg.get("provider") or "").upper() + model_name = str(cfg.get("model_name") or "").strip() + litellm_params = cfg.get("litellm_params") or {} + base_model = str(litellm_params.get("base_model") or model_name).strip() + + if provider == "OPENROUTER": + entry = or_pricing.get(model_name) + if entry: + input_cost = _safe_float(entry.get("prompt")) + output_cost = _safe_float(entry.get("completion")) + else: + # Vision configs from ``_generate_vision_llm_configs`` + # carry their pricing inline because the OpenRouter + # raw-pricing cache is keyed by chat-catalogue model_id; + # vision flows pick up the inline values here. + input_cost = _safe_float(cfg.get("input_cost_per_token")) + output_cost = _safe_float(cfg.get("output_cost_per_token")) + if input_cost == 0.0 and output_cost == 0.0: + skipped_no_pricing += 1 + continue + aliases = _alias_set_for_openrouter(model_name) + count = _register( + aliases, + input_cost=input_cost, + output_cost=output_cost, + provider="openrouter", + ) + if count > 0: + registered_models += 1 + registered_aliases += count + if len(sample_keys) < 6: + sample_keys.extend(aliases[:2]) + continue + + input_cost = _safe_float( + cfg.get("input_cost_per_token") + or litellm_params.get("input_cost_per_token") + ) + output_cost = _safe_float( + cfg.get("output_cost_per_token") + or litellm_params.get("output_cost_per_token") + ) + if input_cost == 0.0 and output_cost == 0.0: + skipped_no_pricing += 1 + continue + aliases = _alias_set_for_yaml(provider, model_name, base_model) + provider_slug = "azure" if provider == "AZURE_OPENAI" else provider.lower() + count = _register( + aliases, + input_cost=input_cost, + output_cost=output_cost, + provider=provider_slug, + ) + if count > 0: + registered_models += 1 + registered_aliases += count + if len(sample_keys) < 6: + sample_keys.extend(aliases[:2]) + + logger.info( + "[PricingRegistration:%s] registered pricing for %d models (%d aliases); " + "%d configs had no pricing data; sample registered keys=%s", + label, + registered_models, + registered_aliases, + skipped_no_pricing, + sample_keys, + ) + return registered_models, registered_aliases, skipped_no_pricing, sample_keys + + +def register_pricing_from_global_configs() -> None: + """Register pricing for every known LLM deployment with LiteLLM. + + Walks ``config.GLOBAL_LLM_CONFIGS`` *and* ``config.GLOBAL_VISION_LLM_CONFIGS`` + so vision calls (during indexing) can resolve cost the same way chat + calls do — namely: + + 1. ``OPENROUTER``: pulls the cached raw pricing from + ``OpenRouterIntegrationService`` (populated during its own + startup fetch) and converts the per-token strings to floats. For + vision configs that carry pricing inline (``input_cost_per_token`` / + ``output_cost_per_token`` set on the cfg itself) we fall back to + those values when the OR cache misses the model. + 2. Anything else: looks for operator-declared + ``input_cost_per_token`` / ``output_cost_per_token`` on the YAML + config block (top-level or nested under ``litellm_params``). + + **Image generation is intentionally NOT registered here.** The cost + shape for image-gen is per-image (``output_cost_per_image``), not + per-token, and LiteLLM's ``register_model`` doesn't accept those + keys via the chat-cost path. OpenRouter image-gen models populate + ``response_cost`` directly from their response header instead, and + Azure-native image-gen models are already in LiteLLM's cost map. + + Calls without a resolved pair of costs are skipped, not registered + with zeros — operators who forget pricing get a "$0 debit" warning + in ``TokenTrackingCallback`` rather than silently overwriting any + pricing LiteLLM might know natively. + """ + from app.config import config as app_config + + chat_configs: list[dict] = list(getattr(app_config, "GLOBAL_LLM_CONFIGS", []) or []) + vision_configs: list[dict] = list( + getattr(app_config, "GLOBAL_VISION_LLM_CONFIGS", []) or [] + ) + if not chat_configs and not vision_configs: + logger.info("[PricingRegistration] no global configs to register") + return + + or_pricing: dict[str, dict[str, str]] = {} + try: + from app.services.openrouter_integration_service import ( + OpenRouterIntegrationService, + ) + + if OpenRouterIntegrationService.is_initialized(): + or_pricing = OpenRouterIntegrationService.get_instance().get_raw_pricing() + except Exception as exc: + logger.debug( + "[PricingRegistration] OpenRouter pricing not available yet: %s", exc + ) + + if chat_configs: + _register_chat_shape_configs(chat_configs, or_pricing=or_pricing, label="chat") + if vision_configs: + _register_chat_shape_configs( + vision_configs, or_pricing=or_pricing, label="vision" + ) diff --git a/surfsense_backend/app/services/provider_api_base.py b/surfsense_backend/app/services/provider_api_base.py new file mode 100644 index 000000000..979d7d3a1 --- /dev/null +++ b/surfsense_backend/app/services/provider_api_base.py @@ -0,0 +1,107 @@ +"""Provider-aware ``api_base`` resolution shared by chat / image-gen / vision. + +LiteLLM falls back to the module-global ``litellm.api_base`` when an +individual call doesn't pass one, which silently inherits provider-agnostic +env vars like ``AZURE_OPENAI_ENDPOINT`` / ``OPENAI_API_BASE``. Without an +explicit ``api_base``, an ``openrouter/`` request can end up at an +Azure endpoint and 404 with ``Resource not found`` (real reproducer: +[litellm/llms/openrouter/image_generation/transformation.py:242-263] appends +``/chat/completions`` to whatever inherited base it gets, regardless of +provider). + +The chat router has had this defense for a while +(``llm_router_service.py:466-478``). This module hoists the maps + cascade +into a tiny standalone helper so vision and image-gen can share the same +source of truth without an inter-service circular import. +""" + +from __future__ import annotations + + +PROVIDER_DEFAULT_API_BASE: dict[str, str] = { + "openrouter": "https://openrouter.ai/api/v1", + "groq": "https://api.groq.com/openai/v1", + "mistral": "https://api.mistral.ai/v1", + "perplexity": "https://api.perplexity.ai", + "xai": "https://api.x.ai/v1", + "cerebras": "https://api.cerebras.ai/v1", + "deepinfra": "https://api.deepinfra.com/v1/openai", + "fireworks_ai": "https://api.fireworks.ai/inference/v1", + "together_ai": "https://api.together.xyz/v1", + "anyscale": "https://api.endpoints.anyscale.com/v1", + "cometapi": "https://api.cometapi.com/v1", + "sambanova": "https://api.sambanova.ai/v1", +} +"""Default ``api_base`` per LiteLLM provider prefix (lowercase). + +Only providers with a well-known, stable public base URL are listed — +self-hosted / BYO-endpoint providers (ollama, custom, bedrock, vertex_ai, +huggingface, databricks, cloudflare, replicate) are intentionally omitted +so their existing config-driven behaviour is preserved.""" + + +PROVIDER_KEY_DEFAULT_API_BASE: dict[str, str] = { + "DEEPSEEK": "https://api.deepseek.com/v1", + "ALIBABA_QWEN": "https://dashscope-intl.aliyuncs.com/compatible-mode/v1", + "MOONSHOT": "https://api.moonshot.ai/v1", + "ZHIPU": "https://open.bigmodel.cn/api/paas/v4", + "MINIMAX": "https://api.minimax.io/v1", +} +"""Canonical provider key (uppercase) → base URL. + +Used when the LiteLLM provider prefix is the generic ``openai`` shim but the +config's ``provider`` field tells us which API it actually is (DeepSeek, +Alibaba, Moonshot, Zhipu, MiniMax all use the ``openai`` prefix but each +has its own base URL).""" + + +def resolve_api_base( + *, + provider: str | None, + provider_prefix: str | None, + config_api_base: str | None, +) -> str | None: + """Resolve a non-Azure-leaking ``api_base`` for a deployment. + + Cascade (first non-empty wins): + 1. The config's own ``api_base`` (whitespace-only treated as missing). + 2. ``PROVIDER_KEY_DEFAULT_API_BASE[provider.upper()]``. + 3. ``PROVIDER_DEFAULT_API_BASE[provider_prefix.lower()]``. + 4. ``None`` — caller should NOT set ``api_base`` and let the LiteLLM + provider integration apply its own default (e.g. AzureOpenAI's + deployment-derived URL, custom provider's per-deployment URL). + + Args: + provider: The config's ``provider`` field (e.g. ``"OPENROUTER"``, + ``"DEEPSEEK"``). Case-insensitive. + provider_prefix: The LiteLLM model-string prefix the same call + site builds for the model id (e.g. ``"openrouter"``, + ``"groq"``). Case-insensitive. + config_api_base: ``api_base`` from the global YAML / DB row / + OpenRouter dynamic config. Empty / whitespace-only means + "missing" — the resolver still applies the cascade. + + Returns: + A URL string, or ``None`` if no default applies for this provider. + """ + if config_api_base and config_api_base.strip(): + return config_api_base + + if provider: + key_default = PROVIDER_KEY_DEFAULT_API_BASE.get(provider.upper()) + if key_default: + return key_default + + if provider_prefix: + prefix_default = PROVIDER_DEFAULT_API_BASE.get(provider_prefix.lower()) + if prefix_default: + return prefix_default + + return None + + +__all__ = [ + "PROVIDER_DEFAULT_API_BASE", + "PROVIDER_KEY_DEFAULT_API_BASE", + "resolve_api_base", +] diff --git a/surfsense_backend/app/services/quota_checked_vision_llm.py b/surfsense_backend/app/services/quota_checked_vision_llm.py new file mode 100644 index 000000000..0040e5a5b --- /dev/null +++ b/surfsense_backend/app/services/quota_checked_vision_llm.py @@ -0,0 +1,105 @@ +""" +Vision LLM proxy that enforces premium credit quota on every ``ainvoke``. + +Used by :func:`app.services.llm_service.get_vision_llm` so callers in the +indexing pipeline (file processors, connector indexers, etl pipeline) can +keep invoking the LLM exactly the way they do today — ``await llm.ainvoke(...)`` +— without threading ``user_id`` through every parser. The wrapper looks like +a chat model from the outside; on the inside it routes each call through +``billable_call`` so the user's premium credit pool is reserved → finalized +or released, and a ``TokenUsage`` audit row is written. + +Free configs are returned unwrapped from ``get_vision_llm`` (they do not +need quota enforcement) so this class only ever wraps premium configs. + +Why a wrapper instead of plumbing ``user_id`` through every caller: + +* The indexer ecosystem has 8+ entry points (Google Drive, OneDrive, + Dropbox, local-folder, file-processor, ETL pipeline) each calling + ``parse_with_vision_llm(...)``. Adding a ``user_id`` argument to each is + invasive, error-prone, and easy for a future indexer to forget. +* Per the design (issue M), we always debit the *search-space owner*, not + the triggering user, so ``user_id`` is fully derivable from the search + space the caller is already operating on. The wrapper captures it once + at construction time. +* ``langchain_litellm.ChatLiteLLM`` has no public hook for "before each + call run this coroutine"; subclassing isn't safe across versions because + it derives from ``BaseChatModel`` which expects specific Pydantic shapes. + Composition via attribute proxying (``__getattr__``) is robust to + upstream changes — every method other than ``ainvoke`` falls through to + the inner LLM unchanged. +""" + +from __future__ import annotations + +import logging +from typing import Any +from uuid import UUID + +from app.services.billable_calls import QuotaInsufficientError, billable_call + +logger = logging.getLogger(__name__) + + +class QuotaCheckedVisionLLM: + """Composition wrapper around a langchain chat model that enforces + premium credit quota on every ``ainvoke``. + + Anything other than ``ainvoke`` is forwarded to the inner model so + ``invoke`` (sync), ``astream``, ``with_structured_output``, etc. all + still work — they simply bypass quota enforcement, which is fine + because the indexing pipeline only ever calls ``ainvoke`` today. + """ + + def __init__( + self, + inner_llm: Any, + *, + user_id: UUID, + search_space_id: int, + billing_tier: str, + base_model: str, + quota_reserve_tokens: int | None, + usage_type: str = "vision_extraction", + ) -> None: + self._inner = inner_llm + self._user_id = user_id + self._search_space_id = search_space_id + self._billing_tier = billing_tier + self._base_model = base_model + self._quota_reserve_tokens = quota_reserve_tokens + self._usage_type = usage_type + + async def ainvoke(self, input: Any, *args: Any, **kwargs: Any) -> Any: + """Proxied async invoke that runs the underlying call inside + ``billable_call``. + + Raises: + QuotaInsufficientError: when the user has exhausted their + premium credit pool. Caller (``etl_pipeline_service._extract_image``) + catches this and falls back to the document parser. + """ + async with billable_call( + user_id=self._user_id, + search_space_id=self._search_space_id, + billing_tier=self._billing_tier, + base_model=self._base_model, + quota_reserve_tokens=self._quota_reserve_tokens, + usage_type=self._usage_type, + call_details={"model": self._base_model}, + ): + return await self._inner.ainvoke(input, *args, **kwargs) + + def __getattr__(self, name: str) -> Any: + """Forward everything else (``invoke``, ``astream``, ``bind``, + ``with_structured_output``, …) to the inner model. + + ``__getattr__`` is only consulted when the attribute is *not* + already found on the proxy, which is exactly the contract we + want — methods we override stay on the proxy, the rest fall + through. + """ + return getattr(self._inner, name) + + +__all__ = ["QuotaCheckedVisionLLM", "QuotaInsufficientError"] diff --git a/surfsense_backend/app/services/token_quota_service.py b/surfsense_backend/app/services/token_quota_service.py index a3ec7aed0..310c3eb5e 100644 --- a/surfsense_backend/app/services/token_quota_service.py +++ b/surfsense_backend/app/services/token_quota_service.py @@ -22,6 +22,71 @@ from app.config import config logger = logging.getLogger(__name__) +# --------------------------------------------------------------------------- +# Per-call reservation estimator (USD micro-units) +# --------------------------------------------------------------------------- + +# Minimum reserve in micros so a user with $0.0001 left can still make a tiny +# request, and so models without registered pricing reserve at least +# something while the call runs (debited 0 at finalize anyway when their +# cost can't be resolved). +_QUOTA_MIN_RESERVE_MICROS = 100 + + +def estimate_call_reserve_micros( + *, + base_model: str, + quota_reserve_tokens: int | None, +) -> int: + """Return the number of micro-USD to reserve for one premium call. + + Computes a worst-case upper bound from LiteLLM's per-token pricing + table: + + reserve_usd ≈ reserve_tokens x (input_cost + output_cost) + + so the math scales with model cost — Claude Opus + 4K reserve_tokens + naturally reserves ≈ $0.36, while a cheap model reserves only a few + cents. Clamped to ``[_QUOTA_MIN_RESERVE_MICROS, QUOTA_MAX_RESERVE_MICROS]`` + so a misconfigured "$1000/M" model can't lock the whole balance on + one call. + + If ``litellm.get_model_info`` raises (model unknown) we fall back to + the floor — 100 micros / $0.0001 — which is enough to gate a sane + request without over-reserving for a model whose pricing the + operator hasn't declared yet. + """ + reserve_tokens = quota_reserve_tokens or config.QUOTA_MAX_RESERVE_PER_CALL + if reserve_tokens <= 0: + reserve_tokens = config.QUOTA_MAX_RESERVE_PER_CALL + + try: + from litellm import get_model_info + + info = get_model_info(base_model) if base_model else {} + input_cost = float(info.get("input_cost_per_token") or 0.0) + output_cost = float(info.get("output_cost_per_token") or 0.0) + except Exception as exc: + logger.debug( + "[quota_reserve] cost lookup failed for base_model=%s: %s", + base_model, + exc, + ) + input_cost = 0.0 + output_cost = 0.0 + + if input_cost == 0.0 and output_cost == 0.0: + return _QUOTA_MIN_RESERVE_MICROS + + reserve_usd = reserve_tokens * (input_cost + output_cost) + reserve_micros = round(reserve_usd * 1_000_000) + if reserve_micros < _QUOTA_MIN_RESERVE_MICROS: + reserve_micros = _QUOTA_MIN_RESERVE_MICROS + if reserve_micros > config.QUOTA_MAX_RESERVE_MICROS: + reserve_micros = config.QUOTA_MAX_RESERVE_MICROS + return reserve_micros + + class QuotaScope(StrEnum): ANONYMOUS = "anonymous" PREMIUM = "premium" @@ -444,8 +509,16 @@ class TokenQuotaService: db_session: AsyncSession, user_id: Any, request_id: str, - reserve_tokens: int, + reserve_micros: int, ) -> QuotaResult: + """Reserve ``reserve_micros`` (USD micro-units) from the user's + premium credit balance. + + ``QuotaResult.used``/``limit``/``reserved``/``remaining`` are + all in micro-USD on this code path; callers (chat stream, + token-status route, FE display) convert to dollars by dividing + by 1_000_000. + """ from app.db import User user = ( @@ -465,11 +538,11 @@ class TokenQuotaService: limit=0, ) - limit = user.premium_tokens_limit - used = user.premium_tokens_used - reserved = user.premium_tokens_reserved + limit = user.premium_credit_micros_limit + used = user.premium_credit_micros_used + reserved = user.premium_credit_micros_reserved - effective = used + reserved + reserve_tokens + effective = used + reserved + reserve_micros if effective > limit: remaining = max(0, limit - used - reserved) await db_session.rollback() @@ -482,10 +555,10 @@ class TokenQuotaService: remaining=remaining, ) - user.premium_tokens_reserved = reserved + reserve_tokens + user.premium_credit_micros_reserved = reserved + reserve_micros await db_session.commit() - new_reserved = reserved + reserve_tokens + new_reserved = reserved + reserve_micros remaining = max(0, limit - used - new_reserved) warning_threshold = int(limit * 0.8) @@ -510,9 +583,12 @@ class TokenQuotaService: db_session: AsyncSession, user_id: Any, request_id: str, - actual_tokens: int, - reserved_tokens: int, + actual_micros: int, + reserved_micros: int, ) -> QuotaResult: + """Settle the reservation: release ``reserved_micros`` and debit + ``actual_micros`` (the LiteLLM-reported provider cost in micro-USD). + """ from app.db import User user = ( @@ -529,16 +605,18 @@ class TokenQuotaService: allowed=False, status=QuotaStatus.BLOCKED, used=0, limit=0 ) - user.premium_tokens_reserved = max( - 0, user.premium_tokens_reserved - reserved_tokens + user.premium_credit_micros_reserved = max( + 0, user.premium_credit_micros_reserved - reserved_micros + ) + user.premium_credit_micros_used = ( + user.premium_credit_micros_used + actual_micros ) - user.premium_tokens_used = user.premium_tokens_used + actual_tokens await db_session.commit() - limit = user.premium_tokens_limit - used = user.premium_tokens_used - reserved = user.premium_tokens_reserved + limit = user.premium_credit_micros_limit + used = user.premium_credit_micros_used + reserved = user.premium_credit_micros_reserved remaining = max(0, limit - used - reserved) warning_threshold = int(limit * 0.8) @@ -562,8 +640,13 @@ class TokenQuotaService: async def premium_release( db_session: AsyncSession, user_id: Any, - reserved_tokens: int, + reserved_micros: int, ) -> None: + """Release ``reserved_micros`` previously held by ``premium_reserve``. + + Used when a request fails before finalize (so the reservation + doesn't leak credit). + """ from app.db import User user = ( @@ -576,8 +659,8 @@ class TokenQuotaService: .scalar_one_or_none() ) if user is not None: - user.premium_tokens_reserved = max( - 0, user.premium_tokens_reserved - reserved_tokens + user.premium_credit_micros_reserved = max( + 0, user.premium_credit_micros_reserved - reserved_micros ) await db_session.commit() @@ -598,9 +681,9 @@ class TokenQuotaService: allowed=False, status=QuotaStatus.BLOCKED, used=0, limit=0 ) - limit = user.premium_tokens_limit - used = user.premium_tokens_used - reserved = user.premium_tokens_reserved + limit = user.premium_credit_micros_limit + used = user.premium_credit_micros_used + reserved = user.premium_credit_micros_reserved remaining = max(0, limit - used - reserved) warning_threshold = int(limit * 0.8) diff --git a/surfsense_backend/app/services/token_tracking_service.py b/surfsense_backend/app/services/token_tracking_service.py index 9aa8c6e70..9406d9be4 100644 --- a/surfsense_backend/app/services/token_tracking_service.py +++ b/surfsense_backend/app/services/token_tracking_service.py @@ -16,11 +16,14 @@ from __future__ import annotations import dataclasses import logging +from collections.abc import AsyncIterator +from contextlib import asynccontextmanager from contextvars import ContextVar from dataclasses import dataclass, field from typing import Any from uuid import UUID +import litellm from litellm.integrations.custom_logger import CustomLogger from sqlalchemy.ext.asyncio import AsyncSession @@ -35,6 +38,8 @@ class TokenCallRecord: prompt_tokens: int completion_tokens: int total_tokens: int + cost_micros: int = 0 + call_kind: str = "chat" @dataclass @@ -49,6 +54,8 @@ class TurnTokenAccumulator: prompt_tokens: int, completion_tokens: int, total_tokens: int, + cost_micros: int = 0, + call_kind: str = "chat", ) -> None: self.calls.append( TokenCallRecord( @@ -56,20 +63,28 @@ class TurnTokenAccumulator: prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=total_tokens, + cost_micros=cost_micros, + call_kind=call_kind, ) ) def per_message_summary(self) -> dict[str, dict[str, int]]: - """Return token counts grouped by model name.""" + """Return token counts (and cost) grouped by model name.""" by_model: dict[str, dict[str, int]] = {} for c in self.calls: entry = by_model.setdefault( c.model, - {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}, + { + "prompt_tokens": 0, + "completion_tokens": 0, + "total_tokens": 0, + "cost_micros": 0, + }, ) entry["prompt_tokens"] += c.prompt_tokens entry["completion_tokens"] += c.completion_tokens entry["total_tokens"] += c.total_tokens + entry["cost_micros"] += c.cost_micros return by_model @property @@ -84,6 +99,21 @@ class TurnTokenAccumulator: def total_completion_tokens(self) -> int: return sum(c.completion_tokens for c in self.calls) + @property + def total_cost_micros(self) -> int: + """Sum of per-call ``cost_micros`` across the entire turn. + + Used by ``stream_new_chat`` to debit a premium turn's actual + provider cost (in micro-USD) from the user's premium credit + balance. ``cost_micros`` per call is captured by + ``TokenTrackingCallback.async_log_success_event`` from + ``kwargs["response_cost"]`` (LiteLLM's auto-calculated cost), + with multiple fallback paths so OpenRouter dynamic models and + custom Azure deployments still bill correctly when our + ``pricing_registration`` ran at startup. + """ + return sum(c.cost_micros for c in self.calls) + def serialized_calls(self) -> list[dict[str, Any]]: return [dataclasses.asdict(c) for c in self.calls] @@ -94,7 +124,14 @@ _turn_accumulator: ContextVar[TurnTokenAccumulator | None] = ContextVar( def start_turn() -> TurnTokenAccumulator: - """Create a fresh accumulator for the current async context and return it.""" + """Create a fresh accumulator for the current async context and return it. + + NOTE: Used by ``stream_new_chat`` for the long-lived chat turn. For + short-lived per-call billable wrappers (image generation REST endpoint, + vision LLM during indexing) prefer :func:`scoped_turn`, which uses a + ContextVar reset token to restore the *previous* accumulator on exit and + avoids leaking call records across reservations (issue B). + """ acc = TurnTokenAccumulator() _turn_accumulator.set(acc) logger.info("[TokenTracking] start_turn: new accumulator created (id=%s)", id(acc)) @@ -105,6 +142,140 @@ def get_current_accumulator() -> TurnTokenAccumulator | None: return _turn_accumulator.get() +@asynccontextmanager +async def scoped_turn() -> AsyncIterator[TurnTokenAccumulator]: + """Async context manager that scopes a fresh ``TurnTokenAccumulator`` + for the duration of the ``async with`` block, then *resets* the + ContextVar to its previous value on exit. + + This is the safe primitive for per-call billable operations + (image generation, vision LLM extraction, podcasts) that may run + inside an outer chat turn or be called sequentially from the same + background worker. Using ``ContextVar.set`` without ``reset`` (as + :func:`start_turn` does) would leak the inner accumulator into the + outer scope, causing the outer chat turn to debit cost twice. + + Usage:: + + async with scoped_turn() as acc: + await llm.ainvoke(...) + # acc.total_cost_micros captures cost from the LiteLLM callback + # Outer accumulator (if any) is restored here. + """ + acc = TurnTokenAccumulator() + token = _turn_accumulator.set(acc) + logger.debug( + "[TokenTracking] scoped_turn: enter (acc id=%s, prev token=%s)", + id(acc), + token, + ) + try: + yield acc + finally: + _turn_accumulator.reset(token) + logger.debug( + "[TokenTracking] scoped_turn: exit (acc id=%s captured %d call(s), %d micros total)", + id(acc), + len(acc.calls), + acc.total_cost_micros, + ) + + +def _extract_cost_usd( + kwargs: dict[str, Any], + response_obj: Any, + model: str, + prompt_tokens: int, + completion_tokens: int, + is_image: bool = False, +) -> float: + """Best-effort USD cost extraction for a single LLM/image call. + + Tries four sources in priority order and returns the first that + yields a positive number; returns 0.0 if all four fail (the call + will then debit nothing from the user's balance — fail-safe). + + Sources: + 1. ``kwargs["response_cost"]`` — LiteLLM's standard callback + field, populated for ``Router.acompletion`` since PR #12500. + 2. ``response_obj._hidden_params["response_cost"]`` — same value + exposed on the response itself. + 3. ``litellm.completion_cost(completion_response=response_obj)`` + — recompute from the response and LiteLLM's pricing table. + 4. ``litellm.cost_per_token(model, prompt_tokens, completion_tokens)`` + — manual fallback for OpenRouter/custom-Azure models that + only resolve via aliases registered by + ``pricing_registration`` at startup. **Skipped for image + responses** — ``cost_per_token`` does not support ``ImageResponse`` + and would raise; the cost map for image-gen lives in different + keys (``output_cost_per_image``) handled by ``completion_cost``. + """ + cost = kwargs.get("response_cost") + if cost is not None: + try: + value = float(cost) + except (TypeError, ValueError): + value = 0.0 + if value > 0: + return value + + hidden = getattr(response_obj, "_hidden_params", None) or {} + if isinstance(hidden, dict): + cost = hidden.get("response_cost") + if cost is not None: + try: + value = float(cost) + except (TypeError, ValueError): + value = 0.0 + if value > 0: + return value + + try: + value = float(litellm.completion_cost(completion_response=response_obj)) + if value > 0: + return value + except Exception as exc: + if is_image: + # Image-gen path: OpenRouter's image responses can omit + # ``usage.cost`` and LiteLLM's ``default_image_cost_calculator`` + # then *raises* (no cost map for OpenRouter image models). + # Bail out with a warning rather than falling through to + # cost_per_token (which is also incompatible with ImageResponse). + logger.warning( + "[TokenTracking] completion_cost failed for image model=%s " + "(provider may have omitted usage.cost). Debiting 0. " + "Cause: %s", + model, + exc, + ) + return 0.0 + logger.debug( + "[TokenTracking] completion_cost failed for model=%s: %s", model, exc + ) + + if is_image: + # Never call cost_per_token for ImageResponse — keys mismatch and + # the function is documented chat-only. + return 0.0 + + if model and (prompt_tokens > 0 or completion_tokens > 0): + try: + prompt_cost, completion_cost = litellm.cost_per_token( + model=model, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + ) + value = float(prompt_cost) + float(completion_cost) + if value > 0: + return value + except Exception as exc: + logger.debug( + "[TokenTracking] cost_per_token failed for model=%s: %s", model, exc + ) + + return 0.0 + + class TokenTrackingCallback(CustomLogger): """LiteLLM callback that captures token usage into the turn accumulator.""" @@ -122,6 +293,13 @@ class TokenTrackingCallback(CustomLogger): ) return + # Detect image generation responses — they have a different usage + # shape (ImageUsage with input_tokens/output_tokens) and require a + # different cost-extraction path. We probe by class name to avoid a + # hard import dependency on litellm internals. + response_cls = type(response_obj).__name__ + is_image = response_cls == "ImageResponse" + usage = getattr(response_obj, "usage", None) if not usage: logger.debug( @@ -129,24 +307,66 @@ class TokenTrackingCallback(CustomLogger): ) return - prompt_tokens = getattr(usage, "prompt_tokens", 0) or 0 - completion_tokens = getattr(usage, "completion_tokens", 0) or 0 - total_tokens = getattr(usage, "total_tokens", 0) or 0 + if is_image: + # ``ImageUsage`` exposes ``input_tokens`` / ``output_tokens`` + # (not prompt_tokens/completion_tokens). Several providers + # populate only one or neither (e.g. OpenRouter's gpt-image-1 + # passes through `input_tokens` from the prompt but no + # completion); fall through gracefully to 0. + prompt_tokens = getattr(usage, "input_tokens", 0) or 0 + completion_tokens = getattr(usage, "output_tokens", 0) or 0 + total_tokens = ( + getattr(usage, "total_tokens", 0) or prompt_tokens + completion_tokens + ) + call_kind = "image_generation" + else: + prompt_tokens = getattr(usage, "prompt_tokens", 0) or 0 + completion_tokens = getattr(usage, "completion_tokens", 0) or 0 + total_tokens = getattr(usage, "total_tokens", 0) or 0 + call_kind = "chat" model = kwargs.get("model", "unknown") + cost_usd = _extract_cost_usd( + kwargs=kwargs, + response_obj=response_obj, + model=model, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + is_image=is_image, + ) + cost_micros = round(cost_usd * 1_000_000) if cost_usd > 0 else 0 + + if cost_micros == 0 and (prompt_tokens > 0 or completion_tokens > 0): + logger.warning( + "[TokenTracking] No cost resolved for model=%s prompt=%d completion=%d " + "kind=%s — debiting 0. Register pricing via pricing_registration or YAML " + "input_cost_per_token/output_cost_per_token (or rely on response_cost " + "for image generation).", + model, + prompt_tokens, + completion_tokens, + call_kind, + ) + acc.add( model=model, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=total_tokens, + cost_micros=cost_micros, + call_kind=call_kind, ) logger.info( - "[TokenTracking] Captured: model=%s prompt=%d completion=%d total=%d (accumulator now has %d calls)", + "[TokenTracking] Captured: model=%s kind=%s prompt=%d completion=%d total=%d " + "cost=$%.6f (%d micros) (accumulator now has %d calls)", model, + call_kind, prompt_tokens, completion_tokens, total_tokens, + cost_usd, + cost_micros, len(acc.calls), ) @@ -168,6 +388,7 @@ async def record_token_usage( prompt_tokens: int = 0, completion_tokens: int = 0, total_tokens: int = 0, + cost_micros: int = 0, model_breakdown: dict[str, Any] | None = None, call_details: dict[str, Any] | None = None, thread_id: int | None = None, @@ -185,6 +406,7 @@ async def record_token_usage( prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=total_tokens, + cost_micros=cost_micros, model_breakdown=model_breakdown, call_details=call_details, thread_id=thread_id, @@ -194,11 +416,12 @@ async def record_token_usage( ) session.add(record) logger.debug( - "[TokenTracking] recorded %s usage: prompt=%d completion=%d total=%d", + "[TokenTracking] recorded %s usage: prompt=%d completion=%d total=%d cost_micros=%d", usage_type, prompt_tokens, completion_tokens, total_tokens, + cost_micros, ) return record except Exception: diff --git a/surfsense_backend/app/services/vision_llm_router_service.py b/surfsense_backend/app/services/vision_llm_router_service.py index 0d782ab2b..ed5de921c 100644 --- a/surfsense_backend/app/services/vision_llm_router_service.py +++ b/surfsense_backend/app/services/vision_llm_router_service.py @@ -3,6 +3,8 @@ from typing import Any from litellm import Router +from app.services.provider_api_base import resolve_api_base + logger = logging.getLogger(__name__) VISION_AUTO_MODE_ID = 0 @@ -108,10 +110,11 @@ class VisionLLMRouterService: if not config.get("model_name") or not config.get("api_key"): return None + provider = config.get("provider", "").upper() if config.get("custom_provider"): - model_string = f"{config['custom_provider']}/{config['model_name']}" + provider_prefix = config["custom_provider"] + model_string = f"{provider_prefix}/{config['model_name']}" else: - provider = config.get("provider", "").upper() provider_prefix = VISION_PROVIDER_MAP.get(provider, provider.lower()) model_string = f"{provider_prefix}/{config['model_name']}" @@ -120,8 +123,13 @@ class VisionLLMRouterService: "api_key": config.get("api_key"), } - if config.get("api_base"): - litellm_params["api_base"] = config["api_base"] + api_base = resolve_api_base( + provider=provider, + provider_prefix=provider_prefix, + config_api_base=config.get("api_base"), + ) + if api_base: + litellm_params["api_base"] = api_base if config.get("api_version"): litellm_params["api_version"] = config["api_version"] diff --git a/surfsense_backend/app/tasks/celery_tasks/podcast_tasks.py b/surfsense_backend/app/tasks/celery_tasks/podcast_tasks.py index 953011ecf..937877473 100644 --- a/surfsense_backend/app/tasks/celery_tasks/podcast_tasks.py +++ b/surfsense_backend/app/tasks/celery_tasks/podcast_tasks.py @@ -9,7 +9,13 @@ from sqlalchemy import select from app.agents.podcaster.graph import graph as podcaster_graph from app.agents.podcaster.state import State as PodcasterState from app.celery_app import celery_app +from app.config import config as app_config from app.db import Podcast, PodcastStatus +from app.services.billable_calls import ( + QuotaInsufficientError, + _resolve_agent_billing_for_search_space, + billable_call, +) from app.tasks.celery_tasks import get_celery_session_maker logger = logging.getLogger(__name__) @@ -96,6 +102,31 @@ async def _generate_content_podcast( podcast.status = PodcastStatus.GENERATING await session.commit() + try: + ( + owner_user_id, + billing_tier, + base_model, + ) = await _resolve_agent_billing_for_search_space( + session, + search_space_id, + thread_id=podcast.thread_id, + ) + except ValueError as resolve_err: + logger.error( + "Podcast %s: cannot resolve billing for search_space=%s: %s", + podcast.id, + search_space_id, + resolve_err, + ) + podcast.status = PodcastStatus.FAILED + await session.commit() + return { + "status": "failed", + "podcast_id": podcast.id, + "reason": "billing_resolution_failed", + } + graph_config = { "configurable": { "podcast_title": podcast.title, @@ -109,9 +140,39 @@ async def _generate_content_podcast( db_session=session, ) - graph_result = await podcaster_graph.ainvoke( - initial_state, config=graph_config - ) + try: + async with billable_call( + user_id=owner_user_id, + search_space_id=search_space_id, + billing_tier=billing_tier, + base_model=base_model, + quota_reserve_micros_override=app_config.QUOTA_DEFAULT_PODCAST_RESERVE_MICROS, + usage_type="podcast_generation", + thread_id=podcast.thread_id, + call_details={ + "podcast_id": podcast.id, + "title": podcast.title, + }, + ): + graph_result = await podcaster_graph.ainvoke( + initial_state, config=graph_config + ) + except QuotaInsufficientError as exc: + logger.info( + "Podcast %s denied: out of premium credits " + "(used=%d/%d remaining=%d)", + podcast.id, + exc.used_micros, + exc.limit_micros, + exc.remaining_micros, + ) + podcast.status = PodcastStatus.FAILED + await session.commit() + return { + "status": "failed", + "podcast_id": podcast.id, + "reason": "premium_quota_exhausted", + } podcast_transcript = graph_result.get("podcast_transcript", []) file_path = graph_result.get("final_podcast_file_path", "") diff --git a/surfsense_backend/app/tasks/celery_tasks/video_presentation_tasks.py b/surfsense_backend/app/tasks/celery_tasks/video_presentation_tasks.py index 7880b385f..4f0c427d9 100644 --- a/surfsense_backend/app/tasks/celery_tasks/video_presentation_tasks.py +++ b/surfsense_backend/app/tasks/celery_tasks/video_presentation_tasks.py @@ -9,7 +9,13 @@ from sqlalchemy import select from app.agents.video_presentation.graph import graph as video_presentation_graph from app.agents.video_presentation.state import State as VideoPresentationState from app.celery_app import celery_app +from app.config import config as app_config from app.db import VideoPresentation, VideoPresentationStatus +from app.services.billable_calls import ( + QuotaInsufficientError, + _resolve_agent_billing_for_search_space, + billable_call, +) from app.tasks.celery_tasks import get_celery_session_maker logger = logging.getLogger(__name__) @@ -97,6 +103,32 @@ async def _generate_video_presentation( video_pres.status = VideoPresentationStatus.GENERATING await session.commit() + try: + ( + owner_user_id, + billing_tier, + base_model, + ) = await _resolve_agent_billing_for_search_space( + session, + search_space_id, + thread_id=video_pres.thread_id, + ) + except ValueError as resolve_err: + logger.error( + "VideoPresentation %s: cannot resolve billing for " + "search_space=%s: %s", + video_pres.id, + search_space_id, + resolve_err, + ) + video_pres.status = VideoPresentationStatus.FAILED + await session.commit() + return { + "status": "failed", + "video_presentation_id": video_pres.id, + "reason": "billing_resolution_failed", + } + graph_config = { "configurable": { "video_title": video_pres.title, @@ -110,9 +142,39 @@ async def _generate_video_presentation( db_session=session, ) - graph_result = await video_presentation_graph.ainvoke( - initial_state, config=graph_config - ) + try: + async with billable_call( + user_id=owner_user_id, + search_space_id=search_space_id, + billing_tier=billing_tier, + base_model=base_model, + quota_reserve_micros_override=app_config.QUOTA_DEFAULT_VIDEO_PRESENTATION_RESERVE_MICROS, + usage_type="video_presentation_generation", + thread_id=video_pres.thread_id, + call_details={ + "video_presentation_id": video_pres.id, + "title": video_pres.title, + }, + ): + graph_result = await video_presentation_graph.ainvoke( + initial_state, config=graph_config + ) + except QuotaInsufficientError as exc: + logger.info( + "VideoPresentation %s denied: out of premium credits " + "(used=%d/%d remaining=%d)", + video_pres.id, + exc.used_micros, + exc.limit_micros, + exc.remaining_micros, + ) + video_pres.status = VideoPresentationStatus.FAILED + await session.commit() + return { + "status": "failed", + "video_presentation_id": video_pres.id, + "reason": "premium_quota_exhausted", + } # Serialize slides (parsed content + audio info merged) slides_raw = graph_result.get("slides", []) diff --git a/surfsense_backend/app/tasks/chat/stream_new_chat.py b/surfsense_backend/app/tasks/chat/stream_new_chat.py index dbfe9a67b..31c0d7d6d 100644 --- a/surfsense_backend/app/tasks/chat/stream_new_chat.py +++ b/surfsense_backend/app/tasks/chat/stream_new_chat.py @@ -2236,8 +2236,10 @@ async def stream_new_chat( accumulator = start_turn() - # Premium quota tracking state - _premium_reserved = 0 + # Premium credit (USD micro-units) tracking state. Stores the + # amount reserved up front so we can release it on cancellation + # and finalize-debit the actual provider cost reported by LiteLLM. + _premium_reserved_micros = 0 _premium_request_id: str | None = None _emit_stream_error = partial( @@ -2331,23 +2333,28 @@ async def stream_new_chat( if _needs_premium_quota: import uuid as _uuid - from app.config import config as _app_config - from app.services.token_quota_service import TokenQuotaService + from app.services.token_quota_service import ( + TokenQuotaService, + estimate_call_reserve_micros, + ) _premium_request_id = _uuid.uuid4().hex[:16] - reserve_amount = min( - agent_config.quota_reserve_tokens - or _app_config.QUOTA_MAX_RESERVE_PER_CALL, - _app_config.QUOTA_MAX_RESERVE_PER_CALL, + _agent_litellm_params = agent_config.litellm_params or {} + _agent_base_model = ( + _agent_litellm_params.get("base_model") or agent_config.model_name or "" + ) + reserve_amount_micros = estimate_call_reserve_micros( + base_model=_agent_base_model, + quota_reserve_tokens=agent_config.quota_reserve_tokens, ) async with shielded_async_session() as quota_session: quota_result = await TokenQuotaService.premium_reserve( db_session=quota_session, user_id=UUID(user_id), request_id=_premium_request_id, - reserve_tokens=reserve_amount, + reserve_micros=reserve_amount_micros, ) - _premium_reserved = reserve_amount + _premium_reserved_micros = reserve_amount_micros if not quota_result.allowed: if requested_llm_config_id == 0: try: @@ -2382,7 +2389,7 @@ async def stream_new_chat( yield streaming_service.format_done() return _premium_request_id = None - _premium_reserved = 0 + _premium_reserved_micros = 0 _log_chat_stream_error( flow=flow, error_kind="premium_quota_exhausted", @@ -3020,9 +3027,10 @@ async def stream_new_chat( usage_summary = accumulator.per_message_summary() _perf_log.info( - "[token_usage] interrupted new_chat: calls=%d total=%d summary=%s", + "[token_usage] interrupted new_chat: calls=%d total=%d cost_micros=%d summary=%s", len(accumulator.calls), accumulator.grand_total, + accumulator.total_cost_micros, usage_summary, ) if usage_summary: @@ -3033,6 +3041,7 @@ async def stream_new_chat( "prompt_tokens": accumulator.total_prompt_tokens, "completion_tokens": accumulator.total_completion_tokens, "total_tokens": accumulator.grand_total, + "cost_micros": accumulator.total_cost_micros, "call_details": accumulator.serialized_calls(), }, ) @@ -3060,7 +3069,11 @@ async def stream_new_chat( chat_id, generated_title ) - # Finalize premium quota with actual tokens. + # Finalize premium credit debit with the actual provider cost + # reported by LiteLLM, summed across every call in the turn. + # Mirrors the pre-cost behaviour of "premium turn → all calls + # count" so free sub-agent calls during a premium turn still + # contribute to the bill (they're $0 in practice anyway). if _premium_request_id and user_id: try: from app.services.token_quota_service import TokenQuotaService @@ -3070,11 +3083,11 @@ async def stream_new_chat( db_session=quota_session, user_id=UUID(user_id), request_id=_premium_request_id, - actual_tokens=accumulator.grand_total, - reserved_tokens=_premium_reserved, + actual_micros=accumulator.total_cost_micros, + reserved_micros=_premium_reserved_micros, ) _premium_request_id = None - _premium_reserved = 0 + _premium_reserved_micros = 0 except Exception: logging.getLogger(__name__).warning( "Failed to finalize premium quota for user %s", @@ -3084,9 +3097,10 @@ async def stream_new_chat( usage_summary = accumulator.per_message_summary() _perf_log.info( - "[token_usage] normal new_chat: calls=%d total=%d summary=%s", + "[token_usage] normal new_chat: calls=%d total=%d cost_micros=%d summary=%s", len(accumulator.calls), accumulator.grand_total, + accumulator.total_cost_micros, usage_summary, ) if usage_summary: @@ -3097,6 +3111,7 @@ async def stream_new_chat( "prompt_tokens": accumulator.total_prompt_tokens, "completion_tokens": accumulator.total_completion_tokens, "total_tokens": accumulator.grand_total, + "cost_micros": accumulator.total_cost_micros, "call_details": accumulator.serialized_calls(), }, ) @@ -3190,7 +3205,7 @@ async def stream_new_chat( end_turn(str(chat_id)) # Release premium reservation if not finalized - if _premium_request_id and _premium_reserved > 0 and user_id: + if _premium_request_id and _premium_reserved_micros > 0 and user_id: try: from app.services.token_quota_service import TokenQuotaService @@ -3198,9 +3213,9 @@ async def stream_new_chat( await TokenQuotaService.premium_release( db_session=quota_session, user_id=UUID(user_id), - reserved_tokens=_premium_reserved, + reserved_micros=_premium_reserved_micros, ) - _premium_reserved = 0 + _premium_reserved_micros = 0 except Exception: logging.getLogger(__name__).warning( "Failed to release premium quota for user %s", user_id @@ -3369,8 +3384,8 @@ async def stream_resume_chat( "[stream_resume] LLM config loaded in %.3fs", time.perf_counter() - _t0 ) - # Premium quota reservation (same logic as stream_new_chat) - _resume_premium_reserved = 0 + # Premium credit reservation (same logic as stream_new_chat). + _resume_premium_reserved_micros = 0 _resume_premium_request_id: str | None = None _resume_needs_premium = ( agent_config is not None and user_id and agent_config.is_premium @@ -3378,23 +3393,30 @@ async def stream_resume_chat( if _resume_needs_premium: import uuid as _uuid - from app.config import config as _app_config - from app.services.token_quota_service import TokenQuotaService + from app.services.token_quota_service import ( + TokenQuotaService, + estimate_call_reserve_micros, + ) _resume_premium_request_id = _uuid.uuid4().hex[:16] - reserve_amount = min( - agent_config.quota_reserve_tokens - or _app_config.QUOTA_MAX_RESERVE_PER_CALL, - _app_config.QUOTA_MAX_RESERVE_PER_CALL, + _resume_litellm_params = agent_config.litellm_params or {} + _resume_base_model = ( + _resume_litellm_params.get("base_model") + or agent_config.model_name + or "" + ) + reserve_amount_micros = estimate_call_reserve_micros( + base_model=_resume_base_model, + quota_reserve_tokens=agent_config.quota_reserve_tokens, ) async with shielded_async_session() as quota_session: quota_result = await TokenQuotaService.premium_reserve( db_session=quota_session, user_id=UUID(user_id), request_id=_resume_premium_request_id, - reserve_tokens=reserve_amount, + reserve_micros=reserve_amount_micros, ) - _resume_premium_reserved = reserve_amount + _resume_premium_reserved_micros = reserve_amount_micros if not quota_result.allowed: if requested_llm_config_id == 0: try: @@ -3429,7 +3451,7 @@ async def stream_resume_chat( yield streaming_service.format_done() return _resume_premium_request_id = None - _resume_premium_reserved = 0 + _resume_premium_reserved_micros = 0 _log_chat_stream_error( flow="resume", error_kind="premium_quota_exhausted", @@ -3746,9 +3768,10 @@ async def stream_resume_chat( if stream_result.is_interrupted: usage_summary = accumulator.per_message_summary() _perf_log.info( - "[token_usage] interrupted resume_chat: calls=%d total=%d summary=%s", + "[token_usage] interrupted resume_chat: calls=%d total=%d cost_micros=%d summary=%s", len(accumulator.calls), accumulator.grand_total, + accumulator.total_cost_micros, usage_summary, ) if usage_summary: @@ -3759,6 +3782,7 @@ async def stream_resume_chat( "prompt_tokens": accumulator.total_prompt_tokens, "completion_tokens": accumulator.total_completion_tokens, "total_tokens": accumulator.grand_total, + "cost_micros": accumulator.total_cost_micros, "call_details": accumulator.serialized_calls(), }, ) @@ -3768,7 +3792,9 @@ async def stream_resume_chat( yield streaming_service.format_done() return - # Finalize premium quota for resume path + # Finalize premium credit debit for resume path with the actual + # provider cost reported by LiteLLM (sum of cost across all + # calls in the turn). if _resume_premium_request_id and user_id: try: from app.services.token_quota_service import TokenQuotaService @@ -3778,11 +3804,11 @@ async def stream_resume_chat( db_session=quota_session, user_id=UUID(user_id), request_id=_resume_premium_request_id, - actual_tokens=accumulator.grand_total, - reserved_tokens=_resume_premium_reserved, + actual_micros=accumulator.total_cost_micros, + reserved_micros=_resume_premium_reserved_micros, ) _resume_premium_request_id = None - _resume_premium_reserved = 0 + _resume_premium_reserved_micros = 0 except Exception: logging.getLogger(__name__).warning( "Failed to finalize premium quota for user %s (resume)", @@ -3792,9 +3818,10 @@ async def stream_resume_chat( usage_summary = accumulator.per_message_summary() _perf_log.info( - "[token_usage] normal resume_chat: calls=%d total=%d summary=%s", + "[token_usage] normal resume_chat: calls=%d total=%d cost_micros=%d summary=%s", len(accumulator.calls), accumulator.grand_total, + accumulator.total_cost_micros, usage_summary, ) if usage_summary: @@ -3805,6 +3832,7 @@ async def stream_resume_chat( "prompt_tokens": accumulator.total_prompt_tokens, "completion_tokens": accumulator.total_completion_tokens, "total_tokens": accumulator.grand_total, + "cost_micros": accumulator.total_cost_micros, "call_details": accumulator.serialized_calls(), }, ) @@ -3855,7 +3883,11 @@ async def stream_resume_chat( end_turn(str(chat_id)) # Release premium reservation if not finalized - if _resume_premium_request_id and _resume_premium_reserved > 0 and user_id: + if ( + _resume_premium_request_id + and _resume_premium_reserved_micros > 0 + and user_id + ): try: from app.services.token_quota_service import TokenQuotaService @@ -3863,9 +3895,9 @@ async def stream_resume_chat( await TokenQuotaService.premium_release( db_session=quota_session, user_id=UUID(user_id), - reserved_tokens=_resume_premium_reserved, + reserved_micros=_resume_premium_reserved_micros, ) - _resume_premium_reserved = 0 + _resume_premium_reserved_micros = 0 except Exception: logging.getLogger(__name__).warning( "Failed to release premium quota for user %s (resume)", user_id diff --git a/surfsense_backend/tests/unit/routes/test_image_gen_quota.py b/surfsense_backend/tests/unit/routes/test_image_gen_quota.py new file mode 100644 index 000000000..636b7de31 --- /dev/null +++ b/surfsense_backend/tests/unit/routes/test_image_gen_quota.py @@ -0,0 +1,138 @@ +"""Unit tests for the image-generation route's billing-resolution helper. + +End-to-end "POST /image-generations returns 402" coverage requires the +integration harness (real DB, real auth) and lives in +``tests/integration/document_upload/`` alongside the other quota tests. +This unit test focuses on the new ``_resolve_billing_for_image_gen`` +helper which: + +* Returns ``free`` for Auto mode, even when premium configs exist + (Auto-mode billing-tier surfacing is a follow-up). +* Returns ``free`` for user-owned BYOK configs (positive IDs). +* Returns the global config's ``billing_tier`` for negative IDs. +* Honours the per-config ``quota_reserve_micros`` override when present. +""" + +from __future__ import annotations + +from types import SimpleNamespace + +import pytest + +pytestmark = pytest.mark.unit + + +@pytest.mark.asyncio +async def test_resolve_billing_for_auto_mode(monkeypatch): + from app.routes import image_generation_routes + from app.services.billable_calls import DEFAULT_IMAGE_RESERVE_MICROS + + search_space = SimpleNamespace(image_generation_config_id=None) + tier, model, reserve = await image_generation_routes._resolve_billing_for_image_gen( + session=None, # Not consumed on this code path. + config_id=0, # IMAGE_GEN_AUTO_MODE_ID + search_space=search_space, + ) + assert tier == "free" + assert model == "auto" + assert reserve == DEFAULT_IMAGE_RESERVE_MICROS + + +@pytest.mark.asyncio +async def test_resolve_billing_for_premium_global_config(monkeypatch): + from app.config import config + from app.routes import image_generation_routes + + monkeypatch.setattr( + config, + "GLOBAL_IMAGE_GEN_CONFIGS", + [ + { + "id": -1, + "provider": "OPENAI", + "model_name": "gpt-image-1", + "billing_tier": "premium", + "quota_reserve_micros": 75_000, + }, + { + "id": -2, + "provider": "OPENROUTER", + "model_name": "google/gemini-2.5-flash-image", + "billing_tier": "free", + }, + ], + raising=False, + ) + + search_space = SimpleNamespace(image_generation_config_id=None) + + # Premium with override. + tier, model, reserve = await image_generation_routes._resolve_billing_for_image_gen( + session=None, config_id=-1, search_space=search_space + ) + assert tier == "premium" + assert model == "openai/gpt-image-1" + assert reserve == 75_000 + + # Free, no override → falls back to default. + from app.services.billable_calls import DEFAULT_IMAGE_RESERVE_MICROS + + tier, model, reserve = await image_generation_routes._resolve_billing_for_image_gen( + session=None, config_id=-2, search_space=search_space + ) + assert tier == "free" + # Provider-prefixed model string for OpenRouter. + assert "google/gemini-2.5-flash-image" in model + assert reserve == DEFAULT_IMAGE_RESERVE_MICROS + + +@pytest.mark.asyncio +async def test_resolve_billing_for_user_owned_byok_is_free(): + """User-owned BYOK configs (positive IDs) cost the user nothing on + our side — they pay the provider directly. Always free. + """ + from app.routes import image_generation_routes + from app.services.billable_calls import DEFAULT_IMAGE_RESERVE_MICROS + + search_space = SimpleNamespace(image_generation_config_id=None) + tier, model, reserve = await image_generation_routes._resolve_billing_for_image_gen( + session=None, config_id=42, search_space=search_space + ) + assert tier == "free" + assert model == "user_byok" + assert reserve == DEFAULT_IMAGE_RESERVE_MICROS + + +@pytest.mark.asyncio +async def test_resolve_billing_falls_back_to_search_space_default(monkeypatch): + """When the request omits ``image_generation_config_id``, the helper + must consult the search space's default — so a search space pinned + to a premium global config still gates new requests by quota. + """ + from app.config import config + from app.routes import image_generation_routes + + monkeypatch.setattr( + config, + "GLOBAL_IMAGE_GEN_CONFIGS", + [ + { + "id": -7, + "provider": "OPENAI", + "model_name": "gpt-image-1", + "billing_tier": "premium", + } + ], + raising=False, + ) + + search_space = SimpleNamespace(image_generation_config_id=-7) + ( + tier, + model, + _reserve, + ) = await image_generation_routes._resolve_billing_for_image_gen( + session=None, config_id=None, search_space=search_space + ) + assert tier == "premium" + assert model == "openai/gpt-image-1" diff --git a/surfsense_backend/tests/unit/services/test_agent_billing_resolver.py b/surfsense_backend/tests/unit/services/test_agent_billing_resolver.py new file mode 100644 index 000000000..fa8819b39 --- /dev/null +++ b/surfsense_backend/tests/unit/services/test_agent_billing_resolver.py @@ -0,0 +1,436 @@ +"""Unit tests for ``_resolve_agent_billing_for_search_space``. + +Validates the resolver used by Celery podcast/video tasks to compute +``(owner_user_id, billing_tier, base_model)`` from a search space and its +agent LLM config. The resolver mirrors chat's billing-resolution pattern at +``stream_new_chat.py:2294-2351`` and is the single integration point that +prevents Auto-mode podcast/video from leaking premium credit. + +Coverage: + +* Auto mode + ``thread_id`` set, pin resolves to a negative-id premium + global → returns ``("premium", )``. +* Auto mode + ``thread_id`` set, pin resolves to a negative-id free + global → returns ``("free", )``. +* Auto mode + ``thread_id`` set, pin resolves to a positive-id BYOK config + → always ``"free"``. +* Auto mode + ``thread_id=None`` → fallback to ``("free", "auto")`` without + hitting the pin service. +* Negative id (no Auto) → uses ``get_global_llm_config``'s + ``billing_tier``. +* Positive id (user BYOK) → always ``"free"``. +* Search space not found → raises ``ValueError``. +* ``agent_llm_id`` is None → raises ``ValueError``. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from types import SimpleNamespace +from uuid import UUID, uuid4 + +import pytest + +pytestmark = pytest.mark.unit + + +# --------------------------------------------------------------------------- +# Fakes +# --------------------------------------------------------------------------- + + +class _FakeExecResult: + def __init__(self, obj): + self._obj = obj + + def scalars(self): + return self + + def first(self): + return self._obj + + +class _FakeSession: + """Tiny AsyncSession stub. + + ``responses`` is a list of objects to return from successive + ``execute()`` calls (in order). The resolver makes at most two + ``execute()`` calls (search-space lookup, then optionally NewLLMConfig + lookup), so two queued responses cover the matrix. + """ + + def __init__(self, responses: list): + self._responses = list(responses) + + async def execute(self, _stmt): + if not self._responses: + return _FakeExecResult(None) + return _FakeExecResult(self._responses.pop(0)) + + async def commit(self) -> None: + pass + + +@dataclass +class _FakePinResolution: + resolved_llm_config_id: int + resolved_tier: str = "premium" + from_existing_pin: bool = False + + +def _make_search_space(*, agent_llm_id: int | None, user_id: UUID) -> SimpleNamespace: + return SimpleNamespace( + id=42, + agent_llm_id=agent_llm_id, + user_id=user_id, + ) + + +def _make_byok_config( + *, id_: int, base_model: str | None = None, model_name: str = "gpt-byok" +) -> SimpleNamespace: + return SimpleNamespace( + id=id_, + model_name=model_name, + litellm_params={"base_model": base_model} if base_model else {}, + ) + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_auto_mode_with_thread_id_resolves_to_premium_global(monkeypatch): + """Auto + thread → pin service resolves to negative-id premium config → + resolver returns ``("premium", )``.""" + from app.services.billable_calls import _resolve_agent_billing_for_search_space + + user_id = uuid4() + session = _FakeSession([_make_search_space(agent_llm_id=0, user_id=user_id)]) + + # Mock the pin service to return a concrete premium config id. + async def _fake_resolve_pin( + sess, + *, + thread_id, + search_space_id, + user_id, + selected_llm_config_id, + force_repin_free=False, + ): + assert selected_llm_config_id == 0 + assert thread_id == 99 + return _FakePinResolution(resolved_llm_config_id=-1, resolved_tier="premium") + + # Mock global config lookup to return a premium entry. + def _fake_get_global(cfg_id): + if cfg_id == -1: + return { + "id": -1, + "model_name": "gpt-5.4", + "billing_tier": "premium", + "litellm_params": {"base_model": "gpt-5.4"}, + } + return None + + # Lazy imports inside the resolver — patch the *target* modules so the + # imported names resolve to our fakes. + import app.services.auto_model_pin_service as pin_module + import app.services.llm_service as llm_module + + monkeypatch.setattr( + pin_module, "resolve_or_get_pinned_llm_config_id", _fake_resolve_pin + ) + monkeypatch.setattr(llm_module, "get_global_llm_config", _fake_get_global) + + owner, tier, base_model = await _resolve_agent_billing_for_search_space( + session, search_space_id=42, thread_id=99 + ) + + assert owner == user_id + assert tier == "premium" + assert base_model == "gpt-5.4" + + +@pytest.mark.asyncio +async def test_auto_mode_with_thread_id_resolves_to_free_global(monkeypatch): + """Auto + thread → pin returns negative-id free config → resolver + returns ``("free", )``. Same path the pin service takes for + out-of-credit users (graceful degradation).""" + from app.services.billable_calls import _resolve_agent_billing_for_search_space + + user_id = uuid4() + session = _FakeSession([_make_search_space(agent_llm_id=0, user_id=user_id)]) + + async def _fake_resolve_pin( + sess, + *, + thread_id, + search_space_id, + user_id, + selected_llm_config_id, + force_repin_free=False, + ): + return _FakePinResolution(resolved_llm_config_id=-3, resolved_tier="free") + + def _fake_get_global(cfg_id): + if cfg_id == -3: + return { + "id": -3, + "model_name": "openrouter/free-model", + "billing_tier": "free", + "litellm_params": {"base_model": "openrouter/free-model"}, + } + return None + + import app.services.auto_model_pin_service as pin_module + import app.services.llm_service as llm_module + + monkeypatch.setattr( + pin_module, "resolve_or_get_pinned_llm_config_id", _fake_resolve_pin + ) + monkeypatch.setattr(llm_module, "get_global_llm_config", _fake_get_global) + + owner, tier, base_model = await _resolve_agent_billing_for_search_space( + session, search_space_id=42, thread_id=99 + ) + + assert owner == user_id + assert tier == "free" + assert base_model == "openrouter/free-model" + + +@pytest.mark.asyncio +async def test_auto_mode_with_thread_id_resolves_to_byok_is_free(monkeypatch): + """Auto + thread → pin returns positive-id BYOK config → resolver + returns ``("free", ...)`` (BYOK is always free per + ``AgentConfig.from_new_llm_config``).""" + from app.services.billable_calls import _resolve_agent_billing_for_search_space + + user_id = uuid4() + search_space = _make_search_space(agent_llm_id=0, user_id=user_id) + byok_cfg = _make_byok_config( + id_=17, base_model="anthropic/claude-3-haiku", model_name="my-claude" + ) + session = _FakeSession([search_space, byok_cfg]) + + async def _fake_resolve_pin( + sess, + *, + thread_id, + search_space_id, + user_id, + selected_llm_config_id, + force_repin_free=False, + ): + return _FakePinResolution(resolved_llm_config_id=17, resolved_tier="free") + + import app.services.auto_model_pin_service as pin_module + + monkeypatch.setattr( + pin_module, "resolve_or_get_pinned_llm_config_id", _fake_resolve_pin + ) + + owner, tier, base_model = await _resolve_agent_billing_for_search_space( + session, search_space_id=42, thread_id=99 + ) + + assert owner == user_id + assert tier == "free" + assert base_model == "anthropic/claude-3-haiku" + + +@pytest.mark.asyncio +async def test_auto_mode_without_thread_id_falls_back_to_free(): + """Auto + ``thread_id=None`` → ``("free", "auto")`` without invoking + the pin service. Forward-compat fallback for any future direct-API + entrypoint that doesn't have a chat thread.""" + from app.services.billable_calls import _resolve_agent_billing_for_search_space + + user_id = uuid4() + session = _FakeSession([_make_search_space(agent_llm_id=0, user_id=user_id)]) + + owner, tier, base_model = await _resolve_agent_billing_for_search_space( + session, search_space_id=42, thread_id=None + ) + + assert owner == user_id + assert tier == "free" + assert base_model == "auto" + + +@pytest.mark.asyncio +async def test_auto_mode_pin_failure_falls_back_to_free(monkeypatch): + """If the pin service raises ``ValueError`` (thread missing / + mismatched search space), the resolver should log and return free + rather than killing the whole task.""" + from app.services.billable_calls import _resolve_agent_billing_for_search_space + + user_id = uuid4() + session = _FakeSession([_make_search_space(agent_llm_id=0, user_id=user_id)]) + + async def _fake_resolve_pin(*args, **kwargs): + raise ValueError("thread missing") + + import app.services.auto_model_pin_service as pin_module + + monkeypatch.setattr( + pin_module, "resolve_or_get_pinned_llm_config_id", _fake_resolve_pin + ) + + owner, tier, base_model = await _resolve_agent_billing_for_search_space( + session, search_space_id=42, thread_id=99 + ) + + assert owner == user_id + assert tier == "free" + assert base_model == "auto" + + +@pytest.mark.asyncio +async def test_negative_id_premium_global_returns_premium(monkeypatch): + """Explicit negative agent_llm_id → ``get_global_llm_config`` → + return its ``billing_tier``.""" + from app.services.billable_calls import _resolve_agent_billing_for_search_space + + user_id = uuid4() + session = _FakeSession([_make_search_space(agent_llm_id=-1, user_id=user_id)]) + + def _fake_get_global(cfg_id): + return { + "id": cfg_id, + "model_name": "gpt-5.4", + "billing_tier": "premium", + "litellm_params": {"base_model": "gpt-5.4"}, + } + + import app.services.llm_service as llm_module + + monkeypatch.setattr(llm_module, "get_global_llm_config", _fake_get_global) + + owner, tier, base_model = await _resolve_agent_billing_for_search_space( + session, search_space_id=42, thread_id=99 + ) + + assert owner == user_id + assert tier == "premium" + assert base_model == "gpt-5.4" + + +@pytest.mark.asyncio +async def test_negative_id_free_global_returns_free(monkeypatch): + from app.services.billable_calls import _resolve_agent_billing_for_search_space + + user_id = uuid4() + session = _FakeSession([_make_search_space(agent_llm_id=-2, user_id=user_id)]) + + def _fake_get_global(cfg_id): + return { + "id": cfg_id, + "model_name": "openrouter/some-free", + "billing_tier": "free", + "litellm_params": {"base_model": "openrouter/some-free"}, + } + + import app.services.llm_service as llm_module + + monkeypatch.setattr(llm_module, "get_global_llm_config", _fake_get_global) + + owner, tier, base_model = await _resolve_agent_billing_for_search_space( + session, search_space_id=42, thread_id=None + ) + + assert owner == user_id + assert tier == "free" + assert base_model == "openrouter/some-free" + + +@pytest.mark.asyncio +async def test_negative_id_missing_base_model_falls_back_to_model_name(monkeypatch): + """When the global config has no ``litellm_params.base_model``, the + resolver falls back to ``model_name`` — matching chat's behavior.""" + from app.services.billable_calls import _resolve_agent_billing_for_search_space + + user_id = uuid4() + session = _FakeSession([_make_search_space(agent_llm_id=-5, user_id=user_id)]) + + def _fake_get_global(cfg_id): + return { + "id": cfg_id, + "model_name": "fallback-model", + "billing_tier": "premium", + # No litellm_params. + } + + import app.services.llm_service as llm_module + + monkeypatch.setattr(llm_module, "get_global_llm_config", _fake_get_global) + + _, tier, base_model = await _resolve_agent_billing_for_search_space( + session, search_space_id=42 + ) + + assert tier == "premium" + assert base_model == "fallback-model" + + +@pytest.mark.asyncio +async def test_positive_id_byok_is_always_free(): + """Positive agent_llm_id → user-owned BYOK NewLLMConfig → always free, + regardless of underlying provider tier.""" + from app.services.billable_calls import _resolve_agent_billing_for_search_space + + user_id = uuid4() + search_space = _make_search_space(agent_llm_id=23, user_id=user_id) + byok_cfg = _make_byok_config(id_=23, base_model="anthropic/claude-3.5-sonnet") + session = _FakeSession([search_space, byok_cfg]) + + owner, tier, base_model = await _resolve_agent_billing_for_search_space( + session, search_space_id=42 + ) + + assert owner == user_id + assert tier == "free" + assert base_model == "anthropic/claude-3.5-sonnet" + + +@pytest.mark.asyncio +async def test_positive_id_byok_missing_returns_free_with_empty_base_model(): + """If the BYOK config row is missing/deleted but the search space still + points at it, the resolver still returns free (no debit) with an empty + base_model — billable_call's premium path is skipped, no harm done.""" + from app.services.billable_calls import _resolve_agent_billing_for_search_space + + user_id = uuid4() + session = _FakeSession([_make_search_space(agent_llm_id=99, user_id=user_id)]) + + owner, tier, base_model = await _resolve_agent_billing_for_search_space( + session, search_space_id=42 + ) + + assert owner == user_id + assert tier == "free" + assert base_model == "" + + +@pytest.mark.asyncio +async def test_search_space_not_found_raises_value_error(): + from app.services.billable_calls import _resolve_agent_billing_for_search_space + + session = _FakeSession([None]) + + with pytest.raises(ValueError, match="Search space"): + await _resolve_agent_billing_for_search_space(session, search_space_id=999) + + +@pytest.mark.asyncio +async def test_agent_llm_id_none_raises_value_error(): + from app.services.billable_calls import _resolve_agent_billing_for_search_space + + user_id = uuid4() + session = _FakeSession([_make_search_space(agent_llm_id=None, user_id=user_id)]) + + with pytest.raises(ValueError, match="agent_llm_id"): + await _resolve_agent_billing_for_search_space(session, search_space_id=42) diff --git a/surfsense_backend/tests/unit/services/test_billable_call.py b/surfsense_backend/tests/unit/services/test_billable_call.py new file mode 100644 index 000000000..86de5f23d --- /dev/null +++ b/surfsense_backend/tests/unit/services/test_billable_call.py @@ -0,0 +1,432 @@ +"""Unit tests for the ``billable_call`` async context manager. + +Covers the per-call premium-credit lifecycle for image generation and +vision LLM extraction: + +* Free configs bypass reserve/finalize but still write an audit row. +* Premium reserve denial raises ``QuotaInsufficientError`` (HTTP 402 in the + route layer). +* Successful premium calls reserve, yield the accumulator, then finalize + with the LiteLLM-reported actual cost — and write an audit row. +* Failed premium calls release the reservation so credit isn't leaked. +* All quota DB ops happen inside their OWN ``shielded_async_session``, + isolating them from the caller's transaction (issue A). +""" + +from __future__ import annotations + +import contextlib +from typing import Any +from uuid import uuid4 + +import pytest + +pytestmark = pytest.mark.unit + + +# --------------------------------------------------------------------------- +# Fakes +# --------------------------------------------------------------------------- + + +class _FakeQuotaResult: + def __init__( + self, + *, + allowed: bool, + used: int = 0, + limit: int = 5_000_000, + remaining: int = 5_000_000, + ) -> None: + self.allowed = allowed + self.used = used + self.limit = limit + self.remaining = remaining + + +class _FakeSession: + """Minimal AsyncSession stub — record commits for assertion.""" + + def __init__(self) -> None: + self.committed = False + self.added: list[Any] = [] + + def add(self, obj: Any) -> None: + self.added.append(obj) + + async def commit(self) -> None: + self.committed = True + + async def close(self) -> None: + pass + + +@contextlib.asynccontextmanager +async def _fake_shielded_session(): + s = _FakeSession() + _SESSIONS_USED.append(s) + yield s + + +_SESSIONS_USED: list[_FakeSession] = [] + + +def _patch_isolation_layer(monkeypatch, *, reserve_result, finalize_result=None): + """Wire fake reserve/finalize/release/session helpers.""" + _SESSIONS_USED.clear() + reserve_calls: list[dict[str, Any]] = [] + finalize_calls: list[dict[str, Any]] = [] + release_calls: list[dict[str, Any]] = [] + + async def _fake_reserve(*, db_session, user_id, request_id, reserve_micros): + reserve_calls.append( + { + "user_id": user_id, + "reserve_micros": reserve_micros, + "request_id": request_id, + } + ) + return reserve_result + + async def _fake_finalize( + *, db_session, user_id, request_id, actual_micros, reserved_micros + ): + finalize_calls.append( + { + "user_id": user_id, + "actual_micros": actual_micros, + "reserved_micros": reserved_micros, + } + ) + return finalize_result or _FakeQuotaResult(allowed=True) + + async def _fake_release(*, db_session, user_id, reserved_micros): + release_calls.append({"user_id": user_id, "reserved_micros": reserved_micros}) + + record_calls: list[dict[str, Any]] = [] + + async def _fake_record(session, **kwargs): + record_calls.append(kwargs) + return object() + + monkeypatch.setattr( + "app.services.billable_calls.TokenQuotaService.premium_reserve", + _fake_reserve, + raising=False, + ) + monkeypatch.setattr( + "app.services.billable_calls.TokenQuotaService.premium_finalize", + _fake_finalize, + raising=False, + ) + monkeypatch.setattr( + "app.services.billable_calls.TokenQuotaService.premium_release", + _fake_release, + raising=False, + ) + monkeypatch.setattr( + "app.services.billable_calls.shielded_async_session", + _fake_shielded_session, + raising=False, + ) + monkeypatch.setattr( + "app.services.billable_calls.record_token_usage", + _fake_record, + raising=False, + ) + + return { + "reserve": reserve_calls, + "finalize": finalize_calls, + "release": release_calls, + "record": record_calls, + } + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_free_path_skips_reserve_but_writes_audit_row(monkeypatch): + from app.services.billable_calls import billable_call + + spies = _patch_isolation_layer( + monkeypatch, reserve_result=_FakeQuotaResult(allowed=True) + ) + user_id = uuid4() + + async with billable_call( + user_id=user_id, + search_space_id=42, + billing_tier="free", + base_model="openai/gpt-image-1", + usage_type="image_generation", + ) as acc: + # Simulate a captured cost — the accumulator is fed by the LiteLLM + # callback in real life, here we add it manually. + acc.add( + model="openai/gpt-image-1", + prompt_tokens=0, + completion_tokens=0, + total_tokens=0, + cost_micros=37_000, + call_kind="image_generation", + ) + + assert spies["reserve"] == [] + assert spies["finalize"] == [] + assert spies["release"] == [] + # Free still audits. + assert len(spies["record"]) == 1 + assert spies["record"][0]["usage_type"] == "image_generation" + assert spies["record"][0]["cost_micros"] == 37_000 + + +@pytest.mark.asyncio +async def test_premium_reserve_denied_raises_quota_insufficient(monkeypatch): + from app.services.billable_calls import ( + QuotaInsufficientError, + billable_call, + ) + + spies = _patch_isolation_layer( + monkeypatch, + reserve_result=_FakeQuotaResult( + allowed=False, used=5_000_000, limit=5_000_000, remaining=0 + ), + ) + user_id = uuid4() + + with pytest.raises(QuotaInsufficientError) as exc_info: + async with billable_call( + user_id=user_id, + search_space_id=42, + billing_tier="premium", + base_model="openai/gpt-image-1", + quota_reserve_micros_override=50_000, + usage_type="image_generation", + ): + pytest.fail("body should not run when reserve is denied") + + err = exc_info.value + assert err.usage_type == "image_generation" + assert err.used_micros == 5_000_000 + assert err.limit_micros == 5_000_000 + assert err.remaining_micros == 0 + # Reserve was attempted, but no finalize/release on a denied reserve + # — the reservation never actually held credit. + assert len(spies["reserve"]) == 1 + assert spies["finalize"] == [] + assert spies["release"] == [] + # Denied premium calls do NOT create an audit row (no work happened). + assert spies["record"] == [] + + +@pytest.mark.asyncio +async def test_premium_success_finalizes_with_actual_cost(monkeypatch): + from app.services.billable_calls import billable_call + + spies = _patch_isolation_layer( + monkeypatch, reserve_result=_FakeQuotaResult(allowed=True) + ) + user_id = uuid4() + + async with billable_call( + user_id=user_id, + search_space_id=42, + billing_tier="premium", + base_model="openai/gpt-image-1", + quota_reserve_micros_override=50_000, + usage_type="image_generation", + ) as acc: + # LiteLLM callback would normally fill this — simulate $0.04 image. + acc.add( + model="openai/gpt-image-1", + prompt_tokens=0, + completion_tokens=0, + total_tokens=0, + cost_micros=40_000, + call_kind="image_generation", + ) + + assert len(spies["reserve"]) == 1 + assert spies["reserve"][0]["reserve_micros"] == 50_000 + assert len(spies["finalize"]) == 1 + assert spies["finalize"][0]["actual_micros"] == 40_000 + assert spies["finalize"][0]["reserved_micros"] == 50_000 + assert spies["release"] == [] + # And audit row written with the actual debited cost. + assert spies["record"][0]["cost_micros"] == 40_000 + # Each quota op opened its OWN session — proves session isolation. + assert len(_SESSIONS_USED) >= 3 + # Sessions used should each have committed (or be the audit one which commits). + for _s in _SESSIONS_USED: + # finalize/reserve happen via TokenQuotaService.* which we stub — + # they don't actually call commit on our fake session, but the + # audit session does. We just assert >=1 session committed. + pass + assert any(s.committed for s in _SESSIONS_USED) + + +@pytest.mark.asyncio +async def test_premium_failure_releases_reservation(monkeypatch): + from app.services.billable_calls import billable_call + + spies = _patch_isolation_layer( + monkeypatch, reserve_result=_FakeQuotaResult(allowed=True) + ) + user_id = uuid4() + + class _ProviderError(Exception): + pass + + with pytest.raises(_ProviderError): + async with billable_call( + user_id=user_id, + search_space_id=42, + billing_tier="premium", + base_model="openai/gpt-image-1", + quota_reserve_micros_override=50_000, + usage_type="image_generation", + ): + raise _ProviderError("OpenRouter 503") + + assert len(spies["reserve"]) == 1 + assert spies["finalize"] == [] + # Failure path: release the held reservation. + assert len(spies["release"]) == 1 + assert spies["release"][0]["reserved_micros"] == 50_000 + + +@pytest.mark.asyncio +async def test_premium_uses_estimator_when_no_micros_override(monkeypatch): + """When ``quota_reserve_micros_override`` is None we fall back to + ``estimate_call_reserve_micros(base_model, quota_reserve_tokens)``. + Vision LLM calls take this path (token-priced models). + """ + from app.services.billable_calls import billable_call + + spies = _patch_isolation_layer( + monkeypatch, reserve_result=_FakeQuotaResult(allowed=True) + ) + + captured_estimator_calls: list[dict[str, Any]] = [] + + def _fake_estimate(*, base_model, quota_reserve_tokens): + captured_estimator_calls.append( + {"base_model": base_model, "quota_reserve_tokens": quota_reserve_tokens} + ) + return 12_345 + + monkeypatch.setattr( + "app.services.billable_calls.estimate_call_reserve_micros", + _fake_estimate, + raising=False, + ) + + user_id = uuid4() + async with billable_call( + user_id=user_id, + search_space_id=1, + billing_tier="premium", + base_model="openai/gpt-4o", + quota_reserve_tokens=4000, + usage_type="vision_extraction", + ): + pass + + assert captured_estimator_calls == [ + {"base_model": "openai/gpt-4o", "quota_reserve_tokens": 4000} + ] + assert spies["reserve"][0]["reserve_micros"] == 12_345 + + +# --------------------------------------------------------------------------- +# Podcast / video-presentation usage_type coverage +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_free_podcast_path_audits_with_podcast_usage_type(monkeypatch): + """Free podcast configs must skip reserve/finalize but still emit a + ``TokenUsage`` row tagged ``usage_type='podcast_generation'`` so we + have full audit coverage of free-tier agent runs.""" + from app.services.billable_calls import billable_call + + spies = _patch_isolation_layer( + monkeypatch, reserve_result=_FakeQuotaResult(allowed=True) + ) + user_id = uuid4() + + async with billable_call( + user_id=user_id, + search_space_id=42, + billing_tier="free", + base_model="openrouter/some-free-model", + quota_reserve_micros_override=200_000, + usage_type="podcast_generation", + thread_id=99, + call_details={"podcast_id": 7, "title": "Test Podcast"}, + ) as acc: + # Two transcript LLM calls aggregated into one accumulator. + acc.add( + model="openrouter/some-free-model", + prompt_tokens=1500, + completion_tokens=8000, + total_tokens=9500, + cost_micros=0, + call_kind="chat", + ) + + assert spies["reserve"] == [] + assert spies["finalize"] == [] + assert spies["release"] == [] + + assert len(spies["record"]) == 1 + row = spies["record"][0] + assert row["usage_type"] == "podcast_generation" + assert row["thread_id"] == 99 + assert row["search_space_id"] == 42 + assert row["call_details"] == {"podcast_id": 7, "title": "Test Podcast"} + + +@pytest.mark.asyncio +async def test_premium_video_denial_raises_quota_insufficient(monkeypatch): + """Premium video-presentation runs that hit a denied reservation must + raise ``QuotaInsufficientError`` *before* the graph runs and must not + emit an audit row (no work happened).""" + from app.services.billable_calls import ( + QuotaInsufficientError, + billable_call, + ) + + spies = _patch_isolation_layer( + monkeypatch, + reserve_result=_FakeQuotaResult( + allowed=False, used=4_500_000, limit=5_000_000, remaining=500_000 + ), + ) + user_id = uuid4() + + with pytest.raises(QuotaInsufficientError) as exc_info: + async with billable_call( + user_id=user_id, + search_space_id=42, + billing_tier="premium", + base_model="gpt-5.4", + quota_reserve_micros_override=1_000_000, + usage_type="video_presentation_generation", + thread_id=99, + call_details={"video_presentation_id": 12, "title": "Test Video"}, + ): + pytest.fail("body should not run when reserve is denied") + + err = exc_info.value + assert err.usage_type == "video_presentation_generation" + assert err.remaining_micros == 500_000 + assert spies["reserve"][0]["reserve_micros"] == 1_000_000 + assert spies["finalize"] == [] + assert spies["release"] == [] + assert spies["record"] == [] diff --git a/surfsense_backend/tests/unit/services/test_openrouter_integration_service.py b/surfsense_backend/tests/unit/services/test_openrouter_integration_service.py index 085740032..b635b4fe8 100644 --- a/surfsense_backend/tests/unit/services/test_openrouter_integration_service.py +++ b/surfsense_backend/tests/unit/services/test_openrouter_integration_service.py @@ -214,3 +214,159 @@ def test_generate_configs_drops_non_text_and_non_tool_models(): assert "openai/gpt-4o" in model_names assert "openai/dall-e" not in model_names assert "openai/completion-only" not in model_names + + +# --------------------------------------------------------------------------- +# _generate_image_gen_configs / _generate_vision_llm_configs +# --------------------------------------------------------------------------- + + +def test_generate_image_gen_configs_filters_by_image_output(): + """Only models with ``output_modalities`` containing ``image`` are emitted. + Tool-calling and context filters are intentionally NOT applied — image + generation has nothing to do with tool calls and context windows. + """ + from app.services.openrouter_integration_service import ( + _generate_image_gen_configs, + ) + + raw = [ + # Pure image-gen model (small context, no tools — should still emit). + { + "id": "openai/gpt-image-1", + "architecture": {"output_modalities": ["image"]}, + "context_length": 4_000, + "pricing": {"prompt": "0", "completion": "0"}, + }, + # Multi-modal: text+image output (should still emit). + { + "id": "google/gemini-2.5-flash-image", + "architecture": {"output_modalities": ["text", "image"]}, + "context_length": 1_000_000, + "pricing": {"prompt": "0.000001", "completion": "0.000004"}, + }, + # Pure text model — must NOT emit. + { + "id": "openai/gpt-4o", + "architecture": {"output_modalities": ["text"]}, + "context_length": 128_000, + "pricing": {"prompt": "0.000005", "completion": "0.000015"}, + }, + ] + + cfgs = _generate_image_gen_configs(raw, dict(_SETTINGS_BASE)) + model_names = {c["model_name"] for c in cfgs} + assert "openai/gpt-image-1" in model_names + assert "google/gemini-2.5-flash-image" in model_names + assert "openai/gpt-4o" not in model_names + + # Each config must carry ``billing_tier`` for routing in image_generation_routes. + for c in cfgs: + assert c["billing_tier"] in {"free", "premium"} + assert c["provider"] == "OPENROUTER" + assert c[_OPENROUTER_DYNAMIC_MARKER] is True + + +def test_generate_image_gen_configs_assigns_image_id_offset(): + """Image configs use a different id_offset (-20000) so their negative + IDs don't collide with chat configs (-10000) or vision configs (-30000). + """ + from app.services.openrouter_integration_service import ( + _generate_image_gen_configs, + ) + + raw = [ + { + "id": "openai/gpt-image-1", + "architecture": {"output_modalities": ["image"]}, + "context_length": 4_000, + "pricing": {"prompt": "0", "completion": "0"}, + } + ] + # Don't pass image_id_offset → use the module default (-20000). + cfgs = _generate_image_gen_configs(raw, dict(_SETTINGS_BASE)) + assert all(c["id"] < -20_000 + 1 for c in cfgs) + assert all(c["id"] > -29_000_000 for c in cfgs) + + +def test_generate_vision_llm_configs_filters_by_image_input_text_output(): + """Vision LLMs must accept image input AND emit text — pure image-gen + (no text out) and text-only (no image in) models are excluded. + """ + from app.services.openrouter_integration_service import ( + _generate_vision_llm_configs, + ) + + raw = [ + # GPT-4o: vision LLM (image in, text out) — must emit. + { + "id": "openai/gpt-4o", + "architecture": { + "input_modalities": ["text", "image"], + "output_modalities": ["text"], + }, + "context_length": 128_000, + "pricing": {"prompt": "0.000005", "completion": "0.000015"}, + }, + # Pure image generator — image *output*, no text out. Must NOT emit. + { + "id": "openai/gpt-image-1", + "architecture": { + "input_modalities": ["text"], + "output_modalities": ["image"], + }, + "context_length": 4_000, + "pricing": {"prompt": "0", "completion": "0"}, + }, + # Pure text model (no image in). Must NOT emit. + { + "id": "anthropic/claude-3-haiku", + "architecture": { + "input_modalities": ["text"], + "output_modalities": ["text"], + }, + "context_length": 200_000, + "pricing": {"prompt": "0.000001", "completion": "0.000005"}, + }, + ] + + cfgs = _generate_vision_llm_configs(raw, dict(_SETTINGS_BASE)) + names = {c["model_name"] for c in cfgs} + assert names == {"openai/gpt-4o"} + + cfg = cfgs[0] + assert cfg["billing_tier"] == "premium" + # Pricing carried inline so pricing_registration can register vision + # under ``openrouter/openai/gpt-4o`` even if the chat catalogue cache + # is cleared. + assert cfg["input_cost_per_token"] == pytest.approx(5e-6) + assert cfg["output_cost_per_token"] == pytest.approx(15e-6) + assert cfg[_OPENROUTER_DYNAMIC_MARKER] is True + + +def test_generate_vision_llm_configs_drops_chat_only_filters(): + """A small-context vision model that doesn't advertise tool calling is + still a valid vision LLM for "describe this image" prompts. The chat + filters (``supports_tool_calling``, ``has_sufficient_context``) must + NOT be applied to vision emission. + """ + from app.services.openrouter_integration_service import ( + _generate_vision_llm_configs, + ) + + raw = [ + { + "id": "tiny/vision-mini", + "architecture": { + "input_modalities": ["text", "image"], + "output_modalities": ["text"], + }, + "supported_parameters": [], # no tools + "context_length": 4_000, # well below MIN_CONTEXT_LENGTH + "pricing": {"prompt": "0.0000001", "completion": "0.0000005"}, + } + ] + + cfgs = _generate_vision_llm_configs(raw, dict(_SETTINGS_BASE)) + assert len(cfgs) == 1 + assert cfgs[0]["model_name"] == "tiny/vision-mini" diff --git a/surfsense_backend/tests/unit/services/test_pricing_registration.py b/surfsense_backend/tests/unit/services/test_pricing_registration.py new file mode 100644 index 000000000..e97250ff2 --- /dev/null +++ b/surfsense_backend/tests/unit/services/test_pricing_registration.py @@ -0,0 +1,447 @@ +"""Pricing registration unit tests. + +The pricing-registration module is what makes ``response_cost`` populate +correctly for OpenRouter dynamic models and operator-defined Azure +deployments — both of which LiteLLM doesn't natively know about. The tests +exercise: + +* The alias generators emit every shape that LiteLLM's cost-callback might + use (``openrouter/X`` and bare ``X``; YAML-defined ``base_model``, + ``provider/base_model``, ``provider/model_name``, plus the special + ``azure_openai`` → ``azure`` normalisation). +* ``register_pricing_from_global_configs`` calls ``litellm.register_model`` + with the right alias set and pricing values per provider. +* Configs without a resolvable pair of cost values are skipped — never + registered as zero, since that would override pricing LiteLLM might + already know natively. +""" + +from __future__ import annotations + +from typing import Any + +import pytest + +pytestmark = pytest.mark.unit + + +# --------------------------------------------------------------------------- +# Alias generators +# --------------------------------------------------------------------------- + + +def test_openrouter_alias_set_includes_prefixed_and_bare(): + from app.services.pricing_registration import _alias_set_for_openrouter + + aliases = _alias_set_for_openrouter("anthropic/claude-3-5-sonnet") + assert aliases == [ + "openrouter/anthropic/claude-3-5-sonnet", + "anthropic/claude-3-5-sonnet", + ] + + +def test_openrouter_alias_set_dedupes(): + """If the model id is already prefixed with ``openrouter/``, the alias + set must not contain duplicates that would re-register the same key + twice. + """ + from app.services.pricing_registration import _alias_set_for_openrouter + + aliases = _alias_set_for_openrouter("openrouter/foo") + # The bare and prefixed variants compute to the same string here, so we + # at minimum require uniqueness. + assert len(aliases) == len(set(aliases)) + + +def test_yaml_alias_set_for_azure_openai_normalises_to_azure(): + """``azure_openai`` (our YAML provider slug) must register under + ``azure/`` so the LiteLLM Router's deployment-resolution path + (which uses provider ``azure``) finds the pricing too. + """ + from app.services.pricing_registration import _alias_set_for_yaml + + aliases = _alias_set_for_yaml( + provider="AZURE_OPENAI", + model_name="gpt-5.4", + base_model="gpt-5.4", + ) + assert "gpt-5.4" in aliases + assert "azure_openai/gpt-5.4" in aliases + assert "azure/gpt-5.4" in aliases + + +def test_yaml_alias_set_distinguishes_model_name_and_base_model(): + """When ``model_name`` differs from ``base_model`` (operator labelled a + deployment), both must appear in the alias set since either may surface + in callbacks depending on the call path. + """ + from app.services.pricing_registration import _alias_set_for_yaml + + aliases = _alias_set_for_yaml( + provider="OPENAI", + model_name="my-deployment-label", + base_model="gpt-4o", + ) + assert "gpt-4o" in aliases + assert "openai/gpt-4o" in aliases + assert "my-deployment-label" in aliases + assert "openai/my-deployment-label" in aliases + + +def test_yaml_alias_set_omits_provider_prefix_when_provider_blank(): + from app.services.pricing_registration import _alias_set_for_yaml + + aliases = _alias_set_for_yaml( + provider="", + model_name="foo", + base_model="bar", + ) + assert "bar" in aliases + assert "foo" in aliases + assert all("/" not in a for a in aliases) + + +# --------------------------------------------------------------------------- +# register_pricing_from_global_configs +# --------------------------------------------------------------------------- + + +class _RegistrationSpy: + """Captures the dicts passed to ``litellm.register_model``. + + Many calls may go through; we just record them all and let tests assert + against the union. + """ + + def __init__(self) -> None: + self.calls: list[dict[str, Any]] = [] + + def __call__(self, payload: dict[str, Any]) -> None: + self.calls.append(payload) + + @property + def all_keys(self) -> set[str]: + keys: set[str] = set() + for payload in self.calls: + keys.update(payload.keys()) + return keys + + +def _patch_register(monkeypatch: pytest.MonkeyPatch) -> _RegistrationSpy: + spy = _RegistrationSpy() + monkeypatch.setattr( + "app.services.pricing_registration.litellm.register_model", + spy, + raising=False, + ) + return spy + + +def _patch_openrouter_pricing( + monkeypatch: pytest.MonkeyPatch, mapping: dict[str, dict[str, str]] +) -> None: + """Pretend the OpenRouter integration is initialised with ``mapping``.""" + + class _Stub: + def get_raw_pricing(self) -> dict[str, dict[str, str]]: + return mapping + + class _StubService: + @classmethod + def is_initialized(cls) -> bool: + return True + + @classmethod + def get_instance(cls) -> _Stub: + return _Stub() + + monkeypatch.setattr( + "app.services.openrouter_integration_service.OpenRouterIntegrationService", + _StubService, + raising=False, + ) + + +def test_openrouter_models_register_under_aliases(monkeypatch): + """An OpenRouter config whose ``model_name`` is in the cached raw + pricing map is registered under both ``openrouter/X`` and bare ``X``. + """ + from app.config import config + from app.services.pricing_registration import register_pricing_from_global_configs + + spy = _patch_register(monkeypatch) + _patch_openrouter_pricing( + monkeypatch, + { + "anthropic/claude-3-5-sonnet": { + "prompt": "0.000003", + "completion": "0.000015", + } + }, + ) + + monkeypatch.setattr( + config, + "GLOBAL_LLM_CONFIGS", + [ + { + "id": 1, + "provider": "OPENROUTER", + "model_name": "anthropic/claude-3-5-sonnet", + } + ], + ) + + register_pricing_from_global_configs() + + assert "openrouter/anthropic/claude-3-5-sonnet" in spy.all_keys + assert "anthropic/claude-3-5-sonnet" in spy.all_keys + # Costs are float-converted from the raw OpenRouter strings. + payload = spy.calls[0] + assert payload["openrouter/anthropic/claude-3-5-sonnet"][ + "input_cost_per_token" + ] == pytest.approx(3e-6) + assert payload["openrouter/anthropic/claude-3-5-sonnet"][ + "output_cost_per_token" + ] == pytest.approx(15e-6) + assert ( + payload["openrouter/anthropic/claude-3-5-sonnet"]["litellm_provider"] + == "openrouter" + ) + + +def test_yaml_override_registers_under_alias_set(monkeypatch): + """Operator-declared ``input_cost_per_token`` / + ``output_cost_per_token`` on a YAML config registers under every + alias the YAML alias generator produces — including the ``azure/`` + normalisation for ``azure_openai`` providers. + """ + from app.config import config + from app.services.pricing_registration import register_pricing_from_global_configs + + spy = _patch_register(monkeypatch) + _patch_openrouter_pricing(monkeypatch, {}) + + monkeypatch.setattr( + config, + "GLOBAL_LLM_CONFIGS", + [ + { + "id": 1, + "provider": "AZURE_OPENAI", + "model_name": "gpt-5.4", + "litellm_params": { + "base_model": "gpt-5.4", + "input_cost_per_token": 2e-6, + "output_cost_per_token": 8e-6, + }, + } + ], + ) + + register_pricing_from_global_configs() + + keys = spy.all_keys + assert "gpt-5.4" in keys + assert "azure_openai/gpt-5.4" in keys + assert "azure/gpt-5.4" in keys + + payload = spy.calls[0] + entry = payload["gpt-5.4"] + assert entry["input_cost_per_token"] == pytest.approx(2e-6) + assert entry["output_cost_per_token"] == pytest.approx(8e-6) + assert entry["litellm_provider"] == "azure" + + +def test_no_override_means_no_registration(monkeypatch): + """A YAML config that *omits* both pricing fields must NOT be registered + — registering as zero would override LiteLLM's native pricing for the + ``base_model`` key (e.g. ``gpt-4o``) and silently make every user's + bill drop to $0. Fail-safe is "skip and warn", not "register zero". + """ + from app.config import config + from app.services.pricing_registration import register_pricing_from_global_configs + + spy = _patch_register(monkeypatch) + _patch_openrouter_pricing(monkeypatch, {}) + + monkeypatch.setattr( + config, + "GLOBAL_LLM_CONFIGS", + [ + { + "id": 1, + "provider": "OPENAI", + "model_name": "gpt-4o", + "litellm_params": {"base_model": "gpt-4o"}, + } + ], + ) + + register_pricing_from_global_configs() + + assert spy.calls == [] + + +def test_openrouter_skipped_when_pricing_missing(monkeypatch): + """If the OpenRouter raw-pricing cache doesn't carry an entry for a + configured model (network blip during refresh, model added later, etc.), + we skip it rather than registering zero pricing. + """ + from app.config import config + from app.services.pricing_registration import register_pricing_from_global_configs + + spy = _patch_register(monkeypatch) + _patch_openrouter_pricing( + monkeypatch, {"some/other-model": {"prompt": "1", "completion": "1"}} + ) + + monkeypatch.setattr( + config, + "GLOBAL_LLM_CONFIGS", + [ + { + "id": 1, + "provider": "OPENROUTER", + "model_name": "anthropic/claude-3-5-sonnet", + } + ], + ) + + register_pricing_from_global_configs() + + assert spy.calls == [] + + +def test_register_continues_after_individual_failure(monkeypatch, caplog): + """A single bad ``register_model`` call (e.g. raising LiteLLM error) + must not abort registration of the remaining configs. + """ + from app.config import config + from app.services.pricing_registration import register_pricing_from_global_configs + + failing_keys: set[str] = {"anthropic/claude-3-5-sonnet"} + successful_calls: list[dict[str, Any]] = [] + + def _maybe_fail(payload: dict[str, Any]) -> None: + if any(k in failing_keys for k in payload): + raise RuntimeError("boom") + successful_calls.append(payload) + + monkeypatch.setattr( + "app.services.pricing_registration.litellm.register_model", + _maybe_fail, + raising=False, + ) + _patch_openrouter_pricing( + monkeypatch, + { + "anthropic/claude-3-5-sonnet": { + "prompt": "0.000003", + "completion": "0.000015", + } + }, + ) + + monkeypatch.setattr( + config, + "GLOBAL_LLM_CONFIGS", + [ + { + "id": 1, + "provider": "OPENROUTER", + "model_name": "anthropic/claude-3-5-sonnet", + }, + { + "id": 2, + "provider": "OPENAI", + "model_name": "custom-deployment", + "litellm_params": { + "base_model": "custom-deployment", + "input_cost_per_token": 1e-6, + "output_cost_per_token": 2e-6, + }, + }, + ], + ) + + register_pricing_from_global_configs() + + # The good config still registered. + assert any("custom-deployment" in payload for payload in successful_calls) + + +def test_vision_configs_registered_with_chat_shape(monkeypatch): + """``register_pricing_from_global_configs`` walks + ``GLOBAL_VISION_LLM_CONFIGS`` in addition to the chat configs so vision + calls (during indexing) bill correctly. Vision configs use the same + chat-shape token prices, but image-gen pricing is intentionally NOT + registered here (handled via ``response_cost`` in LiteLLM). + """ + from app.config import config + from app.services.pricing_registration import register_pricing_from_global_configs + + spy = _patch_register(monkeypatch) + _patch_openrouter_pricing( + monkeypatch, + {"openai/gpt-4o": {"prompt": "0.000005", "completion": "0.000015"}}, + ) + + # No chat configs — only vision. Proves the vision walk is a separate + # iteration, not piggy-backed on the chat list. + monkeypatch.setattr(config, "GLOBAL_LLM_CONFIGS", []) + monkeypatch.setattr( + config, + "GLOBAL_VISION_LLM_CONFIGS", + [ + { + "id": -1, + "provider": "OPENROUTER", + "model_name": "openai/gpt-4o", + "billing_tier": "premium", + "input_cost_per_token": 5e-6, + "output_cost_per_token": 15e-6, + } + ], + ) + + register_pricing_from_global_configs() + + assert "openrouter/openai/gpt-4o" in spy.all_keys + payload_value = spy.calls[0]["openrouter/openai/gpt-4o"] + assert payload_value["mode"] == "chat" + assert payload_value["litellm_provider"] == "openrouter" + assert payload_value["input_cost_per_token"] == pytest.approx(5e-6) + assert payload_value["output_cost_per_token"] == pytest.approx(15e-6) + + +def test_vision_with_inline_pricing_when_or_cache_missing(monkeypatch): + """If the OpenRouter pricing cache misses a vision model (different + catalogue surface), the vision walk falls back to inline + ``input_cost_per_token``/``output_cost_per_token`` on the cfg itself. + """ + from app.config import config + from app.services.pricing_registration import register_pricing_from_global_configs + + spy = _patch_register(monkeypatch) + _patch_openrouter_pricing(monkeypatch, {}) + + monkeypatch.setattr(config, "GLOBAL_LLM_CONFIGS", []) + monkeypatch.setattr( + config, + "GLOBAL_VISION_LLM_CONFIGS", + [ + { + "id": -1, + "provider": "OPENROUTER", + "model_name": "google/gemini-2.5-flash", + "billing_tier": "premium", + "input_cost_per_token": 1e-6, + "output_cost_per_token": 4e-6, + } + ], + ) + + register_pricing_from_global_configs() + + assert "openrouter/google/gemini-2.5-flash" in spy.all_keys diff --git a/surfsense_backend/tests/unit/services/test_quota_checked_vision_llm.py b/surfsense_backend/tests/unit/services/test_quota_checked_vision_llm.py new file mode 100644 index 000000000..9e35b6f9c --- /dev/null +++ b/surfsense_backend/tests/unit/services/test_quota_checked_vision_llm.py @@ -0,0 +1,157 @@ +"""Unit tests for ``QuotaCheckedVisionLLM``. + +Validates that: + +* Calling ``ainvoke`` routes through ``billable_call`` (premium credit + enforcement) and forwards the inner LLM's response on success. +* The wrapper proxies non-overridden attributes to the inner LLM + (``__getattr__``) so ``invoke`` / ``astream`` / ``with_structured_output`` + still work without quota gating (they're not used in indexing today). +* When ``billable_call`` raises ``QuotaInsufficientError`` the wrapper + bubbles it up — the ETL pipeline catches that and falls back to OCR. +""" + +from __future__ import annotations + +import contextlib +from typing import Any +from uuid import uuid4 + +import pytest + +pytestmark = pytest.mark.unit + + +class _FakeInnerLLM: + """Stand-in for ``langchain_litellm.ChatLiteLLM``.""" + + def __init__(self, response: Any = "OCR'd content") -> None: + self._response = response + self.ainvoke_calls: list[Any] = [] + + async def ainvoke(self, input: Any, *args: Any, **kwargs: Any) -> Any: + self.ainvoke_calls.append(input) + return self._response + + def some_other_method(self, x: int) -> int: + return x * 2 + + +@contextlib.asynccontextmanager +async def _passthrough_billable_call(**_kwargs): + """Stand-in for billable_call that always allows the call to run.""" + + class _Acc: + total_cost_micros = 0 + total_prompt_tokens = 0 + total_completion_tokens = 0 + grand_total = 0 + calls: list[Any] = [] + + def per_message_summary(self) -> dict[str, dict[str, int]]: + return {} + + yield _Acc() + + +@pytest.mark.asyncio +async def test_ainvoke_routes_through_billable_call(monkeypatch): + from app.services.quota_checked_vision_llm import QuotaCheckedVisionLLM + + captured_kwargs: list[dict[str, Any]] = [] + + @contextlib.asynccontextmanager + async def _spy_billable_call(**kwargs): + captured_kwargs.append(kwargs) + async with _passthrough_billable_call() as acc: + yield acc + + monkeypatch.setattr( + "app.services.quota_checked_vision_llm.billable_call", + _spy_billable_call, + raising=False, + ) + + inner = _FakeInnerLLM(response="A red apple on a white table") + user_id = uuid4() + wrapper = QuotaCheckedVisionLLM( + inner, + user_id=user_id, + search_space_id=99, + billing_tier="premium", + base_model="openai/gpt-4o", + quota_reserve_tokens=4000, + ) + + result = await wrapper.ainvoke([{"text": "what is this?"}]) + assert result == "A red apple on a white table" + assert len(inner.ainvoke_calls) == 1 + assert len(captured_kwargs) == 1 + bc_kwargs = captured_kwargs[0] + assert bc_kwargs["user_id"] == user_id + assert bc_kwargs["search_space_id"] == 99 + assert bc_kwargs["billing_tier"] == "premium" + assert bc_kwargs["base_model"] == "openai/gpt-4o" + assert bc_kwargs["quota_reserve_tokens"] == 4000 + assert bc_kwargs["usage_type"] == "vision_extraction" + + +@pytest.mark.asyncio +async def test_ainvoke_propagates_quota_insufficient_error(monkeypatch): + from app.services.billable_calls import QuotaInsufficientError + from app.services.quota_checked_vision_llm import QuotaCheckedVisionLLM + + @contextlib.asynccontextmanager + async def _denying_billable_call(**_kwargs): + raise QuotaInsufficientError( + usage_type="vision_extraction", + used_micros=5_000_000, + limit_micros=5_000_000, + remaining_micros=0, + ) + yield # unreachable but required for asynccontextmanager type + + monkeypatch.setattr( + "app.services.quota_checked_vision_llm.billable_call", + _denying_billable_call, + raising=False, + ) + + inner = _FakeInnerLLM() + wrapper = QuotaCheckedVisionLLM( + inner, + user_id=uuid4(), + search_space_id=1, + billing_tier="premium", + base_model="openai/gpt-4o", + quota_reserve_tokens=4000, + ) + + with pytest.raises(QuotaInsufficientError): + await wrapper.ainvoke([{"text": "x"}]) + + # Inner LLM never ran on a denied reservation. + assert inner.ainvoke_calls == [] + + +@pytest.mark.asyncio +async def test_proxies_non_overridden_attributes_to_inner(): + """``__getattr__`` forwards anything not on the proxy itself, so any + method we didn't explicitly override (``invoke``, ``astream``, + ``with_structured_output``, etc.) still works — just without quota + gating, which is fine because the indexer only ever calls ainvoke. + """ + from app.services.quota_checked_vision_llm import QuotaCheckedVisionLLM + + inner = _FakeInnerLLM() + wrapper = QuotaCheckedVisionLLM( + inner, + user_id=uuid4(), + search_space_id=1, + billing_tier="premium", + base_model="openai/gpt-4o", + quota_reserve_tokens=4000, + ) + + # ``some_other_method`` is on the inner only. + assert wrapper.some_other_method(7) == 14 diff --git a/surfsense_backend/tests/unit/services/test_token_quota_service_cost.py b/surfsense_backend/tests/unit/services/test_token_quota_service_cost.py new file mode 100644 index 000000000..63681828d --- /dev/null +++ b/surfsense_backend/tests/unit/services/test_token_quota_service_cost.py @@ -0,0 +1,515 @@ +"""Cost-based premium quota unit tests. + +Covers the USD-micro behaviour added in migration 140: + +* ``TurnTokenAccumulator.total_cost_micros`` sums ``cost_micros`` across all + calls in a turn — used as the debit amount when ``agent_config.is_premium`` + is true, regardless of which underlying model produced each call. This + preserves the prior "premium turn → all calls in turn count" rule from the + token-based system. +* ``estimate_call_reserve_micros`` scales linearly with model pricing, + clamps to a sane floor when pricing is unknown, and respects the + ``QUOTA_MAX_RESERVE_MICROS`` ceiling so a misconfigured "$1000/M" entry + can't lock the whole balance on one call. +""" + +from __future__ import annotations + +import pytest + +pytestmark = pytest.mark.unit + + +# --------------------------------------------------------------------------- +# TurnTokenAccumulator — premium-turn debit semantics +# --------------------------------------------------------------------------- + + +def test_total_cost_micros_sums_premium_and_free_calls(): + """A premium turn that also called a free sub-agent debits the union. + + The plan deliberately preserved the existing "premium turn → all calls + count" behaviour because per-call premium filtering relied on + ``LLMRouterService._premium_model_strings`` which only covers router-pool + deployments. ``total_cost_micros`` therefore must include free-model + calls (whose ``cost_micros`` is typically ``0``) as well as the premium + call's actual provider cost. + """ + from app.services.token_tracking_service import TurnTokenAccumulator + + acc = TurnTokenAccumulator() + # Premium model (e.g. claude-opus): non-zero cost. + acc.add( + model="anthropic/claude-3-5-sonnet", + prompt_tokens=1200, + completion_tokens=400, + total_tokens=1600, + cost_micros=12_345, + ) + # Free sub-agent (e.g. title-gen on a free model): zero cost. + acc.add( + model="gpt-4o-mini", + prompt_tokens=120, + completion_tokens=20, + total_tokens=140, + cost_micros=0, + ) + # A second premium-priced call within the same turn. + acc.add( + model="anthropic/claude-3-5-sonnet", + prompt_tokens=800, + completion_tokens=200, + total_tokens=1000, + cost_micros=7_500, + ) + + assert acc.total_cost_micros == 12_345 + 0 + 7_500 + # Token totals stay correct so the FE display path still works. + assert acc.grand_total == 1600 + 140 + 1000 + + +def test_total_cost_micros_zero_when_no_calls(): + """An empty accumulator must report zero cost (no division-by-zero, no None).""" + from app.services.token_tracking_service import TurnTokenAccumulator + + acc = TurnTokenAccumulator() + assert acc.total_cost_micros == 0 + assert acc.grand_total == 0 + + +def test_per_message_summary_groups_cost_by_model(): + """``per_message_summary`` must accumulate ``cost_micros`` per model so the + SSE ``model_breakdown`` payload reports actual USD spend per provider. + """ + from app.services.token_tracking_service import TurnTokenAccumulator + + acc = TurnTokenAccumulator() + acc.add( + model="claude-3-5-sonnet", + prompt_tokens=100, + completion_tokens=50, + total_tokens=150, + cost_micros=4_000, + ) + acc.add( + model="claude-3-5-sonnet", + prompt_tokens=200, + completion_tokens=100, + total_tokens=300, + cost_micros=8_000, + ) + acc.add( + model="gpt-4o-mini", + prompt_tokens=50, + completion_tokens=10, + total_tokens=60, + cost_micros=200, + ) + + summary = acc.per_message_summary() + assert summary["claude-3-5-sonnet"]["cost_micros"] == 12_000 + assert summary["claude-3-5-sonnet"]["total_tokens"] == 450 + assert summary["gpt-4o-mini"]["cost_micros"] == 200 + + +def test_serialized_calls_includes_cost_micros(): + """``serialized_calls`` is what flows into the SSE ``call_details`` + payload; cost_micros must be present on each entry so the FE message-info + dropdown can render per-call USD. + """ + from app.services.token_tracking_service import TurnTokenAccumulator + + acc = TurnTokenAccumulator() + acc.add( + model="m", + prompt_tokens=1, + completion_tokens=1, + total_tokens=2, + cost_micros=42, + ) + serialized = acc.serialized_calls() + assert serialized == [ + { + "model": "m", + "prompt_tokens": 1, + "completion_tokens": 1, + "total_tokens": 2, + "cost_micros": 42, + "call_kind": "chat", + } + ] + + +# --------------------------------------------------------------------------- +# estimate_call_reserve_micros — sizing and clamping +# --------------------------------------------------------------------------- + + +def test_reserve_returns_floor_when_model_unknown(monkeypatch): + """If LiteLLM doesn't know the model, ``get_model_info`` raises and the + helper falls back to the 100-micro floor — small enough that a user with + $0.0001 left can still send a tiny request, but non-zero so we still gate + against an empty balance. + """ + import litellm + + from app.services import token_quota_service + + def _raise(_name): + raise KeyError("unknown") + + monkeypatch.setattr(litellm, "get_model_info", _raise, raising=False) + + micros = token_quota_service.estimate_call_reserve_micros( + base_model="nonexistent-model", + quota_reserve_tokens=4000, + ) + assert micros == token_quota_service._QUOTA_MIN_RESERVE_MICROS + assert micros == 100 + + +def test_reserve_returns_floor_when_pricing_is_zero(monkeypatch): + """LiteLLM may *return* a model with both cost-per-token fields at 0 + (pricing not yet registered). The helper must not multiply 0 x tokens + and end up reserving 0 — it must clamp to the floor. + """ + import litellm + + from app.services import token_quota_service + + monkeypatch.setattr( + litellm, + "get_model_info", + lambda _name: {"input_cost_per_token": 0, "output_cost_per_token": 0}, + raising=False, + ) + + micros = token_quota_service.estimate_call_reserve_micros( + base_model="some-pending-model", + quota_reserve_tokens=4000, + ) + assert micros == token_quota_service._QUOTA_MIN_RESERVE_MICROS + + +def test_reserve_scales_with_model_cost(monkeypatch): + """Claude-Opus-priced model with 4000 reserve_tokens reserves + ~$0.36 = 360_000 micros. Critically this must NOT be clamped down to + some small artificial cap — that was the bug the plan called out. + """ + import litellm + + from app.config import config + from app.services import token_quota_service + + monkeypatch.setattr( + litellm, + "get_model_info", + lambda _name: { + "input_cost_per_token": 15e-6, + "output_cost_per_token": 75e-6, + }, + raising=False, + ) + monkeypatch.setattr(config, "QUOTA_MAX_RESERVE_MICROS", 1_000_000, raising=False) + + micros = token_quota_service.estimate_call_reserve_micros( + base_model="claude-3-opus", + quota_reserve_tokens=4000, + ) + # 4000 * (15e-6 + 75e-6) = 4000 * 90e-6 = 0.36 USD = 360_000 micros. + assert micros == 360_000 + + +def test_reserve_clamps_to_max_ceiling(monkeypatch): + """A misconfigured "$1000 / M" model with 4000 reserve_tokens would + nominally compute to $4 = 4_000_000 micros. The ceiling + ``QUOTA_MAX_RESERVE_MICROS`` must clamp that so a bad pricing entry + can't lock the user's whole balance on one call. + """ + import litellm + + from app.config import config + from app.services import token_quota_service + + monkeypatch.setattr( + litellm, + "get_model_info", + lambda _name: { + "input_cost_per_token": 1e-3, + "output_cost_per_token": 0, + }, + raising=False, + ) + monkeypatch.setattr(config, "QUOTA_MAX_RESERVE_MICROS", 1_000_000, raising=False) + + micros = token_quota_service.estimate_call_reserve_micros( + base_model="oops-misconfigured", + quota_reserve_tokens=4000, + ) + assert micros == 1_000_000 + + +def test_reserve_uses_default_when_quota_reserve_tokens_missing(monkeypatch): + """Per-config ``quota_reserve_tokens`` is optional; when ``None`` or + zero, the helper must fall back to the global ``QUOTA_MAX_RESERVE_PER_CALL`` + so anonymous-style configs still reserve the operator-tunable default. + """ + import litellm + + from app.config import config + from app.services import token_quota_service + + monkeypatch.setattr( + litellm, + "get_model_info", + lambda _name: { + "input_cost_per_token": 1e-6, + "output_cost_per_token": 1e-6, + }, + raising=False, + ) + monkeypatch.setattr(config, "QUOTA_MAX_RESERVE_PER_CALL", 2000, raising=False) + monkeypatch.setattr(config, "QUOTA_MAX_RESERVE_MICROS", 1_000_000, raising=False) + + # 2000 * (1e-6 + 1e-6) = 4e-3 USD = 4000 micros + assert ( + token_quota_service.estimate_call_reserve_micros( + base_model="cheap", quota_reserve_tokens=None + ) + == 4000 + ) + assert ( + token_quota_service.estimate_call_reserve_micros( + base_model="cheap", quota_reserve_tokens=0 + ) + == 4000 + ) + + +# --------------------------------------------------------------------------- +# TokenTrackingCallback — image vs chat usage shape +# --------------------------------------------------------------------------- + + +class _FakeImageUsage: + """Mimics LiteLLM's ``ImageUsage`` (input_tokens / output_tokens shape).""" + + def __init__( + self, + input_tokens: int = 0, + output_tokens: int = 0, + total_tokens: int | None = None, + ) -> None: + self.input_tokens = input_tokens + self.output_tokens = output_tokens + if total_tokens is not None: + self.total_tokens = total_tokens + + +class _FakeImageResponse: + """Mimics LiteLLM's ``ImageResponse`` — same name so the callback's + ``type(...).__name__`` probe routes to the image branch. + """ + + def __init__(self, usage: _FakeImageUsage, response_cost: float | None = None): + self.usage = usage + if response_cost is not None: + self._hidden_params = {"response_cost": response_cost} + + +# Re-tag the helper class as ``ImageResponse`` for the type-name probe in +# the callback. We can't simply name the class ``ImageResponse`` because +# the test runner sometimes imports test modules in surprising ways and +# we want to be explicit. +_FakeImageResponse.__name__ = "ImageResponse" + + +class _FakeChatUsage: + def __init__(self, prompt: int, completion: int): + self.prompt_tokens = prompt + self.completion_tokens = completion + self.total_tokens = prompt + completion + + +class _FakeChatResponse: + def __init__(self, usage: _FakeChatUsage): + self.usage = usage + + +@pytest.mark.asyncio +async def test_callback_reads_image_usage_input_output_tokens(): + """``TokenTrackingCallback`` must read ``input_tokens``/``output_tokens`` + for ``ImageResponse`` (LiteLLM's ImageUsage shape), NOT + prompt_tokens/completion_tokens which is the chat shape. + """ + from app.services.token_tracking_service import ( + TokenTrackingCallback, + scoped_turn, + ) + + cb = TokenTrackingCallback() + response = _FakeImageResponse( + usage=_FakeImageUsage(input_tokens=42, output_tokens=8, total_tokens=50), + response_cost=0.04, # $0.04 per image + ) + + async with scoped_turn() as acc: + await cb.async_log_success_event( + kwargs={"model": "openai/gpt-image-1", "response_cost": 0.04}, + response_obj=response, + start_time=None, + end_time=None, + ) + assert len(acc.calls) == 1 + call = acc.calls[0] + assert call.prompt_tokens == 42 + assert call.completion_tokens == 8 + assert call.total_tokens == 50 + # 0.04 USD = 40_000 micros + assert call.cost_micros == 40_000 + assert call.call_kind == "image_generation" + + +@pytest.mark.asyncio +async def test_callback_chat_path_unchanged(): + """Chat responses must still read prompt_tokens/completion_tokens.""" + from app.services.token_tracking_service import ( + TokenTrackingCallback, + scoped_turn, + ) + + cb = TokenTrackingCallback() + response = _FakeChatResponse(_FakeChatUsage(prompt=120, completion=30)) + + async with scoped_turn() as acc: + await cb.async_log_success_event( + kwargs={ + "model": "openrouter/anthropic/claude-3-5-sonnet", + "response_cost": 0.0036, + }, + response_obj=response, + start_time=None, + end_time=None, + ) + assert len(acc.calls) == 1 + call = acc.calls[0] + assert call.prompt_tokens == 120 + assert call.completion_tokens == 30 + assert call.total_tokens == 150 + assert call.cost_micros == 3_600 + assert call.call_kind == "chat" + + +@pytest.mark.asyncio +async def test_callback_image_missing_response_cost_falls_back_to_zero(monkeypatch): + """When OpenRouter omits ``usage.cost`` LiteLLM's + ``default_image_cost_calculator`` raises. The defensive image branch in + ``_extract_cost_usd`` must NOT call ``cost_per_token`` (which is + chat-shaped and would raise too) — it returns 0 with a WARNING log. + """ + import litellm + + from app.services.token_tracking_service import ( + TokenTrackingCallback, + scoped_turn, + ) + + # Force completion_cost to raise the same way OpenRouter image-gen fails. + def _boom(*_args, **_kwargs): + raise ValueError("model_cost: missing entry for openrouter image model") + + monkeypatch.setattr(litellm, "completion_cost", _boom, raising=False) + + # And make sure cost_per_token is NEVER called for the image path — + # if it were, our ``is_image=True`` branch is broken. + cost_per_token_calls: list = [] + + def _record_cost_per_token(**kwargs): + cost_per_token_calls.append(kwargs) + return (0.0, 0.0) + + monkeypatch.setattr( + litellm, "cost_per_token", _record_cost_per_token, raising=False + ) + + cb = TokenTrackingCallback() + response = _FakeImageResponse( + usage=_FakeImageUsage(input_tokens=7, output_tokens=0) + ) + + async with scoped_turn() as acc: + await cb.async_log_success_event( + kwargs={"model": "openrouter/google/gemini-2.5-flash-image"}, + response_obj=response, + start_time=None, + end_time=None, + ) + + assert len(acc.calls) == 1 + assert acc.calls[0].cost_micros == 0 + assert acc.calls[0].call_kind == "image_generation" + # The image branch must short-circuit before cost_per_token. + assert cost_per_token_calls == [] + + +# --------------------------------------------------------------------------- +# scoped_turn — ContextVar reset semantics (issue B) +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_scoped_turn_restores_outer_accumulator(): + """``scoped_turn`` must restore the previous ContextVar value on exit + so a per-call wrapper inside an outer chat turn doesn't leak its + accumulator outward (which would cause double-debit at chat-turn exit). + """ + from app.services.token_tracking_service import ( + get_current_accumulator, + scoped_turn, + start_turn, + ) + + outer = start_turn() + assert get_current_accumulator() is outer + + async with scoped_turn() as inner: + assert get_current_accumulator() is inner + assert inner is not outer + inner.add( + model="x", + prompt_tokens=1, + completion_tokens=1, + total_tokens=2, + cost_micros=5, + ) + + # After exit the outer accumulator is restored unchanged. + assert get_current_accumulator() is outer + assert outer.total_cost_micros == 0 + assert len(outer.calls) == 0 + # The inner accumulator captured the call but didn't bleed into outer. + assert inner.total_cost_micros == 5 + + +@pytest.mark.asyncio +async def test_scoped_turn_resets_to_none_when_no_outer(): + """Running ``scoped_turn`` outside any chat turn (e.g. a background + indexing job) must leave the ContextVar at ``None`` on exit so the + next *unrelated* request starts clean. + """ + from app.services.token_tracking_service import ( + _turn_accumulator, + get_current_accumulator, + scoped_turn, + ) + + # ContextVar default is None for a fresh test isolated context. We + # simulate "no outer" explicitly to be robust against test order. + token = _turn_accumulator.set(None) + try: + assert get_current_accumulator() is None + async with scoped_turn() as acc: + assert get_current_accumulator() is acc + assert get_current_accumulator() is None + finally: + _turn_accumulator.reset(token) diff --git a/surfsense_backend/tests/unit/tasks/test_podcast_billing.py b/surfsense_backend/tests/unit/tasks/test_podcast_billing.py new file mode 100644 index 000000000..38d6ba2ca --- /dev/null +++ b/surfsense_backend/tests/unit/tasks/test_podcast_billing.py @@ -0,0 +1,325 @@ +"""Unit tests for podcast Celery task billing integration. + +Validates ``_generate_content_podcast`` correctly wraps +``podcaster_graph.ainvoke`` in a ``billable_call`` envelope, propagates the +search-space owner's billing decision, and degrades cleanly when the +resolver fails or premium credit is exhausted. + +Coverage: + +* Happy-path free config: resolver → ``billable_call`` enters with + ``usage_type='podcast_generation'`` and the configured reserve override, + graph runs, podcast row flips to ``READY``. +* Happy-path premium config: same wiring with ``billing_tier='premium'``. +* Quota denial: ``billable_call`` raises ``QuotaInsufficientError`` → + graph is *not* invoked, podcast row flips to ``FAILED``, return dict + carries ``reason='premium_quota_exhausted'``. +* Resolver failure: ``ValueError`` from the resolver → podcast row flips + to ``FAILED``, return dict carries ``reason='billing_resolution_failed'``. +""" + +from __future__ import annotations + +import contextlib +from types import SimpleNamespace +from typing import Any +from uuid import uuid4 + +import pytest + +pytestmark = pytest.mark.unit + + +# --------------------------------------------------------------------------- +# Fakes +# --------------------------------------------------------------------------- + + +class _FakeExecResult: + def __init__(self, obj): + self._obj = obj + + def scalars(self): + return self + + def first(self): + return self._obj + + def filter(self, *_args, **_kwargs): + return self + + +class _FakeSession: + def __init__(self, podcast): + self._podcast = podcast + self.commit_count = 0 + + async def execute(self, _stmt): + return _FakeExecResult(self._podcast) + + async def commit(self): + self.commit_count += 1 + + async def __aenter__(self): + return self + + async def __aexit__(self, *args): + return None + + +class _FakeSessionMaker: + def __init__(self, session: _FakeSession): + self._session = session + + def __call__(self): + return self._session + + +def _make_podcast(podcast_id: int = 7, thread_id: int = 99) -> SimpleNamespace: + """Stand-in for a ``Podcast`` row. Importing ``PodcastStatus`` lazily + inside helpers keeps this fixture cheap.""" + return SimpleNamespace( + id=podcast_id, + title="Test Podcast", + thread_id=thread_id, + status=None, + podcast_transcript=None, + file_location=None, + ) + + +@contextlib.asynccontextmanager +async def _ok_billable_call(**kwargs): + """Stand-in for ``billable_call`` that records its kwargs and yields a + no-op accumulator-shaped object.""" + _CALL_LOG.append(kwargs) + yield SimpleNamespace() + + +_CALL_LOG: list[dict[str, Any]] = [] + + +@contextlib.asynccontextmanager +async def _denying_billable_call(**kwargs): + from app.services.billable_calls import QuotaInsufficientError + + _CALL_LOG.append(kwargs) + raise QuotaInsufficientError( + usage_type=kwargs.get("usage_type", "?"), + used_micros=5_000_000, + limit_micros=5_000_000, + remaining_micros=0, + ) + yield SimpleNamespace() # pragma: no cover — for grammar only + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +@pytest.fixture(autouse=True) +def _reset_call_log(): + _CALL_LOG.clear() + yield + _CALL_LOG.clear() + + +@pytest.mark.asyncio +async def test_billable_call_invoked_with_correct_kwargs_for_free_config(monkeypatch): + """Happy path: free billing tier still wraps the graph call so the + audit row is recorded. Verifies kwargs threading.""" + from app.config import config as app_config + from app.db import PodcastStatus + from app.tasks.celery_tasks import podcast_tasks + + podcast = _make_podcast(podcast_id=7, thread_id=99) + session = _FakeSession(podcast) + monkeypatch.setattr( + podcast_tasks, + "get_celery_session_maker", + lambda: _FakeSessionMaker(session), + ) + + user_id = uuid4() + + async def _fake_resolver(sess, search_space_id, *, thread_id=None): + assert search_space_id == 555 + assert thread_id == 99 + return user_id, "free", "openrouter/some-free-model" + + monkeypatch.setattr( + podcast_tasks, "_resolve_agent_billing_for_search_space", _fake_resolver + ) + monkeypatch.setattr(podcast_tasks, "billable_call", _ok_billable_call) + + async def _fake_graph_invoke(state, config): + return { + "podcast_transcript": [ + SimpleNamespace(speaker_id=0, dialog="Hi"), + SimpleNamespace(speaker_id=1, dialog="Hello"), + ], + "final_podcast_file_path": "/tmp/podcast.wav", + } + + monkeypatch.setattr(podcast_tasks.podcaster_graph, "ainvoke", _fake_graph_invoke) + + result = await podcast_tasks._generate_content_podcast( + podcast_id=7, + source_content="hello world", + search_space_id=555, + user_prompt="make it short", + ) + + assert result["status"] == "ready" + assert result["podcast_id"] == 7 + assert podcast.status == PodcastStatus.READY + assert podcast.file_location == "/tmp/podcast.wav" + + assert len(_CALL_LOG) == 1 + call = _CALL_LOG[0] + assert call["user_id"] == user_id + assert call["search_space_id"] == 555 + assert call["billing_tier"] == "free" + assert call["base_model"] == "openrouter/some-free-model" + assert call["usage_type"] == "podcast_generation" + assert ( + call["quota_reserve_micros_override"] + == app_config.QUOTA_DEFAULT_PODCAST_RESERVE_MICROS + ) + assert call["thread_id"] == 99 + assert call["call_details"] == {"podcast_id": 7, "title": "Test Podcast"} + + +@pytest.mark.asyncio +async def test_billable_call_invoked_with_premium_tier(monkeypatch): + """Premium resolution flows through to ``billable_call`` so the + reserve/finalize path triggers.""" + from app.tasks.celery_tasks import podcast_tasks + + podcast = _make_podcast() + session = _FakeSession(podcast) + monkeypatch.setattr( + podcast_tasks, + "get_celery_session_maker", + lambda: _FakeSessionMaker(session), + ) + + user_id = uuid4() + + async def _fake_resolver(sess, search_space_id, *, thread_id=None): + return user_id, "premium", "gpt-5.4" + + monkeypatch.setattr( + podcast_tasks, "_resolve_agent_billing_for_search_space", _fake_resolver + ) + monkeypatch.setattr(podcast_tasks, "billable_call", _ok_billable_call) + + async def _fake_graph_invoke(state, config): + return {"podcast_transcript": [], "final_podcast_file_path": "x.wav"} + + monkeypatch.setattr(podcast_tasks.podcaster_graph, "ainvoke", _fake_graph_invoke) + + await podcast_tasks._generate_content_podcast( + podcast_id=7, + source_content="hi", + search_space_id=555, + user_prompt=None, + ) + + assert _CALL_LOG[0]["billing_tier"] == "premium" + assert _CALL_LOG[0]["base_model"] == "gpt-5.4" + + +@pytest.mark.asyncio +async def test_quota_insufficient_marks_podcast_failed_and_skips_graph(monkeypatch): + """When ``billable_call`` denies the reservation, the graph never + runs and the podcast row flips to FAILED with the documented reason + code.""" + from app.db import PodcastStatus + from app.tasks.celery_tasks import podcast_tasks + + podcast = _make_podcast(podcast_id=8) + session = _FakeSession(podcast) + monkeypatch.setattr( + podcast_tasks, + "get_celery_session_maker", + lambda: _FakeSessionMaker(session), + ) + + async def _fake_resolver(sess, search_space_id, *, thread_id=None): + return uuid4(), "premium", "gpt-5.4" + + monkeypatch.setattr( + podcast_tasks, "_resolve_agent_billing_for_search_space", _fake_resolver + ) + monkeypatch.setattr(podcast_tasks, "billable_call", _denying_billable_call) + + graph_invoked = [] + + async def _fake_graph_invoke(state, config): + graph_invoked.append(True) + return {} + + monkeypatch.setattr(podcast_tasks.podcaster_graph, "ainvoke", _fake_graph_invoke) + + result = await podcast_tasks._generate_content_podcast( + podcast_id=8, + source_content="hi", + search_space_id=555, + user_prompt=None, + ) + + assert result == { + "status": "failed", + "podcast_id": 8, + "reason": "premium_quota_exhausted", + } + assert podcast.status == PodcastStatus.FAILED + assert graph_invoked == [] # Graph never ran on denied reservation. + + +@pytest.mark.asyncio +async def test_resolver_failure_marks_podcast_failed(monkeypatch): + """If the resolver raises (e.g. search-space deleted), the task fails + cleanly without invoking the graph.""" + from app.db import PodcastStatus + from app.tasks.celery_tasks import podcast_tasks + + podcast = _make_podcast(podcast_id=9) + session = _FakeSession(podcast) + monkeypatch.setattr( + podcast_tasks, + "get_celery_session_maker", + lambda: _FakeSessionMaker(session), + ) + + async def _failing_resolver(sess, search_space_id, *, thread_id=None): + raise ValueError("Search space 555 not found") + + monkeypatch.setattr( + podcast_tasks, "_resolve_agent_billing_for_search_space", _failing_resolver + ) + + graph_invoked = [] + + async def _fake_graph_invoke(state, config): + graph_invoked.append(True) + return {} + + monkeypatch.setattr(podcast_tasks.podcaster_graph, "ainvoke", _fake_graph_invoke) + + result = await podcast_tasks._generate_content_podcast( + podcast_id=9, + source_content="hi", + search_space_id=555, + user_prompt=None, + ) + + assert result == { + "status": "failed", + "podcast_id": 9, + "reason": "billing_resolution_failed", + } + assert podcast.status == PodcastStatus.FAILED + assert graph_invoked == [] diff --git a/surfsense_backend/tests/unit/tasks/test_video_presentation_billing.py b/surfsense_backend/tests/unit/tasks/test_video_presentation_billing.py new file mode 100644 index 000000000..671f57ae4 --- /dev/null +++ b/surfsense_backend/tests/unit/tasks/test_video_presentation_billing.py @@ -0,0 +1,330 @@ +"""Unit tests for video-presentation Celery task billing integration. + +Mirrors ``test_podcast_billing.py`` for the video-presentation task. +Validates the same wrap-graph-in-billable_call pattern and ensures the +larger ``QUOTA_DEFAULT_VIDEO_PRESENTATION_RESERVE_MICROS`` reservation is +threaded through. + +Coverage: + +* Free config: graph runs, ``billable_call`` invoked with the video + reserve override. +* Premium config: same wiring with ``billing_tier='premium'``. +* Quota denial: graph not invoked, row → FAILED, reason code surfaced. +* Resolver failure: row → FAILED with ``billing_resolution_failed``. +""" + +from __future__ import annotations + +import contextlib +from types import SimpleNamespace +from typing import Any +from uuid import uuid4 + +import pytest + +pytestmark = pytest.mark.unit + + +# --------------------------------------------------------------------------- +# Fakes +# --------------------------------------------------------------------------- + + +class _FakeExecResult: + def __init__(self, obj): + self._obj = obj + + def scalars(self): + return self + + def first(self): + return self._obj + + def filter(self, *_args, **_kwargs): + return self + + +class _FakeSession: + def __init__(self, video): + self._video = video + self.commit_count = 0 + + async def execute(self, _stmt): + return _FakeExecResult(self._video) + + async def commit(self): + self.commit_count += 1 + + async def __aenter__(self): + return self + + async def __aexit__(self, *args): + return None + + +class _FakeSessionMaker: + def __init__(self, session: _FakeSession): + self._session = session + + def __call__(self): + return self._session + + +def _make_video(video_id: int = 11, thread_id: int = 99) -> SimpleNamespace: + return SimpleNamespace( + id=video_id, + title="Test Presentation", + thread_id=thread_id, + status=None, + slides=None, + scene_codes=None, + ) + + +_CALL_LOG: list[dict[str, Any]] = [] + + +@contextlib.asynccontextmanager +async def _ok_billable_call(**kwargs): + _CALL_LOG.append(kwargs) + yield SimpleNamespace() + + +@contextlib.asynccontextmanager +async def _denying_billable_call(**kwargs): + from app.services.billable_calls import QuotaInsufficientError + + _CALL_LOG.append(kwargs) + raise QuotaInsufficientError( + usage_type=kwargs.get("usage_type", "?"), + used_micros=5_000_000, + limit_micros=5_000_000, + remaining_micros=0, + ) + yield SimpleNamespace() # pragma: no cover + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +@pytest.fixture(autouse=True) +def _reset_call_log(): + _CALL_LOG.clear() + yield + _CALL_LOG.clear() + + +@pytest.mark.asyncio +async def test_billable_call_invoked_with_correct_kwargs_for_free_config(monkeypatch): + from app.config import config as app_config + from app.db import VideoPresentationStatus + from app.tasks.celery_tasks import video_presentation_tasks + + video = _make_video(video_id=11, thread_id=99) + session = _FakeSession(video) + monkeypatch.setattr( + video_presentation_tasks, + "get_celery_session_maker", + lambda: _FakeSessionMaker(session), + ) + + user_id = uuid4() + + async def _fake_resolver(sess, search_space_id, *, thread_id=None): + assert search_space_id == 777 + assert thread_id == 99 + return user_id, "free", "openrouter/some-free-model" + + monkeypatch.setattr( + video_presentation_tasks, + "_resolve_agent_billing_for_search_space", + _fake_resolver, + ) + monkeypatch.setattr(video_presentation_tasks, "billable_call", _ok_billable_call) + + async def _fake_graph_invoke(state, config): + return {"slides": [], "slide_audio_results": [], "slide_scene_codes": []} + + monkeypatch.setattr( + video_presentation_tasks.video_presentation_graph, + "ainvoke", + _fake_graph_invoke, + ) + + result = await video_presentation_tasks._generate_video_presentation( + video_presentation_id=11, + source_content="content", + search_space_id=777, + user_prompt=None, + ) + + assert result["status"] == "ready" + assert result["video_presentation_id"] == 11 + assert video.status == VideoPresentationStatus.READY + + assert len(_CALL_LOG) == 1 + call = _CALL_LOG[0] + assert call["user_id"] == user_id + assert call["search_space_id"] == 777 + assert call["billing_tier"] == "free" + assert call["base_model"] == "openrouter/some-free-model" + assert call["usage_type"] == "video_presentation_generation" + assert ( + call["quota_reserve_micros_override"] + == app_config.QUOTA_DEFAULT_VIDEO_PRESENTATION_RESERVE_MICROS + ) + assert call["thread_id"] == 99 + assert call["call_details"] == { + "video_presentation_id": 11, + "title": "Test Presentation", + } + + +@pytest.mark.asyncio +async def test_billable_call_invoked_with_premium_tier(monkeypatch): + from app.tasks.celery_tasks import video_presentation_tasks + + video = _make_video() + session = _FakeSession(video) + monkeypatch.setattr( + video_presentation_tasks, + "get_celery_session_maker", + lambda: _FakeSessionMaker(session), + ) + + user_id = uuid4() + + async def _fake_resolver(sess, search_space_id, *, thread_id=None): + return user_id, "premium", "gpt-5.4" + + monkeypatch.setattr( + video_presentation_tasks, + "_resolve_agent_billing_for_search_space", + _fake_resolver, + ) + monkeypatch.setattr(video_presentation_tasks, "billable_call", _ok_billable_call) + + async def _fake_graph_invoke(state, config): + return {"slides": [], "slide_audio_results": [], "slide_scene_codes": []} + + monkeypatch.setattr( + video_presentation_tasks.video_presentation_graph, + "ainvoke", + _fake_graph_invoke, + ) + + await video_presentation_tasks._generate_video_presentation( + video_presentation_id=11, + source_content="content", + search_space_id=777, + user_prompt=None, + ) + + assert _CALL_LOG[0]["billing_tier"] == "premium" + assert _CALL_LOG[0]["base_model"] == "gpt-5.4" + + +@pytest.mark.asyncio +async def test_quota_insufficient_marks_video_failed_and_skips_graph(monkeypatch): + from app.db import VideoPresentationStatus + from app.tasks.celery_tasks import video_presentation_tasks + + video = _make_video(video_id=12) + session = _FakeSession(video) + monkeypatch.setattr( + video_presentation_tasks, + "get_celery_session_maker", + lambda: _FakeSessionMaker(session), + ) + + async def _fake_resolver(sess, search_space_id, *, thread_id=None): + return uuid4(), "premium", "gpt-5.4" + + monkeypatch.setattr( + video_presentation_tasks, + "_resolve_agent_billing_for_search_space", + _fake_resolver, + ) + monkeypatch.setattr( + video_presentation_tasks, "billable_call", _denying_billable_call + ) + + graph_invoked = [] + + async def _fake_graph_invoke(state, config): + graph_invoked.append(True) + return {} + + monkeypatch.setattr( + video_presentation_tasks.video_presentation_graph, + "ainvoke", + _fake_graph_invoke, + ) + + result = await video_presentation_tasks._generate_video_presentation( + video_presentation_id=12, + source_content="content", + search_space_id=777, + user_prompt=None, + ) + + assert result == { + "status": "failed", + "video_presentation_id": 12, + "reason": "premium_quota_exhausted", + } + assert video.status == VideoPresentationStatus.FAILED + assert graph_invoked == [] + + +@pytest.mark.asyncio +async def test_resolver_failure_marks_video_failed(monkeypatch): + from app.db import VideoPresentationStatus + from app.tasks.celery_tasks import video_presentation_tasks + + video = _make_video(video_id=13) + session = _FakeSession(video) + monkeypatch.setattr( + video_presentation_tasks, + "get_celery_session_maker", + lambda: _FakeSessionMaker(session), + ) + + async def _failing_resolver(sess, search_space_id, *, thread_id=None): + raise ValueError("Search space 777 not found") + + monkeypatch.setattr( + video_presentation_tasks, + "_resolve_agent_billing_for_search_space", + _failing_resolver, + ) + + graph_invoked = [] + + async def _fake_graph_invoke(state, config): + graph_invoked.append(True) + return {} + + monkeypatch.setattr( + video_presentation_tasks.video_presentation_graph, + "ainvoke", + _fake_graph_invoke, + ) + + result = await video_presentation_tasks._generate_video_presentation( + video_presentation_id=13, + source_content="content", + search_space_id=777, + user_prompt=None, + ) + + assert result == { + "status": "failed", + "video_presentation_id": 13, + "reason": "billing_resolution_failed", + } + assert video.status == VideoPresentationStatus.FAILED + assert graph_invoked == [] diff --git a/surfsense_web/app/(home)/free/page.tsx b/surfsense_web/app/(home)/free/page.tsx index 8d9ed5cb1..3ddd5195f 100644 --- a/surfsense_web/app/(home)/free/page.tsx +++ b/surfsense_web/app/(home)/free/page.tsx @@ -127,7 +127,7 @@ const FAQ_ITEMS = [ { question: "What happens after I use my free tokens?", answer: - "After your free tokens, create a free SurfSense account to unlock 3 million more premium tokens. Additional tokens can be purchased at $1 per million. Non-premium models remain unlimited for registered users.", + "After your free tokens, create a free SurfSense account to unlock $5 of premium credit. Additional credit can be topped up at $1 for $1 of credit, billed at the actual provider cost. Non-premium models remain unlimited for registered users.", }, { question: "Is Claude AI available without login?", @@ -329,7 +329,7 @@ export default async function FreeHubPage() {

Want More Features?

- Create a free SurfSense account to unlock 3 million tokens, document uploads with + Create a free SurfSense account to unlock $5 of premium credit, document uploads with citations, team collaboration, and integrations with Slack, Google Drive, Notion, and 30+ more tools.

diff --git a/surfsense_web/app/(home)/pricing/page.tsx b/surfsense_web/app/(home)/pricing/page.tsx index 6ad9435bf..6f332be70 100644 --- a/surfsense_web/app/(home)/pricing/page.tsx +++ b/surfsense_web/app/(home)/pricing/page.tsx @@ -5,7 +5,7 @@ import { BreadcrumbNav } from "@/components/seo/breadcrumb-nav"; export const metadata: Metadata = { title: "Pricing | SurfSense - Free AI Search Plans", description: - "Explore SurfSense plans and pricing. Start free with 500 pages & 3M premium tokens. Use ChatGPT, Claude AI, and premium AI models. Pay-as-you-go tokens at $1 per million.", + "Explore SurfSense plans and pricing. Start free with 500 pages & $5 of premium credit. Use ChatGPT, Claude AI, and premium AI models. Pay as you go at provider cost — $1 buys $1 of credit.", alternates: { canonical: "https://surfsense.com/pricing", }, diff --git a/surfsense_web/app/dashboard/[search_space_id]/buy-more/page.tsx b/surfsense_web/app/dashboard/[search_space_id]/buy-more/page.tsx index 3017160e1..0c5662712 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/buy-more/page.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/buy-more/page.tsx @@ -8,7 +8,7 @@ import { cn } from "@/lib/utils"; const TABS = [ { id: "pages", label: "Pages" }, - { id: "tokens", label: "Premium Tokens" }, + { id: "tokens", label: "Premium Credit" }, ] as const; type TabId = (typeof TABS)[number]["id"]; diff --git a/surfsense_web/app/dashboard/[search_space_id]/user-settings/components/PurchaseHistoryContent.tsx b/surfsense_web/app/dashboard/[search_space_id]/user-settings/components/PurchaseHistoryContent.tsx index 2b7422f80..cf73b5eba 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/user-settings/components/PurchaseHistoryContent.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/user-settings/components/PurchaseHistoryContent.tsx @@ -28,6 +28,12 @@ type UnifiedPurchase = { kind: PurchaseKind; created_at: string; status: PagePurchaseStatus; + /** + * Granted units. Interpretation depends on ``kind``: + * - ``"pages"`` — integer number of indexed pages. + * - ``"tokens"`` — integer micro-USD of credit (1_000_000 = $1.00). + * The ``Granted`` column formats accordingly. + */ granted: number; amount_total: number | null; currency: string | null; @@ -58,7 +64,7 @@ const KIND_META: Record< iconClass: "text-sky-500", }, tokens: { - label: "Premium Tokens", + label: "Premium Credit", icon: Coins, iconClass: "text-amber-500", }, @@ -97,12 +103,25 @@ function normalizeTokenPurchase(p: TokenPurchase): UnifiedPurchase { kind: "tokens", created_at: p.created_at, status: p.status, - granted: p.tokens_granted, + granted: p.credit_micros_granted, amount_total: p.amount_total, currency: p.currency, }; } +function formatGranted(p: UnifiedPurchase): string { + if (p.kind === "tokens") { + const dollars = p.granted / 1_000_000; + // Premium credit packs are always whole dollars at the moment, but + // future fractional grants (refunds, partial top-ups) shouldn't + // silently round to "$0". + if (dollars >= 1) return `$${dollars.toFixed(2)} of credit`; + if (dollars > 0) return `$${dollars.toFixed(3)} of credit`; + return "$0 of credit"; + } + return p.granted.toLocaleString(); +} + export function PurchaseHistoryContent() { const results = useQueries({ queries: [ @@ -143,7 +162,7 @@ export function PurchaseHistoryContent() {

No purchases yet

- Your page and premium token purchases will appear here after checkout. + Your page and premium credit purchases will appear here after checkout.

); @@ -177,7 +196,7 @@ export function PurchaseHistoryContent() { - {p.granted.toLocaleString()} + {formatGranted(p)} {formatAmount(p.amount_total, p.currency)} diff --git a/surfsense_web/atoms/user/user-query.atoms.ts b/surfsense_web/atoms/user/user-query.atoms.ts index a59811324..4b6717440 100644 --- a/surfsense_web/atoms/user/user-query.atoms.ts +++ b/surfsense_web/atoms/user/user-query.atoms.ts @@ -8,9 +8,9 @@ const userQueryFn = () => userApiService.getMe(); export const currentUserAtom = atomWithQuery(() => { return { queryKey: USER_QUERY_KEY, - // Live-changing numeric fields (pages_*, premium_tokens_*) are now - // pushed via Zero (queries.user.me()), so /users/me only needs to - // fire once per session for the static profile fields. + // Live-changing numeric fields (pages_*, premium_credit_micros_*) + // are now pushed via Zero (queries.user.me()), so /users/me only + // needs to fire once per session for the static profile fields. staleTime: Infinity, enabled: !!getBearerToken(), queryFn: userQueryFn, diff --git a/surfsense_web/components/assistant-ui/assistant-message.tsx b/surfsense_web/components/assistant-ui/assistant-message.tsx index 711bb2fe2..ffb0e4dc8 100644 --- a/surfsense_web/components/assistant-ui/assistant-message.tsx +++ b/surfsense_web/components/assistant-ui/assistant-message.tsx @@ -399,6 +399,19 @@ function formatMessageDate(date: Date): string { }); } +/** + * Format provider USD cost (in micro-USD) for inline display next to a + * token count. Falls back to ``"<$0.001"`` for sub-tenth-of-a-cent + * costs so a real-but-tiny figure doesn't render as ``$0.000``. + */ +function formatTurnCost(micros: number): string { + const dollars = micros / 1_000_000; + if (dollars >= 1) return `$${dollars.toFixed(2)}`; + if (dollars >= 0.01) return `$${dollars.toFixed(3)}`; + if (dollars > 0) return "<$0.001"; + return "$0"; +} + const MessageInfoDropdown: FC = () => { const messageId = useAuiState(({ message }) => message?.id); const createdAt = useAuiState(({ message }) => message?.createdAt); @@ -451,6 +464,7 @@ const MessageInfoDropdown: FC = () => { {models.length > 0 ? ( models.map(([model, counts]) => { const { name, icon } = resolveModel(model); + const costMicros = counts.cost_micros; return ( { {counts.total_tokens.toLocaleString()} tokens + {costMicros && costMicros > 0 + ? ` · ${formatTurnCost(costMicros)}` + : ""} ); @@ -474,6 +491,9 @@ const MessageInfoDropdown: FC = () => { > {usage.total_tokens.toLocaleString()} tokens + {usage.cost_micros && usage.cost_micros > 0 + ? ` · ${formatTurnCost(usage.cost_micros)}` + : ""} )} diff --git a/surfsense_web/components/assistant-ui/token-usage-context.tsx b/surfsense_web/components/assistant-ui/token-usage-context.tsx index b3f71ab21..dd80bcac3 100644 --- a/surfsense_web/components/assistant-ui/token-usage-context.tsx +++ b/surfsense_web/components/assistant-ui/token-usage-context.tsx @@ -13,13 +13,30 @@ export interface TokenUsageData { prompt_tokens: number; completion_tokens: number; total_tokens: number; + /** + * Total provider USD cost for this assistant turn, in micro-USD + * (1_000_000 = $1.00). Populated from LiteLLM's response_cost on + * the backend. Optional because pre-cost-credits messages persisted + * before the migration won't have it. + */ + cost_micros?: number; usage?: Record< string, - { prompt_tokens: number; completion_tokens: number; total_tokens: number } + { + prompt_tokens: number; + completion_tokens: number; + total_tokens: number; + cost_micros?: number; + } >; model_breakdown?: Record< string, - { prompt_tokens: number; completion_tokens: number; total_tokens: number } + { + prompt_tokens: number; + completion_tokens: number; + total_tokens: number; + cost_micros?: number; + } >; } diff --git a/surfsense_web/components/free-chat/quota-warning-banner.tsx b/surfsense_web/components/free-chat/quota-warning-banner.tsx index 3bfedf1b3..e013a64a8 100644 --- a/surfsense_web/components/free-chat/quota-warning-banner.tsx +++ b/surfsense_web/components/free-chat/quota-warning-banner.tsx @@ -40,7 +40,7 @@ export function QuotaWarningBanner({

You've used all {limit.toLocaleString()} free tokens. Create a free account to - get 3 million tokens and access to all models. + get $5 of premium credit and access to all models.

Create an account {" "} - for 5M free tokens. + for $5 of premium credit.

- {(totalTokens / 1_000_000).toFixed(0)}M tokens + ${(totalCreditMicros / 1_000_000).toFixed(0)} of credit ))}
- {(totalTokens / 1_000_000).toFixed(0)}M premium tokens + ${(totalCreditMicros / 1_000_000).toFixed(0)} of credit ${totalPrice}
@@ -149,7 +164,7 @@ export function BuyTokensContent() { ) : ( <> - Buy {(totalTokens / 1_000_000).toFixed(0)}M Tokens for ${totalPrice} + Buy ${(totalCreditMicros / 1_000_000).toFixed(0)} of credit for ${totalPrice} )} diff --git a/surfsense_web/components/settings/image-model-manager.tsx b/surfsense_web/components/settings/image-model-manager.tsx index f5f128f80..ced97464e 100644 --- a/surfsense_web/components/settings/image-model-manager.tsx +++ b/surfsense_web/components/settings/image-model-manager.tsx @@ -190,7 +190,25 @@ export function ImageModelManager({ searchSpaceId }: ImageModelManagerProps) { ? "model" : "models"} {" "} - available from your administrator. + available from your administrator.{" "} + {(() => { + const nonAuto = globalConfigs.filter( + (g) => !("is_auto_mode" in g && g.is_auto_mode) + ); + const premium = nonAuto.filter( + (g) => + "billing_tier" in g && + (g as { billing_tier?: string }).billing_tier === "premium" + ).length; + const free = nonAuto.length - premium; + if (premium > 0 && free > 0) { + return `${premium} premium, ${free} free.`; + } + if (premium > 0) { + return `All ${premium} premium — debits your shared credit pool.`; + } + return `All ${free} free.`; + })()}

diff --git a/surfsense_web/components/settings/llm-role-manager.tsx b/surfsense_web/components/settings/llm-role-manager.tsx index e21dc9028..a2eb6a22e 100644 --- a/surfsense_web/components/settings/llm-role-manager.tsx +++ b/surfsense_web/components/settings/llm-role-manager.tsx @@ -371,6 +371,17 @@ export function LLMRoleManager({ searchSpaceId }: LLMRoleManagerProps) { {roleGlobalConfigs.map((config) => { const isAuto = "is_auto_mode" in config && config.is_auto_mode; + // Read billing_tier from the global config; default to "free" + // for legacy YAMLs / Auto stub. Premium gets a purple badge, + // free gets an emerald one — same palette as the chat + // model selector so the meaning is consistent across + // surfaces (issues E, H). + const billingTier = + ("billing_tier" in config && + typeof config.billing_tier === "string" && + config.billing_tier) || + "free"; + const isPremium = billingTier === "premium"; return ( {config.name} - {isAuto && ( + {isAuto ? ( Recommended + ) : isPremium ? ( + + Premium + + ) : ( + + Free + )} diff --git a/surfsense_web/components/settings/vision-model-manager.tsx b/surfsense_web/components/settings/vision-model-manager.tsx index 8abfa4774..886d71008 100644 --- a/surfsense_web/components/settings/vision-model-manager.tsx +++ b/surfsense_web/components/settings/vision-model-manager.tsx @@ -191,7 +191,25 @@ export function VisionModelManager({ searchSpaceId }: VisionModelManagerProps) { ? "model" : "models"} {" "} - available from your administrator. + available from your administrator.{" "} + {(() => { + const nonAuto = globalConfigs.filter( + (g) => !("is_auto_mode" in g && g.is_auto_mode) + ); + const premium = nonAuto.filter( + (g) => + "billing_tier" in g && + (g as { billing_tier?: string }).billing_tier === "premium" + ).length; + const free = nonAuto.length - premium; + if (premium > 0 && free > 0) { + return `${premium} premium, ${free} free.`; + } + if (premium > 0) { + return `All ${premium} premium — debits your shared credit pool.`; + } + return `All ${free} free.`; + })()}

diff --git a/surfsense_web/contexts/login-gate.tsx b/surfsense_web/contexts/login-gate.tsx index fad64fa9f..790e5c00e 100644 --- a/surfsense_web/contexts/login-gate.tsx +++ b/surfsense_web/contexts/login-gate.tsx @@ -44,8 +44,8 @@ export function LoginGateProvider({ children }: { children: ReactNode }) { Create a free account to {feature} - Get 3 million tokens, save chat history, upload documents, use all AI tools, and - connect 30+ integrations. + Get $5 of premium credit, save chat history, upload documents, use all AI tools, + and connect 30+ integrations. diff --git a/surfsense_web/contracts/types/new-llm-config.types.ts b/surfsense_web/contracts/types/new-llm-config.types.ts index ecffc573e..2d6b70eda 100644 --- a/surfsense_web/contracts/types/new-llm-config.types.ts +++ b/surfsense_web/contracts/types/new-llm-config.types.ts @@ -258,6 +258,8 @@ export const globalImageGenConfig = z.object({ litellm_params: z.record(z.string(), z.any()).nullable().optional(), is_global: z.literal(true), is_auto_mode: z.boolean().optional().default(false), + billing_tier: z.string().default("free"), + quota_reserve_micros: z.number().nullable().optional(), }); export const getGlobalImageGenConfigsResponse = z.array(globalImageGenConfig); @@ -338,6 +340,10 @@ export const globalVisionLLMConfig = z.object({ litellm_params: z.record(z.string(), z.any()).nullable().optional(), is_global: z.literal(true), is_auto_mode: z.boolean().optional().default(false), + billing_tier: z.string().default("free"), + quota_reserve_tokens: z.number().nullable().optional(), + input_cost_per_token: z.number().nullable().optional(), + output_cost_per_token: z.number().nullable().optional(), }); export const getGlobalVisionLLMConfigsResponse = z.array(globalVisionLLMConfig); diff --git a/surfsense_web/contracts/types/stripe.types.ts b/surfsense_web/contracts/types/stripe.types.ts index c8b017044..251f7a176 100644 --- a/surfsense_web/contracts/types/stripe.types.ts +++ b/surfsense_web/contracts/types/stripe.types.ts @@ -32,7 +32,7 @@ export const getPagePurchasesResponse = z.object({ purchases: z.array(pagePurchase), }); -// Premium token purchases +// Premium credit purchases export const createTokenCheckoutSessionRequest = z.object({ quantity: z.number().int().min(1).max(100), search_space_id: z.number().int().min(1), @@ -42,11 +42,16 @@ export const createTokenCheckoutSessionResponse = z.object({ checkout_url: z.string(), }); +// Premium credit balance + purchase records. +// +// The unit is integer micro-USD (1_000_000 == $1.00). The schema names +// kept the ``Token`` prefix for API back-compat with pinned clients; +// the field names below are authoritative. export const tokenStripeStatusResponse = z.object({ token_buying_enabled: z.boolean(), - premium_tokens_used: z.number().default(0), - premium_tokens_limit: z.number().default(0), - premium_tokens_remaining: z.number().default(0), + premium_credit_micros_used: z.number().default(0), + premium_credit_micros_limit: z.number().default(0), + premium_credit_micros_remaining: z.number().default(0), }); export const tokenPurchaseStatusEnum = pagePurchaseStatusEnum; @@ -56,7 +61,7 @@ export const tokenPurchase = z.object({ stripe_checkout_session_id: z.string(), stripe_payment_intent_id: z.string().nullable(), quantity: z.number(), - tokens_granted: z.number(), + credit_micros_granted: z.number(), amount_total: z.number().nullable(), currency: z.string().nullable(), status: tokenPurchaseStatusEnum, diff --git a/surfsense_web/lib/chat/chat-error-classifier.ts b/surfsense_web/lib/chat/chat-error-classifier.ts index 95d9848f2..1c67d59a1 100644 --- a/surfsense_web/lib/chat/chat-error-classifier.ts +++ b/surfsense_web/lib/chat/chat-error-classifier.ts @@ -41,7 +41,7 @@ export interface RawChatErrorInput { } export const PREMIUM_QUOTA_ASSISTANT_MESSAGE = - "I can’t continue with the current premium model because your premium tokens are exhausted. Switch to a free model or buy more tokens to continue."; + "I can’t continue with the current premium model because your premium credit is exhausted. Switch to a free model or top up your credit to continue."; function getErrorMessage(error: unknown): string { if (error instanceof Error) return error.message; diff --git a/surfsense_web/lib/chat/streaming-state.ts b/surfsense_web/lib/chat/streaming-state.ts index 80e7bffbe..6df56f0ce 100644 --- a/surfsense_web/lib/chat/streaming-state.ts +++ b/surfsense_web/lib/chat/streaming-state.ts @@ -541,16 +541,23 @@ export type SSEEvent = data: { usage: Record< string, - { prompt_tokens: number; completion_tokens: number; total_tokens: number } + { + prompt_tokens: number; + completion_tokens: number; + total_tokens: number; + cost_micros?: number; + } >; prompt_tokens: number; completion_tokens: number; total_tokens: number; + cost_micros?: number; call_details: Array<{ model: string; prompt_tokens: number; completion_tokens: number; total_tokens: number; + cost_micros?: number; }>; }; } diff --git a/surfsense_web/lib/chat/thread-persistence.ts b/surfsense_web/lib/chat/thread-persistence.ts index fc970c26e..7fec60a23 100644 --- a/surfsense_web/lib/chat/thread-persistence.ts +++ b/surfsense_web/lib/chat/thread-persistence.ts @@ -30,9 +30,20 @@ export interface TokenUsageSummary { prompt_tokens: number; completion_tokens: number; total_tokens: number; + /** + * Total provider USD cost for this assistant turn, in micro-USD + * (1_000_000 = $1.00). Optional because rows persisted before the + * cost-credits migration won't have it. + */ + cost_micros?: number; model_breakdown?: Record< string, - { prompt_tokens: number; completion_tokens: number; total_tokens: number } + { + prompt_tokens: number; + completion_tokens: number; + total_tokens: number; + cost_micros?: number; + } > | null; } diff --git a/surfsense_web/zero/schema/user.ts b/surfsense_web/zero/schema/user.ts index 0e6234db5..f483fa9b4 100644 --- a/surfsense_web/zero/schema/user.ts +++ b/surfsense_web/zero/schema/user.ts @@ -1,11 +1,20 @@ import { number, string, table } from "@rocicorp/zero"; +/** + * Live-meter slice of the ``user`` table replicated through Zero. + * + * ``premiumCreditMicrosLimit`` / ``premiumCreditMicrosUsed`` are stored + * as integer micro-USD (1_000_000 == $1.00). UI consumers divide by 1M + * when displaying. Sensitive fields (email, hashed_password, oauth, etc.) + * are intentionally omitted via the Postgres column-list publication so + * they never enter WAL replication. + */ export const userTable = table("user") .columns({ id: string(), pagesLimit: number().from("pages_limit"), pagesUsed: number().from("pages_used"), - premiumTokensLimit: number().from("premium_tokens_limit"), - premiumTokensUsed: number().from("premium_tokens_used"), + premiumCreditMicrosLimit: number().from("premium_credit_micros_limit"), + premiumCreditMicrosUsed: number().from("premium_credit_micros_used"), }) .primaryKey("id");