fix(auth):enforce session auth cutover

This commit is contained in:
Anish Sarkar 2026-06-24 03:55:39 +05:30
parent fbecbb98b5
commit 62c7efb216
4 changed files with 214 additions and 87 deletions

View file

@ -3,6 +3,7 @@ import uuid
from datetime import UTC, datetime
import httpx
import jwt
from fastapi import Depends, HTTPException, Request, Response, status
from fastapi.responses import JSONResponse, RedirectResponse
from fastapi_users import BaseUserManager, FastAPIUsers, UUIDIDMixin, models
@ -12,12 +13,12 @@ from fastapi_users.authentication import (
JWTStrategy,
)
from fastapi_users.db import SQLAlchemyUserDatabase
from pydantic import BaseModel
from fastapi_users.jwt import generate_jwt
from sqlalchemy import update
from sqlalchemy.ext.asyncio import AsyncSession
from app.auth.context import AuthContext
from app.auth.session_cookies import write_session
from app.auth.session_cookies import access_expires_at, write_session
from app.config import config
from app.db import (
Prompt,
@ -37,13 +38,6 @@ from app.utils.refresh_tokens import create_refresh_token
logger = logging.getLogger(__name__)
class BearerResponse(BaseModel):
access_token: str
refresh_token: str
token_type: str
access_expires_at: int
SECRET = config.SECRET_KEY
@ -232,8 +226,23 @@ async def get_user_manager(user_db: SQLAlchemyUserDatabase = Depends(get_user_db
yield UserManager(user_db)
class IatJWTStrategy(JWTStrategy[models.UP, models.ID]):
async def write_token(self, user: models.UP) -> str:
data = {
"sub": str(user.id),
"aud": self.token_audience,
"iat": int(datetime.now(UTC).timestamp()),
}
return generate_jwt(
data,
self.encode_key,
self.lifetime_seconds,
algorithm=self.algorithm,
)
def get_jwt_strategy() -> JWTStrategy[models.UP, models.ID]:
return JWTStrategy(
return IatJWTStrategy(
secret=SECRET,
lifetime_seconds=config.ACCESS_TOKEN_LIFETIME_SECONDS,
)
@ -262,48 +271,34 @@ def get_jwt_strategy() -> JWTStrategy[models.UP, models.ID]:
# BEARER AUTH CODE.
class CustomBearerTransport(BearerTransport):
async def get_login_response(self, token: str) -> Response:
import jwt
# Decode JWT to get user_id for refresh token creation
access_expires_at = 0
try:
payload = jwt.decode(
token, SECRET, algorithms=["HS256"], options={"verify_aud": False}
)
access_expires_at = int(payload["exp"])
user_id = uuid.UUID(payload.get("sub"))
refresh_token = await create_refresh_token(user_id)
except Exception as e:
logger.error(f"Failed to create refresh token: {e}")
# Fall back to response without refresh token
refresh_token = ""
bearer_response = BearerResponse(
access_token=token,
refresh_token=refresh_token,
token_type="bearer",
access_expires_at=access_expires_at,
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to create session",
) from e
if config.AUTH_TYPE == "GOOGLE":
response = RedirectResponse(
f"{config.NEXT_FRONTEND_URL}/auth/callback",
f"{config.NEXT_FRONTEND_URL}/dashboard",
status_code=302,
)
write_session(
response,
bearer_response.access_token,
bearer_response.refresh_token,
)
return response
else:
response = JSONResponse(bearer_response.model_dump())
write_session(
response,
bearer_response.access_token,
bearer_response.refresh_token,
response = JSONResponse(
{
"authenticated": True,
"access_expires_at": access_expires_at(token),
}
)
return response
write_session(response, token, refresh_token)
return response
bearer_transport = CustomBearerTransport(tokenUrl="auth/jwt/login")
@ -318,6 +313,22 @@ auth_backend = AuthenticationBackend(
fastapi_users = FastAPIUsers[User, uuid.UUID](get_user_manager, [auth_backend])
def _token_meets_epoch(token: str) -> bool:
min_issued_at = config.MIN_ISSUED_AT
if min_issued_at <= 0:
return True
try:
payload = jwt.decode(
token, SECRET, algorithms=["HS256"], options={"verify_aud": False}
)
except jwt.PyJWTError:
return False
issued_at = payload.get("iat")
return isinstance(issued_at, (int, float)) and int(issued_at) >= min_issued_at
async def get_auth_context(
request: Request,
session: AsyncSession = Depends(get_async_session),
@ -341,7 +352,7 @@ async def get_auth_context(
maybe_touch_last_used(pat)
return AuthContext.pat_auth(pat.user, pat)
if is_bearer:
if is_bearer and _token_meets_epoch(token):
try:
user = await get_jwt_strategy().read_token(token, user_manager)
except Exception:
@ -352,7 +363,7 @@ async def get_auth_context(
return AuthContext.session(user)
cookie_token = request.cookies.get(config.SESSION_COOKIE_NAME)
if cookie_token:
if cookie_token and _token_meets_epoch(cookie_token):
try:
user = await get_jwt_strategy().read_token(cookie_token, user_manager)
except Exception: