diff --git a/surfsense_backend/alembic/versions/92_add_refresh_tokens_table.py b/surfsense_backend/alembic/versions/92_add_refresh_tokens_table.py new file mode 100644 index 000000000..c7e133ae9 --- /dev/null +++ b/surfsense_backend/alembic/versions/92_add_refresh_tokens_table.py @@ -0,0 +1,92 @@ +"""Add refresh_tokens table for user session management + +Revision ID: 92 +Revises: 91 + +Changes: +1. Create refresh_tokens table with columns: + - id (primary key) + - user_id (foreign key to user) + - token_hash (unique, indexed) + - expires_at (indexed) + - is_revoked + - family_id (indexed, for token rotation tracking) + - created_at, updated_at (timestamps) +""" + +from collections.abc import Sequence + +import sqlalchemy as sa +from sqlalchemy.dialects.postgresql import UUID + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "92" +down_revision: str | None = "91" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def upgrade() -> None: + """Create refresh_tokens table (idempotent).""" + # Check if table already exists + connection = op.get_bind() + result = connection.execute( + sa.text( + "SELECT EXISTS (SELECT FROM information_schema.tables WHERE table_name = 'refresh_tokens')" + ) + ) + table_exists = result.scalar() + + if not table_exists: + op.create_table( + "refresh_tokens", + sa.Column("id", sa.Integer(), autoincrement=True, nullable=False), + sa.Column("user_id", UUID(as_uuid=True), nullable=False), + sa.Column("token_hash", sa.String(256), nullable=False), + sa.Column("expires_at", sa.TIMESTAMP(timezone=True), nullable=False), + sa.Column("is_revoked", sa.Boolean(), nullable=False, default=False), + sa.Column("family_id", UUID(as_uuid=True), nullable=False), + sa.Column( + "created_at", + sa.TIMESTAMP(timezone=True), + server_default=sa.text("now()"), + nullable=False, + ), + sa.Column( + "updated_at", + sa.TIMESTAMP(timezone=True), + server_default=sa.text("now()"), + nullable=False, + ), + sa.PrimaryKeyConstraint("id"), + sa.ForeignKeyConstraint( + ["user_id"], + ["user.id"], + ondelete="CASCADE", + ), + ) + + # Create indexes if they don't exist + op.execute( + "CREATE INDEX IF NOT EXISTS ix_refresh_tokens_user_id ON refresh_tokens (user_id)" + ) + op.execute( + "CREATE UNIQUE INDEX IF NOT EXISTS ix_refresh_tokens_token_hash ON refresh_tokens (token_hash)" + ) + op.execute( + "CREATE INDEX IF NOT EXISTS ix_refresh_tokens_expires_at ON refresh_tokens (expires_at)" + ) + op.execute( + "CREATE INDEX IF NOT EXISTS ix_refresh_tokens_family_id ON refresh_tokens (family_id)" + ) + + +def downgrade() -> None: + """Drop refresh_tokens table (idempotent).""" + op.execute("DROP INDEX IF EXISTS ix_refresh_tokens_family_id") + op.execute("DROP INDEX IF EXISTS ix_refresh_tokens_expires_at") + op.execute("DROP INDEX IF EXISTS ix_refresh_tokens_token_hash") + op.execute("DROP INDEX IF EXISTS ix_refresh_tokens_user_id") + op.execute("DROP TABLE IF EXISTS refresh_tokens") diff --git a/surfsense_backend/app/app.py b/surfsense_backend/app/app.py index 01dd0da3d..63da4e8ad 100644 --- a/surfsense_backend/app/app.py +++ b/surfsense_backend/app/app.py @@ -12,6 +12,7 @@ from app.agents.new_chat.checkpointer import ( from app.config import config, initialize_llm_router from app.db import User, create_db_and_tables, get_async_session from app.routes import router as crud_router +from app.routes.auth_routes import router as auth_router from app.schemas import UserCreate, UserRead, UserUpdate from app.tasks.surfsense_docs_indexer import seed_surfsense_docs from app.users import SECRET, auth_backend, current_active_user, fastapi_users @@ -111,6 +112,9 @@ app.include_router( tags=["users"], ) +# Include custom auth routes (refresh token, logout) +app.include_router(auth_router) + if config.AUTH_TYPE == "GOOGLE": from fastapi.responses import RedirectResponse diff --git a/surfsense_backend/app/routes/auth_routes.py b/surfsense_backend/app/routes/auth_routes.py new file mode 100644 index 000000000..67abc5482 --- /dev/null +++ b/surfsense_backend/app/routes/auth_routes.py @@ -0,0 +1,115 @@ +"""Authentication routes for refresh token management.""" + +import logging + +from fastapi import APIRouter, Cookie, Depends, HTTPException, Response, status +from sqlalchemy import select + +from app.db import User, async_session_maker +from app.schemas.auth import LogoutAllResponse, LogoutResponse, RefreshTokenResponse +from app.users import current_active_user, get_jwt_strategy +from app.utils.auth_cookies import ( + REFRESH_TOKEN_COOKIE_NAME, + delete_refresh_token_cookie, + set_refresh_token_cookie, +) +from app.utils.refresh_tokens import ( + revoke_all_user_tokens, + revoke_refresh_token, + rotate_refresh_token, + validate_refresh_token, +) + +logger = logging.getLogger(__name__) + +router = APIRouter(prefix="/auth/jwt", tags=["auth"]) + + +@router.post("/refresh", response_model=RefreshTokenResponse) +async def refresh_access_token( + response: Response, + refresh_token: str | None = Cookie(default=None, alias=REFRESH_TOKEN_COOKIE_NAME), +): + """ + Exchange a valid refresh token for a new access token and refresh token. + Reads refresh token from HTTP-only cookie. Implements token rotation for security. + """ + if not refresh_token: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Refresh token not found", + ) + + token_record = await validate_refresh_token(refresh_token) + + if not token_record: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid or expired refresh token", + ) + + # Get user from token record + async with async_session_maker() as session: + result = await session.execute( + select(User).where(User.id == token_record.user_id) + ) + user = result.scalars().first() + + if not user: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="User not found", + ) + + # Generate new access token + strategy = get_jwt_strategy() + access_token = await strategy.write_token(user) + + # Rotate refresh token + new_refresh_token = await rotate_refresh_token(token_record) + + # Set the new refresh token in cookie + set_refresh_token_cookie(response, new_refresh_token) + + logger.info(f"Refreshed token for user {user.id}") + + return RefreshTokenResponse( + access_token=access_token, + refresh_token=new_refresh_token, + ) + + +@router.post("/logout", response_model=LogoutResponse) +async def logout( + response: Response, + refresh_token: str | None = Cookie(default=None, alias=REFRESH_TOKEN_COOKIE_NAME), +): + """ + Logout current device by revoking the refresh token from cookie. + """ + if refresh_token: + await revoke_refresh_token(refresh_token) + + # Always delete the cookie + delete_refresh_token_cookie(response) + + logger.info("User logged out from current device") + return LogoutResponse() + + +@router.post("/logout-all", response_model=LogoutAllResponse) +async def logout_all_devices( + response: Response, + user: User = Depends(current_active_user), +): + """ + Logout from all devices by revoking all refresh tokens for the user. + Requires valid access token. + """ + await revoke_all_user_tokens(user.id) + + # Delete the cookie on current device + delete_refresh_token_cookie(response) + + logger.info(f"User {user.id} logged out from all devices") + return LogoutAllResponse() diff --git a/surfsense_backend/app/schemas/__init__.py b/surfsense_backend/app/schemas/__init__.py index 6c9577c46..45dba2ba4 100644 --- a/surfsense_backend/app/schemas/__init__.py +++ b/surfsense_backend/app/schemas/__init__.py @@ -1,3 +1,4 @@ +from .auth import LogoutAllResponse, LogoutResponse, RefreshTokenResponse from .base import IDModel, TimestampModel from .chunks import ChunkBase, ChunkCreate, ChunkRead, ChunkUpdate from .documents import ( @@ -117,6 +118,9 @@ __all__ = [ "LogFilter", "LogRead", "LogUpdate", + # Auth schemas + "LogoutAllResponse", + "LogoutResponse", # Search source connector schemas "MCPConnectorCreate", "MCPConnectorRead", @@ -146,6 +150,7 @@ __all__ = [ "PodcastCreate", "PodcastRead", "PodcastUpdate", + "RefreshTokenResponse", "RoleCreate", "RoleRead", "RoleUpdate", diff --git a/surfsense_backend/app/schemas/auth.py b/surfsense_backend/app/schemas/auth.py new file mode 100644 index 000000000..77c61de7e --- /dev/null +++ b/surfsense_backend/app/schemas/auth.py @@ -0,0 +1,23 @@ +"""Authentication schemas for refresh token endpoints.""" + +from pydantic import BaseModel + + +class RefreshTokenResponse(BaseModel): + """Response from token refresh endpoint.""" + + access_token: str + refresh_token: str + token_type: str = "bearer" + + +class LogoutResponse(BaseModel): + """Response from logout endpoint (current device).""" + + detail: str = "Successfully logged out" + + +class LogoutAllResponse(BaseModel): + """Response from logout all devices endpoint.""" + + detail: str = "Successfully logged out from all devices" diff --git a/surfsense_backend/app/users.py b/surfsense_backend/app/users.py index cbffd359d..ffb6c89e8 100644 --- a/surfsense_backend/app/users.py +++ b/surfsense_backend/app/users.py @@ -1,8 +1,5 @@ -import hashlib import logging -import secrets import uuid -from datetime import UTC, datetime, timedelta import httpx from fastapi import Depends, Request, Response @@ -15,11 +12,9 @@ from fastapi_users.authentication import ( ) from fastapi_users.db import SQLAlchemyUserDatabase from pydantic import BaseModel -from sqlalchemy import select, update from app.config import config from app.db import ( - RefreshToken, SearchSpace, SearchSpaceMembership, SearchSpaceRole, @@ -28,6 +23,8 @@ from app.db import ( get_default_roles_config, get_user_db, ) +from app.utils.auth_cookies import set_refresh_token_cookie +from app.utils.refresh_tokens import create_refresh_token logger = logging.getLogger(__name__) @@ -41,123 +38,6 @@ class BearerResponse(BaseModel): SECRET = config.SECRET_KEY -# Refresh token utilities (multi-session) -def generate_refresh_token() -> str: - """Generate a cryptographically secure refresh token.""" - return secrets.token_urlsafe(32) - - -def hash_token(token: str) -> str: - """Hash a token for secure storage.""" - return hashlib.sha256(token.encode()).hexdigest() - - -async def create_refresh_token( - user_id: uuid.UUID, - family_id: uuid.UUID | None = None, -) -> str: - """ - Create and store a new refresh token for a user. - - Args: - user_id: The user's ID - family_id: Optional family ID for token rotation - - Returns: - The plaintext refresh token - """ - token = generate_refresh_token() - token_hash = hash_token(token) - expires_at = datetime.now(UTC) + timedelta( - seconds=config.REFRESH_TOKEN_LIFETIME_SECONDS - ) - - if family_id is None: - family_id = uuid.uuid4() - - async with async_session_maker() as session: - refresh_token = RefreshToken( - user_id=user_id, - token_hash=token_hash, - expires_at=expires_at, - family_id=family_id, - ) - session.add(refresh_token) - await session.commit() - - return token - - -async def validate_refresh_token(token: str) -> RefreshToken | None: - """ - Validate a refresh token. Handles reuse detection. - - Args: - token: The plaintext refresh token - - Returns: - RefreshToken if valid, None otherwise - """ - token_hash = hash_token(token) - - async with async_session_maker() as session: - result = await session.execute( - select(RefreshToken).where(RefreshToken.token_hash == token_hash) - ) - refresh_token = result.scalars().first() - - if not refresh_token: - return None - - # Reuse detection: revoked token used while family has active tokens - if refresh_token.is_revoked: - active = await session.execute( - select(RefreshToken).where( - RefreshToken.family_id == refresh_token.family_id, - RefreshToken.is_revoked == False, # noqa: E712 - RefreshToken.expires_at > datetime.now(UTC), - ) - ) - if active.scalars().first(): - # Revoke entire family - await session.execute( - update(RefreshToken) - .where(RefreshToken.family_id == refresh_token.family_id) - .values(is_revoked=True) - ) - await session.commit() - logger.warning(f"Token reuse detected for user {refresh_token.user_id}") - return None - - if refresh_token.is_expired: - return None - - return refresh_token - - -async def rotate_refresh_token(old_token: RefreshToken) -> str: - """Revoke old token and create new one in same family.""" - async with async_session_maker() as session: - await session.execute( - update(RefreshToken) - .where(RefreshToken.id == old_token.id) - .values(is_revoked=True) - ) - await session.commit() - - return await create_refresh_token(old_token.user_id, old_token.family_id) - - -async def revoke_all_user_tokens(user_id: uuid.UUID) -> None: - """Revoke all refresh tokens for a user (logout all devices).""" - async with async_session_maker() as session: - await session.execute( - update(RefreshToken) - .where(RefreshToken.user_id == user_id) - .values(is_revoked=True) - ) - await session.commit() - if config.AUTH_TYPE == "GOOGLE": from httpx_oauth.clients.google import GoogleOAuth2 @@ -358,11 +238,16 @@ class CustomBearerTransport(BearerTransport): redirect_url = ( f"{config.NEXT_FRONTEND_URL}/auth/callback" f"?token={bearer_response.access_token}" - f"&refresh_token={bearer_response.refresh_token}" ) - return RedirectResponse(redirect_url, status_code=302) + response = RedirectResponse(redirect_url, status_code=302) else: - return JSONResponse(bearer_response.model_dump()) + response = JSONResponse(bearer_response.model_dump()) + + # Set refresh token as HTTP-only cookie + if refresh_token: + set_refresh_token_cookie(response, refresh_token) + + return response bearer_transport = CustomBearerTransport(tokenUrl="auth/jwt/login") diff --git a/surfsense_backend/app/utils/auth_cookies.py b/surfsense_backend/app/utils/auth_cookies.py new file mode 100644 index 000000000..52da80f9d --- /dev/null +++ b/surfsense_backend/app/utils/auth_cookies.py @@ -0,0 +1,29 @@ +"""Utilities for managing authentication cookies.""" + +from fastapi import Response + +from app.config import config + +REFRESH_TOKEN_COOKIE_NAME = "refresh_token" + + +def set_refresh_token_cookie(response: Response, token: str) -> None: + """Set the refresh token as an HTTP-only cookie.""" + response.set_cookie( + key=REFRESH_TOKEN_COOKIE_NAME, + value=token, + max_age=config.REFRESH_TOKEN_LIFETIME_SECONDS, + httponly=True, + secure=True, # Only send over HTTPS + samesite="lax", + ) + + +def delete_refresh_token_cookie(response: Response) -> None: + """Delete the refresh token cookie.""" + response.delete_cookie( + key=REFRESH_TOKEN_COOKIE_NAME, + httponly=True, + secure=True, + samesite="lax", + ) diff --git a/surfsense_backend/app/utils/refresh_tokens.py b/surfsense_backend/app/utils/refresh_tokens.py new file mode 100644 index 000000000..8c0312ba8 --- /dev/null +++ b/surfsense_backend/app/utils/refresh_tokens.py @@ -0,0 +1,153 @@ +"""Utilities for managing refresh tokens.""" + +import hashlib +import logging +import secrets +import uuid +from datetime import UTC, datetime, timedelta + +from sqlalchemy import select, update + +from app.config import config +from app.db import RefreshToken, async_session_maker + +logger = logging.getLogger(__name__) + + +def generate_refresh_token() -> str: + """Generate a cryptographically secure refresh token.""" + return secrets.token_urlsafe(32) + + +def hash_token(token: str) -> str: + """Hash a token for secure storage.""" + return hashlib.sha256(token.encode()).hexdigest() + + +async def create_refresh_token( + user_id: uuid.UUID, + family_id: uuid.UUID | None = None, +) -> str: + """ + Create and store a new refresh token for a user. + + Args: + user_id: The user's ID + family_id: Optional family ID for token rotation + + Returns: + The plaintext refresh token + """ + token = generate_refresh_token() + token_hash = hash_token(token) + expires_at = datetime.now(UTC) + timedelta( + seconds=config.REFRESH_TOKEN_LIFETIME_SECONDS + ) + + if family_id is None: + family_id = uuid.uuid4() + + async with async_session_maker() as session: + refresh_token = RefreshToken( + user_id=user_id, + token_hash=token_hash, + expires_at=expires_at, + family_id=family_id, + ) + session.add(refresh_token) + await session.commit() + + return token + + +async def validate_refresh_token(token: str) -> RefreshToken | None: + """ + Validate a refresh token. Handles reuse detection. + + Args: + token: The plaintext refresh token + + Returns: + RefreshToken if valid, None otherwise + """ + token_hash = hash_token(token) + + async with async_session_maker() as session: + result = await session.execute( + select(RefreshToken).where(RefreshToken.token_hash == token_hash) + ) + refresh_token = result.scalars().first() + + if not refresh_token: + return None + + # Reuse detection: revoked token used while family has active tokens + if refresh_token.is_revoked: + active = await session.execute( + select(RefreshToken).where( + RefreshToken.family_id == refresh_token.family_id, + RefreshToken.is_revoked == False, # noqa: E712 + RefreshToken.expires_at > datetime.now(UTC), + ) + ) + if active.scalars().first(): + # Revoke entire family + await session.execute( + update(RefreshToken) + .where(RefreshToken.family_id == refresh_token.family_id) + .values(is_revoked=True) + ) + await session.commit() + logger.warning(f"Token reuse detected for user {refresh_token.user_id}") + return None + + if refresh_token.is_expired: + return None + + return refresh_token + + +async def rotate_refresh_token(old_token: RefreshToken) -> str: + """Revoke old token and create new one in same family.""" + async with async_session_maker() as session: + await session.execute( + update(RefreshToken) + .where(RefreshToken.id == old_token.id) + .values(is_revoked=True) + ) + await session.commit() + + return await create_refresh_token(old_token.user_id, old_token.family_id) + + +async def revoke_refresh_token(token: str) -> bool: + """ + Revoke a single refresh token by its plaintext value. + + Args: + token: The plaintext refresh token + + Returns: + True if token was found and revoked, False otherwise + """ + token_hash = hash_token(token) + + async with async_session_maker() as session: + result = await session.execute( + update(RefreshToken) + .where(RefreshToken.token_hash == token_hash) + .values(is_revoked=True) + ) + await session.commit() + return result.rowcount > 0 + + +async def revoke_all_user_tokens(user_id: uuid.UUID) -> None: + """Revoke all refresh tokens for a user (logout all devices).""" + async with async_session_maker() as session: + await session.execute( + update(RefreshToken) + .where(RefreshToken.user_id == user_id) + .values(is_revoked=True) + ) + await session.commit()