Merge remote-tracking branch 'origin/master' into ts-port

This commit is contained in:
elpresidank 2026-05-01 22:16:51 -05:00
commit 54a6e49bd3
97 changed files with 10771 additions and 1264 deletions

View file

@ -14,13 +14,13 @@ from trustgraph.embeddings.ollama.processor import Processor
class TestOllamaDynamicModelLoading(IsolatedAsyncioTestCase):
"""Test Ollama dynamic model selection"""
@patch('trustgraph.embeddings.ollama.processor.Client')
@patch('trustgraph.embeddings.ollama.processor.AsyncClient')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__')
async def test_client_initialized_with_host(self, mock_embeddings_init, mock_async_init, mock_client_class):
"""Test that Ollama client is initialized with correct host"""
# Arrange
mock_ollama_client = Mock()
mock_ollama_client = AsyncMock()
mock_response = Mock()
mock_response.embeddings = [[0.1, 0.2, 0.3, 0.4, 0.5]]
mock_ollama_client.embed.return_value = mock_response
@ -36,13 +36,13 @@ class TestOllamaDynamicModelLoading(IsolatedAsyncioTestCase):
mock_client_class.assert_called_once_with(host="http://localhost:11434")
assert processor.default_model == "test-model"
@patch('trustgraph.embeddings.ollama.processor.Client')
@patch('trustgraph.embeddings.ollama.processor.AsyncClient')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__')
async def test_on_embeddings_uses_default_model(self, mock_embeddings_init, mock_async_init, mock_client_class):
"""Test that on_embeddings uses default model when no model specified"""
# Arrange
mock_ollama_client = Mock()
mock_ollama_client = AsyncMock()
mock_response = Mock()
mock_response.embeddings = [[0.1, 0.2, 0.3, 0.4, 0.5]]
mock_ollama_client.embed.return_value = mock_response
@ -62,13 +62,13 @@ class TestOllamaDynamicModelLoading(IsolatedAsyncioTestCase):
)
assert result == [[0.1, 0.2, 0.3, 0.4, 0.5]]
@patch('trustgraph.embeddings.ollama.processor.Client')
@patch('trustgraph.embeddings.ollama.processor.AsyncClient')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__')
async def test_on_embeddings_uses_specified_model(self, mock_embeddings_init, mock_async_init, mock_client_class):
"""Test that on_embeddings uses specified model when provided"""
# Arrange
mock_ollama_client = Mock()
mock_ollama_client = AsyncMock()
mock_response = Mock()
mock_response.embeddings = [[0.1, 0.2, 0.3, 0.4, 0.5]]
mock_ollama_client.embed.return_value = mock_response
@ -88,13 +88,13 @@ class TestOllamaDynamicModelLoading(IsolatedAsyncioTestCase):
)
assert result == [[0.1, 0.2, 0.3, 0.4, 0.5]]
@patch('trustgraph.embeddings.ollama.processor.Client')
@patch('trustgraph.embeddings.ollama.processor.AsyncClient')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__')
async def test_multiple_model_switches(self, mock_embeddings_init, mock_async_init, mock_client_class):
"""Test switching between multiple models"""
# Arrange
mock_ollama_client = Mock()
mock_ollama_client = AsyncMock()
mock_response = Mock()
mock_response.embeddings = [[0.1, 0.2, 0.3, 0.4, 0.5]]
mock_ollama_client.embed.return_value = mock_response
@ -118,13 +118,13 @@ class TestOllamaDynamicModelLoading(IsolatedAsyncioTestCase):
assert calls[2][1]['model'] == "model-a"
assert calls[3][1]['model'] == "test-model" # Default
@patch('trustgraph.embeddings.ollama.processor.Client')
@patch('trustgraph.embeddings.ollama.processor.AsyncClient')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__')
async def test_none_model_uses_default(self, mock_embeddings_init, mock_async_init, mock_client_class):
"""Test that None model parameter falls back to default"""
# Arrange
mock_ollama_client = Mock()
mock_ollama_client = AsyncMock()
mock_response = Mock()
mock_response.embeddings = [[0.1, 0.2, 0.3, 0.4, 0.5]]
mock_ollama_client.embed.return_value = mock_response
@ -143,13 +143,13 @@ class TestOllamaDynamicModelLoading(IsolatedAsyncioTestCase):
input=["test text"]
)
@patch('trustgraph.embeddings.ollama.processor.Client')
@patch('trustgraph.embeddings.ollama.processor.AsyncClient')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__')
async def test_initialization_without_model_uses_default(self, mock_embeddings_init, mock_async_init, mock_client_class):
"""Test initialization without model parameter uses module default"""
# Arrange
mock_ollama_client = Mock()
mock_ollama_client = AsyncMock()
mock_client_class.return_value = mock_ollama_client
mock_async_init.return_value = None
mock_embeddings_init.return_value = None

View file

@ -277,6 +277,60 @@ class TestTripleValidation:
is_invalid = extractor.is_valid_triple(subject, predicate, object_val, sample_ontology_subset, entity_types_invalid)
assert not is_invalid, "Invalid range should be rejected"
def test_is_valid_triple_subclass_is_accepted(self, extractor, sample_ontology_subset):
"""Domain check passes when actual type is a subclass of expected."""
sample_ontology_subset.classes["Cake"] = {
"uri": "http://purl.org/ontology/fo/Cake",
"type": "owl:Class",
"subclass_of": "Recipe",
}
sample_ontology_subset.object_properties["has_ingredient"] = {
"domain": "Recipe",
"range": "Ingredient",
}
result = extractor.is_valid_triple(
subject="cake:lemon-drizzle",
predicate="has_ingredient",
object_val="ingredient:lemon",
ontology_subset=sample_ontology_subset,
entity_types={"cake:lemon-drizzle": "Cake", "ingredient:lemon": "Ingredient"},
)
assert result is True
def test_is_valid_triple_handles_subclass_cycle_without_infinite_loop(self, extractor, sample_ontology_subset):
"""A cycle in subclass_of must return False instead of hanging."""
sample_ontology_subset.classes["A"] = {"subclass_of": "B"}
sample_ontology_subset.classes["B"] = {"subclass_of": "A"}
sample_ontology_subset.object_properties["p"] = {"domain": "Recipe", "range": "Ingredient"}
result = extractor.is_valid_triple(
subject="entity:x",
predicate="p",
object_val="ingredient:y",
ontology_subset=sample_ontology_subset,
entity_types={"entity:x": "A", "ingredient:y": "Ingredient"},
)
assert result is False
def test_is_valid_triple_entity_types_none_default(self, extractor, sample_ontology_subset):
"""entity_types=None should not raise; domain/range checks skip if type unknown."""
sample_ontology_subset.object_properties["has_ingredient"] = {
"domain": "Recipe",
"range": "Ingredient",
}
result = extractor.is_valid_triple(
subject="recipe:x",
predicate="has_ingredient",
object_val="ingredient:y",
ontology_subset=sample_ontology_subset,
)
assert result is True
class TestTripleParsing:
"""Test suite for parsing triples from LLM responses."""
@ -377,6 +431,24 @@ class TestTripleParsing:
assert triple.p.type == IRI, "Predicate should be IRI type"
assert triple.o.type == LITERAL, "Object literal should be LITERAL type"
def test_parse_and_validate_triples_collects_entity_types_from_rdf_type(self, extractor, sample_ontology_subset):
"""entity_types should be built from rdf:type triples in the same batch."""
sample_ontology_subset.object_properties["has_ingredient"] = {
"domain": "Recipe",
"range": "Ingredient",
}
triples_response = [
{"subject": "recipe:cornish-pasty", "predicate": "rdf:type", "object": "Recipe"},
{"subject": "ingredient:beef", "predicate": "rdf:type", "object": "Ingredient"},
{"subject": "recipe:cornish-pasty", "predicate": "has_ingredient", "object": "ingredient:beef"},
]
valid_triples = extractor.parse_and_validate_triples(
triples_response, sample_ontology_subset
)
assert len(valid_triples) == 3
class TestURIExpansionInExtraction:
"""Test suite for URI expansion during triple extraction."""

View file

@ -1,69 +1,447 @@
"""
Tests for Gateway Authentication
Tests for gateway/auth.py IamAuth, JWT verification, API key
resolution cache.
JWTs are signed with real Ed25519 keypairs generated per-test, so
the crypto path is exercised end-to-end without mocks. API-key
resolution is tested against a stubbed IamClient since the real
one requires pub/sub.
"""
import base64
import json
import time
from unittest.mock import AsyncMock, Mock, patch
import pytest
from aiohttp import web
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import ed25519
from trustgraph.gateway.auth import Authenticator
from trustgraph.gateway.auth import (
IamAuth, Identity,
_b64url_decode, _verify_jwt_eddsa,
API_KEY_CACHE_TTL,
)
class TestAuthenticator:
"""Test cases for Authenticator class"""
# -- helpers ---------------------------------------------------------------
def test_authenticator_initialization_with_token(self):
"""Test Authenticator initialization with valid token"""
auth = Authenticator(token="test-token-123")
assert auth.token == "test-token-123"
assert auth.allow_all is False
def test_authenticator_initialization_with_allow_all(self):
"""Test Authenticator initialization with allow_all=True"""
auth = Authenticator(allow_all=True)
assert auth.token is None
assert auth.allow_all is True
def _b64url(data: bytes) -> str:
return base64.urlsafe_b64encode(data).rstrip(b"=").decode("ascii")
def test_authenticator_initialization_without_token_raises_error(self):
"""Test Authenticator initialization without token raises RuntimeError"""
with pytest.raises(RuntimeError, match="Need a token"):
Authenticator()
def test_authenticator_initialization_with_empty_token_raises_error(self):
"""Test Authenticator initialization with empty token raises RuntimeError"""
with pytest.raises(RuntimeError, match="Need a token"):
Authenticator(token="")
def make_keypair():
priv = ed25519.Ed25519PrivateKey.generate()
public_pem = priv.public_key().public_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.SubjectPublicKeyInfo,
).decode("ascii")
return priv, public_pem
def test_permitted_with_allow_all_returns_true(self):
"""Test permitted method returns True when allow_all is enabled"""
auth = Authenticator(allow_all=True)
# Should return True regardless of token or roles
assert auth.permitted("any-token", []) is True
assert auth.permitted("different-token", ["admin"]) is True
assert auth.permitted(None, ["user"]) is True
def test_permitted_with_matching_token_returns_true(self):
"""Test permitted method returns True with matching token"""
auth = Authenticator(token="secret-token")
# Should return True when tokens match
assert auth.permitted("secret-token", []) is True
assert auth.permitted("secret-token", ["admin", "user"]) is True
def sign_jwt(priv, claims, alg="EdDSA"):
header = {"alg": alg, "typ": "JWT", "kid": "kid-test"}
h = _b64url(json.dumps(header, separators=(",", ":"), sort_keys=True).encode())
p = _b64url(json.dumps(claims, separators=(",", ":"), sort_keys=True).encode())
signing_input = f"{h}.{p}".encode("ascii")
if alg == "EdDSA":
sig = priv.sign(signing_input)
else:
raise ValueError(f"test helper doesn't sign {alg}")
return f"{h}.{p}.{_b64url(sig)}"
def test_permitted_with_non_matching_token_returns_false(self):
"""Test permitted method returns False with non-matching token"""
auth = Authenticator(token="secret-token")
# Should return False when tokens don't match
assert auth.permitted("wrong-token", []) is False
assert auth.permitted("different-token", ["admin"]) is False
assert auth.permitted(None, ["user"]) is False
def test_permitted_with_token_and_allow_all_returns_true(self):
"""Test permitted method with both token and allow_all set"""
auth = Authenticator(token="test-token", allow_all=True)
# allow_all should take precedence
assert auth.permitted("any-token", []) is True
assert auth.permitted("wrong-token", ["admin"]) is True
def make_request(auth_header):
"""Minimal stand-in for an aiohttp request — IamAuth only reads
``request.headers["Authorization"]``."""
req = Mock()
req.headers = {}
if auth_header is not None:
req.headers["Authorization"] = auth_header
return req
# -- pure helpers ----------------------------------------------------------
class TestB64UrlDecode:
def test_round_trip_without_padding(self):
data = b"hello"
encoded = _b64url(data)
assert _b64url_decode(encoded) == data
def test_handles_various_lengths(self):
for s in (b"a", b"ab", b"abc", b"abcd", b"abcde"):
assert _b64url_decode(_b64url(s)) == s
# -- JWT verification -----------------------------------------------------
class TestVerifyJwtEddsa:
def test_valid_jwt_passes(self):
priv, pub = make_keypair()
claims = {
"sub": "user-1", "workspace": "default",
"iat": int(time.time()),
"exp": int(time.time()) + 60,
}
token = sign_jwt(priv, claims)
got = _verify_jwt_eddsa(token, pub)
assert got["sub"] == "user-1"
assert got["workspace"] == "default"
def test_expired_jwt_rejected(self):
priv, pub = make_keypair()
claims = {
"sub": "user-1", "workspace": "default",
"iat": int(time.time()) - 3600,
"exp": int(time.time()) - 1,
}
token = sign_jwt(priv, claims)
with pytest.raises(ValueError, match="expired"):
_verify_jwt_eddsa(token, pub)
def test_bad_signature_rejected(self):
priv_a, _ = make_keypair()
_, pub_b = make_keypair()
claims = {
"sub": "user-1", "workspace": "default",
"iat": int(time.time()),
"exp": int(time.time()) + 60,
}
token = sign_jwt(priv_a, claims)
# pub_b never signed this token.
with pytest.raises(Exception):
_verify_jwt_eddsa(token, pub_b)
def test_malformed_jwt_rejected(self):
_, pub = make_keypair()
with pytest.raises(ValueError, match="malformed"):
_verify_jwt_eddsa("not-a-jwt", pub)
def test_unsupported_algorithm_rejected(self):
priv, pub = make_keypair()
# Manually build an "alg":"HS256" header — no signer needed
# since we expect it to bail before verifying.
header = {"alg": "HS256", "typ": "JWT", "kid": "x"}
payload = {
"sub": "user-1", "workspace": "default",
"iat": int(time.time()), "exp": int(time.time()) + 60,
}
h = _b64url(json.dumps(header, separators=(",", ":")).encode())
p = _b64url(json.dumps(payload, separators=(",", ":")).encode())
sig = _b64url(b"not-a-real-sig")
token = f"{h}.{p}.{sig}"
with pytest.raises(ValueError, match="unsupported alg"):
_verify_jwt_eddsa(token, pub)
# -- Identity --------------------------------------------------------------
class TestIdentity:
def test_fields(self):
i = Identity(
handle="u", workspace="w",
principal_id="u", source="api-key",
)
assert i.handle == "u"
assert i.workspace == "w"
assert i.principal_id == "u"
assert i.source == "api-key"
# -- IamAuth.authenticate --------------------------------------------------
class TestIamAuthDispatch:
"""``authenticate()`` chooses between the JWT and API-key paths
by shape of the bearer."""
@pytest.mark.asyncio
async def test_no_authorization_header_raises_401(self):
auth = IamAuth(backend=Mock())
with pytest.raises(web.HTTPUnauthorized):
await auth.authenticate(make_request(None))
@pytest.mark.asyncio
async def test_non_bearer_header_raises_401(self):
auth = IamAuth(backend=Mock())
with pytest.raises(web.HTTPUnauthorized):
await auth.authenticate(make_request("Basic whatever"))
@pytest.mark.asyncio
async def test_empty_bearer_raises_401(self):
auth = IamAuth(backend=Mock())
with pytest.raises(web.HTTPUnauthorized):
await auth.authenticate(make_request("Bearer "))
@pytest.mark.asyncio
async def test_unknown_format_raises_401(self):
# Not tg_... and not dotted-JWT shape.
auth = IamAuth(backend=Mock())
with pytest.raises(web.HTTPUnauthorized):
await auth.authenticate(make_request("Bearer garbage"))
@pytest.mark.asyncio
async def test_valid_jwt_resolves_to_identity(self):
priv, pub = make_keypair()
claims = {
"sub": "user-1", "workspace": "default",
"iat": int(time.time()),
"exp": int(time.time()) + 60,
}
token = sign_jwt(priv, claims)
auth = IamAuth(backend=Mock())
auth._signing_public_pem = pub
ident = await auth.authenticate(
make_request(f"Bearer {token}")
)
assert ident.handle == "user-1"
assert ident.workspace == "default"
assert ident.principal_id == "user-1"
assert ident.source == "jwt"
@pytest.mark.asyncio
async def test_jwt_without_public_key_fails(self):
# If the gateway hasn't fetched IAM's public key yet, JWTs
# must not validate — even ones that would otherwise pass.
priv, _ = make_keypair()
claims = {
"sub": "user-1", "workspace": "default",
"iat": int(time.time()), "exp": int(time.time()) + 60,
}
token = sign_jwt(priv, claims)
auth = IamAuth(backend=Mock())
# _signing_public_pem defaults to None
with pytest.raises(web.HTTPUnauthorized):
await auth.authenticate(make_request(f"Bearer {token}"))
@pytest.mark.asyncio
async def test_api_key_path(self):
auth = IamAuth(backend=Mock())
async def fake_resolve(api_key):
assert api_key == "tg_testkey"
# Roles are returned by the regime as a hint but the
# gateway ignores them — kept here so the resolve
# protocol shape is exercised.
return ("user-xyz", "default", ["admin"])
async def fake_with_client(op):
return await op(Mock(resolve_api_key=fake_resolve))
with patch.object(auth, "_with_client", side_effect=fake_with_client):
ident = await auth.authenticate(
make_request("Bearer tg_testkey")
)
assert ident.handle == "user-xyz"
assert ident.workspace == "default"
assert ident.principal_id == "user-xyz"
assert ident.source == "api-key"
@pytest.mark.asyncio
async def test_api_key_rejection_masked_as_401(self):
auth = IamAuth(backend=Mock())
async def fake_with_client(op):
raise RuntimeError("auth-failed: unknown api key")
with patch.object(auth, "_with_client", side_effect=fake_with_client):
with pytest.raises(web.HTTPUnauthorized):
await auth.authenticate(
make_request("Bearer tg_bogus")
)
# -- API key cache ---------------------------------------------------------
class TestApiKeyCache:
@pytest.mark.asyncio
async def test_cache_hit_skips_iam(self):
auth = IamAuth(backend=Mock())
calls = {"n": 0}
async def fake_with_client(op):
calls["n"] += 1
return await op(Mock(
resolve_api_key=AsyncMock(
return_value=("u", "default", ["reader"]),
)
))
with patch.object(auth, "_with_client", side_effect=fake_with_client):
await auth.authenticate(make_request("Bearer tg_k1"))
await auth.authenticate(make_request("Bearer tg_k1"))
await auth.authenticate(make_request("Bearer tg_k1"))
# Only the first lookup reaches IAM; the rest are cache hits.
assert calls["n"] == 1
@pytest.mark.asyncio
async def test_different_keys_are_separately_cached(self):
auth = IamAuth(backend=Mock())
seen = []
async def fake_with_client(op):
async def resolve(plaintext):
seen.append(plaintext)
return ("u-" + plaintext, "default", ["reader"])
return await op(Mock(resolve_api_key=resolve))
with patch.object(auth, "_with_client", side_effect=fake_with_client):
a = await auth.authenticate(make_request("Bearer tg_a"))
b = await auth.authenticate(make_request("Bearer tg_b"))
assert a.handle == "u-tg_a"
assert b.handle == "u-tg_b"
assert seen == ["tg_a", "tg_b"]
@pytest.mark.asyncio
async def test_cache_has_ttl_constant_set(self):
# Not a behaviour test — just ensures we don't accidentally
# set TTL to 0 (which would defeat the cache) or to a week.
assert 10 <= API_KEY_CACHE_TTL <= 3600
# -- IamAuth.authorise -----------------------------------------------------
class TestAuthorise:
"""``authorise()`` is the gateway's only authorisation entry
point under the IAM contract. It calls iam-svc, caches the
decision for the regime's TTL (clamped above), and raises 403
on deny / 401 on regime error (fail closed)."""
def _make_identity(self, handle="u-1", workspace="default"):
return Identity(
handle=handle, workspace=workspace,
principal_id=handle, source="api-key",
)
@pytest.mark.asyncio
async def test_allow_returns_no_exception(self):
auth = IamAuth(backend=Mock())
async def fake_with_client(op):
return await op(Mock(
authorise=AsyncMock(return_value=(True, 30)),
))
with patch.object(auth, "_with_client", side_effect=fake_with_client):
await auth.authorise(
self._make_identity(),
"graph:read",
{"workspace": "default"},
{},
)
@pytest.mark.asyncio
async def test_deny_raises_403(self):
auth = IamAuth(backend=Mock())
async def fake_with_client(op):
return await op(Mock(
authorise=AsyncMock(return_value=(False, 30)),
))
with patch.object(auth, "_with_client", side_effect=fake_with_client):
with pytest.raises(web.HTTPForbidden):
await auth.authorise(
self._make_identity(),
"users:admin",
{},
{"workspace": "acme"},
)
@pytest.mark.asyncio
async def test_regime_error_fails_closed_as_401(self):
# If iam-svc errors, the gateway must NOT silently allow.
auth = IamAuth(backend=Mock())
async def fake_with_client(op):
raise RuntimeError("iam-svc down")
with patch.object(auth, "_with_client", side_effect=fake_with_client):
with pytest.raises(web.HTTPUnauthorized):
await auth.authorise(
self._make_identity(),
"graph:read",
{"workspace": "default"},
{},
)
@pytest.mark.asyncio
async def test_allow_decision_is_cached(self):
auth = IamAuth(backend=Mock())
calls = {"n": 0}
async def fake_with_client(op):
calls["n"] += 1
return await op(Mock(
authorise=AsyncMock(return_value=(True, 30)),
))
with patch.object(auth, "_with_client", side_effect=fake_with_client):
ident = self._make_identity()
for _ in range(5):
await auth.authorise(
ident, "graph:read", {"workspace": "default"}, {},
)
assert calls["n"] == 1
@pytest.mark.asyncio
async def test_deny_decision_is_cached(self):
auth = IamAuth(backend=Mock())
calls = {"n": 0}
async def fake_with_client(op):
calls["n"] += 1
return await op(Mock(
authorise=AsyncMock(return_value=(False, 30)),
))
with patch.object(auth, "_with_client", side_effect=fake_with_client):
ident = self._make_identity()
for _ in range(5):
with pytest.raises(web.HTTPForbidden):
await auth.authorise(
ident, "users:admin", {}, {"workspace": "acme"},
)
# Denies are cached too — repeated attempts don't re-hit IAM.
assert calls["n"] == 1
@pytest.mark.asyncio
async def test_different_resources_cached_separately(self):
auth = IamAuth(backend=Mock())
calls = {"n": 0}
async def fake_with_client(op):
calls["n"] += 1
return await op(Mock(
authorise=AsyncMock(return_value=(True, 30)),
))
with patch.object(auth, "_with_client", side_effect=fake_with_client):
ident = self._make_identity()
await auth.authorise(
ident, "graph:read", {"workspace": "a"}, {},
)
await auth.authorise(
ident, "graph:read", {"workspace": "b"}, {},
)
# Different resource → different cache key → two IAM calls.
assert calls["n"] == 2

View file

@ -0,0 +1,171 @@
"""
Tests for gateway/capabilities.py the thin authorisation surface
under the IAM contract.
The gateway no longer holds policy state (roles, capability sets,
workspace scopes); those live in iam-svc. These tests cover only
what the gateway shim does itself: PUBLIC / AUTHENTICATED short-
circuiting, default-fill of workspace, and forwarding of capability
checks to ``auth.authorise``.
"""
import pytest
from aiohttp import web
from unittest.mock import AsyncMock, MagicMock
from trustgraph.gateway.capabilities import (
PUBLIC, AUTHENTICATED,
enforce, enforce_workspace,
access_denied, auth_failure,
)
# -- test fixtures ---------------------------------------------------------
class _Identity:
"""Stand-in for auth.Identity — under the IAM contract it has
just ``handle``, ``workspace``, ``principal_id``, ``source``."""
def __init__(self, handle="user-1", workspace="default"):
self.handle = handle
self.workspace = workspace
self.principal_id = handle
self.source = "api-key"
def _allow_auth(identity=None):
"""Build an Auth double that authenticates to ``identity`` and
allows every authorise() call."""
auth = MagicMock()
auth.authenticate = AsyncMock(
return_value=identity or _Identity(),
)
auth.authorise = AsyncMock(return_value=None)
return auth
def _deny_auth(identity=None):
"""Build an Auth double that authenticates but denies authorise."""
auth = MagicMock()
auth.authenticate = AsyncMock(
return_value=identity or _Identity(),
)
auth.authorise = AsyncMock(side_effect=access_denied())
return auth
# -- enforce() -------------------------------------------------------------
class TestEnforce:
@pytest.mark.asyncio
async def test_public_returns_none_no_auth(self):
auth = _allow_auth()
result = await enforce(MagicMock(), auth, PUBLIC)
assert result is None
auth.authenticate.assert_not_called()
auth.authorise.assert_not_called()
@pytest.mark.asyncio
async def test_authenticated_skips_authorise(self):
identity = _Identity()
auth = _allow_auth(identity)
result = await enforce(MagicMock(), auth, AUTHENTICATED)
assert result is identity
auth.authenticate.assert_awaited_once()
auth.authorise.assert_not_called()
@pytest.mark.asyncio
async def test_capability_calls_authorise_system_level(self):
identity = _Identity()
auth = _allow_auth(identity)
result = await enforce(MagicMock(), auth, "graph:read")
assert result is identity
auth.authorise.assert_awaited_once_with(
identity, "graph:read", {}, {},
)
@pytest.mark.asyncio
async def test_capability_denied_raises_forbidden(self):
auth = _deny_auth()
with pytest.raises(web.HTTPForbidden):
await enforce(MagicMock(), auth, "users:admin")
# -- enforce_workspace() ---------------------------------------------------
class TestEnforceWorkspace:
@pytest.mark.asyncio
async def test_default_fills_from_identity(self):
data = {"operation": "x"}
auth = _allow_auth()
await enforce_workspace(data, _Identity(workspace="default"), auth)
assert data["workspace"] == "default"
@pytest.mark.asyncio
async def test_caller_supplied_workspace_kept(self):
data = {"workspace": "acme", "operation": "x"}
auth = _allow_auth()
await enforce_workspace(data, _Identity(workspace="default"), auth)
assert data["workspace"] == "acme"
@pytest.mark.asyncio
async def test_no_capability_skips_authorise(self):
data = {"workspace": "default"}
auth = _allow_auth()
await enforce_workspace(data, _Identity(), auth)
auth.authorise.assert_not_called()
@pytest.mark.asyncio
async def test_capability_calls_authorise_with_resource(self):
data = {"workspace": "acme"}
identity = _Identity()
auth = _allow_auth(identity)
await enforce_workspace(
data, identity, auth, capability="graph:read",
)
auth.authorise.assert_awaited_once_with(
identity, "graph:read", {"workspace": "acme"}, {},
)
@pytest.mark.asyncio
async def test_capability_denied_propagates(self):
data = {"workspace": "acme"}
auth = _deny_auth()
with pytest.raises(web.HTTPForbidden):
await enforce_workspace(
data, _Identity(), auth, capability="users:admin",
)
@pytest.mark.asyncio
async def test_non_dict_passthrough(self):
auth = _allow_auth()
result = await enforce_workspace("not-a-dict", _Identity(), auth)
assert result == "not-a-dict"
auth.authorise.assert_not_called()
# -- helpers ---------------------------------------------------------------
class TestResponseHelpers:
def test_auth_failure_is_401(self):
exc = auth_failure()
assert exc.status == 401
assert "auth failure" in exc.text
def test_access_denied_is_403(self):
exc = access_denied()
assert exc.status == 403
assert "access denied" in exc.text
class TestSentinels:
def test_public_and_authenticated_are_distinct(self):
assert PUBLIC != AUTHENTICATED

View file

@ -42,7 +42,7 @@ class TestDispatcherManager:
mock_backend = Mock()
mock_config_receiver = Mock()
manager = DispatcherManager(mock_backend, mock_config_receiver)
manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock())
assert manager.backend == mock_backend
assert manager.config_receiver == mock_config_receiver
@ -59,7 +59,10 @@ class TestDispatcherManager:
mock_backend = Mock()
mock_config_receiver = Mock()
manager = DispatcherManager(mock_backend, mock_config_receiver, prefix="custom-prefix")
manager = DispatcherManager(
mock_backend, mock_config_receiver,
auth=Mock(), prefix="custom-prefix",
)
assert manager.prefix == "custom-prefix"
@ -68,7 +71,7 @@ class TestDispatcherManager:
"""Test start_flow method"""
mock_backend = Mock()
mock_config_receiver = Mock()
manager = DispatcherManager(mock_backend, mock_config_receiver)
manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock())
flow_data = {"name": "test_flow", "steps": []}
@ -82,7 +85,7 @@ class TestDispatcherManager:
"""Test stop_flow method"""
mock_backend = Mock()
mock_config_receiver = Mock()
manager = DispatcherManager(mock_backend, mock_config_receiver)
manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock())
# Pre-populate with a flow
flow_data = {"name": "test_flow", "steps": []}
@ -96,7 +99,7 @@ class TestDispatcherManager:
"""Test dispatch_global_service returns DispatcherWrapper"""
mock_backend = Mock()
mock_config_receiver = Mock()
manager = DispatcherManager(mock_backend, mock_config_receiver)
manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock())
wrapper = manager.dispatch_global_service()
@ -107,7 +110,7 @@ class TestDispatcherManager:
"""Test dispatch_core_export returns DispatcherWrapper"""
mock_backend = Mock()
mock_config_receiver = Mock()
manager = DispatcherManager(mock_backend, mock_config_receiver)
manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock())
wrapper = manager.dispatch_core_export()
@ -118,7 +121,7 @@ class TestDispatcherManager:
"""Test dispatch_core_import returns DispatcherWrapper"""
mock_backend = Mock()
mock_config_receiver = Mock()
manager = DispatcherManager(mock_backend, mock_config_receiver)
manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock())
wrapper = manager.dispatch_core_import()
@ -130,7 +133,7 @@ class TestDispatcherManager:
"""Test process_core_import method"""
mock_backend = Mock()
mock_config_receiver = Mock()
manager = DispatcherManager(mock_backend, mock_config_receiver)
manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock())
with patch('trustgraph.gateway.dispatch.manager.CoreImport') as mock_core_import:
mock_importer = Mock()
@ -148,7 +151,7 @@ class TestDispatcherManager:
"""Test process_core_export method"""
mock_backend = Mock()
mock_config_receiver = Mock()
manager = DispatcherManager(mock_backend, mock_config_receiver)
manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock())
with patch('trustgraph.gateway.dispatch.manager.CoreExport') as mock_core_export:
mock_exporter = Mock()
@ -166,7 +169,7 @@ class TestDispatcherManager:
"""Test process_global_service method"""
mock_backend = Mock()
mock_config_receiver = Mock()
manager = DispatcherManager(mock_backend, mock_config_receiver)
manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock())
manager.invoke_global_service = AsyncMock(return_value="global_result")
@ -181,7 +184,7 @@ class TestDispatcherManager:
"""Test invoke_global_service with existing dispatcher"""
mock_backend = Mock()
mock_config_receiver = Mock()
manager = DispatcherManager(mock_backend, mock_config_receiver)
manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock())
# Pre-populate with existing dispatcher
mock_dispatcher = Mock()
@ -198,7 +201,7 @@ class TestDispatcherManager:
"""Test invoke_global_service creates new dispatcher"""
mock_backend = Mock()
mock_config_receiver = Mock()
manager = DispatcherManager(mock_backend, mock_config_receiver)
manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock())
with patch('trustgraph.gateway.dispatch.manager.global_dispatchers') as mock_dispatchers:
mock_dispatcher_class = Mock()
@ -230,7 +233,7 @@ class TestDispatcherManager:
"""Test dispatch_flow_import returns correct method"""
mock_backend = Mock()
mock_config_receiver = Mock()
manager = DispatcherManager(mock_backend, mock_config_receiver)
manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock())
result = manager.dispatch_flow_import()
@ -240,7 +243,7 @@ class TestDispatcherManager:
"""Test dispatch_flow_export returns correct method"""
mock_backend = Mock()
mock_config_receiver = Mock()
manager = DispatcherManager(mock_backend, mock_config_receiver)
manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock())
result = manager.dispatch_flow_export()
@ -250,7 +253,7 @@ class TestDispatcherManager:
"""Test dispatch_socket returns correct method"""
mock_backend = Mock()
mock_config_receiver = Mock()
manager = DispatcherManager(mock_backend, mock_config_receiver)
manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock())
result = manager.dispatch_socket()
@ -260,7 +263,7 @@ class TestDispatcherManager:
"""Test dispatch_flow_service returns DispatcherWrapper"""
mock_backend = Mock()
mock_config_receiver = Mock()
manager = DispatcherManager(mock_backend, mock_config_receiver)
manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock())
wrapper = manager.dispatch_flow_service()
@ -272,7 +275,7 @@ class TestDispatcherManager:
"""Test process_flow_import with valid flow and kind"""
mock_backend = Mock()
mock_config_receiver = Mock()
manager = DispatcherManager(mock_backend, mock_config_receiver)
manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock())
# Setup test flow
manager.flows[("default", "test_flow")] = {
@ -308,7 +311,7 @@ class TestDispatcherManager:
"""Test process_flow_import with invalid flow"""
mock_backend = Mock()
mock_config_receiver = Mock()
manager = DispatcherManager(mock_backend, mock_config_receiver)
manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock())
params = {"flow": "invalid_flow", "kind": "triples"}
@ -323,7 +326,7 @@ class TestDispatcherManager:
warnings.simplefilter("ignore", RuntimeWarning)
mock_backend = Mock()
mock_config_receiver = Mock()
manager = DispatcherManager(mock_backend, mock_config_receiver)
manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock())
# Setup test flow
manager.flows[("default", "test_flow")] = {
@ -345,7 +348,7 @@ class TestDispatcherManager:
"""Test process_flow_export with valid flow and kind"""
mock_backend = Mock()
mock_config_receiver = Mock()
manager = DispatcherManager(mock_backend, mock_config_receiver)
manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock())
# Setup test flow
manager.flows[("default", "test_flow")] = {
@ -378,26 +381,47 @@ class TestDispatcherManager:
@pytest.mark.asyncio
async def test_process_socket(self):
"""Test process_socket method"""
"""process_socket constructs a Mux with the manager's auth
instance passed through this is the gateway's trust path
for first-frame WebSocket authentication. A Mux cannot be
built without auth (tested separately); this test pins that
the dispatcher-manager threads the correct auth value into
the Mux constructor call."""
mock_backend = Mock()
mock_config_receiver = Mock()
manager = DispatcherManager(mock_backend, mock_config_receiver)
mock_auth = Mock()
manager = DispatcherManager(
mock_backend, mock_config_receiver, auth=mock_auth,
)
with patch('trustgraph.gateway.dispatch.manager.Mux') as mock_mux:
mock_mux_instance = Mock()
mock_mux.return_value = mock_mux_instance
result = await manager.process_socket("ws", "running", {})
mock_mux.assert_called_once_with(manager, "ws", "running")
mock_mux.assert_called_once_with(
manager, "ws", "running", auth=mock_auth,
)
assert result == mock_mux_instance
def test_dispatcher_manager_requires_auth(self):
"""Constructing a DispatcherManager without an auth argument
must fail a no-auth DispatcherManager would produce a
Mux without authentication, silently downgrading the socket
auth path."""
mock_backend = Mock()
mock_config_receiver = Mock()
with pytest.raises(ValueError, match="auth"):
DispatcherManager(mock_backend, mock_config_receiver, auth=None)
@pytest.mark.asyncio
async def test_process_flow_service(self):
"""Test process_flow_service method"""
mock_backend = Mock()
mock_config_receiver = Mock()
manager = DispatcherManager(mock_backend, mock_config_receiver)
manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock())
manager.invoke_flow_service = AsyncMock(return_value="flow_result")
@ -412,7 +436,7 @@ class TestDispatcherManager:
"""Test invoke_flow_service with existing dispatcher"""
mock_backend = Mock()
mock_config_receiver = Mock()
manager = DispatcherManager(mock_backend, mock_config_receiver)
manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock())
# Add flow to the flows dictionary
manager.flows[("default", "test_flow")] = {"services": {"agent": {}}}
@ -432,7 +456,7 @@ class TestDispatcherManager:
"""Test invoke_flow_service creates request-response dispatcher"""
mock_backend = Mock()
mock_config_receiver = Mock()
manager = DispatcherManager(mock_backend, mock_config_receiver)
manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock())
# Setup test flow
manager.flows[("default", "test_flow")] = {
@ -476,7 +500,7 @@ class TestDispatcherManager:
"""Test invoke_flow_service creates sender dispatcher"""
mock_backend = Mock()
mock_config_receiver = Mock()
manager = DispatcherManager(mock_backend, mock_config_receiver)
manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock())
# Setup test flow
manager.flows[("default", "test_flow")] = {
@ -516,7 +540,7 @@ class TestDispatcherManager:
"""Test invoke_flow_service with invalid flow"""
mock_backend = Mock()
mock_config_receiver = Mock()
manager = DispatcherManager(mock_backend, mock_config_receiver)
manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock())
with pytest.raises(RuntimeError, match="Invalid flow"):
await manager.invoke_flow_service("data", "responder", "default", "invalid_flow", "agent")
@ -526,7 +550,7 @@ class TestDispatcherManager:
"""Test invoke_flow_service with kind not supported by flow"""
mock_backend = Mock()
mock_config_receiver = Mock()
manager = DispatcherManager(mock_backend, mock_config_receiver)
manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock())
# Setup test flow without agent interface
manager.flows[("default", "test_flow")] = {
@ -543,7 +567,7 @@ class TestDispatcherManager:
"""Test invoke_flow_service with invalid kind"""
mock_backend = Mock()
mock_config_receiver = Mock()
manager = DispatcherManager(mock_backend, mock_config_receiver)
manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock())
# Setup test flow with interface but unsupported kind
manager.flows[("default", "test_flow")] = {
@ -570,7 +594,7 @@ class TestDispatcherManager:
"""
mock_backend = Mock()
mock_config_receiver = Mock()
manager = DispatcherManager(mock_backend, mock_config_receiver)
manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock())
async def slow_start():
# Yield to the event loop so other coroutines get a chance to run,
@ -606,7 +630,7 @@ class TestDispatcherManager:
"""
mock_backend = Mock()
mock_config_receiver = Mock()
manager = DispatcherManager(mock_backend, mock_config_receiver)
manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock())
manager.flows[("default", "test_flow")] = {
"interfaces": {

View file

@ -12,6 +12,19 @@ from trustgraph.gateway.dispatch.mux import Mux, MAX_QUEUE_SIZE
class TestMux:
"""Test cases for Mux class"""
def test_mux_requires_auth(self):
"""Constructing a Mux without an ``auth`` argument must
fail. The Mux implements the first-frame auth protocol and
there is no no-auth mode a no-auth Mux would silently
accept every frame without authenticating it."""
with pytest.raises(ValueError, match="auth"):
Mux(
dispatcher_manager=MagicMock(),
ws=MagicMock(),
running=MagicMock(),
auth=None,
)
def test_mux_initialization(self):
"""Test Mux initialization"""
mock_dispatcher_manager = MagicMock()
@ -21,7 +34,8 @@ class TestMux:
mux = Mux(
dispatcher_manager=mock_dispatcher_manager,
ws=mock_ws,
running=mock_running
running=mock_running,
auth=MagicMock(),
)
assert mux.dispatcher_manager == mock_dispatcher_manager
@ -40,7 +54,8 @@ class TestMux:
mux = Mux(
dispatcher_manager=mock_dispatcher_manager,
ws=mock_ws,
running=mock_running
running=mock_running,
auth=MagicMock(),
)
# Call destroy
@ -61,7 +76,8 @@ class TestMux:
mux = Mux(
dispatcher_manager=mock_dispatcher_manager,
ws=None,
running=mock_running
running=mock_running,
auth=MagicMock(),
)
# Call destroy
@ -81,7 +97,8 @@ class TestMux:
mux = Mux(
dispatcher_manager=mock_dispatcher_manager,
ws=mock_ws,
running=mock_running
running=mock_running,
auth=MagicMock(),
)
# Mock message with valid JSON
@ -108,7 +125,8 @@ class TestMux:
mux = Mux(
dispatcher_manager=mock_dispatcher_manager,
ws=mock_ws,
running=mock_running
running=mock_running,
auth=MagicMock(),
)
# Mock message without request field
@ -137,7 +155,8 @@ class TestMux:
mux = Mux(
dispatcher_manager=mock_dispatcher_manager,
ws=mock_ws,
running=mock_running
running=mock_running,
auth=MagicMock(),
)
# Mock message without id field
@ -164,7 +183,8 @@ class TestMux:
mux = Mux(
dispatcher_manager=mock_dispatcher_manager,
ws=mock_ws,
running=mock_running
running=mock_running,
auth=MagicMock(),
)
# Mock message with invalid JSON

View file

@ -13,29 +13,36 @@ class TestConstantEndpoint:
"""Test cases for ConstantEndpoint class"""
def test_constant_endpoint_initialization(self):
"""Test ConstantEndpoint initialization"""
"""Construction records the configured capability on the
instance. The capability is a required argument no
permissive default and the test passes an explicit
value to demonstrate the contract."""
mock_auth = MagicMock()
mock_dispatcher = MagicMock()
endpoint = ConstantEndpoint(
endpoint_path="/api/test",
auth=mock_auth,
dispatcher=mock_dispatcher
dispatcher=mock_dispatcher,
capability="config:read",
)
assert endpoint.path == "/api/test"
assert endpoint.auth == mock_auth
assert endpoint.dispatcher == mock_dispatcher
assert endpoint.operation == "service"
assert endpoint.capability == "config:read"
@pytest.mark.asyncio
async def test_constant_endpoint_start_method(self):
"""Test ConstantEndpoint start method (should be no-op)"""
mock_auth = MagicMock()
mock_dispatcher = MagicMock()
endpoint = ConstantEndpoint("/api/test", mock_auth, mock_dispatcher)
endpoint = ConstantEndpoint(
"/api/test", mock_auth, mock_dispatcher,
capability="config:read",
)
# start() should complete without error
await endpoint.start()
@ -44,10 +51,13 @@ class TestConstantEndpoint:
mock_auth = MagicMock()
mock_dispatcher = MagicMock()
mock_app = MagicMock()
endpoint = ConstantEndpoint("/api/test", mock_auth, mock_dispatcher)
endpoint = ConstantEndpoint(
"/api/test", mock_auth, mock_dispatcher,
capability="config:read",
)
endpoint.add_routes(mock_app)
# Verify add_routes was called with POST route
mock_app.add_routes.assert_called_once()
# The call should include web.post with the path and handler

View file

@ -1,4 +1,12 @@
"""Tests for Gateway i18n pack endpoint."""
"""Tests for Gateway i18n pack endpoint.
Production registers this endpoint with ``capability=PUBLIC``: the
login UI needs to render its own i18n strings before any user has
authenticated, so the endpoint is deliberately pre-auth. These
tests exercise the PUBLIC configuration that is the production
contract. Behaviour of authenticated endpoints is covered by the
IamAuth tests in ``test_auth.py``.
"""
import json
from unittest.mock import MagicMock
@ -7,6 +15,7 @@ import pytest
from aiohttp import web
from trustgraph.gateway.endpoint.i18n import I18nPackEndpoint
from trustgraph.gateway.capabilities import PUBLIC
class TestI18nPackEndpoint:
@ -17,23 +26,28 @@ class TestI18nPackEndpoint:
endpoint = I18nPackEndpoint(
endpoint_path="/api/v1/i18n/packs/{lang}",
auth=mock_auth,
capability=PUBLIC,
)
assert endpoint.path == "/api/v1/i18n/packs/{lang}"
assert endpoint.auth == mock_auth
assert endpoint.operation == "service"
assert endpoint.capability == PUBLIC
@pytest.mark.asyncio
async def test_i18n_endpoint_start_method(self):
mock_auth = MagicMock()
endpoint = I18nPackEndpoint("/api/v1/i18n/packs/{lang}", mock_auth)
endpoint = I18nPackEndpoint(
"/api/v1/i18n/packs/{lang}", mock_auth, capability=PUBLIC,
)
await endpoint.start()
def test_add_routes_registers_get_handler(self):
mock_auth = MagicMock()
mock_app = MagicMock()
endpoint = I18nPackEndpoint("/api/v1/i18n/packs/{lang}", mock_auth)
endpoint = I18nPackEndpoint(
"/api/v1/i18n/packs/{lang}", mock_auth, capability=PUBLIC,
)
endpoint.add_routes(mock_app)
mock_app.add_routes.assert_called_once()
@ -41,35 +55,55 @@ class TestI18nPackEndpoint:
assert len(call_args) == 1
@pytest.mark.asyncio
async def test_handle_unauthorized_on_invalid_auth_scheme(self):
async def test_handle_returns_pack_without_authenticating(self):
"""The PUBLIC endpoint serves the language pack without
invoking the auth handler at all pre-login UI must be
reachable. The test uses an auth mock that raises if
touched, so any auth attempt by the endpoint is caught."""
mock_auth = MagicMock()
mock_auth.permitted.return_value = True
endpoint = I18nPackEndpoint("/api/v1/i18n/packs/{lang}", mock_auth)
def _should_not_be_called(*args, **kwargs):
raise AssertionError(
"PUBLIC endpoint must not invoke auth.authenticate"
)
mock_auth.authenticate = _should_not_be_called
endpoint = I18nPackEndpoint(
"/api/v1/i18n/packs/{lang}", mock_auth, capability=PUBLIC,
)
request = MagicMock()
request.path = "/api/v1/i18n/packs/en"
# A caller-supplied Authorization header of any form should
# be ignored — PUBLIC means we don't look at it.
request.headers = {"Authorization": "Token abc"}
request.match_info = {"lang": "en"}
resp = await endpoint.handle(request)
assert isinstance(resp, web.HTTPUnauthorized)
@pytest.mark.asyncio
async def test_handle_returns_pack_when_permitted(self):
mock_auth = MagicMock()
mock_auth.permitted.return_value = True
endpoint = I18nPackEndpoint("/api/v1/i18n/packs/{lang}", mock_auth)
request = MagicMock()
request.path = "/api/v1/i18n/packs/en"
request.headers = {}
request.match_info = {"lang": "en"}
resp = await endpoint.handle(request)
assert resp.status == 200
payload = json.loads(resp.body.decode("utf-8"))
assert isinstance(payload, dict)
assert "cli.verify_system_status.title" in payload
@pytest.mark.asyncio
async def test_handle_rejects_path_traversal(self):
"""The ``lang`` path parameter is reflected through to the
filesystem-backed pack loader. The endpoint contains an
explicit defense against ``/`` and ``..`` in the value; this
test pins that defense in place."""
mock_auth = MagicMock()
endpoint = I18nPackEndpoint(
"/api/v1/i18n/packs/{lang}", mock_auth, capability=PUBLIC,
)
for bad in ("../../etc/passwd", "en/../fr", "a/b"):
request = MagicMock()
request.path = f"/api/v1/i18n/packs/{bad}"
request.headers = {}
request.match_info = {"lang": bad}
resp = await endpoint.handle(request)
assert isinstance(resp, web.HTTPBadRequest), (
f"path-traversal defense did not reject lang={bad!r}"
)

View file

@ -12,30 +12,24 @@ class TestEndpointManager:
"""Test cases for EndpointManager class"""
def test_endpoint_manager_initialization(self):
"""Test EndpointManager initialization creates all endpoints"""
"""EndpointManager wires up the full endpoint set and
records dispatcher_manager / timeout on the instance."""
mock_dispatcher_manager = MagicMock()
mock_auth = MagicMock()
# Mock dispatcher methods
mock_dispatcher_manager.dispatch_global_service.return_value = MagicMock()
mock_dispatcher_manager.dispatch_socket.return_value = MagicMock()
mock_dispatcher_manager.dispatch_flow_service.return_value = MagicMock()
mock_dispatcher_manager.dispatch_flow_import.return_value = MagicMock()
mock_dispatcher_manager.dispatch_flow_export.return_value = MagicMock()
mock_dispatcher_manager.dispatch_core_import.return_value = MagicMock()
mock_dispatcher_manager.dispatch_core_export.return_value = MagicMock()
# The dispatcher_manager exposes a small set of factory
# methods — MagicMock auto-creates them, returning fresh
# MagicMocks on each call.
manager = EndpointManager(
dispatcher_manager=mock_dispatcher_manager,
auth=mock_auth,
prometheus_url="http://prometheus:9090",
timeout=300
timeout=300,
)
assert manager.dispatcher_manager == mock_dispatcher_manager
assert manager.timeout == 300
assert manager.services == {}
assert len(manager.endpoints) > 0 # Should have multiple endpoints
assert len(manager.endpoints) > 0
def test_endpoint_manager_with_default_timeout(self):
"""Test EndpointManager with default timeout value"""
@ -79,9 +73,17 @@ class TestEndpointManager:
prometheus_url="http://test:9090"
)
# Verify all dispatcher methods were called during initialization
# Each dispatcher factory is invoked once per endpoint that
# needs a dedicated wire. dispatch_auth_iam is shared by
# two endpoints — AuthEndpoints (login / bootstrap /
# change-password) and IamEndpoint (registry-driven
# /api/v1/iam) — so it's expected to be called twice.
# Both forwarders pin the dispatcher to kind=iam and reuse
# the same factory; they're distinct from
# dispatch_global_service (the generic /api/v1/{kind} route).
mock_dispatcher_manager.dispatch_global_service.assert_called_once()
mock_dispatcher_manager.dispatch_socket.assert_called() # Called twice
assert mock_dispatcher_manager.dispatch_auth_iam.call_count == 2
mock_dispatcher_manager.dispatch_socket.assert_called_once()
mock_dispatcher_manager.dispatch_flow_service.assert_called_once()
mock_dispatcher_manager.dispatch_flow_import.assert_called_once()
mock_dispatcher_manager.dispatch_flow_export.assert_called_once()

View file

@ -12,31 +12,35 @@ class TestMetricsEndpoint:
"""Test cases for MetricsEndpoint class"""
def test_metrics_endpoint_initialization(self):
"""Test MetricsEndpoint initialization"""
"""Construction records the configured capability on the
instance. In production MetricsEndpoint is gated by
'metrics:read' so that's the natural value to pass."""
mock_auth = MagicMock()
endpoint = MetricsEndpoint(
prometheus_url="http://prometheus:9090",
endpoint_path="/metrics",
auth=mock_auth
auth=mock_auth,
capability="metrics:read",
)
assert endpoint.prometheus_url == "http://prometheus:9090"
assert endpoint.path == "/metrics"
assert endpoint.auth == mock_auth
assert endpoint.operation == "service"
assert endpoint.capability == "metrics:read"
@pytest.mark.asyncio
async def test_metrics_endpoint_start_method(self):
"""Test MetricsEndpoint start method (should be no-op)"""
mock_auth = MagicMock()
endpoint = MetricsEndpoint(
prometheus_url="http://localhost:9090",
endpoint_path="/metrics",
auth=mock_auth
auth=mock_auth,
capability="metrics:read",
)
# start() should complete without error
await endpoint.start()
@ -44,15 +48,16 @@ class TestMetricsEndpoint:
"""Test add_routes method registers GET route with wildcard path"""
mock_auth = MagicMock()
mock_app = MagicMock()
endpoint = MetricsEndpoint(
prometheus_url="http://prometheus:9090",
endpoint_path="/metrics",
auth=mock_auth
auth=mock_auth,
capability="metrics:read",
)
endpoint.add_routes(mock_app)
# Verify add_routes was called with GET route
mock_app.add_routes.assert_called_once()
# The call should include web.get with wildcard path pattern

View file

@ -1,5 +1,12 @@
"""
Tests for Gateway Socket Endpoint
Tests for Gateway Socket Endpoint.
In production the only SocketEndpoint registered with HTTP-layer
auth is ``/api/v1/socket`` using ``capability=AUTHENTICATED`` with
``in_band_auth=True`` (first-frame auth over the websocket frames,
not at the handshake). The tests below use AUTHENTICATED as the
representative capability; construction / worker / listener
behaviour is independent of which capability is configured.
"""
import pytest
@ -7,41 +14,47 @@ from unittest.mock import MagicMock, AsyncMock
from aiohttp import WSMsgType
from trustgraph.gateway.endpoint.socket import SocketEndpoint
from trustgraph.gateway.capabilities import AUTHENTICATED
class TestSocketEndpoint:
"""Test cases for SocketEndpoint class"""
def test_socket_endpoint_initialization(self):
"""Test SocketEndpoint initialization"""
"""Construction records the configured capability on the
instance. No permissive default is applied."""
mock_auth = MagicMock()
mock_dispatcher = MagicMock()
endpoint = SocketEndpoint(
endpoint_path="/api/socket",
auth=mock_auth,
dispatcher=mock_dispatcher
dispatcher=mock_dispatcher,
capability=AUTHENTICATED,
)
assert endpoint.path == "/api/socket"
assert endpoint.auth == mock_auth
assert endpoint.dispatcher == mock_dispatcher
assert endpoint.operation == "socket"
assert endpoint.capability == AUTHENTICATED
@pytest.mark.asyncio
async def test_worker_method(self):
"""Test SocketEndpoint worker method"""
mock_auth = MagicMock()
mock_dispatcher = AsyncMock()
endpoint = SocketEndpoint("/api/socket", mock_auth, mock_dispatcher)
endpoint = SocketEndpoint(
"/api/socket", mock_auth, mock_dispatcher,
capability=AUTHENTICATED,
)
mock_ws = MagicMock()
mock_running = MagicMock()
# Call worker method
await endpoint.worker(mock_ws, mock_dispatcher, mock_running)
# Verify dispatcher.run was called
mock_dispatcher.run.assert_called_once()
@ -50,8 +63,11 @@ class TestSocketEndpoint:
"""Test SocketEndpoint listener method with text message"""
mock_auth = MagicMock()
mock_dispatcher = AsyncMock()
endpoint = SocketEndpoint("/api/socket", mock_auth, mock_dispatcher)
endpoint = SocketEndpoint(
"/api/socket", mock_auth, mock_dispatcher,
capability=AUTHENTICATED,
)
# Mock websocket with text message
mock_msg = MagicMock()
@ -80,8 +96,11 @@ class TestSocketEndpoint:
"""Test SocketEndpoint listener method with binary message"""
mock_auth = MagicMock()
mock_dispatcher = AsyncMock()
endpoint = SocketEndpoint("/api/socket", mock_auth, mock_dispatcher)
endpoint = SocketEndpoint(
"/api/socket", mock_auth, mock_dispatcher,
capability=AUTHENTICATED,
)
# Mock websocket with binary message
mock_msg = MagicMock()
@ -110,8 +129,11 @@ class TestSocketEndpoint:
"""Test SocketEndpoint listener method with close message"""
mock_auth = MagicMock()
mock_dispatcher = AsyncMock()
endpoint = SocketEndpoint("/api/socket", mock_auth, mock_dispatcher)
endpoint = SocketEndpoint(
"/api/socket", mock_auth, mock_dispatcher,
capability=AUTHENTICATED,
)
# Mock websocket with close message
mock_msg = MagicMock()

View file

@ -12,48 +12,57 @@ class TestStreamEndpoint:
"""Test cases for StreamEndpoint class"""
def test_stream_endpoint_initialization_with_post(self):
"""Test StreamEndpoint initialization with POST method"""
"""Construction records the configured capability on the
instance. StreamEndpoint is used in production for the
core-import / core-export / document-stream routes; a
document-write capability is a realistic value for a POST
stream (e.g. core-import)."""
mock_auth = MagicMock()
mock_dispatcher = MagicMock()
endpoint = StreamEndpoint(
endpoint_path="/api/stream",
auth=mock_auth,
dispatcher=mock_dispatcher,
method="POST"
capability="documents:write",
method="POST",
)
assert endpoint.path == "/api/stream"
assert endpoint.auth == mock_auth
assert endpoint.dispatcher == mock_dispatcher
assert endpoint.operation == "service"
assert endpoint.capability == "documents:write"
assert endpoint.method == "POST"
def test_stream_endpoint_initialization_with_get(self):
"""Test StreamEndpoint initialization with GET method"""
"""GET stream — export-style endpoint, read capability."""
mock_auth = MagicMock()
mock_dispatcher = MagicMock()
endpoint = StreamEndpoint(
endpoint_path="/api/stream",
auth=mock_auth,
dispatcher=mock_dispatcher,
method="GET"
capability="documents:read",
method="GET",
)
assert endpoint.method == "GET"
def test_stream_endpoint_initialization_default_method(self):
"""Test StreamEndpoint initialization with default POST method"""
"""Test StreamEndpoint initialization with default POST method.
The method default is cosmetic; the capability is not
defaulted it is always required."""
mock_auth = MagicMock()
mock_dispatcher = MagicMock()
endpoint = StreamEndpoint(
endpoint_path="/api/stream",
auth=mock_auth,
dispatcher=mock_dispatcher
dispatcher=mock_dispatcher,
capability="documents:write",
)
assert endpoint.method == "POST" # Default value
@pytest.mark.asyncio
@ -61,9 +70,12 @@ class TestStreamEndpoint:
"""Test StreamEndpoint start method (should be no-op)"""
mock_auth = MagicMock()
mock_dispatcher = MagicMock()
endpoint = StreamEndpoint("/api/stream", mock_auth, mock_dispatcher)
endpoint = StreamEndpoint(
"/api/stream", mock_auth, mock_dispatcher,
capability="documents:write",
)
# start() should complete without error
await endpoint.start()
@ -72,16 +84,17 @@ class TestStreamEndpoint:
mock_auth = MagicMock()
mock_dispatcher = MagicMock()
mock_app = MagicMock()
endpoint = StreamEndpoint(
endpoint_path="/api/stream",
auth=mock_auth,
dispatcher=mock_dispatcher,
method="POST"
capability="documents:write",
method="POST",
)
endpoint.add_routes(mock_app)
# Verify add_routes was called with POST route
mock_app.add_routes.assert_called_once()
call_args = mock_app.add_routes.call_args[0][0]
@ -92,16 +105,17 @@ class TestStreamEndpoint:
mock_auth = MagicMock()
mock_dispatcher = MagicMock()
mock_app = MagicMock()
endpoint = StreamEndpoint(
endpoint_path="/api/stream",
auth=mock_auth,
dispatcher=mock_dispatcher,
method="GET"
capability="documents:read",
method="GET",
)
endpoint.add_routes(mock_app)
# Verify add_routes was called with GET route
mock_app.add_routes.assert_called_once()
call_args = mock_app.add_routes.call_args[0][0]
@ -112,13 +126,14 @@ class TestStreamEndpoint:
mock_auth = MagicMock()
mock_dispatcher = MagicMock()
mock_app = MagicMock()
endpoint = StreamEndpoint(
endpoint_path="/api/stream",
auth=mock_auth,
dispatcher=mock_dispatcher,
method="INVALID"
capability="documents:write",
method="INVALID",
)
with pytest.raises(RuntimeError, match="Bad method"):
endpoint.add_routes(mock_app)

View file

@ -12,29 +12,36 @@ class TestVariableEndpoint:
"""Test cases for VariableEndpoint class"""
def test_variable_endpoint_initialization(self):
"""Test VariableEndpoint initialization"""
"""Construction records the configured capability on the
instance. VariableEndpoint is used in production for the
/api/v1/{kind} admin-scoped global service routes, so a
write-side capability is a realistic value for the test."""
mock_auth = MagicMock()
mock_dispatcher = MagicMock()
endpoint = VariableEndpoint(
endpoint_path="/api/variable",
auth=mock_auth,
dispatcher=mock_dispatcher
dispatcher=mock_dispatcher,
capability="config:write",
)
assert endpoint.path == "/api/variable"
assert endpoint.auth == mock_auth
assert endpoint.dispatcher == mock_dispatcher
assert endpoint.operation == "service"
assert endpoint.capability == "config:write"
@pytest.mark.asyncio
async def test_variable_endpoint_start_method(self):
"""Test VariableEndpoint start method (should be no-op)"""
mock_auth = MagicMock()
mock_dispatcher = MagicMock()
endpoint = VariableEndpoint("/api/var", mock_auth, mock_dispatcher)
endpoint = VariableEndpoint(
"/api/var", mock_auth, mock_dispatcher,
capability="config:write",
)
# start() should complete without error
await endpoint.start()
@ -43,10 +50,13 @@ class TestVariableEndpoint:
mock_auth = MagicMock()
mock_dispatcher = MagicMock()
mock_app = MagicMock()
endpoint = VariableEndpoint("/api/variable", mock_auth, mock_dispatcher)
endpoint = VariableEndpoint(
"/api/variable", mock_auth, mock_dispatcher,
capability="config:write",
)
endpoint.add_routes(mock_app)
# Verify add_routes was called with POST route
mock_app.add_routes.assert_called_once()
call_args = mock_app.add_routes.call_args[0][0]

View file

@ -1,355 +1,179 @@
"""
Tests for Gateway Service API
Tests for gateway/service.py the Api class that wires together
the pub/sub backend, IAM auth, config receiver, dispatcher manager,
and endpoint manager.
The legacy ``GATEWAY_SECRET`` / ``default_api_token`` / allow-all
surface is gone, so the tests here focus on the Api's construction
and composition rather than the removed auth behaviour. IamAuth's
own behaviour is covered in test_auth.py.
"""
import pytest
import asyncio
from unittest.mock import Mock, patch, MagicMock, AsyncMock
from unittest.mock import AsyncMock, Mock, patch
from aiohttp import web
import pulsar
from trustgraph.gateway.service import Api, run, default_pulsar_host, default_prometheus_url, default_timeout, default_port, default_api_token
# Tests for Gateway Service API
from trustgraph.gateway.service import (
Api,
default_pulsar_host, default_prometheus_url,
default_timeout, default_port,
)
from trustgraph.gateway.auth import IamAuth
class TestApi:
"""Test cases for Api class"""
# -- constants -------------------------------------------------------------
def test_api_initialization_with_defaults(self):
"""Test Api initialization with default values"""
with patch('trustgraph.gateway.service.get_pubsub') as mock_get_pubsub:
mock_backend = Mock()
mock_get_pubsub.return_value = mock_backend
api = Api()
class TestDefaults:
assert api.port == default_port
assert api.timeout == default_timeout
assert api.pulsar_host == default_pulsar_host
assert api.pulsar_api_key is None
assert api.prometheus_url == default_prometheus_url + "/"
assert api.auth.allow_all is True
def test_exports_default_constants(self):
# These are consumed by CLIs / tests / docs. Sanity-check
# that they're the expected shape.
assert default_port == 8088
assert default_timeout == 600
assert default_pulsar_host.startswith("pulsar://")
assert default_prometheus_url.startswith("http")
# Verify get_pubsub was called
mock_get_pubsub.assert_called_once()
def test_api_initialization_with_custom_config(self):
"""Test Api initialization with custom configuration"""
# -- Api construction ------------------------------------------------------
@pytest.fixture
def mock_backend():
return Mock()
@pytest.fixture
def api(mock_backend):
with patch(
"trustgraph.gateway.service.get_pubsub",
return_value=mock_backend,
):
yield Api()
class TestApiConstruction:
def test_defaults(self, api):
assert api.port == default_port
assert api.timeout == default_timeout
assert api.pulsar_host == default_pulsar_host
assert api.pulsar_api_key is None
# prometheus_url gets normalised with a trailing slash
assert api.prometheus_url == default_prometheus_url + "/"
def test_auth_is_iam_backed(self, api):
# Any Api always gets an IamAuth. There is no "no auth" mode
# (GATEWAY_SECRET / allow_all has been removed — see IAM spec).
assert isinstance(api.auth, IamAuth)
def test_components_wired(self, api):
assert api.config_receiver is not None
assert api.dispatcher_manager is not None
assert api.endpoint_manager is not None
def test_dispatcher_manager_has_auth(self, api):
# The Mux uses this handle for first-frame socket auth.
assert api.dispatcher_manager.auth is api.auth
def test_custom_config(self, mock_backend):
config = {
"port": 9000,
"timeout": 300,
"pulsar_host": "pulsar://custom-host:6650",
"pulsar_api_key": "test-api-key",
"pulsar_listener": "custom-listener",
"pulsar_api_key": "custom-key",
"prometheus_url": "http://custom-prometheus:9090",
"api_token": "secret-token"
}
with patch(
"trustgraph.gateway.service.get_pubsub",
return_value=mock_backend,
):
a = Api(**config)
with patch('trustgraph.gateway.service.get_pubsub') as mock_get_pubsub:
mock_backend = Mock()
mock_get_pubsub.return_value = mock_backend
assert a.port == 9000
assert a.timeout == 300
assert a.pulsar_host == "pulsar://custom-host:6650"
assert a.pulsar_api_key == "custom-key"
# Trailing slash added.
assert a.prometheus_url == "http://custom-prometheus:9090/"
api = Api(**config)
def test_prometheus_url_already_has_trailing_slash(self, mock_backend):
with patch(
"trustgraph.gateway.service.get_pubsub",
return_value=mock_backend,
):
a = Api(prometheus_url="http://p:9090/")
assert a.prometheus_url == "http://p:9090/"
assert api.port == 9000
assert api.timeout == 300
assert api.pulsar_host == "pulsar://custom-host:6650"
assert api.pulsar_api_key == "test-api-key"
assert api.prometheus_url == "http://custom-prometheus:9090/"
assert api.auth.token == "secret-token"
assert api.auth.allow_all is False
def test_queue_overrides_parsed_for_config(self, mock_backend):
with patch(
"trustgraph.gateway.service.get_pubsub",
return_value=mock_backend,
):
a = Api(
config_request_queue="alt-config-req",
config_response_queue="alt-config-resp",
)
overrides = a.dispatcher_manager.queue_overrides
assert overrides.get("config", {}).get("request") == "alt-config-req"
assert overrides.get("config", {}).get("response") == "alt-config-resp"
# Verify get_pubsub was called with config
mock_get_pubsub.assert_called_once_with(**config)
def test_api_initialization_with_pulsar_api_key(self):
"""Test Api initialization with Pulsar API key authentication"""
with patch('trustgraph.gateway.service.get_pubsub') as mock_get_pubsub:
mock_get_pubsub.return_value = Mock()
# -- app_factory -----------------------------------------------------------
api = Api(pulsar_api_key="test-key")
# Verify api key was stored
assert api.pulsar_api_key == "test-key"
mock_get_pubsub.assert_called_once()
def test_api_initialization_prometheus_url_normalization(self):
"""Test that prometheus_url gets normalized with trailing slash"""
with patch('trustgraph.gateway.service.get_pubsub') as mock_get_pubsub:
mock_get_pubsub.return_value = Mock()
# Test URL without trailing slash
api = Api(prometheus_url="http://prometheus:9090")
assert api.prometheus_url == "http://prometheus:9090/"
# Test URL with trailing slash
api = Api(prometheus_url="http://prometheus:9090/")
assert api.prometheus_url == "http://prometheus:9090/"
def test_api_initialization_empty_api_token_means_no_auth(self):
"""Test that empty API token results in allow_all authentication"""
with patch('trustgraph.gateway.service.get_pubsub') as mock_get_pubsub:
mock_get_pubsub.return_value = Mock()
api = Api(api_token="")
assert api.auth.allow_all is True
def test_api_initialization_none_api_token_means_no_auth(self):
"""Test that None API token results in allow_all authentication"""
with patch('trustgraph.gateway.service.get_pubsub') as mock_get_pubsub:
mock_get_pubsub.return_value = Mock()
api = Api(api_token=None)
assert api.auth.allow_all is True
class TestAppFactory:
@pytest.mark.asyncio
async def test_app_factory_creates_application(self):
"""Test that app_factory creates aiohttp application"""
with patch('trustgraph.gateway.service.get_pubsub') as mock_get_pubsub:
mock_get_pubsub.return_value = Mock()
api = Api()
# Mock the dependencies
api.config_receiver = Mock()
api.config_receiver.start = AsyncMock()
api.endpoint_manager = Mock()
api.endpoint_manager.add_routes = Mock()
api.endpoint_manager.start = AsyncMock()
app = await api.app_factory()
assert isinstance(app, web.Application)
assert app._client_max_size == 256 * 1024 * 1024
# Verify that config receiver was started
api.config_receiver.start.assert_called_once()
# Verify that endpoint manager was configured
api.endpoint_manager.add_routes.assert_called_once_with(app)
api.endpoint_manager.start.assert_called_once()
async def test_creates_aiohttp_app(self, api):
# Stub out the long-tail dependencies that reach out to IAM /
# pub/sub so we can exercise the factory in isolation.
api.auth.start = AsyncMock()
api.config_receiver = Mock()
api.config_receiver.start = AsyncMock()
api.endpoint_manager = Mock()
api.endpoint_manager.add_routes = Mock()
api.endpoint_manager.start = AsyncMock()
api.endpoints = []
app = await api.app_factory()
assert isinstance(app, web.Application)
assert app._client_max_size == 256 * 1024 * 1024
api.auth.start.assert_called_once()
api.config_receiver.start.assert_called_once()
api.endpoint_manager.add_routes.assert_called_once_with(app)
api.endpoint_manager.start.assert_called_once()
@pytest.mark.asyncio
async def test_app_factory_with_custom_endpoints(self):
"""Test app_factory with custom endpoints"""
with patch('trustgraph.gateway.service.get_pubsub') as mock_get_pubsub:
mock_get_pubsub.return_value = Mock()
api = Api()
# Mock custom endpoints
mock_endpoint1 = Mock()
mock_endpoint1.add_routes = Mock()
mock_endpoint1.start = AsyncMock()
mock_endpoint2 = Mock()
mock_endpoint2.add_routes = Mock()
mock_endpoint2.start = AsyncMock()
api.endpoints = [mock_endpoint1, mock_endpoint2]
# Mock the dependencies
api.config_receiver = Mock()
api.config_receiver.start = AsyncMock()
api.endpoint_manager = Mock()
api.endpoint_manager.add_routes = Mock()
api.endpoint_manager.start = AsyncMock()
app = await api.app_factory()
# Verify custom endpoints were configured
mock_endpoint1.add_routes.assert_called_once_with(app)
mock_endpoint1.start.assert_called_once()
mock_endpoint2.add_routes.assert_called_once_with(app)
mock_endpoint2.start.assert_called_once()
async def test_auth_start_runs_before_accepting_traffic(self, api):
"""``auth.start()`` fetches the IAM signing key, and must
complete (or time out) before the gateway begins accepting
requests. It's the first await in app_factory."""
order = []
def test_run_method_calls_web_run_app(self):
"""Test that run method calls web.run_app"""
with patch('trustgraph.gateway.service.get_pubsub') as mock_get_pubsub, \
patch('aiohttp.web.run_app') as mock_run_app:
mock_get_pubsub.return_value = Mock()
# AsyncMock.side_effect expects a sync callable (its return
# value becomes the coroutine's return); a plain list.append
# avoids the "coroutine was never awaited" trap of an async
# side_effect.
api.auth.start = AsyncMock(
side_effect=lambda: order.append("auth"),
)
api.config_receiver = Mock()
api.config_receiver.start = AsyncMock(
side_effect=lambda: order.append("config"),
)
api.endpoint_manager = Mock()
api.endpoint_manager.add_routes = Mock()
api.endpoint_manager.start = AsyncMock(
side_effect=lambda: order.append("endpoints"),
)
api.endpoints = []
# Api.run() passes self.app_factory() — a coroutine — to
# web.run_app, which would normally consume it inside its own
# event loop. Since we mock run_app, close the coroutine here
# so it doesn't leak as an "unawaited coroutine" RuntimeWarning.
def _consume_coro(coro, **kwargs):
coro.close()
mock_run_app.side_effect = _consume_coro
await api.app_factory()
api = Api(port=8080)
api.run()
# Verify run_app was called once with the correct port
mock_run_app.assert_called_once()
args, kwargs = mock_run_app.call_args
assert len(args) == 1 # Should have one positional arg (the coroutine)
assert kwargs == {'port': 8080} # Should have port keyword arg
def test_api_components_initialization(self):
"""Test that all API components are properly initialized"""
with patch('trustgraph.gateway.service.get_pubsub') as mock_get_pubsub:
mock_get_pubsub.return_value = Mock()
api = Api()
# Verify all components are initialized
assert api.config_receiver is not None
assert api.dispatcher_manager is not None
assert api.endpoint_manager is not None
assert api.endpoints == []
# Verify component relationships
assert api.dispatcher_manager.backend == api.pubsub_backend
assert api.dispatcher_manager.config_receiver == api.config_receiver
assert api.endpoint_manager.dispatcher_manager == api.dispatcher_manager
# EndpointManager doesn't store auth directly, it passes it to individual endpoints
class TestRunFunction:
"""Test cases for the run() function"""
def test_run_function_with_metrics_enabled(self):
"""Test run function with metrics enabled"""
import warnings
# Suppress the specific async warning with a broader pattern
warnings.filterwarnings("ignore", message=".*Api.app_factory.*was never awaited", category=RuntimeWarning)
with patch('argparse.ArgumentParser.parse_args') as mock_parse_args, \
patch('trustgraph.gateway.service.start_http_server') as mock_start_http_server:
# Mock command line arguments
mock_args = Mock()
mock_args.metrics = True
mock_args.metrics_port = 8000
mock_parse_args.return_value = mock_args
# Create a simple mock instance without any async methods
mock_api_instance = Mock()
mock_api_instance.run = Mock()
# Create a mock Api class without importing the real one
mock_api = Mock(return_value=mock_api_instance)
# Patch using context manager to avoid importing the real Api class
with patch('trustgraph.gateway.service.Api', mock_api):
# Mock vars() to return a dict
with patch('builtins.vars') as mock_vars:
mock_vars.return_value = {
'metrics': True,
'metrics_port': 8000,
'pulsar_host': default_pulsar_host,
'timeout': default_timeout
}
run()
# Verify metrics server was started
mock_start_http_server.assert_called_once_with(8000)
# Verify Api was created and run was called
mock_api.assert_called_once()
mock_api_instance.run.assert_called_once()
@patch('trustgraph.gateway.service.start_http_server')
@patch('argparse.ArgumentParser.parse_args')
def test_run_function_with_metrics_disabled(self, mock_parse_args, mock_start_http_server):
"""Test run function with metrics disabled"""
# Mock command line arguments
mock_args = Mock()
mock_args.metrics = False
mock_parse_args.return_value = mock_args
# Create a simple mock instance without any async methods
mock_api_instance = Mock()
mock_api_instance.run = Mock()
# Patch the Api class inside the test without using decorators
with patch('trustgraph.gateway.service.Api') as mock_api:
mock_api.return_value = mock_api_instance
# Mock vars() to return a dict
with patch('builtins.vars') as mock_vars:
mock_vars.return_value = {
'metrics': False,
'metrics_port': 8000,
'pulsar_host': default_pulsar_host,
'timeout': default_timeout
}
run()
# Verify metrics server was NOT started
mock_start_http_server.assert_not_called()
# Verify Api was created and run was called
mock_api.assert_called_once()
mock_api_instance.run.assert_called_once()
@patch('argparse.ArgumentParser.parse_args')
def test_run_function_argument_parsing(self, mock_parse_args):
"""Test that run function properly parses command line arguments"""
# Mock command line arguments
mock_args = Mock()
mock_args.metrics = False
mock_parse_args.return_value = mock_args
# Create a simple mock instance without any async methods
mock_api_instance = Mock()
mock_api_instance.run = Mock()
# Mock vars() to return a dict with all expected arguments
expected_args = {
'pulsar_host': 'pulsar://test:6650',
'pulsar_api_key': 'test-key',
'pulsar_listener': 'test-listener',
'prometheus_url': 'http://test-prometheus:9090',
'port': 9000,
'timeout': 300,
'api_token': 'secret',
'log_level': 'INFO',
'metrics': False,
'metrics_port': 8001
}
# Patch the Api class inside the test without using decorators
with patch('trustgraph.gateway.service.Api') as mock_api:
mock_api.return_value = mock_api_instance
with patch('builtins.vars') as mock_vars:
mock_vars.return_value = expected_args
run()
# Verify Api was created with the parsed arguments
mock_api.assert_called_once_with(**expected_args)
mock_api_instance.run.assert_called_once()
def test_run_function_creates_argument_parser(self):
"""Test that run function creates argument parser with correct arguments"""
with patch('argparse.ArgumentParser') as mock_parser_class:
mock_parser = Mock()
mock_parser_class.return_value = mock_parser
mock_parser.parse_args.return_value = Mock(metrics=False)
with patch('trustgraph.gateway.service.Api') as mock_api, \
patch('builtins.vars') as mock_vars:
mock_vars.return_value = {'metrics': False}
mock_api.return_value = Mock()
run()
# Verify ArgumentParser was created
mock_parser_class.assert_called_once()
# Verify add_argument was called for each expected argument
expected_arguments = [
'pulsar-host', 'pulsar-api-key', 'pulsar-listener',
'prometheus-url', 'port', 'timeout', 'api-token',
'log-level', 'metrics', 'metrics-port'
]
# Check that add_argument was called multiple times (once for each arg)
assert mock_parser.add_argument.call_count >= len(expected_arguments)
# auth.start must be first (before config receiver, before
# any endpoint starts).
assert order[0] == "auth"
# All three must have run.
assert set(order) == {"auth", "config", "endpoints"}

View file

@ -1,4 +1,15 @@
"""Unit tests for SocketEndpoint graceful shutdown functionality."""
"""Unit tests for SocketEndpoint graceful shutdown functionality.
These tests exercise SocketEndpoint in its handshake-auth
configuration (``in_band_auth=False``) the mode used in production
for the flow import/export streaming endpoints. The mux socket at
``/api/v1/socket`` uses ``in_band_auth=True`` instead, where the
handshake always accepts and authentication runs on the first
WebSocket frame; that path is covered by the Mux tests.
Every endpoint constructor here passes an explicit capability no
permissive default is relied upon.
"""
import pytest
import asyncio
@ -6,13 +17,32 @@ from unittest.mock import AsyncMock, MagicMock, patch
from aiohttp import web, WSMsgType
from trustgraph.gateway.endpoint.socket import SocketEndpoint
from trustgraph.gateway.running import Running
from trustgraph.gateway.auth import Identity
# Representative capability used across these tests — corresponds to
# the flow-import streaming endpoint pattern that uses this class.
TEST_CAP = "graph:write"
def _valid_identity():
return Identity(
handle="test-user",
workspace="default",
principal_id="test-user",
source="api-key",
)
@pytest.fixture
def mock_auth():
"""Mock authentication service."""
"""Mock IAM-backed authenticator. Successful by default —
``authenticate`` returns a valid identity and ``authorise``
allows everything. Tests that need the failure paths override
the relevant attribute locally."""
auth = MagicMock()
auth.permitted.return_value = True
auth.authenticate = AsyncMock(return_value=_valid_identity())
auth.authorise = AsyncMock(return_value=None)
return auth
@ -25,7 +55,7 @@ def mock_dispatcher_factory():
dispatcher.receive = AsyncMock()
dispatcher.destroy = AsyncMock()
return dispatcher
return dispatcher_factory
@ -35,7 +65,8 @@ def socket_endpoint(mock_auth, mock_dispatcher_factory):
return SocketEndpoint(
endpoint_path="/test-socket",
auth=mock_auth,
dispatcher=mock_dispatcher_factory
dispatcher=mock_dispatcher_factory,
capability=TEST_CAP,
)
@ -61,7 +92,10 @@ def mock_request():
@pytest.mark.asyncio
async def test_listener_graceful_shutdown_on_close():
"""Test listener handles websocket close gracefully."""
socket_endpoint = SocketEndpoint("/test", MagicMock(), AsyncMock())
socket_endpoint = SocketEndpoint(
"/test", MagicMock(), AsyncMock(),
capability=TEST_CAP,
)
# Mock websocket that closes after one message
ws = AsyncMock()
@ -99,9 +133,10 @@ async def test_listener_graceful_shutdown_on_close():
@pytest.mark.asyncio
async def test_handle_normal_flow():
"""Test normal websocket handling flow."""
"""Valid bearer → handshake accepted, dispatcher created."""
mock_auth = MagicMock()
mock_auth.permitted.return_value = True
mock_auth.authenticate = AsyncMock(return_value=_valid_identity())
mock_auth.authorise = AsyncMock(return_value=None)
dispatcher_created = False
async def mock_dispatcher_factory(ws, running, match_info):
@ -111,7 +146,10 @@ async def test_handle_normal_flow():
dispatcher.destroy = AsyncMock()
return dispatcher
socket_endpoint = SocketEndpoint("/test", mock_auth, mock_dispatcher_factory)
socket_endpoint = SocketEndpoint(
"/test", mock_auth, mock_dispatcher_factory,
capability=TEST_CAP,
)
request = MagicMock()
request.query = {"token": "valid-token"}
@ -155,7 +193,8 @@ async def test_handle_normal_flow():
async def test_handle_exception_group_cleanup():
"""Test exception group triggers dispatcher cleanup."""
mock_auth = MagicMock()
mock_auth.permitted.return_value = True
mock_auth.authenticate = AsyncMock(return_value=_valid_identity())
mock_auth.authorise = AsyncMock(return_value=None)
mock_dispatcher = AsyncMock()
mock_dispatcher.destroy = AsyncMock()
@ -163,7 +202,10 @@ async def test_handle_exception_group_cleanup():
async def mock_dispatcher_factory(ws, running, match_info):
return mock_dispatcher
socket_endpoint = SocketEndpoint("/test", mock_auth, mock_dispatcher_factory)
socket_endpoint = SocketEndpoint(
"/test", mock_auth, mock_dispatcher_factory,
capability=TEST_CAP,
)
request = MagicMock()
request.query = {"token": "valid-token"}
@ -222,7 +264,8 @@ async def test_handle_exception_group_cleanup():
async def test_handle_dispatcher_cleanup_timeout():
"""Test dispatcher cleanup with timeout."""
mock_auth = MagicMock()
mock_auth.permitted.return_value = True
mock_auth.authenticate = AsyncMock(return_value=_valid_identity())
mock_auth.authorise = AsyncMock(return_value=None)
# Mock dispatcher that takes long to destroy
mock_dispatcher = AsyncMock()
@ -231,7 +274,10 @@ async def test_handle_dispatcher_cleanup_timeout():
async def mock_dispatcher_factory(ws, running, match_info):
return mock_dispatcher
socket_endpoint = SocketEndpoint("/test", mock_auth, mock_dispatcher_factory)
socket_endpoint = SocketEndpoint(
"/test", mock_auth, mock_dispatcher_factory,
capability=TEST_CAP,
)
request = MagicMock()
request.query = {"token": "valid-token"}
@ -285,49 +331,68 @@ async def test_handle_dispatcher_cleanup_timeout():
@pytest.mark.asyncio
async def test_handle_unauthorized_request():
"""Test handling of unauthorized requests."""
"""A bearer that the IAM layer rejects causes the handshake to
fail with 401. IamAuth surfaces an HTTPUnauthorized; the
endpoint propagates it. Note that the endpoint intentionally
does NOT distinguish 'bad token', 'expired', 'revoked', etc.
that's the IAM error-masking policy."""
mock_auth = MagicMock()
mock_auth.permitted.return_value = False # Unauthorized
socket_endpoint = SocketEndpoint("/test", mock_auth, AsyncMock())
mock_auth.authenticate = AsyncMock(side_effect=web.HTTPUnauthorized(
text='{"error":"auth failure"}',
content_type="application/json",
))
socket_endpoint = SocketEndpoint(
"/test", mock_auth, AsyncMock(),
capability=TEST_CAP,
)
request = MagicMock()
request.query = {"token": "invalid-token"}
result = await socket_endpoint.handle(request)
# Should return HTTP 401
assert isinstance(result, web.HTTPUnauthorized)
# Should have checked permission
mock_auth.permitted.assert_called_once_with("invalid-token", "socket")
# authenticate must have been invoked with a synthetic request
# carrying Bearer <the-token>. The endpoint wraps the query-
# string token into an Authorization header for a uniform auth
# path — the IAM layer does not look at query strings directly.
mock_auth.authenticate.assert_called_once()
passed_req = mock_auth.authenticate.call_args.args[0]
assert passed_req.headers["Authorization"] == "Bearer invalid-token"
@pytest.mark.asyncio
async def test_handle_missing_token():
"""Test handling of requests with missing token."""
"""Request with no ``token`` query param → 401 before any
IAM call is made (cheap short-circuit)."""
mock_auth = MagicMock()
mock_auth.permitted.return_value = False
socket_endpoint = SocketEndpoint("/test", mock_auth, AsyncMock())
mock_auth.authenticate = AsyncMock(
side_effect=AssertionError(
"authenticate must not be invoked when no token is present"
),
)
socket_endpoint = SocketEndpoint(
"/test", mock_auth, AsyncMock(),
capability=TEST_CAP,
)
request = MagicMock()
request.query = {} # No token
result = await socket_endpoint.handle(request)
# Should return HTTP 401
assert isinstance(result, web.HTTPUnauthorized)
# Should have checked permission with empty token
mock_auth.permitted.assert_called_once_with("", "socket")
mock_auth.authenticate.assert_not_called()
@pytest.mark.asyncio
async def test_handle_websocket_already_closed():
"""Test handling when websocket is already closed."""
mock_auth = MagicMock()
mock_auth.permitted.return_value = True
mock_auth.authenticate = AsyncMock(return_value=_valid_identity())
mock_auth.authorise = AsyncMock(return_value=None)
mock_dispatcher = AsyncMock()
mock_dispatcher.destroy = AsyncMock()
@ -335,7 +400,10 @@ async def test_handle_websocket_already_closed():
async def mock_dispatcher_factory(ws, running, match_info):
return mock_dispatcher
socket_endpoint = SocketEndpoint("/test", mock_auth, mock_dispatcher_factory)
socket_endpoint = SocketEndpoint(
"/test", mock_auth, mock_dispatcher_factory,
capability=TEST_CAP,
)
request = MagicMock()
request.query = {"token": "valid-token"}

View file

@ -15,13 +15,13 @@ from trustgraph.base import LlmResult
class TestOllamaProcessorSimple(IsolatedAsyncioTestCase):
"""Test Ollama processor functionality"""
@patch('trustgraph.model.text_completion.ollama.llm.Client')
@patch('trustgraph.model.text_completion.ollama.llm.AsyncClient')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_processor_initialization_basic(self, mock_llm_init, mock_async_init, mock_client_class):
"""Test basic processor initialization"""
# Arrange
mock_client = MagicMock()
mock_client = AsyncMock()
mock_client_class.return_value = mock_client
# Mock the parent class initialization
@ -44,13 +44,13 @@ class TestOllamaProcessorSimple(IsolatedAsyncioTestCase):
assert hasattr(processor, 'llm')
mock_client_class.assert_called_once_with(host='http://localhost:11434')
@patch('trustgraph.model.text_completion.ollama.llm.Client')
@patch('trustgraph.model.text_completion.ollama.llm.AsyncClient')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_success(self, mock_llm_init, mock_async_init, mock_client_class):
"""Test successful content generation"""
# Arrange
mock_client = MagicMock()
mock_client = AsyncMock()
mock_response = {
'response': 'Generated response from Ollama',
'prompt_eval_count': 15,
@ -83,13 +83,13 @@ class TestOllamaProcessorSimple(IsolatedAsyncioTestCase):
assert result.model == 'llama2'
mock_client.generate.assert_called_once_with('llama2', "System prompt\n\nUser prompt", options={'temperature': 0.0})
@patch('trustgraph.model.text_completion.ollama.llm.Client')
@patch('trustgraph.model.text_completion.ollama.llm.AsyncClient')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_generic_exception(self, mock_llm_init, mock_async_init, mock_client_class):
"""Test handling of generic exceptions"""
# Arrange
mock_client = MagicMock()
mock_client = AsyncMock()
mock_client.generate.side_effect = Exception("Connection error")
mock_client_class.return_value = mock_client
@ -110,13 +110,13 @@ class TestOllamaProcessorSimple(IsolatedAsyncioTestCase):
with pytest.raises(Exception, match="Connection error"):
await processor.generate_content("System prompt", "User prompt")
@patch('trustgraph.model.text_completion.ollama.llm.Client')
@patch('trustgraph.model.text_completion.ollama.llm.AsyncClient')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_processor_initialization_with_custom_parameters(self, mock_llm_init, mock_async_init, mock_client_class):
"""Test processor initialization with custom parameters"""
# Arrange
mock_client = MagicMock()
mock_client = AsyncMock()
mock_client_class.return_value = mock_client
mock_async_init.return_value = None
@ -137,13 +137,13 @@ class TestOllamaProcessorSimple(IsolatedAsyncioTestCase):
assert processor.default_model == 'mistral'
mock_client_class.assert_called_once_with(host='http://192.168.1.100:11434')
@patch('trustgraph.model.text_completion.ollama.llm.Client')
@patch('trustgraph.model.text_completion.ollama.llm.AsyncClient')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_processor_initialization_with_defaults(self, mock_llm_init, mock_async_init, mock_client_class):
"""Test processor initialization with default values"""
# Arrange
mock_client = MagicMock()
mock_client = AsyncMock()
mock_client_class.return_value = mock_client
mock_async_init.return_value = None
@ -164,13 +164,13 @@ class TestOllamaProcessorSimple(IsolatedAsyncioTestCase):
# Should use default_ollama (http://localhost:11434 or from OLLAMA_HOST env)
mock_client_class.assert_called_once()
@patch('trustgraph.model.text_completion.ollama.llm.Client')
@patch('trustgraph.model.text_completion.ollama.llm.AsyncClient')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_empty_prompts(self, mock_llm_init, mock_async_init, mock_client_class):
"""Test content generation with empty prompts"""
# Arrange
mock_client = MagicMock()
mock_client = AsyncMock()
mock_response = {
'response': 'Default response',
'prompt_eval_count': 2,
@ -205,13 +205,13 @@ class TestOllamaProcessorSimple(IsolatedAsyncioTestCase):
# The prompt should be "" + "\n\n" + "" = "\n\n"
mock_client.generate.assert_called_once_with('llama2', "\n\n", options={'temperature': 0.0})
@patch('trustgraph.model.text_completion.ollama.llm.Client')
@patch('trustgraph.model.text_completion.ollama.llm.AsyncClient')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_token_counting(self, mock_llm_init, mock_async_init, mock_client_class):
"""Test token counting from Ollama response"""
# Arrange
mock_client = MagicMock()
mock_client = AsyncMock()
mock_response = {
'response': 'Test response',
'prompt_eval_count': 50,
@ -243,13 +243,13 @@ class TestOllamaProcessorSimple(IsolatedAsyncioTestCase):
assert result.out_token == 25
assert result.model == 'llama2'
@patch('trustgraph.model.text_completion.ollama.llm.Client')
@patch('trustgraph.model.text_completion.ollama.llm.AsyncClient')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_ollama_client_initialization(self, mock_llm_init, mock_async_init, mock_client_class):
"""Test that Ollama client is initialized correctly"""
# Arrange
mock_client = MagicMock()
mock_client = AsyncMock()
mock_client_class.return_value = mock_client
mock_async_init.return_value = None
@ -273,13 +273,13 @@ class TestOllamaProcessorSimple(IsolatedAsyncioTestCase):
# Verify processor has the client
assert processor.llm == mock_client
@patch('trustgraph.model.text_completion.ollama.llm.Client')
@patch('trustgraph.model.text_completion.ollama.llm.AsyncClient')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_prompt_construction(self, mock_llm_init, mock_async_init, mock_client_class):
"""Test prompt construction with system and user prompts"""
# Arrange
mock_client = MagicMock()
mock_client = AsyncMock()
mock_response = {
'response': 'Response with system instructions',
'prompt_eval_count': 25,
@ -312,13 +312,13 @@ class TestOllamaProcessorSimple(IsolatedAsyncioTestCase):
# Verify the combined prompt
mock_client.generate.assert_called_once_with('llama2', "You are a helpful assistant\n\nWhat is AI?", options={'temperature': 0.0})
@patch('trustgraph.model.text_completion.ollama.llm.Client')
@patch('trustgraph.model.text_completion.ollama.llm.AsyncClient')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_temperature_override(self, mock_llm_init, mock_async_init, mock_client_class):
"""Test temperature parameter override functionality"""
# Arrange
mock_client = MagicMock()
mock_client = AsyncMock()
mock_response = {
'response': 'Response with custom temperature',
'prompt_eval_count': 20,
@ -360,13 +360,13 @@ class TestOllamaProcessorSimple(IsolatedAsyncioTestCase):
options={'temperature': 0.8} # Should use runtime override
)
@patch('trustgraph.model.text_completion.ollama.llm.Client')
@patch('trustgraph.model.text_completion.ollama.llm.AsyncClient')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_model_override(self, mock_llm_init, mock_async_init, mock_client_class):
"""Test model parameter override functionality"""
# Arrange
mock_client = MagicMock()
mock_client = AsyncMock()
mock_response = {
'response': 'Response with custom model',
'prompt_eval_count': 18,
@ -408,13 +408,13 @@ class TestOllamaProcessorSimple(IsolatedAsyncioTestCase):
options={'temperature': 0.1} # Should use processor default
)
@patch('trustgraph.model.text_completion.ollama.llm.Client')
@patch('trustgraph.model.text_completion.ollama.llm.AsyncClient')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_both_parameters_override(self, mock_llm_init, mock_async_init, mock_client_class):
"""Test overriding both model and temperature parameters simultaneously"""
# Arrange
mock_client = MagicMock()
mock_client = AsyncMock()
mock_response = {
'response': 'Response with both overrides',
'prompt_eval_count': 22,