mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-12 17:22:38 +02:00
94 lines
2.6 KiB
Python
94 lines
2.6 KiB
Python
|
|
import base64
|
||
|
|
import hashlib
|
||
|
|
import hmac
|
||
|
|
import json
|
||
|
|
import time
|
||
|
|
from uuid import uuid4
|
||
|
|
|
||
|
|
import pytest
|
||
|
|
from fastapi import HTTPException
|
||
|
|
|
||
|
|
from app.utils.oauth_security import OAuthStateManager
|
||
|
|
|
||
|
|
SECRET = "unit-test-secret"
|
||
|
|
|
||
|
|
|
||
|
|
def _encode_state(payload: dict, *, signature: str | None = None) -> str:
|
||
|
|
"""Build an OAuth state payload compatible with OAuthStateManager."""
|
||
|
|
signature_payload = payload.copy()
|
||
|
|
payload_str = json.dumps(signature_payload, sort_keys=True)
|
||
|
|
computed_signature = hmac.new(
|
||
|
|
SECRET.encode(),
|
||
|
|
payload_str.encode(),
|
||
|
|
hashlib.sha256,
|
||
|
|
).hexdigest()
|
||
|
|
encoded_payload = {
|
||
|
|
**signature_payload,
|
||
|
|
"signature": signature if signature is not None else computed_signature,
|
||
|
|
}
|
||
|
|
return base64.urlsafe_b64encode(json.dumps(encoded_payload).encode()).decode()
|
||
|
|
|
||
|
|
|
||
|
|
def test_validate_state_accepts_fresh_signed_state():
|
||
|
|
mgr = OAuthStateManager(secret_key=SECRET, max_age_seconds=600)
|
||
|
|
user_id = uuid4()
|
||
|
|
|
||
|
|
state = mgr.generate_secure_state(
|
||
|
|
space_id=1,
|
||
|
|
user_id=user_id,
|
||
|
|
toolkit_id="googledrive",
|
||
|
|
)
|
||
|
|
|
||
|
|
decoded = mgr.validate_state(state)
|
||
|
|
|
||
|
|
assert decoded["space_id"] == 1
|
||
|
|
assert decoded["user_id"] == str(user_id)
|
||
|
|
assert decoded["toolkit_id"] == "googledrive"
|
||
|
|
|
||
|
|
|
||
|
|
def test_validate_state_rejects_expired_state():
|
||
|
|
mgr = OAuthStateManager(secret_key=SECRET, max_age_seconds=600)
|
||
|
|
expired_state = _encode_state(
|
||
|
|
{
|
||
|
|
"space_id": 1,
|
||
|
|
"user_id": str(uuid4()),
|
||
|
|
"timestamp": int(time.time()) - 3600,
|
||
|
|
"toolkit_id": "googledrive",
|
||
|
|
}
|
||
|
|
)
|
||
|
|
|
||
|
|
with pytest.raises(HTTPException) as exc:
|
||
|
|
mgr.validate_state(expired_state)
|
||
|
|
|
||
|
|
assert exc.value.status_code == 400
|
||
|
|
assert "expired" in exc.value.detail.lower()
|
||
|
|
|
||
|
|
|
||
|
|
def test_validate_state_rejects_tampered_signature():
|
||
|
|
mgr = OAuthStateManager(secret_key=SECRET, max_age_seconds=600)
|
||
|
|
tampered_state = _encode_state(
|
||
|
|
{
|
||
|
|
"space_id": 1,
|
||
|
|
"user_id": str(uuid4()),
|
||
|
|
"timestamp": int(time.time()),
|
||
|
|
"toolkit_id": "googledrive",
|
||
|
|
},
|
||
|
|
signature="deadbeef" * 8,
|
||
|
|
)
|
||
|
|
|
||
|
|
with pytest.raises(HTTPException) as exc:
|
||
|
|
mgr.validate_state(tampered_state)
|
||
|
|
|
||
|
|
assert exc.value.status_code == 400
|
||
|
|
assert "tampering" in exc.value.detail.lower()
|
||
|
|
|
||
|
|
|
||
|
|
def test_validate_state_rejects_malformed_state():
|
||
|
|
mgr = OAuthStateManager(secret_key=SECRET)
|
||
|
|
|
||
|
|
with pytest.raises(HTTPException) as exc:
|
||
|
|
mgr.validate_state("not-base64-and-not-json")
|
||
|
|
|
||
|
|
assert exc.value.status_code == 400
|
||
|
|
assert "invalid state format" in exc.value.detail.lower()
|