mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-04-25 00:36:31 +02:00
Add RefreshToken model and multi-session refresh token logic
This commit is contained in:
parent
048ef7024f
commit
9bd7d74755
2 changed files with 152 additions and 18 deletions
|
|
@ -1444,16 +1444,12 @@ else:
|
|||
class RefreshToken(Base, TimestampMixin):
|
||||
"""
|
||||
Stores refresh tokens for user session management.
|
||||
|
||||
Refresh tokens are long-lived tokens (2 weeks) used to obtain new
|
||||
access tokens without requiring re-authentication.
|
||||
Each row represents one device/session.
|
||||
"""
|
||||
|
||||
__tablename__ = "refresh_tokens"
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
|
||||
# User relationship
|
||||
user_id = Column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("user.id", ondelete="CASCADE"),
|
||||
|
|
@ -1461,27 +1457,17 @@ class RefreshToken(Base, TimestampMixin):
|
|||
index=True,
|
||||
)
|
||||
user = relationship("User", back_populates="refresh_tokens")
|
||||
|
||||
# Token hash (stored hashed, not plaintext)
|
||||
token_hash = Column(String(256), unique=True, nullable=False, index=True)
|
||||
|
||||
# Token expiration
|
||||
expires_at = Column(TIMESTAMP(timezone=True), nullable=False, index=True)
|
||||
|
||||
# Revocation flag
|
||||
is_revoked = Column(Boolean, default=False, nullable=False)
|
||||
|
||||
# Token family for rotation tracking (detect reuse attacks)
|
||||
family_id = Column(UUID(as_uuid=True), nullable=False, index=True)
|
||||
|
||||
@property
|
||||
def is_expired(self) -> bool:
|
||||
"""Check if the token has expired."""
|
||||
return datetime.now(UTC) >= self.expires_at
|
||||
|
||||
@property
|
||||
def is_valid(self) -> bool:
|
||||
"""Check if the token is valid (not expired and not revoked)."""
|
||||
return not self.is_expired and not self.is_revoked
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,8 @@
|
|||
import hashlib
|
||||
import logging
|
||||
import secrets
|
||||
import uuid
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
import httpx
|
||||
from fastapi import Depends, Request, Response
|
||||
|
|
@ -12,9 +15,11 @@ from fastapi_users.authentication import (
|
|||
)
|
||||
from fastapi_users.db import SQLAlchemyUserDatabase
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import select, update
|
||||
|
||||
from app.config import config
|
||||
from app.db import (
|
||||
RefreshToken,
|
||||
SearchSpace,
|
||||
SearchSpaceMembership,
|
||||
SearchSpaceRole,
|
||||
|
|
@ -29,11 +34,130 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
class BearerResponse(BaseModel):
|
||||
access_token: str
|
||||
refresh_token: str
|
||||
token_type: str
|
||||
|
||||
|
||||
SECRET = config.SECRET_KEY
|
||||
|
||||
|
||||
# Refresh token utilities (multi-session)
|
||||
def generate_refresh_token() -> str:
|
||||
"""Generate a cryptographically secure refresh token."""
|
||||
return secrets.token_urlsafe(32)
|
||||
|
||||
|
||||
def hash_token(token: str) -> str:
|
||||
"""Hash a token for secure storage."""
|
||||
return hashlib.sha256(token.encode()).hexdigest()
|
||||
|
||||
|
||||
async def create_refresh_token(
|
||||
user_id: uuid.UUID,
|
||||
family_id: uuid.UUID | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Create and store a new refresh token for a user.
|
||||
|
||||
Args:
|
||||
user_id: The user's ID
|
||||
family_id: Optional family ID for token rotation
|
||||
|
||||
Returns:
|
||||
The plaintext refresh token
|
||||
"""
|
||||
token = generate_refresh_token()
|
||||
token_hash = hash_token(token)
|
||||
expires_at = datetime.now(UTC) + timedelta(
|
||||
seconds=config.REFRESH_TOKEN_LIFETIME_SECONDS
|
||||
)
|
||||
|
||||
if family_id is None:
|
||||
family_id = uuid.uuid4()
|
||||
|
||||
async with async_session_maker() as session:
|
||||
refresh_token = RefreshToken(
|
||||
user_id=user_id,
|
||||
token_hash=token_hash,
|
||||
expires_at=expires_at,
|
||||
family_id=family_id,
|
||||
)
|
||||
session.add(refresh_token)
|
||||
await session.commit()
|
||||
|
||||
return token
|
||||
|
||||
|
||||
async def validate_refresh_token(token: str) -> RefreshToken | None:
|
||||
"""
|
||||
Validate a refresh token. Handles reuse detection.
|
||||
|
||||
Args:
|
||||
token: The plaintext refresh token
|
||||
|
||||
Returns:
|
||||
RefreshToken if valid, None otherwise
|
||||
"""
|
||||
token_hash = hash_token(token)
|
||||
|
||||
async with async_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(RefreshToken).where(RefreshToken.token_hash == token_hash)
|
||||
)
|
||||
refresh_token = result.scalars().first()
|
||||
|
||||
if not refresh_token:
|
||||
return None
|
||||
|
||||
# Reuse detection: revoked token used while family has active tokens
|
||||
if refresh_token.is_revoked:
|
||||
active = await session.execute(
|
||||
select(RefreshToken).where(
|
||||
RefreshToken.family_id == refresh_token.family_id,
|
||||
RefreshToken.is_revoked == False, # noqa: E712
|
||||
RefreshToken.expires_at > datetime.now(UTC),
|
||||
)
|
||||
)
|
||||
if active.scalars().first():
|
||||
# Revoke entire family
|
||||
await session.execute(
|
||||
update(RefreshToken)
|
||||
.where(RefreshToken.family_id == refresh_token.family_id)
|
||||
.values(is_revoked=True)
|
||||
)
|
||||
await session.commit()
|
||||
logger.warning(f"Token reuse detected for user {refresh_token.user_id}")
|
||||
return None
|
||||
|
||||
if refresh_token.is_expired:
|
||||
return None
|
||||
|
||||
return refresh_token
|
||||
|
||||
|
||||
async def rotate_refresh_token(old_token: RefreshToken) -> str:
|
||||
"""Revoke old token and create new one in same family."""
|
||||
async with async_session_maker() as session:
|
||||
await session.execute(
|
||||
update(RefreshToken)
|
||||
.where(RefreshToken.id == old_token.id)
|
||||
.values(is_revoked=True)
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
return await create_refresh_token(old_token.user_id, old_token.family_id)
|
||||
|
||||
|
||||
async def revoke_all_user_tokens(user_id: uuid.UUID) -> None:
|
||||
"""Revoke all refresh tokens for a user (logout all devices)."""
|
||||
async with async_session_maker() as session:
|
||||
await session.execute(
|
||||
update(RefreshToken)
|
||||
.where(RefreshToken.user_id == user_id)
|
||||
.values(is_revoked=True)
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
if config.AUTH_TYPE == "GOOGLE":
|
||||
from httpx_oauth.clients.google import GoogleOAuth2
|
||||
|
||||
|
|
@ -183,7 +307,10 @@ async def get_user_manager(user_db: SQLAlchemyUserDatabase = Depends(get_user_db
|
|||
|
||||
|
||||
def get_jwt_strategy() -> JWTStrategy[models.UP, models.ID]:
|
||||
return JWTStrategy(secret=SECRET, lifetime_seconds=3600 * 24)
|
||||
return JWTStrategy(
|
||||
secret=SECRET,
|
||||
lifetime_seconds=config.ACCESS_TOKEN_LIFETIME_SECONDS,
|
||||
)
|
||||
|
||||
|
||||
# # COOKIE AUTH | Uncomment if you want to use cookie auth.
|
||||
|
|
@ -209,9 +336,30 @@ def get_jwt_strategy() -> JWTStrategy[models.UP, models.ID]:
|
|||
# BEARER AUTH CODE.
|
||||
class CustomBearerTransport(BearerTransport):
|
||||
async def get_login_response(self, token: str) -> Response:
|
||||
bearer_response = BearerResponse(access_token=token, token_type="bearer")
|
||||
redirect_url = f"{config.NEXT_FRONTEND_URL}/auth/callback?token={bearer_response.access_token}"
|
||||
import jwt
|
||||
|
||||
# Decode JWT to get user_id for refresh token creation
|
||||
try:
|
||||
payload = jwt.decode(token, SECRET, algorithms=["HS256"])
|
||||
user_id = uuid.UUID(payload.get("sub"))
|
||||
refresh_token = await create_refresh_token(user_id)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create refresh token: {e}")
|
||||
# Fall back to response without refresh token
|
||||
refresh_token = ""
|
||||
|
||||
bearer_response = BearerResponse(
|
||||
access_token=token,
|
||||
refresh_token=refresh_token,
|
||||
token_type="bearer",
|
||||
)
|
||||
|
||||
if config.AUTH_TYPE == "GOOGLE":
|
||||
redirect_url = (
|
||||
f"{config.NEXT_FRONTEND_URL}/auth/callback"
|
||||
f"?token={bearer_response.access_token}"
|
||||
f"&refresh_token={bearer_response.refresh_token}"
|
||||
)
|
||||
return RedirectResponse(redirect_url, status_code=302)
|
||||
else:
|
||||
return JSONResponse(bearer_response.model_dump())
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue