feat: unified credits and its cost calculations

This commit is contained in:
DESKTOP-RTLN3BA\$punk 2026-05-02 14:34:23 -07:00
parent 451a98936e
commit ae9d36d77f
61 changed files with 5835 additions and 272 deletions

View file

@ -0,0 +1,430 @@
"""
Per-call billable wrapper for image generation, vision LLM extraction, and
any other short-lived premium operation that must charge against the user's
shared premium credit pool.
The ``billable_call`` async context manager encapsulates the standard
"reserve → execute → finalize / release → record audit row" lifecycle in a
single primitive so callers (the image-generation REST route and the
vision-LLM wrapper used during indexing) don't have to re-implement it.
KEY DESIGN POINTS (issue A, B):
1. **Session isolation.** ``billable_call`` takes *no* ``db_session``
argument. All ``TokenQuotaService.premium_*`` calls and the audit-row
insert each run inside their own ``shielded_async_session()``. This
guarantees that a quota commit/rollback can never accidentally flush or
roll back rows the caller has staged in the request's main session
(e.g. a freshly-created ``ImageGeneration`` row).
2. **ContextVar safety.** The accumulator is scoped via
:func:`scoped_turn` (which uses ``ContextVar.reset(token)``), so a
nested ``billable_call`` inside an outer chat turn cannot corrupt the
chat turn's accumulator.
3. **Free configs are still audited.** Free calls bypass the reserve /
finalize dance entirely but still record a ``TokenUsage`` audit row with
the LiteLLM-reported ``cost_micros``. This keeps the cost-attribution
pipeline complete for analytics even when nothing is debited.
4. **Quota denial raises ``QuotaInsufficientError``.** The route handler is
responsible for translating that into HTTP 402. We *do not* catch the
denial inside ``billable_call`` letting it propagate also prevents
the image-generation route from creating an ``ImageGeneration`` row
for a request that never actually ran.
"""
from __future__ import annotations
import logging
from collections.abc import AsyncIterator
from contextlib import asynccontextmanager
from typing import Any
from uuid import UUID, uuid4
from sqlalchemy.ext.asyncio import AsyncSession
from app.config import config
from app.db import shielded_async_session
from app.services.token_quota_service import (
TokenQuotaService,
estimate_call_reserve_micros,
)
from app.services.token_tracking_service import (
TurnTokenAccumulator,
record_token_usage,
scoped_turn,
)
logger = logging.getLogger(__name__)
class QuotaInsufficientError(Exception):
"""Raised when ``TokenQuotaService.premium_reserve`` denies a billable
call because the user has exhausted their premium credit pool.
The route handler should catch this and return HTTP 402 Payment
Required (or the equivalent for the surface area). Outside of the HTTP
layer (e.g. the ``QuotaCheckedVisionLLM`` wrapper used during indexing)
callers may catch this and degrade gracefully e.g. fall back to OCR
when vision is unavailable.
"""
def __init__(
self,
*,
usage_type: str,
used_micros: int,
limit_micros: int,
remaining_micros: int,
) -> None:
self.usage_type = usage_type
self.used_micros = used_micros
self.limit_micros = limit_micros
self.remaining_micros = remaining_micros
super().__init__(
f"Premium credit exhausted for {usage_type}: "
f"used={used_micros} limit={limit_micros} remaining={remaining_micros} (micro-USD)"
)
@asynccontextmanager
async def billable_call(
*,
user_id: UUID,
search_space_id: int,
billing_tier: str,
base_model: str,
quota_reserve_tokens: int | None = None,
quota_reserve_micros_override: int | None = None,
usage_type: str,
thread_id: int | None = None,
message_id: int | None = None,
call_details: dict[str, Any] | None = None,
) -> AsyncIterator[TurnTokenAccumulator]:
"""Wrap a single billable LLM/image call.
Args:
user_id: Owner of the credit pool to debit. For vision-LLM during
indexing this is the *search-space owner* (issue M), not the
triggering user.
search_space_id: Required recorded on the ``TokenUsage`` audit row.
billing_tier: ``"premium"`` debits; anything else (``"free"``) skips
the reserve/finalize dance but still records an audit row with
the captured cost.
base_model: Used by :func:`estimate_call_reserve_micros` to compute
a worst-case reservation from LiteLLM's pricing table.
quota_reserve_tokens: Optional per-config override for the chat-style
reserve estimator (vision LLM uses this).
quota_reserve_micros_override: Optional flat micro-USD reservation
(image generation uses this its cost shape is per-image, not
per-token).
usage_type: ``"image_generation"`` / ``"vision_extraction"`` / etc.
Recorded on the ``TokenUsage`` row.
thread_id, message_id: Optional FK columns on ``TokenUsage``.
call_details: Optional per-call metadata (model name, parameters)
forwarded to ``record_token_usage``.
Yields:
The ``TurnTokenAccumulator`` scoped to this call. The caller invokes
the underlying LLM/image API while inside the ``async with``; the
``TokenTrackingCallback`` populates the accumulator automatically.
Raises:
QuotaInsufficientError: when premium and ``premium_reserve`` denies.
"""
is_premium = billing_tier == "premium"
async with scoped_turn() as acc:
# ---------- Free path: just audit -------------------------------
if not is_premium:
try:
yield acc
finally:
# Always audit, even on exception, so we capture cost when
# provider returns successfully but the caller raises later.
try:
async with shielded_async_session() as audit_session:
await record_token_usage(
audit_session,
usage_type=usage_type,
search_space_id=search_space_id,
user_id=user_id,
prompt_tokens=acc.total_prompt_tokens,
completion_tokens=acc.total_completion_tokens,
total_tokens=acc.grand_total,
cost_micros=acc.total_cost_micros,
model_breakdown=acc.per_message_summary(),
call_details=call_details,
thread_id=thread_id,
message_id=message_id,
)
await audit_session.commit()
except Exception:
logger.exception(
"[billable_call] free-path audit insert failed for "
"usage_type=%s user_id=%s",
usage_type,
user_id,
)
return
# ---------- Premium path: reserve → execute → finalize ----------
if quota_reserve_micros_override is not None:
reserve_micros = max(1, int(quota_reserve_micros_override))
else:
reserve_micros = estimate_call_reserve_micros(
base_model=base_model or "",
quota_reserve_tokens=quota_reserve_tokens,
)
request_id = str(uuid4())
async with shielded_async_session() as quota_session:
reserve_result = await TokenQuotaService.premium_reserve(
db_session=quota_session,
user_id=user_id,
request_id=request_id,
reserve_micros=reserve_micros,
)
if not reserve_result.allowed:
logger.info(
"[billable_call] reserve DENIED user=%s usage_type=%s "
"reserve=%d used=%d limit=%d remaining=%d",
user_id,
usage_type,
reserve_micros,
reserve_result.used,
reserve_result.limit,
reserve_result.remaining,
)
raise QuotaInsufficientError(
usage_type=usage_type,
used_micros=reserve_result.used,
limit_micros=reserve_result.limit,
remaining_micros=reserve_result.remaining,
)
logger.info(
"[billable_call] reserve OK user=%s usage_type=%s reserve_micros=%d "
"(remaining=%d)",
user_id,
usage_type,
reserve_micros,
reserve_result.remaining,
)
try:
yield acc
except BaseException:
# Release on any failure (including QuotaInsufficientError raised
# from a downstream call, asyncio cancellation, etc.). We use
# BaseException so cancellation also releases.
try:
async with shielded_async_session() as quota_session:
await TokenQuotaService.premium_release(
db_session=quota_session,
user_id=user_id,
reserved_micros=reserve_micros,
)
except Exception:
logger.exception(
"[billable_call] premium_release failed for user=%s "
"reserve_micros=%d (reservation will be GC'd by quota "
"reconciliation if/when implemented)",
user_id,
reserve_micros,
)
raise
# ---------- Success: finalize + audit ----------------------------
actual_micros = acc.total_cost_micros
try:
async with shielded_async_session() as quota_session:
final_result = await TokenQuotaService.premium_finalize(
db_session=quota_session,
user_id=user_id,
request_id=request_id,
actual_micros=actual_micros,
reserved_micros=reserve_micros,
)
logger.info(
"[billable_call] finalize user=%s usage_type=%s actual=%d "
"reserved=%d → used=%d/%d (remaining=%d)",
user_id,
usage_type,
actual_micros,
reserve_micros,
final_result.used,
final_result.limit,
final_result.remaining,
)
except Exception:
# Last-ditch: if finalize itself fails, we must at least release
# so the reservation doesn't leak.
logger.exception(
"[billable_call] premium_finalize failed for user=%s; "
"attempting release",
user_id,
)
try:
async with shielded_async_session() as quota_session:
await TokenQuotaService.premium_release(
db_session=quota_session,
user_id=user_id,
reserved_micros=reserve_micros,
)
except Exception:
logger.exception(
"[billable_call] release after finalize failure ALSO failed "
"for user=%s",
user_id,
)
try:
async with shielded_async_session() as audit_session:
await record_token_usage(
audit_session,
usage_type=usage_type,
search_space_id=search_space_id,
user_id=user_id,
prompt_tokens=acc.total_prompt_tokens,
completion_tokens=acc.total_completion_tokens,
total_tokens=acc.grand_total,
cost_micros=actual_micros,
model_breakdown=acc.per_message_summary(),
call_details=call_details,
thread_id=thread_id,
message_id=message_id,
)
await audit_session.commit()
except Exception:
logger.exception(
"[billable_call] premium-path audit insert failed for "
"usage_type=%s user_id=%s (debit was applied)",
usage_type,
user_id,
)
async def _resolve_agent_billing_for_search_space(
session: AsyncSession,
search_space_id: int,
*,
thread_id: int | None = None,
) -> tuple[UUID, str, str]:
"""Resolve ``(owner_user_id, billing_tier, base_model)`` for the search-space
agent LLM.
Used by Celery tasks (podcast generation, video presentation) to bill the
search-space owner's premium credit pool when the agent LLM is premium.
Resolution rules mirror chat at ``stream_new_chat.py:2294-2351``:
- Search space not found / no ``agent_llm_id``: raise ``ValueError``.
- **Auto mode** (``id == AUTO_FASTEST_ID == 0``):
* ``thread_id`` is set: delegate to
``resolve_or_get_pinned_llm_config_id`` (the same call chat uses) and
recurse into the resolved id. Reuses chat's existing pin if present
so the same model bills for chat + downstream podcast/video. If the
user is not premium-eligible, the pin service auto-restricts to free
deployments denial only happens later in
``billable_call.premium_reserve`` if the pin really is premium and
credit ran out mid-flow.
* ``thread_id`` is None: fallback to ``("free", "auto")``. Forward-compat
for any future direct-API path; today both Celery tasks always pass
``thread_id``.
- **Negative id** (global YAML / OpenRouter): ``cfg["billing_tier"]``
(defaults to ``"free"`` via ``app/config/__init__.py:52`` setdefault),
``base_model = litellm_params.get("base_model") or model_name``
NOT provider-prefixed, matching chat's cost-map lookup convention.
- **Positive id** (user BYOK ``NewLLMConfig``): always free (matches
``AgentConfig.from_new_llm_config`` which hard-codes ``billing_tier="free"``);
``base_model`` from ``litellm_params`` or ``model_name``.
Note on imports: ``llm_service``, ``auto_model_pin_service``, and
``llm_router_service`` are imported lazily inside the function body to
avoid hoisting litellm side-effects (``litellm.callbacks =
[token_tracker]``, ``litellm.drop_params``, etc.) into
``billable_calls.py``'s module load path.
"""
from sqlalchemy import select
from app.db import NewLLMConfig, SearchSpace
result = await session.execute(
select(SearchSpace).where(SearchSpace.id == search_space_id)
)
search_space = result.scalars().first()
if search_space is None:
raise ValueError(f"Search space {search_space_id} not found")
agent_llm_id = search_space.agent_llm_id
if agent_llm_id is None:
raise ValueError(
f"Search space {search_space_id} has no agent_llm_id configured"
)
owner_user_id: UUID = search_space.user_id
from app.services.auto_model_pin_service import (
AUTO_FASTEST_ID,
resolve_or_get_pinned_llm_config_id,
)
if agent_llm_id == AUTO_FASTEST_ID:
if thread_id is None:
return owner_user_id, "free", "auto"
try:
resolution = await resolve_or_get_pinned_llm_config_id(
session,
thread_id=thread_id,
search_space_id=search_space_id,
user_id=str(owner_user_id),
selected_llm_config_id=AUTO_FASTEST_ID,
)
except ValueError:
logger.warning(
"[agent_billing] Auto-mode pin resolution failed for "
"search_space=%s thread=%s; falling back to free",
search_space_id,
thread_id,
exc_info=True,
)
return owner_user_id, "free", "auto"
agent_llm_id = resolution.resolved_llm_config_id
if agent_llm_id < 0:
from app.services.llm_service import get_global_llm_config
cfg = get_global_llm_config(agent_llm_id) or {}
billing_tier = str(cfg.get("billing_tier", "free")).lower()
litellm_params = cfg.get("litellm_params") or {}
base_model = litellm_params.get("base_model") or cfg.get("model_name") or ""
return owner_user_id, billing_tier, base_model
nlc_result = await session.execute(
select(NewLLMConfig).where(
NewLLMConfig.id == agent_llm_id,
NewLLMConfig.search_space_id == search_space_id,
)
)
nlc = nlc_result.scalars().first()
base_model = ""
if nlc is not None:
litellm_params = nlc.litellm_params or {}
base_model = litellm_params.get("base_model") or nlc.model_name or ""
return owner_user_id, "free", base_model
__all__ = [
"QuotaInsufficientError",
"_resolve_agent_billing_for_search_space",
"billable_call",
]
# Re-export the config knob so callers don't have to import config just for
# the default image reserve.
DEFAULT_IMAGE_RESERVE_MICROS = config.QUOTA_DEFAULT_IMAGE_RESERVE_MICROS

View file

@ -134,42 +134,16 @@ PROVIDER_MAP = {
}
# Default ``api_base`` per LiteLLM provider prefix. Used as a safety net when
# a global LLM config does *not* specify ``api_base``: without this, LiteLLM
# happily picks up provider-agnostic env vars (e.g. ``AZURE_API_BASE``,
# ``OPENAI_API_BASE``) and routes, say, an ``openrouter/anthropic/claude-3-haiku``
# request to an Azure endpoint, which then 404s with ``Resource not found``.
# Only providers with a well-known, stable public base URL are listed here —
# self-hosted / BYO-endpoint providers (ollama, custom, bedrock, vertex_ai,
# huggingface, databricks, cloudflare, replicate) are intentionally omitted
# so their existing config-driven behaviour is preserved.
PROVIDER_DEFAULT_API_BASE = {
"openrouter": "https://openrouter.ai/api/v1",
"groq": "https://api.groq.com/openai/v1",
"mistral": "https://api.mistral.ai/v1",
"perplexity": "https://api.perplexity.ai",
"xai": "https://api.x.ai/v1",
"cerebras": "https://api.cerebras.ai/v1",
"deepinfra": "https://api.deepinfra.com/v1/openai",
"fireworks_ai": "https://api.fireworks.ai/inference/v1",
"together_ai": "https://api.together.xyz/v1",
"anyscale": "https://api.endpoints.anyscale.com/v1",
"cometapi": "https://api.cometapi.com/v1",
"sambanova": "https://api.sambanova.ai/v1",
}
# Canonical provider → base URL when a config uses a generic ``openai``-style
# prefix but the ``provider`` field tells us which API it really is
# (e.g. DeepSeek/Alibaba/Moonshot/Zhipu/MiniMax all use ``openai`` compat but
# each has its own base URL).
PROVIDER_KEY_DEFAULT_API_BASE = {
"DEEPSEEK": "https://api.deepseek.com/v1",
"ALIBABA_QWEN": "https://dashscope-intl.aliyuncs.com/compatible-mode/v1",
"MOONSHOT": "https://api.moonshot.ai/v1",
"ZHIPU": "https://open.bigmodel.cn/api/paas/v4",
"MINIMAX": "https://api.minimax.io/v1",
}
# ``PROVIDER_DEFAULT_API_BASE`` and ``PROVIDER_KEY_DEFAULT_API_BASE`` were
# hoisted to ``app.services.provider_api_base`` so vision and image-gen
# call sites can share the exact same defense (OpenRouter / Groq / etc.
# 404-ing against an inherited Azure endpoint). Re-exported here for
# backward compatibility with any external import.
from app.services.provider_api_base import ( # noqa: E402
PROVIDER_DEFAULT_API_BASE,
PROVIDER_KEY_DEFAULT_API_BASE,
resolve_api_base,
)
class LLMRouterService:
@ -466,14 +440,14 @@ class LLMRouterService:
# Resolve ``api_base``. Config value wins; otherwise apply a
# provider-aware default so the deployment does not silently
# inherit unrelated env vars (e.g. ``AZURE_API_BASE``) and route
# requests to the wrong endpoint. See ``PROVIDER_DEFAULT_API_BASE``
# requests to the wrong endpoint. See ``provider_api_base``
# docstring for the motivating bug (OpenRouter models 404-ing
# against an Azure endpoint).
api_base = config.get("api_base")
if not api_base:
api_base = PROVIDER_KEY_DEFAULT_API_BASE.get(provider)
if not api_base:
api_base = PROVIDER_DEFAULT_API_BASE.get(provider_prefix)
api_base = resolve_api_base(
provider=provider,
provider_prefix=provider_prefix,
config_api_base=config.get("api_base"),
)
if api_base:
litellm_params["api_base"] = api_base

View file

@ -496,8 +496,14 @@ async def get_vision_llm(
- Auto mode (ID 0): VisionLLMRouterService
- Global (negative ID): YAML configs
- DB (positive ID): VisionLLMConfig table
Premium global configs are wrapped in :class:`QuotaCheckedVisionLLM`
so each ``ainvoke`` debits the search-space owner's premium credit
pool. User-owned BYOK configs and free global configs are returned
unwrapped they don't consume premium credit (issue M).
"""
from app.db import VisionLLMConfig
from app.services.quota_checked_vision_llm import QuotaCheckedVisionLLM
from app.services.vision_llm_router_service import (
VISION_PROVIDER_MAP,
VisionLLMRouterService,
@ -519,6 +525,8 @@ async def get_vision_llm(
logger.error(f"No vision LLM configured for search space {search_space_id}")
return None
owner_user_id = search_space.user_id
if is_vision_auto_mode(config_id):
if not VisionLLMRouterService.is_initialized():
logger.error(
@ -526,6 +534,13 @@ async def get_vision_llm(
)
return None
try:
# Auto mode is currently treated as free at the wrapper
# level — the underlying router can dispatch to either
# premium or free YAML configs but routing decisions are
# opaque. If/when we want to bill Auto-routed vision
# calls we'd need to thread the resolved deployment's
# billing_tier back from the router. For now we keep
# parity with chat Auto, which also doesn't pre-classify.
return ChatLiteLLMRouter(
router=VisionLLMRouterService.get_router(),
streaming=True,
@ -562,8 +577,21 @@ async def get_vision_llm(
from app.agents.new_chat.llm_config import SanitizedChatLiteLLM
return SanitizedChatLiteLLM(**litellm_kwargs)
inner_llm = SanitizedChatLiteLLM(**litellm_kwargs)
billing_tier = str(global_cfg.get("billing_tier", "free")).lower()
if billing_tier == "premium":
return QuotaCheckedVisionLLM(
inner_llm,
user_id=owner_user_id,
search_space_id=search_space_id,
billing_tier=billing_tier,
base_model=model_string,
quota_reserve_tokens=global_cfg.get("quota_reserve_tokens"),
)
return inner_llm
# User-owned (positive ID) BYOK configs — always free.
result = await session.execute(
select(VisionLLMConfig).where(
VisionLLMConfig.id == config_id,

View file

@ -93,6 +93,35 @@ def _is_text_output_model(model: dict) -> bool:
return output_mods == ["text"]
def _is_image_output_model(model: dict) -> bool:
"""Return True if the model can produce image output.
OpenRouter's ``architecture.output_modalities`` is a list (e.g.
``["image"]`` for pure image generators, ``["text", "image"]`` for
multi-modal generators that also emit captions). We accept any model
that can output images; the call site decides whether to use the
image-generation API or chat completion.
"""
output_mods = model.get("architecture", {}).get("output_modalities", []) or []
return "image" in output_mods
def _is_vision_input_model(model: dict) -> bool:
"""Return True if the model can ingest an image AND emit text.
OpenRouter's ``architecture.input_modalities`` lists what the model
accepts; ``output_modalities`` lists what it produces. A vision LLM
is a model that takes images in and produces text out i.e. it can
answer questions about a screenshot or extract content from an
image. Pure image-to-image models (e.g. style transfer) and
text-only models are excluded.
"""
arch = model.get("architecture", {}) or {}
input_mods = arch.get("input_modalities", []) or []
output_mods = arch.get("output_modalities", []) or []
return "image" in input_mods and "text" in output_mods
def _supports_tool_calling(model: dict) -> bool:
"""Return True if the model supports function/tool calling."""
supported = model.get("supported_parameters") or []
@ -175,6 +204,32 @@ async def _fetch_models_async() -> list[dict] | None:
return None
def _extract_raw_pricing(raw_models: list[dict]) -> dict[str, dict[str, str]]:
"""Return a ``{model_id: {"prompt": str, "completion": str}}`` map.
Pricing values are kept as the raw OpenRouter strings (e.g.
``"0.000003"``); ``pricing_registration`` converts them to floats
when registering with LiteLLM. Models with missing or malformed
pricing are simply omitted operator-side risk if any of those are
premium.
"""
pricing: dict[str, dict[str, str]] = {}
for model in raw_models:
model_id = str(model.get("id") or "").strip()
if not model_id:
continue
p = model.get("pricing") or {}
prompt = p.get("prompt")
completion = p.get("completion")
if prompt is None and completion is None:
continue
pricing[model_id] = {
"prompt": str(prompt) if prompt is not None else "",
"completion": str(completion) if completion is not None else "",
}
return pricing
def _generate_configs(
raw_models: list[dict],
settings: dict[str, Any],
@ -282,6 +337,162 @@ def _generate_configs(
return configs
# ID-offset bands used to keep dynamic OpenRouter configs in their own
# namespace per surface. Image / vision get separate bands so a single
# Postgres-INTEGER cfg ID is unambiguous about which selector it belongs to.
_OPENROUTER_IMAGE_ID_OFFSET_DEFAULT = -20000
_OPENROUTER_VISION_ID_OFFSET_DEFAULT = -30000
def _generate_image_gen_configs(
raw_models: list[dict], settings: dict[str, Any]
) -> list[dict]:
"""Convert OpenRouter image-generation models into global image-gen
config dicts (matches the YAML shape consumed by ``image_generation_routes``).
Filter:
- architecture.output_modalities contains "image"
- compatible provider (excluded slugs blocked)
- allowed model id (excluded list blocked)
Notably we *drop* the chat-only filters (``_supports_tool_calling`` and
``_has_sufficient_context``) because tool calls and context windows are
irrelevant for the ``aimage_generation`` API. ``billing_tier`` is
derived per model the same way as chat (``_openrouter_tier``).
Cost is intentionally *not* registered with LiteLLM at startup
(``pricing_registration`` skips image gen): OpenRouter image-gen
models are not in LiteLLM's native cost map and OpenRouter populates
``response_cost`` directly from the response header. A defensive
branch in ``_extract_cost_usd`` handles the rare case where
``usage.cost`` is missing see ``token_tracking_service``.
"""
id_offset: int = int(
settings.get("image_id_offset") or _OPENROUTER_IMAGE_ID_OFFSET_DEFAULT
)
api_key: str = settings.get("api_key", "")
rpm: int = settings.get("rpm", 200)
free_rpm: int = settings.get("free_rpm", 20)
litellm_params: dict = settings.get("litellm_params") or {}
image_models = [
m
for m in raw_models
if _is_image_output_model(m)
and _is_compatible_provider(m)
and _is_allowed_model(m)
and "/" in m.get("id", "")
]
configs: list[dict] = []
taken: set[int] = set()
for model in image_models:
model_id: str = model["id"]
name: str = model.get("name", model_id)
tier = _openrouter_tier(model)
cfg: dict[str, Any] = {
"id": _stable_config_id(model_id, id_offset, taken),
"name": name,
"description": f"{name} via OpenRouter (image generation)",
"provider": "OPENROUTER",
"model_name": model_id,
"api_key": api_key,
"api_base": "",
"api_version": None,
"rpm": free_rpm if tier == "free" else rpm,
"litellm_params": dict(litellm_params),
"billing_tier": tier,
_OPENROUTER_DYNAMIC_MARKER: True,
}
configs.append(cfg)
return configs
def _generate_vision_llm_configs(
raw_models: list[dict], settings: dict[str, Any]
) -> list[dict]:
"""Convert OpenRouter vision-capable LLMs into global vision-LLM config
dicts (matches the YAML shape consumed by ``vision_llm_routes``).
Filter:
- architecture.input_modalities contains "image"
- architecture.output_modalities contains "text"
- compatible provider (excluded slugs blocked)
- allowed model id (excluded list blocked)
Vision-LLM is invoked from the indexer (image extraction during
document upload) via ``langchain_litellm.ChatLiteLLM.ainvoke``, so
the chat-only ``_supports_tool_calling`` and ``_has_sufficient_context``
filters do not apply: a small-context vision model that doesn't
advertise tool-calling is still perfectly viable for "describe this
image" prompts.
"""
id_offset: int = int(
settings.get("vision_id_offset") or _OPENROUTER_VISION_ID_OFFSET_DEFAULT
)
api_key: str = settings.get("api_key", "")
rpm: int = settings.get("rpm", 200)
tpm: int = settings.get("tpm", 1_000_000)
free_rpm: int = settings.get("free_rpm", 20)
free_tpm: int = settings.get("free_tpm", 100_000)
quota_reserve_tokens: int = settings.get("quota_reserve_tokens", 4000)
litellm_params: dict = settings.get("litellm_params") or {}
vision_models = [
m
for m in raw_models
if _is_vision_input_model(m)
and _is_compatible_provider(m)
and _is_allowed_model(m)
and "/" in m.get("id", "")
]
configs: list[dict] = []
taken: set[int] = set()
for model in vision_models:
model_id: str = model["id"]
name: str = model.get("name", model_id)
tier = _openrouter_tier(model)
pricing = model.get("pricing") or {}
# Capture per-token prices so ``pricing_registration`` can
# register them with LiteLLM at startup (and so the cost
# estimator in ``estimate_call_reserve_micros`` can resolve
# them at reserve time).
try:
input_cost = float(pricing.get("prompt", 0) or 0)
except (TypeError, ValueError):
input_cost = 0.0
try:
output_cost = float(pricing.get("completion", 0) or 0)
except (TypeError, ValueError):
output_cost = 0.0
cfg: dict[str, Any] = {
"id": _stable_config_id(model_id, id_offset, taken),
"name": name,
"description": f"{name} via OpenRouter (vision)",
"provider": "OPENROUTER",
"model_name": model_id,
"api_key": api_key,
"api_base": "",
"api_version": None,
"rpm": free_rpm if tier == "free" else rpm,
"tpm": free_tpm if tier == "free" else tpm,
"litellm_params": dict(litellm_params),
"billing_tier": tier,
"quota_reserve_tokens": quota_reserve_tokens,
"input_cost_per_token": input_cost or None,
"output_cost_per_token": output_cost or None,
_OPENROUTER_DYNAMIC_MARKER: True,
}
configs.append(cfg)
return configs
class OpenRouterIntegrationService:
"""Singleton that manages the dynamic OpenRouter model catalogue."""
@ -300,6 +511,19 @@ class OpenRouterIntegrationService:
# Shape: {model_name: {"gated": bool, "score": float | None}}
self._health_cache: dict[str, dict[str, Any]] = {}
self._enrich_task: asyncio.Task | None = None
# Raw OpenRouter pricing per model_id, captured at the same time
# we generate configs. Consumed by ``pricing_registration`` to
# teach LiteLLM the per-token cost of every dynamic deployment so
# the success-callback can populate ``response_cost`` correctly.
self._raw_pricing: dict[str, dict[str, str]] = {}
# Cached raw catalogue from the most recent fetch. Image / vision
# emitters reuse this to avoid a second network call per surface.
self._raw_models: list[dict] = []
# Image / vision config caches (only populated when the matching
# opt-in flag is true on initialize). Refreshed in lockstep with
# the chat catalogue.
self._image_configs: list[dict] = []
self._vision_configs: list[dict] = []
@classmethod
def get_instance(cls) -> "OpenRouterIntegrationService":
@ -329,8 +553,32 @@ class OpenRouterIntegrationService:
self._initialized = True
return []
self._raw_models = raw_models
self._configs = _generate_configs(raw_models, settings)
self._configs_by_id = {c["id"]: c for c in self._configs}
self._raw_pricing = _extract_raw_pricing(raw_models)
# Populate image / vision caches when their opt-in flag is set.
# Empty otherwise so the accessors return [] without re-running
# filters every refresh.
if settings.get("image_generation_enabled"):
self._image_configs = _generate_image_gen_configs(raw_models, settings)
logger.info(
"OpenRouter integration: image-gen emission ON (%d models)",
len(self._image_configs),
)
else:
self._image_configs = []
if settings.get("vision_enabled"):
self._vision_configs = _generate_vision_llm_configs(raw_models, settings)
logger.info(
"OpenRouter integration: vision LLM emission ON (%d models)",
len(self._vision_configs),
)
else:
self._vision_configs = []
self._initialized = True
tier_counts = self._tier_counts(self._configs)
@ -369,6 +617,8 @@ class OpenRouterIntegrationService:
new_configs = _generate_configs(raw_models, self._settings)
new_by_id = {c["id"]: c for c in new_configs}
self._raw_pricing = _extract_raw_pricing(raw_models)
self._raw_models = raw_models
from app.config import config as app_config
@ -382,6 +632,29 @@ class OpenRouterIntegrationService:
self._configs = new_configs
self._configs_by_id = new_by_id
# Image / vision lists are atomic-swapped the same way: filter out
# the previous dynamic entries from the live config list and append
# the freshly generated ones. No-ops when the opt-in flag is off.
if self._settings.get("image_generation_enabled"):
new_image = _generate_image_gen_configs(raw_models, self._settings)
static_image = [
c
for c in app_config.GLOBAL_IMAGE_GEN_CONFIGS
if not c.get(_OPENROUTER_DYNAMIC_MARKER)
]
app_config.GLOBAL_IMAGE_GEN_CONFIGS = static_image + new_image
self._image_configs = new_image
if self._settings.get("vision_enabled"):
new_vision = _generate_vision_llm_configs(raw_models, self._settings)
static_vision = [
c
for c in app_config.GLOBAL_VISION_LLM_CONFIGS
if not c.get(_OPENROUTER_DYNAMIC_MARKER)
]
app_config.GLOBAL_VISION_LLM_CONFIGS = static_vision + new_vision
self._vision_configs = new_vision
# Catalogue churn invalidates per-config "recently healthy" credit
# earned by the previous turn's preflight. Drop the whole table so
# the next turn re-probes against the freshly loaded configs.
@ -407,6 +680,21 @@ class OpenRouterIntegrationService:
# so a hand-picked dead OR model is gated like a dynamic one.
await self._enrich_health_safely(static_configs + new_configs, log_summary=True)
# Re-register LiteLLM pricing for the freshly fetched catalogue
# so newly added OR models bill correctly on their first call.
# Runs before the router rebuild because the router may issue
# cost-table lookups during deployment registration.
try:
from app.services.pricing_registration import (
register_pricing_from_global_configs,
)
register_pricing_from_global_configs()
except Exception as exc:
logger.warning(
"OpenRouter refresh: pricing re-registration skipped (%s)", exc
)
# Rebuild the LiteLLM router so freshly fetched configs flow through
# (dynamic OR premium entries now opt into the pool, free ones stay
# out; a refresh also needs to pick up any static-config edits and
@ -635,3 +923,34 @@ class OpenRouterIntegrationService:
def get_config_by_id(self, config_id: int) -> dict | None:
return self._configs_by_id.get(config_id)
def get_image_generation_configs(self) -> list[dict]:
"""Return the dynamic OpenRouter image-generation configs (empty
list when the ``image_generation_enabled`` flag is off).
Each entry already has ``billing_tier`` derived per-model from
OpenRouter's signals and is shaped to drop directly into
``Config.GLOBAL_IMAGE_GEN_CONFIGS``.
"""
return list(self._image_configs)
def get_vision_llm_configs(self) -> list[dict]:
"""Return the dynamic OpenRouter vision-LLM configs (empty list
when the ``vision_enabled`` flag is off).
Each entry exposes ``input_cost_per_token`` / ``output_cost_per_token``
so ``pricing_registration`` can teach LiteLLM the cost of these
models the same way it does for chat which keeps the billable
wrapper able to debit accurate micro-USD on a vision call.
"""
return list(self._vision_configs)
def get_raw_pricing(self) -> dict[str, dict[str, str]]:
"""Return the cached raw OpenRouter pricing map.
Shape: ``{model_id: {"prompt": str, "completion": str}}``. The
values are the strings OpenRouter publishes (USD per token),
never converted to floats here so the caller can decide how to
handle malformed or unset entries.
"""
return dict(self._raw_pricing)

View file

@ -0,0 +1,274 @@
"""
Pricing registration with LiteLLM.
Many models reach our LiteLLM callback without LiteLLM knowing their
per-token cost namely:
* The ~300 dynamic OpenRouter deployments (their pricing only lives on
OpenRouter's ``/api/v1/models`` payload, never in LiteLLM's published
pricing table).
* Static YAML deployments whose ``base_model`` name is operator-defined
(e.g. custom Azure deployment names like ``gpt-5.4``) and therefore
not in LiteLLM's table either.
Without registration, ``kwargs["response_cost"]`` is 0 for those calls
and the user gets billed nothing a fail-safe but wrong answer for a
cost-based credit system. This module runs once at startup, after the
OpenRouter integration has fetched its catalogue, and registers each
known model's pricing with ``litellm.register_model()`` under multiple
plausible alias keys (LiteLLM's cost lookup may use any of them
depending on whether the call went through the Router, ChatLiteLLM,
or a direct ``acompletion``).
Operators who run a custom Azure deployment whose ``base_model`` name
isn't in LiteLLM's table can declare per-token pricing inline in
``global_llm_config.yaml`` via ``input_cost_per_token`` and
``output_cost_per_token`` (USD per token, e.g. ``0.000002``). Without
that declaration the model's calls debit 0 — never overbilled.
"""
from __future__ import annotations
import logging
from typing import Any
import litellm
logger = logging.getLogger(__name__)
def _safe_float(value: Any) -> float:
"""Return ``float(value)`` if it parses to a positive number, else 0.0."""
if value is None:
return 0.0
try:
f = float(value)
except (TypeError, ValueError):
return 0.0
return f if f > 0 else 0.0
def _alias_set_for_openrouter(model_id: str) -> list[str]:
"""Return the alias keys to register an OpenRouter model under.
LiteLLM's cost-callback lookup key varies by call path:
- Router with ``model="openrouter/X"`` kwargs["model"] is
typically ``openrouter/X``.
- LiteLLM's own provider routing may strip the prefix and pass the
bare ``X`` to the cost-table lookup.
Registering under both keeps the lookup hermetic regardless of
which path the call took.
"""
aliases = [f"openrouter/{model_id}", model_id]
return list(dict.fromkeys(a for a in aliases if a))
def _alias_set_for_yaml(provider: str, model_name: str, base_model: str) -> list[str]:
"""Return the alias keys to register a static YAML deployment under.
Same reasoning as the OpenRouter set: cover the bare ``base_model``,
the ``<provider>/<model>`` form LiteLLM Router constructs, and the
bare ``model_name`` because callbacks sometimes see whichever was
configured first.
"""
provider_lower = (provider or "").lower()
aliases: list[str] = []
if base_model:
aliases.append(base_model)
if provider_lower and base_model:
aliases.append(f"{provider_lower}/{base_model}")
if model_name and model_name != base_model:
aliases.append(model_name)
if provider_lower and model_name and model_name != base_model:
aliases.append(f"{provider_lower}/{model_name}")
# Azure deployments often surface as "azure/<name>"; normalise the
# ``azure_openai`` provider slug to the LiteLLM-canonical ``azure``.
if provider_lower == "azure_openai":
if base_model:
aliases.append(f"azure/{base_model}")
if model_name and model_name != base_model:
aliases.append(f"azure/{model_name}")
return list(dict.fromkeys(a for a in aliases if a))
def _register(
aliases: list[str],
*,
input_cost: float,
output_cost: float,
provider: str,
mode: str = "chat",
) -> int:
"""Register a single pricing entry under every alias in ``aliases``.
Returns the count of aliases successfully registered.
"""
payload: dict[str, dict[str, Any]] = {}
for alias in aliases:
payload[alias] = {
"input_cost_per_token": input_cost,
"output_cost_per_token": output_cost,
"litellm_provider": provider,
"mode": mode,
}
if not payload:
return 0
try:
litellm.register_model(payload)
except Exception as exc:
logger.warning(
"[PricingRegistration] register_model failed for aliases=%s: %s",
aliases,
exc,
)
return 0
return len(payload)
def _register_chat_shape_configs(
configs: list[dict],
*,
or_pricing: dict[str, dict[str, str]],
label: str,
) -> tuple[int, int, int, list[str]]:
"""Common loop that registers per-token pricing for a list of "chat-shape"
configs (chat or vision LLM both use ``input_cost_per_token`` /
``output_cost_per_token`` and the LiteLLM ``mode="chat"`` cost shape).
Returns ``(registered_models, registered_aliases, skipped, sample_keys)``.
"""
registered_models = 0
registered_aliases = 0
skipped_no_pricing = 0
sample_keys: list[str] = []
for cfg in configs:
provider = str(cfg.get("provider") or "").upper()
model_name = str(cfg.get("model_name") or "").strip()
litellm_params = cfg.get("litellm_params") or {}
base_model = str(litellm_params.get("base_model") or model_name).strip()
if provider == "OPENROUTER":
entry = or_pricing.get(model_name)
if entry:
input_cost = _safe_float(entry.get("prompt"))
output_cost = _safe_float(entry.get("completion"))
else:
# Vision configs from ``_generate_vision_llm_configs``
# carry their pricing inline because the OpenRouter
# raw-pricing cache is keyed by chat-catalogue model_id;
# vision flows pick up the inline values here.
input_cost = _safe_float(cfg.get("input_cost_per_token"))
output_cost = _safe_float(cfg.get("output_cost_per_token"))
if input_cost == 0.0 and output_cost == 0.0:
skipped_no_pricing += 1
continue
aliases = _alias_set_for_openrouter(model_name)
count = _register(
aliases,
input_cost=input_cost,
output_cost=output_cost,
provider="openrouter",
)
if count > 0:
registered_models += 1
registered_aliases += count
if len(sample_keys) < 6:
sample_keys.extend(aliases[:2])
continue
input_cost = _safe_float(
cfg.get("input_cost_per_token")
or litellm_params.get("input_cost_per_token")
)
output_cost = _safe_float(
cfg.get("output_cost_per_token")
or litellm_params.get("output_cost_per_token")
)
if input_cost == 0.0 and output_cost == 0.0:
skipped_no_pricing += 1
continue
aliases = _alias_set_for_yaml(provider, model_name, base_model)
provider_slug = "azure" if provider == "AZURE_OPENAI" else provider.lower()
count = _register(
aliases,
input_cost=input_cost,
output_cost=output_cost,
provider=provider_slug,
)
if count > 0:
registered_models += 1
registered_aliases += count
if len(sample_keys) < 6:
sample_keys.extend(aliases[:2])
logger.info(
"[PricingRegistration:%s] registered pricing for %d models (%d aliases); "
"%d configs had no pricing data; sample registered keys=%s",
label,
registered_models,
registered_aliases,
skipped_no_pricing,
sample_keys,
)
return registered_models, registered_aliases, skipped_no_pricing, sample_keys
def register_pricing_from_global_configs() -> None:
"""Register pricing for every known LLM deployment with LiteLLM.
Walks ``config.GLOBAL_LLM_CONFIGS`` *and* ``config.GLOBAL_VISION_LLM_CONFIGS``
so vision calls (during indexing) can resolve cost the same way chat
calls do namely:
1. ``OPENROUTER``: pulls the cached raw pricing from
``OpenRouterIntegrationService`` (populated during its own
startup fetch) and converts the per-token strings to floats. For
vision configs that carry pricing inline (``input_cost_per_token`` /
``output_cost_per_token`` set on the cfg itself) we fall back to
those values when the OR cache misses the model.
2. Anything else: looks for operator-declared
``input_cost_per_token`` / ``output_cost_per_token`` on the YAML
config block (top-level or nested under ``litellm_params``).
**Image generation is intentionally NOT registered here.** The cost
shape for image-gen is per-image (``output_cost_per_image``), not
per-token, and LiteLLM's ``register_model`` doesn't accept those
keys via the chat-cost path. OpenRouter image-gen models populate
``response_cost`` directly from their response header instead, and
Azure-native image-gen models are already in LiteLLM's cost map.
Calls without a resolved pair of costs are skipped, not registered
with zeros operators who forget pricing get a "$0 debit" warning
in ``TokenTrackingCallback`` rather than silently overwriting any
pricing LiteLLM might know natively.
"""
from app.config import config as app_config
chat_configs: list[dict] = list(getattr(app_config, "GLOBAL_LLM_CONFIGS", []) or [])
vision_configs: list[dict] = list(
getattr(app_config, "GLOBAL_VISION_LLM_CONFIGS", []) or []
)
if not chat_configs and not vision_configs:
logger.info("[PricingRegistration] no global configs to register")
return
or_pricing: dict[str, dict[str, str]] = {}
try:
from app.services.openrouter_integration_service import (
OpenRouterIntegrationService,
)
if OpenRouterIntegrationService.is_initialized():
or_pricing = OpenRouterIntegrationService.get_instance().get_raw_pricing()
except Exception as exc:
logger.debug(
"[PricingRegistration] OpenRouter pricing not available yet: %s", exc
)
if chat_configs:
_register_chat_shape_configs(chat_configs, or_pricing=or_pricing, label="chat")
if vision_configs:
_register_chat_shape_configs(
vision_configs, or_pricing=or_pricing, label="vision"
)

View file

@ -0,0 +1,107 @@
"""Provider-aware ``api_base`` resolution shared by chat / image-gen / vision.
LiteLLM falls back to the module-global ``litellm.api_base`` when an
individual call doesn't pass one, which silently inherits provider-agnostic
env vars like ``AZURE_OPENAI_ENDPOINT`` / ``OPENAI_API_BASE``. Without an
explicit ``api_base``, an ``openrouter/<model>`` request can end up at an
Azure endpoint and 404 with ``Resource not found`` (real reproducer:
[litellm/llms/openrouter/image_generation/transformation.py:242-263] appends
``/chat/completions`` to whatever inherited base it gets, regardless of
provider).
The chat router has had this defense for a while
(``llm_router_service.py:466-478``). This module hoists the maps + cascade
into a tiny standalone helper so vision and image-gen can share the same
source of truth without an inter-service circular import.
"""
from __future__ import annotations
PROVIDER_DEFAULT_API_BASE: dict[str, str] = {
"openrouter": "https://openrouter.ai/api/v1",
"groq": "https://api.groq.com/openai/v1",
"mistral": "https://api.mistral.ai/v1",
"perplexity": "https://api.perplexity.ai",
"xai": "https://api.x.ai/v1",
"cerebras": "https://api.cerebras.ai/v1",
"deepinfra": "https://api.deepinfra.com/v1/openai",
"fireworks_ai": "https://api.fireworks.ai/inference/v1",
"together_ai": "https://api.together.xyz/v1",
"anyscale": "https://api.endpoints.anyscale.com/v1",
"cometapi": "https://api.cometapi.com/v1",
"sambanova": "https://api.sambanova.ai/v1",
}
"""Default ``api_base`` per LiteLLM provider prefix (lowercase).
Only providers with a well-known, stable public base URL are listed
self-hosted / BYO-endpoint providers (ollama, custom, bedrock, vertex_ai,
huggingface, databricks, cloudflare, replicate) are intentionally omitted
so their existing config-driven behaviour is preserved."""
PROVIDER_KEY_DEFAULT_API_BASE: dict[str, str] = {
"DEEPSEEK": "https://api.deepseek.com/v1",
"ALIBABA_QWEN": "https://dashscope-intl.aliyuncs.com/compatible-mode/v1",
"MOONSHOT": "https://api.moonshot.ai/v1",
"ZHIPU": "https://open.bigmodel.cn/api/paas/v4",
"MINIMAX": "https://api.minimax.io/v1",
}
"""Canonical provider key (uppercase) → base URL.
Used when the LiteLLM provider prefix is the generic ``openai`` shim but the
config's ``provider`` field tells us which API it actually is (DeepSeek,
Alibaba, Moonshot, Zhipu, MiniMax all use the ``openai`` prefix but each
has its own base URL)."""
def resolve_api_base(
*,
provider: str | None,
provider_prefix: str | None,
config_api_base: str | None,
) -> str | None:
"""Resolve a non-Azure-leaking ``api_base`` for a deployment.
Cascade (first non-empty wins):
1. The config's own ``api_base`` (whitespace-only treated as missing).
2. ``PROVIDER_KEY_DEFAULT_API_BASE[provider.upper()]``.
3. ``PROVIDER_DEFAULT_API_BASE[provider_prefix.lower()]``.
4. ``None`` caller should NOT set ``api_base`` and let the LiteLLM
provider integration apply its own default (e.g. AzureOpenAI's
deployment-derived URL, custom provider's per-deployment URL).
Args:
provider: The config's ``provider`` field (e.g. ``"OPENROUTER"``,
``"DEEPSEEK"``). Case-insensitive.
provider_prefix: The LiteLLM model-string prefix the same call
site builds for the model id (e.g. ``"openrouter"``,
``"groq"``). Case-insensitive.
config_api_base: ``api_base`` from the global YAML / DB row /
OpenRouter dynamic config. Empty / whitespace-only means
"missing" the resolver still applies the cascade.
Returns:
A URL string, or ``None`` if no default applies for this provider.
"""
if config_api_base and config_api_base.strip():
return config_api_base
if provider:
key_default = PROVIDER_KEY_DEFAULT_API_BASE.get(provider.upper())
if key_default:
return key_default
if provider_prefix:
prefix_default = PROVIDER_DEFAULT_API_BASE.get(provider_prefix.lower())
if prefix_default:
return prefix_default
return None
__all__ = [
"PROVIDER_DEFAULT_API_BASE",
"PROVIDER_KEY_DEFAULT_API_BASE",
"resolve_api_base",
]

View file

@ -0,0 +1,105 @@
"""
Vision LLM proxy that enforces premium credit quota on every ``ainvoke``.
Used by :func:`app.services.llm_service.get_vision_llm` so callers in the
indexing pipeline (file processors, connector indexers, etl pipeline) can
keep invoking the LLM exactly the way they do today ``await llm.ainvoke(...)``
without threading ``user_id`` through every parser. The wrapper looks like
a chat model from the outside; on the inside it routes each call through
``billable_call`` so the user's premium credit pool is reserved → finalized
or released, and a ``TokenUsage`` audit row is written.
Free configs are returned unwrapped from ``get_vision_llm`` (they do not
need quota enforcement) so this class only ever wraps premium configs.
Why a wrapper instead of plumbing ``user_id`` through every caller:
* The indexer ecosystem has 8+ entry points (Google Drive, OneDrive,
Dropbox, local-folder, file-processor, ETL pipeline) each calling
``parse_with_vision_llm(...)``. Adding a ``user_id`` argument to each is
invasive, error-prone, and easy for a future indexer to forget.
* Per the design (issue M), we always debit the *search-space owner*, not
the triggering user, so ``user_id`` is fully derivable from the search
space the caller is already operating on. The wrapper captures it once
at construction time.
* ``langchain_litellm.ChatLiteLLM`` has no public hook for "before each
call run this coroutine"; subclassing isn't safe across versions because
it derives from ``BaseChatModel`` which expects specific Pydantic shapes.
Composition via attribute proxying (``__getattr__``) is robust to
upstream changes every method other than ``ainvoke`` falls through to
the inner LLM unchanged.
"""
from __future__ import annotations
import logging
from typing import Any
from uuid import UUID
from app.services.billable_calls import QuotaInsufficientError, billable_call
logger = logging.getLogger(__name__)
class QuotaCheckedVisionLLM:
"""Composition wrapper around a langchain chat model that enforces
premium credit quota on every ``ainvoke``.
Anything other than ``ainvoke`` is forwarded to the inner model so
``invoke`` (sync), ``astream``, ``with_structured_output``, etc. all
still work they simply bypass quota enforcement, which is fine
because the indexing pipeline only ever calls ``ainvoke`` today.
"""
def __init__(
self,
inner_llm: Any,
*,
user_id: UUID,
search_space_id: int,
billing_tier: str,
base_model: str,
quota_reserve_tokens: int | None,
usage_type: str = "vision_extraction",
) -> None:
self._inner = inner_llm
self._user_id = user_id
self._search_space_id = search_space_id
self._billing_tier = billing_tier
self._base_model = base_model
self._quota_reserve_tokens = quota_reserve_tokens
self._usage_type = usage_type
async def ainvoke(self, input: Any, *args: Any, **kwargs: Any) -> Any:
"""Proxied async invoke that runs the underlying call inside
``billable_call``.
Raises:
QuotaInsufficientError: when the user has exhausted their
premium credit pool. Caller (``etl_pipeline_service._extract_image``)
catches this and falls back to the document parser.
"""
async with billable_call(
user_id=self._user_id,
search_space_id=self._search_space_id,
billing_tier=self._billing_tier,
base_model=self._base_model,
quota_reserve_tokens=self._quota_reserve_tokens,
usage_type=self._usage_type,
call_details={"model": self._base_model},
):
return await self._inner.ainvoke(input, *args, **kwargs)
def __getattr__(self, name: str) -> Any:
"""Forward everything else (``invoke``, ``astream``, ``bind``,
``with_structured_output``, ) to the inner model.
``__getattr__`` is only consulted when the attribute is *not*
already found on the proxy, which is exactly the contract we
want methods we override stay on the proxy, the rest fall
through.
"""
return getattr(self._inner, name)
__all__ = ["QuotaCheckedVisionLLM", "QuotaInsufficientError"]

View file

@ -22,6 +22,71 @@ from app.config import config
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Per-call reservation estimator (USD micro-units)
# ---------------------------------------------------------------------------
# Minimum reserve in micros so a user with $0.0001 left can still make a tiny
# request, and so models without registered pricing reserve at least
# something while the call runs (debited 0 at finalize anyway when their
# cost can't be resolved).
_QUOTA_MIN_RESERVE_MICROS = 100
def estimate_call_reserve_micros(
*,
base_model: str,
quota_reserve_tokens: int | None,
) -> int:
"""Return the number of micro-USD to reserve for one premium call.
Computes a worst-case upper bound from LiteLLM's per-token pricing
table:
reserve_usd reserve_tokens x (input_cost + output_cost)
so the math scales with model cost Claude Opus + 4K reserve_tokens
naturally reserves $0.36, while a cheap model reserves only a few
cents. Clamped to ``[_QUOTA_MIN_RESERVE_MICROS, QUOTA_MAX_RESERVE_MICROS]``
so a misconfigured "$1000/M" model can't lock the whole balance on
one call.
If ``litellm.get_model_info`` raises (model unknown) we fall back to
the floor 100 micros / $0.0001 which is enough to gate a sane
request without over-reserving for a model whose pricing the
operator hasn't declared yet.
"""
reserve_tokens = quota_reserve_tokens or config.QUOTA_MAX_RESERVE_PER_CALL
if reserve_tokens <= 0:
reserve_tokens = config.QUOTA_MAX_RESERVE_PER_CALL
try:
from litellm import get_model_info
info = get_model_info(base_model) if base_model else {}
input_cost = float(info.get("input_cost_per_token") or 0.0)
output_cost = float(info.get("output_cost_per_token") or 0.0)
except Exception as exc:
logger.debug(
"[quota_reserve] cost lookup failed for base_model=%s: %s",
base_model,
exc,
)
input_cost = 0.0
output_cost = 0.0
if input_cost == 0.0 and output_cost == 0.0:
return _QUOTA_MIN_RESERVE_MICROS
reserve_usd = reserve_tokens * (input_cost + output_cost)
reserve_micros = round(reserve_usd * 1_000_000)
if reserve_micros < _QUOTA_MIN_RESERVE_MICROS:
reserve_micros = _QUOTA_MIN_RESERVE_MICROS
if reserve_micros > config.QUOTA_MAX_RESERVE_MICROS:
reserve_micros = config.QUOTA_MAX_RESERVE_MICROS
return reserve_micros
class QuotaScope(StrEnum):
ANONYMOUS = "anonymous"
PREMIUM = "premium"
@ -444,8 +509,16 @@ class TokenQuotaService:
db_session: AsyncSession,
user_id: Any,
request_id: str,
reserve_tokens: int,
reserve_micros: int,
) -> QuotaResult:
"""Reserve ``reserve_micros`` (USD micro-units) from the user's
premium credit balance.
``QuotaResult.used``/``limit``/``reserved``/``remaining`` are
all in micro-USD on this code path; callers (chat stream,
token-status route, FE display) convert to dollars by dividing
by 1_000_000.
"""
from app.db import User
user = (
@ -465,11 +538,11 @@ class TokenQuotaService:
limit=0,
)
limit = user.premium_tokens_limit
used = user.premium_tokens_used
reserved = user.premium_tokens_reserved
limit = user.premium_credit_micros_limit
used = user.premium_credit_micros_used
reserved = user.premium_credit_micros_reserved
effective = used + reserved + reserve_tokens
effective = used + reserved + reserve_micros
if effective > limit:
remaining = max(0, limit - used - reserved)
await db_session.rollback()
@ -482,10 +555,10 @@ class TokenQuotaService:
remaining=remaining,
)
user.premium_tokens_reserved = reserved + reserve_tokens
user.premium_credit_micros_reserved = reserved + reserve_micros
await db_session.commit()
new_reserved = reserved + reserve_tokens
new_reserved = reserved + reserve_micros
remaining = max(0, limit - used - new_reserved)
warning_threshold = int(limit * 0.8)
@ -510,9 +583,12 @@ class TokenQuotaService:
db_session: AsyncSession,
user_id: Any,
request_id: str,
actual_tokens: int,
reserved_tokens: int,
actual_micros: int,
reserved_micros: int,
) -> QuotaResult:
"""Settle the reservation: release ``reserved_micros`` and debit
``actual_micros`` (the LiteLLM-reported provider cost in micro-USD).
"""
from app.db import User
user = (
@ -529,16 +605,18 @@ class TokenQuotaService:
allowed=False, status=QuotaStatus.BLOCKED, used=0, limit=0
)
user.premium_tokens_reserved = max(
0, user.premium_tokens_reserved - reserved_tokens
user.premium_credit_micros_reserved = max(
0, user.premium_credit_micros_reserved - reserved_micros
)
user.premium_credit_micros_used = (
user.premium_credit_micros_used + actual_micros
)
user.premium_tokens_used = user.premium_tokens_used + actual_tokens
await db_session.commit()
limit = user.premium_tokens_limit
used = user.premium_tokens_used
reserved = user.premium_tokens_reserved
limit = user.premium_credit_micros_limit
used = user.premium_credit_micros_used
reserved = user.premium_credit_micros_reserved
remaining = max(0, limit - used - reserved)
warning_threshold = int(limit * 0.8)
@ -562,8 +640,13 @@ class TokenQuotaService:
async def premium_release(
db_session: AsyncSession,
user_id: Any,
reserved_tokens: int,
reserved_micros: int,
) -> None:
"""Release ``reserved_micros`` previously held by ``premium_reserve``.
Used when a request fails before finalize (so the reservation
doesn't leak credit).
"""
from app.db import User
user = (
@ -576,8 +659,8 @@ class TokenQuotaService:
.scalar_one_or_none()
)
if user is not None:
user.premium_tokens_reserved = max(
0, user.premium_tokens_reserved - reserved_tokens
user.premium_credit_micros_reserved = max(
0, user.premium_credit_micros_reserved - reserved_micros
)
await db_session.commit()
@ -598,9 +681,9 @@ class TokenQuotaService:
allowed=False, status=QuotaStatus.BLOCKED, used=0, limit=0
)
limit = user.premium_tokens_limit
used = user.premium_tokens_used
reserved = user.premium_tokens_reserved
limit = user.premium_credit_micros_limit
used = user.premium_credit_micros_used
reserved = user.premium_credit_micros_reserved
remaining = max(0, limit - used - reserved)
warning_threshold = int(limit * 0.8)

View file

@ -16,11 +16,14 @@ from __future__ import annotations
import dataclasses
import logging
from collections.abc import AsyncIterator
from contextlib import asynccontextmanager
from contextvars import ContextVar
from dataclasses import dataclass, field
from typing import Any
from uuid import UUID
import litellm
from litellm.integrations.custom_logger import CustomLogger
from sqlalchemy.ext.asyncio import AsyncSession
@ -35,6 +38,8 @@ class TokenCallRecord:
prompt_tokens: int
completion_tokens: int
total_tokens: int
cost_micros: int = 0
call_kind: str = "chat"
@dataclass
@ -49,6 +54,8 @@ class TurnTokenAccumulator:
prompt_tokens: int,
completion_tokens: int,
total_tokens: int,
cost_micros: int = 0,
call_kind: str = "chat",
) -> None:
self.calls.append(
TokenCallRecord(
@ -56,20 +63,28 @@ class TurnTokenAccumulator:
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=total_tokens,
cost_micros=cost_micros,
call_kind=call_kind,
)
)
def per_message_summary(self) -> dict[str, dict[str, int]]:
"""Return token counts grouped by model name."""
"""Return token counts (and cost) grouped by model name."""
by_model: dict[str, dict[str, int]] = {}
for c in self.calls:
entry = by_model.setdefault(
c.model,
{"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
{
"prompt_tokens": 0,
"completion_tokens": 0,
"total_tokens": 0,
"cost_micros": 0,
},
)
entry["prompt_tokens"] += c.prompt_tokens
entry["completion_tokens"] += c.completion_tokens
entry["total_tokens"] += c.total_tokens
entry["cost_micros"] += c.cost_micros
return by_model
@property
@ -84,6 +99,21 @@ class TurnTokenAccumulator:
def total_completion_tokens(self) -> int:
return sum(c.completion_tokens for c in self.calls)
@property
def total_cost_micros(self) -> int:
"""Sum of per-call ``cost_micros`` across the entire turn.
Used by ``stream_new_chat`` to debit a premium turn's actual
provider cost (in micro-USD) from the user's premium credit
balance. ``cost_micros`` per call is captured by
``TokenTrackingCallback.async_log_success_event`` from
``kwargs["response_cost"]`` (LiteLLM's auto-calculated cost),
with multiple fallback paths so OpenRouter dynamic models and
custom Azure deployments still bill correctly when our
``pricing_registration`` ran at startup.
"""
return sum(c.cost_micros for c in self.calls)
def serialized_calls(self) -> list[dict[str, Any]]:
return [dataclasses.asdict(c) for c in self.calls]
@ -94,7 +124,14 @@ _turn_accumulator: ContextVar[TurnTokenAccumulator | None] = ContextVar(
def start_turn() -> TurnTokenAccumulator:
"""Create a fresh accumulator for the current async context and return it."""
"""Create a fresh accumulator for the current async context and return it.
NOTE: Used by ``stream_new_chat`` for the long-lived chat turn. For
short-lived per-call billable wrappers (image generation REST endpoint,
vision LLM during indexing) prefer :func:`scoped_turn`, which uses a
ContextVar reset token to restore the *previous* accumulator on exit and
avoids leaking call records across reservations (issue B).
"""
acc = TurnTokenAccumulator()
_turn_accumulator.set(acc)
logger.info("[TokenTracking] start_turn: new accumulator created (id=%s)", id(acc))
@ -105,6 +142,140 @@ def get_current_accumulator() -> TurnTokenAccumulator | None:
return _turn_accumulator.get()
@asynccontextmanager
async def scoped_turn() -> AsyncIterator[TurnTokenAccumulator]:
"""Async context manager that scopes a fresh ``TurnTokenAccumulator``
for the duration of the ``async with`` block, then *resets* the
ContextVar to its previous value on exit.
This is the safe primitive for per-call billable operations
(image generation, vision LLM extraction, podcasts) that may run
inside an outer chat turn or be called sequentially from the same
background worker. Using ``ContextVar.set`` without ``reset`` (as
:func:`start_turn` does) would leak the inner accumulator into the
outer scope, causing the outer chat turn to debit cost twice.
Usage::
async with scoped_turn() as acc:
await llm.ainvoke(...)
# acc.total_cost_micros captures cost from the LiteLLM callback
# Outer accumulator (if any) is restored here.
"""
acc = TurnTokenAccumulator()
token = _turn_accumulator.set(acc)
logger.debug(
"[TokenTracking] scoped_turn: enter (acc id=%s, prev token=%s)",
id(acc),
token,
)
try:
yield acc
finally:
_turn_accumulator.reset(token)
logger.debug(
"[TokenTracking] scoped_turn: exit (acc id=%s captured %d call(s), %d micros total)",
id(acc),
len(acc.calls),
acc.total_cost_micros,
)
def _extract_cost_usd(
kwargs: dict[str, Any],
response_obj: Any,
model: str,
prompt_tokens: int,
completion_tokens: int,
is_image: bool = False,
) -> float:
"""Best-effort USD cost extraction for a single LLM/image call.
Tries four sources in priority order and returns the first that
yields a positive number; returns 0.0 if all four fail (the call
will then debit nothing from the user's balance — fail-safe).
Sources:
1. ``kwargs["response_cost"]`` LiteLLM's standard callback
field, populated for ``Router.acompletion`` since PR #12500.
2. ``response_obj._hidden_params["response_cost"]`` same value
exposed on the response itself.
3. ``litellm.completion_cost(completion_response=response_obj)``
recompute from the response and LiteLLM's pricing table.
4. ``litellm.cost_per_token(model, prompt_tokens, completion_tokens)``
manual fallback for OpenRouter/custom-Azure models that
only resolve via aliases registered by
``pricing_registration`` at startup. **Skipped for image
responses** ``cost_per_token`` does not support ``ImageResponse``
and would raise; the cost map for image-gen lives in different
keys (``output_cost_per_image``) handled by ``completion_cost``.
"""
cost = kwargs.get("response_cost")
if cost is not None:
try:
value = float(cost)
except (TypeError, ValueError):
value = 0.0
if value > 0:
return value
hidden = getattr(response_obj, "_hidden_params", None) or {}
if isinstance(hidden, dict):
cost = hidden.get("response_cost")
if cost is not None:
try:
value = float(cost)
except (TypeError, ValueError):
value = 0.0
if value > 0:
return value
try:
value = float(litellm.completion_cost(completion_response=response_obj))
if value > 0:
return value
except Exception as exc:
if is_image:
# Image-gen path: OpenRouter's image responses can omit
# ``usage.cost`` and LiteLLM's ``default_image_cost_calculator``
# then *raises* (no cost map for OpenRouter image models).
# Bail out with a warning rather than falling through to
# cost_per_token (which is also incompatible with ImageResponse).
logger.warning(
"[TokenTracking] completion_cost failed for image model=%s "
"(provider may have omitted usage.cost). Debiting 0. "
"Cause: %s",
model,
exc,
)
return 0.0
logger.debug(
"[TokenTracking] completion_cost failed for model=%s: %s", model, exc
)
if is_image:
# Never call cost_per_token for ImageResponse — keys mismatch and
# the function is documented chat-only.
return 0.0
if model and (prompt_tokens > 0 or completion_tokens > 0):
try:
prompt_cost, completion_cost = litellm.cost_per_token(
model=model,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
)
value = float(prompt_cost) + float(completion_cost)
if value > 0:
return value
except Exception as exc:
logger.debug(
"[TokenTracking] cost_per_token failed for model=%s: %s", model, exc
)
return 0.0
class TokenTrackingCallback(CustomLogger):
"""LiteLLM callback that captures token usage into the turn accumulator."""
@ -122,6 +293,13 @@ class TokenTrackingCallback(CustomLogger):
)
return
# Detect image generation responses — they have a different usage
# shape (ImageUsage with input_tokens/output_tokens) and require a
# different cost-extraction path. We probe by class name to avoid a
# hard import dependency on litellm internals.
response_cls = type(response_obj).__name__
is_image = response_cls == "ImageResponse"
usage = getattr(response_obj, "usage", None)
if not usage:
logger.debug(
@ -129,24 +307,66 @@ class TokenTrackingCallback(CustomLogger):
)
return
prompt_tokens = getattr(usage, "prompt_tokens", 0) or 0
completion_tokens = getattr(usage, "completion_tokens", 0) or 0
total_tokens = getattr(usage, "total_tokens", 0) or 0
if is_image:
# ``ImageUsage`` exposes ``input_tokens`` / ``output_tokens``
# (not prompt_tokens/completion_tokens). Several providers
# populate only one or neither (e.g. OpenRouter's gpt-image-1
# passes through `input_tokens` from the prompt but no
# completion); fall through gracefully to 0.
prompt_tokens = getattr(usage, "input_tokens", 0) or 0
completion_tokens = getattr(usage, "output_tokens", 0) or 0
total_tokens = (
getattr(usage, "total_tokens", 0) or prompt_tokens + completion_tokens
)
call_kind = "image_generation"
else:
prompt_tokens = getattr(usage, "prompt_tokens", 0) or 0
completion_tokens = getattr(usage, "completion_tokens", 0) or 0
total_tokens = getattr(usage, "total_tokens", 0) or 0
call_kind = "chat"
model = kwargs.get("model", "unknown")
cost_usd = _extract_cost_usd(
kwargs=kwargs,
response_obj=response_obj,
model=model,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
is_image=is_image,
)
cost_micros = round(cost_usd * 1_000_000) if cost_usd > 0 else 0
if cost_micros == 0 and (prompt_tokens > 0 or completion_tokens > 0):
logger.warning(
"[TokenTracking] No cost resolved for model=%s prompt=%d completion=%d "
"kind=%s — debiting 0. Register pricing via pricing_registration or YAML "
"input_cost_per_token/output_cost_per_token (or rely on response_cost "
"for image generation).",
model,
prompt_tokens,
completion_tokens,
call_kind,
)
acc.add(
model=model,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=total_tokens,
cost_micros=cost_micros,
call_kind=call_kind,
)
logger.info(
"[TokenTracking] Captured: model=%s prompt=%d completion=%d total=%d (accumulator now has %d calls)",
"[TokenTracking] Captured: model=%s kind=%s prompt=%d completion=%d total=%d "
"cost=$%.6f (%d micros) (accumulator now has %d calls)",
model,
call_kind,
prompt_tokens,
completion_tokens,
total_tokens,
cost_usd,
cost_micros,
len(acc.calls),
)
@ -168,6 +388,7 @@ async def record_token_usage(
prompt_tokens: int = 0,
completion_tokens: int = 0,
total_tokens: int = 0,
cost_micros: int = 0,
model_breakdown: dict[str, Any] | None = None,
call_details: dict[str, Any] | None = None,
thread_id: int | None = None,
@ -185,6 +406,7 @@ async def record_token_usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=total_tokens,
cost_micros=cost_micros,
model_breakdown=model_breakdown,
call_details=call_details,
thread_id=thread_id,
@ -194,11 +416,12 @@ async def record_token_usage(
)
session.add(record)
logger.debug(
"[TokenTracking] recorded %s usage: prompt=%d completion=%d total=%d",
"[TokenTracking] recorded %s usage: prompt=%d completion=%d total=%d cost_micros=%d",
usage_type,
prompt_tokens,
completion_tokens,
total_tokens,
cost_micros,
)
return record
except Exception:

View file

@ -3,6 +3,8 @@ from typing import Any
from litellm import Router
from app.services.provider_api_base import resolve_api_base
logger = logging.getLogger(__name__)
VISION_AUTO_MODE_ID = 0
@ -108,10 +110,11 @@ class VisionLLMRouterService:
if not config.get("model_name") or not config.get("api_key"):
return None
provider = config.get("provider", "").upper()
if config.get("custom_provider"):
model_string = f"{config['custom_provider']}/{config['model_name']}"
provider_prefix = config["custom_provider"]
model_string = f"{provider_prefix}/{config['model_name']}"
else:
provider = config.get("provider", "").upper()
provider_prefix = VISION_PROVIDER_MAP.get(provider, provider.lower())
model_string = f"{provider_prefix}/{config['model_name']}"
@ -120,8 +123,13 @@ class VisionLLMRouterService:
"api_key": config.get("api_key"),
}
if config.get("api_base"):
litellm_params["api_base"] = config["api_base"]
api_base = resolve_api_base(
provider=provider,
provider_prefix=provider_prefix,
config_api_base=config.get("api_base"),
)
if api_base:
litellm_params["api_base"] = api_base
if config.get("api_version"):
litellm_params["api_version"] = config["api_version"]