mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-29 18:36:22 +02:00
feat: IAM service, gateway auth middleware, capability model, and CLIs (#849)
Replaces the legacy GATEWAY_SECRET shared-token gate with an IAM-backed
identity and authorisation model. The gateway no longer has an
"allow-all" or "no auth" mode; every request is authenticated via the
IAM service, authorised against a capability model that encodes both
the operation and the workspace it targets, and rejected with a
deliberately-uninformative 401 / 403 on any failure.
IAM service (trustgraph-flow/trustgraph/iam, trustgraph-base/schema/iam)
-----------------------------------------------------------------------
* New backend service (iam-svc) owning users, workspaces, API keys,
passwords and JWT signing keys in Cassandra. Reached over the
standard pub/sub request/response pattern; gateway is the only
caller.
* Operations: bootstrap, resolve-api-key, login, get-signing-key-public,
rotate-signing-key, create/list/get/update/disable/delete/enable-user,
change-password, reset-password, create/list/get/update/disable-
workspace, create/list/revoke-api-key.
* Ed25519 JWT signing (alg=EdDSA). Key rotation writes a new kid and
retires the previous one; validation is grace-period friendly.
* Passwords: PBKDF2-HMAC-SHA-256, 600k iterations, per-user salt.
* API keys: 128-bit random, SHA-256 hashed. Plaintext returned once.
* Bootstrap is explicit: --bootstrap-mode {token,bootstrap} is a
required startup argument with no permissive default. Masked
"auth failure" errors hide whether a refused bootstrap request was
due to mode, state, or authorisation.
Gateway authentication (trustgraph-flow/trustgraph/gateway/auth.py)
-------------------------------------------------------------------
* IamAuth replaces the legacy Authenticator. Distinguishes JWTs
(three-segment dotted) from API keys by shape; verifies JWTs
locally using the cached IAM public key; resolves API keys via
IAM with a short-TTL hash-keyed cache. Every failure path
surfaces the same 401 body ("auth failure") so callers cannot
enumerate credential state.
* Public key is fetched at gateway startup with a bounded retry loop;
traffic does not begin flowing until auth has started.
Capability model (trustgraph-flow/trustgraph/gateway/capabilities.py)
---------------------------------------------------------------------
* Roles have two dimensions: a capability set and a workspace scope.
OSS ships reader / writer / admin; the first two are workspace-
assigned, admin is cross-workspace ("*"). No "cross-workspace"
pseudo-capability — workspace permission is a property of the role.
* check(identity, capability, target_workspace=None) is the single
authorisation test: some role must grant the capability *and* be
active in the target workspace.
* enforce_workspace validates a request-body workspace against the
caller's role scopes and injects the resolved value. Cross-
workspace admin is permitted by role scope, not by a bypass.
* Gateway endpoints declare a required capability explicitly — no
permissive default. Construction fails fast if omitted. Enterprise
editions can replace the role table without changing the wire
protocol.
WebSocket first-frame auth (dispatch/mux.py, endpoint/socket.py)
----------------------------------------------------------------
* /api/v1/socket handshake unconditionally accepts; authentication
runs on the first WebSocket frame ({"type":"auth","token":"..."})
with {"type":"auth-ok","workspace":"..."} / {"type":"auth-failed"}.
The socket stays open on failure so the client can re-authenticate
— browsers treat a handshake-time 401 as terminal, breaking
reconnection.
* Mux.receive rejects every non-auth frame before auth succeeds,
enforces the caller's workspace (envelope + inner payload) using
the role-scope resolver, and supports mid-session re-auth.
* Flow import/export streaming endpoints keep the legacy ?token=
handshake (URL-scoped short-lived transfers; no re-auth need).
Auth surface
------------
* POST /api/v1/auth/login — public, returns a JWT.
* POST /api/v1/auth/bootstrap — public; forwards to IAM's bootstrap
op which itself enforces mode + tables-empty.
* POST /api/v1/auth/change-password — any authenticated user.
* POST /api/v1/iam — admin-only generic forwarder for the rest of
the IAM API (per-op REST endpoints to follow in a later change).
Removed / breaking
------------------
* GATEWAY_SECRET / --api-token / default_api_token and the legacy
Authenticator.permitted contract. The gateway cannot run without
IAM.
* ?token= on /api/v1/socket.
* DispatcherManager and Mux both raise on auth=None — no silent
downgrade path.
CLI tools (trustgraph-cli)
--------------------------
tg-bootstrap-iam, tg-login, tg-create-user, tg-list-users,
tg-disable-user, tg-enable-user, tg-delete-user, tg-change-password,
tg-reset-password, tg-create-api-key, tg-list-api-keys,
tg-revoke-api-key, tg-create-workspace, tg-list-workspaces. Passwords
read via getpass; tokens / one-time secrets written to stdout with
operator context on stderr so shell composition works cleanly.
AsyncSocketClient / SocketClient updated to the first-frame auth
protocol.
Specifications
--------------
* docs/tech-specs/iam.md updated with the error policy, workspace
resolver extension point, and OSS role-scope model.
* docs/tech-specs/iam-protocol.md (new) — transport, dataclasses,
operation table, error taxonomy, bootstrap modes.
* docs/tech-specs/capabilities.md (new) — capability vocabulary, OSS
role bundles, agent-as-composition note, enforcement-boundary
policy, enterprise extensibility.
Tests
-----
* test_auth.py (rewritten) — IamAuth + JWT round-trip with real
Ed25519 keypairs + API-key cache behaviour.
* test_capabilities.py (new) — role table sanity, check across
role x workspace combinations, enforce_workspace paths,
unknown-cap / unknown-role fail-closed.
* Every endpoint test construction now names its capability
explicitly (no permissive defaults relied upon). New tests pin
the fail-closed invariants: DispatcherManager / Mux refuse
auth=None; i18n path-traversal defense is exercised.
* test_socket_graceful_shutdown rewritten against IamAuth.
This commit is contained in:
parent
ae9936c9cc
commit
67b2fc448f
61 changed files with 6474 additions and 792 deletions
|
|
@ -1,69 +1,312 @@
|
|||
"""
|
||||
Tests for Gateway Authentication
|
||||
Tests for gateway/auth.py — IamAuth, JWT verification, API key
|
||||
resolution cache.
|
||||
|
||||
JWTs are signed with real Ed25519 keypairs generated per-test, so
|
||||
the crypto path is exercised end-to-end without mocks. API-key
|
||||
resolution is tested against a stubbed IamClient since the real
|
||||
one requires pub/sub.
|
||||
"""
|
||||
|
||||
import base64
|
||||
import json
|
||||
import time
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
from aiohttp import web
|
||||
from cryptography.hazmat.primitives import serialization
|
||||
from cryptography.hazmat.primitives.asymmetric import ed25519
|
||||
|
||||
from trustgraph.gateway.auth import Authenticator
|
||||
from trustgraph.gateway.auth import (
|
||||
IamAuth, Identity,
|
||||
_b64url_decode, _verify_jwt_eddsa,
|
||||
API_KEY_CACHE_TTL,
|
||||
)
|
||||
|
||||
|
||||
class TestAuthenticator:
|
||||
"""Test cases for Authenticator class"""
|
||||
# -- helpers ---------------------------------------------------------------
|
||||
|
||||
def test_authenticator_initialization_with_token(self):
|
||||
"""Test Authenticator initialization with valid token"""
|
||||
auth = Authenticator(token="test-token-123")
|
||||
|
||||
assert auth.token == "test-token-123"
|
||||
assert auth.allow_all is False
|
||||
|
||||
def test_authenticator_initialization_with_allow_all(self):
|
||||
"""Test Authenticator initialization with allow_all=True"""
|
||||
auth = Authenticator(allow_all=True)
|
||||
|
||||
assert auth.token is None
|
||||
assert auth.allow_all is True
|
||||
def _b64url(data: bytes) -> str:
|
||||
return base64.urlsafe_b64encode(data).rstrip(b"=").decode("ascii")
|
||||
|
||||
def test_authenticator_initialization_without_token_raises_error(self):
|
||||
"""Test Authenticator initialization without token raises RuntimeError"""
|
||||
with pytest.raises(RuntimeError, match="Need a token"):
|
||||
Authenticator()
|
||||
|
||||
def test_authenticator_initialization_with_empty_token_raises_error(self):
|
||||
"""Test Authenticator initialization with empty token raises RuntimeError"""
|
||||
with pytest.raises(RuntimeError, match="Need a token"):
|
||||
Authenticator(token="")
|
||||
def make_keypair():
|
||||
priv = ed25519.Ed25519PrivateKey.generate()
|
||||
public_pem = priv.public_key().public_bytes(
|
||||
encoding=serialization.Encoding.PEM,
|
||||
format=serialization.PublicFormat.SubjectPublicKeyInfo,
|
||||
).decode("ascii")
|
||||
return priv, public_pem
|
||||
|
||||
def test_permitted_with_allow_all_returns_true(self):
|
||||
"""Test permitted method returns True when allow_all is enabled"""
|
||||
auth = Authenticator(allow_all=True)
|
||||
|
||||
# Should return True regardless of token or roles
|
||||
assert auth.permitted("any-token", []) is True
|
||||
assert auth.permitted("different-token", ["admin"]) is True
|
||||
assert auth.permitted(None, ["user"]) is True
|
||||
|
||||
def test_permitted_with_matching_token_returns_true(self):
|
||||
"""Test permitted method returns True with matching token"""
|
||||
auth = Authenticator(token="secret-token")
|
||||
|
||||
# Should return True when tokens match
|
||||
assert auth.permitted("secret-token", []) is True
|
||||
assert auth.permitted("secret-token", ["admin", "user"]) is True
|
||||
def sign_jwt(priv, claims, alg="EdDSA"):
|
||||
header = {"alg": alg, "typ": "JWT", "kid": "kid-test"}
|
||||
h = _b64url(json.dumps(header, separators=(",", ":"), sort_keys=True).encode())
|
||||
p = _b64url(json.dumps(claims, separators=(",", ":"), sort_keys=True).encode())
|
||||
signing_input = f"{h}.{p}".encode("ascii")
|
||||
if alg == "EdDSA":
|
||||
sig = priv.sign(signing_input)
|
||||
else:
|
||||
raise ValueError(f"test helper doesn't sign {alg}")
|
||||
return f"{h}.{p}.{_b64url(sig)}"
|
||||
|
||||
def test_permitted_with_non_matching_token_returns_false(self):
|
||||
"""Test permitted method returns False with non-matching token"""
|
||||
auth = Authenticator(token="secret-token")
|
||||
|
||||
# Should return False when tokens don't match
|
||||
assert auth.permitted("wrong-token", []) is False
|
||||
assert auth.permitted("different-token", ["admin"]) is False
|
||||
assert auth.permitted(None, ["user"]) is False
|
||||
|
||||
def test_permitted_with_token_and_allow_all_returns_true(self):
|
||||
"""Test permitted method with both token and allow_all set"""
|
||||
auth = Authenticator(token="test-token", allow_all=True)
|
||||
|
||||
# allow_all should take precedence
|
||||
assert auth.permitted("any-token", []) is True
|
||||
assert auth.permitted("wrong-token", ["admin"]) is True
|
||||
def make_request(auth_header):
|
||||
"""Minimal stand-in for an aiohttp request — IamAuth only reads
|
||||
``request.headers["Authorization"]``."""
|
||||
req = Mock()
|
||||
req.headers = {}
|
||||
if auth_header is not None:
|
||||
req.headers["Authorization"] = auth_header
|
||||
return req
|
||||
|
||||
|
||||
# -- pure helpers ----------------------------------------------------------
|
||||
|
||||
|
||||
class TestB64UrlDecode:
|
||||
|
||||
def test_round_trip_without_padding(self):
|
||||
data = b"hello"
|
||||
encoded = _b64url(data)
|
||||
assert _b64url_decode(encoded) == data
|
||||
|
||||
def test_handles_various_lengths(self):
|
||||
for s in (b"a", b"ab", b"abc", b"abcd", b"abcde"):
|
||||
assert _b64url_decode(_b64url(s)) == s
|
||||
|
||||
|
||||
# -- JWT verification -----------------------------------------------------
|
||||
|
||||
|
||||
class TestVerifyJwtEddsa:
|
||||
|
||||
def test_valid_jwt_passes(self):
|
||||
priv, pub = make_keypair()
|
||||
claims = {
|
||||
"sub": "user-1", "workspace": "default",
|
||||
"roles": ["reader"],
|
||||
"iat": int(time.time()),
|
||||
"exp": int(time.time()) + 60,
|
||||
}
|
||||
token = sign_jwt(priv, claims)
|
||||
got = _verify_jwt_eddsa(token, pub)
|
||||
assert got["sub"] == "user-1"
|
||||
assert got["workspace"] == "default"
|
||||
|
||||
def test_expired_jwt_rejected(self):
|
||||
priv, pub = make_keypair()
|
||||
claims = {
|
||||
"sub": "user-1", "workspace": "default", "roles": [],
|
||||
"iat": int(time.time()) - 3600,
|
||||
"exp": int(time.time()) - 1,
|
||||
}
|
||||
token = sign_jwt(priv, claims)
|
||||
with pytest.raises(ValueError, match="expired"):
|
||||
_verify_jwt_eddsa(token, pub)
|
||||
|
||||
def test_bad_signature_rejected(self):
|
||||
priv_a, _ = make_keypair()
|
||||
_, pub_b = make_keypair()
|
||||
claims = {
|
||||
"sub": "user-1", "workspace": "default", "roles": [],
|
||||
"iat": int(time.time()),
|
||||
"exp": int(time.time()) + 60,
|
||||
}
|
||||
token = sign_jwt(priv_a, claims)
|
||||
# pub_b never signed this token.
|
||||
with pytest.raises(Exception):
|
||||
_verify_jwt_eddsa(token, pub_b)
|
||||
|
||||
def test_malformed_jwt_rejected(self):
|
||||
_, pub = make_keypair()
|
||||
with pytest.raises(ValueError, match="malformed"):
|
||||
_verify_jwt_eddsa("not-a-jwt", pub)
|
||||
|
||||
def test_unsupported_algorithm_rejected(self):
|
||||
priv, pub = make_keypair()
|
||||
# Manually build an "alg":"HS256" header — no signer needed
|
||||
# since we expect it to bail before verifying.
|
||||
header = {"alg": "HS256", "typ": "JWT", "kid": "x"}
|
||||
payload = {
|
||||
"sub": "user-1", "workspace": "default", "roles": [],
|
||||
"iat": int(time.time()), "exp": int(time.time()) + 60,
|
||||
}
|
||||
h = _b64url(json.dumps(header, separators=(",", ":")).encode())
|
||||
p = _b64url(json.dumps(payload, separators=(",", ":")).encode())
|
||||
sig = _b64url(b"not-a-real-sig")
|
||||
token = f"{h}.{p}.{sig}"
|
||||
with pytest.raises(ValueError, match="unsupported alg"):
|
||||
_verify_jwt_eddsa(token, pub)
|
||||
|
||||
|
||||
# -- Identity --------------------------------------------------------------
|
||||
|
||||
|
||||
class TestIdentity:
|
||||
|
||||
def test_fields(self):
|
||||
i = Identity(
|
||||
user_id="u", workspace="w", roles=["reader"], source="api-key",
|
||||
)
|
||||
assert i.user_id == "u"
|
||||
assert i.workspace == "w"
|
||||
assert i.roles == ["reader"]
|
||||
assert i.source == "api-key"
|
||||
|
||||
|
||||
# -- IamAuth.authenticate --------------------------------------------------
|
||||
|
||||
|
||||
class TestIamAuthDispatch:
|
||||
"""``authenticate()`` chooses between the JWT and API-key paths
|
||||
by shape of the bearer."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_authorization_header_raises_401(self):
|
||||
auth = IamAuth(backend=Mock())
|
||||
with pytest.raises(web.HTTPUnauthorized):
|
||||
await auth.authenticate(make_request(None))
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_bearer_header_raises_401(self):
|
||||
auth = IamAuth(backend=Mock())
|
||||
with pytest.raises(web.HTTPUnauthorized):
|
||||
await auth.authenticate(make_request("Basic whatever"))
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_bearer_raises_401(self):
|
||||
auth = IamAuth(backend=Mock())
|
||||
with pytest.raises(web.HTTPUnauthorized):
|
||||
await auth.authenticate(make_request("Bearer "))
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unknown_format_raises_401(self):
|
||||
# Not tg_... and not dotted-JWT shape.
|
||||
auth = IamAuth(backend=Mock())
|
||||
with pytest.raises(web.HTTPUnauthorized):
|
||||
await auth.authenticate(make_request("Bearer garbage"))
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_valid_jwt_resolves_to_identity(self):
|
||||
priv, pub = make_keypair()
|
||||
claims = {
|
||||
"sub": "user-1", "workspace": "default",
|
||||
"roles": ["writer"],
|
||||
"iat": int(time.time()),
|
||||
"exp": int(time.time()) + 60,
|
||||
}
|
||||
token = sign_jwt(priv, claims)
|
||||
|
||||
auth = IamAuth(backend=Mock())
|
||||
auth._signing_public_pem = pub
|
||||
|
||||
ident = await auth.authenticate(
|
||||
make_request(f"Bearer {token}")
|
||||
)
|
||||
assert ident.user_id == "user-1"
|
||||
assert ident.workspace == "default"
|
||||
assert ident.roles == ["writer"]
|
||||
assert ident.source == "jwt"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_jwt_without_public_key_fails(self):
|
||||
# If the gateway hasn't fetched IAM's public key yet, JWTs
|
||||
# must not validate — even ones that would otherwise pass.
|
||||
priv, _ = make_keypair()
|
||||
claims = {
|
||||
"sub": "user-1", "workspace": "default", "roles": [],
|
||||
"iat": int(time.time()), "exp": int(time.time()) + 60,
|
||||
}
|
||||
token = sign_jwt(priv, claims)
|
||||
auth = IamAuth(backend=Mock())
|
||||
# _signing_public_pem defaults to None
|
||||
with pytest.raises(web.HTTPUnauthorized):
|
||||
await auth.authenticate(make_request(f"Bearer {token}"))
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_api_key_path(self):
|
||||
auth = IamAuth(backend=Mock())
|
||||
|
||||
async def fake_resolve(api_key):
|
||||
assert api_key == "tg_testkey"
|
||||
return ("user-xyz", "default", ["admin"])
|
||||
|
||||
async def fake_with_client(op):
|
||||
return await op(Mock(resolve_api_key=fake_resolve))
|
||||
|
||||
with patch.object(auth, "_with_client", side_effect=fake_with_client):
|
||||
ident = await auth.authenticate(
|
||||
make_request("Bearer tg_testkey")
|
||||
)
|
||||
assert ident.user_id == "user-xyz"
|
||||
assert ident.workspace == "default"
|
||||
assert ident.roles == ["admin"]
|
||||
assert ident.source == "api-key"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_api_key_rejection_masked_as_401(self):
|
||||
auth = IamAuth(backend=Mock())
|
||||
|
||||
async def fake_with_client(op):
|
||||
raise RuntimeError("auth-failed: unknown api key")
|
||||
|
||||
with patch.object(auth, "_with_client", side_effect=fake_with_client):
|
||||
with pytest.raises(web.HTTPUnauthorized):
|
||||
await auth.authenticate(
|
||||
make_request("Bearer tg_bogus")
|
||||
)
|
||||
|
||||
|
||||
# -- API key cache ---------------------------------------------------------
|
||||
|
||||
|
||||
class TestApiKeyCache:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_hit_skips_iam(self):
|
||||
auth = IamAuth(backend=Mock())
|
||||
calls = {"n": 0}
|
||||
|
||||
async def fake_with_client(op):
|
||||
calls["n"] += 1
|
||||
return await op(Mock(
|
||||
resolve_api_key=AsyncMock(
|
||||
return_value=("u", "default", ["reader"]),
|
||||
)
|
||||
))
|
||||
|
||||
with patch.object(auth, "_with_client", side_effect=fake_with_client):
|
||||
await auth.authenticate(make_request("Bearer tg_k1"))
|
||||
await auth.authenticate(make_request("Bearer tg_k1"))
|
||||
await auth.authenticate(make_request("Bearer tg_k1"))
|
||||
|
||||
# Only the first lookup reaches IAM; the rest are cache hits.
|
||||
assert calls["n"] == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_different_keys_are_separately_cached(self):
|
||||
auth = IamAuth(backend=Mock())
|
||||
seen = []
|
||||
|
||||
async def fake_with_client(op):
|
||||
async def resolve(plaintext):
|
||||
seen.append(plaintext)
|
||||
return ("u-" + plaintext, "default", ["reader"])
|
||||
return await op(Mock(resolve_api_key=resolve))
|
||||
|
||||
with patch.object(auth, "_with_client", side_effect=fake_with_client):
|
||||
a = await auth.authenticate(make_request("Bearer tg_a"))
|
||||
b = await auth.authenticate(make_request("Bearer tg_b"))
|
||||
|
||||
assert a.user_id == "u-tg_a"
|
||||
assert b.user_id == "u-tg_b"
|
||||
assert seen == ["tg_a", "tg_b"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_has_ttl_constant_set(self):
|
||||
# Not a behaviour test — just ensures we don't accidentally
|
||||
# set TTL to 0 (which would defeat the cache) or to a week.
|
||||
assert 10 <= API_KEY_CACHE_TTL <= 3600
|
||||
|
|
|
|||
203
tests/unit/test_gateway/test_capabilities.py
Normal file
203
tests/unit/test_gateway/test_capabilities.py
Normal file
|
|
@ -0,0 +1,203 @@
|
|||
"""
|
||||
Tests for gateway/capabilities.py — the capability + role + workspace
|
||||
model that underpins all gateway authorisation.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from aiohttp import web
|
||||
|
||||
from trustgraph.gateway.capabilities import (
|
||||
PUBLIC, AUTHENTICATED,
|
||||
KNOWN_CAPABILITIES, ROLE_DEFINITIONS,
|
||||
check, enforce_workspace, access_denied, auth_failure,
|
||||
)
|
||||
|
||||
|
||||
# -- test fixtures ---------------------------------------------------------
|
||||
|
||||
|
||||
class _Identity:
|
||||
"""Minimal stand-in for auth.Identity — the capability module
|
||||
accesses ``.workspace`` and ``.roles``."""
|
||||
def __init__(self, workspace, roles):
|
||||
self.user_id = "user-1"
|
||||
self.workspace = workspace
|
||||
self.roles = list(roles)
|
||||
|
||||
|
||||
def reader_in(ws):
|
||||
return _Identity(ws, ["reader"])
|
||||
|
||||
|
||||
def writer_in(ws):
|
||||
return _Identity(ws, ["writer"])
|
||||
|
||||
|
||||
def admin_in(ws):
|
||||
return _Identity(ws, ["admin"])
|
||||
|
||||
|
||||
# -- role table sanity -----------------------------------------------------
|
||||
|
||||
|
||||
class TestRoleTable:
|
||||
|
||||
def test_oss_roles_present(self):
|
||||
assert set(ROLE_DEFINITIONS.keys()) == {"reader", "writer", "admin"}
|
||||
|
||||
def test_admin_is_cross_workspace(self):
|
||||
assert ROLE_DEFINITIONS["admin"]["workspace_scope"] == "*"
|
||||
|
||||
def test_reader_writer_are_assigned_scope(self):
|
||||
assert ROLE_DEFINITIONS["reader"]["workspace_scope"] == "assigned"
|
||||
assert ROLE_DEFINITIONS["writer"]["workspace_scope"] == "assigned"
|
||||
|
||||
def test_admin_superset_of_writer(self):
|
||||
admin = ROLE_DEFINITIONS["admin"]["capabilities"]
|
||||
writer = ROLE_DEFINITIONS["writer"]["capabilities"]
|
||||
assert writer.issubset(admin)
|
||||
|
||||
def test_writer_superset_of_reader(self):
|
||||
writer = ROLE_DEFINITIONS["writer"]["capabilities"]
|
||||
reader = ROLE_DEFINITIONS["reader"]["capabilities"]
|
||||
assert reader.issubset(writer)
|
||||
|
||||
def test_admin_has_users_admin(self):
|
||||
assert "users:admin" in ROLE_DEFINITIONS["admin"]["capabilities"]
|
||||
|
||||
def test_writer_does_not_have_users_admin(self):
|
||||
assert "users:admin" not in ROLE_DEFINITIONS["writer"]["capabilities"]
|
||||
|
||||
def test_every_bundled_capability_is_known(self):
|
||||
for role in ROLE_DEFINITIONS.values():
|
||||
for cap in role["capabilities"]:
|
||||
assert cap in KNOWN_CAPABILITIES
|
||||
|
||||
|
||||
# -- check() ---------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCheck:
|
||||
|
||||
def test_reader_has_reader_cap_in_own_workspace(self):
|
||||
assert check(reader_in("default"), "graph:read", "default")
|
||||
|
||||
def test_reader_does_not_have_writer_cap(self):
|
||||
assert not check(reader_in("default"), "graph:write", "default")
|
||||
|
||||
def test_reader_cannot_act_in_other_workspace(self):
|
||||
assert not check(reader_in("default"), "graph:read", "acme")
|
||||
|
||||
def test_writer_has_write_in_own_workspace(self):
|
||||
assert check(writer_in("default"), "graph:write", "default")
|
||||
|
||||
def test_writer_cannot_act_in_other_workspace(self):
|
||||
assert not check(writer_in("default"), "graph:write", "acme")
|
||||
|
||||
def test_admin_has_everything_everywhere(self):
|
||||
for cap in ("graph:read", "graph:write", "config:write",
|
||||
"users:admin", "metrics:read"):
|
||||
assert check(admin_in("default"), cap, "acme"), (
|
||||
f"admin should have {cap} in acme"
|
||||
)
|
||||
|
||||
def test_admin_has_caps_without_explicit_workspace(self):
|
||||
assert check(admin_in("default"), "users:admin")
|
||||
|
||||
def test_default_target_is_identity_workspace(self):
|
||||
# Reader with no target workspace → should check against own
|
||||
assert check(reader_in("default"), "graph:read")
|
||||
|
||||
def test_unknown_capability_returns_false(self):
|
||||
assert not check(admin_in("default"), "nonsense:cap", "default")
|
||||
|
||||
def test_unknown_role_contributes_nothing(self):
|
||||
ident = _Identity("default", ["made-up-role"])
|
||||
assert not check(ident, "graph:read", "default")
|
||||
|
||||
def test_multi_role_union(self):
|
||||
# If a user is both reader and admin, they inherit admin's
|
||||
# cross-workspace powers.
|
||||
ident = _Identity("default", ["reader", "admin"])
|
||||
assert check(ident, "users:admin", "acme")
|
||||
|
||||
|
||||
# -- enforce_workspace() ---------------------------------------------------
|
||||
|
||||
|
||||
class TestEnforceWorkspace:
|
||||
|
||||
def test_reader_in_own_workspace_allowed(self):
|
||||
data = {"workspace": "default", "operation": "x"}
|
||||
enforce_workspace(data, reader_in("default"))
|
||||
assert data["workspace"] == "default"
|
||||
|
||||
def test_reader_no_workspace_injects_assigned(self):
|
||||
data = {"operation": "x"}
|
||||
enforce_workspace(data, reader_in("default"))
|
||||
assert data["workspace"] == "default"
|
||||
|
||||
def test_reader_mismatched_workspace_denied(self):
|
||||
data = {"workspace": "acme", "operation": "x"}
|
||||
with pytest.raises(web.HTTPForbidden):
|
||||
enforce_workspace(data, reader_in("default"))
|
||||
|
||||
def test_admin_can_target_any_workspace(self):
|
||||
data = {"workspace": "acme", "operation": "x"}
|
||||
enforce_workspace(data, admin_in("default"))
|
||||
assert data["workspace"] == "acme"
|
||||
|
||||
def test_admin_no_workspace_defaults_to_assigned(self):
|
||||
data = {"operation": "x"}
|
||||
enforce_workspace(data, admin_in("default"))
|
||||
assert data["workspace"] == "default"
|
||||
|
||||
def test_writer_same_workspace_specified_allowed(self):
|
||||
data = {"workspace": "default"}
|
||||
enforce_workspace(data, writer_in("default"))
|
||||
assert data["workspace"] == "default"
|
||||
|
||||
def test_non_dict_passthrough(self):
|
||||
# Non-dict bodies are returned unchanged (e.g. streaming).
|
||||
result = enforce_workspace("not-a-dict", reader_in("default"))
|
||||
assert result == "not-a-dict"
|
||||
|
||||
def test_with_capability_tightens_check(self):
|
||||
# Reader lacks graph:write; workspace-only check would pass
|
||||
# (scope is fine), but combined check must reject.
|
||||
data = {"workspace": "default"}
|
||||
with pytest.raises(web.HTTPForbidden):
|
||||
enforce_workspace(
|
||||
data, reader_in("default"), capability="graph:write",
|
||||
)
|
||||
|
||||
def test_with_capability_passes_when_granted(self):
|
||||
data = {"workspace": "default"}
|
||||
enforce_workspace(
|
||||
data, reader_in("default"), capability="graph:read",
|
||||
)
|
||||
assert data["workspace"] == "default"
|
||||
|
||||
|
||||
# -- helpers ---------------------------------------------------------------
|
||||
|
||||
|
||||
class TestResponseHelpers:
|
||||
|
||||
def test_auth_failure_is_401(self):
|
||||
exc = auth_failure()
|
||||
assert exc.status == 401
|
||||
assert "auth failure" in exc.text
|
||||
|
||||
def test_access_denied_is_403(self):
|
||||
exc = access_denied()
|
||||
assert exc.status == 403
|
||||
assert "access denied" in exc.text
|
||||
|
||||
|
||||
class TestSentinels:
|
||||
|
||||
def test_public_and_authenticated_are_distinct(self):
|
||||
assert PUBLIC != AUTHENTICATED
|
||||
assert PUBLIC not in KNOWN_CAPABILITIES
|
||||
assert AUTHENTICATED not in KNOWN_CAPABILITIES
|
||||
|
|
@ -42,7 +42,7 @@ class TestDispatcherManager:
|
|||
mock_backend = Mock()
|
||||
mock_config_receiver = Mock()
|
||||
|
||||
manager = DispatcherManager(mock_backend, mock_config_receiver)
|
||||
manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock())
|
||||
|
||||
assert manager.backend == mock_backend
|
||||
assert manager.config_receiver == mock_config_receiver
|
||||
|
|
@ -59,7 +59,10 @@ class TestDispatcherManager:
|
|||
mock_backend = Mock()
|
||||
mock_config_receiver = Mock()
|
||||
|
||||
manager = DispatcherManager(mock_backend, mock_config_receiver, prefix="custom-prefix")
|
||||
manager = DispatcherManager(
|
||||
mock_backend, mock_config_receiver,
|
||||
auth=Mock(), prefix="custom-prefix",
|
||||
)
|
||||
|
||||
assert manager.prefix == "custom-prefix"
|
||||
|
||||
|
|
@ -68,7 +71,7 @@ class TestDispatcherManager:
|
|||
"""Test start_flow method"""
|
||||
mock_backend = Mock()
|
||||
mock_config_receiver = Mock()
|
||||
manager = DispatcherManager(mock_backend, mock_config_receiver)
|
||||
manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock())
|
||||
|
||||
flow_data = {"name": "test_flow", "steps": []}
|
||||
|
||||
|
|
@ -82,7 +85,7 @@ class TestDispatcherManager:
|
|||
"""Test stop_flow method"""
|
||||
mock_backend = Mock()
|
||||
mock_config_receiver = Mock()
|
||||
manager = DispatcherManager(mock_backend, mock_config_receiver)
|
||||
manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock())
|
||||
|
||||
# Pre-populate with a flow
|
||||
flow_data = {"name": "test_flow", "steps": []}
|
||||
|
|
@ -96,7 +99,7 @@ class TestDispatcherManager:
|
|||
"""Test dispatch_global_service returns DispatcherWrapper"""
|
||||
mock_backend = Mock()
|
||||
mock_config_receiver = Mock()
|
||||
manager = DispatcherManager(mock_backend, mock_config_receiver)
|
||||
manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock())
|
||||
|
||||
wrapper = manager.dispatch_global_service()
|
||||
|
||||
|
|
@ -107,7 +110,7 @@ class TestDispatcherManager:
|
|||
"""Test dispatch_core_export returns DispatcherWrapper"""
|
||||
mock_backend = Mock()
|
||||
mock_config_receiver = Mock()
|
||||
manager = DispatcherManager(mock_backend, mock_config_receiver)
|
||||
manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock())
|
||||
|
||||
wrapper = manager.dispatch_core_export()
|
||||
|
||||
|
|
@ -118,7 +121,7 @@ class TestDispatcherManager:
|
|||
"""Test dispatch_core_import returns DispatcherWrapper"""
|
||||
mock_backend = Mock()
|
||||
mock_config_receiver = Mock()
|
||||
manager = DispatcherManager(mock_backend, mock_config_receiver)
|
||||
manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock())
|
||||
|
||||
wrapper = manager.dispatch_core_import()
|
||||
|
||||
|
|
@ -130,7 +133,7 @@ class TestDispatcherManager:
|
|||
"""Test process_core_import method"""
|
||||
mock_backend = Mock()
|
||||
mock_config_receiver = Mock()
|
||||
manager = DispatcherManager(mock_backend, mock_config_receiver)
|
||||
manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock())
|
||||
|
||||
with patch('trustgraph.gateway.dispatch.manager.CoreImport') as mock_core_import:
|
||||
mock_importer = Mock()
|
||||
|
|
@ -148,7 +151,7 @@ class TestDispatcherManager:
|
|||
"""Test process_core_export method"""
|
||||
mock_backend = Mock()
|
||||
mock_config_receiver = Mock()
|
||||
manager = DispatcherManager(mock_backend, mock_config_receiver)
|
||||
manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock())
|
||||
|
||||
with patch('trustgraph.gateway.dispatch.manager.CoreExport') as mock_core_export:
|
||||
mock_exporter = Mock()
|
||||
|
|
@ -166,7 +169,7 @@ class TestDispatcherManager:
|
|||
"""Test process_global_service method"""
|
||||
mock_backend = Mock()
|
||||
mock_config_receiver = Mock()
|
||||
manager = DispatcherManager(mock_backend, mock_config_receiver)
|
||||
manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock())
|
||||
|
||||
manager.invoke_global_service = AsyncMock(return_value="global_result")
|
||||
|
||||
|
|
@ -181,7 +184,7 @@ class TestDispatcherManager:
|
|||
"""Test invoke_global_service with existing dispatcher"""
|
||||
mock_backend = Mock()
|
||||
mock_config_receiver = Mock()
|
||||
manager = DispatcherManager(mock_backend, mock_config_receiver)
|
||||
manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock())
|
||||
|
||||
# Pre-populate with existing dispatcher
|
||||
mock_dispatcher = Mock()
|
||||
|
|
@ -198,7 +201,7 @@ class TestDispatcherManager:
|
|||
"""Test invoke_global_service creates new dispatcher"""
|
||||
mock_backend = Mock()
|
||||
mock_config_receiver = Mock()
|
||||
manager = DispatcherManager(mock_backend, mock_config_receiver)
|
||||
manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock())
|
||||
|
||||
with patch('trustgraph.gateway.dispatch.manager.global_dispatchers') as mock_dispatchers:
|
||||
mock_dispatcher_class = Mock()
|
||||
|
|
@ -230,7 +233,7 @@ class TestDispatcherManager:
|
|||
"""Test dispatch_flow_import returns correct method"""
|
||||
mock_backend = Mock()
|
||||
mock_config_receiver = Mock()
|
||||
manager = DispatcherManager(mock_backend, mock_config_receiver)
|
||||
manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock())
|
||||
|
||||
result = manager.dispatch_flow_import()
|
||||
|
||||
|
|
@ -240,7 +243,7 @@ class TestDispatcherManager:
|
|||
"""Test dispatch_flow_export returns correct method"""
|
||||
mock_backend = Mock()
|
||||
mock_config_receiver = Mock()
|
||||
manager = DispatcherManager(mock_backend, mock_config_receiver)
|
||||
manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock())
|
||||
|
||||
result = manager.dispatch_flow_export()
|
||||
|
||||
|
|
@ -250,7 +253,7 @@ class TestDispatcherManager:
|
|||
"""Test dispatch_socket returns correct method"""
|
||||
mock_backend = Mock()
|
||||
mock_config_receiver = Mock()
|
||||
manager = DispatcherManager(mock_backend, mock_config_receiver)
|
||||
manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock())
|
||||
|
||||
result = manager.dispatch_socket()
|
||||
|
||||
|
|
@ -260,7 +263,7 @@ class TestDispatcherManager:
|
|||
"""Test dispatch_flow_service returns DispatcherWrapper"""
|
||||
mock_backend = Mock()
|
||||
mock_config_receiver = Mock()
|
||||
manager = DispatcherManager(mock_backend, mock_config_receiver)
|
||||
manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock())
|
||||
|
||||
wrapper = manager.dispatch_flow_service()
|
||||
|
||||
|
|
@ -272,7 +275,7 @@ class TestDispatcherManager:
|
|||
"""Test process_flow_import with valid flow and kind"""
|
||||
mock_backend = Mock()
|
||||
mock_config_receiver = Mock()
|
||||
manager = DispatcherManager(mock_backend, mock_config_receiver)
|
||||
manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock())
|
||||
|
||||
# Setup test flow
|
||||
manager.flows[("default", "test_flow")] = {
|
||||
|
|
@ -308,7 +311,7 @@ class TestDispatcherManager:
|
|||
"""Test process_flow_import with invalid flow"""
|
||||
mock_backend = Mock()
|
||||
mock_config_receiver = Mock()
|
||||
manager = DispatcherManager(mock_backend, mock_config_receiver)
|
||||
manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock())
|
||||
|
||||
params = {"flow": "invalid_flow", "kind": "triples"}
|
||||
|
||||
|
|
@ -323,7 +326,7 @@ class TestDispatcherManager:
|
|||
warnings.simplefilter("ignore", RuntimeWarning)
|
||||
mock_backend = Mock()
|
||||
mock_config_receiver = Mock()
|
||||
manager = DispatcherManager(mock_backend, mock_config_receiver)
|
||||
manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock())
|
||||
|
||||
# Setup test flow
|
||||
manager.flows[("default", "test_flow")] = {
|
||||
|
|
@ -345,7 +348,7 @@ class TestDispatcherManager:
|
|||
"""Test process_flow_export with valid flow and kind"""
|
||||
mock_backend = Mock()
|
||||
mock_config_receiver = Mock()
|
||||
manager = DispatcherManager(mock_backend, mock_config_receiver)
|
||||
manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock())
|
||||
|
||||
# Setup test flow
|
||||
manager.flows[("default", "test_flow")] = {
|
||||
|
|
@ -378,26 +381,47 @@ class TestDispatcherManager:
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_socket(self):
|
||||
"""Test process_socket method"""
|
||||
"""process_socket constructs a Mux with the manager's auth
|
||||
instance passed through — this is the gateway's trust path
|
||||
for first-frame WebSocket authentication. A Mux cannot be
|
||||
built without auth (tested separately); this test pins that
|
||||
the dispatcher-manager threads the correct auth value into
|
||||
the Mux constructor call."""
|
||||
mock_backend = Mock()
|
||||
mock_config_receiver = Mock()
|
||||
manager = DispatcherManager(mock_backend, mock_config_receiver)
|
||||
|
||||
mock_auth = Mock()
|
||||
manager = DispatcherManager(
|
||||
mock_backend, mock_config_receiver, auth=mock_auth,
|
||||
)
|
||||
|
||||
with patch('trustgraph.gateway.dispatch.manager.Mux') as mock_mux:
|
||||
mock_mux_instance = Mock()
|
||||
mock_mux.return_value = mock_mux_instance
|
||||
|
||||
|
||||
result = await manager.process_socket("ws", "running", {})
|
||||
|
||||
mock_mux.assert_called_once_with(manager, "ws", "running")
|
||||
|
||||
mock_mux.assert_called_once_with(
|
||||
manager, "ws", "running", auth=mock_auth,
|
||||
)
|
||||
assert result == mock_mux_instance
|
||||
|
||||
def test_dispatcher_manager_requires_auth(self):
|
||||
"""Constructing a DispatcherManager without an auth argument
|
||||
must fail — a no-auth DispatcherManager would produce a
|
||||
Mux without authentication, silently downgrading the socket
|
||||
auth path."""
|
||||
mock_backend = Mock()
|
||||
mock_config_receiver = Mock()
|
||||
|
||||
with pytest.raises(ValueError, match="auth"):
|
||||
DispatcherManager(mock_backend, mock_config_receiver, auth=None)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_flow_service(self):
|
||||
"""Test process_flow_service method"""
|
||||
mock_backend = Mock()
|
||||
mock_config_receiver = Mock()
|
||||
manager = DispatcherManager(mock_backend, mock_config_receiver)
|
||||
manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock())
|
||||
|
||||
manager.invoke_flow_service = AsyncMock(return_value="flow_result")
|
||||
|
||||
|
|
@ -412,7 +436,7 @@ class TestDispatcherManager:
|
|||
"""Test invoke_flow_service with existing dispatcher"""
|
||||
mock_backend = Mock()
|
||||
mock_config_receiver = Mock()
|
||||
manager = DispatcherManager(mock_backend, mock_config_receiver)
|
||||
manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock())
|
||||
|
||||
# Add flow to the flows dictionary
|
||||
manager.flows[("default", "test_flow")] = {"services": {"agent": {}}}
|
||||
|
|
@ -432,7 +456,7 @@ class TestDispatcherManager:
|
|||
"""Test invoke_flow_service creates request-response dispatcher"""
|
||||
mock_backend = Mock()
|
||||
mock_config_receiver = Mock()
|
||||
manager = DispatcherManager(mock_backend, mock_config_receiver)
|
||||
manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock())
|
||||
|
||||
# Setup test flow
|
||||
manager.flows[("default", "test_flow")] = {
|
||||
|
|
@ -476,7 +500,7 @@ class TestDispatcherManager:
|
|||
"""Test invoke_flow_service creates sender dispatcher"""
|
||||
mock_backend = Mock()
|
||||
mock_config_receiver = Mock()
|
||||
manager = DispatcherManager(mock_backend, mock_config_receiver)
|
||||
manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock())
|
||||
|
||||
# Setup test flow
|
||||
manager.flows[("default", "test_flow")] = {
|
||||
|
|
@ -516,7 +540,7 @@ class TestDispatcherManager:
|
|||
"""Test invoke_flow_service with invalid flow"""
|
||||
mock_backend = Mock()
|
||||
mock_config_receiver = Mock()
|
||||
manager = DispatcherManager(mock_backend, mock_config_receiver)
|
||||
manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock())
|
||||
|
||||
with pytest.raises(RuntimeError, match="Invalid flow"):
|
||||
await manager.invoke_flow_service("data", "responder", "default", "invalid_flow", "agent")
|
||||
|
|
@ -526,7 +550,7 @@ class TestDispatcherManager:
|
|||
"""Test invoke_flow_service with kind not supported by flow"""
|
||||
mock_backend = Mock()
|
||||
mock_config_receiver = Mock()
|
||||
manager = DispatcherManager(mock_backend, mock_config_receiver)
|
||||
manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock())
|
||||
|
||||
# Setup test flow without agent interface
|
||||
manager.flows[("default", "test_flow")] = {
|
||||
|
|
@ -543,7 +567,7 @@ class TestDispatcherManager:
|
|||
"""Test invoke_flow_service with invalid kind"""
|
||||
mock_backend = Mock()
|
||||
mock_config_receiver = Mock()
|
||||
manager = DispatcherManager(mock_backend, mock_config_receiver)
|
||||
manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock())
|
||||
|
||||
# Setup test flow with interface but unsupported kind
|
||||
manager.flows[("default", "test_flow")] = {
|
||||
|
|
@ -570,7 +594,7 @@ class TestDispatcherManager:
|
|||
"""
|
||||
mock_backend = Mock()
|
||||
mock_config_receiver = Mock()
|
||||
manager = DispatcherManager(mock_backend, mock_config_receiver)
|
||||
manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock())
|
||||
|
||||
async def slow_start():
|
||||
# Yield to the event loop so other coroutines get a chance to run,
|
||||
|
|
@ -606,7 +630,7 @@ class TestDispatcherManager:
|
|||
"""
|
||||
mock_backend = Mock()
|
||||
mock_config_receiver = Mock()
|
||||
manager = DispatcherManager(mock_backend, mock_config_receiver)
|
||||
manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock())
|
||||
|
||||
manager.flows[("default", "test_flow")] = {
|
||||
"interfaces": {
|
||||
|
|
|
|||
|
|
@ -12,6 +12,19 @@ from trustgraph.gateway.dispatch.mux import Mux, MAX_QUEUE_SIZE
|
|||
class TestMux:
|
||||
"""Test cases for Mux class"""
|
||||
|
||||
def test_mux_requires_auth(self):
|
||||
"""Constructing a Mux without an ``auth`` argument must
|
||||
fail. The Mux implements the first-frame auth protocol and
|
||||
there is no no-auth mode — a no-auth Mux would silently
|
||||
accept every frame without authenticating it."""
|
||||
with pytest.raises(ValueError, match="auth"):
|
||||
Mux(
|
||||
dispatcher_manager=MagicMock(),
|
||||
ws=MagicMock(),
|
||||
running=MagicMock(),
|
||||
auth=None,
|
||||
)
|
||||
|
||||
def test_mux_initialization(self):
|
||||
"""Test Mux initialization"""
|
||||
mock_dispatcher_manager = MagicMock()
|
||||
|
|
@ -21,7 +34,8 @@ class TestMux:
|
|||
mux = Mux(
|
||||
dispatcher_manager=mock_dispatcher_manager,
|
||||
ws=mock_ws,
|
||||
running=mock_running
|
||||
running=mock_running,
|
||||
auth=MagicMock(),
|
||||
)
|
||||
|
||||
assert mux.dispatcher_manager == mock_dispatcher_manager
|
||||
|
|
@ -40,7 +54,8 @@ class TestMux:
|
|||
mux = Mux(
|
||||
dispatcher_manager=mock_dispatcher_manager,
|
||||
ws=mock_ws,
|
||||
running=mock_running
|
||||
running=mock_running,
|
||||
auth=MagicMock(),
|
||||
)
|
||||
|
||||
# Call destroy
|
||||
|
|
@ -61,7 +76,8 @@ class TestMux:
|
|||
mux = Mux(
|
||||
dispatcher_manager=mock_dispatcher_manager,
|
||||
ws=None,
|
||||
running=mock_running
|
||||
running=mock_running,
|
||||
auth=MagicMock(),
|
||||
)
|
||||
|
||||
# Call destroy
|
||||
|
|
@ -81,7 +97,8 @@ class TestMux:
|
|||
mux = Mux(
|
||||
dispatcher_manager=mock_dispatcher_manager,
|
||||
ws=mock_ws,
|
||||
running=mock_running
|
||||
running=mock_running,
|
||||
auth=MagicMock(),
|
||||
)
|
||||
|
||||
# Mock message with valid JSON
|
||||
|
|
@ -108,7 +125,8 @@ class TestMux:
|
|||
mux = Mux(
|
||||
dispatcher_manager=mock_dispatcher_manager,
|
||||
ws=mock_ws,
|
||||
running=mock_running
|
||||
running=mock_running,
|
||||
auth=MagicMock(),
|
||||
)
|
||||
|
||||
# Mock message without request field
|
||||
|
|
@ -137,7 +155,8 @@ class TestMux:
|
|||
mux = Mux(
|
||||
dispatcher_manager=mock_dispatcher_manager,
|
||||
ws=mock_ws,
|
||||
running=mock_running
|
||||
running=mock_running,
|
||||
auth=MagicMock(),
|
||||
)
|
||||
|
||||
# Mock message without id field
|
||||
|
|
@ -164,7 +183,8 @@ class TestMux:
|
|||
mux = Mux(
|
||||
dispatcher_manager=mock_dispatcher_manager,
|
||||
ws=mock_ws,
|
||||
running=mock_running
|
||||
running=mock_running,
|
||||
auth=MagicMock(),
|
||||
)
|
||||
|
||||
# Mock message with invalid JSON
|
||||
|
|
|
|||
|
|
@ -13,29 +13,36 @@ class TestConstantEndpoint:
|
|||
"""Test cases for ConstantEndpoint class"""
|
||||
|
||||
def test_constant_endpoint_initialization(self):
|
||||
"""Test ConstantEndpoint initialization"""
|
||||
"""Construction records the configured capability on the
|
||||
instance. The capability is a required argument — no
|
||||
permissive default — and the test passes an explicit
|
||||
value to demonstrate the contract."""
|
||||
mock_auth = MagicMock()
|
||||
mock_dispatcher = MagicMock()
|
||||
|
||||
|
||||
endpoint = ConstantEndpoint(
|
||||
endpoint_path="/api/test",
|
||||
auth=mock_auth,
|
||||
dispatcher=mock_dispatcher
|
||||
dispatcher=mock_dispatcher,
|
||||
capability="config:read",
|
||||
)
|
||||
|
||||
|
||||
assert endpoint.path == "/api/test"
|
||||
assert endpoint.auth == mock_auth
|
||||
assert endpoint.dispatcher == mock_dispatcher
|
||||
assert endpoint.operation == "service"
|
||||
assert endpoint.capability == "config:read"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_constant_endpoint_start_method(self):
|
||||
"""Test ConstantEndpoint start method (should be no-op)"""
|
||||
mock_auth = MagicMock()
|
||||
mock_dispatcher = MagicMock()
|
||||
|
||||
endpoint = ConstantEndpoint("/api/test", mock_auth, mock_dispatcher)
|
||||
|
||||
|
||||
endpoint = ConstantEndpoint(
|
||||
"/api/test", mock_auth, mock_dispatcher,
|
||||
capability="config:read",
|
||||
)
|
||||
|
||||
# start() should complete without error
|
||||
await endpoint.start()
|
||||
|
||||
|
|
@ -44,10 +51,13 @@ class TestConstantEndpoint:
|
|||
mock_auth = MagicMock()
|
||||
mock_dispatcher = MagicMock()
|
||||
mock_app = MagicMock()
|
||||
|
||||
endpoint = ConstantEndpoint("/api/test", mock_auth, mock_dispatcher)
|
||||
|
||||
endpoint = ConstantEndpoint(
|
||||
"/api/test", mock_auth, mock_dispatcher,
|
||||
capability="config:read",
|
||||
)
|
||||
endpoint.add_routes(mock_app)
|
||||
|
||||
|
||||
# Verify add_routes was called with POST route
|
||||
mock_app.add_routes.assert_called_once()
|
||||
# The call should include web.post with the path and handler
|
||||
|
|
|
|||
|
|
@ -1,4 +1,12 @@
|
|||
"""Tests for Gateway i18n pack endpoint."""
|
||||
"""Tests for Gateway i18n pack endpoint.
|
||||
|
||||
Production registers this endpoint with ``capability=PUBLIC``: the
|
||||
login UI needs to render its own i18n strings before any user has
|
||||
authenticated, so the endpoint is deliberately pre-auth. These
|
||||
tests exercise the PUBLIC configuration — that is the production
|
||||
contract. Behaviour of authenticated endpoints is covered by the
|
||||
IamAuth tests in ``test_auth.py``.
|
||||
"""
|
||||
|
||||
import json
|
||||
from unittest.mock import MagicMock
|
||||
|
|
@ -7,6 +15,7 @@ import pytest
|
|||
from aiohttp import web
|
||||
|
||||
from trustgraph.gateway.endpoint.i18n import I18nPackEndpoint
|
||||
from trustgraph.gateway.capabilities import PUBLIC
|
||||
|
||||
|
||||
class TestI18nPackEndpoint:
|
||||
|
|
@ -17,23 +26,28 @@ class TestI18nPackEndpoint:
|
|||
endpoint = I18nPackEndpoint(
|
||||
endpoint_path="/api/v1/i18n/packs/{lang}",
|
||||
auth=mock_auth,
|
||||
capability=PUBLIC,
|
||||
)
|
||||
|
||||
assert endpoint.path == "/api/v1/i18n/packs/{lang}"
|
||||
assert endpoint.auth == mock_auth
|
||||
assert endpoint.operation == "service"
|
||||
assert endpoint.capability == PUBLIC
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_i18n_endpoint_start_method(self):
|
||||
mock_auth = MagicMock()
|
||||
endpoint = I18nPackEndpoint("/api/v1/i18n/packs/{lang}", mock_auth)
|
||||
endpoint = I18nPackEndpoint(
|
||||
"/api/v1/i18n/packs/{lang}", mock_auth, capability=PUBLIC,
|
||||
)
|
||||
await endpoint.start()
|
||||
|
||||
def test_add_routes_registers_get_handler(self):
|
||||
mock_auth = MagicMock()
|
||||
mock_app = MagicMock()
|
||||
|
||||
endpoint = I18nPackEndpoint("/api/v1/i18n/packs/{lang}", mock_auth)
|
||||
endpoint = I18nPackEndpoint(
|
||||
"/api/v1/i18n/packs/{lang}", mock_auth, capability=PUBLIC,
|
||||
)
|
||||
endpoint.add_routes(mock_app)
|
||||
|
||||
mock_app.add_routes.assert_called_once()
|
||||
|
|
@ -41,35 +55,55 @@ class TestI18nPackEndpoint:
|
|||
assert len(call_args) == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_unauthorized_on_invalid_auth_scheme(self):
|
||||
async def test_handle_returns_pack_without_authenticating(self):
|
||||
"""The PUBLIC endpoint serves the language pack without
|
||||
invoking the auth handler at all — pre-login UI must be
|
||||
reachable. The test uses an auth mock that raises if
|
||||
touched, so any auth attempt by the endpoint is caught."""
|
||||
mock_auth = MagicMock()
|
||||
mock_auth.permitted.return_value = True
|
||||
|
||||
endpoint = I18nPackEndpoint("/api/v1/i18n/packs/{lang}", mock_auth)
|
||||
def _should_not_be_called(*args, **kwargs):
|
||||
raise AssertionError(
|
||||
"PUBLIC endpoint must not invoke auth.authenticate"
|
||||
)
|
||||
mock_auth.authenticate = _should_not_be_called
|
||||
|
||||
endpoint = I18nPackEndpoint(
|
||||
"/api/v1/i18n/packs/{lang}", mock_auth, capability=PUBLIC,
|
||||
)
|
||||
|
||||
request = MagicMock()
|
||||
request.path = "/api/v1/i18n/packs/en"
|
||||
# A caller-supplied Authorization header of any form should
|
||||
# be ignored — PUBLIC means we don't look at it.
|
||||
request.headers = {"Authorization": "Token abc"}
|
||||
request.match_info = {"lang": "en"}
|
||||
|
||||
resp = await endpoint.handle(request)
|
||||
assert isinstance(resp, web.HTTPUnauthorized)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_returns_pack_when_permitted(self):
|
||||
mock_auth = MagicMock()
|
||||
mock_auth.permitted.return_value = True
|
||||
|
||||
endpoint = I18nPackEndpoint("/api/v1/i18n/packs/{lang}", mock_auth)
|
||||
|
||||
request = MagicMock()
|
||||
request.path = "/api/v1/i18n/packs/en"
|
||||
request.headers = {}
|
||||
request.match_info = {"lang": "en"}
|
||||
|
||||
resp = await endpoint.handle(request)
|
||||
|
||||
assert resp.status == 200
|
||||
payload = json.loads(resp.body.decode("utf-8"))
|
||||
assert isinstance(payload, dict)
|
||||
assert "cli.verify_system_status.title" in payload
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_rejects_path_traversal(self):
|
||||
"""The ``lang`` path parameter is reflected through to the
|
||||
filesystem-backed pack loader. The endpoint contains an
|
||||
explicit defense against ``/`` and ``..`` in the value; this
|
||||
test pins that defense in place."""
|
||||
mock_auth = MagicMock()
|
||||
endpoint = I18nPackEndpoint(
|
||||
"/api/v1/i18n/packs/{lang}", mock_auth, capability=PUBLIC,
|
||||
)
|
||||
|
||||
for bad in ("../../etc/passwd", "en/../fr", "a/b"):
|
||||
request = MagicMock()
|
||||
request.path = f"/api/v1/i18n/packs/{bad}"
|
||||
request.headers = {}
|
||||
request.match_info = {"lang": bad}
|
||||
|
||||
resp = await endpoint.handle(request)
|
||||
assert isinstance(resp, web.HTTPBadRequest), (
|
||||
f"path-traversal defense did not reject lang={bad!r}"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -12,30 +12,24 @@ class TestEndpointManager:
|
|||
"""Test cases for EndpointManager class"""
|
||||
|
||||
def test_endpoint_manager_initialization(self):
|
||||
"""Test EndpointManager initialization creates all endpoints"""
|
||||
"""EndpointManager wires up the full endpoint set and
|
||||
records dispatcher_manager / timeout on the instance."""
|
||||
mock_dispatcher_manager = MagicMock()
|
||||
mock_auth = MagicMock()
|
||||
|
||||
# Mock dispatcher methods
|
||||
mock_dispatcher_manager.dispatch_global_service.return_value = MagicMock()
|
||||
mock_dispatcher_manager.dispatch_socket.return_value = MagicMock()
|
||||
mock_dispatcher_manager.dispatch_flow_service.return_value = MagicMock()
|
||||
mock_dispatcher_manager.dispatch_flow_import.return_value = MagicMock()
|
||||
mock_dispatcher_manager.dispatch_flow_export.return_value = MagicMock()
|
||||
mock_dispatcher_manager.dispatch_core_import.return_value = MagicMock()
|
||||
mock_dispatcher_manager.dispatch_core_export.return_value = MagicMock()
|
||||
|
||||
|
||||
# The dispatcher_manager exposes a small set of factory
|
||||
# methods — MagicMock auto-creates them, returning fresh
|
||||
# MagicMocks on each call.
|
||||
manager = EndpointManager(
|
||||
dispatcher_manager=mock_dispatcher_manager,
|
||||
auth=mock_auth,
|
||||
prometheus_url="http://prometheus:9090",
|
||||
timeout=300
|
||||
timeout=300,
|
||||
)
|
||||
|
||||
|
||||
assert manager.dispatcher_manager == mock_dispatcher_manager
|
||||
assert manager.timeout == 300
|
||||
assert manager.services == {}
|
||||
assert len(manager.endpoints) > 0 # Should have multiple endpoints
|
||||
assert len(manager.endpoints) > 0
|
||||
|
||||
def test_endpoint_manager_with_default_timeout(self):
|
||||
"""Test EndpointManager with default timeout value"""
|
||||
|
|
@ -79,9 +73,15 @@ class TestEndpointManager:
|
|||
prometheus_url="http://test:9090"
|
||||
)
|
||||
|
||||
# Verify all dispatcher methods were called during initialization
|
||||
# Each dispatcher factory is invoked exactly once during
|
||||
# construction — one per endpoint that needs a dedicated
|
||||
# wire. dispatch_auth_iam is the dedicated factory for the
|
||||
# AuthEndpoints forwarder (login / bootstrap /
|
||||
# change-password), distinct from dispatch_global_service
|
||||
# (the generic /api/v1/{kind} route).
|
||||
mock_dispatcher_manager.dispatch_global_service.assert_called_once()
|
||||
mock_dispatcher_manager.dispatch_socket.assert_called() # Called twice
|
||||
mock_dispatcher_manager.dispatch_auth_iam.assert_called_once()
|
||||
mock_dispatcher_manager.dispatch_socket.assert_called_once()
|
||||
mock_dispatcher_manager.dispatch_flow_service.assert_called_once()
|
||||
mock_dispatcher_manager.dispatch_flow_import.assert_called_once()
|
||||
mock_dispatcher_manager.dispatch_flow_export.assert_called_once()
|
||||
|
|
|
|||
|
|
@ -12,31 +12,35 @@ class TestMetricsEndpoint:
|
|||
"""Test cases for MetricsEndpoint class"""
|
||||
|
||||
def test_metrics_endpoint_initialization(self):
|
||||
"""Test MetricsEndpoint initialization"""
|
||||
"""Construction records the configured capability on the
|
||||
instance. In production MetricsEndpoint is gated by
|
||||
'metrics:read' so that's the natural value to pass."""
|
||||
mock_auth = MagicMock()
|
||||
|
||||
|
||||
endpoint = MetricsEndpoint(
|
||||
prometheus_url="http://prometheus:9090",
|
||||
endpoint_path="/metrics",
|
||||
auth=mock_auth
|
||||
auth=mock_auth,
|
||||
capability="metrics:read",
|
||||
)
|
||||
|
||||
|
||||
assert endpoint.prometheus_url == "http://prometheus:9090"
|
||||
assert endpoint.path == "/metrics"
|
||||
assert endpoint.auth == mock_auth
|
||||
assert endpoint.operation == "service"
|
||||
assert endpoint.capability == "metrics:read"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_metrics_endpoint_start_method(self):
|
||||
"""Test MetricsEndpoint start method (should be no-op)"""
|
||||
mock_auth = MagicMock()
|
||||
|
||||
|
||||
endpoint = MetricsEndpoint(
|
||||
prometheus_url="http://localhost:9090",
|
||||
endpoint_path="/metrics",
|
||||
auth=mock_auth
|
||||
auth=mock_auth,
|
||||
capability="metrics:read",
|
||||
)
|
||||
|
||||
|
||||
# start() should complete without error
|
||||
await endpoint.start()
|
||||
|
||||
|
|
@ -44,15 +48,16 @@ class TestMetricsEndpoint:
|
|||
"""Test add_routes method registers GET route with wildcard path"""
|
||||
mock_auth = MagicMock()
|
||||
mock_app = MagicMock()
|
||||
|
||||
|
||||
endpoint = MetricsEndpoint(
|
||||
prometheus_url="http://prometheus:9090",
|
||||
endpoint_path="/metrics",
|
||||
auth=mock_auth
|
||||
auth=mock_auth,
|
||||
capability="metrics:read",
|
||||
)
|
||||
|
||||
|
||||
endpoint.add_routes(mock_app)
|
||||
|
||||
|
||||
# Verify add_routes was called with GET route
|
||||
mock_app.add_routes.assert_called_once()
|
||||
# The call should include web.get with wildcard path pattern
|
||||
|
|
|
|||
|
|
@ -1,5 +1,12 @@
|
|||
"""
|
||||
Tests for Gateway Socket Endpoint
|
||||
Tests for Gateway Socket Endpoint.
|
||||
|
||||
In production the only SocketEndpoint registered with HTTP-layer
|
||||
auth is ``/api/v1/socket`` using ``capability=AUTHENTICATED`` with
|
||||
``in_band_auth=True`` (first-frame auth over the websocket frames,
|
||||
not at the handshake). The tests below use AUTHENTICATED as the
|
||||
representative capability; construction / worker / listener
|
||||
behaviour is independent of which capability is configured.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
|
@ -7,41 +14,47 @@ from unittest.mock import MagicMock, AsyncMock
|
|||
from aiohttp import WSMsgType
|
||||
|
||||
from trustgraph.gateway.endpoint.socket import SocketEndpoint
|
||||
from trustgraph.gateway.capabilities import AUTHENTICATED
|
||||
|
||||
|
||||
class TestSocketEndpoint:
|
||||
"""Test cases for SocketEndpoint class"""
|
||||
|
||||
def test_socket_endpoint_initialization(self):
|
||||
"""Test SocketEndpoint initialization"""
|
||||
"""Construction records the configured capability on the
|
||||
instance. No permissive default is applied."""
|
||||
mock_auth = MagicMock()
|
||||
mock_dispatcher = MagicMock()
|
||||
|
||||
|
||||
endpoint = SocketEndpoint(
|
||||
endpoint_path="/api/socket",
|
||||
auth=mock_auth,
|
||||
dispatcher=mock_dispatcher
|
||||
dispatcher=mock_dispatcher,
|
||||
capability=AUTHENTICATED,
|
||||
)
|
||||
|
||||
|
||||
assert endpoint.path == "/api/socket"
|
||||
assert endpoint.auth == mock_auth
|
||||
assert endpoint.dispatcher == mock_dispatcher
|
||||
assert endpoint.operation == "socket"
|
||||
assert endpoint.capability == AUTHENTICATED
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_worker_method(self):
|
||||
"""Test SocketEndpoint worker method"""
|
||||
mock_auth = MagicMock()
|
||||
mock_dispatcher = AsyncMock()
|
||||
|
||||
endpoint = SocketEndpoint("/api/socket", mock_auth, mock_dispatcher)
|
||||
|
||||
|
||||
endpoint = SocketEndpoint(
|
||||
"/api/socket", mock_auth, mock_dispatcher,
|
||||
capability=AUTHENTICATED,
|
||||
)
|
||||
|
||||
mock_ws = MagicMock()
|
||||
mock_running = MagicMock()
|
||||
|
||||
|
||||
# Call worker method
|
||||
await endpoint.worker(mock_ws, mock_dispatcher, mock_running)
|
||||
|
||||
|
||||
# Verify dispatcher.run was called
|
||||
mock_dispatcher.run.assert_called_once()
|
||||
|
||||
|
|
@ -50,8 +63,11 @@ class TestSocketEndpoint:
|
|||
"""Test SocketEndpoint listener method with text message"""
|
||||
mock_auth = MagicMock()
|
||||
mock_dispatcher = AsyncMock()
|
||||
|
||||
endpoint = SocketEndpoint("/api/socket", mock_auth, mock_dispatcher)
|
||||
|
||||
endpoint = SocketEndpoint(
|
||||
"/api/socket", mock_auth, mock_dispatcher,
|
||||
capability=AUTHENTICATED,
|
||||
)
|
||||
|
||||
# Mock websocket with text message
|
||||
mock_msg = MagicMock()
|
||||
|
|
@ -80,8 +96,11 @@ class TestSocketEndpoint:
|
|||
"""Test SocketEndpoint listener method with binary message"""
|
||||
mock_auth = MagicMock()
|
||||
mock_dispatcher = AsyncMock()
|
||||
|
||||
endpoint = SocketEndpoint("/api/socket", mock_auth, mock_dispatcher)
|
||||
|
||||
endpoint = SocketEndpoint(
|
||||
"/api/socket", mock_auth, mock_dispatcher,
|
||||
capability=AUTHENTICATED,
|
||||
)
|
||||
|
||||
# Mock websocket with binary message
|
||||
mock_msg = MagicMock()
|
||||
|
|
@ -110,8 +129,11 @@ class TestSocketEndpoint:
|
|||
"""Test SocketEndpoint listener method with close message"""
|
||||
mock_auth = MagicMock()
|
||||
mock_dispatcher = AsyncMock()
|
||||
|
||||
endpoint = SocketEndpoint("/api/socket", mock_auth, mock_dispatcher)
|
||||
|
||||
endpoint = SocketEndpoint(
|
||||
"/api/socket", mock_auth, mock_dispatcher,
|
||||
capability=AUTHENTICATED,
|
||||
)
|
||||
|
||||
# Mock websocket with close message
|
||||
mock_msg = MagicMock()
|
||||
|
|
|
|||
|
|
@ -12,48 +12,57 @@ class TestStreamEndpoint:
|
|||
"""Test cases for StreamEndpoint class"""
|
||||
|
||||
def test_stream_endpoint_initialization_with_post(self):
|
||||
"""Test StreamEndpoint initialization with POST method"""
|
||||
"""Construction records the configured capability on the
|
||||
instance. StreamEndpoint is used in production for the
|
||||
core-import / core-export / document-stream routes; a
|
||||
document-write capability is a realistic value for a POST
|
||||
stream (e.g. core-import)."""
|
||||
mock_auth = MagicMock()
|
||||
mock_dispatcher = MagicMock()
|
||||
|
||||
|
||||
endpoint = StreamEndpoint(
|
||||
endpoint_path="/api/stream",
|
||||
auth=mock_auth,
|
||||
dispatcher=mock_dispatcher,
|
||||
method="POST"
|
||||
capability="documents:write",
|
||||
method="POST",
|
||||
)
|
||||
|
||||
|
||||
assert endpoint.path == "/api/stream"
|
||||
assert endpoint.auth == mock_auth
|
||||
assert endpoint.dispatcher == mock_dispatcher
|
||||
assert endpoint.operation == "service"
|
||||
assert endpoint.capability == "documents:write"
|
||||
assert endpoint.method == "POST"
|
||||
|
||||
def test_stream_endpoint_initialization_with_get(self):
|
||||
"""Test StreamEndpoint initialization with GET method"""
|
||||
"""GET stream — export-style endpoint, read capability."""
|
||||
mock_auth = MagicMock()
|
||||
mock_dispatcher = MagicMock()
|
||||
|
||||
|
||||
endpoint = StreamEndpoint(
|
||||
endpoint_path="/api/stream",
|
||||
auth=mock_auth,
|
||||
dispatcher=mock_dispatcher,
|
||||
method="GET"
|
||||
capability="documents:read",
|
||||
method="GET",
|
||||
)
|
||||
|
||||
|
||||
assert endpoint.method == "GET"
|
||||
|
||||
def test_stream_endpoint_initialization_default_method(self):
|
||||
"""Test StreamEndpoint initialization with default POST method"""
|
||||
"""Test StreamEndpoint initialization with default POST method.
|
||||
The method default is cosmetic; the capability is not
|
||||
defaulted — it is always required."""
|
||||
mock_auth = MagicMock()
|
||||
mock_dispatcher = MagicMock()
|
||||
|
||||
|
||||
endpoint = StreamEndpoint(
|
||||
endpoint_path="/api/stream",
|
||||
auth=mock_auth,
|
||||
dispatcher=mock_dispatcher
|
||||
dispatcher=mock_dispatcher,
|
||||
capability="documents:write",
|
||||
)
|
||||
|
||||
|
||||
assert endpoint.method == "POST" # Default value
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -61,9 +70,12 @@ class TestStreamEndpoint:
|
|||
"""Test StreamEndpoint start method (should be no-op)"""
|
||||
mock_auth = MagicMock()
|
||||
mock_dispatcher = MagicMock()
|
||||
|
||||
endpoint = StreamEndpoint("/api/stream", mock_auth, mock_dispatcher)
|
||||
|
||||
|
||||
endpoint = StreamEndpoint(
|
||||
"/api/stream", mock_auth, mock_dispatcher,
|
||||
capability="documents:write",
|
||||
)
|
||||
|
||||
# start() should complete without error
|
||||
await endpoint.start()
|
||||
|
||||
|
|
@ -72,16 +84,17 @@ class TestStreamEndpoint:
|
|||
mock_auth = MagicMock()
|
||||
mock_dispatcher = MagicMock()
|
||||
mock_app = MagicMock()
|
||||
|
||||
|
||||
endpoint = StreamEndpoint(
|
||||
endpoint_path="/api/stream",
|
||||
auth=mock_auth,
|
||||
dispatcher=mock_dispatcher,
|
||||
method="POST"
|
||||
capability="documents:write",
|
||||
method="POST",
|
||||
)
|
||||
|
||||
|
||||
endpoint.add_routes(mock_app)
|
||||
|
||||
|
||||
# Verify add_routes was called with POST route
|
||||
mock_app.add_routes.assert_called_once()
|
||||
call_args = mock_app.add_routes.call_args[0][0]
|
||||
|
|
@ -92,16 +105,17 @@ class TestStreamEndpoint:
|
|||
mock_auth = MagicMock()
|
||||
mock_dispatcher = MagicMock()
|
||||
mock_app = MagicMock()
|
||||
|
||||
|
||||
endpoint = StreamEndpoint(
|
||||
endpoint_path="/api/stream",
|
||||
auth=mock_auth,
|
||||
dispatcher=mock_dispatcher,
|
||||
method="GET"
|
||||
capability="documents:read",
|
||||
method="GET",
|
||||
)
|
||||
|
||||
|
||||
endpoint.add_routes(mock_app)
|
||||
|
||||
|
||||
# Verify add_routes was called with GET route
|
||||
mock_app.add_routes.assert_called_once()
|
||||
call_args = mock_app.add_routes.call_args[0][0]
|
||||
|
|
@ -112,13 +126,14 @@ class TestStreamEndpoint:
|
|||
mock_auth = MagicMock()
|
||||
mock_dispatcher = MagicMock()
|
||||
mock_app = MagicMock()
|
||||
|
||||
|
||||
endpoint = StreamEndpoint(
|
||||
endpoint_path="/api/stream",
|
||||
auth=mock_auth,
|
||||
dispatcher=mock_dispatcher,
|
||||
method="INVALID"
|
||||
capability="documents:write",
|
||||
method="INVALID",
|
||||
)
|
||||
|
||||
|
||||
with pytest.raises(RuntimeError, match="Bad method"):
|
||||
endpoint.add_routes(mock_app)
|
||||
|
|
@ -12,29 +12,36 @@ class TestVariableEndpoint:
|
|||
"""Test cases for VariableEndpoint class"""
|
||||
|
||||
def test_variable_endpoint_initialization(self):
|
||||
"""Test VariableEndpoint initialization"""
|
||||
"""Construction records the configured capability on the
|
||||
instance. VariableEndpoint is used in production for the
|
||||
/api/v1/{kind} admin-scoped global service routes, so a
|
||||
write-side capability is a realistic value for the test."""
|
||||
mock_auth = MagicMock()
|
||||
mock_dispatcher = MagicMock()
|
||||
|
||||
|
||||
endpoint = VariableEndpoint(
|
||||
endpoint_path="/api/variable",
|
||||
auth=mock_auth,
|
||||
dispatcher=mock_dispatcher
|
||||
dispatcher=mock_dispatcher,
|
||||
capability="config:write",
|
||||
)
|
||||
|
||||
|
||||
assert endpoint.path == "/api/variable"
|
||||
assert endpoint.auth == mock_auth
|
||||
assert endpoint.dispatcher == mock_dispatcher
|
||||
assert endpoint.operation == "service"
|
||||
assert endpoint.capability == "config:write"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_variable_endpoint_start_method(self):
|
||||
"""Test VariableEndpoint start method (should be no-op)"""
|
||||
mock_auth = MagicMock()
|
||||
mock_dispatcher = MagicMock()
|
||||
|
||||
endpoint = VariableEndpoint("/api/var", mock_auth, mock_dispatcher)
|
||||
|
||||
|
||||
endpoint = VariableEndpoint(
|
||||
"/api/var", mock_auth, mock_dispatcher,
|
||||
capability="config:write",
|
||||
)
|
||||
|
||||
# start() should complete without error
|
||||
await endpoint.start()
|
||||
|
||||
|
|
@ -43,10 +50,13 @@ class TestVariableEndpoint:
|
|||
mock_auth = MagicMock()
|
||||
mock_dispatcher = MagicMock()
|
||||
mock_app = MagicMock()
|
||||
|
||||
endpoint = VariableEndpoint("/api/variable", mock_auth, mock_dispatcher)
|
||||
|
||||
endpoint = VariableEndpoint(
|
||||
"/api/variable", mock_auth, mock_dispatcher,
|
||||
capability="config:write",
|
||||
)
|
||||
endpoint.add_routes(mock_app)
|
||||
|
||||
|
||||
# Verify add_routes was called with POST route
|
||||
mock_app.add_routes.assert_called_once()
|
||||
call_args = mock_app.add_routes.call_args[0][0]
|
||||
|
|
|
|||
|
|
@ -1,355 +1,179 @@
|
|||
"""
|
||||
Tests for Gateway Service API
|
||||
Tests for gateway/service.py — the Api class that wires together
|
||||
the pub/sub backend, IAM auth, config receiver, dispatcher manager,
|
||||
and endpoint manager.
|
||||
|
||||
The legacy ``GATEWAY_SECRET`` / ``default_api_token`` / allow-all
|
||||
surface is gone, so the tests here focus on the Api's construction
|
||||
and composition rather than the removed auth behaviour. IamAuth's
|
||||
own behaviour is covered in test_auth.py.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import asyncio
|
||||
from unittest.mock import Mock, patch, MagicMock, AsyncMock
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
from aiohttp import web
|
||||
import pulsar
|
||||
|
||||
from trustgraph.gateway.service import Api, run, default_pulsar_host, default_prometheus_url, default_timeout, default_port, default_api_token
|
||||
|
||||
# Tests for Gateway Service API
|
||||
from trustgraph.gateway.service import (
|
||||
Api,
|
||||
default_pulsar_host, default_prometheus_url,
|
||||
default_timeout, default_port,
|
||||
)
|
||||
from trustgraph.gateway.auth import IamAuth
|
||||
|
||||
|
||||
class TestApi:
|
||||
"""Test cases for Api class"""
|
||||
|
||||
# -- constants -------------------------------------------------------------
|
||||
|
||||
def test_api_initialization_with_defaults(self):
|
||||
"""Test Api initialization with default values"""
|
||||
with patch('trustgraph.gateway.service.get_pubsub') as mock_get_pubsub:
|
||||
mock_backend = Mock()
|
||||
mock_get_pubsub.return_value = mock_backend
|
||||
|
||||
api = Api()
|
||||
class TestDefaults:
|
||||
|
||||
assert api.port == default_port
|
||||
assert api.timeout == default_timeout
|
||||
assert api.pulsar_host == default_pulsar_host
|
||||
assert api.pulsar_api_key is None
|
||||
assert api.prometheus_url == default_prometheus_url + "/"
|
||||
assert api.auth.allow_all is True
|
||||
def test_exports_default_constants(self):
|
||||
# These are consumed by CLIs / tests / docs. Sanity-check
|
||||
# that they're the expected shape.
|
||||
assert default_port == 8088
|
||||
assert default_timeout == 600
|
||||
assert default_pulsar_host.startswith("pulsar://")
|
||||
assert default_prometheus_url.startswith("http")
|
||||
|
||||
# Verify get_pubsub was called
|
||||
mock_get_pubsub.assert_called_once()
|
||||
|
||||
def test_api_initialization_with_custom_config(self):
|
||||
"""Test Api initialization with custom configuration"""
|
||||
# -- Api construction ------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_backend():
|
||||
return Mock()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def api(mock_backend):
|
||||
with patch(
|
||||
"trustgraph.gateway.service.get_pubsub",
|
||||
return_value=mock_backend,
|
||||
):
|
||||
yield Api()
|
||||
|
||||
|
||||
class TestApiConstruction:
|
||||
|
||||
def test_defaults(self, api):
|
||||
assert api.port == default_port
|
||||
assert api.timeout == default_timeout
|
||||
assert api.pulsar_host == default_pulsar_host
|
||||
assert api.pulsar_api_key is None
|
||||
# prometheus_url gets normalised with a trailing slash
|
||||
assert api.prometheus_url == default_prometheus_url + "/"
|
||||
|
||||
def test_auth_is_iam_backed(self, api):
|
||||
# Any Api always gets an IamAuth. There is no "no auth" mode
|
||||
# (GATEWAY_SECRET / allow_all has been removed — see IAM spec).
|
||||
assert isinstance(api.auth, IamAuth)
|
||||
|
||||
def test_components_wired(self, api):
|
||||
assert api.config_receiver is not None
|
||||
assert api.dispatcher_manager is not None
|
||||
assert api.endpoint_manager is not None
|
||||
|
||||
def test_dispatcher_manager_has_auth(self, api):
|
||||
# The Mux uses this handle for first-frame socket auth.
|
||||
assert api.dispatcher_manager.auth is api.auth
|
||||
|
||||
def test_custom_config(self, mock_backend):
|
||||
config = {
|
||||
"port": 9000,
|
||||
"timeout": 300,
|
||||
"pulsar_host": "pulsar://custom-host:6650",
|
||||
"pulsar_api_key": "test-api-key",
|
||||
"pulsar_listener": "custom-listener",
|
||||
"pulsar_api_key": "custom-key",
|
||||
"prometheus_url": "http://custom-prometheus:9090",
|
||||
"api_token": "secret-token"
|
||||
}
|
||||
with patch(
|
||||
"trustgraph.gateway.service.get_pubsub",
|
||||
return_value=mock_backend,
|
||||
):
|
||||
a = Api(**config)
|
||||
|
||||
with patch('trustgraph.gateway.service.get_pubsub') as mock_get_pubsub:
|
||||
mock_backend = Mock()
|
||||
mock_get_pubsub.return_value = mock_backend
|
||||
assert a.port == 9000
|
||||
assert a.timeout == 300
|
||||
assert a.pulsar_host == "pulsar://custom-host:6650"
|
||||
assert a.pulsar_api_key == "custom-key"
|
||||
# Trailing slash added.
|
||||
assert a.prometheus_url == "http://custom-prometheus:9090/"
|
||||
|
||||
api = Api(**config)
|
||||
def test_prometheus_url_already_has_trailing_slash(self, mock_backend):
|
||||
with patch(
|
||||
"trustgraph.gateway.service.get_pubsub",
|
||||
return_value=mock_backend,
|
||||
):
|
||||
a = Api(prometheus_url="http://p:9090/")
|
||||
assert a.prometheus_url == "http://p:9090/"
|
||||
|
||||
assert api.port == 9000
|
||||
assert api.timeout == 300
|
||||
assert api.pulsar_host == "pulsar://custom-host:6650"
|
||||
assert api.pulsar_api_key == "test-api-key"
|
||||
assert api.prometheus_url == "http://custom-prometheus:9090/"
|
||||
assert api.auth.token == "secret-token"
|
||||
assert api.auth.allow_all is False
|
||||
def test_queue_overrides_parsed_for_config(self, mock_backend):
|
||||
with patch(
|
||||
"trustgraph.gateway.service.get_pubsub",
|
||||
return_value=mock_backend,
|
||||
):
|
||||
a = Api(
|
||||
config_request_queue="alt-config-req",
|
||||
config_response_queue="alt-config-resp",
|
||||
)
|
||||
overrides = a.dispatcher_manager.queue_overrides
|
||||
assert overrides.get("config", {}).get("request") == "alt-config-req"
|
||||
assert overrides.get("config", {}).get("response") == "alt-config-resp"
|
||||
|
||||
# Verify get_pubsub was called with config
|
||||
mock_get_pubsub.assert_called_once_with(**config)
|
||||
|
||||
def test_api_initialization_with_pulsar_api_key(self):
|
||||
"""Test Api initialization with Pulsar API key authentication"""
|
||||
with patch('trustgraph.gateway.service.get_pubsub') as mock_get_pubsub:
|
||||
mock_get_pubsub.return_value = Mock()
|
||||
# -- app_factory -----------------------------------------------------------
|
||||
|
||||
api = Api(pulsar_api_key="test-key")
|
||||
|
||||
# Verify api key was stored
|
||||
assert api.pulsar_api_key == "test-key"
|
||||
mock_get_pubsub.assert_called_once()
|
||||
|
||||
def test_api_initialization_prometheus_url_normalization(self):
|
||||
"""Test that prometheus_url gets normalized with trailing slash"""
|
||||
with patch('trustgraph.gateway.service.get_pubsub') as mock_get_pubsub:
|
||||
mock_get_pubsub.return_value = Mock()
|
||||
|
||||
# Test URL without trailing slash
|
||||
api = Api(prometheus_url="http://prometheus:9090")
|
||||
assert api.prometheus_url == "http://prometheus:9090/"
|
||||
|
||||
# Test URL with trailing slash
|
||||
api = Api(prometheus_url="http://prometheus:9090/")
|
||||
assert api.prometheus_url == "http://prometheus:9090/"
|
||||
|
||||
def test_api_initialization_empty_api_token_means_no_auth(self):
|
||||
"""Test that empty API token results in allow_all authentication"""
|
||||
with patch('trustgraph.gateway.service.get_pubsub') as mock_get_pubsub:
|
||||
mock_get_pubsub.return_value = Mock()
|
||||
|
||||
api = Api(api_token="")
|
||||
assert api.auth.allow_all is True
|
||||
|
||||
def test_api_initialization_none_api_token_means_no_auth(self):
|
||||
"""Test that None API token results in allow_all authentication"""
|
||||
with patch('trustgraph.gateway.service.get_pubsub') as mock_get_pubsub:
|
||||
mock_get_pubsub.return_value = Mock()
|
||||
|
||||
api = Api(api_token=None)
|
||||
assert api.auth.allow_all is True
|
||||
class TestAppFactory:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_app_factory_creates_application(self):
|
||||
"""Test that app_factory creates aiohttp application"""
|
||||
with patch('trustgraph.gateway.service.get_pubsub') as mock_get_pubsub:
|
||||
mock_get_pubsub.return_value = Mock()
|
||||
|
||||
api = Api()
|
||||
|
||||
# Mock the dependencies
|
||||
api.config_receiver = Mock()
|
||||
api.config_receiver.start = AsyncMock()
|
||||
api.endpoint_manager = Mock()
|
||||
api.endpoint_manager.add_routes = Mock()
|
||||
api.endpoint_manager.start = AsyncMock()
|
||||
|
||||
app = await api.app_factory()
|
||||
|
||||
assert isinstance(app, web.Application)
|
||||
assert app._client_max_size == 256 * 1024 * 1024
|
||||
|
||||
# Verify that config receiver was started
|
||||
api.config_receiver.start.assert_called_once()
|
||||
|
||||
# Verify that endpoint manager was configured
|
||||
api.endpoint_manager.add_routes.assert_called_once_with(app)
|
||||
api.endpoint_manager.start.assert_called_once()
|
||||
async def test_creates_aiohttp_app(self, api):
|
||||
# Stub out the long-tail dependencies that reach out to IAM /
|
||||
# pub/sub so we can exercise the factory in isolation.
|
||||
api.auth.start = AsyncMock()
|
||||
api.config_receiver = Mock()
|
||||
api.config_receiver.start = AsyncMock()
|
||||
api.endpoint_manager = Mock()
|
||||
api.endpoint_manager.add_routes = Mock()
|
||||
api.endpoint_manager.start = AsyncMock()
|
||||
api.endpoints = []
|
||||
|
||||
app = await api.app_factory()
|
||||
|
||||
assert isinstance(app, web.Application)
|
||||
assert app._client_max_size == 256 * 1024 * 1024
|
||||
api.auth.start.assert_called_once()
|
||||
api.config_receiver.start.assert_called_once()
|
||||
api.endpoint_manager.add_routes.assert_called_once_with(app)
|
||||
api.endpoint_manager.start.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_app_factory_with_custom_endpoints(self):
|
||||
"""Test app_factory with custom endpoints"""
|
||||
with patch('trustgraph.gateway.service.get_pubsub') as mock_get_pubsub:
|
||||
mock_get_pubsub.return_value = Mock()
|
||||
|
||||
api = Api()
|
||||
|
||||
# Mock custom endpoints
|
||||
mock_endpoint1 = Mock()
|
||||
mock_endpoint1.add_routes = Mock()
|
||||
mock_endpoint1.start = AsyncMock()
|
||||
|
||||
mock_endpoint2 = Mock()
|
||||
mock_endpoint2.add_routes = Mock()
|
||||
mock_endpoint2.start = AsyncMock()
|
||||
|
||||
api.endpoints = [mock_endpoint1, mock_endpoint2]
|
||||
|
||||
# Mock the dependencies
|
||||
api.config_receiver = Mock()
|
||||
api.config_receiver.start = AsyncMock()
|
||||
api.endpoint_manager = Mock()
|
||||
api.endpoint_manager.add_routes = Mock()
|
||||
api.endpoint_manager.start = AsyncMock()
|
||||
|
||||
app = await api.app_factory()
|
||||
|
||||
# Verify custom endpoints were configured
|
||||
mock_endpoint1.add_routes.assert_called_once_with(app)
|
||||
mock_endpoint1.start.assert_called_once()
|
||||
mock_endpoint2.add_routes.assert_called_once_with(app)
|
||||
mock_endpoint2.start.assert_called_once()
|
||||
async def test_auth_start_runs_before_accepting_traffic(self, api):
|
||||
"""``auth.start()`` fetches the IAM signing key, and must
|
||||
complete (or time out) before the gateway begins accepting
|
||||
requests. It's the first await in app_factory."""
|
||||
order = []
|
||||
|
||||
def test_run_method_calls_web_run_app(self):
|
||||
"""Test that run method calls web.run_app"""
|
||||
with patch('trustgraph.gateway.service.get_pubsub') as mock_get_pubsub, \
|
||||
patch('aiohttp.web.run_app') as mock_run_app:
|
||||
mock_get_pubsub.return_value = Mock()
|
||||
# AsyncMock.side_effect expects a sync callable (its return
|
||||
# value becomes the coroutine's return); a plain list.append
|
||||
# avoids the "coroutine was never awaited" trap of an async
|
||||
# side_effect.
|
||||
api.auth.start = AsyncMock(
|
||||
side_effect=lambda: order.append("auth"),
|
||||
)
|
||||
api.config_receiver = Mock()
|
||||
api.config_receiver.start = AsyncMock(
|
||||
side_effect=lambda: order.append("config"),
|
||||
)
|
||||
api.endpoint_manager = Mock()
|
||||
api.endpoint_manager.add_routes = Mock()
|
||||
api.endpoint_manager.start = AsyncMock(
|
||||
side_effect=lambda: order.append("endpoints"),
|
||||
)
|
||||
api.endpoints = []
|
||||
|
||||
# Api.run() passes self.app_factory() — a coroutine — to
|
||||
# web.run_app, which would normally consume it inside its own
|
||||
# event loop. Since we mock run_app, close the coroutine here
|
||||
# so it doesn't leak as an "unawaited coroutine" RuntimeWarning.
|
||||
def _consume_coro(coro, **kwargs):
|
||||
coro.close()
|
||||
mock_run_app.side_effect = _consume_coro
|
||||
await api.app_factory()
|
||||
|
||||
api = Api(port=8080)
|
||||
api.run()
|
||||
|
||||
# Verify run_app was called once with the correct port
|
||||
mock_run_app.assert_called_once()
|
||||
args, kwargs = mock_run_app.call_args
|
||||
assert len(args) == 1 # Should have one positional arg (the coroutine)
|
||||
assert kwargs == {'port': 8080} # Should have port keyword arg
|
||||
|
||||
def test_api_components_initialization(self):
|
||||
"""Test that all API components are properly initialized"""
|
||||
with patch('trustgraph.gateway.service.get_pubsub') as mock_get_pubsub:
|
||||
mock_get_pubsub.return_value = Mock()
|
||||
|
||||
api = Api()
|
||||
|
||||
# Verify all components are initialized
|
||||
assert api.config_receiver is not None
|
||||
assert api.dispatcher_manager is not None
|
||||
assert api.endpoint_manager is not None
|
||||
assert api.endpoints == []
|
||||
|
||||
# Verify component relationships
|
||||
assert api.dispatcher_manager.backend == api.pubsub_backend
|
||||
assert api.dispatcher_manager.config_receiver == api.config_receiver
|
||||
assert api.endpoint_manager.dispatcher_manager == api.dispatcher_manager
|
||||
# EndpointManager doesn't store auth directly, it passes it to individual endpoints
|
||||
|
||||
|
||||
class TestRunFunction:
|
||||
"""Test cases for the run() function"""
|
||||
|
||||
def test_run_function_with_metrics_enabled(self):
|
||||
"""Test run function with metrics enabled"""
|
||||
import warnings
|
||||
# Suppress the specific async warning with a broader pattern
|
||||
warnings.filterwarnings("ignore", message=".*Api.app_factory.*was never awaited", category=RuntimeWarning)
|
||||
|
||||
with patch('argparse.ArgumentParser.parse_args') as mock_parse_args, \
|
||||
patch('trustgraph.gateway.service.start_http_server') as mock_start_http_server:
|
||||
|
||||
# Mock command line arguments
|
||||
mock_args = Mock()
|
||||
mock_args.metrics = True
|
||||
mock_args.metrics_port = 8000
|
||||
mock_parse_args.return_value = mock_args
|
||||
|
||||
# Create a simple mock instance without any async methods
|
||||
mock_api_instance = Mock()
|
||||
mock_api_instance.run = Mock()
|
||||
|
||||
# Create a mock Api class without importing the real one
|
||||
mock_api = Mock(return_value=mock_api_instance)
|
||||
|
||||
# Patch using context manager to avoid importing the real Api class
|
||||
with patch('trustgraph.gateway.service.Api', mock_api):
|
||||
# Mock vars() to return a dict
|
||||
with patch('builtins.vars') as mock_vars:
|
||||
mock_vars.return_value = {
|
||||
'metrics': True,
|
||||
'metrics_port': 8000,
|
||||
'pulsar_host': default_pulsar_host,
|
||||
'timeout': default_timeout
|
||||
}
|
||||
|
||||
run()
|
||||
|
||||
# Verify metrics server was started
|
||||
mock_start_http_server.assert_called_once_with(8000)
|
||||
|
||||
# Verify Api was created and run was called
|
||||
mock_api.assert_called_once()
|
||||
mock_api_instance.run.assert_called_once()
|
||||
|
||||
@patch('trustgraph.gateway.service.start_http_server')
|
||||
@patch('argparse.ArgumentParser.parse_args')
|
||||
def test_run_function_with_metrics_disabled(self, mock_parse_args, mock_start_http_server):
|
||||
"""Test run function with metrics disabled"""
|
||||
# Mock command line arguments
|
||||
mock_args = Mock()
|
||||
mock_args.metrics = False
|
||||
mock_parse_args.return_value = mock_args
|
||||
|
||||
# Create a simple mock instance without any async methods
|
||||
mock_api_instance = Mock()
|
||||
mock_api_instance.run = Mock()
|
||||
|
||||
# Patch the Api class inside the test without using decorators
|
||||
with patch('trustgraph.gateway.service.Api') as mock_api:
|
||||
mock_api.return_value = mock_api_instance
|
||||
|
||||
# Mock vars() to return a dict
|
||||
with patch('builtins.vars') as mock_vars:
|
||||
mock_vars.return_value = {
|
||||
'metrics': False,
|
||||
'metrics_port': 8000,
|
||||
'pulsar_host': default_pulsar_host,
|
||||
'timeout': default_timeout
|
||||
}
|
||||
|
||||
run()
|
||||
|
||||
# Verify metrics server was NOT started
|
||||
mock_start_http_server.assert_not_called()
|
||||
|
||||
# Verify Api was created and run was called
|
||||
mock_api.assert_called_once()
|
||||
mock_api_instance.run.assert_called_once()
|
||||
|
||||
@patch('argparse.ArgumentParser.parse_args')
|
||||
def test_run_function_argument_parsing(self, mock_parse_args):
|
||||
"""Test that run function properly parses command line arguments"""
|
||||
# Mock command line arguments
|
||||
mock_args = Mock()
|
||||
mock_args.metrics = False
|
||||
mock_parse_args.return_value = mock_args
|
||||
|
||||
# Create a simple mock instance without any async methods
|
||||
mock_api_instance = Mock()
|
||||
mock_api_instance.run = Mock()
|
||||
|
||||
# Mock vars() to return a dict with all expected arguments
|
||||
expected_args = {
|
||||
'pulsar_host': 'pulsar://test:6650',
|
||||
'pulsar_api_key': 'test-key',
|
||||
'pulsar_listener': 'test-listener',
|
||||
'prometheus_url': 'http://test-prometheus:9090',
|
||||
'port': 9000,
|
||||
'timeout': 300,
|
||||
'api_token': 'secret',
|
||||
'log_level': 'INFO',
|
||||
'metrics': False,
|
||||
'metrics_port': 8001
|
||||
}
|
||||
|
||||
# Patch the Api class inside the test without using decorators
|
||||
with patch('trustgraph.gateway.service.Api') as mock_api:
|
||||
mock_api.return_value = mock_api_instance
|
||||
|
||||
with patch('builtins.vars') as mock_vars:
|
||||
mock_vars.return_value = expected_args
|
||||
|
||||
run()
|
||||
|
||||
# Verify Api was created with the parsed arguments
|
||||
mock_api.assert_called_once_with(**expected_args)
|
||||
mock_api_instance.run.assert_called_once()
|
||||
|
||||
def test_run_function_creates_argument_parser(self):
|
||||
"""Test that run function creates argument parser with correct arguments"""
|
||||
with patch('argparse.ArgumentParser') as mock_parser_class:
|
||||
mock_parser = Mock()
|
||||
mock_parser_class.return_value = mock_parser
|
||||
mock_parser.parse_args.return_value = Mock(metrics=False)
|
||||
|
||||
with patch('trustgraph.gateway.service.Api') as mock_api, \
|
||||
patch('builtins.vars') as mock_vars:
|
||||
mock_vars.return_value = {'metrics': False}
|
||||
mock_api.return_value = Mock()
|
||||
|
||||
run()
|
||||
|
||||
# Verify ArgumentParser was created
|
||||
mock_parser_class.assert_called_once()
|
||||
|
||||
# Verify add_argument was called for each expected argument
|
||||
expected_arguments = [
|
||||
'pulsar-host', 'pulsar-api-key', 'pulsar-listener',
|
||||
'prometheus-url', 'port', 'timeout', 'api-token',
|
||||
'log-level', 'metrics', 'metrics-port'
|
||||
]
|
||||
|
||||
# Check that add_argument was called multiple times (once for each arg)
|
||||
assert mock_parser.add_argument.call_count >= len(expected_arguments)
|
||||
# auth.start must be first (before config receiver, before
|
||||
# any endpoint starts).
|
||||
assert order[0] == "auth"
|
||||
# All three must have run.
|
||||
assert set(order) == {"auth", "config", "endpoints"}
|
||||
|
|
|
|||
|
|
@ -1,4 +1,15 @@
|
|||
"""Unit tests for SocketEndpoint graceful shutdown functionality."""
|
||||
"""Unit tests for SocketEndpoint graceful shutdown functionality.
|
||||
|
||||
These tests exercise SocketEndpoint in its handshake-auth
|
||||
configuration (``in_band_auth=False``) — the mode used in production
|
||||
for the flow import/export streaming endpoints. The mux socket at
|
||||
``/api/v1/socket`` uses ``in_band_auth=True`` instead, where the
|
||||
handshake always accepts and authentication runs on the first
|
||||
WebSocket frame; that path is covered by the Mux tests.
|
||||
|
||||
Every endpoint constructor here passes an explicit capability — no
|
||||
permissive default is relied upon.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import asyncio
|
||||
|
|
@ -6,13 +17,31 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
|||
from aiohttp import web, WSMsgType
|
||||
from trustgraph.gateway.endpoint.socket import SocketEndpoint
|
||||
from trustgraph.gateway.running import Running
|
||||
from trustgraph.gateway.auth import Identity
|
||||
|
||||
|
||||
# Representative capability used across these tests — corresponds to
|
||||
# the flow-import streaming endpoint pattern that uses this class.
|
||||
TEST_CAP = "graph:write"
|
||||
|
||||
|
||||
def _valid_identity(roles=("admin",)):
|
||||
return Identity(
|
||||
user_id="test-user",
|
||||
workspace="default",
|
||||
roles=list(roles),
|
||||
source="api-key",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_auth():
|
||||
"""Mock authentication service."""
|
||||
"""Mock IAM-backed authenticator. Successful by default —
|
||||
``authenticate`` returns a valid admin identity. Tests that
|
||||
need the auth failure path override the ``authenticate``
|
||||
attribute locally."""
|
||||
auth = MagicMock()
|
||||
auth.permitted.return_value = True
|
||||
auth.authenticate = AsyncMock(return_value=_valid_identity())
|
||||
return auth
|
||||
|
||||
|
||||
|
|
@ -25,7 +54,7 @@ def mock_dispatcher_factory():
|
|||
dispatcher.receive = AsyncMock()
|
||||
dispatcher.destroy = AsyncMock()
|
||||
return dispatcher
|
||||
|
||||
|
||||
return dispatcher_factory
|
||||
|
||||
|
||||
|
|
@ -35,7 +64,8 @@ def socket_endpoint(mock_auth, mock_dispatcher_factory):
|
|||
return SocketEndpoint(
|
||||
endpoint_path="/test-socket",
|
||||
auth=mock_auth,
|
||||
dispatcher=mock_dispatcher_factory
|
||||
dispatcher=mock_dispatcher_factory,
|
||||
capability=TEST_CAP,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -61,7 +91,10 @@ def mock_request():
|
|||
@pytest.mark.asyncio
|
||||
async def test_listener_graceful_shutdown_on_close():
|
||||
"""Test listener handles websocket close gracefully."""
|
||||
socket_endpoint = SocketEndpoint("/test", MagicMock(), AsyncMock())
|
||||
socket_endpoint = SocketEndpoint(
|
||||
"/test", MagicMock(), AsyncMock(),
|
||||
capability=TEST_CAP,
|
||||
)
|
||||
|
||||
# Mock websocket that closes after one message
|
||||
ws = AsyncMock()
|
||||
|
|
@ -99,9 +132,9 @@ async def test_listener_graceful_shutdown_on_close():
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_normal_flow():
|
||||
"""Test normal websocket handling flow."""
|
||||
"""Valid bearer → handshake accepted, dispatcher created."""
|
||||
mock_auth = MagicMock()
|
||||
mock_auth.permitted.return_value = True
|
||||
mock_auth.authenticate = AsyncMock(return_value=_valid_identity())
|
||||
|
||||
dispatcher_created = False
|
||||
async def mock_dispatcher_factory(ws, running, match_info):
|
||||
|
|
@ -111,7 +144,10 @@ async def test_handle_normal_flow():
|
|||
dispatcher.destroy = AsyncMock()
|
||||
return dispatcher
|
||||
|
||||
socket_endpoint = SocketEndpoint("/test", mock_auth, mock_dispatcher_factory)
|
||||
socket_endpoint = SocketEndpoint(
|
||||
"/test", mock_auth, mock_dispatcher_factory,
|
||||
capability=TEST_CAP,
|
||||
)
|
||||
|
||||
request = MagicMock()
|
||||
request.query = {"token": "valid-token"}
|
||||
|
|
@ -155,7 +191,7 @@ async def test_handle_normal_flow():
|
|||
async def test_handle_exception_group_cleanup():
|
||||
"""Test exception group triggers dispatcher cleanup."""
|
||||
mock_auth = MagicMock()
|
||||
mock_auth.permitted.return_value = True
|
||||
mock_auth.authenticate = AsyncMock(return_value=_valid_identity())
|
||||
|
||||
mock_dispatcher = AsyncMock()
|
||||
mock_dispatcher.destroy = AsyncMock()
|
||||
|
|
@ -163,7 +199,10 @@ async def test_handle_exception_group_cleanup():
|
|||
async def mock_dispatcher_factory(ws, running, match_info):
|
||||
return mock_dispatcher
|
||||
|
||||
socket_endpoint = SocketEndpoint("/test", mock_auth, mock_dispatcher_factory)
|
||||
socket_endpoint = SocketEndpoint(
|
||||
"/test", mock_auth, mock_dispatcher_factory,
|
||||
capability=TEST_CAP,
|
||||
)
|
||||
|
||||
request = MagicMock()
|
||||
request.query = {"token": "valid-token"}
|
||||
|
|
@ -222,7 +261,7 @@ async def test_handle_exception_group_cleanup():
|
|||
async def test_handle_dispatcher_cleanup_timeout():
|
||||
"""Test dispatcher cleanup with timeout."""
|
||||
mock_auth = MagicMock()
|
||||
mock_auth.permitted.return_value = True
|
||||
mock_auth.authenticate = AsyncMock(return_value=_valid_identity())
|
||||
|
||||
# Mock dispatcher that takes long to destroy
|
||||
mock_dispatcher = AsyncMock()
|
||||
|
|
@ -231,7 +270,10 @@ async def test_handle_dispatcher_cleanup_timeout():
|
|||
async def mock_dispatcher_factory(ws, running, match_info):
|
||||
return mock_dispatcher
|
||||
|
||||
socket_endpoint = SocketEndpoint("/test", mock_auth, mock_dispatcher_factory)
|
||||
socket_endpoint = SocketEndpoint(
|
||||
"/test", mock_auth, mock_dispatcher_factory,
|
||||
capability=TEST_CAP,
|
||||
)
|
||||
|
||||
request = MagicMock()
|
||||
request.query = {"token": "valid-token"}
|
||||
|
|
@ -285,49 +327,67 @@ async def test_handle_dispatcher_cleanup_timeout():
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_unauthorized_request():
|
||||
"""Test handling of unauthorized requests."""
|
||||
"""A bearer that the IAM layer rejects causes the handshake to
|
||||
fail with 401. IamAuth surfaces an HTTPUnauthorized; the
|
||||
endpoint propagates it. Note that the endpoint intentionally
|
||||
does NOT distinguish 'bad token', 'expired', 'revoked', etc. —
|
||||
that's the IAM error-masking policy."""
|
||||
mock_auth = MagicMock()
|
||||
mock_auth.permitted.return_value = False # Unauthorized
|
||||
|
||||
socket_endpoint = SocketEndpoint("/test", mock_auth, AsyncMock())
|
||||
|
||||
mock_auth.authenticate = AsyncMock(side_effect=web.HTTPUnauthorized(
|
||||
text='{"error":"auth failure"}',
|
||||
content_type="application/json",
|
||||
))
|
||||
|
||||
socket_endpoint = SocketEndpoint(
|
||||
"/test", mock_auth, AsyncMock(),
|
||||
capability=TEST_CAP,
|
||||
)
|
||||
|
||||
request = MagicMock()
|
||||
request.query = {"token": "invalid-token"}
|
||||
|
||||
|
||||
result = await socket_endpoint.handle(request)
|
||||
|
||||
# Should return HTTP 401
|
||||
|
||||
assert isinstance(result, web.HTTPUnauthorized)
|
||||
|
||||
# Should have checked permission
|
||||
mock_auth.permitted.assert_called_once_with("invalid-token", "socket")
|
||||
# authenticate must have been invoked with a synthetic request
|
||||
# carrying Bearer <the-token>. The endpoint wraps the query-
|
||||
# string token into an Authorization header for a uniform auth
|
||||
# path — the IAM layer does not look at query strings directly.
|
||||
mock_auth.authenticate.assert_called_once()
|
||||
passed_req = mock_auth.authenticate.call_args.args[0]
|
||||
assert passed_req.headers["Authorization"] == "Bearer invalid-token"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_missing_token():
|
||||
"""Test handling of requests with missing token."""
|
||||
"""Request with no ``token`` query param → 401 before any
|
||||
IAM call is made (cheap short-circuit)."""
|
||||
mock_auth = MagicMock()
|
||||
mock_auth.permitted.return_value = False
|
||||
|
||||
socket_endpoint = SocketEndpoint("/test", mock_auth, AsyncMock())
|
||||
|
||||
mock_auth.authenticate = AsyncMock(
|
||||
side_effect=AssertionError(
|
||||
"authenticate must not be invoked when no token is present"
|
||||
),
|
||||
)
|
||||
|
||||
socket_endpoint = SocketEndpoint(
|
||||
"/test", mock_auth, AsyncMock(),
|
||||
capability=TEST_CAP,
|
||||
)
|
||||
|
||||
request = MagicMock()
|
||||
request.query = {} # No token
|
||||
|
||||
|
||||
result = await socket_endpoint.handle(request)
|
||||
|
||||
# Should return HTTP 401
|
||||
|
||||
assert isinstance(result, web.HTTPUnauthorized)
|
||||
|
||||
# Should have checked permission with empty token
|
||||
mock_auth.permitted.assert_called_once_with("", "socket")
|
||||
mock_auth.authenticate.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_websocket_already_closed():
|
||||
"""Test handling when websocket is already closed."""
|
||||
mock_auth = MagicMock()
|
||||
mock_auth.permitted.return_value = True
|
||||
mock_auth.authenticate = AsyncMock(return_value=_valid_identity())
|
||||
|
||||
mock_dispatcher = AsyncMock()
|
||||
mock_dispatcher.destroy = AsyncMock()
|
||||
|
|
@ -335,7 +395,10 @@ async def test_handle_websocket_already_closed():
|
|||
async def mock_dispatcher_factory(ws, running, match_info):
|
||||
return mock_dispatcher
|
||||
|
||||
socket_endpoint = SocketEndpoint("/test", mock_auth, mock_dispatcher_factory)
|
||||
socket_endpoint = SocketEndpoint(
|
||||
"/test", mock_auth, mock_dispatcher_factory,
|
||||
capability=TEST_CAP,
|
||||
)
|
||||
|
||||
request = MagicMock()
|
||||
request.query = {"token": "valid-token"}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue