mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-04 13:22:41 +02:00
feat: unified credits and its cost calculations
This commit is contained in:
parent
451a98936e
commit
ae9d36d77f
61 changed files with 5835 additions and 272 deletions
|
|
@ -16,11 +16,14 @@ from __future__ import annotations
|
|||
|
||||
import dataclasses
|
||||
import logging
|
||||
from collections.abc import AsyncIterator
|
||||
from contextlib import asynccontextmanager
|
||||
from contextvars import ContextVar
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
import litellm
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
|
|
@ -35,6 +38,8 @@ class TokenCallRecord:
|
|||
prompt_tokens: int
|
||||
completion_tokens: int
|
||||
total_tokens: int
|
||||
cost_micros: int = 0
|
||||
call_kind: str = "chat"
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -49,6 +54,8 @@ class TurnTokenAccumulator:
|
|||
prompt_tokens: int,
|
||||
completion_tokens: int,
|
||||
total_tokens: int,
|
||||
cost_micros: int = 0,
|
||||
call_kind: str = "chat",
|
||||
) -> None:
|
||||
self.calls.append(
|
||||
TokenCallRecord(
|
||||
|
|
@ -56,20 +63,28 @@ class TurnTokenAccumulator:
|
|||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=total_tokens,
|
||||
cost_micros=cost_micros,
|
||||
call_kind=call_kind,
|
||||
)
|
||||
)
|
||||
|
||||
def per_message_summary(self) -> dict[str, dict[str, int]]:
|
||||
"""Return token counts grouped by model name."""
|
||||
"""Return token counts (and cost) grouped by model name."""
|
||||
by_model: dict[str, dict[str, int]] = {}
|
||||
for c in self.calls:
|
||||
entry = by_model.setdefault(
|
||||
c.model,
|
||||
{"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
|
||||
{
|
||||
"prompt_tokens": 0,
|
||||
"completion_tokens": 0,
|
||||
"total_tokens": 0,
|
||||
"cost_micros": 0,
|
||||
},
|
||||
)
|
||||
entry["prompt_tokens"] += c.prompt_tokens
|
||||
entry["completion_tokens"] += c.completion_tokens
|
||||
entry["total_tokens"] += c.total_tokens
|
||||
entry["cost_micros"] += c.cost_micros
|
||||
return by_model
|
||||
|
||||
@property
|
||||
|
|
@ -84,6 +99,21 @@ class TurnTokenAccumulator:
|
|||
def total_completion_tokens(self) -> int:
|
||||
return sum(c.completion_tokens for c in self.calls)
|
||||
|
||||
@property
|
||||
def total_cost_micros(self) -> int:
|
||||
"""Sum of per-call ``cost_micros`` across the entire turn.
|
||||
|
||||
Used by ``stream_new_chat`` to debit a premium turn's actual
|
||||
provider cost (in micro-USD) from the user's premium credit
|
||||
balance. ``cost_micros`` per call is captured by
|
||||
``TokenTrackingCallback.async_log_success_event`` from
|
||||
``kwargs["response_cost"]`` (LiteLLM's auto-calculated cost),
|
||||
with multiple fallback paths so OpenRouter dynamic models and
|
||||
custom Azure deployments still bill correctly when our
|
||||
``pricing_registration`` ran at startup.
|
||||
"""
|
||||
return sum(c.cost_micros for c in self.calls)
|
||||
|
||||
def serialized_calls(self) -> list[dict[str, Any]]:
|
||||
return [dataclasses.asdict(c) for c in self.calls]
|
||||
|
||||
|
|
@ -94,7 +124,14 @@ _turn_accumulator: ContextVar[TurnTokenAccumulator | None] = ContextVar(
|
|||
|
||||
|
||||
def start_turn() -> TurnTokenAccumulator:
|
||||
"""Create a fresh accumulator for the current async context and return it."""
|
||||
"""Create a fresh accumulator for the current async context and return it.
|
||||
|
||||
NOTE: Used by ``stream_new_chat`` for the long-lived chat turn. For
|
||||
short-lived per-call billable wrappers (image generation REST endpoint,
|
||||
vision LLM during indexing) prefer :func:`scoped_turn`, which uses a
|
||||
ContextVar reset token to restore the *previous* accumulator on exit and
|
||||
avoids leaking call records across reservations (issue B).
|
||||
"""
|
||||
acc = TurnTokenAccumulator()
|
||||
_turn_accumulator.set(acc)
|
||||
logger.info("[TokenTracking] start_turn: new accumulator created (id=%s)", id(acc))
|
||||
|
|
@ -105,6 +142,140 @@ def get_current_accumulator() -> TurnTokenAccumulator | None:
|
|||
return _turn_accumulator.get()
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def scoped_turn() -> AsyncIterator[TurnTokenAccumulator]:
|
||||
"""Async context manager that scopes a fresh ``TurnTokenAccumulator``
|
||||
for the duration of the ``async with`` block, then *resets* the
|
||||
ContextVar to its previous value on exit.
|
||||
|
||||
This is the safe primitive for per-call billable operations
|
||||
(image generation, vision LLM extraction, podcasts) that may run
|
||||
inside an outer chat turn or be called sequentially from the same
|
||||
background worker. Using ``ContextVar.set`` without ``reset`` (as
|
||||
:func:`start_turn` does) would leak the inner accumulator into the
|
||||
outer scope, causing the outer chat turn to debit cost twice.
|
||||
|
||||
Usage::
|
||||
|
||||
async with scoped_turn() as acc:
|
||||
await llm.ainvoke(...)
|
||||
# acc.total_cost_micros captures cost from the LiteLLM callback
|
||||
# Outer accumulator (if any) is restored here.
|
||||
"""
|
||||
acc = TurnTokenAccumulator()
|
||||
token = _turn_accumulator.set(acc)
|
||||
logger.debug(
|
||||
"[TokenTracking] scoped_turn: enter (acc id=%s, prev token=%s)",
|
||||
id(acc),
|
||||
token,
|
||||
)
|
||||
try:
|
||||
yield acc
|
||||
finally:
|
||||
_turn_accumulator.reset(token)
|
||||
logger.debug(
|
||||
"[TokenTracking] scoped_turn: exit (acc id=%s captured %d call(s), %d micros total)",
|
||||
id(acc),
|
||||
len(acc.calls),
|
||||
acc.total_cost_micros,
|
||||
)
|
||||
|
||||
|
||||
def _extract_cost_usd(
|
||||
kwargs: dict[str, Any],
|
||||
response_obj: Any,
|
||||
model: str,
|
||||
prompt_tokens: int,
|
||||
completion_tokens: int,
|
||||
is_image: bool = False,
|
||||
) -> float:
|
||||
"""Best-effort USD cost extraction for a single LLM/image call.
|
||||
|
||||
Tries four sources in priority order and returns the first that
|
||||
yields a positive number; returns 0.0 if all four fail (the call
|
||||
will then debit nothing from the user's balance — fail-safe).
|
||||
|
||||
Sources:
|
||||
1. ``kwargs["response_cost"]`` — LiteLLM's standard callback
|
||||
field, populated for ``Router.acompletion`` since PR #12500.
|
||||
2. ``response_obj._hidden_params["response_cost"]`` — same value
|
||||
exposed on the response itself.
|
||||
3. ``litellm.completion_cost(completion_response=response_obj)``
|
||||
— recompute from the response and LiteLLM's pricing table.
|
||||
4. ``litellm.cost_per_token(model, prompt_tokens, completion_tokens)``
|
||||
— manual fallback for OpenRouter/custom-Azure models that
|
||||
only resolve via aliases registered by
|
||||
``pricing_registration`` at startup. **Skipped for image
|
||||
responses** — ``cost_per_token`` does not support ``ImageResponse``
|
||||
and would raise; the cost map for image-gen lives in different
|
||||
keys (``output_cost_per_image``) handled by ``completion_cost``.
|
||||
"""
|
||||
cost = kwargs.get("response_cost")
|
||||
if cost is not None:
|
||||
try:
|
||||
value = float(cost)
|
||||
except (TypeError, ValueError):
|
||||
value = 0.0
|
||||
if value > 0:
|
||||
return value
|
||||
|
||||
hidden = getattr(response_obj, "_hidden_params", None) or {}
|
||||
if isinstance(hidden, dict):
|
||||
cost = hidden.get("response_cost")
|
||||
if cost is not None:
|
||||
try:
|
||||
value = float(cost)
|
||||
except (TypeError, ValueError):
|
||||
value = 0.0
|
||||
if value > 0:
|
||||
return value
|
||||
|
||||
try:
|
||||
value = float(litellm.completion_cost(completion_response=response_obj))
|
||||
if value > 0:
|
||||
return value
|
||||
except Exception as exc:
|
||||
if is_image:
|
||||
# Image-gen path: OpenRouter's image responses can omit
|
||||
# ``usage.cost`` and LiteLLM's ``default_image_cost_calculator``
|
||||
# then *raises* (no cost map for OpenRouter image models).
|
||||
# Bail out with a warning rather than falling through to
|
||||
# cost_per_token (which is also incompatible with ImageResponse).
|
||||
logger.warning(
|
||||
"[TokenTracking] completion_cost failed for image model=%s "
|
||||
"(provider may have omitted usage.cost). Debiting 0. "
|
||||
"Cause: %s",
|
||||
model,
|
||||
exc,
|
||||
)
|
||||
return 0.0
|
||||
logger.debug(
|
||||
"[TokenTracking] completion_cost failed for model=%s: %s", model, exc
|
||||
)
|
||||
|
||||
if is_image:
|
||||
# Never call cost_per_token for ImageResponse — keys mismatch and
|
||||
# the function is documented chat-only.
|
||||
return 0.0
|
||||
|
||||
if model and (prompt_tokens > 0 or completion_tokens > 0):
|
||||
try:
|
||||
prompt_cost, completion_cost = litellm.cost_per_token(
|
||||
model=model,
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
)
|
||||
value = float(prompt_cost) + float(completion_cost)
|
||||
if value > 0:
|
||||
return value
|
||||
except Exception as exc:
|
||||
logger.debug(
|
||||
"[TokenTracking] cost_per_token failed for model=%s: %s", model, exc
|
||||
)
|
||||
|
||||
return 0.0
|
||||
|
||||
|
||||
class TokenTrackingCallback(CustomLogger):
|
||||
"""LiteLLM callback that captures token usage into the turn accumulator."""
|
||||
|
||||
|
|
@ -122,6 +293,13 @@ class TokenTrackingCallback(CustomLogger):
|
|||
)
|
||||
return
|
||||
|
||||
# Detect image generation responses — they have a different usage
|
||||
# shape (ImageUsage with input_tokens/output_tokens) and require a
|
||||
# different cost-extraction path. We probe by class name to avoid a
|
||||
# hard import dependency on litellm internals.
|
||||
response_cls = type(response_obj).__name__
|
||||
is_image = response_cls == "ImageResponse"
|
||||
|
||||
usage = getattr(response_obj, "usage", None)
|
||||
if not usage:
|
||||
logger.debug(
|
||||
|
|
@ -129,24 +307,66 @@ class TokenTrackingCallback(CustomLogger):
|
|||
)
|
||||
return
|
||||
|
||||
prompt_tokens = getattr(usage, "prompt_tokens", 0) or 0
|
||||
completion_tokens = getattr(usage, "completion_tokens", 0) or 0
|
||||
total_tokens = getattr(usage, "total_tokens", 0) or 0
|
||||
if is_image:
|
||||
# ``ImageUsage`` exposes ``input_tokens`` / ``output_tokens``
|
||||
# (not prompt_tokens/completion_tokens). Several providers
|
||||
# populate only one or neither (e.g. OpenRouter's gpt-image-1
|
||||
# passes through `input_tokens` from the prompt but no
|
||||
# completion); fall through gracefully to 0.
|
||||
prompt_tokens = getattr(usage, "input_tokens", 0) or 0
|
||||
completion_tokens = getattr(usage, "output_tokens", 0) or 0
|
||||
total_tokens = (
|
||||
getattr(usage, "total_tokens", 0) or prompt_tokens + completion_tokens
|
||||
)
|
||||
call_kind = "image_generation"
|
||||
else:
|
||||
prompt_tokens = getattr(usage, "prompt_tokens", 0) or 0
|
||||
completion_tokens = getattr(usage, "completion_tokens", 0) or 0
|
||||
total_tokens = getattr(usage, "total_tokens", 0) or 0
|
||||
call_kind = "chat"
|
||||
|
||||
model = kwargs.get("model", "unknown")
|
||||
|
||||
cost_usd = _extract_cost_usd(
|
||||
kwargs=kwargs,
|
||||
response_obj=response_obj,
|
||||
model=model,
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
is_image=is_image,
|
||||
)
|
||||
cost_micros = round(cost_usd * 1_000_000) if cost_usd > 0 else 0
|
||||
|
||||
if cost_micros == 0 and (prompt_tokens > 0 or completion_tokens > 0):
|
||||
logger.warning(
|
||||
"[TokenTracking] No cost resolved for model=%s prompt=%d completion=%d "
|
||||
"kind=%s — debiting 0. Register pricing via pricing_registration or YAML "
|
||||
"input_cost_per_token/output_cost_per_token (or rely on response_cost "
|
||||
"for image generation).",
|
||||
model,
|
||||
prompt_tokens,
|
||||
completion_tokens,
|
||||
call_kind,
|
||||
)
|
||||
|
||||
acc.add(
|
||||
model=model,
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=total_tokens,
|
||||
cost_micros=cost_micros,
|
||||
call_kind=call_kind,
|
||||
)
|
||||
logger.info(
|
||||
"[TokenTracking] Captured: model=%s prompt=%d completion=%d total=%d (accumulator now has %d calls)",
|
||||
"[TokenTracking] Captured: model=%s kind=%s prompt=%d completion=%d total=%d "
|
||||
"cost=$%.6f (%d micros) (accumulator now has %d calls)",
|
||||
model,
|
||||
call_kind,
|
||||
prompt_tokens,
|
||||
completion_tokens,
|
||||
total_tokens,
|
||||
cost_usd,
|
||||
cost_micros,
|
||||
len(acc.calls),
|
||||
)
|
||||
|
||||
|
|
@ -168,6 +388,7 @@ async def record_token_usage(
|
|||
prompt_tokens: int = 0,
|
||||
completion_tokens: int = 0,
|
||||
total_tokens: int = 0,
|
||||
cost_micros: int = 0,
|
||||
model_breakdown: dict[str, Any] | None = None,
|
||||
call_details: dict[str, Any] | None = None,
|
||||
thread_id: int | None = None,
|
||||
|
|
@ -185,6 +406,7 @@ async def record_token_usage(
|
|||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=total_tokens,
|
||||
cost_micros=cost_micros,
|
||||
model_breakdown=model_breakdown,
|
||||
call_details=call_details,
|
||||
thread_id=thread_id,
|
||||
|
|
@ -194,11 +416,12 @@ async def record_token_usage(
|
|||
)
|
||||
session.add(record)
|
||||
logger.debug(
|
||||
"[TokenTracking] recorded %s usage: prompt=%d completion=%d total=%d",
|
||||
"[TokenTracking] recorded %s usage: prompt=%d completion=%d total=%d cost_micros=%d",
|
||||
usage_type,
|
||||
prompt_tokens,
|
||||
completion_tokens,
|
||||
total_tokens,
|
||||
cost_micros,
|
||||
)
|
||||
return record
|
||||
except Exception:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue