mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-06-26 21:39:43 +02:00
fix(auth):enforce session auth cutover
This commit is contained in:
parent
fbecbb98b5
commit
62c7efb216
4 changed files with 214 additions and 87 deletions
|
|
@ -3,6 +3,7 @@ import uuid
|
|||
from datetime import UTC, datetime
|
||||
|
||||
import httpx
|
||||
import jwt
|
||||
from fastapi import Depends, HTTPException, Request, Response, status
|
||||
from fastapi.responses import JSONResponse, RedirectResponse
|
||||
from fastapi_users import BaseUserManager, FastAPIUsers, UUIDIDMixin, models
|
||||
|
|
@ -12,12 +13,12 @@ from fastapi_users.authentication import (
|
|||
JWTStrategy,
|
||||
)
|
||||
from fastapi_users.db import SQLAlchemyUserDatabase
|
||||
from pydantic import BaseModel
|
||||
from fastapi_users.jwt import generate_jwt
|
||||
from sqlalchemy import update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.auth.context import AuthContext
|
||||
from app.auth.session_cookies import write_session
|
||||
from app.auth.session_cookies import access_expires_at, write_session
|
||||
from app.config import config
|
||||
from app.db import (
|
||||
Prompt,
|
||||
|
|
@ -37,13 +38,6 @@ from app.utils.refresh_tokens import create_refresh_token
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BearerResponse(BaseModel):
|
||||
access_token: str
|
||||
refresh_token: str
|
||||
token_type: str
|
||||
access_expires_at: int
|
||||
|
||||
|
||||
SECRET = config.SECRET_KEY
|
||||
|
||||
|
||||
|
|
@ -232,8 +226,23 @@ async def get_user_manager(user_db: SQLAlchemyUserDatabase = Depends(get_user_db
|
|||
yield UserManager(user_db)
|
||||
|
||||
|
||||
class IatJWTStrategy(JWTStrategy[models.UP, models.ID]):
|
||||
async def write_token(self, user: models.UP) -> str:
|
||||
data = {
|
||||
"sub": str(user.id),
|
||||
"aud": self.token_audience,
|
||||
"iat": int(datetime.now(UTC).timestamp()),
|
||||
}
|
||||
return generate_jwt(
|
||||
data,
|
||||
self.encode_key,
|
||||
self.lifetime_seconds,
|
||||
algorithm=self.algorithm,
|
||||
)
|
||||
|
||||
|
||||
def get_jwt_strategy() -> JWTStrategy[models.UP, models.ID]:
|
||||
return JWTStrategy(
|
||||
return IatJWTStrategy(
|
||||
secret=SECRET,
|
||||
lifetime_seconds=config.ACCESS_TOKEN_LIFETIME_SECONDS,
|
||||
)
|
||||
|
|
@ -262,48 +271,34 @@ def get_jwt_strategy() -> JWTStrategy[models.UP, models.ID]:
|
|||
# BEARER AUTH CODE.
|
||||
class CustomBearerTransport(BearerTransport):
|
||||
async def get_login_response(self, token: str) -> Response:
|
||||
import jwt
|
||||
|
||||
# Decode JWT to get user_id for refresh token creation
|
||||
access_expires_at = 0
|
||||
try:
|
||||
payload = jwt.decode(
|
||||
token, SECRET, algorithms=["HS256"], options={"verify_aud": False}
|
||||
)
|
||||
access_expires_at = int(payload["exp"])
|
||||
user_id = uuid.UUID(payload.get("sub"))
|
||||
refresh_token = await create_refresh_token(user_id)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create refresh token: {e}")
|
||||
# Fall back to response without refresh token
|
||||
refresh_token = ""
|
||||
|
||||
bearer_response = BearerResponse(
|
||||
access_token=token,
|
||||
refresh_token=refresh_token,
|
||||
token_type="bearer",
|
||||
access_expires_at=access_expires_at,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to create session",
|
||||
) from e
|
||||
|
||||
if config.AUTH_TYPE == "GOOGLE":
|
||||
response = RedirectResponse(
|
||||
f"{config.NEXT_FRONTEND_URL}/auth/callback",
|
||||
f"{config.NEXT_FRONTEND_URL}/dashboard",
|
||||
status_code=302,
|
||||
)
|
||||
write_session(
|
||||
response,
|
||||
bearer_response.access_token,
|
||||
bearer_response.refresh_token,
|
||||
)
|
||||
return response
|
||||
else:
|
||||
response = JSONResponse(bearer_response.model_dump())
|
||||
write_session(
|
||||
response,
|
||||
bearer_response.access_token,
|
||||
bearer_response.refresh_token,
|
||||
response = JSONResponse(
|
||||
{
|
||||
"authenticated": True,
|
||||
"access_expires_at": access_expires_at(token),
|
||||
}
|
||||
)
|
||||
return response
|
||||
|
||||
write_session(response, token, refresh_token)
|
||||
return response
|
||||
|
||||
|
||||
bearer_transport = CustomBearerTransport(tokenUrl="auth/jwt/login")
|
||||
|
|
@ -318,6 +313,22 @@ auth_backend = AuthenticationBackend(
|
|||
fastapi_users = FastAPIUsers[User, uuid.UUID](get_user_manager, [auth_backend])
|
||||
|
||||
|
||||
def _token_meets_epoch(token: str) -> bool:
|
||||
min_issued_at = config.MIN_ISSUED_AT
|
||||
if min_issued_at <= 0:
|
||||
return True
|
||||
|
||||
try:
|
||||
payload = jwt.decode(
|
||||
token, SECRET, algorithms=["HS256"], options={"verify_aud": False}
|
||||
)
|
||||
except jwt.PyJWTError:
|
||||
return False
|
||||
|
||||
issued_at = payload.get("iat")
|
||||
return isinstance(issued_at, (int, float)) and int(issued_at) >= min_issued_at
|
||||
|
||||
|
||||
async def get_auth_context(
|
||||
request: Request,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
|
|
@ -341,7 +352,7 @@ async def get_auth_context(
|
|||
maybe_touch_last_used(pat)
|
||||
return AuthContext.pat_auth(pat.user, pat)
|
||||
|
||||
if is_bearer:
|
||||
if is_bearer and _token_meets_epoch(token):
|
||||
try:
|
||||
user = await get_jwt_strategy().read_token(token, user_manager)
|
||||
except Exception:
|
||||
|
|
@ -352,7 +363,7 @@ async def get_auth_context(
|
|||
return AuthContext.session(user)
|
||||
|
||||
cookie_token = request.cookies.get(config.SESSION_COOKIE_NAME)
|
||||
if cookie_token:
|
||||
if cookie_token and _token_meets_epoch(cookie_token):
|
||||
try:
|
||||
user = await get_jwt_strategy().read_token(cookie_token, user_manager)
|
||||
except Exception:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue