mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-04-26 01:06:23 +02:00
feat: implement token encryption and state management for OAuth connectors
- Added encryption for sensitive tokens (access token, refresh token, client secret) in Google Calendar, Google Drive, Gmail, Linear, and Notion connectors to enhance security. - Introduced OAuthStateManager for secure state parameter generation and validation, improving the integrity of OAuth flows. - Updated callback routes to handle state validation and error management, ensuring robust handling of authorization processes. - Enhanced indexers to support decryption of tokens for backward compatibility, maintaining functionality with existing encrypted credentials. - Improved validation for date parameters in connector routes to ensure proper input handling.
This commit is contained in:
parent
ec7599362d
commit
45489423d1
17 changed files with 1140 additions and 152 deletions
216
surfsense_backend/app/utils/oauth_security.py
Normal file
216
surfsense_backend/app/utils/oauth_security.py
Normal file
|
|
@ -0,0 +1,216 @@
|
|||
"""
|
||||
OAuth Security Utilities.
|
||||
|
||||
Provides secure state parameter generation/validation and token encryption
|
||||
for OAuth 2.0 flows.
|
||||
"""
|
||||
|
||||
import base64
|
||||
import hashlib
|
||||
import hmac
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from uuid import UUID
|
||||
|
||||
from cryptography.fernet import Fernet
|
||||
from fastapi import HTTPException
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OAuthStateManager:
|
||||
"""Manages secure OAuth state parameters with HMAC signatures."""
|
||||
|
||||
def __init__(self, secret_key: str, max_age_seconds: int = 600):
|
||||
"""
|
||||
Initialize OAuth state manager.
|
||||
|
||||
Args:
|
||||
secret_key: Secret key for HMAC signing (should be SECRET_KEY from config)
|
||||
max_age_seconds: Maximum age of state parameter in seconds (default 10 minutes)
|
||||
"""
|
||||
if not secret_key:
|
||||
raise ValueError("secret_key is required for OAuth state management")
|
||||
self.secret_key = secret_key
|
||||
self.max_age_seconds = max_age_seconds
|
||||
|
||||
def generate_secure_state(self, space_id: int, user_id: UUID, **extra_fields) -> str:
|
||||
"""
|
||||
Generate cryptographically signed state parameter.
|
||||
|
||||
Args:
|
||||
space_id: The search space ID
|
||||
user_id: The user ID
|
||||
**extra_fields: Additional fields to include in state (e.g., code_verifier for PKCE)
|
||||
|
||||
Returns:
|
||||
Base64-encoded state parameter with HMAC signature
|
||||
"""
|
||||
timestamp = int(time.time())
|
||||
state_payload = {
|
||||
"space_id": space_id,
|
||||
"user_id": str(user_id),
|
||||
"timestamp": timestamp,
|
||||
}
|
||||
|
||||
# Add any extra fields (e.g., code_verifier for PKCE)
|
||||
state_payload.update(extra_fields)
|
||||
|
||||
# Create signature
|
||||
payload_str = json.dumps(state_payload, sort_keys=True)
|
||||
signature = hmac.new(
|
||||
self.secret_key.encode(),
|
||||
payload_str.encode(),
|
||||
hashlib.sha256,
|
||||
).hexdigest()
|
||||
|
||||
# Include signature in state
|
||||
state_payload["signature"] = signature
|
||||
state_encoded = base64.urlsafe_b64encode(
|
||||
json.dumps(state_payload).encode()
|
||||
).decode()
|
||||
|
||||
return state_encoded
|
||||
|
||||
def validate_state(self, state: str) -> dict:
|
||||
"""
|
||||
Validate and decode state parameter with signature verification.
|
||||
|
||||
Args:
|
||||
state: The state parameter from OAuth callback
|
||||
|
||||
Returns:
|
||||
Decoded state data (space_id, user_id, timestamp)
|
||||
|
||||
Raises:
|
||||
HTTPException: If state is invalid, expired, or tampered with
|
||||
"""
|
||||
try:
|
||||
decoded = base64.urlsafe_b64decode(state.encode()).decode()
|
||||
data = json.loads(decoded)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=400, detail=f"Invalid state format: {e!s}"
|
||||
) from e
|
||||
|
||||
# Verify signature exists
|
||||
signature = data.pop("signature", None)
|
||||
if not signature:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Missing state signature"
|
||||
)
|
||||
|
||||
# Verify signature
|
||||
payload_str = json.dumps(data, sort_keys=True)
|
||||
expected_signature = hmac.new(
|
||||
self.secret_key.encode(),
|
||||
payload_str.encode(),
|
||||
hashlib.sha256,
|
||||
).hexdigest()
|
||||
|
||||
if not hmac.compare_digest(signature, expected_signature):
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Invalid state signature - possible tampering"
|
||||
)
|
||||
|
||||
# Verify timestamp (prevent replay attacks)
|
||||
timestamp = data.get("timestamp", 0)
|
||||
current_time = time.time()
|
||||
age = current_time - timestamp
|
||||
|
||||
if age < 0:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Invalid state timestamp"
|
||||
)
|
||||
|
||||
if age > self.max_age_seconds:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="State parameter expired. Please try again.",
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
|
||||
class TokenEncryption:
|
||||
"""Encrypt/decrypt sensitive OAuth tokens for storage."""
|
||||
|
||||
def __init__(self, secret_key: str):
|
||||
"""
|
||||
Initialize token encryption.
|
||||
|
||||
Args:
|
||||
secret_key: Secret key for encryption (should be SECRET_KEY from config)
|
||||
"""
|
||||
if not secret_key:
|
||||
raise ValueError("secret_key is required for token encryption")
|
||||
# Derive Fernet key from secret using SHA256
|
||||
# Note: In production, consider using HKDF for key derivation
|
||||
key = base64.urlsafe_b64encode(
|
||||
hashlib.sha256(secret_key.encode()).digest()
|
||||
)
|
||||
try:
|
||||
self.cipher = Fernet(key)
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
f"Failed to initialize encryption cipher: {e!s}"
|
||||
) from e
|
||||
|
||||
def encrypt_token(self, token: str) -> str:
|
||||
"""
|
||||
Encrypt a token for storage.
|
||||
|
||||
Args:
|
||||
token: Plaintext token to encrypt
|
||||
|
||||
Returns:
|
||||
Encrypted token string
|
||||
"""
|
||||
if not token:
|
||||
return token
|
||||
try:
|
||||
return self.cipher.encrypt(token.encode()).decode()
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to encrypt token: {e!s}")
|
||||
raise ValueError(f"Token encryption failed: {e!s}") from e
|
||||
|
||||
def decrypt_token(self, encrypted_token: str) -> str:
|
||||
"""
|
||||
Decrypt a stored token.
|
||||
|
||||
Args:
|
||||
encrypted_token: Encrypted token string
|
||||
|
||||
Returns:
|
||||
Decrypted plaintext token
|
||||
"""
|
||||
if not encrypted_token:
|
||||
return encrypted_token
|
||||
try:
|
||||
return self.cipher.decrypt(encrypted_token.encode()).decode()
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to decrypt token: {e!s}")
|
||||
raise ValueError(f"Token decryption failed: {e!s}") from e
|
||||
|
||||
def is_encrypted(self, token: str) -> bool:
|
||||
"""
|
||||
Check if a token appears to be encrypted.
|
||||
|
||||
Args:
|
||||
token: Token string to check
|
||||
|
||||
Returns:
|
||||
True if token appears encrypted, False otherwise
|
||||
"""
|
||||
if not token:
|
||||
return False
|
||||
# Encrypted tokens are base64-encoded and have specific format
|
||||
# This is a heuristic check - encrypted tokens are longer and base64-like
|
||||
try:
|
||||
# Try to decode as base64
|
||||
base64.urlsafe_b64decode(token.encode())
|
||||
# If it's base64 and reasonably long, likely encrypted
|
||||
return len(token) > 20
|
||||
except Exception:
|
||||
return False
|
||||
Loading…
Add table
Add a link
Reference in a new issue