mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-06-26 21:39:43 +02:00
61 lines
1.9 KiB
Python
61 lines
1.9 KiB
Python
"""CSRF protection for ambient cookie-authenticated requests."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from urllib.parse import urlparse
|
|
|
|
from fastapi import status
|
|
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
|
|
from starlette.requests import Request
|
|
from starlette.responses import JSONResponse, Response
|
|
|
|
from app.config import config
|
|
|
|
UNSAFE_METHODS = {"POST", "PUT", "PATCH", "DELETE"}
|
|
|
|
|
|
def _origin_from_url(url: str | None) -> str | None:
|
|
if not url:
|
|
return None
|
|
parsed = urlparse(url)
|
|
if not parsed.scheme or not parsed.netloc:
|
|
return None
|
|
return f"{parsed.scheme}://{parsed.netloc}"
|
|
|
|
|
|
def _allowed_origins() -> set[str]:
|
|
origins = set(config.CSRF_ALLOWED_ORIGINS)
|
|
for url in (config.NEXT_FRONTEND_URL, config.SURFSENSE_PUBLIC_URL):
|
|
origin = _origin_from_url(url)
|
|
if origin:
|
|
origins.add(origin)
|
|
return origins
|
|
|
|
|
|
class CsrfOriginMiddleware(BaseHTTPMiddleware):
|
|
async def dispatch(
|
|
self,
|
|
request: Request,
|
|
call_next: RequestResponseEndpoint,
|
|
) -> Response:
|
|
if request.method not in UNSAFE_METHODS:
|
|
return await call_next(request)
|
|
|
|
# PAT/Bearer credentials are not ambient browser credentials and are not
|
|
# CSRF-able. Enforce only when the web session cookie is the credential.
|
|
if (
|
|
request.headers.get("Authorization")
|
|
or config.SESSION_COOKIE_NAME not in request.cookies
|
|
):
|
|
return await call_next(request)
|
|
|
|
origin = request.headers.get("Origin") or _origin_from_url(
|
|
request.headers.get("Referer")
|
|
)
|
|
if origin not in _allowed_origins():
|
|
return JSONResponse(
|
|
{"detail": "CSRF origin check failed"},
|
|
status_code=status.HTTP_403_FORBIDDEN,
|
|
)
|
|
|
|
return await call_next(request)
|