diff --git a/surfsense_backend/app/app.py b/surfsense_backend/app/app.py index e35b310e0..70269e723 100644 --- a/surfsense_backend/app/app.py +++ b/surfsense_backend/app/app.py @@ -1,6 +1,9 @@ import logging import os +import time +from collections import defaultdict from contextlib import asynccontextmanager +from threading import Lock import redis from fastapi import Depends, FastAPI, HTTPException, Request, status @@ -58,17 +61,22 @@ def _rate_limit_exceeded_handler(request: Request, exc: RateLimitExceeded): # ============================================================================ -# Auth-Specific Rate Limits (Redis-backed FastAPI dependencies) +# Auth-Specific Rate Limits (Redis-backed with in-memory fallback) # ============================================================================ # Stricter per-IP limits on auth endpoints to prevent: # - Brute force password attacks # - User enumeration via REGISTER_USER_ALREADY_EXISTS # - Email spam via forgot-password # -# These use direct Redis INCR+EXPIRE for simplicity and reliability. +# Primary: Redis INCR+EXPIRE (shared across all workers). +# Fallback: In-memory sliding window (per-worker) when Redis is unavailable. # Same Redis instance as SlowAPI / Celery. _rate_limit_redis: redis.Redis | None = None +# In-memory fallback rate limiter (per-worker, used only when Redis is down) +_memory_rate_limits: dict[str, list[float]] = defaultdict(list) +_memory_lock = Lock() + def _get_rate_limit_redis() -> redis.Redis: """Get or create Redis client for auth rate limiting.""" @@ -78,13 +86,43 @@ def _get_rate_limit_redis() -> redis.Redis: return _rate_limit_redis +def _check_rate_limit_memory( + client_ip: str, max_requests: int, window_seconds: int, scope: str +): + """ + In-memory fallback rate limiter using a sliding window. + Used only when Redis is unavailable. Per-worker only (not shared), + so effective limit = max_requests x num_workers. + """ + key = f"{scope}:{client_ip}" + now = time.monotonic() + + with _memory_lock: + # Evict timestamps outside the current window + _memory_rate_limits[key] = [ + t for t in _memory_rate_limits[key] if now - t < window_seconds + ] + + if len(_memory_rate_limits[key]) >= max_requests: + rate_limit_logger.warning( + f"Rate limit exceeded (in-memory fallback) on {scope} for IP {client_ip} " + f"({len(_memory_rate_limits[key])}/{max_requests} in {window_seconds}s)" + ) + raise HTTPException( + status_code=429, + detail="RATE_LIMIT_EXCEEDED", + ) + + _memory_rate_limits[key].append(now) + + def _check_rate_limit( request: Request, max_requests: int, window_seconds: int, scope: str ): """ Check per-IP rate limit using Redis. Raises 429 if exceeded. Uses atomic INCR + EXPIRE to avoid race conditions. - Fails open (allows request) if Redis is unavailable. + Falls back to in-memory sliding window if Redis is unavailable. """ client_ip = get_remote_address(request) key = f"surfsense:auth_rate_limit:{scope}:{client_ip}" @@ -98,12 +136,12 @@ def _check_rate_limit( pipe.expire(key, window_seconds) result = pipe.execute() except (redis.exceptions.RedisError, OSError) as exc: - # Redis unavailable — fail open to preserve auth availability. - # SlowAPI middleware provides a secondary rate-limiting layer. + # Redis unavailable — fall back to in-memory rate limiting rate_limit_logger.warning( f"Redis unavailable for rate limiting ({scope}), " - f"allowing request from {client_ip}: {exc}" + f"falling back to in-memory limiter for {client_ip}: {exc}" ) + _check_rate_limit_memory(client_ip, max_requests, window_seconds, scope) return current_count = result[0] # INCR returns the new value