SurfSense/surfsense_backend/app/utils/refresh_tokens.py
2026-06-23 12:49:46 +05:30

205 lines
6.1 KiB
Python

"""Utilities for managing refresh tokens."""
import hashlib
import logging
import secrets
import uuid
from dataclasses import dataclass
from datetime import UTC, datetime, timedelta
from sqlalchemy import select, update
from app.config import config
from app.db import RefreshToken, async_session_maker
logger = logging.getLogger(__name__)
@dataclass(frozen=True)
class RefreshRotationResult:
user_id: uuid.UUID
refresh_token: str | None
access_only: bool = False
def generate_refresh_token() -> str:
"""Generate a cryptographically secure refresh token."""
return secrets.token_urlsafe(32)
def hash_token(token: str) -> str:
"""Hash a token for secure storage."""
return hashlib.sha256(token.encode()).hexdigest()
async def create_refresh_token(
user_id: uuid.UUID,
family_id: uuid.UUID | None = None,
absolute_expiry: datetime | None = None,
) -> str:
"""
Create and store a new refresh token for a user.
Args:
user_id: The user's ID
family_id: Optional family ID for token rotation
Returns:
The plaintext refresh token
"""
token = generate_refresh_token()
token_hash = hash_token(token)
now = datetime.now(UTC)
if absolute_expiry is None:
absolute_expiry = now + timedelta(seconds=config.REFRESH_ABSOLUTE_LIFETIME_SECONDS)
expires_at = min(
now + timedelta(seconds=config.REFRESH_TOKEN_LIFETIME_SECONDS),
absolute_expiry,
)
if family_id is None:
family_id = uuid.uuid4()
async with async_session_maker() as session:
refresh_token = RefreshToken(
user_id=user_id,
token_hash=token_hash,
expires_at=expires_at,
family_id=family_id,
absolute_expiry=absolute_expiry,
)
session.add(refresh_token)
await session.commit()
return token
async def validate_refresh_token(token: str) -> RefreshToken | None:
"""Validate an active refresh token without rotating it."""
token_hash = hash_token(token)
async with async_session_maker() as session:
result = await session.execute(
select(RefreshToken).where(RefreshToken.token_hash == token_hash)
)
refresh_token = result.scalars().first()
if not refresh_token:
return None
now = datetime.now(UTC)
if (
refresh_token.revoked_at is not None
or now >= refresh_token.expires_at
or (
refresh_token.absolute_expiry is not None
and now >= refresh_token.absolute_expiry
)
):
return None
return refresh_token
async def rotate_refresh_token(token: str) -> RefreshRotationResult | None:
"""Atomically rotate a refresh token with access-only grace."""
token_hash = hash_token(token)
now = datetime.now(UTC)
grace_window = timedelta(seconds=config.REFRESH_ROTATION_GRACE_SECONDS)
async with async_session_maker() as session:
async with session.begin():
result = await session.execute(
select(RefreshToken)
.where(RefreshToken.token_hash == token_hash)
.with_for_update()
)
refresh_token = result.scalars().first()
if not refresh_token:
return None
user_id = refresh_token.user_id
if refresh_token.revoked_at is not None:
if (
now - refresh_token.revoked_at <= grace_window
and now < refresh_token.expires_at
):
return RefreshRotationResult(
user_id=user_id,
refresh_token=None,
access_only=True,
)
await session.execute(
update(RefreshToken)
.where(RefreshToken.family_id == refresh_token.family_id)
.values(revoked_at=now, expires_at=now)
)
logger.warning(f"Token reuse detected for user {user_id}")
return None
if now >= refresh_token.expires_at:
return None
family_cap = refresh_token.absolute_expiry or (
now + timedelta(seconds=config.REFRESH_ABSOLUTE_LIFETIME_SECONDS)
)
if now >= family_cap:
return None
new_plaintext = generate_refresh_token()
child = RefreshToken(
user_id=user_id,
token_hash=hash_token(new_plaintext),
expires_at=min(
now + timedelta(seconds=config.REFRESH_TOKEN_LIFETIME_SECONDS),
family_cap,
),
family_id=refresh_token.family_id,
absolute_expiry=family_cap,
)
session.add(child)
refresh_token.revoked_at = now
refresh_token.absolute_expiry = family_cap
return RefreshRotationResult(
user_id=user_id,
refresh_token=new_plaintext,
access_only=False,
)
async def revoke_refresh_token(token: str) -> bool:
"""
Revoke a single refresh token by its plaintext value.
Args:
token: The plaintext refresh token
Returns:
True if token was found and revoked, False otherwise
"""
token_hash = hash_token(token)
now = datetime.now(UTC)
async with async_session_maker() as session:
result = await session.execute(
update(RefreshToken)
.where(RefreshToken.token_hash == token_hash)
.values(revoked_at=now, expires_at=now)
)
await session.commit()
return result.rowcount > 0
async def revoke_all_user_tokens(user_id: uuid.UUID) -> None:
"""Revoke all refresh tokens for a user (logout all devices)."""
now = datetime.now(UTC)
async with async_session_maker() as session:
await session.execute(
update(RefreshToken)
.where(RefreshToken.user_id == user_id)
.values(revoked_at=now, expires_at=now)
)
await session.commit()