Add refresh token auth routes and utilities

This commit is contained in:
CREDO23 2026-02-05 17:29:50 +02:00
parent 9bd7d74755
commit f3a9922eb9
8 changed files with 431 additions and 125 deletions

View file

@ -0,0 +1,92 @@
"""Add refresh_tokens table for user session management
Revision ID: 92
Revises: 91
Changes:
1. Create refresh_tokens table with columns:
- id (primary key)
- user_id (foreign key to user)
- token_hash (unique, indexed)
- expires_at (indexed)
- is_revoked
- family_id (indexed, for token rotation tracking)
- created_at, updated_at (timestamps)
"""
from collections.abc import Sequence
import sqlalchemy as sa
from sqlalchemy.dialects.postgresql import UUID
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "92"
down_revision: str | None = "91"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
def upgrade() -> None:
"""Create refresh_tokens table (idempotent)."""
# Check if table already exists
connection = op.get_bind()
result = connection.execute(
sa.text(
"SELECT EXISTS (SELECT FROM information_schema.tables WHERE table_name = 'refresh_tokens')"
)
)
table_exists = result.scalar()
if not table_exists:
op.create_table(
"refresh_tokens",
sa.Column("id", sa.Integer(), autoincrement=True, nullable=False),
sa.Column("user_id", UUID(as_uuid=True), nullable=False),
sa.Column("token_hash", sa.String(256), nullable=False),
sa.Column("expires_at", sa.TIMESTAMP(timezone=True), nullable=False),
sa.Column("is_revoked", sa.Boolean(), nullable=False, default=False),
sa.Column("family_id", UUID(as_uuid=True), nullable=False),
sa.Column(
"created_at",
sa.TIMESTAMP(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.Column(
"updated_at",
sa.TIMESTAMP(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.PrimaryKeyConstraint("id"),
sa.ForeignKeyConstraint(
["user_id"],
["user.id"],
ondelete="CASCADE",
),
)
# Create indexes if they don't exist
op.execute(
"CREATE INDEX IF NOT EXISTS ix_refresh_tokens_user_id ON refresh_tokens (user_id)"
)
op.execute(
"CREATE UNIQUE INDEX IF NOT EXISTS ix_refresh_tokens_token_hash ON refresh_tokens (token_hash)"
)
op.execute(
"CREATE INDEX IF NOT EXISTS ix_refresh_tokens_expires_at ON refresh_tokens (expires_at)"
)
op.execute(
"CREATE INDEX IF NOT EXISTS ix_refresh_tokens_family_id ON refresh_tokens (family_id)"
)
def downgrade() -> None:
"""Drop refresh_tokens table (idempotent)."""
op.execute("DROP INDEX IF EXISTS ix_refresh_tokens_family_id")
op.execute("DROP INDEX IF EXISTS ix_refresh_tokens_expires_at")
op.execute("DROP INDEX IF EXISTS ix_refresh_tokens_token_hash")
op.execute("DROP INDEX IF EXISTS ix_refresh_tokens_user_id")
op.execute("DROP TABLE IF EXISTS refresh_tokens")

View file

@ -12,6 +12,7 @@ from app.agents.new_chat.checkpointer import (
from app.config import config, initialize_llm_router
from app.db import User, create_db_and_tables, get_async_session
from app.routes import router as crud_router
from app.routes.auth_routes import router as auth_router
from app.schemas import UserCreate, UserRead, UserUpdate
from app.tasks.surfsense_docs_indexer import seed_surfsense_docs
from app.users import SECRET, auth_backend, current_active_user, fastapi_users
@ -111,6 +112,9 @@ app.include_router(
tags=["users"],
)
# Include custom auth routes (refresh token, logout)
app.include_router(auth_router)
if config.AUTH_TYPE == "GOOGLE":
from fastapi.responses import RedirectResponse

View file

@ -0,0 +1,115 @@
"""Authentication routes for refresh token management."""
import logging
from fastapi import APIRouter, Cookie, Depends, HTTPException, Response, status
from sqlalchemy import select
from app.db import User, async_session_maker
from app.schemas.auth import LogoutAllResponse, LogoutResponse, RefreshTokenResponse
from app.users import current_active_user, get_jwt_strategy
from app.utils.auth_cookies import (
REFRESH_TOKEN_COOKIE_NAME,
delete_refresh_token_cookie,
set_refresh_token_cookie,
)
from app.utils.refresh_tokens import (
revoke_all_user_tokens,
revoke_refresh_token,
rotate_refresh_token,
validate_refresh_token,
)
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/auth/jwt", tags=["auth"])
@router.post("/refresh", response_model=RefreshTokenResponse)
async def refresh_access_token(
response: Response,
refresh_token: str | None = Cookie(default=None, alias=REFRESH_TOKEN_COOKIE_NAME),
):
"""
Exchange a valid refresh token for a new access token and refresh token.
Reads refresh token from HTTP-only cookie. Implements token rotation for security.
"""
if not refresh_token:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Refresh token not found",
)
token_record = await validate_refresh_token(refresh_token)
if not token_record:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid or expired refresh token",
)
# Get user from token record
async with async_session_maker() as session:
result = await session.execute(
select(User).where(User.id == token_record.user_id)
)
user = result.scalars().first()
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="User not found",
)
# Generate new access token
strategy = get_jwt_strategy()
access_token = await strategy.write_token(user)
# Rotate refresh token
new_refresh_token = await rotate_refresh_token(token_record)
# Set the new refresh token in cookie
set_refresh_token_cookie(response, new_refresh_token)
logger.info(f"Refreshed token for user {user.id}")
return RefreshTokenResponse(
access_token=access_token,
refresh_token=new_refresh_token,
)
@router.post("/logout", response_model=LogoutResponse)
async def logout(
response: Response,
refresh_token: str | None = Cookie(default=None, alias=REFRESH_TOKEN_COOKIE_NAME),
):
"""
Logout current device by revoking the refresh token from cookie.
"""
if refresh_token:
await revoke_refresh_token(refresh_token)
# Always delete the cookie
delete_refresh_token_cookie(response)
logger.info("User logged out from current device")
return LogoutResponse()
@router.post("/logout-all", response_model=LogoutAllResponse)
async def logout_all_devices(
response: Response,
user: User = Depends(current_active_user),
):
"""
Logout from all devices by revoking all refresh tokens for the user.
Requires valid access token.
"""
await revoke_all_user_tokens(user.id)
# Delete the cookie on current device
delete_refresh_token_cookie(response)
logger.info(f"User {user.id} logged out from all devices")
return LogoutAllResponse()

View file

@ -1,3 +1,4 @@
from .auth import LogoutAllResponse, LogoutResponse, RefreshTokenResponse
from .base import IDModel, TimestampModel
from .chunks import ChunkBase, ChunkCreate, ChunkRead, ChunkUpdate
from .documents import (
@ -117,6 +118,9 @@ __all__ = [
"LogFilter",
"LogRead",
"LogUpdate",
# Auth schemas
"LogoutAllResponse",
"LogoutResponse",
# Search source connector schemas
"MCPConnectorCreate",
"MCPConnectorRead",
@ -146,6 +150,7 @@ __all__ = [
"PodcastCreate",
"PodcastRead",
"PodcastUpdate",
"RefreshTokenResponse",
"RoleCreate",
"RoleRead",
"RoleUpdate",

View file

@ -0,0 +1,23 @@
"""Authentication schemas for refresh token endpoints."""
from pydantic import BaseModel
class RefreshTokenResponse(BaseModel):
"""Response from token refresh endpoint."""
access_token: str
refresh_token: str
token_type: str = "bearer"
class LogoutResponse(BaseModel):
"""Response from logout endpoint (current device)."""
detail: str = "Successfully logged out"
class LogoutAllResponse(BaseModel):
"""Response from logout all devices endpoint."""
detail: str = "Successfully logged out from all devices"

View file

@ -1,8 +1,5 @@
import hashlib
import logging
import secrets
import uuid
from datetime import UTC, datetime, timedelta
import httpx
from fastapi import Depends, Request, Response
@ -15,11 +12,9 @@ from fastapi_users.authentication import (
)
from fastapi_users.db import SQLAlchemyUserDatabase
from pydantic import BaseModel
from sqlalchemy import select, update
from app.config import config
from app.db import (
RefreshToken,
SearchSpace,
SearchSpaceMembership,
SearchSpaceRole,
@ -28,6 +23,8 @@ from app.db import (
get_default_roles_config,
get_user_db,
)
from app.utils.auth_cookies import set_refresh_token_cookie
from app.utils.refresh_tokens import create_refresh_token
logger = logging.getLogger(__name__)
@ -41,123 +38,6 @@ class BearerResponse(BaseModel):
SECRET = config.SECRET_KEY
# Refresh token utilities (multi-session)
def generate_refresh_token() -> str:
"""Generate a cryptographically secure refresh token."""
return secrets.token_urlsafe(32)
def hash_token(token: str) -> str:
"""Hash a token for secure storage."""
return hashlib.sha256(token.encode()).hexdigest()
async def create_refresh_token(
user_id: uuid.UUID,
family_id: uuid.UUID | None = None,
) -> str:
"""
Create and store a new refresh token for a user.
Args:
user_id: The user's ID
family_id: Optional family ID for token rotation
Returns:
The plaintext refresh token
"""
token = generate_refresh_token()
token_hash = hash_token(token)
expires_at = datetime.now(UTC) + timedelta(
seconds=config.REFRESH_TOKEN_LIFETIME_SECONDS
)
if family_id is None:
family_id = uuid.uuid4()
async with async_session_maker() as session:
refresh_token = RefreshToken(
user_id=user_id,
token_hash=token_hash,
expires_at=expires_at,
family_id=family_id,
)
session.add(refresh_token)
await session.commit()
return token
async def validate_refresh_token(token: str) -> RefreshToken | None:
"""
Validate a refresh token. Handles reuse detection.
Args:
token: The plaintext refresh token
Returns:
RefreshToken if valid, None otherwise
"""
token_hash = hash_token(token)
async with async_session_maker() as session:
result = await session.execute(
select(RefreshToken).where(RefreshToken.token_hash == token_hash)
)
refresh_token = result.scalars().first()
if not refresh_token:
return None
# Reuse detection: revoked token used while family has active tokens
if refresh_token.is_revoked:
active = await session.execute(
select(RefreshToken).where(
RefreshToken.family_id == refresh_token.family_id,
RefreshToken.is_revoked == False, # noqa: E712
RefreshToken.expires_at > datetime.now(UTC),
)
)
if active.scalars().first():
# Revoke entire family
await session.execute(
update(RefreshToken)
.where(RefreshToken.family_id == refresh_token.family_id)
.values(is_revoked=True)
)
await session.commit()
logger.warning(f"Token reuse detected for user {refresh_token.user_id}")
return None
if refresh_token.is_expired:
return None
return refresh_token
async def rotate_refresh_token(old_token: RefreshToken) -> str:
"""Revoke old token and create new one in same family."""
async with async_session_maker() as session:
await session.execute(
update(RefreshToken)
.where(RefreshToken.id == old_token.id)
.values(is_revoked=True)
)
await session.commit()
return await create_refresh_token(old_token.user_id, old_token.family_id)
async def revoke_all_user_tokens(user_id: uuid.UUID) -> None:
"""Revoke all refresh tokens for a user (logout all devices)."""
async with async_session_maker() as session:
await session.execute(
update(RefreshToken)
.where(RefreshToken.user_id == user_id)
.values(is_revoked=True)
)
await session.commit()
if config.AUTH_TYPE == "GOOGLE":
from httpx_oauth.clients.google import GoogleOAuth2
@ -358,11 +238,16 @@ class CustomBearerTransport(BearerTransport):
redirect_url = (
f"{config.NEXT_FRONTEND_URL}/auth/callback"
f"?token={bearer_response.access_token}"
f"&refresh_token={bearer_response.refresh_token}"
)
return RedirectResponse(redirect_url, status_code=302)
response = RedirectResponse(redirect_url, status_code=302)
else:
return JSONResponse(bearer_response.model_dump())
response = JSONResponse(bearer_response.model_dump())
# Set refresh token as HTTP-only cookie
if refresh_token:
set_refresh_token_cookie(response, refresh_token)
return response
bearer_transport = CustomBearerTransport(tokenUrl="auth/jwt/login")

View file

@ -0,0 +1,29 @@
"""Utilities for managing authentication cookies."""
from fastapi import Response
from app.config import config
REFRESH_TOKEN_COOKIE_NAME = "refresh_token"
def set_refresh_token_cookie(response: Response, token: str) -> None:
"""Set the refresh token as an HTTP-only cookie."""
response.set_cookie(
key=REFRESH_TOKEN_COOKIE_NAME,
value=token,
max_age=config.REFRESH_TOKEN_LIFETIME_SECONDS,
httponly=True,
secure=True, # Only send over HTTPS
samesite="lax",
)
def delete_refresh_token_cookie(response: Response) -> None:
"""Delete the refresh token cookie."""
response.delete_cookie(
key=REFRESH_TOKEN_COOKIE_NAME,
httponly=True,
secure=True,
samesite="lax",
)

View file

@ -0,0 +1,153 @@
"""Utilities for managing refresh tokens."""
import hashlib
import logging
import secrets
import uuid
from datetime import UTC, datetime, timedelta
from sqlalchemy import select, update
from app.config import config
from app.db import RefreshToken, async_session_maker
logger = logging.getLogger(__name__)
def generate_refresh_token() -> str:
"""Generate a cryptographically secure refresh token."""
return secrets.token_urlsafe(32)
def hash_token(token: str) -> str:
"""Hash a token for secure storage."""
return hashlib.sha256(token.encode()).hexdigest()
async def create_refresh_token(
user_id: uuid.UUID,
family_id: uuid.UUID | None = None,
) -> str:
"""
Create and store a new refresh token for a user.
Args:
user_id: The user's ID
family_id: Optional family ID for token rotation
Returns:
The plaintext refresh token
"""
token = generate_refresh_token()
token_hash = hash_token(token)
expires_at = datetime.now(UTC) + timedelta(
seconds=config.REFRESH_TOKEN_LIFETIME_SECONDS
)
if family_id is None:
family_id = uuid.uuid4()
async with async_session_maker() as session:
refresh_token = RefreshToken(
user_id=user_id,
token_hash=token_hash,
expires_at=expires_at,
family_id=family_id,
)
session.add(refresh_token)
await session.commit()
return token
async def validate_refresh_token(token: str) -> RefreshToken | None:
"""
Validate a refresh token. Handles reuse detection.
Args:
token: The plaintext refresh token
Returns:
RefreshToken if valid, None otherwise
"""
token_hash = hash_token(token)
async with async_session_maker() as session:
result = await session.execute(
select(RefreshToken).where(RefreshToken.token_hash == token_hash)
)
refresh_token = result.scalars().first()
if not refresh_token:
return None
# Reuse detection: revoked token used while family has active tokens
if refresh_token.is_revoked:
active = await session.execute(
select(RefreshToken).where(
RefreshToken.family_id == refresh_token.family_id,
RefreshToken.is_revoked == False, # noqa: E712
RefreshToken.expires_at > datetime.now(UTC),
)
)
if active.scalars().first():
# Revoke entire family
await session.execute(
update(RefreshToken)
.where(RefreshToken.family_id == refresh_token.family_id)
.values(is_revoked=True)
)
await session.commit()
logger.warning(f"Token reuse detected for user {refresh_token.user_id}")
return None
if refresh_token.is_expired:
return None
return refresh_token
async def rotate_refresh_token(old_token: RefreshToken) -> str:
"""Revoke old token and create new one in same family."""
async with async_session_maker() as session:
await session.execute(
update(RefreshToken)
.where(RefreshToken.id == old_token.id)
.values(is_revoked=True)
)
await session.commit()
return await create_refresh_token(old_token.user_id, old_token.family_id)
async def revoke_refresh_token(token: str) -> bool:
"""
Revoke a single refresh token by its plaintext value.
Args:
token: The plaintext refresh token
Returns:
True if token was found and revoked, False otherwise
"""
token_hash = hash_token(token)
async with async_session_maker() as session:
result = await session.execute(
update(RefreshToken)
.where(RefreshToken.token_hash == token_hash)
.values(is_revoked=True)
)
await session.commit()
return result.rowcount > 0
async def revoke_all_user_tokens(user_id: uuid.UUID) -> None:
"""Revoke all refresh tokens for a user (logout all devices)."""
async with async_session_maker() as session:
await session.execute(
update(RefreshToken)
.where(RefreshToken.user_id == user_id)
.values(is_revoked=True)
)
await session.commit()