mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-04-26 01:06:23 +02:00
747 lines
26 KiB
Python
747 lines
26 KiB
Python
import asyncio
|
|
import gc
|
|
import logging
|
|
import time
|
|
import uuid
|
|
from collections import defaultdict
|
|
from contextlib import asynccontextmanager
|
|
from datetime import UTC, datetime
|
|
from threading import Lock
|
|
|
|
import redis
|
|
from fastapi import Depends, FastAPI, HTTPException, Request, status
|
|
from fastapi.exceptions import RequestValidationError
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from fastapi.responses import JSONResponse
|
|
from slowapi.errors import RateLimitExceeded
|
|
from slowapi.middleware import SlowAPIMiddleware
|
|
from slowapi.util import get_remote_address # noqa: F401 — kept for reference
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
|
|
from starlette.requests import Request as StarletteRequest
|
|
from starlette.responses import Response as StarletteResponse
|
|
from uvicorn.middleware.proxy_headers import ProxyHeadersMiddleware
|
|
|
|
from app.agents.new_chat.checkpointer import (
|
|
close_checkpointer,
|
|
setup_checkpointer_tables,
|
|
)
|
|
from app.config import (
|
|
config,
|
|
initialize_image_gen_router,
|
|
initialize_llm_router,
|
|
initialize_openrouter_integration,
|
|
initialize_vision_llm_router,
|
|
)
|
|
from app.db import User, create_db_and_tables, get_async_session
|
|
from app.exceptions import GENERIC_5XX_MESSAGE, ISSUES_URL, SurfSenseError
|
|
from app.rate_limiter import get_real_client_ip, limiter
|
|
from app.routes import router as crud_router
|
|
from app.routes.auth_routes import router as auth_router
|
|
from app.schemas import UserCreate, UserRead, UserUpdate
|
|
from app.tasks.surfsense_docs_indexer import seed_surfsense_docs
|
|
from app.users import SECRET, auth_backend, current_active_user, fastapi_users
|
|
from app.utils.perf import get_perf_logger, log_system_snapshot
|
|
|
|
_error_logger = logging.getLogger("surfsense.errors")
|
|
|
|
rate_limit_logger = logging.getLogger("surfsense.rate_limit")
|
|
|
|
|
|
# ============================================================================
|
|
# Rate Limiting Configuration (SlowAPI + Redis)
|
|
# ============================================================================
|
|
# Uses the same Redis instance as Celery for zero additional infrastructure.
|
|
# Protects auth endpoints from brute force and user enumeration attacks.
|
|
|
|
# limiter is imported from app.rate_limiter (shared module to avoid circular imports)
|
|
|
|
|
|
def _get_request_id(request: Request) -> str:
|
|
"""Return the request ID from state, header, or generate a new one."""
|
|
if hasattr(request.state, "request_id"):
|
|
return request.state.request_id
|
|
return request.headers.get("X-Request-ID", f"req_{uuid.uuid4().hex[:12]}")
|
|
|
|
|
|
def _build_error_response(
|
|
status_code: int,
|
|
message: str,
|
|
*,
|
|
code: str = "INTERNAL_ERROR",
|
|
request_id: str = "",
|
|
extra_headers: dict[str, str] | None = None,
|
|
) -> JSONResponse:
|
|
"""Build the standardized error envelope (new ``error`` + legacy ``detail``)."""
|
|
body = {
|
|
"error": {
|
|
"code": code,
|
|
"message": message,
|
|
"status": status_code,
|
|
"request_id": request_id,
|
|
"timestamp": datetime.now(UTC).isoformat(),
|
|
"report_url": ISSUES_URL,
|
|
},
|
|
"detail": message,
|
|
}
|
|
headers = {"X-Request-ID": request_id}
|
|
if extra_headers:
|
|
headers.update(extra_headers)
|
|
return JSONResponse(status_code=status_code, content=body, headers=headers)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Global exception handlers
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def _surfsense_error_handler(request: Request, exc: SurfSenseError) -> JSONResponse:
|
|
"""Handle our own structured exceptions."""
|
|
rid = _get_request_id(request)
|
|
if exc.status_code >= 500:
|
|
_error_logger.error(
|
|
"[%s] %s - %s: %s",
|
|
rid,
|
|
request.url.path,
|
|
exc.code,
|
|
exc,
|
|
exc_info=True,
|
|
)
|
|
message = exc.message if exc.safe_for_client else GENERIC_5XX_MESSAGE
|
|
return _build_error_response(
|
|
exc.status_code, message, code=exc.code, request_id=rid
|
|
)
|
|
|
|
|
|
def _http_exception_handler(request: Request, exc: HTTPException) -> JSONResponse:
|
|
"""Wrap FastAPI/Starlette HTTPExceptions into the standard envelope.
|
|
|
|
5xx sanitization policy:
|
|
- 500 responses are sanitized (replaced with ``GENERIC_5XX_MESSAGE``) because
|
|
they usually wrap raw internal errors and may leak sensitive info.
|
|
- Other 5xx statuses (501, 502, 503, 504, ...) are raised explicitly by
|
|
route code to communicate a specific, user-safe operational state
|
|
(e.g. 503 "Page purchases are temporarily unavailable."). Those details
|
|
are preserved so the frontend can render them, but the error is still
|
|
logged server-side.
|
|
"""
|
|
rid = _get_request_id(request)
|
|
should_sanitize = exc.status_code == 500
|
|
|
|
# Structured dict details (e.g. {"code": "CAPTCHA_REQUIRED", "message": "..."})
|
|
# are preserved so the frontend can parse them.
|
|
if isinstance(exc.detail, dict):
|
|
err_code = exc.detail.get("code", _status_to_code(exc.status_code))
|
|
message = exc.detail.get("message", str(exc.detail))
|
|
if exc.status_code >= 500:
|
|
_error_logger.error(
|
|
"[%s] %s - HTTPException %d: %s",
|
|
rid,
|
|
request.url.path,
|
|
exc.status_code,
|
|
message,
|
|
)
|
|
if should_sanitize:
|
|
message = GENERIC_5XX_MESSAGE
|
|
err_code = "INTERNAL_ERROR"
|
|
body = {
|
|
"error": {
|
|
"code": err_code,
|
|
"message": message,
|
|
"status": exc.status_code,
|
|
"request_id": rid,
|
|
"timestamp": datetime.now(UTC).isoformat(),
|
|
"report_url": ISSUES_URL,
|
|
},
|
|
"detail": exc.detail,
|
|
}
|
|
return JSONResponse(
|
|
status_code=exc.status_code,
|
|
content=body,
|
|
headers={"X-Request-ID": rid},
|
|
)
|
|
|
|
detail = exc.detail if isinstance(exc.detail, str) else str(exc.detail)
|
|
if exc.status_code >= 500:
|
|
_error_logger.error(
|
|
"[%s] %s - HTTPException %d: %s",
|
|
rid,
|
|
request.url.path,
|
|
exc.status_code,
|
|
detail,
|
|
)
|
|
if should_sanitize:
|
|
detail = GENERIC_5XX_MESSAGE
|
|
code = _status_to_code(exc.status_code, detail)
|
|
return _build_error_response(exc.status_code, detail, code=code, request_id=rid)
|
|
|
|
|
|
def _validation_error_handler(
|
|
request: Request, exc: RequestValidationError
|
|
) -> JSONResponse:
|
|
"""Return 422 with field-level detail in the standard envelope."""
|
|
rid = _get_request_id(request)
|
|
fields = []
|
|
for err in exc.errors():
|
|
loc = " -> ".join(str(part) for part in err.get("loc", []))
|
|
fields.append(f"{loc}: {err.get('msg', 'invalid')}")
|
|
message = (
|
|
f"Validation failed: {'; '.join(fields)}" if fields else "Validation failed."
|
|
)
|
|
return _build_error_response(422, message, code="VALIDATION_ERROR", request_id=rid)
|
|
|
|
|
|
def _unhandled_exception_handler(request: Request, exc: Exception) -> JSONResponse:
|
|
"""Catch-all: log full traceback, return sanitized 500."""
|
|
rid = _get_request_id(request)
|
|
_error_logger.error(
|
|
"[%s] Unhandled exception on %s %s",
|
|
rid,
|
|
request.method,
|
|
request.url.path,
|
|
exc_info=True,
|
|
)
|
|
return _build_error_response(
|
|
500, GENERIC_5XX_MESSAGE, code="INTERNAL_ERROR", request_id=rid
|
|
)
|
|
|
|
|
|
def _status_to_code(status_code: int, detail: str = "") -> str:
|
|
if detail == "RATE_LIMIT_EXCEEDED":
|
|
return "RATE_LIMIT_EXCEEDED"
|
|
mapping = {
|
|
400: "BAD_REQUEST",
|
|
401: "UNAUTHORIZED",
|
|
403: "FORBIDDEN",
|
|
404: "NOT_FOUND",
|
|
405: "METHOD_NOT_ALLOWED",
|
|
409: "CONFLICT",
|
|
422: "VALIDATION_ERROR",
|
|
429: "RATE_LIMIT_EXCEEDED",
|
|
}
|
|
return mapping.get(
|
|
status_code, "INTERNAL_ERROR" if status_code >= 500 else "CLIENT_ERROR"
|
|
)
|
|
|
|
|
|
def _rate_limit_exceeded_handler(request: Request, exc: RateLimitExceeded):
|
|
"""Custom 429 handler that returns JSON matching our error envelope."""
|
|
rid = _get_request_id(request)
|
|
retry_after = exc.detail.split("per")[-1].strip() if exc.detail else "60"
|
|
return _build_error_response(
|
|
429,
|
|
"Too many requests. Please slow down and try again.",
|
|
code="RATE_LIMIT_EXCEEDED",
|
|
request_id=rid,
|
|
extra_headers={"Retry-After": retry_after},
|
|
)
|
|
|
|
|
|
# ============================================================================
|
|
# Auth-Specific Rate Limits (Redis-backed with in-memory fallback)
|
|
# ============================================================================
|
|
# Stricter per-IP limits on auth endpoints to prevent:
|
|
# - Brute force password attacks
|
|
# - User enumeration via REGISTER_USER_ALREADY_EXISTS
|
|
# - Email spam via forgot-password
|
|
#
|
|
# Primary: Redis INCR+EXPIRE (shared across all workers).
|
|
# Fallback: In-memory sliding window (per-worker) when Redis is unavailable.
|
|
# Same Redis instance as SlowAPI / Celery.
|
|
_rate_limit_redis: redis.Redis | None = None
|
|
|
|
# In-memory fallback rate limiter (per-worker, used only when Redis is down)
|
|
_memory_rate_limits: dict[str, list[float]] = defaultdict(list)
|
|
_memory_lock = Lock()
|
|
|
|
|
|
def _get_rate_limit_redis() -> redis.Redis:
|
|
"""Get or create Redis client for auth rate limiting."""
|
|
global _rate_limit_redis
|
|
if _rate_limit_redis is None:
|
|
_rate_limit_redis = redis.from_url(config.REDIS_APP_URL, decode_responses=True)
|
|
return _rate_limit_redis
|
|
|
|
|
|
def _check_rate_limit_memory(
|
|
client_ip: str, max_requests: int, window_seconds: int, scope: str
|
|
):
|
|
"""
|
|
In-memory fallback rate limiter using a sliding window.
|
|
Used only when Redis is unavailable. Per-worker only (not shared),
|
|
so effective limit = max_requests x num_workers.
|
|
"""
|
|
key = f"{scope}:{client_ip}"
|
|
now = time.monotonic()
|
|
|
|
with _memory_lock:
|
|
timestamps = [t for t in _memory_rate_limits[key] if now - t < window_seconds]
|
|
|
|
if not timestamps:
|
|
_memory_rate_limits.pop(key, None)
|
|
else:
|
|
_memory_rate_limits[key] = timestamps
|
|
|
|
if len(timestamps) >= max_requests:
|
|
rate_limit_logger.warning(
|
|
f"Rate limit exceeded (in-memory fallback) on {scope} for IP {client_ip} "
|
|
f"({len(timestamps)}/{max_requests} in {window_seconds}s)"
|
|
)
|
|
raise HTTPException(
|
|
status_code=429,
|
|
detail="RATE_LIMIT_EXCEEDED",
|
|
)
|
|
|
|
_memory_rate_limits[key] = [*timestamps, now]
|
|
|
|
|
|
def _check_rate_limit(
|
|
request: Request, max_requests: int, window_seconds: int, scope: str
|
|
):
|
|
"""
|
|
Check per-IP rate limit using Redis. Raises 429 if exceeded.
|
|
Uses atomic INCR + EXPIRE to avoid race conditions.
|
|
Falls back to in-memory sliding window if Redis is unavailable.
|
|
"""
|
|
client_ip = get_real_client_ip(request)
|
|
key = f"surfsense:auth_rate_limit:{scope}:{client_ip}"
|
|
|
|
try:
|
|
r = _get_rate_limit_redis()
|
|
|
|
# Atomic: increment first, then set TTL if this is a new key
|
|
pipe = r.pipeline()
|
|
pipe.incr(key)
|
|
pipe.expire(key, window_seconds)
|
|
result = pipe.execute()
|
|
except (redis.exceptions.RedisError, OSError) as exc:
|
|
# Redis unavailable — fall back to in-memory rate limiting
|
|
rate_limit_logger.warning(
|
|
f"Redis unavailable for rate limiting ({scope}), "
|
|
f"falling back to in-memory limiter for {client_ip}: {exc}"
|
|
)
|
|
_check_rate_limit_memory(client_ip, max_requests, window_seconds, scope)
|
|
return
|
|
|
|
current_count = result[0] # INCR returns the new value
|
|
|
|
if current_count > max_requests:
|
|
rate_limit_logger.warning(
|
|
f"Rate limit exceeded on {scope} for IP {client_ip} "
|
|
f"({current_count}/{max_requests} in {window_seconds}s)"
|
|
)
|
|
raise HTTPException(
|
|
status_code=429,
|
|
detail="RATE_LIMIT_EXCEEDED",
|
|
)
|
|
|
|
|
|
def rate_limit_login(request: Request):
|
|
"""5 login attempts per minute per IP."""
|
|
_check_rate_limit(request, max_requests=5, window_seconds=60, scope="login")
|
|
|
|
|
|
def rate_limit_register(request: Request):
|
|
"""3 registration attempts per minute per IP."""
|
|
_check_rate_limit(request, max_requests=3, window_seconds=60, scope="register")
|
|
|
|
|
|
def rate_limit_password_reset(request: Request):
|
|
"""2 password reset attempts per minute per IP."""
|
|
_check_rate_limit(
|
|
request, max_requests=2, window_seconds=60, scope="password_reset"
|
|
)
|
|
|
|
|
|
def _enable_slow_callback_logging(threshold_sec: float = 0.5) -> None:
|
|
"""Monkey-patch the event loop to warn whenever a callback blocks longer than *threshold_sec*.
|
|
|
|
This helps pinpoint synchronous code that freezes the entire FastAPI server.
|
|
Only active when the PERF_DEBUG env var is set (to avoid overhead in production).
|
|
"""
|
|
import os
|
|
|
|
if not os.environ.get("PERF_DEBUG"):
|
|
return
|
|
|
|
_slow_log = logging.getLogger("surfsense.perf.slow")
|
|
_slow_log.setLevel(logging.WARNING)
|
|
if not _slow_log.handlers:
|
|
_h = logging.StreamHandler()
|
|
_h.setFormatter(logging.Formatter("%(asctime)s [SLOW-CALLBACK] %(message)s"))
|
|
_slow_log.addHandler(_h)
|
|
_slow_log.propagate = False
|
|
|
|
loop = asyncio.get_running_loop()
|
|
loop.slow_callback_duration = threshold_sec # type: ignore[attr-defined]
|
|
loop.set_debug(True)
|
|
_slow_log.warning(
|
|
"Event-loop slow-callback detector ENABLED (threshold=%.1fs). "
|
|
"Set PERF_DEBUG='' to disable.",
|
|
threshold_sec,
|
|
)
|
|
|
|
|
|
def _start_openrouter_background_refresh() -> None:
|
|
"""Start periodic OpenRouter model refresh if integration is enabled."""
|
|
from app.services.openrouter_integration_service import OpenRouterIntegrationService
|
|
|
|
if not OpenRouterIntegrationService.is_initialized():
|
|
return
|
|
settings = config.OPENROUTER_INTEGRATION_SETTINGS
|
|
if settings:
|
|
interval = settings.get("refresh_interval_hours", 24)
|
|
OpenRouterIntegrationService.get_instance().start_background_refresh(interval)
|
|
|
|
|
|
def _stop_openrouter_background_refresh() -> None:
|
|
"""Cancel the periodic OpenRouter refresh task on shutdown."""
|
|
from app.services.openrouter_integration_service import OpenRouterIntegrationService
|
|
|
|
if OpenRouterIntegrationService.is_initialized():
|
|
OpenRouterIntegrationService.get_instance().stop_background_refresh()
|
|
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(app: FastAPI):
|
|
# Tune GC: lower gen-2 threshold so long-lived garbage is collected
|
|
# sooner (default 700/10/10 → 700/10/5). This reduces peak RSS
|
|
# with minimal CPU overhead.
|
|
gc.set_threshold(700, 10, 5)
|
|
|
|
_enable_slow_callback_logging(threshold_sec=0.5)
|
|
await create_db_and_tables()
|
|
await setup_checkpointer_tables()
|
|
initialize_openrouter_integration()
|
|
_start_openrouter_background_refresh()
|
|
initialize_llm_router()
|
|
initialize_image_gen_router()
|
|
initialize_vision_llm_router()
|
|
try:
|
|
await asyncio.wait_for(seed_surfsense_docs(), timeout=120)
|
|
except TimeoutError:
|
|
logging.getLogger(__name__).warning(
|
|
"Surfsense docs seeding timed out after 120s — skipping. "
|
|
"Docs will be indexed on the next restart."
|
|
)
|
|
|
|
log_system_snapshot("startup_complete")
|
|
|
|
yield
|
|
|
|
_stop_openrouter_background_refresh()
|
|
await close_checkpointer()
|
|
|
|
|
|
def registration_allowed():
|
|
if not config.REGISTRATION_ENABLED:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_403_FORBIDDEN, detail="Registration is disabled"
|
|
)
|
|
return True
|
|
|
|
|
|
app = FastAPI(lifespan=lifespan)
|
|
|
|
# Register rate limiter and custom 429 handler
|
|
app.state.limiter = limiter
|
|
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
|
|
|
|
# Register structured global exception handlers (order matters: most specific first)
|
|
app.add_exception_handler(SurfSenseError, _surfsense_error_handler)
|
|
app.add_exception_handler(RequestValidationError, _validation_error_handler)
|
|
app.add_exception_handler(HTTPException, _http_exception_handler)
|
|
app.add_exception_handler(Exception, _unhandled_exception_handler)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Request-ID middleware
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class RequestIDMiddleware(BaseHTTPMiddleware):
|
|
"""Attach a unique request ID to every request and echo it in the response."""
|
|
|
|
async def dispatch(
|
|
self, request: StarletteRequest, call_next: RequestResponseEndpoint
|
|
) -> StarletteResponse:
|
|
request_id = request.headers.get("X-Request-ID", f"req_{uuid.uuid4().hex[:12]}")
|
|
request.state.request_id = request_id
|
|
response = await call_next(request)
|
|
response.headers["X-Request-ID"] = request_id
|
|
return response
|
|
|
|
|
|
app.add_middleware(RequestIDMiddleware)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Request-level performance middleware
|
|
# ---------------------------------------------------------------------------
|
|
# Logs wall-clock time, method, path, and status for every request so we can
|
|
# spot slow endpoints in production logs.
|
|
|
|
_PERF_SLOW_REQUEST_THRESHOLD = float(
|
|
__import__("os").environ.get("PERF_SLOW_REQUEST_MS", "2000")
|
|
)
|
|
|
|
|
|
class RequestPerfMiddleware(BaseHTTPMiddleware):
|
|
"""Middleware that logs per-request wall-clock time.
|
|
|
|
- ALL requests are logged at DEBUG level.
|
|
- Requests exceeding PERF_SLOW_REQUEST_MS (default 2000ms) are logged at
|
|
WARNING level with a system snapshot so we can correlate slow responses
|
|
with CPU/memory usage at that moment.
|
|
"""
|
|
|
|
async def dispatch(
|
|
self, request: StarletteRequest, call_next: RequestResponseEndpoint
|
|
) -> StarletteResponse:
|
|
perf = get_perf_logger()
|
|
t0 = time.perf_counter()
|
|
response = await call_next(request)
|
|
elapsed_ms = (time.perf_counter() - t0) * 1000
|
|
|
|
path = request.url.path
|
|
method = request.method
|
|
status = response.status_code
|
|
|
|
perf.debug(
|
|
"[request] %s %s -> %d in %.1fms",
|
|
method,
|
|
path,
|
|
status,
|
|
elapsed_ms,
|
|
)
|
|
|
|
if elapsed_ms > _PERF_SLOW_REQUEST_THRESHOLD:
|
|
perf.warning(
|
|
"[SLOW_REQUEST] %s %s -> %d in %.1fms (threshold=%.0fms)",
|
|
method,
|
|
path,
|
|
status,
|
|
elapsed_ms,
|
|
_PERF_SLOW_REQUEST_THRESHOLD,
|
|
)
|
|
log_system_snapshot("slow_request")
|
|
|
|
return response
|
|
|
|
|
|
app.add_middleware(RequestPerfMiddleware)
|
|
|
|
# Add SlowAPI middleware for automatic rate limiting
|
|
# Uses Starlette BaseHTTPMiddleware (not the raw ASGI variant) to avoid
|
|
# corrupting StreamingResponse — SlowAPIASGIMiddleware re-sends
|
|
# http.response.start on every body chunk, breaking SSE/streaming endpoints.
|
|
app.add_middleware(SlowAPIMiddleware)
|
|
|
|
# Add ProxyHeaders middleware FIRST to trust proxy headers (e.g., from Cloudflare)
|
|
# This ensures FastAPI uses HTTPS in redirects when behind a proxy
|
|
app.add_middleware(ProxyHeadersMiddleware, trusted_hosts="*")
|
|
|
|
# Add CORS middleware
|
|
# When using credentials, we must specify exact origins (not "*")
|
|
# Build allowed origins list from NEXT_FRONTEND_URL
|
|
allowed_origins = []
|
|
if config.NEXT_FRONTEND_URL:
|
|
allowed_origins.append(config.NEXT_FRONTEND_URL)
|
|
# Also allow without trailing slash and with www/without www variants
|
|
frontend_url = config.NEXT_FRONTEND_URL.rstrip("/")
|
|
if frontend_url not in allowed_origins:
|
|
allowed_origins.append(frontend_url)
|
|
# Handle www variants
|
|
if "://www." in frontend_url:
|
|
non_www = frontend_url.replace("://www.", "://")
|
|
if non_www not in allowed_origins:
|
|
allowed_origins.append(non_www)
|
|
elif "://" in frontend_url and "://www." not in frontend_url:
|
|
# Add www variant
|
|
www_url = frontend_url.replace("://", "://www.")
|
|
if www_url not in allowed_origins:
|
|
allowed_origins.append(www_url)
|
|
|
|
allowed_origins.extend(
|
|
[ # For local development and desktop app
|
|
"http://localhost:3000",
|
|
"http://127.0.0.1:3000",
|
|
]
|
|
)
|
|
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=allowed_origins,
|
|
allow_origin_regex=r"^https?://(localhost|127\.0\.0\.1)(:\d+)?$",
|
|
allow_credentials=True,
|
|
allow_methods=["*"], # Allows all methods
|
|
allow_headers=["*"], # Allows all headers
|
|
)
|
|
|
|
app.include_router(
|
|
fastapi_users.get_auth_router(auth_backend),
|
|
prefix="/auth/jwt",
|
|
tags=["auth"],
|
|
dependencies=[Depends(rate_limit_login)],
|
|
)
|
|
app.include_router(
|
|
fastapi_users.get_register_router(UserRead, UserCreate),
|
|
prefix="/auth",
|
|
tags=["auth"],
|
|
dependencies=[
|
|
Depends(rate_limit_register),
|
|
Depends(registration_allowed), # blocks registration when disabled
|
|
],
|
|
)
|
|
app.include_router(
|
|
fastapi_users.get_reset_password_router(),
|
|
prefix="/auth",
|
|
tags=["auth"],
|
|
dependencies=[Depends(rate_limit_password_reset)],
|
|
)
|
|
app.include_router(
|
|
fastapi_users.get_verify_router(UserRead),
|
|
prefix="/auth",
|
|
tags=["auth"],
|
|
)
|
|
app.include_router(
|
|
fastapi_users.get_users_router(UserRead, UserUpdate),
|
|
prefix="/users",
|
|
tags=["users"],
|
|
)
|
|
|
|
# Include custom auth routes (refresh token, logout)
|
|
app.include_router(auth_router)
|
|
|
|
if config.AUTH_TYPE == "GOOGLE":
|
|
from fastapi.responses import RedirectResponse
|
|
|
|
from app.users import google_oauth_client
|
|
|
|
# Determine if we're in a secure context (HTTPS) or local development (HTTP)
|
|
# The CSRF cookie must have secure=False for HTTP (localhost development)
|
|
is_secure_context = config.BACKEND_URL and config.BACKEND_URL.startswith("https://")
|
|
|
|
# For cross-origin OAuth (frontend and backend on different domains):
|
|
# - SameSite=None is required to allow cross-origin cookie setting
|
|
# - Secure=True is required when SameSite=None
|
|
# For same-origin or local development, use SameSite=Lax (default)
|
|
csrf_cookie_samesite = "none" if is_secure_context else "lax"
|
|
|
|
# Extract the domain from BACKEND_URL for cookie domain setting
|
|
# This helps with cross-site cookie issues in Firefox/Safari
|
|
csrf_cookie_domain = None
|
|
if config.BACKEND_URL:
|
|
from urllib.parse import urlparse
|
|
|
|
parsed_url = urlparse(config.BACKEND_URL)
|
|
csrf_cookie_domain = parsed_url.hostname
|
|
|
|
app.include_router(
|
|
fastapi_users.get_oauth_router(
|
|
google_oauth_client,
|
|
auth_backend,
|
|
SECRET,
|
|
is_verified_by_default=True,
|
|
csrf_token_cookie_secure=is_secure_context,
|
|
csrf_token_cookie_samesite=csrf_cookie_samesite,
|
|
csrf_token_cookie_httponly=False, # Required for cross-site OAuth in Firefox/Safari
|
|
)
|
|
if not config.BACKEND_URL
|
|
else fastapi_users.get_oauth_router(
|
|
google_oauth_client,
|
|
auth_backend,
|
|
SECRET,
|
|
is_verified_by_default=True,
|
|
redirect_url=f"{config.BACKEND_URL}/auth/google/callback",
|
|
csrf_token_cookie_secure=is_secure_context,
|
|
csrf_token_cookie_samesite=csrf_cookie_samesite,
|
|
csrf_token_cookie_httponly=False, # Required for cross-site OAuth in Firefox/Safari
|
|
csrf_token_cookie_domain=csrf_cookie_domain, # Explicitly set cookie domain
|
|
),
|
|
prefix="/auth/google",
|
|
tags=["auth"],
|
|
dependencies=[
|
|
Depends(registration_allowed)
|
|
], # blocks OAuth registration when disabled
|
|
)
|
|
|
|
# Add a redirect-based authorize endpoint for Firefox/Safari compatibility
|
|
# This endpoint performs a server-side redirect instead of returning JSON
|
|
# which fixes cross-site cookie issues where browsers don't send cookies
|
|
# set via cross-origin fetch requests on subsequent redirects
|
|
@app.get("/auth/google/authorize-redirect", tags=["auth"])
|
|
async def google_authorize_redirect(
|
|
request: Request,
|
|
):
|
|
"""
|
|
Redirect-based OAuth authorization endpoint.
|
|
|
|
Unlike the standard /auth/google/authorize endpoint that returns JSON,
|
|
this endpoint directly redirects the browser to Google's OAuth page.
|
|
This fixes CSRF cookie issues in Firefox and Safari where cookies set
|
|
via cross-origin fetch requests are not sent on subsequent redirects.
|
|
"""
|
|
import secrets
|
|
|
|
from fastapi_users.router.oauth import generate_state_token
|
|
|
|
# Generate CSRF token
|
|
csrf_token = secrets.token_urlsafe(32)
|
|
|
|
# Build state token
|
|
state_data = {"csrftoken": csrf_token}
|
|
state = generate_state_token(state_data, SECRET, lifetime_seconds=3600)
|
|
|
|
# Get the callback URL
|
|
if config.BACKEND_URL:
|
|
redirect_url = f"{config.BACKEND_URL}/auth/google/callback"
|
|
else:
|
|
redirect_url = str(request.url_for("oauth:google.jwt.callback"))
|
|
|
|
# Get authorization URL from Google
|
|
authorization_url = await google_oauth_client.get_authorization_url(
|
|
redirect_url,
|
|
state,
|
|
scope=["openid", "email", "profile"],
|
|
)
|
|
|
|
# Create redirect response and set CSRF cookie
|
|
response = RedirectResponse(url=authorization_url, status_code=302)
|
|
response.set_cookie(
|
|
key="fastapiusersoauthcsrf",
|
|
value=csrf_token,
|
|
max_age=3600,
|
|
path="/",
|
|
domain=csrf_cookie_domain,
|
|
secure=is_secure_context,
|
|
httponly=False, # Required for cross-site OAuth in Firefox/Safari
|
|
samesite=csrf_cookie_samesite,
|
|
)
|
|
|
|
return response
|
|
|
|
|
|
# Anonymous (no-login) chat routes — mounted at /api/v1/public/anon-chat
|
|
from app.routes.anonymous_chat_routes import ( # noqa: E402
|
|
router as anonymous_chat_router,
|
|
)
|
|
|
|
app.include_router(anonymous_chat_router)
|
|
|
|
app.include_router(crud_router, prefix="/api/v1", tags=["crud"])
|
|
|
|
|
|
@app.get("/health", tags=["health"])
|
|
@limiter.exempt
|
|
async def health_check():
|
|
"""Lightweight liveness probe exempt from rate limiting."""
|
|
return {"status": "ok"}
|
|
|
|
|
|
@app.get("/verify-token")
|
|
async def authenticated_route(
|
|
user: User = Depends(current_active_user),
|
|
session: AsyncSession = Depends(get_async_session),
|
|
):
|
|
return {"message": "Token is valid"}
|