SurfSense/surfsense_backend/app/services/token_quota_service.py
DESKTOP-RTLN3BA\$punk ff4e0f9b62
Some checks are pending
Build and Push Docker Images / tag_release (push) Waiting to run
Build and Push Docker Images / build (./surfsense_backend, ./surfsense_backend/Dockerfile, backend, surfsense-backend, ubuntu-24.04-arm, linux/arm64, arm64) (push) Blocked by required conditions
Build and Push Docker Images / build (./surfsense_backend, ./surfsense_backend/Dockerfile, backend, surfsense-backend, ubuntu-latest, linux/amd64, amd64) (push) Blocked by required conditions
Build and Push Docker Images / build (./surfsense_web, ./surfsense_web/Dockerfile, web, surfsense-web, ubuntu-24.04-arm, linux/arm64, arm64) (push) Blocked by required conditions
Build and Push Docker Images / build (./surfsense_web, ./surfsense_web/Dockerfile, web, surfsense-web, ubuntu-latest, linux/amd64, amd64) (push) Blocked by required conditions
Build and Push Docker Images / create_manifest (backend, surfsense-backend) (push) Blocked by required conditions
Build and Push Docker Images / create_manifest (web, surfsense-web) (push) Blocked by required conditions
feat: no login experience and prem tokens
2026-04-15 17:02:00 -07:00

621 lines
18 KiB
Python

"""
Atomic token quota service for anonymous and registered users.
Provides reserve/finalize/release/get_usage operations with race-safe
implementation using Redis Lua scripts (anonymous) and Postgres row locks
(registered premium).
"""
from __future__ import annotations
import hashlib
import logging
from enum import StrEnum
from typing import Any
import redis.asyncio as aioredis
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.config import config
logger = logging.getLogger(__name__)
class QuotaScope(StrEnum):
ANONYMOUS = "anonymous"
PREMIUM = "premium"
class QuotaStatus(StrEnum):
OK = "ok"
WARNING = "warning"
BLOCKED = "blocked"
class QuotaResult:
__slots__ = ("allowed", "limit", "remaining", "reserved", "status", "used")
def __init__(
self,
allowed: bool,
status: QuotaStatus,
used: int,
limit: int,
reserved: int = 0,
remaining: int = 0,
):
self.allowed = allowed
self.status = status
self.used = used
self.limit = limit
self.reserved = reserved
self.remaining = remaining
def to_dict(self) -> dict[str, Any]:
return {
"allowed": self.allowed,
"status": self.status.value,
"used": self.used,
"limit": self.limit,
"reserved": self.reserved,
"remaining": self.remaining,
}
# ---------------------------------------------------------------------------
# Redis Lua scripts for atomic anonymous quota operations
# ---------------------------------------------------------------------------
# KEYS[1] = quota key (e.g. "anon_quota:session:<session_id>")
# ARGV[1] = reserve_tokens
# ARGV[2] = limit
# ARGV[3] = warning_threshold
# ARGV[4] = request_id
# ARGV[5] = ttl_seconds
# Returns: [allowed(0/1), status("ok"/"warning"/"blocked"), used, reserved]
_RESERVE_LUA = """
local key = KEYS[1]
local reserve = tonumber(ARGV[1])
local limit = tonumber(ARGV[2])
local warning = tonumber(ARGV[3])
local req_id = ARGV[4]
local ttl = tonumber(ARGV[5])
local used = tonumber(redis.call('HGET', key, 'used') or '0')
local reserved = tonumber(redis.call('HGET', key, 'reserved') or '0')
local effective = used + reserved + reserve
if effective > limit then
return {0, 'blocked', used, reserved}
end
redis.call('HINCRBY', key, 'reserved', reserve)
redis.call('HSET', key, 'req:' .. req_id, reserve)
redis.call('EXPIRE', key, ttl)
local new_reserved = reserved + reserve
local status = 'ok'
if (used + new_reserved) >= warning then
status = 'warning'
end
return {1, status, used, new_reserved}
"""
# KEYS[1] = quota key
# ARGV[1] = request_id
# ARGV[2] = actual_tokens
# ARGV[3] = warning_threshold
# Returns: [used, reserved, status]
_FINALIZE_LUA = """
local key = KEYS[1]
local req_id = ARGV[1]
local actual = tonumber(ARGV[2])
local warning = tonumber(ARGV[3])
local orig_reserve = tonumber(redis.call('HGET', key, 'req:' .. req_id) or '0')
if orig_reserve == 0 then
return {tonumber(redis.call('HGET', key, 'used') or '0'), tonumber(redis.call('HGET', key, 'reserved') or '0'), 'ok'}
end
redis.call('HDEL', key, 'req:' .. req_id)
redis.call('HINCRBY', key, 'reserved', -orig_reserve)
redis.call('HINCRBY', key, 'used', actual)
local used = tonumber(redis.call('HGET', key, 'used') or '0')
local reserved = tonumber(redis.call('HGET', key, 'reserved') or '0')
local status = 'ok'
if used >= warning then
status = 'warning'
end
return {used, reserved, status}
"""
# KEYS[1] = quota key
# ARGV[1] = request_id
# Returns: 1 if released, 0 if not found
_RELEASE_LUA = """
local key = KEYS[1]
local req_id = ARGV[1]
local orig_reserve = tonumber(redis.call('HGET', key, 'req:' .. req_id) or '0')
if orig_reserve == 0 then
return 0
end
redis.call('HDEL', key, 'req:' .. req_id)
redis.call('HINCRBY', key, 'reserved', -orig_reserve)
return 1
"""
def _get_anon_redis() -> aioredis.Redis:
return aioredis.from_url(config.REDIS_APP_URL, decode_responses=True)
def compute_anon_identity_key(
session_id: str,
ip_hash: str | None = None,
) -> str:
"""Build the Redis hash key for anonymous quota tracking.
Uses the signed session cookie as primary identity. The IP hash
is tracked separately so cookie-reset evasion is caught.
"""
return f"anon_quota:session:{session_id}"
def compute_ip_quota_key(ip_address: str) -> str:
"""Build IP-only quota key. UA is excluded so rotating User-Agent cannot bypass limits."""
h = hashlib.sha256(ip_address.encode()).hexdigest()[:16]
return f"anon_quota:ip:{h}"
# ---------------------------------------------------------------------------
# Concurrent stream limiter (per-IP)
# ---------------------------------------------------------------------------
# Atomic acquire: returns 1 if slot acquired, 0 if at capacity.
# KEYS[1] = stream counter key ARGV[1] = max_concurrent ARGV[2] = safety_ttl
_ACQUIRE_STREAM_LUA = """
local key = KEYS[1]
local max_c = tonumber(ARGV[1])
local ttl = tonumber(ARGV[2])
local cur = tonumber(redis.call('GET', key) or '0')
if cur >= max_c then
return 0
end
redis.call('INCR', key)
redis.call('EXPIRE', key, ttl)
return 1
"""
# Atomic release: DECR with floor at 0
_RELEASE_STREAM_LUA = """
local key = KEYS[1]
local cur = tonumber(redis.call('GET', key) or '0')
if cur <= 0 then
return 0
end
redis.call('DECR', key)
return 1
"""
def compute_stream_slot_key(ip_address: str) -> str:
h = hashlib.sha256(ip_address.encode()).hexdigest()[:16]
return f"anon:streams:{h}"
def compute_request_count_key(ip_address: str) -> str:
h = hashlib.sha256(ip_address.encode()).hexdigest()[:16]
return f"anon:reqcount:{h}"
class TokenQuotaService:
"""Unified quota service for anonymous (Redis) and premium (Postgres) scopes."""
# ------------------------------------------------------------------
# Concurrent stream limiter
# ------------------------------------------------------------------
@staticmethod
async def anon_acquire_stream_slot(
ip_address: str,
max_concurrent: int = 2,
safety_ttl: int = 300,
) -> bool:
key = compute_stream_slot_key(ip_address)
r = _get_anon_redis()
try:
result = await r.eval(
_ACQUIRE_STREAM_LUA, 1, key, str(max_concurrent), str(safety_ttl)
)
return bool(result)
finally:
await r.aclose()
@staticmethod
async def anon_release_stream_slot(ip_address: str) -> None:
key = compute_stream_slot_key(ip_address)
r = _get_anon_redis()
try:
await r.eval(_RELEASE_STREAM_LUA, 1, key)
finally:
await r.aclose()
# ------------------------------------------------------------------
# Per-IP request counter (for CAPTCHA triggering)
# ------------------------------------------------------------------
@staticmethod
async def anon_increment_request_count(ip_address: str, ttl: int = 86400) -> int:
"""Increment and return current request count for this IP. TTL resets daily."""
key = compute_request_count_key(ip_address)
r = _get_anon_redis()
try:
pipe = r.pipeline()
pipe.incr(key)
pipe.expire(key, ttl)
results = await pipe.execute()
return int(results[0])
finally:
await r.aclose()
@staticmethod
async def anon_get_request_count(ip_address: str) -> int:
key = compute_request_count_key(ip_address)
r = _get_anon_redis()
try:
val = await r.get(key)
return int(val) if val else 0
finally:
await r.aclose()
@staticmethod
async def anon_reset_request_count(ip_address: str) -> None:
key = compute_request_count_key(ip_address)
r = _get_anon_redis()
try:
await r.delete(key)
finally:
await r.aclose()
# ------------------------------------------------------------------
# Anonymous (Redis-backed)
# ------------------------------------------------------------------
@staticmethod
async def anon_reserve(
session_key: str,
ip_key: str | None,
request_id: str,
reserve_tokens: int,
) -> QuotaResult:
limit = config.ANON_TOKEN_LIMIT
warning = config.ANON_TOKEN_WARNING_THRESHOLD
ttl = config.ANON_TOKEN_QUOTA_TTL_DAYS * 86400
r = _get_anon_redis()
try:
result = await r.eval(
_RESERVE_LUA,
1,
session_key,
str(reserve_tokens),
str(limit),
str(warning),
request_id,
str(ttl),
)
allowed = bool(result[0])
status_str = result[1] if isinstance(result[1], str) else result[1].decode()
used = int(result[2])
reserved = int(result[3])
if ip_key:
ip_result = await r.eval(
_RESERVE_LUA,
1,
ip_key,
str(reserve_tokens),
str(limit),
str(warning),
request_id,
str(ttl),
)
ip_allowed = bool(ip_result[0])
ip_used = int(ip_result[2])
if not ip_allowed and allowed:
await r.eval(_RELEASE_LUA, 1, session_key, request_id)
allowed = False
status_str = "blocked"
used = max(used, ip_used)
status = QuotaStatus(status_str)
remaining = max(0, limit - used - reserved)
return QuotaResult(
allowed=allowed,
status=status,
used=used,
limit=limit,
reserved=reserved,
remaining=remaining,
)
finally:
await r.aclose()
@staticmethod
async def anon_finalize(
session_key: str,
ip_key: str | None,
request_id: str,
actual_tokens: int,
) -> QuotaResult:
warning = config.ANON_TOKEN_WARNING_THRESHOLD
limit = config.ANON_TOKEN_LIMIT
r = _get_anon_redis()
try:
result = await r.eval(
_FINALIZE_LUA,
1,
session_key,
request_id,
str(actual_tokens),
str(warning),
)
used = int(result[0])
reserved = int(result[1])
status_str = result[2] if isinstance(result[2], str) else result[2].decode()
if ip_key:
await r.eval(
_FINALIZE_LUA,
1,
ip_key,
request_id,
str(actual_tokens),
str(warning),
)
status = QuotaStatus(status_str)
remaining = max(0, limit - used - reserved)
return QuotaResult(
allowed=True,
status=status,
used=used,
limit=limit,
reserved=reserved,
remaining=remaining,
)
finally:
await r.aclose()
@staticmethod
async def anon_release(
session_key: str,
ip_key: str | None,
request_id: str,
) -> None:
r = _get_anon_redis()
try:
await r.eval(_RELEASE_LUA, 1, session_key, request_id)
if ip_key:
await r.eval(_RELEASE_LUA, 1, ip_key, request_id)
finally:
await r.aclose()
@staticmethod
async def anon_get_usage(session_key: str) -> QuotaResult:
limit = config.ANON_TOKEN_LIMIT
warning = config.ANON_TOKEN_WARNING_THRESHOLD
r = _get_anon_redis()
try:
data = await r.hgetall(session_key)
used = int(data.get("used", 0))
reserved = int(data.get("reserved", 0))
remaining = max(0, limit - used - reserved)
if used >= limit:
status = QuotaStatus.BLOCKED
elif used >= warning:
status = QuotaStatus.WARNING
else:
status = QuotaStatus.OK
return QuotaResult(
allowed=used < limit,
status=status,
used=used,
limit=limit,
reserved=reserved,
remaining=remaining,
)
finally:
await r.aclose()
# ------------------------------------------------------------------
# Premium (Postgres-backed)
# ------------------------------------------------------------------
@staticmethod
async def premium_reserve(
db_session: AsyncSession,
user_id: Any,
request_id: str,
reserve_tokens: int,
) -> QuotaResult:
from app.db import User
user = (
(
await db_session.execute(
select(User).where(User.id == user_id).with_for_update(of=User)
)
)
.unique()
.scalar_one_or_none()
)
if user is None:
return QuotaResult(
allowed=False,
status=QuotaStatus.BLOCKED,
used=0,
limit=0,
)
limit = user.premium_tokens_limit
used = user.premium_tokens_used
reserved = user.premium_tokens_reserved
effective = used + reserved + reserve_tokens
if effective > limit:
remaining = max(0, limit - used - reserved)
await db_session.rollback()
return QuotaResult(
allowed=False,
status=QuotaStatus.BLOCKED,
used=used,
limit=limit,
reserved=reserved,
remaining=remaining,
)
user.premium_tokens_reserved = reserved + reserve_tokens
await db_session.commit()
new_reserved = reserved + reserve_tokens
remaining = max(0, limit - used - new_reserved)
warning_threshold = int(limit * 0.8)
if (used + new_reserved) >= limit:
status = QuotaStatus.BLOCKED
elif (used + new_reserved) >= warning_threshold:
status = QuotaStatus.WARNING
else:
status = QuotaStatus.OK
return QuotaResult(
allowed=True,
status=status,
used=used,
limit=limit,
reserved=new_reserved,
remaining=remaining,
)
@staticmethod
async def premium_finalize(
db_session: AsyncSession,
user_id: Any,
request_id: str,
actual_tokens: int,
reserved_tokens: int,
) -> QuotaResult:
from app.db import User
user = (
(
await db_session.execute(
select(User).where(User.id == user_id).with_for_update(of=User)
)
)
.unique()
.scalar_one_or_none()
)
if user is None:
return QuotaResult(
allowed=False, status=QuotaStatus.BLOCKED, used=0, limit=0
)
user.premium_tokens_reserved = max(
0, user.premium_tokens_reserved - reserved_tokens
)
user.premium_tokens_used = user.premium_tokens_used + actual_tokens
await db_session.commit()
limit = user.premium_tokens_limit
used = user.premium_tokens_used
reserved = user.premium_tokens_reserved
remaining = max(0, limit - used - reserved)
warning_threshold = int(limit * 0.8)
if used >= limit:
status = QuotaStatus.BLOCKED
elif used >= warning_threshold:
status = QuotaStatus.WARNING
else:
status = QuotaStatus.OK
return QuotaResult(
allowed=True,
status=status,
used=used,
limit=limit,
reserved=reserved,
remaining=remaining,
)
@staticmethod
async def premium_release(
db_session: AsyncSession,
user_id: Any,
reserved_tokens: int,
) -> None:
from app.db import User
user = (
(
await db_session.execute(
select(User).where(User.id == user_id).with_for_update(of=User)
)
)
.unique()
.scalar_one_or_none()
)
if user is not None:
user.premium_tokens_reserved = max(
0, user.premium_tokens_reserved - reserved_tokens
)
await db_session.commit()
@staticmethod
async def premium_get_usage(
db_session: AsyncSession,
user_id: Any,
) -> QuotaResult:
from app.db import User
user = (
(await db_session.execute(select(User).where(User.id == user_id)))
.unique()
.scalar_one_or_none()
)
if user is None:
return QuotaResult(
allowed=False, status=QuotaStatus.BLOCKED, used=0, limit=0
)
limit = user.premium_tokens_limit
used = user.premium_tokens_used
reserved = user.premium_tokens_reserved
remaining = max(0, limit - used - reserved)
warning_threshold = int(limit * 0.8)
if used >= limit:
status = QuotaStatus.BLOCKED
elif used >= warning_threshold:
status = QuotaStatus.WARNING
else:
status = QuotaStatus.OK
return QuotaResult(
allowed=used < limit,
status=status,
used=used,
limit=limit,
reserved=reserved,
remaining=remaining,
)