SurfSense/surfsense_backend/app/utils/oauth_security.py
Anish Sarkar e814540727 refactor: move PKCE pair generatio for airtable
- Removed the `generate_pkce_pair` function from `airtable_add_connector_route.py` and relocated it to `oauth_security.py` for better organization.
- Updated imports in `airtable_add_connector_route.py` to reflect the new location of the PKCE generation function.
2026-04-04 03:36:54 +05:30

231 lines
7.1 KiB
Python

"""
OAuth Security Utilities.
Provides secure state parameter generation/validation and token encryption
for OAuth 2.0 flows.
"""
import base64
import hashlib
import hmac
import json
import logging
import time
from random import SystemRandom
from string import ascii_letters, digits
from uuid import UUID
from cryptography.fernet import Fernet
from fastapi import HTTPException
logger = logging.getLogger(__name__)
_PKCE_CHARS = ascii_letters + digits + "-._~"
_PKCE_RNG = SystemRandom()
def generate_code_verifier(length: int = 128) -> str:
"""Generate a PKCE code_verifier (RFC 7636, 43-128 unreserved chars)."""
return "".join(_PKCE_RNG.choice(_PKCE_CHARS) for _ in range(length))
def generate_pkce_pair(length: int = 128) -> tuple[str, str]:
"""Generate a PKCE code_verifier and its S256 code_challenge."""
verifier = generate_code_verifier(length)
challenge = (
base64.urlsafe_b64encode(hashlib.sha256(verifier.encode()).digest())
.decode()
.rstrip("=")
)
return verifier, challenge
class OAuthStateManager:
"""Manages secure OAuth state parameters with HMAC signatures."""
def __init__(self, secret_key: str, max_age_seconds: int = 600):
"""
Initialize OAuth state manager.
Args:
secret_key: Secret key for HMAC signing (should be SECRET_KEY from config)
max_age_seconds: Maximum age of state parameter in seconds (default 10 minutes)
"""
if not secret_key:
raise ValueError("secret_key is required for OAuth state management")
self.secret_key = secret_key
self.max_age_seconds = max_age_seconds
def generate_secure_state(
self, space_id: int, user_id: UUID, **extra_fields
) -> str:
"""
Generate cryptographically signed state parameter.
Args:
space_id: The search space ID
user_id: The user ID
**extra_fields: Additional fields to include in state (e.g., code_verifier for PKCE)
Returns:
Base64-encoded state parameter with HMAC signature
"""
timestamp = int(time.time())
state_payload = {
"space_id": space_id,
"user_id": str(user_id),
"timestamp": timestamp,
}
# Add any extra fields (e.g., code_verifier for PKCE)
state_payload.update(extra_fields)
# Create signature
payload_str = json.dumps(state_payload, sort_keys=True)
signature = hmac.new(
self.secret_key.encode(),
payload_str.encode(),
hashlib.sha256,
).hexdigest()
# Include signature in state
state_payload["signature"] = signature
state_encoded = base64.urlsafe_b64encode(
json.dumps(state_payload).encode()
).decode()
return state_encoded
def validate_state(self, state: str) -> dict:
"""
Validate and decode state parameter with signature verification.
Args:
state: The state parameter from OAuth callback
Returns:
Decoded state data (space_id, user_id, timestamp)
Raises:
HTTPException: If state is invalid, expired, or tampered with
"""
try:
decoded = base64.urlsafe_b64decode(state.encode()).decode()
data = json.loads(decoded)
except Exception as e:
raise HTTPException(
status_code=400, detail=f"Invalid state format: {e!s}"
) from e
# Verify signature exists
signature = data.pop("signature", None)
if not signature:
raise HTTPException(status_code=400, detail="Missing state signature")
# Verify signature
payload_str = json.dumps(data, sort_keys=True)
expected_signature = hmac.new(
self.secret_key.encode(),
payload_str.encode(),
hashlib.sha256,
).hexdigest()
if not hmac.compare_digest(signature, expected_signature):
raise HTTPException(
status_code=400, detail="Invalid state signature - possible tampering"
)
# Verify timestamp (prevent replay attacks)
timestamp = data.get("timestamp", 0)
current_time = time.time()
age = current_time - timestamp
if age < 0:
raise HTTPException(status_code=400, detail="Invalid state timestamp")
if age > self.max_age_seconds:
raise HTTPException(
status_code=400,
detail="State parameter expired. Please try again.",
)
return data
class TokenEncryption:
"""Encrypt/decrypt sensitive OAuth tokens for storage."""
def __init__(self, secret_key: str):
"""
Initialize token encryption.
Args:
secret_key: Secret key for encryption (should be SECRET_KEY from config)
"""
if not secret_key:
raise ValueError("secret_key is required for token encryption")
# Derive Fernet key from secret using SHA256
# Note: In production, consider using HKDF for key derivation
key = base64.urlsafe_b64encode(hashlib.sha256(secret_key.encode()).digest())
try:
self.cipher = Fernet(key)
except Exception as e:
raise ValueError(f"Failed to initialize encryption cipher: {e!s}") from e
def encrypt_token(self, token: str) -> str:
"""
Encrypt a token for storage.
Args:
token: Plaintext token to encrypt
Returns:
Encrypted token string
"""
if not token:
return token
try:
return self.cipher.encrypt(token.encode()).decode()
except Exception as e:
logger.error(f"Failed to encrypt token: {e!s}")
raise ValueError(f"Token encryption failed: {e!s}") from e
def decrypt_token(self, encrypted_token: str) -> str:
"""
Decrypt a stored token.
Args:
encrypted_token: Encrypted token string
Returns:
Decrypted plaintext token
"""
if not encrypted_token:
return encrypted_token
try:
return self.cipher.decrypt(encrypted_token.encode()).decode()
except Exception as e:
logger.error(f"Failed to decrypt token: {e!s}")
raise ValueError(f"Token decryption failed: {e!s}") from e
def is_encrypted(self, token: str) -> bool:
"""
Check if a token appears to be encrypted.
Args:
token: Token string to check
Returns:
True if token appears encrypted, False otherwise
"""
if not token:
return False
# Encrypted tokens are base64-encoded and have specific format
# This is a heuristic check - encrypted tokens are longer and base64-like
try:
# Try to decode as base64
base64.urlsafe_b64decode(token.encode())
# If it's base64 and reasonably long, likely encrypted
return len(token) > 20
except Exception:
return False