mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-04 13:22:41 +02:00
feat: unified credits and its cost calculations
This commit is contained in:
parent
451a98936e
commit
ae9d36d77f
61 changed files with 5835 additions and 272 deletions
|
|
@ -31,6 +31,7 @@ from app.config import (
|
|||
initialize_image_gen_router,
|
||||
initialize_llm_router,
|
||||
initialize_openrouter_integration,
|
||||
initialize_pricing_registration,
|
||||
initialize_vision_llm_router,
|
||||
)
|
||||
from app.db import User, create_db_and_tables, get_async_session
|
||||
|
|
@ -432,6 +433,7 @@ async def lifespan(app: FastAPI):
|
|||
await setup_checkpointer_tables()
|
||||
initialize_openrouter_integration()
|
||||
_start_openrouter_background_refresh()
|
||||
initialize_pricing_registration()
|
||||
initialize_llm_router()
|
||||
initialize_image_gen_router()
|
||||
initialize_vision_llm_router()
|
||||
|
|
|
|||
|
|
@ -22,10 +22,12 @@ def init_worker(**kwargs):
|
|||
initialize_image_gen_router,
|
||||
initialize_llm_router,
|
||||
initialize_openrouter_integration,
|
||||
initialize_pricing_registration,
|
||||
initialize_vision_llm_router,
|
||||
)
|
||||
|
||||
initialize_openrouter_integration()
|
||||
initialize_pricing_registration()
|
||||
initialize_llm_router()
|
||||
initialize_image_gen_router()
|
||||
initialize_vision_llm_router()
|
||||
|
|
|
|||
|
|
@ -138,7 +138,11 @@ def load_global_image_gen_configs():
|
|||
try:
|
||||
with open(global_config_file, encoding="utf-8") as f:
|
||||
data = yaml.safe_load(f)
|
||||
return data.get("global_image_generation_configs", [])
|
||||
configs = data.get("global_image_generation_configs", []) or []
|
||||
for cfg in configs:
|
||||
if isinstance(cfg, dict):
|
||||
cfg.setdefault("billing_tier", "free")
|
||||
return configs
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to load global image generation configs: {e}")
|
||||
return []
|
||||
|
|
@ -153,7 +157,11 @@ def load_global_vision_llm_configs():
|
|||
try:
|
||||
with open(global_config_file, encoding="utf-8") as f:
|
||||
data = yaml.safe_load(f)
|
||||
return data.get("global_vision_llm_configs", [])
|
||||
configs = data.get("global_vision_llm_configs", []) or []
|
||||
for cfg in configs:
|
||||
if isinstance(cfg, dict):
|
||||
cfg.setdefault("billing_tier", "free")
|
||||
return configs
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to load global vision LLM configs: {e}")
|
||||
return []
|
||||
|
|
@ -254,6 +262,15 @@ def load_openrouter_integration_settings() -> dict | None:
|
|||
"anonymous_enabled_free", settings["anonymous_enabled"]
|
||||
)
|
||||
|
||||
# Image generation + vision LLM emission are opt-in (issue L).
|
||||
# OpenRouter's catalogue contains hundreds of image / vision
|
||||
# capable models; auto-injecting all of them into every
|
||||
# deployment would explode the model selector and surprise
|
||||
# operators upgrading from prior versions. Default to False so
|
||||
# admins must explicitly turn them on.
|
||||
settings.setdefault("image_generation_enabled", False)
|
||||
settings.setdefault("vision_enabled", False)
|
||||
|
||||
return settings
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to load OpenRouter integration settings: {e}")
|
||||
|
|
@ -296,10 +313,60 @@ def initialize_openrouter_integration():
|
|||
)
|
||||
else:
|
||||
print("Info: OpenRouter integration enabled but no models fetched")
|
||||
|
||||
# Image generation + vision LLM emissions are opt-in (issue L).
|
||||
# Both reuse the catalogue already cached by ``service.initialize``
|
||||
# so we don't make additional network calls here.
|
||||
if settings.get("image_generation_enabled"):
|
||||
try:
|
||||
image_configs = service.get_image_generation_configs()
|
||||
if image_configs:
|
||||
config.GLOBAL_IMAGE_GEN_CONFIGS.extend(image_configs)
|
||||
print(
|
||||
f"Info: OpenRouter integration added {len(image_configs)} "
|
||||
f"image-generation models"
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to inject OpenRouter image-gen configs: {e}")
|
||||
|
||||
if settings.get("vision_enabled"):
|
||||
try:
|
||||
vision_configs = service.get_vision_llm_configs()
|
||||
if vision_configs:
|
||||
config.GLOBAL_VISION_LLM_CONFIGS.extend(vision_configs)
|
||||
print(
|
||||
f"Info: OpenRouter integration added {len(vision_configs)} "
|
||||
f"vision LLM models"
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to inject OpenRouter vision-LLM configs: {e}")
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to initialize OpenRouter integration: {e}")
|
||||
|
||||
|
||||
def initialize_pricing_registration():
|
||||
"""
|
||||
Teach LiteLLM the per-token cost of every deployment in
|
||||
``config.GLOBAL_LLM_CONFIGS`` (OpenRouter dynamic models pulled
|
||||
from the OpenRouter catalogue + any operator-declared YAML pricing).
|
||||
|
||||
Must run AFTER ``initialize_openrouter_integration()`` so the
|
||||
OpenRouter catalogue is populated and BEFORE the first LLM call so
|
||||
``response_cost`` is available in ``TokenTrackingCallback``.
|
||||
|
||||
Failures are logged but never raised — startup must not be blocked
|
||||
by a missing pricing entry; the worst-case is the model debits 0.
|
||||
"""
|
||||
try:
|
||||
from app.services.pricing_registration import (
|
||||
register_pricing_from_global_configs,
|
||||
)
|
||||
|
||||
register_pricing_from_global_configs()
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to register LiteLLM pricing: {e}")
|
||||
|
||||
|
||||
def initialize_llm_router():
|
||||
"""
|
||||
Initialize the LLM Router service for Auto mode.
|
||||
|
|
@ -444,14 +511,54 @@ class Config:
|
|||
os.getenv("STRIPE_RECONCILIATION_BATCH_SIZE", "100")
|
||||
)
|
||||
|
||||
# Premium token quota settings
|
||||
PREMIUM_TOKEN_LIMIT = int(os.getenv("PREMIUM_TOKEN_LIMIT", "3000000"))
|
||||
# Premium credit (micro-USD) quota settings.
|
||||
#
|
||||
# Storage unit is integer micro-USD (1_000_000 = $1.00). The legacy
|
||||
# ``PREMIUM_TOKEN_LIMIT`` and ``STRIPE_TOKENS_PER_UNIT`` env vars are
|
||||
# still honoured for one release as fall-back values — the prior
|
||||
# $1-per-1M-tokens Stripe price means every existing value maps 1:1
|
||||
# to micros, so operators upgrading without changing their .env still
|
||||
# get correct behaviour. A startup deprecation warning fires below if
|
||||
# they're set.
|
||||
PREMIUM_CREDIT_MICROS_LIMIT = int(
|
||||
os.getenv("PREMIUM_CREDIT_MICROS_LIMIT")
|
||||
or os.getenv("PREMIUM_TOKEN_LIMIT", "5000000")
|
||||
)
|
||||
STRIPE_PREMIUM_TOKEN_PRICE_ID = os.getenv("STRIPE_PREMIUM_TOKEN_PRICE_ID")
|
||||
STRIPE_TOKENS_PER_UNIT = int(os.getenv("STRIPE_TOKENS_PER_UNIT", "1000000"))
|
||||
STRIPE_CREDIT_MICROS_PER_UNIT = int(
|
||||
os.getenv("STRIPE_CREDIT_MICROS_PER_UNIT")
|
||||
or os.getenv("STRIPE_TOKENS_PER_UNIT", "1000000")
|
||||
)
|
||||
STRIPE_TOKEN_BUYING_ENABLED = (
|
||||
os.getenv("STRIPE_TOKEN_BUYING_ENABLED", "FALSE").upper() == "TRUE"
|
||||
)
|
||||
|
||||
# Safety ceiling on the per-call premium reservation. ``stream_new_chat``
|
||||
# estimates an upper-bound cost from ``litellm.get_model_info`` x the
|
||||
# config's ``quota_reserve_tokens`` and clamps the result to this value
|
||||
# so a misconfigured "$1000/M" model can't lock the user's whole balance
|
||||
# on one call. Default $1.00 covers realistic worst-cases (Opus + 4K
|
||||
# reserve_tokens ≈ $0.36) with headroom.
|
||||
QUOTA_MAX_RESERVE_MICROS = int(os.getenv("QUOTA_MAX_RESERVE_MICROS", "1000000"))
|
||||
|
||||
if os.getenv("PREMIUM_TOKEN_LIMIT") and not os.getenv(
|
||||
"PREMIUM_CREDIT_MICROS_LIMIT"
|
||||
):
|
||||
print(
|
||||
"Warning: PREMIUM_TOKEN_LIMIT is deprecated; rename to "
|
||||
"PREMIUM_CREDIT_MICROS_LIMIT (1:1 numerical mapping under the "
|
||||
"current Stripe price). The old key will be removed in a "
|
||||
"future release."
|
||||
)
|
||||
if os.getenv("STRIPE_TOKENS_PER_UNIT") and not os.getenv(
|
||||
"STRIPE_CREDIT_MICROS_PER_UNIT"
|
||||
):
|
||||
print(
|
||||
"Warning: STRIPE_TOKENS_PER_UNIT is deprecated; rename to "
|
||||
"STRIPE_CREDIT_MICROS_PER_UNIT (1:1 numerical mapping). "
|
||||
"The old key will be removed in a future release."
|
||||
)
|
||||
|
||||
# Anonymous / no-login mode settings
|
||||
NOLOGIN_MODE_ENABLED = os.getenv("NOLOGIN_MODE_ENABLED", "FALSE").upper() == "TRUE"
|
||||
ANON_TOKEN_LIMIT = int(os.getenv("ANON_TOKEN_LIMIT", "500000"))
|
||||
|
|
@ -464,6 +571,35 @@ class Config:
|
|||
# Default quota reserve tokens when not specified per-model
|
||||
QUOTA_MAX_RESERVE_PER_CALL = int(os.getenv("QUOTA_MAX_RESERVE_PER_CALL", "8000"))
|
||||
|
||||
# Per-image reservation (in micro-USD) used by ``billable_call`` for the
|
||||
# ``POST /image-generations`` endpoint when the global config does not
|
||||
# override it. $0.05 covers realistic worst-cases for current OpenAI /
|
||||
# OpenRouter image-gen pricing. Bypassed entirely for free configs.
|
||||
QUOTA_DEFAULT_IMAGE_RESERVE_MICROS = int(
|
||||
os.getenv("QUOTA_DEFAULT_IMAGE_RESERVE_MICROS", "50000")
|
||||
)
|
||||
|
||||
# Per-podcast reservation (in micro-USD). One agent LLM call generating
|
||||
# a transcript, typically 5k-20k completion tokens. $0.20 covers a long
|
||||
# premium-model run. Tune via env.
|
||||
QUOTA_DEFAULT_PODCAST_RESERVE_MICROS = int(
|
||||
os.getenv("QUOTA_DEFAULT_PODCAST_RESERVE_MICROS", "200000")
|
||||
)
|
||||
|
||||
# Per-video-presentation reservation (in micro-USD). Fan-out of N
|
||||
# slide-scene generations (up to ``VIDEO_PRESENTATION_MAX_SLIDES=30``)
|
||||
# plus refine retries; can produce many premium completions. $1.00
|
||||
# covers worst-case. Tune via env.
|
||||
#
|
||||
# NOTE: this equals the existing ``QUOTA_MAX_RESERVE_MICROS`` default of
|
||||
# 1_000_000. The override path in ``billable_call`` bypasses the
|
||||
# per-call clamp in ``estimate_call_reserve_micros``, so this is the
|
||||
# *actual* hold — raising it via env is fine but means a single video
|
||||
# task can lock $1+ of credit.
|
||||
QUOTA_DEFAULT_VIDEO_PRESENTATION_RESERVE_MICROS = int(
|
||||
os.getenv("QUOTA_DEFAULT_VIDEO_PRESENTATION_RESERVE_MICROS", "1000000")
|
||||
)
|
||||
|
||||
# Abuse prevention: concurrent stream cap and CAPTCHA
|
||||
ANON_MAX_CONCURRENT_STREAMS = int(os.getenv("ANON_MAX_CONCURRENT_STREAMS", "2"))
|
||||
ANON_CAPTCHA_REQUEST_THRESHOLD = int(
|
||||
|
|
|
|||
|
|
@ -19,6 +19,24 @@
|
|||
# Structure matches NewLLMConfig:
|
||||
# - Model configuration (provider, model_name, api_key, etc.)
|
||||
# - Prompt configuration (system_instructions, citations_enabled)
|
||||
#
|
||||
# COST-BASED PREMIUM CREDITS:
|
||||
# Each premium config bills the user's USD-credit balance based on the
|
||||
# actual provider cost reported by LiteLLM. For models LiteLLM already
|
||||
# knows (most OpenAI/Anthropic/etc. names) you don't need to do anything.
|
||||
# For custom Azure deployment names (e.g. an in-house "gpt-5.4" deployment)
|
||||
# or any model LiteLLM doesn't have in its built-in pricing table, declare
|
||||
# per-token costs inline so they bill correctly:
|
||||
#
|
||||
# litellm_params:
|
||||
# base_model: "my-custom-azure-deploy"
|
||||
# # USD per token; e.g. 0.000003 == $3.00 per million input tokens
|
||||
# input_cost_per_token: 0.000003
|
||||
# output_cost_per_token: 0.000015
|
||||
#
|
||||
# OpenRouter dynamic models pull pricing automatically from OpenRouter's
|
||||
# API — no inline declaration needed. Models without resolvable pricing
|
||||
# debit $0 from the user's balance and log a WARNING.
|
||||
|
||||
# Router Settings for Auto Mode
|
||||
# These settings control how the LiteLLM Router distributes requests across models
|
||||
|
|
@ -292,6 +310,17 @@ openrouter_integration:
|
|||
free_rpm: 20
|
||||
free_tpm: 100000
|
||||
|
||||
# Image generation + vision LLM emission are OPT-IN. OpenRouter's catalogue
|
||||
# contains hundreds of image- and vision-capable models; turning these on
|
||||
# injects them into the global Image-Generation / Vision-LLM model
|
||||
# selectors alongside any static configs. Tier (free/premium) is derived
|
||||
# per model the same way it is for chat (`:free` suffix or zero pricing).
|
||||
# When a user picks a premium image/vision model the call debits the
|
||||
# shared $5 USD-cost-based premium credit pool — so leaving these off
|
||||
# avoids surprise quota burn on existing deployments. Default: false.
|
||||
image_generation_enabled: false
|
||||
vision_enabled: false
|
||||
|
||||
litellm_params:
|
||||
max_tokens: 16384
|
||||
system_instructions: ""
|
||||
|
|
|
|||
|
|
@ -731,6 +731,7 @@ class TokenUsage(BaseModel, TimestampMixin):
|
|||
prompt_tokens = Column(Integer, nullable=False, default=0)
|
||||
completion_tokens = Column(Integer, nullable=False, default=0)
|
||||
total_tokens = Column(Integer, nullable=False, default=0)
|
||||
cost_micros = Column(BigInteger, nullable=False, default=0, server_default="0")
|
||||
model_breakdown = Column(JSONB, nullable=True)
|
||||
call_details = Column(JSONB, nullable=True)
|
||||
|
||||
|
|
@ -1793,7 +1794,15 @@ class PagePurchase(Base, TimestampMixin):
|
|||
|
||||
|
||||
class PremiumTokenPurchase(Base, TimestampMixin):
|
||||
"""Tracks Stripe checkout sessions used to grant additional premium token credits."""
|
||||
"""Tracks Stripe checkout sessions used to grant additional premium credit (USD micro-units).
|
||||
|
||||
Note: the table name is preserved (``premium_token_purchases``) for
|
||||
operational continuity even though the unit is now USD micro-credits
|
||||
instead of raw tokens. The ``credit_micros_granted`` column replaced
|
||||
the legacy ``tokens_granted`` in migration 140; the stored values
|
||||
were not transformed because the prior $1 = 1M tokens Stripe price
|
||||
makes the unit conversion 1:1 numerically.
|
||||
"""
|
||||
|
||||
__tablename__ = "premium_token_purchases"
|
||||
__allow_unmapped__ = True
|
||||
|
|
@ -1810,7 +1819,7 @@ class PremiumTokenPurchase(Base, TimestampMixin):
|
|||
)
|
||||
stripe_payment_intent_id = Column(String(255), nullable=True, index=True)
|
||||
quantity = Column(Integer, nullable=False)
|
||||
tokens_granted = Column(BigInteger, nullable=False)
|
||||
credit_micros_granted = Column(BigInteger, nullable=False)
|
||||
amount_total = Column(Integer, nullable=True)
|
||||
currency = Column(String(10), nullable=True)
|
||||
status = Column(
|
||||
|
|
@ -2109,16 +2118,16 @@ if config.AUTH_TYPE == "GOOGLE":
|
|||
)
|
||||
pages_used = Column(Integer, nullable=False, default=0, server_default="0")
|
||||
|
||||
premium_tokens_limit = Column(
|
||||
premium_credit_micros_limit = Column(
|
||||
BigInteger,
|
||||
nullable=False,
|
||||
default=config.PREMIUM_TOKEN_LIMIT,
|
||||
server_default=str(config.PREMIUM_TOKEN_LIMIT),
|
||||
default=config.PREMIUM_CREDIT_MICROS_LIMIT,
|
||||
server_default=str(config.PREMIUM_CREDIT_MICROS_LIMIT),
|
||||
)
|
||||
premium_tokens_used = Column(
|
||||
premium_credit_micros_used = Column(
|
||||
BigInteger, nullable=False, default=0, server_default="0"
|
||||
)
|
||||
premium_tokens_reserved = Column(
|
||||
premium_credit_micros_reserved = Column(
|
||||
BigInteger, nullable=False, default=0, server_default="0"
|
||||
)
|
||||
|
||||
|
|
@ -2241,16 +2250,16 @@ else:
|
|||
)
|
||||
pages_used = Column(Integer, nullable=False, default=0, server_default="0")
|
||||
|
||||
premium_tokens_limit = Column(
|
||||
premium_credit_micros_limit = Column(
|
||||
BigInteger,
|
||||
nullable=False,
|
||||
default=config.PREMIUM_TOKEN_LIMIT,
|
||||
server_default=str(config.PREMIUM_TOKEN_LIMIT),
|
||||
default=config.PREMIUM_CREDIT_MICROS_LIMIT,
|
||||
server_default=str(config.PREMIUM_CREDIT_MICROS_LIMIT),
|
||||
)
|
||||
premium_tokens_used = Column(
|
||||
premium_credit_micros_used = Column(
|
||||
BigInteger, nullable=False, default=0, server_default="0"
|
||||
)
|
||||
premium_tokens_reserved = Column(
|
||||
premium_credit_micros_reserved = Column(
|
||||
BigInteger, nullable=False, default=0, server_default="0"
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -68,12 +68,25 @@ class EtlPipelineService:
|
|||
etl_service="VISION_LLM",
|
||||
content_type="image",
|
||||
)
|
||||
except Exception:
|
||||
logging.warning(
|
||||
"Vision LLM failed for %s, falling back to document parser",
|
||||
request.filename,
|
||||
exc_info=True,
|
||||
)
|
||||
except Exception as exc:
|
||||
# Special-case quota exhaustion so we log a clearer message
|
||||
# — the vision LLM didn't "fail", the user just ran out of
|
||||
# premium credit. Falling through to the document parser
|
||||
# is a graceful degradation: OCR/Unstructured still
|
||||
# extracts text from the image without burning credit.
|
||||
from app.services.billable_calls import QuotaInsufficientError
|
||||
|
||||
if isinstance(exc, QuotaInsufficientError):
|
||||
logging.info(
|
||||
"Vision LLM quota exhausted for %s; falling back to document parser",
|
||||
request.filename,
|
||||
)
|
||||
else:
|
||||
logging.warning(
|
||||
"Vision LLM failed for %s, falling back to document parser",
|
||||
request.filename,
|
||||
exc_info=True,
|
||||
)
|
||||
else:
|
||||
logging.info(
|
||||
"No vision LLM provided, falling back to document parser for %s",
|
||||
|
|
|
|||
|
|
@ -36,6 +36,11 @@ from app.schemas import (
|
|||
ImageGenerationListRead,
|
||||
ImageGenerationRead,
|
||||
)
|
||||
from app.services.billable_calls import (
|
||||
DEFAULT_IMAGE_RESERVE_MICROS,
|
||||
QuotaInsufficientError,
|
||||
billable_call,
|
||||
)
|
||||
from app.services.image_gen_router_service import (
|
||||
IMAGE_GEN_AUTO_MODE_ID,
|
||||
ImageGenRouterService,
|
||||
|
|
@ -92,6 +97,50 @@ def _build_model_string(
|
|||
return f"{prefix}/{model_name}"
|
||||
|
||||
|
||||
async def _resolve_billing_for_image_gen(
|
||||
session: AsyncSession,
|
||||
config_id: int | None,
|
||||
search_space: SearchSpace,
|
||||
) -> tuple[str, str, int]:
|
||||
"""Resolve ``(billing_tier, base_model, reserve_micros)`` for a request.
|
||||
|
||||
The resolution mirrors ``_execute_image_generation``'s lookup tree but
|
||||
only extracts the fields needed for billing — we do this *before*
|
||||
``billable_call`` so the reservation is correctly sized for the
|
||||
config that will actually run, and so we don't open an
|
||||
``ImageGeneration`` row for a request that's about to 402.
|
||||
|
||||
User-owned (positive ID) BYOK configs are always free — they cost
|
||||
the user nothing on our side. Auto mode currently treats as free
|
||||
because the underlying router can dispatch to either premium or
|
||||
free YAML configs and we don't surface the resolved deployment up
|
||||
here yet. Bringing Auto under premium billing would require
|
||||
threading the chosen deployment back from ``ImageGenRouterService``.
|
||||
"""
|
||||
resolved_id = config_id
|
||||
if resolved_id is None:
|
||||
resolved_id = search_space.image_generation_config_id or IMAGE_GEN_AUTO_MODE_ID
|
||||
|
||||
if is_image_gen_auto_mode(resolved_id):
|
||||
return ("free", "auto", DEFAULT_IMAGE_RESERVE_MICROS)
|
||||
|
||||
if resolved_id < 0:
|
||||
cfg = _get_global_image_gen_config(resolved_id) or {}
|
||||
billing_tier = str(cfg.get("billing_tier", "free")).lower()
|
||||
base_model = _build_model_string(
|
||||
cfg.get("provider", ""),
|
||||
cfg.get("model_name", ""),
|
||||
cfg.get("custom_provider"),
|
||||
)
|
||||
reserve_micros = int(
|
||||
cfg.get("quota_reserve_micros") or DEFAULT_IMAGE_RESERVE_MICROS
|
||||
)
|
||||
return (billing_tier, base_model, reserve_micros)
|
||||
|
||||
# Positive ID = user-owned BYOK image-gen config — always free.
|
||||
return ("free", "user_byok", DEFAULT_IMAGE_RESERVE_MICROS)
|
||||
|
||||
|
||||
async def _execute_image_generation(
|
||||
session: AsyncSession,
|
||||
image_gen: ImageGeneration,
|
||||
|
|
@ -225,6 +274,9 @@ async def get_global_image_gen_configs(
|
|||
"litellm_params": {},
|
||||
"is_global": True,
|
||||
"is_auto_mode": True,
|
||||
# Auto mode currently treated as free until per-deployment
|
||||
# billing-tier surfacing lands (see _resolve_billing_for_image_gen).
|
||||
"billing_tier": "free",
|
||||
}
|
||||
)
|
||||
|
||||
|
|
@ -241,6 +293,8 @@ async def get_global_image_gen_configs(
|
|||
"api_version": cfg.get("api_version") or None,
|
||||
"litellm_params": cfg.get("litellm_params", {}),
|
||||
"is_global": True,
|
||||
"billing_tier": cfg.get("billing_tier", "free"),
|
||||
"quota_reserve_micros": cfg.get("quota_reserve_micros"),
|
||||
}
|
||||
)
|
||||
|
||||
|
|
@ -454,7 +508,26 @@ async def create_image_generation(
|
|||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""Create and execute an image generation request."""
|
||||
"""Create and execute an image generation request.
|
||||
|
||||
Premium configs are gated by the user's shared premium credit pool.
|
||||
The flow is:
|
||||
|
||||
1. Permission check + load the search space (cheap, no provider call).
|
||||
2. Resolve which config will run so we know its billing tier and the
|
||||
worst-case reservation size *before* opening any DB rows.
|
||||
3. Wrap the entire ImageGeneration row insert + provider call in
|
||||
``billable_call``. If quota is denied, ``billable_call`` raises
|
||||
``QuotaInsufficientError`` *before* we flush a row, which we
|
||||
translate to HTTP 402 (no orphaned rows on the user's account,
|
||||
no inserted error rows for "you ran out of credit").
|
||||
4. On success, the actual ``response_cost`` flows through the
|
||||
LiteLLM callback into the accumulator, and ``billable_call``
|
||||
finalizes the debit at exit. Inner ``try/except`` still catches
|
||||
provider errors and stores them on ``error_message`` (HTTP 200
|
||||
with ``error_message`` set is preserved for failed-but-not-quota
|
||||
scenarios — clients already know how to surface those).
|
||||
"""
|
||||
try:
|
||||
await check_permission(
|
||||
session,
|
||||
|
|
@ -471,33 +544,70 @@ async def create_image_generation(
|
|||
if not search_space:
|
||||
raise HTTPException(status_code=404, detail="Search space not found")
|
||||
|
||||
db_image_gen = ImageGeneration(
|
||||
prompt=data.prompt,
|
||||
model=data.model,
|
||||
n=data.n,
|
||||
quality=data.quality,
|
||||
size=data.size,
|
||||
style=data.style,
|
||||
response_format=data.response_format,
|
||||
image_generation_config_id=data.image_generation_config_id,
|
||||
search_space_id=data.search_space_id,
|
||||
created_by_id=user.id,
|
||||
billing_tier, base_model, reserve_micros = await _resolve_billing_for_image_gen(
|
||||
session, data.image_generation_config_id, search_space
|
||||
)
|
||||
session.add(db_image_gen)
|
||||
await session.flush()
|
||||
|
||||
try:
|
||||
await _execute_image_generation(session, db_image_gen, search_space)
|
||||
except Exception as e:
|
||||
logger.exception("Image generation call failed")
|
||||
db_image_gen.error_message = str(e)
|
||||
# billable_call runs OUTSIDE the inner try/except so QuotaInsufficientError
|
||||
# propagates to the outer ``except QuotaInsufficientError`` handler
|
||||
# below as HTTP 402 — it is intentionally NOT swallowed into
|
||||
# ``error_message`` because that would (1) imply a successful row
|
||||
# exists when none does, and (2) return HTTP 200 to a client
|
||||
# whose request was actively *denied* (issue K).
|
||||
async with billable_call(
|
||||
user_id=search_space.user_id,
|
||||
search_space_id=data.search_space_id,
|
||||
billing_tier=billing_tier,
|
||||
base_model=base_model,
|
||||
quota_reserve_micros_override=reserve_micros,
|
||||
usage_type="image_generation",
|
||||
call_details={"model": base_model, "prompt": data.prompt[:100]},
|
||||
):
|
||||
db_image_gen = ImageGeneration(
|
||||
prompt=data.prompt,
|
||||
model=data.model,
|
||||
n=data.n,
|
||||
quality=data.quality,
|
||||
size=data.size,
|
||||
style=data.style,
|
||||
response_format=data.response_format,
|
||||
image_generation_config_id=data.image_generation_config_id,
|
||||
search_space_id=data.search_space_id,
|
||||
created_by_id=user.id,
|
||||
)
|
||||
session.add(db_image_gen)
|
||||
await session.flush()
|
||||
|
||||
await session.commit()
|
||||
await session.refresh(db_image_gen)
|
||||
return db_image_gen
|
||||
try:
|
||||
await _execute_image_generation(session, db_image_gen, search_space)
|
||||
except Exception as e:
|
||||
logger.exception("Image generation call failed")
|
||||
db_image_gen.error_message = str(e)
|
||||
|
||||
await session.commit()
|
||||
await session.refresh(db_image_gen)
|
||||
return db_image_gen
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except QuotaInsufficientError as exc:
|
||||
# The user's premium credit pool is empty. No DB row is created
|
||||
# because ``billable_call`` denies before yielding (issue K).
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=402,
|
||||
detail={
|
||||
"error_code": "premium_quota_exhausted",
|
||||
"usage_type": exc.usage_type,
|
||||
"used_micros": exc.used_micros,
|
||||
"limit_micros": exc.limit_micros,
|
||||
"remaining_micros": exc.remaining_micros,
|
||||
"message": (
|
||||
"Out of premium credits for image generation. "
|
||||
"Purchase additional credits or switch to a free model."
|
||||
),
|
||||
},
|
||||
) from exc
|
||||
except SQLAlchemyError:
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
|
|
|
|||
|
|
@ -1366,7 +1366,11 @@ async def append_message(
|
|||
# flush assigns the PK/defaults without a round-trip SELECT
|
||||
await session.flush()
|
||||
|
||||
# Persist token usage if provided (for assistant messages)
|
||||
# Persist token usage if provided (for assistant messages).
|
||||
# ``cost_micros`` is the provider USD cost reported by LiteLLM,
|
||||
# forwarded by the FE through the appendMessage round-trip so
|
||||
# the historical TokenUsage row matches the credit debit applied
|
||||
# at finalize time.
|
||||
token_usage_data = raw_body.get("token_usage")
|
||||
if token_usage_data and message_role == NewChatMessageRole.ASSISTANT:
|
||||
await record_token_usage(
|
||||
|
|
@ -1377,6 +1381,7 @@ async def append_message(
|
|||
prompt_tokens=token_usage_data.get("prompt_tokens", 0),
|
||||
completion_tokens=token_usage_data.get("completion_tokens", 0),
|
||||
total_tokens=token_usage_data.get("total_tokens", 0),
|
||||
cost_micros=token_usage_data.get("cost_micros", 0),
|
||||
model_breakdown=token_usage_data.get("usage"),
|
||||
call_details=token_usage_data.get("call_details"),
|
||||
thread_id=thread_id,
|
||||
|
|
|
|||
|
|
@ -594,6 +594,7 @@ async def _get_image_gen_config_by_id(
|
|||
"model_name": "auto",
|
||||
"is_global": True,
|
||||
"is_auto_mode": True,
|
||||
"billing_tier": "free",
|
||||
}
|
||||
|
||||
if config_id < 0:
|
||||
|
|
@ -610,6 +611,7 @@ async def _get_image_gen_config_by_id(
|
|||
"api_version": cfg.get("api_version") or None,
|
||||
"litellm_params": cfg.get("litellm_params", {}),
|
||||
"is_global": True,
|
||||
"billing_tier": cfg.get("billing_tier", "free"),
|
||||
}
|
||||
return None
|
||||
|
||||
|
|
@ -652,6 +654,7 @@ async def _get_vision_llm_config_by_id(
|
|||
"model_name": "auto",
|
||||
"is_global": True,
|
||||
"is_auto_mode": True,
|
||||
"billing_tier": "free",
|
||||
}
|
||||
|
||||
if config_id < 0:
|
||||
|
|
@ -668,6 +671,7 @@ async def _get_vision_llm_config_by_id(
|
|||
"api_version": cfg.get("api_version") or None,
|
||||
"litellm_params": cfg.get("litellm_params", {}),
|
||||
"is_global": True,
|
||||
"billing_tier": cfg.get("billing_tier", "free"),
|
||||
}
|
||||
return None
|
||||
|
||||
|
|
|
|||
|
|
@ -251,9 +251,16 @@ async def _fulfill_completed_token_purchase(
|
|||
metadata = _get_metadata(checkout_session)
|
||||
user_id = metadata.get("user_id")
|
||||
quantity = int(metadata.get("quantity", "0"))
|
||||
tokens_per_unit = int(metadata.get("tokens_per_unit", "0"))
|
||||
# Read the new metadata key first, fall back to the legacy one so
|
||||
# in-flight checkout sessions created before the cost-credits
|
||||
# release still fulfil correctly (the unit is numerically the
|
||||
# same: $1 buys 1_000_000 micro-USD == 1_000_000 tokens).
|
||||
credit_micros_per_unit = int(
|
||||
metadata.get("credit_micros_per_unit")
|
||||
or metadata.get("tokens_per_unit", "0")
|
||||
)
|
||||
|
||||
if not user_id or quantity <= 0 or tokens_per_unit <= 0:
|
||||
if not user_id or quantity <= 0 or credit_micros_per_unit <= 0:
|
||||
logger.error(
|
||||
"Skipping token fulfillment for session %s: incomplete metadata %s",
|
||||
checkout_session_id,
|
||||
|
|
@ -268,7 +275,7 @@ async def _fulfill_completed_token_purchase(
|
|||
getattr(checkout_session, "payment_intent", None)
|
||||
),
|
||||
quantity=quantity,
|
||||
tokens_granted=quantity * tokens_per_unit,
|
||||
credit_micros_granted=quantity * credit_micros_per_unit,
|
||||
amount_total=getattr(checkout_session, "amount_total", None),
|
||||
currency=getattr(checkout_session, "currency", None),
|
||||
status=PremiumTokenPurchaseStatus.PENDING,
|
||||
|
|
@ -303,9 +310,14 @@ async def _fulfill_completed_token_purchase(
|
|||
purchase.stripe_payment_intent_id = _normalize_optional_string(
|
||||
getattr(checkout_session, "payment_intent", None)
|
||||
)
|
||||
user.premium_tokens_limit = (
|
||||
max(user.premium_tokens_used, user.premium_tokens_limit)
|
||||
+ purchase.tokens_granted
|
||||
# Top up the user's credit balance by the granted micro-USD amount.
|
||||
# ``max(used, limit)`` clamps the case where the legacy code wrote a
|
||||
# used value above the limit (e.g. underbilling rounding) so adding
|
||||
# ``credit_micros_granted`` always lifts the limit by the full pack
|
||||
# size rather than disappearing into past overuse.
|
||||
user.premium_credit_micros_limit = (
|
||||
max(user.premium_credit_micros_used, user.premium_credit_micros_limit)
|
||||
+ purchase.credit_micros_granted
|
||||
)
|
||||
|
||||
await db_session.commit()
|
||||
|
|
@ -532,12 +544,18 @@ async def create_token_checkout_session(
|
|||
user: User = Depends(current_active_user),
|
||||
db_session: AsyncSession = Depends(get_async_session),
|
||||
):
|
||||
"""Create a Stripe Checkout Session for buying premium token packs."""
|
||||
"""Create a Stripe Checkout Session for buying premium credit packs.
|
||||
|
||||
Each pack grants ``STRIPE_CREDIT_MICROS_PER_UNIT`` micro-USD of
|
||||
credit (default 1_000_000 = $1.00). The user's balance is debited
|
||||
at the actual provider cost reported by LiteLLM at finalize time,
|
||||
so $1 of credit always buys $1 worth of provider usage at cost.
|
||||
"""
|
||||
_ensure_token_buying_enabled()
|
||||
stripe_client = get_stripe_client()
|
||||
price_id = _get_required_token_price_id()
|
||||
success_url, cancel_url = _get_token_checkout_urls(body.search_space_id)
|
||||
tokens_granted = body.quantity * config.STRIPE_TOKENS_PER_UNIT
|
||||
credit_micros_granted = body.quantity * config.STRIPE_CREDIT_MICROS_PER_UNIT
|
||||
|
||||
try:
|
||||
checkout_session = stripe_client.v1.checkout.sessions.create(
|
||||
|
|
@ -556,8 +574,8 @@ async def create_token_checkout_session(
|
|||
"metadata": {
|
||||
"user_id": str(user.id),
|
||||
"quantity": str(body.quantity),
|
||||
"tokens_per_unit": str(config.STRIPE_TOKENS_PER_UNIT),
|
||||
"purchase_type": "premium_tokens",
|
||||
"credit_micros_per_unit": str(config.STRIPE_CREDIT_MICROS_PER_UNIT),
|
||||
"purchase_type": "premium_credit",
|
||||
},
|
||||
}
|
||||
)
|
||||
|
|
@ -583,7 +601,7 @@ async def create_token_checkout_session(
|
|||
getattr(checkout_session, "payment_intent", None)
|
||||
),
|
||||
quantity=body.quantity,
|
||||
tokens_granted=tokens_granted,
|
||||
credit_micros_granted=credit_micros_granted,
|
||||
amount_total=getattr(checkout_session, "amount_total", None),
|
||||
currency=getattr(checkout_session, "currency", None),
|
||||
status=PremiumTokenPurchaseStatus.PENDING,
|
||||
|
|
@ -598,14 +616,19 @@ async def create_token_checkout_session(
|
|||
async def get_token_status(
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""Return token-buying availability and current premium quota for frontend."""
|
||||
used = user.premium_tokens_used
|
||||
limit = user.premium_tokens_limit
|
||||
"""Return token-buying availability and current premium credit quota for frontend.
|
||||
|
||||
Values are in micro-USD (1_000_000 = $1.00); the FE divides by 1M
|
||||
when displaying. The route name is preserved for back-compat with
|
||||
pinned client deployments.
|
||||
"""
|
||||
used = user.premium_credit_micros_used
|
||||
limit = user.premium_credit_micros_limit
|
||||
return TokenStripeStatusResponse(
|
||||
token_buying_enabled=config.STRIPE_TOKEN_BUYING_ENABLED,
|
||||
premium_tokens_used=used,
|
||||
premium_tokens_limit=limit,
|
||||
premium_tokens_remaining=max(0, limit - used),
|
||||
premium_credit_micros_used=used,
|
||||
premium_credit_micros_limit=limit,
|
||||
premium_credit_micros_remaining=max(0, limit - used),
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -82,6 +82,9 @@ async def get_global_vision_llm_configs(
|
|||
"litellm_params": {},
|
||||
"is_global": True,
|
||||
"is_auto_mode": True,
|
||||
# Auto mode treated as free until per-deployment billing-tier
|
||||
# surfacing lands; see ``get_vision_llm`` for parity.
|
||||
"billing_tier": "free",
|
||||
}
|
||||
)
|
||||
|
||||
|
|
@ -98,6 +101,10 @@ async def get_global_vision_llm_configs(
|
|||
"api_version": cfg.get("api_version") or None,
|
||||
"litellm_params": cfg.get("litellm_params", {}),
|
||||
"is_global": True,
|
||||
"billing_tier": cfg.get("billing_tier", "free"),
|
||||
"quota_reserve_tokens": cfg.get("quota_reserve_tokens"),
|
||||
"input_cost_per_token": cfg.get("input_cost_per_token"),
|
||||
"output_cost_per_token": cfg.get("output_cost_per_token"),
|
||||
}
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -215,6 +215,12 @@ class GlobalImageGenConfigRead(BaseModel):
|
|||
Schema for reading global image generation configs from YAML.
|
||||
Global configs have negative IDs. API key is hidden.
|
||||
ID 0 is reserved for Auto mode (LiteLLM Router load balancing).
|
||||
|
||||
The ``billing_tier`` field allows the frontend to show a Premium/Free
|
||||
badge and (more importantly) tells the backend whether to debit the
|
||||
user's premium credit pool when this config is used. ``"free"`` is
|
||||
the default for backward compatibility — admins must explicitly opt
|
||||
a global config into ``"premium"``.
|
||||
"""
|
||||
|
||||
id: int = Field(
|
||||
|
|
@ -231,3 +237,15 @@ class GlobalImageGenConfigRead(BaseModel):
|
|||
litellm_params: dict[str, Any] | None = None
|
||||
is_global: bool = True
|
||||
is_auto_mode: bool = False
|
||||
billing_tier: str = Field(
|
||||
default="free",
|
||||
description="'free' or 'premium'. Premium debits the user's premium credit pool (USD-cost-based).",
|
||||
)
|
||||
quota_reserve_micros: int | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Optional override for the reservation amount (in micro-USD) used when "
|
||||
"this image generation is premium. Falls back to "
|
||||
"QUOTA_DEFAULT_IMAGE_RESERVE_MICROS when omitted."
|
||||
),
|
||||
)
|
||||
|
|
|
|||
|
|
@ -39,6 +39,7 @@ class TokenUsageSummary(BaseModel):
|
|||
prompt_tokens: int = 0
|
||||
completion_tokens: int = 0
|
||||
total_tokens: int = 0
|
||||
cost_micros: int = 0
|
||||
model_breakdown: dict | None = None
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
|
|
|||
|
|
@ -70,13 +70,17 @@ class CreateTokenCheckoutSessionResponse(BaseModel):
|
|||
|
||||
|
||||
class TokenPurchaseRead(BaseModel):
|
||||
"""Serialized premium token purchase record."""
|
||||
"""Serialized premium credit purchase record.
|
||||
|
||||
``credit_micros_granted`` is in micro-USD (1_000_000 = $1.00). The
|
||||
schema name kept ``Token`` for API back-compat with pinned clients.
|
||||
"""
|
||||
|
||||
id: uuid.UUID
|
||||
stripe_checkout_session_id: str
|
||||
stripe_payment_intent_id: str | None = None
|
||||
quantity: int
|
||||
tokens_granted: int
|
||||
credit_micros_granted: int
|
||||
amount_total: int | None = None
|
||||
currency: str | None = None
|
||||
status: str
|
||||
|
|
@ -87,15 +91,19 @@ class TokenPurchaseRead(BaseModel):
|
|||
|
||||
|
||||
class TokenPurchaseHistoryResponse(BaseModel):
|
||||
"""Response containing the user's premium token purchases."""
|
||||
"""Response containing the user's premium credit purchases."""
|
||||
|
||||
purchases: list[TokenPurchaseRead]
|
||||
|
||||
|
||||
class TokenStripeStatusResponse(BaseModel):
|
||||
"""Response describing token-buying availability and current quota."""
|
||||
"""Response describing premium-credit-buying availability and balance.
|
||||
|
||||
All ``premium_credit_micros_*`` fields are in micro-USD; the FE
|
||||
divides by 1_000_000 to display USD.
|
||||
"""
|
||||
|
||||
token_buying_enabled: bool
|
||||
premium_tokens_used: int = 0
|
||||
premium_tokens_limit: int = 0
|
||||
premium_tokens_remaining: int = 0
|
||||
premium_credit_micros_used: int = 0
|
||||
premium_credit_micros_limit: int = 0
|
||||
premium_credit_micros_remaining: int = 0
|
||||
|
|
|
|||
|
|
@ -62,6 +62,15 @@ class VisionLLMConfigPublic(BaseModel):
|
|||
|
||||
|
||||
class GlobalVisionLLMConfigRead(BaseModel):
|
||||
"""Schema for reading global vision LLM configs from YAML.
|
||||
|
||||
The ``billing_tier`` field allows the frontend to show a Premium/Free
|
||||
badge and (more importantly) tells the backend whether to debit the
|
||||
user's premium credit pool when this config is used. ``"free"`` is
|
||||
the default for backward compatibility — admins must explicitly opt
|
||||
a global config into ``"premium"``.
|
||||
"""
|
||||
|
||||
id: int = Field(...)
|
||||
name: str
|
||||
description: str | None = None
|
||||
|
|
@ -73,3 +82,26 @@ class GlobalVisionLLMConfigRead(BaseModel):
|
|||
litellm_params: dict[str, Any] | None = None
|
||||
is_global: bool = True
|
||||
is_auto_mode: bool = False
|
||||
billing_tier: str = Field(
|
||||
default="free",
|
||||
description="'free' or 'premium'. Premium debits the user's premium credit pool (USD-cost-based).",
|
||||
)
|
||||
quota_reserve_tokens: int | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Optional override for the per-call reservation in *tokens* — "
|
||||
"converted to micro-USD via the model's input/output prices at "
|
||||
"reservation time. Falls back to QUOTA_DEFAULT_RESERVE_TOKENS."
|
||||
),
|
||||
)
|
||||
input_cost_per_token: float | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Optional input price in USD/token. Used by pricing_registration to "
|
||||
"register custom Azure / OpenRouter aliases with LiteLLM at startup."
|
||||
),
|
||||
)
|
||||
output_cost_per_token: float | None = Field(
|
||||
default=None,
|
||||
description="Optional output price in USD/token. Pair with input_cost_per_token.",
|
||||
)
|
||||
|
|
|
|||
430
surfsense_backend/app/services/billable_calls.py
Normal file
430
surfsense_backend/app/services/billable_calls.py
Normal file
|
|
@ -0,0 +1,430 @@
|
|||
"""
|
||||
Per-call billable wrapper for image generation, vision LLM extraction, and
|
||||
any other short-lived premium operation that must charge against the user's
|
||||
shared premium credit pool.
|
||||
|
||||
The ``billable_call`` async context manager encapsulates the standard
|
||||
"reserve → execute → finalize / release → record audit row" lifecycle in a
|
||||
single primitive so callers (the image-generation REST route and the
|
||||
vision-LLM wrapper used during indexing) don't have to re-implement it.
|
||||
|
||||
KEY DESIGN POINTS (issue A, B):
|
||||
|
||||
1. **Session isolation.** ``billable_call`` takes *no* ``db_session``
|
||||
argument. All ``TokenQuotaService.premium_*`` calls and the audit-row
|
||||
insert each run inside their own ``shielded_async_session()``. This
|
||||
guarantees that a quota commit/rollback can never accidentally flush or
|
||||
roll back rows the caller has staged in the request's main session
|
||||
(e.g. a freshly-created ``ImageGeneration`` row).
|
||||
|
||||
2. **ContextVar safety.** The accumulator is scoped via
|
||||
:func:`scoped_turn` (which uses ``ContextVar.reset(token)``), so a
|
||||
nested ``billable_call`` inside an outer chat turn cannot corrupt the
|
||||
chat turn's accumulator.
|
||||
|
||||
3. **Free configs are still audited.** Free calls bypass the reserve /
|
||||
finalize dance entirely but still record a ``TokenUsage`` audit row with
|
||||
the LiteLLM-reported ``cost_micros``. This keeps the cost-attribution
|
||||
pipeline complete for analytics even when nothing is debited.
|
||||
|
||||
4. **Quota denial raises ``QuotaInsufficientError``.** The route handler is
|
||||
responsible for translating that into HTTP 402. We *do not* catch the
|
||||
denial inside ``billable_call`` — letting it propagate also prevents
|
||||
the image-generation route from creating an ``ImageGeneration`` row
|
||||
for a request that never actually ran.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections.abc import AsyncIterator
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Any
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.config import config
|
||||
from app.db import shielded_async_session
|
||||
from app.services.token_quota_service import (
|
||||
TokenQuotaService,
|
||||
estimate_call_reserve_micros,
|
||||
)
|
||||
from app.services.token_tracking_service import (
|
||||
TurnTokenAccumulator,
|
||||
record_token_usage,
|
||||
scoped_turn,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class QuotaInsufficientError(Exception):
|
||||
"""Raised when ``TokenQuotaService.premium_reserve`` denies a billable
|
||||
call because the user has exhausted their premium credit pool.
|
||||
|
||||
The route handler should catch this and return HTTP 402 Payment
|
||||
Required (or the equivalent for the surface area). Outside of the HTTP
|
||||
layer (e.g. the ``QuotaCheckedVisionLLM`` wrapper used during indexing)
|
||||
callers may catch this and degrade gracefully — e.g. fall back to OCR
|
||||
when vision is unavailable.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
usage_type: str,
|
||||
used_micros: int,
|
||||
limit_micros: int,
|
||||
remaining_micros: int,
|
||||
) -> None:
|
||||
self.usage_type = usage_type
|
||||
self.used_micros = used_micros
|
||||
self.limit_micros = limit_micros
|
||||
self.remaining_micros = remaining_micros
|
||||
super().__init__(
|
||||
f"Premium credit exhausted for {usage_type}: "
|
||||
f"used={used_micros} limit={limit_micros} remaining={remaining_micros} (micro-USD)"
|
||||
)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def billable_call(
|
||||
*,
|
||||
user_id: UUID,
|
||||
search_space_id: int,
|
||||
billing_tier: str,
|
||||
base_model: str,
|
||||
quota_reserve_tokens: int | None = None,
|
||||
quota_reserve_micros_override: int | None = None,
|
||||
usage_type: str,
|
||||
thread_id: int | None = None,
|
||||
message_id: int | None = None,
|
||||
call_details: dict[str, Any] | None = None,
|
||||
) -> AsyncIterator[TurnTokenAccumulator]:
|
||||
"""Wrap a single billable LLM/image call.
|
||||
|
||||
Args:
|
||||
user_id: Owner of the credit pool to debit. For vision-LLM during
|
||||
indexing this is the *search-space owner* (issue M), not the
|
||||
triggering user.
|
||||
search_space_id: Required — recorded on the ``TokenUsage`` audit row.
|
||||
billing_tier: ``"premium"`` debits; anything else (``"free"``) skips
|
||||
the reserve/finalize dance but still records an audit row with
|
||||
the captured cost.
|
||||
base_model: Used by :func:`estimate_call_reserve_micros` to compute
|
||||
a worst-case reservation from LiteLLM's pricing table.
|
||||
quota_reserve_tokens: Optional per-config override for the chat-style
|
||||
reserve estimator (vision LLM uses this).
|
||||
quota_reserve_micros_override: Optional flat micro-USD reservation
|
||||
(image generation uses this — its cost shape is per-image, not
|
||||
per-token).
|
||||
usage_type: ``"image_generation"`` / ``"vision_extraction"`` / etc.
|
||||
Recorded on the ``TokenUsage`` row.
|
||||
thread_id, message_id: Optional FK columns on ``TokenUsage``.
|
||||
call_details: Optional per-call metadata (model name, parameters)
|
||||
forwarded to ``record_token_usage``.
|
||||
|
||||
Yields:
|
||||
The ``TurnTokenAccumulator`` scoped to this call. The caller invokes
|
||||
the underlying LLM/image API while inside the ``async with``; the
|
||||
``TokenTrackingCallback`` populates the accumulator automatically.
|
||||
|
||||
Raises:
|
||||
QuotaInsufficientError: when premium and ``premium_reserve`` denies.
|
||||
"""
|
||||
is_premium = billing_tier == "premium"
|
||||
|
||||
async with scoped_turn() as acc:
|
||||
# ---------- Free path: just audit -------------------------------
|
||||
if not is_premium:
|
||||
try:
|
||||
yield acc
|
||||
finally:
|
||||
# Always audit, even on exception, so we capture cost when
|
||||
# provider returns successfully but the caller raises later.
|
||||
try:
|
||||
async with shielded_async_session() as audit_session:
|
||||
await record_token_usage(
|
||||
audit_session,
|
||||
usage_type=usage_type,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
prompt_tokens=acc.total_prompt_tokens,
|
||||
completion_tokens=acc.total_completion_tokens,
|
||||
total_tokens=acc.grand_total,
|
||||
cost_micros=acc.total_cost_micros,
|
||||
model_breakdown=acc.per_message_summary(),
|
||||
call_details=call_details,
|
||||
thread_id=thread_id,
|
||||
message_id=message_id,
|
||||
)
|
||||
await audit_session.commit()
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"[billable_call] free-path audit insert failed for "
|
||||
"usage_type=%s user_id=%s",
|
||||
usage_type,
|
||||
user_id,
|
||||
)
|
||||
return
|
||||
|
||||
# ---------- Premium path: reserve → execute → finalize ----------
|
||||
if quota_reserve_micros_override is not None:
|
||||
reserve_micros = max(1, int(quota_reserve_micros_override))
|
||||
else:
|
||||
reserve_micros = estimate_call_reserve_micros(
|
||||
base_model=base_model or "",
|
||||
quota_reserve_tokens=quota_reserve_tokens,
|
||||
)
|
||||
|
||||
request_id = str(uuid4())
|
||||
|
||||
async with shielded_async_session() as quota_session:
|
||||
reserve_result = await TokenQuotaService.premium_reserve(
|
||||
db_session=quota_session,
|
||||
user_id=user_id,
|
||||
request_id=request_id,
|
||||
reserve_micros=reserve_micros,
|
||||
)
|
||||
|
||||
if not reserve_result.allowed:
|
||||
logger.info(
|
||||
"[billable_call] reserve DENIED user=%s usage_type=%s "
|
||||
"reserve=%d used=%d limit=%d remaining=%d",
|
||||
user_id,
|
||||
usage_type,
|
||||
reserve_micros,
|
||||
reserve_result.used,
|
||||
reserve_result.limit,
|
||||
reserve_result.remaining,
|
||||
)
|
||||
raise QuotaInsufficientError(
|
||||
usage_type=usage_type,
|
||||
used_micros=reserve_result.used,
|
||||
limit_micros=reserve_result.limit,
|
||||
remaining_micros=reserve_result.remaining,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"[billable_call] reserve OK user=%s usage_type=%s reserve_micros=%d "
|
||||
"(remaining=%d)",
|
||||
user_id,
|
||||
usage_type,
|
||||
reserve_micros,
|
||||
reserve_result.remaining,
|
||||
)
|
||||
|
||||
try:
|
||||
yield acc
|
||||
except BaseException:
|
||||
# Release on any failure (including QuotaInsufficientError raised
|
||||
# from a downstream call, asyncio cancellation, etc.). We use
|
||||
# BaseException so cancellation also releases.
|
||||
try:
|
||||
async with shielded_async_session() as quota_session:
|
||||
await TokenQuotaService.premium_release(
|
||||
db_session=quota_session,
|
||||
user_id=user_id,
|
||||
reserved_micros=reserve_micros,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"[billable_call] premium_release failed for user=%s "
|
||||
"reserve_micros=%d (reservation will be GC'd by quota "
|
||||
"reconciliation if/when implemented)",
|
||||
user_id,
|
||||
reserve_micros,
|
||||
)
|
||||
raise
|
||||
|
||||
# ---------- Success: finalize + audit ----------------------------
|
||||
actual_micros = acc.total_cost_micros
|
||||
try:
|
||||
async with shielded_async_session() as quota_session:
|
||||
final_result = await TokenQuotaService.premium_finalize(
|
||||
db_session=quota_session,
|
||||
user_id=user_id,
|
||||
request_id=request_id,
|
||||
actual_micros=actual_micros,
|
||||
reserved_micros=reserve_micros,
|
||||
)
|
||||
logger.info(
|
||||
"[billable_call] finalize user=%s usage_type=%s actual=%d "
|
||||
"reserved=%d → used=%d/%d (remaining=%d)",
|
||||
user_id,
|
||||
usage_type,
|
||||
actual_micros,
|
||||
reserve_micros,
|
||||
final_result.used,
|
||||
final_result.limit,
|
||||
final_result.remaining,
|
||||
)
|
||||
except Exception:
|
||||
# Last-ditch: if finalize itself fails, we must at least release
|
||||
# so the reservation doesn't leak.
|
||||
logger.exception(
|
||||
"[billable_call] premium_finalize failed for user=%s; "
|
||||
"attempting release",
|
||||
user_id,
|
||||
)
|
||||
try:
|
||||
async with shielded_async_session() as quota_session:
|
||||
await TokenQuotaService.premium_release(
|
||||
db_session=quota_session,
|
||||
user_id=user_id,
|
||||
reserved_micros=reserve_micros,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"[billable_call] release after finalize failure ALSO failed "
|
||||
"for user=%s",
|
||||
user_id,
|
||||
)
|
||||
|
||||
try:
|
||||
async with shielded_async_session() as audit_session:
|
||||
await record_token_usage(
|
||||
audit_session,
|
||||
usage_type=usage_type,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
prompt_tokens=acc.total_prompt_tokens,
|
||||
completion_tokens=acc.total_completion_tokens,
|
||||
total_tokens=acc.grand_total,
|
||||
cost_micros=actual_micros,
|
||||
model_breakdown=acc.per_message_summary(),
|
||||
call_details=call_details,
|
||||
thread_id=thread_id,
|
||||
message_id=message_id,
|
||||
)
|
||||
await audit_session.commit()
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"[billable_call] premium-path audit insert failed for "
|
||||
"usage_type=%s user_id=%s (debit was applied)",
|
||||
usage_type,
|
||||
user_id,
|
||||
)
|
||||
|
||||
|
||||
async def _resolve_agent_billing_for_search_space(
|
||||
session: AsyncSession,
|
||||
search_space_id: int,
|
||||
*,
|
||||
thread_id: int | None = None,
|
||||
) -> tuple[UUID, str, str]:
|
||||
"""Resolve ``(owner_user_id, billing_tier, base_model)`` for the search-space
|
||||
agent LLM.
|
||||
|
||||
Used by Celery tasks (podcast generation, video presentation) to bill the
|
||||
search-space owner's premium credit pool when the agent LLM is premium.
|
||||
|
||||
Resolution rules mirror chat at ``stream_new_chat.py:2294-2351``:
|
||||
|
||||
- Search space not found / no ``agent_llm_id``: raise ``ValueError``.
|
||||
- **Auto mode** (``id == AUTO_FASTEST_ID == 0``):
|
||||
* ``thread_id`` is set: delegate to
|
||||
``resolve_or_get_pinned_llm_config_id`` (the same call chat uses) and
|
||||
recurse into the resolved id. Reuses chat's existing pin if present
|
||||
so the same model bills for chat + downstream podcast/video. If the
|
||||
user is not premium-eligible, the pin service auto-restricts to free
|
||||
deployments — denial only happens later in
|
||||
``billable_call.premium_reserve`` if the pin really is premium and
|
||||
credit ran out mid-flow.
|
||||
* ``thread_id`` is None: fallback to ``("free", "auto")``. Forward-compat
|
||||
for any future direct-API path; today both Celery tasks always pass
|
||||
``thread_id``.
|
||||
- **Negative id** (global YAML / OpenRouter): ``cfg["billing_tier"]``
|
||||
(defaults to ``"free"`` via ``app/config/__init__.py:52`` setdefault),
|
||||
``base_model = litellm_params.get("base_model") or model_name`` —
|
||||
NOT provider-prefixed, matching chat's cost-map lookup convention.
|
||||
- **Positive id** (user BYOK ``NewLLMConfig``): always free (matches
|
||||
``AgentConfig.from_new_llm_config`` which hard-codes ``billing_tier="free"``);
|
||||
``base_model`` from ``litellm_params`` or ``model_name``.
|
||||
|
||||
Note on imports: ``llm_service``, ``auto_model_pin_service``, and
|
||||
``llm_router_service`` are imported lazily inside the function body to
|
||||
avoid hoisting litellm side-effects (``litellm.callbacks =
|
||||
[token_tracker]``, ``litellm.drop_params``, etc.) into
|
||||
``billable_calls.py``'s module load path.
|
||||
"""
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.db import NewLLMConfig, SearchSpace
|
||||
|
||||
result = await session.execute(
|
||||
select(SearchSpace).where(SearchSpace.id == search_space_id)
|
||||
)
|
||||
search_space = result.scalars().first()
|
||||
if search_space is None:
|
||||
raise ValueError(f"Search space {search_space_id} not found")
|
||||
|
||||
agent_llm_id = search_space.agent_llm_id
|
||||
if agent_llm_id is None:
|
||||
raise ValueError(
|
||||
f"Search space {search_space_id} has no agent_llm_id configured"
|
||||
)
|
||||
|
||||
owner_user_id: UUID = search_space.user_id
|
||||
|
||||
from app.services.auto_model_pin_service import (
|
||||
AUTO_FASTEST_ID,
|
||||
resolve_or_get_pinned_llm_config_id,
|
||||
)
|
||||
|
||||
if agent_llm_id == AUTO_FASTEST_ID:
|
||||
if thread_id is None:
|
||||
return owner_user_id, "free", "auto"
|
||||
try:
|
||||
resolution = await resolve_or_get_pinned_llm_config_id(
|
||||
session,
|
||||
thread_id=thread_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=str(owner_user_id),
|
||||
selected_llm_config_id=AUTO_FASTEST_ID,
|
||||
)
|
||||
except ValueError:
|
||||
logger.warning(
|
||||
"[agent_billing] Auto-mode pin resolution failed for "
|
||||
"search_space=%s thread=%s; falling back to free",
|
||||
search_space_id,
|
||||
thread_id,
|
||||
exc_info=True,
|
||||
)
|
||||
return owner_user_id, "free", "auto"
|
||||
agent_llm_id = resolution.resolved_llm_config_id
|
||||
|
||||
if agent_llm_id < 0:
|
||||
from app.services.llm_service import get_global_llm_config
|
||||
|
||||
cfg = get_global_llm_config(agent_llm_id) or {}
|
||||
billing_tier = str(cfg.get("billing_tier", "free")).lower()
|
||||
litellm_params = cfg.get("litellm_params") or {}
|
||||
base_model = litellm_params.get("base_model") or cfg.get("model_name") or ""
|
||||
return owner_user_id, billing_tier, base_model
|
||||
|
||||
nlc_result = await session.execute(
|
||||
select(NewLLMConfig).where(
|
||||
NewLLMConfig.id == agent_llm_id,
|
||||
NewLLMConfig.search_space_id == search_space_id,
|
||||
)
|
||||
)
|
||||
nlc = nlc_result.scalars().first()
|
||||
base_model = ""
|
||||
if nlc is not None:
|
||||
litellm_params = nlc.litellm_params or {}
|
||||
base_model = litellm_params.get("base_model") or nlc.model_name or ""
|
||||
return owner_user_id, "free", base_model
|
||||
|
||||
|
||||
__all__ = [
|
||||
"QuotaInsufficientError",
|
||||
"_resolve_agent_billing_for_search_space",
|
||||
"billable_call",
|
||||
]
|
||||
|
||||
|
||||
# Re-export the config knob so callers don't have to import config just for
|
||||
# the default image reserve.
|
||||
DEFAULT_IMAGE_RESERVE_MICROS = config.QUOTA_DEFAULT_IMAGE_RESERVE_MICROS
|
||||
|
|
@ -134,42 +134,16 @@ PROVIDER_MAP = {
|
|||
}
|
||||
|
||||
|
||||
# Default ``api_base`` per LiteLLM provider prefix. Used as a safety net when
|
||||
# a global LLM config does *not* specify ``api_base``: without this, LiteLLM
|
||||
# happily picks up provider-agnostic env vars (e.g. ``AZURE_API_BASE``,
|
||||
# ``OPENAI_API_BASE``) and routes, say, an ``openrouter/anthropic/claude-3-haiku``
|
||||
# request to an Azure endpoint, which then 404s with ``Resource not found``.
|
||||
# Only providers with a well-known, stable public base URL are listed here —
|
||||
# self-hosted / BYO-endpoint providers (ollama, custom, bedrock, vertex_ai,
|
||||
# huggingface, databricks, cloudflare, replicate) are intentionally omitted
|
||||
# so their existing config-driven behaviour is preserved.
|
||||
PROVIDER_DEFAULT_API_BASE = {
|
||||
"openrouter": "https://openrouter.ai/api/v1",
|
||||
"groq": "https://api.groq.com/openai/v1",
|
||||
"mistral": "https://api.mistral.ai/v1",
|
||||
"perplexity": "https://api.perplexity.ai",
|
||||
"xai": "https://api.x.ai/v1",
|
||||
"cerebras": "https://api.cerebras.ai/v1",
|
||||
"deepinfra": "https://api.deepinfra.com/v1/openai",
|
||||
"fireworks_ai": "https://api.fireworks.ai/inference/v1",
|
||||
"together_ai": "https://api.together.xyz/v1",
|
||||
"anyscale": "https://api.endpoints.anyscale.com/v1",
|
||||
"cometapi": "https://api.cometapi.com/v1",
|
||||
"sambanova": "https://api.sambanova.ai/v1",
|
||||
}
|
||||
|
||||
|
||||
# Canonical provider → base URL when a config uses a generic ``openai``-style
|
||||
# prefix but the ``provider`` field tells us which API it really is
|
||||
# (e.g. DeepSeek/Alibaba/Moonshot/Zhipu/MiniMax all use ``openai`` compat but
|
||||
# each has its own base URL).
|
||||
PROVIDER_KEY_DEFAULT_API_BASE = {
|
||||
"DEEPSEEK": "https://api.deepseek.com/v1",
|
||||
"ALIBABA_QWEN": "https://dashscope-intl.aliyuncs.com/compatible-mode/v1",
|
||||
"MOONSHOT": "https://api.moonshot.ai/v1",
|
||||
"ZHIPU": "https://open.bigmodel.cn/api/paas/v4",
|
||||
"MINIMAX": "https://api.minimax.io/v1",
|
||||
}
|
||||
# ``PROVIDER_DEFAULT_API_BASE`` and ``PROVIDER_KEY_DEFAULT_API_BASE`` were
|
||||
# hoisted to ``app.services.provider_api_base`` so vision and image-gen
|
||||
# call sites can share the exact same defense (OpenRouter / Groq / etc.
|
||||
# 404-ing against an inherited Azure endpoint). Re-exported here for
|
||||
# backward compatibility with any external import.
|
||||
from app.services.provider_api_base import ( # noqa: E402
|
||||
PROVIDER_DEFAULT_API_BASE,
|
||||
PROVIDER_KEY_DEFAULT_API_BASE,
|
||||
resolve_api_base,
|
||||
)
|
||||
|
||||
|
||||
class LLMRouterService:
|
||||
|
|
@ -466,14 +440,14 @@ class LLMRouterService:
|
|||
# Resolve ``api_base``. Config value wins; otherwise apply a
|
||||
# provider-aware default so the deployment does not silently
|
||||
# inherit unrelated env vars (e.g. ``AZURE_API_BASE``) and route
|
||||
# requests to the wrong endpoint. See ``PROVIDER_DEFAULT_API_BASE``
|
||||
# requests to the wrong endpoint. See ``provider_api_base``
|
||||
# docstring for the motivating bug (OpenRouter models 404-ing
|
||||
# against an Azure endpoint).
|
||||
api_base = config.get("api_base")
|
||||
if not api_base:
|
||||
api_base = PROVIDER_KEY_DEFAULT_API_BASE.get(provider)
|
||||
if not api_base:
|
||||
api_base = PROVIDER_DEFAULT_API_BASE.get(provider_prefix)
|
||||
api_base = resolve_api_base(
|
||||
provider=provider,
|
||||
provider_prefix=provider_prefix,
|
||||
config_api_base=config.get("api_base"),
|
||||
)
|
||||
if api_base:
|
||||
litellm_params["api_base"] = api_base
|
||||
|
||||
|
|
|
|||
|
|
@ -496,8 +496,14 @@ async def get_vision_llm(
|
|||
- Auto mode (ID 0): VisionLLMRouterService
|
||||
- Global (negative ID): YAML configs
|
||||
- DB (positive ID): VisionLLMConfig table
|
||||
|
||||
Premium global configs are wrapped in :class:`QuotaCheckedVisionLLM`
|
||||
so each ``ainvoke`` debits the search-space owner's premium credit
|
||||
pool. User-owned BYOK configs and free global configs are returned
|
||||
unwrapped — they don't consume premium credit (issue M).
|
||||
"""
|
||||
from app.db import VisionLLMConfig
|
||||
from app.services.quota_checked_vision_llm import QuotaCheckedVisionLLM
|
||||
from app.services.vision_llm_router_service import (
|
||||
VISION_PROVIDER_MAP,
|
||||
VisionLLMRouterService,
|
||||
|
|
@ -519,6 +525,8 @@ async def get_vision_llm(
|
|||
logger.error(f"No vision LLM configured for search space {search_space_id}")
|
||||
return None
|
||||
|
||||
owner_user_id = search_space.user_id
|
||||
|
||||
if is_vision_auto_mode(config_id):
|
||||
if not VisionLLMRouterService.is_initialized():
|
||||
logger.error(
|
||||
|
|
@ -526,6 +534,13 @@ async def get_vision_llm(
|
|||
)
|
||||
return None
|
||||
try:
|
||||
# Auto mode is currently treated as free at the wrapper
|
||||
# level — the underlying router can dispatch to either
|
||||
# premium or free YAML configs but routing decisions are
|
||||
# opaque. If/when we want to bill Auto-routed vision
|
||||
# calls we'd need to thread the resolved deployment's
|
||||
# billing_tier back from the router. For now we keep
|
||||
# parity with chat Auto, which also doesn't pre-classify.
|
||||
return ChatLiteLLMRouter(
|
||||
router=VisionLLMRouterService.get_router(),
|
||||
streaming=True,
|
||||
|
|
@ -562,8 +577,21 @@ async def get_vision_llm(
|
|||
|
||||
from app.agents.new_chat.llm_config import SanitizedChatLiteLLM
|
||||
|
||||
return SanitizedChatLiteLLM(**litellm_kwargs)
|
||||
inner_llm = SanitizedChatLiteLLM(**litellm_kwargs)
|
||||
|
||||
billing_tier = str(global_cfg.get("billing_tier", "free")).lower()
|
||||
if billing_tier == "premium":
|
||||
return QuotaCheckedVisionLLM(
|
||||
inner_llm,
|
||||
user_id=owner_user_id,
|
||||
search_space_id=search_space_id,
|
||||
billing_tier=billing_tier,
|
||||
base_model=model_string,
|
||||
quota_reserve_tokens=global_cfg.get("quota_reserve_tokens"),
|
||||
)
|
||||
return inner_llm
|
||||
|
||||
# User-owned (positive ID) BYOK configs — always free.
|
||||
result = await session.execute(
|
||||
select(VisionLLMConfig).where(
|
||||
VisionLLMConfig.id == config_id,
|
||||
|
|
|
|||
|
|
@ -93,6 +93,35 @@ def _is_text_output_model(model: dict) -> bool:
|
|||
return output_mods == ["text"]
|
||||
|
||||
|
||||
def _is_image_output_model(model: dict) -> bool:
|
||||
"""Return True if the model can produce image output.
|
||||
|
||||
OpenRouter's ``architecture.output_modalities`` is a list (e.g.
|
||||
``["image"]`` for pure image generators, ``["text", "image"]`` for
|
||||
multi-modal generators that also emit captions). We accept any model
|
||||
that can output images; the call site decides whether to use the
|
||||
image-generation API or chat completion.
|
||||
"""
|
||||
output_mods = model.get("architecture", {}).get("output_modalities", []) or []
|
||||
return "image" in output_mods
|
||||
|
||||
|
||||
def _is_vision_input_model(model: dict) -> bool:
|
||||
"""Return True if the model can ingest an image AND emit text.
|
||||
|
||||
OpenRouter's ``architecture.input_modalities`` lists what the model
|
||||
accepts; ``output_modalities`` lists what it produces. A vision LLM
|
||||
is a model that takes images in and produces text out — i.e. it can
|
||||
answer questions about a screenshot or extract content from an
|
||||
image. Pure image-to-image models (e.g. style transfer) and
|
||||
text-only models are excluded.
|
||||
"""
|
||||
arch = model.get("architecture", {}) or {}
|
||||
input_mods = arch.get("input_modalities", []) or []
|
||||
output_mods = arch.get("output_modalities", []) or []
|
||||
return "image" in input_mods and "text" in output_mods
|
||||
|
||||
|
||||
def _supports_tool_calling(model: dict) -> bool:
|
||||
"""Return True if the model supports function/tool calling."""
|
||||
supported = model.get("supported_parameters") or []
|
||||
|
|
@ -175,6 +204,32 @@ async def _fetch_models_async() -> list[dict] | None:
|
|||
return None
|
||||
|
||||
|
||||
def _extract_raw_pricing(raw_models: list[dict]) -> dict[str, dict[str, str]]:
|
||||
"""Return a ``{model_id: {"prompt": str, "completion": str}}`` map.
|
||||
|
||||
Pricing values are kept as the raw OpenRouter strings (e.g.
|
||||
``"0.000003"``); ``pricing_registration`` converts them to floats
|
||||
when registering with LiteLLM. Models with missing or malformed
|
||||
pricing are simply omitted — operator-side risk if any of those are
|
||||
premium.
|
||||
"""
|
||||
pricing: dict[str, dict[str, str]] = {}
|
||||
for model in raw_models:
|
||||
model_id = str(model.get("id") or "").strip()
|
||||
if not model_id:
|
||||
continue
|
||||
p = model.get("pricing") or {}
|
||||
prompt = p.get("prompt")
|
||||
completion = p.get("completion")
|
||||
if prompt is None and completion is None:
|
||||
continue
|
||||
pricing[model_id] = {
|
||||
"prompt": str(prompt) if prompt is not None else "",
|
||||
"completion": str(completion) if completion is not None else "",
|
||||
}
|
||||
return pricing
|
||||
|
||||
|
||||
def _generate_configs(
|
||||
raw_models: list[dict],
|
||||
settings: dict[str, Any],
|
||||
|
|
@ -282,6 +337,162 @@ def _generate_configs(
|
|||
return configs
|
||||
|
||||
|
||||
# ID-offset bands used to keep dynamic OpenRouter configs in their own
|
||||
# namespace per surface. Image / vision get separate bands so a single
|
||||
# Postgres-INTEGER cfg ID is unambiguous about which selector it belongs to.
|
||||
_OPENROUTER_IMAGE_ID_OFFSET_DEFAULT = -20000
|
||||
_OPENROUTER_VISION_ID_OFFSET_DEFAULT = -30000
|
||||
|
||||
|
||||
def _generate_image_gen_configs(
|
||||
raw_models: list[dict], settings: dict[str, Any]
|
||||
) -> list[dict]:
|
||||
"""Convert OpenRouter image-generation models into global image-gen
|
||||
config dicts (matches the YAML shape consumed by ``image_generation_routes``).
|
||||
|
||||
Filter:
|
||||
- architecture.output_modalities contains "image"
|
||||
- compatible provider (excluded slugs blocked)
|
||||
- allowed model id (excluded list blocked)
|
||||
|
||||
Notably we *drop* the chat-only filters (``_supports_tool_calling`` and
|
||||
``_has_sufficient_context``) because tool calls and context windows are
|
||||
irrelevant for the ``aimage_generation`` API. ``billing_tier`` is
|
||||
derived per model the same way as chat (``_openrouter_tier``).
|
||||
|
||||
Cost is intentionally *not* registered with LiteLLM at startup
|
||||
(``pricing_registration`` skips image gen): OpenRouter image-gen
|
||||
models are not in LiteLLM's native cost map and OpenRouter populates
|
||||
``response_cost`` directly from the response header. A defensive
|
||||
branch in ``_extract_cost_usd`` handles the rare case where
|
||||
``usage.cost`` is missing — see ``token_tracking_service``.
|
||||
"""
|
||||
id_offset: int = int(
|
||||
settings.get("image_id_offset") or _OPENROUTER_IMAGE_ID_OFFSET_DEFAULT
|
||||
)
|
||||
api_key: str = settings.get("api_key", "")
|
||||
rpm: int = settings.get("rpm", 200)
|
||||
free_rpm: int = settings.get("free_rpm", 20)
|
||||
litellm_params: dict = settings.get("litellm_params") or {}
|
||||
|
||||
image_models = [
|
||||
m
|
||||
for m in raw_models
|
||||
if _is_image_output_model(m)
|
||||
and _is_compatible_provider(m)
|
||||
and _is_allowed_model(m)
|
||||
and "/" in m.get("id", "")
|
||||
]
|
||||
|
||||
configs: list[dict] = []
|
||||
taken: set[int] = set()
|
||||
for model in image_models:
|
||||
model_id: str = model["id"]
|
||||
name: str = model.get("name", model_id)
|
||||
tier = _openrouter_tier(model)
|
||||
|
||||
cfg: dict[str, Any] = {
|
||||
"id": _stable_config_id(model_id, id_offset, taken),
|
||||
"name": name,
|
||||
"description": f"{name} via OpenRouter (image generation)",
|
||||
"provider": "OPENROUTER",
|
||||
"model_name": model_id,
|
||||
"api_key": api_key,
|
||||
"api_base": "",
|
||||
"api_version": None,
|
||||
"rpm": free_rpm if tier == "free" else rpm,
|
||||
"litellm_params": dict(litellm_params),
|
||||
"billing_tier": tier,
|
||||
_OPENROUTER_DYNAMIC_MARKER: True,
|
||||
}
|
||||
configs.append(cfg)
|
||||
|
||||
return configs
|
||||
|
||||
|
||||
def _generate_vision_llm_configs(
|
||||
raw_models: list[dict], settings: dict[str, Any]
|
||||
) -> list[dict]:
|
||||
"""Convert OpenRouter vision-capable LLMs into global vision-LLM config
|
||||
dicts (matches the YAML shape consumed by ``vision_llm_routes``).
|
||||
|
||||
Filter:
|
||||
- architecture.input_modalities contains "image"
|
||||
- architecture.output_modalities contains "text"
|
||||
- compatible provider (excluded slugs blocked)
|
||||
- allowed model id (excluded list blocked)
|
||||
|
||||
Vision-LLM is invoked from the indexer (image extraction during
|
||||
document upload) via ``langchain_litellm.ChatLiteLLM.ainvoke``, so
|
||||
the chat-only ``_supports_tool_calling`` and ``_has_sufficient_context``
|
||||
filters do not apply: a small-context vision model that doesn't
|
||||
advertise tool-calling is still perfectly viable for "describe this
|
||||
image" prompts.
|
||||
"""
|
||||
id_offset: int = int(
|
||||
settings.get("vision_id_offset") or _OPENROUTER_VISION_ID_OFFSET_DEFAULT
|
||||
)
|
||||
api_key: str = settings.get("api_key", "")
|
||||
rpm: int = settings.get("rpm", 200)
|
||||
tpm: int = settings.get("tpm", 1_000_000)
|
||||
free_rpm: int = settings.get("free_rpm", 20)
|
||||
free_tpm: int = settings.get("free_tpm", 100_000)
|
||||
quota_reserve_tokens: int = settings.get("quota_reserve_tokens", 4000)
|
||||
litellm_params: dict = settings.get("litellm_params") or {}
|
||||
|
||||
vision_models = [
|
||||
m
|
||||
for m in raw_models
|
||||
if _is_vision_input_model(m)
|
||||
and _is_compatible_provider(m)
|
||||
and _is_allowed_model(m)
|
||||
and "/" in m.get("id", "")
|
||||
]
|
||||
|
||||
configs: list[dict] = []
|
||||
taken: set[int] = set()
|
||||
for model in vision_models:
|
||||
model_id: str = model["id"]
|
||||
name: str = model.get("name", model_id)
|
||||
tier = _openrouter_tier(model)
|
||||
pricing = model.get("pricing") or {}
|
||||
|
||||
# Capture per-token prices so ``pricing_registration`` can
|
||||
# register them with LiteLLM at startup (and so the cost
|
||||
# estimator in ``estimate_call_reserve_micros`` can resolve
|
||||
# them at reserve time).
|
||||
try:
|
||||
input_cost = float(pricing.get("prompt", 0) or 0)
|
||||
except (TypeError, ValueError):
|
||||
input_cost = 0.0
|
||||
try:
|
||||
output_cost = float(pricing.get("completion", 0) or 0)
|
||||
except (TypeError, ValueError):
|
||||
output_cost = 0.0
|
||||
|
||||
cfg: dict[str, Any] = {
|
||||
"id": _stable_config_id(model_id, id_offset, taken),
|
||||
"name": name,
|
||||
"description": f"{name} via OpenRouter (vision)",
|
||||
"provider": "OPENROUTER",
|
||||
"model_name": model_id,
|
||||
"api_key": api_key,
|
||||
"api_base": "",
|
||||
"api_version": None,
|
||||
"rpm": free_rpm if tier == "free" else rpm,
|
||||
"tpm": free_tpm if tier == "free" else tpm,
|
||||
"litellm_params": dict(litellm_params),
|
||||
"billing_tier": tier,
|
||||
"quota_reserve_tokens": quota_reserve_tokens,
|
||||
"input_cost_per_token": input_cost or None,
|
||||
"output_cost_per_token": output_cost or None,
|
||||
_OPENROUTER_DYNAMIC_MARKER: True,
|
||||
}
|
||||
configs.append(cfg)
|
||||
|
||||
return configs
|
||||
|
||||
|
||||
class OpenRouterIntegrationService:
|
||||
"""Singleton that manages the dynamic OpenRouter model catalogue."""
|
||||
|
||||
|
|
@ -300,6 +511,19 @@ class OpenRouterIntegrationService:
|
|||
# Shape: {model_name: {"gated": bool, "score": float | None}}
|
||||
self._health_cache: dict[str, dict[str, Any]] = {}
|
||||
self._enrich_task: asyncio.Task | None = None
|
||||
# Raw OpenRouter pricing per model_id, captured at the same time
|
||||
# we generate configs. Consumed by ``pricing_registration`` to
|
||||
# teach LiteLLM the per-token cost of every dynamic deployment so
|
||||
# the success-callback can populate ``response_cost`` correctly.
|
||||
self._raw_pricing: dict[str, dict[str, str]] = {}
|
||||
# Cached raw catalogue from the most recent fetch. Image / vision
|
||||
# emitters reuse this to avoid a second network call per surface.
|
||||
self._raw_models: list[dict] = []
|
||||
# Image / vision config caches (only populated when the matching
|
||||
# opt-in flag is true on initialize). Refreshed in lockstep with
|
||||
# the chat catalogue.
|
||||
self._image_configs: list[dict] = []
|
||||
self._vision_configs: list[dict] = []
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls) -> "OpenRouterIntegrationService":
|
||||
|
|
@ -329,8 +553,32 @@ class OpenRouterIntegrationService:
|
|||
self._initialized = True
|
||||
return []
|
||||
|
||||
self._raw_models = raw_models
|
||||
self._configs = _generate_configs(raw_models, settings)
|
||||
self._configs_by_id = {c["id"]: c for c in self._configs}
|
||||
self._raw_pricing = _extract_raw_pricing(raw_models)
|
||||
|
||||
# Populate image / vision caches when their opt-in flag is set.
|
||||
# Empty otherwise so the accessors return [] without re-running
|
||||
# filters every refresh.
|
||||
if settings.get("image_generation_enabled"):
|
||||
self._image_configs = _generate_image_gen_configs(raw_models, settings)
|
||||
logger.info(
|
||||
"OpenRouter integration: image-gen emission ON (%d models)",
|
||||
len(self._image_configs),
|
||||
)
|
||||
else:
|
||||
self._image_configs = []
|
||||
|
||||
if settings.get("vision_enabled"):
|
||||
self._vision_configs = _generate_vision_llm_configs(raw_models, settings)
|
||||
logger.info(
|
||||
"OpenRouter integration: vision LLM emission ON (%d models)",
|
||||
len(self._vision_configs),
|
||||
)
|
||||
else:
|
||||
self._vision_configs = []
|
||||
|
||||
self._initialized = True
|
||||
|
||||
tier_counts = self._tier_counts(self._configs)
|
||||
|
|
@ -369,6 +617,8 @@ class OpenRouterIntegrationService:
|
|||
|
||||
new_configs = _generate_configs(raw_models, self._settings)
|
||||
new_by_id = {c["id"]: c for c in new_configs}
|
||||
self._raw_pricing = _extract_raw_pricing(raw_models)
|
||||
self._raw_models = raw_models
|
||||
|
||||
from app.config import config as app_config
|
||||
|
||||
|
|
@ -382,6 +632,29 @@ class OpenRouterIntegrationService:
|
|||
self._configs = new_configs
|
||||
self._configs_by_id = new_by_id
|
||||
|
||||
# Image / vision lists are atomic-swapped the same way: filter out
|
||||
# the previous dynamic entries from the live config list and append
|
||||
# the freshly generated ones. No-ops when the opt-in flag is off.
|
||||
if self._settings.get("image_generation_enabled"):
|
||||
new_image = _generate_image_gen_configs(raw_models, self._settings)
|
||||
static_image = [
|
||||
c
|
||||
for c in app_config.GLOBAL_IMAGE_GEN_CONFIGS
|
||||
if not c.get(_OPENROUTER_DYNAMIC_MARKER)
|
||||
]
|
||||
app_config.GLOBAL_IMAGE_GEN_CONFIGS = static_image + new_image
|
||||
self._image_configs = new_image
|
||||
|
||||
if self._settings.get("vision_enabled"):
|
||||
new_vision = _generate_vision_llm_configs(raw_models, self._settings)
|
||||
static_vision = [
|
||||
c
|
||||
for c in app_config.GLOBAL_VISION_LLM_CONFIGS
|
||||
if not c.get(_OPENROUTER_DYNAMIC_MARKER)
|
||||
]
|
||||
app_config.GLOBAL_VISION_LLM_CONFIGS = static_vision + new_vision
|
||||
self._vision_configs = new_vision
|
||||
|
||||
# Catalogue churn invalidates per-config "recently healthy" credit
|
||||
# earned by the previous turn's preflight. Drop the whole table so
|
||||
# the next turn re-probes against the freshly loaded configs.
|
||||
|
|
@ -407,6 +680,21 @@ class OpenRouterIntegrationService:
|
|||
# so a hand-picked dead OR model is gated like a dynamic one.
|
||||
await self._enrich_health_safely(static_configs + new_configs, log_summary=True)
|
||||
|
||||
# Re-register LiteLLM pricing for the freshly fetched catalogue
|
||||
# so newly added OR models bill correctly on their first call.
|
||||
# Runs before the router rebuild because the router may issue
|
||||
# cost-table lookups during deployment registration.
|
||||
try:
|
||||
from app.services.pricing_registration import (
|
||||
register_pricing_from_global_configs,
|
||||
)
|
||||
|
||||
register_pricing_from_global_configs()
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"OpenRouter refresh: pricing re-registration skipped (%s)", exc
|
||||
)
|
||||
|
||||
# Rebuild the LiteLLM router so freshly fetched configs flow through
|
||||
# (dynamic OR premium entries now opt into the pool, free ones stay
|
||||
# out; a refresh also needs to pick up any static-config edits and
|
||||
|
|
@ -635,3 +923,34 @@ class OpenRouterIntegrationService:
|
|||
|
||||
def get_config_by_id(self, config_id: int) -> dict | None:
|
||||
return self._configs_by_id.get(config_id)
|
||||
|
||||
def get_image_generation_configs(self) -> list[dict]:
|
||||
"""Return the dynamic OpenRouter image-generation configs (empty
|
||||
list when the ``image_generation_enabled`` flag is off).
|
||||
|
||||
Each entry already has ``billing_tier`` derived per-model from
|
||||
OpenRouter's signals and is shaped to drop directly into
|
||||
``Config.GLOBAL_IMAGE_GEN_CONFIGS``.
|
||||
"""
|
||||
return list(self._image_configs)
|
||||
|
||||
def get_vision_llm_configs(self) -> list[dict]:
|
||||
"""Return the dynamic OpenRouter vision-LLM configs (empty list
|
||||
when the ``vision_enabled`` flag is off).
|
||||
|
||||
Each entry exposes ``input_cost_per_token`` / ``output_cost_per_token``
|
||||
so ``pricing_registration`` can teach LiteLLM the cost of these
|
||||
models the same way it does for chat — which keeps the billable
|
||||
wrapper able to debit accurate micro-USD on a vision call.
|
||||
"""
|
||||
return list(self._vision_configs)
|
||||
|
||||
def get_raw_pricing(self) -> dict[str, dict[str, str]]:
|
||||
"""Return the cached raw OpenRouter pricing map.
|
||||
|
||||
Shape: ``{model_id: {"prompt": str, "completion": str}}``. The
|
||||
values are the strings OpenRouter publishes (USD per token),
|
||||
never converted to floats here so the caller can decide how to
|
||||
handle malformed or unset entries.
|
||||
"""
|
||||
return dict(self._raw_pricing)
|
||||
|
|
|
|||
274
surfsense_backend/app/services/pricing_registration.py
Normal file
274
surfsense_backend/app/services/pricing_registration.py
Normal file
|
|
@ -0,0 +1,274 @@
|
|||
"""
|
||||
Pricing registration with LiteLLM.
|
||||
|
||||
Many models reach our LiteLLM callback without LiteLLM knowing their
|
||||
per-token cost — namely:
|
||||
|
||||
* The ~300 dynamic OpenRouter deployments (their pricing only lives on
|
||||
OpenRouter's ``/api/v1/models`` payload, never in LiteLLM's published
|
||||
pricing table).
|
||||
* Static YAML deployments whose ``base_model`` name is operator-defined
|
||||
(e.g. custom Azure deployment names like ``gpt-5.4``) and therefore
|
||||
not in LiteLLM's table either.
|
||||
|
||||
Without registration, ``kwargs["response_cost"]`` is 0 for those calls
|
||||
and the user gets billed nothing — a fail-safe but wrong answer for a
|
||||
cost-based credit system. This module runs once at startup, after the
|
||||
OpenRouter integration has fetched its catalogue, and registers each
|
||||
known model's pricing with ``litellm.register_model()`` under multiple
|
||||
plausible alias keys (LiteLLM's cost lookup may use any of them
|
||||
depending on whether the call went through the Router, ChatLiteLLM,
|
||||
or a direct ``acompletion``).
|
||||
|
||||
Operators who run a custom Azure deployment whose ``base_model`` name
|
||||
isn't in LiteLLM's table can declare per-token pricing inline in
|
||||
``global_llm_config.yaml`` via ``input_cost_per_token`` and
|
||||
``output_cost_per_token`` (USD per token, e.g. ``0.000002``). Without
|
||||
that declaration the model's calls debit 0 — never overbilled.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import litellm
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _safe_float(value: Any) -> float:
|
||||
"""Return ``float(value)`` if it parses to a positive number, else 0.0."""
|
||||
if value is None:
|
||||
return 0.0
|
||||
try:
|
||||
f = float(value)
|
||||
except (TypeError, ValueError):
|
||||
return 0.0
|
||||
return f if f > 0 else 0.0
|
||||
|
||||
|
||||
def _alias_set_for_openrouter(model_id: str) -> list[str]:
|
||||
"""Return the alias keys to register an OpenRouter model under.
|
||||
|
||||
LiteLLM's cost-callback lookup key varies by call path:
|
||||
- Router with ``model="openrouter/X"`` → kwargs["model"] is
|
||||
typically ``openrouter/X``.
|
||||
- LiteLLM's own provider routing may strip the prefix and pass the
|
||||
bare ``X`` to the cost-table lookup.
|
||||
Registering under both keeps the lookup hermetic regardless of
|
||||
which path the call took.
|
||||
"""
|
||||
aliases = [f"openrouter/{model_id}", model_id]
|
||||
return list(dict.fromkeys(a for a in aliases if a))
|
||||
|
||||
|
||||
def _alias_set_for_yaml(provider: str, model_name: str, base_model: str) -> list[str]:
|
||||
"""Return the alias keys to register a static YAML deployment under.
|
||||
|
||||
Same reasoning as the OpenRouter set: cover the bare ``base_model``,
|
||||
the ``<provider>/<model>`` form LiteLLM Router constructs, and the
|
||||
bare ``model_name`` because callbacks sometimes see whichever was
|
||||
configured first.
|
||||
"""
|
||||
provider_lower = (provider or "").lower()
|
||||
aliases: list[str] = []
|
||||
if base_model:
|
||||
aliases.append(base_model)
|
||||
if provider_lower and base_model:
|
||||
aliases.append(f"{provider_lower}/{base_model}")
|
||||
if model_name and model_name != base_model:
|
||||
aliases.append(model_name)
|
||||
if provider_lower and model_name and model_name != base_model:
|
||||
aliases.append(f"{provider_lower}/{model_name}")
|
||||
# Azure deployments often surface as "azure/<name>"; normalise the
|
||||
# ``azure_openai`` provider slug to the LiteLLM-canonical ``azure``.
|
||||
if provider_lower == "azure_openai":
|
||||
if base_model:
|
||||
aliases.append(f"azure/{base_model}")
|
||||
if model_name and model_name != base_model:
|
||||
aliases.append(f"azure/{model_name}")
|
||||
return list(dict.fromkeys(a for a in aliases if a))
|
||||
|
||||
|
||||
def _register(
|
||||
aliases: list[str],
|
||||
*,
|
||||
input_cost: float,
|
||||
output_cost: float,
|
||||
provider: str,
|
||||
mode: str = "chat",
|
||||
) -> int:
|
||||
"""Register a single pricing entry under every alias in ``aliases``.
|
||||
|
||||
Returns the count of aliases successfully registered.
|
||||
"""
|
||||
payload: dict[str, dict[str, Any]] = {}
|
||||
for alias in aliases:
|
||||
payload[alias] = {
|
||||
"input_cost_per_token": input_cost,
|
||||
"output_cost_per_token": output_cost,
|
||||
"litellm_provider": provider,
|
||||
"mode": mode,
|
||||
}
|
||||
if not payload:
|
||||
return 0
|
||||
try:
|
||||
litellm.register_model(payload)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"[PricingRegistration] register_model failed for aliases=%s: %s",
|
||||
aliases,
|
||||
exc,
|
||||
)
|
||||
return 0
|
||||
return len(payload)
|
||||
|
||||
|
||||
def _register_chat_shape_configs(
|
||||
configs: list[dict],
|
||||
*,
|
||||
or_pricing: dict[str, dict[str, str]],
|
||||
label: str,
|
||||
) -> tuple[int, int, int, list[str]]:
|
||||
"""Common loop that registers per-token pricing for a list of "chat-shape"
|
||||
configs (chat or vision LLM — both use ``input_cost_per_token`` /
|
||||
``output_cost_per_token`` and the LiteLLM ``mode="chat"`` cost shape).
|
||||
|
||||
Returns ``(registered_models, registered_aliases, skipped, sample_keys)``.
|
||||
"""
|
||||
registered_models = 0
|
||||
registered_aliases = 0
|
||||
skipped_no_pricing = 0
|
||||
sample_keys: list[str] = []
|
||||
|
||||
for cfg in configs:
|
||||
provider = str(cfg.get("provider") or "").upper()
|
||||
model_name = str(cfg.get("model_name") or "").strip()
|
||||
litellm_params = cfg.get("litellm_params") or {}
|
||||
base_model = str(litellm_params.get("base_model") or model_name).strip()
|
||||
|
||||
if provider == "OPENROUTER":
|
||||
entry = or_pricing.get(model_name)
|
||||
if entry:
|
||||
input_cost = _safe_float(entry.get("prompt"))
|
||||
output_cost = _safe_float(entry.get("completion"))
|
||||
else:
|
||||
# Vision configs from ``_generate_vision_llm_configs``
|
||||
# carry their pricing inline because the OpenRouter
|
||||
# raw-pricing cache is keyed by chat-catalogue model_id;
|
||||
# vision flows pick up the inline values here.
|
||||
input_cost = _safe_float(cfg.get("input_cost_per_token"))
|
||||
output_cost = _safe_float(cfg.get("output_cost_per_token"))
|
||||
if input_cost == 0.0 and output_cost == 0.0:
|
||||
skipped_no_pricing += 1
|
||||
continue
|
||||
aliases = _alias_set_for_openrouter(model_name)
|
||||
count = _register(
|
||||
aliases,
|
||||
input_cost=input_cost,
|
||||
output_cost=output_cost,
|
||||
provider="openrouter",
|
||||
)
|
||||
if count > 0:
|
||||
registered_models += 1
|
||||
registered_aliases += count
|
||||
if len(sample_keys) < 6:
|
||||
sample_keys.extend(aliases[:2])
|
||||
continue
|
||||
|
||||
input_cost = _safe_float(
|
||||
cfg.get("input_cost_per_token")
|
||||
or litellm_params.get("input_cost_per_token")
|
||||
)
|
||||
output_cost = _safe_float(
|
||||
cfg.get("output_cost_per_token")
|
||||
or litellm_params.get("output_cost_per_token")
|
||||
)
|
||||
if input_cost == 0.0 and output_cost == 0.0:
|
||||
skipped_no_pricing += 1
|
||||
continue
|
||||
aliases = _alias_set_for_yaml(provider, model_name, base_model)
|
||||
provider_slug = "azure" if provider == "AZURE_OPENAI" else provider.lower()
|
||||
count = _register(
|
||||
aliases,
|
||||
input_cost=input_cost,
|
||||
output_cost=output_cost,
|
||||
provider=provider_slug,
|
||||
)
|
||||
if count > 0:
|
||||
registered_models += 1
|
||||
registered_aliases += count
|
||||
if len(sample_keys) < 6:
|
||||
sample_keys.extend(aliases[:2])
|
||||
|
||||
logger.info(
|
||||
"[PricingRegistration:%s] registered pricing for %d models (%d aliases); "
|
||||
"%d configs had no pricing data; sample registered keys=%s",
|
||||
label,
|
||||
registered_models,
|
||||
registered_aliases,
|
||||
skipped_no_pricing,
|
||||
sample_keys,
|
||||
)
|
||||
return registered_models, registered_aliases, skipped_no_pricing, sample_keys
|
||||
|
||||
|
||||
def register_pricing_from_global_configs() -> None:
|
||||
"""Register pricing for every known LLM deployment with LiteLLM.
|
||||
|
||||
Walks ``config.GLOBAL_LLM_CONFIGS`` *and* ``config.GLOBAL_VISION_LLM_CONFIGS``
|
||||
so vision calls (during indexing) can resolve cost the same way chat
|
||||
calls do — namely:
|
||||
|
||||
1. ``OPENROUTER``: pulls the cached raw pricing from
|
||||
``OpenRouterIntegrationService`` (populated during its own
|
||||
startup fetch) and converts the per-token strings to floats. For
|
||||
vision configs that carry pricing inline (``input_cost_per_token`` /
|
||||
``output_cost_per_token`` set on the cfg itself) we fall back to
|
||||
those values when the OR cache misses the model.
|
||||
2. Anything else: looks for operator-declared
|
||||
``input_cost_per_token`` / ``output_cost_per_token`` on the YAML
|
||||
config block (top-level or nested under ``litellm_params``).
|
||||
|
||||
**Image generation is intentionally NOT registered here.** The cost
|
||||
shape for image-gen is per-image (``output_cost_per_image``), not
|
||||
per-token, and LiteLLM's ``register_model`` doesn't accept those
|
||||
keys via the chat-cost path. OpenRouter image-gen models populate
|
||||
``response_cost`` directly from their response header instead, and
|
||||
Azure-native image-gen models are already in LiteLLM's cost map.
|
||||
|
||||
Calls without a resolved pair of costs are skipped, not registered
|
||||
with zeros — operators who forget pricing get a "$0 debit" warning
|
||||
in ``TokenTrackingCallback`` rather than silently overwriting any
|
||||
pricing LiteLLM might know natively.
|
||||
"""
|
||||
from app.config import config as app_config
|
||||
|
||||
chat_configs: list[dict] = list(getattr(app_config, "GLOBAL_LLM_CONFIGS", []) or [])
|
||||
vision_configs: list[dict] = list(
|
||||
getattr(app_config, "GLOBAL_VISION_LLM_CONFIGS", []) or []
|
||||
)
|
||||
if not chat_configs and not vision_configs:
|
||||
logger.info("[PricingRegistration] no global configs to register")
|
||||
return
|
||||
|
||||
or_pricing: dict[str, dict[str, str]] = {}
|
||||
try:
|
||||
from app.services.openrouter_integration_service import (
|
||||
OpenRouterIntegrationService,
|
||||
)
|
||||
|
||||
if OpenRouterIntegrationService.is_initialized():
|
||||
or_pricing = OpenRouterIntegrationService.get_instance().get_raw_pricing()
|
||||
except Exception as exc:
|
||||
logger.debug(
|
||||
"[PricingRegistration] OpenRouter pricing not available yet: %s", exc
|
||||
)
|
||||
|
||||
if chat_configs:
|
||||
_register_chat_shape_configs(chat_configs, or_pricing=or_pricing, label="chat")
|
||||
if vision_configs:
|
||||
_register_chat_shape_configs(
|
||||
vision_configs, or_pricing=or_pricing, label="vision"
|
||||
)
|
||||
107
surfsense_backend/app/services/provider_api_base.py
Normal file
107
surfsense_backend/app/services/provider_api_base.py
Normal file
|
|
@ -0,0 +1,107 @@
|
|||
"""Provider-aware ``api_base`` resolution shared by chat / image-gen / vision.
|
||||
|
||||
LiteLLM falls back to the module-global ``litellm.api_base`` when an
|
||||
individual call doesn't pass one, which silently inherits provider-agnostic
|
||||
env vars like ``AZURE_OPENAI_ENDPOINT`` / ``OPENAI_API_BASE``. Without an
|
||||
explicit ``api_base``, an ``openrouter/<model>`` request can end up at an
|
||||
Azure endpoint and 404 with ``Resource not found`` (real reproducer:
|
||||
[litellm/llms/openrouter/image_generation/transformation.py:242-263] appends
|
||||
``/chat/completions`` to whatever inherited base it gets, regardless of
|
||||
provider).
|
||||
|
||||
The chat router has had this defense for a while
|
||||
(``llm_router_service.py:466-478``). This module hoists the maps + cascade
|
||||
into a tiny standalone helper so vision and image-gen can share the same
|
||||
source of truth without an inter-service circular import.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
PROVIDER_DEFAULT_API_BASE: dict[str, str] = {
|
||||
"openrouter": "https://openrouter.ai/api/v1",
|
||||
"groq": "https://api.groq.com/openai/v1",
|
||||
"mistral": "https://api.mistral.ai/v1",
|
||||
"perplexity": "https://api.perplexity.ai",
|
||||
"xai": "https://api.x.ai/v1",
|
||||
"cerebras": "https://api.cerebras.ai/v1",
|
||||
"deepinfra": "https://api.deepinfra.com/v1/openai",
|
||||
"fireworks_ai": "https://api.fireworks.ai/inference/v1",
|
||||
"together_ai": "https://api.together.xyz/v1",
|
||||
"anyscale": "https://api.endpoints.anyscale.com/v1",
|
||||
"cometapi": "https://api.cometapi.com/v1",
|
||||
"sambanova": "https://api.sambanova.ai/v1",
|
||||
}
|
||||
"""Default ``api_base`` per LiteLLM provider prefix (lowercase).
|
||||
|
||||
Only providers with a well-known, stable public base URL are listed —
|
||||
self-hosted / BYO-endpoint providers (ollama, custom, bedrock, vertex_ai,
|
||||
huggingface, databricks, cloudflare, replicate) are intentionally omitted
|
||||
so their existing config-driven behaviour is preserved."""
|
||||
|
||||
|
||||
PROVIDER_KEY_DEFAULT_API_BASE: dict[str, str] = {
|
||||
"DEEPSEEK": "https://api.deepseek.com/v1",
|
||||
"ALIBABA_QWEN": "https://dashscope-intl.aliyuncs.com/compatible-mode/v1",
|
||||
"MOONSHOT": "https://api.moonshot.ai/v1",
|
||||
"ZHIPU": "https://open.bigmodel.cn/api/paas/v4",
|
||||
"MINIMAX": "https://api.minimax.io/v1",
|
||||
}
|
||||
"""Canonical provider key (uppercase) → base URL.
|
||||
|
||||
Used when the LiteLLM provider prefix is the generic ``openai`` shim but the
|
||||
config's ``provider`` field tells us which API it actually is (DeepSeek,
|
||||
Alibaba, Moonshot, Zhipu, MiniMax all use the ``openai`` prefix but each
|
||||
has its own base URL)."""
|
||||
|
||||
|
||||
def resolve_api_base(
|
||||
*,
|
||||
provider: str | None,
|
||||
provider_prefix: str | None,
|
||||
config_api_base: str | None,
|
||||
) -> str | None:
|
||||
"""Resolve a non-Azure-leaking ``api_base`` for a deployment.
|
||||
|
||||
Cascade (first non-empty wins):
|
||||
1. The config's own ``api_base`` (whitespace-only treated as missing).
|
||||
2. ``PROVIDER_KEY_DEFAULT_API_BASE[provider.upper()]``.
|
||||
3. ``PROVIDER_DEFAULT_API_BASE[provider_prefix.lower()]``.
|
||||
4. ``None`` — caller should NOT set ``api_base`` and let the LiteLLM
|
||||
provider integration apply its own default (e.g. AzureOpenAI's
|
||||
deployment-derived URL, custom provider's per-deployment URL).
|
||||
|
||||
Args:
|
||||
provider: The config's ``provider`` field (e.g. ``"OPENROUTER"``,
|
||||
``"DEEPSEEK"``). Case-insensitive.
|
||||
provider_prefix: The LiteLLM model-string prefix the same call
|
||||
site builds for the model id (e.g. ``"openrouter"``,
|
||||
``"groq"``). Case-insensitive.
|
||||
config_api_base: ``api_base`` from the global YAML / DB row /
|
||||
OpenRouter dynamic config. Empty / whitespace-only means
|
||||
"missing" — the resolver still applies the cascade.
|
||||
|
||||
Returns:
|
||||
A URL string, or ``None`` if no default applies for this provider.
|
||||
"""
|
||||
if config_api_base and config_api_base.strip():
|
||||
return config_api_base
|
||||
|
||||
if provider:
|
||||
key_default = PROVIDER_KEY_DEFAULT_API_BASE.get(provider.upper())
|
||||
if key_default:
|
||||
return key_default
|
||||
|
||||
if provider_prefix:
|
||||
prefix_default = PROVIDER_DEFAULT_API_BASE.get(provider_prefix.lower())
|
||||
if prefix_default:
|
||||
return prefix_default
|
||||
|
||||
return None
|
||||
|
||||
|
||||
__all__ = [
|
||||
"PROVIDER_DEFAULT_API_BASE",
|
||||
"PROVIDER_KEY_DEFAULT_API_BASE",
|
||||
"resolve_api_base",
|
||||
]
|
||||
105
surfsense_backend/app/services/quota_checked_vision_llm.py
Normal file
105
surfsense_backend/app/services/quota_checked_vision_llm.py
Normal file
|
|
@ -0,0 +1,105 @@
|
|||
"""
|
||||
Vision LLM proxy that enforces premium credit quota on every ``ainvoke``.
|
||||
|
||||
Used by :func:`app.services.llm_service.get_vision_llm` so callers in the
|
||||
indexing pipeline (file processors, connector indexers, etl pipeline) can
|
||||
keep invoking the LLM exactly the way they do today — ``await llm.ainvoke(...)``
|
||||
— without threading ``user_id`` through every parser. The wrapper looks like
|
||||
a chat model from the outside; on the inside it routes each call through
|
||||
``billable_call`` so the user's premium credit pool is reserved → finalized
|
||||
or released, and a ``TokenUsage`` audit row is written.
|
||||
|
||||
Free configs are returned unwrapped from ``get_vision_llm`` (they do not
|
||||
need quota enforcement) so this class only ever wraps premium configs.
|
||||
|
||||
Why a wrapper instead of plumbing ``user_id`` through every caller:
|
||||
|
||||
* The indexer ecosystem has 8+ entry points (Google Drive, OneDrive,
|
||||
Dropbox, local-folder, file-processor, ETL pipeline) each calling
|
||||
``parse_with_vision_llm(...)``. Adding a ``user_id`` argument to each is
|
||||
invasive, error-prone, and easy for a future indexer to forget.
|
||||
* Per the design (issue M), we always debit the *search-space owner*, not
|
||||
the triggering user, so ``user_id`` is fully derivable from the search
|
||||
space the caller is already operating on. The wrapper captures it once
|
||||
at construction time.
|
||||
* ``langchain_litellm.ChatLiteLLM`` has no public hook for "before each
|
||||
call run this coroutine"; subclassing isn't safe across versions because
|
||||
it derives from ``BaseChatModel`` which expects specific Pydantic shapes.
|
||||
Composition via attribute proxying (``__getattr__``) is robust to
|
||||
upstream changes — every method other than ``ainvoke`` falls through to
|
||||
the inner LLM unchanged.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from app.services.billable_calls import QuotaInsufficientError, billable_call
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class QuotaCheckedVisionLLM:
|
||||
"""Composition wrapper around a langchain chat model that enforces
|
||||
premium credit quota on every ``ainvoke``.
|
||||
|
||||
Anything other than ``ainvoke`` is forwarded to the inner model so
|
||||
``invoke`` (sync), ``astream``, ``with_structured_output``, etc. all
|
||||
still work — they simply bypass quota enforcement, which is fine
|
||||
because the indexing pipeline only ever calls ``ainvoke`` today.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
inner_llm: Any,
|
||||
*,
|
||||
user_id: UUID,
|
||||
search_space_id: int,
|
||||
billing_tier: str,
|
||||
base_model: str,
|
||||
quota_reserve_tokens: int | None,
|
||||
usage_type: str = "vision_extraction",
|
||||
) -> None:
|
||||
self._inner = inner_llm
|
||||
self._user_id = user_id
|
||||
self._search_space_id = search_space_id
|
||||
self._billing_tier = billing_tier
|
||||
self._base_model = base_model
|
||||
self._quota_reserve_tokens = quota_reserve_tokens
|
||||
self._usage_type = usage_type
|
||||
|
||||
async def ainvoke(self, input: Any, *args: Any, **kwargs: Any) -> Any:
|
||||
"""Proxied async invoke that runs the underlying call inside
|
||||
``billable_call``.
|
||||
|
||||
Raises:
|
||||
QuotaInsufficientError: when the user has exhausted their
|
||||
premium credit pool. Caller (``etl_pipeline_service._extract_image``)
|
||||
catches this and falls back to the document parser.
|
||||
"""
|
||||
async with billable_call(
|
||||
user_id=self._user_id,
|
||||
search_space_id=self._search_space_id,
|
||||
billing_tier=self._billing_tier,
|
||||
base_model=self._base_model,
|
||||
quota_reserve_tokens=self._quota_reserve_tokens,
|
||||
usage_type=self._usage_type,
|
||||
call_details={"model": self._base_model},
|
||||
):
|
||||
return await self._inner.ainvoke(input, *args, **kwargs)
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
"""Forward everything else (``invoke``, ``astream``, ``bind``,
|
||||
``with_structured_output``, …) to the inner model.
|
||||
|
||||
``__getattr__`` is only consulted when the attribute is *not*
|
||||
already found on the proxy, which is exactly the contract we
|
||||
want — methods we override stay on the proxy, the rest fall
|
||||
through.
|
||||
"""
|
||||
return getattr(self._inner, name)
|
||||
|
||||
|
||||
__all__ = ["QuotaCheckedVisionLLM", "QuotaInsufficientError"]
|
||||
|
|
@ -22,6 +22,71 @@ from app.config import config
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Per-call reservation estimator (USD micro-units)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Minimum reserve in micros so a user with $0.0001 left can still make a tiny
|
||||
# request, and so models without registered pricing reserve at least
|
||||
# something while the call runs (debited 0 at finalize anyway when their
|
||||
# cost can't be resolved).
|
||||
_QUOTA_MIN_RESERVE_MICROS = 100
|
||||
|
||||
|
||||
def estimate_call_reserve_micros(
|
||||
*,
|
||||
base_model: str,
|
||||
quota_reserve_tokens: int | None,
|
||||
) -> int:
|
||||
"""Return the number of micro-USD to reserve for one premium call.
|
||||
|
||||
Computes a worst-case upper bound from LiteLLM's per-token pricing
|
||||
table:
|
||||
|
||||
reserve_usd ≈ reserve_tokens x (input_cost + output_cost)
|
||||
|
||||
so the math scales with model cost — Claude Opus + 4K reserve_tokens
|
||||
naturally reserves ≈ $0.36, while a cheap model reserves only a few
|
||||
cents. Clamped to ``[_QUOTA_MIN_RESERVE_MICROS, QUOTA_MAX_RESERVE_MICROS]``
|
||||
so a misconfigured "$1000/M" model can't lock the whole balance on
|
||||
one call.
|
||||
|
||||
If ``litellm.get_model_info`` raises (model unknown) we fall back to
|
||||
the floor — 100 micros / $0.0001 — which is enough to gate a sane
|
||||
request without over-reserving for a model whose pricing the
|
||||
operator hasn't declared yet.
|
||||
"""
|
||||
reserve_tokens = quota_reserve_tokens or config.QUOTA_MAX_RESERVE_PER_CALL
|
||||
if reserve_tokens <= 0:
|
||||
reserve_tokens = config.QUOTA_MAX_RESERVE_PER_CALL
|
||||
|
||||
try:
|
||||
from litellm import get_model_info
|
||||
|
||||
info = get_model_info(base_model) if base_model else {}
|
||||
input_cost = float(info.get("input_cost_per_token") or 0.0)
|
||||
output_cost = float(info.get("output_cost_per_token") or 0.0)
|
||||
except Exception as exc:
|
||||
logger.debug(
|
||||
"[quota_reserve] cost lookup failed for base_model=%s: %s",
|
||||
base_model,
|
||||
exc,
|
||||
)
|
||||
input_cost = 0.0
|
||||
output_cost = 0.0
|
||||
|
||||
if input_cost == 0.0 and output_cost == 0.0:
|
||||
return _QUOTA_MIN_RESERVE_MICROS
|
||||
|
||||
reserve_usd = reserve_tokens * (input_cost + output_cost)
|
||||
reserve_micros = round(reserve_usd * 1_000_000)
|
||||
if reserve_micros < _QUOTA_MIN_RESERVE_MICROS:
|
||||
reserve_micros = _QUOTA_MIN_RESERVE_MICROS
|
||||
if reserve_micros > config.QUOTA_MAX_RESERVE_MICROS:
|
||||
reserve_micros = config.QUOTA_MAX_RESERVE_MICROS
|
||||
return reserve_micros
|
||||
|
||||
|
||||
class QuotaScope(StrEnum):
|
||||
ANONYMOUS = "anonymous"
|
||||
PREMIUM = "premium"
|
||||
|
|
@ -444,8 +509,16 @@ class TokenQuotaService:
|
|||
db_session: AsyncSession,
|
||||
user_id: Any,
|
||||
request_id: str,
|
||||
reserve_tokens: int,
|
||||
reserve_micros: int,
|
||||
) -> QuotaResult:
|
||||
"""Reserve ``reserve_micros`` (USD micro-units) from the user's
|
||||
premium credit balance.
|
||||
|
||||
``QuotaResult.used``/``limit``/``reserved``/``remaining`` are
|
||||
all in micro-USD on this code path; callers (chat stream,
|
||||
token-status route, FE display) convert to dollars by dividing
|
||||
by 1_000_000.
|
||||
"""
|
||||
from app.db import User
|
||||
|
||||
user = (
|
||||
|
|
@ -465,11 +538,11 @@ class TokenQuotaService:
|
|||
limit=0,
|
||||
)
|
||||
|
||||
limit = user.premium_tokens_limit
|
||||
used = user.premium_tokens_used
|
||||
reserved = user.premium_tokens_reserved
|
||||
limit = user.premium_credit_micros_limit
|
||||
used = user.premium_credit_micros_used
|
||||
reserved = user.premium_credit_micros_reserved
|
||||
|
||||
effective = used + reserved + reserve_tokens
|
||||
effective = used + reserved + reserve_micros
|
||||
if effective > limit:
|
||||
remaining = max(0, limit - used - reserved)
|
||||
await db_session.rollback()
|
||||
|
|
@ -482,10 +555,10 @@ class TokenQuotaService:
|
|||
remaining=remaining,
|
||||
)
|
||||
|
||||
user.premium_tokens_reserved = reserved + reserve_tokens
|
||||
user.premium_credit_micros_reserved = reserved + reserve_micros
|
||||
await db_session.commit()
|
||||
|
||||
new_reserved = reserved + reserve_tokens
|
||||
new_reserved = reserved + reserve_micros
|
||||
remaining = max(0, limit - used - new_reserved)
|
||||
warning_threshold = int(limit * 0.8)
|
||||
|
||||
|
|
@ -510,9 +583,12 @@ class TokenQuotaService:
|
|||
db_session: AsyncSession,
|
||||
user_id: Any,
|
||||
request_id: str,
|
||||
actual_tokens: int,
|
||||
reserved_tokens: int,
|
||||
actual_micros: int,
|
||||
reserved_micros: int,
|
||||
) -> QuotaResult:
|
||||
"""Settle the reservation: release ``reserved_micros`` and debit
|
||||
``actual_micros`` (the LiteLLM-reported provider cost in micro-USD).
|
||||
"""
|
||||
from app.db import User
|
||||
|
||||
user = (
|
||||
|
|
@ -529,16 +605,18 @@ class TokenQuotaService:
|
|||
allowed=False, status=QuotaStatus.BLOCKED, used=0, limit=0
|
||||
)
|
||||
|
||||
user.premium_tokens_reserved = max(
|
||||
0, user.premium_tokens_reserved - reserved_tokens
|
||||
user.premium_credit_micros_reserved = max(
|
||||
0, user.premium_credit_micros_reserved - reserved_micros
|
||||
)
|
||||
user.premium_credit_micros_used = (
|
||||
user.premium_credit_micros_used + actual_micros
|
||||
)
|
||||
user.premium_tokens_used = user.premium_tokens_used + actual_tokens
|
||||
|
||||
await db_session.commit()
|
||||
|
||||
limit = user.premium_tokens_limit
|
||||
used = user.premium_tokens_used
|
||||
reserved = user.premium_tokens_reserved
|
||||
limit = user.premium_credit_micros_limit
|
||||
used = user.premium_credit_micros_used
|
||||
reserved = user.premium_credit_micros_reserved
|
||||
remaining = max(0, limit - used - reserved)
|
||||
|
||||
warning_threshold = int(limit * 0.8)
|
||||
|
|
@ -562,8 +640,13 @@ class TokenQuotaService:
|
|||
async def premium_release(
|
||||
db_session: AsyncSession,
|
||||
user_id: Any,
|
||||
reserved_tokens: int,
|
||||
reserved_micros: int,
|
||||
) -> None:
|
||||
"""Release ``reserved_micros`` previously held by ``premium_reserve``.
|
||||
|
||||
Used when a request fails before finalize (so the reservation
|
||||
doesn't leak credit).
|
||||
"""
|
||||
from app.db import User
|
||||
|
||||
user = (
|
||||
|
|
@ -576,8 +659,8 @@ class TokenQuotaService:
|
|||
.scalar_one_or_none()
|
||||
)
|
||||
if user is not None:
|
||||
user.premium_tokens_reserved = max(
|
||||
0, user.premium_tokens_reserved - reserved_tokens
|
||||
user.premium_credit_micros_reserved = max(
|
||||
0, user.premium_credit_micros_reserved - reserved_micros
|
||||
)
|
||||
await db_session.commit()
|
||||
|
||||
|
|
@ -598,9 +681,9 @@ class TokenQuotaService:
|
|||
allowed=False, status=QuotaStatus.BLOCKED, used=0, limit=0
|
||||
)
|
||||
|
||||
limit = user.premium_tokens_limit
|
||||
used = user.premium_tokens_used
|
||||
reserved = user.premium_tokens_reserved
|
||||
limit = user.premium_credit_micros_limit
|
||||
used = user.premium_credit_micros_used
|
||||
reserved = user.premium_credit_micros_reserved
|
||||
remaining = max(0, limit - used - reserved)
|
||||
|
||||
warning_threshold = int(limit * 0.8)
|
||||
|
|
|
|||
|
|
@ -16,11 +16,14 @@ from __future__ import annotations
|
|||
|
||||
import dataclasses
|
||||
import logging
|
||||
from collections.abc import AsyncIterator
|
||||
from contextlib import asynccontextmanager
|
||||
from contextvars import ContextVar
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
import litellm
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
|
|
@ -35,6 +38,8 @@ class TokenCallRecord:
|
|||
prompt_tokens: int
|
||||
completion_tokens: int
|
||||
total_tokens: int
|
||||
cost_micros: int = 0
|
||||
call_kind: str = "chat"
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -49,6 +54,8 @@ class TurnTokenAccumulator:
|
|||
prompt_tokens: int,
|
||||
completion_tokens: int,
|
||||
total_tokens: int,
|
||||
cost_micros: int = 0,
|
||||
call_kind: str = "chat",
|
||||
) -> None:
|
||||
self.calls.append(
|
||||
TokenCallRecord(
|
||||
|
|
@ -56,20 +63,28 @@ class TurnTokenAccumulator:
|
|||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=total_tokens,
|
||||
cost_micros=cost_micros,
|
||||
call_kind=call_kind,
|
||||
)
|
||||
)
|
||||
|
||||
def per_message_summary(self) -> dict[str, dict[str, int]]:
|
||||
"""Return token counts grouped by model name."""
|
||||
"""Return token counts (and cost) grouped by model name."""
|
||||
by_model: dict[str, dict[str, int]] = {}
|
||||
for c in self.calls:
|
||||
entry = by_model.setdefault(
|
||||
c.model,
|
||||
{"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
|
||||
{
|
||||
"prompt_tokens": 0,
|
||||
"completion_tokens": 0,
|
||||
"total_tokens": 0,
|
||||
"cost_micros": 0,
|
||||
},
|
||||
)
|
||||
entry["prompt_tokens"] += c.prompt_tokens
|
||||
entry["completion_tokens"] += c.completion_tokens
|
||||
entry["total_tokens"] += c.total_tokens
|
||||
entry["cost_micros"] += c.cost_micros
|
||||
return by_model
|
||||
|
||||
@property
|
||||
|
|
@ -84,6 +99,21 @@ class TurnTokenAccumulator:
|
|||
def total_completion_tokens(self) -> int:
|
||||
return sum(c.completion_tokens for c in self.calls)
|
||||
|
||||
@property
|
||||
def total_cost_micros(self) -> int:
|
||||
"""Sum of per-call ``cost_micros`` across the entire turn.
|
||||
|
||||
Used by ``stream_new_chat`` to debit a premium turn's actual
|
||||
provider cost (in micro-USD) from the user's premium credit
|
||||
balance. ``cost_micros`` per call is captured by
|
||||
``TokenTrackingCallback.async_log_success_event`` from
|
||||
``kwargs["response_cost"]`` (LiteLLM's auto-calculated cost),
|
||||
with multiple fallback paths so OpenRouter dynamic models and
|
||||
custom Azure deployments still bill correctly when our
|
||||
``pricing_registration`` ran at startup.
|
||||
"""
|
||||
return sum(c.cost_micros for c in self.calls)
|
||||
|
||||
def serialized_calls(self) -> list[dict[str, Any]]:
|
||||
return [dataclasses.asdict(c) for c in self.calls]
|
||||
|
||||
|
|
@ -94,7 +124,14 @@ _turn_accumulator: ContextVar[TurnTokenAccumulator | None] = ContextVar(
|
|||
|
||||
|
||||
def start_turn() -> TurnTokenAccumulator:
|
||||
"""Create a fresh accumulator for the current async context and return it."""
|
||||
"""Create a fresh accumulator for the current async context and return it.
|
||||
|
||||
NOTE: Used by ``stream_new_chat`` for the long-lived chat turn. For
|
||||
short-lived per-call billable wrappers (image generation REST endpoint,
|
||||
vision LLM during indexing) prefer :func:`scoped_turn`, which uses a
|
||||
ContextVar reset token to restore the *previous* accumulator on exit and
|
||||
avoids leaking call records across reservations (issue B).
|
||||
"""
|
||||
acc = TurnTokenAccumulator()
|
||||
_turn_accumulator.set(acc)
|
||||
logger.info("[TokenTracking] start_turn: new accumulator created (id=%s)", id(acc))
|
||||
|
|
@ -105,6 +142,140 @@ def get_current_accumulator() -> TurnTokenAccumulator | None:
|
|||
return _turn_accumulator.get()
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def scoped_turn() -> AsyncIterator[TurnTokenAccumulator]:
|
||||
"""Async context manager that scopes a fresh ``TurnTokenAccumulator``
|
||||
for the duration of the ``async with`` block, then *resets* the
|
||||
ContextVar to its previous value on exit.
|
||||
|
||||
This is the safe primitive for per-call billable operations
|
||||
(image generation, vision LLM extraction, podcasts) that may run
|
||||
inside an outer chat turn or be called sequentially from the same
|
||||
background worker. Using ``ContextVar.set`` without ``reset`` (as
|
||||
:func:`start_turn` does) would leak the inner accumulator into the
|
||||
outer scope, causing the outer chat turn to debit cost twice.
|
||||
|
||||
Usage::
|
||||
|
||||
async with scoped_turn() as acc:
|
||||
await llm.ainvoke(...)
|
||||
# acc.total_cost_micros captures cost from the LiteLLM callback
|
||||
# Outer accumulator (if any) is restored here.
|
||||
"""
|
||||
acc = TurnTokenAccumulator()
|
||||
token = _turn_accumulator.set(acc)
|
||||
logger.debug(
|
||||
"[TokenTracking] scoped_turn: enter (acc id=%s, prev token=%s)",
|
||||
id(acc),
|
||||
token,
|
||||
)
|
||||
try:
|
||||
yield acc
|
||||
finally:
|
||||
_turn_accumulator.reset(token)
|
||||
logger.debug(
|
||||
"[TokenTracking] scoped_turn: exit (acc id=%s captured %d call(s), %d micros total)",
|
||||
id(acc),
|
||||
len(acc.calls),
|
||||
acc.total_cost_micros,
|
||||
)
|
||||
|
||||
|
||||
def _extract_cost_usd(
|
||||
kwargs: dict[str, Any],
|
||||
response_obj: Any,
|
||||
model: str,
|
||||
prompt_tokens: int,
|
||||
completion_tokens: int,
|
||||
is_image: bool = False,
|
||||
) -> float:
|
||||
"""Best-effort USD cost extraction for a single LLM/image call.
|
||||
|
||||
Tries four sources in priority order and returns the first that
|
||||
yields a positive number; returns 0.0 if all four fail (the call
|
||||
will then debit nothing from the user's balance — fail-safe).
|
||||
|
||||
Sources:
|
||||
1. ``kwargs["response_cost"]`` — LiteLLM's standard callback
|
||||
field, populated for ``Router.acompletion`` since PR #12500.
|
||||
2. ``response_obj._hidden_params["response_cost"]`` — same value
|
||||
exposed on the response itself.
|
||||
3. ``litellm.completion_cost(completion_response=response_obj)``
|
||||
— recompute from the response and LiteLLM's pricing table.
|
||||
4. ``litellm.cost_per_token(model, prompt_tokens, completion_tokens)``
|
||||
— manual fallback for OpenRouter/custom-Azure models that
|
||||
only resolve via aliases registered by
|
||||
``pricing_registration`` at startup. **Skipped for image
|
||||
responses** — ``cost_per_token`` does not support ``ImageResponse``
|
||||
and would raise; the cost map for image-gen lives in different
|
||||
keys (``output_cost_per_image``) handled by ``completion_cost``.
|
||||
"""
|
||||
cost = kwargs.get("response_cost")
|
||||
if cost is not None:
|
||||
try:
|
||||
value = float(cost)
|
||||
except (TypeError, ValueError):
|
||||
value = 0.0
|
||||
if value > 0:
|
||||
return value
|
||||
|
||||
hidden = getattr(response_obj, "_hidden_params", None) or {}
|
||||
if isinstance(hidden, dict):
|
||||
cost = hidden.get("response_cost")
|
||||
if cost is not None:
|
||||
try:
|
||||
value = float(cost)
|
||||
except (TypeError, ValueError):
|
||||
value = 0.0
|
||||
if value > 0:
|
||||
return value
|
||||
|
||||
try:
|
||||
value = float(litellm.completion_cost(completion_response=response_obj))
|
||||
if value > 0:
|
||||
return value
|
||||
except Exception as exc:
|
||||
if is_image:
|
||||
# Image-gen path: OpenRouter's image responses can omit
|
||||
# ``usage.cost`` and LiteLLM's ``default_image_cost_calculator``
|
||||
# then *raises* (no cost map for OpenRouter image models).
|
||||
# Bail out with a warning rather than falling through to
|
||||
# cost_per_token (which is also incompatible with ImageResponse).
|
||||
logger.warning(
|
||||
"[TokenTracking] completion_cost failed for image model=%s "
|
||||
"(provider may have omitted usage.cost). Debiting 0. "
|
||||
"Cause: %s",
|
||||
model,
|
||||
exc,
|
||||
)
|
||||
return 0.0
|
||||
logger.debug(
|
||||
"[TokenTracking] completion_cost failed for model=%s: %s", model, exc
|
||||
)
|
||||
|
||||
if is_image:
|
||||
# Never call cost_per_token for ImageResponse — keys mismatch and
|
||||
# the function is documented chat-only.
|
||||
return 0.0
|
||||
|
||||
if model and (prompt_tokens > 0 or completion_tokens > 0):
|
||||
try:
|
||||
prompt_cost, completion_cost = litellm.cost_per_token(
|
||||
model=model,
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
)
|
||||
value = float(prompt_cost) + float(completion_cost)
|
||||
if value > 0:
|
||||
return value
|
||||
except Exception as exc:
|
||||
logger.debug(
|
||||
"[TokenTracking] cost_per_token failed for model=%s: %s", model, exc
|
||||
)
|
||||
|
||||
return 0.0
|
||||
|
||||
|
||||
class TokenTrackingCallback(CustomLogger):
|
||||
"""LiteLLM callback that captures token usage into the turn accumulator."""
|
||||
|
||||
|
|
@ -122,6 +293,13 @@ class TokenTrackingCallback(CustomLogger):
|
|||
)
|
||||
return
|
||||
|
||||
# Detect image generation responses — they have a different usage
|
||||
# shape (ImageUsage with input_tokens/output_tokens) and require a
|
||||
# different cost-extraction path. We probe by class name to avoid a
|
||||
# hard import dependency on litellm internals.
|
||||
response_cls = type(response_obj).__name__
|
||||
is_image = response_cls == "ImageResponse"
|
||||
|
||||
usage = getattr(response_obj, "usage", None)
|
||||
if not usage:
|
||||
logger.debug(
|
||||
|
|
@ -129,24 +307,66 @@ class TokenTrackingCallback(CustomLogger):
|
|||
)
|
||||
return
|
||||
|
||||
prompt_tokens = getattr(usage, "prompt_tokens", 0) or 0
|
||||
completion_tokens = getattr(usage, "completion_tokens", 0) or 0
|
||||
total_tokens = getattr(usage, "total_tokens", 0) or 0
|
||||
if is_image:
|
||||
# ``ImageUsage`` exposes ``input_tokens`` / ``output_tokens``
|
||||
# (not prompt_tokens/completion_tokens). Several providers
|
||||
# populate only one or neither (e.g. OpenRouter's gpt-image-1
|
||||
# passes through `input_tokens` from the prompt but no
|
||||
# completion); fall through gracefully to 0.
|
||||
prompt_tokens = getattr(usage, "input_tokens", 0) or 0
|
||||
completion_tokens = getattr(usage, "output_tokens", 0) or 0
|
||||
total_tokens = (
|
||||
getattr(usage, "total_tokens", 0) or prompt_tokens + completion_tokens
|
||||
)
|
||||
call_kind = "image_generation"
|
||||
else:
|
||||
prompt_tokens = getattr(usage, "prompt_tokens", 0) or 0
|
||||
completion_tokens = getattr(usage, "completion_tokens", 0) or 0
|
||||
total_tokens = getattr(usage, "total_tokens", 0) or 0
|
||||
call_kind = "chat"
|
||||
|
||||
model = kwargs.get("model", "unknown")
|
||||
|
||||
cost_usd = _extract_cost_usd(
|
||||
kwargs=kwargs,
|
||||
response_obj=response_obj,
|
||||
model=model,
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
is_image=is_image,
|
||||
)
|
||||
cost_micros = round(cost_usd * 1_000_000) if cost_usd > 0 else 0
|
||||
|
||||
if cost_micros == 0 and (prompt_tokens > 0 or completion_tokens > 0):
|
||||
logger.warning(
|
||||
"[TokenTracking] No cost resolved for model=%s prompt=%d completion=%d "
|
||||
"kind=%s — debiting 0. Register pricing via pricing_registration or YAML "
|
||||
"input_cost_per_token/output_cost_per_token (or rely on response_cost "
|
||||
"for image generation).",
|
||||
model,
|
||||
prompt_tokens,
|
||||
completion_tokens,
|
||||
call_kind,
|
||||
)
|
||||
|
||||
acc.add(
|
||||
model=model,
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=total_tokens,
|
||||
cost_micros=cost_micros,
|
||||
call_kind=call_kind,
|
||||
)
|
||||
logger.info(
|
||||
"[TokenTracking] Captured: model=%s prompt=%d completion=%d total=%d (accumulator now has %d calls)",
|
||||
"[TokenTracking] Captured: model=%s kind=%s prompt=%d completion=%d total=%d "
|
||||
"cost=$%.6f (%d micros) (accumulator now has %d calls)",
|
||||
model,
|
||||
call_kind,
|
||||
prompt_tokens,
|
||||
completion_tokens,
|
||||
total_tokens,
|
||||
cost_usd,
|
||||
cost_micros,
|
||||
len(acc.calls),
|
||||
)
|
||||
|
||||
|
|
@ -168,6 +388,7 @@ async def record_token_usage(
|
|||
prompt_tokens: int = 0,
|
||||
completion_tokens: int = 0,
|
||||
total_tokens: int = 0,
|
||||
cost_micros: int = 0,
|
||||
model_breakdown: dict[str, Any] | None = None,
|
||||
call_details: dict[str, Any] | None = None,
|
||||
thread_id: int | None = None,
|
||||
|
|
@ -185,6 +406,7 @@ async def record_token_usage(
|
|||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=total_tokens,
|
||||
cost_micros=cost_micros,
|
||||
model_breakdown=model_breakdown,
|
||||
call_details=call_details,
|
||||
thread_id=thread_id,
|
||||
|
|
@ -194,11 +416,12 @@ async def record_token_usage(
|
|||
)
|
||||
session.add(record)
|
||||
logger.debug(
|
||||
"[TokenTracking] recorded %s usage: prompt=%d completion=%d total=%d",
|
||||
"[TokenTracking] recorded %s usage: prompt=%d completion=%d total=%d cost_micros=%d",
|
||||
usage_type,
|
||||
prompt_tokens,
|
||||
completion_tokens,
|
||||
total_tokens,
|
||||
cost_micros,
|
||||
)
|
||||
return record
|
||||
except Exception:
|
||||
|
|
|
|||
|
|
@ -3,6 +3,8 @@ from typing import Any
|
|||
|
||||
from litellm import Router
|
||||
|
||||
from app.services.provider_api_base import resolve_api_base
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
VISION_AUTO_MODE_ID = 0
|
||||
|
|
@ -108,10 +110,11 @@ class VisionLLMRouterService:
|
|||
if not config.get("model_name") or not config.get("api_key"):
|
||||
return None
|
||||
|
||||
provider = config.get("provider", "").upper()
|
||||
if config.get("custom_provider"):
|
||||
model_string = f"{config['custom_provider']}/{config['model_name']}"
|
||||
provider_prefix = config["custom_provider"]
|
||||
model_string = f"{provider_prefix}/{config['model_name']}"
|
||||
else:
|
||||
provider = config.get("provider", "").upper()
|
||||
provider_prefix = VISION_PROVIDER_MAP.get(provider, provider.lower())
|
||||
model_string = f"{provider_prefix}/{config['model_name']}"
|
||||
|
||||
|
|
@ -120,8 +123,13 @@ class VisionLLMRouterService:
|
|||
"api_key": config.get("api_key"),
|
||||
}
|
||||
|
||||
if config.get("api_base"):
|
||||
litellm_params["api_base"] = config["api_base"]
|
||||
api_base = resolve_api_base(
|
||||
provider=provider,
|
||||
provider_prefix=provider_prefix,
|
||||
config_api_base=config.get("api_base"),
|
||||
)
|
||||
if api_base:
|
||||
litellm_params["api_base"] = api_base
|
||||
|
||||
if config.get("api_version"):
|
||||
litellm_params["api_version"] = config["api_version"]
|
||||
|
|
|
|||
|
|
@ -9,7 +9,13 @@ from sqlalchemy import select
|
|||
from app.agents.podcaster.graph import graph as podcaster_graph
|
||||
from app.agents.podcaster.state import State as PodcasterState
|
||||
from app.celery_app import celery_app
|
||||
from app.config import config as app_config
|
||||
from app.db import Podcast, PodcastStatus
|
||||
from app.services.billable_calls import (
|
||||
QuotaInsufficientError,
|
||||
_resolve_agent_billing_for_search_space,
|
||||
billable_call,
|
||||
)
|
||||
from app.tasks.celery_tasks import get_celery_session_maker
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -96,6 +102,31 @@ async def _generate_content_podcast(
|
|||
podcast.status = PodcastStatus.GENERATING
|
||||
await session.commit()
|
||||
|
||||
try:
|
||||
(
|
||||
owner_user_id,
|
||||
billing_tier,
|
||||
base_model,
|
||||
) = await _resolve_agent_billing_for_search_space(
|
||||
session,
|
||||
search_space_id,
|
||||
thread_id=podcast.thread_id,
|
||||
)
|
||||
except ValueError as resolve_err:
|
||||
logger.error(
|
||||
"Podcast %s: cannot resolve billing for search_space=%s: %s",
|
||||
podcast.id,
|
||||
search_space_id,
|
||||
resolve_err,
|
||||
)
|
||||
podcast.status = PodcastStatus.FAILED
|
||||
await session.commit()
|
||||
return {
|
||||
"status": "failed",
|
||||
"podcast_id": podcast.id,
|
||||
"reason": "billing_resolution_failed",
|
||||
}
|
||||
|
||||
graph_config = {
|
||||
"configurable": {
|
||||
"podcast_title": podcast.title,
|
||||
|
|
@ -109,9 +140,39 @@ async def _generate_content_podcast(
|
|||
db_session=session,
|
||||
)
|
||||
|
||||
graph_result = await podcaster_graph.ainvoke(
|
||||
initial_state, config=graph_config
|
||||
)
|
||||
try:
|
||||
async with billable_call(
|
||||
user_id=owner_user_id,
|
||||
search_space_id=search_space_id,
|
||||
billing_tier=billing_tier,
|
||||
base_model=base_model,
|
||||
quota_reserve_micros_override=app_config.QUOTA_DEFAULT_PODCAST_RESERVE_MICROS,
|
||||
usage_type="podcast_generation",
|
||||
thread_id=podcast.thread_id,
|
||||
call_details={
|
||||
"podcast_id": podcast.id,
|
||||
"title": podcast.title,
|
||||
},
|
||||
):
|
||||
graph_result = await podcaster_graph.ainvoke(
|
||||
initial_state, config=graph_config
|
||||
)
|
||||
except QuotaInsufficientError as exc:
|
||||
logger.info(
|
||||
"Podcast %s denied: out of premium credits "
|
||||
"(used=%d/%d remaining=%d)",
|
||||
podcast.id,
|
||||
exc.used_micros,
|
||||
exc.limit_micros,
|
||||
exc.remaining_micros,
|
||||
)
|
||||
podcast.status = PodcastStatus.FAILED
|
||||
await session.commit()
|
||||
return {
|
||||
"status": "failed",
|
||||
"podcast_id": podcast.id,
|
||||
"reason": "premium_quota_exhausted",
|
||||
}
|
||||
|
||||
podcast_transcript = graph_result.get("podcast_transcript", [])
|
||||
file_path = graph_result.get("final_podcast_file_path", "")
|
||||
|
|
|
|||
|
|
@ -9,7 +9,13 @@ from sqlalchemy import select
|
|||
from app.agents.video_presentation.graph import graph as video_presentation_graph
|
||||
from app.agents.video_presentation.state import State as VideoPresentationState
|
||||
from app.celery_app import celery_app
|
||||
from app.config import config as app_config
|
||||
from app.db import VideoPresentation, VideoPresentationStatus
|
||||
from app.services.billable_calls import (
|
||||
QuotaInsufficientError,
|
||||
_resolve_agent_billing_for_search_space,
|
||||
billable_call,
|
||||
)
|
||||
from app.tasks.celery_tasks import get_celery_session_maker
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -97,6 +103,32 @@ async def _generate_video_presentation(
|
|||
video_pres.status = VideoPresentationStatus.GENERATING
|
||||
await session.commit()
|
||||
|
||||
try:
|
||||
(
|
||||
owner_user_id,
|
||||
billing_tier,
|
||||
base_model,
|
||||
) = await _resolve_agent_billing_for_search_space(
|
||||
session,
|
||||
search_space_id,
|
||||
thread_id=video_pres.thread_id,
|
||||
)
|
||||
except ValueError as resolve_err:
|
||||
logger.error(
|
||||
"VideoPresentation %s: cannot resolve billing for "
|
||||
"search_space=%s: %s",
|
||||
video_pres.id,
|
||||
search_space_id,
|
||||
resolve_err,
|
||||
)
|
||||
video_pres.status = VideoPresentationStatus.FAILED
|
||||
await session.commit()
|
||||
return {
|
||||
"status": "failed",
|
||||
"video_presentation_id": video_pres.id,
|
||||
"reason": "billing_resolution_failed",
|
||||
}
|
||||
|
||||
graph_config = {
|
||||
"configurable": {
|
||||
"video_title": video_pres.title,
|
||||
|
|
@ -110,9 +142,39 @@ async def _generate_video_presentation(
|
|||
db_session=session,
|
||||
)
|
||||
|
||||
graph_result = await video_presentation_graph.ainvoke(
|
||||
initial_state, config=graph_config
|
||||
)
|
||||
try:
|
||||
async with billable_call(
|
||||
user_id=owner_user_id,
|
||||
search_space_id=search_space_id,
|
||||
billing_tier=billing_tier,
|
||||
base_model=base_model,
|
||||
quota_reserve_micros_override=app_config.QUOTA_DEFAULT_VIDEO_PRESENTATION_RESERVE_MICROS,
|
||||
usage_type="video_presentation_generation",
|
||||
thread_id=video_pres.thread_id,
|
||||
call_details={
|
||||
"video_presentation_id": video_pres.id,
|
||||
"title": video_pres.title,
|
||||
},
|
||||
):
|
||||
graph_result = await video_presentation_graph.ainvoke(
|
||||
initial_state, config=graph_config
|
||||
)
|
||||
except QuotaInsufficientError as exc:
|
||||
logger.info(
|
||||
"VideoPresentation %s denied: out of premium credits "
|
||||
"(used=%d/%d remaining=%d)",
|
||||
video_pres.id,
|
||||
exc.used_micros,
|
||||
exc.limit_micros,
|
||||
exc.remaining_micros,
|
||||
)
|
||||
video_pres.status = VideoPresentationStatus.FAILED
|
||||
await session.commit()
|
||||
return {
|
||||
"status": "failed",
|
||||
"video_presentation_id": video_pres.id,
|
||||
"reason": "premium_quota_exhausted",
|
||||
}
|
||||
|
||||
# Serialize slides (parsed content + audio info merged)
|
||||
slides_raw = graph_result.get("slides", [])
|
||||
|
|
|
|||
|
|
@ -2236,8 +2236,10 @@ async def stream_new_chat(
|
|||
|
||||
accumulator = start_turn()
|
||||
|
||||
# Premium quota tracking state
|
||||
_premium_reserved = 0
|
||||
# Premium credit (USD micro-units) tracking state. Stores the
|
||||
# amount reserved up front so we can release it on cancellation
|
||||
# and finalize-debit the actual provider cost reported by LiteLLM.
|
||||
_premium_reserved_micros = 0
|
||||
_premium_request_id: str | None = None
|
||||
|
||||
_emit_stream_error = partial(
|
||||
|
|
@ -2331,23 +2333,28 @@ async def stream_new_chat(
|
|||
if _needs_premium_quota:
|
||||
import uuid as _uuid
|
||||
|
||||
from app.config import config as _app_config
|
||||
from app.services.token_quota_service import TokenQuotaService
|
||||
from app.services.token_quota_service import (
|
||||
TokenQuotaService,
|
||||
estimate_call_reserve_micros,
|
||||
)
|
||||
|
||||
_premium_request_id = _uuid.uuid4().hex[:16]
|
||||
reserve_amount = min(
|
||||
agent_config.quota_reserve_tokens
|
||||
or _app_config.QUOTA_MAX_RESERVE_PER_CALL,
|
||||
_app_config.QUOTA_MAX_RESERVE_PER_CALL,
|
||||
_agent_litellm_params = agent_config.litellm_params or {}
|
||||
_agent_base_model = (
|
||||
_agent_litellm_params.get("base_model") or agent_config.model_name or ""
|
||||
)
|
||||
reserve_amount_micros = estimate_call_reserve_micros(
|
||||
base_model=_agent_base_model,
|
||||
quota_reserve_tokens=agent_config.quota_reserve_tokens,
|
||||
)
|
||||
async with shielded_async_session() as quota_session:
|
||||
quota_result = await TokenQuotaService.premium_reserve(
|
||||
db_session=quota_session,
|
||||
user_id=UUID(user_id),
|
||||
request_id=_premium_request_id,
|
||||
reserve_tokens=reserve_amount,
|
||||
reserve_micros=reserve_amount_micros,
|
||||
)
|
||||
_premium_reserved = reserve_amount
|
||||
_premium_reserved_micros = reserve_amount_micros
|
||||
if not quota_result.allowed:
|
||||
if requested_llm_config_id == 0:
|
||||
try:
|
||||
|
|
@ -2382,7 +2389,7 @@ async def stream_new_chat(
|
|||
yield streaming_service.format_done()
|
||||
return
|
||||
_premium_request_id = None
|
||||
_premium_reserved = 0
|
||||
_premium_reserved_micros = 0
|
||||
_log_chat_stream_error(
|
||||
flow=flow,
|
||||
error_kind="premium_quota_exhausted",
|
||||
|
|
@ -3020,9 +3027,10 @@ async def stream_new_chat(
|
|||
|
||||
usage_summary = accumulator.per_message_summary()
|
||||
_perf_log.info(
|
||||
"[token_usage] interrupted new_chat: calls=%d total=%d summary=%s",
|
||||
"[token_usage] interrupted new_chat: calls=%d total=%d cost_micros=%d summary=%s",
|
||||
len(accumulator.calls),
|
||||
accumulator.grand_total,
|
||||
accumulator.total_cost_micros,
|
||||
usage_summary,
|
||||
)
|
||||
if usage_summary:
|
||||
|
|
@ -3033,6 +3041,7 @@ async def stream_new_chat(
|
|||
"prompt_tokens": accumulator.total_prompt_tokens,
|
||||
"completion_tokens": accumulator.total_completion_tokens,
|
||||
"total_tokens": accumulator.grand_total,
|
||||
"cost_micros": accumulator.total_cost_micros,
|
||||
"call_details": accumulator.serialized_calls(),
|
||||
},
|
||||
)
|
||||
|
|
@ -3060,7 +3069,11 @@ async def stream_new_chat(
|
|||
chat_id, generated_title
|
||||
)
|
||||
|
||||
# Finalize premium quota with actual tokens.
|
||||
# Finalize premium credit debit with the actual provider cost
|
||||
# reported by LiteLLM, summed across every call in the turn.
|
||||
# Mirrors the pre-cost behaviour of "premium turn → all calls
|
||||
# count" so free sub-agent calls during a premium turn still
|
||||
# contribute to the bill (they're $0 in practice anyway).
|
||||
if _premium_request_id and user_id:
|
||||
try:
|
||||
from app.services.token_quota_service import TokenQuotaService
|
||||
|
|
@ -3070,11 +3083,11 @@ async def stream_new_chat(
|
|||
db_session=quota_session,
|
||||
user_id=UUID(user_id),
|
||||
request_id=_premium_request_id,
|
||||
actual_tokens=accumulator.grand_total,
|
||||
reserved_tokens=_premium_reserved,
|
||||
actual_micros=accumulator.total_cost_micros,
|
||||
reserved_micros=_premium_reserved_micros,
|
||||
)
|
||||
_premium_request_id = None
|
||||
_premium_reserved = 0
|
||||
_premium_reserved_micros = 0
|
||||
except Exception:
|
||||
logging.getLogger(__name__).warning(
|
||||
"Failed to finalize premium quota for user %s",
|
||||
|
|
@ -3084,9 +3097,10 @@ async def stream_new_chat(
|
|||
|
||||
usage_summary = accumulator.per_message_summary()
|
||||
_perf_log.info(
|
||||
"[token_usage] normal new_chat: calls=%d total=%d summary=%s",
|
||||
"[token_usage] normal new_chat: calls=%d total=%d cost_micros=%d summary=%s",
|
||||
len(accumulator.calls),
|
||||
accumulator.grand_total,
|
||||
accumulator.total_cost_micros,
|
||||
usage_summary,
|
||||
)
|
||||
if usage_summary:
|
||||
|
|
@ -3097,6 +3111,7 @@ async def stream_new_chat(
|
|||
"prompt_tokens": accumulator.total_prompt_tokens,
|
||||
"completion_tokens": accumulator.total_completion_tokens,
|
||||
"total_tokens": accumulator.grand_total,
|
||||
"cost_micros": accumulator.total_cost_micros,
|
||||
"call_details": accumulator.serialized_calls(),
|
||||
},
|
||||
)
|
||||
|
|
@ -3190,7 +3205,7 @@ async def stream_new_chat(
|
|||
end_turn(str(chat_id))
|
||||
|
||||
# Release premium reservation if not finalized
|
||||
if _premium_request_id and _premium_reserved > 0 and user_id:
|
||||
if _premium_request_id and _premium_reserved_micros > 0 and user_id:
|
||||
try:
|
||||
from app.services.token_quota_service import TokenQuotaService
|
||||
|
||||
|
|
@ -3198,9 +3213,9 @@ async def stream_new_chat(
|
|||
await TokenQuotaService.premium_release(
|
||||
db_session=quota_session,
|
||||
user_id=UUID(user_id),
|
||||
reserved_tokens=_premium_reserved,
|
||||
reserved_micros=_premium_reserved_micros,
|
||||
)
|
||||
_premium_reserved = 0
|
||||
_premium_reserved_micros = 0
|
||||
except Exception:
|
||||
logging.getLogger(__name__).warning(
|
||||
"Failed to release premium quota for user %s", user_id
|
||||
|
|
@ -3369,8 +3384,8 @@ async def stream_resume_chat(
|
|||
"[stream_resume] LLM config loaded in %.3fs", time.perf_counter() - _t0
|
||||
)
|
||||
|
||||
# Premium quota reservation (same logic as stream_new_chat)
|
||||
_resume_premium_reserved = 0
|
||||
# Premium credit reservation (same logic as stream_new_chat).
|
||||
_resume_premium_reserved_micros = 0
|
||||
_resume_premium_request_id: str | None = None
|
||||
_resume_needs_premium = (
|
||||
agent_config is not None and user_id and agent_config.is_premium
|
||||
|
|
@ -3378,23 +3393,30 @@ async def stream_resume_chat(
|
|||
if _resume_needs_premium:
|
||||
import uuid as _uuid
|
||||
|
||||
from app.config import config as _app_config
|
||||
from app.services.token_quota_service import TokenQuotaService
|
||||
from app.services.token_quota_service import (
|
||||
TokenQuotaService,
|
||||
estimate_call_reserve_micros,
|
||||
)
|
||||
|
||||
_resume_premium_request_id = _uuid.uuid4().hex[:16]
|
||||
reserve_amount = min(
|
||||
agent_config.quota_reserve_tokens
|
||||
or _app_config.QUOTA_MAX_RESERVE_PER_CALL,
|
||||
_app_config.QUOTA_MAX_RESERVE_PER_CALL,
|
||||
_resume_litellm_params = agent_config.litellm_params or {}
|
||||
_resume_base_model = (
|
||||
_resume_litellm_params.get("base_model")
|
||||
or agent_config.model_name
|
||||
or ""
|
||||
)
|
||||
reserve_amount_micros = estimate_call_reserve_micros(
|
||||
base_model=_resume_base_model,
|
||||
quota_reserve_tokens=agent_config.quota_reserve_tokens,
|
||||
)
|
||||
async with shielded_async_session() as quota_session:
|
||||
quota_result = await TokenQuotaService.premium_reserve(
|
||||
db_session=quota_session,
|
||||
user_id=UUID(user_id),
|
||||
request_id=_resume_premium_request_id,
|
||||
reserve_tokens=reserve_amount,
|
||||
reserve_micros=reserve_amount_micros,
|
||||
)
|
||||
_resume_premium_reserved = reserve_amount
|
||||
_resume_premium_reserved_micros = reserve_amount_micros
|
||||
if not quota_result.allowed:
|
||||
if requested_llm_config_id == 0:
|
||||
try:
|
||||
|
|
@ -3429,7 +3451,7 @@ async def stream_resume_chat(
|
|||
yield streaming_service.format_done()
|
||||
return
|
||||
_resume_premium_request_id = None
|
||||
_resume_premium_reserved = 0
|
||||
_resume_premium_reserved_micros = 0
|
||||
_log_chat_stream_error(
|
||||
flow="resume",
|
||||
error_kind="premium_quota_exhausted",
|
||||
|
|
@ -3746,9 +3768,10 @@ async def stream_resume_chat(
|
|||
if stream_result.is_interrupted:
|
||||
usage_summary = accumulator.per_message_summary()
|
||||
_perf_log.info(
|
||||
"[token_usage] interrupted resume_chat: calls=%d total=%d summary=%s",
|
||||
"[token_usage] interrupted resume_chat: calls=%d total=%d cost_micros=%d summary=%s",
|
||||
len(accumulator.calls),
|
||||
accumulator.grand_total,
|
||||
accumulator.total_cost_micros,
|
||||
usage_summary,
|
||||
)
|
||||
if usage_summary:
|
||||
|
|
@ -3759,6 +3782,7 @@ async def stream_resume_chat(
|
|||
"prompt_tokens": accumulator.total_prompt_tokens,
|
||||
"completion_tokens": accumulator.total_completion_tokens,
|
||||
"total_tokens": accumulator.grand_total,
|
||||
"cost_micros": accumulator.total_cost_micros,
|
||||
"call_details": accumulator.serialized_calls(),
|
||||
},
|
||||
)
|
||||
|
|
@ -3768,7 +3792,9 @@ async def stream_resume_chat(
|
|||
yield streaming_service.format_done()
|
||||
return
|
||||
|
||||
# Finalize premium quota for resume path
|
||||
# Finalize premium credit debit for resume path with the actual
|
||||
# provider cost reported by LiteLLM (sum of cost across all
|
||||
# calls in the turn).
|
||||
if _resume_premium_request_id and user_id:
|
||||
try:
|
||||
from app.services.token_quota_service import TokenQuotaService
|
||||
|
|
@ -3778,11 +3804,11 @@ async def stream_resume_chat(
|
|||
db_session=quota_session,
|
||||
user_id=UUID(user_id),
|
||||
request_id=_resume_premium_request_id,
|
||||
actual_tokens=accumulator.grand_total,
|
||||
reserved_tokens=_resume_premium_reserved,
|
||||
actual_micros=accumulator.total_cost_micros,
|
||||
reserved_micros=_resume_premium_reserved_micros,
|
||||
)
|
||||
_resume_premium_request_id = None
|
||||
_resume_premium_reserved = 0
|
||||
_resume_premium_reserved_micros = 0
|
||||
except Exception:
|
||||
logging.getLogger(__name__).warning(
|
||||
"Failed to finalize premium quota for user %s (resume)",
|
||||
|
|
@ -3792,9 +3818,10 @@ async def stream_resume_chat(
|
|||
|
||||
usage_summary = accumulator.per_message_summary()
|
||||
_perf_log.info(
|
||||
"[token_usage] normal resume_chat: calls=%d total=%d summary=%s",
|
||||
"[token_usage] normal resume_chat: calls=%d total=%d cost_micros=%d summary=%s",
|
||||
len(accumulator.calls),
|
||||
accumulator.grand_total,
|
||||
accumulator.total_cost_micros,
|
||||
usage_summary,
|
||||
)
|
||||
if usage_summary:
|
||||
|
|
@ -3805,6 +3832,7 @@ async def stream_resume_chat(
|
|||
"prompt_tokens": accumulator.total_prompt_tokens,
|
||||
"completion_tokens": accumulator.total_completion_tokens,
|
||||
"total_tokens": accumulator.grand_total,
|
||||
"cost_micros": accumulator.total_cost_micros,
|
||||
"call_details": accumulator.serialized_calls(),
|
||||
},
|
||||
)
|
||||
|
|
@ -3855,7 +3883,11 @@ async def stream_resume_chat(
|
|||
end_turn(str(chat_id))
|
||||
|
||||
# Release premium reservation if not finalized
|
||||
if _resume_premium_request_id and _resume_premium_reserved > 0 and user_id:
|
||||
if (
|
||||
_resume_premium_request_id
|
||||
and _resume_premium_reserved_micros > 0
|
||||
and user_id
|
||||
):
|
||||
try:
|
||||
from app.services.token_quota_service import TokenQuotaService
|
||||
|
||||
|
|
@ -3863,9 +3895,9 @@ async def stream_resume_chat(
|
|||
await TokenQuotaService.premium_release(
|
||||
db_session=quota_session,
|
||||
user_id=UUID(user_id),
|
||||
reserved_tokens=_resume_premium_reserved,
|
||||
reserved_micros=_resume_premium_reserved_micros,
|
||||
)
|
||||
_resume_premium_reserved = 0
|
||||
_resume_premium_reserved_micros = 0
|
||||
except Exception:
|
||||
logging.getLogger(__name__).warning(
|
||||
"Failed to release premium quota for user %s (resume)", user_id
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue