mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-06-28 21:49:40 +02:00
fix(auth):add csrf session cookie helpers
This commit is contained in:
parent
b05f30e2f9
commit
d395d4dc1c
2 changed files with 147 additions and 0 deletions
58
surfsense_backend/app/auth/csrf.py
Normal file
58
surfsense_backend/app/auth/csrf.py
Normal file
|
|
@ -0,0 +1,58 @@
|
|||
"""CSRF protection for ambient cookie-authenticated requests."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from fastapi import status
|
||||
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import JSONResponse, Response
|
||||
|
||||
from app.config import config
|
||||
|
||||
UNSAFE_METHODS = {"POST", "PUT", "PATCH", "DELETE"}
|
||||
|
||||
|
||||
def _origin_from_url(url: str | None) -> str | None:
|
||||
if not url:
|
||||
return None
|
||||
parsed = urlparse(url)
|
||||
if not parsed.scheme or not parsed.netloc:
|
||||
return None
|
||||
return f"{parsed.scheme}://{parsed.netloc}"
|
||||
|
||||
|
||||
def _allowed_origins() -> set[str]:
|
||||
origins = set(config.CSRF_ALLOWED_ORIGINS)
|
||||
for url in (config.NEXT_FRONTEND_URL, config.SURFSENSE_PUBLIC_URL):
|
||||
origin = _origin_from_url(url)
|
||||
if origin:
|
||||
origins.add(origin)
|
||||
return origins
|
||||
|
||||
|
||||
class CsrfOriginMiddleware(BaseHTTPMiddleware):
|
||||
async def dispatch(
|
||||
self,
|
||||
request: Request,
|
||||
call_next: RequestResponseEndpoint,
|
||||
) -> Response:
|
||||
if request.method not in UNSAFE_METHODS:
|
||||
return await call_next(request)
|
||||
|
||||
# PAT/Bearer credentials are not ambient browser credentials and are not
|
||||
# CSRF-able. Enforce only when the web session cookie is the credential.
|
||||
if request.headers.get("Authorization") or config.SESSION_COOKIE_NAME not in request.cookies:
|
||||
return await call_next(request)
|
||||
|
||||
origin = request.headers.get("Origin") or _origin_from_url(
|
||||
request.headers.get("Referer")
|
||||
)
|
||||
if origin not in _allowed_origins():
|
||||
return JSONResponse(
|
||||
{"detail": "CSRF origin check failed"},
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
)
|
||||
|
||||
return await call_next(request)
|
||||
89
surfsense_backend/app/auth/session_cookies.py
Normal file
89
surfsense_backend/app/auth/session_cookies.py
Normal file
|
|
@ -0,0 +1,89 @@
|
|||
"""Centralized session-cookie I/O for web authentication."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import Any
|
||||
|
||||
from fastapi import Request, Response
|
||||
|
||||
from app.config import config
|
||||
|
||||
|
||||
def _cookie_secure(request: Request | None = None) -> bool:
|
||||
policy = config.SESSION_COOKIE_SECURE_POLICY
|
||||
if policy == "always":
|
||||
return True
|
||||
if policy == "never":
|
||||
return False
|
||||
if request is not None:
|
||||
proto = request.headers.get("x-forwarded-proto")
|
||||
if proto:
|
||||
return proto.split(",", 1)[0].strip().lower() == "https"
|
||||
return request.url.scheme == "https"
|
||||
return bool(config.BACKEND_URL and config.BACKEND_URL.startswith("https://"))
|
||||
|
||||
|
||||
def _set_persistent_cookie(
|
||||
response: Response,
|
||||
*,
|
||||
key: str,
|
||||
value: str,
|
||||
max_age: int,
|
||||
request: Request | None,
|
||||
) -> None:
|
||||
expires = datetime.now(UTC) + timedelta(seconds=max_age)
|
||||
response.set_cookie(
|
||||
key=key,
|
||||
value=value,
|
||||
max_age=max_age,
|
||||
expires=expires,
|
||||
httponly=True,
|
||||
secure=_cookie_secure(request),
|
||||
samesite=config.SESSION_COOKIE_SAMESITE,
|
||||
domain=config.COOKIE_DOMAIN,
|
||||
path="/",
|
||||
)
|
||||
|
||||
|
||||
def write_session(
|
||||
response: Response,
|
||||
access: str,
|
||||
refresh: str,
|
||||
request: Request | None = None,
|
||||
) -> None:
|
||||
_set_persistent_cookie(
|
||||
response,
|
||||
key=config.SESSION_COOKIE_NAME,
|
||||
value=access,
|
||||
max_age=config.ACCESS_TOKEN_LIFETIME_SECONDS,
|
||||
request=request,
|
||||
)
|
||||
_set_persistent_cookie(
|
||||
response,
|
||||
key=config.REFRESH_COOKIE_NAME,
|
||||
value=refresh,
|
||||
max_age=config.REFRESH_TOKEN_LIFETIME_SECONDS,
|
||||
request=request,
|
||||
)
|
||||
|
||||
|
||||
def clear_session(response: Response, request: Request | None = None) -> None:
|
||||
for key in (config.SESSION_COOKIE_NAME, config.REFRESH_COOKIE_NAME):
|
||||
response.delete_cookie(
|
||||
key=key,
|
||||
path="/",
|
||||
domain=config.COOKIE_DOMAIN,
|
||||
secure=_cookie_secure(request),
|
||||
samesite=config.SESSION_COOKIE_SAMESITE,
|
||||
httponly=True,
|
||||
)
|
||||
|
||||
|
||||
def read_refresh(request: Request, body: Any | None = None) -> str | None:
|
||||
cookie = request.cookies.get(config.REFRESH_COOKIE_NAME)
|
||||
if cookie:
|
||||
return cookie
|
||||
if body is None:
|
||||
return None
|
||||
return getattr(body, "refresh_token", None)
|
||||
Loading…
Add table
Add a link
Reference in a new issue