Merge pull request #784 from CREDO23/sur-137-bug-oauth-tokens-expire-too-quickly-connectors-and-login

[Fixes] Implement refresh token auth, connector token pre-validation, and logout improvements
This commit is contained in:
Rohan Verma 2026-02-05 10:49:02 -08:00 committed by GitHub
commit 459ffd2b78
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
22 changed files with 770 additions and 28 deletions

View file

@ -32,6 +32,11 @@ ELECTRIC_DB_PASSWORD=electric_password
SCHEDULE_CHECKER_INTERVAL=5m
SECRET_KEY=SECRET
# JWT Token Lifetimes (optional, defaults shown)
# ACCESS_TOKEN_LIFETIME_SECONDS=86400 # 1 day
# REFRESH_TOKEN_LIFETIME_SECONDS=1209600 # 2 weeks
NEXT_FRONTEND_URL=http://localhost:3000
# Backend URL for OAuth callbacks (optional, set when behind reverse proxy with HTTPS)

View file

@ -0,0 +1,92 @@
"""Add refresh_tokens table for user session management
Revision ID: 92
Revises: 91
Changes:
1. Create refresh_tokens table with columns:
- id (primary key)
- user_id (foreign key to user)
- token_hash (unique, indexed)
- expires_at (indexed)
- is_revoked
- family_id (indexed, for token rotation tracking)
- created_at, updated_at (timestamps)
"""
from collections.abc import Sequence
import sqlalchemy as sa
from sqlalchemy.dialects.postgresql import UUID
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "92"
down_revision: str | None = "91"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
def upgrade() -> None:
"""Create refresh_tokens table (idempotent)."""
# Check if table already exists
connection = op.get_bind()
result = connection.execute(
sa.text(
"SELECT EXISTS (SELECT FROM information_schema.tables WHERE table_name = 'refresh_tokens')"
)
)
table_exists = result.scalar()
if not table_exists:
op.create_table(
"refresh_tokens",
sa.Column("id", sa.Integer(), autoincrement=True, nullable=False),
sa.Column("user_id", UUID(as_uuid=True), nullable=False),
sa.Column("token_hash", sa.String(256), nullable=False),
sa.Column("expires_at", sa.TIMESTAMP(timezone=True), nullable=False),
sa.Column("is_revoked", sa.Boolean(), nullable=False, default=False),
sa.Column("family_id", UUID(as_uuid=True), nullable=False),
sa.Column(
"created_at",
sa.TIMESTAMP(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.Column(
"updated_at",
sa.TIMESTAMP(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.PrimaryKeyConstraint("id"),
sa.ForeignKeyConstraint(
["user_id"],
["user.id"],
ondelete="CASCADE",
),
)
# Create indexes if they don't exist
op.execute(
"CREATE INDEX IF NOT EXISTS ix_refresh_tokens_user_id ON refresh_tokens (user_id)"
)
op.execute(
"CREATE UNIQUE INDEX IF NOT EXISTS ix_refresh_tokens_token_hash ON refresh_tokens (token_hash)"
)
op.execute(
"CREATE INDEX IF NOT EXISTS ix_refresh_tokens_expires_at ON refresh_tokens (expires_at)"
)
op.execute(
"CREATE INDEX IF NOT EXISTS ix_refresh_tokens_family_id ON refresh_tokens (family_id)"
)
def downgrade() -> None:
"""Drop refresh_tokens table (idempotent)."""
op.execute("DROP INDEX IF EXISTS ix_refresh_tokens_family_id")
op.execute("DROP INDEX IF EXISTS ix_refresh_tokens_expires_at")
op.execute("DROP INDEX IF EXISTS ix_refresh_tokens_token_hash")
op.execute("DROP INDEX IF EXISTS ix_refresh_tokens_user_id")
op.execute("DROP TABLE IF EXISTS refresh_tokens")

View file

@ -12,6 +12,7 @@ from app.agents.new_chat.checkpointer import (
from app.config import config, initialize_llm_router
from app.db import User, create_db_and_tables, get_async_session
from app.routes import router as crud_router
from app.routes.auth_routes import router as auth_router
from app.schemas import UserCreate, UserRead, UserUpdate
from app.tasks.surfsense_docs_indexer import seed_surfsense_docs
from app.users import SECRET, auth_backend, current_active_user, fastapi_users
@ -111,6 +112,9 @@ app.include_router(
tags=["users"],
)
# Include custom auth routes (refresh token, logout)
app.include_router(auth_router)
if config.AUTH_TYPE == "GOOGLE":
from fastapi.responses import RedirectResponse

View file

@ -255,6 +255,14 @@ class Config:
# OAuth JWT
SECRET_KEY = os.getenv("SECRET_KEY")
# JWT Token Lifetimes
ACCESS_TOKEN_LIFETIME_SECONDS = int(
os.getenv("ACCESS_TOKEN_LIFETIME_SECONDS", str(24 * 60 * 60)) # 1 day
)
REFRESH_TOKEN_LIFETIME_SECONDS = int(
os.getenv("REFRESH_TOKEN_LIFETIME_SECONDS", str(14 * 24 * 60 * 60)) # 2 weeks
)
# ETL Service
ETL_SERVICE = os.getenv("ETL_SERVICE")

View file

@ -71,6 +71,14 @@ class AirtableHistoryConnector:
config_data = connector.config.copy()
# Check if access_token exists before processing
raw_access_token = config_data.get("access_token")
if not raw_access_token:
raise ValueError(
"Airtable access token not found. "
"Please reconnect your Airtable account."
)
# Decrypt credentials if they are encrypted
token_encrypted = config_data.get("_token_encrypted", False)
if token_encrypted and config.SECRET_KEY:
@ -98,6 +106,14 @@ class AirtableHistoryConnector:
f"Failed to decrypt Airtable credentials: {e!s}"
) from e
# Final validation after decryption
final_token = config_data.get("access_token")
if not final_token or (isinstance(final_token, str) and not final_token.strip()):
raise ValueError(
"Airtable access token is invalid or empty. "
"Please reconnect your Airtable account."
)
try:
self._credentials = AirtableAuthCredentialsBase.from_dict(config_data)
except Exception as e:

View file

@ -87,6 +87,14 @@ class ConfluenceHistoryConnector:
if is_oauth:
# OAuth 2.0 authentication
# Check if access_token exists before processing
raw_access_token = config_data.get("access_token")
if not raw_access_token:
raise ValueError(
"Confluence access token not found. "
"Please reconnect your Confluence account."
)
# Decrypt credentials if they are encrypted
token_encrypted = config_data.get("_token_encrypted", False)
if token_encrypted and config.SECRET_KEY:
@ -118,6 +126,14 @@ class ConfluenceHistoryConnector:
f"Failed to decrypt Confluence credentials: {e!s}"
) from e
# Final validation after decryption
final_token = config_data.get("access_token")
if not final_token or (isinstance(final_token, str) and not final_token.strip()):
raise ValueError(
"Confluence access token is invalid or empty. "
"Please reconnect your Confluence account."
)
try:
self._credentials = AtlassianAuthCredentialsBase.from_dict(
config_data

View file

@ -86,6 +86,14 @@ class JiraHistoryConnector:
if is_oauth:
# OAuth 2.0 authentication
# Check if access_token exists before processing
raw_access_token = config_data.get("access_token")
if not raw_access_token:
raise ValueError(
"Jira access token not found. "
"Please reconnect your Jira account."
)
if not config.SECRET_KEY:
raise ValueError(
"SECRET_KEY not configured but tokens are marked as encrypted"
@ -119,6 +127,14 @@ class JiraHistoryConnector:
f"Failed to decrypt Jira credentials: {e!s}"
) from e
# Final validation after decryption
final_token = config_data.get("access_token")
if not final_token or (isinstance(final_token, str) and not final_token.strip()):
raise ValueError(
"Jira access token is invalid or empty. "
"Please reconnect your Jira account."
)
try:
self._credentials = AtlassianAuthCredentialsBase.from_dict(
config_data

View file

@ -116,6 +116,14 @@ class LinearConnector:
config_data = connector.config.copy()
# Check if access_token exists before processing
raw_access_token = config_data.get("access_token")
if not raw_access_token:
raise ValueError(
"Linear access token not found. "
"Please reconnect your Linear account."
)
# Decrypt credentials if they are encrypted
token_encrypted = config_data.get("_token_encrypted", False)
if token_encrypted and config.SECRET_KEY:
@ -143,6 +151,14 @@ class LinearConnector:
f"Failed to decrypt Linear credentials: {e!s}"
) from e
# Final validation after decryption
final_token = config_data.get("access_token")
if not final_token or (isinstance(final_token, str) and not final_token.strip()):
raise ValueError(
"Linear access token is invalid or empty. "
"Please reconnect your Linear account."
)
try:
self._credentials = LinearAuthCredentialsBase.from_dict(config_data)
except Exception as e:

View file

@ -1361,6 +1361,13 @@ if config.AUTH_TYPE == "GOOGLE":
display_name = Column(String, nullable=True)
avatar_url = Column(String, nullable=True)
# Refresh tokens for this user
refresh_tokens = relationship(
"RefreshToken",
back_populates="user",
cascade="all, delete-orphan",
)
else:
class User(SQLAlchemyBaseUserTableUUID, Base):
@ -1426,6 +1433,43 @@ else:
display_name = Column(String, nullable=True)
avatar_url = Column(String, nullable=True)
# Refresh tokens for this user
refresh_tokens = relationship(
"RefreshToken",
back_populates="user",
cascade="all, delete-orphan",
)
class RefreshToken(Base, TimestampMixin):
"""
Stores refresh tokens for user session management.
Each row represents one device/session.
"""
__tablename__ = "refresh_tokens"
id = Column(Integer, primary_key=True, autoincrement=True)
user_id = Column(
UUID(as_uuid=True),
ForeignKey("user.id", ondelete="CASCADE"),
nullable=False,
index=True,
)
user = relationship("User", back_populates="refresh_tokens")
token_hash = Column(String(256), unique=True, nullable=False, index=True)
expires_at = Column(TIMESTAMP(timezone=True), nullable=False, index=True)
is_revoked = Column(Boolean, default=False, nullable=False)
family_id = Column(UUID(as_uuid=True), nullable=False, index=True)
@property
def is_expired(self) -> bool:
return datetime.now(UTC) >= self.expires_at
@property
def is_valid(self) -> bool:
return not self.is_expired and not self.is_revoked
engine = create_async_engine(DATABASE_URL)
async_session_maker = async_sessionmaker(engine, expire_on_commit=False)

View file

@ -0,0 +1,93 @@
"""Authentication routes for refresh token management."""
import logging
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy import select
from app.db import User, async_session_maker
from app.schemas.auth import (
LogoutAllResponse,
LogoutRequest,
LogoutResponse,
RefreshTokenRequest,
RefreshTokenResponse,
)
from app.users import current_active_user, get_jwt_strategy
from app.utils.refresh_tokens import (
revoke_all_user_tokens,
revoke_refresh_token,
rotate_refresh_token,
validate_refresh_token,
)
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/auth/jwt", tags=["auth"])
@router.post("/refresh", response_model=RefreshTokenResponse)
async def refresh_access_token(request: RefreshTokenRequest):
"""
Exchange a valid refresh token for a new access token and refresh token.
Implements token rotation for security.
"""
token_record = await validate_refresh_token(request.refresh_token)
if not token_record:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid or expired refresh token",
)
# Get user from token record
async with async_session_maker() as session:
result = await session.execute(
select(User).where(User.id == token_record.user_id)
)
user = result.scalars().first()
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="User not found",
)
# Generate new access token
strategy = get_jwt_strategy()
access_token = await strategy.write_token(user)
# Rotate refresh token
new_refresh_token = await rotate_refresh_token(token_record)
logger.info(f"Refreshed token for user {user.id}")
return RefreshTokenResponse(
access_token=access_token,
refresh_token=new_refresh_token,
)
@router.post("/revoke", response_model=LogoutResponse)
async def revoke_token(request: LogoutRequest):
"""
Logout current device by revoking the provided refresh token.
Does not require authentication - just the refresh token.
"""
revoked = await revoke_refresh_token(request.refresh_token)
if revoked:
logger.info("User logged out from current device - token revoked")
else:
logger.warning("Logout called but no matching token found to revoke")
return LogoutResponse()
@router.post("/logout-all", response_model=LogoutAllResponse)
async def logout_all_devices(user: User = Depends(current_active_user)):
"""
Logout from all devices by revoking all refresh tokens for the user.
Requires valid access token.
"""
await revoke_all_user_tokens(user.id)
logger.info(f"User {user.id} logged out from all devices")
return LogoutAllResponse()

View file

@ -1,3 +1,10 @@
from .auth import (
LogoutAllResponse,
LogoutRequest,
LogoutResponse,
RefreshTokenRequest,
RefreshTokenResponse,
)
from .base import IDModel, TimestampModel
from .chunks import ChunkBase, ChunkCreate, ChunkRead, ChunkUpdate
from .documents import (
@ -117,6 +124,10 @@ __all__ = [
"LogFilter",
"LogRead",
"LogUpdate",
# Auth schemas
"LogoutAllResponse",
"LogoutRequest",
"LogoutResponse",
# Search source connector schemas
"MCPConnectorCreate",
"MCPConnectorRead",
@ -146,6 +157,8 @@ __all__ = [
"PodcastCreate",
"PodcastRead",
"PodcastUpdate",
"RefreshTokenRequest",
"RefreshTokenResponse",
"RoleCreate",
"RoleRead",
"RoleUpdate",

View file

@ -0,0 +1,35 @@
"""Authentication schemas for refresh token endpoints."""
from pydantic import BaseModel
class RefreshTokenRequest(BaseModel):
"""Request body for token refresh endpoint."""
refresh_token: str
class RefreshTokenResponse(BaseModel):
"""Response from token refresh endpoint."""
access_token: str
refresh_token: str
token_type: str = "bearer"
class LogoutRequest(BaseModel):
"""Request body for logout endpoint (current device)."""
refresh_token: str
class LogoutResponse(BaseModel):
"""Response from logout endpoint (current device)."""
detail: str = "Successfully logged out"
class LogoutAllResponse(BaseModel):
"""Response from logout all devices endpoint."""
detail: str = "Successfully logged out from all devices"

View file

@ -23,17 +23,20 @@ from app.db import (
get_default_roles_config,
get_user_db,
)
from app.utils.refresh_tokens import create_refresh_token
logger = logging.getLogger(__name__)
class BearerResponse(BaseModel):
access_token: str
refresh_token: str
token_type: str
SECRET = config.SECRET_KEY
if config.AUTH_TYPE == "GOOGLE":
from httpx_oauth.clients.google import GoogleOAuth2
@ -183,7 +186,10 @@ async def get_user_manager(user_db: SQLAlchemyUserDatabase = Depends(get_user_db
def get_jwt_strategy() -> JWTStrategy[models.UP, models.ID]:
return JWTStrategy(secret=SECRET, lifetime_seconds=3600 * 24)
return JWTStrategy(
secret=SECRET,
lifetime_seconds=config.ACCESS_TOKEN_LIFETIME_SECONDS,
)
# # COOKIE AUTH | Uncomment if you want to use cookie auth.
@ -209,9 +215,30 @@ def get_jwt_strategy() -> JWTStrategy[models.UP, models.ID]:
# BEARER AUTH CODE.
class CustomBearerTransport(BearerTransport):
async def get_login_response(self, token: str) -> Response:
bearer_response = BearerResponse(access_token=token, token_type="bearer")
redirect_url = f"{config.NEXT_FRONTEND_URL}/auth/callback?token={bearer_response.access_token}"
import jwt
# Decode JWT to get user_id for refresh token creation
try:
payload = jwt.decode(token, SECRET, algorithms=["HS256"], options={"verify_aud": False})
user_id = uuid.UUID(payload.get("sub"))
refresh_token = await create_refresh_token(user_id)
except Exception as e:
logger.error(f"Failed to create refresh token: {e}")
# Fall back to response without refresh token
refresh_token = ""
bearer_response = BearerResponse(
access_token=token,
refresh_token=refresh_token,
token_type="bearer",
)
if config.AUTH_TYPE == "GOOGLE":
redirect_url = (
f"{config.NEXT_FRONTEND_URL}/auth/callback"
f"?token={bearer_response.access_token}"
f"&refresh_token={bearer_response.refresh_token}"
)
return RedirectResponse(redirect_url, status_code=302)
else:
return JSONResponse(bearer_response.model_dump())

View file

@ -0,0 +1,153 @@
"""Utilities for managing refresh tokens."""
import hashlib
import logging
import secrets
import uuid
from datetime import UTC, datetime, timedelta
from sqlalchemy import select, update
from app.config import config
from app.db import RefreshToken, async_session_maker
logger = logging.getLogger(__name__)
def generate_refresh_token() -> str:
"""Generate a cryptographically secure refresh token."""
return secrets.token_urlsafe(32)
def hash_token(token: str) -> str:
"""Hash a token for secure storage."""
return hashlib.sha256(token.encode()).hexdigest()
async def create_refresh_token(
user_id: uuid.UUID,
family_id: uuid.UUID | None = None,
) -> str:
"""
Create and store a new refresh token for a user.
Args:
user_id: The user's ID
family_id: Optional family ID for token rotation
Returns:
The plaintext refresh token
"""
token = generate_refresh_token()
token_hash = hash_token(token)
expires_at = datetime.now(UTC) + timedelta(
seconds=config.REFRESH_TOKEN_LIFETIME_SECONDS
)
if family_id is None:
family_id = uuid.uuid4()
async with async_session_maker() as session:
refresh_token = RefreshToken(
user_id=user_id,
token_hash=token_hash,
expires_at=expires_at,
family_id=family_id,
)
session.add(refresh_token)
await session.commit()
return token
async def validate_refresh_token(token: str) -> RefreshToken | None:
"""
Validate a refresh token. Handles reuse detection.
Args:
token: The plaintext refresh token
Returns:
RefreshToken if valid, None otherwise
"""
token_hash = hash_token(token)
async with async_session_maker() as session:
result = await session.execute(
select(RefreshToken).where(RefreshToken.token_hash == token_hash)
)
refresh_token = result.scalars().first()
if not refresh_token:
return None
# Reuse detection: revoked token used while family has active tokens
if refresh_token.is_revoked:
active = await session.execute(
select(RefreshToken).where(
RefreshToken.family_id == refresh_token.family_id,
RefreshToken.is_revoked == False, # noqa: E712
RefreshToken.expires_at > datetime.now(UTC),
)
)
if active.scalars().first():
# Revoke entire family
await session.execute(
update(RefreshToken)
.where(RefreshToken.family_id == refresh_token.family_id)
.values(is_revoked=True)
)
await session.commit()
logger.warning(f"Token reuse detected for user {refresh_token.user_id}")
return None
if refresh_token.is_expired:
return None
return refresh_token
async def rotate_refresh_token(old_token: RefreshToken) -> str:
"""Revoke old token and create new one in same family."""
async with async_session_maker() as session:
await session.execute(
update(RefreshToken)
.where(RefreshToken.id == old_token.id)
.values(is_revoked=True)
)
await session.commit()
return await create_refresh_token(old_token.user_id, old_token.family_id)
async def revoke_refresh_token(token: str) -> bool:
"""
Revoke a single refresh token by its plaintext value.
Args:
token: The plaintext refresh token
Returns:
True if token was found and revoked, False otherwise
"""
token_hash = hash_token(token)
async with async_session_maker() as session:
result = await session.execute(
update(RefreshToken)
.where(RefreshToken.token_hash == token_hash)
.values(is_revoked=True)
)
await session.commit()
return result.rowcount > 0
async def revoke_all_user_tokens(user_id: uuid.UUID) -> None:
"""Revoke all refresh tokens for a user (logout all devices)."""
async with async_session_maker() as session:
await session.execute(
update(RefreshToken)
.where(RefreshToken.user_id == user_id)
.values(is_revoked=True)
)
await session.commit()