fix(auth):enforce session auth cutover

This commit is contained in:
Anish Sarkar 2026-06-24 03:55:39 +05:30
parent fbecbb98b5
commit 62c7efb216
4 changed files with 214 additions and 87 deletions

View file

@ -2,9 +2,10 @@
import logging
from datetime import UTC, datetime
from types import SimpleNamespace
from urllib.parse import urlparse
import httpx
import jwt
from fastapi import APIRouter, Depends, HTTPException, Request, Response, status
from google.auth.transport import requests as google_requests
from google.oauth2 import id_token as google_id_token
@ -12,20 +13,31 @@ from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.auth.context import AuthContext
from app.auth.session_cookies import clear_session, read_refresh, write_session
from app.auth.session_cookies import (
access_expires_at,
clear_session,
issue,
read_refresh,
)
from app.config import config
from app.db import User, async_session_maker, get_async_session
from app.rate_limiter import limiter
from app.schemas.auth import (
DesktopLoginRequest,
DesktopSessionRequest,
LogoutAllResponse,
LogoutRequest,
LogoutResponse,
DesktopSessionRequest,
RefreshTokenRequest,
RefreshTokenResponse,
SessionResponse,
)
from app.users import SECRET, UserManager, get_auth_context, get_jwt_strategy, get_user_manager
from app.users import (
UserManager,
get_auth_context,
get_jwt_strategy,
get_user_manager,
)
from app.utils.refresh_tokens import (
create_refresh_token,
revoke_all_user_tokens,
@ -40,36 +52,13 @@ router = APIRouter(prefix="/auth/jwt", tags=["auth"])
session_router = APIRouter(prefix="/auth", tags=["auth"])
def _access_expires_at(access_token: str) -> int:
payload = jwt.decode(
access_token,
SECRET,
algorithms=["HS256"],
options={"verify_aud": False},
)
return int(payload["exp"])
def _request_access_token(request: Request) -> str | None:
cookie_token = request.cookies.get(config.SESSION_COOKIE_NAME)
if cookie_token:
return cookie_token
auth_header = request.headers.get("Authorization")
if not auth_header:
return None
scheme, _, token = auth_header.partition(" ")
if scheme.lower() == "bearer" and token:
return token
return None
async def _load_user(user_id) -> User | None:
async with async_session_maker() as session:
result = await session.execute(select(User).where(User.id == user_id))
return result.scalars().first()
@router.post("/refresh", response_model=RefreshTokenResponse)
@router.post("/refresh", response_model=None)
@limiter.limit("30/minute")
async def refresh_access_token(
request: Request,
@ -80,7 +69,7 @@ async def refresh_access_token(
Exchange a valid refresh token for a new access token and refresh token.
Implements token rotation for security.
"""
refresh_token = read_refresh(request, body)
refresh_token, mode = read_refresh(request, body)
if not refresh_token:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
@ -101,19 +90,18 @@ async def refresh_access_token(
detail="User not found",
)
# Generate new access token
strategy = get_jwt_strategy()
access_token = await strategy.write_token(user)
if request.cookies.get(config.REFRESH_COOKIE_NAME) and rotation.refresh_token:
write_session(response, access_token, rotation.refresh_token, request)
logger.info(f"Refreshed token for user {user.id}")
return RefreshTokenResponse(
access_token=access_token,
refresh_token=rotation.refresh_token,
access_expires_at=_access_expires_at(access_token),
return issue(
response,
mode,
access=access_token,
refresh=rotation.refresh_token,
access_expires_at=access_expires_at(access_token),
request=request,
)
@ -127,7 +115,7 @@ async def revoke_token(
Logout current device by revoking the provided refresh token.
Does not require authentication - just the refresh token.
"""
refresh_token = read_refresh(request, body)
refresh_token, _mode = read_refresh(request, body)
revoked = await revoke_refresh_token(refresh_token) if refresh_token else False
clear_session(response, request)
if revoked:
@ -158,7 +146,7 @@ async def logout_all_devices(
user = None
if user is None:
refresh_token = read_refresh(request, body)
refresh_token, _mode = read_refresh(request, body)
token_record = await validate_refresh_token(refresh_token) if refresh_token else None
if token_record:
user = await _load_user(token_record.user_id)
@ -178,12 +166,55 @@ async def logout_all_devices(
@session_router.get("/session", response_model=SessionResponse)
async def get_session(
request: Request,
_auth: AuthContext = Depends(get_auth_context),
auth: AuthContext = Depends(get_auth_context),
):
access_token = _request_access_token(request)
if not access_token:
if auth.method == "pat":
return SessionResponse(access_expires_at=None)
access_token = request.cookies.get(config.SESSION_COOKIE_NAME)
if access_token is None:
auth_header = request.headers.get("Authorization")
if auth_header:
scheme, _, token = auth_header.partition(" ")
if scheme.lower() == "bearer" and token:
access_token = token
if access_token is None:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Unauthorized")
return SessionResponse(access_expires_at=_access_expires_at(access_token))
return SessionResponse(access_expires_at=access_expires_at(access_token))
@session_router.post("/desktop/login", response_model=RefreshTokenResponse)
@limiter.limit("5/minute")
async def desktop_password_login(
request: Request,
body: DesktopLoginRequest,
user_manager: UserManager = Depends(get_user_manager),
):
if config.AUTH_TYPE == "GOOGLE":
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Not found")
if not config.REGISTRATION_ENABLED:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Registration is disabled",
)
credentials = SimpleNamespace(username=body.email, password=body.password)
user = await user_manager.authenticate(credentials)
if user is None or not user.is_active:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="LOGIN_BAD_CREDENTIALS",
)
app_access_token = await get_jwt_strategy().write_token(user)
app_refresh_token = await create_refresh_token(user.id)
await user_manager.on_after_login(user, request, None)
return RefreshTokenResponse(
access_token=app_access_token,
refresh_token=app_refresh_token,
access_expires_at=access_expires_at(app_access_token),
)
@session_router.post("/desktop/session", response_model=RefreshTokenResponse)
@ -193,7 +224,17 @@ async def create_desktop_session(
body: DesktopSessionRequest,
user_manager: UserManager = Depends(get_user_manager),
):
if not body.redirect_uri.startswith("http://127.0.0.1:"):
parsed_redirect = urlparse(body.redirect_uri)
try:
redirect_port = parsed_redirect.port
except ValueError:
redirect_port = None
if not (
parsed_redirect.scheme == "http"
and parsed_redirect.hostname in {"127.0.0.1", "::1"}
and redirect_port is not None
and parsed_redirect.path == "/callback"
):
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid redirect URI")
if not config.GOOGLE_DESKTOP_CLIENT_ID:
raise HTTPException(
@ -238,6 +279,7 @@ async def create_desktop_session(
if not claims.get("sub") or not claims.get("email"):
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid Google identity token")
email_verified = bool(claims.get("email_verified"))
user = await user_manager.oauth_callback(
"google",
access_token,
@ -250,13 +292,13 @@ async def create_desktop_session(
),
refresh_token=token_data.get("refresh_token"),
request=request,
associate_by_email=True,
is_verified_by_default=True,
associate_by_email=email_verified,
is_verified_by_default=email_verified,
)
app_access_token = await get_jwt_strategy().write_token(user)
app_refresh_token = await create_refresh_token(user.id)
return RefreshTokenResponse(
access_token=app_access_token,
refresh_token=app_refresh_token,
access_expires_at=_access_expires_at(app_access_token),
access_expires_at=access_expires_at(app_access_token),
)

View file

@ -38,10 +38,15 @@ class LogoutAllResponse(BaseModel):
class SessionResponse(BaseModel):
authenticated: bool = True
access_expires_at: int
access_expires_at: int | None = None
class DesktopSessionRequest(BaseModel):
code: str
code_verifier: str
redirect_uri: str
class DesktopLoginRequest(BaseModel):
email: str
password: str

View file

@ -3,6 +3,7 @@ import uuid
from datetime import UTC, datetime
import httpx
import jwt
from fastapi import Depends, HTTPException, Request, Response, status
from fastapi.responses import JSONResponse, RedirectResponse
from fastapi_users import BaseUserManager, FastAPIUsers, UUIDIDMixin, models
@ -12,12 +13,12 @@ from fastapi_users.authentication import (
JWTStrategy,
)
from fastapi_users.db import SQLAlchemyUserDatabase
from pydantic import BaseModel
from fastapi_users.jwt import generate_jwt
from sqlalchemy import update
from sqlalchemy.ext.asyncio import AsyncSession
from app.auth.context import AuthContext
from app.auth.session_cookies import write_session
from app.auth.session_cookies import access_expires_at, write_session
from app.config import config
from app.db import (
Prompt,
@ -37,13 +38,6 @@ from app.utils.refresh_tokens import create_refresh_token
logger = logging.getLogger(__name__)
class BearerResponse(BaseModel):
access_token: str
refresh_token: str
token_type: str
access_expires_at: int
SECRET = config.SECRET_KEY
@ -232,8 +226,23 @@ async def get_user_manager(user_db: SQLAlchemyUserDatabase = Depends(get_user_db
yield UserManager(user_db)
class IatJWTStrategy(JWTStrategy[models.UP, models.ID]):
async def write_token(self, user: models.UP) -> str:
data = {
"sub": str(user.id),
"aud": self.token_audience,
"iat": int(datetime.now(UTC).timestamp()),
}
return generate_jwt(
data,
self.encode_key,
self.lifetime_seconds,
algorithm=self.algorithm,
)
def get_jwt_strategy() -> JWTStrategy[models.UP, models.ID]:
return JWTStrategy(
return IatJWTStrategy(
secret=SECRET,
lifetime_seconds=config.ACCESS_TOKEN_LIFETIME_SECONDS,
)
@ -262,48 +271,34 @@ def get_jwt_strategy() -> JWTStrategy[models.UP, models.ID]:
# BEARER AUTH CODE.
class CustomBearerTransport(BearerTransport):
async def get_login_response(self, token: str) -> Response:
import jwt
# Decode JWT to get user_id for refresh token creation
access_expires_at = 0
try:
payload = jwt.decode(
token, SECRET, algorithms=["HS256"], options={"verify_aud": False}
)
access_expires_at = int(payload["exp"])
user_id = uuid.UUID(payload.get("sub"))
refresh_token = await create_refresh_token(user_id)
except Exception as e:
logger.error(f"Failed to create refresh token: {e}")
# Fall back to response without refresh token
refresh_token = ""
bearer_response = BearerResponse(
access_token=token,
refresh_token=refresh_token,
token_type="bearer",
access_expires_at=access_expires_at,
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to create session",
) from e
if config.AUTH_TYPE == "GOOGLE":
response = RedirectResponse(
f"{config.NEXT_FRONTEND_URL}/auth/callback",
f"{config.NEXT_FRONTEND_URL}/dashboard",
status_code=302,
)
write_session(
response,
bearer_response.access_token,
bearer_response.refresh_token,
)
return response
else:
response = JSONResponse(bearer_response.model_dump())
write_session(
response,
bearer_response.access_token,
bearer_response.refresh_token,
response = JSONResponse(
{
"authenticated": True,
"access_expires_at": access_expires_at(token),
}
)
return response
write_session(response, token, refresh_token)
return response
bearer_transport = CustomBearerTransport(tokenUrl="auth/jwt/login")
@ -318,6 +313,22 @@ auth_backend = AuthenticationBackend(
fastapi_users = FastAPIUsers[User, uuid.UUID](get_user_manager, [auth_backend])
def _token_meets_epoch(token: str) -> bool:
min_issued_at = config.MIN_ISSUED_AT
if min_issued_at <= 0:
return True
try:
payload = jwt.decode(
token, SECRET, algorithms=["HS256"], options={"verify_aud": False}
)
except jwt.PyJWTError:
return False
issued_at = payload.get("iat")
return isinstance(issued_at, (int, float)) and int(issued_at) >= min_issued_at
async def get_auth_context(
request: Request,
session: AsyncSession = Depends(get_async_session),
@ -341,7 +352,7 @@ async def get_auth_context(
maybe_touch_last_used(pat)
return AuthContext.pat_auth(pat.user, pat)
if is_bearer:
if is_bearer and _token_meets_epoch(token):
try:
user = await get_jwt_strategy().read_token(token, user_manager)
except Exception:
@ -352,7 +363,7 @@ async def get_auth_context(
return AuthContext.session(user)
cookie_token = request.cookies.get(config.SESSION_COOKIE_NAME)
if cookie_token:
if cookie_token and _token_meets_epoch(cookie_token):
try:
user = await get_jwt_strategy().read_token(cookie_token, user_manager)
except Exception:

View file

@ -0,0 +1,69 @@
"""One-shot cutover helper to revoke every refresh token.
Run with --yes during the auth-hardening cutover, alongside setting
MIN_ISSUED_AT to the deploy epoch.
"""
from __future__ import annotations
import argparse
import asyncio
from sqlalchemy import text
from app.db import async_session_maker
async def _count_active_tokens() -> int:
async with async_session_maker() as session:
result = await session.execute(
text(
"""
SELECT count(*)
FROM refresh_tokens
WHERE revoked_at IS NULL
AND expires_at > NOW()
"""
)
)
return int(result.scalar_one())
async def _revoke_all_tokens() -> int:
async with async_session_maker() as session:
result = await session.execute(
text(
"""
UPDATE refresh_tokens
SET revoked_at = NOW(),
expires_at = NOW()
WHERE revoked_at IS NULL
OR expires_at > NOW()
"""
)
)
await session.commit()
return int(result.rowcount or 0)
async def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument(
"--yes",
action="store_true",
help="Actually revoke tokens. Without this flag the command is a dry run.",
)
args = parser.parse_args()
active_count = await _count_active_tokens()
if not args.yes:
print(f"Dry run: {active_count} active refresh token(s) would be revoked.")
print("Re-run with --yes during the auth-hardening cutover to revoke them.")
return
updated_count = await _revoke_all_tokens()
print(f"Revoked {updated_count} refresh token row(s).")
if __name__ == "__main__":
asyncio.run(main())