feat: IAM service, gateway auth middleware, capability model, and CLIs (#849)

Replaces the legacy GATEWAY_SECRET shared-token gate with an IAM-backed
identity and authorisation model.  The gateway no longer has an
"allow-all" or "no auth" mode; every request is authenticated via the
IAM service, authorised against a capability model that encodes both
the operation and the workspace it targets, and rejected with a
deliberately-uninformative 401 / 403 on any failure.

IAM service (trustgraph-flow/trustgraph/iam, trustgraph-base/schema/iam)
-----------------------------------------------------------------------
* New backend service (iam-svc) owning users, workspaces, API keys,
  passwords and JWT signing keys in Cassandra.  Reached over the
  standard pub/sub request/response pattern; gateway is the only
  caller.
* Operations: bootstrap, resolve-api-key, login, get-signing-key-public,
  rotate-signing-key, create/list/get/update/disable/delete/enable-user,
  change-password, reset-password, create/list/get/update/disable-
  workspace, create/list/revoke-api-key.
* Ed25519 JWT signing (alg=EdDSA).  Key rotation writes a new kid and
  retires the previous one; validation is grace-period friendly.
* Passwords: PBKDF2-HMAC-SHA-256, 600k iterations, per-user salt.
* API keys: 128-bit random, SHA-256 hashed.  Plaintext returned once.
* Bootstrap is explicit: --bootstrap-mode {token,bootstrap} is a
  required startup argument with no permissive default.  Masked
  "auth failure" errors hide whether a refused bootstrap request was
  due to mode, state, or authorisation.

Gateway authentication (trustgraph-flow/trustgraph/gateway/auth.py)
-------------------------------------------------------------------
* IamAuth replaces the legacy Authenticator.  Distinguishes JWTs
  (three-segment dotted) from API keys by shape; verifies JWTs
  locally using the cached IAM public key; resolves API keys via
  IAM with a short-TTL hash-keyed cache.  Every failure path
  surfaces the same 401 body ("auth failure") so callers cannot
  enumerate credential state.
* Public key is fetched at gateway startup with a bounded retry loop;
  traffic does not begin flowing until auth has started.

Capability model (trustgraph-flow/trustgraph/gateway/capabilities.py)
---------------------------------------------------------------------
* Roles have two dimensions: a capability set and a workspace scope.
  OSS ships reader / writer / admin; the first two are workspace-
  assigned, admin is cross-workspace ("*").  No "cross-workspace"
  pseudo-capability — workspace permission is a property of the role.
* check(identity, capability, target_workspace=None) is the single
  authorisation test: some role must grant the capability *and* be
  active in the target workspace.
* enforce_workspace validates a request-body workspace against the
  caller's role scopes and injects the resolved value.  Cross-
  workspace admin is permitted by role scope, not by a bypass.
* Gateway endpoints declare a required capability explicitly — no
  permissive default.  Construction fails fast if omitted.  Enterprise
  editions can replace the role table without changing the wire
  protocol.

WebSocket first-frame auth (dispatch/mux.py, endpoint/socket.py)
----------------------------------------------------------------
* /api/v1/socket handshake unconditionally accepts; authentication
  runs on the first WebSocket frame ({"type":"auth","token":"..."})
  with {"type":"auth-ok","workspace":"..."} / {"type":"auth-failed"}.
  The socket stays open on failure so the client can re-authenticate
  — browsers treat a handshake-time 401 as terminal, breaking
  reconnection.
* Mux.receive rejects every non-auth frame before auth succeeds,
  enforces the caller's workspace (envelope + inner payload) using
  the role-scope resolver, and supports mid-session re-auth.
* Flow import/export streaming endpoints keep the legacy ?token=
  handshake (URL-scoped short-lived transfers; no re-auth need).

Auth surface
------------
* POST /api/v1/auth/login — public, returns a JWT.
* POST /api/v1/auth/bootstrap — public; forwards to IAM's bootstrap
  op which itself enforces mode + tables-empty.
* POST /api/v1/auth/change-password — any authenticated user.
* POST /api/v1/iam — admin-only generic forwarder for the rest of
  the IAM API (per-op REST endpoints to follow in a later change).

Removed / breaking
------------------
* GATEWAY_SECRET / --api-token / default_api_token and the legacy
  Authenticator.permitted contract.  The gateway cannot run without
  IAM.
* ?token= on /api/v1/socket.
* DispatcherManager and Mux both raise on auth=None — no silent
  downgrade path.

CLI tools (trustgraph-cli)
--------------------------
tg-bootstrap-iam, tg-login, tg-create-user, tg-list-users,
tg-disable-user, tg-enable-user, tg-delete-user, tg-change-password,
tg-reset-password, tg-create-api-key, tg-list-api-keys,
tg-revoke-api-key, tg-create-workspace, tg-list-workspaces.  Passwords
read via getpass; tokens / one-time secrets written to stdout with
operator context on stderr so shell composition works cleanly.
AsyncSocketClient / SocketClient updated to the first-frame auth
protocol.

Specifications
--------------
* docs/tech-specs/iam.md updated with the error policy, workspace
  resolver extension point, and OSS role-scope model.
* docs/tech-specs/iam-protocol.md (new) — transport, dataclasses,
  operation table, error taxonomy, bootstrap modes.
* docs/tech-specs/capabilities.md (new) — capability vocabulary, OSS
  role bundles, agent-as-composition note, enforcement-boundary
  policy, enterprise extensibility.

Tests
-----
* test_auth.py (rewritten) — IamAuth + JWT round-trip with real
  Ed25519 keypairs + API-key cache behaviour.
* test_capabilities.py (new) — role table sanity, check across
  role x workspace combinations, enforce_workspace paths,
  unknown-cap / unknown-role fail-closed.
* Every endpoint test construction now names its capability
  explicitly (no permissive defaults relied upon).  New tests pin
  the fail-closed invariants: DispatcherManager / Mux refuse
  auth=None; i18n path-traversal defense is exercised.
* test_socket_graceful_shutdown rewritten against IamAuth.
This commit is contained in:
cybermaggedon 2026-04-24 17:29:10 +01:00 committed by GitHub
parent ae9936c9cc
commit 67b2fc448f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
61 changed files with 6474 additions and 792 deletions

View file

@ -1,69 +1,312 @@
"""
Tests for Gateway Authentication
Tests for gateway/auth.py IamAuth, JWT verification, API key
resolution cache.
JWTs are signed with real Ed25519 keypairs generated per-test, so
the crypto path is exercised end-to-end without mocks. API-key
resolution is tested against a stubbed IamClient since the real
one requires pub/sub.
"""
import base64
import json
import time
from unittest.mock import AsyncMock, Mock, patch
import pytest
from aiohttp import web
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import ed25519
from trustgraph.gateway.auth import Authenticator
from trustgraph.gateway.auth import (
IamAuth, Identity,
_b64url_decode, _verify_jwt_eddsa,
API_KEY_CACHE_TTL,
)
class TestAuthenticator:
"""Test cases for Authenticator class"""
# -- helpers ---------------------------------------------------------------
def test_authenticator_initialization_with_token(self):
"""Test Authenticator initialization with valid token"""
auth = Authenticator(token="test-token-123")
assert auth.token == "test-token-123"
assert auth.allow_all is False
def test_authenticator_initialization_with_allow_all(self):
"""Test Authenticator initialization with allow_all=True"""
auth = Authenticator(allow_all=True)
assert auth.token is None
assert auth.allow_all is True
def _b64url(data: bytes) -> str:
return base64.urlsafe_b64encode(data).rstrip(b"=").decode("ascii")
def test_authenticator_initialization_without_token_raises_error(self):
"""Test Authenticator initialization without token raises RuntimeError"""
with pytest.raises(RuntimeError, match="Need a token"):
Authenticator()
def test_authenticator_initialization_with_empty_token_raises_error(self):
"""Test Authenticator initialization with empty token raises RuntimeError"""
with pytest.raises(RuntimeError, match="Need a token"):
Authenticator(token="")
def make_keypair():
priv = ed25519.Ed25519PrivateKey.generate()
public_pem = priv.public_key().public_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.SubjectPublicKeyInfo,
).decode("ascii")
return priv, public_pem
def test_permitted_with_allow_all_returns_true(self):
"""Test permitted method returns True when allow_all is enabled"""
auth = Authenticator(allow_all=True)
# Should return True regardless of token or roles
assert auth.permitted("any-token", []) is True
assert auth.permitted("different-token", ["admin"]) is True
assert auth.permitted(None, ["user"]) is True
def test_permitted_with_matching_token_returns_true(self):
"""Test permitted method returns True with matching token"""
auth = Authenticator(token="secret-token")
# Should return True when tokens match
assert auth.permitted("secret-token", []) is True
assert auth.permitted("secret-token", ["admin", "user"]) is True
def sign_jwt(priv, claims, alg="EdDSA"):
header = {"alg": alg, "typ": "JWT", "kid": "kid-test"}
h = _b64url(json.dumps(header, separators=(",", ":"), sort_keys=True).encode())
p = _b64url(json.dumps(claims, separators=(",", ":"), sort_keys=True).encode())
signing_input = f"{h}.{p}".encode("ascii")
if alg == "EdDSA":
sig = priv.sign(signing_input)
else:
raise ValueError(f"test helper doesn't sign {alg}")
return f"{h}.{p}.{_b64url(sig)}"
def test_permitted_with_non_matching_token_returns_false(self):
"""Test permitted method returns False with non-matching token"""
auth = Authenticator(token="secret-token")
# Should return False when tokens don't match
assert auth.permitted("wrong-token", []) is False
assert auth.permitted("different-token", ["admin"]) is False
assert auth.permitted(None, ["user"]) is False
def test_permitted_with_token_and_allow_all_returns_true(self):
"""Test permitted method with both token and allow_all set"""
auth = Authenticator(token="test-token", allow_all=True)
# allow_all should take precedence
assert auth.permitted("any-token", []) is True
assert auth.permitted("wrong-token", ["admin"]) is True
def make_request(auth_header):
"""Minimal stand-in for an aiohttp request — IamAuth only reads
``request.headers["Authorization"]``."""
req = Mock()
req.headers = {}
if auth_header is not None:
req.headers["Authorization"] = auth_header
return req
# -- pure helpers ----------------------------------------------------------
class TestB64UrlDecode:
def test_round_trip_without_padding(self):
data = b"hello"
encoded = _b64url(data)
assert _b64url_decode(encoded) == data
def test_handles_various_lengths(self):
for s in (b"a", b"ab", b"abc", b"abcd", b"abcde"):
assert _b64url_decode(_b64url(s)) == s
# -- JWT verification -----------------------------------------------------
class TestVerifyJwtEddsa:
def test_valid_jwt_passes(self):
priv, pub = make_keypair()
claims = {
"sub": "user-1", "workspace": "default",
"roles": ["reader"],
"iat": int(time.time()),
"exp": int(time.time()) + 60,
}
token = sign_jwt(priv, claims)
got = _verify_jwt_eddsa(token, pub)
assert got["sub"] == "user-1"
assert got["workspace"] == "default"
def test_expired_jwt_rejected(self):
priv, pub = make_keypair()
claims = {
"sub": "user-1", "workspace": "default", "roles": [],
"iat": int(time.time()) - 3600,
"exp": int(time.time()) - 1,
}
token = sign_jwt(priv, claims)
with pytest.raises(ValueError, match="expired"):
_verify_jwt_eddsa(token, pub)
def test_bad_signature_rejected(self):
priv_a, _ = make_keypair()
_, pub_b = make_keypair()
claims = {
"sub": "user-1", "workspace": "default", "roles": [],
"iat": int(time.time()),
"exp": int(time.time()) + 60,
}
token = sign_jwt(priv_a, claims)
# pub_b never signed this token.
with pytest.raises(Exception):
_verify_jwt_eddsa(token, pub_b)
def test_malformed_jwt_rejected(self):
_, pub = make_keypair()
with pytest.raises(ValueError, match="malformed"):
_verify_jwt_eddsa("not-a-jwt", pub)
def test_unsupported_algorithm_rejected(self):
priv, pub = make_keypair()
# Manually build an "alg":"HS256" header — no signer needed
# since we expect it to bail before verifying.
header = {"alg": "HS256", "typ": "JWT", "kid": "x"}
payload = {
"sub": "user-1", "workspace": "default", "roles": [],
"iat": int(time.time()), "exp": int(time.time()) + 60,
}
h = _b64url(json.dumps(header, separators=(",", ":")).encode())
p = _b64url(json.dumps(payload, separators=(",", ":")).encode())
sig = _b64url(b"not-a-real-sig")
token = f"{h}.{p}.{sig}"
with pytest.raises(ValueError, match="unsupported alg"):
_verify_jwt_eddsa(token, pub)
# -- Identity --------------------------------------------------------------
class TestIdentity:
def test_fields(self):
i = Identity(
user_id="u", workspace="w", roles=["reader"], source="api-key",
)
assert i.user_id == "u"
assert i.workspace == "w"
assert i.roles == ["reader"]
assert i.source == "api-key"
# -- IamAuth.authenticate --------------------------------------------------
class TestIamAuthDispatch:
"""``authenticate()`` chooses between the JWT and API-key paths
by shape of the bearer."""
@pytest.mark.asyncio
async def test_no_authorization_header_raises_401(self):
auth = IamAuth(backend=Mock())
with pytest.raises(web.HTTPUnauthorized):
await auth.authenticate(make_request(None))
@pytest.mark.asyncio
async def test_non_bearer_header_raises_401(self):
auth = IamAuth(backend=Mock())
with pytest.raises(web.HTTPUnauthorized):
await auth.authenticate(make_request("Basic whatever"))
@pytest.mark.asyncio
async def test_empty_bearer_raises_401(self):
auth = IamAuth(backend=Mock())
with pytest.raises(web.HTTPUnauthorized):
await auth.authenticate(make_request("Bearer "))
@pytest.mark.asyncio
async def test_unknown_format_raises_401(self):
# Not tg_... and not dotted-JWT shape.
auth = IamAuth(backend=Mock())
with pytest.raises(web.HTTPUnauthorized):
await auth.authenticate(make_request("Bearer garbage"))
@pytest.mark.asyncio
async def test_valid_jwt_resolves_to_identity(self):
priv, pub = make_keypair()
claims = {
"sub": "user-1", "workspace": "default",
"roles": ["writer"],
"iat": int(time.time()),
"exp": int(time.time()) + 60,
}
token = sign_jwt(priv, claims)
auth = IamAuth(backend=Mock())
auth._signing_public_pem = pub
ident = await auth.authenticate(
make_request(f"Bearer {token}")
)
assert ident.user_id == "user-1"
assert ident.workspace == "default"
assert ident.roles == ["writer"]
assert ident.source == "jwt"
@pytest.mark.asyncio
async def test_jwt_without_public_key_fails(self):
# If the gateway hasn't fetched IAM's public key yet, JWTs
# must not validate — even ones that would otherwise pass.
priv, _ = make_keypair()
claims = {
"sub": "user-1", "workspace": "default", "roles": [],
"iat": int(time.time()), "exp": int(time.time()) + 60,
}
token = sign_jwt(priv, claims)
auth = IamAuth(backend=Mock())
# _signing_public_pem defaults to None
with pytest.raises(web.HTTPUnauthorized):
await auth.authenticate(make_request(f"Bearer {token}"))
@pytest.mark.asyncio
async def test_api_key_path(self):
auth = IamAuth(backend=Mock())
async def fake_resolve(api_key):
assert api_key == "tg_testkey"
return ("user-xyz", "default", ["admin"])
async def fake_with_client(op):
return await op(Mock(resolve_api_key=fake_resolve))
with patch.object(auth, "_with_client", side_effect=fake_with_client):
ident = await auth.authenticate(
make_request("Bearer tg_testkey")
)
assert ident.user_id == "user-xyz"
assert ident.workspace == "default"
assert ident.roles == ["admin"]
assert ident.source == "api-key"
@pytest.mark.asyncio
async def test_api_key_rejection_masked_as_401(self):
auth = IamAuth(backend=Mock())
async def fake_with_client(op):
raise RuntimeError("auth-failed: unknown api key")
with patch.object(auth, "_with_client", side_effect=fake_with_client):
with pytest.raises(web.HTTPUnauthorized):
await auth.authenticate(
make_request("Bearer tg_bogus")
)
# -- API key cache ---------------------------------------------------------
class TestApiKeyCache:
@pytest.mark.asyncio
async def test_cache_hit_skips_iam(self):
auth = IamAuth(backend=Mock())
calls = {"n": 0}
async def fake_with_client(op):
calls["n"] += 1
return await op(Mock(
resolve_api_key=AsyncMock(
return_value=("u", "default", ["reader"]),
)
))
with patch.object(auth, "_with_client", side_effect=fake_with_client):
await auth.authenticate(make_request("Bearer tg_k1"))
await auth.authenticate(make_request("Bearer tg_k1"))
await auth.authenticate(make_request("Bearer tg_k1"))
# Only the first lookup reaches IAM; the rest are cache hits.
assert calls["n"] == 1
@pytest.mark.asyncio
async def test_different_keys_are_separately_cached(self):
auth = IamAuth(backend=Mock())
seen = []
async def fake_with_client(op):
async def resolve(plaintext):
seen.append(plaintext)
return ("u-" + plaintext, "default", ["reader"])
return await op(Mock(resolve_api_key=resolve))
with patch.object(auth, "_with_client", side_effect=fake_with_client):
a = await auth.authenticate(make_request("Bearer tg_a"))
b = await auth.authenticate(make_request("Bearer tg_b"))
assert a.user_id == "u-tg_a"
assert b.user_id == "u-tg_b"
assert seen == ["tg_a", "tg_b"]
@pytest.mark.asyncio
async def test_cache_has_ttl_constant_set(self):
# Not a behaviour test — just ensures we don't accidentally
# set TTL to 0 (which would defeat the cache) or to a week.
assert 10 <= API_KEY_CACHE_TTL <= 3600

View file

@ -0,0 +1,203 @@
"""
Tests for gateway/capabilities.py the capability + role + workspace
model that underpins all gateway authorisation.
"""
import pytest
from aiohttp import web
from trustgraph.gateway.capabilities import (
PUBLIC, AUTHENTICATED,
KNOWN_CAPABILITIES, ROLE_DEFINITIONS,
check, enforce_workspace, access_denied, auth_failure,
)
# -- test fixtures ---------------------------------------------------------
class _Identity:
"""Minimal stand-in for auth.Identity — the capability module
accesses ``.workspace`` and ``.roles``."""
def __init__(self, workspace, roles):
self.user_id = "user-1"
self.workspace = workspace
self.roles = list(roles)
def reader_in(ws):
return _Identity(ws, ["reader"])
def writer_in(ws):
return _Identity(ws, ["writer"])
def admin_in(ws):
return _Identity(ws, ["admin"])
# -- role table sanity -----------------------------------------------------
class TestRoleTable:
def test_oss_roles_present(self):
assert set(ROLE_DEFINITIONS.keys()) == {"reader", "writer", "admin"}
def test_admin_is_cross_workspace(self):
assert ROLE_DEFINITIONS["admin"]["workspace_scope"] == "*"
def test_reader_writer_are_assigned_scope(self):
assert ROLE_DEFINITIONS["reader"]["workspace_scope"] == "assigned"
assert ROLE_DEFINITIONS["writer"]["workspace_scope"] == "assigned"
def test_admin_superset_of_writer(self):
admin = ROLE_DEFINITIONS["admin"]["capabilities"]
writer = ROLE_DEFINITIONS["writer"]["capabilities"]
assert writer.issubset(admin)
def test_writer_superset_of_reader(self):
writer = ROLE_DEFINITIONS["writer"]["capabilities"]
reader = ROLE_DEFINITIONS["reader"]["capabilities"]
assert reader.issubset(writer)
def test_admin_has_users_admin(self):
assert "users:admin" in ROLE_DEFINITIONS["admin"]["capabilities"]
def test_writer_does_not_have_users_admin(self):
assert "users:admin" not in ROLE_DEFINITIONS["writer"]["capabilities"]
def test_every_bundled_capability_is_known(self):
for role in ROLE_DEFINITIONS.values():
for cap in role["capabilities"]:
assert cap in KNOWN_CAPABILITIES
# -- check() ---------------------------------------------------------------
class TestCheck:
def test_reader_has_reader_cap_in_own_workspace(self):
assert check(reader_in("default"), "graph:read", "default")
def test_reader_does_not_have_writer_cap(self):
assert not check(reader_in("default"), "graph:write", "default")
def test_reader_cannot_act_in_other_workspace(self):
assert not check(reader_in("default"), "graph:read", "acme")
def test_writer_has_write_in_own_workspace(self):
assert check(writer_in("default"), "graph:write", "default")
def test_writer_cannot_act_in_other_workspace(self):
assert not check(writer_in("default"), "graph:write", "acme")
def test_admin_has_everything_everywhere(self):
for cap in ("graph:read", "graph:write", "config:write",
"users:admin", "metrics:read"):
assert check(admin_in("default"), cap, "acme"), (
f"admin should have {cap} in acme"
)
def test_admin_has_caps_without_explicit_workspace(self):
assert check(admin_in("default"), "users:admin")
def test_default_target_is_identity_workspace(self):
# Reader with no target workspace → should check against own
assert check(reader_in("default"), "graph:read")
def test_unknown_capability_returns_false(self):
assert not check(admin_in("default"), "nonsense:cap", "default")
def test_unknown_role_contributes_nothing(self):
ident = _Identity("default", ["made-up-role"])
assert not check(ident, "graph:read", "default")
def test_multi_role_union(self):
# If a user is both reader and admin, they inherit admin's
# cross-workspace powers.
ident = _Identity("default", ["reader", "admin"])
assert check(ident, "users:admin", "acme")
# -- enforce_workspace() ---------------------------------------------------
class TestEnforceWorkspace:
def test_reader_in_own_workspace_allowed(self):
data = {"workspace": "default", "operation": "x"}
enforce_workspace(data, reader_in("default"))
assert data["workspace"] == "default"
def test_reader_no_workspace_injects_assigned(self):
data = {"operation": "x"}
enforce_workspace(data, reader_in("default"))
assert data["workspace"] == "default"
def test_reader_mismatched_workspace_denied(self):
data = {"workspace": "acme", "operation": "x"}
with pytest.raises(web.HTTPForbidden):
enforce_workspace(data, reader_in("default"))
def test_admin_can_target_any_workspace(self):
data = {"workspace": "acme", "operation": "x"}
enforce_workspace(data, admin_in("default"))
assert data["workspace"] == "acme"
def test_admin_no_workspace_defaults_to_assigned(self):
data = {"operation": "x"}
enforce_workspace(data, admin_in("default"))
assert data["workspace"] == "default"
def test_writer_same_workspace_specified_allowed(self):
data = {"workspace": "default"}
enforce_workspace(data, writer_in("default"))
assert data["workspace"] == "default"
def test_non_dict_passthrough(self):
# Non-dict bodies are returned unchanged (e.g. streaming).
result = enforce_workspace("not-a-dict", reader_in("default"))
assert result == "not-a-dict"
def test_with_capability_tightens_check(self):
# Reader lacks graph:write; workspace-only check would pass
# (scope is fine), but combined check must reject.
data = {"workspace": "default"}
with pytest.raises(web.HTTPForbidden):
enforce_workspace(
data, reader_in("default"), capability="graph:write",
)
def test_with_capability_passes_when_granted(self):
data = {"workspace": "default"}
enforce_workspace(
data, reader_in("default"), capability="graph:read",
)
assert data["workspace"] == "default"
# -- helpers ---------------------------------------------------------------
class TestResponseHelpers:
def test_auth_failure_is_401(self):
exc = auth_failure()
assert exc.status == 401
assert "auth failure" in exc.text
def test_access_denied_is_403(self):
exc = access_denied()
assert exc.status == 403
assert "access denied" in exc.text
class TestSentinels:
def test_public_and_authenticated_are_distinct(self):
assert PUBLIC != AUTHENTICATED
assert PUBLIC not in KNOWN_CAPABILITIES
assert AUTHENTICATED not in KNOWN_CAPABILITIES

View file

@ -42,7 +42,7 @@ class TestDispatcherManager:
mock_backend = Mock()
mock_config_receiver = Mock()
manager = DispatcherManager(mock_backend, mock_config_receiver)
manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock())
assert manager.backend == mock_backend
assert manager.config_receiver == mock_config_receiver
@ -59,7 +59,10 @@ class TestDispatcherManager:
mock_backend = Mock()
mock_config_receiver = Mock()
manager = DispatcherManager(mock_backend, mock_config_receiver, prefix="custom-prefix")
manager = DispatcherManager(
mock_backend, mock_config_receiver,
auth=Mock(), prefix="custom-prefix",
)
assert manager.prefix == "custom-prefix"
@ -68,7 +71,7 @@ class TestDispatcherManager:
"""Test start_flow method"""
mock_backend = Mock()
mock_config_receiver = Mock()
manager = DispatcherManager(mock_backend, mock_config_receiver)
manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock())
flow_data = {"name": "test_flow", "steps": []}
@ -82,7 +85,7 @@ class TestDispatcherManager:
"""Test stop_flow method"""
mock_backend = Mock()
mock_config_receiver = Mock()
manager = DispatcherManager(mock_backend, mock_config_receiver)
manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock())
# Pre-populate with a flow
flow_data = {"name": "test_flow", "steps": []}
@ -96,7 +99,7 @@ class TestDispatcherManager:
"""Test dispatch_global_service returns DispatcherWrapper"""
mock_backend = Mock()
mock_config_receiver = Mock()
manager = DispatcherManager(mock_backend, mock_config_receiver)
manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock())
wrapper = manager.dispatch_global_service()
@ -107,7 +110,7 @@ class TestDispatcherManager:
"""Test dispatch_core_export returns DispatcherWrapper"""
mock_backend = Mock()
mock_config_receiver = Mock()
manager = DispatcherManager(mock_backend, mock_config_receiver)
manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock())
wrapper = manager.dispatch_core_export()
@ -118,7 +121,7 @@ class TestDispatcherManager:
"""Test dispatch_core_import returns DispatcherWrapper"""
mock_backend = Mock()
mock_config_receiver = Mock()
manager = DispatcherManager(mock_backend, mock_config_receiver)
manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock())
wrapper = manager.dispatch_core_import()
@ -130,7 +133,7 @@ class TestDispatcherManager:
"""Test process_core_import method"""
mock_backend = Mock()
mock_config_receiver = Mock()
manager = DispatcherManager(mock_backend, mock_config_receiver)
manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock())
with patch('trustgraph.gateway.dispatch.manager.CoreImport') as mock_core_import:
mock_importer = Mock()
@ -148,7 +151,7 @@ class TestDispatcherManager:
"""Test process_core_export method"""
mock_backend = Mock()
mock_config_receiver = Mock()
manager = DispatcherManager(mock_backend, mock_config_receiver)
manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock())
with patch('trustgraph.gateway.dispatch.manager.CoreExport') as mock_core_export:
mock_exporter = Mock()
@ -166,7 +169,7 @@ class TestDispatcherManager:
"""Test process_global_service method"""
mock_backend = Mock()
mock_config_receiver = Mock()
manager = DispatcherManager(mock_backend, mock_config_receiver)
manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock())
manager.invoke_global_service = AsyncMock(return_value="global_result")
@ -181,7 +184,7 @@ class TestDispatcherManager:
"""Test invoke_global_service with existing dispatcher"""
mock_backend = Mock()
mock_config_receiver = Mock()
manager = DispatcherManager(mock_backend, mock_config_receiver)
manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock())
# Pre-populate with existing dispatcher
mock_dispatcher = Mock()
@ -198,7 +201,7 @@ class TestDispatcherManager:
"""Test invoke_global_service creates new dispatcher"""
mock_backend = Mock()
mock_config_receiver = Mock()
manager = DispatcherManager(mock_backend, mock_config_receiver)
manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock())
with patch('trustgraph.gateway.dispatch.manager.global_dispatchers') as mock_dispatchers:
mock_dispatcher_class = Mock()
@ -230,7 +233,7 @@ class TestDispatcherManager:
"""Test dispatch_flow_import returns correct method"""
mock_backend = Mock()
mock_config_receiver = Mock()
manager = DispatcherManager(mock_backend, mock_config_receiver)
manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock())
result = manager.dispatch_flow_import()
@ -240,7 +243,7 @@ class TestDispatcherManager:
"""Test dispatch_flow_export returns correct method"""
mock_backend = Mock()
mock_config_receiver = Mock()
manager = DispatcherManager(mock_backend, mock_config_receiver)
manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock())
result = manager.dispatch_flow_export()
@ -250,7 +253,7 @@ class TestDispatcherManager:
"""Test dispatch_socket returns correct method"""
mock_backend = Mock()
mock_config_receiver = Mock()
manager = DispatcherManager(mock_backend, mock_config_receiver)
manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock())
result = manager.dispatch_socket()
@ -260,7 +263,7 @@ class TestDispatcherManager:
"""Test dispatch_flow_service returns DispatcherWrapper"""
mock_backend = Mock()
mock_config_receiver = Mock()
manager = DispatcherManager(mock_backend, mock_config_receiver)
manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock())
wrapper = manager.dispatch_flow_service()
@ -272,7 +275,7 @@ class TestDispatcherManager:
"""Test process_flow_import with valid flow and kind"""
mock_backend = Mock()
mock_config_receiver = Mock()
manager = DispatcherManager(mock_backend, mock_config_receiver)
manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock())
# Setup test flow
manager.flows[("default", "test_flow")] = {
@ -308,7 +311,7 @@ class TestDispatcherManager:
"""Test process_flow_import with invalid flow"""
mock_backend = Mock()
mock_config_receiver = Mock()
manager = DispatcherManager(mock_backend, mock_config_receiver)
manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock())
params = {"flow": "invalid_flow", "kind": "triples"}
@ -323,7 +326,7 @@ class TestDispatcherManager:
warnings.simplefilter("ignore", RuntimeWarning)
mock_backend = Mock()
mock_config_receiver = Mock()
manager = DispatcherManager(mock_backend, mock_config_receiver)
manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock())
# Setup test flow
manager.flows[("default", "test_flow")] = {
@ -345,7 +348,7 @@ class TestDispatcherManager:
"""Test process_flow_export with valid flow and kind"""
mock_backend = Mock()
mock_config_receiver = Mock()
manager = DispatcherManager(mock_backend, mock_config_receiver)
manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock())
# Setup test flow
manager.flows[("default", "test_flow")] = {
@ -378,26 +381,47 @@ class TestDispatcherManager:
@pytest.mark.asyncio
async def test_process_socket(self):
"""Test process_socket method"""
"""process_socket constructs a Mux with the manager's auth
instance passed through this is the gateway's trust path
for first-frame WebSocket authentication. A Mux cannot be
built without auth (tested separately); this test pins that
the dispatcher-manager threads the correct auth value into
the Mux constructor call."""
mock_backend = Mock()
mock_config_receiver = Mock()
manager = DispatcherManager(mock_backend, mock_config_receiver)
mock_auth = Mock()
manager = DispatcherManager(
mock_backend, mock_config_receiver, auth=mock_auth,
)
with patch('trustgraph.gateway.dispatch.manager.Mux') as mock_mux:
mock_mux_instance = Mock()
mock_mux.return_value = mock_mux_instance
result = await manager.process_socket("ws", "running", {})
mock_mux.assert_called_once_with(manager, "ws", "running")
mock_mux.assert_called_once_with(
manager, "ws", "running", auth=mock_auth,
)
assert result == mock_mux_instance
def test_dispatcher_manager_requires_auth(self):
"""Constructing a DispatcherManager without an auth argument
must fail a no-auth DispatcherManager would produce a
Mux without authentication, silently downgrading the socket
auth path."""
mock_backend = Mock()
mock_config_receiver = Mock()
with pytest.raises(ValueError, match="auth"):
DispatcherManager(mock_backend, mock_config_receiver, auth=None)
@pytest.mark.asyncio
async def test_process_flow_service(self):
"""Test process_flow_service method"""
mock_backend = Mock()
mock_config_receiver = Mock()
manager = DispatcherManager(mock_backend, mock_config_receiver)
manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock())
manager.invoke_flow_service = AsyncMock(return_value="flow_result")
@ -412,7 +436,7 @@ class TestDispatcherManager:
"""Test invoke_flow_service with existing dispatcher"""
mock_backend = Mock()
mock_config_receiver = Mock()
manager = DispatcherManager(mock_backend, mock_config_receiver)
manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock())
# Add flow to the flows dictionary
manager.flows[("default", "test_flow")] = {"services": {"agent": {}}}
@ -432,7 +456,7 @@ class TestDispatcherManager:
"""Test invoke_flow_service creates request-response dispatcher"""
mock_backend = Mock()
mock_config_receiver = Mock()
manager = DispatcherManager(mock_backend, mock_config_receiver)
manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock())
# Setup test flow
manager.flows[("default", "test_flow")] = {
@ -476,7 +500,7 @@ class TestDispatcherManager:
"""Test invoke_flow_service creates sender dispatcher"""
mock_backend = Mock()
mock_config_receiver = Mock()
manager = DispatcherManager(mock_backend, mock_config_receiver)
manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock())
# Setup test flow
manager.flows[("default", "test_flow")] = {
@ -516,7 +540,7 @@ class TestDispatcherManager:
"""Test invoke_flow_service with invalid flow"""
mock_backend = Mock()
mock_config_receiver = Mock()
manager = DispatcherManager(mock_backend, mock_config_receiver)
manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock())
with pytest.raises(RuntimeError, match="Invalid flow"):
await manager.invoke_flow_service("data", "responder", "default", "invalid_flow", "agent")
@ -526,7 +550,7 @@ class TestDispatcherManager:
"""Test invoke_flow_service with kind not supported by flow"""
mock_backend = Mock()
mock_config_receiver = Mock()
manager = DispatcherManager(mock_backend, mock_config_receiver)
manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock())
# Setup test flow without agent interface
manager.flows[("default", "test_flow")] = {
@ -543,7 +567,7 @@ class TestDispatcherManager:
"""Test invoke_flow_service with invalid kind"""
mock_backend = Mock()
mock_config_receiver = Mock()
manager = DispatcherManager(mock_backend, mock_config_receiver)
manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock())
# Setup test flow with interface but unsupported kind
manager.flows[("default", "test_flow")] = {
@ -570,7 +594,7 @@ class TestDispatcherManager:
"""
mock_backend = Mock()
mock_config_receiver = Mock()
manager = DispatcherManager(mock_backend, mock_config_receiver)
manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock())
async def slow_start():
# Yield to the event loop so other coroutines get a chance to run,
@ -606,7 +630,7 @@ class TestDispatcherManager:
"""
mock_backend = Mock()
mock_config_receiver = Mock()
manager = DispatcherManager(mock_backend, mock_config_receiver)
manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock())
manager.flows[("default", "test_flow")] = {
"interfaces": {

View file

@ -12,6 +12,19 @@ from trustgraph.gateway.dispatch.mux import Mux, MAX_QUEUE_SIZE
class TestMux:
"""Test cases for Mux class"""
def test_mux_requires_auth(self):
"""Constructing a Mux without an ``auth`` argument must
fail. The Mux implements the first-frame auth protocol and
there is no no-auth mode a no-auth Mux would silently
accept every frame without authenticating it."""
with pytest.raises(ValueError, match="auth"):
Mux(
dispatcher_manager=MagicMock(),
ws=MagicMock(),
running=MagicMock(),
auth=None,
)
def test_mux_initialization(self):
"""Test Mux initialization"""
mock_dispatcher_manager = MagicMock()
@ -21,7 +34,8 @@ class TestMux:
mux = Mux(
dispatcher_manager=mock_dispatcher_manager,
ws=mock_ws,
running=mock_running
running=mock_running,
auth=MagicMock(),
)
assert mux.dispatcher_manager == mock_dispatcher_manager
@ -40,7 +54,8 @@ class TestMux:
mux = Mux(
dispatcher_manager=mock_dispatcher_manager,
ws=mock_ws,
running=mock_running
running=mock_running,
auth=MagicMock(),
)
# Call destroy
@ -61,7 +76,8 @@ class TestMux:
mux = Mux(
dispatcher_manager=mock_dispatcher_manager,
ws=None,
running=mock_running
running=mock_running,
auth=MagicMock(),
)
# Call destroy
@ -81,7 +97,8 @@ class TestMux:
mux = Mux(
dispatcher_manager=mock_dispatcher_manager,
ws=mock_ws,
running=mock_running
running=mock_running,
auth=MagicMock(),
)
# Mock message with valid JSON
@ -108,7 +125,8 @@ class TestMux:
mux = Mux(
dispatcher_manager=mock_dispatcher_manager,
ws=mock_ws,
running=mock_running
running=mock_running,
auth=MagicMock(),
)
# Mock message without request field
@ -137,7 +155,8 @@ class TestMux:
mux = Mux(
dispatcher_manager=mock_dispatcher_manager,
ws=mock_ws,
running=mock_running
running=mock_running,
auth=MagicMock(),
)
# Mock message without id field
@ -164,7 +183,8 @@ class TestMux:
mux = Mux(
dispatcher_manager=mock_dispatcher_manager,
ws=mock_ws,
running=mock_running
running=mock_running,
auth=MagicMock(),
)
# Mock message with invalid JSON

View file

@ -13,29 +13,36 @@ class TestConstantEndpoint:
"""Test cases for ConstantEndpoint class"""
def test_constant_endpoint_initialization(self):
"""Test ConstantEndpoint initialization"""
"""Construction records the configured capability on the
instance. The capability is a required argument no
permissive default and the test passes an explicit
value to demonstrate the contract."""
mock_auth = MagicMock()
mock_dispatcher = MagicMock()
endpoint = ConstantEndpoint(
endpoint_path="/api/test",
auth=mock_auth,
dispatcher=mock_dispatcher
dispatcher=mock_dispatcher,
capability="config:read",
)
assert endpoint.path == "/api/test"
assert endpoint.auth == mock_auth
assert endpoint.dispatcher == mock_dispatcher
assert endpoint.operation == "service"
assert endpoint.capability == "config:read"
@pytest.mark.asyncio
async def test_constant_endpoint_start_method(self):
"""Test ConstantEndpoint start method (should be no-op)"""
mock_auth = MagicMock()
mock_dispatcher = MagicMock()
endpoint = ConstantEndpoint("/api/test", mock_auth, mock_dispatcher)
endpoint = ConstantEndpoint(
"/api/test", mock_auth, mock_dispatcher,
capability="config:read",
)
# start() should complete without error
await endpoint.start()
@ -44,10 +51,13 @@ class TestConstantEndpoint:
mock_auth = MagicMock()
mock_dispatcher = MagicMock()
mock_app = MagicMock()
endpoint = ConstantEndpoint("/api/test", mock_auth, mock_dispatcher)
endpoint = ConstantEndpoint(
"/api/test", mock_auth, mock_dispatcher,
capability="config:read",
)
endpoint.add_routes(mock_app)
# Verify add_routes was called with POST route
mock_app.add_routes.assert_called_once()
# The call should include web.post with the path and handler

View file

@ -1,4 +1,12 @@
"""Tests for Gateway i18n pack endpoint."""
"""Tests for Gateway i18n pack endpoint.
Production registers this endpoint with ``capability=PUBLIC``: the
login UI needs to render its own i18n strings before any user has
authenticated, so the endpoint is deliberately pre-auth. These
tests exercise the PUBLIC configuration that is the production
contract. Behaviour of authenticated endpoints is covered by the
IamAuth tests in ``test_auth.py``.
"""
import json
from unittest.mock import MagicMock
@ -7,6 +15,7 @@ import pytest
from aiohttp import web
from trustgraph.gateway.endpoint.i18n import I18nPackEndpoint
from trustgraph.gateway.capabilities import PUBLIC
class TestI18nPackEndpoint:
@ -17,23 +26,28 @@ class TestI18nPackEndpoint:
endpoint = I18nPackEndpoint(
endpoint_path="/api/v1/i18n/packs/{lang}",
auth=mock_auth,
capability=PUBLIC,
)
assert endpoint.path == "/api/v1/i18n/packs/{lang}"
assert endpoint.auth == mock_auth
assert endpoint.operation == "service"
assert endpoint.capability == PUBLIC
@pytest.mark.asyncio
async def test_i18n_endpoint_start_method(self):
mock_auth = MagicMock()
endpoint = I18nPackEndpoint("/api/v1/i18n/packs/{lang}", mock_auth)
endpoint = I18nPackEndpoint(
"/api/v1/i18n/packs/{lang}", mock_auth, capability=PUBLIC,
)
await endpoint.start()
def test_add_routes_registers_get_handler(self):
mock_auth = MagicMock()
mock_app = MagicMock()
endpoint = I18nPackEndpoint("/api/v1/i18n/packs/{lang}", mock_auth)
endpoint = I18nPackEndpoint(
"/api/v1/i18n/packs/{lang}", mock_auth, capability=PUBLIC,
)
endpoint.add_routes(mock_app)
mock_app.add_routes.assert_called_once()
@ -41,35 +55,55 @@ class TestI18nPackEndpoint:
assert len(call_args) == 1
@pytest.mark.asyncio
async def test_handle_unauthorized_on_invalid_auth_scheme(self):
async def test_handle_returns_pack_without_authenticating(self):
"""The PUBLIC endpoint serves the language pack without
invoking the auth handler at all pre-login UI must be
reachable. The test uses an auth mock that raises if
touched, so any auth attempt by the endpoint is caught."""
mock_auth = MagicMock()
mock_auth.permitted.return_value = True
endpoint = I18nPackEndpoint("/api/v1/i18n/packs/{lang}", mock_auth)
def _should_not_be_called(*args, **kwargs):
raise AssertionError(
"PUBLIC endpoint must not invoke auth.authenticate"
)
mock_auth.authenticate = _should_not_be_called
endpoint = I18nPackEndpoint(
"/api/v1/i18n/packs/{lang}", mock_auth, capability=PUBLIC,
)
request = MagicMock()
request.path = "/api/v1/i18n/packs/en"
# A caller-supplied Authorization header of any form should
# be ignored — PUBLIC means we don't look at it.
request.headers = {"Authorization": "Token abc"}
request.match_info = {"lang": "en"}
resp = await endpoint.handle(request)
assert isinstance(resp, web.HTTPUnauthorized)
@pytest.mark.asyncio
async def test_handle_returns_pack_when_permitted(self):
mock_auth = MagicMock()
mock_auth.permitted.return_value = True
endpoint = I18nPackEndpoint("/api/v1/i18n/packs/{lang}", mock_auth)
request = MagicMock()
request.path = "/api/v1/i18n/packs/en"
request.headers = {}
request.match_info = {"lang": "en"}
resp = await endpoint.handle(request)
assert resp.status == 200
payload = json.loads(resp.body.decode("utf-8"))
assert isinstance(payload, dict)
assert "cli.verify_system_status.title" in payload
@pytest.mark.asyncio
async def test_handle_rejects_path_traversal(self):
"""The ``lang`` path parameter is reflected through to the
filesystem-backed pack loader. The endpoint contains an
explicit defense against ``/`` and ``..`` in the value; this
test pins that defense in place."""
mock_auth = MagicMock()
endpoint = I18nPackEndpoint(
"/api/v1/i18n/packs/{lang}", mock_auth, capability=PUBLIC,
)
for bad in ("../../etc/passwd", "en/../fr", "a/b"):
request = MagicMock()
request.path = f"/api/v1/i18n/packs/{bad}"
request.headers = {}
request.match_info = {"lang": bad}
resp = await endpoint.handle(request)
assert isinstance(resp, web.HTTPBadRequest), (
f"path-traversal defense did not reject lang={bad!r}"
)

View file

@ -12,30 +12,24 @@ class TestEndpointManager:
"""Test cases for EndpointManager class"""
def test_endpoint_manager_initialization(self):
"""Test EndpointManager initialization creates all endpoints"""
"""EndpointManager wires up the full endpoint set and
records dispatcher_manager / timeout on the instance."""
mock_dispatcher_manager = MagicMock()
mock_auth = MagicMock()
# Mock dispatcher methods
mock_dispatcher_manager.dispatch_global_service.return_value = MagicMock()
mock_dispatcher_manager.dispatch_socket.return_value = MagicMock()
mock_dispatcher_manager.dispatch_flow_service.return_value = MagicMock()
mock_dispatcher_manager.dispatch_flow_import.return_value = MagicMock()
mock_dispatcher_manager.dispatch_flow_export.return_value = MagicMock()
mock_dispatcher_manager.dispatch_core_import.return_value = MagicMock()
mock_dispatcher_manager.dispatch_core_export.return_value = MagicMock()
# The dispatcher_manager exposes a small set of factory
# methods — MagicMock auto-creates them, returning fresh
# MagicMocks on each call.
manager = EndpointManager(
dispatcher_manager=mock_dispatcher_manager,
auth=mock_auth,
prometheus_url="http://prometheus:9090",
timeout=300
timeout=300,
)
assert manager.dispatcher_manager == mock_dispatcher_manager
assert manager.timeout == 300
assert manager.services == {}
assert len(manager.endpoints) > 0 # Should have multiple endpoints
assert len(manager.endpoints) > 0
def test_endpoint_manager_with_default_timeout(self):
"""Test EndpointManager with default timeout value"""
@ -79,9 +73,15 @@ class TestEndpointManager:
prometheus_url="http://test:9090"
)
# Verify all dispatcher methods were called during initialization
# Each dispatcher factory is invoked exactly once during
# construction — one per endpoint that needs a dedicated
# wire. dispatch_auth_iam is the dedicated factory for the
# AuthEndpoints forwarder (login / bootstrap /
# change-password), distinct from dispatch_global_service
# (the generic /api/v1/{kind} route).
mock_dispatcher_manager.dispatch_global_service.assert_called_once()
mock_dispatcher_manager.dispatch_socket.assert_called() # Called twice
mock_dispatcher_manager.dispatch_auth_iam.assert_called_once()
mock_dispatcher_manager.dispatch_socket.assert_called_once()
mock_dispatcher_manager.dispatch_flow_service.assert_called_once()
mock_dispatcher_manager.dispatch_flow_import.assert_called_once()
mock_dispatcher_manager.dispatch_flow_export.assert_called_once()

View file

@ -12,31 +12,35 @@ class TestMetricsEndpoint:
"""Test cases for MetricsEndpoint class"""
def test_metrics_endpoint_initialization(self):
"""Test MetricsEndpoint initialization"""
"""Construction records the configured capability on the
instance. In production MetricsEndpoint is gated by
'metrics:read' so that's the natural value to pass."""
mock_auth = MagicMock()
endpoint = MetricsEndpoint(
prometheus_url="http://prometheus:9090",
endpoint_path="/metrics",
auth=mock_auth
auth=mock_auth,
capability="metrics:read",
)
assert endpoint.prometheus_url == "http://prometheus:9090"
assert endpoint.path == "/metrics"
assert endpoint.auth == mock_auth
assert endpoint.operation == "service"
assert endpoint.capability == "metrics:read"
@pytest.mark.asyncio
async def test_metrics_endpoint_start_method(self):
"""Test MetricsEndpoint start method (should be no-op)"""
mock_auth = MagicMock()
endpoint = MetricsEndpoint(
prometheus_url="http://localhost:9090",
endpoint_path="/metrics",
auth=mock_auth
auth=mock_auth,
capability="metrics:read",
)
# start() should complete without error
await endpoint.start()
@ -44,15 +48,16 @@ class TestMetricsEndpoint:
"""Test add_routes method registers GET route with wildcard path"""
mock_auth = MagicMock()
mock_app = MagicMock()
endpoint = MetricsEndpoint(
prometheus_url="http://prometheus:9090",
endpoint_path="/metrics",
auth=mock_auth
auth=mock_auth,
capability="metrics:read",
)
endpoint.add_routes(mock_app)
# Verify add_routes was called with GET route
mock_app.add_routes.assert_called_once()
# The call should include web.get with wildcard path pattern

View file

@ -1,5 +1,12 @@
"""
Tests for Gateway Socket Endpoint
Tests for Gateway Socket Endpoint.
In production the only SocketEndpoint registered with HTTP-layer
auth is ``/api/v1/socket`` using ``capability=AUTHENTICATED`` with
``in_band_auth=True`` (first-frame auth over the websocket frames,
not at the handshake). The tests below use AUTHENTICATED as the
representative capability; construction / worker / listener
behaviour is independent of which capability is configured.
"""
import pytest
@ -7,41 +14,47 @@ from unittest.mock import MagicMock, AsyncMock
from aiohttp import WSMsgType
from trustgraph.gateway.endpoint.socket import SocketEndpoint
from trustgraph.gateway.capabilities import AUTHENTICATED
class TestSocketEndpoint:
"""Test cases for SocketEndpoint class"""
def test_socket_endpoint_initialization(self):
"""Test SocketEndpoint initialization"""
"""Construction records the configured capability on the
instance. No permissive default is applied."""
mock_auth = MagicMock()
mock_dispatcher = MagicMock()
endpoint = SocketEndpoint(
endpoint_path="/api/socket",
auth=mock_auth,
dispatcher=mock_dispatcher
dispatcher=mock_dispatcher,
capability=AUTHENTICATED,
)
assert endpoint.path == "/api/socket"
assert endpoint.auth == mock_auth
assert endpoint.dispatcher == mock_dispatcher
assert endpoint.operation == "socket"
assert endpoint.capability == AUTHENTICATED
@pytest.mark.asyncio
async def test_worker_method(self):
"""Test SocketEndpoint worker method"""
mock_auth = MagicMock()
mock_dispatcher = AsyncMock()
endpoint = SocketEndpoint("/api/socket", mock_auth, mock_dispatcher)
endpoint = SocketEndpoint(
"/api/socket", mock_auth, mock_dispatcher,
capability=AUTHENTICATED,
)
mock_ws = MagicMock()
mock_running = MagicMock()
# Call worker method
await endpoint.worker(mock_ws, mock_dispatcher, mock_running)
# Verify dispatcher.run was called
mock_dispatcher.run.assert_called_once()
@ -50,8 +63,11 @@ class TestSocketEndpoint:
"""Test SocketEndpoint listener method with text message"""
mock_auth = MagicMock()
mock_dispatcher = AsyncMock()
endpoint = SocketEndpoint("/api/socket", mock_auth, mock_dispatcher)
endpoint = SocketEndpoint(
"/api/socket", mock_auth, mock_dispatcher,
capability=AUTHENTICATED,
)
# Mock websocket with text message
mock_msg = MagicMock()
@ -80,8 +96,11 @@ class TestSocketEndpoint:
"""Test SocketEndpoint listener method with binary message"""
mock_auth = MagicMock()
mock_dispatcher = AsyncMock()
endpoint = SocketEndpoint("/api/socket", mock_auth, mock_dispatcher)
endpoint = SocketEndpoint(
"/api/socket", mock_auth, mock_dispatcher,
capability=AUTHENTICATED,
)
# Mock websocket with binary message
mock_msg = MagicMock()
@ -110,8 +129,11 @@ class TestSocketEndpoint:
"""Test SocketEndpoint listener method with close message"""
mock_auth = MagicMock()
mock_dispatcher = AsyncMock()
endpoint = SocketEndpoint("/api/socket", mock_auth, mock_dispatcher)
endpoint = SocketEndpoint(
"/api/socket", mock_auth, mock_dispatcher,
capability=AUTHENTICATED,
)
# Mock websocket with close message
mock_msg = MagicMock()

View file

@ -12,48 +12,57 @@ class TestStreamEndpoint:
"""Test cases for StreamEndpoint class"""
def test_stream_endpoint_initialization_with_post(self):
"""Test StreamEndpoint initialization with POST method"""
"""Construction records the configured capability on the
instance. StreamEndpoint is used in production for the
core-import / core-export / document-stream routes; a
document-write capability is a realistic value for a POST
stream (e.g. core-import)."""
mock_auth = MagicMock()
mock_dispatcher = MagicMock()
endpoint = StreamEndpoint(
endpoint_path="/api/stream",
auth=mock_auth,
dispatcher=mock_dispatcher,
method="POST"
capability="documents:write",
method="POST",
)
assert endpoint.path == "/api/stream"
assert endpoint.auth == mock_auth
assert endpoint.dispatcher == mock_dispatcher
assert endpoint.operation == "service"
assert endpoint.capability == "documents:write"
assert endpoint.method == "POST"
def test_stream_endpoint_initialization_with_get(self):
"""Test StreamEndpoint initialization with GET method"""
"""GET stream — export-style endpoint, read capability."""
mock_auth = MagicMock()
mock_dispatcher = MagicMock()
endpoint = StreamEndpoint(
endpoint_path="/api/stream",
auth=mock_auth,
dispatcher=mock_dispatcher,
method="GET"
capability="documents:read",
method="GET",
)
assert endpoint.method == "GET"
def test_stream_endpoint_initialization_default_method(self):
"""Test StreamEndpoint initialization with default POST method"""
"""Test StreamEndpoint initialization with default POST method.
The method default is cosmetic; the capability is not
defaulted it is always required."""
mock_auth = MagicMock()
mock_dispatcher = MagicMock()
endpoint = StreamEndpoint(
endpoint_path="/api/stream",
auth=mock_auth,
dispatcher=mock_dispatcher
dispatcher=mock_dispatcher,
capability="documents:write",
)
assert endpoint.method == "POST" # Default value
@pytest.mark.asyncio
@ -61,9 +70,12 @@ class TestStreamEndpoint:
"""Test StreamEndpoint start method (should be no-op)"""
mock_auth = MagicMock()
mock_dispatcher = MagicMock()
endpoint = StreamEndpoint("/api/stream", mock_auth, mock_dispatcher)
endpoint = StreamEndpoint(
"/api/stream", mock_auth, mock_dispatcher,
capability="documents:write",
)
# start() should complete without error
await endpoint.start()
@ -72,16 +84,17 @@ class TestStreamEndpoint:
mock_auth = MagicMock()
mock_dispatcher = MagicMock()
mock_app = MagicMock()
endpoint = StreamEndpoint(
endpoint_path="/api/stream",
auth=mock_auth,
dispatcher=mock_dispatcher,
method="POST"
capability="documents:write",
method="POST",
)
endpoint.add_routes(mock_app)
# Verify add_routes was called with POST route
mock_app.add_routes.assert_called_once()
call_args = mock_app.add_routes.call_args[0][0]
@ -92,16 +105,17 @@ class TestStreamEndpoint:
mock_auth = MagicMock()
mock_dispatcher = MagicMock()
mock_app = MagicMock()
endpoint = StreamEndpoint(
endpoint_path="/api/stream",
auth=mock_auth,
dispatcher=mock_dispatcher,
method="GET"
capability="documents:read",
method="GET",
)
endpoint.add_routes(mock_app)
# Verify add_routes was called with GET route
mock_app.add_routes.assert_called_once()
call_args = mock_app.add_routes.call_args[0][0]
@ -112,13 +126,14 @@ class TestStreamEndpoint:
mock_auth = MagicMock()
mock_dispatcher = MagicMock()
mock_app = MagicMock()
endpoint = StreamEndpoint(
endpoint_path="/api/stream",
auth=mock_auth,
dispatcher=mock_dispatcher,
method="INVALID"
capability="documents:write",
method="INVALID",
)
with pytest.raises(RuntimeError, match="Bad method"):
endpoint.add_routes(mock_app)

View file

@ -12,29 +12,36 @@ class TestVariableEndpoint:
"""Test cases for VariableEndpoint class"""
def test_variable_endpoint_initialization(self):
"""Test VariableEndpoint initialization"""
"""Construction records the configured capability on the
instance. VariableEndpoint is used in production for the
/api/v1/{kind} admin-scoped global service routes, so a
write-side capability is a realistic value for the test."""
mock_auth = MagicMock()
mock_dispatcher = MagicMock()
endpoint = VariableEndpoint(
endpoint_path="/api/variable",
auth=mock_auth,
dispatcher=mock_dispatcher
dispatcher=mock_dispatcher,
capability="config:write",
)
assert endpoint.path == "/api/variable"
assert endpoint.auth == mock_auth
assert endpoint.dispatcher == mock_dispatcher
assert endpoint.operation == "service"
assert endpoint.capability == "config:write"
@pytest.mark.asyncio
async def test_variable_endpoint_start_method(self):
"""Test VariableEndpoint start method (should be no-op)"""
mock_auth = MagicMock()
mock_dispatcher = MagicMock()
endpoint = VariableEndpoint("/api/var", mock_auth, mock_dispatcher)
endpoint = VariableEndpoint(
"/api/var", mock_auth, mock_dispatcher,
capability="config:write",
)
# start() should complete without error
await endpoint.start()
@ -43,10 +50,13 @@ class TestVariableEndpoint:
mock_auth = MagicMock()
mock_dispatcher = MagicMock()
mock_app = MagicMock()
endpoint = VariableEndpoint("/api/variable", mock_auth, mock_dispatcher)
endpoint = VariableEndpoint(
"/api/variable", mock_auth, mock_dispatcher,
capability="config:write",
)
endpoint.add_routes(mock_app)
# Verify add_routes was called with POST route
mock_app.add_routes.assert_called_once()
call_args = mock_app.add_routes.call_args[0][0]

View file

@ -1,355 +1,179 @@
"""
Tests for Gateway Service API
Tests for gateway/service.py the Api class that wires together
the pub/sub backend, IAM auth, config receiver, dispatcher manager,
and endpoint manager.
The legacy ``GATEWAY_SECRET`` / ``default_api_token`` / allow-all
surface is gone, so the tests here focus on the Api's construction
and composition rather than the removed auth behaviour. IamAuth's
own behaviour is covered in test_auth.py.
"""
import pytest
import asyncio
from unittest.mock import Mock, patch, MagicMock, AsyncMock
from unittest.mock import AsyncMock, Mock, patch
from aiohttp import web
import pulsar
from trustgraph.gateway.service import Api, run, default_pulsar_host, default_prometheus_url, default_timeout, default_port, default_api_token
# Tests for Gateway Service API
from trustgraph.gateway.service import (
Api,
default_pulsar_host, default_prometheus_url,
default_timeout, default_port,
)
from trustgraph.gateway.auth import IamAuth
class TestApi:
"""Test cases for Api class"""
# -- constants -------------------------------------------------------------
def test_api_initialization_with_defaults(self):
"""Test Api initialization with default values"""
with patch('trustgraph.gateway.service.get_pubsub') as mock_get_pubsub:
mock_backend = Mock()
mock_get_pubsub.return_value = mock_backend
api = Api()
class TestDefaults:
assert api.port == default_port
assert api.timeout == default_timeout
assert api.pulsar_host == default_pulsar_host
assert api.pulsar_api_key is None
assert api.prometheus_url == default_prometheus_url + "/"
assert api.auth.allow_all is True
def test_exports_default_constants(self):
# These are consumed by CLIs / tests / docs. Sanity-check
# that they're the expected shape.
assert default_port == 8088
assert default_timeout == 600
assert default_pulsar_host.startswith("pulsar://")
assert default_prometheus_url.startswith("http")
# Verify get_pubsub was called
mock_get_pubsub.assert_called_once()
def test_api_initialization_with_custom_config(self):
"""Test Api initialization with custom configuration"""
# -- Api construction ------------------------------------------------------
@pytest.fixture
def mock_backend():
return Mock()
@pytest.fixture
def api(mock_backend):
with patch(
"trustgraph.gateway.service.get_pubsub",
return_value=mock_backend,
):
yield Api()
class TestApiConstruction:
def test_defaults(self, api):
assert api.port == default_port
assert api.timeout == default_timeout
assert api.pulsar_host == default_pulsar_host
assert api.pulsar_api_key is None
# prometheus_url gets normalised with a trailing slash
assert api.prometheus_url == default_prometheus_url + "/"
def test_auth_is_iam_backed(self, api):
# Any Api always gets an IamAuth. There is no "no auth" mode
# (GATEWAY_SECRET / allow_all has been removed — see IAM spec).
assert isinstance(api.auth, IamAuth)
def test_components_wired(self, api):
assert api.config_receiver is not None
assert api.dispatcher_manager is not None
assert api.endpoint_manager is not None
def test_dispatcher_manager_has_auth(self, api):
# The Mux uses this handle for first-frame socket auth.
assert api.dispatcher_manager.auth is api.auth
def test_custom_config(self, mock_backend):
config = {
"port": 9000,
"timeout": 300,
"pulsar_host": "pulsar://custom-host:6650",
"pulsar_api_key": "test-api-key",
"pulsar_listener": "custom-listener",
"pulsar_api_key": "custom-key",
"prometheus_url": "http://custom-prometheus:9090",
"api_token": "secret-token"
}
with patch(
"trustgraph.gateway.service.get_pubsub",
return_value=mock_backend,
):
a = Api(**config)
with patch('trustgraph.gateway.service.get_pubsub') as mock_get_pubsub:
mock_backend = Mock()
mock_get_pubsub.return_value = mock_backend
assert a.port == 9000
assert a.timeout == 300
assert a.pulsar_host == "pulsar://custom-host:6650"
assert a.pulsar_api_key == "custom-key"
# Trailing slash added.
assert a.prometheus_url == "http://custom-prometheus:9090/"
api = Api(**config)
def test_prometheus_url_already_has_trailing_slash(self, mock_backend):
with patch(
"trustgraph.gateway.service.get_pubsub",
return_value=mock_backend,
):
a = Api(prometheus_url="http://p:9090/")
assert a.prometheus_url == "http://p:9090/"
assert api.port == 9000
assert api.timeout == 300
assert api.pulsar_host == "pulsar://custom-host:6650"
assert api.pulsar_api_key == "test-api-key"
assert api.prometheus_url == "http://custom-prometheus:9090/"
assert api.auth.token == "secret-token"
assert api.auth.allow_all is False
def test_queue_overrides_parsed_for_config(self, mock_backend):
with patch(
"trustgraph.gateway.service.get_pubsub",
return_value=mock_backend,
):
a = Api(
config_request_queue="alt-config-req",
config_response_queue="alt-config-resp",
)
overrides = a.dispatcher_manager.queue_overrides
assert overrides.get("config", {}).get("request") == "alt-config-req"
assert overrides.get("config", {}).get("response") == "alt-config-resp"
# Verify get_pubsub was called with config
mock_get_pubsub.assert_called_once_with(**config)
def test_api_initialization_with_pulsar_api_key(self):
"""Test Api initialization with Pulsar API key authentication"""
with patch('trustgraph.gateway.service.get_pubsub') as mock_get_pubsub:
mock_get_pubsub.return_value = Mock()
# -- app_factory -----------------------------------------------------------
api = Api(pulsar_api_key="test-key")
# Verify api key was stored
assert api.pulsar_api_key == "test-key"
mock_get_pubsub.assert_called_once()
def test_api_initialization_prometheus_url_normalization(self):
"""Test that prometheus_url gets normalized with trailing slash"""
with patch('trustgraph.gateway.service.get_pubsub') as mock_get_pubsub:
mock_get_pubsub.return_value = Mock()
# Test URL without trailing slash
api = Api(prometheus_url="http://prometheus:9090")
assert api.prometheus_url == "http://prometheus:9090/"
# Test URL with trailing slash
api = Api(prometheus_url="http://prometheus:9090/")
assert api.prometheus_url == "http://prometheus:9090/"
def test_api_initialization_empty_api_token_means_no_auth(self):
"""Test that empty API token results in allow_all authentication"""
with patch('trustgraph.gateway.service.get_pubsub') as mock_get_pubsub:
mock_get_pubsub.return_value = Mock()
api = Api(api_token="")
assert api.auth.allow_all is True
def test_api_initialization_none_api_token_means_no_auth(self):
"""Test that None API token results in allow_all authentication"""
with patch('trustgraph.gateway.service.get_pubsub') as mock_get_pubsub:
mock_get_pubsub.return_value = Mock()
api = Api(api_token=None)
assert api.auth.allow_all is True
class TestAppFactory:
@pytest.mark.asyncio
async def test_app_factory_creates_application(self):
"""Test that app_factory creates aiohttp application"""
with patch('trustgraph.gateway.service.get_pubsub') as mock_get_pubsub:
mock_get_pubsub.return_value = Mock()
api = Api()
# Mock the dependencies
api.config_receiver = Mock()
api.config_receiver.start = AsyncMock()
api.endpoint_manager = Mock()
api.endpoint_manager.add_routes = Mock()
api.endpoint_manager.start = AsyncMock()
app = await api.app_factory()
assert isinstance(app, web.Application)
assert app._client_max_size == 256 * 1024 * 1024
# Verify that config receiver was started
api.config_receiver.start.assert_called_once()
# Verify that endpoint manager was configured
api.endpoint_manager.add_routes.assert_called_once_with(app)
api.endpoint_manager.start.assert_called_once()
async def test_creates_aiohttp_app(self, api):
# Stub out the long-tail dependencies that reach out to IAM /
# pub/sub so we can exercise the factory in isolation.
api.auth.start = AsyncMock()
api.config_receiver = Mock()
api.config_receiver.start = AsyncMock()
api.endpoint_manager = Mock()
api.endpoint_manager.add_routes = Mock()
api.endpoint_manager.start = AsyncMock()
api.endpoints = []
app = await api.app_factory()
assert isinstance(app, web.Application)
assert app._client_max_size == 256 * 1024 * 1024
api.auth.start.assert_called_once()
api.config_receiver.start.assert_called_once()
api.endpoint_manager.add_routes.assert_called_once_with(app)
api.endpoint_manager.start.assert_called_once()
@pytest.mark.asyncio
async def test_app_factory_with_custom_endpoints(self):
"""Test app_factory with custom endpoints"""
with patch('trustgraph.gateway.service.get_pubsub') as mock_get_pubsub:
mock_get_pubsub.return_value = Mock()
api = Api()
# Mock custom endpoints
mock_endpoint1 = Mock()
mock_endpoint1.add_routes = Mock()
mock_endpoint1.start = AsyncMock()
mock_endpoint2 = Mock()
mock_endpoint2.add_routes = Mock()
mock_endpoint2.start = AsyncMock()
api.endpoints = [mock_endpoint1, mock_endpoint2]
# Mock the dependencies
api.config_receiver = Mock()
api.config_receiver.start = AsyncMock()
api.endpoint_manager = Mock()
api.endpoint_manager.add_routes = Mock()
api.endpoint_manager.start = AsyncMock()
app = await api.app_factory()
# Verify custom endpoints were configured
mock_endpoint1.add_routes.assert_called_once_with(app)
mock_endpoint1.start.assert_called_once()
mock_endpoint2.add_routes.assert_called_once_with(app)
mock_endpoint2.start.assert_called_once()
async def test_auth_start_runs_before_accepting_traffic(self, api):
"""``auth.start()`` fetches the IAM signing key, and must
complete (or time out) before the gateway begins accepting
requests. It's the first await in app_factory."""
order = []
def test_run_method_calls_web_run_app(self):
"""Test that run method calls web.run_app"""
with patch('trustgraph.gateway.service.get_pubsub') as mock_get_pubsub, \
patch('aiohttp.web.run_app') as mock_run_app:
mock_get_pubsub.return_value = Mock()
# AsyncMock.side_effect expects a sync callable (its return
# value becomes the coroutine's return); a plain list.append
# avoids the "coroutine was never awaited" trap of an async
# side_effect.
api.auth.start = AsyncMock(
side_effect=lambda: order.append("auth"),
)
api.config_receiver = Mock()
api.config_receiver.start = AsyncMock(
side_effect=lambda: order.append("config"),
)
api.endpoint_manager = Mock()
api.endpoint_manager.add_routes = Mock()
api.endpoint_manager.start = AsyncMock(
side_effect=lambda: order.append("endpoints"),
)
api.endpoints = []
# Api.run() passes self.app_factory() — a coroutine — to
# web.run_app, which would normally consume it inside its own
# event loop. Since we mock run_app, close the coroutine here
# so it doesn't leak as an "unawaited coroutine" RuntimeWarning.
def _consume_coro(coro, **kwargs):
coro.close()
mock_run_app.side_effect = _consume_coro
await api.app_factory()
api = Api(port=8080)
api.run()
# Verify run_app was called once with the correct port
mock_run_app.assert_called_once()
args, kwargs = mock_run_app.call_args
assert len(args) == 1 # Should have one positional arg (the coroutine)
assert kwargs == {'port': 8080} # Should have port keyword arg
def test_api_components_initialization(self):
"""Test that all API components are properly initialized"""
with patch('trustgraph.gateway.service.get_pubsub') as mock_get_pubsub:
mock_get_pubsub.return_value = Mock()
api = Api()
# Verify all components are initialized
assert api.config_receiver is not None
assert api.dispatcher_manager is not None
assert api.endpoint_manager is not None
assert api.endpoints == []
# Verify component relationships
assert api.dispatcher_manager.backend == api.pubsub_backend
assert api.dispatcher_manager.config_receiver == api.config_receiver
assert api.endpoint_manager.dispatcher_manager == api.dispatcher_manager
# EndpointManager doesn't store auth directly, it passes it to individual endpoints
class TestRunFunction:
"""Test cases for the run() function"""
def test_run_function_with_metrics_enabled(self):
"""Test run function with metrics enabled"""
import warnings
# Suppress the specific async warning with a broader pattern
warnings.filterwarnings("ignore", message=".*Api.app_factory.*was never awaited", category=RuntimeWarning)
with patch('argparse.ArgumentParser.parse_args') as mock_parse_args, \
patch('trustgraph.gateway.service.start_http_server') as mock_start_http_server:
# Mock command line arguments
mock_args = Mock()
mock_args.metrics = True
mock_args.metrics_port = 8000
mock_parse_args.return_value = mock_args
# Create a simple mock instance without any async methods
mock_api_instance = Mock()
mock_api_instance.run = Mock()
# Create a mock Api class without importing the real one
mock_api = Mock(return_value=mock_api_instance)
# Patch using context manager to avoid importing the real Api class
with patch('trustgraph.gateway.service.Api', mock_api):
# Mock vars() to return a dict
with patch('builtins.vars') as mock_vars:
mock_vars.return_value = {
'metrics': True,
'metrics_port': 8000,
'pulsar_host': default_pulsar_host,
'timeout': default_timeout
}
run()
# Verify metrics server was started
mock_start_http_server.assert_called_once_with(8000)
# Verify Api was created and run was called
mock_api.assert_called_once()
mock_api_instance.run.assert_called_once()
@patch('trustgraph.gateway.service.start_http_server')
@patch('argparse.ArgumentParser.parse_args')
def test_run_function_with_metrics_disabled(self, mock_parse_args, mock_start_http_server):
"""Test run function with metrics disabled"""
# Mock command line arguments
mock_args = Mock()
mock_args.metrics = False
mock_parse_args.return_value = mock_args
# Create a simple mock instance without any async methods
mock_api_instance = Mock()
mock_api_instance.run = Mock()
# Patch the Api class inside the test without using decorators
with patch('trustgraph.gateway.service.Api') as mock_api:
mock_api.return_value = mock_api_instance
# Mock vars() to return a dict
with patch('builtins.vars') as mock_vars:
mock_vars.return_value = {
'metrics': False,
'metrics_port': 8000,
'pulsar_host': default_pulsar_host,
'timeout': default_timeout
}
run()
# Verify metrics server was NOT started
mock_start_http_server.assert_not_called()
# Verify Api was created and run was called
mock_api.assert_called_once()
mock_api_instance.run.assert_called_once()
@patch('argparse.ArgumentParser.parse_args')
def test_run_function_argument_parsing(self, mock_parse_args):
"""Test that run function properly parses command line arguments"""
# Mock command line arguments
mock_args = Mock()
mock_args.metrics = False
mock_parse_args.return_value = mock_args
# Create a simple mock instance without any async methods
mock_api_instance = Mock()
mock_api_instance.run = Mock()
# Mock vars() to return a dict with all expected arguments
expected_args = {
'pulsar_host': 'pulsar://test:6650',
'pulsar_api_key': 'test-key',
'pulsar_listener': 'test-listener',
'prometheus_url': 'http://test-prometheus:9090',
'port': 9000,
'timeout': 300,
'api_token': 'secret',
'log_level': 'INFO',
'metrics': False,
'metrics_port': 8001
}
# Patch the Api class inside the test without using decorators
with patch('trustgraph.gateway.service.Api') as mock_api:
mock_api.return_value = mock_api_instance
with patch('builtins.vars') as mock_vars:
mock_vars.return_value = expected_args
run()
# Verify Api was created with the parsed arguments
mock_api.assert_called_once_with(**expected_args)
mock_api_instance.run.assert_called_once()
def test_run_function_creates_argument_parser(self):
"""Test that run function creates argument parser with correct arguments"""
with patch('argparse.ArgumentParser') as mock_parser_class:
mock_parser = Mock()
mock_parser_class.return_value = mock_parser
mock_parser.parse_args.return_value = Mock(metrics=False)
with patch('trustgraph.gateway.service.Api') as mock_api, \
patch('builtins.vars') as mock_vars:
mock_vars.return_value = {'metrics': False}
mock_api.return_value = Mock()
run()
# Verify ArgumentParser was created
mock_parser_class.assert_called_once()
# Verify add_argument was called for each expected argument
expected_arguments = [
'pulsar-host', 'pulsar-api-key', 'pulsar-listener',
'prometheus-url', 'port', 'timeout', 'api-token',
'log-level', 'metrics', 'metrics-port'
]
# Check that add_argument was called multiple times (once for each arg)
assert mock_parser.add_argument.call_count >= len(expected_arguments)
# auth.start must be first (before config receiver, before
# any endpoint starts).
assert order[0] == "auth"
# All three must have run.
assert set(order) == {"auth", "config", "endpoints"}

View file

@ -1,4 +1,15 @@
"""Unit tests for SocketEndpoint graceful shutdown functionality."""
"""Unit tests for SocketEndpoint graceful shutdown functionality.
These tests exercise SocketEndpoint in its handshake-auth
configuration (``in_band_auth=False``) the mode used in production
for the flow import/export streaming endpoints. The mux socket at
``/api/v1/socket`` uses ``in_band_auth=True`` instead, where the
handshake always accepts and authentication runs on the first
WebSocket frame; that path is covered by the Mux tests.
Every endpoint constructor here passes an explicit capability no
permissive default is relied upon.
"""
import pytest
import asyncio
@ -6,13 +17,31 @@ from unittest.mock import AsyncMock, MagicMock, patch
from aiohttp import web, WSMsgType
from trustgraph.gateway.endpoint.socket import SocketEndpoint
from trustgraph.gateway.running import Running
from trustgraph.gateway.auth import Identity
# Representative capability used across these tests — corresponds to
# the flow-import streaming endpoint pattern that uses this class.
TEST_CAP = "graph:write"
def _valid_identity(roles=("admin",)):
return Identity(
user_id="test-user",
workspace="default",
roles=list(roles),
source="api-key",
)
@pytest.fixture
def mock_auth():
"""Mock authentication service."""
"""Mock IAM-backed authenticator. Successful by default —
``authenticate`` returns a valid admin identity. Tests that
need the auth failure path override the ``authenticate``
attribute locally."""
auth = MagicMock()
auth.permitted.return_value = True
auth.authenticate = AsyncMock(return_value=_valid_identity())
return auth
@ -25,7 +54,7 @@ def mock_dispatcher_factory():
dispatcher.receive = AsyncMock()
dispatcher.destroy = AsyncMock()
return dispatcher
return dispatcher_factory
@ -35,7 +64,8 @@ def socket_endpoint(mock_auth, mock_dispatcher_factory):
return SocketEndpoint(
endpoint_path="/test-socket",
auth=mock_auth,
dispatcher=mock_dispatcher_factory
dispatcher=mock_dispatcher_factory,
capability=TEST_CAP,
)
@ -61,7 +91,10 @@ def mock_request():
@pytest.mark.asyncio
async def test_listener_graceful_shutdown_on_close():
"""Test listener handles websocket close gracefully."""
socket_endpoint = SocketEndpoint("/test", MagicMock(), AsyncMock())
socket_endpoint = SocketEndpoint(
"/test", MagicMock(), AsyncMock(),
capability=TEST_CAP,
)
# Mock websocket that closes after one message
ws = AsyncMock()
@ -99,9 +132,9 @@ async def test_listener_graceful_shutdown_on_close():
@pytest.mark.asyncio
async def test_handle_normal_flow():
"""Test normal websocket handling flow."""
"""Valid bearer → handshake accepted, dispatcher created."""
mock_auth = MagicMock()
mock_auth.permitted.return_value = True
mock_auth.authenticate = AsyncMock(return_value=_valid_identity())
dispatcher_created = False
async def mock_dispatcher_factory(ws, running, match_info):
@ -111,7 +144,10 @@ async def test_handle_normal_flow():
dispatcher.destroy = AsyncMock()
return dispatcher
socket_endpoint = SocketEndpoint("/test", mock_auth, mock_dispatcher_factory)
socket_endpoint = SocketEndpoint(
"/test", mock_auth, mock_dispatcher_factory,
capability=TEST_CAP,
)
request = MagicMock()
request.query = {"token": "valid-token"}
@ -155,7 +191,7 @@ async def test_handle_normal_flow():
async def test_handle_exception_group_cleanup():
"""Test exception group triggers dispatcher cleanup."""
mock_auth = MagicMock()
mock_auth.permitted.return_value = True
mock_auth.authenticate = AsyncMock(return_value=_valid_identity())
mock_dispatcher = AsyncMock()
mock_dispatcher.destroy = AsyncMock()
@ -163,7 +199,10 @@ async def test_handle_exception_group_cleanup():
async def mock_dispatcher_factory(ws, running, match_info):
return mock_dispatcher
socket_endpoint = SocketEndpoint("/test", mock_auth, mock_dispatcher_factory)
socket_endpoint = SocketEndpoint(
"/test", mock_auth, mock_dispatcher_factory,
capability=TEST_CAP,
)
request = MagicMock()
request.query = {"token": "valid-token"}
@ -222,7 +261,7 @@ async def test_handle_exception_group_cleanup():
async def test_handle_dispatcher_cleanup_timeout():
"""Test dispatcher cleanup with timeout."""
mock_auth = MagicMock()
mock_auth.permitted.return_value = True
mock_auth.authenticate = AsyncMock(return_value=_valid_identity())
# Mock dispatcher that takes long to destroy
mock_dispatcher = AsyncMock()
@ -231,7 +270,10 @@ async def test_handle_dispatcher_cleanup_timeout():
async def mock_dispatcher_factory(ws, running, match_info):
return mock_dispatcher
socket_endpoint = SocketEndpoint("/test", mock_auth, mock_dispatcher_factory)
socket_endpoint = SocketEndpoint(
"/test", mock_auth, mock_dispatcher_factory,
capability=TEST_CAP,
)
request = MagicMock()
request.query = {"token": "valid-token"}
@ -285,49 +327,67 @@ async def test_handle_dispatcher_cleanup_timeout():
@pytest.mark.asyncio
async def test_handle_unauthorized_request():
"""Test handling of unauthorized requests."""
"""A bearer that the IAM layer rejects causes the handshake to
fail with 401. IamAuth surfaces an HTTPUnauthorized; the
endpoint propagates it. Note that the endpoint intentionally
does NOT distinguish 'bad token', 'expired', 'revoked', etc.
that's the IAM error-masking policy."""
mock_auth = MagicMock()
mock_auth.permitted.return_value = False # Unauthorized
socket_endpoint = SocketEndpoint("/test", mock_auth, AsyncMock())
mock_auth.authenticate = AsyncMock(side_effect=web.HTTPUnauthorized(
text='{"error":"auth failure"}',
content_type="application/json",
))
socket_endpoint = SocketEndpoint(
"/test", mock_auth, AsyncMock(),
capability=TEST_CAP,
)
request = MagicMock()
request.query = {"token": "invalid-token"}
result = await socket_endpoint.handle(request)
# Should return HTTP 401
assert isinstance(result, web.HTTPUnauthorized)
# Should have checked permission
mock_auth.permitted.assert_called_once_with("invalid-token", "socket")
# authenticate must have been invoked with a synthetic request
# carrying Bearer <the-token>. The endpoint wraps the query-
# string token into an Authorization header for a uniform auth
# path — the IAM layer does not look at query strings directly.
mock_auth.authenticate.assert_called_once()
passed_req = mock_auth.authenticate.call_args.args[0]
assert passed_req.headers["Authorization"] == "Bearer invalid-token"
@pytest.mark.asyncio
async def test_handle_missing_token():
"""Test handling of requests with missing token."""
"""Request with no ``token`` query param → 401 before any
IAM call is made (cheap short-circuit)."""
mock_auth = MagicMock()
mock_auth.permitted.return_value = False
socket_endpoint = SocketEndpoint("/test", mock_auth, AsyncMock())
mock_auth.authenticate = AsyncMock(
side_effect=AssertionError(
"authenticate must not be invoked when no token is present"
),
)
socket_endpoint = SocketEndpoint(
"/test", mock_auth, AsyncMock(),
capability=TEST_CAP,
)
request = MagicMock()
request.query = {} # No token
result = await socket_endpoint.handle(request)
# Should return HTTP 401
assert isinstance(result, web.HTTPUnauthorized)
# Should have checked permission with empty token
mock_auth.permitted.assert_called_once_with("", "socket")
mock_auth.authenticate.assert_not_called()
@pytest.mark.asyncio
async def test_handle_websocket_already_closed():
"""Test handling when websocket is already closed."""
mock_auth = MagicMock()
mock_auth.permitted.return_value = True
mock_auth.authenticate = AsyncMock(return_value=_valid_identity())
mock_dispatcher = AsyncMock()
mock_dispatcher.destroy = AsyncMock()
@ -335,7 +395,10 @@ async def test_handle_websocket_already_closed():
async def mock_dispatcher_factory(ws, running, match_info):
return mock_dispatcher
socket_endpoint = SocketEndpoint("/test", mock_auth, mock_dispatcher_factory)
socket_endpoint = SocketEndpoint(
"/test", mock_auth, mock_dispatcher_factory,
capability=TEST_CAP,
)
request = MagicMock()
request.query = {"token": "valid-token"}