mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-04-25 00:36:31 +02:00
Switch refresh token storage from cookies to localStorage
This commit is contained in:
parent
f3a9922eb9
commit
233852b681
7 changed files with 160 additions and 88 deletions
|
|
@ -2,17 +2,18 @@
|
|||
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter, Cookie, Depends, HTTPException, Response, status
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.db import User, async_session_maker
|
||||
from app.schemas.auth import LogoutAllResponse, LogoutResponse, RefreshTokenResponse
|
||||
from app.users import current_active_user, get_jwt_strategy
|
||||
from app.utils.auth_cookies import (
|
||||
REFRESH_TOKEN_COOKIE_NAME,
|
||||
delete_refresh_token_cookie,
|
||||
set_refresh_token_cookie,
|
||||
from app.schemas.auth import (
|
||||
LogoutAllResponse,
|
||||
LogoutRequest,
|
||||
LogoutResponse,
|
||||
RefreshTokenRequest,
|
||||
RefreshTokenResponse,
|
||||
)
|
||||
from app.users import current_active_user, get_jwt_strategy
|
||||
from app.utils.refresh_tokens import (
|
||||
revoke_all_user_tokens,
|
||||
revoke_refresh_token,
|
||||
|
|
@ -26,21 +27,12 @@ router = APIRouter(prefix="/auth/jwt", tags=["auth"])
|
|||
|
||||
|
||||
@router.post("/refresh", response_model=RefreshTokenResponse)
|
||||
async def refresh_access_token(
|
||||
response: Response,
|
||||
refresh_token: str | None = Cookie(default=None, alias=REFRESH_TOKEN_COOKIE_NAME),
|
||||
):
|
||||
async def refresh_access_token(request: RefreshTokenRequest):
|
||||
"""
|
||||
Exchange a valid refresh token for a new access token and refresh token.
|
||||
Reads refresh token from HTTP-only cookie. Implements token rotation for security.
|
||||
Implements token rotation for security.
|
||||
"""
|
||||
if not refresh_token:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Refresh token not found",
|
||||
)
|
||||
|
||||
token_record = await validate_refresh_token(refresh_token)
|
||||
token_record = await validate_refresh_token(request.refresh_token)
|
||||
|
||||
if not token_record:
|
||||
raise HTTPException(
|
||||
|
|
@ -68,9 +60,6 @@ async def refresh_access_token(
|
|||
# Rotate refresh token
|
||||
new_refresh_token = await rotate_refresh_token(token_record)
|
||||
|
||||
# Set the new refresh token in cookie
|
||||
set_refresh_token_cookie(response, new_refresh_token)
|
||||
|
||||
logger.info(f"Refreshed token for user {user.id}")
|
||||
|
||||
return RefreshTokenResponse(
|
||||
|
|
@ -80,36 +69,21 @@ async def refresh_access_token(
|
|||
|
||||
|
||||
@router.post("/logout", response_model=LogoutResponse)
|
||||
async def logout(
|
||||
response: Response,
|
||||
refresh_token: str | None = Cookie(default=None, alias=REFRESH_TOKEN_COOKIE_NAME),
|
||||
):
|
||||
async def logout(request: LogoutRequest):
|
||||
"""
|
||||
Logout current device by revoking the refresh token from cookie.
|
||||
Logout current device by revoking the provided refresh token.
|
||||
"""
|
||||
if refresh_token:
|
||||
await revoke_refresh_token(refresh_token)
|
||||
|
||||
# Always delete the cookie
|
||||
delete_refresh_token_cookie(response)
|
||||
|
||||
await revoke_refresh_token(request.refresh_token)
|
||||
logger.info("User logged out from current device")
|
||||
return LogoutResponse()
|
||||
|
||||
|
||||
@router.post("/logout-all", response_model=LogoutAllResponse)
|
||||
async def logout_all_devices(
|
||||
response: Response,
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
async def logout_all_devices(user: User = Depends(current_active_user)):
|
||||
"""
|
||||
Logout from all devices by revoking all refresh tokens for the user.
|
||||
Requires valid access token.
|
||||
"""
|
||||
await revoke_all_user_tokens(user.id)
|
||||
|
||||
# Delete the cookie on current device
|
||||
delete_refresh_token_cookie(response)
|
||||
|
||||
logger.info(f"User {user.id} logged out from all devices")
|
||||
return LogoutAllResponse()
|
||||
|
|
|
|||
|
|
@ -1,4 +1,10 @@
|
|||
from .auth import LogoutAllResponse, LogoutResponse, RefreshTokenResponse
|
||||
from .auth import (
|
||||
LogoutAllResponse,
|
||||
LogoutRequest,
|
||||
LogoutResponse,
|
||||
RefreshTokenRequest,
|
||||
RefreshTokenResponse,
|
||||
)
|
||||
from .base import IDModel, TimestampModel
|
||||
from .chunks import ChunkBase, ChunkCreate, ChunkRead, ChunkUpdate
|
||||
from .documents import (
|
||||
|
|
@ -120,6 +126,7 @@ __all__ = [
|
|||
"LogUpdate",
|
||||
# Auth schemas
|
||||
"LogoutAllResponse",
|
||||
"LogoutRequest",
|
||||
"LogoutResponse",
|
||||
# Search source connector schemas
|
||||
"MCPConnectorCreate",
|
||||
|
|
@ -150,6 +157,7 @@ __all__ = [
|
|||
"PodcastCreate",
|
||||
"PodcastRead",
|
||||
"PodcastUpdate",
|
||||
"RefreshTokenRequest",
|
||||
"RefreshTokenResponse",
|
||||
"RoleCreate",
|
||||
"RoleRead",
|
||||
|
|
|
|||
|
|
@ -3,6 +3,12 @@
|
|||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class RefreshTokenRequest(BaseModel):
|
||||
"""Request body for token refresh endpoint."""
|
||||
|
||||
refresh_token: str
|
||||
|
||||
|
||||
class RefreshTokenResponse(BaseModel):
|
||||
"""Response from token refresh endpoint."""
|
||||
|
||||
|
|
@ -11,6 +17,12 @@ class RefreshTokenResponse(BaseModel):
|
|||
token_type: str = "bearer"
|
||||
|
||||
|
||||
class LogoutRequest(BaseModel):
|
||||
"""Request body for logout endpoint (current device)."""
|
||||
|
||||
refresh_token: str
|
||||
|
||||
|
||||
class LogoutResponse(BaseModel):
|
||||
"""Response from logout endpoint (current device)."""
|
||||
|
||||
|
|
|
|||
|
|
@ -23,7 +23,6 @@ from app.db import (
|
|||
get_default_roles_config,
|
||||
get_user_db,
|
||||
)
|
||||
from app.utils.auth_cookies import set_refresh_token_cookie
|
||||
from app.utils.refresh_tokens import create_refresh_token
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -238,16 +237,11 @@ class CustomBearerTransport(BearerTransport):
|
|||
redirect_url = (
|
||||
f"{config.NEXT_FRONTEND_URL}/auth/callback"
|
||||
f"?token={bearer_response.access_token}"
|
||||
f"&refresh_token={bearer_response.refresh_token}"
|
||||
)
|
||||
response = RedirectResponse(redirect_url, status_code=302)
|
||||
return RedirectResponse(redirect_url, status_code=302)
|
||||
else:
|
||||
response = JSONResponse(bearer_response.model_dump())
|
||||
|
||||
# Set refresh token as HTTP-only cookie
|
||||
if refresh_token:
|
||||
set_refresh_token_cookie(response, refresh_token)
|
||||
|
||||
return response
|
||||
return JSONResponse(bearer_response.model_dump())
|
||||
|
||||
|
||||
bearer_transport = CustomBearerTransport(tokenUrl="auth/jwt/login")
|
||||
|
|
|
|||
|
|
@ -1,29 +0,0 @@
|
|||
"""Utilities for managing authentication cookies."""
|
||||
|
||||
from fastapi import Response
|
||||
|
||||
from app.config import config
|
||||
|
||||
REFRESH_TOKEN_COOKIE_NAME = "refresh_token"
|
||||
|
||||
|
||||
def set_refresh_token_cookie(response: Response, token: str) -> None:
|
||||
"""Set the refresh token as an HTTP-only cookie."""
|
||||
response.set_cookie(
|
||||
key=REFRESH_TOKEN_COOKIE_NAME,
|
||||
value=token,
|
||||
max_age=config.REFRESH_TOKEN_LIFETIME_SECONDS,
|
||||
httponly=True,
|
||||
secure=True, # Only send over HTTPS
|
||||
samesite="lax",
|
||||
)
|
||||
|
||||
|
||||
def delete_refresh_token_cookie(response: Response) -> None:
|
||||
"""Delete the refresh token cookie."""
|
||||
response.delete_cookie(
|
||||
key=REFRESH_TOKEN_COOKIE_NAME,
|
||||
httponly=True,
|
||||
secure=True,
|
||||
samesite="lax",
|
||||
)
|
||||
|
|
@ -3,7 +3,7 @@
|
|||
import { useSearchParams } from "next/navigation";
|
||||
import { useEffect } from "react";
|
||||
import { useGlobalLoadingEffect } from "@/hooks/use-global-loading";
|
||||
import { getAndClearRedirectPath, setBearerToken } from "@/lib/auth-utils";
|
||||
import { getAndClearRedirectPath, setBearerToken, setRefreshToken } from "@/lib/auth-utils";
|
||||
import { trackLoginSuccess } from "@/lib/posthog/events";
|
||||
|
||||
interface TokenHandlerProps {
|
||||
|
|
@ -35,8 +35,9 @@ const TokenHandler = ({
|
|||
// Only run on client-side
|
||||
if (typeof window === "undefined") return;
|
||||
|
||||
// Get token from URL parameters
|
||||
// Get tokens from URL parameters
|
||||
const token = searchParams.get(tokenParamName);
|
||||
const refreshToken = searchParams.get("refresh_token");
|
||||
|
||||
if (token) {
|
||||
try {
|
||||
|
|
@ -50,10 +51,15 @@ const TokenHandler = ({
|
|||
// Clear the flag for future logins
|
||||
sessionStorage.removeItem("login_success_tracked");
|
||||
|
||||
// Store token in localStorage using both methods for compatibility
|
||||
// Store access token in localStorage using both methods for compatibility
|
||||
localStorage.setItem(storageKey, token);
|
||||
setBearerToken(token);
|
||||
|
||||
// Store refresh token if provided
|
||||
if (refreshToken) {
|
||||
setRefreshToken(refreshToken);
|
||||
}
|
||||
|
||||
// Check if there's a saved redirect path from before the auth flow
|
||||
const savedRedirectPath = getAndClearRedirectPath();
|
||||
|
||||
|
|
|
|||
|
|
@ -4,6 +4,11 @@
|
|||
|
||||
const REDIRECT_PATH_KEY = "surfsense_redirect_path";
|
||||
const BEARER_TOKEN_KEY = "surfsense_bearer_token";
|
||||
const REFRESH_TOKEN_KEY = "surfsense_refresh_token";
|
||||
|
||||
// Flag to prevent multiple simultaneous refresh attempts
|
||||
let isRefreshing = false;
|
||||
let refreshPromise: Promise<string | null> | null = null;
|
||||
|
||||
/**
|
||||
* Saves the current path and redirects to login page
|
||||
|
|
@ -21,8 +26,9 @@ export function handleUnauthorized(): void {
|
|||
localStorage.setItem(REDIRECT_PATH_KEY, currentPath);
|
||||
}
|
||||
|
||||
// Clear the token
|
||||
// Clear both tokens
|
||||
localStorage.removeItem(BEARER_TOKEN_KEY);
|
||||
localStorage.removeItem(REFRESH_TOKEN_KEY);
|
||||
|
||||
// Redirect to home page (which has login options)
|
||||
window.location.href = "/login";
|
||||
|
|
@ -66,6 +72,38 @@ export function clearBearerToken(): void {
|
|||
localStorage.removeItem(BEARER_TOKEN_KEY);
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the refresh token from localStorage
|
||||
*/
|
||||
export function getRefreshToken(): string | null {
|
||||
if (typeof window === "undefined") return null;
|
||||
return localStorage.getItem(REFRESH_TOKEN_KEY);
|
||||
}
|
||||
|
||||
/**
|
||||
* Sets the refresh token in localStorage
|
||||
*/
|
||||
export function setRefreshToken(token: string): void {
|
||||
if (typeof window === "undefined") return;
|
||||
localStorage.setItem(REFRESH_TOKEN_KEY, token);
|
||||
}
|
||||
|
||||
/**
|
||||
* Clears the refresh token from localStorage
|
||||
*/
|
||||
export function clearRefreshToken(): void {
|
||||
if (typeof window === "undefined") return;
|
||||
localStorage.removeItem(REFRESH_TOKEN_KEY);
|
||||
}
|
||||
|
||||
/**
|
||||
* Clears all auth tokens from localStorage
|
||||
*/
|
||||
export function clearAllTokens(): void {
|
||||
clearBearerToken();
|
||||
clearRefreshToken();
|
||||
}
|
||||
|
||||
/**
|
||||
* Checks if the user is authenticated (has a token)
|
||||
*/
|
||||
|
|
@ -106,14 +144,66 @@ export function getAuthHeaders(additionalHeaders?: Record<string, string>): Reco
|
|||
}
|
||||
|
||||
/**
|
||||
* Authenticated fetch wrapper that handles 401 responses uniformly
|
||||
* Automatically redirects to login on 401 and saves the current path
|
||||
* Attempts to refresh the access token using the stored refresh token.
|
||||
* Returns the new access token if successful, null otherwise.
|
||||
*/
|
||||
async function refreshAccessToken(): Promise<string | null> {
|
||||
// If already refreshing, wait for that request to complete
|
||||
if (isRefreshing && refreshPromise) {
|
||||
return refreshPromise;
|
||||
}
|
||||
|
||||
const currentRefreshToken = getRefreshToken();
|
||||
if (!currentRefreshToken) {
|
||||
return null;
|
||||
}
|
||||
|
||||
isRefreshing = true;
|
||||
refreshPromise = (async () => {
|
||||
try {
|
||||
const backendUrl = process.env.NEXT_PUBLIC_BACKEND_URL || "";
|
||||
const response = await fetch(`${backendUrl}/auth/jwt/refresh`, {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify({ refresh_token: currentRefreshToken }),
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
// Refresh failed, clear tokens
|
||||
clearAllTokens();
|
||||
return null;
|
||||
}
|
||||
|
||||
const data = await response.json();
|
||||
if (data.access_token && data.refresh_token) {
|
||||
setBearerToken(data.access_token);
|
||||
setRefreshToken(data.refresh_token);
|
||||
return data.access_token;
|
||||
}
|
||||
return null;
|
||||
} catch {
|
||||
return null;
|
||||
} finally {
|
||||
isRefreshing = false;
|
||||
refreshPromise = null;
|
||||
}
|
||||
})();
|
||||
|
||||
return refreshPromise;
|
||||
}
|
||||
|
||||
/**
|
||||
* Authenticated fetch wrapper that handles 401 responses uniformly.
|
||||
* On 401, attempts to refresh the token and retry the request.
|
||||
* If refresh fails, redirects to login and saves the current path.
|
||||
*/
|
||||
export async function authenticatedFetch(
|
||||
url: string,
|
||||
options?: RequestInit & { skipAuthRedirect?: boolean }
|
||||
options?: RequestInit & { skipAuthRedirect?: boolean; skipRefresh?: boolean }
|
||||
): Promise<Response> {
|
||||
const { skipAuthRedirect = false, ...fetchOptions } = options || {};
|
||||
const { skipAuthRedirect = false, skipRefresh = false, ...fetchOptions } = options || {};
|
||||
|
||||
const headers = getAuthHeaders(fetchOptions.headers as Record<string, string>);
|
||||
|
||||
|
|
@ -124,6 +214,23 @@ export async function authenticatedFetch(
|
|||
|
||||
// Handle 401 Unauthorized
|
||||
if (response.status === 401 && !skipAuthRedirect) {
|
||||
// Try to refresh the token (unless skipRefresh is set to prevent infinite loops)
|
||||
if (!skipRefresh) {
|
||||
const newToken = await refreshAccessToken();
|
||||
if (newToken) {
|
||||
// Retry the original request with the new token
|
||||
const retryHeaders = {
|
||||
...(fetchOptions.headers as Record<string, string>),
|
||||
Authorization: `Bearer ${newToken}`,
|
||||
};
|
||||
return fetch(url, {
|
||||
...fetchOptions,
|
||||
headers: retryHeaders,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Refresh failed or was skipped, redirect to login
|
||||
handleUnauthorized();
|
||||
throw new Error("Unauthorized: Redirecting to login page");
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue