mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-01 20:03:30 +02:00
Merge pull request #798 from AnishSarkar22/fix/auth
feat(auth): improve error handling and add rate limiting
This commit is contained in:
commit
74b053f707
6 changed files with 3423 additions and 3231 deletions
|
|
@ -1,7 +1,17 @@
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from collections import defaultdict
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
|
from threading import Lock
|
||||||
|
|
||||||
|
import redis
|
||||||
from fastapi import Depends, FastAPI, HTTPException, Request, status
|
from fastapi import Depends, FastAPI, HTTPException, Request, status
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
|
from slowapi import Limiter
|
||||||
|
from slowapi.errors import RateLimitExceeded
|
||||||
|
from slowapi.middleware import SlowAPIASGIMiddleware
|
||||||
|
from slowapi.util import get_remote_address
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
from uvicorn.middleware.proxy_headers import ProxyHeadersMiddleware
|
from uvicorn.middleware.proxy_headers import ProxyHeadersMiddleware
|
||||||
|
|
||||||
|
|
@ -17,6 +27,147 @@ from app.schemas import UserCreate, UserRead, UserUpdate
|
||||||
from app.tasks.surfsense_docs_indexer import seed_surfsense_docs
|
from app.tasks.surfsense_docs_indexer import seed_surfsense_docs
|
||||||
from app.users import SECRET, auth_backend, current_active_user, fastapi_users
|
from app.users import SECRET, auth_backend, current_active_user, fastapi_users
|
||||||
|
|
||||||
|
rate_limit_logger = logging.getLogger("surfsense.rate_limit")
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Rate Limiting Configuration (SlowAPI + Redis)
|
||||||
|
# ============================================================================
|
||||||
|
# Uses the same Redis instance as Celery for zero additional infrastructure.
|
||||||
|
# Protects auth endpoints from brute force and user enumeration attacks.
|
||||||
|
|
||||||
|
# SlowAPI limiter — provides default rate limits (60/min) for ALL routes
|
||||||
|
# via the ASGI middleware. This is the general safety net.
|
||||||
|
limiter = Limiter(
|
||||||
|
key_func=get_remote_address,
|
||||||
|
storage_uri=config.REDIS_APP_URL,
|
||||||
|
default_limits=["60/minute"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _rate_limit_exceeded_handler(request: Request, exc: RateLimitExceeded):
|
||||||
|
"""Custom 429 handler that returns JSON matching our frontend error format."""
|
||||||
|
retry_after = exc.detail.split("per")[-1].strip() if exc.detail else "60"
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=429,
|
||||||
|
content={"detail": "RATE_LIMIT_EXCEEDED"},
|
||||||
|
headers={"Retry-After": retry_after},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Auth-Specific Rate Limits (Redis-backed with in-memory fallback)
|
||||||
|
# ============================================================================
|
||||||
|
# Stricter per-IP limits on auth endpoints to prevent:
|
||||||
|
# - Brute force password attacks
|
||||||
|
# - User enumeration via REGISTER_USER_ALREADY_EXISTS
|
||||||
|
# - Email spam via forgot-password
|
||||||
|
#
|
||||||
|
# Primary: Redis INCR+EXPIRE (shared across all workers).
|
||||||
|
# Fallback: In-memory sliding window (per-worker) when Redis is unavailable.
|
||||||
|
# Same Redis instance as SlowAPI / Celery.
|
||||||
|
_rate_limit_redis: redis.Redis | None = None
|
||||||
|
|
||||||
|
# In-memory fallback rate limiter (per-worker, used only when Redis is down)
|
||||||
|
_memory_rate_limits: dict[str, list[float]] = defaultdict(list)
|
||||||
|
_memory_lock = Lock()
|
||||||
|
|
||||||
|
|
||||||
|
def _get_rate_limit_redis() -> redis.Redis:
|
||||||
|
"""Get or create Redis client for auth rate limiting."""
|
||||||
|
global _rate_limit_redis
|
||||||
|
if _rate_limit_redis is None:
|
||||||
|
_rate_limit_redis = redis.from_url(config.REDIS_APP_URL, decode_responses=True)
|
||||||
|
return _rate_limit_redis
|
||||||
|
|
||||||
|
|
||||||
|
def _check_rate_limit_memory(
|
||||||
|
client_ip: str, max_requests: int, window_seconds: int, scope: str
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
In-memory fallback rate limiter using a sliding window.
|
||||||
|
Used only when Redis is unavailable. Per-worker only (not shared),
|
||||||
|
so effective limit = max_requests x num_workers.
|
||||||
|
"""
|
||||||
|
key = f"{scope}:{client_ip}"
|
||||||
|
now = time.monotonic()
|
||||||
|
|
||||||
|
with _memory_lock:
|
||||||
|
# Evict timestamps outside the current window
|
||||||
|
_memory_rate_limits[key] = [
|
||||||
|
t for t in _memory_rate_limits[key] if now - t < window_seconds
|
||||||
|
]
|
||||||
|
|
||||||
|
if len(_memory_rate_limits[key]) >= max_requests:
|
||||||
|
rate_limit_logger.warning(
|
||||||
|
f"Rate limit exceeded (in-memory fallback) on {scope} for IP {client_ip} "
|
||||||
|
f"({len(_memory_rate_limits[key])}/{max_requests} in {window_seconds}s)"
|
||||||
|
)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=429,
|
||||||
|
detail="RATE_LIMIT_EXCEEDED",
|
||||||
|
)
|
||||||
|
|
||||||
|
_memory_rate_limits[key].append(now)
|
||||||
|
|
||||||
|
|
||||||
|
def _check_rate_limit(
|
||||||
|
request: Request, max_requests: int, window_seconds: int, scope: str
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Check per-IP rate limit using Redis. Raises 429 if exceeded.
|
||||||
|
Uses atomic INCR + EXPIRE to avoid race conditions.
|
||||||
|
Falls back to in-memory sliding window if Redis is unavailable.
|
||||||
|
"""
|
||||||
|
client_ip = get_remote_address(request)
|
||||||
|
key = f"surfsense:auth_rate_limit:{scope}:{client_ip}"
|
||||||
|
|
||||||
|
try:
|
||||||
|
r = _get_rate_limit_redis()
|
||||||
|
|
||||||
|
# Atomic: increment first, then set TTL if this is a new key
|
||||||
|
pipe = r.pipeline()
|
||||||
|
pipe.incr(key)
|
||||||
|
pipe.expire(key, window_seconds)
|
||||||
|
result = pipe.execute()
|
||||||
|
except (redis.exceptions.RedisError, OSError) as exc:
|
||||||
|
# Redis unavailable — fall back to in-memory rate limiting
|
||||||
|
rate_limit_logger.warning(
|
||||||
|
f"Redis unavailable for rate limiting ({scope}), "
|
||||||
|
f"falling back to in-memory limiter for {client_ip}: {exc}"
|
||||||
|
)
|
||||||
|
_check_rate_limit_memory(client_ip, max_requests, window_seconds, scope)
|
||||||
|
return
|
||||||
|
|
||||||
|
current_count = result[0] # INCR returns the new value
|
||||||
|
|
||||||
|
if current_count > max_requests:
|
||||||
|
rate_limit_logger.warning(
|
||||||
|
f"Rate limit exceeded on {scope} for IP {client_ip} "
|
||||||
|
f"({current_count}/{max_requests} in {window_seconds}s)"
|
||||||
|
)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=429,
|
||||||
|
detail="RATE_LIMIT_EXCEEDED",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def rate_limit_login(request: Request):
|
||||||
|
"""5 login attempts per minute per IP."""
|
||||||
|
_check_rate_limit(request, max_requests=5, window_seconds=60, scope="login")
|
||||||
|
|
||||||
|
|
||||||
|
def rate_limit_register(request: Request):
|
||||||
|
"""3 registration attempts per minute per IP."""
|
||||||
|
_check_rate_limit(request, max_requests=3, window_seconds=60, scope="register")
|
||||||
|
|
||||||
|
|
||||||
|
def rate_limit_password_reset(request: Request):
|
||||||
|
"""2 password reset attempts per minute per IP."""
|
||||||
|
_check_rate_limit(
|
||||||
|
request, max_requests=2, window_seconds=60, scope="password_reset"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lifespan(app: FastAPI):
|
async def lifespan(app: FastAPI):
|
||||||
|
|
@ -45,6 +196,14 @@ def registration_allowed():
|
||||||
|
|
||||||
app = FastAPI(lifespan=lifespan)
|
app = FastAPI(lifespan=lifespan)
|
||||||
|
|
||||||
|
# Register rate limiter and custom 429 handler
|
||||||
|
app.state.limiter = limiter
|
||||||
|
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
|
||||||
|
|
||||||
|
# Add SlowAPI ASGI middleware for automatic rate limiting
|
||||||
|
# This applies default_limits to all routes and enables per-route overrides
|
||||||
|
app.add_middleware(SlowAPIASGIMiddleware)
|
||||||
|
|
||||||
# Add ProxyHeaders middleware FIRST to trust proxy headers (e.g., from Cloudflare)
|
# Add ProxyHeaders middleware FIRST to trust proxy headers (e.g., from Cloudflare)
|
||||||
# This ensures FastAPI uses HTTPS in redirects when behind a proxy
|
# This ensures FastAPI uses HTTPS in redirects when behind a proxy
|
||||||
app.add_middleware(ProxyHeadersMiddleware, trusted_hosts="*")
|
app.add_middleware(ProxyHeadersMiddleware, trusted_hosts="*")
|
||||||
|
|
@ -90,18 +249,25 @@ app.add_middleware(
|
||||||
)
|
)
|
||||||
|
|
||||||
app.include_router(
|
app.include_router(
|
||||||
fastapi_users.get_auth_router(auth_backend), prefix="/auth/jwt", tags=["auth"]
|
fastapi_users.get_auth_router(auth_backend),
|
||||||
|
prefix="/auth/jwt",
|
||||||
|
tags=["auth"],
|
||||||
|
dependencies=[Depends(rate_limit_login)],
|
||||||
)
|
)
|
||||||
app.include_router(
|
app.include_router(
|
||||||
fastapi_users.get_register_router(UserRead, UserCreate),
|
fastapi_users.get_register_router(UserRead, UserCreate),
|
||||||
prefix="/auth",
|
prefix="/auth",
|
||||||
tags=["auth"],
|
tags=["auth"],
|
||||||
dependencies=[Depends(registration_allowed)], # blocks registration when disabled
|
dependencies=[
|
||||||
|
Depends(rate_limit_register),
|
||||||
|
Depends(registration_allowed), # blocks registration when disabled
|
||||||
|
],
|
||||||
)
|
)
|
||||||
app.include_router(
|
app.include_router(
|
||||||
fastapi_users.get_reset_password_router(),
|
fastapi_users.get_reset_password_router(),
|
||||||
prefix="/auth",
|
prefix="/auth",
|
||||||
tags=["auth"],
|
tags=["auth"],
|
||||||
|
dependencies=[Depends(rate_limit_password_reset)],
|
||||||
)
|
)
|
||||||
app.include_router(
|
app.include_router(
|
||||||
fastapi_users.get_verify_router(UserRead),
|
fastapi_users.get_verify_router(UserRead),
|
||||||
|
|
|
||||||
|
|
@ -47,6 +47,14 @@ if config.AUTH_TYPE == "GOOGLE":
|
||||||
|
|
||||||
|
|
||||||
class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||||
|
"""
|
||||||
|
Custom user manager extending fastapi-users BaseUserManager.
|
||||||
|
|
||||||
|
Authentication returns a generic error for both non-existent accounts
|
||||||
|
and incorrect passwords to comply with OWASP WSTG-IDNT-04 and
|
||||||
|
prevent user enumeration attacks.
|
||||||
|
"""
|
||||||
|
|
||||||
reset_password_token_secret = SECRET
|
reset_password_token_secret = SECRET
|
||||||
verification_token_secret = SECRET
|
verification_token_secret = SECRET
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -62,6 +62,7 @@ dependencies = [
|
||||||
"unstructured[all-docs]>=0.18.31",
|
"unstructured[all-docs]>=0.18.31",
|
||||||
"unstructured-client>=0.42.3",
|
"unstructured-client>=0.42.3",
|
||||||
"langchain-unstructured>=1.0.1",
|
"langchain-unstructured>=1.0.1",
|
||||||
|
"slowapi>=0.1.9",
|
||||||
]
|
]
|
||||||
|
|
||||||
[dependency-groups]
|
[dependency-groups]
|
||||||
|
|
|
||||||
6441
surfsense_backend/uv.lock
generated
6441
surfsense_backend/uv.lock
generated
File diff suppressed because it is too large
Load diff
|
|
@ -9,7 +9,7 @@ import { useEffect, useState } from "react";
|
||||||
import { toast } from "sonner";
|
import { toast } from "sonner";
|
||||||
import { loginMutationAtom } from "@/atoms/auth/auth-mutation.atoms";
|
import { loginMutationAtom } from "@/atoms/auth/auth-mutation.atoms";
|
||||||
import { Spinner } from "@/components/ui/spinner";
|
import { Spinner } from "@/components/ui/spinner";
|
||||||
import { getAuthErrorDetails, isNetworkError, shouldRetry } from "@/lib/auth-errors";
|
import { getAuthErrorDetails, isNetworkError } from "@/lib/auth-errors";
|
||||||
import { AUTH_TYPE } from "@/lib/env-config";
|
import { AUTH_TYPE } from "@/lib/env-config";
|
||||||
import { ValidationError } from "@/lib/error";
|
import { ValidationError } from "@/lib/error";
|
||||||
import { trackLoginAttempt, trackLoginFailure, trackLoginSuccess } from "@/lib/posthog/events";
|
import { trackLoginAttempt, trackLoginFailure, trackLoginSuccess } from "@/lib/posthog/events";
|
||||||
|
|
@ -65,10 +65,6 @@ export function LocalLoginForm() {
|
||||||
if (err instanceof ValidationError) {
|
if (err instanceof ValidationError) {
|
||||||
trackLoginFailure("local", err.message);
|
trackLoginFailure("local", err.message);
|
||||||
setError({ title: err.name, message: err.message });
|
setError({ title: err.name, message: err.message });
|
||||||
toast.error(err.name, {
|
|
||||||
description: err.message,
|
|
||||||
duration: 6000,
|
|
||||||
});
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -92,22 +88,6 @@ export function LocalLoginForm() {
|
||||||
title: errorDetails.title,
|
title: errorDetails.title,
|
||||||
message: errorDetails.description,
|
message: errorDetails.description,
|
||||||
});
|
});
|
||||||
|
|
||||||
// Show error toast with conditional retry action
|
|
||||||
const toastOptions: any = {
|
|
||||||
description: errorDetails.description,
|
|
||||||
duration: 6000,
|
|
||||||
};
|
|
||||||
|
|
||||||
// Add retry action if the error is retryable
|
|
||||||
if (shouldRetry(errorCode)) {
|
|
||||||
toastOptions.action = {
|
|
||||||
label: "Retry",
|
|
||||||
onClick: () => handleSubmit(e),
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
toast.error(errorDetails.title, toastOptions);
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -20,8 +20,8 @@ const AUTH_ERROR_MESSAGES: AuthErrorMapping = {
|
||||||
description: "Your account may be suspended or restricted",
|
description: "Your account may be suspended or restricted",
|
||||||
},
|
},
|
||||||
"404": {
|
"404": {
|
||||||
title: "Account not found",
|
title: "Not found",
|
||||||
description: "No account exists with this email address",
|
description: "The requested resource was not found",
|
||||||
},
|
},
|
||||||
"409": {
|
"409": {
|
||||||
title: "Account conflict",
|
title: "Account conflict",
|
||||||
|
|
@ -31,6 +31,10 @@ const AUTH_ERROR_MESSAGES: AuthErrorMapping = {
|
||||||
title: "Too many attempts",
|
title: "Too many attempts",
|
||||||
description: "Please wait before trying again",
|
description: "Please wait before trying again",
|
||||||
},
|
},
|
||||||
|
RATE_LIMIT_EXCEEDED: {
|
||||||
|
title: "Too many attempts",
|
||||||
|
description: "You've made too many requests. Please wait a minute and try again.",
|
||||||
|
},
|
||||||
"500": {
|
"500": {
|
||||||
title: "Server error",
|
title: "Server error",
|
||||||
description: "Something went wrong on our end. Please try again",
|
description: "Something went wrong on our end. Please try again",
|
||||||
|
|
@ -42,8 +46,8 @@ const AUTH_ERROR_MESSAGES: AuthErrorMapping = {
|
||||||
|
|
||||||
// FastAPI specific errors
|
// FastAPI specific errors
|
||||||
LOGIN_BAD_CREDENTIALS: {
|
LOGIN_BAD_CREDENTIALS: {
|
||||||
title: "Invalid credentials",
|
title: "Login failed",
|
||||||
description: "The email or password you entered is incorrect",
|
description: "Invalid email or password. If you don't have an account, please sign up.",
|
||||||
},
|
},
|
||||||
LOGIN_USER_NOT_VERIFIED: {
|
LOGIN_USER_NOT_VERIFIED: {
|
||||||
title: "Account not verified",
|
title: "Account not verified",
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue