mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-04-25 16:56:22 +02:00
- Removed the `generate_pkce_pair` function from `airtable_add_connector_route.py` and relocated it to `oauth_security.py` for better organization. - Updated imports in `airtable_add_connector_route.py` to reflect the new location of the PKCE generation function.
231 lines
7.1 KiB
Python
231 lines
7.1 KiB
Python
"""
|
|
OAuth Security Utilities.
|
|
|
|
Provides secure state parameter generation/validation and token encryption
|
|
for OAuth 2.0 flows.
|
|
"""
|
|
|
|
import base64
|
|
import hashlib
|
|
import hmac
|
|
import json
|
|
import logging
|
|
import time
|
|
from random import SystemRandom
|
|
from string import ascii_letters, digits
|
|
from uuid import UUID
|
|
|
|
from cryptography.fernet import Fernet
|
|
from fastapi import HTTPException
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
_PKCE_CHARS = ascii_letters + digits + "-._~"
|
|
_PKCE_RNG = SystemRandom()
|
|
|
|
|
|
def generate_code_verifier(length: int = 128) -> str:
|
|
"""Generate a PKCE code_verifier (RFC 7636, 43-128 unreserved chars)."""
|
|
return "".join(_PKCE_RNG.choice(_PKCE_CHARS) for _ in range(length))
|
|
|
|
|
|
def generate_pkce_pair(length: int = 128) -> tuple[str, str]:
|
|
"""Generate a PKCE code_verifier and its S256 code_challenge."""
|
|
verifier = generate_code_verifier(length)
|
|
challenge = (
|
|
base64.urlsafe_b64encode(hashlib.sha256(verifier.encode()).digest())
|
|
.decode()
|
|
.rstrip("=")
|
|
)
|
|
return verifier, challenge
|
|
|
|
|
|
class OAuthStateManager:
|
|
"""Manages secure OAuth state parameters with HMAC signatures."""
|
|
|
|
def __init__(self, secret_key: str, max_age_seconds: int = 600):
|
|
"""
|
|
Initialize OAuth state manager.
|
|
|
|
Args:
|
|
secret_key: Secret key for HMAC signing (should be SECRET_KEY from config)
|
|
max_age_seconds: Maximum age of state parameter in seconds (default 10 minutes)
|
|
"""
|
|
if not secret_key:
|
|
raise ValueError("secret_key is required for OAuth state management")
|
|
self.secret_key = secret_key
|
|
self.max_age_seconds = max_age_seconds
|
|
|
|
def generate_secure_state(
|
|
self, space_id: int, user_id: UUID, **extra_fields
|
|
) -> str:
|
|
"""
|
|
Generate cryptographically signed state parameter.
|
|
|
|
Args:
|
|
space_id: The search space ID
|
|
user_id: The user ID
|
|
**extra_fields: Additional fields to include in state (e.g., code_verifier for PKCE)
|
|
|
|
Returns:
|
|
Base64-encoded state parameter with HMAC signature
|
|
"""
|
|
timestamp = int(time.time())
|
|
state_payload = {
|
|
"space_id": space_id,
|
|
"user_id": str(user_id),
|
|
"timestamp": timestamp,
|
|
}
|
|
|
|
# Add any extra fields (e.g., code_verifier for PKCE)
|
|
state_payload.update(extra_fields)
|
|
|
|
# Create signature
|
|
payload_str = json.dumps(state_payload, sort_keys=True)
|
|
signature = hmac.new(
|
|
self.secret_key.encode(),
|
|
payload_str.encode(),
|
|
hashlib.sha256,
|
|
).hexdigest()
|
|
|
|
# Include signature in state
|
|
state_payload["signature"] = signature
|
|
state_encoded = base64.urlsafe_b64encode(
|
|
json.dumps(state_payload).encode()
|
|
).decode()
|
|
|
|
return state_encoded
|
|
|
|
def validate_state(self, state: str) -> dict:
|
|
"""
|
|
Validate and decode state parameter with signature verification.
|
|
|
|
Args:
|
|
state: The state parameter from OAuth callback
|
|
|
|
Returns:
|
|
Decoded state data (space_id, user_id, timestamp)
|
|
|
|
Raises:
|
|
HTTPException: If state is invalid, expired, or tampered with
|
|
"""
|
|
try:
|
|
decoded = base64.urlsafe_b64decode(state.encode()).decode()
|
|
data = json.loads(decoded)
|
|
except Exception as e:
|
|
raise HTTPException(
|
|
status_code=400, detail=f"Invalid state format: {e!s}"
|
|
) from e
|
|
|
|
# Verify signature exists
|
|
signature = data.pop("signature", None)
|
|
if not signature:
|
|
raise HTTPException(status_code=400, detail="Missing state signature")
|
|
|
|
# Verify signature
|
|
payload_str = json.dumps(data, sort_keys=True)
|
|
expected_signature = hmac.new(
|
|
self.secret_key.encode(),
|
|
payload_str.encode(),
|
|
hashlib.sha256,
|
|
).hexdigest()
|
|
|
|
if not hmac.compare_digest(signature, expected_signature):
|
|
raise HTTPException(
|
|
status_code=400, detail="Invalid state signature - possible tampering"
|
|
)
|
|
|
|
# Verify timestamp (prevent replay attacks)
|
|
timestamp = data.get("timestamp", 0)
|
|
current_time = time.time()
|
|
age = current_time - timestamp
|
|
|
|
if age < 0:
|
|
raise HTTPException(status_code=400, detail="Invalid state timestamp")
|
|
|
|
if age > self.max_age_seconds:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail="State parameter expired. Please try again.",
|
|
)
|
|
|
|
return data
|
|
|
|
|
|
class TokenEncryption:
|
|
"""Encrypt/decrypt sensitive OAuth tokens for storage."""
|
|
|
|
def __init__(self, secret_key: str):
|
|
"""
|
|
Initialize token encryption.
|
|
|
|
Args:
|
|
secret_key: Secret key for encryption (should be SECRET_KEY from config)
|
|
"""
|
|
if not secret_key:
|
|
raise ValueError("secret_key is required for token encryption")
|
|
# Derive Fernet key from secret using SHA256
|
|
# Note: In production, consider using HKDF for key derivation
|
|
key = base64.urlsafe_b64encode(hashlib.sha256(secret_key.encode()).digest())
|
|
try:
|
|
self.cipher = Fernet(key)
|
|
except Exception as e:
|
|
raise ValueError(f"Failed to initialize encryption cipher: {e!s}") from e
|
|
|
|
def encrypt_token(self, token: str) -> str:
|
|
"""
|
|
Encrypt a token for storage.
|
|
|
|
Args:
|
|
token: Plaintext token to encrypt
|
|
|
|
Returns:
|
|
Encrypted token string
|
|
"""
|
|
if not token:
|
|
return token
|
|
try:
|
|
return self.cipher.encrypt(token.encode()).decode()
|
|
except Exception as e:
|
|
logger.error(f"Failed to encrypt token: {e!s}")
|
|
raise ValueError(f"Token encryption failed: {e!s}") from e
|
|
|
|
def decrypt_token(self, encrypted_token: str) -> str:
|
|
"""
|
|
Decrypt a stored token.
|
|
|
|
Args:
|
|
encrypted_token: Encrypted token string
|
|
|
|
Returns:
|
|
Decrypted plaintext token
|
|
"""
|
|
if not encrypted_token:
|
|
return encrypted_token
|
|
try:
|
|
return self.cipher.decrypt(encrypted_token.encode()).decode()
|
|
except Exception as e:
|
|
logger.error(f"Failed to decrypt token: {e!s}")
|
|
raise ValueError(f"Token decryption failed: {e!s}") from e
|
|
|
|
def is_encrypted(self, token: str) -> bool:
|
|
"""
|
|
Check if a token appears to be encrypted.
|
|
|
|
Args:
|
|
token: Token string to check
|
|
|
|
Returns:
|
|
True if token appears encrypted, False otherwise
|
|
"""
|
|
if not token:
|
|
return False
|
|
# Encrypted tokens are base64-encoded and have specific format
|
|
# This is a heuristic check - encrypted tokens are longer and base64-like
|
|
try:
|
|
# Try to decode as base64
|
|
base64.urlsafe_b64decode(token.encode())
|
|
# If it's base64 and reasonably long, likely encrypted
|
|
return len(token) > 20
|
|
except Exception:
|
|
return False
|