mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-09 07:42:39 +02:00
feat(story-3.5): add cloud-mode LLM model selection with token quota enforcement
Implement system-managed model catalog, subscription tier enforcement, atomic token quota tracking, and frontend cloud/self-hosted conditional rendering. Apply all 20 BMAD code review patches including security fixes (cross-user API key hijack), race condition mitigation (atomic SQL UPDATE), and SSE mid-stream quota error handling. Co-Authored-By: Claude Sonnet 4 <noreply@anthropic.com>
This commit is contained in:
parent
e7382b26de
commit
c1776b3ec8
19 changed files with 1003 additions and 34 deletions
|
|
@ -52,6 +52,9 @@ global_llm_configs:
|
|||
model_name: "gpt-4-turbo-preview"
|
||||
api_key: "sk-your-openai-api-key-here"
|
||||
api_base: ""
|
||||
tier_required: "pro" # free | pro | enterprise
|
||||
cost_per_1k_input_tokens: 0.01
|
||||
cost_per_1k_output_tokens: 0.03
|
||||
# Rate limits for load balancing (requests/tokens per minute)
|
||||
rpm: 500 # Requests per minute
|
||||
tpm: 100000 # Tokens per minute
|
||||
|
|
@ -71,6 +74,9 @@ global_llm_configs:
|
|||
model_name: "claude-3-opus-20240229"
|
||||
api_key: "sk-ant-your-anthropic-api-key-here"
|
||||
api_base: ""
|
||||
tier_required: "pro"
|
||||
cost_per_1k_input_tokens: 0.015
|
||||
cost_per_1k_output_tokens: 0.075
|
||||
rpm: 1000
|
||||
tpm: 100000
|
||||
litellm_params:
|
||||
|
|
@ -88,6 +94,9 @@ global_llm_configs:
|
|||
model_name: "gpt-3.5-turbo"
|
||||
api_key: "sk-your-openai-api-key-here"
|
||||
api_base: ""
|
||||
tier_required: "free"
|
||||
cost_per_1k_input_tokens: 0.0005
|
||||
cost_per_1k_output_tokens: 0.0015
|
||||
rpm: 3500 # GPT-3.5 has higher rate limits
|
||||
tpm: 200000
|
||||
litellm_params:
|
||||
|
|
@ -105,6 +114,9 @@ global_llm_configs:
|
|||
model_name: "deepseek-chat"
|
||||
api_key: "your-deepseek-api-key-here"
|
||||
api_base: "https://api.deepseek.com/v1"
|
||||
tier_required: "free"
|
||||
cost_per_1k_input_tokens: 0.0001
|
||||
cost_per_1k_output_tokens: 0.0002
|
||||
rpm: 60
|
||||
tpm: 100000
|
||||
litellm_params:
|
||||
|
|
@ -134,6 +146,9 @@ global_llm_configs:
|
|||
api_key: "your-azure-api-key-here"
|
||||
api_base: "https://your-resource.openai.azure.com"
|
||||
api_version: "2024-02-15-preview" # Azure API version
|
||||
tier_required: "pro"
|
||||
cost_per_1k_input_tokens: 0.005
|
||||
cost_per_1k_output_tokens: 0.015
|
||||
rpm: 1000
|
||||
tpm: 150000
|
||||
litellm_params:
|
||||
|
|
@ -156,6 +171,9 @@ global_llm_configs:
|
|||
api_key: "your-azure-api-key-here"
|
||||
api_base: "https://your-resource.openai.azure.com"
|
||||
api_version: "2024-02-15-preview"
|
||||
tier_required: "pro"
|
||||
cost_per_1k_input_tokens: 0.01
|
||||
cost_per_1k_output_tokens: 0.03
|
||||
rpm: 500
|
||||
tpm: 100000
|
||||
litellm_params:
|
||||
|
|
@ -174,6 +192,9 @@ global_llm_configs:
|
|||
model_name: "llama3-70b-8192"
|
||||
api_key: "your-groq-api-key-here"
|
||||
api_base: ""
|
||||
tier_required: "pro"
|
||||
cost_per_1k_input_tokens: 0.00059
|
||||
cost_per_1k_output_tokens: 0.00079
|
||||
rpm: 30 # Groq has lower rate limits on free tier
|
||||
tpm: 14400
|
||||
litellm_params:
|
||||
|
|
@ -191,6 +212,9 @@ global_llm_configs:
|
|||
model_name: "MiniMax-M2.5"
|
||||
api_key: "your-minimax-api-key-here"
|
||||
api_base: "https://api.minimax.io/v1"
|
||||
tier_required: "free"
|
||||
cost_per_1k_input_tokens: 0.001
|
||||
cost_per_1k_output_tokens: 0.003
|
||||
rpm: 60
|
||||
tpm: 100000
|
||||
litellm_params:
|
||||
|
|
@ -347,6 +371,10 @@ global_vision_llm_configs:
|
|||
# - system_instructions: Custom prompt or empty string to use defaults
|
||||
# - use_default_system_instructions: true = use SURFSENSE_SYSTEM_INSTRUCTIONS when system_instructions is empty
|
||||
# - citations_enabled: true = include citation instructions, false = include anti-citation instructions
|
||||
# - tier_required: "free" | "pro" | "enterprise" — subscription tier needed to use this model.
|
||||
# If omitted, tier is inferred from model_name via pattern matching (fragile).
|
||||
# - cost_per_1k_input_tokens / cost_per_1k_output_tokens: Optional cost metadata for display.
|
||||
# Not used for billing (token quota is flat), but shown in the UI for transparency.
|
||||
# - All standard LiteLLM providers are supported
|
||||
# - rpm/tpm: Optional rate limits for load balancing (requests/tokens per minute)
|
||||
# These help the router distribute load evenly and avoid rate limit errors
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@ from sqlalchemy import (
|
|||
TIMESTAMP,
|
||||
Boolean,
|
||||
Column,
|
||||
Date,
|
||||
Enum as SQLAlchemyEnum,
|
||||
ForeignKey,
|
||||
Index,
|
||||
|
|
@ -320,6 +321,13 @@ class PagePurchaseStatus(StrEnum):
|
|||
FAILED = "failed"
|
||||
|
||||
|
||||
class SubscriptionStatus(StrEnum):
|
||||
FREE = "free"
|
||||
ACTIVE = "active"
|
||||
CANCELED = "canceled"
|
||||
PAST_DUE = "past_due"
|
||||
|
||||
|
||||
# Centralized configuration for incentive tasks
|
||||
# This makes it easy to add new tasks without changing code in multiple places
|
||||
INCENTIVE_TASKS_CONFIG = {
|
||||
|
|
@ -1955,6 +1963,20 @@ if config.AUTH_TYPE == "GOOGLE":
|
|||
)
|
||||
pages_used = Column(Integer, nullable=False, default=0, server_default="0")
|
||||
|
||||
# Subscription and token quota (cloud mode)
|
||||
monthly_token_limit = Column(Integer, nullable=False, default=100000, server_default="100000")
|
||||
tokens_used_this_month = Column(Integer, nullable=False, default=0, server_default="0")
|
||||
token_reset_date = Column(Date, nullable=True)
|
||||
subscription_status = Column(
|
||||
SQLAlchemyEnum(SubscriptionStatus, name="subscriptionstatus", create_type=True),
|
||||
nullable=False,
|
||||
default=SubscriptionStatus.FREE,
|
||||
server_default="free",
|
||||
)
|
||||
plan_id = Column(String(50), nullable=False, default="free", server_default="free")
|
||||
stripe_customer_id = Column(String(255), nullable=True, unique=True)
|
||||
stripe_subscription_id = Column(String(255), nullable=True, unique=True)
|
||||
|
||||
# User profile from OAuth
|
||||
display_name = Column(String, nullable=True)
|
||||
avatar_url = Column(String, nullable=True)
|
||||
|
|
@ -2069,6 +2091,20 @@ else:
|
|||
)
|
||||
pages_used = Column(Integer, nullable=False, default=0, server_default="0")
|
||||
|
||||
# Subscription and token quota (cloud mode)
|
||||
monthly_token_limit = Column(Integer, nullable=False, default=100000, server_default="100000")
|
||||
tokens_used_this_month = Column(Integer, nullable=False, default=0, server_default="0")
|
||||
token_reset_date = Column(Date, nullable=True)
|
||||
subscription_status = Column(
|
||||
SQLAlchemyEnum(SubscriptionStatus, name="subscriptionstatus", create_type=True),
|
||||
nullable=False,
|
||||
default=SubscriptionStatus.FREE,
|
||||
server_default="free",
|
||||
)
|
||||
plan_id = Column(String(50), nullable=False, default="free", server_default="free")
|
||||
stripe_customer_id = Column(String(255), nullable=True, unique=True)
|
||||
stripe_subscription_id = Column(String(255), nullable=True, unique=True)
|
||||
|
||||
# User profile (can be set manually for non-OAuth users)
|
||||
display_name = Column(String, nullable=True)
|
||||
avatar_url = Column(String, nullable=True)
|
||||
|
|
|
|||
|
|
@ -3,6 +3,9 @@ API route for fetching the available models catalogue.
|
|||
|
||||
Serves a dynamically-updated list sourced from the OpenRouter public API,
|
||||
with a local JSON fallback when the API is unreachable.
|
||||
|
||||
Also exposes a /models/system endpoint that returns the system-managed models
|
||||
from global_llm_config.yaml for use in cloud/hosted mode (no BYOK).
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
|
@ -10,6 +13,7 @@ import logging
|
|||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.config import config
|
||||
from app.db import User
|
||||
from app.services.model_list_service import get_model_list
|
||||
from app.users import current_active_user
|
||||
|
|
@ -25,12 +29,81 @@ class ModelListItem(BaseModel):
|
|||
context_window: str | None = None
|
||||
|
||||
|
||||
class SystemModelItem(BaseModel):
|
||||
"""A system-managed model available in cloud mode."""
|
||||
|
||||
id: int # Negative ID from global_llm_config.yaml (e.g. -1, -2)
|
||||
name: str
|
||||
description: str | None = None
|
||||
provider: str
|
||||
model_name: str
|
||||
tier_required: str = "free" # "free" | "pro" | "enterprise"
|
||||
|
||||
|
||||
def _get_tier_for_model(provider: str, model_name: str) -> str:
|
||||
"""
|
||||
Derive the subscription tier required to use a given model.
|
||||
|
||||
Rules (adjust as pricing plans are defined):
|
||||
- GPT-4 class, Claude 3 Opus, Gemini Ultra → pro
|
||||
- Everything else → free
|
||||
"""
|
||||
model_lower = model_name.lower()
|
||||
|
||||
# Pro-tier models: high-capability / expensive models
|
||||
pro_patterns = [
|
||||
"gpt-4",
|
||||
"claude-3-opus",
|
||||
"claude-3-5-sonnet",
|
||||
"claude-3-7-sonnet",
|
||||
"gemini-1.5-pro",
|
||||
"gemini-2.0-pro",
|
||||
"gemini-2.5-pro",
|
||||
"llama3-70b",
|
||||
"llama-3-70b",
|
||||
"mistral-large",
|
||||
]
|
||||
for pattern in pro_patterns:
|
||||
if pattern in model_lower:
|
||||
return "pro"
|
||||
|
||||
return "free"
|
||||
|
||||
|
||||
def get_tier_for_model_id(model_id: int) -> str:
|
||||
"""
|
||||
Look up the tier_required for a given system model ID.
|
||||
|
||||
Used by chat routes to enforce tier at request time.
|
||||
Prefers explicit `tier_required` from YAML; falls back to pattern matching.
|
||||
|
||||
Returns:
|
||||
The tier string ("free", "pro", "enterprise") or "free" if not found.
|
||||
"""
|
||||
global_configs = config.GLOBAL_LLM_CONFIGS
|
||||
if not global_configs:
|
||||
return "free"
|
||||
|
||||
for cfg in global_configs:
|
||||
if cfg.get("id") == model_id:
|
||||
# Prefer explicit tier from YAML config
|
||||
explicit_tier = cfg.get("tier_required")
|
||||
if explicit_tier:
|
||||
return str(explicit_tier).lower()
|
||||
# Fall back to pattern-based inference
|
||||
provider = str(cfg.get("provider", "UNKNOWN"))
|
||||
model_name = str(cfg.get("model_name", ""))
|
||||
return _get_tier_for_model(provider, model_name)
|
||||
|
||||
return "free"
|
||||
|
||||
|
||||
@router.get("/models", response_model=list[ModelListItem])
|
||||
async def list_available_models(
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""
|
||||
Return all available models grouped by provider.
|
||||
Return all available models grouped by provider (BYOK / self-hosted mode).
|
||||
|
||||
The list is sourced from the OpenRouter public API and cached for 1 hour.
|
||||
If the API is unreachable, a local fallback file is used instead.
|
||||
|
|
@ -42,3 +115,51 @@ async def list_available_models(
|
|||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to fetch model list: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
@router.get("/models/system", response_model=list[SystemModelItem])
|
||||
async def list_system_models(
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""
|
||||
Return system-managed models from global_llm_config.yaml (cloud mode).
|
||||
|
||||
Models are annotated with a `tier_required` field so the frontend can
|
||||
show which models require a paid subscription plan. The caller's current
|
||||
subscription status is NOT checked here — enforcement happens at chat time.
|
||||
|
||||
Only available in cloud mode.
|
||||
"""
|
||||
if not config.is_cloud():
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="System models are only available in cloud mode.",
|
||||
)
|
||||
global_configs = config.GLOBAL_LLM_CONFIGS
|
||||
if not global_configs:
|
||||
return []
|
||||
|
||||
items: list[SystemModelItem] = []
|
||||
for cfg in global_configs:
|
||||
cfg_id = cfg.get("id")
|
||||
if cfg_id is None or cfg_id >= 0:
|
||||
# Skip auto-mode (0) and any mistakenly positive IDs
|
||||
continue
|
||||
|
||||
provider = str(cfg.get("provider", "UNKNOWN"))
|
||||
model_name = str(cfg.get("model_name", ""))
|
||||
# Prefer explicit tier from YAML; fall back to pattern matching
|
||||
explicit_tier = cfg.get("tier_required")
|
||||
tier = str(explicit_tier).lower() if explicit_tier else _get_tier_for_model(provider, model_name)
|
||||
items.append(
|
||||
SystemModelItem(
|
||||
id=cfg_id,
|
||||
name=str(cfg.get("name", model_name)),
|
||||
description=cfg.get("description"),
|
||||
provider=provider,
|
||||
model_name=model_name,
|
||||
tier_required=tier,
|
||||
)
|
||||
)
|
||||
|
||||
return items
|
||||
|
|
|
|||
|
|
@ -51,6 +51,9 @@ from app.schemas.new_chat import (
|
|||
ThreadListItem,
|
||||
ThreadListResponse,
|
||||
)
|
||||
from app.config import config
|
||||
from app.routes.model_list_routes import get_tier_for_model_id
|
||||
from app.services.token_quota_service import TokenQuotaExceededError, TokenQuotaService
|
||||
from app.tasks.chat.stream_new_chat import stream_new_chat, stream_resume_chat
|
||||
from app.users import current_active_user
|
||||
from app.utils.rbac import check_permission
|
||||
|
|
@ -1112,6 +1115,47 @@ async def handle_new_chat(
|
|||
search_space.agent_llm_id if search_space.agent_llm_id is not None else -1
|
||||
)
|
||||
|
||||
# Cloud mode: allow frontend to override with a system model selection
|
||||
# Security: only negative IDs (system models from YAML) are allowed in cloud mode
|
||||
if config.is_cloud() and request.model_id is not None:
|
||||
if request.model_id > 0:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="Custom LLM configurations are not allowed in cloud mode. Use system models only.",
|
||||
)
|
||||
llm_config_id = request.model_id
|
||||
|
||||
# Enforce subscription tier for the selected model
|
||||
required_tier = get_tier_for_model_id(request.model_id)
|
||||
if required_tier == "pro" and hasattr(user, "subscription_status"):
|
||||
user_status = getattr(user, "subscription_status", None)
|
||||
if user_status is None or str(user_status) not in ("active",):
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail={
|
||||
"error": "tier_restricted",
|
||||
"message": f"This model requires a Pro subscription. Current status: {user_status}",
|
||||
"required_tier": required_tier,
|
||||
},
|
||||
)
|
||||
|
||||
# Cloud mode: enforce monthly token quota before streaming
|
||||
if config.is_cloud():
|
||||
try:
|
||||
token_quota_service = TokenQuotaService(session)
|
||||
await token_quota_service.check_token_quota(str(user.id))
|
||||
except TokenQuotaExceededError as exc:
|
||||
raise HTTPException(
|
||||
status_code=402,
|
||||
detail={
|
||||
"error": "token_quota_exceeded",
|
||||
"message": str(exc),
|
||||
"tokens_used": exc.tokens_used,
|
||||
"monthly_token_limit": exc.monthly_token_limit,
|
||||
"upgrade_url": "/pricing",
|
||||
},
|
||||
) from exc
|
||||
|
||||
# Release the read-transaction so we don't hold ACCESS SHARE locks
|
||||
# on searchspaces/documents for the entire duration of the stream.
|
||||
# expire_on_commit=False keeps loaded ORM attrs usable.
|
||||
|
|
@ -1349,6 +1393,47 @@ async def regenerate_response(
|
|||
search_space.agent_llm_id if search_space.agent_llm_id is not None else -1
|
||||
)
|
||||
|
||||
# Cloud mode: allow frontend to override with a system model selection
|
||||
# Security: only negative IDs (system models from YAML) are allowed in cloud mode
|
||||
if config.is_cloud() and request.model_id is not None:
|
||||
if request.model_id > 0:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="Custom LLM configurations are not allowed in cloud mode. Use system models only.",
|
||||
)
|
||||
llm_config_id = request.model_id
|
||||
|
||||
# Enforce subscription tier for the selected model
|
||||
required_tier = get_tier_for_model_id(request.model_id)
|
||||
if required_tier == "pro" and hasattr(user, "subscription_status"):
|
||||
user_status = getattr(user, "subscription_status", None)
|
||||
if user_status is None or str(user_status) not in ("active",):
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail={
|
||||
"error": "tier_restricted",
|
||||
"message": f"This model requires a Pro subscription. Current status: {user_status}",
|
||||
"required_tier": required_tier,
|
||||
},
|
||||
)
|
||||
|
||||
# Cloud mode: enforce monthly token quota before streaming
|
||||
if config.is_cloud():
|
||||
try:
|
||||
token_quota_service = TokenQuotaService(session)
|
||||
await token_quota_service.check_token_quota(str(user.id))
|
||||
except TokenQuotaExceededError as exc:
|
||||
raise HTTPException(
|
||||
status_code=402,
|
||||
detail={
|
||||
"error": "token_quota_exceeded",
|
||||
"message": str(exc),
|
||||
"tokens_used": exc.tokens_used,
|
||||
"monthly_token_limit": exc.monthly_token_limit,
|
||||
"upgrade_url": "/pricing",
|
||||
},
|
||||
) from exc
|
||||
|
||||
# Release the read-transaction so we don't hold ACCESS SHARE locks
|
||||
# on searchspaces/documents for the entire duration of the stream.
|
||||
# expire_on_commit=False keeps loaded ORM attrs (including messages_to_delete PKs) usable.
|
||||
|
|
@ -1472,6 +1557,47 @@ async def resume_chat(
|
|||
search_space.agent_llm_id if search_space.agent_llm_id is not None else -1
|
||||
)
|
||||
|
||||
# Cloud mode: allow frontend to override with a system model selection
|
||||
# Security: only negative IDs (system models from YAML) are allowed in cloud mode
|
||||
if config.is_cloud() and request.model_id is not None:
|
||||
if request.model_id > 0:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="Custom LLM configurations are not allowed in cloud mode. Use system models only.",
|
||||
)
|
||||
llm_config_id = request.model_id
|
||||
|
||||
# Enforce subscription tier for the selected model
|
||||
required_tier = get_tier_for_model_id(request.model_id)
|
||||
if required_tier == "pro" and hasattr(user, "subscription_status"):
|
||||
user_status = getattr(user, "subscription_status", None)
|
||||
if user_status is None or str(user_status) not in ("active",):
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail={
|
||||
"error": "tier_restricted",
|
||||
"message": f"This model requires a Pro subscription. Current status: {user_status}",
|
||||
"required_tier": required_tier,
|
||||
},
|
||||
)
|
||||
|
||||
# Cloud mode: enforce monthly token quota before streaming
|
||||
if config.is_cloud():
|
||||
try:
|
||||
token_quota_service = TokenQuotaService(session)
|
||||
await token_quota_service.check_token_quota(str(user.id))
|
||||
except TokenQuotaExceededError as exc:
|
||||
raise HTTPException(
|
||||
status_code=402,
|
||||
detail={
|
||||
"error": "token_quota_exceeded",
|
||||
"message": str(exc),
|
||||
"tokens_used": exc.tokens_used,
|
||||
"monthly_token_limit": exc.monthly_token_limit,
|
||||
"upgrade_url": "/pricing",
|
||||
},
|
||||
) from exc
|
||||
|
||||
decisions = [d.model_dump() for d in request.decisions]
|
||||
|
||||
# Release the read-transaction so we don't hold ACCESS SHARE locks
|
||||
|
|
|
|||
|
|
@ -175,6 +175,10 @@ class NewChatRequest(BaseModel):
|
|||
disabled_tools: list[str] | None = (
|
||||
None # Optional list of tool names the user has disabled from the UI
|
||||
)
|
||||
# Cloud mode: override the search space's agent_llm_id with a system model
|
||||
# (negative ID from global_llm_config.yaml, selected via SystemModelSelector).
|
||||
# Self-hosted mode: leave None and the search space config is used as before.
|
||||
model_id: int | None = None
|
||||
|
||||
|
||||
class RegenerateRequest(BaseModel):
|
||||
|
|
@ -195,6 +199,7 @@ class RegenerateRequest(BaseModel):
|
|||
mentioned_document_ids: list[int] | None = None
|
||||
mentioned_surfsense_doc_ids: list[int] | None = None
|
||||
disabled_tools: list[str] | None = None
|
||||
model_id: int | None = None # Cloud mode: override with system model ID
|
||||
|
||||
|
||||
# =============================================================================
|
||||
|
|
@ -218,6 +223,7 @@ class ResumeDecision(BaseModel):
|
|||
class ResumeRequest(BaseModel):
|
||||
search_space_id: int
|
||||
decisions: list[ResumeDecision]
|
||||
model_id: int | None = None # Cloud mode: override with system model ID
|
||||
|
||||
|
||||
# =============================================================================
|
||||
|
|
|
|||
189
surfsense_backend/app/services/token_quota_service.py
Normal file
189
surfsense_backend/app/services/token_quota_service.py
Normal file
|
|
@ -0,0 +1,189 @@
|
|||
"""
|
||||
Service for managing user LLM token quotas (cloud subscription mode).
|
||||
|
||||
Mirrors PageLimitService pattern for consistency.
|
||||
"""
|
||||
|
||||
from datetime import UTC, date, datetime, timedelta
|
||||
|
||||
from sqlalchemy import select, update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
|
||||
class TokenQuotaExceededError(Exception):
|
||||
"""
|
||||
Exception raised when a user exceeds their monthly token quota.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Monthly token quota exceeded. Please upgrade your plan.",
|
||||
tokens_used: int = 0,
|
||||
monthly_token_limit: int = 0,
|
||||
tokens_requested: int = 0,
|
||||
):
|
||||
self.tokens_used = tokens_used
|
||||
self.monthly_token_limit = monthly_token_limit
|
||||
self.tokens_requested = tokens_requested
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class TokenQuotaService:
|
||||
"""Service for checking and updating user LLM token quotas."""
|
||||
|
||||
def __init__(self, session: AsyncSession):
|
||||
self.session = session
|
||||
|
||||
async def _maybe_reset_monthly_tokens(self, user) -> None:
|
||||
"""
|
||||
Reset tokens_used_this_month to 0 if token_reset_date has passed.
|
||||
|
||||
Called before any quota check or update so that a new billing cycle
|
||||
starts transparently without requiring a cron job or webhook trigger.
|
||||
|
||||
The token_reset_date is a Date column. We compare against UTC today.
|
||||
|
||||
NOTE: This method does NOT commit — the caller manages the transaction.
|
||||
"""
|
||||
today = datetime.now(UTC).date()
|
||||
|
||||
if not user.token_reset_date:
|
||||
# First time — set reset date 30 days from now
|
||||
user.token_reset_date = today + timedelta(days=30)
|
||||
user.tokens_used_this_month = 0
|
||||
return
|
||||
|
||||
reset_date = user.token_reset_date
|
||||
# Handle if somehow stored as a string (legacy data)
|
||||
if isinstance(reset_date, str):
|
||||
try:
|
||||
reset_date = date.fromisoformat(reset_date)
|
||||
except ValueError:
|
||||
reset_date = today + timedelta(days=30)
|
||||
|
||||
if today >= reset_date:
|
||||
# New billing cycle — reset usage and advance reset date by 30 days
|
||||
new_reset = reset_date + timedelta(days=30)
|
||||
user.tokens_used_this_month = 0
|
||||
user.token_reset_date = new_reset
|
||||
|
||||
async def check_token_quota(
|
||||
self, user_id: str, estimated_tokens: int = 0
|
||||
) -> tuple[bool, int, int]:
|
||||
"""
|
||||
Check if user has remaining token quota this month.
|
||||
|
||||
Args:
|
||||
user_id: The user's UUID (string)
|
||||
estimated_tokens: Optional pre-estimated input token count
|
||||
|
||||
Returns:
|
||||
Tuple of (has_capacity, tokens_used, monthly_token_limit)
|
||||
|
||||
Raises:
|
||||
TokenQuotaExceededError: If user would exceed their monthly limit
|
||||
"""
|
||||
from app.db import User
|
||||
|
||||
result = await self.session.execute(select(User).where(User.id == user_id))
|
||||
user = result.unique().scalar_one_or_none()
|
||||
|
||||
if not user:
|
||||
raise ValueError(f"User with ID {user_id} not found")
|
||||
|
||||
await self._maybe_reset_monthly_tokens(user)
|
||||
await self.session.flush() # Persist any reset changes within the transaction
|
||||
|
||||
tokens_used = user.tokens_used_this_month or 0
|
||||
token_limit = user.monthly_token_limit or 0
|
||||
|
||||
# Strict boundary: >= means at-limit is also exceeded
|
||||
if tokens_used + estimated_tokens >= token_limit and token_limit > 0:
|
||||
raise TokenQuotaExceededError(
|
||||
message=(
|
||||
f"Monthly token quota exceeded. "
|
||||
f"Used: {tokens_used:,}/{token_limit:,} tokens. "
|
||||
f"Estimated request: {estimated_tokens:,} tokens. "
|
||||
f"Please upgrade your subscription plan."
|
||||
),
|
||||
tokens_used=tokens_used,
|
||||
monthly_token_limit=token_limit,
|
||||
tokens_requested=estimated_tokens,
|
||||
)
|
||||
|
||||
return True, tokens_used, token_limit
|
||||
|
||||
async def update_token_usage(
|
||||
self, user_id: str, tokens_to_add: int, allow_exceed: bool = True
|
||||
) -> int:
|
||||
"""
|
||||
Atomically add tokens consumed to the user's monthly usage.
|
||||
|
||||
Uses a single SQL UPDATE with arithmetic expression to prevent
|
||||
race conditions when multiple streams finish concurrently.
|
||||
|
||||
Args:
|
||||
user_id: The user's UUID (string)
|
||||
tokens_to_add: Actual tokens consumed (input + output)
|
||||
allow_exceed: If True (default), records usage even if it pushes
|
||||
past the limit. Set False to enforce hard cap at
|
||||
update time (pre-check should already have fired).
|
||||
|
||||
Returns:
|
||||
New total tokens_used_this_month value
|
||||
"""
|
||||
from app.db import User
|
||||
|
||||
if tokens_to_add <= 0:
|
||||
# Nothing to deduct — fetch current usage and return
|
||||
result = await self.session.execute(
|
||||
select(User.tokens_used_this_month).where(User.id == user_id)
|
||||
)
|
||||
row = result.first()
|
||||
if not row:
|
||||
raise ValueError(f"User with ID {user_id} not found")
|
||||
return row[0] or 0
|
||||
|
||||
# Atomic UPDATE: tokens_used = tokens_used + N (no read-modify-write)
|
||||
stmt = (
|
||||
update(User)
|
||||
.where(User.id == user_id)
|
||||
.values(tokens_used_this_month=User.tokens_used_this_month + tokens_to_add)
|
||||
.returning(User.tokens_used_this_month)
|
||||
)
|
||||
result = await self.session.execute(stmt)
|
||||
row = result.first()
|
||||
|
||||
if not row:
|
||||
raise ValueError(f"User with ID {user_id} not found")
|
||||
|
||||
new_usage = row[0]
|
||||
await self.session.commit()
|
||||
|
||||
return new_usage
|
||||
|
||||
async def get_token_usage(self, user_id: str) -> tuple[int, int]:
|
||||
"""
|
||||
Get user's current token usage and monthly limit.
|
||||
|
||||
Also triggers monthly reset check so the returned values
|
||||
are always for the current billing cycle.
|
||||
|
||||
Args:
|
||||
user_id: The user's UUID (string)
|
||||
|
||||
Returns:
|
||||
Tuple of (tokens_used_this_month, monthly_token_limit)
|
||||
"""
|
||||
from app.db import User
|
||||
|
||||
result = await self.session.execute(select(User).where(User.id == user_id))
|
||||
user = result.unique().scalar_one_or_none()
|
||||
|
||||
if not user:
|
||||
raise ValueError(f"User with ID {user_id} not found")
|
||||
|
||||
await self._maybe_reset_monthly_tokens(user)
|
||||
await self.session.flush()
|
||||
|
||||
return (user.tokens_used_this_month or 0, user.monthly_token_limit or 0)
|
||||
|
|
@ -41,6 +41,7 @@ from app.agents.new_chat.memory_extraction import (
|
|||
extract_and_save_memory,
|
||||
extract_and_save_team_memory,
|
||||
)
|
||||
from app.config import config as app_config
|
||||
from app.db import (
|
||||
ChatVisibility,
|
||||
NewChatMessage,
|
||||
|
|
@ -144,6 +145,7 @@ class StreamResult:
|
|||
interrupt_value: dict[str, Any] | None = None
|
||||
sandbox_files: list[str] = field(default_factory=list) # unused, kept for compat
|
||||
agent_called_update_memory: bool = False
|
||||
total_tokens_used: int = 0 # Accumulated across all LLM calls in the stream
|
||||
|
||||
|
||||
async def _stream_agent_events(
|
||||
|
|
@ -1105,6 +1107,27 @@ async def _stream_agent_events(
|
|||
},
|
||||
)
|
||||
|
||||
elif event_type == "on_chat_model_end":
|
||||
# Accumulate token counts for quota tracking (cloud mode)
|
||||
output = event.get("data", {}).get("output")
|
||||
if output is not None:
|
||||
usage = None
|
||||
if hasattr(output, "usage_metadata") and output.usage_metadata is not None:
|
||||
usage = output.usage_metadata
|
||||
elif hasattr(output, "response_metadata") and output.response_metadata is not None:
|
||||
rm = output.response_metadata or {}
|
||||
usage = rm.get("usage") or rm.get("token_usage") or rm.get("usage_metadata")
|
||||
|
||||
if isinstance(usage, dict):
|
||||
total = (
|
||||
usage.get("total_tokens")
|
||||
or (usage.get("input_tokens", 0) + usage.get("output_tokens", 0))
|
||||
or (usage.get("prompt_tokens", 0) + usage.get("completion_tokens", 0))
|
||||
)
|
||||
result.total_tokens_used += total or 0
|
||||
elif usage is not None and hasattr(usage, "total_tokens"):
|
||||
result.total_tokens_used += getattr(usage, "total_tokens", 0) or 0
|
||||
|
||||
elif event_type in ("on_chain_end", "on_agent_end"):
|
||||
if current_text_id is not None:
|
||||
yield streaming_service.format_text_end(current_text_id)
|
||||
|
|
@ -1569,6 +1592,22 @@ async def stream_new_chat(
|
|||
)
|
||||
)
|
||||
|
||||
# Cloud mode: deduct consumed tokens from the user's monthly quota
|
||||
if app_config.is_cloud() and user_id and stream_result.total_tokens_used > 0:
|
||||
try:
|
||||
async with shielded_async_session() as quota_session:
|
||||
from app.services.token_quota_service import TokenQuotaService
|
||||
|
||||
quota_service = TokenQuotaService(quota_session)
|
||||
await quota_service.update_token_usage(
|
||||
user_id, stream_result.total_tokens_used, allow_exceed=True
|
||||
)
|
||||
except Exception as quota_err:
|
||||
# Non-fatal — log and continue; usage was already streamed
|
||||
logging.getLogger(__name__).warning(
|
||||
"[stream_new_chat] Failed to record token usage: %s", quota_err
|
||||
)
|
||||
|
||||
# Finish the step and message
|
||||
yield streaming_service.format_finish_step()
|
||||
yield streaming_service.format_finish()
|
||||
|
|
@ -1778,6 +1817,22 @@ async def stream_resume_chat(
|
|||
yield streaming_service.format_finish()
|
||||
yield streaming_service.format_done()
|
||||
|
||||
# Cloud mode: deduct consumed tokens from the user's monthly quota
|
||||
if app_config.is_cloud() and user_id and stream_result.total_tokens_used > 0:
|
||||
try:
|
||||
async with shielded_async_session() as quota_session:
|
||||
from app.services.token_quota_service import TokenQuotaService
|
||||
|
||||
quota_service = TokenQuotaService(quota_session)
|
||||
await quota_service.update_token_usage(
|
||||
user_id, stream_result.total_tokens_used, allow_exceed=True
|
||||
)
|
||||
except Exception as quota_err:
|
||||
# Non-fatal — log and continue; usage was already streamed
|
||||
logging.getLogger(__name__).warning(
|
||||
"[stream_resume_chat] Failed to record token usage: %s", quota_err
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
import traceback
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue