diff --git a/surfsense_backend/app/db.py b/surfsense_backend/app/db.py index a4da1a575..2298e7438 100644 --- a/surfsense_backend/app/db.py +++ b/surfsense_backend/app/db.py @@ -1444,16 +1444,12 @@ else: class RefreshToken(Base, TimestampMixin): """ Stores refresh tokens for user session management. - - Refresh tokens are long-lived tokens (2 weeks) used to obtain new - access tokens without requiring re-authentication. + Each row represents one device/session. """ __tablename__ = "refresh_tokens" id = Column(Integer, primary_key=True, autoincrement=True) - - # User relationship user_id = Column( UUID(as_uuid=True), ForeignKey("user.id", ondelete="CASCADE"), @@ -1461,27 +1457,17 @@ class RefreshToken(Base, TimestampMixin): index=True, ) user = relationship("User", back_populates="refresh_tokens") - - # Token hash (stored hashed, not plaintext) token_hash = Column(String(256), unique=True, nullable=False, index=True) - - # Token expiration expires_at = Column(TIMESTAMP(timezone=True), nullable=False, index=True) - - # Revocation flag is_revoked = Column(Boolean, default=False, nullable=False) - - # Token family for rotation tracking (detect reuse attacks) family_id = Column(UUID(as_uuid=True), nullable=False, index=True) @property def is_expired(self) -> bool: - """Check if the token has expired.""" return datetime.now(UTC) >= self.expires_at @property def is_valid(self) -> bool: - """Check if the token is valid (not expired and not revoked).""" return not self.is_expired and not self.is_revoked diff --git a/surfsense_backend/app/users.py b/surfsense_backend/app/users.py index 4be2fe525..cbffd359d 100644 --- a/surfsense_backend/app/users.py +++ b/surfsense_backend/app/users.py @@ -1,5 +1,8 @@ +import hashlib import logging +import secrets import uuid +from datetime import UTC, datetime, timedelta import httpx from fastapi import Depends, Request, Response @@ -12,9 +15,11 @@ from fastapi_users.authentication import ( ) from fastapi_users.db import SQLAlchemyUserDatabase from pydantic import BaseModel +from sqlalchemy import select, update from app.config import config from app.db import ( + RefreshToken, SearchSpace, SearchSpaceMembership, SearchSpaceRole, @@ -29,11 +34,130 @@ logger = logging.getLogger(__name__) class BearerResponse(BaseModel): access_token: str + refresh_token: str token_type: str SECRET = config.SECRET_KEY + +# Refresh token utilities (multi-session) +def generate_refresh_token() -> str: + """Generate a cryptographically secure refresh token.""" + return secrets.token_urlsafe(32) + + +def hash_token(token: str) -> str: + """Hash a token for secure storage.""" + return hashlib.sha256(token.encode()).hexdigest() + + +async def create_refresh_token( + user_id: uuid.UUID, + family_id: uuid.UUID | None = None, +) -> str: + """ + Create and store a new refresh token for a user. + + Args: + user_id: The user's ID + family_id: Optional family ID for token rotation + + Returns: + The plaintext refresh token + """ + token = generate_refresh_token() + token_hash = hash_token(token) + expires_at = datetime.now(UTC) + timedelta( + seconds=config.REFRESH_TOKEN_LIFETIME_SECONDS + ) + + if family_id is None: + family_id = uuid.uuid4() + + async with async_session_maker() as session: + refresh_token = RefreshToken( + user_id=user_id, + token_hash=token_hash, + expires_at=expires_at, + family_id=family_id, + ) + session.add(refresh_token) + await session.commit() + + return token + + +async def validate_refresh_token(token: str) -> RefreshToken | None: + """ + Validate a refresh token. Handles reuse detection. + + Args: + token: The plaintext refresh token + + Returns: + RefreshToken if valid, None otherwise + """ + token_hash = hash_token(token) + + async with async_session_maker() as session: + result = await session.execute( + select(RefreshToken).where(RefreshToken.token_hash == token_hash) + ) + refresh_token = result.scalars().first() + + if not refresh_token: + return None + + # Reuse detection: revoked token used while family has active tokens + if refresh_token.is_revoked: + active = await session.execute( + select(RefreshToken).where( + RefreshToken.family_id == refresh_token.family_id, + RefreshToken.is_revoked == False, # noqa: E712 + RefreshToken.expires_at > datetime.now(UTC), + ) + ) + if active.scalars().first(): + # Revoke entire family + await session.execute( + update(RefreshToken) + .where(RefreshToken.family_id == refresh_token.family_id) + .values(is_revoked=True) + ) + await session.commit() + logger.warning(f"Token reuse detected for user {refresh_token.user_id}") + return None + + if refresh_token.is_expired: + return None + + return refresh_token + + +async def rotate_refresh_token(old_token: RefreshToken) -> str: + """Revoke old token and create new one in same family.""" + async with async_session_maker() as session: + await session.execute( + update(RefreshToken) + .where(RefreshToken.id == old_token.id) + .values(is_revoked=True) + ) + await session.commit() + + return await create_refresh_token(old_token.user_id, old_token.family_id) + + +async def revoke_all_user_tokens(user_id: uuid.UUID) -> None: + """Revoke all refresh tokens for a user (logout all devices).""" + async with async_session_maker() as session: + await session.execute( + update(RefreshToken) + .where(RefreshToken.user_id == user_id) + .values(is_revoked=True) + ) + await session.commit() + if config.AUTH_TYPE == "GOOGLE": from httpx_oauth.clients.google import GoogleOAuth2 @@ -183,7 +307,10 @@ async def get_user_manager(user_db: SQLAlchemyUserDatabase = Depends(get_user_db def get_jwt_strategy() -> JWTStrategy[models.UP, models.ID]: - return JWTStrategy(secret=SECRET, lifetime_seconds=3600 * 24) + return JWTStrategy( + secret=SECRET, + lifetime_seconds=config.ACCESS_TOKEN_LIFETIME_SECONDS, + ) # # COOKIE AUTH | Uncomment if you want to use cookie auth. @@ -209,9 +336,30 @@ def get_jwt_strategy() -> JWTStrategy[models.UP, models.ID]: # BEARER AUTH CODE. class CustomBearerTransport(BearerTransport): async def get_login_response(self, token: str) -> Response: - bearer_response = BearerResponse(access_token=token, token_type="bearer") - redirect_url = f"{config.NEXT_FRONTEND_URL}/auth/callback?token={bearer_response.access_token}" + import jwt + + # Decode JWT to get user_id for refresh token creation + try: + payload = jwt.decode(token, SECRET, algorithms=["HS256"]) + user_id = uuid.UUID(payload.get("sub")) + refresh_token = await create_refresh_token(user_id) + except Exception as e: + logger.error(f"Failed to create refresh token: {e}") + # Fall back to response without refresh token + refresh_token = "" + + bearer_response = BearerResponse( + access_token=token, + refresh_token=refresh_token, + token_type="bearer", + ) + if config.AUTH_TYPE == "GOOGLE": + redirect_url = ( + f"{config.NEXT_FRONTEND_URL}/auth/callback" + f"?token={bearer_response.access_token}" + f"&refresh_token={bearer_response.refresh_token}" + ) return RedirectResponse(redirect_url, status_code=302) else: return JSONResponse(bearer_response.model_dump())