Add refresh token auth routes and utilities

This commit is contained in:
CREDO23 2026-02-05 17:29:50 +02:00
parent 9bd7d74755
commit f3a9922eb9
8 changed files with 431 additions and 125 deletions

View file

@ -1,8 +1,5 @@
import hashlib
import logging
import secrets
import uuid
from datetime import UTC, datetime, timedelta
import httpx
from fastapi import Depends, Request, Response
@ -15,11 +12,9 @@ from fastapi_users.authentication import (
)
from fastapi_users.db import SQLAlchemyUserDatabase
from pydantic import BaseModel
from sqlalchemy import select, update
from app.config import config
from app.db import (
RefreshToken,
SearchSpace,
SearchSpaceMembership,
SearchSpaceRole,
@ -28,6 +23,8 @@ from app.db import (
get_default_roles_config,
get_user_db,
)
from app.utils.auth_cookies import set_refresh_token_cookie
from app.utils.refresh_tokens import create_refresh_token
logger = logging.getLogger(__name__)
@ -41,123 +38,6 @@ class BearerResponse(BaseModel):
SECRET = config.SECRET_KEY
# Refresh token utilities (multi-session)
def generate_refresh_token() -> str:
"""Generate a cryptographically secure refresh token."""
return secrets.token_urlsafe(32)
def hash_token(token: str) -> str:
"""Hash a token for secure storage."""
return hashlib.sha256(token.encode()).hexdigest()
async def create_refresh_token(
user_id: uuid.UUID,
family_id: uuid.UUID | None = None,
) -> str:
"""
Create and store a new refresh token for a user.
Args:
user_id: The user's ID
family_id: Optional family ID for token rotation
Returns:
The plaintext refresh token
"""
token = generate_refresh_token()
token_hash = hash_token(token)
expires_at = datetime.now(UTC) + timedelta(
seconds=config.REFRESH_TOKEN_LIFETIME_SECONDS
)
if family_id is None:
family_id = uuid.uuid4()
async with async_session_maker() as session:
refresh_token = RefreshToken(
user_id=user_id,
token_hash=token_hash,
expires_at=expires_at,
family_id=family_id,
)
session.add(refresh_token)
await session.commit()
return token
async def validate_refresh_token(token: str) -> RefreshToken | None:
"""
Validate a refresh token. Handles reuse detection.
Args:
token: The plaintext refresh token
Returns:
RefreshToken if valid, None otherwise
"""
token_hash = hash_token(token)
async with async_session_maker() as session:
result = await session.execute(
select(RefreshToken).where(RefreshToken.token_hash == token_hash)
)
refresh_token = result.scalars().first()
if not refresh_token:
return None
# Reuse detection: revoked token used while family has active tokens
if refresh_token.is_revoked:
active = await session.execute(
select(RefreshToken).where(
RefreshToken.family_id == refresh_token.family_id,
RefreshToken.is_revoked == False, # noqa: E712
RefreshToken.expires_at > datetime.now(UTC),
)
)
if active.scalars().first():
# Revoke entire family
await session.execute(
update(RefreshToken)
.where(RefreshToken.family_id == refresh_token.family_id)
.values(is_revoked=True)
)
await session.commit()
logger.warning(f"Token reuse detected for user {refresh_token.user_id}")
return None
if refresh_token.is_expired:
return None
return refresh_token
async def rotate_refresh_token(old_token: RefreshToken) -> str:
"""Revoke old token and create new one in same family."""
async with async_session_maker() as session:
await session.execute(
update(RefreshToken)
.where(RefreshToken.id == old_token.id)
.values(is_revoked=True)
)
await session.commit()
return await create_refresh_token(old_token.user_id, old_token.family_id)
async def revoke_all_user_tokens(user_id: uuid.UUID) -> None:
"""Revoke all refresh tokens for a user (logout all devices)."""
async with async_session_maker() as session:
await session.execute(
update(RefreshToken)
.where(RefreshToken.user_id == user_id)
.values(is_revoked=True)
)
await session.commit()
if config.AUTH_TYPE == "GOOGLE":
from httpx_oauth.clients.google import GoogleOAuth2
@ -358,11 +238,16 @@ class CustomBearerTransport(BearerTransport):
redirect_url = (
f"{config.NEXT_FRONTEND_URL}/auth/callback"
f"?token={bearer_response.access_token}"
f"&refresh_token={bearer_response.refresh_token}"
)
return RedirectResponse(redirect_url, status_code=302)
response = RedirectResponse(redirect_url, status_code=302)
else:
return JSONResponse(bearer_response.model_dump())
response = JSONResponse(bearer_response.model_dump())
# Set refresh token as HTTP-only cookie
if refresh_token:
set_refresh_token_cookie(response, refresh_token)
return response
bearer_transport = CustomBearerTransport(tokenUrl="auth/jwt/login")