mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-17 18:35:19 +02:00
Implement system-managed model catalog, subscription tier enforcement, atomic token quota tracking, and frontend cloud/self-hosted conditional rendering. Apply all 20 BMAD code review patches including security fixes (cross-user API key hijack), race condition mitigation (atomic SQL UPDATE), and SSE mid-stream quota error handling. Co-Authored-By: Claude Sonnet 4 <noreply@anthropic.com>
189 lines
6.5 KiB
Python
189 lines
6.5 KiB
Python
"""
|
|
Service for managing user LLM token quotas (cloud subscription mode).
|
|
|
|
Mirrors PageLimitService pattern for consistency.
|
|
"""
|
|
|
|
from datetime import UTC, date, datetime, timedelta
|
|
|
|
from sqlalchemy import select, update
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
|
|
class TokenQuotaExceededError(Exception):
|
|
"""
|
|
Exception raised when a user exceeds their monthly token quota.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
message: str = "Monthly token quota exceeded. Please upgrade your plan.",
|
|
tokens_used: int = 0,
|
|
monthly_token_limit: int = 0,
|
|
tokens_requested: int = 0,
|
|
):
|
|
self.tokens_used = tokens_used
|
|
self.monthly_token_limit = monthly_token_limit
|
|
self.tokens_requested = tokens_requested
|
|
super().__init__(message)
|
|
|
|
|
|
class TokenQuotaService:
|
|
"""Service for checking and updating user LLM token quotas."""
|
|
|
|
def __init__(self, session: AsyncSession):
|
|
self.session = session
|
|
|
|
async def _maybe_reset_monthly_tokens(self, user) -> None:
|
|
"""
|
|
Reset tokens_used_this_month to 0 if token_reset_date has passed.
|
|
|
|
Called before any quota check or update so that a new billing cycle
|
|
starts transparently without requiring a cron job or webhook trigger.
|
|
|
|
The token_reset_date is a Date column. We compare against UTC today.
|
|
|
|
NOTE: This method does NOT commit — the caller manages the transaction.
|
|
"""
|
|
today = datetime.now(UTC).date()
|
|
|
|
if not user.token_reset_date:
|
|
# First time — set reset date 30 days from now
|
|
user.token_reset_date = today + timedelta(days=30)
|
|
user.tokens_used_this_month = 0
|
|
return
|
|
|
|
reset_date = user.token_reset_date
|
|
# Handle if somehow stored as a string (legacy data)
|
|
if isinstance(reset_date, str):
|
|
try:
|
|
reset_date = date.fromisoformat(reset_date)
|
|
except ValueError:
|
|
reset_date = today + timedelta(days=30)
|
|
|
|
if today >= reset_date:
|
|
# New billing cycle — reset usage and advance reset date by 30 days
|
|
new_reset = reset_date + timedelta(days=30)
|
|
user.tokens_used_this_month = 0
|
|
user.token_reset_date = new_reset
|
|
|
|
async def check_token_quota(
|
|
self, user_id: str, estimated_tokens: int = 0
|
|
) -> tuple[bool, int, int]:
|
|
"""
|
|
Check if user has remaining token quota this month.
|
|
|
|
Args:
|
|
user_id: The user's UUID (string)
|
|
estimated_tokens: Optional pre-estimated input token count
|
|
|
|
Returns:
|
|
Tuple of (has_capacity, tokens_used, monthly_token_limit)
|
|
|
|
Raises:
|
|
TokenQuotaExceededError: If user would exceed their monthly limit
|
|
"""
|
|
from app.db import User
|
|
|
|
result = await self.session.execute(select(User).where(User.id == user_id))
|
|
user = result.unique().scalar_one_or_none()
|
|
|
|
if not user:
|
|
raise ValueError(f"User with ID {user_id} not found")
|
|
|
|
await self._maybe_reset_monthly_tokens(user)
|
|
await self.session.flush() # Persist any reset changes within the transaction
|
|
|
|
tokens_used = user.tokens_used_this_month or 0
|
|
token_limit = user.monthly_token_limit or 0
|
|
|
|
# Strict boundary: >= means at-limit is also exceeded
|
|
if tokens_used + estimated_tokens >= token_limit and token_limit > 0:
|
|
raise TokenQuotaExceededError(
|
|
message=(
|
|
f"Monthly token quota exceeded. "
|
|
f"Used: {tokens_used:,}/{token_limit:,} tokens. "
|
|
f"Estimated request: {estimated_tokens:,} tokens. "
|
|
f"Please upgrade your subscription plan."
|
|
),
|
|
tokens_used=tokens_used,
|
|
monthly_token_limit=token_limit,
|
|
tokens_requested=estimated_tokens,
|
|
)
|
|
|
|
return True, tokens_used, token_limit
|
|
|
|
async def update_token_usage(
|
|
self, user_id: str, tokens_to_add: int, allow_exceed: bool = True
|
|
) -> int:
|
|
"""
|
|
Atomically add tokens consumed to the user's monthly usage.
|
|
|
|
Uses a single SQL UPDATE with arithmetic expression to prevent
|
|
race conditions when multiple streams finish concurrently.
|
|
|
|
Args:
|
|
user_id: The user's UUID (string)
|
|
tokens_to_add: Actual tokens consumed (input + output)
|
|
allow_exceed: If True (default), records usage even if it pushes
|
|
past the limit. Set False to enforce hard cap at
|
|
update time (pre-check should already have fired).
|
|
|
|
Returns:
|
|
New total tokens_used_this_month value
|
|
"""
|
|
from app.db import User
|
|
|
|
if tokens_to_add <= 0:
|
|
# Nothing to deduct — fetch current usage and return
|
|
result = await self.session.execute(
|
|
select(User.tokens_used_this_month).where(User.id == user_id)
|
|
)
|
|
row = result.first()
|
|
if not row:
|
|
raise ValueError(f"User with ID {user_id} not found")
|
|
return row[0] or 0
|
|
|
|
# Atomic UPDATE: tokens_used = tokens_used + N (no read-modify-write)
|
|
stmt = (
|
|
update(User)
|
|
.where(User.id == user_id)
|
|
.values(tokens_used_this_month=User.tokens_used_this_month + tokens_to_add)
|
|
.returning(User.tokens_used_this_month)
|
|
)
|
|
result = await self.session.execute(stmt)
|
|
row = result.first()
|
|
|
|
if not row:
|
|
raise ValueError(f"User with ID {user_id} not found")
|
|
|
|
new_usage = row[0]
|
|
await self.session.commit()
|
|
|
|
return new_usage
|
|
|
|
async def get_token_usage(self, user_id: str) -> tuple[int, int]:
|
|
"""
|
|
Get user's current token usage and monthly limit.
|
|
|
|
Also triggers monthly reset check so the returned values
|
|
are always for the current billing cycle.
|
|
|
|
Args:
|
|
user_id: The user's UUID (string)
|
|
|
|
Returns:
|
|
Tuple of (tokens_used_this_month, monthly_token_limit)
|
|
"""
|
|
from app.db import User
|
|
|
|
result = await self.session.execute(select(User).where(User.id == user_id))
|
|
user = result.unique().scalar_one_or_none()
|
|
|
|
if not user:
|
|
raise ValueError(f"User with ID {user_id} not found")
|
|
|
|
await self._maybe_reset_monthly_tokens(user)
|
|
await self.session.flush()
|
|
|
|
return (user.tokens_used_this_month or 0, user.monthly_token_limit or 0)
|