mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-06-28 21:49:40 +02:00
fix(users):expose current user session routes
This commit is contained in:
parent
a547cfe3c3
commit
2b6bf504ec
3 changed files with 90 additions and 43 deletions
|
|
@ -28,6 +28,7 @@ from app.agents.chat.runtime.checkpointer import (
|
|||
setup_checkpointer_tables,
|
||||
)
|
||||
from app.auth.context import AuthContext
|
||||
from app.auth.csrf import CsrfOriginMiddleware
|
||||
from app.config import (
|
||||
config,
|
||||
initialize_image_gen_router,
|
||||
|
|
@ -54,7 +55,10 @@ from app.observability.bootstrap import init_otel, shutdown_otel
|
|||
from app.rate_limiter import get_real_client_ip, limiter
|
||||
from app.routes import router as crud_router
|
||||
from app.routes.auth_routes import router as auth_router
|
||||
from app.schemas import UserCreate, UserRead, UserUpdate
|
||||
from app.routes.auth_routes import session_router
|
||||
from app.routes.users_routes import router as users_router
|
||||
from app.routes.zero_context_routes import router as zero_context_router
|
||||
from app.schemas import UserCreate, UserRead
|
||||
from app.session_events import register_session_hooks
|
||||
from app.users import SECRET, allow_any_principal, auth_backend, fastapi_users
|
||||
from app.utils.perf import log_system_snapshot
|
||||
|
|
@ -817,6 +821,7 @@ app.add_middleware(
|
|||
# FRONTEND_URL to BACKEND_URL.
|
||||
max_age=86400,
|
||||
)
|
||||
app.add_middleware(CsrfOriginMiddleware)
|
||||
|
||||
# Password / email-based auth routers are only mounted when not running in
|
||||
# Google-OAuth-only mode. Mounting them in OAuth-only prod previously left
|
||||
|
|
@ -855,16 +860,14 @@ if config.AUTH_TYPE != "GOOGLE":
|
|||
tags=["auth"],
|
||||
)
|
||||
|
||||
# /users/me (read/update profile) is needed in every auth mode, so it stays
|
||||
# mounted unconditionally.
|
||||
app.include_router(
|
||||
fastapi_users.get_users_router(UserRead, UserUpdate),
|
||||
prefix="/users",
|
||||
tags=["users"],
|
||||
)
|
||||
# /users/me uses the unified auth resolver so web cookie sessions, desktop bearer
|
||||
# sessions, and PAT principals all resolve through the same authority.
|
||||
app.include_router(users_router)
|
||||
|
||||
# Include custom auth routes (refresh token, logout)
|
||||
app.include_router(auth_router)
|
||||
app.include_router(session_router)
|
||||
app.include_router(zero_context_router)
|
||||
|
||||
if config.AUTH_TYPE == "GOOGLE":
|
||||
from fastapi.responses import RedirectResponse
|
||||
|
|
|
|||
27
surfsense_backend/app/routes/users_routes.py
Normal file
27
surfsense_backend/app/routes/users_routes.py
Normal file
|
|
@ -0,0 +1,27 @@
|
|||
"""Cookie-aware user profile routes."""
|
||||
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
|
||||
from app.auth.context import AuthContext
|
||||
from app.schemas import UserRead, UserUpdate
|
||||
from app.users import UserManager, get_auth_context, get_user_manager, require_session_context
|
||||
|
||||
router = APIRouter(prefix="/users", tags=["users"])
|
||||
|
||||
|
||||
@router.get("/me", response_model=UserRead)
|
||||
async def get_current_user_profile(
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
return auth.user
|
||||
|
||||
|
||||
@router.patch("/me", response_model=UserRead)
|
||||
async def update_current_user_profile(
|
||||
update: UserUpdate,
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_session_context),
|
||||
user_manager: UserManager = Depends(get_user_manager),
|
||||
):
|
||||
updated_user = await user_manager.update(update, auth.user, safe=True, request=request)
|
||||
return updated_user
|
||||
|
|
@ -17,6 +17,7 @@ from sqlalchemy import update
|
|||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.auth.context import AuthContext
|
||||
from app.auth.session_cookies import write_session
|
||||
from app.config import config
|
||||
from app.db import (
|
||||
Prompt,
|
||||
|
|
@ -40,6 +41,7 @@ class BearerResponse(BaseModel):
|
|||
access_token: str
|
||||
refresh_token: str
|
||||
token_type: str
|
||||
access_expires_at: int
|
||||
|
||||
|
||||
SECRET = config.SECRET_KEY
|
||||
|
|
@ -263,10 +265,12 @@ class CustomBearerTransport(BearerTransport):
|
|||
import jwt
|
||||
|
||||
# Decode JWT to get user_id for refresh token creation
|
||||
access_expires_at = 0
|
||||
try:
|
||||
payload = jwt.decode(
|
||||
token, SECRET, algorithms=["HS256"], options={"verify_aud": False}
|
||||
)
|
||||
access_expires_at = int(payload["exp"])
|
||||
user_id = uuid.UUID(payload.get("sub"))
|
||||
refresh_token = await create_refresh_token(user_id)
|
||||
except Exception as e:
|
||||
|
|
@ -278,17 +282,28 @@ class CustomBearerTransport(BearerTransport):
|
|||
access_token=token,
|
||||
refresh_token=refresh_token,
|
||||
token_type="bearer",
|
||||
access_expires_at=access_expires_at,
|
||||
)
|
||||
|
||||
if config.AUTH_TYPE == "GOOGLE":
|
||||
redirect_url = (
|
||||
f"{config.NEXT_FRONTEND_URL}/auth/callback"
|
||||
f"?token={bearer_response.access_token}"
|
||||
f"&refresh_token={bearer_response.refresh_token}"
|
||||
response = RedirectResponse(
|
||||
f"{config.NEXT_FRONTEND_URL}/auth/callback",
|
||||
status_code=302,
|
||||
)
|
||||
return RedirectResponse(redirect_url, status_code=302)
|
||||
write_session(
|
||||
response,
|
||||
bearer_response.access_token,
|
||||
bearer_response.refresh_token,
|
||||
)
|
||||
return response
|
||||
else:
|
||||
return JSONResponse(bearer_response.model_dump())
|
||||
response = JSONResponse(bearer_response.model_dump())
|
||||
write_session(
|
||||
response,
|
||||
bearer_response.access_token,
|
||||
bearer_response.refresh_token,
|
||||
)
|
||||
return response
|
||||
|
||||
|
||||
bearer_transport = CustomBearerTransport(tokenUrl="auth/jwt/login")
|
||||
|
|
@ -315,38 +330,42 @@ async def get_auth_context(
|
|||
receives the full SurfSense principal instead of a bare User.
|
||||
"""
|
||||
auth_header = request.headers.get("Authorization")
|
||||
if not auth_header:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Unauthorized",
|
||||
)
|
||||
if auth_header:
|
||||
scheme, _, credential = auth_header.partition(" ")
|
||||
is_bearer = scheme.lower() == "bearer" and bool(credential)
|
||||
token = credential if is_bearer else auth_header.strip()
|
||||
|
||||
scheme, _, token = auth_header.partition(" ")
|
||||
if scheme.lower() != "bearer" or not token:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Unauthorized",
|
||||
)
|
||||
if token.startswith(PAT_PREFIX):
|
||||
pat = await resolve_pat(session, token)
|
||||
if pat and pat.user and pat.user.is_active:
|
||||
maybe_touch_last_used(pat)
|
||||
return AuthContext.pat_auth(pat.user, pat)
|
||||
|
||||
if token.startswith(PAT_PREFIX):
|
||||
pat = await resolve_pat(session, token)
|
||||
if pat and pat.user and pat.user.is_active:
|
||||
maybe_touch_last_used(pat)
|
||||
return AuthContext.pat_auth(pat.user, pat)
|
||||
if is_bearer:
|
||||
try:
|
||||
user = await get_jwt_strategy().read_token(token, user_manager)
|
||||
except Exception:
|
||||
logger.exception("Failed to read bearer access token")
|
||||
user = None
|
||||
|
||||
try:
|
||||
user = await get_jwt_strategy().read_token(token, user_manager)
|
||||
except Exception:
|
||||
logger.exception("Failed to read access token")
|
||||
user = None
|
||||
if user and user.is_active:
|
||||
return AuthContext.session(user)
|
||||
|
||||
if not user or not user.is_active:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Unauthorized",
|
||||
)
|
||||
cookie_token = request.cookies.get(config.SESSION_COOKIE_NAME)
|
||||
if cookie_token:
|
||||
try:
|
||||
user = await get_jwt_strategy().read_token(cookie_token, user_manager)
|
||||
except Exception:
|
||||
logger.exception("Failed to read session cookie access token")
|
||||
user = None
|
||||
|
||||
return AuthContext.session(user)
|
||||
if user and user.is_active:
|
||||
return AuthContext.session(user)
|
||||
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Unauthorized",
|
||||
)
|
||||
|
||||
|
||||
async def allow_any_principal(
|
||||
|
|
@ -372,5 +391,3 @@ async def require_session_context(
|
|||
)
|
||||
return auth
|
||||
|
||||
|
||||
current_optional_user = fastapi_users.current_user(active=True, optional=True)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue