""" Tests for gateway/auth.py — IamAuth, JWT verification, API key resolution cache. JWTs are signed with real Ed25519 keypairs generated per-test, so the crypto path is exercised end-to-end without mocks. API-key resolution is tested against a stubbed IamClient since the real one requires pub/sub. """ import base64 import json import time from unittest.mock import AsyncMock, Mock, patch import pytest from aiohttp import web from cryptography.hazmat.primitives import serialization from cryptography.hazmat.primitives.asymmetric import ed25519 from trustgraph.gateway.auth import ( IamAuth, Identity, _b64url_decode, _verify_jwt_eddsa, API_KEY_CACHE_TTL, ) # -- helpers --------------------------------------------------------------- def _b64url(data: bytes) -> str: return base64.urlsafe_b64encode(data).rstrip(b"=").decode("ascii") def make_keypair(): priv = ed25519.Ed25519PrivateKey.generate() public_pem = priv.public_key().public_bytes( encoding=serialization.Encoding.PEM, format=serialization.PublicFormat.SubjectPublicKeyInfo, ).decode("ascii") return priv, public_pem def sign_jwt(priv, claims, alg="EdDSA"): header = {"alg": alg, "typ": "JWT", "kid": "kid-test"} h = _b64url(json.dumps(header, separators=(",", ":"), sort_keys=True).encode()) p = _b64url(json.dumps(claims, separators=(",", ":"), sort_keys=True).encode()) signing_input = f"{h}.{p}".encode("ascii") if alg == "EdDSA": sig = priv.sign(signing_input) else: raise ValueError(f"test helper doesn't sign {alg}") return f"{h}.{p}.{_b64url(sig)}" def make_request(auth_header): """Minimal stand-in for an aiohttp request — IamAuth only reads ``request.headers["Authorization"]``.""" req = Mock() req.headers = {} if auth_header is not None: req.headers["Authorization"] = auth_header return req # -- pure helpers ---------------------------------------------------------- class TestB64UrlDecode: def test_round_trip_without_padding(self): data = b"hello" encoded = _b64url(data) assert _b64url_decode(encoded) == data def test_handles_various_lengths(self): for s in (b"a", b"ab", b"abc", b"abcd", b"abcde"): assert _b64url_decode(_b64url(s)) == s # -- JWT verification ----------------------------------------------------- class TestVerifyJwtEddsa: def test_valid_jwt_passes(self): priv, pub = make_keypair() claims = { "sub": "user-1", "workspace": "default", "iat": int(time.time()), "exp": int(time.time()) + 60, } token = sign_jwt(priv, claims) got = _verify_jwt_eddsa(token, pub) assert got["sub"] == "user-1" assert got["workspace"] == "default" def test_expired_jwt_rejected(self): priv, pub = make_keypair() claims = { "sub": "user-1", "workspace": "default", "iat": int(time.time()) - 3600, "exp": int(time.time()) - 1, } token = sign_jwt(priv, claims) with pytest.raises(ValueError, match="expired"): _verify_jwt_eddsa(token, pub) def test_bad_signature_rejected(self): priv_a, _ = make_keypair() _, pub_b = make_keypair() claims = { "sub": "user-1", "workspace": "default", "iat": int(time.time()), "exp": int(time.time()) + 60, } token = sign_jwt(priv_a, claims) # pub_b never signed this token. with pytest.raises(Exception): _verify_jwt_eddsa(token, pub_b) def test_malformed_jwt_rejected(self): _, pub = make_keypair() with pytest.raises(ValueError, match="malformed"): _verify_jwt_eddsa("not-a-jwt", pub) def test_unsupported_algorithm_rejected(self): priv, pub = make_keypair() # Manually build an "alg":"HS256" header — no signer needed # since we expect it to bail before verifying. header = {"alg": "HS256", "typ": "JWT", "kid": "x"} payload = { "sub": "user-1", "workspace": "default", "iat": int(time.time()), "exp": int(time.time()) + 60, } h = _b64url(json.dumps(header, separators=(",", ":")).encode()) p = _b64url(json.dumps(payload, separators=(",", ":")).encode()) sig = _b64url(b"not-a-real-sig") token = f"{h}.{p}.{sig}" with pytest.raises(ValueError, match="unsupported alg"): _verify_jwt_eddsa(token, pub) # -- Identity -------------------------------------------------------------- class TestIdentity: def test_fields(self): i = Identity( handle="u", workspace="w", principal_id="u", source="api-key", ) assert i.handle == "u" assert i.workspace == "w" assert i.principal_id == "u" assert i.source == "api-key" # -- IamAuth.authenticate -------------------------------------------------- class TestIamAuthDispatch: """``authenticate()`` chooses between the JWT and API-key paths by shape of the bearer.""" @pytest.mark.asyncio async def test_no_authorization_header_raises_401(self): auth = IamAuth(backend=Mock()) with pytest.raises(web.HTTPUnauthorized): await auth.authenticate(make_request(None)) @pytest.mark.asyncio async def test_non_bearer_header_raises_401(self): auth = IamAuth(backend=Mock()) with pytest.raises(web.HTTPUnauthorized): await auth.authenticate(make_request("Basic whatever")) @pytest.mark.asyncio async def test_empty_bearer_raises_401(self): auth = IamAuth(backend=Mock()) with pytest.raises(web.HTTPUnauthorized): await auth.authenticate(make_request("Bearer ")) @pytest.mark.asyncio async def test_unknown_format_raises_401(self): # Not tg_... and not dotted-JWT shape. auth = IamAuth(backend=Mock()) with pytest.raises(web.HTTPUnauthorized): await auth.authenticate(make_request("Bearer garbage")) @pytest.mark.asyncio async def test_valid_jwt_resolves_to_identity(self): priv, pub = make_keypair() claims = { "sub": "user-1", "workspace": "default", "iat": int(time.time()), "exp": int(time.time()) + 60, } token = sign_jwt(priv, claims) auth = IamAuth(backend=Mock()) auth._signing_public_pem = pub ident = await auth.authenticate( make_request(f"Bearer {token}") ) assert ident.handle == "user-1" assert ident.workspace == "default" assert ident.principal_id == "user-1" assert ident.source == "jwt" @pytest.mark.asyncio async def test_jwt_without_public_key_fails(self): # If the gateway hasn't fetched IAM's public key yet, JWTs # must not validate — even ones that would otherwise pass. priv, _ = make_keypair() claims = { "sub": "user-1", "workspace": "default", "iat": int(time.time()), "exp": int(time.time()) + 60, } token = sign_jwt(priv, claims) auth = IamAuth(backend=Mock()) # _signing_public_pem defaults to None with pytest.raises(web.HTTPUnauthorized): await auth.authenticate(make_request(f"Bearer {token}")) @pytest.mark.asyncio async def test_api_key_path(self): auth = IamAuth(backend=Mock()) async def fake_resolve(api_key): assert api_key == "tg_testkey" # Roles are returned by the regime as a hint but the # gateway ignores them — kept here so the resolve # protocol shape is exercised. return ("user-xyz", "default", ["admin"]) async def fake_with_client(op): return await op(Mock(resolve_api_key=fake_resolve)) with patch.object(auth, "_with_client", side_effect=fake_with_client): ident = await auth.authenticate( make_request("Bearer tg_testkey") ) assert ident.handle == "user-xyz" assert ident.workspace == "default" assert ident.principal_id == "user-xyz" assert ident.source == "api-key" @pytest.mark.asyncio async def test_api_key_rejection_masked_as_401(self): auth = IamAuth(backend=Mock()) async def fake_with_client(op): raise RuntimeError("auth-failed: unknown api key") with patch.object(auth, "_with_client", side_effect=fake_with_client): with pytest.raises(web.HTTPUnauthorized): await auth.authenticate( make_request("Bearer tg_bogus") ) # -- API key cache --------------------------------------------------------- class TestApiKeyCache: @pytest.mark.asyncio async def test_cache_hit_skips_iam(self): auth = IamAuth(backend=Mock()) calls = {"n": 0} async def fake_with_client(op): calls["n"] += 1 return await op(Mock( resolve_api_key=AsyncMock( return_value=("u", "default", ["reader"]), ) )) with patch.object(auth, "_with_client", side_effect=fake_with_client): await auth.authenticate(make_request("Bearer tg_k1")) await auth.authenticate(make_request("Bearer tg_k1")) await auth.authenticate(make_request("Bearer tg_k1")) # Only the first lookup reaches IAM; the rest are cache hits. assert calls["n"] == 1 @pytest.mark.asyncio async def test_different_keys_are_separately_cached(self): auth = IamAuth(backend=Mock()) seen = [] async def fake_with_client(op): async def resolve(plaintext): seen.append(plaintext) return ("u-" + plaintext, "default", ["reader"]) return await op(Mock(resolve_api_key=resolve)) with patch.object(auth, "_with_client", side_effect=fake_with_client): a = await auth.authenticate(make_request("Bearer tg_a")) b = await auth.authenticate(make_request("Bearer tg_b")) assert a.handle == "u-tg_a" assert b.handle == "u-tg_b" assert seen == ["tg_a", "tg_b"] @pytest.mark.asyncio async def test_cache_has_ttl_constant_set(self): # Not a behaviour test — just ensures we don't accidentally # set TTL to 0 (which would defeat the cache) or to a week. assert 10 <= API_KEY_CACHE_TTL <= 3600 # -- IamAuth.authorise ----------------------------------------------------- class TestAuthorise: """``authorise()`` is the gateway's only authorisation entry point under the IAM contract. It calls iam-svc, caches the decision for the regime's TTL (clamped above), and raises 403 on deny / 401 on regime error (fail closed).""" def _make_identity(self, handle="u-1", workspace="default"): return Identity( handle=handle, workspace=workspace, principal_id=handle, source="api-key", ) @pytest.mark.asyncio async def test_allow_returns_no_exception(self): auth = IamAuth(backend=Mock()) async def fake_with_client(op): return await op(Mock( authorise=AsyncMock(return_value=(True, 30)), )) with patch.object(auth, "_with_client", side_effect=fake_with_client): await auth.authorise( self._make_identity(), "graph:read", {"workspace": "default"}, {}, ) @pytest.mark.asyncio async def test_deny_raises_403(self): auth = IamAuth(backend=Mock()) async def fake_with_client(op): return await op(Mock( authorise=AsyncMock(return_value=(False, 30)), )) with patch.object(auth, "_with_client", side_effect=fake_with_client): with pytest.raises(web.HTTPForbidden): await auth.authorise( self._make_identity(), "users:admin", {}, {"workspace": "acme"}, ) @pytest.mark.asyncio async def test_regime_error_fails_closed_as_401(self): # If iam-svc errors, the gateway must NOT silently allow. auth = IamAuth(backend=Mock()) async def fake_with_client(op): raise RuntimeError("iam-svc down") with patch.object(auth, "_with_client", side_effect=fake_with_client): with pytest.raises(web.HTTPUnauthorized): await auth.authorise( self._make_identity(), "graph:read", {"workspace": "default"}, {}, ) @pytest.mark.asyncio async def test_allow_decision_is_cached(self): auth = IamAuth(backend=Mock()) calls = {"n": 0} async def fake_with_client(op): calls["n"] += 1 return await op(Mock( authorise=AsyncMock(return_value=(True, 30)), )) with patch.object(auth, "_with_client", side_effect=fake_with_client): ident = self._make_identity() for _ in range(5): await auth.authorise( ident, "graph:read", {"workspace": "default"}, {}, ) assert calls["n"] == 1 @pytest.mark.asyncio async def test_deny_decision_is_cached(self): auth = IamAuth(backend=Mock()) calls = {"n": 0} async def fake_with_client(op): calls["n"] += 1 return await op(Mock( authorise=AsyncMock(return_value=(False, 30)), )) with patch.object(auth, "_with_client", side_effect=fake_with_client): ident = self._make_identity() for _ in range(5): with pytest.raises(web.HTTPForbidden): await auth.authorise( ident, "users:admin", {}, {"workspace": "acme"}, ) # Denies are cached too — repeated attempts don't re-hit IAM. assert calls["n"] == 1 @pytest.mark.asyncio async def test_different_resources_cached_separately(self): auth = IamAuth(backend=Mock()) calls = {"n": 0} async def fake_with_client(op): calls["n"] += 1 return await op(Mock( authorise=AsyncMock(return_value=(True, 30)), )) with patch.object(auth, "_with_client", side_effect=fake_with_client): ident = self._make_identity() await auth.authorise( ident, "graph:read", {"workspace": "a"}, {}, ) await auth.authorise( ident, "graph:read", {"workspace": "b"}, {}, ) # Different resource → different cache key → two IAM calls. assert calls["n"] == 2