Merge pull request #784 from CREDO23/sur-137-bug-oauth-tokens-expire-too-quickly-connectors-and-login

[Fixes] Implement refresh token auth, connector token pre-validation, and logout improvements
This commit is contained in:
Rohan Verma 2026-02-05 10:49:02 -08:00 committed by GitHub
commit 459ffd2b78
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
22 changed files with 770 additions and 28 deletions

View file

@ -32,6 +32,11 @@ ELECTRIC_DB_PASSWORD=electric_password
SCHEDULE_CHECKER_INTERVAL=5m
SECRET_KEY=SECRET
# JWT Token Lifetimes (optional, defaults shown)
# ACCESS_TOKEN_LIFETIME_SECONDS=86400 # 1 day
# REFRESH_TOKEN_LIFETIME_SECONDS=1209600 # 2 weeks
NEXT_FRONTEND_URL=http://localhost:3000
# Backend URL for OAuth callbacks (optional, set when behind reverse proxy with HTTPS)

View file

@ -0,0 +1,92 @@
"""Add refresh_tokens table for user session management
Revision ID: 92
Revises: 91
Changes:
1. Create refresh_tokens table with columns:
- id (primary key)
- user_id (foreign key to user)
- token_hash (unique, indexed)
- expires_at (indexed)
- is_revoked
- family_id (indexed, for token rotation tracking)
- created_at, updated_at (timestamps)
"""
from collections.abc import Sequence
import sqlalchemy as sa
from sqlalchemy.dialects.postgresql import UUID
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "92"
down_revision: str | None = "91"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
def upgrade() -> None:
"""Create refresh_tokens table (idempotent)."""
# Check if table already exists
connection = op.get_bind()
result = connection.execute(
sa.text(
"SELECT EXISTS (SELECT FROM information_schema.tables WHERE table_name = 'refresh_tokens')"
)
)
table_exists = result.scalar()
if not table_exists:
op.create_table(
"refresh_tokens",
sa.Column("id", sa.Integer(), autoincrement=True, nullable=False),
sa.Column("user_id", UUID(as_uuid=True), nullable=False),
sa.Column("token_hash", sa.String(256), nullable=False),
sa.Column("expires_at", sa.TIMESTAMP(timezone=True), nullable=False),
sa.Column("is_revoked", sa.Boolean(), nullable=False, default=False),
sa.Column("family_id", UUID(as_uuid=True), nullable=False),
sa.Column(
"created_at",
sa.TIMESTAMP(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.Column(
"updated_at",
sa.TIMESTAMP(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.PrimaryKeyConstraint("id"),
sa.ForeignKeyConstraint(
["user_id"],
["user.id"],
ondelete="CASCADE",
),
)
# Create indexes if they don't exist
op.execute(
"CREATE INDEX IF NOT EXISTS ix_refresh_tokens_user_id ON refresh_tokens (user_id)"
)
op.execute(
"CREATE UNIQUE INDEX IF NOT EXISTS ix_refresh_tokens_token_hash ON refresh_tokens (token_hash)"
)
op.execute(
"CREATE INDEX IF NOT EXISTS ix_refresh_tokens_expires_at ON refresh_tokens (expires_at)"
)
op.execute(
"CREATE INDEX IF NOT EXISTS ix_refresh_tokens_family_id ON refresh_tokens (family_id)"
)
def downgrade() -> None:
"""Drop refresh_tokens table (idempotent)."""
op.execute("DROP INDEX IF EXISTS ix_refresh_tokens_family_id")
op.execute("DROP INDEX IF EXISTS ix_refresh_tokens_expires_at")
op.execute("DROP INDEX IF EXISTS ix_refresh_tokens_token_hash")
op.execute("DROP INDEX IF EXISTS ix_refresh_tokens_user_id")
op.execute("DROP TABLE IF EXISTS refresh_tokens")

View file

@ -12,6 +12,7 @@ from app.agents.new_chat.checkpointer import (
from app.config import config, initialize_llm_router
from app.db import User, create_db_and_tables, get_async_session
from app.routes import router as crud_router
from app.routes.auth_routes import router as auth_router
from app.schemas import UserCreate, UserRead, UserUpdate
from app.tasks.surfsense_docs_indexer import seed_surfsense_docs
from app.users import SECRET, auth_backend, current_active_user, fastapi_users
@ -111,6 +112,9 @@ app.include_router(
tags=["users"],
)
# Include custom auth routes (refresh token, logout)
app.include_router(auth_router)
if config.AUTH_TYPE == "GOOGLE":
from fastapi.responses import RedirectResponse

View file

@ -255,6 +255,14 @@ class Config:
# OAuth JWT
SECRET_KEY = os.getenv("SECRET_KEY")
# JWT Token Lifetimes
ACCESS_TOKEN_LIFETIME_SECONDS = int(
os.getenv("ACCESS_TOKEN_LIFETIME_SECONDS", str(24 * 60 * 60)) # 1 day
)
REFRESH_TOKEN_LIFETIME_SECONDS = int(
os.getenv("REFRESH_TOKEN_LIFETIME_SECONDS", str(14 * 24 * 60 * 60)) # 2 weeks
)
# ETL Service
ETL_SERVICE = os.getenv("ETL_SERVICE")

View file

@ -71,6 +71,14 @@ class AirtableHistoryConnector:
config_data = connector.config.copy()
# Check if access_token exists before processing
raw_access_token = config_data.get("access_token")
if not raw_access_token:
raise ValueError(
"Airtable access token not found. "
"Please reconnect your Airtable account."
)
# Decrypt credentials if they are encrypted
token_encrypted = config_data.get("_token_encrypted", False)
if token_encrypted and config.SECRET_KEY:
@ -98,6 +106,14 @@ class AirtableHistoryConnector:
f"Failed to decrypt Airtable credentials: {e!s}"
) from e
# Final validation after decryption
final_token = config_data.get("access_token")
if not final_token or (isinstance(final_token, str) and not final_token.strip()):
raise ValueError(
"Airtable access token is invalid or empty. "
"Please reconnect your Airtable account."
)
try:
self._credentials = AirtableAuthCredentialsBase.from_dict(config_data)
except Exception as e:

View file

@ -87,6 +87,14 @@ class ConfluenceHistoryConnector:
if is_oauth:
# OAuth 2.0 authentication
# Check if access_token exists before processing
raw_access_token = config_data.get("access_token")
if not raw_access_token:
raise ValueError(
"Confluence access token not found. "
"Please reconnect your Confluence account."
)
# Decrypt credentials if they are encrypted
token_encrypted = config_data.get("_token_encrypted", False)
if token_encrypted and config.SECRET_KEY:
@ -118,6 +126,14 @@ class ConfluenceHistoryConnector:
f"Failed to decrypt Confluence credentials: {e!s}"
) from e
# Final validation after decryption
final_token = config_data.get("access_token")
if not final_token or (isinstance(final_token, str) and not final_token.strip()):
raise ValueError(
"Confluence access token is invalid or empty. "
"Please reconnect your Confluence account."
)
try:
self._credentials = AtlassianAuthCredentialsBase.from_dict(
config_data

View file

@ -86,6 +86,14 @@ class JiraHistoryConnector:
if is_oauth:
# OAuth 2.0 authentication
# Check if access_token exists before processing
raw_access_token = config_data.get("access_token")
if not raw_access_token:
raise ValueError(
"Jira access token not found. "
"Please reconnect your Jira account."
)
if not config.SECRET_KEY:
raise ValueError(
"SECRET_KEY not configured but tokens are marked as encrypted"
@ -119,6 +127,14 @@ class JiraHistoryConnector:
f"Failed to decrypt Jira credentials: {e!s}"
) from e
# Final validation after decryption
final_token = config_data.get("access_token")
if not final_token or (isinstance(final_token, str) and not final_token.strip()):
raise ValueError(
"Jira access token is invalid or empty. "
"Please reconnect your Jira account."
)
try:
self._credentials = AtlassianAuthCredentialsBase.from_dict(
config_data

View file

@ -116,6 +116,14 @@ class LinearConnector:
config_data = connector.config.copy()
# Check if access_token exists before processing
raw_access_token = config_data.get("access_token")
if not raw_access_token:
raise ValueError(
"Linear access token not found. "
"Please reconnect your Linear account."
)
# Decrypt credentials if they are encrypted
token_encrypted = config_data.get("_token_encrypted", False)
if token_encrypted and config.SECRET_KEY:
@ -143,6 +151,14 @@ class LinearConnector:
f"Failed to decrypt Linear credentials: {e!s}"
) from e
# Final validation after decryption
final_token = config_data.get("access_token")
if not final_token or (isinstance(final_token, str) and not final_token.strip()):
raise ValueError(
"Linear access token is invalid or empty. "
"Please reconnect your Linear account."
)
try:
self._credentials = LinearAuthCredentialsBase.from_dict(config_data)
except Exception as e:

View file

@ -1361,6 +1361,13 @@ if config.AUTH_TYPE == "GOOGLE":
display_name = Column(String, nullable=True)
avatar_url = Column(String, nullable=True)
# Refresh tokens for this user
refresh_tokens = relationship(
"RefreshToken",
back_populates="user",
cascade="all, delete-orphan",
)
else:
class User(SQLAlchemyBaseUserTableUUID, Base):
@ -1426,6 +1433,43 @@ else:
display_name = Column(String, nullable=True)
avatar_url = Column(String, nullable=True)
# Refresh tokens for this user
refresh_tokens = relationship(
"RefreshToken",
back_populates="user",
cascade="all, delete-orphan",
)
class RefreshToken(Base, TimestampMixin):
"""
Stores refresh tokens for user session management.
Each row represents one device/session.
"""
__tablename__ = "refresh_tokens"
id = Column(Integer, primary_key=True, autoincrement=True)
user_id = Column(
UUID(as_uuid=True),
ForeignKey("user.id", ondelete="CASCADE"),
nullable=False,
index=True,
)
user = relationship("User", back_populates="refresh_tokens")
token_hash = Column(String(256), unique=True, nullable=False, index=True)
expires_at = Column(TIMESTAMP(timezone=True), nullable=False, index=True)
is_revoked = Column(Boolean, default=False, nullable=False)
family_id = Column(UUID(as_uuid=True), nullable=False, index=True)
@property
def is_expired(self) -> bool:
return datetime.now(UTC) >= self.expires_at
@property
def is_valid(self) -> bool:
return not self.is_expired and not self.is_revoked
engine = create_async_engine(DATABASE_URL)
async_session_maker = async_sessionmaker(engine, expire_on_commit=False)

View file

@ -0,0 +1,93 @@
"""Authentication routes for refresh token management."""
import logging
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy import select
from app.db import User, async_session_maker
from app.schemas.auth import (
LogoutAllResponse,
LogoutRequest,
LogoutResponse,
RefreshTokenRequest,
RefreshTokenResponse,
)
from app.users import current_active_user, get_jwt_strategy
from app.utils.refresh_tokens import (
revoke_all_user_tokens,
revoke_refresh_token,
rotate_refresh_token,
validate_refresh_token,
)
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/auth/jwt", tags=["auth"])
@router.post("/refresh", response_model=RefreshTokenResponse)
async def refresh_access_token(request: RefreshTokenRequest):
"""
Exchange a valid refresh token for a new access token and refresh token.
Implements token rotation for security.
"""
token_record = await validate_refresh_token(request.refresh_token)
if not token_record:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid or expired refresh token",
)
# Get user from token record
async with async_session_maker() as session:
result = await session.execute(
select(User).where(User.id == token_record.user_id)
)
user = result.scalars().first()
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="User not found",
)
# Generate new access token
strategy = get_jwt_strategy()
access_token = await strategy.write_token(user)
# Rotate refresh token
new_refresh_token = await rotate_refresh_token(token_record)
logger.info(f"Refreshed token for user {user.id}")
return RefreshTokenResponse(
access_token=access_token,
refresh_token=new_refresh_token,
)
@router.post("/revoke", response_model=LogoutResponse)
async def revoke_token(request: LogoutRequest):
"""
Logout current device by revoking the provided refresh token.
Does not require authentication - just the refresh token.
"""
revoked = await revoke_refresh_token(request.refresh_token)
if revoked:
logger.info("User logged out from current device - token revoked")
else:
logger.warning("Logout called but no matching token found to revoke")
return LogoutResponse()
@router.post("/logout-all", response_model=LogoutAllResponse)
async def logout_all_devices(user: User = Depends(current_active_user)):
"""
Logout from all devices by revoking all refresh tokens for the user.
Requires valid access token.
"""
await revoke_all_user_tokens(user.id)
logger.info(f"User {user.id} logged out from all devices")
return LogoutAllResponse()

View file

@ -1,3 +1,10 @@
from .auth import (
LogoutAllResponse,
LogoutRequest,
LogoutResponse,
RefreshTokenRequest,
RefreshTokenResponse,
)
from .base import IDModel, TimestampModel
from .chunks import ChunkBase, ChunkCreate, ChunkRead, ChunkUpdate
from .documents import (
@ -117,6 +124,10 @@ __all__ = [
"LogFilter",
"LogRead",
"LogUpdate",
# Auth schemas
"LogoutAllResponse",
"LogoutRequest",
"LogoutResponse",
# Search source connector schemas
"MCPConnectorCreate",
"MCPConnectorRead",
@ -146,6 +157,8 @@ __all__ = [
"PodcastCreate",
"PodcastRead",
"PodcastUpdate",
"RefreshTokenRequest",
"RefreshTokenResponse",
"RoleCreate",
"RoleRead",
"RoleUpdate",

View file

@ -0,0 +1,35 @@
"""Authentication schemas for refresh token endpoints."""
from pydantic import BaseModel
class RefreshTokenRequest(BaseModel):
"""Request body for token refresh endpoint."""
refresh_token: str
class RefreshTokenResponse(BaseModel):
"""Response from token refresh endpoint."""
access_token: str
refresh_token: str
token_type: str = "bearer"
class LogoutRequest(BaseModel):
"""Request body for logout endpoint (current device)."""
refresh_token: str
class LogoutResponse(BaseModel):
"""Response from logout endpoint (current device)."""
detail: str = "Successfully logged out"
class LogoutAllResponse(BaseModel):
"""Response from logout all devices endpoint."""
detail: str = "Successfully logged out from all devices"

View file

@ -23,17 +23,20 @@ from app.db import (
get_default_roles_config,
get_user_db,
)
from app.utils.refresh_tokens import create_refresh_token
logger = logging.getLogger(__name__)
class BearerResponse(BaseModel):
access_token: str
refresh_token: str
token_type: str
SECRET = config.SECRET_KEY
if config.AUTH_TYPE == "GOOGLE":
from httpx_oauth.clients.google import GoogleOAuth2
@ -183,7 +186,10 @@ async def get_user_manager(user_db: SQLAlchemyUserDatabase = Depends(get_user_db
def get_jwt_strategy() -> JWTStrategy[models.UP, models.ID]:
return JWTStrategy(secret=SECRET, lifetime_seconds=3600 * 24)
return JWTStrategy(
secret=SECRET,
lifetime_seconds=config.ACCESS_TOKEN_LIFETIME_SECONDS,
)
# # COOKIE AUTH | Uncomment if you want to use cookie auth.
@ -209,9 +215,30 @@ def get_jwt_strategy() -> JWTStrategy[models.UP, models.ID]:
# BEARER AUTH CODE.
class CustomBearerTransport(BearerTransport):
async def get_login_response(self, token: str) -> Response:
bearer_response = BearerResponse(access_token=token, token_type="bearer")
redirect_url = f"{config.NEXT_FRONTEND_URL}/auth/callback?token={bearer_response.access_token}"
import jwt
# Decode JWT to get user_id for refresh token creation
try:
payload = jwt.decode(token, SECRET, algorithms=["HS256"], options={"verify_aud": False})
user_id = uuid.UUID(payload.get("sub"))
refresh_token = await create_refresh_token(user_id)
except Exception as e:
logger.error(f"Failed to create refresh token: {e}")
# Fall back to response without refresh token
refresh_token = ""
bearer_response = BearerResponse(
access_token=token,
refresh_token=refresh_token,
token_type="bearer",
)
if config.AUTH_TYPE == "GOOGLE":
redirect_url = (
f"{config.NEXT_FRONTEND_URL}/auth/callback"
f"?token={bearer_response.access_token}"
f"&refresh_token={bearer_response.refresh_token}"
)
return RedirectResponse(redirect_url, status_code=302)
else:
return JSONResponse(bearer_response.model_dump())

View file

@ -0,0 +1,153 @@
"""Utilities for managing refresh tokens."""
import hashlib
import logging
import secrets
import uuid
from datetime import UTC, datetime, timedelta
from sqlalchemy import select, update
from app.config import config
from app.db import RefreshToken, async_session_maker
logger = logging.getLogger(__name__)
def generate_refresh_token() -> str:
"""Generate a cryptographically secure refresh token."""
return secrets.token_urlsafe(32)
def hash_token(token: str) -> str:
"""Hash a token for secure storage."""
return hashlib.sha256(token.encode()).hexdigest()
async def create_refresh_token(
user_id: uuid.UUID,
family_id: uuid.UUID | None = None,
) -> str:
"""
Create and store a new refresh token for a user.
Args:
user_id: The user's ID
family_id: Optional family ID for token rotation
Returns:
The plaintext refresh token
"""
token = generate_refresh_token()
token_hash = hash_token(token)
expires_at = datetime.now(UTC) + timedelta(
seconds=config.REFRESH_TOKEN_LIFETIME_SECONDS
)
if family_id is None:
family_id = uuid.uuid4()
async with async_session_maker() as session:
refresh_token = RefreshToken(
user_id=user_id,
token_hash=token_hash,
expires_at=expires_at,
family_id=family_id,
)
session.add(refresh_token)
await session.commit()
return token
async def validate_refresh_token(token: str) -> RefreshToken | None:
"""
Validate a refresh token. Handles reuse detection.
Args:
token: The plaintext refresh token
Returns:
RefreshToken if valid, None otherwise
"""
token_hash = hash_token(token)
async with async_session_maker() as session:
result = await session.execute(
select(RefreshToken).where(RefreshToken.token_hash == token_hash)
)
refresh_token = result.scalars().first()
if not refresh_token:
return None
# Reuse detection: revoked token used while family has active tokens
if refresh_token.is_revoked:
active = await session.execute(
select(RefreshToken).where(
RefreshToken.family_id == refresh_token.family_id,
RefreshToken.is_revoked == False, # noqa: E712
RefreshToken.expires_at > datetime.now(UTC),
)
)
if active.scalars().first():
# Revoke entire family
await session.execute(
update(RefreshToken)
.where(RefreshToken.family_id == refresh_token.family_id)
.values(is_revoked=True)
)
await session.commit()
logger.warning(f"Token reuse detected for user {refresh_token.user_id}")
return None
if refresh_token.is_expired:
return None
return refresh_token
async def rotate_refresh_token(old_token: RefreshToken) -> str:
"""Revoke old token and create new one in same family."""
async with async_session_maker() as session:
await session.execute(
update(RefreshToken)
.where(RefreshToken.id == old_token.id)
.values(is_revoked=True)
)
await session.commit()
return await create_refresh_token(old_token.user_id, old_token.family_id)
async def revoke_refresh_token(token: str) -> bool:
"""
Revoke a single refresh token by its plaintext value.
Args:
token: The plaintext refresh token
Returns:
True if token was found and revoked, False otherwise
"""
token_hash = hash_token(token)
async with async_session_maker() as session:
result = await session.execute(
update(RefreshToken)
.where(RefreshToken.token_hash == token_hash)
.values(is_revoked=True)
)
await session.commit()
return result.rowcount > 0
async def revoke_all_user_tokens(user_id: uuid.UUID) -> None:
"""Revoke all refresh tokens for a user (logout all devices)."""
async with async_session_maker() as session:
await session.execute(
update(RefreshToken)
.where(RefreshToken.user_id == user_id)
.values(is_revoked=True)
)
await session.commit()

View file

@ -3,7 +3,7 @@
import { useSearchParams } from "next/navigation";
import { useEffect } from "react";
import { useGlobalLoadingEffect } from "@/hooks/use-global-loading";
import { getAndClearRedirectPath, setBearerToken } from "@/lib/auth-utils";
import { getAndClearRedirectPath, setBearerToken, setRefreshToken } from "@/lib/auth-utils";
import { trackLoginSuccess } from "@/lib/posthog/events";
interface TokenHandlerProps {
@ -35,8 +35,9 @@ const TokenHandler = ({
// Only run on client-side
if (typeof window === "undefined") return;
// Get token from URL parameters
// Get tokens from URL parameters
const token = searchParams.get(tokenParamName);
const refreshToken = searchParams.get("refresh_token");
if (token) {
try {
@ -50,10 +51,15 @@ const TokenHandler = ({
// Clear the flag for future logins
sessionStorage.removeItem("login_success_tracked");
// Store token in localStorage using both methods for compatibility
// Store access token in localStorage using both methods for compatibility
localStorage.setItem(storageKey, token);
setBearerToken(token);
// Store refresh token if provided
if (refreshToken) {
setRefreshToken(refreshToken);
}
// Check if there's a saved redirect path from before the auth flow
const savedRedirectPath = getAndClearRedirectPath();

View file

@ -1,7 +1,8 @@
"use client";
import { BadgeCheck, LogOut } from "lucide-react";
import { BadgeCheck, Loader2, LogOut } from "lucide-react";
import { useRouter } from "next/navigation";
import { useState } from "react";
import { Avatar, AvatarFallback, AvatarImage } from "@/components/ui/avatar";
import { Button } from "@/components/ui/button";
import {
@ -13,6 +14,7 @@ import {
DropdownMenuSeparator,
DropdownMenuTrigger,
} from "@/components/ui/dropdown-menu";
import { logout } from "@/lib/auth-utils";
import { cleanupElectric } from "@/lib/electric/client";
import { resetUser, trackLogout } from "@/lib/posthog/events";
@ -26,8 +28,11 @@ export function UserDropdown({
};
}) {
const router = useRouter();
const [isLoggingOut, setIsLoggingOut] = useState(false);
const handleLogout = async () => {
if (isLoggingOut) return;
setIsLoggingOut(true);
try {
// Track logout event and reset PostHog identity
trackLogout();
@ -41,15 +46,17 @@ export function UserDropdown({
console.warn("[Logout] Electric cleanup failed (will be handled on next login):", err);
}
// Revoke refresh token on server and clear all tokens from localStorage
await logout();
if (typeof window !== "undefined") {
localStorage.removeItem("surfsense_bearer_token");
window.location.href = "/";
}
} catch (error) {
console.error("Error during logout:", error);
// Optionally, provide user feedback
// Even if there's an error, try to clear tokens and redirect
await logout();
if (typeof window !== "undefined") {
localStorage.removeItem("surfsense_bearer_token");
window.location.href = "/";
}
}
@ -85,9 +92,17 @@ export function UserDropdown({
</DropdownMenuItem>
</DropdownMenuGroup>
<DropdownMenuSeparator />
<DropdownMenuItem onClick={handleLogout} className="text-xs md:text-sm">
<LogOut className="mr-2 h-3.5 w-3.5 md:h-4 md:w-4" />
Log out
<DropdownMenuItem
onClick={handleLogout}
className="text-xs md:text-sm"
disabled={isLoggingOut}
>
{isLoggingOut ? (
<Loader2 className="mr-2 h-3.5 w-3.5 md:h-4 md:w-4 animate-spin" />
) : (
<LogOut className="mr-2 h-3.5 w-3.5 md:h-4 md:w-4" />
)}
{isLoggingOut ? "Logging out..." : "Log out"}
</DropdownMenuItem>
</DropdownMenuContent>
</DropdownMenu>

View file

@ -26,6 +26,7 @@ import { isPageLimitExceededMetadata } from "@/contracts/types/inbox.types";
import { useInbox } from "@/hooks/use-inbox";
import { searchSpacesApiService } from "@/lib/apis/search-spaces-api.service";
import { deleteThread, fetchThreads, updateThread } from "@/lib/chat/thread-persistence";
import { logout } from "@/lib/auth-utils";
import { cleanupElectric } from "@/lib/electric/client";
import { resetUser, trackLogout } from "@/lib/posthog/events";
import { cacheKeys } from "@/lib/query-client/cache-keys";
@ -474,12 +475,15 @@ export function LayoutDataProvider({
console.warn("[Logout] Electric cleanup failed (will be handled on next login):", err);
}
// Revoke refresh token on server and clear all tokens from localStorage
await logout();
if (typeof window !== "undefined") {
localStorage.removeItem("surfsense_bearer_token");
router.push("/");
}
} catch (error) {
console.error("Error during logout:", error);
await logout();
router.push("/");
}
}, [router]);

View file

@ -1,7 +1,8 @@
"use client";
import { Check, ChevronUp, Languages, Laptop, LogOut, Moon, Settings, Sun } from "lucide-react";
import { Check, ChevronUp, Languages, Laptop, Loader2, LogOut, Moon, Settings, Sun } from "lucide-react";
import { useTranslations } from "next-intl";
import { useState } from "react";
import {
DropdownMenu,
DropdownMenuContent,
@ -124,6 +125,7 @@ export function SidebarUserProfile({
}: SidebarUserProfileProps) {
const t = useTranslations("sidebar");
const { locale, setLocale } = useLocaleContext();
const [isLoggingOut, setIsLoggingOut] = useState(false);
const bgColor = stringToColor(user.email);
const initials = getInitials(user.email);
const displayName = user.name || user.email.split("@")[0];
@ -136,6 +138,16 @@ export function SidebarUserProfile({
setTheme?.(newTheme);
};
const handleLogout = async () => {
if (isLoggingOut || !onLogout) return;
setIsLoggingOut(true);
try {
await onLogout();
} finally {
setIsLoggingOut(false);
}
};
// Collapsed view - just show avatar with dropdown
if (isCollapsed) {
return (
@ -242,9 +254,13 @@ export function SidebarUserProfile({
<DropdownMenuSeparator />
<DropdownMenuItem onClick={onLogout}>
<LogOut className="mr-2 h-4 w-4" />
{t("logout")}
<DropdownMenuItem onClick={handleLogout} disabled={isLoggingOut}>
{isLoggingOut ? (
<Loader2 className="mr-2 h-4 w-4 animate-spin" />
) : (
<LogOut className="mr-2 h-4 w-4" />
)}
{isLoggingOut ? t("loggingOut") : t("logout")}
</DropdownMenuItem>
</DropdownMenuContent>
</DropdownMenu>
@ -360,9 +376,13 @@ export function SidebarUserProfile({
<DropdownMenuSeparator />
<DropdownMenuItem onClick={onLogout}>
<LogOut className="mr-2 h-4 w-4" />
{t("logout")}
<DropdownMenuItem onClick={handleLogout} disabled={isLoggingOut}>
{isLoggingOut ? (
<Loader2 className="mr-2 h-4 w-4 animate-spin" />
) : (
<LogOut className="mr-2 h-4 w-4" />
)}
{isLoggingOut ? t("loggingOut") : t("logout")}
</DropdownMenuItem>
</DropdownMenuContent>
</DropdownMenu>

View file

@ -1,5 +1,5 @@
import type { ZodType } from "zod";
import { getBearerToken, handleUnauthorized } from "../auth-utils";
import { getBearerToken, handleUnauthorized, refreshAccessToken } from "../auth-utils";
import { AppError, AuthenticationError, AuthorizationError, NotFoundError } from "../error";
enum ResponseType {
@ -17,6 +17,7 @@ export type RequestOptions = {
signal?: AbortSignal;
body?: any;
responseType?: ResponseType;
_isRetry?: boolean; // Internal flag to prevent infinite retry loops
// Add more options as needed
};
@ -135,8 +136,23 @@ class BaseApiService {
throw new AppError("Failed to parse response", response.status, response.statusText);
}
// Handle 401 first before other error handling - ensures token is cleared and user redirected
// Handle 401 - try to refresh token first (only once)
if (response.status === 401) {
if (!options?._isRetry) {
const newToken = await refreshAccessToken();
if (newToken) {
// Retry the request with the new token
return this.request(url, responseSchema, {
...mergedOptions,
headers: {
...mergedOptions.headers,
Authorization: `Bearer ${newToken}`,
},
_isRetry: true,
} as RequestOptions & { responseType?: R });
}
}
// Refresh failed or retry failed, redirect to login
handleUnauthorized();
throw new AuthenticationError(
typeof data === "object" && "detail" in data

View file

@ -4,6 +4,11 @@
const REDIRECT_PATH_KEY = "surfsense_redirect_path";
const BEARER_TOKEN_KEY = "surfsense_bearer_token";
const REFRESH_TOKEN_KEY = "surfsense_refresh_token";
// Flag to prevent multiple simultaneous refresh attempts
let isRefreshing = false;
let refreshPromise: Promise<string | null> | null = null;
/**
* Saves the current path and redirects to login page
@ -21,8 +26,9 @@ export function handleUnauthorized(): void {
localStorage.setItem(REDIRECT_PATH_KEY, currentPath);
}
// Clear the token
// Clear both tokens
localStorage.removeItem(BEARER_TOKEN_KEY);
localStorage.removeItem(REFRESH_TOKEN_KEY);
// Redirect to home page (which has login options)
window.location.href = "/login";
@ -66,6 +72,71 @@ export function clearBearerToken(): void {
localStorage.removeItem(BEARER_TOKEN_KEY);
}
/**
* Gets the refresh token from localStorage
*/
export function getRefreshToken(): string | null {
if (typeof window === "undefined") return null;
return localStorage.getItem(REFRESH_TOKEN_KEY);
}
/**
* Sets the refresh token in localStorage
*/
export function setRefreshToken(token: string): void {
if (typeof window === "undefined") return;
localStorage.setItem(REFRESH_TOKEN_KEY, token);
}
/**
* Clears the refresh token from localStorage
*/
export function clearRefreshToken(): void {
if (typeof window === "undefined") return;
localStorage.removeItem(REFRESH_TOKEN_KEY);
}
/**
* Clears all auth tokens from localStorage
*/
export function clearAllTokens(): void {
clearBearerToken();
clearRefreshToken();
}
/**
* Logout the current user by revoking the refresh token and clearing localStorage.
* Returns true if logout was successful (or tokens were cleared), false otherwise.
*/
export async function logout(): Promise<boolean> {
const refreshToken = getRefreshToken();
// Call backend to revoke the refresh token
if (refreshToken) {
try {
const backendUrl = process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL || "http://localhost:8000";
const response = await fetch(`${backendUrl}/auth/jwt/revoke`, {
method: "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({ refresh_token: refreshToken }),
});
if (!response.ok) {
console.warn("Failed to revoke refresh token:", response.status, await response.text());
}
} catch (error) {
console.warn("Failed to revoke refresh token on server:", error);
// Continue to clear local tokens even if server call fails
}
}
// Clear all tokens from localStorage
clearAllTokens();
return true;
}
/**
* Checks if the user is authenticated (has a token)
*/
@ -106,14 +177,67 @@ export function getAuthHeaders(additionalHeaders?: Record<string, string>): Reco
}
/**
* Authenticated fetch wrapper that handles 401 responses uniformly
* Automatically redirects to login on 401 and saves the current path
* Attempts to refresh the access token using the stored refresh token.
* Returns the new access token if successful, null otherwise.
* Exported for use by API services.
*/
export async function refreshAccessToken(): Promise<string | null> {
// If already refreshing, wait for that request to complete
if (isRefreshing && refreshPromise) {
return refreshPromise;
}
const currentRefreshToken = getRefreshToken();
if (!currentRefreshToken) {
return null;
}
isRefreshing = true;
refreshPromise = (async () => {
try {
const backendUrl = process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL || "http://localhost:8000";
const response = await fetch(`${backendUrl}/auth/jwt/refresh`, {
method: "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({ refresh_token: currentRefreshToken }),
});
if (!response.ok) {
// Refresh failed, clear tokens
clearAllTokens();
return null;
}
const data = await response.json();
if (data.access_token && data.refresh_token) {
setBearerToken(data.access_token);
setRefreshToken(data.refresh_token);
return data.access_token;
}
return null;
} catch {
return null;
} finally {
isRefreshing = false;
refreshPromise = null;
}
})();
return refreshPromise;
}
/**
* Authenticated fetch wrapper that handles 401 responses uniformly.
* On 401, attempts to refresh the token and retry the request.
* If refresh fails, redirects to login and saves the current path.
*/
export async function authenticatedFetch(
url: string,
options?: RequestInit & { skipAuthRedirect?: boolean }
options?: RequestInit & { skipAuthRedirect?: boolean; skipRefresh?: boolean }
): Promise<Response> {
const { skipAuthRedirect = false, ...fetchOptions } = options || {};
const { skipAuthRedirect = false, skipRefresh = false, ...fetchOptions } = options || {};
const headers = getAuthHeaders(fetchOptions.headers as Record<string, string>);
@ -124,6 +248,23 @@ export async function authenticatedFetch(
// Handle 401 Unauthorized
if (response.status === 401 && !skipAuthRedirect) {
// Try to refresh the token (unless skipRefresh is set to prevent infinite loops)
if (!skipRefresh) {
const newToken = await refreshAccessToken();
if (newToken) {
// Retry the original request with the new token
const retryHeaders = {
...(fetchOptions.headers as Record<string, string>),
Authorization: `Bearer ${newToken}`,
};
return fetch(url, {
...fetchOptions,
headers: retryHeaders,
});
}
}
// Refresh failed or was skipped, redirect to login
handleUnauthorized();
throw new Error("Unauthorized: Redirecting to login page");
}

View file

@ -700,6 +700,7 @@
"dark": "Dark",
"system": "System",
"logout": "Logout",
"loggingOut": "Logging out...",
"inbox": "Inbox",
"search_inbox": "Search inbox",
"mark_all_read": "Mark all as read",

View file

@ -685,6 +685,7 @@
"dark": "深色",
"system": "系统",
"logout": "退出登录",
"loggingOut": "正在退出...",
"inbox": "收件箱",
"search_inbox": "搜索收件箱",
"mark_all_read": "全部标记为已读",