mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-04 13:22:41 +02:00
feat: unified credits and its cost calculations
This commit is contained in:
parent
451a98936e
commit
ae9d36d77f
61 changed files with 5835 additions and 272 deletions
|
|
@ -36,6 +36,11 @@ from app.schemas import (
|
|||
ImageGenerationListRead,
|
||||
ImageGenerationRead,
|
||||
)
|
||||
from app.services.billable_calls import (
|
||||
DEFAULT_IMAGE_RESERVE_MICROS,
|
||||
QuotaInsufficientError,
|
||||
billable_call,
|
||||
)
|
||||
from app.services.image_gen_router_service import (
|
||||
IMAGE_GEN_AUTO_MODE_ID,
|
||||
ImageGenRouterService,
|
||||
|
|
@ -92,6 +97,50 @@ def _build_model_string(
|
|||
return f"{prefix}/{model_name}"
|
||||
|
||||
|
||||
async def _resolve_billing_for_image_gen(
|
||||
session: AsyncSession,
|
||||
config_id: int | None,
|
||||
search_space: SearchSpace,
|
||||
) -> tuple[str, str, int]:
|
||||
"""Resolve ``(billing_tier, base_model, reserve_micros)`` for a request.
|
||||
|
||||
The resolution mirrors ``_execute_image_generation``'s lookup tree but
|
||||
only extracts the fields needed for billing — we do this *before*
|
||||
``billable_call`` so the reservation is correctly sized for the
|
||||
config that will actually run, and so we don't open an
|
||||
``ImageGeneration`` row for a request that's about to 402.
|
||||
|
||||
User-owned (positive ID) BYOK configs are always free — they cost
|
||||
the user nothing on our side. Auto mode currently treats as free
|
||||
because the underlying router can dispatch to either premium or
|
||||
free YAML configs and we don't surface the resolved deployment up
|
||||
here yet. Bringing Auto under premium billing would require
|
||||
threading the chosen deployment back from ``ImageGenRouterService``.
|
||||
"""
|
||||
resolved_id = config_id
|
||||
if resolved_id is None:
|
||||
resolved_id = search_space.image_generation_config_id or IMAGE_GEN_AUTO_MODE_ID
|
||||
|
||||
if is_image_gen_auto_mode(resolved_id):
|
||||
return ("free", "auto", DEFAULT_IMAGE_RESERVE_MICROS)
|
||||
|
||||
if resolved_id < 0:
|
||||
cfg = _get_global_image_gen_config(resolved_id) or {}
|
||||
billing_tier = str(cfg.get("billing_tier", "free")).lower()
|
||||
base_model = _build_model_string(
|
||||
cfg.get("provider", ""),
|
||||
cfg.get("model_name", ""),
|
||||
cfg.get("custom_provider"),
|
||||
)
|
||||
reserve_micros = int(
|
||||
cfg.get("quota_reserve_micros") or DEFAULT_IMAGE_RESERVE_MICROS
|
||||
)
|
||||
return (billing_tier, base_model, reserve_micros)
|
||||
|
||||
# Positive ID = user-owned BYOK image-gen config — always free.
|
||||
return ("free", "user_byok", DEFAULT_IMAGE_RESERVE_MICROS)
|
||||
|
||||
|
||||
async def _execute_image_generation(
|
||||
session: AsyncSession,
|
||||
image_gen: ImageGeneration,
|
||||
|
|
@ -225,6 +274,9 @@ async def get_global_image_gen_configs(
|
|||
"litellm_params": {},
|
||||
"is_global": True,
|
||||
"is_auto_mode": True,
|
||||
# Auto mode currently treated as free until per-deployment
|
||||
# billing-tier surfacing lands (see _resolve_billing_for_image_gen).
|
||||
"billing_tier": "free",
|
||||
}
|
||||
)
|
||||
|
||||
|
|
@ -241,6 +293,8 @@ async def get_global_image_gen_configs(
|
|||
"api_version": cfg.get("api_version") or None,
|
||||
"litellm_params": cfg.get("litellm_params", {}),
|
||||
"is_global": True,
|
||||
"billing_tier": cfg.get("billing_tier", "free"),
|
||||
"quota_reserve_micros": cfg.get("quota_reserve_micros"),
|
||||
}
|
||||
)
|
||||
|
||||
|
|
@ -454,7 +508,26 @@ async def create_image_generation(
|
|||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""Create and execute an image generation request."""
|
||||
"""Create and execute an image generation request.
|
||||
|
||||
Premium configs are gated by the user's shared premium credit pool.
|
||||
The flow is:
|
||||
|
||||
1. Permission check + load the search space (cheap, no provider call).
|
||||
2. Resolve which config will run so we know its billing tier and the
|
||||
worst-case reservation size *before* opening any DB rows.
|
||||
3. Wrap the entire ImageGeneration row insert + provider call in
|
||||
``billable_call``. If quota is denied, ``billable_call`` raises
|
||||
``QuotaInsufficientError`` *before* we flush a row, which we
|
||||
translate to HTTP 402 (no orphaned rows on the user's account,
|
||||
no inserted error rows for "you ran out of credit").
|
||||
4. On success, the actual ``response_cost`` flows through the
|
||||
LiteLLM callback into the accumulator, and ``billable_call``
|
||||
finalizes the debit at exit. Inner ``try/except`` still catches
|
||||
provider errors and stores them on ``error_message`` (HTTP 200
|
||||
with ``error_message`` set is preserved for failed-but-not-quota
|
||||
scenarios — clients already know how to surface those).
|
||||
"""
|
||||
try:
|
||||
await check_permission(
|
||||
session,
|
||||
|
|
@ -471,33 +544,70 @@ async def create_image_generation(
|
|||
if not search_space:
|
||||
raise HTTPException(status_code=404, detail="Search space not found")
|
||||
|
||||
db_image_gen = ImageGeneration(
|
||||
prompt=data.prompt,
|
||||
model=data.model,
|
||||
n=data.n,
|
||||
quality=data.quality,
|
||||
size=data.size,
|
||||
style=data.style,
|
||||
response_format=data.response_format,
|
||||
image_generation_config_id=data.image_generation_config_id,
|
||||
search_space_id=data.search_space_id,
|
||||
created_by_id=user.id,
|
||||
billing_tier, base_model, reserve_micros = await _resolve_billing_for_image_gen(
|
||||
session, data.image_generation_config_id, search_space
|
||||
)
|
||||
session.add(db_image_gen)
|
||||
await session.flush()
|
||||
|
||||
try:
|
||||
await _execute_image_generation(session, db_image_gen, search_space)
|
||||
except Exception as e:
|
||||
logger.exception("Image generation call failed")
|
||||
db_image_gen.error_message = str(e)
|
||||
# billable_call runs OUTSIDE the inner try/except so QuotaInsufficientError
|
||||
# propagates to the outer ``except QuotaInsufficientError`` handler
|
||||
# below as HTTP 402 — it is intentionally NOT swallowed into
|
||||
# ``error_message`` because that would (1) imply a successful row
|
||||
# exists when none does, and (2) return HTTP 200 to a client
|
||||
# whose request was actively *denied* (issue K).
|
||||
async with billable_call(
|
||||
user_id=search_space.user_id,
|
||||
search_space_id=data.search_space_id,
|
||||
billing_tier=billing_tier,
|
||||
base_model=base_model,
|
||||
quota_reserve_micros_override=reserve_micros,
|
||||
usage_type="image_generation",
|
||||
call_details={"model": base_model, "prompt": data.prompt[:100]},
|
||||
):
|
||||
db_image_gen = ImageGeneration(
|
||||
prompt=data.prompt,
|
||||
model=data.model,
|
||||
n=data.n,
|
||||
quality=data.quality,
|
||||
size=data.size,
|
||||
style=data.style,
|
||||
response_format=data.response_format,
|
||||
image_generation_config_id=data.image_generation_config_id,
|
||||
search_space_id=data.search_space_id,
|
||||
created_by_id=user.id,
|
||||
)
|
||||
session.add(db_image_gen)
|
||||
await session.flush()
|
||||
|
||||
await session.commit()
|
||||
await session.refresh(db_image_gen)
|
||||
return db_image_gen
|
||||
try:
|
||||
await _execute_image_generation(session, db_image_gen, search_space)
|
||||
except Exception as e:
|
||||
logger.exception("Image generation call failed")
|
||||
db_image_gen.error_message = str(e)
|
||||
|
||||
await session.commit()
|
||||
await session.refresh(db_image_gen)
|
||||
return db_image_gen
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except QuotaInsufficientError as exc:
|
||||
# The user's premium credit pool is empty. No DB row is created
|
||||
# because ``billable_call`` denies before yielding (issue K).
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=402,
|
||||
detail={
|
||||
"error_code": "premium_quota_exhausted",
|
||||
"usage_type": exc.usage_type,
|
||||
"used_micros": exc.used_micros,
|
||||
"limit_micros": exc.limit_micros,
|
||||
"remaining_micros": exc.remaining_micros,
|
||||
"message": (
|
||||
"Out of premium credits for image generation. "
|
||||
"Purchase additional credits or switch to a free model."
|
||||
),
|
||||
},
|
||||
) from exc
|
||||
except SQLAlchemyError:
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
|
|
|
|||
|
|
@ -1366,7 +1366,11 @@ async def append_message(
|
|||
# flush assigns the PK/defaults without a round-trip SELECT
|
||||
await session.flush()
|
||||
|
||||
# Persist token usage if provided (for assistant messages)
|
||||
# Persist token usage if provided (for assistant messages).
|
||||
# ``cost_micros`` is the provider USD cost reported by LiteLLM,
|
||||
# forwarded by the FE through the appendMessage round-trip so
|
||||
# the historical TokenUsage row matches the credit debit applied
|
||||
# at finalize time.
|
||||
token_usage_data = raw_body.get("token_usage")
|
||||
if token_usage_data and message_role == NewChatMessageRole.ASSISTANT:
|
||||
await record_token_usage(
|
||||
|
|
@ -1377,6 +1381,7 @@ async def append_message(
|
|||
prompt_tokens=token_usage_data.get("prompt_tokens", 0),
|
||||
completion_tokens=token_usage_data.get("completion_tokens", 0),
|
||||
total_tokens=token_usage_data.get("total_tokens", 0),
|
||||
cost_micros=token_usage_data.get("cost_micros", 0),
|
||||
model_breakdown=token_usage_data.get("usage"),
|
||||
call_details=token_usage_data.get("call_details"),
|
||||
thread_id=thread_id,
|
||||
|
|
|
|||
|
|
@ -594,6 +594,7 @@ async def _get_image_gen_config_by_id(
|
|||
"model_name": "auto",
|
||||
"is_global": True,
|
||||
"is_auto_mode": True,
|
||||
"billing_tier": "free",
|
||||
}
|
||||
|
||||
if config_id < 0:
|
||||
|
|
@ -610,6 +611,7 @@ async def _get_image_gen_config_by_id(
|
|||
"api_version": cfg.get("api_version") or None,
|
||||
"litellm_params": cfg.get("litellm_params", {}),
|
||||
"is_global": True,
|
||||
"billing_tier": cfg.get("billing_tier", "free"),
|
||||
}
|
||||
return None
|
||||
|
||||
|
|
@ -652,6 +654,7 @@ async def _get_vision_llm_config_by_id(
|
|||
"model_name": "auto",
|
||||
"is_global": True,
|
||||
"is_auto_mode": True,
|
||||
"billing_tier": "free",
|
||||
}
|
||||
|
||||
if config_id < 0:
|
||||
|
|
@ -668,6 +671,7 @@ async def _get_vision_llm_config_by_id(
|
|||
"api_version": cfg.get("api_version") or None,
|
||||
"litellm_params": cfg.get("litellm_params", {}),
|
||||
"is_global": True,
|
||||
"billing_tier": cfg.get("billing_tier", "free"),
|
||||
}
|
||||
return None
|
||||
|
||||
|
|
|
|||
|
|
@ -251,9 +251,16 @@ async def _fulfill_completed_token_purchase(
|
|||
metadata = _get_metadata(checkout_session)
|
||||
user_id = metadata.get("user_id")
|
||||
quantity = int(metadata.get("quantity", "0"))
|
||||
tokens_per_unit = int(metadata.get("tokens_per_unit", "0"))
|
||||
# Read the new metadata key first, fall back to the legacy one so
|
||||
# in-flight checkout sessions created before the cost-credits
|
||||
# release still fulfil correctly (the unit is numerically the
|
||||
# same: $1 buys 1_000_000 micro-USD == 1_000_000 tokens).
|
||||
credit_micros_per_unit = int(
|
||||
metadata.get("credit_micros_per_unit")
|
||||
or metadata.get("tokens_per_unit", "0")
|
||||
)
|
||||
|
||||
if not user_id or quantity <= 0 or tokens_per_unit <= 0:
|
||||
if not user_id or quantity <= 0 or credit_micros_per_unit <= 0:
|
||||
logger.error(
|
||||
"Skipping token fulfillment for session %s: incomplete metadata %s",
|
||||
checkout_session_id,
|
||||
|
|
@ -268,7 +275,7 @@ async def _fulfill_completed_token_purchase(
|
|||
getattr(checkout_session, "payment_intent", None)
|
||||
),
|
||||
quantity=quantity,
|
||||
tokens_granted=quantity * tokens_per_unit,
|
||||
credit_micros_granted=quantity * credit_micros_per_unit,
|
||||
amount_total=getattr(checkout_session, "amount_total", None),
|
||||
currency=getattr(checkout_session, "currency", None),
|
||||
status=PremiumTokenPurchaseStatus.PENDING,
|
||||
|
|
@ -303,9 +310,14 @@ async def _fulfill_completed_token_purchase(
|
|||
purchase.stripe_payment_intent_id = _normalize_optional_string(
|
||||
getattr(checkout_session, "payment_intent", None)
|
||||
)
|
||||
user.premium_tokens_limit = (
|
||||
max(user.premium_tokens_used, user.premium_tokens_limit)
|
||||
+ purchase.tokens_granted
|
||||
# Top up the user's credit balance by the granted micro-USD amount.
|
||||
# ``max(used, limit)`` clamps the case where the legacy code wrote a
|
||||
# used value above the limit (e.g. underbilling rounding) so adding
|
||||
# ``credit_micros_granted`` always lifts the limit by the full pack
|
||||
# size rather than disappearing into past overuse.
|
||||
user.premium_credit_micros_limit = (
|
||||
max(user.premium_credit_micros_used, user.premium_credit_micros_limit)
|
||||
+ purchase.credit_micros_granted
|
||||
)
|
||||
|
||||
await db_session.commit()
|
||||
|
|
@ -532,12 +544,18 @@ async def create_token_checkout_session(
|
|||
user: User = Depends(current_active_user),
|
||||
db_session: AsyncSession = Depends(get_async_session),
|
||||
):
|
||||
"""Create a Stripe Checkout Session for buying premium token packs."""
|
||||
"""Create a Stripe Checkout Session for buying premium credit packs.
|
||||
|
||||
Each pack grants ``STRIPE_CREDIT_MICROS_PER_UNIT`` micro-USD of
|
||||
credit (default 1_000_000 = $1.00). The user's balance is debited
|
||||
at the actual provider cost reported by LiteLLM at finalize time,
|
||||
so $1 of credit always buys $1 worth of provider usage at cost.
|
||||
"""
|
||||
_ensure_token_buying_enabled()
|
||||
stripe_client = get_stripe_client()
|
||||
price_id = _get_required_token_price_id()
|
||||
success_url, cancel_url = _get_token_checkout_urls(body.search_space_id)
|
||||
tokens_granted = body.quantity * config.STRIPE_TOKENS_PER_UNIT
|
||||
credit_micros_granted = body.quantity * config.STRIPE_CREDIT_MICROS_PER_UNIT
|
||||
|
||||
try:
|
||||
checkout_session = stripe_client.v1.checkout.sessions.create(
|
||||
|
|
@ -556,8 +574,8 @@ async def create_token_checkout_session(
|
|||
"metadata": {
|
||||
"user_id": str(user.id),
|
||||
"quantity": str(body.quantity),
|
||||
"tokens_per_unit": str(config.STRIPE_TOKENS_PER_UNIT),
|
||||
"purchase_type": "premium_tokens",
|
||||
"credit_micros_per_unit": str(config.STRIPE_CREDIT_MICROS_PER_UNIT),
|
||||
"purchase_type": "premium_credit",
|
||||
},
|
||||
}
|
||||
)
|
||||
|
|
@ -583,7 +601,7 @@ async def create_token_checkout_session(
|
|||
getattr(checkout_session, "payment_intent", None)
|
||||
),
|
||||
quantity=body.quantity,
|
||||
tokens_granted=tokens_granted,
|
||||
credit_micros_granted=credit_micros_granted,
|
||||
amount_total=getattr(checkout_session, "amount_total", None),
|
||||
currency=getattr(checkout_session, "currency", None),
|
||||
status=PremiumTokenPurchaseStatus.PENDING,
|
||||
|
|
@ -598,14 +616,19 @@ async def create_token_checkout_session(
|
|||
async def get_token_status(
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""Return token-buying availability and current premium quota for frontend."""
|
||||
used = user.premium_tokens_used
|
||||
limit = user.premium_tokens_limit
|
||||
"""Return token-buying availability and current premium credit quota for frontend.
|
||||
|
||||
Values are in micro-USD (1_000_000 = $1.00); the FE divides by 1M
|
||||
when displaying. The route name is preserved for back-compat with
|
||||
pinned client deployments.
|
||||
"""
|
||||
used = user.premium_credit_micros_used
|
||||
limit = user.premium_credit_micros_limit
|
||||
return TokenStripeStatusResponse(
|
||||
token_buying_enabled=config.STRIPE_TOKEN_BUYING_ENABLED,
|
||||
premium_tokens_used=used,
|
||||
premium_tokens_limit=limit,
|
||||
premium_tokens_remaining=max(0, limit - used),
|
||||
premium_credit_micros_used=used,
|
||||
premium_credit_micros_limit=limit,
|
||||
premium_credit_micros_remaining=max(0, limit - used),
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -82,6 +82,9 @@ async def get_global_vision_llm_configs(
|
|||
"litellm_params": {},
|
||||
"is_global": True,
|
||||
"is_auto_mode": True,
|
||||
# Auto mode treated as free until per-deployment billing-tier
|
||||
# surfacing lands; see ``get_vision_llm`` for parity.
|
||||
"billing_tier": "free",
|
||||
}
|
||||
)
|
||||
|
||||
|
|
@ -98,6 +101,10 @@ async def get_global_vision_llm_configs(
|
|||
"api_version": cfg.get("api_version") or None,
|
||||
"litellm_params": cfg.get("litellm_params", {}),
|
||||
"is_global": True,
|
||||
"billing_tier": cfg.get("billing_tier", "free"),
|
||||
"quota_reserve_tokens": cfg.get("quota_reserve_tokens"),
|
||||
"input_cost_per_token": cfg.get("input_cost_per_token"),
|
||||
"output_cost_per_token": cfg.get("output_cost_per_token"),
|
||||
}
|
||||
)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue