mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-04-25 08:46:22 +02:00
- Deleted the migration script for adding web search columns to the database. - Removed web search related fields from the SearchSpace model and schemas. - Eliminated web search health check endpoint and associated service logic. - Updated frontend components to remove web search settings management. - Cleaned up references to SearXNG API in various modules, transitioning to a platform service approach.
521 lines
18 KiB
Python
521 lines
18 KiB
Python
import asyncio
|
|
import gc
|
|
import logging
|
|
import time
|
|
from collections import defaultdict
|
|
from contextlib import asynccontextmanager
|
|
from threading import Lock
|
|
|
|
import redis
|
|
from fastapi import Depends, FastAPI, HTTPException, Request, status
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from fastapi.responses import JSONResponse
|
|
from limits.storage import MemoryStorage
|
|
from slowapi import Limiter
|
|
from slowapi.errors import RateLimitExceeded
|
|
from slowapi.middleware import SlowAPIMiddleware
|
|
from slowapi.util import get_remote_address
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
|
|
from starlette.requests import Request as StarletteRequest
|
|
from starlette.responses import Response as StarletteResponse
|
|
from uvicorn.middleware.proxy_headers import ProxyHeadersMiddleware
|
|
|
|
from app.agents.new_chat.checkpointer import (
|
|
close_checkpointer,
|
|
setup_checkpointer_tables,
|
|
)
|
|
from app.config import config, initialize_image_gen_router, initialize_llm_router
|
|
from app.db import User, create_db_and_tables, get_async_session
|
|
from app.routes import router as crud_router
|
|
from app.routes.auth_routes import router as auth_router
|
|
from app.schemas import UserCreate, UserRead, UserUpdate
|
|
from app.tasks.surfsense_docs_indexer import seed_surfsense_docs
|
|
from app.users import SECRET, auth_backend, current_active_user, fastapi_users
|
|
from app.utils.perf import get_perf_logger, log_system_snapshot
|
|
|
|
rate_limit_logger = logging.getLogger("surfsense.rate_limit")
|
|
|
|
|
|
# ============================================================================
|
|
# Rate Limiting Configuration (SlowAPI + Redis)
|
|
# ============================================================================
|
|
# Uses the same Redis instance as Celery for zero additional infrastructure.
|
|
# Protects auth endpoints from brute force and user enumeration attacks.
|
|
|
|
# SlowAPI limiter — provides default rate limits (1024/min) for ALL routes
|
|
# via the ASGI middleware. This is the general safety net.
|
|
# in_memory_fallback ensures requests are still served (with per-worker
|
|
# in-memory limiting) when Redis is unreachable, instead of hanging.
|
|
limiter = Limiter(
|
|
key_func=get_remote_address,
|
|
storage_uri=config.REDIS_APP_URL,
|
|
default_limits=["1024/minute"],
|
|
in_memory_fallback_enabled=True,
|
|
in_memory_fallback=[MemoryStorage()],
|
|
)
|
|
|
|
|
|
def _rate_limit_exceeded_handler(request: Request, exc: RateLimitExceeded):
|
|
"""Custom 429 handler that returns JSON matching our frontend error format."""
|
|
retry_after = exc.detail.split("per")[-1].strip() if exc.detail else "60"
|
|
return JSONResponse(
|
|
status_code=429,
|
|
content={"detail": "RATE_LIMIT_EXCEEDED"},
|
|
headers={"Retry-After": retry_after},
|
|
)
|
|
|
|
|
|
# ============================================================================
|
|
# Auth-Specific Rate Limits (Redis-backed with in-memory fallback)
|
|
# ============================================================================
|
|
# Stricter per-IP limits on auth endpoints to prevent:
|
|
# - Brute force password attacks
|
|
# - User enumeration via REGISTER_USER_ALREADY_EXISTS
|
|
# - Email spam via forgot-password
|
|
#
|
|
# Primary: Redis INCR+EXPIRE (shared across all workers).
|
|
# Fallback: In-memory sliding window (per-worker) when Redis is unavailable.
|
|
# Same Redis instance as SlowAPI / Celery.
|
|
_rate_limit_redis: redis.Redis | None = None
|
|
|
|
# In-memory fallback rate limiter (per-worker, used only when Redis is down)
|
|
_memory_rate_limits: dict[str, list[float]] = defaultdict(list)
|
|
_memory_lock = Lock()
|
|
|
|
|
|
def _get_rate_limit_redis() -> redis.Redis:
|
|
"""Get or create Redis client for auth rate limiting."""
|
|
global _rate_limit_redis
|
|
if _rate_limit_redis is None:
|
|
_rate_limit_redis = redis.from_url(config.REDIS_APP_URL, decode_responses=True)
|
|
return _rate_limit_redis
|
|
|
|
|
|
def _check_rate_limit_memory(
|
|
client_ip: str, max_requests: int, window_seconds: int, scope: str
|
|
):
|
|
"""
|
|
In-memory fallback rate limiter using a sliding window.
|
|
Used only when Redis is unavailable. Per-worker only (not shared),
|
|
so effective limit = max_requests x num_workers.
|
|
"""
|
|
key = f"{scope}:{client_ip}"
|
|
now = time.monotonic()
|
|
|
|
with _memory_lock:
|
|
timestamps = [t for t in _memory_rate_limits[key] if now - t < window_seconds]
|
|
|
|
if not timestamps:
|
|
_memory_rate_limits.pop(key, None)
|
|
else:
|
|
_memory_rate_limits[key] = timestamps
|
|
|
|
if len(timestamps) >= max_requests:
|
|
rate_limit_logger.warning(
|
|
f"Rate limit exceeded (in-memory fallback) on {scope} for IP {client_ip} "
|
|
f"({len(timestamps)}/{max_requests} in {window_seconds}s)"
|
|
)
|
|
raise HTTPException(
|
|
status_code=429,
|
|
detail="RATE_LIMIT_EXCEEDED",
|
|
)
|
|
|
|
_memory_rate_limits[key] = [*timestamps, now]
|
|
|
|
|
|
def _check_rate_limit(
|
|
request: Request, max_requests: int, window_seconds: int, scope: str
|
|
):
|
|
"""
|
|
Check per-IP rate limit using Redis. Raises 429 if exceeded.
|
|
Uses atomic INCR + EXPIRE to avoid race conditions.
|
|
Falls back to in-memory sliding window if Redis is unavailable.
|
|
"""
|
|
client_ip = get_remote_address(request)
|
|
key = f"surfsense:auth_rate_limit:{scope}:{client_ip}"
|
|
|
|
try:
|
|
r = _get_rate_limit_redis()
|
|
|
|
# Atomic: increment first, then set TTL if this is a new key
|
|
pipe = r.pipeline()
|
|
pipe.incr(key)
|
|
pipe.expire(key, window_seconds)
|
|
result = pipe.execute()
|
|
except (redis.exceptions.RedisError, OSError) as exc:
|
|
# Redis unavailable — fall back to in-memory rate limiting
|
|
rate_limit_logger.warning(
|
|
f"Redis unavailable for rate limiting ({scope}), "
|
|
f"falling back to in-memory limiter for {client_ip}: {exc}"
|
|
)
|
|
_check_rate_limit_memory(client_ip, max_requests, window_seconds, scope)
|
|
return
|
|
|
|
current_count = result[0] # INCR returns the new value
|
|
|
|
if current_count > max_requests:
|
|
rate_limit_logger.warning(
|
|
f"Rate limit exceeded on {scope} for IP {client_ip} "
|
|
f"({current_count}/{max_requests} in {window_seconds}s)"
|
|
)
|
|
raise HTTPException(
|
|
status_code=429,
|
|
detail="RATE_LIMIT_EXCEEDED",
|
|
)
|
|
|
|
|
|
def rate_limit_login(request: Request):
|
|
"""5 login attempts per minute per IP."""
|
|
_check_rate_limit(request, max_requests=5, window_seconds=60, scope="login")
|
|
|
|
|
|
def rate_limit_register(request: Request):
|
|
"""3 registration attempts per minute per IP."""
|
|
_check_rate_limit(request, max_requests=3, window_seconds=60, scope="register")
|
|
|
|
|
|
def rate_limit_password_reset(request: Request):
|
|
"""2 password reset attempts per minute per IP."""
|
|
_check_rate_limit(
|
|
request, max_requests=2, window_seconds=60, scope="password_reset"
|
|
)
|
|
|
|
|
|
def _enable_slow_callback_logging(threshold_sec: float = 0.5) -> None:
|
|
"""Monkey-patch the event loop to warn whenever a callback blocks longer than *threshold_sec*.
|
|
|
|
This helps pinpoint synchronous code that freezes the entire FastAPI server.
|
|
Only active when the PERF_DEBUG env var is set (to avoid overhead in production).
|
|
"""
|
|
import os
|
|
|
|
if not os.environ.get("PERF_DEBUG"):
|
|
return
|
|
|
|
_slow_log = logging.getLogger("surfsense.perf.slow")
|
|
_slow_log.setLevel(logging.WARNING)
|
|
if not _slow_log.handlers:
|
|
_h = logging.StreamHandler()
|
|
_h.setFormatter(logging.Formatter("%(asctime)s [SLOW-CALLBACK] %(message)s"))
|
|
_slow_log.addHandler(_h)
|
|
_slow_log.propagate = False
|
|
|
|
loop = asyncio.get_running_loop()
|
|
loop.slow_callback_duration = threshold_sec # type: ignore[attr-defined]
|
|
loop.set_debug(True)
|
|
_slow_log.warning(
|
|
"Event-loop slow-callback detector ENABLED (threshold=%.1fs). "
|
|
"Set PERF_DEBUG='' to disable.",
|
|
threshold_sec,
|
|
)
|
|
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(app: FastAPI):
|
|
# Tune GC: lower gen-2 threshold so long-lived garbage is collected
|
|
# sooner (default 700/10/10 → 700/10/5). This reduces peak RSS
|
|
# with minimal CPU overhead.
|
|
gc.set_threshold(700, 10, 5)
|
|
|
|
_enable_slow_callback_logging(threshold_sec=0.5)
|
|
await create_db_and_tables()
|
|
await setup_checkpointer_tables()
|
|
initialize_llm_router()
|
|
initialize_image_gen_router()
|
|
try:
|
|
await asyncio.wait_for(seed_surfsense_docs(), timeout=120)
|
|
except TimeoutError:
|
|
logging.getLogger(__name__).warning(
|
|
"Surfsense docs seeding timed out after 120s — skipping. "
|
|
"Docs will be indexed on the next restart."
|
|
)
|
|
|
|
log_system_snapshot("startup_complete")
|
|
|
|
yield
|
|
|
|
await close_checkpointer()
|
|
|
|
|
|
def registration_allowed():
|
|
if not config.REGISTRATION_ENABLED:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_403_FORBIDDEN, detail="Registration is disabled"
|
|
)
|
|
return True
|
|
|
|
|
|
app = FastAPI(lifespan=lifespan)
|
|
|
|
# Register rate limiter and custom 429 handler
|
|
app.state.limiter = limiter
|
|
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Request-level performance middleware
|
|
# ---------------------------------------------------------------------------
|
|
# Logs wall-clock time, method, path, and status for every request so we can
|
|
# spot slow endpoints in production logs.
|
|
|
|
_PERF_SLOW_REQUEST_THRESHOLD = float(
|
|
__import__("os").environ.get("PERF_SLOW_REQUEST_MS", "2000")
|
|
)
|
|
|
|
|
|
class RequestPerfMiddleware(BaseHTTPMiddleware):
|
|
"""Middleware that logs per-request wall-clock time.
|
|
|
|
- ALL requests are logged at DEBUG level.
|
|
- Requests exceeding PERF_SLOW_REQUEST_MS (default 2000ms) are logged at
|
|
WARNING level with a system snapshot so we can correlate slow responses
|
|
with CPU/memory usage at that moment.
|
|
"""
|
|
|
|
async def dispatch(
|
|
self, request: StarletteRequest, call_next: RequestResponseEndpoint
|
|
) -> StarletteResponse:
|
|
perf = get_perf_logger()
|
|
t0 = time.perf_counter()
|
|
response = await call_next(request)
|
|
elapsed_ms = (time.perf_counter() - t0) * 1000
|
|
|
|
path = request.url.path
|
|
method = request.method
|
|
status = response.status_code
|
|
|
|
perf.debug(
|
|
"[request] %s %s -> %d in %.1fms",
|
|
method,
|
|
path,
|
|
status,
|
|
elapsed_ms,
|
|
)
|
|
|
|
if elapsed_ms > _PERF_SLOW_REQUEST_THRESHOLD:
|
|
perf.warning(
|
|
"[SLOW_REQUEST] %s %s -> %d in %.1fms (threshold=%.0fms)",
|
|
method,
|
|
path,
|
|
status,
|
|
elapsed_ms,
|
|
_PERF_SLOW_REQUEST_THRESHOLD,
|
|
)
|
|
log_system_snapshot("slow_request")
|
|
|
|
return response
|
|
|
|
|
|
app.add_middleware(RequestPerfMiddleware)
|
|
|
|
# Add SlowAPI middleware for automatic rate limiting
|
|
# Uses Starlette BaseHTTPMiddleware (not the raw ASGI variant) to avoid
|
|
# corrupting StreamingResponse — SlowAPIASGIMiddleware re-sends
|
|
# http.response.start on every body chunk, breaking SSE/streaming endpoints.
|
|
app.add_middleware(SlowAPIMiddleware)
|
|
|
|
# Add ProxyHeaders middleware FIRST to trust proxy headers (e.g., from Cloudflare)
|
|
# This ensures FastAPI uses HTTPS in redirects when behind a proxy
|
|
app.add_middleware(ProxyHeadersMiddleware, trusted_hosts="*")
|
|
|
|
# Add CORS middleware
|
|
# When using credentials, we must specify exact origins (not "*")
|
|
# Build allowed origins list from NEXT_FRONTEND_URL
|
|
allowed_origins = []
|
|
if config.NEXT_FRONTEND_URL:
|
|
allowed_origins.append(config.NEXT_FRONTEND_URL)
|
|
# Also allow without trailing slash and with www/without www variants
|
|
frontend_url = config.NEXT_FRONTEND_URL.rstrip("/")
|
|
if frontend_url not in allowed_origins:
|
|
allowed_origins.append(frontend_url)
|
|
# Handle www variants
|
|
if "://www." in frontend_url:
|
|
non_www = frontend_url.replace("://www.", "://")
|
|
if non_www not in allowed_origins:
|
|
allowed_origins.append(non_www)
|
|
elif "://" in frontend_url and "://www." not in frontend_url:
|
|
# Add www variant
|
|
www_url = frontend_url.replace("://", "://www.")
|
|
if www_url not in allowed_origins:
|
|
allowed_origins.append(www_url)
|
|
|
|
# For local development, also allow common localhost origins
|
|
if not config.BACKEND_URL or (
|
|
config.NEXT_FRONTEND_URL and "localhost" in config.NEXT_FRONTEND_URL
|
|
):
|
|
allowed_origins.extend(
|
|
[
|
|
"http://localhost:3000",
|
|
"http://127.0.0.1:3000",
|
|
]
|
|
)
|
|
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=allowed_origins,
|
|
allow_credentials=True,
|
|
allow_methods=["*"], # Allows all methods
|
|
allow_headers=["*"], # Allows all headers
|
|
)
|
|
|
|
app.include_router(
|
|
fastapi_users.get_auth_router(auth_backend),
|
|
prefix="/auth/jwt",
|
|
tags=["auth"],
|
|
dependencies=[Depends(rate_limit_login)],
|
|
)
|
|
app.include_router(
|
|
fastapi_users.get_register_router(UserRead, UserCreate),
|
|
prefix="/auth",
|
|
tags=["auth"],
|
|
dependencies=[
|
|
Depends(rate_limit_register),
|
|
Depends(registration_allowed), # blocks registration when disabled
|
|
],
|
|
)
|
|
app.include_router(
|
|
fastapi_users.get_reset_password_router(),
|
|
prefix="/auth",
|
|
tags=["auth"],
|
|
dependencies=[Depends(rate_limit_password_reset)],
|
|
)
|
|
app.include_router(
|
|
fastapi_users.get_verify_router(UserRead),
|
|
prefix="/auth",
|
|
tags=["auth"],
|
|
)
|
|
app.include_router(
|
|
fastapi_users.get_users_router(UserRead, UserUpdate),
|
|
prefix="/users",
|
|
tags=["users"],
|
|
)
|
|
|
|
# Include custom auth routes (refresh token, logout)
|
|
app.include_router(auth_router)
|
|
|
|
if config.AUTH_TYPE == "GOOGLE":
|
|
from fastapi.responses import RedirectResponse
|
|
|
|
from app.users import google_oauth_client
|
|
|
|
# Determine if we're in a secure context (HTTPS) or local development (HTTP)
|
|
# The CSRF cookie must have secure=False for HTTP (localhost development)
|
|
is_secure_context = config.BACKEND_URL and config.BACKEND_URL.startswith("https://")
|
|
|
|
# For cross-origin OAuth (frontend and backend on different domains):
|
|
# - SameSite=None is required to allow cross-origin cookie setting
|
|
# - Secure=True is required when SameSite=None
|
|
# For same-origin or local development, use SameSite=Lax (default)
|
|
csrf_cookie_samesite = "none" if is_secure_context else "lax"
|
|
|
|
# Extract the domain from BACKEND_URL for cookie domain setting
|
|
# This helps with cross-site cookie issues in Firefox/Safari
|
|
csrf_cookie_domain = None
|
|
if config.BACKEND_URL:
|
|
from urllib.parse import urlparse
|
|
|
|
parsed_url = urlparse(config.BACKEND_URL)
|
|
csrf_cookie_domain = parsed_url.hostname
|
|
|
|
app.include_router(
|
|
fastapi_users.get_oauth_router(
|
|
google_oauth_client,
|
|
auth_backend,
|
|
SECRET,
|
|
is_verified_by_default=True,
|
|
csrf_token_cookie_secure=is_secure_context,
|
|
csrf_token_cookie_samesite=csrf_cookie_samesite,
|
|
csrf_token_cookie_httponly=False, # Required for cross-site OAuth in Firefox/Safari
|
|
)
|
|
if not config.BACKEND_URL
|
|
else fastapi_users.get_oauth_router(
|
|
google_oauth_client,
|
|
auth_backend,
|
|
SECRET,
|
|
is_verified_by_default=True,
|
|
redirect_url=f"{config.BACKEND_URL}/auth/google/callback",
|
|
csrf_token_cookie_secure=is_secure_context,
|
|
csrf_token_cookie_samesite=csrf_cookie_samesite,
|
|
csrf_token_cookie_httponly=False, # Required for cross-site OAuth in Firefox/Safari
|
|
csrf_token_cookie_domain=csrf_cookie_domain, # Explicitly set cookie domain
|
|
),
|
|
prefix="/auth/google",
|
|
tags=["auth"],
|
|
dependencies=[
|
|
Depends(registration_allowed)
|
|
], # blocks OAuth registration when disabled
|
|
)
|
|
|
|
# Add a redirect-based authorize endpoint for Firefox/Safari compatibility
|
|
# This endpoint performs a server-side redirect instead of returning JSON
|
|
# which fixes cross-site cookie issues where browsers don't send cookies
|
|
# set via cross-origin fetch requests on subsequent redirects
|
|
@app.get("/auth/google/authorize-redirect", tags=["auth"])
|
|
async def google_authorize_redirect(
|
|
request: Request,
|
|
):
|
|
"""
|
|
Redirect-based OAuth authorization endpoint.
|
|
|
|
Unlike the standard /auth/google/authorize endpoint that returns JSON,
|
|
this endpoint directly redirects the browser to Google's OAuth page.
|
|
This fixes CSRF cookie issues in Firefox and Safari where cookies set
|
|
via cross-origin fetch requests are not sent on subsequent redirects.
|
|
"""
|
|
import secrets
|
|
|
|
from fastapi_users.router.oauth import generate_state_token
|
|
|
|
# Generate CSRF token
|
|
csrf_token = secrets.token_urlsafe(32)
|
|
|
|
# Build state token
|
|
state_data = {"csrftoken": csrf_token}
|
|
state = generate_state_token(state_data, SECRET, lifetime_seconds=3600)
|
|
|
|
# Get the callback URL
|
|
if config.BACKEND_URL:
|
|
redirect_url = f"{config.BACKEND_URL}/auth/google/callback"
|
|
else:
|
|
redirect_url = str(request.url_for("oauth:google.jwt.callback"))
|
|
|
|
# Get authorization URL from Google
|
|
authorization_url = await google_oauth_client.get_authorization_url(
|
|
redirect_url,
|
|
state,
|
|
scope=["openid", "email", "profile"],
|
|
)
|
|
|
|
# Create redirect response and set CSRF cookie
|
|
response = RedirectResponse(url=authorization_url, status_code=302)
|
|
response.set_cookie(
|
|
key="fastapiusersoauthcsrf",
|
|
value=csrf_token,
|
|
max_age=3600,
|
|
path="/",
|
|
domain=csrf_cookie_domain,
|
|
secure=is_secure_context,
|
|
httponly=False, # Required for cross-site OAuth in Firefox/Safari
|
|
samesite=csrf_cookie_samesite,
|
|
)
|
|
|
|
return response
|
|
|
|
|
|
app.include_router(crud_router, prefix="/api/v1", tags=["crud"])
|
|
|
|
|
|
@app.get("/health", tags=["health"])
|
|
@limiter.exempt
|
|
async def health_check():
|
|
"""Lightweight liveness probe exempt from rate limiting."""
|
|
return {"status": "ok"}
|
|
|
|
|
|
@app.get("/verify-token")
|
|
async def authenticated_route(
|
|
user: User = Depends(current_active_user),
|
|
session: AsyncSession = Depends(get_async_session),
|
|
):
|
|
return {"message": "Token is valid"}
|