mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-06-28 21:49:40 +02:00
fix(auth):enforce session auth cutover
This commit is contained in:
parent
fbecbb98b5
commit
62c7efb216
4 changed files with 214 additions and 87 deletions
|
|
@ -2,9 +2,10 @@
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from datetime import UTC, datetime
|
from datetime import UTC, datetime
|
||||||
|
from types import SimpleNamespace
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
import jwt
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Request, Response, status
|
from fastapi import APIRouter, Depends, HTTPException, Request, Response, status
|
||||||
from google.auth.transport import requests as google_requests
|
from google.auth.transport import requests as google_requests
|
||||||
from google.oauth2 import id_token as google_id_token
|
from google.oauth2 import id_token as google_id_token
|
||||||
|
|
@ -12,20 +13,31 @@ from sqlalchemy import select
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.auth.context import AuthContext
|
from app.auth.context import AuthContext
|
||||||
from app.auth.session_cookies import clear_session, read_refresh, write_session
|
from app.auth.session_cookies import (
|
||||||
|
access_expires_at,
|
||||||
|
clear_session,
|
||||||
|
issue,
|
||||||
|
read_refresh,
|
||||||
|
)
|
||||||
from app.config import config
|
from app.config import config
|
||||||
from app.db import User, async_session_maker, get_async_session
|
from app.db import User, async_session_maker, get_async_session
|
||||||
from app.rate_limiter import limiter
|
from app.rate_limiter import limiter
|
||||||
from app.schemas.auth import (
|
from app.schemas.auth import (
|
||||||
|
DesktopLoginRequest,
|
||||||
|
DesktopSessionRequest,
|
||||||
LogoutAllResponse,
|
LogoutAllResponse,
|
||||||
LogoutRequest,
|
LogoutRequest,
|
||||||
LogoutResponse,
|
LogoutResponse,
|
||||||
DesktopSessionRequest,
|
|
||||||
RefreshTokenRequest,
|
RefreshTokenRequest,
|
||||||
RefreshTokenResponse,
|
RefreshTokenResponse,
|
||||||
SessionResponse,
|
SessionResponse,
|
||||||
)
|
)
|
||||||
from app.users import SECRET, UserManager, get_auth_context, get_jwt_strategy, get_user_manager
|
from app.users import (
|
||||||
|
UserManager,
|
||||||
|
get_auth_context,
|
||||||
|
get_jwt_strategy,
|
||||||
|
get_user_manager,
|
||||||
|
)
|
||||||
from app.utils.refresh_tokens import (
|
from app.utils.refresh_tokens import (
|
||||||
create_refresh_token,
|
create_refresh_token,
|
||||||
revoke_all_user_tokens,
|
revoke_all_user_tokens,
|
||||||
|
|
@ -40,36 +52,13 @@ router = APIRouter(prefix="/auth/jwt", tags=["auth"])
|
||||||
session_router = APIRouter(prefix="/auth", tags=["auth"])
|
session_router = APIRouter(prefix="/auth", tags=["auth"])
|
||||||
|
|
||||||
|
|
||||||
def _access_expires_at(access_token: str) -> int:
|
|
||||||
payload = jwt.decode(
|
|
||||||
access_token,
|
|
||||||
SECRET,
|
|
||||||
algorithms=["HS256"],
|
|
||||||
options={"verify_aud": False},
|
|
||||||
)
|
|
||||||
return int(payload["exp"])
|
|
||||||
|
|
||||||
|
|
||||||
def _request_access_token(request: Request) -> str | None:
|
|
||||||
cookie_token = request.cookies.get(config.SESSION_COOKIE_NAME)
|
|
||||||
if cookie_token:
|
|
||||||
return cookie_token
|
|
||||||
auth_header = request.headers.get("Authorization")
|
|
||||||
if not auth_header:
|
|
||||||
return None
|
|
||||||
scheme, _, token = auth_header.partition(" ")
|
|
||||||
if scheme.lower() == "bearer" and token:
|
|
||||||
return token
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
async def _load_user(user_id) -> User | None:
|
async def _load_user(user_id) -> User | None:
|
||||||
async with async_session_maker() as session:
|
async with async_session_maker() as session:
|
||||||
result = await session.execute(select(User).where(User.id == user_id))
|
result = await session.execute(select(User).where(User.id == user_id))
|
||||||
return result.scalars().first()
|
return result.scalars().first()
|
||||||
|
|
||||||
|
|
||||||
@router.post("/refresh", response_model=RefreshTokenResponse)
|
@router.post("/refresh", response_model=None)
|
||||||
@limiter.limit("30/minute")
|
@limiter.limit("30/minute")
|
||||||
async def refresh_access_token(
|
async def refresh_access_token(
|
||||||
request: Request,
|
request: Request,
|
||||||
|
|
@ -80,7 +69,7 @@ async def refresh_access_token(
|
||||||
Exchange a valid refresh token for a new access token and refresh token.
|
Exchange a valid refresh token for a new access token and refresh token.
|
||||||
Implements token rotation for security.
|
Implements token rotation for security.
|
||||||
"""
|
"""
|
||||||
refresh_token = read_refresh(request, body)
|
refresh_token, mode = read_refresh(request, body)
|
||||||
if not refresh_token:
|
if not refresh_token:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
|
@ -101,19 +90,18 @@ async def refresh_access_token(
|
||||||
detail="User not found",
|
detail="User not found",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Generate new access token
|
|
||||||
strategy = get_jwt_strategy()
|
strategy = get_jwt_strategy()
|
||||||
access_token = await strategy.write_token(user)
|
access_token = await strategy.write_token(user)
|
||||||
|
|
||||||
if request.cookies.get(config.REFRESH_COOKIE_NAME) and rotation.refresh_token:
|
|
||||||
write_session(response, access_token, rotation.refresh_token, request)
|
|
||||||
|
|
||||||
logger.info(f"Refreshed token for user {user.id}")
|
logger.info(f"Refreshed token for user {user.id}")
|
||||||
|
|
||||||
return RefreshTokenResponse(
|
return issue(
|
||||||
access_token=access_token,
|
response,
|
||||||
refresh_token=rotation.refresh_token,
|
mode,
|
||||||
access_expires_at=_access_expires_at(access_token),
|
access=access_token,
|
||||||
|
refresh=rotation.refresh_token,
|
||||||
|
access_expires_at=access_expires_at(access_token),
|
||||||
|
request=request,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -127,7 +115,7 @@ async def revoke_token(
|
||||||
Logout current device by revoking the provided refresh token.
|
Logout current device by revoking the provided refresh token.
|
||||||
Does not require authentication - just the refresh token.
|
Does not require authentication - just the refresh token.
|
||||||
"""
|
"""
|
||||||
refresh_token = read_refresh(request, body)
|
refresh_token, _mode = read_refresh(request, body)
|
||||||
revoked = await revoke_refresh_token(refresh_token) if refresh_token else False
|
revoked = await revoke_refresh_token(refresh_token) if refresh_token else False
|
||||||
clear_session(response, request)
|
clear_session(response, request)
|
||||||
if revoked:
|
if revoked:
|
||||||
|
|
@ -158,7 +146,7 @@ async def logout_all_devices(
|
||||||
user = None
|
user = None
|
||||||
|
|
||||||
if user is None:
|
if user is None:
|
||||||
refresh_token = read_refresh(request, body)
|
refresh_token, _mode = read_refresh(request, body)
|
||||||
token_record = await validate_refresh_token(refresh_token) if refresh_token else None
|
token_record = await validate_refresh_token(refresh_token) if refresh_token else None
|
||||||
if token_record:
|
if token_record:
|
||||||
user = await _load_user(token_record.user_id)
|
user = await _load_user(token_record.user_id)
|
||||||
|
|
@ -178,12 +166,55 @@ async def logout_all_devices(
|
||||||
@session_router.get("/session", response_model=SessionResponse)
|
@session_router.get("/session", response_model=SessionResponse)
|
||||||
async def get_session(
|
async def get_session(
|
||||||
request: Request,
|
request: Request,
|
||||||
_auth: AuthContext = Depends(get_auth_context),
|
auth: AuthContext = Depends(get_auth_context),
|
||||||
):
|
):
|
||||||
access_token = _request_access_token(request)
|
if auth.method == "pat":
|
||||||
if not access_token:
|
return SessionResponse(access_expires_at=None)
|
||||||
|
|
||||||
|
access_token = request.cookies.get(config.SESSION_COOKIE_NAME)
|
||||||
|
if access_token is None:
|
||||||
|
auth_header = request.headers.get("Authorization")
|
||||||
|
if auth_header:
|
||||||
|
scheme, _, token = auth_header.partition(" ")
|
||||||
|
if scheme.lower() == "bearer" and token:
|
||||||
|
access_token = token
|
||||||
|
|
||||||
|
if access_token is None:
|
||||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Unauthorized")
|
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Unauthorized")
|
||||||
return SessionResponse(access_expires_at=_access_expires_at(access_token))
|
return SessionResponse(access_expires_at=access_expires_at(access_token))
|
||||||
|
|
||||||
|
|
||||||
|
@session_router.post("/desktop/login", response_model=RefreshTokenResponse)
|
||||||
|
@limiter.limit("5/minute")
|
||||||
|
async def desktop_password_login(
|
||||||
|
request: Request,
|
||||||
|
body: DesktopLoginRequest,
|
||||||
|
user_manager: UserManager = Depends(get_user_manager),
|
||||||
|
):
|
||||||
|
if config.AUTH_TYPE == "GOOGLE":
|
||||||
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Not found")
|
||||||
|
if not config.REGISTRATION_ENABLED:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
detail="Registration is disabled",
|
||||||
|
)
|
||||||
|
|
||||||
|
credentials = SimpleNamespace(username=body.email, password=body.password)
|
||||||
|
user = await user_manager.authenticate(credentials)
|
||||||
|
if user is None or not user.is_active:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="LOGIN_BAD_CREDENTIALS",
|
||||||
|
)
|
||||||
|
|
||||||
|
app_access_token = await get_jwt_strategy().write_token(user)
|
||||||
|
app_refresh_token = await create_refresh_token(user.id)
|
||||||
|
await user_manager.on_after_login(user, request, None)
|
||||||
|
return RefreshTokenResponse(
|
||||||
|
access_token=app_access_token,
|
||||||
|
refresh_token=app_refresh_token,
|
||||||
|
access_expires_at=access_expires_at(app_access_token),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@session_router.post("/desktop/session", response_model=RefreshTokenResponse)
|
@session_router.post("/desktop/session", response_model=RefreshTokenResponse)
|
||||||
|
|
@ -193,7 +224,17 @@ async def create_desktop_session(
|
||||||
body: DesktopSessionRequest,
|
body: DesktopSessionRequest,
|
||||||
user_manager: UserManager = Depends(get_user_manager),
|
user_manager: UserManager = Depends(get_user_manager),
|
||||||
):
|
):
|
||||||
if not body.redirect_uri.startswith("http://127.0.0.1:"):
|
parsed_redirect = urlparse(body.redirect_uri)
|
||||||
|
try:
|
||||||
|
redirect_port = parsed_redirect.port
|
||||||
|
except ValueError:
|
||||||
|
redirect_port = None
|
||||||
|
if not (
|
||||||
|
parsed_redirect.scheme == "http"
|
||||||
|
and parsed_redirect.hostname in {"127.0.0.1", "::1"}
|
||||||
|
and redirect_port is not None
|
||||||
|
and parsed_redirect.path == "/callback"
|
||||||
|
):
|
||||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid redirect URI")
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid redirect URI")
|
||||||
if not config.GOOGLE_DESKTOP_CLIENT_ID:
|
if not config.GOOGLE_DESKTOP_CLIENT_ID:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
|
|
@ -238,6 +279,7 @@ async def create_desktop_session(
|
||||||
if not claims.get("sub") or not claims.get("email"):
|
if not claims.get("sub") or not claims.get("email"):
|
||||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid Google identity token")
|
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid Google identity token")
|
||||||
|
|
||||||
|
email_verified = bool(claims.get("email_verified"))
|
||||||
user = await user_manager.oauth_callback(
|
user = await user_manager.oauth_callback(
|
||||||
"google",
|
"google",
|
||||||
access_token,
|
access_token,
|
||||||
|
|
@ -250,13 +292,13 @@ async def create_desktop_session(
|
||||||
),
|
),
|
||||||
refresh_token=token_data.get("refresh_token"),
|
refresh_token=token_data.get("refresh_token"),
|
||||||
request=request,
|
request=request,
|
||||||
associate_by_email=True,
|
associate_by_email=email_verified,
|
||||||
is_verified_by_default=True,
|
is_verified_by_default=email_verified,
|
||||||
)
|
)
|
||||||
app_access_token = await get_jwt_strategy().write_token(user)
|
app_access_token = await get_jwt_strategy().write_token(user)
|
||||||
app_refresh_token = await create_refresh_token(user.id)
|
app_refresh_token = await create_refresh_token(user.id)
|
||||||
return RefreshTokenResponse(
|
return RefreshTokenResponse(
|
||||||
access_token=app_access_token,
|
access_token=app_access_token,
|
||||||
refresh_token=app_refresh_token,
|
refresh_token=app_refresh_token,
|
||||||
access_expires_at=_access_expires_at(app_access_token),
|
access_expires_at=access_expires_at(app_access_token),
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -38,10 +38,15 @@ class LogoutAllResponse(BaseModel):
|
||||||
|
|
||||||
class SessionResponse(BaseModel):
|
class SessionResponse(BaseModel):
|
||||||
authenticated: bool = True
|
authenticated: bool = True
|
||||||
access_expires_at: int
|
access_expires_at: int | None = None
|
||||||
|
|
||||||
|
|
||||||
class DesktopSessionRequest(BaseModel):
|
class DesktopSessionRequest(BaseModel):
|
||||||
code: str
|
code: str
|
||||||
code_verifier: str
|
code_verifier: str
|
||||||
redirect_uri: str
|
redirect_uri: str
|
||||||
|
|
||||||
|
|
||||||
|
class DesktopLoginRequest(BaseModel):
|
||||||
|
email: str
|
||||||
|
password: str
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,7 @@ import uuid
|
||||||
from datetime import UTC, datetime
|
from datetime import UTC, datetime
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
import jwt
|
||||||
from fastapi import Depends, HTTPException, Request, Response, status
|
from fastapi import Depends, HTTPException, Request, Response, status
|
||||||
from fastapi.responses import JSONResponse, RedirectResponse
|
from fastapi.responses import JSONResponse, RedirectResponse
|
||||||
from fastapi_users import BaseUserManager, FastAPIUsers, UUIDIDMixin, models
|
from fastapi_users import BaseUserManager, FastAPIUsers, UUIDIDMixin, models
|
||||||
|
|
@ -12,12 +13,12 @@ from fastapi_users.authentication import (
|
||||||
JWTStrategy,
|
JWTStrategy,
|
||||||
)
|
)
|
||||||
from fastapi_users.db import SQLAlchemyUserDatabase
|
from fastapi_users.db import SQLAlchemyUserDatabase
|
||||||
from pydantic import BaseModel
|
from fastapi_users.jwt import generate_jwt
|
||||||
from sqlalchemy import update
|
from sqlalchemy import update
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.auth.context import AuthContext
|
from app.auth.context import AuthContext
|
||||||
from app.auth.session_cookies import write_session
|
from app.auth.session_cookies import access_expires_at, write_session
|
||||||
from app.config import config
|
from app.config import config
|
||||||
from app.db import (
|
from app.db import (
|
||||||
Prompt,
|
Prompt,
|
||||||
|
|
@ -37,13 +38,6 @@ from app.utils.refresh_tokens import create_refresh_token
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class BearerResponse(BaseModel):
|
|
||||||
access_token: str
|
|
||||||
refresh_token: str
|
|
||||||
token_type: str
|
|
||||||
access_expires_at: int
|
|
||||||
|
|
||||||
|
|
||||||
SECRET = config.SECRET_KEY
|
SECRET = config.SECRET_KEY
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -232,8 +226,23 @@ async def get_user_manager(user_db: SQLAlchemyUserDatabase = Depends(get_user_db
|
||||||
yield UserManager(user_db)
|
yield UserManager(user_db)
|
||||||
|
|
||||||
|
|
||||||
|
class IatJWTStrategy(JWTStrategy[models.UP, models.ID]):
|
||||||
|
async def write_token(self, user: models.UP) -> str:
|
||||||
|
data = {
|
||||||
|
"sub": str(user.id),
|
||||||
|
"aud": self.token_audience,
|
||||||
|
"iat": int(datetime.now(UTC).timestamp()),
|
||||||
|
}
|
||||||
|
return generate_jwt(
|
||||||
|
data,
|
||||||
|
self.encode_key,
|
||||||
|
self.lifetime_seconds,
|
||||||
|
algorithm=self.algorithm,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_jwt_strategy() -> JWTStrategy[models.UP, models.ID]:
|
def get_jwt_strategy() -> JWTStrategy[models.UP, models.ID]:
|
||||||
return JWTStrategy(
|
return IatJWTStrategy(
|
||||||
secret=SECRET,
|
secret=SECRET,
|
||||||
lifetime_seconds=config.ACCESS_TOKEN_LIFETIME_SECONDS,
|
lifetime_seconds=config.ACCESS_TOKEN_LIFETIME_SECONDS,
|
||||||
)
|
)
|
||||||
|
|
@ -262,48 +271,34 @@ def get_jwt_strategy() -> JWTStrategy[models.UP, models.ID]:
|
||||||
# BEARER AUTH CODE.
|
# BEARER AUTH CODE.
|
||||||
class CustomBearerTransport(BearerTransport):
|
class CustomBearerTransport(BearerTransport):
|
||||||
async def get_login_response(self, token: str) -> Response:
|
async def get_login_response(self, token: str) -> Response:
|
||||||
import jwt
|
|
||||||
|
|
||||||
# Decode JWT to get user_id for refresh token creation
|
|
||||||
access_expires_at = 0
|
|
||||||
try:
|
try:
|
||||||
payload = jwt.decode(
|
payload = jwt.decode(
|
||||||
token, SECRET, algorithms=["HS256"], options={"verify_aud": False}
|
token, SECRET, algorithms=["HS256"], options={"verify_aud": False}
|
||||||
)
|
)
|
||||||
access_expires_at = int(payload["exp"])
|
|
||||||
user_id = uuid.UUID(payload.get("sub"))
|
user_id = uuid.UUID(payload.get("sub"))
|
||||||
refresh_token = await create_refresh_token(user_id)
|
refresh_token = await create_refresh_token(user_id)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to create refresh token: {e}")
|
logger.error(f"Failed to create refresh token: {e}")
|
||||||
# Fall back to response without refresh token
|
raise HTTPException(
|
||||||
refresh_token = ""
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail="Failed to create session",
|
||||||
bearer_response = BearerResponse(
|
) from e
|
||||||
access_token=token,
|
|
||||||
refresh_token=refresh_token,
|
|
||||||
token_type="bearer",
|
|
||||||
access_expires_at=access_expires_at,
|
|
||||||
)
|
|
||||||
|
|
||||||
if config.AUTH_TYPE == "GOOGLE":
|
if config.AUTH_TYPE == "GOOGLE":
|
||||||
response = RedirectResponse(
|
response = RedirectResponse(
|
||||||
f"{config.NEXT_FRONTEND_URL}/auth/callback",
|
f"{config.NEXT_FRONTEND_URL}/dashboard",
|
||||||
status_code=302,
|
status_code=302,
|
||||||
)
|
)
|
||||||
write_session(
|
|
||||||
response,
|
|
||||||
bearer_response.access_token,
|
|
||||||
bearer_response.refresh_token,
|
|
||||||
)
|
|
||||||
return response
|
|
||||||
else:
|
else:
|
||||||
response = JSONResponse(bearer_response.model_dump())
|
response = JSONResponse(
|
||||||
write_session(
|
{
|
||||||
response,
|
"authenticated": True,
|
||||||
bearer_response.access_token,
|
"access_expires_at": access_expires_at(token),
|
||||||
bearer_response.refresh_token,
|
}
|
||||||
)
|
)
|
||||||
return response
|
|
||||||
|
write_session(response, token, refresh_token)
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
bearer_transport = CustomBearerTransport(tokenUrl="auth/jwt/login")
|
bearer_transport = CustomBearerTransport(tokenUrl="auth/jwt/login")
|
||||||
|
|
@ -318,6 +313,22 @@ auth_backend = AuthenticationBackend(
|
||||||
fastapi_users = FastAPIUsers[User, uuid.UUID](get_user_manager, [auth_backend])
|
fastapi_users = FastAPIUsers[User, uuid.UUID](get_user_manager, [auth_backend])
|
||||||
|
|
||||||
|
|
||||||
|
def _token_meets_epoch(token: str) -> bool:
|
||||||
|
min_issued_at = config.MIN_ISSUED_AT
|
||||||
|
if min_issued_at <= 0:
|
||||||
|
return True
|
||||||
|
|
||||||
|
try:
|
||||||
|
payload = jwt.decode(
|
||||||
|
token, SECRET, algorithms=["HS256"], options={"verify_aud": False}
|
||||||
|
)
|
||||||
|
except jwt.PyJWTError:
|
||||||
|
return False
|
||||||
|
|
||||||
|
issued_at = payload.get("iat")
|
||||||
|
return isinstance(issued_at, (int, float)) and int(issued_at) >= min_issued_at
|
||||||
|
|
||||||
|
|
||||||
async def get_auth_context(
|
async def get_auth_context(
|
||||||
request: Request,
|
request: Request,
|
||||||
session: AsyncSession = Depends(get_async_session),
|
session: AsyncSession = Depends(get_async_session),
|
||||||
|
|
@ -341,7 +352,7 @@ async def get_auth_context(
|
||||||
maybe_touch_last_used(pat)
|
maybe_touch_last_used(pat)
|
||||||
return AuthContext.pat_auth(pat.user, pat)
|
return AuthContext.pat_auth(pat.user, pat)
|
||||||
|
|
||||||
if is_bearer:
|
if is_bearer and _token_meets_epoch(token):
|
||||||
try:
|
try:
|
||||||
user = await get_jwt_strategy().read_token(token, user_manager)
|
user = await get_jwt_strategy().read_token(token, user_manager)
|
||||||
except Exception:
|
except Exception:
|
||||||
|
|
@ -352,7 +363,7 @@ async def get_auth_context(
|
||||||
return AuthContext.session(user)
|
return AuthContext.session(user)
|
||||||
|
|
||||||
cookie_token = request.cookies.get(config.SESSION_COOKIE_NAME)
|
cookie_token = request.cookies.get(config.SESSION_COOKIE_NAME)
|
||||||
if cookie_token:
|
if cookie_token and _token_meets_epoch(cookie_token):
|
||||||
try:
|
try:
|
||||||
user = await get_jwt_strategy().read_token(cookie_token, user_manager)
|
user = await get_jwt_strategy().read_token(cookie_token, user_manager)
|
||||||
except Exception:
|
except Exception:
|
||||||
|
|
|
||||||
69
surfsense_backend/scripts/revoke_refresh_tokens_cutover.py
Normal file
69
surfsense_backend/scripts/revoke_refresh_tokens_cutover.py
Normal file
|
|
@ -0,0 +1,69 @@
|
||||||
|
"""One-shot cutover helper to revoke every refresh token.
|
||||||
|
|
||||||
|
Run with --yes during the auth-hardening cutover, alongside setting
|
||||||
|
MIN_ISSUED_AT to the deploy epoch.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
from sqlalchemy import text
|
||||||
|
|
||||||
|
from app.db import async_session_maker
|
||||||
|
|
||||||
|
|
||||||
|
async def _count_active_tokens() -> int:
|
||||||
|
async with async_session_maker() as session:
|
||||||
|
result = await session.execute(
|
||||||
|
text(
|
||||||
|
"""
|
||||||
|
SELECT count(*)
|
||||||
|
FROM refresh_tokens
|
||||||
|
WHERE revoked_at IS NULL
|
||||||
|
AND expires_at > NOW()
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return int(result.scalar_one())
|
||||||
|
|
||||||
|
|
||||||
|
async def _revoke_all_tokens() -> int:
|
||||||
|
async with async_session_maker() as session:
|
||||||
|
result = await session.execute(
|
||||||
|
text(
|
||||||
|
"""
|
||||||
|
UPDATE refresh_tokens
|
||||||
|
SET revoked_at = NOW(),
|
||||||
|
expires_at = NOW()
|
||||||
|
WHERE revoked_at IS NULL
|
||||||
|
OR expires_at > NOW()
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
)
|
||||||
|
await session.commit()
|
||||||
|
return int(result.rowcount or 0)
|
||||||
|
|
||||||
|
|
||||||
|
async def main() -> None:
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"--yes",
|
||||||
|
action="store_true",
|
||||||
|
help="Actually revoke tokens. Without this flag the command is a dry run.",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
active_count = await _count_active_tokens()
|
||||||
|
if not args.yes:
|
||||||
|
print(f"Dry run: {active_count} active refresh token(s) would be revoked.")
|
||||||
|
print("Re-run with --yes during the auth-hardening cutover to revoke them.")
|
||||||
|
return
|
||||||
|
|
||||||
|
updated_count = await _revoke_all_tokens()
|
||||||
|
print(f"Revoked {updated_count} refresh token row(s).")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(main())
|
||||||
Loading…
Add table
Add a link
Reference in a new issue