diff --git a/surfsense_backend/.env.example b/surfsense_backend/.env.example index 2e10f4e36..628329917 100644 --- a/surfsense_backend/.env.example +++ b/surfsense_backend/.env.example @@ -32,6 +32,11 @@ ELECTRIC_DB_PASSWORD=electric_password SCHEDULE_CHECKER_INTERVAL=5m SECRET_KEY=SECRET + +# JWT Token Lifetimes (optional, defaults shown) +# ACCESS_TOKEN_LIFETIME_SECONDS=86400 # 1 day +# REFRESH_TOKEN_LIFETIME_SECONDS=1209600 # 2 weeks + NEXT_FRONTEND_URL=http://localhost:3000 # Backend URL for OAuth callbacks (optional, set when behind reverse proxy with HTTPS) diff --git a/surfsense_backend/alembic/versions/92_add_refresh_tokens_table.py b/surfsense_backend/alembic/versions/92_add_refresh_tokens_table.py new file mode 100644 index 000000000..c7e133ae9 --- /dev/null +++ b/surfsense_backend/alembic/versions/92_add_refresh_tokens_table.py @@ -0,0 +1,92 @@ +"""Add refresh_tokens table for user session management + +Revision ID: 92 +Revises: 91 + +Changes: +1. Create refresh_tokens table with columns: + - id (primary key) + - user_id (foreign key to user) + - token_hash (unique, indexed) + - expires_at (indexed) + - is_revoked + - family_id (indexed, for token rotation tracking) + - created_at, updated_at (timestamps) +""" + +from collections.abc import Sequence + +import sqlalchemy as sa +from sqlalchemy.dialects.postgresql import UUID + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "92" +down_revision: str | None = "91" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def upgrade() -> None: + """Create refresh_tokens table (idempotent).""" + # Check if table already exists + connection = op.get_bind() + result = connection.execute( + sa.text( + "SELECT EXISTS (SELECT FROM information_schema.tables WHERE table_name = 'refresh_tokens')" + ) + ) + table_exists = result.scalar() + + if not table_exists: + op.create_table( + "refresh_tokens", + sa.Column("id", sa.Integer(), autoincrement=True, nullable=False), + sa.Column("user_id", UUID(as_uuid=True), nullable=False), + sa.Column("token_hash", sa.String(256), nullable=False), + sa.Column("expires_at", sa.TIMESTAMP(timezone=True), nullable=False), + sa.Column("is_revoked", sa.Boolean(), nullable=False, default=False), + sa.Column("family_id", UUID(as_uuid=True), nullable=False), + sa.Column( + "created_at", + sa.TIMESTAMP(timezone=True), + server_default=sa.text("now()"), + nullable=False, + ), + sa.Column( + "updated_at", + sa.TIMESTAMP(timezone=True), + server_default=sa.text("now()"), + nullable=False, + ), + sa.PrimaryKeyConstraint("id"), + sa.ForeignKeyConstraint( + ["user_id"], + ["user.id"], + ondelete="CASCADE", + ), + ) + + # Create indexes if they don't exist + op.execute( + "CREATE INDEX IF NOT EXISTS ix_refresh_tokens_user_id ON refresh_tokens (user_id)" + ) + op.execute( + "CREATE UNIQUE INDEX IF NOT EXISTS ix_refresh_tokens_token_hash ON refresh_tokens (token_hash)" + ) + op.execute( + "CREATE INDEX IF NOT EXISTS ix_refresh_tokens_expires_at ON refresh_tokens (expires_at)" + ) + op.execute( + "CREATE INDEX IF NOT EXISTS ix_refresh_tokens_family_id ON refresh_tokens (family_id)" + ) + + +def downgrade() -> None: + """Drop refresh_tokens table (idempotent).""" + op.execute("DROP INDEX IF EXISTS ix_refresh_tokens_family_id") + op.execute("DROP INDEX IF EXISTS ix_refresh_tokens_expires_at") + op.execute("DROP INDEX IF EXISTS ix_refresh_tokens_token_hash") + op.execute("DROP INDEX IF EXISTS ix_refresh_tokens_user_id") + op.execute("DROP TABLE IF EXISTS refresh_tokens") diff --git a/surfsense_backend/app/app.py b/surfsense_backend/app/app.py index 01dd0da3d..63da4e8ad 100644 --- a/surfsense_backend/app/app.py +++ b/surfsense_backend/app/app.py @@ -12,6 +12,7 @@ from app.agents.new_chat.checkpointer import ( from app.config import config, initialize_llm_router from app.db import User, create_db_and_tables, get_async_session from app.routes import router as crud_router +from app.routes.auth_routes import router as auth_router from app.schemas import UserCreate, UserRead, UserUpdate from app.tasks.surfsense_docs_indexer import seed_surfsense_docs from app.users import SECRET, auth_backend, current_active_user, fastapi_users @@ -111,6 +112,9 @@ app.include_router( tags=["users"], ) +# Include custom auth routes (refresh token, logout) +app.include_router(auth_router) + if config.AUTH_TYPE == "GOOGLE": from fastapi.responses import RedirectResponse diff --git a/surfsense_backend/app/config/__init__.py b/surfsense_backend/app/config/__init__.py index 149fedd39..121e5d3b2 100644 --- a/surfsense_backend/app/config/__init__.py +++ b/surfsense_backend/app/config/__init__.py @@ -255,6 +255,14 @@ class Config: # OAuth JWT SECRET_KEY = os.getenv("SECRET_KEY") + # JWT Token Lifetimes + ACCESS_TOKEN_LIFETIME_SECONDS = int( + os.getenv("ACCESS_TOKEN_LIFETIME_SECONDS", str(24 * 60 * 60)) # 1 day + ) + REFRESH_TOKEN_LIFETIME_SECONDS = int( + os.getenv("REFRESH_TOKEN_LIFETIME_SECONDS", str(14 * 24 * 60 * 60)) # 2 weeks + ) + # ETL Service ETL_SERVICE = os.getenv("ETL_SERVICE") diff --git a/surfsense_backend/app/connectors/airtable_history.py b/surfsense_backend/app/connectors/airtable_history.py index 64f6465fe..092485f77 100644 --- a/surfsense_backend/app/connectors/airtable_history.py +++ b/surfsense_backend/app/connectors/airtable_history.py @@ -71,6 +71,14 @@ class AirtableHistoryConnector: config_data = connector.config.copy() + # Check if access_token exists before processing + raw_access_token = config_data.get("access_token") + if not raw_access_token: + raise ValueError( + "Airtable access token not found. " + "Please reconnect your Airtable account." + ) + # Decrypt credentials if they are encrypted token_encrypted = config_data.get("_token_encrypted", False) if token_encrypted and config.SECRET_KEY: @@ -98,6 +106,14 @@ class AirtableHistoryConnector: f"Failed to decrypt Airtable credentials: {e!s}" ) from e + # Final validation after decryption + final_token = config_data.get("access_token") + if not final_token or (isinstance(final_token, str) and not final_token.strip()): + raise ValueError( + "Airtable access token is invalid or empty. " + "Please reconnect your Airtable account." + ) + try: self._credentials = AirtableAuthCredentialsBase.from_dict(config_data) except Exception as e: diff --git a/surfsense_backend/app/connectors/confluence_history.py b/surfsense_backend/app/connectors/confluence_history.py index 9e10ffcf1..908f532db 100644 --- a/surfsense_backend/app/connectors/confluence_history.py +++ b/surfsense_backend/app/connectors/confluence_history.py @@ -87,6 +87,14 @@ class ConfluenceHistoryConnector: if is_oauth: # OAuth 2.0 authentication + # Check if access_token exists before processing + raw_access_token = config_data.get("access_token") + if not raw_access_token: + raise ValueError( + "Confluence access token not found. " + "Please reconnect your Confluence account." + ) + # Decrypt credentials if they are encrypted token_encrypted = config_data.get("_token_encrypted", False) if token_encrypted and config.SECRET_KEY: @@ -118,6 +126,14 @@ class ConfluenceHistoryConnector: f"Failed to decrypt Confluence credentials: {e!s}" ) from e + # Final validation after decryption + final_token = config_data.get("access_token") + if not final_token or (isinstance(final_token, str) and not final_token.strip()): + raise ValueError( + "Confluence access token is invalid or empty. " + "Please reconnect your Confluence account." + ) + try: self._credentials = AtlassianAuthCredentialsBase.from_dict( config_data diff --git a/surfsense_backend/app/connectors/jira_history.py b/surfsense_backend/app/connectors/jira_history.py index 6e04ec2a4..46a28324d 100644 --- a/surfsense_backend/app/connectors/jira_history.py +++ b/surfsense_backend/app/connectors/jira_history.py @@ -86,6 +86,14 @@ class JiraHistoryConnector: if is_oauth: # OAuth 2.0 authentication + # Check if access_token exists before processing + raw_access_token = config_data.get("access_token") + if not raw_access_token: + raise ValueError( + "Jira access token not found. " + "Please reconnect your Jira account." + ) + if not config.SECRET_KEY: raise ValueError( "SECRET_KEY not configured but tokens are marked as encrypted" @@ -119,6 +127,14 @@ class JiraHistoryConnector: f"Failed to decrypt Jira credentials: {e!s}" ) from e + # Final validation after decryption + final_token = config_data.get("access_token") + if not final_token or (isinstance(final_token, str) and not final_token.strip()): + raise ValueError( + "Jira access token is invalid or empty. " + "Please reconnect your Jira account." + ) + try: self._credentials = AtlassianAuthCredentialsBase.from_dict( config_data diff --git a/surfsense_backend/app/connectors/linear_connector.py b/surfsense_backend/app/connectors/linear_connector.py index b8206a40d..6500b9027 100644 --- a/surfsense_backend/app/connectors/linear_connector.py +++ b/surfsense_backend/app/connectors/linear_connector.py @@ -116,6 +116,14 @@ class LinearConnector: config_data = connector.config.copy() + # Check if access_token exists before processing + raw_access_token = config_data.get("access_token") + if not raw_access_token: + raise ValueError( + "Linear access token not found. " + "Please reconnect your Linear account." + ) + # Decrypt credentials if they are encrypted token_encrypted = config_data.get("_token_encrypted", False) if token_encrypted and config.SECRET_KEY: @@ -143,6 +151,14 @@ class LinearConnector: f"Failed to decrypt Linear credentials: {e!s}" ) from e + # Final validation after decryption + final_token = config_data.get("access_token") + if not final_token or (isinstance(final_token, str) and not final_token.strip()): + raise ValueError( + "Linear access token is invalid or empty. " + "Please reconnect your Linear account." + ) + try: self._credentials = LinearAuthCredentialsBase.from_dict(config_data) except Exception as e: diff --git a/surfsense_backend/app/db.py b/surfsense_backend/app/db.py index 5cdb712db..2298e7438 100644 --- a/surfsense_backend/app/db.py +++ b/surfsense_backend/app/db.py @@ -1361,6 +1361,13 @@ if config.AUTH_TYPE == "GOOGLE": display_name = Column(String, nullable=True) avatar_url = Column(String, nullable=True) + # Refresh tokens for this user + refresh_tokens = relationship( + "RefreshToken", + back_populates="user", + cascade="all, delete-orphan", + ) + else: class User(SQLAlchemyBaseUserTableUUID, Base): @@ -1426,6 +1433,43 @@ else: display_name = Column(String, nullable=True) avatar_url = Column(String, nullable=True) + # Refresh tokens for this user + refresh_tokens = relationship( + "RefreshToken", + back_populates="user", + cascade="all, delete-orphan", + ) + + +class RefreshToken(Base, TimestampMixin): + """ + Stores refresh tokens for user session management. + Each row represents one device/session. + """ + + __tablename__ = "refresh_tokens" + + id = Column(Integer, primary_key=True, autoincrement=True) + user_id = Column( + UUID(as_uuid=True), + ForeignKey("user.id", ondelete="CASCADE"), + nullable=False, + index=True, + ) + user = relationship("User", back_populates="refresh_tokens") + token_hash = Column(String(256), unique=True, nullable=False, index=True) + expires_at = Column(TIMESTAMP(timezone=True), nullable=False, index=True) + is_revoked = Column(Boolean, default=False, nullable=False) + family_id = Column(UUID(as_uuid=True), nullable=False, index=True) + + @property + def is_expired(self) -> bool: + return datetime.now(UTC) >= self.expires_at + + @property + def is_valid(self) -> bool: + return not self.is_expired and not self.is_revoked + engine = create_async_engine(DATABASE_URL) async_session_maker = async_sessionmaker(engine, expire_on_commit=False) diff --git a/surfsense_backend/app/routes/auth_routes.py b/surfsense_backend/app/routes/auth_routes.py new file mode 100644 index 000000000..b1cbaf2a5 --- /dev/null +++ b/surfsense_backend/app/routes/auth_routes.py @@ -0,0 +1,93 @@ +"""Authentication routes for refresh token management.""" + +import logging + +from fastapi import APIRouter, Depends, HTTPException, status +from sqlalchemy import select + +from app.db import User, async_session_maker +from app.schemas.auth import ( + LogoutAllResponse, + LogoutRequest, + LogoutResponse, + RefreshTokenRequest, + RefreshTokenResponse, +) +from app.users import current_active_user, get_jwt_strategy +from app.utils.refresh_tokens import ( + revoke_all_user_tokens, + revoke_refresh_token, + rotate_refresh_token, + validate_refresh_token, +) + +logger = logging.getLogger(__name__) + +router = APIRouter(prefix="/auth/jwt", tags=["auth"]) + + +@router.post("/refresh", response_model=RefreshTokenResponse) +async def refresh_access_token(request: RefreshTokenRequest): + """ + Exchange a valid refresh token for a new access token and refresh token. + Implements token rotation for security. + """ + token_record = await validate_refresh_token(request.refresh_token) + + if not token_record: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid or expired refresh token", + ) + + # Get user from token record + async with async_session_maker() as session: + result = await session.execute( + select(User).where(User.id == token_record.user_id) + ) + user = result.scalars().first() + + if not user: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="User not found", + ) + + # Generate new access token + strategy = get_jwt_strategy() + access_token = await strategy.write_token(user) + + # Rotate refresh token + new_refresh_token = await rotate_refresh_token(token_record) + + logger.info(f"Refreshed token for user {user.id}") + + return RefreshTokenResponse( + access_token=access_token, + refresh_token=new_refresh_token, + ) + + +@router.post("/revoke", response_model=LogoutResponse) +async def revoke_token(request: LogoutRequest): + """ + Logout current device by revoking the provided refresh token. + Does not require authentication - just the refresh token. + """ + revoked = await revoke_refresh_token(request.refresh_token) + if revoked: + logger.info("User logged out from current device - token revoked") + else: + logger.warning("Logout called but no matching token found to revoke") + return LogoutResponse() + + +@router.post("/logout-all", response_model=LogoutAllResponse) +async def logout_all_devices(user: User = Depends(current_active_user)): + """ + Logout from all devices by revoking all refresh tokens for the user. + Requires valid access token. + """ + await revoke_all_user_tokens(user.id) + logger.info(f"User {user.id} logged out from all devices") + return LogoutAllResponse() diff --git a/surfsense_backend/app/schemas/__init__.py b/surfsense_backend/app/schemas/__init__.py index 6c9577c46..5ff166733 100644 --- a/surfsense_backend/app/schemas/__init__.py +++ b/surfsense_backend/app/schemas/__init__.py @@ -1,3 +1,10 @@ +from .auth import ( + LogoutAllResponse, + LogoutRequest, + LogoutResponse, + RefreshTokenRequest, + RefreshTokenResponse, +) from .base import IDModel, TimestampModel from .chunks import ChunkBase, ChunkCreate, ChunkRead, ChunkUpdate from .documents import ( @@ -117,6 +124,10 @@ __all__ = [ "LogFilter", "LogRead", "LogUpdate", + # Auth schemas + "LogoutAllResponse", + "LogoutRequest", + "LogoutResponse", # Search source connector schemas "MCPConnectorCreate", "MCPConnectorRead", @@ -146,6 +157,8 @@ __all__ = [ "PodcastCreate", "PodcastRead", "PodcastUpdate", + "RefreshTokenRequest", + "RefreshTokenResponse", "RoleCreate", "RoleRead", "RoleUpdate", diff --git a/surfsense_backend/app/schemas/auth.py b/surfsense_backend/app/schemas/auth.py new file mode 100644 index 000000000..0d958a6d2 --- /dev/null +++ b/surfsense_backend/app/schemas/auth.py @@ -0,0 +1,35 @@ +"""Authentication schemas for refresh token endpoints.""" + +from pydantic import BaseModel + + +class RefreshTokenRequest(BaseModel): + """Request body for token refresh endpoint.""" + + refresh_token: str + + +class RefreshTokenResponse(BaseModel): + """Response from token refresh endpoint.""" + + access_token: str + refresh_token: str + token_type: str = "bearer" + + +class LogoutRequest(BaseModel): + """Request body for logout endpoint (current device).""" + + refresh_token: str + + +class LogoutResponse(BaseModel): + """Response from logout endpoint (current device).""" + + detail: str = "Successfully logged out" + + +class LogoutAllResponse(BaseModel): + """Response from logout all devices endpoint.""" + + detail: str = "Successfully logged out from all devices" diff --git a/surfsense_backend/app/users.py b/surfsense_backend/app/users.py index 4be2fe525..696cdf25e 100644 --- a/surfsense_backend/app/users.py +++ b/surfsense_backend/app/users.py @@ -23,17 +23,20 @@ from app.db import ( get_default_roles_config, get_user_db, ) +from app.utils.refresh_tokens import create_refresh_token logger = logging.getLogger(__name__) class BearerResponse(BaseModel): access_token: str + refresh_token: str token_type: str SECRET = config.SECRET_KEY + if config.AUTH_TYPE == "GOOGLE": from httpx_oauth.clients.google import GoogleOAuth2 @@ -183,7 +186,10 @@ async def get_user_manager(user_db: SQLAlchemyUserDatabase = Depends(get_user_db def get_jwt_strategy() -> JWTStrategy[models.UP, models.ID]: - return JWTStrategy(secret=SECRET, lifetime_seconds=3600 * 24) + return JWTStrategy( + secret=SECRET, + lifetime_seconds=config.ACCESS_TOKEN_LIFETIME_SECONDS, + ) # # COOKIE AUTH | Uncomment if you want to use cookie auth. @@ -209,9 +215,30 @@ def get_jwt_strategy() -> JWTStrategy[models.UP, models.ID]: # BEARER AUTH CODE. class CustomBearerTransport(BearerTransport): async def get_login_response(self, token: str) -> Response: - bearer_response = BearerResponse(access_token=token, token_type="bearer") - redirect_url = f"{config.NEXT_FRONTEND_URL}/auth/callback?token={bearer_response.access_token}" + import jwt + + # Decode JWT to get user_id for refresh token creation + try: + payload = jwt.decode(token, SECRET, algorithms=["HS256"], options={"verify_aud": False}) + user_id = uuid.UUID(payload.get("sub")) + refresh_token = await create_refresh_token(user_id) + except Exception as e: + logger.error(f"Failed to create refresh token: {e}") + # Fall back to response without refresh token + refresh_token = "" + + bearer_response = BearerResponse( + access_token=token, + refresh_token=refresh_token, + token_type="bearer", + ) + if config.AUTH_TYPE == "GOOGLE": + redirect_url = ( + f"{config.NEXT_FRONTEND_URL}/auth/callback" + f"?token={bearer_response.access_token}" + f"&refresh_token={bearer_response.refresh_token}" + ) return RedirectResponse(redirect_url, status_code=302) else: return JSONResponse(bearer_response.model_dump()) diff --git a/surfsense_backend/app/utils/refresh_tokens.py b/surfsense_backend/app/utils/refresh_tokens.py new file mode 100644 index 000000000..8c0312ba8 --- /dev/null +++ b/surfsense_backend/app/utils/refresh_tokens.py @@ -0,0 +1,153 @@ +"""Utilities for managing refresh tokens.""" + +import hashlib +import logging +import secrets +import uuid +from datetime import UTC, datetime, timedelta + +from sqlalchemy import select, update + +from app.config import config +from app.db import RefreshToken, async_session_maker + +logger = logging.getLogger(__name__) + + +def generate_refresh_token() -> str: + """Generate a cryptographically secure refresh token.""" + return secrets.token_urlsafe(32) + + +def hash_token(token: str) -> str: + """Hash a token for secure storage.""" + return hashlib.sha256(token.encode()).hexdigest() + + +async def create_refresh_token( + user_id: uuid.UUID, + family_id: uuid.UUID | None = None, +) -> str: + """ + Create and store a new refresh token for a user. + + Args: + user_id: The user's ID + family_id: Optional family ID for token rotation + + Returns: + The plaintext refresh token + """ + token = generate_refresh_token() + token_hash = hash_token(token) + expires_at = datetime.now(UTC) + timedelta( + seconds=config.REFRESH_TOKEN_LIFETIME_SECONDS + ) + + if family_id is None: + family_id = uuid.uuid4() + + async with async_session_maker() as session: + refresh_token = RefreshToken( + user_id=user_id, + token_hash=token_hash, + expires_at=expires_at, + family_id=family_id, + ) + session.add(refresh_token) + await session.commit() + + return token + + +async def validate_refresh_token(token: str) -> RefreshToken | None: + """ + Validate a refresh token. Handles reuse detection. + + Args: + token: The plaintext refresh token + + Returns: + RefreshToken if valid, None otherwise + """ + token_hash = hash_token(token) + + async with async_session_maker() as session: + result = await session.execute( + select(RefreshToken).where(RefreshToken.token_hash == token_hash) + ) + refresh_token = result.scalars().first() + + if not refresh_token: + return None + + # Reuse detection: revoked token used while family has active tokens + if refresh_token.is_revoked: + active = await session.execute( + select(RefreshToken).where( + RefreshToken.family_id == refresh_token.family_id, + RefreshToken.is_revoked == False, # noqa: E712 + RefreshToken.expires_at > datetime.now(UTC), + ) + ) + if active.scalars().first(): + # Revoke entire family + await session.execute( + update(RefreshToken) + .where(RefreshToken.family_id == refresh_token.family_id) + .values(is_revoked=True) + ) + await session.commit() + logger.warning(f"Token reuse detected for user {refresh_token.user_id}") + return None + + if refresh_token.is_expired: + return None + + return refresh_token + + +async def rotate_refresh_token(old_token: RefreshToken) -> str: + """Revoke old token and create new one in same family.""" + async with async_session_maker() as session: + await session.execute( + update(RefreshToken) + .where(RefreshToken.id == old_token.id) + .values(is_revoked=True) + ) + await session.commit() + + return await create_refresh_token(old_token.user_id, old_token.family_id) + + +async def revoke_refresh_token(token: str) -> bool: + """ + Revoke a single refresh token by its plaintext value. + + Args: + token: The plaintext refresh token + + Returns: + True if token was found and revoked, False otherwise + """ + token_hash = hash_token(token) + + async with async_session_maker() as session: + result = await session.execute( + update(RefreshToken) + .where(RefreshToken.token_hash == token_hash) + .values(is_revoked=True) + ) + await session.commit() + return result.rowcount > 0 + + +async def revoke_all_user_tokens(user_id: uuid.UUID) -> None: + """Revoke all refresh tokens for a user (logout all devices).""" + async with async_session_maker() as session: + await session.execute( + update(RefreshToken) + .where(RefreshToken.user_id == user_id) + .values(is_revoked=True) + ) + await session.commit() diff --git a/surfsense_web/components/TokenHandler.tsx b/surfsense_web/components/TokenHandler.tsx index e3295df7c..230cda81a 100644 --- a/surfsense_web/components/TokenHandler.tsx +++ b/surfsense_web/components/TokenHandler.tsx @@ -3,7 +3,7 @@ import { useSearchParams } from "next/navigation"; import { useEffect } from "react"; import { useGlobalLoadingEffect } from "@/hooks/use-global-loading"; -import { getAndClearRedirectPath, setBearerToken } from "@/lib/auth-utils"; +import { getAndClearRedirectPath, setBearerToken, setRefreshToken } from "@/lib/auth-utils"; import { trackLoginSuccess } from "@/lib/posthog/events"; interface TokenHandlerProps { @@ -35,8 +35,9 @@ const TokenHandler = ({ // Only run on client-side if (typeof window === "undefined") return; - // Get token from URL parameters + // Get tokens from URL parameters const token = searchParams.get(tokenParamName); + const refreshToken = searchParams.get("refresh_token"); if (token) { try { @@ -50,10 +51,15 @@ const TokenHandler = ({ // Clear the flag for future logins sessionStorage.removeItem("login_success_tracked"); - // Store token in localStorage using both methods for compatibility + // Store access token in localStorage using both methods for compatibility localStorage.setItem(storageKey, token); setBearerToken(token); + // Store refresh token if provided + if (refreshToken) { + setRefreshToken(refreshToken); + } + // Check if there's a saved redirect path from before the auth flow const savedRedirectPath = getAndClearRedirectPath(); diff --git a/surfsense_web/components/UserDropdown.tsx b/surfsense_web/components/UserDropdown.tsx index 3dac745cf..233a41a1f 100644 --- a/surfsense_web/components/UserDropdown.tsx +++ b/surfsense_web/components/UserDropdown.tsx @@ -1,7 +1,8 @@ "use client"; -import { BadgeCheck, LogOut } from "lucide-react"; +import { BadgeCheck, Loader2, LogOut } from "lucide-react"; import { useRouter } from "next/navigation"; +import { useState } from "react"; import { Avatar, AvatarFallback, AvatarImage } from "@/components/ui/avatar"; import { Button } from "@/components/ui/button"; import { @@ -13,6 +14,7 @@ import { DropdownMenuSeparator, DropdownMenuTrigger, } from "@/components/ui/dropdown-menu"; +import { logout } from "@/lib/auth-utils"; import { cleanupElectric } from "@/lib/electric/client"; import { resetUser, trackLogout } from "@/lib/posthog/events"; @@ -26,8 +28,11 @@ export function UserDropdown({ }; }) { const router = useRouter(); + const [isLoggingOut, setIsLoggingOut] = useState(false); const handleLogout = async () => { + if (isLoggingOut) return; + setIsLoggingOut(true); try { // Track logout event and reset PostHog identity trackLogout(); @@ -41,15 +46,17 @@ export function UserDropdown({ console.warn("[Logout] Electric cleanup failed (will be handled on next login):", err); } + // Revoke refresh token on server and clear all tokens from localStorage + await logout(); + if (typeof window !== "undefined") { - localStorage.removeItem("surfsense_bearer_token"); window.location.href = "/"; } } catch (error) { console.error("Error during logout:", error); - // Optionally, provide user feedback + // Even if there's an error, try to clear tokens and redirect + await logout(); if (typeof window !== "undefined") { - localStorage.removeItem("surfsense_bearer_token"); window.location.href = "/"; } } @@ -85,9 +92,17 @@ export function UserDropdown({ - - - Log out + + {isLoggingOut ? ( + + ) : ( + + )} + {isLoggingOut ? "Logging out..." : "Log out"} diff --git a/surfsense_web/components/layout/providers/LayoutDataProvider.tsx b/surfsense_web/components/layout/providers/LayoutDataProvider.tsx index c1a9c18c3..2c2af7d46 100644 --- a/surfsense_web/components/layout/providers/LayoutDataProvider.tsx +++ b/surfsense_web/components/layout/providers/LayoutDataProvider.tsx @@ -26,6 +26,7 @@ import { isPageLimitExceededMetadata } from "@/contracts/types/inbox.types"; import { useInbox } from "@/hooks/use-inbox"; import { searchSpacesApiService } from "@/lib/apis/search-spaces-api.service"; import { deleteThread, fetchThreads, updateThread } from "@/lib/chat/thread-persistence"; +import { logout } from "@/lib/auth-utils"; import { cleanupElectric } from "@/lib/electric/client"; import { resetUser, trackLogout } from "@/lib/posthog/events"; import { cacheKeys } from "@/lib/query-client/cache-keys"; @@ -474,12 +475,15 @@ export function LayoutDataProvider({ console.warn("[Logout] Electric cleanup failed (will be handled on next login):", err); } + // Revoke refresh token on server and clear all tokens from localStorage + await logout(); + if (typeof window !== "undefined") { - localStorage.removeItem("surfsense_bearer_token"); router.push("/"); } } catch (error) { console.error("Error during logout:", error); + await logout(); router.push("/"); } }, [router]); diff --git a/surfsense_web/components/layout/ui/sidebar/SidebarUserProfile.tsx b/surfsense_web/components/layout/ui/sidebar/SidebarUserProfile.tsx index 7c96b1dcb..38b3028d2 100644 --- a/surfsense_web/components/layout/ui/sidebar/SidebarUserProfile.tsx +++ b/surfsense_web/components/layout/ui/sidebar/SidebarUserProfile.tsx @@ -1,7 +1,8 @@ "use client"; -import { Check, ChevronUp, Languages, Laptop, LogOut, Moon, Settings, Sun } from "lucide-react"; +import { Check, ChevronUp, Languages, Laptop, Loader2, LogOut, Moon, Settings, Sun } from "lucide-react"; import { useTranslations } from "next-intl"; +import { useState } from "react"; import { DropdownMenu, DropdownMenuContent, @@ -124,6 +125,7 @@ export function SidebarUserProfile({ }: SidebarUserProfileProps) { const t = useTranslations("sidebar"); const { locale, setLocale } = useLocaleContext(); + const [isLoggingOut, setIsLoggingOut] = useState(false); const bgColor = stringToColor(user.email); const initials = getInitials(user.email); const displayName = user.name || user.email.split("@")[0]; @@ -136,6 +138,16 @@ export function SidebarUserProfile({ setTheme?.(newTheme); }; + const handleLogout = async () => { + if (isLoggingOut || !onLogout) return; + setIsLoggingOut(true); + try { + await onLogout(); + } finally { + setIsLoggingOut(false); + } + }; + // Collapsed view - just show avatar with dropdown if (isCollapsed) { return ( @@ -242,9 +254,13 @@ export function SidebarUserProfile({ - - - {t("logout")} + + {isLoggingOut ? ( + + ) : ( + + )} + {isLoggingOut ? t("loggingOut") : t("logout")} @@ -360,9 +376,13 @@ export function SidebarUserProfile({ - - - {t("logout")} + + {isLoggingOut ? ( + + ) : ( + + )} + {isLoggingOut ? t("loggingOut") : t("logout")} diff --git a/surfsense_web/lib/apis/base-api.service.ts b/surfsense_web/lib/apis/base-api.service.ts index a87d4deaf..933e54656 100644 --- a/surfsense_web/lib/apis/base-api.service.ts +++ b/surfsense_web/lib/apis/base-api.service.ts @@ -1,5 +1,5 @@ import type { ZodType } from "zod"; -import { getBearerToken, handleUnauthorized } from "../auth-utils"; +import { getBearerToken, handleUnauthorized, refreshAccessToken } from "../auth-utils"; import { AppError, AuthenticationError, AuthorizationError, NotFoundError } from "../error"; enum ResponseType { @@ -17,6 +17,7 @@ export type RequestOptions = { signal?: AbortSignal; body?: any; responseType?: ResponseType; + _isRetry?: boolean; // Internal flag to prevent infinite retry loops // Add more options as needed }; @@ -135,8 +136,23 @@ class BaseApiService { throw new AppError("Failed to parse response", response.status, response.statusText); } - // Handle 401 first before other error handling - ensures token is cleared and user redirected + // Handle 401 - try to refresh token first (only once) if (response.status === 401) { + if (!options?._isRetry) { + const newToken = await refreshAccessToken(); + if (newToken) { + // Retry the request with the new token + return this.request(url, responseSchema, { + ...mergedOptions, + headers: { + ...mergedOptions.headers, + Authorization: `Bearer ${newToken}`, + }, + _isRetry: true, + } as RequestOptions & { responseType?: R }); + } + } + // Refresh failed or retry failed, redirect to login handleUnauthorized(); throw new AuthenticationError( typeof data === "object" && "detail" in data diff --git a/surfsense_web/lib/auth-utils.ts b/surfsense_web/lib/auth-utils.ts index 604843292..8c067a4b7 100644 --- a/surfsense_web/lib/auth-utils.ts +++ b/surfsense_web/lib/auth-utils.ts @@ -4,6 +4,11 @@ const REDIRECT_PATH_KEY = "surfsense_redirect_path"; const BEARER_TOKEN_KEY = "surfsense_bearer_token"; +const REFRESH_TOKEN_KEY = "surfsense_refresh_token"; + +// Flag to prevent multiple simultaneous refresh attempts +let isRefreshing = false; +let refreshPromise: Promise | null = null; /** * Saves the current path and redirects to login page @@ -21,8 +26,9 @@ export function handleUnauthorized(): void { localStorage.setItem(REDIRECT_PATH_KEY, currentPath); } - // Clear the token + // Clear both tokens localStorage.removeItem(BEARER_TOKEN_KEY); + localStorage.removeItem(REFRESH_TOKEN_KEY); // Redirect to home page (which has login options) window.location.href = "/login"; @@ -66,6 +72,71 @@ export function clearBearerToken(): void { localStorage.removeItem(BEARER_TOKEN_KEY); } +/** + * Gets the refresh token from localStorage + */ +export function getRefreshToken(): string | null { + if (typeof window === "undefined") return null; + return localStorage.getItem(REFRESH_TOKEN_KEY); +} + +/** + * Sets the refresh token in localStorage + */ +export function setRefreshToken(token: string): void { + if (typeof window === "undefined") return; + localStorage.setItem(REFRESH_TOKEN_KEY, token); +} + +/** + * Clears the refresh token from localStorage + */ +export function clearRefreshToken(): void { + if (typeof window === "undefined") return; + localStorage.removeItem(REFRESH_TOKEN_KEY); +} + +/** + * Clears all auth tokens from localStorage + */ +export function clearAllTokens(): void { + clearBearerToken(); + clearRefreshToken(); +} + +/** + * Logout the current user by revoking the refresh token and clearing localStorage. + * Returns true if logout was successful (or tokens were cleared), false otherwise. + */ +export async function logout(): Promise { + const refreshToken = getRefreshToken(); + + // Call backend to revoke the refresh token + if (refreshToken) { + try { + const backendUrl = process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL || "http://localhost:8000"; + const response = await fetch(`${backendUrl}/auth/jwt/revoke`, { + method: "POST", + headers: { + "Content-Type": "application/json", + }, + body: JSON.stringify({ refresh_token: refreshToken }), + }); + + if (!response.ok) { + console.warn("Failed to revoke refresh token:", response.status, await response.text()); + } + } catch (error) { + console.warn("Failed to revoke refresh token on server:", error); + // Continue to clear local tokens even if server call fails + } + } + + // Clear all tokens from localStorage + clearAllTokens(); + return true; +} + /** * Checks if the user is authenticated (has a token) */ @@ -106,14 +177,67 @@ export function getAuthHeaders(additionalHeaders?: Record): Reco } /** - * Authenticated fetch wrapper that handles 401 responses uniformly - * Automatically redirects to login on 401 and saves the current path + * Attempts to refresh the access token using the stored refresh token. + * Returns the new access token if successful, null otherwise. + * Exported for use by API services. + */ +export async function refreshAccessToken(): Promise { + // If already refreshing, wait for that request to complete + if (isRefreshing && refreshPromise) { + return refreshPromise; + } + + const currentRefreshToken = getRefreshToken(); + if (!currentRefreshToken) { + return null; + } + + isRefreshing = true; + refreshPromise = (async () => { + try { + const backendUrl = process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL || "http://localhost:8000"; + const response = await fetch(`${backendUrl}/auth/jwt/refresh`, { + method: "POST", + headers: { + "Content-Type": "application/json", + }, + body: JSON.stringify({ refresh_token: currentRefreshToken }), + }); + + if (!response.ok) { + // Refresh failed, clear tokens + clearAllTokens(); + return null; + } + + const data = await response.json(); + if (data.access_token && data.refresh_token) { + setBearerToken(data.access_token); + setRefreshToken(data.refresh_token); + return data.access_token; + } + return null; + } catch { + return null; + } finally { + isRefreshing = false; + refreshPromise = null; + } + })(); + + return refreshPromise; +} + +/** + * Authenticated fetch wrapper that handles 401 responses uniformly. + * On 401, attempts to refresh the token and retry the request. + * If refresh fails, redirects to login and saves the current path. */ export async function authenticatedFetch( url: string, - options?: RequestInit & { skipAuthRedirect?: boolean } + options?: RequestInit & { skipAuthRedirect?: boolean; skipRefresh?: boolean } ): Promise { - const { skipAuthRedirect = false, ...fetchOptions } = options || {}; + const { skipAuthRedirect = false, skipRefresh = false, ...fetchOptions } = options || {}; const headers = getAuthHeaders(fetchOptions.headers as Record); @@ -124,6 +248,23 @@ export async function authenticatedFetch( // Handle 401 Unauthorized if (response.status === 401 && !skipAuthRedirect) { + // Try to refresh the token (unless skipRefresh is set to prevent infinite loops) + if (!skipRefresh) { + const newToken = await refreshAccessToken(); + if (newToken) { + // Retry the original request with the new token + const retryHeaders = { + ...(fetchOptions.headers as Record), + Authorization: `Bearer ${newToken}`, + }; + return fetch(url, { + ...fetchOptions, + headers: retryHeaders, + }); + } + } + + // Refresh failed or was skipped, redirect to login handleUnauthorized(); throw new Error("Unauthorized: Redirecting to login page"); } diff --git a/surfsense_web/messages/en.json b/surfsense_web/messages/en.json index 3020f6289..a1ef1f248 100644 --- a/surfsense_web/messages/en.json +++ b/surfsense_web/messages/en.json @@ -700,6 +700,7 @@ "dark": "Dark", "system": "System", "logout": "Logout", + "loggingOut": "Logging out...", "inbox": "Inbox", "search_inbox": "Search inbox", "mark_all_read": "Mark all as read", diff --git a/surfsense_web/messages/zh.json b/surfsense_web/messages/zh.json index 8c112da03..60a0d279f 100644 --- a/surfsense_web/messages/zh.json +++ b/surfsense_web/messages/zh.json @@ -685,6 +685,7 @@ "dark": "深色", "system": "系统", "logout": "退出登录", + "loggingOut": "正在退出...", "inbox": "收件箱", "search_inbox": "搜索收件箱", "mark_all_read": "全部标记为已读",