mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-07-02 22:01:05 +02:00
Merge pull request #784 from CREDO23/sur-137-bug-oauth-tokens-expire-too-quickly-connectors-and-login
[Fixes] Implement refresh token auth, connector token pre-validation, and logout improvements
This commit is contained in:
commit
459ffd2b78
22 changed files with 770 additions and 28 deletions
|
|
@ -32,6 +32,11 @@ ELECTRIC_DB_PASSWORD=electric_password
|
||||||
SCHEDULE_CHECKER_INTERVAL=5m
|
SCHEDULE_CHECKER_INTERVAL=5m
|
||||||
|
|
||||||
SECRET_KEY=SECRET
|
SECRET_KEY=SECRET
|
||||||
|
|
||||||
|
# JWT Token Lifetimes (optional, defaults shown)
|
||||||
|
# ACCESS_TOKEN_LIFETIME_SECONDS=86400 # 1 day
|
||||||
|
# REFRESH_TOKEN_LIFETIME_SECONDS=1209600 # 2 weeks
|
||||||
|
|
||||||
NEXT_FRONTEND_URL=http://localhost:3000
|
NEXT_FRONTEND_URL=http://localhost:3000
|
||||||
|
|
||||||
# Backend URL for OAuth callbacks (optional, set when behind reverse proxy with HTTPS)
|
# Backend URL for OAuth callbacks (optional, set when behind reverse proxy with HTTPS)
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,92 @@
|
||||||
|
"""Add refresh_tokens table for user session management
|
||||||
|
|
||||||
|
Revision ID: 92
|
||||||
|
Revises: 91
|
||||||
|
|
||||||
|
Changes:
|
||||||
|
1. Create refresh_tokens table with columns:
|
||||||
|
- id (primary key)
|
||||||
|
- user_id (foreign key to user)
|
||||||
|
- token_hash (unique, indexed)
|
||||||
|
- expires_at (indexed)
|
||||||
|
- is_revoked
|
||||||
|
- family_id (indexed, for token rotation tracking)
|
||||||
|
- created_at, updated_at (timestamps)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from collections.abc import Sequence
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from sqlalchemy.dialects.postgresql import UUID
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = "92"
|
||||||
|
down_revision: str | None = "91"
|
||||||
|
branch_labels: str | Sequence[str] | None = None
|
||||||
|
depends_on: str | Sequence[str] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
"""Create refresh_tokens table (idempotent)."""
|
||||||
|
# Check if table already exists
|
||||||
|
connection = op.get_bind()
|
||||||
|
result = connection.execute(
|
||||||
|
sa.text(
|
||||||
|
"SELECT EXISTS (SELECT FROM information_schema.tables WHERE table_name = 'refresh_tokens')"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
table_exists = result.scalar()
|
||||||
|
|
||||||
|
if not table_exists:
|
||||||
|
op.create_table(
|
||||||
|
"refresh_tokens",
|
||||||
|
sa.Column("id", sa.Integer(), autoincrement=True, nullable=False),
|
||||||
|
sa.Column("user_id", UUID(as_uuid=True), nullable=False),
|
||||||
|
sa.Column("token_hash", sa.String(256), nullable=False),
|
||||||
|
sa.Column("expires_at", sa.TIMESTAMP(timezone=True), nullable=False),
|
||||||
|
sa.Column("is_revoked", sa.Boolean(), nullable=False, default=False),
|
||||||
|
sa.Column("family_id", UUID(as_uuid=True), nullable=False),
|
||||||
|
sa.Column(
|
||||||
|
"created_at",
|
||||||
|
sa.TIMESTAMP(timezone=True),
|
||||||
|
server_default=sa.text("now()"),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
sa.Column(
|
||||||
|
"updated_at",
|
||||||
|
sa.TIMESTAMP(timezone=True),
|
||||||
|
server_default=sa.text("now()"),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
sa.ForeignKeyConstraint(
|
||||||
|
["user_id"],
|
||||||
|
["user.id"],
|
||||||
|
ondelete="CASCADE",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create indexes if they don't exist
|
||||||
|
op.execute(
|
||||||
|
"CREATE INDEX IF NOT EXISTS ix_refresh_tokens_user_id ON refresh_tokens (user_id)"
|
||||||
|
)
|
||||||
|
op.execute(
|
||||||
|
"CREATE UNIQUE INDEX IF NOT EXISTS ix_refresh_tokens_token_hash ON refresh_tokens (token_hash)"
|
||||||
|
)
|
||||||
|
op.execute(
|
||||||
|
"CREATE INDEX IF NOT EXISTS ix_refresh_tokens_expires_at ON refresh_tokens (expires_at)"
|
||||||
|
)
|
||||||
|
op.execute(
|
||||||
|
"CREATE INDEX IF NOT EXISTS ix_refresh_tokens_family_id ON refresh_tokens (family_id)"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
"""Drop refresh_tokens table (idempotent)."""
|
||||||
|
op.execute("DROP INDEX IF EXISTS ix_refresh_tokens_family_id")
|
||||||
|
op.execute("DROP INDEX IF EXISTS ix_refresh_tokens_expires_at")
|
||||||
|
op.execute("DROP INDEX IF EXISTS ix_refresh_tokens_token_hash")
|
||||||
|
op.execute("DROP INDEX IF EXISTS ix_refresh_tokens_user_id")
|
||||||
|
op.execute("DROP TABLE IF EXISTS refresh_tokens")
|
||||||
|
|
@ -12,6 +12,7 @@ from app.agents.new_chat.checkpointer import (
|
||||||
from app.config import config, initialize_llm_router
|
from app.config import config, initialize_llm_router
|
||||||
from app.db import User, create_db_and_tables, get_async_session
|
from app.db import User, create_db_and_tables, get_async_session
|
||||||
from app.routes import router as crud_router
|
from app.routes import router as crud_router
|
||||||
|
from app.routes.auth_routes import router as auth_router
|
||||||
from app.schemas import UserCreate, UserRead, UserUpdate
|
from app.schemas import UserCreate, UserRead, UserUpdate
|
||||||
from app.tasks.surfsense_docs_indexer import seed_surfsense_docs
|
from app.tasks.surfsense_docs_indexer import seed_surfsense_docs
|
||||||
from app.users import SECRET, auth_backend, current_active_user, fastapi_users
|
from app.users import SECRET, auth_backend, current_active_user, fastapi_users
|
||||||
|
|
@ -111,6 +112,9 @@ app.include_router(
|
||||||
tags=["users"],
|
tags=["users"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Include custom auth routes (refresh token, logout)
|
||||||
|
app.include_router(auth_router)
|
||||||
|
|
||||||
if config.AUTH_TYPE == "GOOGLE":
|
if config.AUTH_TYPE == "GOOGLE":
|
||||||
from fastapi.responses import RedirectResponse
|
from fastapi.responses import RedirectResponse
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -255,6 +255,14 @@ class Config:
|
||||||
# OAuth JWT
|
# OAuth JWT
|
||||||
SECRET_KEY = os.getenv("SECRET_KEY")
|
SECRET_KEY = os.getenv("SECRET_KEY")
|
||||||
|
|
||||||
|
# JWT Token Lifetimes
|
||||||
|
ACCESS_TOKEN_LIFETIME_SECONDS = int(
|
||||||
|
os.getenv("ACCESS_TOKEN_LIFETIME_SECONDS", str(24 * 60 * 60)) # 1 day
|
||||||
|
)
|
||||||
|
REFRESH_TOKEN_LIFETIME_SECONDS = int(
|
||||||
|
os.getenv("REFRESH_TOKEN_LIFETIME_SECONDS", str(14 * 24 * 60 * 60)) # 2 weeks
|
||||||
|
)
|
||||||
|
|
||||||
# ETL Service
|
# ETL Service
|
||||||
ETL_SERVICE = os.getenv("ETL_SERVICE")
|
ETL_SERVICE = os.getenv("ETL_SERVICE")
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -71,6 +71,14 @@ class AirtableHistoryConnector:
|
||||||
|
|
||||||
config_data = connector.config.copy()
|
config_data = connector.config.copy()
|
||||||
|
|
||||||
|
# Check if access_token exists before processing
|
||||||
|
raw_access_token = config_data.get("access_token")
|
||||||
|
if not raw_access_token:
|
||||||
|
raise ValueError(
|
||||||
|
"Airtable access token not found. "
|
||||||
|
"Please reconnect your Airtable account."
|
||||||
|
)
|
||||||
|
|
||||||
# Decrypt credentials if they are encrypted
|
# Decrypt credentials if they are encrypted
|
||||||
token_encrypted = config_data.get("_token_encrypted", False)
|
token_encrypted = config_data.get("_token_encrypted", False)
|
||||||
if token_encrypted and config.SECRET_KEY:
|
if token_encrypted and config.SECRET_KEY:
|
||||||
|
|
@ -98,6 +106,14 @@ class AirtableHistoryConnector:
|
||||||
f"Failed to decrypt Airtable credentials: {e!s}"
|
f"Failed to decrypt Airtable credentials: {e!s}"
|
||||||
) from e
|
) from e
|
||||||
|
|
||||||
|
# Final validation after decryption
|
||||||
|
final_token = config_data.get("access_token")
|
||||||
|
if not final_token or (isinstance(final_token, str) and not final_token.strip()):
|
||||||
|
raise ValueError(
|
||||||
|
"Airtable access token is invalid or empty. "
|
||||||
|
"Please reconnect your Airtable account."
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self._credentials = AirtableAuthCredentialsBase.from_dict(config_data)
|
self._credentials = AirtableAuthCredentialsBase.from_dict(config_data)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
|
||||||
|
|
@ -87,6 +87,14 @@ class ConfluenceHistoryConnector:
|
||||||
|
|
||||||
if is_oauth:
|
if is_oauth:
|
||||||
# OAuth 2.0 authentication
|
# OAuth 2.0 authentication
|
||||||
|
# Check if access_token exists before processing
|
||||||
|
raw_access_token = config_data.get("access_token")
|
||||||
|
if not raw_access_token:
|
||||||
|
raise ValueError(
|
||||||
|
"Confluence access token not found. "
|
||||||
|
"Please reconnect your Confluence account."
|
||||||
|
)
|
||||||
|
|
||||||
# Decrypt credentials if they are encrypted
|
# Decrypt credentials if they are encrypted
|
||||||
token_encrypted = config_data.get("_token_encrypted", False)
|
token_encrypted = config_data.get("_token_encrypted", False)
|
||||||
if token_encrypted and config.SECRET_KEY:
|
if token_encrypted and config.SECRET_KEY:
|
||||||
|
|
@ -118,6 +126,14 @@ class ConfluenceHistoryConnector:
|
||||||
f"Failed to decrypt Confluence credentials: {e!s}"
|
f"Failed to decrypt Confluence credentials: {e!s}"
|
||||||
) from e
|
) from e
|
||||||
|
|
||||||
|
# Final validation after decryption
|
||||||
|
final_token = config_data.get("access_token")
|
||||||
|
if not final_token or (isinstance(final_token, str) and not final_token.strip()):
|
||||||
|
raise ValueError(
|
||||||
|
"Confluence access token is invalid or empty. "
|
||||||
|
"Please reconnect your Confluence account."
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self._credentials = AtlassianAuthCredentialsBase.from_dict(
|
self._credentials = AtlassianAuthCredentialsBase.from_dict(
|
||||||
config_data
|
config_data
|
||||||
|
|
|
||||||
|
|
@ -86,6 +86,14 @@ class JiraHistoryConnector:
|
||||||
|
|
||||||
if is_oauth:
|
if is_oauth:
|
||||||
# OAuth 2.0 authentication
|
# OAuth 2.0 authentication
|
||||||
|
# Check if access_token exists before processing
|
||||||
|
raw_access_token = config_data.get("access_token")
|
||||||
|
if not raw_access_token:
|
||||||
|
raise ValueError(
|
||||||
|
"Jira access token not found. "
|
||||||
|
"Please reconnect your Jira account."
|
||||||
|
)
|
||||||
|
|
||||||
if not config.SECRET_KEY:
|
if not config.SECRET_KEY:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"SECRET_KEY not configured but tokens are marked as encrypted"
|
"SECRET_KEY not configured but tokens are marked as encrypted"
|
||||||
|
|
@ -119,6 +127,14 @@ class JiraHistoryConnector:
|
||||||
f"Failed to decrypt Jira credentials: {e!s}"
|
f"Failed to decrypt Jira credentials: {e!s}"
|
||||||
) from e
|
) from e
|
||||||
|
|
||||||
|
# Final validation after decryption
|
||||||
|
final_token = config_data.get("access_token")
|
||||||
|
if not final_token or (isinstance(final_token, str) and not final_token.strip()):
|
||||||
|
raise ValueError(
|
||||||
|
"Jira access token is invalid or empty. "
|
||||||
|
"Please reconnect your Jira account."
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self._credentials = AtlassianAuthCredentialsBase.from_dict(
|
self._credentials = AtlassianAuthCredentialsBase.from_dict(
|
||||||
config_data
|
config_data
|
||||||
|
|
|
||||||
|
|
@ -116,6 +116,14 @@ class LinearConnector:
|
||||||
|
|
||||||
config_data = connector.config.copy()
|
config_data = connector.config.copy()
|
||||||
|
|
||||||
|
# Check if access_token exists before processing
|
||||||
|
raw_access_token = config_data.get("access_token")
|
||||||
|
if not raw_access_token:
|
||||||
|
raise ValueError(
|
||||||
|
"Linear access token not found. "
|
||||||
|
"Please reconnect your Linear account."
|
||||||
|
)
|
||||||
|
|
||||||
# Decrypt credentials if they are encrypted
|
# Decrypt credentials if they are encrypted
|
||||||
token_encrypted = config_data.get("_token_encrypted", False)
|
token_encrypted = config_data.get("_token_encrypted", False)
|
||||||
if token_encrypted and config.SECRET_KEY:
|
if token_encrypted and config.SECRET_KEY:
|
||||||
|
|
@ -143,6 +151,14 @@ class LinearConnector:
|
||||||
f"Failed to decrypt Linear credentials: {e!s}"
|
f"Failed to decrypt Linear credentials: {e!s}"
|
||||||
) from e
|
) from e
|
||||||
|
|
||||||
|
# Final validation after decryption
|
||||||
|
final_token = config_data.get("access_token")
|
||||||
|
if not final_token or (isinstance(final_token, str) and not final_token.strip()):
|
||||||
|
raise ValueError(
|
||||||
|
"Linear access token is invalid or empty. "
|
||||||
|
"Please reconnect your Linear account."
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self._credentials = LinearAuthCredentialsBase.from_dict(config_data)
|
self._credentials = LinearAuthCredentialsBase.from_dict(config_data)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
|
||||||
|
|
@ -1361,6 +1361,13 @@ if config.AUTH_TYPE == "GOOGLE":
|
||||||
display_name = Column(String, nullable=True)
|
display_name = Column(String, nullable=True)
|
||||||
avatar_url = Column(String, nullable=True)
|
avatar_url = Column(String, nullable=True)
|
||||||
|
|
||||||
|
# Refresh tokens for this user
|
||||||
|
refresh_tokens = relationship(
|
||||||
|
"RefreshToken",
|
||||||
|
back_populates="user",
|
||||||
|
cascade="all, delete-orphan",
|
||||||
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
|
||||||
class User(SQLAlchemyBaseUserTableUUID, Base):
|
class User(SQLAlchemyBaseUserTableUUID, Base):
|
||||||
|
|
@ -1426,6 +1433,43 @@ else:
|
||||||
display_name = Column(String, nullable=True)
|
display_name = Column(String, nullable=True)
|
||||||
avatar_url = Column(String, nullable=True)
|
avatar_url = Column(String, nullable=True)
|
||||||
|
|
||||||
|
# Refresh tokens for this user
|
||||||
|
refresh_tokens = relationship(
|
||||||
|
"RefreshToken",
|
||||||
|
back_populates="user",
|
||||||
|
cascade="all, delete-orphan",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class RefreshToken(Base, TimestampMixin):
|
||||||
|
"""
|
||||||
|
Stores refresh tokens for user session management.
|
||||||
|
Each row represents one device/session.
|
||||||
|
"""
|
||||||
|
|
||||||
|
__tablename__ = "refresh_tokens"
|
||||||
|
|
||||||
|
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||||
|
user_id = Column(
|
||||||
|
UUID(as_uuid=True),
|
||||||
|
ForeignKey("user.id", ondelete="CASCADE"),
|
||||||
|
nullable=False,
|
||||||
|
index=True,
|
||||||
|
)
|
||||||
|
user = relationship("User", back_populates="refresh_tokens")
|
||||||
|
token_hash = Column(String(256), unique=True, nullable=False, index=True)
|
||||||
|
expires_at = Column(TIMESTAMP(timezone=True), nullable=False, index=True)
|
||||||
|
is_revoked = Column(Boolean, default=False, nullable=False)
|
||||||
|
family_id = Column(UUID(as_uuid=True), nullable=False, index=True)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_expired(self) -> bool:
|
||||||
|
return datetime.now(UTC) >= self.expires_at
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_valid(self) -> bool:
|
||||||
|
return not self.is_expired and not self.is_revoked
|
||||||
|
|
||||||
|
|
||||||
engine = create_async_engine(DATABASE_URL)
|
engine = create_async_engine(DATABASE_URL)
|
||||||
async_session_maker = async_sessionmaker(engine, expire_on_commit=False)
|
async_session_maker = async_sessionmaker(engine, expire_on_commit=False)
|
||||||
|
|
|
||||||
93
surfsense_backend/app/routes/auth_routes.py
Normal file
93
surfsense_backend/app/routes/auth_routes.py
Normal file
|
|
@ -0,0 +1,93 @@
|
||||||
|
"""Authentication routes for refresh token management."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, status
|
||||||
|
from sqlalchemy import select
|
||||||
|
|
||||||
|
from app.db import User, async_session_maker
|
||||||
|
from app.schemas.auth import (
|
||||||
|
LogoutAllResponse,
|
||||||
|
LogoutRequest,
|
||||||
|
LogoutResponse,
|
||||||
|
RefreshTokenRequest,
|
||||||
|
RefreshTokenResponse,
|
||||||
|
)
|
||||||
|
from app.users import current_active_user, get_jwt_strategy
|
||||||
|
from app.utils.refresh_tokens import (
|
||||||
|
revoke_all_user_tokens,
|
||||||
|
revoke_refresh_token,
|
||||||
|
rotate_refresh_token,
|
||||||
|
validate_refresh_token,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/auth/jwt", tags=["auth"])
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/refresh", response_model=RefreshTokenResponse)
|
||||||
|
async def refresh_access_token(request: RefreshTokenRequest):
|
||||||
|
"""
|
||||||
|
Exchange a valid refresh token for a new access token and refresh token.
|
||||||
|
Implements token rotation for security.
|
||||||
|
"""
|
||||||
|
token_record = await validate_refresh_token(request.refresh_token)
|
||||||
|
|
||||||
|
if not token_record:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Invalid or expired refresh token",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get user from token record
|
||||||
|
async with async_session_maker() as session:
|
||||||
|
result = await session.execute(
|
||||||
|
select(User).where(User.id == token_record.user_id)
|
||||||
|
)
|
||||||
|
user = result.scalars().first()
|
||||||
|
|
||||||
|
if not user:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="User not found",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Generate new access token
|
||||||
|
strategy = get_jwt_strategy()
|
||||||
|
access_token = await strategy.write_token(user)
|
||||||
|
|
||||||
|
# Rotate refresh token
|
||||||
|
new_refresh_token = await rotate_refresh_token(token_record)
|
||||||
|
|
||||||
|
logger.info(f"Refreshed token for user {user.id}")
|
||||||
|
|
||||||
|
return RefreshTokenResponse(
|
||||||
|
access_token=access_token,
|
||||||
|
refresh_token=new_refresh_token,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/revoke", response_model=LogoutResponse)
|
||||||
|
async def revoke_token(request: LogoutRequest):
|
||||||
|
"""
|
||||||
|
Logout current device by revoking the provided refresh token.
|
||||||
|
Does not require authentication - just the refresh token.
|
||||||
|
"""
|
||||||
|
revoked = await revoke_refresh_token(request.refresh_token)
|
||||||
|
if revoked:
|
||||||
|
logger.info("User logged out from current device - token revoked")
|
||||||
|
else:
|
||||||
|
logger.warning("Logout called but no matching token found to revoke")
|
||||||
|
return LogoutResponse()
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/logout-all", response_model=LogoutAllResponse)
|
||||||
|
async def logout_all_devices(user: User = Depends(current_active_user)):
|
||||||
|
"""
|
||||||
|
Logout from all devices by revoking all refresh tokens for the user.
|
||||||
|
Requires valid access token.
|
||||||
|
"""
|
||||||
|
await revoke_all_user_tokens(user.id)
|
||||||
|
logger.info(f"User {user.id} logged out from all devices")
|
||||||
|
return LogoutAllResponse()
|
||||||
|
|
@ -1,3 +1,10 @@
|
||||||
|
from .auth import (
|
||||||
|
LogoutAllResponse,
|
||||||
|
LogoutRequest,
|
||||||
|
LogoutResponse,
|
||||||
|
RefreshTokenRequest,
|
||||||
|
RefreshTokenResponse,
|
||||||
|
)
|
||||||
from .base import IDModel, TimestampModel
|
from .base import IDModel, TimestampModel
|
||||||
from .chunks import ChunkBase, ChunkCreate, ChunkRead, ChunkUpdate
|
from .chunks import ChunkBase, ChunkCreate, ChunkRead, ChunkUpdate
|
||||||
from .documents import (
|
from .documents import (
|
||||||
|
|
@ -117,6 +124,10 @@ __all__ = [
|
||||||
"LogFilter",
|
"LogFilter",
|
||||||
"LogRead",
|
"LogRead",
|
||||||
"LogUpdate",
|
"LogUpdate",
|
||||||
|
# Auth schemas
|
||||||
|
"LogoutAllResponse",
|
||||||
|
"LogoutRequest",
|
||||||
|
"LogoutResponse",
|
||||||
# Search source connector schemas
|
# Search source connector schemas
|
||||||
"MCPConnectorCreate",
|
"MCPConnectorCreate",
|
||||||
"MCPConnectorRead",
|
"MCPConnectorRead",
|
||||||
|
|
@ -146,6 +157,8 @@ __all__ = [
|
||||||
"PodcastCreate",
|
"PodcastCreate",
|
||||||
"PodcastRead",
|
"PodcastRead",
|
||||||
"PodcastUpdate",
|
"PodcastUpdate",
|
||||||
|
"RefreshTokenRequest",
|
||||||
|
"RefreshTokenResponse",
|
||||||
"RoleCreate",
|
"RoleCreate",
|
||||||
"RoleRead",
|
"RoleRead",
|
||||||
"RoleUpdate",
|
"RoleUpdate",
|
||||||
|
|
|
||||||
35
surfsense_backend/app/schemas/auth.py
Normal file
35
surfsense_backend/app/schemas/auth.py
Normal file
|
|
@ -0,0 +1,35 @@
|
||||||
|
"""Authentication schemas for refresh token endpoints."""
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class RefreshTokenRequest(BaseModel):
|
||||||
|
"""Request body for token refresh endpoint."""
|
||||||
|
|
||||||
|
refresh_token: str
|
||||||
|
|
||||||
|
|
||||||
|
class RefreshTokenResponse(BaseModel):
|
||||||
|
"""Response from token refresh endpoint."""
|
||||||
|
|
||||||
|
access_token: str
|
||||||
|
refresh_token: str
|
||||||
|
token_type: str = "bearer"
|
||||||
|
|
||||||
|
|
||||||
|
class LogoutRequest(BaseModel):
|
||||||
|
"""Request body for logout endpoint (current device)."""
|
||||||
|
|
||||||
|
refresh_token: str
|
||||||
|
|
||||||
|
|
||||||
|
class LogoutResponse(BaseModel):
|
||||||
|
"""Response from logout endpoint (current device)."""
|
||||||
|
|
||||||
|
detail: str = "Successfully logged out"
|
||||||
|
|
||||||
|
|
||||||
|
class LogoutAllResponse(BaseModel):
|
||||||
|
"""Response from logout all devices endpoint."""
|
||||||
|
|
||||||
|
detail: str = "Successfully logged out from all devices"
|
||||||
|
|
@ -23,17 +23,20 @@ from app.db import (
|
||||||
get_default_roles_config,
|
get_default_roles_config,
|
||||||
get_user_db,
|
get_user_db,
|
||||||
)
|
)
|
||||||
|
from app.utils.refresh_tokens import create_refresh_token
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class BearerResponse(BaseModel):
|
class BearerResponse(BaseModel):
|
||||||
access_token: str
|
access_token: str
|
||||||
|
refresh_token: str
|
||||||
token_type: str
|
token_type: str
|
||||||
|
|
||||||
|
|
||||||
SECRET = config.SECRET_KEY
|
SECRET = config.SECRET_KEY
|
||||||
|
|
||||||
|
|
||||||
if config.AUTH_TYPE == "GOOGLE":
|
if config.AUTH_TYPE == "GOOGLE":
|
||||||
from httpx_oauth.clients.google import GoogleOAuth2
|
from httpx_oauth.clients.google import GoogleOAuth2
|
||||||
|
|
||||||
|
|
@ -183,7 +186,10 @@ async def get_user_manager(user_db: SQLAlchemyUserDatabase = Depends(get_user_db
|
||||||
|
|
||||||
|
|
||||||
def get_jwt_strategy() -> JWTStrategy[models.UP, models.ID]:
|
def get_jwt_strategy() -> JWTStrategy[models.UP, models.ID]:
|
||||||
return JWTStrategy(secret=SECRET, lifetime_seconds=3600 * 24)
|
return JWTStrategy(
|
||||||
|
secret=SECRET,
|
||||||
|
lifetime_seconds=config.ACCESS_TOKEN_LIFETIME_SECONDS,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# # COOKIE AUTH | Uncomment if you want to use cookie auth.
|
# # COOKIE AUTH | Uncomment if you want to use cookie auth.
|
||||||
|
|
@ -209,9 +215,30 @@ def get_jwt_strategy() -> JWTStrategy[models.UP, models.ID]:
|
||||||
# BEARER AUTH CODE.
|
# BEARER AUTH CODE.
|
||||||
class CustomBearerTransport(BearerTransport):
|
class CustomBearerTransport(BearerTransport):
|
||||||
async def get_login_response(self, token: str) -> Response:
|
async def get_login_response(self, token: str) -> Response:
|
||||||
bearer_response = BearerResponse(access_token=token, token_type="bearer")
|
import jwt
|
||||||
redirect_url = f"{config.NEXT_FRONTEND_URL}/auth/callback?token={bearer_response.access_token}"
|
|
||||||
|
# Decode JWT to get user_id for refresh token creation
|
||||||
|
try:
|
||||||
|
payload = jwt.decode(token, SECRET, algorithms=["HS256"], options={"verify_aud": False})
|
||||||
|
user_id = uuid.UUID(payload.get("sub"))
|
||||||
|
refresh_token = await create_refresh_token(user_id)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to create refresh token: {e}")
|
||||||
|
# Fall back to response without refresh token
|
||||||
|
refresh_token = ""
|
||||||
|
|
||||||
|
bearer_response = BearerResponse(
|
||||||
|
access_token=token,
|
||||||
|
refresh_token=refresh_token,
|
||||||
|
token_type="bearer",
|
||||||
|
)
|
||||||
|
|
||||||
if config.AUTH_TYPE == "GOOGLE":
|
if config.AUTH_TYPE == "GOOGLE":
|
||||||
|
redirect_url = (
|
||||||
|
f"{config.NEXT_FRONTEND_URL}/auth/callback"
|
||||||
|
f"?token={bearer_response.access_token}"
|
||||||
|
f"&refresh_token={bearer_response.refresh_token}"
|
||||||
|
)
|
||||||
return RedirectResponse(redirect_url, status_code=302)
|
return RedirectResponse(redirect_url, status_code=302)
|
||||||
else:
|
else:
|
||||||
return JSONResponse(bearer_response.model_dump())
|
return JSONResponse(bearer_response.model_dump())
|
||||||
|
|
|
||||||
153
surfsense_backend/app/utils/refresh_tokens.py
Normal file
153
surfsense_backend/app/utils/refresh_tokens.py
Normal file
|
|
@ -0,0 +1,153 @@
|
||||||
|
"""Utilities for managing refresh tokens."""
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
import logging
|
||||||
|
import secrets
|
||||||
|
import uuid
|
||||||
|
from datetime import UTC, datetime, timedelta
|
||||||
|
|
||||||
|
from sqlalchemy import select, update
|
||||||
|
|
||||||
|
from app.config import config
|
||||||
|
from app.db import RefreshToken, async_session_maker
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def generate_refresh_token() -> str:
|
||||||
|
"""Generate a cryptographically secure refresh token."""
|
||||||
|
return secrets.token_urlsafe(32)
|
||||||
|
|
||||||
|
|
||||||
|
def hash_token(token: str) -> str:
|
||||||
|
"""Hash a token for secure storage."""
|
||||||
|
return hashlib.sha256(token.encode()).hexdigest()
|
||||||
|
|
||||||
|
|
||||||
|
async def create_refresh_token(
|
||||||
|
user_id: uuid.UUID,
|
||||||
|
family_id: uuid.UUID | None = None,
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Create and store a new refresh token for a user.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: The user's ID
|
||||||
|
family_id: Optional family ID for token rotation
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The plaintext refresh token
|
||||||
|
"""
|
||||||
|
token = generate_refresh_token()
|
||||||
|
token_hash = hash_token(token)
|
||||||
|
expires_at = datetime.now(UTC) + timedelta(
|
||||||
|
seconds=config.REFRESH_TOKEN_LIFETIME_SECONDS
|
||||||
|
)
|
||||||
|
|
||||||
|
if family_id is None:
|
||||||
|
family_id = uuid.uuid4()
|
||||||
|
|
||||||
|
async with async_session_maker() as session:
|
||||||
|
refresh_token = RefreshToken(
|
||||||
|
user_id=user_id,
|
||||||
|
token_hash=token_hash,
|
||||||
|
expires_at=expires_at,
|
||||||
|
family_id=family_id,
|
||||||
|
)
|
||||||
|
session.add(refresh_token)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
return token
|
||||||
|
|
||||||
|
|
||||||
|
async def validate_refresh_token(token: str) -> RefreshToken | None:
|
||||||
|
"""
|
||||||
|
Validate a refresh token. Handles reuse detection.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
token: The plaintext refresh token
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
RefreshToken if valid, None otherwise
|
||||||
|
"""
|
||||||
|
token_hash = hash_token(token)
|
||||||
|
|
||||||
|
async with async_session_maker() as session:
|
||||||
|
result = await session.execute(
|
||||||
|
select(RefreshToken).where(RefreshToken.token_hash == token_hash)
|
||||||
|
)
|
||||||
|
refresh_token = result.scalars().first()
|
||||||
|
|
||||||
|
if not refresh_token:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Reuse detection: revoked token used while family has active tokens
|
||||||
|
if refresh_token.is_revoked:
|
||||||
|
active = await session.execute(
|
||||||
|
select(RefreshToken).where(
|
||||||
|
RefreshToken.family_id == refresh_token.family_id,
|
||||||
|
RefreshToken.is_revoked == False, # noqa: E712
|
||||||
|
RefreshToken.expires_at > datetime.now(UTC),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if active.scalars().first():
|
||||||
|
# Revoke entire family
|
||||||
|
await session.execute(
|
||||||
|
update(RefreshToken)
|
||||||
|
.where(RefreshToken.family_id == refresh_token.family_id)
|
||||||
|
.values(is_revoked=True)
|
||||||
|
)
|
||||||
|
await session.commit()
|
||||||
|
logger.warning(f"Token reuse detected for user {refresh_token.user_id}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
if refresh_token.is_expired:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return refresh_token
|
||||||
|
|
||||||
|
|
||||||
|
async def rotate_refresh_token(old_token: RefreshToken) -> str:
|
||||||
|
"""Revoke old token and create new one in same family."""
|
||||||
|
async with async_session_maker() as session:
|
||||||
|
await session.execute(
|
||||||
|
update(RefreshToken)
|
||||||
|
.where(RefreshToken.id == old_token.id)
|
||||||
|
.values(is_revoked=True)
|
||||||
|
)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
return await create_refresh_token(old_token.user_id, old_token.family_id)
|
||||||
|
|
||||||
|
|
||||||
|
async def revoke_refresh_token(token: str) -> bool:
|
||||||
|
"""
|
||||||
|
Revoke a single refresh token by its plaintext value.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
token: The plaintext refresh token
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if token was found and revoked, False otherwise
|
||||||
|
"""
|
||||||
|
token_hash = hash_token(token)
|
||||||
|
|
||||||
|
async with async_session_maker() as session:
|
||||||
|
result = await session.execute(
|
||||||
|
update(RefreshToken)
|
||||||
|
.where(RefreshToken.token_hash == token_hash)
|
||||||
|
.values(is_revoked=True)
|
||||||
|
)
|
||||||
|
await session.commit()
|
||||||
|
return result.rowcount > 0
|
||||||
|
|
||||||
|
|
||||||
|
async def revoke_all_user_tokens(user_id: uuid.UUID) -> None:
|
||||||
|
"""Revoke all refresh tokens for a user (logout all devices)."""
|
||||||
|
async with async_session_maker() as session:
|
||||||
|
await session.execute(
|
||||||
|
update(RefreshToken)
|
||||||
|
.where(RefreshToken.user_id == user_id)
|
||||||
|
.values(is_revoked=True)
|
||||||
|
)
|
||||||
|
await session.commit()
|
||||||
|
|
@ -3,7 +3,7 @@
|
||||||
import { useSearchParams } from "next/navigation";
|
import { useSearchParams } from "next/navigation";
|
||||||
import { useEffect } from "react";
|
import { useEffect } from "react";
|
||||||
import { useGlobalLoadingEffect } from "@/hooks/use-global-loading";
|
import { useGlobalLoadingEffect } from "@/hooks/use-global-loading";
|
||||||
import { getAndClearRedirectPath, setBearerToken } from "@/lib/auth-utils";
|
import { getAndClearRedirectPath, setBearerToken, setRefreshToken } from "@/lib/auth-utils";
|
||||||
import { trackLoginSuccess } from "@/lib/posthog/events";
|
import { trackLoginSuccess } from "@/lib/posthog/events";
|
||||||
|
|
||||||
interface TokenHandlerProps {
|
interface TokenHandlerProps {
|
||||||
|
|
@ -35,8 +35,9 @@ const TokenHandler = ({
|
||||||
// Only run on client-side
|
// Only run on client-side
|
||||||
if (typeof window === "undefined") return;
|
if (typeof window === "undefined") return;
|
||||||
|
|
||||||
// Get token from URL parameters
|
// Get tokens from URL parameters
|
||||||
const token = searchParams.get(tokenParamName);
|
const token = searchParams.get(tokenParamName);
|
||||||
|
const refreshToken = searchParams.get("refresh_token");
|
||||||
|
|
||||||
if (token) {
|
if (token) {
|
||||||
try {
|
try {
|
||||||
|
|
@ -50,10 +51,15 @@ const TokenHandler = ({
|
||||||
// Clear the flag for future logins
|
// Clear the flag for future logins
|
||||||
sessionStorage.removeItem("login_success_tracked");
|
sessionStorage.removeItem("login_success_tracked");
|
||||||
|
|
||||||
// Store token in localStorage using both methods for compatibility
|
// Store access token in localStorage using both methods for compatibility
|
||||||
localStorage.setItem(storageKey, token);
|
localStorage.setItem(storageKey, token);
|
||||||
setBearerToken(token);
|
setBearerToken(token);
|
||||||
|
|
||||||
|
// Store refresh token if provided
|
||||||
|
if (refreshToken) {
|
||||||
|
setRefreshToken(refreshToken);
|
||||||
|
}
|
||||||
|
|
||||||
// Check if there's a saved redirect path from before the auth flow
|
// Check if there's a saved redirect path from before the auth flow
|
||||||
const savedRedirectPath = getAndClearRedirectPath();
|
const savedRedirectPath = getAndClearRedirectPath();
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,8 @@
|
||||||
"use client";
|
"use client";
|
||||||
|
|
||||||
import { BadgeCheck, LogOut } from "lucide-react";
|
import { BadgeCheck, Loader2, LogOut } from "lucide-react";
|
||||||
import { useRouter } from "next/navigation";
|
import { useRouter } from "next/navigation";
|
||||||
|
import { useState } from "react";
|
||||||
import { Avatar, AvatarFallback, AvatarImage } from "@/components/ui/avatar";
|
import { Avatar, AvatarFallback, AvatarImage } from "@/components/ui/avatar";
|
||||||
import { Button } from "@/components/ui/button";
|
import { Button } from "@/components/ui/button";
|
||||||
import {
|
import {
|
||||||
|
|
@ -13,6 +14,7 @@ import {
|
||||||
DropdownMenuSeparator,
|
DropdownMenuSeparator,
|
||||||
DropdownMenuTrigger,
|
DropdownMenuTrigger,
|
||||||
} from "@/components/ui/dropdown-menu";
|
} from "@/components/ui/dropdown-menu";
|
||||||
|
import { logout } from "@/lib/auth-utils";
|
||||||
import { cleanupElectric } from "@/lib/electric/client";
|
import { cleanupElectric } from "@/lib/electric/client";
|
||||||
import { resetUser, trackLogout } from "@/lib/posthog/events";
|
import { resetUser, trackLogout } from "@/lib/posthog/events";
|
||||||
|
|
||||||
|
|
@ -26,8 +28,11 @@ export function UserDropdown({
|
||||||
};
|
};
|
||||||
}) {
|
}) {
|
||||||
const router = useRouter();
|
const router = useRouter();
|
||||||
|
const [isLoggingOut, setIsLoggingOut] = useState(false);
|
||||||
|
|
||||||
const handleLogout = async () => {
|
const handleLogout = async () => {
|
||||||
|
if (isLoggingOut) return;
|
||||||
|
setIsLoggingOut(true);
|
||||||
try {
|
try {
|
||||||
// Track logout event and reset PostHog identity
|
// Track logout event and reset PostHog identity
|
||||||
trackLogout();
|
trackLogout();
|
||||||
|
|
@ -41,15 +46,17 @@ export function UserDropdown({
|
||||||
console.warn("[Logout] Electric cleanup failed (will be handled on next login):", err);
|
console.warn("[Logout] Electric cleanup failed (will be handled on next login):", err);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Revoke refresh token on server and clear all tokens from localStorage
|
||||||
|
await logout();
|
||||||
|
|
||||||
if (typeof window !== "undefined") {
|
if (typeof window !== "undefined") {
|
||||||
localStorage.removeItem("surfsense_bearer_token");
|
|
||||||
window.location.href = "/";
|
window.location.href = "/";
|
||||||
}
|
}
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error("Error during logout:", error);
|
console.error("Error during logout:", error);
|
||||||
// Optionally, provide user feedback
|
// Even if there's an error, try to clear tokens and redirect
|
||||||
|
await logout();
|
||||||
if (typeof window !== "undefined") {
|
if (typeof window !== "undefined") {
|
||||||
localStorage.removeItem("surfsense_bearer_token");
|
|
||||||
window.location.href = "/";
|
window.location.href = "/";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -85,9 +92,17 @@ export function UserDropdown({
|
||||||
</DropdownMenuItem>
|
</DropdownMenuItem>
|
||||||
</DropdownMenuGroup>
|
</DropdownMenuGroup>
|
||||||
<DropdownMenuSeparator />
|
<DropdownMenuSeparator />
|
||||||
<DropdownMenuItem onClick={handleLogout} className="text-xs md:text-sm">
|
<DropdownMenuItem
|
||||||
<LogOut className="mr-2 h-3.5 w-3.5 md:h-4 md:w-4" />
|
onClick={handleLogout}
|
||||||
Log out
|
className="text-xs md:text-sm"
|
||||||
|
disabled={isLoggingOut}
|
||||||
|
>
|
||||||
|
{isLoggingOut ? (
|
||||||
|
<Loader2 className="mr-2 h-3.5 w-3.5 md:h-4 md:w-4 animate-spin" />
|
||||||
|
) : (
|
||||||
|
<LogOut className="mr-2 h-3.5 w-3.5 md:h-4 md:w-4" />
|
||||||
|
)}
|
||||||
|
{isLoggingOut ? "Logging out..." : "Log out"}
|
||||||
</DropdownMenuItem>
|
</DropdownMenuItem>
|
||||||
</DropdownMenuContent>
|
</DropdownMenuContent>
|
||||||
</DropdownMenu>
|
</DropdownMenu>
|
||||||
|
|
|
||||||
|
|
@ -26,6 +26,7 @@ import { isPageLimitExceededMetadata } from "@/contracts/types/inbox.types";
|
||||||
import { useInbox } from "@/hooks/use-inbox";
|
import { useInbox } from "@/hooks/use-inbox";
|
||||||
import { searchSpacesApiService } from "@/lib/apis/search-spaces-api.service";
|
import { searchSpacesApiService } from "@/lib/apis/search-spaces-api.service";
|
||||||
import { deleteThread, fetchThreads, updateThread } from "@/lib/chat/thread-persistence";
|
import { deleteThread, fetchThreads, updateThread } from "@/lib/chat/thread-persistence";
|
||||||
|
import { logout } from "@/lib/auth-utils";
|
||||||
import { cleanupElectric } from "@/lib/electric/client";
|
import { cleanupElectric } from "@/lib/electric/client";
|
||||||
import { resetUser, trackLogout } from "@/lib/posthog/events";
|
import { resetUser, trackLogout } from "@/lib/posthog/events";
|
||||||
import { cacheKeys } from "@/lib/query-client/cache-keys";
|
import { cacheKeys } from "@/lib/query-client/cache-keys";
|
||||||
|
|
@ -474,12 +475,15 @@ export function LayoutDataProvider({
|
||||||
console.warn("[Logout] Electric cleanup failed (will be handled on next login):", err);
|
console.warn("[Logout] Electric cleanup failed (will be handled on next login):", err);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Revoke refresh token on server and clear all tokens from localStorage
|
||||||
|
await logout();
|
||||||
|
|
||||||
if (typeof window !== "undefined") {
|
if (typeof window !== "undefined") {
|
||||||
localStorage.removeItem("surfsense_bearer_token");
|
|
||||||
router.push("/");
|
router.push("/");
|
||||||
}
|
}
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error("Error during logout:", error);
|
console.error("Error during logout:", error);
|
||||||
|
await logout();
|
||||||
router.push("/");
|
router.push("/");
|
||||||
}
|
}
|
||||||
}, [router]);
|
}, [router]);
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,8 @@
|
||||||
"use client";
|
"use client";
|
||||||
|
|
||||||
import { Check, ChevronUp, Languages, Laptop, LogOut, Moon, Settings, Sun } from "lucide-react";
|
import { Check, ChevronUp, Languages, Laptop, Loader2, LogOut, Moon, Settings, Sun } from "lucide-react";
|
||||||
import { useTranslations } from "next-intl";
|
import { useTranslations } from "next-intl";
|
||||||
|
import { useState } from "react";
|
||||||
import {
|
import {
|
||||||
DropdownMenu,
|
DropdownMenu,
|
||||||
DropdownMenuContent,
|
DropdownMenuContent,
|
||||||
|
|
@ -124,6 +125,7 @@ export function SidebarUserProfile({
|
||||||
}: SidebarUserProfileProps) {
|
}: SidebarUserProfileProps) {
|
||||||
const t = useTranslations("sidebar");
|
const t = useTranslations("sidebar");
|
||||||
const { locale, setLocale } = useLocaleContext();
|
const { locale, setLocale } = useLocaleContext();
|
||||||
|
const [isLoggingOut, setIsLoggingOut] = useState(false);
|
||||||
const bgColor = stringToColor(user.email);
|
const bgColor = stringToColor(user.email);
|
||||||
const initials = getInitials(user.email);
|
const initials = getInitials(user.email);
|
||||||
const displayName = user.name || user.email.split("@")[0];
|
const displayName = user.name || user.email.split("@")[0];
|
||||||
|
|
@ -136,6 +138,16 @@ export function SidebarUserProfile({
|
||||||
setTheme?.(newTheme);
|
setTheme?.(newTheme);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const handleLogout = async () => {
|
||||||
|
if (isLoggingOut || !onLogout) return;
|
||||||
|
setIsLoggingOut(true);
|
||||||
|
try {
|
||||||
|
await onLogout();
|
||||||
|
} finally {
|
||||||
|
setIsLoggingOut(false);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
// Collapsed view - just show avatar with dropdown
|
// Collapsed view - just show avatar with dropdown
|
||||||
if (isCollapsed) {
|
if (isCollapsed) {
|
||||||
return (
|
return (
|
||||||
|
|
@ -242,9 +254,13 @@ export function SidebarUserProfile({
|
||||||
|
|
||||||
<DropdownMenuSeparator />
|
<DropdownMenuSeparator />
|
||||||
|
|
||||||
<DropdownMenuItem onClick={onLogout}>
|
<DropdownMenuItem onClick={handleLogout} disabled={isLoggingOut}>
|
||||||
<LogOut className="mr-2 h-4 w-4" />
|
{isLoggingOut ? (
|
||||||
{t("logout")}
|
<Loader2 className="mr-2 h-4 w-4 animate-spin" />
|
||||||
|
) : (
|
||||||
|
<LogOut className="mr-2 h-4 w-4" />
|
||||||
|
)}
|
||||||
|
{isLoggingOut ? t("loggingOut") : t("logout")}
|
||||||
</DropdownMenuItem>
|
</DropdownMenuItem>
|
||||||
</DropdownMenuContent>
|
</DropdownMenuContent>
|
||||||
</DropdownMenu>
|
</DropdownMenu>
|
||||||
|
|
@ -360,9 +376,13 @@ export function SidebarUserProfile({
|
||||||
|
|
||||||
<DropdownMenuSeparator />
|
<DropdownMenuSeparator />
|
||||||
|
|
||||||
<DropdownMenuItem onClick={onLogout}>
|
<DropdownMenuItem onClick={handleLogout} disabled={isLoggingOut}>
|
||||||
<LogOut className="mr-2 h-4 w-4" />
|
{isLoggingOut ? (
|
||||||
{t("logout")}
|
<Loader2 className="mr-2 h-4 w-4 animate-spin" />
|
||||||
|
) : (
|
||||||
|
<LogOut className="mr-2 h-4 w-4" />
|
||||||
|
)}
|
||||||
|
{isLoggingOut ? t("loggingOut") : t("logout")}
|
||||||
</DropdownMenuItem>
|
</DropdownMenuItem>
|
||||||
</DropdownMenuContent>
|
</DropdownMenuContent>
|
||||||
</DropdownMenu>
|
</DropdownMenu>
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
import type { ZodType } from "zod";
|
import type { ZodType } from "zod";
|
||||||
import { getBearerToken, handleUnauthorized } from "../auth-utils";
|
import { getBearerToken, handleUnauthorized, refreshAccessToken } from "../auth-utils";
|
||||||
import { AppError, AuthenticationError, AuthorizationError, NotFoundError } from "../error";
|
import { AppError, AuthenticationError, AuthorizationError, NotFoundError } from "../error";
|
||||||
|
|
||||||
enum ResponseType {
|
enum ResponseType {
|
||||||
|
|
@ -17,6 +17,7 @@ export type RequestOptions = {
|
||||||
signal?: AbortSignal;
|
signal?: AbortSignal;
|
||||||
body?: any;
|
body?: any;
|
||||||
responseType?: ResponseType;
|
responseType?: ResponseType;
|
||||||
|
_isRetry?: boolean; // Internal flag to prevent infinite retry loops
|
||||||
// Add more options as needed
|
// Add more options as needed
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
@ -135,8 +136,23 @@ class BaseApiService {
|
||||||
throw new AppError("Failed to parse response", response.status, response.statusText);
|
throw new AppError("Failed to parse response", response.status, response.statusText);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handle 401 first before other error handling - ensures token is cleared and user redirected
|
// Handle 401 - try to refresh token first (only once)
|
||||||
if (response.status === 401) {
|
if (response.status === 401) {
|
||||||
|
if (!options?._isRetry) {
|
||||||
|
const newToken = await refreshAccessToken();
|
||||||
|
if (newToken) {
|
||||||
|
// Retry the request with the new token
|
||||||
|
return this.request(url, responseSchema, {
|
||||||
|
...mergedOptions,
|
||||||
|
headers: {
|
||||||
|
...mergedOptions.headers,
|
||||||
|
Authorization: `Bearer ${newToken}`,
|
||||||
|
},
|
||||||
|
_isRetry: true,
|
||||||
|
} as RequestOptions & { responseType?: R });
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Refresh failed or retry failed, redirect to login
|
||||||
handleUnauthorized();
|
handleUnauthorized();
|
||||||
throw new AuthenticationError(
|
throw new AuthenticationError(
|
||||||
typeof data === "object" && "detail" in data
|
typeof data === "object" && "detail" in data
|
||||||
|
|
|
||||||
|
|
@ -4,6 +4,11 @@
|
||||||
|
|
||||||
const REDIRECT_PATH_KEY = "surfsense_redirect_path";
|
const REDIRECT_PATH_KEY = "surfsense_redirect_path";
|
||||||
const BEARER_TOKEN_KEY = "surfsense_bearer_token";
|
const BEARER_TOKEN_KEY = "surfsense_bearer_token";
|
||||||
|
const REFRESH_TOKEN_KEY = "surfsense_refresh_token";
|
||||||
|
|
||||||
|
// Flag to prevent multiple simultaneous refresh attempts
|
||||||
|
let isRefreshing = false;
|
||||||
|
let refreshPromise: Promise<string | null> | null = null;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Saves the current path and redirects to login page
|
* Saves the current path and redirects to login page
|
||||||
|
|
@ -21,8 +26,9 @@ export function handleUnauthorized(): void {
|
||||||
localStorage.setItem(REDIRECT_PATH_KEY, currentPath);
|
localStorage.setItem(REDIRECT_PATH_KEY, currentPath);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Clear the token
|
// Clear both tokens
|
||||||
localStorage.removeItem(BEARER_TOKEN_KEY);
|
localStorage.removeItem(BEARER_TOKEN_KEY);
|
||||||
|
localStorage.removeItem(REFRESH_TOKEN_KEY);
|
||||||
|
|
||||||
// Redirect to home page (which has login options)
|
// Redirect to home page (which has login options)
|
||||||
window.location.href = "/login";
|
window.location.href = "/login";
|
||||||
|
|
@ -66,6 +72,71 @@ export function clearBearerToken(): void {
|
||||||
localStorage.removeItem(BEARER_TOKEN_KEY);
|
localStorage.removeItem(BEARER_TOKEN_KEY);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Gets the refresh token from localStorage
|
||||||
|
*/
|
||||||
|
export function getRefreshToken(): string | null {
|
||||||
|
if (typeof window === "undefined") return null;
|
||||||
|
return localStorage.getItem(REFRESH_TOKEN_KEY);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Sets the refresh token in localStorage
|
||||||
|
*/
|
||||||
|
export function setRefreshToken(token: string): void {
|
||||||
|
if (typeof window === "undefined") return;
|
||||||
|
localStorage.setItem(REFRESH_TOKEN_KEY, token);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Clears the refresh token from localStorage
|
||||||
|
*/
|
||||||
|
export function clearRefreshToken(): void {
|
||||||
|
if (typeof window === "undefined") return;
|
||||||
|
localStorage.removeItem(REFRESH_TOKEN_KEY);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Clears all auth tokens from localStorage
|
||||||
|
*/
|
||||||
|
export function clearAllTokens(): void {
|
||||||
|
clearBearerToken();
|
||||||
|
clearRefreshToken();
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Logout the current user by revoking the refresh token and clearing localStorage.
|
||||||
|
* Returns true if logout was successful (or tokens were cleared), false otherwise.
|
||||||
|
*/
|
||||||
|
export async function logout(): Promise<boolean> {
|
||||||
|
const refreshToken = getRefreshToken();
|
||||||
|
|
||||||
|
// Call backend to revoke the refresh token
|
||||||
|
if (refreshToken) {
|
||||||
|
try {
|
||||||
|
const backendUrl = process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL || "http://localhost:8000";
|
||||||
|
const response = await fetch(`${backendUrl}/auth/jwt/revoke`, {
|
||||||
|
method: "POST",
|
||||||
|
headers: {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
},
|
||||||
|
body: JSON.stringify({ refresh_token: refreshToken }),
|
||||||
|
});
|
||||||
|
|
||||||
|
if (!response.ok) {
|
||||||
|
console.warn("Failed to revoke refresh token:", response.status, await response.text());
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
console.warn("Failed to revoke refresh token on server:", error);
|
||||||
|
// Continue to clear local tokens even if server call fails
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clear all tokens from localStorage
|
||||||
|
clearAllTokens();
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Checks if the user is authenticated (has a token)
|
* Checks if the user is authenticated (has a token)
|
||||||
*/
|
*/
|
||||||
|
|
@ -106,14 +177,67 @@ export function getAuthHeaders(additionalHeaders?: Record<string, string>): Reco
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Authenticated fetch wrapper that handles 401 responses uniformly
|
* Attempts to refresh the access token using the stored refresh token.
|
||||||
* Automatically redirects to login on 401 and saves the current path
|
* Returns the new access token if successful, null otherwise.
|
||||||
|
* Exported for use by API services.
|
||||||
|
*/
|
||||||
|
export async function refreshAccessToken(): Promise<string | null> {
|
||||||
|
// If already refreshing, wait for that request to complete
|
||||||
|
if (isRefreshing && refreshPromise) {
|
||||||
|
return refreshPromise;
|
||||||
|
}
|
||||||
|
|
||||||
|
const currentRefreshToken = getRefreshToken();
|
||||||
|
if (!currentRefreshToken) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
isRefreshing = true;
|
||||||
|
refreshPromise = (async () => {
|
||||||
|
try {
|
||||||
|
const backendUrl = process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL || "http://localhost:8000";
|
||||||
|
const response = await fetch(`${backendUrl}/auth/jwt/refresh`, {
|
||||||
|
method: "POST",
|
||||||
|
headers: {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
},
|
||||||
|
body: JSON.stringify({ refresh_token: currentRefreshToken }),
|
||||||
|
});
|
||||||
|
|
||||||
|
if (!response.ok) {
|
||||||
|
// Refresh failed, clear tokens
|
||||||
|
clearAllTokens();
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
const data = await response.json();
|
||||||
|
if (data.access_token && data.refresh_token) {
|
||||||
|
setBearerToken(data.access_token);
|
||||||
|
setRefreshToken(data.refresh_token);
|
||||||
|
return data.access_token;
|
||||||
|
}
|
||||||
|
return null;
|
||||||
|
} catch {
|
||||||
|
return null;
|
||||||
|
} finally {
|
||||||
|
isRefreshing = false;
|
||||||
|
refreshPromise = null;
|
||||||
|
}
|
||||||
|
})();
|
||||||
|
|
||||||
|
return refreshPromise;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Authenticated fetch wrapper that handles 401 responses uniformly.
|
||||||
|
* On 401, attempts to refresh the token and retry the request.
|
||||||
|
* If refresh fails, redirects to login and saves the current path.
|
||||||
*/
|
*/
|
||||||
export async function authenticatedFetch(
|
export async function authenticatedFetch(
|
||||||
url: string,
|
url: string,
|
||||||
options?: RequestInit & { skipAuthRedirect?: boolean }
|
options?: RequestInit & { skipAuthRedirect?: boolean; skipRefresh?: boolean }
|
||||||
): Promise<Response> {
|
): Promise<Response> {
|
||||||
const { skipAuthRedirect = false, ...fetchOptions } = options || {};
|
const { skipAuthRedirect = false, skipRefresh = false, ...fetchOptions } = options || {};
|
||||||
|
|
||||||
const headers = getAuthHeaders(fetchOptions.headers as Record<string, string>);
|
const headers = getAuthHeaders(fetchOptions.headers as Record<string, string>);
|
||||||
|
|
||||||
|
|
@ -124,6 +248,23 @@ export async function authenticatedFetch(
|
||||||
|
|
||||||
// Handle 401 Unauthorized
|
// Handle 401 Unauthorized
|
||||||
if (response.status === 401 && !skipAuthRedirect) {
|
if (response.status === 401 && !skipAuthRedirect) {
|
||||||
|
// Try to refresh the token (unless skipRefresh is set to prevent infinite loops)
|
||||||
|
if (!skipRefresh) {
|
||||||
|
const newToken = await refreshAccessToken();
|
||||||
|
if (newToken) {
|
||||||
|
// Retry the original request with the new token
|
||||||
|
const retryHeaders = {
|
||||||
|
...(fetchOptions.headers as Record<string, string>),
|
||||||
|
Authorization: `Bearer ${newToken}`,
|
||||||
|
};
|
||||||
|
return fetch(url, {
|
||||||
|
...fetchOptions,
|
||||||
|
headers: retryHeaders,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Refresh failed or was skipped, redirect to login
|
||||||
handleUnauthorized();
|
handleUnauthorized();
|
||||||
throw new Error("Unauthorized: Redirecting to login page");
|
throw new Error("Unauthorized: Redirecting to login page");
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -700,6 +700,7 @@
|
||||||
"dark": "Dark",
|
"dark": "Dark",
|
||||||
"system": "System",
|
"system": "System",
|
||||||
"logout": "Logout",
|
"logout": "Logout",
|
||||||
|
"loggingOut": "Logging out...",
|
||||||
"inbox": "Inbox",
|
"inbox": "Inbox",
|
||||||
"search_inbox": "Search inbox",
|
"search_inbox": "Search inbox",
|
||||||
"mark_all_read": "Mark all as read",
|
"mark_all_read": "Mark all as read",
|
||||||
|
|
|
||||||
|
|
@ -685,6 +685,7 @@
|
||||||
"dark": "深色",
|
"dark": "深色",
|
||||||
"system": "系统",
|
"system": "系统",
|
||||||
"logout": "退出登录",
|
"logout": "退出登录",
|
||||||
|
"loggingOut": "正在退出...",
|
||||||
"inbox": "收件箱",
|
"inbox": "收件箱",
|
||||||
"search_inbox": "搜索收件箱",
|
"search_inbox": "搜索收件箱",
|
||||||
"mark_all_read": "全部标记为已读",
|
"mark_all_read": "全部标记为已读",
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue