mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-06-26 21:39:43 +02:00
fix(auth):enforce session auth cutover
This commit is contained in:
parent
fbecbb98b5
commit
62c7efb216
4 changed files with 214 additions and 87 deletions
|
|
@ -2,9 +2,10 @@
|
|||
|
||||
import logging
|
||||
from datetime import UTC, datetime
|
||||
from types import SimpleNamespace
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import httpx
|
||||
import jwt
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, Response, status
|
||||
from google.auth.transport import requests as google_requests
|
||||
from google.oauth2 import id_token as google_id_token
|
||||
|
|
@ -12,20 +13,31 @@ from sqlalchemy import select
|
|||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.auth.context import AuthContext
|
||||
from app.auth.session_cookies import clear_session, read_refresh, write_session
|
||||
from app.auth.session_cookies import (
|
||||
access_expires_at,
|
||||
clear_session,
|
||||
issue,
|
||||
read_refresh,
|
||||
)
|
||||
from app.config import config
|
||||
from app.db import User, async_session_maker, get_async_session
|
||||
from app.rate_limiter import limiter
|
||||
from app.schemas.auth import (
|
||||
DesktopLoginRequest,
|
||||
DesktopSessionRequest,
|
||||
LogoutAllResponse,
|
||||
LogoutRequest,
|
||||
LogoutResponse,
|
||||
DesktopSessionRequest,
|
||||
RefreshTokenRequest,
|
||||
RefreshTokenResponse,
|
||||
SessionResponse,
|
||||
)
|
||||
from app.users import SECRET, UserManager, get_auth_context, get_jwt_strategy, get_user_manager
|
||||
from app.users import (
|
||||
UserManager,
|
||||
get_auth_context,
|
||||
get_jwt_strategy,
|
||||
get_user_manager,
|
||||
)
|
||||
from app.utils.refresh_tokens import (
|
||||
create_refresh_token,
|
||||
revoke_all_user_tokens,
|
||||
|
|
@ -40,36 +52,13 @@ router = APIRouter(prefix="/auth/jwt", tags=["auth"])
|
|||
session_router = APIRouter(prefix="/auth", tags=["auth"])
|
||||
|
||||
|
||||
def _access_expires_at(access_token: str) -> int:
|
||||
payload = jwt.decode(
|
||||
access_token,
|
||||
SECRET,
|
||||
algorithms=["HS256"],
|
||||
options={"verify_aud": False},
|
||||
)
|
||||
return int(payload["exp"])
|
||||
|
||||
|
||||
def _request_access_token(request: Request) -> str | None:
|
||||
cookie_token = request.cookies.get(config.SESSION_COOKIE_NAME)
|
||||
if cookie_token:
|
||||
return cookie_token
|
||||
auth_header = request.headers.get("Authorization")
|
||||
if not auth_header:
|
||||
return None
|
||||
scheme, _, token = auth_header.partition(" ")
|
||||
if scheme.lower() == "bearer" and token:
|
||||
return token
|
||||
return None
|
||||
|
||||
|
||||
async def _load_user(user_id) -> User | None:
|
||||
async with async_session_maker() as session:
|
||||
result = await session.execute(select(User).where(User.id == user_id))
|
||||
return result.scalars().first()
|
||||
|
||||
|
||||
@router.post("/refresh", response_model=RefreshTokenResponse)
|
||||
@router.post("/refresh", response_model=None)
|
||||
@limiter.limit("30/minute")
|
||||
async def refresh_access_token(
|
||||
request: Request,
|
||||
|
|
@ -80,7 +69,7 @@ async def refresh_access_token(
|
|||
Exchange a valid refresh token for a new access token and refresh token.
|
||||
Implements token rotation for security.
|
||||
"""
|
||||
refresh_token = read_refresh(request, body)
|
||||
refresh_token, mode = read_refresh(request, body)
|
||||
if not refresh_token:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
|
|
@ -101,19 +90,18 @@ async def refresh_access_token(
|
|||
detail="User not found",
|
||||
)
|
||||
|
||||
# Generate new access token
|
||||
strategy = get_jwt_strategy()
|
||||
access_token = await strategy.write_token(user)
|
||||
|
||||
if request.cookies.get(config.REFRESH_COOKIE_NAME) and rotation.refresh_token:
|
||||
write_session(response, access_token, rotation.refresh_token, request)
|
||||
|
||||
logger.info(f"Refreshed token for user {user.id}")
|
||||
|
||||
return RefreshTokenResponse(
|
||||
access_token=access_token,
|
||||
refresh_token=rotation.refresh_token,
|
||||
access_expires_at=_access_expires_at(access_token),
|
||||
return issue(
|
||||
response,
|
||||
mode,
|
||||
access=access_token,
|
||||
refresh=rotation.refresh_token,
|
||||
access_expires_at=access_expires_at(access_token),
|
||||
request=request,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -127,7 +115,7 @@ async def revoke_token(
|
|||
Logout current device by revoking the provided refresh token.
|
||||
Does not require authentication - just the refresh token.
|
||||
"""
|
||||
refresh_token = read_refresh(request, body)
|
||||
refresh_token, _mode = read_refresh(request, body)
|
||||
revoked = await revoke_refresh_token(refresh_token) if refresh_token else False
|
||||
clear_session(response, request)
|
||||
if revoked:
|
||||
|
|
@ -158,7 +146,7 @@ async def logout_all_devices(
|
|||
user = None
|
||||
|
||||
if user is None:
|
||||
refresh_token = read_refresh(request, body)
|
||||
refresh_token, _mode = read_refresh(request, body)
|
||||
token_record = await validate_refresh_token(refresh_token) if refresh_token else None
|
||||
if token_record:
|
||||
user = await _load_user(token_record.user_id)
|
||||
|
|
@ -178,12 +166,55 @@ async def logout_all_devices(
|
|||
@session_router.get("/session", response_model=SessionResponse)
|
||||
async def get_session(
|
||||
request: Request,
|
||||
_auth: AuthContext = Depends(get_auth_context),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
access_token = _request_access_token(request)
|
||||
if not access_token:
|
||||
if auth.method == "pat":
|
||||
return SessionResponse(access_expires_at=None)
|
||||
|
||||
access_token = request.cookies.get(config.SESSION_COOKIE_NAME)
|
||||
if access_token is None:
|
||||
auth_header = request.headers.get("Authorization")
|
||||
if auth_header:
|
||||
scheme, _, token = auth_header.partition(" ")
|
||||
if scheme.lower() == "bearer" and token:
|
||||
access_token = token
|
||||
|
||||
if access_token is None:
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Unauthorized")
|
||||
return SessionResponse(access_expires_at=_access_expires_at(access_token))
|
||||
return SessionResponse(access_expires_at=access_expires_at(access_token))
|
||||
|
||||
|
||||
@session_router.post("/desktop/login", response_model=RefreshTokenResponse)
|
||||
@limiter.limit("5/minute")
|
||||
async def desktop_password_login(
|
||||
request: Request,
|
||||
body: DesktopLoginRequest,
|
||||
user_manager: UserManager = Depends(get_user_manager),
|
||||
):
|
||||
if config.AUTH_TYPE == "GOOGLE":
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Not found")
|
||||
if not config.REGISTRATION_ENABLED:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Registration is disabled",
|
||||
)
|
||||
|
||||
credentials = SimpleNamespace(username=body.email, password=body.password)
|
||||
user = await user_manager.authenticate(credentials)
|
||||
if user is None or not user.is_active:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="LOGIN_BAD_CREDENTIALS",
|
||||
)
|
||||
|
||||
app_access_token = await get_jwt_strategy().write_token(user)
|
||||
app_refresh_token = await create_refresh_token(user.id)
|
||||
await user_manager.on_after_login(user, request, None)
|
||||
return RefreshTokenResponse(
|
||||
access_token=app_access_token,
|
||||
refresh_token=app_refresh_token,
|
||||
access_expires_at=access_expires_at(app_access_token),
|
||||
)
|
||||
|
||||
|
||||
@session_router.post("/desktop/session", response_model=RefreshTokenResponse)
|
||||
|
|
@ -193,7 +224,17 @@ async def create_desktop_session(
|
|||
body: DesktopSessionRequest,
|
||||
user_manager: UserManager = Depends(get_user_manager),
|
||||
):
|
||||
if not body.redirect_uri.startswith("http://127.0.0.1:"):
|
||||
parsed_redirect = urlparse(body.redirect_uri)
|
||||
try:
|
||||
redirect_port = parsed_redirect.port
|
||||
except ValueError:
|
||||
redirect_port = None
|
||||
if not (
|
||||
parsed_redirect.scheme == "http"
|
||||
and parsed_redirect.hostname in {"127.0.0.1", "::1"}
|
||||
and redirect_port is not None
|
||||
and parsed_redirect.path == "/callback"
|
||||
):
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid redirect URI")
|
||||
if not config.GOOGLE_DESKTOP_CLIENT_ID:
|
||||
raise HTTPException(
|
||||
|
|
@ -238,6 +279,7 @@ async def create_desktop_session(
|
|||
if not claims.get("sub") or not claims.get("email"):
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid Google identity token")
|
||||
|
||||
email_verified = bool(claims.get("email_verified"))
|
||||
user = await user_manager.oauth_callback(
|
||||
"google",
|
||||
access_token,
|
||||
|
|
@ -250,13 +292,13 @@ async def create_desktop_session(
|
|||
),
|
||||
refresh_token=token_data.get("refresh_token"),
|
||||
request=request,
|
||||
associate_by_email=True,
|
||||
is_verified_by_default=True,
|
||||
associate_by_email=email_verified,
|
||||
is_verified_by_default=email_verified,
|
||||
)
|
||||
app_access_token = await get_jwt_strategy().write_token(user)
|
||||
app_refresh_token = await create_refresh_token(user.id)
|
||||
return RefreshTokenResponse(
|
||||
access_token=app_access_token,
|
||||
refresh_token=app_refresh_token,
|
||||
access_expires_at=_access_expires_at(app_access_token),
|
||||
access_expires_at=access_expires_at(app_access_token),
|
||||
)
|
||||
|
|
|
|||
|
|
@ -38,10 +38,15 @@ class LogoutAllResponse(BaseModel):
|
|||
|
||||
class SessionResponse(BaseModel):
|
||||
authenticated: bool = True
|
||||
access_expires_at: int
|
||||
access_expires_at: int | None = None
|
||||
|
||||
|
||||
class DesktopSessionRequest(BaseModel):
|
||||
code: str
|
||||
code_verifier: str
|
||||
redirect_uri: str
|
||||
|
||||
|
||||
class DesktopLoginRequest(BaseModel):
|
||||
email: str
|
||||
password: str
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ import uuid
|
|||
from datetime import UTC, datetime
|
||||
|
||||
import httpx
|
||||
import jwt
|
||||
from fastapi import Depends, HTTPException, Request, Response, status
|
||||
from fastapi.responses import JSONResponse, RedirectResponse
|
||||
from fastapi_users import BaseUserManager, FastAPIUsers, UUIDIDMixin, models
|
||||
|
|
@ -12,12 +13,12 @@ from fastapi_users.authentication import (
|
|||
JWTStrategy,
|
||||
)
|
||||
from fastapi_users.db import SQLAlchemyUserDatabase
|
||||
from pydantic import BaseModel
|
||||
from fastapi_users.jwt import generate_jwt
|
||||
from sqlalchemy import update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.auth.context import AuthContext
|
||||
from app.auth.session_cookies import write_session
|
||||
from app.auth.session_cookies import access_expires_at, write_session
|
||||
from app.config import config
|
||||
from app.db import (
|
||||
Prompt,
|
||||
|
|
@ -37,13 +38,6 @@ from app.utils.refresh_tokens import create_refresh_token
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BearerResponse(BaseModel):
|
||||
access_token: str
|
||||
refresh_token: str
|
||||
token_type: str
|
||||
access_expires_at: int
|
||||
|
||||
|
||||
SECRET = config.SECRET_KEY
|
||||
|
||||
|
||||
|
|
@ -232,8 +226,23 @@ async def get_user_manager(user_db: SQLAlchemyUserDatabase = Depends(get_user_db
|
|||
yield UserManager(user_db)
|
||||
|
||||
|
||||
class IatJWTStrategy(JWTStrategy[models.UP, models.ID]):
|
||||
async def write_token(self, user: models.UP) -> str:
|
||||
data = {
|
||||
"sub": str(user.id),
|
||||
"aud": self.token_audience,
|
||||
"iat": int(datetime.now(UTC).timestamp()),
|
||||
}
|
||||
return generate_jwt(
|
||||
data,
|
||||
self.encode_key,
|
||||
self.lifetime_seconds,
|
||||
algorithm=self.algorithm,
|
||||
)
|
||||
|
||||
|
||||
def get_jwt_strategy() -> JWTStrategy[models.UP, models.ID]:
|
||||
return JWTStrategy(
|
||||
return IatJWTStrategy(
|
||||
secret=SECRET,
|
||||
lifetime_seconds=config.ACCESS_TOKEN_LIFETIME_SECONDS,
|
||||
)
|
||||
|
|
@ -262,48 +271,34 @@ def get_jwt_strategy() -> JWTStrategy[models.UP, models.ID]:
|
|||
# BEARER AUTH CODE.
|
||||
class CustomBearerTransport(BearerTransport):
|
||||
async def get_login_response(self, token: str) -> Response:
|
||||
import jwt
|
||||
|
||||
# Decode JWT to get user_id for refresh token creation
|
||||
access_expires_at = 0
|
||||
try:
|
||||
payload = jwt.decode(
|
||||
token, SECRET, algorithms=["HS256"], options={"verify_aud": False}
|
||||
)
|
||||
access_expires_at = int(payload["exp"])
|
||||
user_id = uuid.UUID(payload.get("sub"))
|
||||
refresh_token = await create_refresh_token(user_id)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create refresh token: {e}")
|
||||
# Fall back to response without refresh token
|
||||
refresh_token = ""
|
||||
|
||||
bearer_response = BearerResponse(
|
||||
access_token=token,
|
||||
refresh_token=refresh_token,
|
||||
token_type="bearer",
|
||||
access_expires_at=access_expires_at,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to create session",
|
||||
) from e
|
||||
|
||||
if config.AUTH_TYPE == "GOOGLE":
|
||||
response = RedirectResponse(
|
||||
f"{config.NEXT_FRONTEND_URL}/auth/callback",
|
||||
f"{config.NEXT_FRONTEND_URL}/dashboard",
|
||||
status_code=302,
|
||||
)
|
||||
write_session(
|
||||
response,
|
||||
bearer_response.access_token,
|
||||
bearer_response.refresh_token,
|
||||
)
|
||||
return response
|
||||
else:
|
||||
response = JSONResponse(bearer_response.model_dump())
|
||||
write_session(
|
||||
response,
|
||||
bearer_response.access_token,
|
||||
bearer_response.refresh_token,
|
||||
response = JSONResponse(
|
||||
{
|
||||
"authenticated": True,
|
||||
"access_expires_at": access_expires_at(token),
|
||||
}
|
||||
)
|
||||
return response
|
||||
|
||||
write_session(response, token, refresh_token)
|
||||
return response
|
||||
|
||||
|
||||
bearer_transport = CustomBearerTransport(tokenUrl="auth/jwt/login")
|
||||
|
|
@ -318,6 +313,22 @@ auth_backend = AuthenticationBackend(
|
|||
fastapi_users = FastAPIUsers[User, uuid.UUID](get_user_manager, [auth_backend])
|
||||
|
||||
|
||||
def _token_meets_epoch(token: str) -> bool:
|
||||
min_issued_at = config.MIN_ISSUED_AT
|
||||
if min_issued_at <= 0:
|
||||
return True
|
||||
|
||||
try:
|
||||
payload = jwt.decode(
|
||||
token, SECRET, algorithms=["HS256"], options={"verify_aud": False}
|
||||
)
|
||||
except jwt.PyJWTError:
|
||||
return False
|
||||
|
||||
issued_at = payload.get("iat")
|
||||
return isinstance(issued_at, (int, float)) and int(issued_at) >= min_issued_at
|
||||
|
||||
|
||||
async def get_auth_context(
|
||||
request: Request,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
|
|
@ -341,7 +352,7 @@ async def get_auth_context(
|
|||
maybe_touch_last_used(pat)
|
||||
return AuthContext.pat_auth(pat.user, pat)
|
||||
|
||||
if is_bearer:
|
||||
if is_bearer and _token_meets_epoch(token):
|
||||
try:
|
||||
user = await get_jwt_strategy().read_token(token, user_manager)
|
||||
except Exception:
|
||||
|
|
@ -352,7 +363,7 @@ async def get_auth_context(
|
|||
return AuthContext.session(user)
|
||||
|
||||
cookie_token = request.cookies.get(config.SESSION_COOKIE_NAME)
|
||||
if cookie_token:
|
||||
if cookie_token and _token_meets_epoch(cookie_token):
|
||||
try:
|
||||
user = await get_jwt_strategy().read_token(cookie_token, user_manager)
|
||||
except Exception:
|
||||
|
|
|
|||
69
surfsense_backend/scripts/revoke_refresh_tokens_cutover.py
Normal file
69
surfsense_backend/scripts/revoke_refresh_tokens_cutover.py
Normal file
|
|
@ -0,0 +1,69 @@
|
|||
"""One-shot cutover helper to revoke every refresh token.
|
||||
|
||||
Run with --yes during the auth-hardening cutover, alongside setting
|
||||
MIN_ISSUED_AT to the deploy epoch.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
|
||||
from sqlalchemy import text
|
||||
|
||||
from app.db import async_session_maker
|
||||
|
||||
|
||||
async def _count_active_tokens() -> int:
|
||||
async with async_session_maker() as session:
|
||||
result = await session.execute(
|
||||
text(
|
||||
"""
|
||||
SELECT count(*)
|
||||
FROM refresh_tokens
|
||||
WHERE revoked_at IS NULL
|
||||
AND expires_at > NOW()
|
||||
"""
|
||||
)
|
||||
)
|
||||
return int(result.scalar_one())
|
||||
|
||||
|
||||
async def _revoke_all_tokens() -> int:
|
||||
async with async_session_maker() as session:
|
||||
result = await session.execute(
|
||||
text(
|
||||
"""
|
||||
UPDATE refresh_tokens
|
||||
SET revoked_at = NOW(),
|
||||
expires_at = NOW()
|
||||
WHERE revoked_at IS NULL
|
||||
OR expires_at > NOW()
|
||||
"""
|
||||
)
|
||||
)
|
||||
await session.commit()
|
||||
return int(result.rowcount or 0)
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--yes",
|
||||
action="store_true",
|
||||
help="Actually revoke tokens. Without this flag the command is a dry run.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
active_count = await _count_active_tokens()
|
||||
if not args.yes:
|
||||
print(f"Dry run: {active_count} active refresh token(s) would be revoked.")
|
||||
print("Re-run with --yes during the auth-hardening cutover to revoke them.")
|
||||
return
|
||||
|
||||
updated_count = await _revoke_all_tokens()
|
||||
print(f"Revoked {updated_count} refresh token row(s).")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
Loading…
Add table
Add a link
Reference in a new issue