mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-06-26 21:39:43 +02:00
fix(auth):harden refresh token schema
This commit is contained in:
parent
d395d4dc1c
commit
5ba940f905
2 changed files with 161 additions and 43 deletions
|
|
@ -0,0 +1,66 @@
|
|||
"""harden refresh token schema
|
||||
|
||||
Revision ID: 169
|
||||
Revises: 168
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
revision: str = "169"
|
||||
down_revision: str | None = "168"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"refresh_tokens",
|
||||
sa.Column("revoked_at", sa.TIMESTAMP(timezone=True), nullable=True),
|
||||
)
|
||||
op.add_column(
|
||||
"refresh_tokens",
|
||||
sa.Column("absolute_expiry", sa.TIMESTAMP(timezone=True), nullable=True),
|
||||
)
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE refresh_tokens
|
||||
SET revoked_at = NOW()
|
||||
WHERE is_revoked = TRUE
|
||||
"""
|
||||
)
|
||||
op.alter_column(
|
||||
"refresh_tokens",
|
||||
"token_hash",
|
||||
existing_type=sa.String(length=256),
|
||||
type_=sa.String(length=64),
|
||||
existing_nullable=False,
|
||||
)
|
||||
op.drop_column("refresh_tokens", "is_revoked")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.add_column(
|
||||
"refresh_tokens",
|
||||
sa.Column("is_revoked", sa.Boolean(), nullable=False, server_default="false"),
|
||||
)
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE refresh_tokens
|
||||
SET is_revoked = TRUE
|
||||
WHERE revoked_at IS NOT NULL
|
||||
"""
|
||||
)
|
||||
op.alter_column("refresh_tokens", "is_revoked", server_default=None)
|
||||
op.alter_column(
|
||||
"refresh_tokens",
|
||||
"token_hash",
|
||||
existing_type=sa.String(length=64),
|
||||
type_=sa.String(length=256),
|
||||
existing_nullable=False,
|
||||
)
|
||||
op.drop_column("refresh_tokens", "absolute_expiry")
|
||||
op.drop_column("refresh_tokens", "revoked_at")
|
||||
|
|
@ -4,6 +4,7 @@ import hashlib
|
|||
import logging
|
||||
import secrets
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
from sqlalchemy import select, update
|
||||
|
|
@ -14,6 +15,13 @@ from app.db import RefreshToken, async_session_maker
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RefreshRotationResult:
|
||||
user_id: uuid.UUID
|
||||
refresh_token: str | None
|
||||
access_only: bool = False
|
||||
|
||||
|
||||
def generate_refresh_token() -> str:
|
||||
"""Generate a cryptographically secure refresh token."""
|
||||
return secrets.token_urlsafe(32)
|
||||
|
|
@ -27,6 +35,7 @@ def hash_token(token: str) -> str:
|
|||
async def create_refresh_token(
|
||||
user_id: uuid.UUID,
|
||||
family_id: uuid.UUID | None = None,
|
||||
absolute_expiry: datetime | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Create and store a new refresh token for a user.
|
||||
|
|
@ -40,8 +49,12 @@ async def create_refresh_token(
|
|||
"""
|
||||
token = generate_refresh_token()
|
||||
token_hash = hash_token(token)
|
||||
expires_at = datetime.now(UTC) + timedelta(
|
||||
seconds=config.REFRESH_TOKEN_LIFETIME_SECONDS
|
||||
now = datetime.now(UTC)
|
||||
if absolute_expiry is None:
|
||||
absolute_expiry = now + timedelta(seconds=config.REFRESH_ABSOLUTE_LIFETIME_SECONDS)
|
||||
expires_at = min(
|
||||
now + timedelta(seconds=config.REFRESH_TOKEN_LIFETIME_SECONDS),
|
||||
absolute_expiry,
|
||||
)
|
||||
|
||||
if family_id is None:
|
||||
|
|
@ -53,6 +66,7 @@ async def create_refresh_token(
|
|||
token_hash=token_hash,
|
||||
expires_at=expires_at,
|
||||
family_id=family_id,
|
||||
absolute_expiry=absolute_expiry,
|
||||
)
|
||||
session.add(refresh_token)
|
||||
await session.commit()
|
||||
|
|
@ -61,15 +75,7 @@ async def create_refresh_token(
|
|||
|
||||
|
||||
async def validate_refresh_token(token: str) -> RefreshToken | None:
|
||||
"""
|
||||
Validate a refresh token. Handles reuse detection.
|
||||
|
||||
Args:
|
||||
token: The plaintext refresh token
|
||||
|
||||
Returns:
|
||||
RefreshToken if valid, None otherwise
|
||||
"""
|
||||
"""Validate an active refresh token without rotating it."""
|
||||
token_hash = hash_token(token)
|
||||
|
||||
async with async_session_maker() as session:
|
||||
|
|
@ -81,43 +87,87 @@ async def validate_refresh_token(token: str) -> RefreshToken | None:
|
|||
if not refresh_token:
|
||||
return None
|
||||
|
||||
# Reuse detection: revoked token used while family has active tokens
|
||||
if refresh_token.is_revoked:
|
||||
active = await session.execute(
|
||||
select(RefreshToken).where(
|
||||
RefreshToken.family_id == refresh_token.family_id,
|
||||
RefreshToken.is_revoked == False, # noqa: E712
|
||||
RefreshToken.expires_at > datetime.now(UTC),
|
||||
)
|
||||
now = datetime.now(UTC)
|
||||
if (
|
||||
refresh_token.revoked_at is not None
|
||||
or now >= refresh_token.expires_at
|
||||
or (
|
||||
refresh_token.absolute_expiry is not None
|
||||
and now >= refresh_token.absolute_expiry
|
||||
)
|
||||
if active.scalars().first():
|
||||
# Revoke entire family
|
||||
await session.execute(
|
||||
update(RefreshToken)
|
||||
.where(RefreshToken.family_id == refresh_token.family_id)
|
||||
.values(is_revoked=True)
|
||||
)
|
||||
await session.commit()
|
||||
logger.warning(f"Token reuse detected for user {refresh_token.user_id}")
|
||||
return None
|
||||
|
||||
if refresh_token.is_expired:
|
||||
):
|
||||
return None
|
||||
|
||||
return refresh_token
|
||||
|
||||
|
||||
async def rotate_refresh_token(old_token: RefreshToken) -> str:
|
||||
"""Revoke old token and create new one in same family."""
|
||||
async with async_session_maker() as session:
|
||||
await session.execute(
|
||||
update(RefreshToken)
|
||||
.where(RefreshToken.id == old_token.id)
|
||||
.values(is_revoked=True)
|
||||
)
|
||||
await session.commit()
|
||||
async def rotate_refresh_token(token: str) -> RefreshRotationResult | None:
|
||||
"""Atomically rotate a refresh token with access-only grace."""
|
||||
token_hash = hash_token(token)
|
||||
now = datetime.now(UTC)
|
||||
grace_window = timedelta(seconds=config.REFRESH_ROTATION_GRACE_SECONDS)
|
||||
|
||||
return await create_refresh_token(old_token.user_id, old_token.family_id)
|
||||
async with async_session_maker() as session:
|
||||
async with session.begin():
|
||||
result = await session.execute(
|
||||
select(RefreshToken)
|
||||
.where(RefreshToken.token_hash == token_hash)
|
||||
.with_for_update()
|
||||
)
|
||||
refresh_token = result.scalars().first()
|
||||
|
||||
if not refresh_token:
|
||||
return None
|
||||
user_id = refresh_token.user_id
|
||||
|
||||
if refresh_token.revoked_at is not None:
|
||||
if (
|
||||
now - refresh_token.revoked_at <= grace_window
|
||||
and now < refresh_token.expires_at
|
||||
):
|
||||
return RefreshRotationResult(
|
||||
user_id=user_id,
|
||||
refresh_token=None,
|
||||
access_only=True,
|
||||
)
|
||||
|
||||
await session.execute(
|
||||
update(RefreshToken)
|
||||
.where(RefreshToken.family_id == refresh_token.family_id)
|
||||
.values(revoked_at=now, expires_at=now)
|
||||
)
|
||||
logger.warning(f"Token reuse detected for user {user_id}")
|
||||
return None
|
||||
|
||||
if now >= refresh_token.expires_at:
|
||||
return None
|
||||
|
||||
family_cap = refresh_token.absolute_expiry or (
|
||||
now + timedelta(seconds=config.REFRESH_ABSOLUTE_LIFETIME_SECONDS)
|
||||
)
|
||||
if now >= family_cap:
|
||||
return None
|
||||
|
||||
new_plaintext = generate_refresh_token()
|
||||
child = RefreshToken(
|
||||
user_id=user_id,
|
||||
token_hash=hash_token(new_plaintext),
|
||||
expires_at=min(
|
||||
now + timedelta(seconds=config.REFRESH_TOKEN_LIFETIME_SECONDS),
|
||||
family_cap,
|
||||
),
|
||||
family_id=refresh_token.family_id,
|
||||
absolute_expiry=family_cap,
|
||||
)
|
||||
session.add(child)
|
||||
refresh_token.revoked_at = now
|
||||
refresh_token.absolute_expiry = family_cap
|
||||
|
||||
return RefreshRotationResult(
|
||||
user_id=user_id,
|
||||
refresh_token=new_plaintext,
|
||||
access_only=False,
|
||||
)
|
||||
|
||||
|
||||
async def revoke_refresh_token(token: str) -> bool:
|
||||
|
|
@ -131,12 +181,13 @@ async def revoke_refresh_token(token: str) -> bool:
|
|||
True if token was found and revoked, False otherwise
|
||||
"""
|
||||
token_hash = hash_token(token)
|
||||
now = datetime.now(UTC)
|
||||
|
||||
async with async_session_maker() as session:
|
||||
result = await session.execute(
|
||||
update(RefreshToken)
|
||||
.where(RefreshToken.token_hash == token_hash)
|
||||
.values(is_revoked=True)
|
||||
.values(revoked_at=now, expires_at=now)
|
||||
)
|
||||
await session.commit()
|
||||
return result.rowcount > 0
|
||||
|
|
@ -144,10 +195,11 @@ async def revoke_refresh_token(token: str) -> bool:
|
|||
|
||||
async def revoke_all_user_tokens(user_id: uuid.UUID) -> None:
|
||||
"""Revoke all refresh tokens for a user (logout all devices)."""
|
||||
now = datetime.now(UTC)
|
||||
async with async_session_maker() as session:
|
||||
await session.execute(
|
||||
update(RefreshToken)
|
||||
.where(RefreshToken.user_id == user_id)
|
||||
.values(is_revoked=True)
|
||||
.values(revoked_at=now, expires_at=now)
|
||||
)
|
||||
await session.commit()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue