SurfSense/surfsense_backend/app/gateway/ratelimit.py

136 lines
3.6 KiB
Python

"""Redis token-bucket rate limiter for gateway outbound traffic."""
from __future__ import annotations
import asyncio
import logging
import time
from dataclasses import dataclass
import redis.asyncio as aioredis
from app.config import config
from app.observability.metrics import record_gateway_redis_fallback
logger = logging.getLogger(__name__)
_TOKEN_BUCKET_LUA = """
local capacity = tonumber(ARGV[1])
local refill_rate = tonumber(ARGV[2])
local now = tonumber(ARGV[3])
local consume = tonumber(ARGV[4])
local bucket = redis.call('HMGET', KEYS[1], 'tokens', 'last_refill')
local tokens = tonumber(bucket[1]) or capacity
local last_refill = tonumber(bucket[2]) or now
local elapsed = math.max(0, now - last_refill)
tokens = math.min(capacity, tokens + (elapsed * refill_rate))
if tokens >= consume then
tokens = tokens - consume
redis.call('HMSET', KEYS[1], 'tokens', tokens, 'last_refill', now)
redis.call('EXPIRE', KEYS[1], 3600)
return 0
else
redis.call('HMSET', KEYS[1], 'tokens', tokens, 'last_refill', now)
redis.call('EXPIRE', KEYS[1], 3600)
local needed = consume - tokens
return math.ceil((needed / refill_rate) * 1000)
end
"""
_redis_client: aioredis.Redis | None = None
@dataclass
class _MemoryBucket:
tokens: float
last_refill: float
_memory_buckets: dict[str, _MemoryBucket] = {}
_memory_lock = asyncio.Lock()
def _redis() -> aioredis.Redis:
global _redis_client
if _redis_client is None:
_redis_client = aioredis.from_url(config.REDIS_APP_URL, decode_responses=True)
return _redis_client
async def _memory_fallback_acquire(
scope: str,
capacity: int,
refill_per_sec: float,
consume: float,
) -> int:
now = time.time()
async with _memory_lock:
bucket = _memory_buckets.get(scope)
if bucket is None:
bucket = _MemoryBucket(tokens=float(capacity), last_refill=now)
_memory_buckets[scope] = bucket
elapsed = max(0.0, now - bucket.last_refill)
bucket.tokens = min(float(capacity), bucket.tokens + elapsed * refill_per_sec)
bucket.last_refill = now
if bucket.tokens >= consume:
bucket.tokens -= consume
return 0
needed = consume - bucket.tokens
return int((needed / refill_per_sec) * 1000) if refill_per_sec > 0 else 1000
async def acquire_token(
scope: str,
*,
capacity: int,
refill_per_sec: float,
consume: float = 1.0,
) -> int:
"""Return 0 if allowed, otherwise milliseconds to wait.
Redis is the primary coordination mechanism. If Redis is unavailable,
fall back to per-process memory so the gateway degrades instead of failing
closed during a short Redis outage.
"""
redis_key = f"gateway:bucket:{scope}"
try:
wait_ms = await _redis().eval(
_TOKEN_BUCKET_LUA,
1,
redis_key,
capacity,
refill_per_sec,
time.time(),
consume,
)
return int(wait_ms)
except (aioredis.RedisError, OSError) as exc:
logger.warning("Redis rate limiter unavailable; using memory fallback: %s", exc)
record_gateway_redis_fallback()
return await _memory_fallback_acquire(scope, capacity, refill_per_sec, consume)
async def wait_for_token(
scope: str,
*,
capacity: int,
refill_per_sec: float,
consume: float = 1.0,
) -> int:
wait_ms = await acquire_token(
scope,
capacity=capacity,
refill_per_sec=refill_per_sec,
consume=consume,
)
if wait_ms > 0:
await asyncio.sleep(wait_ms / 1000)
return wait_ms