SurfSense/surfsense_backend/app/auth/csrf.py
2026-06-25 04:31:22 +05:30

61 lines
1.9 KiB
Python

"""CSRF protection for ambient cookie-authenticated requests."""
from __future__ import annotations
from urllib.parse import urlparse
from fastapi import status
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
from starlette.requests import Request
from starlette.responses import JSONResponse, Response
from app.config import config
UNSAFE_METHODS = {"POST", "PUT", "PATCH", "DELETE"}
def _origin_from_url(url: str | None) -> str | None:
if not url:
return None
parsed = urlparse(url)
if not parsed.scheme or not parsed.netloc:
return None
return f"{parsed.scheme}://{parsed.netloc}"
def _allowed_origins() -> set[str]:
origins = set(config.CSRF_ALLOWED_ORIGINS)
for url in (config.NEXT_FRONTEND_URL, config.SURFSENSE_PUBLIC_URL):
origin = _origin_from_url(url)
if origin:
origins.add(origin)
return origins
class CsrfOriginMiddleware(BaseHTTPMiddleware):
async def dispatch(
self,
request: Request,
call_next: RequestResponseEndpoint,
) -> Response:
if request.method not in UNSAFE_METHODS:
return await call_next(request)
# PAT/Bearer credentials are not ambient browser credentials and are not
# CSRF-able. Enforce only when the web session cookie is the credential.
if (
request.headers.get("Authorization")
or config.SESSION_COOKIE_NAME not in request.cookies
):
return await call_next(request)
origin = request.headers.get("Origin") or _origin_from_url(
request.headers.get("Referer")
)
if origin not in _allowed_origins():
return JSONResponse(
{"detail": "CSRF origin check failed"},
status_code=status.HTTP_403_FORBIDDEN,
)
return await call_next(request)