feat: no login experience and prem tokens
Some checks are pending
Build and Push Docker Images / tag_release (push) Waiting to run
Build and Push Docker Images / build (./surfsense_backend, ./surfsense_backend/Dockerfile, backend, surfsense-backend, ubuntu-24.04-arm, linux/arm64, arm64) (push) Blocked by required conditions
Build and Push Docker Images / build (./surfsense_backend, ./surfsense_backend/Dockerfile, backend, surfsense-backend, ubuntu-latest, linux/amd64, amd64) (push) Blocked by required conditions
Build and Push Docker Images / build (./surfsense_web, ./surfsense_web/Dockerfile, web, surfsense-web, ubuntu-24.04-arm, linux/arm64, arm64) (push) Blocked by required conditions
Build and Push Docker Images / build (./surfsense_web, ./surfsense_web/Dockerfile, web, surfsense-web, ubuntu-latest, linux/amd64, amd64) (push) Blocked by required conditions
Build and Push Docker Images / create_manifest (backend, surfsense-backend) (push) Blocked by required conditions
Build and Push Docker Images / create_manifest (web, surfsense-web) (push) Blocked by required conditions

This commit is contained in:
DESKTOP-RTLN3BA\$punk 2026-04-15 17:02:00 -07:00
parent 87452bb315
commit ff4e0f9b62
68 changed files with 5914 additions and 121 deletions

View file

@ -161,6 +161,7 @@ async def create_surfsense_deep_agent(
firecrawl_api_key: str | None = None,
thread_visibility: ChatVisibility | None = None,
mentioned_document_ids: list[int] | None = None,
anon_session_id: str | None = None,
):
"""
Create a SurfSense deep agent with configurable tools and prompts.
@ -463,6 +464,7 @@ async def create_surfsense_deep_agent(
available_connectors=available_connectors,
available_document_types=available_document_types,
mentioned_document_ids=mentioned_document_ids,
anon_session_id=anon_session_id,
),
SurfSenseFilesystemMiddleware(
search_space_id=search_space_id,

View file

@ -109,6 +109,12 @@ class AgentConfig:
# Auto mode flag
is_auto_mode: bool = False
# Token quota and policy
billing_tier: str = "free"
is_premium: bool = False
anonymous_enabled: bool = False
quota_reserve_tokens: int | None = None
@classmethod
def from_auto_mode(cls) -> "AgentConfig":
"""
@ -130,6 +136,10 @@ class AgentConfig:
config_id=AUTO_MODE_ID,
config_name="Auto (Fastest)",
is_auto_mode=True,
billing_tier="free",
is_premium=False,
anonymous_enabled=False,
quota_reserve_tokens=None,
)
@classmethod
@ -158,6 +168,10 @@ class AgentConfig:
config_id=config.id,
config_name=config.name,
is_auto_mode=False,
billing_tier="free",
is_premium=False,
anonymous_enabled=False,
quota_reserve_tokens=None,
)
@classmethod
@ -195,6 +209,10 @@ class AgentConfig:
config_id=yaml_config.get("id"),
config_name=yaml_config.get("name"),
is_auto_mode=False,
billing_tier=yaml_config.get("billing_tier", "free"),
is_premium=yaml_config.get("billing_tier", "free") == "premium",
anonymous_enabled=yaml_config.get("anonymous_enabled", False),
quota_reserve_tokens=yaml_config.get("quota_reserve_tokens"),
)

View file

@ -819,6 +819,34 @@ async def build_scoped_filesystem(
return files, doc_id_to_path
def _build_anon_scoped_filesystem(
documents: Sequence[dict[str, Any]],
) -> dict[str, dict[str, str]]:
"""Build a scoped filesystem for anonymous documents without DB queries.
Anonymous uploads have no folders, so all files go under /documents.
"""
files: dict[str, dict[str, str]] = {}
for document in documents:
doc_meta = document.get("document") or {}
title = str(doc_meta.get("title") or "untitled")
file_name = _safe_filename(title)
path = f"/documents/{file_name}"
if path in files:
doc_id = doc_meta.get("id", "dup")
stem = file_name.removesuffix(".xml")
path = f"/documents/{stem} ({doc_id}).xml"
matched_ids = set(document.get("matched_chunk_ids") or [])
xml_content = _build_document_xml(document, matched_chunk_ids=matched_ids)
files[path] = {
"content": xml_content.split("\n"),
"encoding": "utf-8",
"created_at": "",
"modified_at": "",
}
return files
class KnowledgeBaseSearchMiddleware(AgentMiddleware): # type: ignore[type-arg]
"""Pre-agent middleware that always searches the KB and seeds a scoped filesystem."""
@ -833,6 +861,7 @@ class KnowledgeBaseSearchMiddleware(AgentMiddleware): # type: ignore[type-arg]
available_document_types: list[str] | None = None,
top_k: int = 10,
mentioned_document_ids: list[int] | None = None,
anon_session_id: str | None = None,
) -> None:
self.llm = llm
self.search_space_id = search_space_id
@ -840,6 +869,7 @@ class KnowledgeBaseSearchMiddleware(AgentMiddleware): # type: ignore[type-arg]
self.available_document_types = available_document_types
self.top_k = top_k
self.mentioned_document_ids = mentioned_document_ids or []
self.anon_session_id = anon_session_id
async def _plan_search_inputs(
self,
@ -913,6 +943,50 @@ class KnowledgeBaseSearchMiddleware(AgentMiddleware): # type: ignore[type-arg]
pass
return asyncio.run(self.abefore_agent(state, runtime))
async def _load_anon_document(self) -> dict[str, Any] | None:
"""Load the anonymous user's uploaded document from Redis."""
if not self.anon_session_id:
return None
try:
import redis.asyncio as aioredis
from app.config import config
redis_client = aioredis.from_url(
config.REDIS_APP_URL, decode_responses=True
)
try:
redis_key = f"anon:doc:{self.anon_session_id}"
data = await redis_client.get(redis_key)
if not data:
return None
doc = json.loads(data)
return {
"document_id": -1,
"content": doc.get("content", ""),
"score": 1.0,
"chunks": [
{
"chunk_id": -1,
"content": doc.get("content", ""),
}
],
"matched_chunk_ids": [-1],
"document": {
"id": -1,
"title": doc.get("filename", "uploaded_document"),
"document_type": "FILE",
"metadata": {"source": "anonymous_upload"},
},
"source": "FILE",
"_user_mentioned": True,
}
finally:
await redis_client.aclose()
except Exception as exc:
logger.warning("Failed to load anonymous document from Redis: %s", exc)
return None
async def abefore_agent( # type: ignore[override]
self,
state: AgentState,
@ -937,6 +1011,35 @@ class KnowledgeBaseSearchMiddleware(AgentMiddleware): # type: ignore[type-arg]
t0 = _perf_log and asyncio.get_event_loop().time()
existing_files = state.get("files")
# --- Anonymous session: load Redis doc and skip DB queries ---
if self.anon_session_id:
merged: list[dict[str, Any]] = []
anon_doc = await self._load_anon_document()
if anon_doc:
merged.append(anon_doc)
if merged:
new_files = _build_anon_scoped_filesystem(merged)
mentioned_paths = set(new_files.keys())
else:
new_files = {}
mentioned_paths = set()
ai_msg, tool_msg = _build_synthetic_ls(
existing_files,
new_files,
mentioned_paths=mentioned_paths,
)
if t0 is not None:
_perf_log.info(
"[kb_fs_middleware] anon completed in %.3fs new_files=%d",
asyncio.get_event_loop().time() - t0,
len(new_files),
)
return {"files": new_files, "messages": [ai_msg, tool_msg]}
# --- Authenticated session: full KB search ---
(
planned_query,
start_date,
@ -954,8 +1057,6 @@ class KnowledgeBaseSearchMiddleware(AgentMiddleware): # type: ignore[type-arg]
document_ids=self.mentioned_document_ids,
search_space_id=self.search_space_id,
)
# Clear after first turn so they are not re-fetched on subsequent
# messages within the same agent instance.
self.mentioned_document_ids = []
# --- 2. Run KB search (recency browse or hybrid) ---
@ -983,26 +1084,24 @@ class KnowledgeBaseSearchMiddleware(AgentMiddleware): # type: ignore[type-arg]
# --- 3. Merge: mentioned first, then search (dedup by doc id) ---
seen_doc_ids: set[int] = set()
merged: list[dict[str, Any]] = []
merged_auth: list[dict[str, Any]] = []
for doc in mentioned_results:
doc_id = (doc.get("document") or {}).get("id")
if doc_id is not None:
seen_doc_ids.add(doc_id)
merged.append(doc)
merged_auth.append(doc)
for doc in search_results:
doc_id = (doc.get("document") or {}).get("id")
if doc_id is not None and doc_id in seen_doc_ids:
continue
merged.append(doc)
merged_auth.append(doc)
# --- 4. Build scoped filesystem ---
new_files, doc_id_to_path = await build_scoped_filesystem(
documents=merged,
documents=merged_auth,
search_space_id=self.search_space_id,
)
# Identify which paths belong to user-mentioned documents using
# the authoritative doc_id -> path mapping (no title guessing).
mentioned_doc_ids = {
(d.get("document") or {}).get("id") for d in mentioned_results
}

View file

@ -13,8 +13,6 @@ from fastapi import Depends, FastAPI, HTTPException, Request, status
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from limits.storage import MemoryStorage
from slowapi import Limiter
from slowapi.errors import RateLimitExceeded
from slowapi.middleware import SlowAPIMiddleware
from slowapi.util import get_remote_address
@ -36,6 +34,7 @@ from app.config import (
)
from app.db import User, create_db_and_tables, get_async_session
from app.exceptions import GENERIC_5XX_MESSAGE, ISSUES_URL, SurfSenseError
from app.rate_limiter import limiter
from app.routes import router as crud_router
from app.routes.auth_routes import router as auth_router
from app.schemas import UserCreate, UserRead, UserUpdate
@ -54,17 +53,7 @@ rate_limit_logger = logging.getLogger("surfsense.rate_limit")
# Uses the same Redis instance as Celery for zero additional infrastructure.
# Protects auth endpoints from brute force and user enumeration attacks.
# SlowAPI limiter — provides default rate limits (1024/min) for ALL routes
# via the ASGI middleware. This is the general safety net.
# in_memory_fallback ensures requests are still served (with per-worker
# in-memory limiting) when Redis is unreachable, instead of hanging.
limiter = Limiter(
key_func=get_remote_address,
storage_uri=config.REDIS_APP_URL,
default_limits=["1024/minute"],
in_memory_fallback_enabled=True,
in_memory_fallback=[MemoryStorage()],
)
# limiter is imported from app.rate_limiter (shared module to avoid circular imports)
def _get_request_id(request: Request) -> str:
@ -126,6 +115,39 @@ def _surfsense_error_handler(request: Request, exc: SurfSenseError) -> JSONRespo
def _http_exception_handler(request: Request, exc: HTTPException) -> JSONResponse:
"""Wrap FastAPI/Starlette HTTPExceptions into the standard envelope."""
rid = _get_request_id(request)
# Structured dict details (e.g. {"code": "CAPTCHA_REQUIRED", "message": "..."})
# are preserved so the frontend can parse them.
if isinstance(exc.detail, dict):
err_code = exc.detail.get("code", _status_to_code(exc.status_code))
message = exc.detail.get("message", str(exc.detail))
if exc.status_code >= 500:
_error_logger.error(
"[%s] %s - HTTPException %d: %s",
rid,
request.url.path,
exc.status_code,
message,
)
message = GENERIC_5XX_MESSAGE
err_code = "INTERNAL_ERROR"
body = {
"error": {
"code": err_code,
"message": message,
"status": exc.status_code,
"request_id": rid,
"timestamp": datetime.now(UTC).isoformat(),
"report_url": ISSUES_URL,
},
"detail": exc.detail,
}
return JSONResponse(
status_code=exc.status_code,
content=body,
headers={"X-Request-ID": rid},
)
detail = exc.detail if isinstance(exc.detail, str) else str(exc.detail)
if exc.status_code >= 500:
_error_logger.error(
@ -663,6 +685,13 @@ if config.AUTH_TYPE == "GOOGLE":
return response
# Anonymous (no-login) chat routes — mounted at /api/v1/public/anon-chat
from app.routes.anonymous_chat_routes import ( # noqa: E402
router as anonymous_chat_router,
)
app.include_router(anonymous_chat_router)
app.include_router(crud_router, prefix="/api/v1", tags=["crud"])

View file

@ -187,4 +187,11 @@ celery_app.conf.beat_schedule = {
"expires": 60,
},
},
"reconcile-pending-stripe-token-purchases": {
"task": "reconcile_pending_stripe_token_purchases",
"schedule": crontab(**stripe_reconciliation_schedule_params),
"options": {
"expires": 60,
},
},
}

View file

@ -42,7 +42,25 @@ def load_global_llm_configs():
try:
with open(global_config_file, encoding="utf-8") as f:
data = yaml.safe_load(f)
return data.get("global_llm_configs", [])
configs = data.get("global_llm_configs", [])
seen_slugs: dict[str, int] = {}
for cfg in configs:
cfg.setdefault("billing_tier", "free")
cfg.setdefault("anonymous_enabled", False)
cfg.setdefault("seo_enabled", False)
if cfg.get("seo_enabled") and cfg.get("seo_slug"):
slug = cfg["seo_slug"]
if slug in seen_slugs:
print(
f"Warning: Duplicate seo_slug '{slug}' in global LLM configs "
f"(ids {seen_slugs[slug]} and {cfg.get('id')})"
)
else:
seen_slugs[slug] = cfg.get("id", 0)
return configs
except Exception as e:
print(f"Warning: Failed to load global LLM configs: {e}")
return []
@ -307,6 +325,36 @@ class Config:
os.getenv("STRIPE_RECONCILIATION_BATCH_SIZE", "100")
)
# Premium token quota settings
PREMIUM_TOKEN_LIMIT = int(os.getenv("PREMIUM_TOKEN_LIMIT", "5000000"))
STRIPE_PREMIUM_TOKEN_PRICE_ID = os.getenv("STRIPE_PREMIUM_TOKEN_PRICE_ID")
STRIPE_TOKENS_PER_UNIT = int(os.getenv("STRIPE_TOKENS_PER_UNIT", "1000000"))
STRIPE_TOKEN_BUYING_ENABLED = (
os.getenv("STRIPE_TOKEN_BUYING_ENABLED", "FALSE").upper() == "TRUE"
)
# Anonymous / no-login mode settings
NOLOGIN_MODE_ENABLED = os.getenv("NOLOGIN_MODE_ENABLED", "FALSE").upper() == "TRUE"
ANON_TOKEN_LIMIT = int(os.getenv("ANON_TOKEN_LIMIT", "1000000"))
ANON_TOKEN_WARNING_THRESHOLD = int(
os.getenv("ANON_TOKEN_WARNING_THRESHOLD", "800000")
)
ANON_TOKEN_QUOTA_TTL_DAYS = int(os.getenv("ANON_TOKEN_QUOTA_TTL_DAYS", "30"))
ANON_MAX_UPLOAD_SIZE_MB = int(os.getenv("ANON_MAX_UPLOAD_SIZE_MB", "5"))
# Default quota reserve tokens when not specified per-model
QUOTA_MAX_RESERVE_PER_CALL = int(os.getenv("QUOTA_MAX_RESERVE_PER_CALL", "8000"))
# Abuse prevention: concurrent stream cap and CAPTCHA
ANON_MAX_CONCURRENT_STREAMS = int(os.getenv("ANON_MAX_CONCURRENT_STREAMS", "2"))
ANON_CAPTCHA_REQUEST_THRESHOLD = int(
os.getenv("ANON_CAPTCHA_REQUEST_THRESHOLD", "5")
)
# Cloudflare Turnstile CAPTCHA
TURNSTILE_ENABLED = os.getenv("TURNSTILE_ENABLED", "FALSE").upper() == "TRUE"
TURNSTILE_SECRET_KEY = os.getenv("TURNSTILE_SECRET_KEY", "")
# Auth
AUTH_TYPE = os.getenv("AUTH_TYPE")
REGISTRATION_ENABLED = os.getenv("REGISTRATION_ENABLED", "TRUE").upper() == "TRUE"

View file

@ -48,6 +48,11 @@ global_llm_configs:
- id: -1
name: "Global GPT-4 Turbo"
description: "OpenAI's GPT-4 Turbo with default prompts and citations"
billing_tier: "free"
anonymous_enabled: true
seo_enabled: true
seo_slug: "gpt-4-turbo"
quota_reserve_tokens: 4000
provider: "OPENAI"
model_name: "gpt-4-turbo-preview"
api_key: "sk-your-openai-api-key-here"
@ -67,6 +72,11 @@ global_llm_configs:
- id: -2
name: "Global Claude 3 Opus"
description: "Anthropic's most capable model with citations"
billing_tier: "free"
anonymous_enabled: true
seo_enabled: true
seo_slug: "claude-3-opus"
quota_reserve_tokens: 4000
provider: "ANTHROPIC"
model_name: "claude-3-opus-20240229"
api_key: "sk-ant-your-anthropic-api-key-here"
@ -84,6 +94,11 @@ global_llm_configs:
- id: -3
name: "Global GPT-3.5 Turbo (Fast)"
description: "Fast responses without citations for quick queries"
billing_tier: "free"
anonymous_enabled: true
seo_enabled: true
seo_slug: "gpt-3.5-turbo-fast"
quota_reserve_tokens: 2000
provider: "OPENAI"
model_name: "gpt-3.5-turbo"
api_key: "sk-your-openai-api-key-here"
@ -101,6 +116,11 @@ global_llm_configs:
- id: -4
name: "Global DeepSeek Chat (Chinese)"
description: "DeepSeek optimized for Chinese language responses"
billing_tier: "free"
anonymous_enabled: true
seo_enabled: true
seo_slug: "deepseek-chat-chinese"
quota_reserve_tokens: 4000
provider: "DEEPSEEK"
model_name: "deepseek-chat"
api_key: "your-deepseek-api-key-here"
@ -128,6 +148,11 @@ global_llm_configs:
- id: -5
name: "Global Azure GPT-4o"
description: "Azure OpenAI GPT-4o deployment"
billing_tier: "free"
anonymous_enabled: true
seo_enabled: true
seo_slug: "azure-gpt-4o"
quota_reserve_tokens: 4000
provider: "AZURE"
# model_name format for Azure: azure/<your-deployment-name>
model_name: "azure/gpt-4o-deployment"
@ -151,6 +176,11 @@ global_llm_configs:
- id: -6
name: "Global Azure GPT-4 Turbo"
description: "Azure OpenAI GPT-4 Turbo deployment"
billing_tier: "free"
anonymous_enabled: true
seo_enabled: true
seo_slug: "azure-gpt-4-turbo"
quota_reserve_tokens: 4000
provider: "AZURE"
model_name: "azure/gpt-4-turbo-deployment"
api_key: "your-azure-api-key-here"
@ -170,6 +200,11 @@ global_llm_configs:
- id: -7
name: "Global Groq Llama 3"
description: "Ultra-fast Llama 3 70B via Groq"
billing_tier: "free"
anonymous_enabled: true
seo_enabled: true
seo_slug: "groq-llama-3"
quota_reserve_tokens: 8000
provider: "GROQ"
model_name: "llama3-70b-8192"
api_key: "your-groq-api-key-here"
@ -187,6 +222,11 @@ global_llm_configs:
- id: -8
name: "Global MiniMax M2.5"
description: "MiniMax M2.5 with 204K context window and competitive pricing"
billing_tier: "free"
anonymous_enabled: true
seo_enabled: true
seo_slug: "minimax-m2.5"
quota_reserve_tokens: 4000
provider: "MINIMAX"
model_name: "MiniMax-M2.5"
api_key: "your-minimax-api-key-here"
@ -365,3 +405,13 @@ global_vision_llm_configs:
# - Only use vision-capable models (GPT-4o, Gemini, Claude 3, etc.)
# - Lower temperature (0.3) is recommended for accurate screenshot analysis
# - Lower max_tokens (1000) is sufficient since autocomplete produces short suggestions
#
# TOKEN QUOTA & ANONYMOUS ACCESS NOTES:
# - billing_tier: "free" or "premium". Controls whether registered users need premium token quota.
# - anonymous_enabled: true/false. Whether the model appears in the public no-login catalog.
# - seo_enabled: true/false. Whether a /free/<seo_slug> landing page is generated.
# - seo_slug: Stable URL slug for SEO pages. Must be unique. Do NOT change once public.
# - seo_title: Optional HTML title tag override for the model's /free/<slug> page.
# - seo_description: Optional meta description override for the model's /free/<slug> page.
# - quota_reserve_tokens: Tokens reserved before each LLM call for quota enforcement.
# Independent of litellm_params.max_tokens. Used by the token quota service.

View file

@ -12,6 +12,7 @@ from sqlalchemy import (
ARRAY,
JSON,
TIMESTAMP,
BigInteger,
Boolean,
Column,
Enum as SQLAlchemyEnum,
@ -318,6 +319,12 @@ class PagePurchaseStatus(StrEnum):
FAILED = "failed"
class PremiumTokenPurchaseStatus(StrEnum):
PENDING = "pending"
COMPLETED = "completed"
FAILED = "failed"
# Centralized configuration for incentive tasks
# This makes it easy to add new tasks without changing code in multiple places
INCENTIVE_TASKS_CONFIG = {
@ -1739,6 +1746,38 @@ class PagePurchase(Base, TimestampMixin):
user = relationship("User", back_populates="page_purchases")
class PremiumTokenPurchase(Base, TimestampMixin):
"""Tracks Stripe checkout sessions used to grant additional premium token credits."""
__tablename__ = "premium_token_purchases"
__allow_unmapped__ = True
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
user_id = Column(
UUID(as_uuid=True),
ForeignKey("user.id", ondelete="CASCADE"),
nullable=False,
index=True,
)
stripe_checkout_session_id = Column(
String(255), nullable=False, unique=True, index=True
)
stripe_payment_intent_id = Column(String(255), nullable=True, index=True)
quantity = Column(Integer, nullable=False)
tokens_granted = Column(BigInteger, nullable=False)
amount_total = Column(Integer, nullable=True)
currency = Column(String(10), nullable=True)
status = Column(
SQLAlchemyEnum(PremiumTokenPurchaseStatus),
nullable=False,
default=PremiumTokenPurchaseStatus.PENDING,
index=True,
)
completed_at = Column(TIMESTAMP(timezone=True), nullable=True)
user = relationship("User", back_populates="premium_token_purchases")
class SearchSpaceRole(BaseModel, TimestampMixin):
"""
Custom roles that can be defined per search space.
@ -2009,6 +2048,11 @@ if config.AUTH_TYPE == "GOOGLE":
back_populates="user",
cascade="all, delete-orphan",
)
premium_token_purchases = relationship(
"PremiumTokenPurchase",
back_populates="user",
cascade="all, delete-orphan",
)
# Page usage tracking for ETL services
pages_limit = Column(
@ -2019,6 +2063,19 @@ if config.AUTH_TYPE == "GOOGLE":
)
pages_used = Column(Integer, nullable=False, default=0, server_default="0")
premium_tokens_limit = Column(
BigInteger,
nullable=False,
default=config.PREMIUM_TOKEN_LIMIT,
server_default=str(config.PREMIUM_TOKEN_LIMIT),
)
premium_tokens_used = Column(
BigInteger, nullable=False, default=0, server_default="0"
)
premium_tokens_reserved = Column(
BigInteger, nullable=False, default=0, server_default="0"
)
# User profile from OAuth
display_name = Column(String, nullable=True)
avatar_url = Column(String, nullable=True)
@ -2123,6 +2180,11 @@ else:
back_populates="user",
cascade="all, delete-orphan",
)
premium_token_purchases = relationship(
"PremiumTokenPurchase",
back_populates="user",
cascade="all, delete-orphan",
)
# Page usage tracking for ETL services
pages_limit = Column(
@ -2133,6 +2195,19 @@ else:
)
pages_used = Column(Integer, nullable=False, default=0, server_default="0")
premium_tokens_limit = Column(
BigInteger,
nullable=False,
default=config.PREMIUM_TOKEN_LIMIT,
server_default=str(config.PREMIUM_TOKEN_LIMIT),
)
premium_tokens_used = Column(
BigInteger, nullable=False, default=0, server_default="0"
)
premium_tokens_reserved = Column(
BigInteger, nullable=False, default=0, server_default="0"
)
# User profile (can be set manually for non-OAuth users)
display_name = Column(String, nullable=True)
avatar_url = Column(String, nullable=True)

View file

@ -0,0 +1,15 @@
"""Shared SlowAPI limiter instance used by app.py and route modules."""
from limits.storage import MemoryStorage
from slowapi import Limiter
from slowapi.util import get_remote_address
from app.config import config
limiter = Limiter(
key_func=get_remote_address,
storage_uri=config.REDIS_APP_URL,
default_limits=["1024/minute"],
in_memory_fallback_enabled=True,
in_memory_fallback=[MemoryStorage()],
)

View file

@ -0,0 +1,610 @@
"""Public API endpoints for anonymous (no-login) chat."""
from __future__ import annotations
import logging
import secrets
import uuid
from pathlib import PurePosixPath
from typing import Any
from fastapi import APIRouter, HTTPException, Request, Response, UploadFile, status
from fastapi.responses import StreamingResponse
from pydantic import BaseModel, Field
from app.config import config
from app.etl_pipeline.file_classifier import (
DIRECT_CONVERT_EXTENSIONS,
PLAINTEXT_EXTENSIONS,
)
from app.rate_limiter import limiter
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/v1/public/anon-chat", tags=["anonymous-chat"])
ANON_COOKIE_NAME = "surfsense_anon_session"
ANON_COOKIE_MAX_AGE = config.ANON_TOKEN_QUOTA_TTL_DAYS * 86400
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _get_or_create_session_id(request: Request, response: Response) -> str:
"""Read the signed session cookie or create a new one."""
session_id = request.cookies.get(ANON_COOKIE_NAME)
if session_id and len(session_id) == 43:
return session_id
session_id = secrets.token_urlsafe(32)
response.set_cookie(
key=ANON_COOKIE_NAME,
value=session_id,
max_age=ANON_COOKIE_MAX_AGE,
httponly=True,
samesite="lax",
secure=request.url.scheme == "https",
path="/",
)
return session_id
def _get_client_ip(request: Request) -> str:
forwarded = request.headers.get("x-forwarded-for")
return (
forwarded.split(",")[0].strip()
if forwarded
else (request.client.host if request.client else "unknown")
)
# ---------------------------------------------------------------------------
# Schemas
# ---------------------------------------------------------------------------
class AnonChatRequest(BaseModel):
model_slug: str = Field(..., max_length=100)
messages: list[dict[str, Any]] = Field(..., min_length=1)
disabled_tools: list[str] | None = None
turnstile_token: str | None = None
class AnonQuotaResponse(BaseModel):
used: int
limit: int
remaining: int
status: str
warning_threshold: int
captcha_required: bool = False
class AnonModelResponse(BaseModel):
id: int
name: str
description: str | None = None
provider: str
model_name: str
billing_tier: str = "free"
is_premium: bool = False
seo_slug: str | None = None
seo_enabled: bool = False
seo_title: str | None = None
seo_description: str | None = None
quota_reserve_tokens: int | None = None
# ---------------------------------------------------------------------------
# Routes
# ---------------------------------------------------------------------------
@router.get("/models", response_model=list[AnonModelResponse])
async def list_anonymous_models():
"""Return all models enabled for anonymous access."""
if not config.NOLOGIN_MODE_ENABLED:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="No-login mode is not enabled.",
)
models = []
for cfg in config.GLOBAL_LLM_CONFIGS:
if cfg.get("anonymous_enabled", False):
models.append(
AnonModelResponse(
id=cfg.get("id", 0),
name=cfg.get("name", ""),
description=cfg.get("description"),
provider=cfg.get("provider", ""),
model_name=cfg.get("model_name", ""),
billing_tier=cfg.get("billing_tier", "free"),
is_premium=cfg.get("billing_tier", "free") == "premium",
seo_slug=cfg.get("seo_slug"),
seo_enabled=cfg.get("seo_enabled", False),
seo_title=cfg.get("seo_title"),
seo_description=cfg.get("seo_description"),
quota_reserve_tokens=cfg.get("quota_reserve_tokens"),
)
)
return models
@router.get("/models/{slug}", response_model=AnonModelResponse)
async def get_anonymous_model(slug: str):
"""Return a single model by its SEO slug."""
if not config.NOLOGIN_MODE_ENABLED:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="No-login mode is not enabled.",
)
for cfg in config.GLOBAL_LLM_CONFIGS:
if cfg.get("anonymous_enabled", False) and cfg.get("seo_slug") == slug:
return AnonModelResponse(
id=cfg.get("id", 0),
name=cfg.get("name", ""),
description=cfg.get("description"),
provider=cfg.get("provider", ""),
model_name=cfg.get("model_name", ""),
billing_tier=cfg.get("billing_tier", "free"),
is_premium=cfg.get("billing_tier", "free") == "premium",
seo_slug=cfg.get("seo_slug"),
seo_enabled=cfg.get("seo_enabled", False),
seo_title=cfg.get("seo_title"),
seo_description=cfg.get("seo_description"),
quota_reserve_tokens=cfg.get("quota_reserve_tokens"),
)
raise HTTPException(status_code=404, detail="Model not found")
@router.get("/quota", response_model=AnonQuotaResponse)
@limiter.limit("30/minute")
async def get_anonymous_quota(request: Request, response: Response):
"""Return current token usage for the anonymous session.
Reports the *stricter* of session and IP buckets so that opening a
new browser on the same IP doesn't show a misleadingly fresh quota.
"""
if not config.NOLOGIN_MODE_ENABLED:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="No-login mode is not enabled.",
)
from app.services.token_quota_service import (
TokenQuotaService,
compute_anon_identity_key,
compute_ip_quota_key,
)
client_ip = _get_client_ip(request)
session_id = _get_or_create_session_id(request, response)
session_key = compute_anon_identity_key(session_id)
session_result = await TokenQuotaService.anon_get_usage(session_key)
ip_key = compute_ip_quota_key(client_ip)
ip_result = await TokenQuotaService.anon_get_usage(ip_key)
# Use whichever bucket has higher usage — that's the real constraint
result = ip_result if ip_result.used > session_result.used else session_result
captcha_needed = False
if config.TURNSTILE_ENABLED:
req_count = await TokenQuotaService.anon_get_request_count(client_ip)
captcha_needed = req_count >= config.ANON_CAPTCHA_REQUEST_THRESHOLD
return AnonQuotaResponse(
used=result.used,
limit=result.limit,
remaining=result.remaining,
status=result.status.value,
warning_threshold=config.ANON_TOKEN_WARNING_THRESHOLD,
captcha_required=captcha_needed,
)
@router.post("/stream")
@limiter.limit("15/minute")
async def stream_anonymous_chat(
body: AnonChatRequest,
request: Request,
response: Response,
):
"""Stream a chat response for an anonymous user with quota enforcement."""
if not config.NOLOGIN_MODE_ENABLED:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="No-login mode is not enabled.",
)
from app.agents.new_chat.llm_config import (
AgentConfig,
create_chat_litellm_from_agent_config,
)
from app.services.token_quota_service import (
TokenQuotaService,
compute_anon_identity_key,
compute_ip_quota_key,
)
from app.services.turnstile_service import verify_turnstile_token
# Find the model config by slug
model_cfg = None
for cfg in config.GLOBAL_LLM_CONFIGS:
if (
cfg.get("anonymous_enabled", False)
and cfg.get("seo_slug") == body.model_slug
):
model_cfg = cfg
break
if model_cfg is None:
raise HTTPException(
status_code=404, detail="Model not found or not available for anonymous use"
)
client_ip = _get_client_ip(request)
# --- Concurrent stream limit ---
slot_acquired = await TokenQuotaService.anon_acquire_stream_slot(
client_ip, max_concurrent=config.ANON_MAX_CONCURRENT_STREAMS
)
if not slot_acquired:
raise HTTPException(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
detail={
"code": "ANON_TOO_MANY_STREAMS",
"message": f"Max {config.ANON_MAX_CONCURRENT_STREAMS} concurrent chats allowed. Please wait for a response to finish.",
},
)
try:
# --- CAPTCHA enforcement (check count without incrementing; count
# is bumped only after a successful response in _generate) ---
if config.TURNSTILE_ENABLED:
req_count = await TokenQuotaService.anon_get_request_count(client_ip)
if req_count >= config.ANON_CAPTCHA_REQUEST_THRESHOLD:
if not body.turnstile_token:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail={
"code": "CAPTCHA_REQUIRED",
"message": "Please complete the CAPTCHA to continue chatting.",
},
)
valid = await verify_turnstile_token(body.turnstile_token, client_ip)
if not valid:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail={
"code": "CAPTCHA_INVALID",
"message": "CAPTCHA verification failed. Please try again.",
},
)
await TokenQuotaService.anon_reset_request_count(client_ip)
# Build identity keys
session_id = _get_or_create_session_id(request, response)
session_key = compute_anon_identity_key(session_id)
ip_key = compute_ip_quota_key(client_ip)
# Reserve tokens
reserve_amount = min(
model_cfg.get("quota_reserve_tokens", config.QUOTA_MAX_RESERVE_PER_CALL),
config.QUOTA_MAX_RESERVE_PER_CALL,
)
request_id = uuid.uuid4().hex[:16]
quota_result = await TokenQuotaService.anon_reserve(
session_key=session_key,
ip_key=ip_key,
request_id=request_id,
reserve_tokens=reserve_amount,
)
if not quota_result.allowed:
raise HTTPException(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
detail={
"code": "ANON_QUOTA_EXCEEDED",
"message": "You've used all your free tokens. Create an account for 5M more!",
"used": quota_result.used,
"limit": quota_result.limit,
},
)
# Create agent config from YAML
agent_config = AgentConfig.from_yaml_config(model_cfg)
llm = create_chat_litellm_from_agent_config(agent_config)
if not llm:
await TokenQuotaService.anon_release(session_key, ip_key, request_id)
raise HTTPException(status_code=500, detail="Failed to create LLM instance")
# Server-side tool allow-list enforcement
anon_allowed_tools = {"web_search"}
client_disabled = set(body.disabled_tools) if body.disabled_tools else set()
enabled_for_agent = anon_allowed_tools - client_disabled
except HTTPException:
await TokenQuotaService.anon_release_stream_slot(client_ip)
raise
async def _generate():
from langchain_core.messages import HumanMessage
from app.agents.new_chat.chat_deepagent import create_surfsense_deep_agent
from app.agents.new_chat.checkpointer import get_checkpointer
from app.db import shielded_async_session
from app.services.connector_service import ConnectorService
from app.services.new_streaming_service import VercelStreamingService
from app.services.token_tracking_service import start_turn
from app.tasks.chat.stream_new_chat import StreamResult, _stream_agent_events
accumulator = start_turn()
streaming_service = VercelStreamingService()
try:
async with shielded_async_session() as session:
connector_service = ConnectorService(session, search_space_id=None)
checkpointer = await get_checkpointer()
anon_thread_id = f"anon-{session_id}-{request_id}"
agent = await create_surfsense_deep_agent(
llm=llm,
search_space_id=0,
db_session=session,
connector_service=connector_service,
checkpointer=checkpointer,
user_id=None,
thread_id=None,
agent_config=agent_config,
enabled_tools=list(enabled_for_agent),
disabled_tools=None,
anon_session_id=session_id,
)
user_query = ""
for msg in reversed(body.messages):
if msg.get("role") == "user":
user_query = msg.get("content", "")
break
langchain_messages = [HumanMessage(content=user_query)]
input_state = {
"messages": langchain_messages,
"search_space_id": 0,
}
langgraph_config = {
"configurable": {"thread_id": anon_thread_id},
"recursion_limit": 40,
}
yield streaming_service.format_message_start()
yield streaming_service.format_start_step()
initial_step_id = "thinking-1"
query_preview = user_query[:80] + (
"..." if len(user_query) > 80 else ""
)
initial_items = [f"Processing: {query_preview}"]
yield streaming_service.format_thinking_step(
step_id=initial_step_id,
title="Understanding your request",
status="in_progress",
items=initial_items,
)
stream_result = StreamResult()
async for sse in _stream_agent_events(
agent=agent,
config=langgraph_config,
input_data=input_state,
streaming_service=streaming_service,
result=stream_result,
step_prefix="thinking",
initial_step_id=initial_step_id,
initial_step_title="Understanding your request",
initial_step_items=initial_items,
):
yield sse
# Finalize quota with actual tokens
actual_tokens = accumulator.grand_total
finalize_result = await TokenQuotaService.anon_finalize(
session_key=session_key,
ip_key=ip_key,
request_id=request_id,
actual_tokens=actual_tokens,
)
# Count this as 1 completed response for CAPTCHA threshold
if config.TURNSTILE_ENABLED:
await TokenQuotaService.anon_increment_request_count(client_ip)
yield streaming_service.format_data(
"anon-quota",
{
"used": finalize_result.used,
"limit": finalize_result.limit,
"remaining": finalize_result.remaining,
"status": finalize_result.status.value,
},
)
if accumulator.per_message_summary():
yield streaming_service.format_data(
"token-usage",
{
"usage": accumulator.per_message_summary(),
"prompt_tokens": accumulator.total_prompt_tokens,
"completion_tokens": accumulator.total_completion_tokens,
"total_tokens": accumulator.grand_total,
},
)
yield streaming_service.format_finish_step()
yield streaming_service.format_finish()
yield streaming_service.format_done()
except Exception as e:
logger.exception("Anonymous chat stream error")
await TokenQuotaService.anon_release(session_key, ip_key, request_id)
yield streaming_service.format_error(f"Error during chat: {e!s}")
yield streaming_service.format_done()
finally:
await TokenQuotaService.anon_release_stream_slot(client_ip)
return StreamingResponse(
_generate(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
},
)
# ---------------------------------------------------------------------------
# Anonymous Document Upload (1-doc limit, plaintext/direct-convert only)
# ---------------------------------------------------------------------------
ANON_ALLOWED_EXTENSIONS = PLAINTEXT_EXTENSIONS | DIRECT_CONVERT_EXTENSIONS
ANON_DOC_REDIS_PREFIX = "anon:doc:"
class AnonDocResponse(BaseModel):
filename: str
size_bytes: int
status: str = "uploaded"
@router.post("/upload", response_model=AnonDocResponse)
@limiter.limit("5/minute")
async def upload_anonymous_document(
file: UploadFile,
request: Request,
response: Response,
):
"""Upload a single document for anonymous chat (1-doc limit per session)."""
if not config.NOLOGIN_MODE_ENABLED:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="No-login mode is not enabled.",
)
session_id = _get_or_create_session_id(request, response)
if not file.filename:
raise HTTPException(status_code=400, detail="No filename provided")
ext = PurePosixPath(file.filename).suffix.lower()
if ext not in ANON_ALLOWED_EXTENSIONS:
raise HTTPException(
status_code=415,
detail=(
"File type not supported for anonymous upload. "
"Create an account to upload PDFs, Word documents, images, audio, and 20+ more file types. "
"Allowed extensions: text, code, CSV, HTML files."
),
)
max_size = config.ANON_MAX_UPLOAD_SIZE_MB * 1024 * 1024
content = await file.read()
if len(content) > max_size:
raise HTTPException(
status_code=413,
detail=f"File too large. Max size is {config.ANON_MAX_UPLOAD_SIZE_MB} MB.",
)
import json as _json
import redis.asyncio as aioredis
redis_client = aioredis.from_url(config.REDIS_APP_URL, decode_responses=True)
redis_key = f"{ANON_DOC_REDIS_PREFIX}{session_id}"
try:
existing = await redis_client.exists(redis_key)
if existing:
raise HTTPException(
status_code=409,
detail="Document limit reached. Create an account to upload more.",
)
text_content: str
if ext in PLAINTEXT_EXTENSIONS:
text_content = content.decode("utf-8", errors="replace")
elif ext in DIRECT_CONVERT_EXTENSIONS:
if ext in {".csv", ".tsv"}:
text_content = content.decode("utf-8", errors="replace")
else:
try:
from markdownify import markdownify
text_content = markdownify(
content.decode("utf-8", errors="replace")
)
except ImportError:
text_content = content.decode("utf-8", errors="replace")
else:
text_content = content.decode("utf-8", errors="replace")
doc_data = _json.dumps(
{
"filename": file.filename,
"size_bytes": len(content),
"content": text_content,
}
)
ttl_seconds = config.ANON_TOKEN_QUOTA_TTL_DAYS * 86400
await redis_client.set(redis_key, doc_data, ex=ttl_seconds)
finally:
await redis_client.aclose()
return AnonDocResponse(
filename=file.filename,
size_bytes=len(content),
)
@router.get("/document")
async def get_anonymous_document(request: Request, response: Response):
"""Get metadata of the uploaded document for the anonymous session."""
if not config.NOLOGIN_MODE_ENABLED:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="No-login mode is not enabled.",
)
session_id = _get_or_create_session_id(request, response)
import json as _json
import redis.asyncio as aioredis
redis_client = aioredis.from_url(config.REDIS_APP_URL, decode_responses=True)
redis_key = f"{ANON_DOC_REDIS_PREFIX}{session_id}"
try:
data = await redis_client.get(redis_key)
if not data:
raise HTTPException(status_code=404, detail="No document uploaded")
doc = _json.loads(data)
return {
"filename": doc["filename"],
"size_bytes": doc["size_bytes"],
}
finally:
await redis_client.aclose()

View file

@ -76,6 +76,14 @@ async def get_global_new_llm_configs(
"citations_enabled": True,
"is_global": True,
"is_auto_mode": True,
"billing_tier": "free",
"is_premium": False,
"anonymous_enabled": False,
"seo_enabled": False,
"seo_slug": None,
"seo_title": None,
"seo_description": None,
"quota_reserve_tokens": None,
}
)
@ -97,6 +105,14 @@ async def get_global_new_llm_configs(
),
"citations_enabled": cfg.get("citations_enabled", True),
"is_global": True,
"billing_tier": cfg.get("billing_tier", "free"),
"is_premium": cfg.get("billing_tier", "free") == "premium",
"anonymous_enabled": cfg.get("anonymous_enabled", False),
"seo_enabled": cfg.get("seo_enabled", False),
"seo_slug": cfg.get("seo_slug"),
"seo_title": cfg.get("seo_title"),
"seo_description": cfg.get("seo_description"),
"quota_reserve_tokens": cfg.get("quota_reserve_tokens"),
}
safe_configs.append(safe_config)

View file

@ -13,13 +13,24 @@ from sqlalchemy.ext.asyncio import AsyncSession
from stripe import SignatureVerificationError, StripeClient, StripeError
from app.config import config
from app.db import PagePurchase, PagePurchaseStatus, User, get_async_session
from app.db import (
PagePurchase,
PagePurchaseStatus,
PremiumTokenPurchase,
PremiumTokenPurchaseStatus,
User,
get_async_session,
)
from app.schemas.stripe import (
CreateCheckoutSessionRequest,
CreateCheckoutSessionResponse,
CreateTokenCheckoutSessionRequest,
CreateTokenCheckoutSessionResponse,
PagePurchaseHistoryResponse,
StripeStatusResponse,
StripeWebhookResponse,
TokenPurchaseHistoryResponse,
TokenStripeStatusResponse,
)
from app.users import current_active_user
@ -151,6 +162,26 @@ async def _mark_purchase_failed(
return StripeWebhookResponse()
async def _mark_token_purchase_failed(
db_session: AsyncSession, checkout_session_id: str
) -> StripeWebhookResponse:
purchase = (
await db_session.execute(
select(PremiumTokenPurchase)
.where(
PremiumTokenPurchase.stripe_checkout_session_id == checkout_session_id
)
.with_for_update()
)
).scalar_one_or_none()
if purchase is not None and purchase.status == PremiumTokenPurchaseStatus.PENDING:
purchase.status = PremiumTokenPurchaseStatus.FAILED
await db_session.commit()
return StripeWebhookResponse()
async def _fulfill_completed_purchase(
db_session: AsyncSession, checkout_session: Any
) -> StripeWebhookResponse:
@ -201,6 +232,86 @@ async def _fulfill_completed_purchase(
return StripeWebhookResponse()
async def _fulfill_completed_token_purchase(
db_session: AsyncSession, checkout_session: Any
) -> StripeWebhookResponse:
"""Grant premium tokens to the user after a confirmed Stripe payment."""
checkout_session_id = str(checkout_session.id)
purchase = (
await db_session.execute(
select(PremiumTokenPurchase)
.where(
PremiumTokenPurchase.stripe_checkout_session_id == checkout_session_id
)
.with_for_update()
)
).scalar_one_or_none()
if purchase is None:
metadata = _get_metadata(checkout_session)
user_id = metadata.get("user_id")
quantity = int(metadata.get("quantity", "0"))
tokens_per_unit = int(metadata.get("tokens_per_unit", "0"))
if not user_id or quantity <= 0 or tokens_per_unit <= 0:
logger.error(
"Skipping token fulfillment for session %s: incomplete metadata %s",
checkout_session_id,
metadata,
)
return StripeWebhookResponse()
purchase = PremiumTokenPurchase(
user_id=uuid.UUID(user_id),
stripe_checkout_session_id=checkout_session_id,
stripe_payment_intent_id=_normalize_optional_string(
getattr(checkout_session, "payment_intent", None)
),
quantity=quantity,
tokens_granted=quantity * tokens_per_unit,
amount_total=getattr(checkout_session, "amount_total", None),
currency=getattr(checkout_session, "currency", None),
status=PremiumTokenPurchaseStatus.PENDING,
)
db_session.add(purchase)
await db_session.flush()
if purchase.status == PremiumTokenPurchaseStatus.COMPLETED:
return StripeWebhookResponse()
user = (
(
await db_session.execute(
select(User).where(User.id == purchase.user_id).with_for_update(of=User)
)
)
.unique()
.scalar_one_or_none()
)
if user is None:
logger.error(
"Skipping token fulfillment for session %s: user %s not found",
purchase.stripe_checkout_session_id,
purchase.user_id,
)
return StripeWebhookResponse()
purchase.status = PremiumTokenPurchaseStatus.COMPLETED
purchase.completed_at = datetime.now(UTC)
purchase.amount_total = getattr(checkout_session, "amount_total", None)
purchase.currency = getattr(checkout_session, "currency", None)
purchase.stripe_payment_intent_id = _normalize_optional_string(
getattr(checkout_session, "payment_intent", None)
)
user.premium_tokens_limit = (
max(user.premium_tokens_used, user.premium_tokens_limit)
+ purchase.tokens_granted
)
await db_session.commit()
return StripeWebhookResponse()
@router.post("/create-checkout-session", response_model=CreateCheckoutSessionResponse)
async def create_checkout_session(
body: CreateCheckoutSessionRequest,
@ -333,6 +444,10 @@ async def stripe_webhook(
)
return StripeWebhookResponse()
metadata = _get_metadata(checkout_session)
purchase_type = metadata.get("purchase_type", "page_packs")
if purchase_type == "premium_tokens":
return await _fulfill_completed_token_purchase(db_session, checkout_session)
return await _fulfill_completed_purchase(db_session, checkout_session)
if event.type in {
@ -340,6 +455,12 @@ async def stripe_webhook(
"checkout.session.expired",
}:
checkout_session = event.data.object
metadata = _get_metadata(checkout_session)
purchase_type = metadata.get("purchase_type", "page_packs")
if purchase_type == "premium_tokens":
return await _mark_token_purchase_failed(
db_session, str(checkout_session.id)
)
return await _mark_purchase_failed(db_session, str(checkout_session.id))
return StripeWebhookResponse()
@ -369,3 +490,146 @@ async def get_page_purchases(
)
return PagePurchaseHistoryResponse(purchases=purchases)
# =============================================================================
# Premium Token Purchase Routes
# =============================================================================
def _ensure_token_buying_enabled() -> None:
if not config.STRIPE_TOKEN_BUYING_ENABLED:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="Premium token purchases are temporarily unavailable.",
)
def _get_token_checkout_urls(search_space_id: int) -> tuple[str, str]:
if not config.NEXT_FRONTEND_URL:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="NEXT_FRONTEND_URL is not configured.",
)
base_url = config.NEXT_FRONTEND_URL.rstrip("/")
success_url = f"{base_url}/dashboard/{search_space_id}/purchase-success"
cancel_url = f"{base_url}/dashboard/{search_space_id}/purchase-cancel"
return success_url, cancel_url
def _get_required_token_price_id() -> str:
if not config.STRIPE_PREMIUM_TOKEN_PRICE_ID:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="STRIPE_PREMIUM_TOKEN_PRICE_ID is not configured.",
)
return config.STRIPE_PREMIUM_TOKEN_PRICE_ID
@router.post("/create-token-checkout-session")
async def create_token_checkout_session(
body: CreateTokenCheckoutSessionRequest,
user: User = Depends(current_active_user),
db_session: AsyncSession = Depends(get_async_session),
):
"""Create a Stripe Checkout Session for buying premium token packs."""
_ensure_token_buying_enabled()
stripe_client = get_stripe_client()
price_id = _get_required_token_price_id()
success_url, cancel_url = _get_token_checkout_urls(body.search_space_id)
tokens_granted = body.quantity * config.STRIPE_TOKENS_PER_UNIT
try:
checkout_session = stripe_client.v1.checkout.sessions.create(
params={
"mode": "payment",
"success_url": success_url,
"cancel_url": cancel_url,
"line_items": [
{
"price": price_id,
"quantity": body.quantity,
}
],
"client_reference_id": str(user.id),
"customer_email": user.email,
"metadata": {
"user_id": str(user.id),
"quantity": str(body.quantity),
"tokens_per_unit": str(config.STRIPE_TOKENS_PER_UNIT),
"purchase_type": "premium_tokens",
},
}
)
except StripeError as exc:
logger.exception("Failed to create token checkout session for user %s", user.id)
raise HTTPException(
status_code=status.HTTP_502_BAD_GATEWAY,
detail="Unable to create Stripe checkout session.",
) from exc
checkout_url = getattr(checkout_session, "url", None)
if not checkout_url:
raise HTTPException(
status_code=status.HTTP_502_BAD_GATEWAY,
detail="Stripe checkout session did not return a URL.",
)
db_session.add(
PremiumTokenPurchase(
user_id=user.id,
stripe_checkout_session_id=str(checkout_session.id),
stripe_payment_intent_id=_normalize_optional_string(
getattr(checkout_session, "payment_intent", None)
),
quantity=body.quantity,
tokens_granted=tokens_granted,
amount_total=getattr(checkout_session, "amount_total", None),
currency=getattr(checkout_session, "currency", None),
status=PremiumTokenPurchaseStatus.PENDING,
)
)
await db_session.commit()
return CreateTokenCheckoutSessionResponse(checkout_url=checkout_url)
@router.get("/token-status")
async def get_token_status(
user: User = Depends(current_active_user),
):
"""Return token-buying availability and current premium quota for frontend."""
used = user.premium_tokens_used
limit = user.premium_tokens_limit
return TokenStripeStatusResponse(
token_buying_enabled=config.STRIPE_TOKEN_BUYING_ENABLED,
premium_tokens_used=used,
premium_tokens_limit=limit,
premium_tokens_remaining=max(0, limit - used),
)
@router.get("/token-purchases")
async def get_token_purchases(
user: User = Depends(current_active_user),
db_session: AsyncSession = Depends(get_async_session),
offset: int = 0,
limit: int = 50,
):
"""Return the authenticated user's premium token purchase history."""
limit = min(limit, 100)
purchases = (
(
await db_session.execute(
select(PremiumTokenPurchase)
.where(PremiumTokenPurchase.user_id == user.id)
.order_by(PremiumTokenPurchase.created_at.desc())
.offset(offset)
.limit(limit)
)
)
.scalars()
.all()
)
return TokenPurchaseHistoryResponse(purchases=purchases)

View file

@ -164,6 +164,15 @@ class GlobalNewLLMConfigRead(BaseModel):
is_global: bool = True # Always true for global configs
is_auto_mode: bool = False # True only for Auto mode (ID 0)
billing_tier: str = "free"
is_premium: bool = False
anonymous_enabled: bool = False
seo_enabled: bool = False
seo_slug: str | None = None
seo_title: str | None = None
seo_description: str | None = None
quota_reserve_tokens: int | None = None
# =============================================================================
# LLM Preferences Schemas (for role assignments)

View file

@ -54,3 +54,48 @@ class StripeWebhookResponse(BaseModel):
"""Generic acknowledgement for Stripe webhook delivery."""
received: bool = True
class CreateTokenCheckoutSessionRequest(BaseModel):
"""Request body for creating a premium token purchase checkout session."""
quantity: int = Field(ge=1, le=100)
search_space_id: int = Field(ge=1)
class CreateTokenCheckoutSessionResponse(BaseModel):
"""Response containing the Stripe-hosted checkout URL."""
checkout_url: str
class TokenPurchaseRead(BaseModel):
"""Serialized premium token purchase record."""
id: uuid.UUID
stripe_checkout_session_id: str
stripe_payment_intent_id: str | None = None
quantity: int
tokens_granted: int
amount_total: int | None = None
currency: str | None = None
status: str
completed_at: datetime | None = None
created_at: datetime
model_config = ConfigDict(from_attributes=True)
class TokenPurchaseHistoryResponse(BaseModel):
"""Response containing the user's premium token purchases."""
purchases: list[TokenPurchaseRead]
class TokenStripeStatusResponse(BaseModel):
"""Response describing token-buying availability and current quota."""
token_buying_enabled: bool
premium_tokens_used: int = 0
premium_tokens_limit: int = 0
premium_tokens_remaining: int = 0

View file

@ -135,9 +135,12 @@ class LLMRouterService:
logger.debug("LLM Router already initialized, skipping")
return
# Build model list from global configs
auto_configs = [
c for c in global_configs if c.get("billing_tier", "free") != "premium"
]
model_list = []
for config in global_configs:
for config in auto_configs:
deployment = cls._config_to_deployment(config)
if deployment:
model_list.append(deployment)

View file

@ -0,0 +1,621 @@
"""
Atomic token quota service for anonymous and registered users.
Provides reserve/finalize/release/get_usage operations with race-safe
implementation using Redis Lua scripts (anonymous) and Postgres row locks
(registered premium).
"""
from __future__ import annotations
import hashlib
import logging
from enum import StrEnum
from typing import Any
import redis.asyncio as aioredis
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.config import config
logger = logging.getLogger(__name__)
class QuotaScope(StrEnum):
ANONYMOUS = "anonymous"
PREMIUM = "premium"
class QuotaStatus(StrEnum):
OK = "ok"
WARNING = "warning"
BLOCKED = "blocked"
class QuotaResult:
__slots__ = ("allowed", "limit", "remaining", "reserved", "status", "used")
def __init__(
self,
allowed: bool,
status: QuotaStatus,
used: int,
limit: int,
reserved: int = 0,
remaining: int = 0,
):
self.allowed = allowed
self.status = status
self.used = used
self.limit = limit
self.reserved = reserved
self.remaining = remaining
def to_dict(self) -> dict[str, Any]:
return {
"allowed": self.allowed,
"status": self.status.value,
"used": self.used,
"limit": self.limit,
"reserved": self.reserved,
"remaining": self.remaining,
}
# ---------------------------------------------------------------------------
# Redis Lua scripts for atomic anonymous quota operations
# ---------------------------------------------------------------------------
# KEYS[1] = quota key (e.g. "anon_quota:session:<session_id>")
# ARGV[1] = reserve_tokens
# ARGV[2] = limit
# ARGV[3] = warning_threshold
# ARGV[4] = request_id
# ARGV[5] = ttl_seconds
# Returns: [allowed(0/1), status("ok"/"warning"/"blocked"), used, reserved]
_RESERVE_LUA = """
local key = KEYS[1]
local reserve = tonumber(ARGV[1])
local limit = tonumber(ARGV[2])
local warning = tonumber(ARGV[3])
local req_id = ARGV[4]
local ttl = tonumber(ARGV[5])
local used = tonumber(redis.call('HGET', key, 'used') or '0')
local reserved = tonumber(redis.call('HGET', key, 'reserved') or '0')
local effective = used + reserved + reserve
if effective > limit then
return {0, 'blocked', used, reserved}
end
redis.call('HINCRBY', key, 'reserved', reserve)
redis.call('HSET', key, 'req:' .. req_id, reserve)
redis.call('EXPIRE', key, ttl)
local new_reserved = reserved + reserve
local status = 'ok'
if (used + new_reserved) >= warning then
status = 'warning'
end
return {1, status, used, new_reserved}
"""
# KEYS[1] = quota key
# ARGV[1] = request_id
# ARGV[2] = actual_tokens
# ARGV[3] = warning_threshold
# Returns: [used, reserved, status]
_FINALIZE_LUA = """
local key = KEYS[1]
local req_id = ARGV[1]
local actual = tonumber(ARGV[2])
local warning = tonumber(ARGV[3])
local orig_reserve = tonumber(redis.call('HGET', key, 'req:' .. req_id) or '0')
if orig_reserve == 0 then
return {tonumber(redis.call('HGET', key, 'used') or '0'), tonumber(redis.call('HGET', key, 'reserved') or '0'), 'ok'}
end
redis.call('HDEL', key, 'req:' .. req_id)
redis.call('HINCRBY', key, 'reserved', -orig_reserve)
redis.call('HINCRBY', key, 'used', actual)
local used = tonumber(redis.call('HGET', key, 'used') or '0')
local reserved = tonumber(redis.call('HGET', key, 'reserved') or '0')
local status = 'ok'
if used >= warning then
status = 'warning'
end
return {used, reserved, status}
"""
# KEYS[1] = quota key
# ARGV[1] = request_id
# Returns: 1 if released, 0 if not found
_RELEASE_LUA = """
local key = KEYS[1]
local req_id = ARGV[1]
local orig_reserve = tonumber(redis.call('HGET', key, 'req:' .. req_id) or '0')
if orig_reserve == 0 then
return 0
end
redis.call('HDEL', key, 'req:' .. req_id)
redis.call('HINCRBY', key, 'reserved', -orig_reserve)
return 1
"""
def _get_anon_redis() -> aioredis.Redis:
return aioredis.from_url(config.REDIS_APP_URL, decode_responses=True)
def compute_anon_identity_key(
session_id: str,
ip_hash: str | None = None,
) -> str:
"""Build the Redis hash key for anonymous quota tracking.
Uses the signed session cookie as primary identity. The IP hash
is tracked separately so cookie-reset evasion is caught.
"""
return f"anon_quota:session:{session_id}"
def compute_ip_quota_key(ip_address: str) -> str:
"""Build IP-only quota key. UA is excluded so rotating User-Agent cannot bypass limits."""
h = hashlib.sha256(ip_address.encode()).hexdigest()[:16]
return f"anon_quota:ip:{h}"
# ---------------------------------------------------------------------------
# Concurrent stream limiter (per-IP)
# ---------------------------------------------------------------------------
# Atomic acquire: returns 1 if slot acquired, 0 if at capacity.
# KEYS[1] = stream counter key ARGV[1] = max_concurrent ARGV[2] = safety_ttl
_ACQUIRE_STREAM_LUA = """
local key = KEYS[1]
local max_c = tonumber(ARGV[1])
local ttl = tonumber(ARGV[2])
local cur = tonumber(redis.call('GET', key) or '0')
if cur >= max_c then
return 0
end
redis.call('INCR', key)
redis.call('EXPIRE', key, ttl)
return 1
"""
# Atomic release: DECR with floor at 0
_RELEASE_STREAM_LUA = """
local key = KEYS[1]
local cur = tonumber(redis.call('GET', key) or '0')
if cur <= 0 then
return 0
end
redis.call('DECR', key)
return 1
"""
def compute_stream_slot_key(ip_address: str) -> str:
h = hashlib.sha256(ip_address.encode()).hexdigest()[:16]
return f"anon:streams:{h}"
def compute_request_count_key(ip_address: str) -> str:
h = hashlib.sha256(ip_address.encode()).hexdigest()[:16]
return f"anon:reqcount:{h}"
class TokenQuotaService:
"""Unified quota service for anonymous (Redis) and premium (Postgres) scopes."""
# ------------------------------------------------------------------
# Concurrent stream limiter
# ------------------------------------------------------------------
@staticmethod
async def anon_acquire_stream_slot(
ip_address: str,
max_concurrent: int = 2,
safety_ttl: int = 300,
) -> bool:
key = compute_stream_slot_key(ip_address)
r = _get_anon_redis()
try:
result = await r.eval(
_ACQUIRE_STREAM_LUA, 1, key, str(max_concurrent), str(safety_ttl)
)
return bool(result)
finally:
await r.aclose()
@staticmethod
async def anon_release_stream_slot(ip_address: str) -> None:
key = compute_stream_slot_key(ip_address)
r = _get_anon_redis()
try:
await r.eval(_RELEASE_STREAM_LUA, 1, key)
finally:
await r.aclose()
# ------------------------------------------------------------------
# Per-IP request counter (for CAPTCHA triggering)
# ------------------------------------------------------------------
@staticmethod
async def anon_increment_request_count(ip_address: str, ttl: int = 86400) -> int:
"""Increment and return current request count for this IP. TTL resets daily."""
key = compute_request_count_key(ip_address)
r = _get_anon_redis()
try:
pipe = r.pipeline()
pipe.incr(key)
pipe.expire(key, ttl)
results = await pipe.execute()
return int(results[0])
finally:
await r.aclose()
@staticmethod
async def anon_get_request_count(ip_address: str) -> int:
key = compute_request_count_key(ip_address)
r = _get_anon_redis()
try:
val = await r.get(key)
return int(val) if val else 0
finally:
await r.aclose()
@staticmethod
async def anon_reset_request_count(ip_address: str) -> None:
key = compute_request_count_key(ip_address)
r = _get_anon_redis()
try:
await r.delete(key)
finally:
await r.aclose()
# ------------------------------------------------------------------
# Anonymous (Redis-backed)
# ------------------------------------------------------------------
@staticmethod
async def anon_reserve(
session_key: str,
ip_key: str | None,
request_id: str,
reserve_tokens: int,
) -> QuotaResult:
limit = config.ANON_TOKEN_LIMIT
warning = config.ANON_TOKEN_WARNING_THRESHOLD
ttl = config.ANON_TOKEN_QUOTA_TTL_DAYS * 86400
r = _get_anon_redis()
try:
result = await r.eval(
_RESERVE_LUA,
1,
session_key,
str(reserve_tokens),
str(limit),
str(warning),
request_id,
str(ttl),
)
allowed = bool(result[0])
status_str = result[1] if isinstance(result[1], str) else result[1].decode()
used = int(result[2])
reserved = int(result[3])
if ip_key:
ip_result = await r.eval(
_RESERVE_LUA,
1,
ip_key,
str(reserve_tokens),
str(limit),
str(warning),
request_id,
str(ttl),
)
ip_allowed = bool(ip_result[0])
ip_used = int(ip_result[2])
if not ip_allowed and allowed:
await r.eval(_RELEASE_LUA, 1, session_key, request_id)
allowed = False
status_str = "blocked"
used = max(used, ip_used)
status = QuotaStatus(status_str)
remaining = max(0, limit - used - reserved)
return QuotaResult(
allowed=allowed,
status=status,
used=used,
limit=limit,
reserved=reserved,
remaining=remaining,
)
finally:
await r.aclose()
@staticmethod
async def anon_finalize(
session_key: str,
ip_key: str | None,
request_id: str,
actual_tokens: int,
) -> QuotaResult:
warning = config.ANON_TOKEN_WARNING_THRESHOLD
limit = config.ANON_TOKEN_LIMIT
r = _get_anon_redis()
try:
result = await r.eval(
_FINALIZE_LUA,
1,
session_key,
request_id,
str(actual_tokens),
str(warning),
)
used = int(result[0])
reserved = int(result[1])
status_str = result[2] if isinstance(result[2], str) else result[2].decode()
if ip_key:
await r.eval(
_FINALIZE_LUA,
1,
ip_key,
request_id,
str(actual_tokens),
str(warning),
)
status = QuotaStatus(status_str)
remaining = max(0, limit - used - reserved)
return QuotaResult(
allowed=True,
status=status,
used=used,
limit=limit,
reserved=reserved,
remaining=remaining,
)
finally:
await r.aclose()
@staticmethod
async def anon_release(
session_key: str,
ip_key: str | None,
request_id: str,
) -> None:
r = _get_anon_redis()
try:
await r.eval(_RELEASE_LUA, 1, session_key, request_id)
if ip_key:
await r.eval(_RELEASE_LUA, 1, ip_key, request_id)
finally:
await r.aclose()
@staticmethod
async def anon_get_usage(session_key: str) -> QuotaResult:
limit = config.ANON_TOKEN_LIMIT
warning = config.ANON_TOKEN_WARNING_THRESHOLD
r = _get_anon_redis()
try:
data = await r.hgetall(session_key)
used = int(data.get("used", 0))
reserved = int(data.get("reserved", 0))
remaining = max(0, limit - used - reserved)
if used >= limit:
status = QuotaStatus.BLOCKED
elif used >= warning:
status = QuotaStatus.WARNING
else:
status = QuotaStatus.OK
return QuotaResult(
allowed=used < limit,
status=status,
used=used,
limit=limit,
reserved=reserved,
remaining=remaining,
)
finally:
await r.aclose()
# ------------------------------------------------------------------
# Premium (Postgres-backed)
# ------------------------------------------------------------------
@staticmethod
async def premium_reserve(
db_session: AsyncSession,
user_id: Any,
request_id: str,
reserve_tokens: int,
) -> QuotaResult:
from app.db import User
user = (
(
await db_session.execute(
select(User).where(User.id == user_id).with_for_update(of=User)
)
)
.unique()
.scalar_one_or_none()
)
if user is None:
return QuotaResult(
allowed=False,
status=QuotaStatus.BLOCKED,
used=0,
limit=0,
)
limit = user.premium_tokens_limit
used = user.premium_tokens_used
reserved = user.premium_tokens_reserved
effective = used + reserved + reserve_tokens
if effective > limit:
remaining = max(0, limit - used - reserved)
await db_session.rollback()
return QuotaResult(
allowed=False,
status=QuotaStatus.BLOCKED,
used=used,
limit=limit,
reserved=reserved,
remaining=remaining,
)
user.premium_tokens_reserved = reserved + reserve_tokens
await db_session.commit()
new_reserved = reserved + reserve_tokens
remaining = max(0, limit - used - new_reserved)
warning_threshold = int(limit * 0.8)
if (used + new_reserved) >= limit:
status = QuotaStatus.BLOCKED
elif (used + new_reserved) >= warning_threshold:
status = QuotaStatus.WARNING
else:
status = QuotaStatus.OK
return QuotaResult(
allowed=True,
status=status,
used=used,
limit=limit,
reserved=new_reserved,
remaining=remaining,
)
@staticmethod
async def premium_finalize(
db_session: AsyncSession,
user_id: Any,
request_id: str,
actual_tokens: int,
reserved_tokens: int,
) -> QuotaResult:
from app.db import User
user = (
(
await db_session.execute(
select(User).where(User.id == user_id).with_for_update(of=User)
)
)
.unique()
.scalar_one_or_none()
)
if user is None:
return QuotaResult(
allowed=False, status=QuotaStatus.BLOCKED, used=0, limit=0
)
user.premium_tokens_reserved = max(
0, user.premium_tokens_reserved - reserved_tokens
)
user.premium_tokens_used = user.premium_tokens_used + actual_tokens
await db_session.commit()
limit = user.premium_tokens_limit
used = user.premium_tokens_used
reserved = user.premium_tokens_reserved
remaining = max(0, limit - used - reserved)
warning_threshold = int(limit * 0.8)
if used >= limit:
status = QuotaStatus.BLOCKED
elif used >= warning_threshold:
status = QuotaStatus.WARNING
else:
status = QuotaStatus.OK
return QuotaResult(
allowed=True,
status=status,
used=used,
limit=limit,
reserved=reserved,
remaining=remaining,
)
@staticmethod
async def premium_release(
db_session: AsyncSession,
user_id: Any,
reserved_tokens: int,
) -> None:
from app.db import User
user = (
(
await db_session.execute(
select(User).where(User.id == user_id).with_for_update(of=User)
)
)
.unique()
.scalar_one_or_none()
)
if user is not None:
user.premium_tokens_reserved = max(
0, user.premium_tokens_reserved - reserved_tokens
)
await db_session.commit()
@staticmethod
async def premium_get_usage(
db_session: AsyncSession,
user_id: Any,
) -> QuotaResult:
from app.db import User
user = (
(await db_session.execute(select(User).where(User.id == user_id)))
.unique()
.scalar_one_or_none()
)
if user is None:
return QuotaResult(
allowed=False, status=QuotaStatus.BLOCKED, used=0, limit=0
)
limit = user.premium_tokens_limit
used = user.premium_tokens_used
reserved = user.premium_tokens_reserved
remaining = max(0, limit - used - reserved)
warning_threshold = int(limit * 0.8)
if used >= limit:
status = QuotaStatus.BLOCKED
elif used >= warning_threshold:
status = QuotaStatus.WARNING
else:
status = QuotaStatus.OK
return QuotaResult(
allowed=used < limit,
status=status,
used=used,
limit=limit,
reserved=reserved,
remaining=remaining,
)

View file

@ -0,0 +1,52 @@
"""Cloudflare Turnstile CAPTCHA verification service."""
from __future__ import annotations
import logging
import httpx
from app.config import config
logger = logging.getLogger(__name__)
TURNSTILE_VERIFY_URL = "https://challenges.cloudflare.com/turnstile/v0/siteverify"
async def verify_turnstile_token(token: str, remote_ip: str | None = None) -> bool:
"""Verify a Turnstile response token with Cloudflare.
Returns True when the token is valid and the challenge was solved by a
real user. Returns False (never raises) on network errors or invalid
tokens so callers can treat it as a simple boolean gate.
"""
if not config.TURNSTILE_ENABLED:
return True
secret = config.TURNSTILE_SECRET_KEY
if not secret:
logger.warning("TURNSTILE_SECRET_KEY is not set; skipping verification")
return True
payload: dict[str, str] = {
"secret": secret,
"response": token,
}
if remote_ip:
payload["remoteip"] = remote_ip
try:
async with httpx.AsyncClient(timeout=10) as client:
resp = await client.post(TURNSTILE_VERIFY_URL, data=payload)
resp.raise_for_status()
data = resp.json()
success = data.get("success", False)
if not success:
logger.info(
"Turnstile verification failed: %s",
data.get("error-codes", []),
)
return bool(success)
except Exception:
logger.exception("Turnstile verification request failed")
return False

View file

@ -1,4 +1,4 @@
"""Reconcile pending Stripe page purchases that might miss webhook fulfillment."""
"""Reconcile pending Stripe purchases that might miss webhook fulfillment."""
from __future__ import annotations
@ -11,7 +11,12 @@ from stripe import StripeClient, StripeError
from app.celery_app import celery_app
from app.config import config
from app.db import PagePurchase, PagePurchaseStatus
from app.db import (
PagePurchase,
PagePurchaseStatus,
PremiumTokenPurchase,
PremiumTokenPurchaseStatus,
)
from app.routes import stripe_routes
from app.tasks.celery_tasks import get_celery_session_maker
@ -126,7 +131,108 @@ async def _reconcile_pending_page_purchases() -> None:
await db_session.rollback()
logger.info(
"Stripe reconciliation completed. fulfilled=%s failed=%s checked=%s",
"Stripe page reconciliation completed. fulfilled=%s failed=%s checked=%s",
fulfilled_count,
failed_count,
len(pending_purchases),
)
@celery_app.task(name="reconcile_pending_stripe_token_purchases")
def reconcile_pending_stripe_token_purchases_task():
"""Recover paid token purchases that were left pending due to missed webhook handling."""
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(_reconcile_pending_token_purchases())
finally:
loop.close()
async def _reconcile_pending_token_purchases() -> None:
"""Reconcile stale pending token purchases against Stripe source of truth."""
stripe_client = get_stripe_client()
if stripe_client is None:
return
lookback_minutes = max(config.STRIPE_RECONCILIATION_LOOKBACK_MINUTES, 0)
batch_size = max(config.STRIPE_RECONCILIATION_BATCH_SIZE, 1)
cutoff = datetime.now(UTC) - timedelta(minutes=lookback_minutes)
async with get_celery_session_maker()() as db_session:
pending_purchases = (
(
await db_session.execute(
select(PremiumTokenPurchase)
.where(
PremiumTokenPurchase.status
== PremiumTokenPurchaseStatus.PENDING,
PremiumTokenPurchase.created_at <= cutoff,
)
.order_by(PremiumTokenPurchase.created_at.asc())
.limit(batch_size)
)
)
.scalars()
.all()
)
if not pending_purchases:
logger.debug(
"Stripe token reconciliation found no pending purchases older than %s minutes.",
lookback_minutes,
)
return
logger.info(
"Stripe token reconciliation checking %s pending purchases (cutoff=%s, batch=%s).",
len(pending_purchases),
lookback_minutes,
batch_size,
)
fulfilled_count = 0
failed_count = 0
for purchase in pending_purchases:
checkout_session_id = purchase.stripe_checkout_session_id
try:
checkout_session = stripe_client.v1.checkout.sessions.retrieve(
checkout_session_id
)
except StripeError:
logger.exception(
"Stripe token reconciliation failed to retrieve checkout session %s",
checkout_session_id,
)
await db_session.rollback()
continue
payment_status = getattr(checkout_session, "payment_status", None)
session_status = getattr(checkout_session, "status", None)
try:
if payment_status in {"paid", "no_payment_required"}:
await stripe_routes._fulfill_completed_token_purchase(
db_session, checkout_session
)
fulfilled_count += 1
elif session_status == "expired":
await stripe_routes._mark_token_purchase_failed(
db_session, str(checkout_session.id)
)
failed_count += 1
except Exception:
logger.exception(
"Stripe token reconciliation failed while processing checkout session %s",
checkout_session_id,
)
await db_session.rollback()
logger.info(
"Stripe token reconciliation completed. fulfilled=%s failed=%s checked=%s",
fulfilled_count,
failed_count,
len(pending_purchases),

View file

@ -1175,6 +1175,10 @@ async def stream_new_chat(
accumulator = start_turn()
# Premium quota tracking state
_premium_reserved = 0
_premium_request_id: str | None = None
session = async_session_maker()
try:
# Mark AI as responding to this user for live collaboration
@ -1220,6 +1224,34 @@ async def stream_new_chat(
llm_config_id,
)
# Premium quota reservation
if agent_config and agent_config.is_premium and user_id:
import uuid as _uuid
from app.config import config as _app_config
from app.services.token_quota_service import TokenQuotaService
_premium_request_id = _uuid.uuid4().hex[:16]
reserve_amount = min(
agent_config.quota_reserve_tokens
or _app_config.QUOTA_MAX_RESERVE_PER_CALL,
_app_config.QUOTA_MAX_RESERVE_PER_CALL,
)
async with shielded_async_session() as quota_session:
quota_result = await TokenQuotaService.premium_reserve(
db_session=quota_session,
user_id=UUID(user_id),
request_id=_premium_request_id,
reserve_tokens=reserve_amount,
)
_premium_reserved = reserve_amount
if not quota_result.allowed:
yield streaming_service.format_error(
"Premium token quota exceeded. Please purchase more tokens to continue using premium models."
)
yield streaming_service.format_done()
return
if not llm:
yield streaming_service.format_error("Failed to create LLM instance")
yield streaming_service.format_done()
@ -1626,6 +1658,26 @@ async def stream_new_chat(
chat_id, generated_title
)
# Finalize premium quota with actual tokens
if _premium_request_id and user_id:
try:
from app.services.token_quota_service import TokenQuotaService
async with shielded_async_session() as quota_session:
await TokenQuotaService.premium_finalize(
db_session=quota_session,
user_id=UUID(user_id),
request_id=_premium_request_id,
actual_tokens=accumulator.grand_total,
reserved_tokens=_premium_reserved,
)
except Exception:
logging.getLogger(__name__).warning(
"Failed to finalize premium quota for user %s",
user_id,
exc_info=True,
)
usage_summary = accumulator.per_message_summary()
_perf_log.info(
"[token_usage] normal new_chat: calls=%d total=%d summary=%s",
@ -1700,6 +1752,23 @@ async def stream_new_chat(
# (CancelledError is a BaseException), and the rest of the
# finally block — including session.close() — would never run.
with anyio.CancelScope(shield=True):
# Release premium reservation if not finalized
if _premium_request_id and _premium_reserved > 0 and user_id:
try:
from app.services.token_quota_service import TokenQuotaService
async with shielded_async_session() as quota_session:
await TokenQuotaService.premium_release(
db_session=quota_session,
user_id=UUID(user_id),
reserved_tokens=_premium_reserved,
)
_premium_reserved = 0
except Exception:
logging.getLogger(__name__).warning(
"Failed to release premium quota for user %s", user_id
)
try:
await session.rollback()
await clear_ai_responding(session, chat_id)