mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-06-28 21:49:40 +02:00
fix(auth):return session based auth responses
This commit is contained in:
parent
2966f41a5a
commit
a547cfe3c3
2 changed files with 199 additions and 22 deletions
|
|
@ -1,21 +1,33 @@
|
|||
"""Authentication routes for refresh token management."""
|
||||
|
||||
import logging
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
import httpx
|
||||
import jwt
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, Response, status
|
||||
from google.auth.transport import requests as google_requests
|
||||
from google.oauth2 import id_token as google_id_token
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.auth.context import AuthContext
|
||||
from app.db import User, async_session_maker
|
||||
from app.auth.session_cookies import clear_session, read_refresh, write_session
|
||||
from app.config import config
|
||||
from app.db import User, async_session_maker, get_async_session
|
||||
from app.rate_limiter import limiter
|
||||
from app.schemas.auth import (
|
||||
LogoutAllResponse,
|
||||
LogoutRequest,
|
||||
LogoutResponse,
|
||||
DesktopSessionRequest,
|
||||
RefreshTokenRequest,
|
||||
RefreshTokenResponse,
|
||||
SessionResponse,
|
||||
)
|
||||
from app.users import get_jwt_strategy, require_session_context
|
||||
from app.users import SECRET, UserManager, get_auth_context, get_jwt_strategy, get_user_manager
|
||||
from app.utils.refresh_tokens import (
|
||||
create_refresh_token,
|
||||
revoke_all_user_tokens,
|
||||
revoke_refresh_token,
|
||||
rotate_refresh_token,
|
||||
|
|
@ -25,29 +37,64 @@ from app.utils.refresh_tokens import (
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/auth/jwt", tags=["auth"])
|
||||
session_router = APIRouter(prefix="/auth", tags=["auth"])
|
||||
|
||||
|
||||
def _access_expires_at(access_token: str) -> int:
|
||||
payload = jwt.decode(
|
||||
access_token,
|
||||
SECRET,
|
||||
algorithms=["HS256"],
|
||||
options={"verify_aud": False},
|
||||
)
|
||||
return int(payload["exp"])
|
||||
|
||||
|
||||
def _request_access_token(request: Request) -> str | None:
|
||||
cookie_token = request.cookies.get(config.SESSION_COOKIE_NAME)
|
||||
if cookie_token:
|
||||
return cookie_token
|
||||
auth_header = request.headers.get("Authorization")
|
||||
if not auth_header:
|
||||
return None
|
||||
scheme, _, token = auth_header.partition(" ")
|
||||
if scheme.lower() == "bearer" and token:
|
||||
return token
|
||||
return None
|
||||
|
||||
|
||||
async def _load_user(user_id) -> User | None:
|
||||
async with async_session_maker() as session:
|
||||
result = await session.execute(select(User).where(User.id == user_id))
|
||||
return result.scalars().first()
|
||||
|
||||
|
||||
@router.post("/refresh", response_model=RefreshTokenResponse)
|
||||
async def refresh_access_token(request: RefreshTokenRequest):
|
||||
@limiter.limit("30/minute")
|
||||
async def refresh_access_token(
|
||||
request: Request,
|
||||
response: Response,
|
||||
body: RefreshTokenRequest | None = None,
|
||||
):
|
||||
"""
|
||||
Exchange a valid refresh token for a new access token and refresh token.
|
||||
Implements token rotation for security.
|
||||
"""
|
||||
token_record = await validate_refresh_token(request.refresh_token)
|
||||
|
||||
if not token_record:
|
||||
refresh_token = read_refresh(request, body)
|
||||
if not refresh_token:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid or expired refresh token",
|
||||
)
|
||||
|
||||
# Get user from token record
|
||||
async with async_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(User).where(User.id == token_record.user_id)
|
||||
rotation = await rotate_refresh_token(refresh_token)
|
||||
if not rotation:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid or expired refresh token",
|
||||
)
|
||||
user = result.scalars().first()
|
||||
|
||||
user = await _load_user(rotation.user_id)
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
|
|
@ -58,24 +105,31 @@ async def refresh_access_token(request: RefreshTokenRequest):
|
|||
strategy = get_jwt_strategy()
|
||||
access_token = await strategy.write_token(user)
|
||||
|
||||
# Rotate refresh token
|
||||
new_refresh_token = await rotate_refresh_token(token_record)
|
||||
if request.cookies.get(config.REFRESH_COOKIE_NAME) and rotation.refresh_token:
|
||||
write_session(response, access_token, rotation.refresh_token, request)
|
||||
|
||||
logger.info(f"Refreshed token for user {user.id}")
|
||||
|
||||
return RefreshTokenResponse(
|
||||
access_token=access_token,
|
||||
refresh_token=new_refresh_token,
|
||||
refresh_token=rotation.refresh_token,
|
||||
access_expires_at=_access_expires_at(access_token),
|
||||
)
|
||||
|
||||
|
||||
@router.post("/revoke", response_model=LogoutResponse)
|
||||
async def revoke_token(request: LogoutRequest):
|
||||
async def revoke_token(
|
||||
request: Request,
|
||||
response: Response,
|
||||
body: LogoutRequest | None = None,
|
||||
):
|
||||
"""
|
||||
Logout current device by revoking the provided refresh token.
|
||||
Does not require authentication - just the refresh token.
|
||||
"""
|
||||
revoked = await revoke_refresh_token(request.refresh_token)
|
||||
refresh_token = read_refresh(request, body)
|
||||
revoked = await revoke_refresh_token(refresh_token) if refresh_token else False
|
||||
clear_session(response, request)
|
||||
if revoked:
|
||||
logger.info("User logged out from current device - token revoked")
|
||||
else:
|
||||
|
|
@ -85,13 +139,124 @@ async def revoke_token(request: LogoutRequest):
|
|||
|
||||
@router.post("/logout-all", response_model=LogoutAllResponse)
|
||||
async def logout_all_devices(
|
||||
auth: AuthContext = Depends(require_session_context),
|
||||
request: Request,
|
||||
response: Response,
|
||||
body: LogoutRequest | None = None,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user_manager: UserManager = Depends(get_user_manager),
|
||||
):
|
||||
"""
|
||||
Logout from all devices by revoking all refresh tokens for the user.
|
||||
Requires valid access token.
|
||||
"""
|
||||
user = auth.user
|
||||
user: User | None = None
|
||||
try:
|
||||
auth = await get_auth_context(request, session=session, user_manager=user_manager)
|
||||
if auth.is_session:
|
||||
user = auth.user
|
||||
except HTTPException:
|
||||
user = None
|
||||
|
||||
if user is None:
|
||||
refresh_token = read_refresh(request, body)
|
||||
token_record = await validate_refresh_token(refresh_token) if refresh_token else None
|
||||
if token_record:
|
||||
user = await _load_user(token_record.user_id)
|
||||
|
||||
if user is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid or expired refresh token",
|
||||
)
|
||||
|
||||
await revoke_all_user_tokens(user.id)
|
||||
clear_session(response, request)
|
||||
logger.info(f"User {user.id} logged out from all devices")
|
||||
return LogoutAllResponse()
|
||||
|
||||
|
||||
@session_router.get("/session", response_model=SessionResponse)
|
||||
async def get_session(
|
||||
request: Request,
|
||||
_auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
access_token = _request_access_token(request)
|
||||
if not access_token:
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Unauthorized")
|
||||
return SessionResponse(access_expires_at=_access_expires_at(access_token))
|
||||
|
||||
|
||||
@session_router.post("/desktop/session", response_model=RefreshTokenResponse)
|
||||
@limiter.limit("20/minute")
|
||||
async def create_desktop_session(
|
||||
request: Request,
|
||||
body: DesktopSessionRequest,
|
||||
user_manager: UserManager = Depends(get_user_manager),
|
||||
):
|
||||
if not body.redirect_uri.startswith("http://127.0.0.1:"):
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid redirect URI")
|
||||
if not config.GOOGLE_DESKTOP_CLIENT_ID:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="Desktop OAuth is not configured",
|
||||
)
|
||||
|
||||
token_payload = {
|
||||
"client_id": config.GOOGLE_DESKTOP_CLIENT_ID,
|
||||
"code": body.code,
|
||||
"code_verifier": body.code_verifier,
|
||||
"grant_type": "authorization_code",
|
||||
"redirect_uri": body.redirect_uri,
|
||||
}
|
||||
if config.GOOGLE_DESKTOP_CLIENT_SECRET:
|
||||
token_payload["client_secret"] = config.GOOGLE_DESKTOP_CLIENT_SECRET
|
||||
|
||||
async with httpx.AsyncClient(timeout=10) as client:
|
||||
token_response = await client.post("https://oauth2.googleapis.com/token", data=token_payload)
|
||||
if token_response.status_code >= 400:
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="OAuth exchange failed")
|
||||
token_data = token_response.json()
|
||||
|
||||
id_token = token_data.get("id_token")
|
||||
access_token = token_data.get("access_token")
|
||||
if not id_token or not access_token:
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="OAuth exchange failed")
|
||||
|
||||
try:
|
||||
claims = google_id_token.verify_oauth2_token(
|
||||
id_token,
|
||||
google_requests.Request(),
|
||||
config.GOOGLE_DESKTOP_CLIENT_ID,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning("Desktop Google id_token verification failed: %s", exc)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid Google identity token",
|
||||
) from exc
|
||||
|
||||
if not claims.get("sub") or not claims.get("email"):
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid Google identity token")
|
||||
|
||||
user = await user_manager.oauth_callback(
|
||||
"google",
|
||||
access_token,
|
||||
claims["sub"],
|
||||
claims["email"],
|
||||
expires_at=(
|
||||
int(datetime.now(UTC).timestamp()) + int(token_data["expires_in"])
|
||||
if token_data.get("expires_in")
|
||||
else None
|
||||
),
|
||||
refresh_token=token_data.get("refresh_token"),
|
||||
request=request,
|
||||
associate_by_email=True,
|
||||
is_verified_by_default=True,
|
||||
)
|
||||
app_access_token = await get_jwt_strategy().write_token(user)
|
||||
app_refresh_token = await create_refresh_token(user.id)
|
||||
return RefreshTokenResponse(
|
||||
access_token=app_access_token,
|
||||
refresh_token=app_refresh_token,
|
||||
access_expires_at=_access_expires_at(app_access_token),
|
||||
)
|
||||
|
|
|
|||
|
|
@ -6,21 +6,22 @@ from pydantic import BaseModel
|
|||
class RefreshTokenRequest(BaseModel):
|
||||
"""Request body for token refresh endpoint."""
|
||||
|
||||
refresh_token: str
|
||||
refresh_token: str | None = None
|
||||
|
||||
|
||||
class RefreshTokenResponse(BaseModel):
|
||||
"""Response from token refresh endpoint."""
|
||||
|
||||
access_token: str
|
||||
refresh_token: str
|
||||
refresh_token: str | None = None
|
||||
token_type: str = "bearer"
|
||||
access_expires_at: int
|
||||
|
||||
|
||||
class LogoutRequest(BaseModel):
|
||||
"""Request body for logout endpoint (current device)."""
|
||||
|
||||
refresh_token: str
|
||||
refresh_token: str | None = None
|
||||
|
||||
|
||||
class LogoutResponse(BaseModel):
|
||||
|
|
@ -33,3 +34,14 @@ class LogoutAllResponse(BaseModel):
|
|||
"""Response from logout all devices endpoint."""
|
||||
|
||||
detail: str = "Successfully logged out from all devices"
|
||||
|
||||
|
||||
class SessionResponse(BaseModel):
|
||||
authenticated: bool = True
|
||||
access_expires_at: int
|
||||
|
||||
|
||||
class DesktopSessionRequest(BaseModel):
|
||||
code: str
|
||||
code_verifier: str
|
||||
redirect_uri: str
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue