mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-29 19:35:20 +02:00
refactor(chat): add streaming/flows/shared/premium_quota.py
Centralizes the premium-credits lifecycle for chat turns: * needs_premium_quota: gate check (premium user + non-fallback config). * PremiumReservation: dataclass capturing reservation state + token totals. * reserve_premium / finalize_premium / release_premium: idempotent reservation, commit, and rollback used by the orchestrators. Add-only; legacy stream_new_chat.py keeps its inline quota handling until cutover.
This commit is contained in:
parent
e9a98ecafb
commit
40300d300a
1 changed files with 132 additions and 0 deletions
|
|
@ -0,0 +1,132 @@
|
|||
"""Premium credit (USD micro-units) reserve / finalize / release lifecycle.
|
||||
|
||||
Both ``stream_new_chat`` and ``stream_resume_chat`` reserve premium credits up
|
||||
front (so a single LLM call can't run away with the budget), then finalize the
|
||||
actual provider cost reported by LiteLLM when the turn completes successfully,
|
||||
or release the reservation on the cancellation / interrupted-without-finalize
|
||||
paths.
|
||||
|
||||
State is held by the orchestrator as a simple ``PremiumReservation`` tuple
|
||||
so reservation, fallback-on-denied, finalize, and release can all be reasoned
|
||||
about from one place.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import uuid as _uuid
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING
|
||||
from uuid import UUID
|
||||
|
||||
from app.agents.new_chat.llm_config import AgentConfig
|
||||
from app.db import shielded_async_session
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.services.token_tracking_service import TokenAccumulator
|
||||
|
||||
|
||||
@dataclass
|
||||
class PremiumReservation:
|
||||
"""Active premium-credit reservation for one turn.
|
||||
|
||||
``request_id`` is the per-reservation idempotency key (also passed to
|
||||
``finalize``/``release`` so racing branches resolve to the same row).
|
||||
``reserved_micros`` is the up-front estimate; ``finalize`` debits the
|
||||
actual cost, ``release`` returns it untouched.
|
||||
"""
|
||||
|
||||
request_id: str
|
||||
reserved_micros: int
|
||||
allowed: bool
|
||||
|
||||
|
||||
def needs_premium_quota(
|
||||
agent_config: AgentConfig | None, user_id: str | None
|
||||
) -> bool:
|
||||
return bool(agent_config is not None and user_id and agent_config.is_premium)
|
||||
|
||||
|
||||
async def reserve_premium(
|
||||
*,
|
||||
agent_config: AgentConfig,
|
||||
user_id: str,
|
||||
) -> PremiumReservation:
|
||||
"""Reserve estimated micros up front; returns the reservation handle."""
|
||||
from app.services.token_quota_service import (
|
||||
TokenQuotaService,
|
||||
estimate_call_reserve_micros,
|
||||
)
|
||||
|
||||
request_id = _uuid.uuid4().hex[:16]
|
||||
litellm_params = agent_config.litellm_params or {}
|
||||
base_model = (
|
||||
litellm_params.get("base_model") if isinstance(litellm_params, dict) else None
|
||||
) or agent_config.model_name or ""
|
||||
reserve_amount_micros = estimate_call_reserve_micros(
|
||||
base_model=base_model,
|
||||
quota_reserve_tokens=agent_config.quota_reserve_tokens,
|
||||
)
|
||||
async with shielded_async_session() as quota_session:
|
||||
quota_result = await TokenQuotaService.premium_reserve(
|
||||
db_session=quota_session,
|
||||
user_id=UUID(user_id),
|
||||
request_id=request_id,
|
||||
reserve_micros=reserve_amount_micros,
|
||||
)
|
||||
return PremiumReservation(
|
||||
request_id=request_id,
|
||||
reserved_micros=reserve_amount_micros,
|
||||
allowed=quota_result.allowed,
|
||||
)
|
||||
|
||||
|
||||
async def finalize_premium(
|
||||
*,
|
||||
reservation: PremiumReservation,
|
||||
user_id: str,
|
||||
accumulator: TokenAccumulator,
|
||||
) -> None:
|
||||
"""Finalize debit using the actual provider cost reported by LiteLLM.
|
||||
|
||||
Best-effort: failures here must not bubble up to the SSE stream — the user
|
||||
has already received their tokens; we log and move on.
|
||||
"""
|
||||
try:
|
||||
from app.services.token_quota_service import TokenQuotaService
|
||||
|
||||
async with shielded_async_session() as quota_session:
|
||||
await TokenQuotaService.premium_finalize(
|
||||
db_session=quota_session,
|
||||
user_id=UUID(user_id),
|
||||
request_id=reservation.request_id,
|
||||
actual_micros=accumulator.total_cost_micros,
|
||||
reserved_micros=reservation.reserved_micros,
|
||||
)
|
||||
except Exception:
|
||||
logging.getLogger(__name__).warning(
|
||||
"Failed to finalize premium quota for user %s",
|
||||
user_id,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
|
||||
async def release_premium(
|
||||
*,
|
||||
reservation: PremiumReservation,
|
||||
user_id: str,
|
||||
) -> None:
|
||||
"""Release the reservation on cancellation paths; never raises."""
|
||||
try:
|
||||
from app.services.token_quota_service import TokenQuotaService
|
||||
|
||||
async with shielded_async_session() as quota_session:
|
||||
await TokenQuotaService.premium_release(
|
||||
db_session=quota_session,
|
||||
user_id=UUID(user_id),
|
||||
reserved_micros=reservation.reserved_micros,
|
||||
)
|
||||
except Exception:
|
||||
logging.getLogger(__name__).warning(
|
||||
"Failed to release premium quota for user %s", user_id
|
||||
)
|
||||
Loading…
Add table
Add a link
Reference in a new issue