mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-09 07:42:39 +02:00
feat(story-3.5): add cloud-mode LLM model selection with token quota enforcement
Implement system-managed model catalog, subscription tier enforcement, atomic token quota tracking, and frontend cloud/self-hosted conditional rendering. Apply all 20 BMAD code review patches including security fixes (cross-user API key hijack), race condition mitigation (atomic SQL UPDATE), and SSE mid-stream quota error handling. Co-Authored-By: Claude Sonnet 4 <noreply@anthropic.com>
This commit is contained in:
parent
e7382b26de
commit
c1776b3ec8
19 changed files with 1003 additions and 34 deletions
|
|
@ -3,6 +3,9 @@ API route for fetching the available models catalogue.
|
|||
|
||||
Serves a dynamically-updated list sourced from the OpenRouter public API,
|
||||
with a local JSON fallback when the API is unreachable.
|
||||
|
||||
Also exposes a /models/system endpoint that returns the system-managed models
|
||||
from global_llm_config.yaml for use in cloud/hosted mode (no BYOK).
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
|
@ -10,6 +13,7 @@ import logging
|
|||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.config import config
|
||||
from app.db import User
|
||||
from app.services.model_list_service import get_model_list
|
||||
from app.users import current_active_user
|
||||
|
|
@ -25,12 +29,81 @@ class ModelListItem(BaseModel):
|
|||
context_window: str | None = None
|
||||
|
||||
|
||||
class SystemModelItem(BaseModel):
|
||||
"""A system-managed model available in cloud mode."""
|
||||
|
||||
id: int # Negative ID from global_llm_config.yaml (e.g. -1, -2)
|
||||
name: str
|
||||
description: str | None = None
|
||||
provider: str
|
||||
model_name: str
|
||||
tier_required: str = "free" # "free" | "pro" | "enterprise"
|
||||
|
||||
|
||||
def _get_tier_for_model(provider: str, model_name: str) -> str:
|
||||
"""
|
||||
Derive the subscription tier required to use a given model.
|
||||
|
||||
Rules (adjust as pricing plans are defined):
|
||||
- GPT-4 class, Claude 3 Opus, Gemini Ultra → pro
|
||||
- Everything else → free
|
||||
"""
|
||||
model_lower = model_name.lower()
|
||||
|
||||
# Pro-tier models: high-capability / expensive models
|
||||
pro_patterns = [
|
||||
"gpt-4",
|
||||
"claude-3-opus",
|
||||
"claude-3-5-sonnet",
|
||||
"claude-3-7-sonnet",
|
||||
"gemini-1.5-pro",
|
||||
"gemini-2.0-pro",
|
||||
"gemini-2.5-pro",
|
||||
"llama3-70b",
|
||||
"llama-3-70b",
|
||||
"mistral-large",
|
||||
]
|
||||
for pattern in pro_patterns:
|
||||
if pattern in model_lower:
|
||||
return "pro"
|
||||
|
||||
return "free"
|
||||
|
||||
|
||||
def get_tier_for_model_id(model_id: int) -> str:
|
||||
"""
|
||||
Look up the tier_required for a given system model ID.
|
||||
|
||||
Used by chat routes to enforce tier at request time.
|
||||
Prefers explicit `tier_required` from YAML; falls back to pattern matching.
|
||||
|
||||
Returns:
|
||||
The tier string ("free", "pro", "enterprise") or "free" if not found.
|
||||
"""
|
||||
global_configs = config.GLOBAL_LLM_CONFIGS
|
||||
if not global_configs:
|
||||
return "free"
|
||||
|
||||
for cfg in global_configs:
|
||||
if cfg.get("id") == model_id:
|
||||
# Prefer explicit tier from YAML config
|
||||
explicit_tier = cfg.get("tier_required")
|
||||
if explicit_tier:
|
||||
return str(explicit_tier).lower()
|
||||
# Fall back to pattern-based inference
|
||||
provider = str(cfg.get("provider", "UNKNOWN"))
|
||||
model_name = str(cfg.get("model_name", ""))
|
||||
return _get_tier_for_model(provider, model_name)
|
||||
|
||||
return "free"
|
||||
|
||||
|
||||
@router.get("/models", response_model=list[ModelListItem])
|
||||
async def list_available_models(
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""
|
||||
Return all available models grouped by provider.
|
||||
Return all available models grouped by provider (BYOK / self-hosted mode).
|
||||
|
||||
The list is sourced from the OpenRouter public API and cached for 1 hour.
|
||||
If the API is unreachable, a local fallback file is used instead.
|
||||
|
|
@ -42,3 +115,51 @@ async def list_available_models(
|
|||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to fetch model list: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
@router.get("/models/system", response_model=list[SystemModelItem])
|
||||
async def list_system_models(
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""
|
||||
Return system-managed models from global_llm_config.yaml (cloud mode).
|
||||
|
||||
Models are annotated with a `tier_required` field so the frontend can
|
||||
show which models require a paid subscription plan. The caller's current
|
||||
subscription status is NOT checked here — enforcement happens at chat time.
|
||||
|
||||
Only available in cloud mode.
|
||||
"""
|
||||
if not config.is_cloud():
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="System models are only available in cloud mode.",
|
||||
)
|
||||
global_configs = config.GLOBAL_LLM_CONFIGS
|
||||
if not global_configs:
|
||||
return []
|
||||
|
||||
items: list[SystemModelItem] = []
|
||||
for cfg in global_configs:
|
||||
cfg_id = cfg.get("id")
|
||||
if cfg_id is None or cfg_id >= 0:
|
||||
# Skip auto-mode (0) and any mistakenly positive IDs
|
||||
continue
|
||||
|
||||
provider = str(cfg.get("provider", "UNKNOWN"))
|
||||
model_name = str(cfg.get("model_name", ""))
|
||||
# Prefer explicit tier from YAML; fall back to pattern matching
|
||||
explicit_tier = cfg.get("tier_required")
|
||||
tier = str(explicit_tier).lower() if explicit_tier else _get_tier_for_model(provider, model_name)
|
||||
items.append(
|
||||
SystemModelItem(
|
||||
id=cfg_id,
|
||||
name=str(cfg.get("name", model_name)),
|
||||
description=cfg.get("description"),
|
||||
provider=provider,
|
||||
model_name=model_name,
|
||||
tier_required=tier,
|
||||
)
|
||||
)
|
||||
|
||||
return items
|
||||
|
|
|
|||
|
|
@ -51,6 +51,9 @@ from app.schemas.new_chat import (
|
|||
ThreadListItem,
|
||||
ThreadListResponse,
|
||||
)
|
||||
from app.config import config
|
||||
from app.routes.model_list_routes import get_tier_for_model_id
|
||||
from app.services.token_quota_service import TokenQuotaExceededError, TokenQuotaService
|
||||
from app.tasks.chat.stream_new_chat import stream_new_chat, stream_resume_chat
|
||||
from app.users import current_active_user
|
||||
from app.utils.rbac import check_permission
|
||||
|
|
@ -1112,6 +1115,47 @@ async def handle_new_chat(
|
|||
search_space.agent_llm_id if search_space.agent_llm_id is not None else -1
|
||||
)
|
||||
|
||||
# Cloud mode: allow frontend to override with a system model selection
|
||||
# Security: only negative IDs (system models from YAML) are allowed in cloud mode
|
||||
if config.is_cloud() and request.model_id is not None:
|
||||
if request.model_id > 0:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="Custom LLM configurations are not allowed in cloud mode. Use system models only.",
|
||||
)
|
||||
llm_config_id = request.model_id
|
||||
|
||||
# Enforce subscription tier for the selected model
|
||||
required_tier = get_tier_for_model_id(request.model_id)
|
||||
if required_tier == "pro" and hasattr(user, "subscription_status"):
|
||||
user_status = getattr(user, "subscription_status", None)
|
||||
if user_status is None or str(user_status) not in ("active",):
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail={
|
||||
"error": "tier_restricted",
|
||||
"message": f"This model requires a Pro subscription. Current status: {user_status}",
|
||||
"required_tier": required_tier,
|
||||
},
|
||||
)
|
||||
|
||||
# Cloud mode: enforce monthly token quota before streaming
|
||||
if config.is_cloud():
|
||||
try:
|
||||
token_quota_service = TokenQuotaService(session)
|
||||
await token_quota_service.check_token_quota(str(user.id))
|
||||
except TokenQuotaExceededError as exc:
|
||||
raise HTTPException(
|
||||
status_code=402,
|
||||
detail={
|
||||
"error": "token_quota_exceeded",
|
||||
"message": str(exc),
|
||||
"tokens_used": exc.tokens_used,
|
||||
"monthly_token_limit": exc.monthly_token_limit,
|
||||
"upgrade_url": "/pricing",
|
||||
},
|
||||
) from exc
|
||||
|
||||
# Release the read-transaction so we don't hold ACCESS SHARE locks
|
||||
# on searchspaces/documents for the entire duration of the stream.
|
||||
# expire_on_commit=False keeps loaded ORM attrs usable.
|
||||
|
|
@ -1349,6 +1393,47 @@ async def regenerate_response(
|
|||
search_space.agent_llm_id if search_space.agent_llm_id is not None else -1
|
||||
)
|
||||
|
||||
# Cloud mode: allow frontend to override with a system model selection
|
||||
# Security: only negative IDs (system models from YAML) are allowed in cloud mode
|
||||
if config.is_cloud() and request.model_id is not None:
|
||||
if request.model_id > 0:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="Custom LLM configurations are not allowed in cloud mode. Use system models only.",
|
||||
)
|
||||
llm_config_id = request.model_id
|
||||
|
||||
# Enforce subscription tier for the selected model
|
||||
required_tier = get_tier_for_model_id(request.model_id)
|
||||
if required_tier == "pro" and hasattr(user, "subscription_status"):
|
||||
user_status = getattr(user, "subscription_status", None)
|
||||
if user_status is None or str(user_status) not in ("active",):
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail={
|
||||
"error": "tier_restricted",
|
||||
"message": f"This model requires a Pro subscription. Current status: {user_status}",
|
||||
"required_tier": required_tier,
|
||||
},
|
||||
)
|
||||
|
||||
# Cloud mode: enforce monthly token quota before streaming
|
||||
if config.is_cloud():
|
||||
try:
|
||||
token_quota_service = TokenQuotaService(session)
|
||||
await token_quota_service.check_token_quota(str(user.id))
|
||||
except TokenQuotaExceededError as exc:
|
||||
raise HTTPException(
|
||||
status_code=402,
|
||||
detail={
|
||||
"error": "token_quota_exceeded",
|
||||
"message": str(exc),
|
||||
"tokens_used": exc.tokens_used,
|
||||
"monthly_token_limit": exc.monthly_token_limit,
|
||||
"upgrade_url": "/pricing",
|
||||
},
|
||||
) from exc
|
||||
|
||||
# Release the read-transaction so we don't hold ACCESS SHARE locks
|
||||
# on searchspaces/documents for the entire duration of the stream.
|
||||
# expire_on_commit=False keeps loaded ORM attrs (including messages_to_delete PKs) usable.
|
||||
|
|
@ -1472,6 +1557,47 @@ async def resume_chat(
|
|||
search_space.agent_llm_id if search_space.agent_llm_id is not None else -1
|
||||
)
|
||||
|
||||
# Cloud mode: allow frontend to override with a system model selection
|
||||
# Security: only negative IDs (system models from YAML) are allowed in cloud mode
|
||||
if config.is_cloud() and request.model_id is not None:
|
||||
if request.model_id > 0:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="Custom LLM configurations are not allowed in cloud mode. Use system models only.",
|
||||
)
|
||||
llm_config_id = request.model_id
|
||||
|
||||
# Enforce subscription tier for the selected model
|
||||
required_tier = get_tier_for_model_id(request.model_id)
|
||||
if required_tier == "pro" and hasattr(user, "subscription_status"):
|
||||
user_status = getattr(user, "subscription_status", None)
|
||||
if user_status is None or str(user_status) not in ("active",):
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail={
|
||||
"error": "tier_restricted",
|
||||
"message": f"This model requires a Pro subscription. Current status: {user_status}",
|
||||
"required_tier": required_tier,
|
||||
},
|
||||
)
|
||||
|
||||
# Cloud mode: enforce monthly token quota before streaming
|
||||
if config.is_cloud():
|
||||
try:
|
||||
token_quota_service = TokenQuotaService(session)
|
||||
await token_quota_service.check_token_quota(str(user.id))
|
||||
except TokenQuotaExceededError as exc:
|
||||
raise HTTPException(
|
||||
status_code=402,
|
||||
detail={
|
||||
"error": "token_quota_exceeded",
|
||||
"message": str(exc),
|
||||
"tokens_used": exc.tokens_used,
|
||||
"monthly_token_limit": exc.monthly_token_limit,
|
||||
"upgrade_url": "/pricing",
|
||||
},
|
||||
) from exc
|
||||
|
||||
decisions = [d.model_dump() for d in request.decisions]
|
||||
|
||||
# Release the read-transaction so we don't hold ACCESS SHARE locks
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue