mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-06-26 21:39:43 +02:00
feat(oauth): migrate Google OAuth account IDs to use 'sub' and enhance user resolution logic
This commit is contained in:
parent
eb76c02d43
commit
e5aded5a65
3 changed files with 298 additions and 78 deletions
|
|
@ -0,0 +1,48 @@
|
||||||
|
"""migrate Google OAuth account IDs to sub
|
||||||
|
|
||||||
|
Revision ID: 169
|
||||||
|
Revises: 168
|
||||||
|
"""
|
||||||
|
|
||||||
|
from collections.abc import Sequence
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
revision: str = "169"
|
||||||
|
down_revision: str | None = "168"
|
||||||
|
branch_labels: str | Sequence[str] | None = None
|
||||||
|
depends_on: str | Sequence[str] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
op.execute(
|
||||||
|
"""
|
||||||
|
UPDATE oauth_account AS legacy
|
||||||
|
SET account_id = regexp_replace(legacy.account_id, '^people/', '')
|
||||||
|
WHERE legacy.oauth_name = 'google'
|
||||||
|
AND legacy.account_id LIKE 'people/%'
|
||||||
|
AND NOT EXISTS (
|
||||||
|
SELECT 1
|
||||||
|
FROM oauth_account AS canonical
|
||||||
|
WHERE canonical.oauth_name = 'google'
|
||||||
|
AND canonical.account_id = regexp_replace(legacy.account_id, '^people/', '')
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.execute(
|
||||||
|
"""
|
||||||
|
UPDATE oauth_account AS canonical
|
||||||
|
SET account_id = 'people/' || canonical.account_id
|
||||||
|
WHERE canonical.oauth_name = 'google'
|
||||||
|
AND canonical.account_id NOT LIKE 'people/%'
|
||||||
|
AND NOT EXISTS (
|
||||||
|
SELECT 1
|
||||||
|
FROM oauth_account AS legacy
|
||||||
|
WHERE legacy.oauth_name = 'google'
|
||||||
|
AND legacy.account_id = 'people/' || canonical.account_id
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
@ -10,7 +10,7 @@ from datetime import UTC, datetime
|
||||||
from threading import Lock
|
from threading import Lock
|
||||||
|
|
||||||
import redis
|
import redis
|
||||||
from fastapi import Depends, FastAPI, HTTPException, Request, status
|
from fastapi import Depends, FastAPI, HTTPException, Request, Response, status
|
||||||
from fastapi.exceptions import RequestValidationError
|
from fastapi.exceptions import RequestValidationError
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from fastapi.responses import JSONResponse
|
from fastapi.responses import JSONResponse
|
||||||
|
|
@ -54,8 +54,11 @@ from app.observability import metrics as ot_metrics
|
||||||
from app.observability.bootstrap import init_otel, shutdown_otel
|
from app.observability.bootstrap import init_otel, shutdown_otel
|
||||||
from app.rate_limiter import get_real_client_ip, limiter
|
from app.rate_limiter import get_real_client_ip, limiter
|
||||||
from app.routes import router as crud_router
|
from app.routes import router as crud_router
|
||||||
from app.routes.auth_routes import router as auth_router
|
from app.routes.auth_routes import (
|
||||||
from app.routes.auth_routes import session_router
|
resolve_google_user,
|
||||||
|
router as auth_router,
|
||||||
|
session_router,
|
||||||
|
)
|
||||||
from app.routes.users_routes import router as users_router
|
from app.routes.users_routes import router as users_router
|
||||||
from app.routes.zero_context_routes import router as zero_context_router
|
from app.routes.zero_context_routes import router as zero_context_router
|
||||||
from app.schemas import UserCreate, UserRead
|
from app.schemas import UserCreate, UserRead
|
||||||
|
|
@ -893,36 +896,183 @@ if config.AUTH_TYPE == "GOOGLE":
|
||||||
parsed_url = urlparse(config.BACKEND_URL)
|
parsed_url = urlparse(config.BACKEND_URL)
|
||||||
csrf_cookie_domain = parsed_url.hostname
|
csrf_cookie_domain = parsed_url.hostname
|
||||||
|
|
||||||
app.include_router(
|
from fastapi_users.jwt import decode_jwt
|
||||||
fastapi_users.get_oauth_router(
|
from fastapi_users.router.oauth import (
|
||||||
google_oauth_client,
|
CSRF_TOKEN_COOKIE_NAME,
|
||||||
auth_backend,
|
CSRF_TOKEN_KEY,
|
||||||
SECRET,
|
STATE_TOKEN_AUDIENCE,
|
||||||
is_verified_by_default=True,
|
generate_state_token,
|
||||||
csrf_token_cookie_secure=is_secure_context,
|
)
|
||||||
csrf_token_cookie_samesite=csrf_cookie_samesite,
|
from google.auth.transport import requests as google_requests
|
||||||
csrf_token_cookie_httponly=False, # Required for cross-site OAuth in Firefox/Safari
|
from google.oauth2 import id_token as google_id_token
|
||||||
|
|
||||||
|
from app.users import get_user_manager
|
||||||
|
|
||||||
|
def _google_callback_url(request: Request) -> str:
|
||||||
|
if config.BACKEND_URL:
|
||||||
|
return f"{config.BACKEND_URL}/auth/google/callback"
|
||||||
|
return str(request.url_for("google_oauth_callback"))
|
||||||
|
|
||||||
|
def _set_google_oauth_csrf_cookie(response: Response, csrf_token: str) -> None:
|
||||||
|
response.set_cookie(
|
||||||
|
key=CSRF_TOKEN_COOKIE_NAME,
|
||||||
|
value=csrf_token,
|
||||||
|
max_age=3600,
|
||||||
|
path="/",
|
||||||
|
domain=csrf_cookie_domain,
|
||||||
|
secure=is_secure_context,
|
||||||
|
httponly=False, # Required for cross-site OAuth in Firefox/Safari
|
||||||
|
samesite=csrf_cookie_samesite,
|
||||||
)
|
)
|
||||||
if not config.BACKEND_URL
|
|
||||||
else fastapi_users.get_oauth_router(
|
async def _google_authorization_url(request: Request, response: Response) -> str:
|
||||||
google_oauth_client,
|
import secrets
|
||||||
auth_backend,
|
|
||||||
|
csrf_token = secrets.token_urlsafe(32)
|
||||||
|
state = generate_state_token(
|
||||||
|
{CSRF_TOKEN_KEY: csrf_token},
|
||||||
SECRET,
|
SECRET,
|
||||||
is_verified_by_default=True,
|
lifetime_seconds=3600,
|
||||||
redirect_url=f"{config.BACKEND_URL}/auth/google/callback",
|
)
|
||||||
csrf_token_cookie_secure=is_secure_context,
|
authorization_url = await google_oauth_client.get_authorization_url(
|
||||||
csrf_token_cookie_samesite=csrf_cookie_samesite,
|
_google_callback_url(request),
|
||||||
csrf_token_cookie_httponly=False, # Required for cross-site OAuth in Firefox/Safari
|
state,
|
||||||
csrf_token_cookie_domain=csrf_cookie_domain, # Explicitly set cookie domain
|
scope=["openid", "email", "profile"],
|
||||||
),
|
)
|
||||||
prefix="/auth/google",
|
_set_google_oauth_csrf_cookie(response, csrf_token)
|
||||||
|
return authorization_url
|
||||||
|
|
||||||
|
@app.get(
|
||||||
|
"/auth/google/authorize",
|
||||||
tags=["auth"],
|
tags=["auth"],
|
||||||
# REGISTRATION_ENABLED is a master auth kill switch: when set to FALSE
|
|
||||||
# it blocks BOTH new OAuth signups AND login of existing OAuth users
|
|
||||||
# (the fastapi-users OAuth router shares one callback for create+login,
|
|
||||||
# so this dependency closes both paths together).
|
|
||||||
dependencies=[Depends(registration_allowed)],
|
dependencies=[Depends(registration_allowed)],
|
||||||
)
|
)
|
||||||
|
async def google_authorize(request: Request, response: Response):
|
||||||
|
"""Return Google's authorization URL, matching fastapi-users' shape."""
|
||||||
|
return {
|
||||||
|
"authorization_url": await _google_authorization_url(request, response)
|
||||||
|
}
|
||||||
|
|
||||||
|
@app.get(
|
||||||
|
"/auth/google/callback",
|
||||||
|
name="google_oauth_callback",
|
||||||
|
tags=["auth"],
|
||||||
|
dependencies=[Depends(registration_allowed)],
|
||||||
|
)
|
||||||
|
async def google_oauth_callback(
|
||||||
|
request: Request,
|
||||||
|
user_manager=Depends(get_user_manager),
|
||||||
|
):
|
||||||
|
"""Handle web Google OAuth with the same verified-email policy as desktop."""
|
||||||
|
import secrets
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
import jwt as pyjwt
|
||||||
|
|
||||||
|
state = request.query_params.get("state")
|
||||||
|
code = request.query_params.get("code")
|
||||||
|
if not state or not code:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="OAuth callback missing code or state",
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
state_data = decode_jwt(state, SECRET, [STATE_TOKEN_AUDIENCE])
|
||||||
|
except pyjwt.DecodeError as exc:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="ACCESS_TOKEN_DECODE_ERROR",
|
||||||
|
) from exc
|
||||||
|
except pyjwt.ExpiredSignatureError as exc:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="ACCESS_TOKEN_ALREADY_EXPIRED",
|
||||||
|
) from exc
|
||||||
|
|
||||||
|
cookie_csrf_token = request.cookies.get(CSRF_TOKEN_COOKIE_NAME)
|
||||||
|
state_csrf_token = state_data.get(CSRF_TOKEN_KEY)
|
||||||
|
if (
|
||||||
|
not cookie_csrf_token
|
||||||
|
or not state_csrf_token
|
||||||
|
or not secrets.compare_digest(cookie_csrf_token, state_csrf_token)
|
||||||
|
):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="OAUTH_INVALID_STATE",
|
||||||
|
)
|
||||||
|
|
||||||
|
token_payload = {
|
||||||
|
"client_id": config.GOOGLE_OAUTH_CLIENT_ID,
|
||||||
|
"client_secret": config.GOOGLE_OAUTH_CLIENT_SECRET,
|
||||||
|
"code": code,
|
||||||
|
"grant_type": "authorization_code",
|
||||||
|
"redirect_uri": _google_callback_url(request),
|
||||||
|
}
|
||||||
|
async with httpx.AsyncClient(timeout=10) as client:
|
||||||
|
token_response = await client.post(
|
||||||
|
"https://oauth2.googleapis.com/token",
|
||||||
|
data=token_payload,
|
||||||
|
)
|
||||||
|
if token_response.status_code >= 400:
|
||||||
|
_error_logger.warning("Web Google OAuth exchange failed: %s", token_response.text)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="OAuth exchange failed",
|
||||||
|
)
|
||||||
|
|
||||||
|
token_data = token_response.json()
|
||||||
|
google_access_token = token_data.get("access_token")
|
||||||
|
google_id_token_value = token_data.get("id_token")
|
||||||
|
if not google_access_token or not google_id_token_value:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="OAuth exchange failed",
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
claims = google_id_token.verify_oauth2_token(
|
||||||
|
google_id_token_value,
|
||||||
|
google_requests.Request(),
|
||||||
|
config.GOOGLE_OAUTH_CLIENT_ID,
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
_error_logger.warning("Web Google id_token verification failed: %s", exc)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Invalid Google identity token",
|
||||||
|
) from exc
|
||||||
|
|
||||||
|
expires_at = (
|
||||||
|
int(datetime.now(UTC).timestamp()) + int(token_data["expires_in"])
|
||||||
|
if token_data.get("expires_in")
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
user = await resolve_google_user(
|
||||||
|
user_manager=user_manager,
|
||||||
|
request=request,
|
||||||
|
google_access_token=google_access_token,
|
||||||
|
claims=claims,
|
||||||
|
expires_at=expires_at,
|
||||||
|
google_refresh_token=token_data.get("refresh_token"),
|
||||||
|
)
|
||||||
|
if not user.is_active:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="LOGIN_BAD_CREDENTIALS",
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await auth_backend.login(auth_backend.get_strategy(), user)
|
||||||
|
await user_manager.on_after_login(user, request, response)
|
||||||
|
response.delete_cookie(
|
||||||
|
key=CSRF_TOKEN_COOKIE_NAME,
|
||||||
|
path="/",
|
||||||
|
domain=csrf_cookie_domain,
|
||||||
|
secure=is_secure_context,
|
||||||
|
samesite=csrf_cookie_samesite,
|
||||||
|
httponly=False,
|
||||||
|
)
|
||||||
|
return response
|
||||||
|
|
||||||
# Add a redirect-based authorize endpoint for Firefox/Safari compatibility
|
# Add a redirect-based authorize endpoint for Firefox/Safari compatibility
|
||||||
# This endpoint performs a server-side redirect instead of returning JSON
|
# This endpoint performs a server-side redirect instead of returning JSON
|
||||||
|
|
@ -947,43 +1097,9 @@ if config.AUTH_TYPE == "GOOGLE":
|
||||||
This fixes CSRF cookie issues in Firefox and Safari where cookies set
|
This fixes CSRF cookie issues in Firefox and Safari where cookies set
|
||||||
via cross-origin fetch requests are not sent on subsequent redirects.
|
via cross-origin fetch requests are not sent on subsequent redirects.
|
||||||
"""
|
"""
|
||||||
import secrets
|
response = RedirectResponse(url="", status_code=302)
|
||||||
|
authorization_url = await _google_authorization_url(request, response)
|
||||||
from fastapi_users.router.oauth import generate_state_token
|
response.headers["location"] = authorization_url
|
||||||
|
|
||||||
# Generate CSRF token
|
|
||||||
csrf_token = secrets.token_urlsafe(32)
|
|
||||||
|
|
||||||
# Build state token
|
|
||||||
state_data = {"csrftoken": csrf_token}
|
|
||||||
state = generate_state_token(state_data, SECRET, lifetime_seconds=3600)
|
|
||||||
|
|
||||||
# Get the callback URL
|
|
||||||
if config.BACKEND_URL:
|
|
||||||
redirect_url = f"{config.BACKEND_URL}/auth/google/callback"
|
|
||||||
else:
|
|
||||||
redirect_url = str(request.url_for("oauth:google.jwt.callback"))
|
|
||||||
|
|
||||||
# Get authorization URL from Google
|
|
||||||
authorization_url = await google_oauth_client.get_authorization_url(
|
|
||||||
redirect_url,
|
|
||||||
state,
|
|
||||||
scope=["openid", "email", "profile"],
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create redirect response and set CSRF cookie
|
|
||||||
response = RedirectResponse(url=authorization_url, status_code=302)
|
|
||||||
response.set_cookie(
|
|
||||||
key="fastapiusersoauthcsrf",
|
|
||||||
value=csrf_token,
|
|
||||||
max_age=3600,
|
|
||||||
path="/",
|
|
||||||
domain=csrf_cookie_domain,
|
|
||||||
secure=is_secure_context,
|
|
||||||
httponly=False, # Required for cross-site OAuth in Firefox/Safari
|
|
||||||
samesite=csrf_cookie_samesite,
|
|
||||||
)
|
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -7,6 +7,7 @@ from urllib.parse import urlparse
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Request, Response, status
|
from fastapi import APIRouter, Depends, HTTPException, Request, Response, status
|
||||||
|
from fastapi_users import exceptions as fastapi_users_exceptions
|
||||||
from google.auth.transport import requests as google_requests
|
from google.auth.transport import requests as google_requests
|
||||||
from google.oauth2 import id_token as google_id_token
|
from google.oauth2 import id_token as google_id_token
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
|
|
@ -58,6 +59,68 @@ async def _load_user(user_id) -> User | None:
|
||||||
return result.scalars().first()
|
return result.scalars().first()
|
||||||
|
|
||||||
|
|
||||||
|
async def resolve_google_user(
|
||||||
|
*,
|
||||||
|
user_manager: UserManager,
|
||||||
|
request: Request,
|
||||||
|
google_access_token: str,
|
||||||
|
claims: dict,
|
||||||
|
expires_at: int | None = None,
|
||||||
|
google_refresh_token: str | None = None,
|
||||||
|
) -> User:
|
||||||
|
"""Resolve a Google identity with one policy for web and desktop OAuth.
|
||||||
|
|
||||||
|
Email-based account linking is only allowed when Google asserts that the
|
||||||
|
email is verified. Existing OAuth accounts continue to resolve by provider
|
||||||
|
account id regardless of the current email claim.
|
||||||
|
"""
|
||||||
|
if not claims.get("sub") or not claims.get("email"):
|
||||||
|
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid Google identity token")
|
||||||
|
|
||||||
|
sub = claims["sub"]
|
||||||
|
email_verified = bool(claims.get("email_verified"))
|
||||||
|
|
||||||
|
canonical_user = await user_manager.user_db.get_by_oauth_account("google", sub)
|
||||||
|
if canonical_user is None:
|
||||||
|
legacy_account_id = f"people/{sub}"
|
||||||
|
legacy_user = await user_manager.user_db.get_by_oauth_account(
|
||||||
|
"google", legacy_account_id
|
||||||
|
)
|
||||||
|
if legacy_user is not None:
|
||||||
|
# Fallback for pre-sub Google OAuth rows created by the old web flow.
|
||||||
|
# TODO: Remove after oauth_account is fully backfilled to bare Google
|
||||||
|
# sub and production has zero google rows with account_id LIKE 'people/%'.
|
||||||
|
for oauth_account in legacy_user.oauth_accounts:
|
||||||
|
if (
|
||||||
|
oauth_account.oauth_name == "google"
|
||||||
|
and oauth_account.account_id == legacy_account_id
|
||||||
|
):
|
||||||
|
await user_manager.user_db.update_oauth_account(
|
||||||
|
legacy_user,
|
||||||
|
oauth_account,
|
||||||
|
{"account_id": sub},
|
||||||
|
)
|
||||||
|
break
|
||||||
|
|
||||||
|
try:
|
||||||
|
return await user_manager.oauth_callback(
|
||||||
|
"google",
|
||||||
|
google_access_token,
|
||||||
|
sub,
|
||||||
|
claims["email"],
|
||||||
|
expires_at=expires_at,
|
||||||
|
refresh_token=google_refresh_token,
|
||||||
|
request=request,
|
||||||
|
associate_by_email=email_verified,
|
||||||
|
is_verified_by_default=email_verified,
|
||||||
|
)
|
||||||
|
except fastapi_users_exceptions.UserAlreadyExists as exc:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="OAUTH_USER_ALREADY_EXISTS",
|
||||||
|
) from exc
|
||||||
|
|
||||||
|
|
||||||
@router.post("/refresh", response_model=None)
|
@router.post("/refresh", response_model=None)
|
||||||
@limiter.limit("30/minute")
|
@limiter.limit("30/minute")
|
||||||
async def refresh_access_token(
|
async def refresh_access_token(
|
||||||
|
|
@ -276,24 +339,17 @@ async def create_desktop_session(
|
||||||
detail="Invalid Google identity token",
|
detail="Invalid Google identity token",
|
||||||
) from exc
|
) from exc
|
||||||
|
|
||||||
if not claims.get("sub") or not claims.get("email"):
|
user = await resolve_google_user(
|
||||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid Google identity token")
|
user_manager=user_manager,
|
||||||
|
request=request,
|
||||||
email_verified = bool(claims.get("email_verified"))
|
google_access_token=access_token,
|
||||||
user = await user_manager.oauth_callback(
|
claims=claims,
|
||||||
"google",
|
|
||||||
access_token,
|
|
||||||
claims["sub"],
|
|
||||||
claims["email"],
|
|
||||||
expires_at=(
|
expires_at=(
|
||||||
int(datetime.now(UTC).timestamp()) + int(token_data["expires_in"])
|
int(datetime.now(UTC).timestamp()) + int(token_data["expires_in"])
|
||||||
if token_data.get("expires_in")
|
if token_data.get("expires_in")
|
||||||
else None
|
else None
|
||||||
),
|
),
|
||||||
refresh_token=token_data.get("refresh_token"),
|
google_refresh_token=token_data.get("refresh_token"),
|
||||||
request=request,
|
|
||||||
associate_by_email=email_verified,
|
|
||||||
is_verified_by_default=email_verified,
|
|
||||||
)
|
)
|
||||||
app_access_token = await get_jwt_strategy().write_token(user)
|
app_access_token = await get_jwt_strategy().write_token(user)
|
||||||
app_refresh_token = await create_refresh_token(user.id)
|
app_refresh_token = await create_refresh_token(user.id)
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue