diff --git a/surfsense_backend/tests/unit/utils/test_oauth_security.py b/surfsense_backend/tests/unit/utils/test_oauth_security.py new file mode 100644 index 000000000..a00df62ee --- /dev/null +++ b/surfsense_backend/tests/unit/utils/test_oauth_security.py @@ -0,0 +1,93 @@ +import base64 +import hashlib +import hmac +import json +import time +from uuid import uuid4 + +import pytest +from fastapi import HTTPException + +from app.utils.oauth_security import OAuthStateManager + +SECRET = "unit-test-secret" + + +def _encode_state(payload: dict, *, signature: str | None = None) -> str: + """Build an OAuth state payload compatible with OAuthStateManager.""" + signature_payload = payload.copy() + payload_str = json.dumps(signature_payload, sort_keys=True) + computed_signature = hmac.new( + SECRET.encode(), + payload_str.encode(), + hashlib.sha256, + ).hexdigest() + encoded_payload = { + **signature_payload, + "signature": signature if signature is not None else computed_signature, + } + return base64.urlsafe_b64encode(json.dumps(encoded_payload).encode()).decode() + + +def test_validate_state_accepts_fresh_signed_state(): + mgr = OAuthStateManager(secret_key=SECRET, max_age_seconds=600) + user_id = uuid4() + + state = mgr.generate_secure_state( + space_id=1, + user_id=user_id, + toolkit_id="googledrive", + ) + + decoded = mgr.validate_state(state) + + assert decoded["space_id"] == 1 + assert decoded["user_id"] == str(user_id) + assert decoded["toolkit_id"] == "googledrive" + + +def test_validate_state_rejects_expired_state(): + mgr = OAuthStateManager(secret_key=SECRET, max_age_seconds=600) + expired_state = _encode_state( + { + "space_id": 1, + "user_id": str(uuid4()), + "timestamp": int(time.time()) - 3600, + "toolkit_id": "googledrive", + } + ) + + with pytest.raises(HTTPException) as exc: + mgr.validate_state(expired_state) + + assert exc.value.status_code == 400 + assert "expired" in exc.value.detail.lower() + + +def test_validate_state_rejects_tampered_signature(): + mgr = OAuthStateManager(secret_key=SECRET, max_age_seconds=600) + tampered_state = _encode_state( + { + "space_id": 1, + "user_id": str(uuid4()), + "timestamp": int(time.time()), + "toolkit_id": "googledrive", + }, + signature="deadbeef" * 8, + ) + + with pytest.raises(HTTPException) as exc: + mgr.validate_state(tampered_state) + + assert exc.value.status_code == 400 + assert "tampering" in exc.value.detail.lower() + + +def test_validate_state_rejects_malformed_state(): + mgr = OAuthStateManager(secret_key=SECRET) + + with pytest.raises(HTTPException) as exc: + mgr.validate_state("not-base64-and-not-json") + + assert exc.value.status_code == 400 + assert "invalid state format" in exc.value.detail.lower()