mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-15 18:25:18 +02:00
feat: implement PKCE support in native Google OAuth flows
- Added `generate_code_verifier` function to create a PKCE code verifier for enhanced security. - Updated Google Calendar, Drive, and Gmail connector routes to utilize the PKCE code verifier during OAuth authorization. - Modified state management to include the code verifier for secure state generation and validation.
This commit is contained in:
parent
09008c8f1a
commit
8e6b1c77ea
4 changed files with 55 additions and 15 deletions
|
|
@ -28,7 +28,7 @@ from app.utils.connector_naming import (
|
|||
check_duplicate_connector,
|
||||
generate_unique_connector_name,
|
||||
)
|
||||
from app.utils.oauth_security import OAuthStateManager, TokenEncryption
|
||||
from app.utils.oauth_security import OAuthStateManager, TokenEncryption, generate_code_verifier
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -96,9 +96,14 @@ async def connect_calendar(space_id: int, user: User = Depends(current_active_us
|
|||
|
||||
flow = get_google_flow()
|
||||
|
||||
# Generate secure state parameter with HMAC signature
|
||||
code_verifier = generate_code_verifier()
|
||||
flow.code_verifier = code_verifier
|
||||
|
||||
# Generate secure state parameter with HMAC signature (includes PKCE code_verifier)
|
||||
state_manager = get_state_manager()
|
||||
state_encoded = state_manager.generate_secure_state(space_id, user.id)
|
||||
state_encoded = state_manager.generate_secure_state(
|
||||
space_id, user.id, code_verifier=code_verifier
|
||||
)
|
||||
|
||||
auth_url, _ = flow.authorization_url(
|
||||
access_type="offline",
|
||||
|
|
@ -146,8 +151,11 @@ async def reauth_calendar(
|
|||
|
||||
flow = get_google_flow()
|
||||
|
||||
code_verifier = generate_code_verifier()
|
||||
flow.code_verifier = code_verifier
|
||||
|
||||
state_manager = get_state_manager()
|
||||
extra: dict = {"connector_id": connector_id}
|
||||
extra: dict = {"connector_id": connector_id, "code_verifier": code_verifier}
|
||||
if return_url and return_url.startswith("/"):
|
||||
extra["return_url"] = return_url
|
||||
state_encoded = state_manager.generate_secure_state(space_id, user.id, **extra)
|
||||
|
|
@ -225,6 +233,7 @@ async def calendar_callback(
|
|||
|
||||
user_id = UUID(data["user_id"])
|
||||
space_id = data["space_id"]
|
||||
code_verifier = data.get("code_verifier")
|
||||
|
||||
# Validate redirect URI (security: ensure it matches configured value)
|
||||
if not config.GOOGLE_CALENDAR_REDIRECT_URI:
|
||||
|
|
@ -233,6 +242,7 @@ async def calendar_callback(
|
|||
)
|
||||
|
||||
flow = get_google_flow()
|
||||
flow.code_verifier = code_verifier
|
||||
flow.fetch_token(code=code)
|
||||
|
||||
creds = flow.credentials
|
||||
|
|
|
|||
|
|
@ -41,7 +41,7 @@ from app.utils.connector_naming import (
|
|||
check_duplicate_connector,
|
||||
generate_unique_connector_name,
|
||||
)
|
||||
from app.utils.oauth_security import OAuthStateManager, TokenEncryption
|
||||
from app.utils.oauth_security import OAuthStateManager, TokenEncryption, generate_code_verifier
|
||||
|
||||
# Relax token scope validation for Google OAuth
|
||||
os.environ["OAUTHLIB_RELAX_TOKEN_SCOPE"] = "1"
|
||||
|
|
@ -127,14 +127,19 @@ async def connect_drive(space_id: int, user: User = Depends(current_active_user)
|
|||
|
||||
flow = get_google_flow()
|
||||
|
||||
# Generate secure state parameter with HMAC signature
|
||||
code_verifier = generate_code_verifier()
|
||||
flow.code_verifier = code_verifier
|
||||
|
||||
# Generate secure state parameter with HMAC signature (includes PKCE code_verifier)
|
||||
state_manager = get_state_manager()
|
||||
state_encoded = state_manager.generate_secure_state(space_id, user.id)
|
||||
state_encoded = state_manager.generate_secure_state(
|
||||
space_id, user.id, code_verifier=code_verifier
|
||||
)
|
||||
|
||||
# Generate authorization URL
|
||||
auth_url, _ = flow.authorization_url(
|
||||
access_type="offline", # Get refresh token
|
||||
prompt="consent", # Force consent screen to get refresh token
|
||||
access_type="offline",
|
||||
prompt="consent",
|
||||
include_granted_scopes="true",
|
||||
state=state_encoded,
|
||||
)
|
||||
|
|
@ -193,8 +198,11 @@ async def reauth_drive(
|
|||
|
||||
flow = get_google_flow()
|
||||
|
||||
code_verifier = generate_code_verifier()
|
||||
flow.code_verifier = code_verifier
|
||||
|
||||
state_manager = get_state_manager()
|
||||
extra: dict = {"connector_id": connector_id}
|
||||
extra: dict = {"connector_id": connector_id, "code_verifier": code_verifier}
|
||||
if return_url and return_url.startswith("/"):
|
||||
extra["return_url"] = return_url
|
||||
state_encoded = state_manager.generate_secure_state(space_id, user.id, **extra)
|
||||
|
|
@ -285,6 +293,7 @@ async def drive_callback(
|
|||
space_id = data["space_id"]
|
||||
reauth_connector_id = data.get("connector_id")
|
||||
reauth_return_url = data.get("return_url")
|
||||
code_verifier = data.get("code_verifier")
|
||||
|
||||
logger.info(
|
||||
f"Processing Google Drive callback for user {user_id}, space {space_id}"
|
||||
|
|
@ -296,8 +305,9 @@ async def drive_callback(
|
|||
status_code=500, detail="GOOGLE_DRIVE_REDIRECT_URI not configured"
|
||||
)
|
||||
|
||||
# Exchange authorization code for tokens
|
||||
# Exchange authorization code for tokens (restore PKCE code_verifier from state)
|
||||
flow = get_google_flow()
|
||||
flow.code_verifier = code_verifier
|
||||
flow.fetch_token(code=code)
|
||||
|
||||
creds = flow.credentials
|
||||
|
|
|
|||
|
|
@ -28,7 +28,7 @@ from app.utils.connector_naming import (
|
|||
check_duplicate_connector,
|
||||
generate_unique_connector_name,
|
||||
)
|
||||
from app.utils.oauth_security import OAuthStateManager, TokenEncryption
|
||||
from app.utils.oauth_security import OAuthStateManager, TokenEncryption, generate_code_verifier
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -109,9 +109,14 @@ async def connect_gmail(space_id: int, user: User = Depends(current_active_user)
|
|||
|
||||
flow = get_google_flow()
|
||||
|
||||
# Generate secure state parameter with HMAC signature
|
||||
code_verifier = generate_code_verifier()
|
||||
flow.code_verifier = code_verifier
|
||||
|
||||
# Generate secure state parameter with HMAC signature (includes PKCE code_verifier)
|
||||
state_manager = get_state_manager()
|
||||
state_encoded = state_manager.generate_secure_state(space_id, user.id)
|
||||
state_encoded = state_manager.generate_secure_state(
|
||||
space_id, user.id, code_verifier=code_verifier
|
||||
)
|
||||
|
||||
auth_url, _ = flow.authorization_url(
|
||||
access_type="offline",
|
||||
|
|
@ -164,8 +169,11 @@ async def reauth_gmail(
|
|||
|
||||
flow = get_google_flow()
|
||||
|
||||
code_verifier = generate_code_verifier()
|
||||
flow.code_verifier = code_verifier
|
||||
|
||||
state_manager = get_state_manager()
|
||||
extra: dict = {"connector_id": connector_id}
|
||||
extra: dict = {"connector_id": connector_id, "code_verifier": code_verifier}
|
||||
if return_url and return_url.startswith("/"):
|
||||
extra["return_url"] = return_url
|
||||
state_encoded = state_manager.generate_secure_state(space_id, user.id, **extra)
|
||||
|
|
@ -256,6 +264,7 @@ async def gmail_callback(
|
|||
|
||||
user_id = UUID(data["user_id"])
|
||||
space_id = data["space_id"]
|
||||
code_verifier = data.get("code_verifier")
|
||||
|
||||
# Validate redirect URI (security: ensure it matches configured value)
|
||||
if not config.GOOGLE_GMAIL_REDIRECT_URI:
|
||||
|
|
@ -264,6 +273,7 @@ async def gmail_callback(
|
|||
)
|
||||
|
||||
flow = get_google_flow()
|
||||
flow.code_verifier = code_verifier
|
||||
flow.fetch_token(code=code)
|
||||
|
||||
creds = flow.credentials
|
||||
|
|
|
|||
|
|
@ -11,6 +11,8 @@ import hmac
|
|||
import json
|
||||
import logging
|
||||
import time
|
||||
from random import SystemRandom
|
||||
from string import ascii_letters, digits
|
||||
from uuid import UUID
|
||||
|
||||
from cryptography.fernet import Fernet
|
||||
|
|
@ -18,6 +20,14 @@ from fastapi import HTTPException
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_PKCE_CHARS = ascii_letters + digits + "-._~"
|
||||
_PKCE_RNG = SystemRandom()
|
||||
|
||||
|
||||
def generate_code_verifier(length: int = 128) -> str:
|
||||
"""Generate a PKCE code_verifier (RFC 7636, 43-128 unreserved chars)."""
|
||||
return "".join(_PKCE_RNG.choice(_PKCE_CHARS) for _ in range(length))
|
||||
|
||||
|
||||
class OAuthStateManager:
|
||||
"""Manages secure OAuth state parameters with HMAC signatures."""
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue