SurfSense/surfsense_backend/app/utils/refresh_tokens.py

206 lines
6.1 KiB
Python
Raw Normal View History

"""Utilities for managing refresh tokens."""
import hashlib
import logging
import secrets
import uuid
2026-06-23 12:49:46 +05:30
from dataclasses import dataclass
from datetime import UTC, datetime, timedelta
from sqlalchemy import select, update
from app.config import config
from app.db import RefreshToken, async_session_maker
logger = logging.getLogger(__name__)
2026-06-23 12:49:46 +05:30
@dataclass(frozen=True)
class RefreshRotationResult:
user_id: uuid.UUID
refresh_token: str | None
access_only: bool = False
def generate_refresh_token() -> str:
"""Generate a cryptographically secure refresh token."""
return secrets.token_urlsafe(32)
def hash_token(token: str) -> str:
"""Hash a token for secure storage."""
return hashlib.sha256(token.encode()).hexdigest()
async def create_refresh_token(
user_id: uuid.UUID,
family_id: uuid.UUID | None = None,
2026-06-23 12:49:46 +05:30
absolute_expiry: datetime | None = None,
) -> str:
"""
Create and store a new refresh token for a user.
Args:
user_id: The user's ID
family_id: Optional family ID for token rotation
Returns:
The plaintext refresh token
"""
token = generate_refresh_token()
token_hash = hash_token(token)
2026-06-23 12:49:46 +05:30
now = datetime.now(UTC)
if absolute_expiry is None:
absolute_expiry = now + timedelta(seconds=config.REFRESH_ABSOLUTE_LIFETIME_SECONDS)
expires_at = min(
now + timedelta(seconds=config.REFRESH_TOKEN_LIFETIME_SECONDS),
absolute_expiry,
)
if family_id is None:
family_id = uuid.uuid4()
async with async_session_maker() as session:
refresh_token = RefreshToken(
user_id=user_id,
token_hash=token_hash,
expires_at=expires_at,
family_id=family_id,
2026-06-23 12:49:46 +05:30
absolute_expiry=absolute_expiry,
)
session.add(refresh_token)
await session.commit()
return token
async def validate_refresh_token(token: str) -> RefreshToken | None:
2026-06-23 12:49:46 +05:30
"""Validate an active refresh token without rotating it."""
token_hash = hash_token(token)
async with async_session_maker() as session:
result = await session.execute(
select(RefreshToken).where(RefreshToken.token_hash == token_hash)
)
refresh_token = result.scalars().first()
if not refresh_token:
return None
2026-06-23 12:49:46 +05:30
now = datetime.now(UTC)
if (
refresh_token.revoked_at is not None
or now >= refresh_token.expires_at
or (
refresh_token.absolute_expiry is not None
and now >= refresh_token.absolute_expiry
)
):
return None
return refresh_token
async def rotate_refresh_token(token: str) -> RefreshRotationResult | None:
"""Atomically rotate a refresh token with access-only grace."""
token_hash = hash_token(token)
now = datetime.now(UTC)
grace_window = timedelta(seconds=config.REFRESH_ROTATION_GRACE_SECONDS)
async with async_session_maker() as session:
async with session.begin():
result = await session.execute(
select(RefreshToken)
.where(RefreshToken.token_hash == token_hash)
.with_for_update()
)
2026-06-23 12:49:46 +05:30
refresh_token = result.scalars().first()
if not refresh_token:
return None
user_id = refresh_token.user_id
if refresh_token.revoked_at is not None:
if (
now - refresh_token.revoked_at <= grace_window
and now < refresh_token.expires_at
):
return RefreshRotationResult(
user_id=user_id,
refresh_token=None,
access_only=True,
)
await session.execute(
update(RefreshToken)
.where(RefreshToken.family_id == refresh_token.family_id)
2026-06-23 12:49:46 +05:30
.values(revoked_at=now, expires_at=now)
)
2026-06-23 12:49:46 +05:30
logger.warning(f"Token reuse detected for user {user_id}")
return None
2026-06-23 12:49:46 +05:30
if now >= refresh_token.expires_at:
return None
2026-06-23 12:49:46 +05:30
family_cap = refresh_token.absolute_expiry or (
now + timedelta(seconds=config.REFRESH_ABSOLUTE_LIFETIME_SECONDS)
)
if now >= family_cap:
return None
new_plaintext = generate_refresh_token()
child = RefreshToken(
user_id=user_id,
token_hash=hash_token(new_plaintext),
expires_at=min(
now + timedelta(seconds=config.REFRESH_TOKEN_LIFETIME_SECONDS),
family_cap,
),
family_id=refresh_token.family_id,
absolute_expiry=family_cap,
)
session.add(child)
refresh_token.revoked_at = now
refresh_token.absolute_expiry = family_cap
2026-06-23 12:49:46 +05:30
return RefreshRotationResult(
user_id=user_id,
refresh_token=new_plaintext,
access_only=False,
)
async def revoke_refresh_token(token: str) -> bool:
"""
Revoke a single refresh token by its plaintext value.
Args:
token: The plaintext refresh token
Returns:
True if token was found and revoked, False otherwise
"""
token_hash = hash_token(token)
2026-06-23 12:49:46 +05:30
now = datetime.now(UTC)
async with async_session_maker() as session:
result = await session.execute(
update(RefreshToken)
.where(RefreshToken.token_hash == token_hash)
2026-06-23 12:49:46 +05:30
.values(revoked_at=now, expires_at=now)
)
await session.commit()
return result.rowcount > 0
async def revoke_all_user_tokens(user_id: uuid.UUID) -> None:
"""Revoke all refresh tokens for a user (logout all devices)."""
2026-06-23 12:49:46 +05:30
now = datetime.now(UTC)
async with async_session_maker() as session:
await session.execute(
update(RefreshToken)
.where(RefreshToken.user_id == user_id)
2026-06-23 12:49:46 +05:30
.values(revoked_at=now, expires_at=now)
)
await session.commit()