diff --git a/surfsense_backend/app/routes/auth_routes.py b/surfsense_backend/app/routes/auth_routes.py index 17b6e922f..3ec475f68 100644 --- a/surfsense_backend/app/routes/auth_routes.py +++ b/surfsense_backend/app/routes/auth_routes.py @@ -2,9 +2,10 @@ import logging from datetime import UTC, datetime +from types import SimpleNamespace +from urllib.parse import urlparse import httpx -import jwt from fastapi import APIRouter, Depends, HTTPException, Request, Response, status from google.auth.transport import requests as google_requests from google.oauth2 import id_token as google_id_token @@ -12,20 +13,31 @@ from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from app.auth.context import AuthContext -from app.auth.session_cookies import clear_session, read_refresh, write_session +from app.auth.session_cookies import ( + access_expires_at, + clear_session, + issue, + read_refresh, +) from app.config import config from app.db import User, async_session_maker, get_async_session from app.rate_limiter import limiter from app.schemas.auth import ( + DesktopLoginRequest, + DesktopSessionRequest, LogoutAllResponse, LogoutRequest, LogoutResponse, - DesktopSessionRequest, RefreshTokenRequest, RefreshTokenResponse, SessionResponse, ) -from app.users import SECRET, UserManager, get_auth_context, get_jwt_strategy, get_user_manager +from app.users import ( + UserManager, + get_auth_context, + get_jwt_strategy, + get_user_manager, +) from app.utils.refresh_tokens import ( create_refresh_token, revoke_all_user_tokens, @@ -40,36 +52,13 @@ router = APIRouter(prefix="/auth/jwt", tags=["auth"]) session_router = APIRouter(prefix="/auth", tags=["auth"]) -def _access_expires_at(access_token: str) -> int: - payload = jwt.decode( - access_token, - SECRET, - algorithms=["HS256"], - options={"verify_aud": False}, - ) - return int(payload["exp"]) - - -def _request_access_token(request: Request) -> str | None: - cookie_token = request.cookies.get(config.SESSION_COOKIE_NAME) - if cookie_token: - return cookie_token - auth_header = request.headers.get("Authorization") - if not auth_header: - return None - scheme, _, token = auth_header.partition(" ") - if scheme.lower() == "bearer" and token: - return token - return None - - async def _load_user(user_id) -> User | None: async with async_session_maker() as session: result = await session.execute(select(User).where(User.id == user_id)) return result.scalars().first() -@router.post("/refresh", response_model=RefreshTokenResponse) +@router.post("/refresh", response_model=None) @limiter.limit("30/minute") async def refresh_access_token( request: Request, @@ -80,7 +69,7 @@ async def refresh_access_token( Exchange a valid refresh token for a new access token and refresh token. Implements token rotation for security. """ - refresh_token = read_refresh(request, body) + refresh_token, mode = read_refresh(request, body) if not refresh_token: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, @@ -101,19 +90,18 @@ async def refresh_access_token( detail="User not found", ) - # Generate new access token strategy = get_jwt_strategy() access_token = await strategy.write_token(user) - if request.cookies.get(config.REFRESH_COOKIE_NAME) and rotation.refresh_token: - write_session(response, access_token, rotation.refresh_token, request) - logger.info(f"Refreshed token for user {user.id}") - return RefreshTokenResponse( - access_token=access_token, - refresh_token=rotation.refresh_token, - access_expires_at=_access_expires_at(access_token), + return issue( + response, + mode, + access=access_token, + refresh=rotation.refresh_token, + access_expires_at=access_expires_at(access_token), + request=request, ) @@ -127,7 +115,7 @@ async def revoke_token( Logout current device by revoking the provided refresh token. Does not require authentication - just the refresh token. """ - refresh_token = read_refresh(request, body) + refresh_token, _mode = read_refresh(request, body) revoked = await revoke_refresh_token(refresh_token) if refresh_token else False clear_session(response, request) if revoked: @@ -158,7 +146,7 @@ async def logout_all_devices( user = None if user is None: - refresh_token = read_refresh(request, body) + refresh_token, _mode = read_refresh(request, body) token_record = await validate_refresh_token(refresh_token) if refresh_token else None if token_record: user = await _load_user(token_record.user_id) @@ -178,12 +166,55 @@ async def logout_all_devices( @session_router.get("/session", response_model=SessionResponse) async def get_session( request: Request, - _auth: AuthContext = Depends(get_auth_context), + auth: AuthContext = Depends(get_auth_context), ): - access_token = _request_access_token(request) - if not access_token: + if auth.method == "pat": + return SessionResponse(access_expires_at=None) + + access_token = request.cookies.get(config.SESSION_COOKIE_NAME) + if access_token is None: + auth_header = request.headers.get("Authorization") + if auth_header: + scheme, _, token = auth_header.partition(" ") + if scheme.lower() == "bearer" and token: + access_token = token + + if access_token is None: raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Unauthorized") - return SessionResponse(access_expires_at=_access_expires_at(access_token)) + return SessionResponse(access_expires_at=access_expires_at(access_token)) + + +@session_router.post("/desktop/login", response_model=RefreshTokenResponse) +@limiter.limit("5/minute") +async def desktop_password_login( + request: Request, + body: DesktopLoginRequest, + user_manager: UserManager = Depends(get_user_manager), +): + if config.AUTH_TYPE == "GOOGLE": + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Not found") + if not config.REGISTRATION_ENABLED: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Registration is disabled", + ) + + credentials = SimpleNamespace(username=body.email, password=body.password) + user = await user_manager.authenticate(credentials) + if user is None or not user.is_active: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="LOGIN_BAD_CREDENTIALS", + ) + + app_access_token = await get_jwt_strategy().write_token(user) + app_refresh_token = await create_refresh_token(user.id) + await user_manager.on_after_login(user, request, None) + return RefreshTokenResponse( + access_token=app_access_token, + refresh_token=app_refresh_token, + access_expires_at=access_expires_at(app_access_token), + ) @session_router.post("/desktop/session", response_model=RefreshTokenResponse) @@ -193,7 +224,17 @@ async def create_desktop_session( body: DesktopSessionRequest, user_manager: UserManager = Depends(get_user_manager), ): - if not body.redirect_uri.startswith("http://127.0.0.1:"): + parsed_redirect = urlparse(body.redirect_uri) + try: + redirect_port = parsed_redirect.port + except ValueError: + redirect_port = None + if not ( + parsed_redirect.scheme == "http" + and parsed_redirect.hostname in {"127.0.0.1", "::1"} + and redirect_port is not None + and parsed_redirect.path == "/callback" + ): raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid redirect URI") if not config.GOOGLE_DESKTOP_CLIENT_ID: raise HTTPException( @@ -238,6 +279,7 @@ async def create_desktop_session( if not claims.get("sub") or not claims.get("email"): raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid Google identity token") + email_verified = bool(claims.get("email_verified")) user = await user_manager.oauth_callback( "google", access_token, @@ -250,13 +292,13 @@ async def create_desktop_session( ), refresh_token=token_data.get("refresh_token"), request=request, - associate_by_email=True, - is_verified_by_default=True, + associate_by_email=email_verified, + is_verified_by_default=email_verified, ) app_access_token = await get_jwt_strategy().write_token(user) app_refresh_token = await create_refresh_token(user.id) return RefreshTokenResponse( access_token=app_access_token, refresh_token=app_refresh_token, - access_expires_at=_access_expires_at(app_access_token), + access_expires_at=access_expires_at(app_access_token), ) diff --git a/surfsense_backend/app/schemas/auth.py b/surfsense_backend/app/schemas/auth.py index af8940d01..bdc009109 100644 --- a/surfsense_backend/app/schemas/auth.py +++ b/surfsense_backend/app/schemas/auth.py @@ -38,10 +38,15 @@ class LogoutAllResponse(BaseModel): class SessionResponse(BaseModel): authenticated: bool = True - access_expires_at: int + access_expires_at: int | None = None class DesktopSessionRequest(BaseModel): code: str code_verifier: str redirect_uri: str + + +class DesktopLoginRequest(BaseModel): + email: str + password: str diff --git a/surfsense_backend/app/users.py b/surfsense_backend/app/users.py index 19db79b3a..524904ad7 100644 --- a/surfsense_backend/app/users.py +++ b/surfsense_backend/app/users.py @@ -3,6 +3,7 @@ import uuid from datetime import UTC, datetime import httpx +import jwt from fastapi import Depends, HTTPException, Request, Response, status from fastapi.responses import JSONResponse, RedirectResponse from fastapi_users import BaseUserManager, FastAPIUsers, UUIDIDMixin, models @@ -12,12 +13,12 @@ from fastapi_users.authentication import ( JWTStrategy, ) from fastapi_users.db import SQLAlchemyUserDatabase -from pydantic import BaseModel +from fastapi_users.jwt import generate_jwt from sqlalchemy import update from sqlalchemy.ext.asyncio import AsyncSession from app.auth.context import AuthContext -from app.auth.session_cookies import write_session +from app.auth.session_cookies import access_expires_at, write_session from app.config import config from app.db import ( Prompt, @@ -37,13 +38,6 @@ from app.utils.refresh_tokens import create_refresh_token logger = logging.getLogger(__name__) -class BearerResponse(BaseModel): - access_token: str - refresh_token: str - token_type: str - access_expires_at: int - - SECRET = config.SECRET_KEY @@ -232,8 +226,23 @@ async def get_user_manager(user_db: SQLAlchemyUserDatabase = Depends(get_user_db yield UserManager(user_db) +class IatJWTStrategy(JWTStrategy[models.UP, models.ID]): + async def write_token(self, user: models.UP) -> str: + data = { + "sub": str(user.id), + "aud": self.token_audience, + "iat": int(datetime.now(UTC).timestamp()), + } + return generate_jwt( + data, + self.encode_key, + self.lifetime_seconds, + algorithm=self.algorithm, + ) + + def get_jwt_strategy() -> JWTStrategy[models.UP, models.ID]: - return JWTStrategy( + return IatJWTStrategy( secret=SECRET, lifetime_seconds=config.ACCESS_TOKEN_LIFETIME_SECONDS, ) @@ -262,48 +271,34 @@ def get_jwt_strategy() -> JWTStrategy[models.UP, models.ID]: # BEARER AUTH CODE. class CustomBearerTransport(BearerTransport): async def get_login_response(self, token: str) -> Response: - import jwt - - # Decode JWT to get user_id for refresh token creation - access_expires_at = 0 try: payload = jwt.decode( token, SECRET, algorithms=["HS256"], options={"verify_aud": False} ) - access_expires_at = int(payload["exp"]) user_id = uuid.UUID(payload.get("sub")) refresh_token = await create_refresh_token(user_id) except Exception as e: logger.error(f"Failed to create refresh token: {e}") - # Fall back to response without refresh token - refresh_token = "" - - bearer_response = BearerResponse( - access_token=token, - refresh_token=refresh_token, - token_type="bearer", - access_expires_at=access_expires_at, - ) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Failed to create session", + ) from e if config.AUTH_TYPE == "GOOGLE": response = RedirectResponse( - f"{config.NEXT_FRONTEND_URL}/auth/callback", + f"{config.NEXT_FRONTEND_URL}/dashboard", status_code=302, ) - write_session( - response, - bearer_response.access_token, - bearer_response.refresh_token, - ) - return response else: - response = JSONResponse(bearer_response.model_dump()) - write_session( - response, - bearer_response.access_token, - bearer_response.refresh_token, + response = JSONResponse( + { + "authenticated": True, + "access_expires_at": access_expires_at(token), + } ) - return response + + write_session(response, token, refresh_token) + return response bearer_transport = CustomBearerTransport(tokenUrl="auth/jwt/login") @@ -318,6 +313,22 @@ auth_backend = AuthenticationBackend( fastapi_users = FastAPIUsers[User, uuid.UUID](get_user_manager, [auth_backend]) +def _token_meets_epoch(token: str) -> bool: + min_issued_at = config.MIN_ISSUED_AT + if min_issued_at <= 0: + return True + + try: + payload = jwt.decode( + token, SECRET, algorithms=["HS256"], options={"verify_aud": False} + ) + except jwt.PyJWTError: + return False + + issued_at = payload.get("iat") + return isinstance(issued_at, (int, float)) and int(issued_at) >= min_issued_at + + async def get_auth_context( request: Request, session: AsyncSession = Depends(get_async_session), @@ -341,7 +352,7 @@ async def get_auth_context( maybe_touch_last_used(pat) return AuthContext.pat_auth(pat.user, pat) - if is_bearer: + if is_bearer and _token_meets_epoch(token): try: user = await get_jwt_strategy().read_token(token, user_manager) except Exception: @@ -352,7 +363,7 @@ async def get_auth_context( return AuthContext.session(user) cookie_token = request.cookies.get(config.SESSION_COOKIE_NAME) - if cookie_token: + if cookie_token and _token_meets_epoch(cookie_token): try: user = await get_jwt_strategy().read_token(cookie_token, user_manager) except Exception: diff --git a/surfsense_backend/scripts/revoke_refresh_tokens_cutover.py b/surfsense_backend/scripts/revoke_refresh_tokens_cutover.py new file mode 100644 index 000000000..449d4a3e9 --- /dev/null +++ b/surfsense_backend/scripts/revoke_refresh_tokens_cutover.py @@ -0,0 +1,69 @@ +"""One-shot cutover helper to revoke every refresh token. + +Run with --yes during the auth-hardening cutover, alongside setting +MIN_ISSUED_AT to the deploy epoch. +""" + +from __future__ import annotations + +import argparse +import asyncio + +from sqlalchemy import text + +from app.db import async_session_maker + + +async def _count_active_tokens() -> int: + async with async_session_maker() as session: + result = await session.execute( + text( + """ + SELECT count(*) + FROM refresh_tokens + WHERE revoked_at IS NULL + AND expires_at > NOW() + """ + ) + ) + return int(result.scalar_one()) + + +async def _revoke_all_tokens() -> int: + async with async_session_maker() as session: + result = await session.execute( + text( + """ + UPDATE refresh_tokens + SET revoked_at = NOW(), + expires_at = NOW() + WHERE revoked_at IS NULL + OR expires_at > NOW() + """ + ) + ) + await session.commit() + return int(result.rowcount or 0) + + +async def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument( + "--yes", + action="store_true", + help="Actually revoke tokens. Without this flag the command is a dry run.", + ) + args = parser.parse_args() + + active_count = await _count_active_tokens() + if not args.yes: + print(f"Dry run: {active_count} active refresh token(s) would be revoked.") + print("Re-run with --yes during the auth-hardening cutover to revoke them.") + return + + updated_count = await _revoke_all_tokens() + print(f"Revoked {updated_count} refresh token row(s).") + + +if __name__ == "__main__": + asyncio.run(main())