diff --git a/surfsense_backend/alembic/versions/169_harden_refresh_token_schema.py b/surfsense_backend/alembic/versions/169_harden_refresh_token_schema.py new file mode 100644 index 000000000..acdfafa68 --- /dev/null +++ b/surfsense_backend/alembic/versions/169_harden_refresh_token_schema.py @@ -0,0 +1,66 @@ +"""harden refresh token schema + +Revision ID: 169 +Revises: 168 +""" + +from collections.abc import Sequence + +import sqlalchemy as sa + +from alembic import op + +revision: str = "169" +down_revision: str | None = "168" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def upgrade() -> None: + op.add_column( + "refresh_tokens", + sa.Column("revoked_at", sa.TIMESTAMP(timezone=True), nullable=True), + ) + op.add_column( + "refresh_tokens", + sa.Column("absolute_expiry", sa.TIMESTAMP(timezone=True), nullable=True), + ) + op.execute( + """ + UPDATE refresh_tokens + SET revoked_at = NOW() + WHERE is_revoked = TRUE + """ + ) + op.alter_column( + "refresh_tokens", + "token_hash", + existing_type=sa.String(length=256), + type_=sa.String(length=64), + existing_nullable=False, + ) + op.drop_column("refresh_tokens", "is_revoked") + + +def downgrade() -> None: + op.add_column( + "refresh_tokens", + sa.Column("is_revoked", sa.Boolean(), nullable=False, server_default="false"), + ) + op.execute( + """ + UPDATE refresh_tokens + SET is_revoked = TRUE + WHERE revoked_at IS NOT NULL + """ + ) + op.alter_column("refresh_tokens", "is_revoked", server_default=None) + op.alter_column( + "refresh_tokens", + "token_hash", + existing_type=sa.String(length=64), + type_=sa.String(length=256), + existing_nullable=False, + ) + op.drop_column("refresh_tokens", "absolute_expiry") + op.drop_column("refresh_tokens", "revoked_at") diff --git a/surfsense_backend/app/utils/refresh_tokens.py b/surfsense_backend/app/utils/refresh_tokens.py index 8c0312ba8..a1c5b658f 100644 --- a/surfsense_backend/app/utils/refresh_tokens.py +++ b/surfsense_backend/app/utils/refresh_tokens.py @@ -4,6 +4,7 @@ import hashlib import logging import secrets import uuid +from dataclasses import dataclass from datetime import UTC, datetime, timedelta from sqlalchemy import select, update @@ -14,6 +15,13 @@ from app.db import RefreshToken, async_session_maker logger = logging.getLogger(__name__) +@dataclass(frozen=True) +class RefreshRotationResult: + user_id: uuid.UUID + refresh_token: str | None + access_only: bool = False + + def generate_refresh_token() -> str: """Generate a cryptographically secure refresh token.""" return secrets.token_urlsafe(32) @@ -27,6 +35,7 @@ def hash_token(token: str) -> str: async def create_refresh_token( user_id: uuid.UUID, family_id: uuid.UUID | None = None, + absolute_expiry: datetime | None = None, ) -> str: """ Create and store a new refresh token for a user. @@ -40,8 +49,12 @@ async def create_refresh_token( """ token = generate_refresh_token() token_hash = hash_token(token) - expires_at = datetime.now(UTC) + timedelta( - seconds=config.REFRESH_TOKEN_LIFETIME_SECONDS + now = datetime.now(UTC) + if absolute_expiry is None: + absolute_expiry = now + timedelta(seconds=config.REFRESH_ABSOLUTE_LIFETIME_SECONDS) + expires_at = min( + now + timedelta(seconds=config.REFRESH_TOKEN_LIFETIME_SECONDS), + absolute_expiry, ) if family_id is None: @@ -53,6 +66,7 @@ async def create_refresh_token( token_hash=token_hash, expires_at=expires_at, family_id=family_id, + absolute_expiry=absolute_expiry, ) session.add(refresh_token) await session.commit() @@ -61,15 +75,7 @@ async def create_refresh_token( async def validate_refresh_token(token: str) -> RefreshToken | None: - """ - Validate a refresh token. Handles reuse detection. - - Args: - token: The plaintext refresh token - - Returns: - RefreshToken if valid, None otherwise - """ + """Validate an active refresh token without rotating it.""" token_hash = hash_token(token) async with async_session_maker() as session: @@ -81,43 +87,87 @@ async def validate_refresh_token(token: str) -> RefreshToken | None: if not refresh_token: return None - # Reuse detection: revoked token used while family has active tokens - if refresh_token.is_revoked: - active = await session.execute( - select(RefreshToken).where( - RefreshToken.family_id == refresh_token.family_id, - RefreshToken.is_revoked == False, # noqa: E712 - RefreshToken.expires_at > datetime.now(UTC), - ) + now = datetime.now(UTC) + if ( + refresh_token.revoked_at is not None + or now >= refresh_token.expires_at + or ( + refresh_token.absolute_expiry is not None + and now >= refresh_token.absolute_expiry ) - if active.scalars().first(): - # Revoke entire family - await session.execute( - update(RefreshToken) - .where(RefreshToken.family_id == refresh_token.family_id) - .values(is_revoked=True) - ) - await session.commit() - logger.warning(f"Token reuse detected for user {refresh_token.user_id}") - return None - - if refresh_token.is_expired: + ): return None return refresh_token -async def rotate_refresh_token(old_token: RefreshToken) -> str: - """Revoke old token and create new one in same family.""" - async with async_session_maker() as session: - await session.execute( - update(RefreshToken) - .where(RefreshToken.id == old_token.id) - .values(is_revoked=True) - ) - await session.commit() +async def rotate_refresh_token(token: str) -> RefreshRotationResult | None: + """Atomically rotate a refresh token with access-only grace.""" + token_hash = hash_token(token) + now = datetime.now(UTC) + grace_window = timedelta(seconds=config.REFRESH_ROTATION_GRACE_SECONDS) - return await create_refresh_token(old_token.user_id, old_token.family_id) + async with async_session_maker() as session: + async with session.begin(): + result = await session.execute( + select(RefreshToken) + .where(RefreshToken.token_hash == token_hash) + .with_for_update() + ) + refresh_token = result.scalars().first() + + if not refresh_token: + return None + user_id = refresh_token.user_id + + if refresh_token.revoked_at is not None: + if ( + now - refresh_token.revoked_at <= grace_window + and now < refresh_token.expires_at + ): + return RefreshRotationResult( + user_id=user_id, + refresh_token=None, + access_only=True, + ) + + await session.execute( + update(RefreshToken) + .where(RefreshToken.family_id == refresh_token.family_id) + .values(revoked_at=now, expires_at=now) + ) + logger.warning(f"Token reuse detected for user {user_id}") + return None + + if now >= refresh_token.expires_at: + return None + + family_cap = refresh_token.absolute_expiry or ( + now + timedelta(seconds=config.REFRESH_ABSOLUTE_LIFETIME_SECONDS) + ) + if now >= family_cap: + return None + + new_plaintext = generate_refresh_token() + child = RefreshToken( + user_id=user_id, + token_hash=hash_token(new_plaintext), + expires_at=min( + now + timedelta(seconds=config.REFRESH_TOKEN_LIFETIME_SECONDS), + family_cap, + ), + family_id=refresh_token.family_id, + absolute_expiry=family_cap, + ) + session.add(child) + refresh_token.revoked_at = now + refresh_token.absolute_expiry = family_cap + + return RefreshRotationResult( + user_id=user_id, + refresh_token=new_plaintext, + access_only=False, + ) async def revoke_refresh_token(token: str) -> bool: @@ -131,12 +181,13 @@ async def revoke_refresh_token(token: str) -> bool: True if token was found and revoked, False otherwise """ token_hash = hash_token(token) + now = datetime.now(UTC) async with async_session_maker() as session: result = await session.execute( update(RefreshToken) .where(RefreshToken.token_hash == token_hash) - .values(is_revoked=True) + .values(revoked_at=now, expires_at=now) ) await session.commit() return result.rowcount > 0 @@ -144,10 +195,11 @@ async def revoke_refresh_token(token: str) -> bool: async def revoke_all_user_tokens(user_id: uuid.UUID) -> None: """Revoke all refresh tokens for a user (logout all devices).""" + now = datetime.now(UTC) async with async_session_maker() as session: await session.execute( update(RefreshToken) .where(RefreshToken.user_id == user_id) - .values(is_revoked=True) + .values(revoked_at=now, expires_at=now) ) await session.commit()