feat: unified credits and its cost calculations

This commit is contained in:
DESKTOP-RTLN3BA\$punk 2026-05-02 14:34:23 -07:00
parent 451a98936e
commit ae9d36d77f
61 changed files with 5835 additions and 272 deletions

View file

@ -31,6 +31,7 @@ from app.config import (
initialize_image_gen_router,
initialize_llm_router,
initialize_openrouter_integration,
initialize_pricing_registration,
initialize_vision_llm_router,
)
from app.db import User, create_db_and_tables, get_async_session
@ -432,6 +433,7 @@ async def lifespan(app: FastAPI):
await setup_checkpointer_tables()
initialize_openrouter_integration()
_start_openrouter_background_refresh()
initialize_pricing_registration()
initialize_llm_router()
initialize_image_gen_router()
initialize_vision_llm_router()

View file

@ -22,10 +22,12 @@ def init_worker(**kwargs):
initialize_image_gen_router,
initialize_llm_router,
initialize_openrouter_integration,
initialize_pricing_registration,
initialize_vision_llm_router,
)
initialize_openrouter_integration()
initialize_pricing_registration()
initialize_llm_router()
initialize_image_gen_router()
initialize_vision_llm_router()

View file

@ -138,7 +138,11 @@ def load_global_image_gen_configs():
try:
with open(global_config_file, encoding="utf-8") as f:
data = yaml.safe_load(f)
return data.get("global_image_generation_configs", [])
configs = data.get("global_image_generation_configs", []) or []
for cfg in configs:
if isinstance(cfg, dict):
cfg.setdefault("billing_tier", "free")
return configs
except Exception as e:
print(f"Warning: Failed to load global image generation configs: {e}")
return []
@ -153,7 +157,11 @@ def load_global_vision_llm_configs():
try:
with open(global_config_file, encoding="utf-8") as f:
data = yaml.safe_load(f)
return data.get("global_vision_llm_configs", [])
configs = data.get("global_vision_llm_configs", []) or []
for cfg in configs:
if isinstance(cfg, dict):
cfg.setdefault("billing_tier", "free")
return configs
except Exception as e:
print(f"Warning: Failed to load global vision LLM configs: {e}")
return []
@ -254,6 +262,15 @@ def load_openrouter_integration_settings() -> dict | None:
"anonymous_enabled_free", settings["anonymous_enabled"]
)
# Image generation + vision LLM emission are opt-in (issue L).
# OpenRouter's catalogue contains hundreds of image / vision
# capable models; auto-injecting all of them into every
# deployment would explode the model selector and surprise
# operators upgrading from prior versions. Default to False so
# admins must explicitly turn them on.
settings.setdefault("image_generation_enabled", False)
settings.setdefault("vision_enabled", False)
return settings
except Exception as e:
print(f"Warning: Failed to load OpenRouter integration settings: {e}")
@ -296,10 +313,60 @@ def initialize_openrouter_integration():
)
else:
print("Info: OpenRouter integration enabled but no models fetched")
# Image generation + vision LLM emissions are opt-in (issue L).
# Both reuse the catalogue already cached by ``service.initialize``
# so we don't make additional network calls here.
if settings.get("image_generation_enabled"):
try:
image_configs = service.get_image_generation_configs()
if image_configs:
config.GLOBAL_IMAGE_GEN_CONFIGS.extend(image_configs)
print(
f"Info: OpenRouter integration added {len(image_configs)} "
f"image-generation models"
)
except Exception as e:
print(f"Warning: Failed to inject OpenRouter image-gen configs: {e}")
if settings.get("vision_enabled"):
try:
vision_configs = service.get_vision_llm_configs()
if vision_configs:
config.GLOBAL_VISION_LLM_CONFIGS.extend(vision_configs)
print(
f"Info: OpenRouter integration added {len(vision_configs)} "
f"vision LLM models"
)
except Exception as e:
print(f"Warning: Failed to inject OpenRouter vision-LLM configs: {e}")
except Exception as e:
print(f"Warning: Failed to initialize OpenRouter integration: {e}")
def initialize_pricing_registration():
"""
Teach LiteLLM the per-token cost of every deployment in
``config.GLOBAL_LLM_CONFIGS`` (OpenRouter dynamic models pulled
from the OpenRouter catalogue + any operator-declared YAML pricing).
Must run AFTER ``initialize_openrouter_integration()`` so the
OpenRouter catalogue is populated and BEFORE the first LLM call so
``response_cost`` is available in ``TokenTrackingCallback``.
Failures are logged but never raised startup must not be blocked
by a missing pricing entry; the worst-case is the model debits 0.
"""
try:
from app.services.pricing_registration import (
register_pricing_from_global_configs,
)
register_pricing_from_global_configs()
except Exception as e:
print(f"Warning: Failed to register LiteLLM pricing: {e}")
def initialize_llm_router():
"""
Initialize the LLM Router service for Auto mode.
@ -444,14 +511,54 @@ class Config:
os.getenv("STRIPE_RECONCILIATION_BATCH_SIZE", "100")
)
# Premium token quota settings
PREMIUM_TOKEN_LIMIT = int(os.getenv("PREMIUM_TOKEN_LIMIT", "3000000"))
# Premium credit (micro-USD) quota settings.
#
# Storage unit is integer micro-USD (1_000_000 = $1.00). The legacy
# ``PREMIUM_TOKEN_LIMIT`` and ``STRIPE_TOKENS_PER_UNIT`` env vars are
# still honoured for one release as fall-back values — the prior
# $1-per-1M-tokens Stripe price means every existing value maps 1:1
# to micros, so operators upgrading without changing their .env still
# get correct behaviour. A startup deprecation warning fires below if
# they're set.
PREMIUM_CREDIT_MICROS_LIMIT = int(
os.getenv("PREMIUM_CREDIT_MICROS_LIMIT")
or os.getenv("PREMIUM_TOKEN_LIMIT", "5000000")
)
STRIPE_PREMIUM_TOKEN_PRICE_ID = os.getenv("STRIPE_PREMIUM_TOKEN_PRICE_ID")
STRIPE_TOKENS_PER_UNIT = int(os.getenv("STRIPE_TOKENS_PER_UNIT", "1000000"))
STRIPE_CREDIT_MICROS_PER_UNIT = int(
os.getenv("STRIPE_CREDIT_MICROS_PER_UNIT")
or os.getenv("STRIPE_TOKENS_PER_UNIT", "1000000")
)
STRIPE_TOKEN_BUYING_ENABLED = (
os.getenv("STRIPE_TOKEN_BUYING_ENABLED", "FALSE").upper() == "TRUE"
)
# Safety ceiling on the per-call premium reservation. ``stream_new_chat``
# estimates an upper-bound cost from ``litellm.get_model_info`` x the
# config's ``quota_reserve_tokens`` and clamps the result to this value
# so a misconfigured "$1000/M" model can't lock the user's whole balance
# on one call. Default $1.00 covers realistic worst-cases (Opus + 4K
# reserve_tokens ≈ $0.36) with headroom.
QUOTA_MAX_RESERVE_MICROS = int(os.getenv("QUOTA_MAX_RESERVE_MICROS", "1000000"))
if os.getenv("PREMIUM_TOKEN_LIMIT") and not os.getenv(
"PREMIUM_CREDIT_MICROS_LIMIT"
):
print(
"Warning: PREMIUM_TOKEN_LIMIT is deprecated; rename to "
"PREMIUM_CREDIT_MICROS_LIMIT (1:1 numerical mapping under the "
"current Stripe price). The old key will be removed in a "
"future release."
)
if os.getenv("STRIPE_TOKENS_PER_UNIT") and not os.getenv(
"STRIPE_CREDIT_MICROS_PER_UNIT"
):
print(
"Warning: STRIPE_TOKENS_PER_UNIT is deprecated; rename to "
"STRIPE_CREDIT_MICROS_PER_UNIT (1:1 numerical mapping). "
"The old key will be removed in a future release."
)
# Anonymous / no-login mode settings
NOLOGIN_MODE_ENABLED = os.getenv("NOLOGIN_MODE_ENABLED", "FALSE").upper() == "TRUE"
ANON_TOKEN_LIMIT = int(os.getenv("ANON_TOKEN_LIMIT", "500000"))
@ -464,6 +571,35 @@ class Config:
# Default quota reserve tokens when not specified per-model
QUOTA_MAX_RESERVE_PER_CALL = int(os.getenv("QUOTA_MAX_RESERVE_PER_CALL", "8000"))
# Per-image reservation (in micro-USD) used by ``billable_call`` for the
# ``POST /image-generations`` endpoint when the global config does not
# override it. $0.05 covers realistic worst-cases for current OpenAI /
# OpenRouter image-gen pricing. Bypassed entirely for free configs.
QUOTA_DEFAULT_IMAGE_RESERVE_MICROS = int(
os.getenv("QUOTA_DEFAULT_IMAGE_RESERVE_MICROS", "50000")
)
# Per-podcast reservation (in micro-USD). One agent LLM call generating
# a transcript, typically 5k-20k completion tokens. $0.20 covers a long
# premium-model run. Tune via env.
QUOTA_DEFAULT_PODCAST_RESERVE_MICROS = int(
os.getenv("QUOTA_DEFAULT_PODCAST_RESERVE_MICROS", "200000")
)
# Per-video-presentation reservation (in micro-USD). Fan-out of N
# slide-scene generations (up to ``VIDEO_PRESENTATION_MAX_SLIDES=30``)
# plus refine retries; can produce many premium completions. $1.00
# covers worst-case. Tune via env.
#
# NOTE: this equals the existing ``QUOTA_MAX_RESERVE_MICROS`` default of
# 1_000_000. The override path in ``billable_call`` bypasses the
# per-call clamp in ``estimate_call_reserve_micros``, so this is the
# *actual* hold — raising it via env is fine but means a single video
# task can lock $1+ of credit.
QUOTA_DEFAULT_VIDEO_PRESENTATION_RESERVE_MICROS = int(
os.getenv("QUOTA_DEFAULT_VIDEO_PRESENTATION_RESERVE_MICROS", "1000000")
)
# Abuse prevention: concurrent stream cap and CAPTCHA
ANON_MAX_CONCURRENT_STREAMS = int(os.getenv("ANON_MAX_CONCURRENT_STREAMS", "2"))
ANON_CAPTCHA_REQUEST_THRESHOLD = int(

View file

@ -19,6 +19,24 @@
# Structure matches NewLLMConfig:
# - Model configuration (provider, model_name, api_key, etc.)
# - Prompt configuration (system_instructions, citations_enabled)
#
# COST-BASED PREMIUM CREDITS:
# Each premium config bills the user's USD-credit balance based on the
# actual provider cost reported by LiteLLM. For models LiteLLM already
# knows (most OpenAI/Anthropic/etc. names) you don't need to do anything.
# For custom Azure deployment names (e.g. an in-house "gpt-5.4" deployment)
# or any model LiteLLM doesn't have in its built-in pricing table, declare
# per-token costs inline so they bill correctly:
#
# litellm_params:
# base_model: "my-custom-azure-deploy"
# # USD per token; e.g. 0.000003 == $3.00 per million input tokens
# input_cost_per_token: 0.000003
# output_cost_per_token: 0.000015
#
# OpenRouter dynamic models pull pricing automatically from OpenRouter's
# API — no inline declaration needed. Models without resolvable pricing
# debit $0 from the user's balance and log a WARNING.
# Router Settings for Auto Mode
# These settings control how the LiteLLM Router distributes requests across models
@ -292,6 +310,17 @@ openrouter_integration:
free_rpm: 20
free_tpm: 100000
# Image generation + vision LLM emission are OPT-IN. OpenRouter's catalogue
# contains hundreds of image- and vision-capable models; turning these on
# injects them into the global Image-Generation / Vision-LLM model
# selectors alongside any static configs. Tier (free/premium) is derived
# per model the same way it is for chat (`:free` suffix or zero pricing).
# When a user picks a premium image/vision model the call debits the
# shared $5 USD-cost-based premium credit pool — so leaving these off
# avoids surprise quota burn on existing deployments. Default: false.
image_generation_enabled: false
vision_enabled: false
litellm_params:
max_tokens: 16384
system_instructions: ""

View file

@ -731,6 +731,7 @@ class TokenUsage(BaseModel, TimestampMixin):
prompt_tokens = Column(Integer, nullable=False, default=0)
completion_tokens = Column(Integer, nullable=False, default=0)
total_tokens = Column(Integer, nullable=False, default=0)
cost_micros = Column(BigInteger, nullable=False, default=0, server_default="0")
model_breakdown = Column(JSONB, nullable=True)
call_details = Column(JSONB, nullable=True)
@ -1793,7 +1794,15 @@ class PagePurchase(Base, TimestampMixin):
class PremiumTokenPurchase(Base, TimestampMixin):
"""Tracks Stripe checkout sessions used to grant additional premium token credits."""
"""Tracks Stripe checkout sessions used to grant additional premium credit (USD micro-units).
Note: the table name is preserved (``premium_token_purchases``) for
operational continuity even though the unit is now USD micro-credits
instead of raw tokens. The ``credit_micros_granted`` column replaced
the legacy ``tokens_granted`` in migration 140; the stored values
were not transformed because the prior $1 = 1M tokens Stripe price
makes the unit conversion 1:1 numerically.
"""
__tablename__ = "premium_token_purchases"
__allow_unmapped__ = True
@ -1810,7 +1819,7 @@ class PremiumTokenPurchase(Base, TimestampMixin):
)
stripe_payment_intent_id = Column(String(255), nullable=True, index=True)
quantity = Column(Integer, nullable=False)
tokens_granted = Column(BigInteger, nullable=False)
credit_micros_granted = Column(BigInteger, nullable=False)
amount_total = Column(Integer, nullable=True)
currency = Column(String(10), nullable=True)
status = Column(
@ -2109,16 +2118,16 @@ if config.AUTH_TYPE == "GOOGLE":
)
pages_used = Column(Integer, nullable=False, default=0, server_default="0")
premium_tokens_limit = Column(
premium_credit_micros_limit = Column(
BigInteger,
nullable=False,
default=config.PREMIUM_TOKEN_LIMIT,
server_default=str(config.PREMIUM_TOKEN_LIMIT),
default=config.PREMIUM_CREDIT_MICROS_LIMIT,
server_default=str(config.PREMIUM_CREDIT_MICROS_LIMIT),
)
premium_tokens_used = Column(
premium_credit_micros_used = Column(
BigInteger, nullable=False, default=0, server_default="0"
)
premium_tokens_reserved = Column(
premium_credit_micros_reserved = Column(
BigInteger, nullable=False, default=0, server_default="0"
)
@ -2241,16 +2250,16 @@ else:
)
pages_used = Column(Integer, nullable=False, default=0, server_default="0")
premium_tokens_limit = Column(
premium_credit_micros_limit = Column(
BigInteger,
nullable=False,
default=config.PREMIUM_TOKEN_LIMIT,
server_default=str(config.PREMIUM_TOKEN_LIMIT),
default=config.PREMIUM_CREDIT_MICROS_LIMIT,
server_default=str(config.PREMIUM_CREDIT_MICROS_LIMIT),
)
premium_tokens_used = Column(
premium_credit_micros_used = Column(
BigInteger, nullable=False, default=0, server_default="0"
)
premium_tokens_reserved = Column(
premium_credit_micros_reserved = Column(
BigInteger, nullable=False, default=0, server_default="0"
)

View file

@ -68,12 +68,25 @@ class EtlPipelineService:
etl_service="VISION_LLM",
content_type="image",
)
except Exception:
logging.warning(
"Vision LLM failed for %s, falling back to document parser",
request.filename,
exc_info=True,
)
except Exception as exc:
# Special-case quota exhaustion so we log a clearer message
# — the vision LLM didn't "fail", the user just ran out of
# premium credit. Falling through to the document parser
# is a graceful degradation: OCR/Unstructured still
# extracts text from the image without burning credit.
from app.services.billable_calls import QuotaInsufficientError
if isinstance(exc, QuotaInsufficientError):
logging.info(
"Vision LLM quota exhausted for %s; falling back to document parser",
request.filename,
)
else:
logging.warning(
"Vision LLM failed for %s, falling back to document parser",
request.filename,
exc_info=True,
)
else:
logging.info(
"No vision LLM provided, falling back to document parser for %s",

View file

@ -36,6 +36,11 @@ from app.schemas import (
ImageGenerationListRead,
ImageGenerationRead,
)
from app.services.billable_calls import (
DEFAULT_IMAGE_RESERVE_MICROS,
QuotaInsufficientError,
billable_call,
)
from app.services.image_gen_router_service import (
IMAGE_GEN_AUTO_MODE_ID,
ImageGenRouterService,
@ -92,6 +97,50 @@ def _build_model_string(
return f"{prefix}/{model_name}"
async def _resolve_billing_for_image_gen(
session: AsyncSession,
config_id: int | None,
search_space: SearchSpace,
) -> tuple[str, str, int]:
"""Resolve ``(billing_tier, base_model, reserve_micros)`` for a request.
The resolution mirrors ``_execute_image_generation``'s lookup tree but
only extracts the fields needed for billing we do this *before*
``billable_call`` so the reservation is correctly sized for the
config that will actually run, and so we don't open an
``ImageGeneration`` row for a request that's about to 402.
User-owned (positive ID) BYOK configs are always free they cost
the user nothing on our side. Auto mode currently treats as free
because the underlying router can dispatch to either premium or
free YAML configs and we don't surface the resolved deployment up
here yet. Bringing Auto under premium billing would require
threading the chosen deployment back from ``ImageGenRouterService``.
"""
resolved_id = config_id
if resolved_id is None:
resolved_id = search_space.image_generation_config_id or IMAGE_GEN_AUTO_MODE_ID
if is_image_gen_auto_mode(resolved_id):
return ("free", "auto", DEFAULT_IMAGE_RESERVE_MICROS)
if resolved_id < 0:
cfg = _get_global_image_gen_config(resolved_id) or {}
billing_tier = str(cfg.get("billing_tier", "free")).lower()
base_model = _build_model_string(
cfg.get("provider", ""),
cfg.get("model_name", ""),
cfg.get("custom_provider"),
)
reserve_micros = int(
cfg.get("quota_reserve_micros") or DEFAULT_IMAGE_RESERVE_MICROS
)
return (billing_tier, base_model, reserve_micros)
# Positive ID = user-owned BYOK image-gen config — always free.
return ("free", "user_byok", DEFAULT_IMAGE_RESERVE_MICROS)
async def _execute_image_generation(
session: AsyncSession,
image_gen: ImageGeneration,
@ -225,6 +274,9 @@ async def get_global_image_gen_configs(
"litellm_params": {},
"is_global": True,
"is_auto_mode": True,
# Auto mode currently treated as free until per-deployment
# billing-tier surfacing lands (see _resolve_billing_for_image_gen).
"billing_tier": "free",
}
)
@ -241,6 +293,8 @@ async def get_global_image_gen_configs(
"api_version": cfg.get("api_version") or None,
"litellm_params": cfg.get("litellm_params", {}),
"is_global": True,
"billing_tier": cfg.get("billing_tier", "free"),
"quota_reserve_micros": cfg.get("quota_reserve_micros"),
}
)
@ -454,7 +508,26 @@ async def create_image_generation(
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""Create and execute an image generation request."""
"""Create and execute an image generation request.
Premium configs are gated by the user's shared premium credit pool.
The flow is:
1. Permission check + load the search space (cheap, no provider call).
2. Resolve which config will run so we know its billing tier and the
worst-case reservation size *before* opening any DB rows.
3. Wrap the entire ImageGeneration row insert + provider call in
``billable_call``. If quota is denied, ``billable_call`` raises
``QuotaInsufficientError`` *before* we flush a row, which we
translate to HTTP 402 (no orphaned rows on the user's account,
no inserted error rows for "you ran out of credit").
4. On success, the actual ``response_cost`` flows through the
LiteLLM callback into the accumulator, and ``billable_call``
finalizes the debit at exit. Inner ``try/except`` still catches
provider errors and stores them on ``error_message`` (HTTP 200
with ``error_message`` set is preserved for failed-but-not-quota
scenarios clients already know how to surface those).
"""
try:
await check_permission(
session,
@ -471,33 +544,70 @@ async def create_image_generation(
if not search_space:
raise HTTPException(status_code=404, detail="Search space not found")
db_image_gen = ImageGeneration(
prompt=data.prompt,
model=data.model,
n=data.n,
quality=data.quality,
size=data.size,
style=data.style,
response_format=data.response_format,
image_generation_config_id=data.image_generation_config_id,
search_space_id=data.search_space_id,
created_by_id=user.id,
billing_tier, base_model, reserve_micros = await _resolve_billing_for_image_gen(
session, data.image_generation_config_id, search_space
)
session.add(db_image_gen)
await session.flush()
try:
await _execute_image_generation(session, db_image_gen, search_space)
except Exception as e:
logger.exception("Image generation call failed")
db_image_gen.error_message = str(e)
# billable_call runs OUTSIDE the inner try/except so QuotaInsufficientError
# propagates to the outer ``except QuotaInsufficientError`` handler
# below as HTTP 402 — it is intentionally NOT swallowed into
# ``error_message`` because that would (1) imply a successful row
# exists when none does, and (2) return HTTP 200 to a client
# whose request was actively *denied* (issue K).
async with billable_call(
user_id=search_space.user_id,
search_space_id=data.search_space_id,
billing_tier=billing_tier,
base_model=base_model,
quota_reserve_micros_override=reserve_micros,
usage_type="image_generation",
call_details={"model": base_model, "prompt": data.prompt[:100]},
):
db_image_gen = ImageGeneration(
prompt=data.prompt,
model=data.model,
n=data.n,
quality=data.quality,
size=data.size,
style=data.style,
response_format=data.response_format,
image_generation_config_id=data.image_generation_config_id,
search_space_id=data.search_space_id,
created_by_id=user.id,
)
session.add(db_image_gen)
await session.flush()
await session.commit()
await session.refresh(db_image_gen)
return db_image_gen
try:
await _execute_image_generation(session, db_image_gen, search_space)
except Exception as e:
logger.exception("Image generation call failed")
db_image_gen.error_message = str(e)
await session.commit()
await session.refresh(db_image_gen)
return db_image_gen
except HTTPException:
raise
except QuotaInsufficientError as exc:
# The user's premium credit pool is empty. No DB row is created
# because ``billable_call`` denies before yielding (issue K).
await session.rollback()
raise HTTPException(
status_code=402,
detail={
"error_code": "premium_quota_exhausted",
"usage_type": exc.usage_type,
"used_micros": exc.used_micros,
"limit_micros": exc.limit_micros,
"remaining_micros": exc.remaining_micros,
"message": (
"Out of premium credits for image generation. "
"Purchase additional credits or switch to a free model."
),
},
) from exc
except SQLAlchemyError:
await session.rollback()
raise HTTPException(

View file

@ -1366,7 +1366,11 @@ async def append_message(
# flush assigns the PK/defaults without a round-trip SELECT
await session.flush()
# Persist token usage if provided (for assistant messages)
# Persist token usage if provided (for assistant messages).
# ``cost_micros`` is the provider USD cost reported by LiteLLM,
# forwarded by the FE through the appendMessage round-trip so
# the historical TokenUsage row matches the credit debit applied
# at finalize time.
token_usage_data = raw_body.get("token_usage")
if token_usage_data and message_role == NewChatMessageRole.ASSISTANT:
await record_token_usage(
@ -1377,6 +1381,7 @@ async def append_message(
prompt_tokens=token_usage_data.get("prompt_tokens", 0),
completion_tokens=token_usage_data.get("completion_tokens", 0),
total_tokens=token_usage_data.get("total_tokens", 0),
cost_micros=token_usage_data.get("cost_micros", 0),
model_breakdown=token_usage_data.get("usage"),
call_details=token_usage_data.get("call_details"),
thread_id=thread_id,

View file

@ -594,6 +594,7 @@ async def _get_image_gen_config_by_id(
"model_name": "auto",
"is_global": True,
"is_auto_mode": True,
"billing_tier": "free",
}
if config_id < 0:
@ -610,6 +611,7 @@ async def _get_image_gen_config_by_id(
"api_version": cfg.get("api_version") or None,
"litellm_params": cfg.get("litellm_params", {}),
"is_global": True,
"billing_tier": cfg.get("billing_tier", "free"),
}
return None
@ -652,6 +654,7 @@ async def _get_vision_llm_config_by_id(
"model_name": "auto",
"is_global": True,
"is_auto_mode": True,
"billing_tier": "free",
}
if config_id < 0:
@ -668,6 +671,7 @@ async def _get_vision_llm_config_by_id(
"api_version": cfg.get("api_version") or None,
"litellm_params": cfg.get("litellm_params", {}),
"is_global": True,
"billing_tier": cfg.get("billing_tier", "free"),
}
return None

View file

@ -251,9 +251,16 @@ async def _fulfill_completed_token_purchase(
metadata = _get_metadata(checkout_session)
user_id = metadata.get("user_id")
quantity = int(metadata.get("quantity", "0"))
tokens_per_unit = int(metadata.get("tokens_per_unit", "0"))
# Read the new metadata key first, fall back to the legacy one so
# in-flight checkout sessions created before the cost-credits
# release still fulfil correctly (the unit is numerically the
# same: $1 buys 1_000_000 micro-USD == 1_000_000 tokens).
credit_micros_per_unit = int(
metadata.get("credit_micros_per_unit")
or metadata.get("tokens_per_unit", "0")
)
if not user_id or quantity <= 0 or tokens_per_unit <= 0:
if not user_id or quantity <= 0 or credit_micros_per_unit <= 0:
logger.error(
"Skipping token fulfillment for session %s: incomplete metadata %s",
checkout_session_id,
@ -268,7 +275,7 @@ async def _fulfill_completed_token_purchase(
getattr(checkout_session, "payment_intent", None)
),
quantity=quantity,
tokens_granted=quantity * tokens_per_unit,
credit_micros_granted=quantity * credit_micros_per_unit,
amount_total=getattr(checkout_session, "amount_total", None),
currency=getattr(checkout_session, "currency", None),
status=PremiumTokenPurchaseStatus.PENDING,
@ -303,9 +310,14 @@ async def _fulfill_completed_token_purchase(
purchase.stripe_payment_intent_id = _normalize_optional_string(
getattr(checkout_session, "payment_intent", None)
)
user.premium_tokens_limit = (
max(user.premium_tokens_used, user.premium_tokens_limit)
+ purchase.tokens_granted
# Top up the user's credit balance by the granted micro-USD amount.
# ``max(used, limit)`` clamps the case where the legacy code wrote a
# used value above the limit (e.g. underbilling rounding) so adding
# ``credit_micros_granted`` always lifts the limit by the full pack
# size rather than disappearing into past overuse.
user.premium_credit_micros_limit = (
max(user.premium_credit_micros_used, user.premium_credit_micros_limit)
+ purchase.credit_micros_granted
)
await db_session.commit()
@ -532,12 +544,18 @@ async def create_token_checkout_session(
user: User = Depends(current_active_user),
db_session: AsyncSession = Depends(get_async_session),
):
"""Create a Stripe Checkout Session for buying premium token packs."""
"""Create a Stripe Checkout Session for buying premium credit packs.
Each pack grants ``STRIPE_CREDIT_MICROS_PER_UNIT`` micro-USD of
credit (default 1_000_000 = $1.00). The user's balance is debited
at the actual provider cost reported by LiteLLM at finalize time,
so $1 of credit always buys $1 worth of provider usage at cost.
"""
_ensure_token_buying_enabled()
stripe_client = get_stripe_client()
price_id = _get_required_token_price_id()
success_url, cancel_url = _get_token_checkout_urls(body.search_space_id)
tokens_granted = body.quantity * config.STRIPE_TOKENS_PER_UNIT
credit_micros_granted = body.quantity * config.STRIPE_CREDIT_MICROS_PER_UNIT
try:
checkout_session = stripe_client.v1.checkout.sessions.create(
@ -556,8 +574,8 @@ async def create_token_checkout_session(
"metadata": {
"user_id": str(user.id),
"quantity": str(body.quantity),
"tokens_per_unit": str(config.STRIPE_TOKENS_PER_UNIT),
"purchase_type": "premium_tokens",
"credit_micros_per_unit": str(config.STRIPE_CREDIT_MICROS_PER_UNIT),
"purchase_type": "premium_credit",
},
}
)
@ -583,7 +601,7 @@ async def create_token_checkout_session(
getattr(checkout_session, "payment_intent", None)
),
quantity=body.quantity,
tokens_granted=tokens_granted,
credit_micros_granted=credit_micros_granted,
amount_total=getattr(checkout_session, "amount_total", None),
currency=getattr(checkout_session, "currency", None),
status=PremiumTokenPurchaseStatus.PENDING,
@ -598,14 +616,19 @@ async def create_token_checkout_session(
async def get_token_status(
user: User = Depends(current_active_user),
):
"""Return token-buying availability and current premium quota for frontend."""
used = user.premium_tokens_used
limit = user.premium_tokens_limit
"""Return token-buying availability and current premium credit quota for frontend.
Values are in micro-USD (1_000_000 = $1.00); the FE divides by 1M
when displaying. The route name is preserved for back-compat with
pinned client deployments.
"""
used = user.premium_credit_micros_used
limit = user.premium_credit_micros_limit
return TokenStripeStatusResponse(
token_buying_enabled=config.STRIPE_TOKEN_BUYING_ENABLED,
premium_tokens_used=used,
premium_tokens_limit=limit,
premium_tokens_remaining=max(0, limit - used),
premium_credit_micros_used=used,
premium_credit_micros_limit=limit,
premium_credit_micros_remaining=max(0, limit - used),
)

View file

@ -82,6 +82,9 @@ async def get_global_vision_llm_configs(
"litellm_params": {},
"is_global": True,
"is_auto_mode": True,
# Auto mode treated as free until per-deployment billing-tier
# surfacing lands; see ``get_vision_llm`` for parity.
"billing_tier": "free",
}
)
@ -98,6 +101,10 @@ async def get_global_vision_llm_configs(
"api_version": cfg.get("api_version") or None,
"litellm_params": cfg.get("litellm_params", {}),
"is_global": True,
"billing_tier": cfg.get("billing_tier", "free"),
"quota_reserve_tokens": cfg.get("quota_reserve_tokens"),
"input_cost_per_token": cfg.get("input_cost_per_token"),
"output_cost_per_token": cfg.get("output_cost_per_token"),
}
)

View file

@ -215,6 +215,12 @@ class GlobalImageGenConfigRead(BaseModel):
Schema for reading global image generation configs from YAML.
Global configs have negative IDs. API key is hidden.
ID 0 is reserved for Auto mode (LiteLLM Router load balancing).
The ``billing_tier`` field allows the frontend to show a Premium/Free
badge and (more importantly) tells the backend whether to debit the
user's premium credit pool when this config is used. ``"free"`` is
the default for backward compatibility admins must explicitly opt
a global config into ``"premium"``.
"""
id: int = Field(
@ -231,3 +237,15 @@ class GlobalImageGenConfigRead(BaseModel):
litellm_params: dict[str, Any] | None = None
is_global: bool = True
is_auto_mode: bool = False
billing_tier: str = Field(
default="free",
description="'free' or 'premium'. Premium debits the user's premium credit pool (USD-cost-based).",
)
quota_reserve_micros: int | None = Field(
default=None,
description=(
"Optional override for the reservation amount (in micro-USD) used when "
"this image generation is premium. Falls back to "
"QUOTA_DEFAULT_IMAGE_RESERVE_MICROS when omitted."
),
)

View file

@ -39,6 +39,7 @@ class TokenUsageSummary(BaseModel):
prompt_tokens: int = 0
completion_tokens: int = 0
total_tokens: int = 0
cost_micros: int = 0
model_breakdown: dict | None = None
model_config = ConfigDict(from_attributes=True)

View file

@ -70,13 +70,17 @@ class CreateTokenCheckoutSessionResponse(BaseModel):
class TokenPurchaseRead(BaseModel):
"""Serialized premium token purchase record."""
"""Serialized premium credit purchase record.
``credit_micros_granted`` is in micro-USD (1_000_000 = $1.00). The
schema name kept ``Token`` for API back-compat with pinned clients.
"""
id: uuid.UUID
stripe_checkout_session_id: str
stripe_payment_intent_id: str | None = None
quantity: int
tokens_granted: int
credit_micros_granted: int
amount_total: int | None = None
currency: str | None = None
status: str
@ -87,15 +91,19 @@ class TokenPurchaseRead(BaseModel):
class TokenPurchaseHistoryResponse(BaseModel):
"""Response containing the user's premium token purchases."""
"""Response containing the user's premium credit purchases."""
purchases: list[TokenPurchaseRead]
class TokenStripeStatusResponse(BaseModel):
"""Response describing token-buying availability and current quota."""
"""Response describing premium-credit-buying availability and balance.
All ``premium_credit_micros_*`` fields are in micro-USD; the FE
divides by 1_000_000 to display USD.
"""
token_buying_enabled: bool
premium_tokens_used: int = 0
premium_tokens_limit: int = 0
premium_tokens_remaining: int = 0
premium_credit_micros_used: int = 0
premium_credit_micros_limit: int = 0
premium_credit_micros_remaining: int = 0

View file

@ -62,6 +62,15 @@ class VisionLLMConfigPublic(BaseModel):
class GlobalVisionLLMConfigRead(BaseModel):
"""Schema for reading global vision LLM configs from YAML.
The ``billing_tier`` field allows the frontend to show a Premium/Free
badge and (more importantly) tells the backend whether to debit the
user's premium credit pool when this config is used. ``"free"`` is
the default for backward compatibility admins must explicitly opt
a global config into ``"premium"``.
"""
id: int = Field(...)
name: str
description: str | None = None
@ -73,3 +82,26 @@ class GlobalVisionLLMConfigRead(BaseModel):
litellm_params: dict[str, Any] | None = None
is_global: bool = True
is_auto_mode: bool = False
billing_tier: str = Field(
default="free",
description="'free' or 'premium'. Premium debits the user's premium credit pool (USD-cost-based).",
)
quota_reserve_tokens: int | None = Field(
default=None,
description=(
"Optional override for the per-call reservation in *tokens* — "
"converted to micro-USD via the model's input/output prices at "
"reservation time. Falls back to QUOTA_DEFAULT_RESERVE_TOKENS."
),
)
input_cost_per_token: float | None = Field(
default=None,
description=(
"Optional input price in USD/token. Used by pricing_registration to "
"register custom Azure / OpenRouter aliases with LiteLLM at startup."
),
)
output_cost_per_token: float | None = Field(
default=None,
description="Optional output price in USD/token. Pair with input_cost_per_token.",
)

View file

@ -0,0 +1,430 @@
"""
Per-call billable wrapper for image generation, vision LLM extraction, and
any other short-lived premium operation that must charge against the user's
shared premium credit pool.
The ``billable_call`` async context manager encapsulates the standard
"reserve → execute → finalize / release → record audit row" lifecycle in a
single primitive so callers (the image-generation REST route and the
vision-LLM wrapper used during indexing) don't have to re-implement it.
KEY DESIGN POINTS (issue A, B):
1. **Session isolation.** ``billable_call`` takes *no* ``db_session``
argument. All ``TokenQuotaService.premium_*`` calls and the audit-row
insert each run inside their own ``shielded_async_session()``. This
guarantees that a quota commit/rollback can never accidentally flush or
roll back rows the caller has staged in the request's main session
(e.g. a freshly-created ``ImageGeneration`` row).
2. **ContextVar safety.** The accumulator is scoped via
:func:`scoped_turn` (which uses ``ContextVar.reset(token)``), so a
nested ``billable_call`` inside an outer chat turn cannot corrupt the
chat turn's accumulator.
3. **Free configs are still audited.** Free calls bypass the reserve /
finalize dance entirely but still record a ``TokenUsage`` audit row with
the LiteLLM-reported ``cost_micros``. This keeps the cost-attribution
pipeline complete for analytics even when nothing is debited.
4. **Quota denial raises ``QuotaInsufficientError``.** The route handler is
responsible for translating that into HTTP 402. We *do not* catch the
denial inside ``billable_call`` letting it propagate also prevents
the image-generation route from creating an ``ImageGeneration`` row
for a request that never actually ran.
"""
from __future__ import annotations
import logging
from collections.abc import AsyncIterator
from contextlib import asynccontextmanager
from typing import Any
from uuid import UUID, uuid4
from sqlalchemy.ext.asyncio import AsyncSession
from app.config import config
from app.db import shielded_async_session
from app.services.token_quota_service import (
TokenQuotaService,
estimate_call_reserve_micros,
)
from app.services.token_tracking_service import (
TurnTokenAccumulator,
record_token_usage,
scoped_turn,
)
logger = logging.getLogger(__name__)
class QuotaInsufficientError(Exception):
"""Raised when ``TokenQuotaService.premium_reserve`` denies a billable
call because the user has exhausted their premium credit pool.
The route handler should catch this and return HTTP 402 Payment
Required (or the equivalent for the surface area). Outside of the HTTP
layer (e.g. the ``QuotaCheckedVisionLLM`` wrapper used during indexing)
callers may catch this and degrade gracefully e.g. fall back to OCR
when vision is unavailable.
"""
def __init__(
self,
*,
usage_type: str,
used_micros: int,
limit_micros: int,
remaining_micros: int,
) -> None:
self.usage_type = usage_type
self.used_micros = used_micros
self.limit_micros = limit_micros
self.remaining_micros = remaining_micros
super().__init__(
f"Premium credit exhausted for {usage_type}: "
f"used={used_micros} limit={limit_micros} remaining={remaining_micros} (micro-USD)"
)
@asynccontextmanager
async def billable_call(
*,
user_id: UUID,
search_space_id: int,
billing_tier: str,
base_model: str,
quota_reserve_tokens: int | None = None,
quota_reserve_micros_override: int | None = None,
usage_type: str,
thread_id: int | None = None,
message_id: int | None = None,
call_details: dict[str, Any] | None = None,
) -> AsyncIterator[TurnTokenAccumulator]:
"""Wrap a single billable LLM/image call.
Args:
user_id: Owner of the credit pool to debit. For vision-LLM during
indexing this is the *search-space owner* (issue M), not the
triggering user.
search_space_id: Required recorded on the ``TokenUsage`` audit row.
billing_tier: ``"premium"`` debits; anything else (``"free"``) skips
the reserve/finalize dance but still records an audit row with
the captured cost.
base_model: Used by :func:`estimate_call_reserve_micros` to compute
a worst-case reservation from LiteLLM's pricing table.
quota_reserve_tokens: Optional per-config override for the chat-style
reserve estimator (vision LLM uses this).
quota_reserve_micros_override: Optional flat micro-USD reservation
(image generation uses this its cost shape is per-image, not
per-token).
usage_type: ``"image_generation"`` / ``"vision_extraction"`` / etc.
Recorded on the ``TokenUsage`` row.
thread_id, message_id: Optional FK columns on ``TokenUsage``.
call_details: Optional per-call metadata (model name, parameters)
forwarded to ``record_token_usage``.
Yields:
The ``TurnTokenAccumulator`` scoped to this call. The caller invokes
the underlying LLM/image API while inside the ``async with``; the
``TokenTrackingCallback`` populates the accumulator automatically.
Raises:
QuotaInsufficientError: when premium and ``premium_reserve`` denies.
"""
is_premium = billing_tier == "premium"
async with scoped_turn() as acc:
# ---------- Free path: just audit -------------------------------
if not is_premium:
try:
yield acc
finally:
# Always audit, even on exception, so we capture cost when
# provider returns successfully but the caller raises later.
try:
async with shielded_async_session() as audit_session:
await record_token_usage(
audit_session,
usage_type=usage_type,
search_space_id=search_space_id,
user_id=user_id,
prompt_tokens=acc.total_prompt_tokens,
completion_tokens=acc.total_completion_tokens,
total_tokens=acc.grand_total,
cost_micros=acc.total_cost_micros,
model_breakdown=acc.per_message_summary(),
call_details=call_details,
thread_id=thread_id,
message_id=message_id,
)
await audit_session.commit()
except Exception:
logger.exception(
"[billable_call] free-path audit insert failed for "
"usage_type=%s user_id=%s",
usage_type,
user_id,
)
return
# ---------- Premium path: reserve → execute → finalize ----------
if quota_reserve_micros_override is not None:
reserve_micros = max(1, int(quota_reserve_micros_override))
else:
reserve_micros = estimate_call_reserve_micros(
base_model=base_model or "",
quota_reserve_tokens=quota_reserve_tokens,
)
request_id = str(uuid4())
async with shielded_async_session() as quota_session:
reserve_result = await TokenQuotaService.premium_reserve(
db_session=quota_session,
user_id=user_id,
request_id=request_id,
reserve_micros=reserve_micros,
)
if not reserve_result.allowed:
logger.info(
"[billable_call] reserve DENIED user=%s usage_type=%s "
"reserve=%d used=%d limit=%d remaining=%d",
user_id,
usage_type,
reserve_micros,
reserve_result.used,
reserve_result.limit,
reserve_result.remaining,
)
raise QuotaInsufficientError(
usage_type=usage_type,
used_micros=reserve_result.used,
limit_micros=reserve_result.limit,
remaining_micros=reserve_result.remaining,
)
logger.info(
"[billable_call] reserve OK user=%s usage_type=%s reserve_micros=%d "
"(remaining=%d)",
user_id,
usage_type,
reserve_micros,
reserve_result.remaining,
)
try:
yield acc
except BaseException:
# Release on any failure (including QuotaInsufficientError raised
# from a downstream call, asyncio cancellation, etc.). We use
# BaseException so cancellation also releases.
try:
async with shielded_async_session() as quota_session:
await TokenQuotaService.premium_release(
db_session=quota_session,
user_id=user_id,
reserved_micros=reserve_micros,
)
except Exception:
logger.exception(
"[billable_call] premium_release failed for user=%s "
"reserve_micros=%d (reservation will be GC'd by quota "
"reconciliation if/when implemented)",
user_id,
reserve_micros,
)
raise
# ---------- Success: finalize + audit ----------------------------
actual_micros = acc.total_cost_micros
try:
async with shielded_async_session() as quota_session:
final_result = await TokenQuotaService.premium_finalize(
db_session=quota_session,
user_id=user_id,
request_id=request_id,
actual_micros=actual_micros,
reserved_micros=reserve_micros,
)
logger.info(
"[billable_call] finalize user=%s usage_type=%s actual=%d "
"reserved=%d → used=%d/%d (remaining=%d)",
user_id,
usage_type,
actual_micros,
reserve_micros,
final_result.used,
final_result.limit,
final_result.remaining,
)
except Exception:
# Last-ditch: if finalize itself fails, we must at least release
# so the reservation doesn't leak.
logger.exception(
"[billable_call] premium_finalize failed for user=%s; "
"attempting release",
user_id,
)
try:
async with shielded_async_session() as quota_session:
await TokenQuotaService.premium_release(
db_session=quota_session,
user_id=user_id,
reserved_micros=reserve_micros,
)
except Exception:
logger.exception(
"[billable_call] release after finalize failure ALSO failed "
"for user=%s",
user_id,
)
try:
async with shielded_async_session() as audit_session:
await record_token_usage(
audit_session,
usage_type=usage_type,
search_space_id=search_space_id,
user_id=user_id,
prompt_tokens=acc.total_prompt_tokens,
completion_tokens=acc.total_completion_tokens,
total_tokens=acc.grand_total,
cost_micros=actual_micros,
model_breakdown=acc.per_message_summary(),
call_details=call_details,
thread_id=thread_id,
message_id=message_id,
)
await audit_session.commit()
except Exception:
logger.exception(
"[billable_call] premium-path audit insert failed for "
"usage_type=%s user_id=%s (debit was applied)",
usage_type,
user_id,
)
async def _resolve_agent_billing_for_search_space(
session: AsyncSession,
search_space_id: int,
*,
thread_id: int | None = None,
) -> tuple[UUID, str, str]:
"""Resolve ``(owner_user_id, billing_tier, base_model)`` for the search-space
agent LLM.
Used by Celery tasks (podcast generation, video presentation) to bill the
search-space owner's premium credit pool when the agent LLM is premium.
Resolution rules mirror chat at ``stream_new_chat.py:2294-2351``:
- Search space not found / no ``agent_llm_id``: raise ``ValueError``.
- **Auto mode** (``id == AUTO_FASTEST_ID == 0``):
* ``thread_id`` is set: delegate to
``resolve_or_get_pinned_llm_config_id`` (the same call chat uses) and
recurse into the resolved id. Reuses chat's existing pin if present
so the same model bills for chat + downstream podcast/video. If the
user is not premium-eligible, the pin service auto-restricts to free
deployments denial only happens later in
``billable_call.premium_reserve`` if the pin really is premium and
credit ran out mid-flow.
* ``thread_id`` is None: fallback to ``("free", "auto")``. Forward-compat
for any future direct-API path; today both Celery tasks always pass
``thread_id``.
- **Negative id** (global YAML / OpenRouter): ``cfg["billing_tier"]``
(defaults to ``"free"`` via ``app/config/__init__.py:52`` setdefault),
``base_model = litellm_params.get("base_model") or model_name``
NOT provider-prefixed, matching chat's cost-map lookup convention.
- **Positive id** (user BYOK ``NewLLMConfig``): always free (matches
``AgentConfig.from_new_llm_config`` which hard-codes ``billing_tier="free"``);
``base_model`` from ``litellm_params`` or ``model_name``.
Note on imports: ``llm_service``, ``auto_model_pin_service``, and
``llm_router_service`` are imported lazily inside the function body to
avoid hoisting litellm side-effects (``litellm.callbacks =
[token_tracker]``, ``litellm.drop_params``, etc.) into
``billable_calls.py``'s module load path.
"""
from sqlalchemy import select
from app.db import NewLLMConfig, SearchSpace
result = await session.execute(
select(SearchSpace).where(SearchSpace.id == search_space_id)
)
search_space = result.scalars().first()
if search_space is None:
raise ValueError(f"Search space {search_space_id} not found")
agent_llm_id = search_space.agent_llm_id
if agent_llm_id is None:
raise ValueError(
f"Search space {search_space_id} has no agent_llm_id configured"
)
owner_user_id: UUID = search_space.user_id
from app.services.auto_model_pin_service import (
AUTO_FASTEST_ID,
resolve_or_get_pinned_llm_config_id,
)
if agent_llm_id == AUTO_FASTEST_ID:
if thread_id is None:
return owner_user_id, "free", "auto"
try:
resolution = await resolve_or_get_pinned_llm_config_id(
session,
thread_id=thread_id,
search_space_id=search_space_id,
user_id=str(owner_user_id),
selected_llm_config_id=AUTO_FASTEST_ID,
)
except ValueError:
logger.warning(
"[agent_billing] Auto-mode pin resolution failed for "
"search_space=%s thread=%s; falling back to free",
search_space_id,
thread_id,
exc_info=True,
)
return owner_user_id, "free", "auto"
agent_llm_id = resolution.resolved_llm_config_id
if agent_llm_id < 0:
from app.services.llm_service import get_global_llm_config
cfg = get_global_llm_config(agent_llm_id) or {}
billing_tier = str(cfg.get("billing_tier", "free")).lower()
litellm_params = cfg.get("litellm_params") or {}
base_model = litellm_params.get("base_model") or cfg.get("model_name") or ""
return owner_user_id, billing_tier, base_model
nlc_result = await session.execute(
select(NewLLMConfig).where(
NewLLMConfig.id == agent_llm_id,
NewLLMConfig.search_space_id == search_space_id,
)
)
nlc = nlc_result.scalars().first()
base_model = ""
if nlc is not None:
litellm_params = nlc.litellm_params or {}
base_model = litellm_params.get("base_model") or nlc.model_name or ""
return owner_user_id, "free", base_model
__all__ = [
"QuotaInsufficientError",
"_resolve_agent_billing_for_search_space",
"billable_call",
]
# Re-export the config knob so callers don't have to import config just for
# the default image reserve.
DEFAULT_IMAGE_RESERVE_MICROS = config.QUOTA_DEFAULT_IMAGE_RESERVE_MICROS

View file

@ -134,42 +134,16 @@ PROVIDER_MAP = {
}
# Default ``api_base`` per LiteLLM provider prefix. Used as a safety net when
# a global LLM config does *not* specify ``api_base``: without this, LiteLLM
# happily picks up provider-agnostic env vars (e.g. ``AZURE_API_BASE``,
# ``OPENAI_API_BASE``) and routes, say, an ``openrouter/anthropic/claude-3-haiku``
# request to an Azure endpoint, which then 404s with ``Resource not found``.
# Only providers with a well-known, stable public base URL are listed here —
# self-hosted / BYO-endpoint providers (ollama, custom, bedrock, vertex_ai,
# huggingface, databricks, cloudflare, replicate) are intentionally omitted
# so their existing config-driven behaviour is preserved.
PROVIDER_DEFAULT_API_BASE = {
"openrouter": "https://openrouter.ai/api/v1",
"groq": "https://api.groq.com/openai/v1",
"mistral": "https://api.mistral.ai/v1",
"perplexity": "https://api.perplexity.ai",
"xai": "https://api.x.ai/v1",
"cerebras": "https://api.cerebras.ai/v1",
"deepinfra": "https://api.deepinfra.com/v1/openai",
"fireworks_ai": "https://api.fireworks.ai/inference/v1",
"together_ai": "https://api.together.xyz/v1",
"anyscale": "https://api.endpoints.anyscale.com/v1",
"cometapi": "https://api.cometapi.com/v1",
"sambanova": "https://api.sambanova.ai/v1",
}
# Canonical provider → base URL when a config uses a generic ``openai``-style
# prefix but the ``provider`` field tells us which API it really is
# (e.g. DeepSeek/Alibaba/Moonshot/Zhipu/MiniMax all use ``openai`` compat but
# each has its own base URL).
PROVIDER_KEY_DEFAULT_API_BASE = {
"DEEPSEEK": "https://api.deepseek.com/v1",
"ALIBABA_QWEN": "https://dashscope-intl.aliyuncs.com/compatible-mode/v1",
"MOONSHOT": "https://api.moonshot.ai/v1",
"ZHIPU": "https://open.bigmodel.cn/api/paas/v4",
"MINIMAX": "https://api.minimax.io/v1",
}
# ``PROVIDER_DEFAULT_API_BASE`` and ``PROVIDER_KEY_DEFAULT_API_BASE`` were
# hoisted to ``app.services.provider_api_base`` so vision and image-gen
# call sites can share the exact same defense (OpenRouter / Groq / etc.
# 404-ing against an inherited Azure endpoint). Re-exported here for
# backward compatibility with any external import.
from app.services.provider_api_base import ( # noqa: E402
PROVIDER_DEFAULT_API_BASE,
PROVIDER_KEY_DEFAULT_API_BASE,
resolve_api_base,
)
class LLMRouterService:
@ -466,14 +440,14 @@ class LLMRouterService:
# Resolve ``api_base``. Config value wins; otherwise apply a
# provider-aware default so the deployment does not silently
# inherit unrelated env vars (e.g. ``AZURE_API_BASE``) and route
# requests to the wrong endpoint. See ``PROVIDER_DEFAULT_API_BASE``
# requests to the wrong endpoint. See ``provider_api_base``
# docstring for the motivating bug (OpenRouter models 404-ing
# against an Azure endpoint).
api_base = config.get("api_base")
if not api_base:
api_base = PROVIDER_KEY_DEFAULT_API_BASE.get(provider)
if not api_base:
api_base = PROVIDER_DEFAULT_API_BASE.get(provider_prefix)
api_base = resolve_api_base(
provider=provider,
provider_prefix=provider_prefix,
config_api_base=config.get("api_base"),
)
if api_base:
litellm_params["api_base"] = api_base

View file

@ -496,8 +496,14 @@ async def get_vision_llm(
- Auto mode (ID 0): VisionLLMRouterService
- Global (negative ID): YAML configs
- DB (positive ID): VisionLLMConfig table
Premium global configs are wrapped in :class:`QuotaCheckedVisionLLM`
so each ``ainvoke`` debits the search-space owner's premium credit
pool. User-owned BYOK configs and free global configs are returned
unwrapped they don't consume premium credit (issue M).
"""
from app.db import VisionLLMConfig
from app.services.quota_checked_vision_llm import QuotaCheckedVisionLLM
from app.services.vision_llm_router_service import (
VISION_PROVIDER_MAP,
VisionLLMRouterService,
@ -519,6 +525,8 @@ async def get_vision_llm(
logger.error(f"No vision LLM configured for search space {search_space_id}")
return None
owner_user_id = search_space.user_id
if is_vision_auto_mode(config_id):
if not VisionLLMRouterService.is_initialized():
logger.error(
@ -526,6 +534,13 @@ async def get_vision_llm(
)
return None
try:
# Auto mode is currently treated as free at the wrapper
# level — the underlying router can dispatch to either
# premium or free YAML configs but routing decisions are
# opaque. If/when we want to bill Auto-routed vision
# calls we'd need to thread the resolved deployment's
# billing_tier back from the router. For now we keep
# parity with chat Auto, which also doesn't pre-classify.
return ChatLiteLLMRouter(
router=VisionLLMRouterService.get_router(),
streaming=True,
@ -562,8 +577,21 @@ async def get_vision_llm(
from app.agents.new_chat.llm_config import SanitizedChatLiteLLM
return SanitizedChatLiteLLM(**litellm_kwargs)
inner_llm = SanitizedChatLiteLLM(**litellm_kwargs)
billing_tier = str(global_cfg.get("billing_tier", "free")).lower()
if billing_tier == "premium":
return QuotaCheckedVisionLLM(
inner_llm,
user_id=owner_user_id,
search_space_id=search_space_id,
billing_tier=billing_tier,
base_model=model_string,
quota_reserve_tokens=global_cfg.get("quota_reserve_tokens"),
)
return inner_llm
# User-owned (positive ID) BYOK configs — always free.
result = await session.execute(
select(VisionLLMConfig).where(
VisionLLMConfig.id == config_id,

View file

@ -93,6 +93,35 @@ def _is_text_output_model(model: dict) -> bool:
return output_mods == ["text"]
def _is_image_output_model(model: dict) -> bool:
"""Return True if the model can produce image output.
OpenRouter's ``architecture.output_modalities`` is a list (e.g.
``["image"]`` for pure image generators, ``["text", "image"]`` for
multi-modal generators that also emit captions). We accept any model
that can output images; the call site decides whether to use the
image-generation API or chat completion.
"""
output_mods = model.get("architecture", {}).get("output_modalities", []) or []
return "image" in output_mods
def _is_vision_input_model(model: dict) -> bool:
"""Return True if the model can ingest an image AND emit text.
OpenRouter's ``architecture.input_modalities`` lists what the model
accepts; ``output_modalities`` lists what it produces. A vision LLM
is a model that takes images in and produces text out i.e. it can
answer questions about a screenshot or extract content from an
image. Pure image-to-image models (e.g. style transfer) and
text-only models are excluded.
"""
arch = model.get("architecture", {}) or {}
input_mods = arch.get("input_modalities", []) or []
output_mods = arch.get("output_modalities", []) or []
return "image" in input_mods and "text" in output_mods
def _supports_tool_calling(model: dict) -> bool:
"""Return True if the model supports function/tool calling."""
supported = model.get("supported_parameters") or []
@ -175,6 +204,32 @@ async def _fetch_models_async() -> list[dict] | None:
return None
def _extract_raw_pricing(raw_models: list[dict]) -> dict[str, dict[str, str]]:
"""Return a ``{model_id: {"prompt": str, "completion": str}}`` map.
Pricing values are kept as the raw OpenRouter strings (e.g.
``"0.000003"``); ``pricing_registration`` converts them to floats
when registering with LiteLLM. Models with missing or malformed
pricing are simply omitted operator-side risk if any of those are
premium.
"""
pricing: dict[str, dict[str, str]] = {}
for model in raw_models:
model_id = str(model.get("id") or "").strip()
if not model_id:
continue
p = model.get("pricing") or {}
prompt = p.get("prompt")
completion = p.get("completion")
if prompt is None and completion is None:
continue
pricing[model_id] = {
"prompt": str(prompt) if prompt is not None else "",
"completion": str(completion) if completion is not None else "",
}
return pricing
def _generate_configs(
raw_models: list[dict],
settings: dict[str, Any],
@ -282,6 +337,162 @@ def _generate_configs(
return configs
# ID-offset bands used to keep dynamic OpenRouter configs in their own
# namespace per surface. Image / vision get separate bands so a single
# Postgres-INTEGER cfg ID is unambiguous about which selector it belongs to.
_OPENROUTER_IMAGE_ID_OFFSET_DEFAULT = -20000
_OPENROUTER_VISION_ID_OFFSET_DEFAULT = -30000
def _generate_image_gen_configs(
raw_models: list[dict], settings: dict[str, Any]
) -> list[dict]:
"""Convert OpenRouter image-generation models into global image-gen
config dicts (matches the YAML shape consumed by ``image_generation_routes``).
Filter:
- architecture.output_modalities contains "image"
- compatible provider (excluded slugs blocked)
- allowed model id (excluded list blocked)
Notably we *drop* the chat-only filters (``_supports_tool_calling`` and
``_has_sufficient_context``) because tool calls and context windows are
irrelevant for the ``aimage_generation`` API. ``billing_tier`` is
derived per model the same way as chat (``_openrouter_tier``).
Cost is intentionally *not* registered with LiteLLM at startup
(``pricing_registration`` skips image gen): OpenRouter image-gen
models are not in LiteLLM's native cost map and OpenRouter populates
``response_cost`` directly from the response header. A defensive
branch in ``_extract_cost_usd`` handles the rare case where
``usage.cost`` is missing see ``token_tracking_service``.
"""
id_offset: int = int(
settings.get("image_id_offset") or _OPENROUTER_IMAGE_ID_OFFSET_DEFAULT
)
api_key: str = settings.get("api_key", "")
rpm: int = settings.get("rpm", 200)
free_rpm: int = settings.get("free_rpm", 20)
litellm_params: dict = settings.get("litellm_params") or {}
image_models = [
m
for m in raw_models
if _is_image_output_model(m)
and _is_compatible_provider(m)
and _is_allowed_model(m)
and "/" in m.get("id", "")
]
configs: list[dict] = []
taken: set[int] = set()
for model in image_models:
model_id: str = model["id"]
name: str = model.get("name", model_id)
tier = _openrouter_tier(model)
cfg: dict[str, Any] = {
"id": _stable_config_id(model_id, id_offset, taken),
"name": name,
"description": f"{name} via OpenRouter (image generation)",
"provider": "OPENROUTER",
"model_name": model_id,
"api_key": api_key,
"api_base": "",
"api_version": None,
"rpm": free_rpm if tier == "free" else rpm,
"litellm_params": dict(litellm_params),
"billing_tier": tier,
_OPENROUTER_DYNAMIC_MARKER: True,
}
configs.append(cfg)
return configs
def _generate_vision_llm_configs(
raw_models: list[dict], settings: dict[str, Any]
) -> list[dict]:
"""Convert OpenRouter vision-capable LLMs into global vision-LLM config
dicts (matches the YAML shape consumed by ``vision_llm_routes``).
Filter:
- architecture.input_modalities contains "image"
- architecture.output_modalities contains "text"
- compatible provider (excluded slugs blocked)
- allowed model id (excluded list blocked)
Vision-LLM is invoked from the indexer (image extraction during
document upload) via ``langchain_litellm.ChatLiteLLM.ainvoke``, so
the chat-only ``_supports_tool_calling`` and ``_has_sufficient_context``
filters do not apply: a small-context vision model that doesn't
advertise tool-calling is still perfectly viable for "describe this
image" prompts.
"""
id_offset: int = int(
settings.get("vision_id_offset") or _OPENROUTER_VISION_ID_OFFSET_DEFAULT
)
api_key: str = settings.get("api_key", "")
rpm: int = settings.get("rpm", 200)
tpm: int = settings.get("tpm", 1_000_000)
free_rpm: int = settings.get("free_rpm", 20)
free_tpm: int = settings.get("free_tpm", 100_000)
quota_reserve_tokens: int = settings.get("quota_reserve_tokens", 4000)
litellm_params: dict = settings.get("litellm_params") or {}
vision_models = [
m
for m in raw_models
if _is_vision_input_model(m)
and _is_compatible_provider(m)
and _is_allowed_model(m)
and "/" in m.get("id", "")
]
configs: list[dict] = []
taken: set[int] = set()
for model in vision_models:
model_id: str = model["id"]
name: str = model.get("name", model_id)
tier = _openrouter_tier(model)
pricing = model.get("pricing") or {}
# Capture per-token prices so ``pricing_registration`` can
# register them with LiteLLM at startup (and so the cost
# estimator in ``estimate_call_reserve_micros`` can resolve
# them at reserve time).
try:
input_cost = float(pricing.get("prompt", 0) or 0)
except (TypeError, ValueError):
input_cost = 0.0
try:
output_cost = float(pricing.get("completion", 0) or 0)
except (TypeError, ValueError):
output_cost = 0.0
cfg: dict[str, Any] = {
"id": _stable_config_id(model_id, id_offset, taken),
"name": name,
"description": f"{name} via OpenRouter (vision)",
"provider": "OPENROUTER",
"model_name": model_id,
"api_key": api_key,
"api_base": "",
"api_version": None,
"rpm": free_rpm if tier == "free" else rpm,
"tpm": free_tpm if tier == "free" else tpm,
"litellm_params": dict(litellm_params),
"billing_tier": tier,
"quota_reserve_tokens": quota_reserve_tokens,
"input_cost_per_token": input_cost or None,
"output_cost_per_token": output_cost or None,
_OPENROUTER_DYNAMIC_MARKER: True,
}
configs.append(cfg)
return configs
class OpenRouterIntegrationService:
"""Singleton that manages the dynamic OpenRouter model catalogue."""
@ -300,6 +511,19 @@ class OpenRouterIntegrationService:
# Shape: {model_name: {"gated": bool, "score": float | None}}
self._health_cache: dict[str, dict[str, Any]] = {}
self._enrich_task: asyncio.Task | None = None
# Raw OpenRouter pricing per model_id, captured at the same time
# we generate configs. Consumed by ``pricing_registration`` to
# teach LiteLLM the per-token cost of every dynamic deployment so
# the success-callback can populate ``response_cost`` correctly.
self._raw_pricing: dict[str, dict[str, str]] = {}
# Cached raw catalogue from the most recent fetch. Image / vision
# emitters reuse this to avoid a second network call per surface.
self._raw_models: list[dict] = []
# Image / vision config caches (only populated when the matching
# opt-in flag is true on initialize). Refreshed in lockstep with
# the chat catalogue.
self._image_configs: list[dict] = []
self._vision_configs: list[dict] = []
@classmethod
def get_instance(cls) -> "OpenRouterIntegrationService":
@ -329,8 +553,32 @@ class OpenRouterIntegrationService:
self._initialized = True
return []
self._raw_models = raw_models
self._configs = _generate_configs(raw_models, settings)
self._configs_by_id = {c["id"]: c for c in self._configs}
self._raw_pricing = _extract_raw_pricing(raw_models)
# Populate image / vision caches when their opt-in flag is set.
# Empty otherwise so the accessors return [] without re-running
# filters every refresh.
if settings.get("image_generation_enabled"):
self._image_configs = _generate_image_gen_configs(raw_models, settings)
logger.info(
"OpenRouter integration: image-gen emission ON (%d models)",
len(self._image_configs),
)
else:
self._image_configs = []
if settings.get("vision_enabled"):
self._vision_configs = _generate_vision_llm_configs(raw_models, settings)
logger.info(
"OpenRouter integration: vision LLM emission ON (%d models)",
len(self._vision_configs),
)
else:
self._vision_configs = []
self._initialized = True
tier_counts = self._tier_counts(self._configs)
@ -369,6 +617,8 @@ class OpenRouterIntegrationService:
new_configs = _generate_configs(raw_models, self._settings)
new_by_id = {c["id"]: c for c in new_configs}
self._raw_pricing = _extract_raw_pricing(raw_models)
self._raw_models = raw_models
from app.config import config as app_config
@ -382,6 +632,29 @@ class OpenRouterIntegrationService:
self._configs = new_configs
self._configs_by_id = new_by_id
# Image / vision lists are atomic-swapped the same way: filter out
# the previous dynamic entries from the live config list and append
# the freshly generated ones. No-ops when the opt-in flag is off.
if self._settings.get("image_generation_enabled"):
new_image = _generate_image_gen_configs(raw_models, self._settings)
static_image = [
c
for c in app_config.GLOBAL_IMAGE_GEN_CONFIGS
if not c.get(_OPENROUTER_DYNAMIC_MARKER)
]
app_config.GLOBAL_IMAGE_GEN_CONFIGS = static_image + new_image
self._image_configs = new_image
if self._settings.get("vision_enabled"):
new_vision = _generate_vision_llm_configs(raw_models, self._settings)
static_vision = [
c
for c in app_config.GLOBAL_VISION_LLM_CONFIGS
if not c.get(_OPENROUTER_DYNAMIC_MARKER)
]
app_config.GLOBAL_VISION_LLM_CONFIGS = static_vision + new_vision
self._vision_configs = new_vision
# Catalogue churn invalidates per-config "recently healthy" credit
# earned by the previous turn's preflight. Drop the whole table so
# the next turn re-probes against the freshly loaded configs.
@ -407,6 +680,21 @@ class OpenRouterIntegrationService:
# so a hand-picked dead OR model is gated like a dynamic one.
await self._enrich_health_safely(static_configs + new_configs, log_summary=True)
# Re-register LiteLLM pricing for the freshly fetched catalogue
# so newly added OR models bill correctly on their first call.
# Runs before the router rebuild because the router may issue
# cost-table lookups during deployment registration.
try:
from app.services.pricing_registration import (
register_pricing_from_global_configs,
)
register_pricing_from_global_configs()
except Exception as exc:
logger.warning(
"OpenRouter refresh: pricing re-registration skipped (%s)", exc
)
# Rebuild the LiteLLM router so freshly fetched configs flow through
# (dynamic OR premium entries now opt into the pool, free ones stay
# out; a refresh also needs to pick up any static-config edits and
@ -635,3 +923,34 @@ class OpenRouterIntegrationService:
def get_config_by_id(self, config_id: int) -> dict | None:
return self._configs_by_id.get(config_id)
def get_image_generation_configs(self) -> list[dict]:
"""Return the dynamic OpenRouter image-generation configs (empty
list when the ``image_generation_enabled`` flag is off).
Each entry already has ``billing_tier`` derived per-model from
OpenRouter's signals and is shaped to drop directly into
``Config.GLOBAL_IMAGE_GEN_CONFIGS``.
"""
return list(self._image_configs)
def get_vision_llm_configs(self) -> list[dict]:
"""Return the dynamic OpenRouter vision-LLM configs (empty list
when the ``vision_enabled`` flag is off).
Each entry exposes ``input_cost_per_token`` / ``output_cost_per_token``
so ``pricing_registration`` can teach LiteLLM the cost of these
models the same way it does for chat which keeps the billable
wrapper able to debit accurate micro-USD on a vision call.
"""
return list(self._vision_configs)
def get_raw_pricing(self) -> dict[str, dict[str, str]]:
"""Return the cached raw OpenRouter pricing map.
Shape: ``{model_id: {"prompt": str, "completion": str}}``. The
values are the strings OpenRouter publishes (USD per token),
never converted to floats here so the caller can decide how to
handle malformed or unset entries.
"""
return dict(self._raw_pricing)

View file

@ -0,0 +1,274 @@
"""
Pricing registration with LiteLLM.
Many models reach our LiteLLM callback without LiteLLM knowing their
per-token cost namely:
* The ~300 dynamic OpenRouter deployments (their pricing only lives on
OpenRouter's ``/api/v1/models`` payload, never in LiteLLM's published
pricing table).
* Static YAML deployments whose ``base_model`` name is operator-defined
(e.g. custom Azure deployment names like ``gpt-5.4``) and therefore
not in LiteLLM's table either.
Without registration, ``kwargs["response_cost"]`` is 0 for those calls
and the user gets billed nothing a fail-safe but wrong answer for a
cost-based credit system. This module runs once at startup, after the
OpenRouter integration has fetched its catalogue, and registers each
known model's pricing with ``litellm.register_model()`` under multiple
plausible alias keys (LiteLLM's cost lookup may use any of them
depending on whether the call went through the Router, ChatLiteLLM,
or a direct ``acompletion``).
Operators who run a custom Azure deployment whose ``base_model`` name
isn't in LiteLLM's table can declare per-token pricing inline in
``global_llm_config.yaml`` via ``input_cost_per_token`` and
``output_cost_per_token`` (USD per token, e.g. ``0.000002``). Without
that declaration the model's calls debit 0 — never overbilled.
"""
from __future__ import annotations
import logging
from typing import Any
import litellm
logger = logging.getLogger(__name__)
def _safe_float(value: Any) -> float:
"""Return ``float(value)`` if it parses to a positive number, else 0.0."""
if value is None:
return 0.0
try:
f = float(value)
except (TypeError, ValueError):
return 0.0
return f if f > 0 else 0.0
def _alias_set_for_openrouter(model_id: str) -> list[str]:
"""Return the alias keys to register an OpenRouter model under.
LiteLLM's cost-callback lookup key varies by call path:
- Router with ``model="openrouter/X"`` kwargs["model"] is
typically ``openrouter/X``.
- LiteLLM's own provider routing may strip the prefix and pass the
bare ``X`` to the cost-table lookup.
Registering under both keeps the lookup hermetic regardless of
which path the call took.
"""
aliases = [f"openrouter/{model_id}", model_id]
return list(dict.fromkeys(a for a in aliases if a))
def _alias_set_for_yaml(provider: str, model_name: str, base_model: str) -> list[str]:
"""Return the alias keys to register a static YAML deployment under.
Same reasoning as the OpenRouter set: cover the bare ``base_model``,
the ``<provider>/<model>`` form LiteLLM Router constructs, and the
bare ``model_name`` because callbacks sometimes see whichever was
configured first.
"""
provider_lower = (provider or "").lower()
aliases: list[str] = []
if base_model:
aliases.append(base_model)
if provider_lower and base_model:
aliases.append(f"{provider_lower}/{base_model}")
if model_name and model_name != base_model:
aliases.append(model_name)
if provider_lower and model_name and model_name != base_model:
aliases.append(f"{provider_lower}/{model_name}")
# Azure deployments often surface as "azure/<name>"; normalise the
# ``azure_openai`` provider slug to the LiteLLM-canonical ``azure``.
if provider_lower == "azure_openai":
if base_model:
aliases.append(f"azure/{base_model}")
if model_name and model_name != base_model:
aliases.append(f"azure/{model_name}")
return list(dict.fromkeys(a for a in aliases if a))
def _register(
aliases: list[str],
*,
input_cost: float,
output_cost: float,
provider: str,
mode: str = "chat",
) -> int:
"""Register a single pricing entry under every alias in ``aliases``.
Returns the count of aliases successfully registered.
"""
payload: dict[str, dict[str, Any]] = {}
for alias in aliases:
payload[alias] = {
"input_cost_per_token": input_cost,
"output_cost_per_token": output_cost,
"litellm_provider": provider,
"mode": mode,
}
if not payload:
return 0
try:
litellm.register_model(payload)
except Exception as exc:
logger.warning(
"[PricingRegistration] register_model failed for aliases=%s: %s",
aliases,
exc,
)
return 0
return len(payload)
def _register_chat_shape_configs(
configs: list[dict],
*,
or_pricing: dict[str, dict[str, str]],
label: str,
) -> tuple[int, int, int, list[str]]:
"""Common loop that registers per-token pricing for a list of "chat-shape"
configs (chat or vision LLM both use ``input_cost_per_token`` /
``output_cost_per_token`` and the LiteLLM ``mode="chat"`` cost shape).
Returns ``(registered_models, registered_aliases, skipped, sample_keys)``.
"""
registered_models = 0
registered_aliases = 0
skipped_no_pricing = 0
sample_keys: list[str] = []
for cfg in configs:
provider = str(cfg.get("provider") or "").upper()
model_name = str(cfg.get("model_name") or "").strip()
litellm_params = cfg.get("litellm_params") or {}
base_model = str(litellm_params.get("base_model") or model_name).strip()
if provider == "OPENROUTER":
entry = or_pricing.get(model_name)
if entry:
input_cost = _safe_float(entry.get("prompt"))
output_cost = _safe_float(entry.get("completion"))
else:
# Vision configs from ``_generate_vision_llm_configs``
# carry their pricing inline because the OpenRouter
# raw-pricing cache is keyed by chat-catalogue model_id;
# vision flows pick up the inline values here.
input_cost = _safe_float(cfg.get("input_cost_per_token"))
output_cost = _safe_float(cfg.get("output_cost_per_token"))
if input_cost == 0.0 and output_cost == 0.0:
skipped_no_pricing += 1
continue
aliases = _alias_set_for_openrouter(model_name)
count = _register(
aliases,
input_cost=input_cost,
output_cost=output_cost,
provider="openrouter",
)
if count > 0:
registered_models += 1
registered_aliases += count
if len(sample_keys) < 6:
sample_keys.extend(aliases[:2])
continue
input_cost = _safe_float(
cfg.get("input_cost_per_token")
or litellm_params.get("input_cost_per_token")
)
output_cost = _safe_float(
cfg.get("output_cost_per_token")
or litellm_params.get("output_cost_per_token")
)
if input_cost == 0.0 and output_cost == 0.0:
skipped_no_pricing += 1
continue
aliases = _alias_set_for_yaml(provider, model_name, base_model)
provider_slug = "azure" if provider == "AZURE_OPENAI" else provider.lower()
count = _register(
aliases,
input_cost=input_cost,
output_cost=output_cost,
provider=provider_slug,
)
if count > 0:
registered_models += 1
registered_aliases += count
if len(sample_keys) < 6:
sample_keys.extend(aliases[:2])
logger.info(
"[PricingRegistration:%s] registered pricing for %d models (%d aliases); "
"%d configs had no pricing data; sample registered keys=%s",
label,
registered_models,
registered_aliases,
skipped_no_pricing,
sample_keys,
)
return registered_models, registered_aliases, skipped_no_pricing, sample_keys
def register_pricing_from_global_configs() -> None:
"""Register pricing for every known LLM deployment with LiteLLM.
Walks ``config.GLOBAL_LLM_CONFIGS`` *and* ``config.GLOBAL_VISION_LLM_CONFIGS``
so vision calls (during indexing) can resolve cost the same way chat
calls do namely:
1. ``OPENROUTER``: pulls the cached raw pricing from
``OpenRouterIntegrationService`` (populated during its own
startup fetch) and converts the per-token strings to floats. For
vision configs that carry pricing inline (``input_cost_per_token`` /
``output_cost_per_token`` set on the cfg itself) we fall back to
those values when the OR cache misses the model.
2. Anything else: looks for operator-declared
``input_cost_per_token`` / ``output_cost_per_token`` on the YAML
config block (top-level or nested under ``litellm_params``).
**Image generation is intentionally NOT registered here.** The cost
shape for image-gen is per-image (``output_cost_per_image``), not
per-token, and LiteLLM's ``register_model`` doesn't accept those
keys via the chat-cost path. OpenRouter image-gen models populate
``response_cost`` directly from their response header instead, and
Azure-native image-gen models are already in LiteLLM's cost map.
Calls without a resolved pair of costs are skipped, not registered
with zeros operators who forget pricing get a "$0 debit" warning
in ``TokenTrackingCallback`` rather than silently overwriting any
pricing LiteLLM might know natively.
"""
from app.config import config as app_config
chat_configs: list[dict] = list(getattr(app_config, "GLOBAL_LLM_CONFIGS", []) or [])
vision_configs: list[dict] = list(
getattr(app_config, "GLOBAL_VISION_LLM_CONFIGS", []) or []
)
if not chat_configs and not vision_configs:
logger.info("[PricingRegistration] no global configs to register")
return
or_pricing: dict[str, dict[str, str]] = {}
try:
from app.services.openrouter_integration_service import (
OpenRouterIntegrationService,
)
if OpenRouterIntegrationService.is_initialized():
or_pricing = OpenRouterIntegrationService.get_instance().get_raw_pricing()
except Exception as exc:
logger.debug(
"[PricingRegistration] OpenRouter pricing not available yet: %s", exc
)
if chat_configs:
_register_chat_shape_configs(chat_configs, or_pricing=or_pricing, label="chat")
if vision_configs:
_register_chat_shape_configs(
vision_configs, or_pricing=or_pricing, label="vision"
)

View file

@ -0,0 +1,107 @@
"""Provider-aware ``api_base`` resolution shared by chat / image-gen / vision.
LiteLLM falls back to the module-global ``litellm.api_base`` when an
individual call doesn't pass one, which silently inherits provider-agnostic
env vars like ``AZURE_OPENAI_ENDPOINT`` / ``OPENAI_API_BASE``. Without an
explicit ``api_base``, an ``openrouter/<model>`` request can end up at an
Azure endpoint and 404 with ``Resource not found`` (real reproducer:
[litellm/llms/openrouter/image_generation/transformation.py:242-263] appends
``/chat/completions`` to whatever inherited base it gets, regardless of
provider).
The chat router has had this defense for a while
(``llm_router_service.py:466-478``). This module hoists the maps + cascade
into a tiny standalone helper so vision and image-gen can share the same
source of truth without an inter-service circular import.
"""
from __future__ import annotations
PROVIDER_DEFAULT_API_BASE: dict[str, str] = {
"openrouter": "https://openrouter.ai/api/v1",
"groq": "https://api.groq.com/openai/v1",
"mistral": "https://api.mistral.ai/v1",
"perplexity": "https://api.perplexity.ai",
"xai": "https://api.x.ai/v1",
"cerebras": "https://api.cerebras.ai/v1",
"deepinfra": "https://api.deepinfra.com/v1/openai",
"fireworks_ai": "https://api.fireworks.ai/inference/v1",
"together_ai": "https://api.together.xyz/v1",
"anyscale": "https://api.endpoints.anyscale.com/v1",
"cometapi": "https://api.cometapi.com/v1",
"sambanova": "https://api.sambanova.ai/v1",
}
"""Default ``api_base`` per LiteLLM provider prefix (lowercase).
Only providers with a well-known, stable public base URL are listed
self-hosted / BYO-endpoint providers (ollama, custom, bedrock, vertex_ai,
huggingface, databricks, cloudflare, replicate) are intentionally omitted
so their existing config-driven behaviour is preserved."""
PROVIDER_KEY_DEFAULT_API_BASE: dict[str, str] = {
"DEEPSEEK": "https://api.deepseek.com/v1",
"ALIBABA_QWEN": "https://dashscope-intl.aliyuncs.com/compatible-mode/v1",
"MOONSHOT": "https://api.moonshot.ai/v1",
"ZHIPU": "https://open.bigmodel.cn/api/paas/v4",
"MINIMAX": "https://api.minimax.io/v1",
}
"""Canonical provider key (uppercase) → base URL.
Used when the LiteLLM provider prefix is the generic ``openai`` shim but the
config's ``provider`` field tells us which API it actually is (DeepSeek,
Alibaba, Moonshot, Zhipu, MiniMax all use the ``openai`` prefix but each
has its own base URL)."""
def resolve_api_base(
*,
provider: str | None,
provider_prefix: str | None,
config_api_base: str | None,
) -> str | None:
"""Resolve a non-Azure-leaking ``api_base`` for a deployment.
Cascade (first non-empty wins):
1. The config's own ``api_base`` (whitespace-only treated as missing).
2. ``PROVIDER_KEY_DEFAULT_API_BASE[provider.upper()]``.
3. ``PROVIDER_DEFAULT_API_BASE[provider_prefix.lower()]``.
4. ``None`` caller should NOT set ``api_base`` and let the LiteLLM
provider integration apply its own default (e.g. AzureOpenAI's
deployment-derived URL, custom provider's per-deployment URL).
Args:
provider: The config's ``provider`` field (e.g. ``"OPENROUTER"``,
``"DEEPSEEK"``). Case-insensitive.
provider_prefix: The LiteLLM model-string prefix the same call
site builds for the model id (e.g. ``"openrouter"``,
``"groq"``). Case-insensitive.
config_api_base: ``api_base`` from the global YAML / DB row /
OpenRouter dynamic config. Empty / whitespace-only means
"missing" the resolver still applies the cascade.
Returns:
A URL string, or ``None`` if no default applies for this provider.
"""
if config_api_base and config_api_base.strip():
return config_api_base
if provider:
key_default = PROVIDER_KEY_DEFAULT_API_BASE.get(provider.upper())
if key_default:
return key_default
if provider_prefix:
prefix_default = PROVIDER_DEFAULT_API_BASE.get(provider_prefix.lower())
if prefix_default:
return prefix_default
return None
__all__ = [
"PROVIDER_DEFAULT_API_BASE",
"PROVIDER_KEY_DEFAULT_API_BASE",
"resolve_api_base",
]

View file

@ -0,0 +1,105 @@
"""
Vision LLM proxy that enforces premium credit quota on every ``ainvoke``.
Used by :func:`app.services.llm_service.get_vision_llm` so callers in the
indexing pipeline (file processors, connector indexers, etl pipeline) can
keep invoking the LLM exactly the way they do today ``await llm.ainvoke(...)``
without threading ``user_id`` through every parser. The wrapper looks like
a chat model from the outside; on the inside it routes each call through
``billable_call`` so the user's premium credit pool is reserved → finalized
or released, and a ``TokenUsage`` audit row is written.
Free configs are returned unwrapped from ``get_vision_llm`` (they do not
need quota enforcement) so this class only ever wraps premium configs.
Why a wrapper instead of plumbing ``user_id`` through every caller:
* The indexer ecosystem has 8+ entry points (Google Drive, OneDrive,
Dropbox, local-folder, file-processor, ETL pipeline) each calling
``parse_with_vision_llm(...)``. Adding a ``user_id`` argument to each is
invasive, error-prone, and easy for a future indexer to forget.
* Per the design (issue M), we always debit the *search-space owner*, not
the triggering user, so ``user_id`` is fully derivable from the search
space the caller is already operating on. The wrapper captures it once
at construction time.
* ``langchain_litellm.ChatLiteLLM`` has no public hook for "before each
call run this coroutine"; subclassing isn't safe across versions because
it derives from ``BaseChatModel`` which expects specific Pydantic shapes.
Composition via attribute proxying (``__getattr__``) is robust to
upstream changes every method other than ``ainvoke`` falls through to
the inner LLM unchanged.
"""
from __future__ import annotations
import logging
from typing import Any
from uuid import UUID
from app.services.billable_calls import QuotaInsufficientError, billable_call
logger = logging.getLogger(__name__)
class QuotaCheckedVisionLLM:
"""Composition wrapper around a langchain chat model that enforces
premium credit quota on every ``ainvoke``.
Anything other than ``ainvoke`` is forwarded to the inner model so
``invoke`` (sync), ``astream``, ``with_structured_output``, etc. all
still work they simply bypass quota enforcement, which is fine
because the indexing pipeline only ever calls ``ainvoke`` today.
"""
def __init__(
self,
inner_llm: Any,
*,
user_id: UUID,
search_space_id: int,
billing_tier: str,
base_model: str,
quota_reserve_tokens: int | None,
usage_type: str = "vision_extraction",
) -> None:
self._inner = inner_llm
self._user_id = user_id
self._search_space_id = search_space_id
self._billing_tier = billing_tier
self._base_model = base_model
self._quota_reserve_tokens = quota_reserve_tokens
self._usage_type = usage_type
async def ainvoke(self, input: Any, *args: Any, **kwargs: Any) -> Any:
"""Proxied async invoke that runs the underlying call inside
``billable_call``.
Raises:
QuotaInsufficientError: when the user has exhausted their
premium credit pool. Caller (``etl_pipeline_service._extract_image``)
catches this and falls back to the document parser.
"""
async with billable_call(
user_id=self._user_id,
search_space_id=self._search_space_id,
billing_tier=self._billing_tier,
base_model=self._base_model,
quota_reserve_tokens=self._quota_reserve_tokens,
usage_type=self._usage_type,
call_details={"model": self._base_model},
):
return await self._inner.ainvoke(input, *args, **kwargs)
def __getattr__(self, name: str) -> Any:
"""Forward everything else (``invoke``, ``astream``, ``bind``,
``with_structured_output``, ) to the inner model.
``__getattr__`` is only consulted when the attribute is *not*
already found on the proxy, which is exactly the contract we
want methods we override stay on the proxy, the rest fall
through.
"""
return getattr(self._inner, name)
__all__ = ["QuotaCheckedVisionLLM", "QuotaInsufficientError"]

View file

@ -22,6 +22,71 @@ from app.config import config
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Per-call reservation estimator (USD micro-units)
# ---------------------------------------------------------------------------
# Minimum reserve in micros so a user with $0.0001 left can still make a tiny
# request, and so models without registered pricing reserve at least
# something while the call runs (debited 0 at finalize anyway when their
# cost can't be resolved).
_QUOTA_MIN_RESERVE_MICROS = 100
def estimate_call_reserve_micros(
*,
base_model: str,
quota_reserve_tokens: int | None,
) -> int:
"""Return the number of micro-USD to reserve for one premium call.
Computes a worst-case upper bound from LiteLLM's per-token pricing
table:
reserve_usd reserve_tokens x (input_cost + output_cost)
so the math scales with model cost Claude Opus + 4K reserve_tokens
naturally reserves $0.36, while a cheap model reserves only a few
cents. Clamped to ``[_QUOTA_MIN_RESERVE_MICROS, QUOTA_MAX_RESERVE_MICROS]``
so a misconfigured "$1000/M" model can't lock the whole balance on
one call.
If ``litellm.get_model_info`` raises (model unknown) we fall back to
the floor 100 micros / $0.0001 which is enough to gate a sane
request without over-reserving for a model whose pricing the
operator hasn't declared yet.
"""
reserve_tokens = quota_reserve_tokens or config.QUOTA_MAX_RESERVE_PER_CALL
if reserve_tokens <= 0:
reserve_tokens = config.QUOTA_MAX_RESERVE_PER_CALL
try:
from litellm import get_model_info
info = get_model_info(base_model) if base_model else {}
input_cost = float(info.get("input_cost_per_token") or 0.0)
output_cost = float(info.get("output_cost_per_token") or 0.0)
except Exception as exc:
logger.debug(
"[quota_reserve] cost lookup failed for base_model=%s: %s",
base_model,
exc,
)
input_cost = 0.0
output_cost = 0.0
if input_cost == 0.0 and output_cost == 0.0:
return _QUOTA_MIN_RESERVE_MICROS
reserve_usd = reserve_tokens * (input_cost + output_cost)
reserve_micros = round(reserve_usd * 1_000_000)
if reserve_micros < _QUOTA_MIN_RESERVE_MICROS:
reserve_micros = _QUOTA_MIN_RESERVE_MICROS
if reserve_micros > config.QUOTA_MAX_RESERVE_MICROS:
reserve_micros = config.QUOTA_MAX_RESERVE_MICROS
return reserve_micros
class QuotaScope(StrEnum):
ANONYMOUS = "anonymous"
PREMIUM = "premium"
@ -444,8 +509,16 @@ class TokenQuotaService:
db_session: AsyncSession,
user_id: Any,
request_id: str,
reserve_tokens: int,
reserve_micros: int,
) -> QuotaResult:
"""Reserve ``reserve_micros`` (USD micro-units) from the user's
premium credit balance.
``QuotaResult.used``/``limit``/``reserved``/``remaining`` are
all in micro-USD on this code path; callers (chat stream,
token-status route, FE display) convert to dollars by dividing
by 1_000_000.
"""
from app.db import User
user = (
@ -465,11 +538,11 @@ class TokenQuotaService:
limit=0,
)
limit = user.premium_tokens_limit
used = user.premium_tokens_used
reserved = user.premium_tokens_reserved
limit = user.premium_credit_micros_limit
used = user.premium_credit_micros_used
reserved = user.premium_credit_micros_reserved
effective = used + reserved + reserve_tokens
effective = used + reserved + reserve_micros
if effective > limit:
remaining = max(0, limit - used - reserved)
await db_session.rollback()
@ -482,10 +555,10 @@ class TokenQuotaService:
remaining=remaining,
)
user.premium_tokens_reserved = reserved + reserve_tokens
user.premium_credit_micros_reserved = reserved + reserve_micros
await db_session.commit()
new_reserved = reserved + reserve_tokens
new_reserved = reserved + reserve_micros
remaining = max(0, limit - used - new_reserved)
warning_threshold = int(limit * 0.8)
@ -510,9 +583,12 @@ class TokenQuotaService:
db_session: AsyncSession,
user_id: Any,
request_id: str,
actual_tokens: int,
reserved_tokens: int,
actual_micros: int,
reserved_micros: int,
) -> QuotaResult:
"""Settle the reservation: release ``reserved_micros`` and debit
``actual_micros`` (the LiteLLM-reported provider cost in micro-USD).
"""
from app.db import User
user = (
@ -529,16 +605,18 @@ class TokenQuotaService:
allowed=False, status=QuotaStatus.BLOCKED, used=0, limit=0
)
user.premium_tokens_reserved = max(
0, user.premium_tokens_reserved - reserved_tokens
user.premium_credit_micros_reserved = max(
0, user.premium_credit_micros_reserved - reserved_micros
)
user.premium_credit_micros_used = (
user.premium_credit_micros_used + actual_micros
)
user.premium_tokens_used = user.premium_tokens_used + actual_tokens
await db_session.commit()
limit = user.premium_tokens_limit
used = user.premium_tokens_used
reserved = user.premium_tokens_reserved
limit = user.premium_credit_micros_limit
used = user.premium_credit_micros_used
reserved = user.premium_credit_micros_reserved
remaining = max(0, limit - used - reserved)
warning_threshold = int(limit * 0.8)
@ -562,8 +640,13 @@ class TokenQuotaService:
async def premium_release(
db_session: AsyncSession,
user_id: Any,
reserved_tokens: int,
reserved_micros: int,
) -> None:
"""Release ``reserved_micros`` previously held by ``premium_reserve``.
Used when a request fails before finalize (so the reservation
doesn't leak credit).
"""
from app.db import User
user = (
@ -576,8 +659,8 @@ class TokenQuotaService:
.scalar_one_or_none()
)
if user is not None:
user.premium_tokens_reserved = max(
0, user.premium_tokens_reserved - reserved_tokens
user.premium_credit_micros_reserved = max(
0, user.premium_credit_micros_reserved - reserved_micros
)
await db_session.commit()
@ -598,9 +681,9 @@ class TokenQuotaService:
allowed=False, status=QuotaStatus.BLOCKED, used=0, limit=0
)
limit = user.premium_tokens_limit
used = user.premium_tokens_used
reserved = user.premium_tokens_reserved
limit = user.premium_credit_micros_limit
used = user.premium_credit_micros_used
reserved = user.premium_credit_micros_reserved
remaining = max(0, limit - used - reserved)
warning_threshold = int(limit * 0.8)

View file

@ -16,11 +16,14 @@ from __future__ import annotations
import dataclasses
import logging
from collections.abc import AsyncIterator
from contextlib import asynccontextmanager
from contextvars import ContextVar
from dataclasses import dataclass, field
from typing import Any
from uuid import UUID
import litellm
from litellm.integrations.custom_logger import CustomLogger
from sqlalchemy.ext.asyncio import AsyncSession
@ -35,6 +38,8 @@ class TokenCallRecord:
prompt_tokens: int
completion_tokens: int
total_tokens: int
cost_micros: int = 0
call_kind: str = "chat"
@dataclass
@ -49,6 +54,8 @@ class TurnTokenAccumulator:
prompt_tokens: int,
completion_tokens: int,
total_tokens: int,
cost_micros: int = 0,
call_kind: str = "chat",
) -> None:
self.calls.append(
TokenCallRecord(
@ -56,20 +63,28 @@ class TurnTokenAccumulator:
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=total_tokens,
cost_micros=cost_micros,
call_kind=call_kind,
)
)
def per_message_summary(self) -> dict[str, dict[str, int]]:
"""Return token counts grouped by model name."""
"""Return token counts (and cost) grouped by model name."""
by_model: dict[str, dict[str, int]] = {}
for c in self.calls:
entry = by_model.setdefault(
c.model,
{"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
{
"prompt_tokens": 0,
"completion_tokens": 0,
"total_tokens": 0,
"cost_micros": 0,
},
)
entry["prompt_tokens"] += c.prompt_tokens
entry["completion_tokens"] += c.completion_tokens
entry["total_tokens"] += c.total_tokens
entry["cost_micros"] += c.cost_micros
return by_model
@property
@ -84,6 +99,21 @@ class TurnTokenAccumulator:
def total_completion_tokens(self) -> int:
return sum(c.completion_tokens for c in self.calls)
@property
def total_cost_micros(self) -> int:
"""Sum of per-call ``cost_micros`` across the entire turn.
Used by ``stream_new_chat`` to debit a premium turn's actual
provider cost (in micro-USD) from the user's premium credit
balance. ``cost_micros`` per call is captured by
``TokenTrackingCallback.async_log_success_event`` from
``kwargs["response_cost"]`` (LiteLLM's auto-calculated cost),
with multiple fallback paths so OpenRouter dynamic models and
custom Azure deployments still bill correctly when our
``pricing_registration`` ran at startup.
"""
return sum(c.cost_micros for c in self.calls)
def serialized_calls(self) -> list[dict[str, Any]]:
return [dataclasses.asdict(c) for c in self.calls]
@ -94,7 +124,14 @@ _turn_accumulator: ContextVar[TurnTokenAccumulator | None] = ContextVar(
def start_turn() -> TurnTokenAccumulator:
"""Create a fresh accumulator for the current async context and return it."""
"""Create a fresh accumulator for the current async context and return it.
NOTE: Used by ``stream_new_chat`` for the long-lived chat turn. For
short-lived per-call billable wrappers (image generation REST endpoint,
vision LLM during indexing) prefer :func:`scoped_turn`, which uses a
ContextVar reset token to restore the *previous* accumulator on exit and
avoids leaking call records across reservations (issue B).
"""
acc = TurnTokenAccumulator()
_turn_accumulator.set(acc)
logger.info("[TokenTracking] start_turn: new accumulator created (id=%s)", id(acc))
@ -105,6 +142,140 @@ def get_current_accumulator() -> TurnTokenAccumulator | None:
return _turn_accumulator.get()
@asynccontextmanager
async def scoped_turn() -> AsyncIterator[TurnTokenAccumulator]:
"""Async context manager that scopes a fresh ``TurnTokenAccumulator``
for the duration of the ``async with`` block, then *resets* the
ContextVar to its previous value on exit.
This is the safe primitive for per-call billable operations
(image generation, vision LLM extraction, podcasts) that may run
inside an outer chat turn or be called sequentially from the same
background worker. Using ``ContextVar.set`` without ``reset`` (as
:func:`start_turn` does) would leak the inner accumulator into the
outer scope, causing the outer chat turn to debit cost twice.
Usage::
async with scoped_turn() as acc:
await llm.ainvoke(...)
# acc.total_cost_micros captures cost from the LiteLLM callback
# Outer accumulator (if any) is restored here.
"""
acc = TurnTokenAccumulator()
token = _turn_accumulator.set(acc)
logger.debug(
"[TokenTracking] scoped_turn: enter (acc id=%s, prev token=%s)",
id(acc),
token,
)
try:
yield acc
finally:
_turn_accumulator.reset(token)
logger.debug(
"[TokenTracking] scoped_turn: exit (acc id=%s captured %d call(s), %d micros total)",
id(acc),
len(acc.calls),
acc.total_cost_micros,
)
def _extract_cost_usd(
kwargs: dict[str, Any],
response_obj: Any,
model: str,
prompt_tokens: int,
completion_tokens: int,
is_image: bool = False,
) -> float:
"""Best-effort USD cost extraction for a single LLM/image call.
Tries four sources in priority order and returns the first that
yields a positive number; returns 0.0 if all four fail (the call
will then debit nothing from the user's balance — fail-safe).
Sources:
1. ``kwargs["response_cost"]`` LiteLLM's standard callback
field, populated for ``Router.acompletion`` since PR #12500.
2. ``response_obj._hidden_params["response_cost"]`` same value
exposed on the response itself.
3. ``litellm.completion_cost(completion_response=response_obj)``
recompute from the response and LiteLLM's pricing table.
4. ``litellm.cost_per_token(model, prompt_tokens, completion_tokens)``
manual fallback for OpenRouter/custom-Azure models that
only resolve via aliases registered by
``pricing_registration`` at startup. **Skipped for image
responses** ``cost_per_token`` does not support ``ImageResponse``
and would raise; the cost map for image-gen lives in different
keys (``output_cost_per_image``) handled by ``completion_cost``.
"""
cost = kwargs.get("response_cost")
if cost is not None:
try:
value = float(cost)
except (TypeError, ValueError):
value = 0.0
if value > 0:
return value
hidden = getattr(response_obj, "_hidden_params", None) or {}
if isinstance(hidden, dict):
cost = hidden.get("response_cost")
if cost is not None:
try:
value = float(cost)
except (TypeError, ValueError):
value = 0.0
if value > 0:
return value
try:
value = float(litellm.completion_cost(completion_response=response_obj))
if value > 0:
return value
except Exception as exc:
if is_image:
# Image-gen path: OpenRouter's image responses can omit
# ``usage.cost`` and LiteLLM's ``default_image_cost_calculator``
# then *raises* (no cost map for OpenRouter image models).
# Bail out with a warning rather than falling through to
# cost_per_token (which is also incompatible with ImageResponse).
logger.warning(
"[TokenTracking] completion_cost failed for image model=%s "
"(provider may have omitted usage.cost). Debiting 0. "
"Cause: %s",
model,
exc,
)
return 0.0
logger.debug(
"[TokenTracking] completion_cost failed for model=%s: %s", model, exc
)
if is_image:
# Never call cost_per_token for ImageResponse — keys mismatch and
# the function is documented chat-only.
return 0.0
if model and (prompt_tokens > 0 or completion_tokens > 0):
try:
prompt_cost, completion_cost = litellm.cost_per_token(
model=model,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
)
value = float(prompt_cost) + float(completion_cost)
if value > 0:
return value
except Exception as exc:
logger.debug(
"[TokenTracking] cost_per_token failed for model=%s: %s", model, exc
)
return 0.0
class TokenTrackingCallback(CustomLogger):
"""LiteLLM callback that captures token usage into the turn accumulator."""
@ -122,6 +293,13 @@ class TokenTrackingCallback(CustomLogger):
)
return
# Detect image generation responses — they have a different usage
# shape (ImageUsage with input_tokens/output_tokens) and require a
# different cost-extraction path. We probe by class name to avoid a
# hard import dependency on litellm internals.
response_cls = type(response_obj).__name__
is_image = response_cls == "ImageResponse"
usage = getattr(response_obj, "usage", None)
if not usage:
logger.debug(
@ -129,24 +307,66 @@ class TokenTrackingCallback(CustomLogger):
)
return
prompt_tokens = getattr(usage, "prompt_tokens", 0) or 0
completion_tokens = getattr(usage, "completion_tokens", 0) or 0
total_tokens = getattr(usage, "total_tokens", 0) or 0
if is_image:
# ``ImageUsage`` exposes ``input_tokens`` / ``output_tokens``
# (not prompt_tokens/completion_tokens). Several providers
# populate only one or neither (e.g. OpenRouter's gpt-image-1
# passes through `input_tokens` from the prompt but no
# completion); fall through gracefully to 0.
prompt_tokens = getattr(usage, "input_tokens", 0) or 0
completion_tokens = getattr(usage, "output_tokens", 0) or 0
total_tokens = (
getattr(usage, "total_tokens", 0) or prompt_tokens + completion_tokens
)
call_kind = "image_generation"
else:
prompt_tokens = getattr(usage, "prompt_tokens", 0) or 0
completion_tokens = getattr(usage, "completion_tokens", 0) or 0
total_tokens = getattr(usage, "total_tokens", 0) or 0
call_kind = "chat"
model = kwargs.get("model", "unknown")
cost_usd = _extract_cost_usd(
kwargs=kwargs,
response_obj=response_obj,
model=model,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
is_image=is_image,
)
cost_micros = round(cost_usd * 1_000_000) if cost_usd > 0 else 0
if cost_micros == 0 and (prompt_tokens > 0 or completion_tokens > 0):
logger.warning(
"[TokenTracking] No cost resolved for model=%s prompt=%d completion=%d "
"kind=%s — debiting 0. Register pricing via pricing_registration or YAML "
"input_cost_per_token/output_cost_per_token (or rely on response_cost "
"for image generation).",
model,
prompt_tokens,
completion_tokens,
call_kind,
)
acc.add(
model=model,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=total_tokens,
cost_micros=cost_micros,
call_kind=call_kind,
)
logger.info(
"[TokenTracking] Captured: model=%s prompt=%d completion=%d total=%d (accumulator now has %d calls)",
"[TokenTracking] Captured: model=%s kind=%s prompt=%d completion=%d total=%d "
"cost=$%.6f (%d micros) (accumulator now has %d calls)",
model,
call_kind,
prompt_tokens,
completion_tokens,
total_tokens,
cost_usd,
cost_micros,
len(acc.calls),
)
@ -168,6 +388,7 @@ async def record_token_usage(
prompt_tokens: int = 0,
completion_tokens: int = 0,
total_tokens: int = 0,
cost_micros: int = 0,
model_breakdown: dict[str, Any] | None = None,
call_details: dict[str, Any] | None = None,
thread_id: int | None = None,
@ -185,6 +406,7 @@ async def record_token_usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=total_tokens,
cost_micros=cost_micros,
model_breakdown=model_breakdown,
call_details=call_details,
thread_id=thread_id,
@ -194,11 +416,12 @@ async def record_token_usage(
)
session.add(record)
logger.debug(
"[TokenTracking] recorded %s usage: prompt=%d completion=%d total=%d",
"[TokenTracking] recorded %s usage: prompt=%d completion=%d total=%d cost_micros=%d",
usage_type,
prompt_tokens,
completion_tokens,
total_tokens,
cost_micros,
)
return record
except Exception:

View file

@ -3,6 +3,8 @@ from typing import Any
from litellm import Router
from app.services.provider_api_base import resolve_api_base
logger = logging.getLogger(__name__)
VISION_AUTO_MODE_ID = 0
@ -108,10 +110,11 @@ class VisionLLMRouterService:
if not config.get("model_name") or not config.get("api_key"):
return None
provider = config.get("provider", "").upper()
if config.get("custom_provider"):
model_string = f"{config['custom_provider']}/{config['model_name']}"
provider_prefix = config["custom_provider"]
model_string = f"{provider_prefix}/{config['model_name']}"
else:
provider = config.get("provider", "").upper()
provider_prefix = VISION_PROVIDER_MAP.get(provider, provider.lower())
model_string = f"{provider_prefix}/{config['model_name']}"
@ -120,8 +123,13 @@ class VisionLLMRouterService:
"api_key": config.get("api_key"),
}
if config.get("api_base"):
litellm_params["api_base"] = config["api_base"]
api_base = resolve_api_base(
provider=provider,
provider_prefix=provider_prefix,
config_api_base=config.get("api_base"),
)
if api_base:
litellm_params["api_base"] = api_base
if config.get("api_version"):
litellm_params["api_version"] = config["api_version"]

View file

@ -9,7 +9,13 @@ from sqlalchemy import select
from app.agents.podcaster.graph import graph as podcaster_graph
from app.agents.podcaster.state import State as PodcasterState
from app.celery_app import celery_app
from app.config import config as app_config
from app.db import Podcast, PodcastStatus
from app.services.billable_calls import (
QuotaInsufficientError,
_resolve_agent_billing_for_search_space,
billable_call,
)
from app.tasks.celery_tasks import get_celery_session_maker
logger = logging.getLogger(__name__)
@ -96,6 +102,31 @@ async def _generate_content_podcast(
podcast.status = PodcastStatus.GENERATING
await session.commit()
try:
(
owner_user_id,
billing_tier,
base_model,
) = await _resolve_agent_billing_for_search_space(
session,
search_space_id,
thread_id=podcast.thread_id,
)
except ValueError as resolve_err:
logger.error(
"Podcast %s: cannot resolve billing for search_space=%s: %s",
podcast.id,
search_space_id,
resolve_err,
)
podcast.status = PodcastStatus.FAILED
await session.commit()
return {
"status": "failed",
"podcast_id": podcast.id,
"reason": "billing_resolution_failed",
}
graph_config = {
"configurable": {
"podcast_title": podcast.title,
@ -109,9 +140,39 @@ async def _generate_content_podcast(
db_session=session,
)
graph_result = await podcaster_graph.ainvoke(
initial_state, config=graph_config
)
try:
async with billable_call(
user_id=owner_user_id,
search_space_id=search_space_id,
billing_tier=billing_tier,
base_model=base_model,
quota_reserve_micros_override=app_config.QUOTA_DEFAULT_PODCAST_RESERVE_MICROS,
usage_type="podcast_generation",
thread_id=podcast.thread_id,
call_details={
"podcast_id": podcast.id,
"title": podcast.title,
},
):
graph_result = await podcaster_graph.ainvoke(
initial_state, config=graph_config
)
except QuotaInsufficientError as exc:
logger.info(
"Podcast %s denied: out of premium credits "
"(used=%d/%d remaining=%d)",
podcast.id,
exc.used_micros,
exc.limit_micros,
exc.remaining_micros,
)
podcast.status = PodcastStatus.FAILED
await session.commit()
return {
"status": "failed",
"podcast_id": podcast.id,
"reason": "premium_quota_exhausted",
}
podcast_transcript = graph_result.get("podcast_transcript", [])
file_path = graph_result.get("final_podcast_file_path", "")

View file

@ -9,7 +9,13 @@ from sqlalchemy import select
from app.agents.video_presentation.graph import graph as video_presentation_graph
from app.agents.video_presentation.state import State as VideoPresentationState
from app.celery_app import celery_app
from app.config import config as app_config
from app.db import VideoPresentation, VideoPresentationStatus
from app.services.billable_calls import (
QuotaInsufficientError,
_resolve_agent_billing_for_search_space,
billable_call,
)
from app.tasks.celery_tasks import get_celery_session_maker
logger = logging.getLogger(__name__)
@ -97,6 +103,32 @@ async def _generate_video_presentation(
video_pres.status = VideoPresentationStatus.GENERATING
await session.commit()
try:
(
owner_user_id,
billing_tier,
base_model,
) = await _resolve_agent_billing_for_search_space(
session,
search_space_id,
thread_id=video_pres.thread_id,
)
except ValueError as resolve_err:
logger.error(
"VideoPresentation %s: cannot resolve billing for "
"search_space=%s: %s",
video_pres.id,
search_space_id,
resolve_err,
)
video_pres.status = VideoPresentationStatus.FAILED
await session.commit()
return {
"status": "failed",
"video_presentation_id": video_pres.id,
"reason": "billing_resolution_failed",
}
graph_config = {
"configurable": {
"video_title": video_pres.title,
@ -110,9 +142,39 @@ async def _generate_video_presentation(
db_session=session,
)
graph_result = await video_presentation_graph.ainvoke(
initial_state, config=graph_config
)
try:
async with billable_call(
user_id=owner_user_id,
search_space_id=search_space_id,
billing_tier=billing_tier,
base_model=base_model,
quota_reserve_micros_override=app_config.QUOTA_DEFAULT_VIDEO_PRESENTATION_RESERVE_MICROS,
usage_type="video_presentation_generation",
thread_id=video_pres.thread_id,
call_details={
"video_presentation_id": video_pres.id,
"title": video_pres.title,
},
):
graph_result = await video_presentation_graph.ainvoke(
initial_state, config=graph_config
)
except QuotaInsufficientError as exc:
logger.info(
"VideoPresentation %s denied: out of premium credits "
"(used=%d/%d remaining=%d)",
video_pres.id,
exc.used_micros,
exc.limit_micros,
exc.remaining_micros,
)
video_pres.status = VideoPresentationStatus.FAILED
await session.commit()
return {
"status": "failed",
"video_presentation_id": video_pres.id,
"reason": "premium_quota_exhausted",
}
# Serialize slides (parsed content + audio info merged)
slides_raw = graph_result.get("slides", [])

View file

@ -2236,8 +2236,10 @@ async def stream_new_chat(
accumulator = start_turn()
# Premium quota tracking state
_premium_reserved = 0
# Premium credit (USD micro-units) tracking state. Stores the
# amount reserved up front so we can release it on cancellation
# and finalize-debit the actual provider cost reported by LiteLLM.
_premium_reserved_micros = 0
_premium_request_id: str | None = None
_emit_stream_error = partial(
@ -2331,23 +2333,28 @@ async def stream_new_chat(
if _needs_premium_quota:
import uuid as _uuid
from app.config import config as _app_config
from app.services.token_quota_service import TokenQuotaService
from app.services.token_quota_service import (
TokenQuotaService,
estimate_call_reserve_micros,
)
_premium_request_id = _uuid.uuid4().hex[:16]
reserve_amount = min(
agent_config.quota_reserve_tokens
or _app_config.QUOTA_MAX_RESERVE_PER_CALL,
_app_config.QUOTA_MAX_RESERVE_PER_CALL,
_agent_litellm_params = agent_config.litellm_params or {}
_agent_base_model = (
_agent_litellm_params.get("base_model") or agent_config.model_name or ""
)
reserve_amount_micros = estimate_call_reserve_micros(
base_model=_agent_base_model,
quota_reserve_tokens=agent_config.quota_reserve_tokens,
)
async with shielded_async_session() as quota_session:
quota_result = await TokenQuotaService.premium_reserve(
db_session=quota_session,
user_id=UUID(user_id),
request_id=_premium_request_id,
reserve_tokens=reserve_amount,
reserve_micros=reserve_amount_micros,
)
_premium_reserved = reserve_amount
_premium_reserved_micros = reserve_amount_micros
if not quota_result.allowed:
if requested_llm_config_id == 0:
try:
@ -2382,7 +2389,7 @@ async def stream_new_chat(
yield streaming_service.format_done()
return
_premium_request_id = None
_premium_reserved = 0
_premium_reserved_micros = 0
_log_chat_stream_error(
flow=flow,
error_kind="premium_quota_exhausted",
@ -3020,9 +3027,10 @@ async def stream_new_chat(
usage_summary = accumulator.per_message_summary()
_perf_log.info(
"[token_usage] interrupted new_chat: calls=%d total=%d summary=%s",
"[token_usage] interrupted new_chat: calls=%d total=%d cost_micros=%d summary=%s",
len(accumulator.calls),
accumulator.grand_total,
accumulator.total_cost_micros,
usage_summary,
)
if usage_summary:
@ -3033,6 +3041,7 @@ async def stream_new_chat(
"prompt_tokens": accumulator.total_prompt_tokens,
"completion_tokens": accumulator.total_completion_tokens,
"total_tokens": accumulator.grand_total,
"cost_micros": accumulator.total_cost_micros,
"call_details": accumulator.serialized_calls(),
},
)
@ -3060,7 +3069,11 @@ async def stream_new_chat(
chat_id, generated_title
)
# Finalize premium quota with actual tokens.
# Finalize premium credit debit with the actual provider cost
# reported by LiteLLM, summed across every call in the turn.
# Mirrors the pre-cost behaviour of "premium turn → all calls
# count" so free sub-agent calls during a premium turn still
# contribute to the bill (they're $0 in practice anyway).
if _premium_request_id and user_id:
try:
from app.services.token_quota_service import TokenQuotaService
@ -3070,11 +3083,11 @@ async def stream_new_chat(
db_session=quota_session,
user_id=UUID(user_id),
request_id=_premium_request_id,
actual_tokens=accumulator.grand_total,
reserved_tokens=_premium_reserved,
actual_micros=accumulator.total_cost_micros,
reserved_micros=_premium_reserved_micros,
)
_premium_request_id = None
_premium_reserved = 0
_premium_reserved_micros = 0
except Exception:
logging.getLogger(__name__).warning(
"Failed to finalize premium quota for user %s",
@ -3084,9 +3097,10 @@ async def stream_new_chat(
usage_summary = accumulator.per_message_summary()
_perf_log.info(
"[token_usage] normal new_chat: calls=%d total=%d summary=%s",
"[token_usage] normal new_chat: calls=%d total=%d cost_micros=%d summary=%s",
len(accumulator.calls),
accumulator.grand_total,
accumulator.total_cost_micros,
usage_summary,
)
if usage_summary:
@ -3097,6 +3111,7 @@ async def stream_new_chat(
"prompt_tokens": accumulator.total_prompt_tokens,
"completion_tokens": accumulator.total_completion_tokens,
"total_tokens": accumulator.grand_total,
"cost_micros": accumulator.total_cost_micros,
"call_details": accumulator.serialized_calls(),
},
)
@ -3190,7 +3205,7 @@ async def stream_new_chat(
end_turn(str(chat_id))
# Release premium reservation if not finalized
if _premium_request_id and _premium_reserved > 0 and user_id:
if _premium_request_id and _premium_reserved_micros > 0 and user_id:
try:
from app.services.token_quota_service import TokenQuotaService
@ -3198,9 +3213,9 @@ async def stream_new_chat(
await TokenQuotaService.premium_release(
db_session=quota_session,
user_id=UUID(user_id),
reserved_tokens=_premium_reserved,
reserved_micros=_premium_reserved_micros,
)
_premium_reserved = 0
_premium_reserved_micros = 0
except Exception:
logging.getLogger(__name__).warning(
"Failed to release premium quota for user %s", user_id
@ -3369,8 +3384,8 @@ async def stream_resume_chat(
"[stream_resume] LLM config loaded in %.3fs", time.perf_counter() - _t0
)
# Premium quota reservation (same logic as stream_new_chat)
_resume_premium_reserved = 0
# Premium credit reservation (same logic as stream_new_chat).
_resume_premium_reserved_micros = 0
_resume_premium_request_id: str | None = None
_resume_needs_premium = (
agent_config is not None and user_id and agent_config.is_premium
@ -3378,23 +3393,30 @@ async def stream_resume_chat(
if _resume_needs_premium:
import uuid as _uuid
from app.config import config as _app_config
from app.services.token_quota_service import TokenQuotaService
from app.services.token_quota_service import (
TokenQuotaService,
estimate_call_reserve_micros,
)
_resume_premium_request_id = _uuid.uuid4().hex[:16]
reserve_amount = min(
agent_config.quota_reserve_tokens
or _app_config.QUOTA_MAX_RESERVE_PER_CALL,
_app_config.QUOTA_MAX_RESERVE_PER_CALL,
_resume_litellm_params = agent_config.litellm_params or {}
_resume_base_model = (
_resume_litellm_params.get("base_model")
or agent_config.model_name
or ""
)
reserve_amount_micros = estimate_call_reserve_micros(
base_model=_resume_base_model,
quota_reserve_tokens=agent_config.quota_reserve_tokens,
)
async with shielded_async_session() as quota_session:
quota_result = await TokenQuotaService.premium_reserve(
db_session=quota_session,
user_id=UUID(user_id),
request_id=_resume_premium_request_id,
reserve_tokens=reserve_amount,
reserve_micros=reserve_amount_micros,
)
_resume_premium_reserved = reserve_amount
_resume_premium_reserved_micros = reserve_amount_micros
if not quota_result.allowed:
if requested_llm_config_id == 0:
try:
@ -3429,7 +3451,7 @@ async def stream_resume_chat(
yield streaming_service.format_done()
return
_resume_premium_request_id = None
_resume_premium_reserved = 0
_resume_premium_reserved_micros = 0
_log_chat_stream_error(
flow="resume",
error_kind="premium_quota_exhausted",
@ -3746,9 +3768,10 @@ async def stream_resume_chat(
if stream_result.is_interrupted:
usage_summary = accumulator.per_message_summary()
_perf_log.info(
"[token_usage] interrupted resume_chat: calls=%d total=%d summary=%s",
"[token_usage] interrupted resume_chat: calls=%d total=%d cost_micros=%d summary=%s",
len(accumulator.calls),
accumulator.grand_total,
accumulator.total_cost_micros,
usage_summary,
)
if usage_summary:
@ -3759,6 +3782,7 @@ async def stream_resume_chat(
"prompt_tokens": accumulator.total_prompt_tokens,
"completion_tokens": accumulator.total_completion_tokens,
"total_tokens": accumulator.grand_total,
"cost_micros": accumulator.total_cost_micros,
"call_details": accumulator.serialized_calls(),
},
)
@ -3768,7 +3792,9 @@ async def stream_resume_chat(
yield streaming_service.format_done()
return
# Finalize premium quota for resume path
# Finalize premium credit debit for resume path with the actual
# provider cost reported by LiteLLM (sum of cost across all
# calls in the turn).
if _resume_premium_request_id and user_id:
try:
from app.services.token_quota_service import TokenQuotaService
@ -3778,11 +3804,11 @@ async def stream_resume_chat(
db_session=quota_session,
user_id=UUID(user_id),
request_id=_resume_premium_request_id,
actual_tokens=accumulator.grand_total,
reserved_tokens=_resume_premium_reserved,
actual_micros=accumulator.total_cost_micros,
reserved_micros=_resume_premium_reserved_micros,
)
_resume_premium_request_id = None
_resume_premium_reserved = 0
_resume_premium_reserved_micros = 0
except Exception:
logging.getLogger(__name__).warning(
"Failed to finalize premium quota for user %s (resume)",
@ -3792,9 +3818,10 @@ async def stream_resume_chat(
usage_summary = accumulator.per_message_summary()
_perf_log.info(
"[token_usage] normal resume_chat: calls=%d total=%d summary=%s",
"[token_usage] normal resume_chat: calls=%d total=%d cost_micros=%d summary=%s",
len(accumulator.calls),
accumulator.grand_total,
accumulator.total_cost_micros,
usage_summary,
)
if usage_summary:
@ -3805,6 +3832,7 @@ async def stream_resume_chat(
"prompt_tokens": accumulator.total_prompt_tokens,
"completion_tokens": accumulator.total_completion_tokens,
"total_tokens": accumulator.grand_total,
"cost_micros": accumulator.total_cost_micros,
"call_details": accumulator.serialized_calls(),
},
)
@ -3855,7 +3883,11 @@ async def stream_resume_chat(
end_turn(str(chat_id))
# Release premium reservation if not finalized
if _resume_premium_request_id and _resume_premium_reserved > 0 and user_id:
if (
_resume_premium_request_id
and _resume_premium_reserved_micros > 0
and user_id
):
try:
from app.services.token_quota_service import TokenQuotaService
@ -3863,9 +3895,9 @@ async def stream_resume_chat(
await TokenQuotaService.premium_release(
db_session=quota_session,
user_id=UUID(user_id),
reserved_tokens=_resume_premium_reserved,
reserved_micros=_resume_premium_reserved_micros,
)
_resume_premium_reserved = 0
_resume_premium_reserved_micros = 0
except Exception:
logging.getLogger(__name__).warning(
"Failed to release premium quota for user %s (resume)", user_id