fix(auth):harden session cookie transport

This commit is contained in:
Anish Sarkar 2026-06-24 03:55:39 +05:30
parent 9b127a8533
commit fbecbb98b5
3 changed files with 53 additions and 13 deletions

View file

@ -807,6 +807,7 @@ allowed_origins.extend(
] ]
) )
app.add_middleware(CsrfOriginMiddleware)
app.add_middleware( app.add_middleware(
CORSMiddleware, CORSMiddleware,
allow_origins=allowed_origins, allow_origins=allowed_origins,
@ -821,7 +822,6 @@ app.add_middleware(
# FRONTEND_URL to BACKEND_URL. # FRONTEND_URL to BACKEND_URL.
max_age=86400, max_age=86400,
) )
app.add_middleware(CsrfOriginMiddleware)
# Password / email-based auth routers are only mounted when not running in # Password / email-based auth routers are only mounted when not running in
# Google-OAuth-only mode. Mounting them in OAuth-only prod previously left # Google-OAuth-only mode. Mounting them in OAuth-only prod previously left

View file

@ -3,13 +3,20 @@
from __future__ import annotations from __future__ import annotations
from datetime import UTC, datetime, timedelta from datetime import UTC, datetime, timedelta
from enum import Enum
from typing import Any from typing import Any
import jwt
from fastapi import Request, Response from fastapi import Request, Response
from app.config import config from app.config import config
class TransportMode(Enum):
COOKIE = "cookie"
HEADER = "header"
def _cookie_secure(request: Request | None = None) -> bool: def _cookie_secure(request: Request | None = None) -> bool:
policy = config.SESSION_COOKIE_SECURE_POLICY policy = config.SESSION_COOKIE_SECURE_POLICY
if policy == "always": if policy == "always":
@ -49,7 +56,7 @@ def _set_persistent_cookie(
def write_session( def write_session(
response: Response, response: Response,
access: str, access: str,
refresh: str, refresh: str | None = None,
request: Request | None = None, request: Request | None = None,
) -> None: ) -> None:
_set_persistent_cookie( _set_persistent_cookie(
@ -59,13 +66,14 @@ def write_session(
max_age=config.ACCESS_TOKEN_LIFETIME_SECONDS, max_age=config.ACCESS_TOKEN_LIFETIME_SECONDS,
request=request, request=request,
) )
_set_persistent_cookie( if refresh is not None:
response, _set_persistent_cookie(
key=config.REFRESH_COOKIE_NAME, response,
value=refresh, key=config.REFRESH_COOKIE_NAME,
max_age=config.REFRESH_TOKEN_LIFETIME_SECONDS, value=refresh,
request=request, max_age=config.REFRESH_TOKEN_LIFETIME_SECONDS,
) request=request,
)
def clear_session(response: Response, request: Request | None = None) -> None: def clear_session(response: Response, request: Request | None = None) -> None:
@ -80,10 +88,41 @@ def clear_session(response: Response, request: Request | None = None) -> None:
) )
def read_refresh(request: Request, body: Any | None = None) -> str | None: def read_refresh(request: Request, body: Any | None = None) -> tuple[str | None, TransportMode]:
cookie = request.cookies.get(config.REFRESH_COOKIE_NAME) cookie = request.cookies.get(config.REFRESH_COOKIE_NAME)
if cookie: if cookie:
return cookie return cookie, TransportMode.COOKIE
if body is None: if body is None:
return None return None, TransportMode.HEADER
return getattr(body, "refresh_token", None) return getattr(body, "refresh_token", None), TransportMode.HEADER
def access_expires_at(access_token: str) -> int:
payload = jwt.decode(
access_token,
config.SECRET_KEY,
algorithms=["HS256"],
options={"verify_aud": False},
)
return int(payload["exp"])
def issue(
response: Response,
mode: TransportMode,
*,
access: str,
refresh: str | None,
access_expires_at: int,
request: Request | None = None,
) -> dict:
if mode is TransportMode.COOKIE:
write_session(response, access, refresh, request)
return {"authenticated": True, "access_expires_at": access_expires_at}
return {
"access_token": access,
"refresh_token": refresh,
"token_type": "bearer",
"access_expires_at": access_expires_at,
}

View file

@ -918,6 +918,7 @@ class Config:
ACCESS_TOKEN_LIFETIME_SECONDS = int( ACCESS_TOKEN_LIFETIME_SECONDS = int(
os.getenv("ACCESS_TOKEN_LIFETIME_SECONDS", str(30 * 60)) # 30 minutes os.getenv("ACCESS_TOKEN_LIFETIME_SECONDS", str(30 * 60)) # 30 minutes
) )
MIN_ISSUED_AT = int(os.getenv("MIN_ISSUED_AT", "0"))
REFRESH_TOKEN_LIFETIME_SECONDS = int( REFRESH_TOKEN_LIFETIME_SECONDS = int(
os.getenv("REFRESH_TOKEN_LIFETIME_SECONDS", str(14 * 24 * 60 * 60)) # 2 weeks os.getenv("REFRESH_TOKEN_LIFETIME_SECONDS", str(14 * 24 * 60 * 60)) # 2 weeks
) )