fix(auth):harden refresh token schema

This commit is contained in:
Anish Sarkar 2026-06-23 12:49:46 +05:30
parent d395d4dc1c
commit 5ba940f905
2 changed files with 161 additions and 43 deletions

View file

@ -0,0 +1,66 @@
"""harden refresh token schema
Revision ID: 169
Revises: 168
"""
from collections.abc import Sequence
import sqlalchemy as sa
from alembic import op
revision: str = "169"
down_revision: str | None = "168"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
def upgrade() -> None:
op.add_column(
"refresh_tokens",
sa.Column("revoked_at", sa.TIMESTAMP(timezone=True), nullable=True),
)
op.add_column(
"refresh_tokens",
sa.Column("absolute_expiry", sa.TIMESTAMP(timezone=True), nullable=True),
)
op.execute(
"""
UPDATE refresh_tokens
SET revoked_at = NOW()
WHERE is_revoked = TRUE
"""
)
op.alter_column(
"refresh_tokens",
"token_hash",
existing_type=sa.String(length=256),
type_=sa.String(length=64),
existing_nullable=False,
)
op.drop_column("refresh_tokens", "is_revoked")
def downgrade() -> None:
op.add_column(
"refresh_tokens",
sa.Column("is_revoked", sa.Boolean(), nullable=False, server_default="false"),
)
op.execute(
"""
UPDATE refresh_tokens
SET is_revoked = TRUE
WHERE revoked_at IS NOT NULL
"""
)
op.alter_column("refresh_tokens", "is_revoked", server_default=None)
op.alter_column(
"refresh_tokens",
"token_hash",
existing_type=sa.String(length=64),
type_=sa.String(length=256),
existing_nullable=False,
)
op.drop_column("refresh_tokens", "absolute_expiry")
op.drop_column("refresh_tokens", "revoked_at")

View file

@ -4,6 +4,7 @@ import hashlib
import logging
import secrets
import uuid
from dataclasses import dataclass
from datetime import UTC, datetime, timedelta
from sqlalchemy import select, update
@ -14,6 +15,13 @@ from app.db import RefreshToken, async_session_maker
logger = logging.getLogger(__name__)
@dataclass(frozen=True)
class RefreshRotationResult:
user_id: uuid.UUID
refresh_token: str | None
access_only: bool = False
def generate_refresh_token() -> str:
"""Generate a cryptographically secure refresh token."""
return secrets.token_urlsafe(32)
@ -27,6 +35,7 @@ def hash_token(token: str) -> str:
async def create_refresh_token(
user_id: uuid.UUID,
family_id: uuid.UUID | None = None,
absolute_expiry: datetime | None = None,
) -> str:
"""
Create and store a new refresh token for a user.
@ -40,8 +49,12 @@ async def create_refresh_token(
"""
token = generate_refresh_token()
token_hash = hash_token(token)
expires_at = datetime.now(UTC) + timedelta(
seconds=config.REFRESH_TOKEN_LIFETIME_SECONDS
now = datetime.now(UTC)
if absolute_expiry is None:
absolute_expiry = now + timedelta(seconds=config.REFRESH_ABSOLUTE_LIFETIME_SECONDS)
expires_at = min(
now + timedelta(seconds=config.REFRESH_TOKEN_LIFETIME_SECONDS),
absolute_expiry,
)
if family_id is None:
@ -53,6 +66,7 @@ async def create_refresh_token(
token_hash=token_hash,
expires_at=expires_at,
family_id=family_id,
absolute_expiry=absolute_expiry,
)
session.add(refresh_token)
await session.commit()
@ -61,15 +75,7 @@ async def create_refresh_token(
async def validate_refresh_token(token: str) -> RefreshToken | None:
"""
Validate a refresh token. Handles reuse detection.
Args:
token: The plaintext refresh token
Returns:
RefreshToken if valid, None otherwise
"""
"""Validate an active refresh token without rotating it."""
token_hash = hash_token(token)
async with async_session_maker() as session:
@ -81,43 +87,87 @@ async def validate_refresh_token(token: str) -> RefreshToken | None:
if not refresh_token:
return None
# Reuse detection: revoked token used while family has active tokens
if refresh_token.is_revoked:
active = await session.execute(
select(RefreshToken).where(
RefreshToken.family_id == refresh_token.family_id,
RefreshToken.is_revoked == False, # noqa: E712
RefreshToken.expires_at > datetime.now(UTC),
)
now = datetime.now(UTC)
if (
refresh_token.revoked_at is not None
or now >= refresh_token.expires_at
or (
refresh_token.absolute_expiry is not None
and now >= refresh_token.absolute_expiry
)
if active.scalars().first():
# Revoke entire family
await session.execute(
update(RefreshToken)
.where(RefreshToken.family_id == refresh_token.family_id)
.values(is_revoked=True)
)
await session.commit()
logger.warning(f"Token reuse detected for user {refresh_token.user_id}")
return None
if refresh_token.is_expired:
):
return None
return refresh_token
async def rotate_refresh_token(old_token: RefreshToken) -> str:
"""Revoke old token and create new one in same family."""
async with async_session_maker() as session:
await session.execute(
update(RefreshToken)
.where(RefreshToken.id == old_token.id)
.values(is_revoked=True)
)
await session.commit()
async def rotate_refresh_token(token: str) -> RefreshRotationResult | None:
"""Atomically rotate a refresh token with access-only grace."""
token_hash = hash_token(token)
now = datetime.now(UTC)
grace_window = timedelta(seconds=config.REFRESH_ROTATION_GRACE_SECONDS)
return await create_refresh_token(old_token.user_id, old_token.family_id)
async with async_session_maker() as session:
async with session.begin():
result = await session.execute(
select(RefreshToken)
.where(RefreshToken.token_hash == token_hash)
.with_for_update()
)
refresh_token = result.scalars().first()
if not refresh_token:
return None
user_id = refresh_token.user_id
if refresh_token.revoked_at is not None:
if (
now - refresh_token.revoked_at <= grace_window
and now < refresh_token.expires_at
):
return RefreshRotationResult(
user_id=user_id,
refresh_token=None,
access_only=True,
)
await session.execute(
update(RefreshToken)
.where(RefreshToken.family_id == refresh_token.family_id)
.values(revoked_at=now, expires_at=now)
)
logger.warning(f"Token reuse detected for user {user_id}")
return None
if now >= refresh_token.expires_at:
return None
family_cap = refresh_token.absolute_expiry or (
now + timedelta(seconds=config.REFRESH_ABSOLUTE_LIFETIME_SECONDS)
)
if now >= family_cap:
return None
new_plaintext = generate_refresh_token()
child = RefreshToken(
user_id=user_id,
token_hash=hash_token(new_plaintext),
expires_at=min(
now + timedelta(seconds=config.REFRESH_TOKEN_LIFETIME_SECONDS),
family_cap,
),
family_id=refresh_token.family_id,
absolute_expiry=family_cap,
)
session.add(child)
refresh_token.revoked_at = now
refresh_token.absolute_expiry = family_cap
return RefreshRotationResult(
user_id=user_id,
refresh_token=new_plaintext,
access_only=False,
)
async def revoke_refresh_token(token: str) -> bool:
@ -131,12 +181,13 @@ async def revoke_refresh_token(token: str) -> bool:
True if token was found and revoked, False otherwise
"""
token_hash = hash_token(token)
now = datetime.now(UTC)
async with async_session_maker() as session:
result = await session.execute(
update(RefreshToken)
.where(RefreshToken.token_hash == token_hash)
.values(is_revoked=True)
.values(revoked_at=now, expires_at=now)
)
await session.commit()
return result.rowcount > 0
@ -144,10 +195,11 @@ async def revoke_refresh_token(token: str) -> bool:
async def revoke_all_user_tokens(user_id: uuid.UUID) -> None:
"""Revoke all refresh tokens for a user (logout all devices)."""
now = datetime.now(UTC)
async with async_session_maker() as session:
await session.execute(
update(RefreshToken)
.where(RefreshToken.user_id == user_id)
.values(is_revoked=True)
.values(revoked_at=now, expires_at=now)
)
await session.commit()