mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-04-25 08:46:22 +02:00
154 lines
4.3 KiB
Python
154 lines
4.3 KiB
Python
|
|
"""Utilities for managing refresh tokens."""
|
||
|
|
|
||
|
|
import hashlib
|
||
|
|
import logging
|
||
|
|
import secrets
|
||
|
|
import uuid
|
||
|
|
from datetime import UTC, datetime, timedelta
|
||
|
|
|
||
|
|
from sqlalchemy import select, update
|
||
|
|
|
||
|
|
from app.config import config
|
||
|
|
from app.db import RefreshToken, async_session_maker
|
||
|
|
|
||
|
|
logger = logging.getLogger(__name__)
|
||
|
|
|
||
|
|
|
||
|
|
def generate_refresh_token() -> str:
|
||
|
|
"""Generate a cryptographically secure refresh token."""
|
||
|
|
return secrets.token_urlsafe(32)
|
||
|
|
|
||
|
|
|
||
|
|
def hash_token(token: str) -> str:
|
||
|
|
"""Hash a token for secure storage."""
|
||
|
|
return hashlib.sha256(token.encode()).hexdigest()
|
||
|
|
|
||
|
|
|
||
|
|
async def create_refresh_token(
|
||
|
|
user_id: uuid.UUID,
|
||
|
|
family_id: uuid.UUID | None = None,
|
||
|
|
) -> str:
|
||
|
|
"""
|
||
|
|
Create and store a new refresh token for a user.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
user_id: The user's ID
|
||
|
|
family_id: Optional family ID for token rotation
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
The plaintext refresh token
|
||
|
|
"""
|
||
|
|
token = generate_refresh_token()
|
||
|
|
token_hash = hash_token(token)
|
||
|
|
expires_at = datetime.now(UTC) + timedelta(
|
||
|
|
seconds=config.REFRESH_TOKEN_LIFETIME_SECONDS
|
||
|
|
)
|
||
|
|
|
||
|
|
if family_id is None:
|
||
|
|
family_id = uuid.uuid4()
|
||
|
|
|
||
|
|
async with async_session_maker() as session:
|
||
|
|
refresh_token = RefreshToken(
|
||
|
|
user_id=user_id,
|
||
|
|
token_hash=token_hash,
|
||
|
|
expires_at=expires_at,
|
||
|
|
family_id=family_id,
|
||
|
|
)
|
||
|
|
session.add(refresh_token)
|
||
|
|
await session.commit()
|
||
|
|
|
||
|
|
return token
|
||
|
|
|
||
|
|
|
||
|
|
async def validate_refresh_token(token: str) -> RefreshToken | None:
|
||
|
|
"""
|
||
|
|
Validate a refresh token. Handles reuse detection.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
token: The plaintext refresh token
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
RefreshToken if valid, None otherwise
|
||
|
|
"""
|
||
|
|
token_hash = hash_token(token)
|
||
|
|
|
||
|
|
async with async_session_maker() as session:
|
||
|
|
result = await session.execute(
|
||
|
|
select(RefreshToken).where(RefreshToken.token_hash == token_hash)
|
||
|
|
)
|
||
|
|
refresh_token = result.scalars().first()
|
||
|
|
|
||
|
|
if not refresh_token:
|
||
|
|
return None
|
||
|
|
|
||
|
|
# Reuse detection: revoked token used while family has active tokens
|
||
|
|
if refresh_token.is_revoked:
|
||
|
|
active = await session.execute(
|
||
|
|
select(RefreshToken).where(
|
||
|
|
RefreshToken.family_id == refresh_token.family_id,
|
||
|
|
RefreshToken.is_revoked == False, # noqa: E712
|
||
|
|
RefreshToken.expires_at > datetime.now(UTC),
|
||
|
|
)
|
||
|
|
)
|
||
|
|
if active.scalars().first():
|
||
|
|
# Revoke entire family
|
||
|
|
await session.execute(
|
||
|
|
update(RefreshToken)
|
||
|
|
.where(RefreshToken.family_id == refresh_token.family_id)
|
||
|
|
.values(is_revoked=True)
|
||
|
|
)
|
||
|
|
await session.commit()
|
||
|
|
logger.warning(f"Token reuse detected for user {refresh_token.user_id}")
|
||
|
|
return None
|
||
|
|
|
||
|
|
if refresh_token.is_expired:
|
||
|
|
return None
|
||
|
|
|
||
|
|
return refresh_token
|
||
|
|
|
||
|
|
|
||
|
|
async def rotate_refresh_token(old_token: RefreshToken) -> str:
|
||
|
|
"""Revoke old token and create new one in same family."""
|
||
|
|
async with async_session_maker() as session:
|
||
|
|
await session.execute(
|
||
|
|
update(RefreshToken)
|
||
|
|
.where(RefreshToken.id == old_token.id)
|
||
|
|
.values(is_revoked=True)
|
||
|
|
)
|
||
|
|
await session.commit()
|
||
|
|
|
||
|
|
return await create_refresh_token(old_token.user_id, old_token.family_id)
|
||
|
|
|
||
|
|
|
||
|
|
async def revoke_refresh_token(token: str) -> bool:
|
||
|
|
"""
|
||
|
|
Revoke a single refresh token by its plaintext value.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
token: The plaintext refresh token
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
True if token was found and revoked, False otherwise
|
||
|
|
"""
|
||
|
|
token_hash = hash_token(token)
|
||
|
|
|
||
|
|
async with async_session_maker() as session:
|
||
|
|
result = await session.execute(
|
||
|
|
update(RefreshToken)
|
||
|
|
.where(RefreshToken.token_hash == token_hash)
|
||
|
|
.values(is_revoked=True)
|
||
|
|
)
|
||
|
|
await session.commit()
|
||
|
|
return result.rowcount > 0
|
||
|
|
|
||
|
|
|
||
|
|
async def revoke_all_user_tokens(user_id: uuid.UUID) -> None:
|
||
|
|
"""Revoke all refresh tokens for a user (logout all devices)."""
|
||
|
|
async with async_session_maker() as session:
|
||
|
|
await session.execute(
|
||
|
|
update(RefreshToken)
|
||
|
|
.where(RefreshToken.user_id == user_id)
|
||
|
|
.values(is_revoked=True)
|
||
|
|
)
|
||
|
|
await session.commit()
|