SurfSense/surfsense_backend/tests/unit/utils/test_oauth_security.py

94 lines
2.6 KiB
Python
Raw Normal View History

import base64
import hashlib
import hmac
import json
import time
from uuid import uuid4
import pytest
from fastapi import HTTPException
from app.utils.oauth_security import OAuthStateManager
SECRET = "unit-test-secret"
def _encode_state(payload: dict, *, signature: str | None = None) -> str:
"""Build an OAuth state payload compatible with OAuthStateManager."""
signature_payload = payload.copy()
payload_str = json.dumps(signature_payload, sort_keys=True)
computed_signature = hmac.new(
SECRET.encode(),
payload_str.encode(),
hashlib.sha256,
).hexdigest()
encoded_payload = {
**signature_payload,
"signature": signature if signature is not None else computed_signature,
}
return base64.urlsafe_b64encode(json.dumps(encoded_payload).encode()).decode()
def test_validate_state_accepts_fresh_signed_state():
mgr = OAuthStateManager(secret_key=SECRET, max_age_seconds=600)
user_id = uuid4()
state = mgr.generate_secure_state(
space_id=1,
user_id=user_id,
toolkit_id="googledrive",
)
decoded = mgr.validate_state(state)
assert decoded["space_id"] == 1
assert decoded["user_id"] == str(user_id)
assert decoded["toolkit_id"] == "googledrive"
def test_validate_state_rejects_expired_state():
mgr = OAuthStateManager(secret_key=SECRET, max_age_seconds=600)
expired_state = _encode_state(
{
"space_id": 1,
"user_id": str(uuid4()),
"timestamp": int(time.time()) - 3600,
"toolkit_id": "googledrive",
}
)
with pytest.raises(HTTPException) as exc:
mgr.validate_state(expired_state)
assert exc.value.status_code == 400
assert "expired" in exc.value.detail.lower()
def test_validate_state_rejects_tampered_signature():
mgr = OAuthStateManager(secret_key=SECRET, max_age_seconds=600)
tampered_state = _encode_state(
{
"space_id": 1,
"user_id": str(uuid4()),
"timestamp": int(time.time()),
"toolkit_id": "googledrive",
},
signature="deadbeef" * 8,
)
with pytest.raises(HTTPException) as exc:
mgr.validate_state(tampered_state)
assert exc.value.status_code == 400
assert "tampering" in exc.value.detail.lower()
def test_validate_state_rejects_malformed_state():
mgr = OAuthStateManager(secret_key=SECRET)
with pytest.raises(HTTPException) as exc:
mgr.validate_state("not-base64-and-not-json")
assert exc.value.status_code == 400
assert "invalid state format" in exc.value.detail.lower()