mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-01 20:03:30 +02:00
feat: no login experience and prem tokens
Some checks are pending
Build and Push Docker Images / tag_release (push) Waiting to run
Build and Push Docker Images / build (./surfsense_backend, ./surfsense_backend/Dockerfile, backend, surfsense-backend, ubuntu-24.04-arm, linux/arm64, arm64) (push) Blocked by required conditions
Build and Push Docker Images / build (./surfsense_backend, ./surfsense_backend/Dockerfile, backend, surfsense-backend, ubuntu-latest, linux/amd64, amd64) (push) Blocked by required conditions
Build and Push Docker Images / build (./surfsense_web, ./surfsense_web/Dockerfile, web, surfsense-web, ubuntu-24.04-arm, linux/arm64, arm64) (push) Blocked by required conditions
Build and Push Docker Images / build (./surfsense_web, ./surfsense_web/Dockerfile, web, surfsense-web, ubuntu-latest, linux/amd64, amd64) (push) Blocked by required conditions
Build and Push Docker Images / create_manifest (backend, surfsense-backend) (push) Blocked by required conditions
Build and Push Docker Images / create_manifest (web, surfsense-web) (push) Blocked by required conditions
Some checks are pending
Build and Push Docker Images / tag_release (push) Waiting to run
Build and Push Docker Images / build (./surfsense_backend, ./surfsense_backend/Dockerfile, backend, surfsense-backend, ubuntu-24.04-arm, linux/arm64, arm64) (push) Blocked by required conditions
Build and Push Docker Images / build (./surfsense_backend, ./surfsense_backend/Dockerfile, backend, surfsense-backend, ubuntu-latest, linux/amd64, amd64) (push) Blocked by required conditions
Build and Push Docker Images / build (./surfsense_web, ./surfsense_web/Dockerfile, web, surfsense-web, ubuntu-24.04-arm, linux/arm64, arm64) (push) Blocked by required conditions
Build and Push Docker Images / build (./surfsense_web, ./surfsense_web/Dockerfile, web, surfsense-web, ubuntu-latest, linux/amd64, amd64) (push) Blocked by required conditions
Build and Push Docker Images / create_manifest (backend, surfsense-backend) (push) Blocked by required conditions
Build and Push Docker Images / create_manifest (web, surfsense-web) (push) Blocked by required conditions
This commit is contained in:
parent
87452bb315
commit
ff4e0f9b62
68 changed files with 5914 additions and 121 deletions
|
|
@ -161,6 +161,7 @@ async def create_surfsense_deep_agent(
|
|||
firecrawl_api_key: str | None = None,
|
||||
thread_visibility: ChatVisibility | None = None,
|
||||
mentioned_document_ids: list[int] | None = None,
|
||||
anon_session_id: str | None = None,
|
||||
):
|
||||
"""
|
||||
Create a SurfSense deep agent with configurable tools and prompts.
|
||||
|
|
@ -463,6 +464,7 @@ async def create_surfsense_deep_agent(
|
|||
available_connectors=available_connectors,
|
||||
available_document_types=available_document_types,
|
||||
mentioned_document_ids=mentioned_document_ids,
|
||||
anon_session_id=anon_session_id,
|
||||
),
|
||||
SurfSenseFilesystemMiddleware(
|
||||
search_space_id=search_space_id,
|
||||
|
|
|
|||
|
|
@ -109,6 +109,12 @@ class AgentConfig:
|
|||
# Auto mode flag
|
||||
is_auto_mode: bool = False
|
||||
|
||||
# Token quota and policy
|
||||
billing_tier: str = "free"
|
||||
is_premium: bool = False
|
||||
anonymous_enabled: bool = False
|
||||
quota_reserve_tokens: int | None = None
|
||||
|
||||
@classmethod
|
||||
def from_auto_mode(cls) -> "AgentConfig":
|
||||
"""
|
||||
|
|
@ -130,6 +136,10 @@ class AgentConfig:
|
|||
config_id=AUTO_MODE_ID,
|
||||
config_name="Auto (Fastest)",
|
||||
is_auto_mode=True,
|
||||
billing_tier="free",
|
||||
is_premium=False,
|
||||
anonymous_enabled=False,
|
||||
quota_reserve_tokens=None,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
|
@ -158,6 +168,10 @@ class AgentConfig:
|
|||
config_id=config.id,
|
||||
config_name=config.name,
|
||||
is_auto_mode=False,
|
||||
billing_tier="free",
|
||||
is_premium=False,
|
||||
anonymous_enabled=False,
|
||||
quota_reserve_tokens=None,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
|
@ -195,6 +209,10 @@ class AgentConfig:
|
|||
config_id=yaml_config.get("id"),
|
||||
config_name=yaml_config.get("name"),
|
||||
is_auto_mode=False,
|
||||
billing_tier=yaml_config.get("billing_tier", "free"),
|
||||
is_premium=yaml_config.get("billing_tier", "free") == "premium",
|
||||
anonymous_enabled=yaml_config.get("anonymous_enabled", False),
|
||||
quota_reserve_tokens=yaml_config.get("quota_reserve_tokens"),
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -819,6 +819,34 @@ async def build_scoped_filesystem(
|
|||
return files, doc_id_to_path
|
||||
|
||||
|
||||
def _build_anon_scoped_filesystem(
|
||||
documents: Sequence[dict[str, Any]],
|
||||
) -> dict[str, dict[str, str]]:
|
||||
"""Build a scoped filesystem for anonymous documents without DB queries.
|
||||
|
||||
Anonymous uploads have no folders, so all files go under /documents.
|
||||
"""
|
||||
files: dict[str, dict[str, str]] = {}
|
||||
for document in documents:
|
||||
doc_meta = document.get("document") or {}
|
||||
title = str(doc_meta.get("title") or "untitled")
|
||||
file_name = _safe_filename(title)
|
||||
path = f"/documents/{file_name}"
|
||||
if path in files:
|
||||
doc_id = doc_meta.get("id", "dup")
|
||||
stem = file_name.removesuffix(".xml")
|
||||
path = f"/documents/{stem} ({doc_id}).xml"
|
||||
matched_ids = set(document.get("matched_chunk_ids") or [])
|
||||
xml_content = _build_document_xml(document, matched_chunk_ids=matched_ids)
|
||||
files[path] = {
|
||||
"content": xml_content.split("\n"),
|
||||
"encoding": "utf-8",
|
||||
"created_at": "",
|
||||
"modified_at": "",
|
||||
}
|
||||
return files
|
||||
|
||||
|
||||
class KnowledgeBaseSearchMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
||||
"""Pre-agent middleware that always searches the KB and seeds a scoped filesystem."""
|
||||
|
||||
|
|
@ -833,6 +861,7 @@ class KnowledgeBaseSearchMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
|||
available_document_types: list[str] | None = None,
|
||||
top_k: int = 10,
|
||||
mentioned_document_ids: list[int] | None = None,
|
||||
anon_session_id: str | None = None,
|
||||
) -> None:
|
||||
self.llm = llm
|
||||
self.search_space_id = search_space_id
|
||||
|
|
@ -840,6 +869,7 @@ class KnowledgeBaseSearchMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
|||
self.available_document_types = available_document_types
|
||||
self.top_k = top_k
|
||||
self.mentioned_document_ids = mentioned_document_ids or []
|
||||
self.anon_session_id = anon_session_id
|
||||
|
||||
async def _plan_search_inputs(
|
||||
self,
|
||||
|
|
@ -913,6 +943,50 @@ class KnowledgeBaseSearchMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
|||
pass
|
||||
return asyncio.run(self.abefore_agent(state, runtime))
|
||||
|
||||
async def _load_anon_document(self) -> dict[str, Any] | None:
|
||||
"""Load the anonymous user's uploaded document from Redis."""
|
||||
if not self.anon_session_id:
|
||||
return None
|
||||
try:
|
||||
import redis.asyncio as aioredis
|
||||
|
||||
from app.config import config
|
||||
|
||||
redis_client = aioredis.from_url(
|
||||
config.REDIS_APP_URL, decode_responses=True
|
||||
)
|
||||
try:
|
||||
redis_key = f"anon:doc:{self.anon_session_id}"
|
||||
data = await redis_client.get(redis_key)
|
||||
if not data:
|
||||
return None
|
||||
doc = json.loads(data)
|
||||
return {
|
||||
"document_id": -1,
|
||||
"content": doc.get("content", ""),
|
||||
"score": 1.0,
|
||||
"chunks": [
|
||||
{
|
||||
"chunk_id": -1,
|
||||
"content": doc.get("content", ""),
|
||||
}
|
||||
],
|
||||
"matched_chunk_ids": [-1],
|
||||
"document": {
|
||||
"id": -1,
|
||||
"title": doc.get("filename", "uploaded_document"),
|
||||
"document_type": "FILE",
|
||||
"metadata": {"source": "anonymous_upload"},
|
||||
},
|
||||
"source": "FILE",
|
||||
"_user_mentioned": True,
|
||||
}
|
||||
finally:
|
||||
await redis_client.aclose()
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to load anonymous document from Redis: %s", exc)
|
||||
return None
|
||||
|
||||
async def abefore_agent( # type: ignore[override]
|
||||
self,
|
||||
state: AgentState,
|
||||
|
|
@ -937,6 +1011,35 @@ class KnowledgeBaseSearchMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
|||
|
||||
t0 = _perf_log and asyncio.get_event_loop().time()
|
||||
existing_files = state.get("files")
|
||||
|
||||
# --- Anonymous session: load Redis doc and skip DB queries ---
|
||||
if self.anon_session_id:
|
||||
merged: list[dict[str, Any]] = []
|
||||
anon_doc = await self._load_anon_document()
|
||||
if anon_doc:
|
||||
merged.append(anon_doc)
|
||||
|
||||
if merged:
|
||||
new_files = _build_anon_scoped_filesystem(merged)
|
||||
mentioned_paths = set(new_files.keys())
|
||||
else:
|
||||
new_files = {}
|
||||
mentioned_paths = set()
|
||||
|
||||
ai_msg, tool_msg = _build_synthetic_ls(
|
||||
existing_files,
|
||||
new_files,
|
||||
mentioned_paths=mentioned_paths,
|
||||
)
|
||||
if t0 is not None:
|
||||
_perf_log.info(
|
||||
"[kb_fs_middleware] anon completed in %.3fs new_files=%d",
|
||||
asyncio.get_event_loop().time() - t0,
|
||||
len(new_files),
|
||||
)
|
||||
return {"files": new_files, "messages": [ai_msg, tool_msg]}
|
||||
|
||||
# --- Authenticated session: full KB search ---
|
||||
(
|
||||
planned_query,
|
||||
start_date,
|
||||
|
|
@ -954,8 +1057,6 @@ class KnowledgeBaseSearchMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
|||
document_ids=self.mentioned_document_ids,
|
||||
search_space_id=self.search_space_id,
|
||||
)
|
||||
# Clear after first turn so they are not re-fetched on subsequent
|
||||
# messages within the same agent instance.
|
||||
self.mentioned_document_ids = []
|
||||
|
||||
# --- 2. Run KB search (recency browse or hybrid) ---
|
||||
|
|
@ -983,26 +1084,24 @@ class KnowledgeBaseSearchMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
|||
|
||||
# --- 3. Merge: mentioned first, then search (dedup by doc id) ---
|
||||
seen_doc_ids: set[int] = set()
|
||||
merged: list[dict[str, Any]] = []
|
||||
merged_auth: list[dict[str, Any]] = []
|
||||
for doc in mentioned_results:
|
||||
doc_id = (doc.get("document") or {}).get("id")
|
||||
if doc_id is not None:
|
||||
seen_doc_ids.add(doc_id)
|
||||
merged.append(doc)
|
||||
merged_auth.append(doc)
|
||||
for doc in search_results:
|
||||
doc_id = (doc.get("document") or {}).get("id")
|
||||
if doc_id is not None and doc_id in seen_doc_ids:
|
||||
continue
|
||||
merged.append(doc)
|
||||
merged_auth.append(doc)
|
||||
|
||||
# --- 4. Build scoped filesystem ---
|
||||
new_files, doc_id_to_path = await build_scoped_filesystem(
|
||||
documents=merged,
|
||||
documents=merged_auth,
|
||||
search_space_id=self.search_space_id,
|
||||
)
|
||||
|
||||
# Identify which paths belong to user-mentioned documents using
|
||||
# the authoritative doc_id -> path mapping (no title guessing).
|
||||
mentioned_doc_ids = {
|
||||
(d.get("document") or {}).get("id") for d in mentioned_results
|
||||
}
|
||||
|
|
|
|||
|
|
@ -13,8 +13,6 @@ from fastapi import Depends, FastAPI, HTTPException, Request, status
|
|||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse
|
||||
from limits.storage import MemoryStorage
|
||||
from slowapi import Limiter
|
||||
from slowapi.errors import RateLimitExceeded
|
||||
from slowapi.middleware import SlowAPIMiddleware
|
||||
from slowapi.util import get_remote_address
|
||||
|
|
@ -36,6 +34,7 @@ from app.config import (
|
|||
)
|
||||
from app.db import User, create_db_and_tables, get_async_session
|
||||
from app.exceptions import GENERIC_5XX_MESSAGE, ISSUES_URL, SurfSenseError
|
||||
from app.rate_limiter import limiter
|
||||
from app.routes import router as crud_router
|
||||
from app.routes.auth_routes import router as auth_router
|
||||
from app.schemas import UserCreate, UserRead, UserUpdate
|
||||
|
|
@ -54,17 +53,7 @@ rate_limit_logger = logging.getLogger("surfsense.rate_limit")
|
|||
# Uses the same Redis instance as Celery for zero additional infrastructure.
|
||||
# Protects auth endpoints from brute force and user enumeration attacks.
|
||||
|
||||
# SlowAPI limiter — provides default rate limits (1024/min) for ALL routes
|
||||
# via the ASGI middleware. This is the general safety net.
|
||||
# in_memory_fallback ensures requests are still served (with per-worker
|
||||
# in-memory limiting) when Redis is unreachable, instead of hanging.
|
||||
limiter = Limiter(
|
||||
key_func=get_remote_address,
|
||||
storage_uri=config.REDIS_APP_URL,
|
||||
default_limits=["1024/minute"],
|
||||
in_memory_fallback_enabled=True,
|
||||
in_memory_fallback=[MemoryStorage()],
|
||||
)
|
||||
# limiter is imported from app.rate_limiter (shared module to avoid circular imports)
|
||||
|
||||
|
||||
def _get_request_id(request: Request) -> str:
|
||||
|
|
@ -126,6 +115,39 @@ def _surfsense_error_handler(request: Request, exc: SurfSenseError) -> JSONRespo
|
|||
def _http_exception_handler(request: Request, exc: HTTPException) -> JSONResponse:
|
||||
"""Wrap FastAPI/Starlette HTTPExceptions into the standard envelope."""
|
||||
rid = _get_request_id(request)
|
||||
|
||||
# Structured dict details (e.g. {"code": "CAPTCHA_REQUIRED", "message": "..."})
|
||||
# are preserved so the frontend can parse them.
|
||||
if isinstance(exc.detail, dict):
|
||||
err_code = exc.detail.get("code", _status_to_code(exc.status_code))
|
||||
message = exc.detail.get("message", str(exc.detail))
|
||||
if exc.status_code >= 500:
|
||||
_error_logger.error(
|
||||
"[%s] %s - HTTPException %d: %s",
|
||||
rid,
|
||||
request.url.path,
|
||||
exc.status_code,
|
||||
message,
|
||||
)
|
||||
message = GENERIC_5XX_MESSAGE
|
||||
err_code = "INTERNAL_ERROR"
|
||||
body = {
|
||||
"error": {
|
||||
"code": err_code,
|
||||
"message": message,
|
||||
"status": exc.status_code,
|
||||
"request_id": rid,
|
||||
"timestamp": datetime.now(UTC).isoformat(),
|
||||
"report_url": ISSUES_URL,
|
||||
},
|
||||
"detail": exc.detail,
|
||||
}
|
||||
return JSONResponse(
|
||||
status_code=exc.status_code,
|
||||
content=body,
|
||||
headers={"X-Request-ID": rid},
|
||||
)
|
||||
|
||||
detail = exc.detail if isinstance(exc.detail, str) else str(exc.detail)
|
||||
if exc.status_code >= 500:
|
||||
_error_logger.error(
|
||||
|
|
@ -663,6 +685,13 @@ if config.AUTH_TYPE == "GOOGLE":
|
|||
return response
|
||||
|
||||
|
||||
# Anonymous (no-login) chat routes — mounted at /api/v1/public/anon-chat
|
||||
from app.routes.anonymous_chat_routes import ( # noqa: E402
|
||||
router as anonymous_chat_router,
|
||||
)
|
||||
|
||||
app.include_router(anonymous_chat_router)
|
||||
|
||||
app.include_router(crud_router, prefix="/api/v1", tags=["crud"])
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -187,4 +187,11 @@ celery_app.conf.beat_schedule = {
|
|||
"expires": 60,
|
||||
},
|
||||
},
|
||||
"reconcile-pending-stripe-token-purchases": {
|
||||
"task": "reconcile_pending_stripe_token_purchases",
|
||||
"schedule": crontab(**stripe_reconciliation_schedule_params),
|
||||
"options": {
|
||||
"expires": 60,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
|
|
|||
|
|
@ -42,7 +42,25 @@ def load_global_llm_configs():
|
|||
try:
|
||||
with open(global_config_file, encoding="utf-8") as f:
|
||||
data = yaml.safe_load(f)
|
||||
return data.get("global_llm_configs", [])
|
||||
configs = data.get("global_llm_configs", [])
|
||||
|
||||
seen_slugs: dict[str, int] = {}
|
||||
for cfg in configs:
|
||||
cfg.setdefault("billing_tier", "free")
|
||||
cfg.setdefault("anonymous_enabled", False)
|
||||
cfg.setdefault("seo_enabled", False)
|
||||
|
||||
if cfg.get("seo_enabled") and cfg.get("seo_slug"):
|
||||
slug = cfg["seo_slug"]
|
||||
if slug in seen_slugs:
|
||||
print(
|
||||
f"Warning: Duplicate seo_slug '{slug}' in global LLM configs "
|
||||
f"(ids {seen_slugs[slug]} and {cfg.get('id')})"
|
||||
)
|
||||
else:
|
||||
seen_slugs[slug] = cfg.get("id", 0)
|
||||
|
||||
return configs
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to load global LLM configs: {e}")
|
||||
return []
|
||||
|
|
@ -307,6 +325,36 @@ class Config:
|
|||
os.getenv("STRIPE_RECONCILIATION_BATCH_SIZE", "100")
|
||||
)
|
||||
|
||||
# Premium token quota settings
|
||||
PREMIUM_TOKEN_LIMIT = int(os.getenv("PREMIUM_TOKEN_LIMIT", "5000000"))
|
||||
STRIPE_PREMIUM_TOKEN_PRICE_ID = os.getenv("STRIPE_PREMIUM_TOKEN_PRICE_ID")
|
||||
STRIPE_TOKENS_PER_UNIT = int(os.getenv("STRIPE_TOKENS_PER_UNIT", "1000000"))
|
||||
STRIPE_TOKEN_BUYING_ENABLED = (
|
||||
os.getenv("STRIPE_TOKEN_BUYING_ENABLED", "FALSE").upper() == "TRUE"
|
||||
)
|
||||
|
||||
# Anonymous / no-login mode settings
|
||||
NOLOGIN_MODE_ENABLED = os.getenv("NOLOGIN_MODE_ENABLED", "FALSE").upper() == "TRUE"
|
||||
ANON_TOKEN_LIMIT = int(os.getenv("ANON_TOKEN_LIMIT", "1000000"))
|
||||
ANON_TOKEN_WARNING_THRESHOLD = int(
|
||||
os.getenv("ANON_TOKEN_WARNING_THRESHOLD", "800000")
|
||||
)
|
||||
ANON_TOKEN_QUOTA_TTL_DAYS = int(os.getenv("ANON_TOKEN_QUOTA_TTL_DAYS", "30"))
|
||||
ANON_MAX_UPLOAD_SIZE_MB = int(os.getenv("ANON_MAX_UPLOAD_SIZE_MB", "5"))
|
||||
|
||||
# Default quota reserve tokens when not specified per-model
|
||||
QUOTA_MAX_RESERVE_PER_CALL = int(os.getenv("QUOTA_MAX_RESERVE_PER_CALL", "8000"))
|
||||
|
||||
# Abuse prevention: concurrent stream cap and CAPTCHA
|
||||
ANON_MAX_CONCURRENT_STREAMS = int(os.getenv("ANON_MAX_CONCURRENT_STREAMS", "2"))
|
||||
ANON_CAPTCHA_REQUEST_THRESHOLD = int(
|
||||
os.getenv("ANON_CAPTCHA_REQUEST_THRESHOLD", "5")
|
||||
)
|
||||
|
||||
# Cloudflare Turnstile CAPTCHA
|
||||
TURNSTILE_ENABLED = os.getenv("TURNSTILE_ENABLED", "FALSE").upper() == "TRUE"
|
||||
TURNSTILE_SECRET_KEY = os.getenv("TURNSTILE_SECRET_KEY", "")
|
||||
|
||||
# Auth
|
||||
AUTH_TYPE = os.getenv("AUTH_TYPE")
|
||||
REGISTRATION_ENABLED = os.getenv("REGISTRATION_ENABLED", "TRUE").upper() == "TRUE"
|
||||
|
|
|
|||
|
|
@ -48,6 +48,11 @@ global_llm_configs:
|
|||
- id: -1
|
||||
name: "Global GPT-4 Turbo"
|
||||
description: "OpenAI's GPT-4 Turbo with default prompts and citations"
|
||||
billing_tier: "free"
|
||||
anonymous_enabled: true
|
||||
seo_enabled: true
|
||||
seo_slug: "gpt-4-turbo"
|
||||
quota_reserve_tokens: 4000
|
||||
provider: "OPENAI"
|
||||
model_name: "gpt-4-turbo-preview"
|
||||
api_key: "sk-your-openai-api-key-here"
|
||||
|
|
@ -67,6 +72,11 @@ global_llm_configs:
|
|||
- id: -2
|
||||
name: "Global Claude 3 Opus"
|
||||
description: "Anthropic's most capable model with citations"
|
||||
billing_tier: "free"
|
||||
anonymous_enabled: true
|
||||
seo_enabled: true
|
||||
seo_slug: "claude-3-opus"
|
||||
quota_reserve_tokens: 4000
|
||||
provider: "ANTHROPIC"
|
||||
model_name: "claude-3-opus-20240229"
|
||||
api_key: "sk-ant-your-anthropic-api-key-here"
|
||||
|
|
@ -84,6 +94,11 @@ global_llm_configs:
|
|||
- id: -3
|
||||
name: "Global GPT-3.5 Turbo (Fast)"
|
||||
description: "Fast responses without citations for quick queries"
|
||||
billing_tier: "free"
|
||||
anonymous_enabled: true
|
||||
seo_enabled: true
|
||||
seo_slug: "gpt-3.5-turbo-fast"
|
||||
quota_reserve_tokens: 2000
|
||||
provider: "OPENAI"
|
||||
model_name: "gpt-3.5-turbo"
|
||||
api_key: "sk-your-openai-api-key-here"
|
||||
|
|
@ -101,6 +116,11 @@ global_llm_configs:
|
|||
- id: -4
|
||||
name: "Global DeepSeek Chat (Chinese)"
|
||||
description: "DeepSeek optimized for Chinese language responses"
|
||||
billing_tier: "free"
|
||||
anonymous_enabled: true
|
||||
seo_enabled: true
|
||||
seo_slug: "deepseek-chat-chinese"
|
||||
quota_reserve_tokens: 4000
|
||||
provider: "DEEPSEEK"
|
||||
model_name: "deepseek-chat"
|
||||
api_key: "your-deepseek-api-key-here"
|
||||
|
|
@ -128,6 +148,11 @@ global_llm_configs:
|
|||
- id: -5
|
||||
name: "Global Azure GPT-4o"
|
||||
description: "Azure OpenAI GPT-4o deployment"
|
||||
billing_tier: "free"
|
||||
anonymous_enabled: true
|
||||
seo_enabled: true
|
||||
seo_slug: "azure-gpt-4o"
|
||||
quota_reserve_tokens: 4000
|
||||
provider: "AZURE"
|
||||
# model_name format for Azure: azure/<your-deployment-name>
|
||||
model_name: "azure/gpt-4o-deployment"
|
||||
|
|
@ -151,6 +176,11 @@ global_llm_configs:
|
|||
- id: -6
|
||||
name: "Global Azure GPT-4 Turbo"
|
||||
description: "Azure OpenAI GPT-4 Turbo deployment"
|
||||
billing_tier: "free"
|
||||
anonymous_enabled: true
|
||||
seo_enabled: true
|
||||
seo_slug: "azure-gpt-4-turbo"
|
||||
quota_reserve_tokens: 4000
|
||||
provider: "AZURE"
|
||||
model_name: "azure/gpt-4-turbo-deployment"
|
||||
api_key: "your-azure-api-key-here"
|
||||
|
|
@ -170,6 +200,11 @@ global_llm_configs:
|
|||
- id: -7
|
||||
name: "Global Groq Llama 3"
|
||||
description: "Ultra-fast Llama 3 70B via Groq"
|
||||
billing_tier: "free"
|
||||
anonymous_enabled: true
|
||||
seo_enabled: true
|
||||
seo_slug: "groq-llama-3"
|
||||
quota_reserve_tokens: 8000
|
||||
provider: "GROQ"
|
||||
model_name: "llama3-70b-8192"
|
||||
api_key: "your-groq-api-key-here"
|
||||
|
|
@ -187,6 +222,11 @@ global_llm_configs:
|
|||
- id: -8
|
||||
name: "Global MiniMax M2.5"
|
||||
description: "MiniMax M2.5 with 204K context window and competitive pricing"
|
||||
billing_tier: "free"
|
||||
anonymous_enabled: true
|
||||
seo_enabled: true
|
||||
seo_slug: "minimax-m2.5"
|
||||
quota_reserve_tokens: 4000
|
||||
provider: "MINIMAX"
|
||||
model_name: "MiniMax-M2.5"
|
||||
api_key: "your-minimax-api-key-here"
|
||||
|
|
@ -365,3 +405,13 @@ global_vision_llm_configs:
|
|||
# - Only use vision-capable models (GPT-4o, Gemini, Claude 3, etc.)
|
||||
# - Lower temperature (0.3) is recommended for accurate screenshot analysis
|
||||
# - Lower max_tokens (1000) is sufficient since autocomplete produces short suggestions
|
||||
#
|
||||
# TOKEN QUOTA & ANONYMOUS ACCESS NOTES:
|
||||
# - billing_tier: "free" or "premium". Controls whether registered users need premium token quota.
|
||||
# - anonymous_enabled: true/false. Whether the model appears in the public no-login catalog.
|
||||
# - seo_enabled: true/false. Whether a /free/<seo_slug> landing page is generated.
|
||||
# - seo_slug: Stable URL slug for SEO pages. Must be unique. Do NOT change once public.
|
||||
# - seo_title: Optional HTML title tag override for the model's /free/<slug> page.
|
||||
# - seo_description: Optional meta description override for the model's /free/<slug> page.
|
||||
# - quota_reserve_tokens: Tokens reserved before each LLM call for quota enforcement.
|
||||
# Independent of litellm_params.max_tokens. Used by the token quota service.
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ from sqlalchemy import (
|
|||
ARRAY,
|
||||
JSON,
|
||||
TIMESTAMP,
|
||||
BigInteger,
|
||||
Boolean,
|
||||
Column,
|
||||
Enum as SQLAlchemyEnum,
|
||||
|
|
@ -318,6 +319,12 @@ class PagePurchaseStatus(StrEnum):
|
|||
FAILED = "failed"
|
||||
|
||||
|
||||
class PremiumTokenPurchaseStatus(StrEnum):
|
||||
PENDING = "pending"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
|
||||
|
||||
# Centralized configuration for incentive tasks
|
||||
# This makes it easy to add new tasks without changing code in multiple places
|
||||
INCENTIVE_TASKS_CONFIG = {
|
||||
|
|
@ -1739,6 +1746,38 @@ class PagePurchase(Base, TimestampMixin):
|
|||
user = relationship("User", back_populates="page_purchases")
|
||||
|
||||
|
||||
class PremiumTokenPurchase(Base, TimestampMixin):
|
||||
"""Tracks Stripe checkout sessions used to grant additional premium token credits."""
|
||||
|
||||
__tablename__ = "premium_token_purchases"
|
||||
__allow_unmapped__ = True
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
user_id = Column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("user.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
stripe_checkout_session_id = Column(
|
||||
String(255), nullable=False, unique=True, index=True
|
||||
)
|
||||
stripe_payment_intent_id = Column(String(255), nullable=True, index=True)
|
||||
quantity = Column(Integer, nullable=False)
|
||||
tokens_granted = Column(BigInteger, nullable=False)
|
||||
amount_total = Column(Integer, nullable=True)
|
||||
currency = Column(String(10), nullable=True)
|
||||
status = Column(
|
||||
SQLAlchemyEnum(PremiumTokenPurchaseStatus),
|
||||
nullable=False,
|
||||
default=PremiumTokenPurchaseStatus.PENDING,
|
||||
index=True,
|
||||
)
|
||||
completed_at = Column(TIMESTAMP(timezone=True), nullable=True)
|
||||
|
||||
user = relationship("User", back_populates="premium_token_purchases")
|
||||
|
||||
|
||||
class SearchSpaceRole(BaseModel, TimestampMixin):
|
||||
"""
|
||||
Custom roles that can be defined per search space.
|
||||
|
|
@ -2009,6 +2048,11 @@ if config.AUTH_TYPE == "GOOGLE":
|
|||
back_populates="user",
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
premium_token_purchases = relationship(
|
||||
"PremiumTokenPurchase",
|
||||
back_populates="user",
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
|
||||
# Page usage tracking for ETL services
|
||||
pages_limit = Column(
|
||||
|
|
@ -2019,6 +2063,19 @@ if config.AUTH_TYPE == "GOOGLE":
|
|||
)
|
||||
pages_used = Column(Integer, nullable=False, default=0, server_default="0")
|
||||
|
||||
premium_tokens_limit = Column(
|
||||
BigInteger,
|
||||
nullable=False,
|
||||
default=config.PREMIUM_TOKEN_LIMIT,
|
||||
server_default=str(config.PREMIUM_TOKEN_LIMIT),
|
||||
)
|
||||
premium_tokens_used = Column(
|
||||
BigInteger, nullable=False, default=0, server_default="0"
|
||||
)
|
||||
premium_tokens_reserved = Column(
|
||||
BigInteger, nullable=False, default=0, server_default="0"
|
||||
)
|
||||
|
||||
# User profile from OAuth
|
||||
display_name = Column(String, nullable=True)
|
||||
avatar_url = Column(String, nullable=True)
|
||||
|
|
@ -2123,6 +2180,11 @@ else:
|
|||
back_populates="user",
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
premium_token_purchases = relationship(
|
||||
"PremiumTokenPurchase",
|
||||
back_populates="user",
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
|
||||
# Page usage tracking for ETL services
|
||||
pages_limit = Column(
|
||||
|
|
@ -2133,6 +2195,19 @@ else:
|
|||
)
|
||||
pages_used = Column(Integer, nullable=False, default=0, server_default="0")
|
||||
|
||||
premium_tokens_limit = Column(
|
||||
BigInteger,
|
||||
nullable=False,
|
||||
default=config.PREMIUM_TOKEN_LIMIT,
|
||||
server_default=str(config.PREMIUM_TOKEN_LIMIT),
|
||||
)
|
||||
premium_tokens_used = Column(
|
||||
BigInteger, nullable=False, default=0, server_default="0"
|
||||
)
|
||||
premium_tokens_reserved = Column(
|
||||
BigInteger, nullable=False, default=0, server_default="0"
|
||||
)
|
||||
|
||||
# User profile (can be set manually for non-OAuth users)
|
||||
display_name = Column(String, nullable=True)
|
||||
avatar_url = Column(String, nullable=True)
|
||||
|
|
|
|||
15
surfsense_backend/app/rate_limiter.py
Normal file
15
surfsense_backend/app/rate_limiter.py
Normal file
|
|
@ -0,0 +1,15 @@
|
|||
"""Shared SlowAPI limiter instance used by app.py and route modules."""
|
||||
|
||||
from limits.storage import MemoryStorage
|
||||
from slowapi import Limiter
|
||||
from slowapi.util import get_remote_address
|
||||
|
||||
from app.config import config
|
||||
|
||||
limiter = Limiter(
|
||||
key_func=get_remote_address,
|
||||
storage_uri=config.REDIS_APP_URL,
|
||||
default_limits=["1024/minute"],
|
||||
in_memory_fallback_enabled=True,
|
||||
in_memory_fallback=[MemoryStorage()],
|
||||
)
|
||||
610
surfsense_backend/app/routes/anonymous_chat_routes.py
Normal file
610
surfsense_backend/app/routes/anonymous_chat_routes.py
Normal file
|
|
@ -0,0 +1,610 @@
|
|||
"""Public API endpoints for anonymous (no-login) chat."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import secrets
|
||||
import uuid
|
||||
from pathlib import PurePosixPath
|
||||
from typing import Any
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Request, Response, UploadFile, status
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.config import config
|
||||
from app.etl_pipeline.file_classifier import (
|
||||
DIRECT_CONVERT_EXTENSIONS,
|
||||
PLAINTEXT_EXTENSIONS,
|
||||
)
|
||||
from app.rate_limiter import limiter
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/v1/public/anon-chat", tags=["anonymous-chat"])
|
||||
|
||||
ANON_COOKIE_NAME = "surfsense_anon_session"
|
||||
ANON_COOKIE_MAX_AGE = config.ANON_TOKEN_QUOTA_TTL_DAYS * 86400
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _get_or_create_session_id(request: Request, response: Response) -> str:
|
||||
"""Read the signed session cookie or create a new one."""
|
||||
session_id = request.cookies.get(ANON_COOKIE_NAME)
|
||||
if session_id and len(session_id) == 43:
|
||||
return session_id
|
||||
session_id = secrets.token_urlsafe(32)
|
||||
response.set_cookie(
|
||||
key=ANON_COOKIE_NAME,
|
||||
value=session_id,
|
||||
max_age=ANON_COOKIE_MAX_AGE,
|
||||
httponly=True,
|
||||
samesite="lax",
|
||||
secure=request.url.scheme == "https",
|
||||
path="/",
|
||||
)
|
||||
return session_id
|
||||
|
||||
|
||||
def _get_client_ip(request: Request) -> str:
|
||||
forwarded = request.headers.get("x-forwarded-for")
|
||||
return (
|
||||
forwarded.split(",")[0].strip()
|
||||
if forwarded
|
||||
else (request.client.host if request.client else "unknown")
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Schemas
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class AnonChatRequest(BaseModel):
|
||||
model_slug: str = Field(..., max_length=100)
|
||||
messages: list[dict[str, Any]] = Field(..., min_length=1)
|
||||
disabled_tools: list[str] | None = None
|
||||
turnstile_token: str | None = None
|
||||
|
||||
|
||||
class AnonQuotaResponse(BaseModel):
|
||||
used: int
|
||||
limit: int
|
||||
remaining: int
|
||||
status: str
|
||||
warning_threshold: int
|
||||
captcha_required: bool = False
|
||||
|
||||
|
||||
class AnonModelResponse(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
description: str | None = None
|
||||
provider: str
|
||||
model_name: str
|
||||
billing_tier: str = "free"
|
||||
is_premium: bool = False
|
||||
seo_slug: str | None = None
|
||||
seo_enabled: bool = False
|
||||
seo_title: str | None = None
|
||||
seo_description: str | None = None
|
||||
quota_reserve_tokens: int | None = None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Routes
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@router.get("/models", response_model=list[AnonModelResponse])
|
||||
async def list_anonymous_models():
|
||||
"""Return all models enabled for anonymous access."""
|
||||
if not config.NOLOGIN_MODE_ENABLED:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="No-login mode is not enabled.",
|
||||
)
|
||||
|
||||
models = []
|
||||
for cfg in config.GLOBAL_LLM_CONFIGS:
|
||||
if cfg.get("anonymous_enabled", False):
|
||||
models.append(
|
||||
AnonModelResponse(
|
||||
id=cfg.get("id", 0),
|
||||
name=cfg.get("name", ""),
|
||||
description=cfg.get("description"),
|
||||
provider=cfg.get("provider", ""),
|
||||
model_name=cfg.get("model_name", ""),
|
||||
billing_tier=cfg.get("billing_tier", "free"),
|
||||
is_premium=cfg.get("billing_tier", "free") == "premium",
|
||||
seo_slug=cfg.get("seo_slug"),
|
||||
seo_enabled=cfg.get("seo_enabled", False),
|
||||
seo_title=cfg.get("seo_title"),
|
||||
seo_description=cfg.get("seo_description"),
|
||||
quota_reserve_tokens=cfg.get("quota_reserve_tokens"),
|
||||
)
|
||||
)
|
||||
return models
|
||||
|
||||
|
||||
@router.get("/models/{slug}", response_model=AnonModelResponse)
|
||||
async def get_anonymous_model(slug: str):
|
||||
"""Return a single model by its SEO slug."""
|
||||
if not config.NOLOGIN_MODE_ENABLED:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="No-login mode is not enabled.",
|
||||
)
|
||||
|
||||
for cfg in config.GLOBAL_LLM_CONFIGS:
|
||||
if cfg.get("anonymous_enabled", False) and cfg.get("seo_slug") == slug:
|
||||
return AnonModelResponse(
|
||||
id=cfg.get("id", 0),
|
||||
name=cfg.get("name", ""),
|
||||
description=cfg.get("description"),
|
||||
provider=cfg.get("provider", ""),
|
||||
model_name=cfg.get("model_name", ""),
|
||||
billing_tier=cfg.get("billing_tier", "free"),
|
||||
is_premium=cfg.get("billing_tier", "free") == "premium",
|
||||
seo_slug=cfg.get("seo_slug"),
|
||||
seo_enabled=cfg.get("seo_enabled", False),
|
||||
seo_title=cfg.get("seo_title"),
|
||||
seo_description=cfg.get("seo_description"),
|
||||
quota_reserve_tokens=cfg.get("quota_reserve_tokens"),
|
||||
)
|
||||
|
||||
raise HTTPException(status_code=404, detail="Model not found")
|
||||
|
||||
|
||||
@router.get("/quota", response_model=AnonQuotaResponse)
|
||||
@limiter.limit("30/minute")
|
||||
async def get_anonymous_quota(request: Request, response: Response):
|
||||
"""Return current token usage for the anonymous session.
|
||||
|
||||
Reports the *stricter* of session and IP buckets so that opening a
|
||||
new browser on the same IP doesn't show a misleadingly fresh quota.
|
||||
"""
|
||||
if not config.NOLOGIN_MODE_ENABLED:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="No-login mode is not enabled.",
|
||||
)
|
||||
|
||||
from app.services.token_quota_service import (
|
||||
TokenQuotaService,
|
||||
compute_anon_identity_key,
|
||||
compute_ip_quota_key,
|
||||
)
|
||||
|
||||
client_ip = _get_client_ip(request)
|
||||
|
||||
session_id = _get_or_create_session_id(request, response)
|
||||
session_key = compute_anon_identity_key(session_id)
|
||||
session_result = await TokenQuotaService.anon_get_usage(session_key)
|
||||
|
||||
ip_key = compute_ip_quota_key(client_ip)
|
||||
ip_result = await TokenQuotaService.anon_get_usage(ip_key)
|
||||
|
||||
# Use whichever bucket has higher usage — that's the real constraint
|
||||
result = ip_result if ip_result.used > session_result.used else session_result
|
||||
|
||||
captcha_needed = False
|
||||
if config.TURNSTILE_ENABLED:
|
||||
req_count = await TokenQuotaService.anon_get_request_count(client_ip)
|
||||
captcha_needed = req_count >= config.ANON_CAPTCHA_REQUEST_THRESHOLD
|
||||
|
||||
return AnonQuotaResponse(
|
||||
used=result.used,
|
||||
limit=result.limit,
|
||||
remaining=result.remaining,
|
||||
status=result.status.value,
|
||||
warning_threshold=config.ANON_TOKEN_WARNING_THRESHOLD,
|
||||
captcha_required=captcha_needed,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/stream")
|
||||
@limiter.limit("15/minute")
|
||||
async def stream_anonymous_chat(
|
||||
body: AnonChatRequest,
|
||||
request: Request,
|
||||
response: Response,
|
||||
):
|
||||
"""Stream a chat response for an anonymous user with quota enforcement."""
|
||||
if not config.NOLOGIN_MODE_ENABLED:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="No-login mode is not enabled.",
|
||||
)
|
||||
|
||||
from app.agents.new_chat.llm_config import (
|
||||
AgentConfig,
|
||||
create_chat_litellm_from_agent_config,
|
||||
)
|
||||
from app.services.token_quota_service import (
|
||||
TokenQuotaService,
|
||||
compute_anon_identity_key,
|
||||
compute_ip_quota_key,
|
||||
)
|
||||
from app.services.turnstile_service import verify_turnstile_token
|
||||
|
||||
# Find the model config by slug
|
||||
model_cfg = None
|
||||
for cfg in config.GLOBAL_LLM_CONFIGS:
|
||||
if (
|
||||
cfg.get("anonymous_enabled", False)
|
||||
and cfg.get("seo_slug") == body.model_slug
|
||||
):
|
||||
model_cfg = cfg
|
||||
break
|
||||
|
||||
if model_cfg is None:
|
||||
raise HTTPException(
|
||||
status_code=404, detail="Model not found or not available for anonymous use"
|
||||
)
|
||||
|
||||
client_ip = _get_client_ip(request)
|
||||
|
||||
# --- Concurrent stream limit ---
|
||||
slot_acquired = await TokenQuotaService.anon_acquire_stream_slot(
|
||||
client_ip, max_concurrent=config.ANON_MAX_CONCURRENT_STREAMS
|
||||
)
|
||||
if not slot_acquired:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
||||
detail={
|
||||
"code": "ANON_TOO_MANY_STREAMS",
|
||||
"message": f"Max {config.ANON_MAX_CONCURRENT_STREAMS} concurrent chats allowed. Please wait for a response to finish.",
|
||||
},
|
||||
)
|
||||
|
||||
try:
|
||||
# --- CAPTCHA enforcement (check count without incrementing; count
|
||||
# is bumped only after a successful response in _generate) ---
|
||||
if config.TURNSTILE_ENABLED:
|
||||
req_count = await TokenQuotaService.anon_get_request_count(client_ip)
|
||||
if req_count >= config.ANON_CAPTCHA_REQUEST_THRESHOLD:
|
||||
if not body.turnstile_token:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail={
|
||||
"code": "CAPTCHA_REQUIRED",
|
||||
"message": "Please complete the CAPTCHA to continue chatting.",
|
||||
},
|
||||
)
|
||||
valid = await verify_turnstile_token(body.turnstile_token, client_ip)
|
||||
if not valid:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail={
|
||||
"code": "CAPTCHA_INVALID",
|
||||
"message": "CAPTCHA verification failed. Please try again.",
|
||||
},
|
||||
)
|
||||
await TokenQuotaService.anon_reset_request_count(client_ip)
|
||||
|
||||
# Build identity keys
|
||||
session_id = _get_or_create_session_id(request, response)
|
||||
session_key = compute_anon_identity_key(session_id)
|
||||
ip_key = compute_ip_quota_key(client_ip)
|
||||
|
||||
# Reserve tokens
|
||||
reserve_amount = min(
|
||||
model_cfg.get("quota_reserve_tokens", config.QUOTA_MAX_RESERVE_PER_CALL),
|
||||
config.QUOTA_MAX_RESERVE_PER_CALL,
|
||||
)
|
||||
request_id = uuid.uuid4().hex[:16]
|
||||
|
||||
quota_result = await TokenQuotaService.anon_reserve(
|
||||
session_key=session_key,
|
||||
ip_key=ip_key,
|
||||
request_id=request_id,
|
||||
reserve_tokens=reserve_amount,
|
||||
)
|
||||
|
||||
if not quota_result.allowed:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
||||
detail={
|
||||
"code": "ANON_QUOTA_EXCEEDED",
|
||||
"message": "You've used all your free tokens. Create an account for 5M more!",
|
||||
"used": quota_result.used,
|
||||
"limit": quota_result.limit,
|
||||
},
|
||||
)
|
||||
|
||||
# Create agent config from YAML
|
||||
agent_config = AgentConfig.from_yaml_config(model_cfg)
|
||||
llm = create_chat_litellm_from_agent_config(agent_config)
|
||||
if not llm:
|
||||
await TokenQuotaService.anon_release(session_key, ip_key, request_id)
|
||||
raise HTTPException(status_code=500, detail="Failed to create LLM instance")
|
||||
|
||||
# Server-side tool allow-list enforcement
|
||||
anon_allowed_tools = {"web_search"}
|
||||
client_disabled = set(body.disabled_tools) if body.disabled_tools else set()
|
||||
enabled_for_agent = anon_allowed_tools - client_disabled
|
||||
|
||||
except HTTPException:
|
||||
await TokenQuotaService.anon_release_stream_slot(client_ip)
|
||||
raise
|
||||
|
||||
async def _generate():
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
from app.agents.new_chat.chat_deepagent import create_surfsense_deep_agent
|
||||
from app.agents.new_chat.checkpointer import get_checkpointer
|
||||
from app.db import shielded_async_session
|
||||
from app.services.connector_service import ConnectorService
|
||||
from app.services.new_streaming_service import VercelStreamingService
|
||||
from app.services.token_tracking_service import start_turn
|
||||
from app.tasks.chat.stream_new_chat import StreamResult, _stream_agent_events
|
||||
|
||||
accumulator = start_turn()
|
||||
streaming_service = VercelStreamingService()
|
||||
|
||||
try:
|
||||
async with shielded_async_session() as session:
|
||||
connector_service = ConnectorService(session, search_space_id=None)
|
||||
checkpointer = await get_checkpointer()
|
||||
|
||||
anon_thread_id = f"anon-{session_id}-{request_id}"
|
||||
|
||||
agent = await create_surfsense_deep_agent(
|
||||
llm=llm,
|
||||
search_space_id=0,
|
||||
db_session=session,
|
||||
connector_service=connector_service,
|
||||
checkpointer=checkpointer,
|
||||
user_id=None,
|
||||
thread_id=None,
|
||||
agent_config=agent_config,
|
||||
enabled_tools=list(enabled_for_agent),
|
||||
disabled_tools=None,
|
||||
anon_session_id=session_id,
|
||||
)
|
||||
|
||||
user_query = ""
|
||||
for msg in reversed(body.messages):
|
||||
if msg.get("role") == "user":
|
||||
user_query = msg.get("content", "")
|
||||
break
|
||||
|
||||
langchain_messages = [HumanMessage(content=user_query)]
|
||||
input_state = {
|
||||
"messages": langchain_messages,
|
||||
"search_space_id": 0,
|
||||
}
|
||||
|
||||
langgraph_config = {
|
||||
"configurable": {"thread_id": anon_thread_id},
|
||||
"recursion_limit": 40,
|
||||
}
|
||||
|
||||
yield streaming_service.format_message_start()
|
||||
yield streaming_service.format_start_step()
|
||||
|
||||
initial_step_id = "thinking-1"
|
||||
query_preview = user_query[:80] + (
|
||||
"..." if len(user_query) > 80 else ""
|
||||
)
|
||||
initial_items = [f"Processing: {query_preview}"]
|
||||
|
||||
yield streaming_service.format_thinking_step(
|
||||
step_id=initial_step_id,
|
||||
title="Understanding your request",
|
||||
status="in_progress",
|
||||
items=initial_items,
|
||||
)
|
||||
|
||||
stream_result = StreamResult()
|
||||
|
||||
async for sse in _stream_agent_events(
|
||||
agent=agent,
|
||||
config=langgraph_config,
|
||||
input_data=input_state,
|
||||
streaming_service=streaming_service,
|
||||
result=stream_result,
|
||||
step_prefix="thinking",
|
||||
initial_step_id=initial_step_id,
|
||||
initial_step_title="Understanding your request",
|
||||
initial_step_items=initial_items,
|
||||
):
|
||||
yield sse
|
||||
|
||||
# Finalize quota with actual tokens
|
||||
actual_tokens = accumulator.grand_total
|
||||
finalize_result = await TokenQuotaService.anon_finalize(
|
||||
session_key=session_key,
|
||||
ip_key=ip_key,
|
||||
request_id=request_id,
|
||||
actual_tokens=actual_tokens,
|
||||
)
|
||||
|
||||
# Count this as 1 completed response for CAPTCHA threshold
|
||||
if config.TURNSTILE_ENABLED:
|
||||
await TokenQuotaService.anon_increment_request_count(client_ip)
|
||||
|
||||
yield streaming_service.format_data(
|
||||
"anon-quota",
|
||||
{
|
||||
"used": finalize_result.used,
|
||||
"limit": finalize_result.limit,
|
||||
"remaining": finalize_result.remaining,
|
||||
"status": finalize_result.status.value,
|
||||
},
|
||||
)
|
||||
|
||||
if accumulator.per_message_summary():
|
||||
yield streaming_service.format_data(
|
||||
"token-usage",
|
||||
{
|
||||
"usage": accumulator.per_message_summary(),
|
||||
"prompt_tokens": accumulator.total_prompt_tokens,
|
||||
"completion_tokens": accumulator.total_completion_tokens,
|
||||
"total_tokens": accumulator.grand_total,
|
||||
},
|
||||
)
|
||||
|
||||
yield streaming_service.format_finish_step()
|
||||
yield streaming_service.format_finish()
|
||||
yield streaming_service.format_done()
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Anonymous chat stream error")
|
||||
await TokenQuotaService.anon_release(session_key, ip_key, request_id)
|
||||
yield streaming_service.format_error(f"Error during chat: {e!s}")
|
||||
yield streaming_service.format_done()
|
||||
finally:
|
||||
await TokenQuotaService.anon_release_stream_slot(client_ip)
|
||||
|
||||
return StreamingResponse(
|
||||
_generate(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Anonymous Document Upload (1-doc limit, plaintext/direct-convert only)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
ANON_ALLOWED_EXTENSIONS = PLAINTEXT_EXTENSIONS | DIRECT_CONVERT_EXTENSIONS
|
||||
ANON_DOC_REDIS_PREFIX = "anon:doc:"
|
||||
|
||||
|
||||
class AnonDocResponse(BaseModel):
|
||||
filename: str
|
||||
size_bytes: int
|
||||
status: str = "uploaded"
|
||||
|
||||
|
||||
@router.post("/upload", response_model=AnonDocResponse)
|
||||
@limiter.limit("5/minute")
|
||||
async def upload_anonymous_document(
|
||||
file: UploadFile,
|
||||
request: Request,
|
||||
response: Response,
|
||||
):
|
||||
"""Upload a single document for anonymous chat (1-doc limit per session)."""
|
||||
if not config.NOLOGIN_MODE_ENABLED:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="No-login mode is not enabled.",
|
||||
)
|
||||
|
||||
session_id = _get_or_create_session_id(request, response)
|
||||
|
||||
if not file.filename:
|
||||
raise HTTPException(status_code=400, detail="No filename provided")
|
||||
|
||||
ext = PurePosixPath(file.filename).suffix.lower()
|
||||
if ext not in ANON_ALLOWED_EXTENSIONS:
|
||||
raise HTTPException(
|
||||
status_code=415,
|
||||
detail=(
|
||||
"File type not supported for anonymous upload. "
|
||||
"Create an account to upload PDFs, Word documents, images, audio, and 20+ more file types. "
|
||||
"Allowed extensions: text, code, CSV, HTML files."
|
||||
),
|
||||
)
|
||||
|
||||
max_size = config.ANON_MAX_UPLOAD_SIZE_MB * 1024 * 1024
|
||||
content = await file.read()
|
||||
if len(content) > max_size:
|
||||
raise HTTPException(
|
||||
status_code=413,
|
||||
detail=f"File too large. Max size is {config.ANON_MAX_UPLOAD_SIZE_MB} MB.",
|
||||
)
|
||||
|
||||
import json as _json
|
||||
|
||||
import redis.asyncio as aioredis
|
||||
|
||||
redis_client = aioredis.from_url(config.REDIS_APP_URL, decode_responses=True)
|
||||
redis_key = f"{ANON_DOC_REDIS_PREFIX}{session_id}"
|
||||
|
||||
try:
|
||||
existing = await redis_client.exists(redis_key)
|
||||
if existing:
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail="Document limit reached. Create an account to upload more.",
|
||||
)
|
||||
|
||||
text_content: str
|
||||
if ext in PLAINTEXT_EXTENSIONS:
|
||||
text_content = content.decode("utf-8", errors="replace")
|
||||
elif ext in DIRECT_CONVERT_EXTENSIONS:
|
||||
if ext in {".csv", ".tsv"}:
|
||||
text_content = content.decode("utf-8", errors="replace")
|
||||
else:
|
||||
try:
|
||||
from markdownify import markdownify
|
||||
|
||||
text_content = markdownify(
|
||||
content.decode("utf-8", errors="replace")
|
||||
)
|
||||
except ImportError:
|
||||
text_content = content.decode("utf-8", errors="replace")
|
||||
else:
|
||||
text_content = content.decode("utf-8", errors="replace")
|
||||
|
||||
doc_data = _json.dumps(
|
||||
{
|
||||
"filename": file.filename,
|
||||
"size_bytes": len(content),
|
||||
"content": text_content,
|
||||
}
|
||||
)
|
||||
|
||||
ttl_seconds = config.ANON_TOKEN_QUOTA_TTL_DAYS * 86400
|
||||
await redis_client.set(redis_key, doc_data, ex=ttl_seconds)
|
||||
|
||||
finally:
|
||||
await redis_client.aclose()
|
||||
|
||||
return AnonDocResponse(
|
||||
filename=file.filename,
|
||||
size_bytes=len(content),
|
||||
)
|
||||
|
||||
|
||||
@router.get("/document")
|
||||
async def get_anonymous_document(request: Request, response: Response):
|
||||
"""Get metadata of the uploaded document for the anonymous session."""
|
||||
if not config.NOLOGIN_MODE_ENABLED:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="No-login mode is not enabled.",
|
||||
)
|
||||
|
||||
session_id = _get_or_create_session_id(request, response)
|
||||
|
||||
import json as _json
|
||||
|
||||
import redis.asyncio as aioredis
|
||||
|
||||
redis_client = aioredis.from_url(config.REDIS_APP_URL, decode_responses=True)
|
||||
redis_key = f"{ANON_DOC_REDIS_PREFIX}{session_id}"
|
||||
|
||||
try:
|
||||
data = await redis_client.get(redis_key)
|
||||
if not data:
|
||||
raise HTTPException(status_code=404, detail="No document uploaded")
|
||||
|
||||
doc = _json.loads(data)
|
||||
return {
|
||||
"filename": doc["filename"],
|
||||
"size_bytes": doc["size_bytes"],
|
||||
}
|
||||
finally:
|
||||
await redis_client.aclose()
|
||||
|
|
@ -76,6 +76,14 @@ async def get_global_new_llm_configs(
|
|||
"citations_enabled": True,
|
||||
"is_global": True,
|
||||
"is_auto_mode": True,
|
||||
"billing_tier": "free",
|
||||
"is_premium": False,
|
||||
"anonymous_enabled": False,
|
||||
"seo_enabled": False,
|
||||
"seo_slug": None,
|
||||
"seo_title": None,
|
||||
"seo_description": None,
|
||||
"quota_reserve_tokens": None,
|
||||
}
|
||||
)
|
||||
|
||||
|
|
@ -97,6 +105,14 @@ async def get_global_new_llm_configs(
|
|||
),
|
||||
"citations_enabled": cfg.get("citations_enabled", True),
|
||||
"is_global": True,
|
||||
"billing_tier": cfg.get("billing_tier", "free"),
|
||||
"is_premium": cfg.get("billing_tier", "free") == "premium",
|
||||
"anonymous_enabled": cfg.get("anonymous_enabled", False),
|
||||
"seo_enabled": cfg.get("seo_enabled", False),
|
||||
"seo_slug": cfg.get("seo_slug"),
|
||||
"seo_title": cfg.get("seo_title"),
|
||||
"seo_description": cfg.get("seo_description"),
|
||||
"quota_reserve_tokens": cfg.get("quota_reserve_tokens"),
|
||||
}
|
||||
safe_configs.append(safe_config)
|
||||
|
||||
|
|
|
|||
|
|
@ -13,13 +13,24 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|||
from stripe import SignatureVerificationError, StripeClient, StripeError
|
||||
|
||||
from app.config import config
|
||||
from app.db import PagePurchase, PagePurchaseStatus, User, get_async_session
|
||||
from app.db import (
|
||||
PagePurchase,
|
||||
PagePurchaseStatus,
|
||||
PremiumTokenPurchase,
|
||||
PremiumTokenPurchaseStatus,
|
||||
User,
|
||||
get_async_session,
|
||||
)
|
||||
from app.schemas.stripe import (
|
||||
CreateCheckoutSessionRequest,
|
||||
CreateCheckoutSessionResponse,
|
||||
CreateTokenCheckoutSessionRequest,
|
||||
CreateTokenCheckoutSessionResponse,
|
||||
PagePurchaseHistoryResponse,
|
||||
StripeStatusResponse,
|
||||
StripeWebhookResponse,
|
||||
TokenPurchaseHistoryResponse,
|
||||
TokenStripeStatusResponse,
|
||||
)
|
||||
from app.users import current_active_user
|
||||
|
||||
|
|
@ -151,6 +162,26 @@ async def _mark_purchase_failed(
|
|||
return StripeWebhookResponse()
|
||||
|
||||
|
||||
async def _mark_token_purchase_failed(
|
||||
db_session: AsyncSession, checkout_session_id: str
|
||||
) -> StripeWebhookResponse:
|
||||
purchase = (
|
||||
await db_session.execute(
|
||||
select(PremiumTokenPurchase)
|
||||
.where(
|
||||
PremiumTokenPurchase.stripe_checkout_session_id == checkout_session_id
|
||||
)
|
||||
.with_for_update()
|
||||
)
|
||||
).scalar_one_or_none()
|
||||
|
||||
if purchase is not None and purchase.status == PremiumTokenPurchaseStatus.PENDING:
|
||||
purchase.status = PremiumTokenPurchaseStatus.FAILED
|
||||
await db_session.commit()
|
||||
|
||||
return StripeWebhookResponse()
|
||||
|
||||
|
||||
async def _fulfill_completed_purchase(
|
||||
db_session: AsyncSession, checkout_session: Any
|
||||
) -> StripeWebhookResponse:
|
||||
|
|
@ -201,6 +232,86 @@ async def _fulfill_completed_purchase(
|
|||
return StripeWebhookResponse()
|
||||
|
||||
|
||||
async def _fulfill_completed_token_purchase(
|
||||
db_session: AsyncSession, checkout_session: Any
|
||||
) -> StripeWebhookResponse:
|
||||
"""Grant premium tokens to the user after a confirmed Stripe payment."""
|
||||
checkout_session_id = str(checkout_session.id)
|
||||
purchase = (
|
||||
await db_session.execute(
|
||||
select(PremiumTokenPurchase)
|
||||
.where(
|
||||
PremiumTokenPurchase.stripe_checkout_session_id == checkout_session_id
|
||||
)
|
||||
.with_for_update()
|
||||
)
|
||||
).scalar_one_or_none()
|
||||
|
||||
if purchase is None:
|
||||
metadata = _get_metadata(checkout_session)
|
||||
user_id = metadata.get("user_id")
|
||||
quantity = int(metadata.get("quantity", "0"))
|
||||
tokens_per_unit = int(metadata.get("tokens_per_unit", "0"))
|
||||
|
||||
if not user_id or quantity <= 0 or tokens_per_unit <= 0:
|
||||
logger.error(
|
||||
"Skipping token fulfillment for session %s: incomplete metadata %s",
|
||||
checkout_session_id,
|
||||
metadata,
|
||||
)
|
||||
return StripeWebhookResponse()
|
||||
|
||||
purchase = PremiumTokenPurchase(
|
||||
user_id=uuid.UUID(user_id),
|
||||
stripe_checkout_session_id=checkout_session_id,
|
||||
stripe_payment_intent_id=_normalize_optional_string(
|
||||
getattr(checkout_session, "payment_intent", None)
|
||||
),
|
||||
quantity=quantity,
|
||||
tokens_granted=quantity * tokens_per_unit,
|
||||
amount_total=getattr(checkout_session, "amount_total", None),
|
||||
currency=getattr(checkout_session, "currency", None),
|
||||
status=PremiumTokenPurchaseStatus.PENDING,
|
||||
)
|
||||
db_session.add(purchase)
|
||||
await db_session.flush()
|
||||
|
||||
if purchase.status == PremiumTokenPurchaseStatus.COMPLETED:
|
||||
return StripeWebhookResponse()
|
||||
|
||||
user = (
|
||||
(
|
||||
await db_session.execute(
|
||||
select(User).where(User.id == purchase.user_id).with_for_update(of=User)
|
||||
)
|
||||
)
|
||||
.unique()
|
||||
.scalar_one_or_none()
|
||||
)
|
||||
if user is None:
|
||||
logger.error(
|
||||
"Skipping token fulfillment for session %s: user %s not found",
|
||||
purchase.stripe_checkout_session_id,
|
||||
purchase.user_id,
|
||||
)
|
||||
return StripeWebhookResponse()
|
||||
|
||||
purchase.status = PremiumTokenPurchaseStatus.COMPLETED
|
||||
purchase.completed_at = datetime.now(UTC)
|
||||
purchase.amount_total = getattr(checkout_session, "amount_total", None)
|
||||
purchase.currency = getattr(checkout_session, "currency", None)
|
||||
purchase.stripe_payment_intent_id = _normalize_optional_string(
|
||||
getattr(checkout_session, "payment_intent", None)
|
||||
)
|
||||
user.premium_tokens_limit = (
|
||||
max(user.premium_tokens_used, user.premium_tokens_limit)
|
||||
+ purchase.tokens_granted
|
||||
)
|
||||
|
||||
await db_session.commit()
|
||||
return StripeWebhookResponse()
|
||||
|
||||
|
||||
@router.post("/create-checkout-session", response_model=CreateCheckoutSessionResponse)
|
||||
async def create_checkout_session(
|
||||
body: CreateCheckoutSessionRequest,
|
||||
|
|
@ -333,6 +444,10 @@ async def stripe_webhook(
|
|||
)
|
||||
return StripeWebhookResponse()
|
||||
|
||||
metadata = _get_metadata(checkout_session)
|
||||
purchase_type = metadata.get("purchase_type", "page_packs")
|
||||
if purchase_type == "premium_tokens":
|
||||
return await _fulfill_completed_token_purchase(db_session, checkout_session)
|
||||
return await _fulfill_completed_purchase(db_session, checkout_session)
|
||||
|
||||
if event.type in {
|
||||
|
|
@ -340,6 +455,12 @@ async def stripe_webhook(
|
|||
"checkout.session.expired",
|
||||
}:
|
||||
checkout_session = event.data.object
|
||||
metadata = _get_metadata(checkout_session)
|
||||
purchase_type = metadata.get("purchase_type", "page_packs")
|
||||
if purchase_type == "premium_tokens":
|
||||
return await _mark_token_purchase_failed(
|
||||
db_session, str(checkout_session.id)
|
||||
)
|
||||
return await _mark_purchase_failed(db_session, str(checkout_session.id))
|
||||
|
||||
return StripeWebhookResponse()
|
||||
|
|
@ -369,3 +490,146 @@ async def get_page_purchases(
|
|||
)
|
||||
|
||||
return PagePurchaseHistoryResponse(purchases=purchases)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Premium Token Purchase Routes
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def _ensure_token_buying_enabled() -> None:
|
||||
if not config.STRIPE_TOKEN_BUYING_ENABLED:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="Premium token purchases are temporarily unavailable.",
|
||||
)
|
||||
|
||||
|
||||
def _get_token_checkout_urls(search_space_id: int) -> tuple[str, str]:
|
||||
if not config.NEXT_FRONTEND_URL:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="NEXT_FRONTEND_URL is not configured.",
|
||||
)
|
||||
base_url = config.NEXT_FRONTEND_URL.rstrip("/")
|
||||
success_url = f"{base_url}/dashboard/{search_space_id}/purchase-success"
|
||||
cancel_url = f"{base_url}/dashboard/{search_space_id}/purchase-cancel"
|
||||
return success_url, cancel_url
|
||||
|
||||
|
||||
def _get_required_token_price_id() -> str:
|
||||
if not config.STRIPE_PREMIUM_TOKEN_PRICE_ID:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="STRIPE_PREMIUM_TOKEN_PRICE_ID is not configured.",
|
||||
)
|
||||
return config.STRIPE_PREMIUM_TOKEN_PRICE_ID
|
||||
|
||||
|
||||
@router.post("/create-token-checkout-session")
|
||||
async def create_token_checkout_session(
|
||||
body: CreateTokenCheckoutSessionRequest,
|
||||
user: User = Depends(current_active_user),
|
||||
db_session: AsyncSession = Depends(get_async_session),
|
||||
):
|
||||
"""Create a Stripe Checkout Session for buying premium token packs."""
|
||||
_ensure_token_buying_enabled()
|
||||
stripe_client = get_stripe_client()
|
||||
price_id = _get_required_token_price_id()
|
||||
success_url, cancel_url = _get_token_checkout_urls(body.search_space_id)
|
||||
tokens_granted = body.quantity * config.STRIPE_TOKENS_PER_UNIT
|
||||
|
||||
try:
|
||||
checkout_session = stripe_client.v1.checkout.sessions.create(
|
||||
params={
|
||||
"mode": "payment",
|
||||
"success_url": success_url,
|
||||
"cancel_url": cancel_url,
|
||||
"line_items": [
|
||||
{
|
||||
"price": price_id,
|
||||
"quantity": body.quantity,
|
||||
}
|
||||
],
|
||||
"client_reference_id": str(user.id),
|
||||
"customer_email": user.email,
|
||||
"metadata": {
|
||||
"user_id": str(user.id),
|
||||
"quantity": str(body.quantity),
|
||||
"tokens_per_unit": str(config.STRIPE_TOKENS_PER_UNIT),
|
||||
"purchase_type": "premium_tokens",
|
||||
},
|
||||
}
|
||||
)
|
||||
except StripeError as exc:
|
||||
logger.exception("Failed to create token checkout session for user %s", user.id)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_502_BAD_GATEWAY,
|
||||
detail="Unable to create Stripe checkout session.",
|
||||
) from exc
|
||||
|
||||
checkout_url = getattr(checkout_session, "url", None)
|
||||
if not checkout_url:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_502_BAD_GATEWAY,
|
||||
detail="Stripe checkout session did not return a URL.",
|
||||
)
|
||||
|
||||
db_session.add(
|
||||
PremiumTokenPurchase(
|
||||
user_id=user.id,
|
||||
stripe_checkout_session_id=str(checkout_session.id),
|
||||
stripe_payment_intent_id=_normalize_optional_string(
|
||||
getattr(checkout_session, "payment_intent", None)
|
||||
),
|
||||
quantity=body.quantity,
|
||||
tokens_granted=tokens_granted,
|
||||
amount_total=getattr(checkout_session, "amount_total", None),
|
||||
currency=getattr(checkout_session, "currency", None),
|
||||
status=PremiumTokenPurchaseStatus.PENDING,
|
||||
)
|
||||
)
|
||||
await db_session.commit()
|
||||
|
||||
return CreateTokenCheckoutSessionResponse(checkout_url=checkout_url)
|
||||
|
||||
|
||||
@router.get("/token-status")
|
||||
async def get_token_status(
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""Return token-buying availability and current premium quota for frontend."""
|
||||
used = user.premium_tokens_used
|
||||
limit = user.premium_tokens_limit
|
||||
return TokenStripeStatusResponse(
|
||||
token_buying_enabled=config.STRIPE_TOKEN_BUYING_ENABLED,
|
||||
premium_tokens_used=used,
|
||||
premium_tokens_limit=limit,
|
||||
premium_tokens_remaining=max(0, limit - used),
|
||||
)
|
||||
|
||||
|
||||
@router.get("/token-purchases")
|
||||
async def get_token_purchases(
|
||||
user: User = Depends(current_active_user),
|
||||
db_session: AsyncSession = Depends(get_async_session),
|
||||
offset: int = 0,
|
||||
limit: int = 50,
|
||||
):
|
||||
"""Return the authenticated user's premium token purchase history."""
|
||||
limit = min(limit, 100)
|
||||
purchases = (
|
||||
(
|
||||
await db_session.execute(
|
||||
select(PremiumTokenPurchase)
|
||||
.where(PremiumTokenPurchase.user_id == user.id)
|
||||
.order_by(PremiumTokenPurchase.created_at.desc())
|
||||
.offset(offset)
|
||||
.limit(limit)
|
||||
)
|
||||
)
|
||||
.scalars()
|
||||
.all()
|
||||
)
|
||||
|
||||
return TokenPurchaseHistoryResponse(purchases=purchases)
|
||||
|
|
|
|||
|
|
@ -164,6 +164,15 @@ class GlobalNewLLMConfigRead(BaseModel):
|
|||
is_global: bool = True # Always true for global configs
|
||||
is_auto_mode: bool = False # True only for Auto mode (ID 0)
|
||||
|
||||
billing_tier: str = "free"
|
||||
is_premium: bool = False
|
||||
anonymous_enabled: bool = False
|
||||
seo_enabled: bool = False
|
||||
seo_slug: str | None = None
|
||||
seo_title: str | None = None
|
||||
seo_description: str | None = None
|
||||
quota_reserve_tokens: int | None = None
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# LLM Preferences Schemas (for role assignments)
|
||||
|
|
|
|||
|
|
@ -54,3 +54,48 @@ class StripeWebhookResponse(BaseModel):
|
|||
"""Generic acknowledgement for Stripe webhook delivery."""
|
||||
|
||||
received: bool = True
|
||||
|
||||
|
||||
class CreateTokenCheckoutSessionRequest(BaseModel):
|
||||
"""Request body for creating a premium token purchase checkout session."""
|
||||
|
||||
quantity: int = Field(ge=1, le=100)
|
||||
search_space_id: int = Field(ge=1)
|
||||
|
||||
|
||||
class CreateTokenCheckoutSessionResponse(BaseModel):
|
||||
"""Response containing the Stripe-hosted checkout URL."""
|
||||
|
||||
checkout_url: str
|
||||
|
||||
|
||||
class TokenPurchaseRead(BaseModel):
|
||||
"""Serialized premium token purchase record."""
|
||||
|
||||
id: uuid.UUID
|
||||
stripe_checkout_session_id: str
|
||||
stripe_payment_intent_id: str | None = None
|
||||
quantity: int
|
||||
tokens_granted: int
|
||||
amount_total: int | None = None
|
||||
currency: str | None = None
|
||||
status: str
|
||||
completed_at: datetime | None = None
|
||||
created_at: datetime
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class TokenPurchaseHistoryResponse(BaseModel):
|
||||
"""Response containing the user's premium token purchases."""
|
||||
|
||||
purchases: list[TokenPurchaseRead]
|
||||
|
||||
|
||||
class TokenStripeStatusResponse(BaseModel):
|
||||
"""Response describing token-buying availability and current quota."""
|
||||
|
||||
token_buying_enabled: bool
|
||||
premium_tokens_used: int = 0
|
||||
premium_tokens_limit: int = 0
|
||||
premium_tokens_remaining: int = 0
|
||||
|
|
|
|||
|
|
@ -135,9 +135,12 @@ class LLMRouterService:
|
|||
logger.debug("LLM Router already initialized, skipping")
|
||||
return
|
||||
|
||||
# Build model list from global configs
|
||||
auto_configs = [
|
||||
c for c in global_configs if c.get("billing_tier", "free") != "premium"
|
||||
]
|
||||
|
||||
model_list = []
|
||||
for config in global_configs:
|
||||
for config in auto_configs:
|
||||
deployment = cls._config_to_deployment(config)
|
||||
if deployment:
|
||||
model_list.append(deployment)
|
||||
|
|
|
|||
621
surfsense_backend/app/services/token_quota_service.py
Normal file
621
surfsense_backend/app/services/token_quota_service.py
Normal file
|
|
@ -0,0 +1,621 @@
|
|||
"""
|
||||
Atomic token quota service for anonymous and registered users.
|
||||
|
||||
Provides reserve/finalize/release/get_usage operations with race-safe
|
||||
implementation using Redis Lua scripts (anonymous) and Postgres row locks
|
||||
(registered premium).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
|
||||
import redis.asyncio as aioredis
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.config import config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class QuotaScope(StrEnum):
|
||||
ANONYMOUS = "anonymous"
|
||||
PREMIUM = "premium"
|
||||
|
||||
|
||||
class QuotaStatus(StrEnum):
|
||||
OK = "ok"
|
||||
WARNING = "warning"
|
||||
BLOCKED = "blocked"
|
||||
|
||||
|
||||
class QuotaResult:
|
||||
__slots__ = ("allowed", "limit", "remaining", "reserved", "status", "used")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
allowed: bool,
|
||||
status: QuotaStatus,
|
||||
used: int,
|
||||
limit: int,
|
||||
reserved: int = 0,
|
||||
remaining: int = 0,
|
||||
):
|
||||
self.allowed = allowed
|
||||
self.status = status
|
||||
self.used = used
|
||||
self.limit = limit
|
||||
self.reserved = reserved
|
||||
self.remaining = remaining
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"allowed": self.allowed,
|
||||
"status": self.status.value,
|
||||
"used": self.used,
|
||||
"limit": self.limit,
|
||||
"reserved": self.reserved,
|
||||
"remaining": self.remaining,
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Redis Lua scripts for atomic anonymous quota operations
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# KEYS[1] = quota key (e.g. "anon_quota:session:<session_id>")
|
||||
# ARGV[1] = reserve_tokens
|
||||
# ARGV[2] = limit
|
||||
# ARGV[3] = warning_threshold
|
||||
# ARGV[4] = request_id
|
||||
# ARGV[5] = ttl_seconds
|
||||
# Returns: [allowed(0/1), status("ok"/"warning"/"blocked"), used, reserved]
|
||||
_RESERVE_LUA = """
|
||||
local key = KEYS[1]
|
||||
local reserve = tonumber(ARGV[1])
|
||||
local limit = tonumber(ARGV[2])
|
||||
local warning = tonumber(ARGV[3])
|
||||
local req_id = ARGV[4]
|
||||
local ttl = tonumber(ARGV[5])
|
||||
|
||||
local used = tonumber(redis.call('HGET', key, 'used') or '0')
|
||||
local reserved = tonumber(redis.call('HGET', key, 'reserved') or '0')
|
||||
|
||||
local effective = used + reserved + reserve
|
||||
if effective > limit then
|
||||
return {0, 'blocked', used, reserved}
|
||||
end
|
||||
|
||||
redis.call('HINCRBY', key, 'reserved', reserve)
|
||||
redis.call('HSET', key, 'req:' .. req_id, reserve)
|
||||
redis.call('EXPIRE', key, ttl)
|
||||
|
||||
local new_reserved = reserved + reserve
|
||||
local status = 'ok'
|
||||
if (used + new_reserved) >= warning then
|
||||
status = 'warning'
|
||||
end
|
||||
|
||||
return {1, status, used, new_reserved}
|
||||
"""
|
||||
|
||||
# KEYS[1] = quota key
|
||||
# ARGV[1] = request_id
|
||||
# ARGV[2] = actual_tokens
|
||||
# ARGV[3] = warning_threshold
|
||||
# Returns: [used, reserved, status]
|
||||
_FINALIZE_LUA = """
|
||||
local key = KEYS[1]
|
||||
local req_id = ARGV[1]
|
||||
local actual = tonumber(ARGV[2])
|
||||
local warning = tonumber(ARGV[3])
|
||||
|
||||
local orig_reserve = tonumber(redis.call('HGET', key, 'req:' .. req_id) or '0')
|
||||
if orig_reserve == 0 then
|
||||
return {tonumber(redis.call('HGET', key, 'used') or '0'), tonumber(redis.call('HGET', key, 'reserved') or '0'), 'ok'}
|
||||
end
|
||||
|
||||
redis.call('HDEL', key, 'req:' .. req_id)
|
||||
redis.call('HINCRBY', key, 'reserved', -orig_reserve)
|
||||
redis.call('HINCRBY', key, 'used', actual)
|
||||
|
||||
local used = tonumber(redis.call('HGET', key, 'used') or '0')
|
||||
local reserved = tonumber(redis.call('HGET', key, 'reserved') or '0')
|
||||
local status = 'ok'
|
||||
if used >= warning then
|
||||
status = 'warning'
|
||||
end
|
||||
return {used, reserved, status}
|
||||
"""
|
||||
|
||||
# KEYS[1] = quota key
|
||||
# ARGV[1] = request_id
|
||||
# Returns: 1 if released, 0 if not found
|
||||
_RELEASE_LUA = """
|
||||
local key = KEYS[1]
|
||||
local req_id = ARGV[1]
|
||||
|
||||
local orig_reserve = tonumber(redis.call('HGET', key, 'req:' .. req_id) or '0')
|
||||
if orig_reserve == 0 then
|
||||
return 0
|
||||
end
|
||||
|
||||
redis.call('HDEL', key, 'req:' .. req_id)
|
||||
redis.call('HINCRBY', key, 'reserved', -orig_reserve)
|
||||
return 1
|
||||
"""
|
||||
|
||||
|
||||
def _get_anon_redis() -> aioredis.Redis:
|
||||
return aioredis.from_url(config.REDIS_APP_URL, decode_responses=True)
|
||||
|
||||
|
||||
def compute_anon_identity_key(
|
||||
session_id: str,
|
||||
ip_hash: str | None = None,
|
||||
) -> str:
|
||||
"""Build the Redis hash key for anonymous quota tracking.
|
||||
|
||||
Uses the signed session cookie as primary identity. The IP hash
|
||||
is tracked separately so cookie-reset evasion is caught.
|
||||
"""
|
||||
return f"anon_quota:session:{session_id}"
|
||||
|
||||
|
||||
def compute_ip_quota_key(ip_address: str) -> str:
|
||||
"""Build IP-only quota key. UA is excluded so rotating User-Agent cannot bypass limits."""
|
||||
h = hashlib.sha256(ip_address.encode()).hexdigest()[:16]
|
||||
return f"anon_quota:ip:{h}"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Concurrent stream limiter (per-IP)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Atomic acquire: returns 1 if slot acquired, 0 if at capacity.
|
||||
# KEYS[1] = stream counter key ARGV[1] = max_concurrent ARGV[2] = safety_ttl
|
||||
_ACQUIRE_STREAM_LUA = """
|
||||
local key = KEYS[1]
|
||||
local max_c = tonumber(ARGV[1])
|
||||
local ttl = tonumber(ARGV[2])
|
||||
local cur = tonumber(redis.call('GET', key) or '0')
|
||||
if cur >= max_c then
|
||||
return 0
|
||||
end
|
||||
redis.call('INCR', key)
|
||||
redis.call('EXPIRE', key, ttl)
|
||||
return 1
|
||||
"""
|
||||
|
||||
# Atomic release: DECR with floor at 0
|
||||
_RELEASE_STREAM_LUA = """
|
||||
local key = KEYS[1]
|
||||
local cur = tonumber(redis.call('GET', key) or '0')
|
||||
if cur <= 0 then
|
||||
return 0
|
||||
end
|
||||
redis.call('DECR', key)
|
||||
return 1
|
||||
"""
|
||||
|
||||
|
||||
def compute_stream_slot_key(ip_address: str) -> str:
|
||||
h = hashlib.sha256(ip_address.encode()).hexdigest()[:16]
|
||||
return f"anon:streams:{h}"
|
||||
|
||||
|
||||
def compute_request_count_key(ip_address: str) -> str:
|
||||
h = hashlib.sha256(ip_address.encode()).hexdigest()[:16]
|
||||
return f"anon:reqcount:{h}"
|
||||
|
||||
|
||||
class TokenQuotaService:
|
||||
"""Unified quota service for anonymous (Redis) and premium (Postgres) scopes."""
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Concurrent stream limiter
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@staticmethod
|
||||
async def anon_acquire_stream_slot(
|
||||
ip_address: str,
|
||||
max_concurrent: int = 2,
|
||||
safety_ttl: int = 300,
|
||||
) -> bool:
|
||||
key = compute_stream_slot_key(ip_address)
|
||||
r = _get_anon_redis()
|
||||
try:
|
||||
result = await r.eval(
|
||||
_ACQUIRE_STREAM_LUA, 1, key, str(max_concurrent), str(safety_ttl)
|
||||
)
|
||||
return bool(result)
|
||||
finally:
|
||||
await r.aclose()
|
||||
|
||||
@staticmethod
|
||||
async def anon_release_stream_slot(ip_address: str) -> None:
|
||||
key = compute_stream_slot_key(ip_address)
|
||||
r = _get_anon_redis()
|
||||
try:
|
||||
await r.eval(_RELEASE_STREAM_LUA, 1, key)
|
||||
finally:
|
||||
await r.aclose()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Per-IP request counter (for CAPTCHA triggering)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@staticmethod
|
||||
async def anon_increment_request_count(ip_address: str, ttl: int = 86400) -> int:
|
||||
"""Increment and return current request count for this IP. TTL resets daily."""
|
||||
key = compute_request_count_key(ip_address)
|
||||
r = _get_anon_redis()
|
||||
try:
|
||||
pipe = r.pipeline()
|
||||
pipe.incr(key)
|
||||
pipe.expire(key, ttl)
|
||||
results = await pipe.execute()
|
||||
return int(results[0])
|
||||
finally:
|
||||
await r.aclose()
|
||||
|
||||
@staticmethod
|
||||
async def anon_get_request_count(ip_address: str) -> int:
|
||||
key = compute_request_count_key(ip_address)
|
||||
r = _get_anon_redis()
|
||||
try:
|
||||
val = await r.get(key)
|
||||
return int(val) if val else 0
|
||||
finally:
|
||||
await r.aclose()
|
||||
|
||||
@staticmethod
|
||||
async def anon_reset_request_count(ip_address: str) -> None:
|
||||
key = compute_request_count_key(ip_address)
|
||||
r = _get_anon_redis()
|
||||
try:
|
||||
await r.delete(key)
|
||||
finally:
|
||||
await r.aclose()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Anonymous (Redis-backed)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@staticmethod
|
||||
async def anon_reserve(
|
||||
session_key: str,
|
||||
ip_key: str | None,
|
||||
request_id: str,
|
||||
reserve_tokens: int,
|
||||
) -> QuotaResult:
|
||||
limit = config.ANON_TOKEN_LIMIT
|
||||
warning = config.ANON_TOKEN_WARNING_THRESHOLD
|
||||
ttl = config.ANON_TOKEN_QUOTA_TTL_DAYS * 86400
|
||||
|
||||
r = _get_anon_redis()
|
||||
try:
|
||||
result = await r.eval(
|
||||
_RESERVE_LUA,
|
||||
1,
|
||||
session_key,
|
||||
str(reserve_tokens),
|
||||
str(limit),
|
||||
str(warning),
|
||||
request_id,
|
||||
str(ttl),
|
||||
)
|
||||
allowed = bool(result[0])
|
||||
status_str = result[1] if isinstance(result[1], str) else result[1].decode()
|
||||
used = int(result[2])
|
||||
reserved = int(result[3])
|
||||
|
||||
if ip_key:
|
||||
ip_result = await r.eval(
|
||||
_RESERVE_LUA,
|
||||
1,
|
||||
ip_key,
|
||||
str(reserve_tokens),
|
||||
str(limit),
|
||||
str(warning),
|
||||
request_id,
|
||||
str(ttl),
|
||||
)
|
||||
ip_allowed = bool(ip_result[0])
|
||||
ip_used = int(ip_result[2])
|
||||
if not ip_allowed and allowed:
|
||||
await r.eval(_RELEASE_LUA, 1, session_key, request_id)
|
||||
allowed = False
|
||||
status_str = "blocked"
|
||||
used = max(used, ip_used)
|
||||
|
||||
status = QuotaStatus(status_str)
|
||||
remaining = max(0, limit - used - reserved)
|
||||
return QuotaResult(
|
||||
allowed=allowed,
|
||||
status=status,
|
||||
used=used,
|
||||
limit=limit,
|
||||
reserved=reserved,
|
||||
remaining=remaining,
|
||||
)
|
||||
finally:
|
||||
await r.aclose()
|
||||
|
||||
@staticmethod
|
||||
async def anon_finalize(
|
||||
session_key: str,
|
||||
ip_key: str | None,
|
||||
request_id: str,
|
||||
actual_tokens: int,
|
||||
) -> QuotaResult:
|
||||
warning = config.ANON_TOKEN_WARNING_THRESHOLD
|
||||
limit = config.ANON_TOKEN_LIMIT
|
||||
r = _get_anon_redis()
|
||||
try:
|
||||
result = await r.eval(
|
||||
_FINALIZE_LUA,
|
||||
1,
|
||||
session_key,
|
||||
request_id,
|
||||
str(actual_tokens),
|
||||
str(warning),
|
||||
)
|
||||
used = int(result[0])
|
||||
reserved = int(result[1])
|
||||
status_str = result[2] if isinstance(result[2], str) else result[2].decode()
|
||||
|
||||
if ip_key:
|
||||
await r.eval(
|
||||
_FINALIZE_LUA,
|
||||
1,
|
||||
ip_key,
|
||||
request_id,
|
||||
str(actual_tokens),
|
||||
str(warning),
|
||||
)
|
||||
|
||||
status = QuotaStatus(status_str)
|
||||
remaining = max(0, limit - used - reserved)
|
||||
return QuotaResult(
|
||||
allowed=True,
|
||||
status=status,
|
||||
used=used,
|
||||
limit=limit,
|
||||
reserved=reserved,
|
||||
remaining=remaining,
|
||||
)
|
||||
finally:
|
||||
await r.aclose()
|
||||
|
||||
@staticmethod
|
||||
async def anon_release(
|
||||
session_key: str,
|
||||
ip_key: str | None,
|
||||
request_id: str,
|
||||
) -> None:
|
||||
r = _get_anon_redis()
|
||||
try:
|
||||
await r.eval(_RELEASE_LUA, 1, session_key, request_id)
|
||||
if ip_key:
|
||||
await r.eval(_RELEASE_LUA, 1, ip_key, request_id)
|
||||
finally:
|
||||
await r.aclose()
|
||||
|
||||
@staticmethod
|
||||
async def anon_get_usage(session_key: str) -> QuotaResult:
|
||||
limit = config.ANON_TOKEN_LIMIT
|
||||
warning = config.ANON_TOKEN_WARNING_THRESHOLD
|
||||
r = _get_anon_redis()
|
||||
try:
|
||||
data = await r.hgetall(session_key)
|
||||
used = int(data.get("used", 0))
|
||||
reserved = int(data.get("reserved", 0))
|
||||
remaining = max(0, limit - used - reserved)
|
||||
|
||||
if used >= limit:
|
||||
status = QuotaStatus.BLOCKED
|
||||
elif used >= warning:
|
||||
status = QuotaStatus.WARNING
|
||||
else:
|
||||
status = QuotaStatus.OK
|
||||
|
||||
return QuotaResult(
|
||||
allowed=used < limit,
|
||||
status=status,
|
||||
used=used,
|
||||
limit=limit,
|
||||
reserved=reserved,
|
||||
remaining=remaining,
|
||||
)
|
||||
finally:
|
||||
await r.aclose()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Premium (Postgres-backed)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@staticmethod
|
||||
async def premium_reserve(
|
||||
db_session: AsyncSession,
|
||||
user_id: Any,
|
||||
request_id: str,
|
||||
reserve_tokens: int,
|
||||
) -> QuotaResult:
|
||||
from app.db import User
|
||||
|
||||
user = (
|
||||
(
|
||||
await db_session.execute(
|
||||
select(User).where(User.id == user_id).with_for_update(of=User)
|
||||
)
|
||||
)
|
||||
.unique()
|
||||
.scalar_one_or_none()
|
||||
)
|
||||
if user is None:
|
||||
return QuotaResult(
|
||||
allowed=False,
|
||||
status=QuotaStatus.BLOCKED,
|
||||
used=0,
|
||||
limit=0,
|
||||
)
|
||||
|
||||
limit = user.premium_tokens_limit
|
||||
used = user.premium_tokens_used
|
||||
reserved = user.premium_tokens_reserved
|
||||
|
||||
effective = used + reserved + reserve_tokens
|
||||
if effective > limit:
|
||||
remaining = max(0, limit - used - reserved)
|
||||
await db_session.rollback()
|
||||
return QuotaResult(
|
||||
allowed=False,
|
||||
status=QuotaStatus.BLOCKED,
|
||||
used=used,
|
||||
limit=limit,
|
||||
reserved=reserved,
|
||||
remaining=remaining,
|
||||
)
|
||||
|
||||
user.premium_tokens_reserved = reserved + reserve_tokens
|
||||
await db_session.commit()
|
||||
|
||||
new_reserved = reserved + reserve_tokens
|
||||
remaining = max(0, limit - used - new_reserved)
|
||||
warning_threshold = int(limit * 0.8)
|
||||
|
||||
if (used + new_reserved) >= limit:
|
||||
status = QuotaStatus.BLOCKED
|
||||
elif (used + new_reserved) >= warning_threshold:
|
||||
status = QuotaStatus.WARNING
|
||||
else:
|
||||
status = QuotaStatus.OK
|
||||
|
||||
return QuotaResult(
|
||||
allowed=True,
|
||||
status=status,
|
||||
used=used,
|
||||
limit=limit,
|
||||
reserved=new_reserved,
|
||||
remaining=remaining,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def premium_finalize(
|
||||
db_session: AsyncSession,
|
||||
user_id: Any,
|
||||
request_id: str,
|
||||
actual_tokens: int,
|
||||
reserved_tokens: int,
|
||||
) -> QuotaResult:
|
||||
from app.db import User
|
||||
|
||||
user = (
|
||||
(
|
||||
await db_session.execute(
|
||||
select(User).where(User.id == user_id).with_for_update(of=User)
|
||||
)
|
||||
)
|
||||
.unique()
|
||||
.scalar_one_or_none()
|
||||
)
|
||||
if user is None:
|
||||
return QuotaResult(
|
||||
allowed=False, status=QuotaStatus.BLOCKED, used=0, limit=0
|
||||
)
|
||||
|
||||
user.premium_tokens_reserved = max(
|
||||
0, user.premium_tokens_reserved - reserved_tokens
|
||||
)
|
||||
user.premium_tokens_used = user.premium_tokens_used + actual_tokens
|
||||
|
||||
await db_session.commit()
|
||||
|
||||
limit = user.premium_tokens_limit
|
||||
used = user.premium_tokens_used
|
||||
reserved = user.premium_tokens_reserved
|
||||
remaining = max(0, limit - used - reserved)
|
||||
|
||||
warning_threshold = int(limit * 0.8)
|
||||
if used >= limit:
|
||||
status = QuotaStatus.BLOCKED
|
||||
elif used >= warning_threshold:
|
||||
status = QuotaStatus.WARNING
|
||||
else:
|
||||
status = QuotaStatus.OK
|
||||
|
||||
return QuotaResult(
|
||||
allowed=True,
|
||||
status=status,
|
||||
used=used,
|
||||
limit=limit,
|
||||
reserved=reserved,
|
||||
remaining=remaining,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def premium_release(
|
||||
db_session: AsyncSession,
|
||||
user_id: Any,
|
||||
reserved_tokens: int,
|
||||
) -> None:
|
||||
from app.db import User
|
||||
|
||||
user = (
|
||||
(
|
||||
await db_session.execute(
|
||||
select(User).where(User.id == user_id).with_for_update(of=User)
|
||||
)
|
||||
)
|
||||
.unique()
|
||||
.scalar_one_or_none()
|
||||
)
|
||||
if user is not None:
|
||||
user.premium_tokens_reserved = max(
|
||||
0, user.premium_tokens_reserved - reserved_tokens
|
||||
)
|
||||
await db_session.commit()
|
||||
|
||||
@staticmethod
|
||||
async def premium_get_usage(
|
||||
db_session: AsyncSession,
|
||||
user_id: Any,
|
||||
) -> QuotaResult:
|
||||
from app.db import User
|
||||
|
||||
user = (
|
||||
(await db_session.execute(select(User).where(User.id == user_id)))
|
||||
.unique()
|
||||
.scalar_one_or_none()
|
||||
)
|
||||
if user is None:
|
||||
return QuotaResult(
|
||||
allowed=False, status=QuotaStatus.BLOCKED, used=0, limit=0
|
||||
)
|
||||
|
||||
limit = user.premium_tokens_limit
|
||||
used = user.premium_tokens_used
|
||||
reserved = user.premium_tokens_reserved
|
||||
remaining = max(0, limit - used - reserved)
|
||||
|
||||
warning_threshold = int(limit * 0.8)
|
||||
if used >= limit:
|
||||
status = QuotaStatus.BLOCKED
|
||||
elif used >= warning_threshold:
|
||||
status = QuotaStatus.WARNING
|
||||
else:
|
||||
status = QuotaStatus.OK
|
||||
|
||||
return QuotaResult(
|
||||
allowed=used < limit,
|
||||
status=status,
|
||||
used=used,
|
||||
limit=limit,
|
||||
reserved=reserved,
|
||||
remaining=remaining,
|
||||
)
|
||||
52
surfsense_backend/app/services/turnstile_service.py
Normal file
52
surfsense_backend/app/services/turnstile_service.py
Normal file
|
|
@ -0,0 +1,52 @@
|
|||
"""Cloudflare Turnstile CAPTCHA verification service."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
import httpx
|
||||
|
||||
from app.config import config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
TURNSTILE_VERIFY_URL = "https://challenges.cloudflare.com/turnstile/v0/siteverify"
|
||||
|
||||
|
||||
async def verify_turnstile_token(token: str, remote_ip: str | None = None) -> bool:
|
||||
"""Verify a Turnstile response token with Cloudflare.
|
||||
|
||||
Returns True when the token is valid and the challenge was solved by a
|
||||
real user. Returns False (never raises) on network errors or invalid
|
||||
tokens so callers can treat it as a simple boolean gate.
|
||||
"""
|
||||
if not config.TURNSTILE_ENABLED:
|
||||
return True
|
||||
|
||||
secret = config.TURNSTILE_SECRET_KEY
|
||||
if not secret:
|
||||
logger.warning("TURNSTILE_SECRET_KEY is not set; skipping verification")
|
||||
return True
|
||||
|
||||
payload: dict[str, str] = {
|
||||
"secret": secret,
|
||||
"response": token,
|
||||
}
|
||||
if remote_ip:
|
||||
payload["remoteip"] = remote_ip
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10) as client:
|
||||
resp = await client.post(TURNSTILE_VERIFY_URL, data=payload)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
success = data.get("success", False)
|
||||
if not success:
|
||||
logger.info(
|
||||
"Turnstile verification failed: %s",
|
||||
data.get("error-codes", []),
|
||||
)
|
||||
return bool(success)
|
||||
except Exception:
|
||||
logger.exception("Turnstile verification request failed")
|
||||
return False
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
"""Reconcile pending Stripe page purchases that might miss webhook fulfillment."""
|
||||
"""Reconcile pending Stripe purchases that might miss webhook fulfillment."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
|
|
@ -11,7 +11,12 @@ from stripe import StripeClient, StripeError
|
|||
|
||||
from app.celery_app import celery_app
|
||||
from app.config import config
|
||||
from app.db import PagePurchase, PagePurchaseStatus
|
||||
from app.db import (
|
||||
PagePurchase,
|
||||
PagePurchaseStatus,
|
||||
PremiumTokenPurchase,
|
||||
PremiumTokenPurchaseStatus,
|
||||
)
|
||||
from app.routes import stripe_routes
|
||||
from app.tasks.celery_tasks import get_celery_session_maker
|
||||
|
||||
|
|
@ -126,7 +131,108 @@ async def _reconcile_pending_page_purchases() -> None:
|
|||
await db_session.rollback()
|
||||
|
||||
logger.info(
|
||||
"Stripe reconciliation completed. fulfilled=%s failed=%s checked=%s",
|
||||
"Stripe page reconciliation completed. fulfilled=%s failed=%s checked=%s",
|
||||
fulfilled_count,
|
||||
failed_count,
|
||||
len(pending_purchases),
|
||||
)
|
||||
|
||||
|
||||
@celery_app.task(name="reconcile_pending_stripe_token_purchases")
|
||||
def reconcile_pending_stripe_token_purchases_task():
|
||||
"""Recover paid token purchases that were left pending due to missed webhook handling."""
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
loop.run_until_complete(_reconcile_pending_token_purchases())
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
|
||||
async def _reconcile_pending_token_purchases() -> None:
|
||||
"""Reconcile stale pending token purchases against Stripe source of truth."""
|
||||
stripe_client = get_stripe_client()
|
||||
if stripe_client is None:
|
||||
return
|
||||
|
||||
lookback_minutes = max(config.STRIPE_RECONCILIATION_LOOKBACK_MINUTES, 0)
|
||||
batch_size = max(config.STRIPE_RECONCILIATION_BATCH_SIZE, 1)
|
||||
cutoff = datetime.now(UTC) - timedelta(minutes=lookback_minutes)
|
||||
|
||||
async with get_celery_session_maker()() as db_session:
|
||||
pending_purchases = (
|
||||
(
|
||||
await db_session.execute(
|
||||
select(PremiumTokenPurchase)
|
||||
.where(
|
||||
PremiumTokenPurchase.status
|
||||
== PremiumTokenPurchaseStatus.PENDING,
|
||||
PremiumTokenPurchase.created_at <= cutoff,
|
||||
)
|
||||
.order_by(PremiumTokenPurchase.created_at.asc())
|
||||
.limit(batch_size)
|
||||
)
|
||||
)
|
||||
.scalars()
|
||||
.all()
|
||||
)
|
||||
|
||||
if not pending_purchases:
|
||||
logger.debug(
|
||||
"Stripe token reconciliation found no pending purchases older than %s minutes.",
|
||||
lookback_minutes,
|
||||
)
|
||||
return
|
||||
|
||||
logger.info(
|
||||
"Stripe token reconciliation checking %s pending purchases (cutoff=%s, batch=%s).",
|
||||
len(pending_purchases),
|
||||
lookback_minutes,
|
||||
batch_size,
|
||||
)
|
||||
|
||||
fulfilled_count = 0
|
||||
failed_count = 0
|
||||
|
||||
for purchase in pending_purchases:
|
||||
checkout_session_id = purchase.stripe_checkout_session_id
|
||||
|
||||
try:
|
||||
checkout_session = stripe_client.v1.checkout.sessions.retrieve(
|
||||
checkout_session_id
|
||||
)
|
||||
except StripeError:
|
||||
logger.exception(
|
||||
"Stripe token reconciliation failed to retrieve checkout session %s",
|
||||
checkout_session_id,
|
||||
)
|
||||
await db_session.rollback()
|
||||
continue
|
||||
|
||||
payment_status = getattr(checkout_session, "payment_status", None)
|
||||
session_status = getattr(checkout_session, "status", None)
|
||||
|
||||
try:
|
||||
if payment_status in {"paid", "no_payment_required"}:
|
||||
await stripe_routes._fulfill_completed_token_purchase(
|
||||
db_session, checkout_session
|
||||
)
|
||||
fulfilled_count += 1
|
||||
elif session_status == "expired":
|
||||
await stripe_routes._mark_token_purchase_failed(
|
||||
db_session, str(checkout_session.id)
|
||||
)
|
||||
failed_count += 1
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Stripe token reconciliation failed while processing checkout session %s",
|
||||
checkout_session_id,
|
||||
)
|
||||
await db_session.rollback()
|
||||
|
||||
logger.info(
|
||||
"Stripe token reconciliation completed. fulfilled=%s failed=%s checked=%s",
|
||||
fulfilled_count,
|
||||
failed_count,
|
||||
len(pending_purchases),
|
||||
|
|
|
|||
|
|
@ -1175,6 +1175,10 @@ async def stream_new_chat(
|
|||
|
||||
accumulator = start_turn()
|
||||
|
||||
# Premium quota tracking state
|
||||
_premium_reserved = 0
|
||||
_premium_request_id: str | None = None
|
||||
|
||||
session = async_session_maker()
|
||||
try:
|
||||
# Mark AI as responding to this user for live collaboration
|
||||
|
|
@ -1220,6 +1224,34 @@ async def stream_new_chat(
|
|||
llm_config_id,
|
||||
)
|
||||
|
||||
# Premium quota reservation
|
||||
if agent_config and agent_config.is_premium and user_id:
|
||||
import uuid as _uuid
|
||||
|
||||
from app.config import config as _app_config
|
||||
from app.services.token_quota_service import TokenQuotaService
|
||||
|
||||
_premium_request_id = _uuid.uuid4().hex[:16]
|
||||
reserve_amount = min(
|
||||
agent_config.quota_reserve_tokens
|
||||
or _app_config.QUOTA_MAX_RESERVE_PER_CALL,
|
||||
_app_config.QUOTA_MAX_RESERVE_PER_CALL,
|
||||
)
|
||||
async with shielded_async_session() as quota_session:
|
||||
quota_result = await TokenQuotaService.premium_reserve(
|
||||
db_session=quota_session,
|
||||
user_id=UUID(user_id),
|
||||
request_id=_premium_request_id,
|
||||
reserve_tokens=reserve_amount,
|
||||
)
|
||||
_premium_reserved = reserve_amount
|
||||
if not quota_result.allowed:
|
||||
yield streaming_service.format_error(
|
||||
"Premium token quota exceeded. Please purchase more tokens to continue using premium models."
|
||||
)
|
||||
yield streaming_service.format_done()
|
||||
return
|
||||
|
||||
if not llm:
|
||||
yield streaming_service.format_error("Failed to create LLM instance")
|
||||
yield streaming_service.format_done()
|
||||
|
|
@ -1626,6 +1658,26 @@ async def stream_new_chat(
|
|||
chat_id, generated_title
|
||||
)
|
||||
|
||||
# Finalize premium quota with actual tokens
|
||||
if _premium_request_id and user_id:
|
||||
try:
|
||||
from app.services.token_quota_service import TokenQuotaService
|
||||
|
||||
async with shielded_async_session() as quota_session:
|
||||
await TokenQuotaService.premium_finalize(
|
||||
db_session=quota_session,
|
||||
user_id=UUID(user_id),
|
||||
request_id=_premium_request_id,
|
||||
actual_tokens=accumulator.grand_total,
|
||||
reserved_tokens=_premium_reserved,
|
||||
)
|
||||
except Exception:
|
||||
logging.getLogger(__name__).warning(
|
||||
"Failed to finalize premium quota for user %s",
|
||||
user_id,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
usage_summary = accumulator.per_message_summary()
|
||||
_perf_log.info(
|
||||
"[token_usage] normal new_chat: calls=%d total=%d summary=%s",
|
||||
|
|
@ -1700,6 +1752,23 @@ async def stream_new_chat(
|
|||
# (CancelledError is a BaseException), and the rest of the
|
||||
# finally block — including session.close() — would never run.
|
||||
with anyio.CancelScope(shield=True):
|
||||
# Release premium reservation if not finalized
|
||||
if _premium_request_id and _premium_reserved > 0 and user_id:
|
||||
try:
|
||||
from app.services.token_quota_service import TokenQuotaService
|
||||
|
||||
async with shielded_async_session() as quota_session:
|
||||
await TokenQuotaService.premium_release(
|
||||
db_session=quota_session,
|
||||
user_id=UUID(user_id),
|
||||
reserved_tokens=_premium_reserved,
|
||||
)
|
||||
_premium_reserved = 0
|
||||
except Exception:
|
||||
logging.getLogger(__name__).warning(
|
||||
"Failed to release premium quota for user %s", user_id
|
||||
)
|
||||
|
||||
try:
|
||||
await session.rollback()
|
||||
await clear_ai_responding(session, chat_id)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue