mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-07-01 09:29:38 +02:00
Merge remote-tracking branch 'origin/master' into ts-port
This commit is contained in:
commit
54a6e49bd3
97 changed files with 10771 additions and 1264 deletions
|
|
@ -14,13 +14,13 @@ from trustgraph.embeddings.ollama.processor import Processor
|
|||
class TestOllamaDynamicModelLoading(IsolatedAsyncioTestCase):
|
||||
"""Test Ollama dynamic model selection"""
|
||||
|
||||
@patch('trustgraph.embeddings.ollama.processor.Client')
|
||||
@patch('trustgraph.embeddings.ollama.processor.AsyncClient')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__')
|
||||
async def test_client_initialized_with_host(self, mock_embeddings_init, mock_async_init, mock_client_class):
|
||||
"""Test that Ollama client is initialized with correct host"""
|
||||
# Arrange
|
||||
mock_ollama_client = Mock()
|
||||
mock_ollama_client = AsyncMock()
|
||||
mock_response = Mock()
|
||||
mock_response.embeddings = [[0.1, 0.2, 0.3, 0.4, 0.5]]
|
||||
mock_ollama_client.embed.return_value = mock_response
|
||||
|
|
@ -36,13 +36,13 @@ class TestOllamaDynamicModelLoading(IsolatedAsyncioTestCase):
|
|||
mock_client_class.assert_called_once_with(host="http://localhost:11434")
|
||||
assert processor.default_model == "test-model"
|
||||
|
||||
@patch('trustgraph.embeddings.ollama.processor.Client')
|
||||
@patch('trustgraph.embeddings.ollama.processor.AsyncClient')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__')
|
||||
async def test_on_embeddings_uses_default_model(self, mock_embeddings_init, mock_async_init, mock_client_class):
|
||||
"""Test that on_embeddings uses default model when no model specified"""
|
||||
# Arrange
|
||||
mock_ollama_client = Mock()
|
||||
mock_ollama_client = AsyncMock()
|
||||
mock_response = Mock()
|
||||
mock_response.embeddings = [[0.1, 0.2, 0.3, 0.4, 0.5]]
|
||||
mock_ollama_client.embed.return_value = mock_response
|
||||
|
|
@ -62,13 +62,13 @@ class TestOllamaDynamicModelLoading(IsolatedAsyncioTestCase):
|
|||
)
|
||||
assert result == [[0.1, 0.2, 0.3, 0.4, 0.5]]
|
||||
|
||||
@patch('trustgraph.embeddings.ollama.processor.Client')
|
||||
@patch('trustgraph.embeddings.ollama.processor.AsyncClient')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__')
|
||||
async def test_on_embeddings_uses_specified_model(self, mock_embeddings_init, mock_async_init, mock_client_class):
|
||||
"""Test that on_embeddings uses specified model when provided"""
|
||||
# Arrange
|
||||
mock_ollama_client = Mock()
|
||||
mock_ollama_client = AsyncMock()
|
||||
mock_response = Mock()
|
||||
mock_response.embeddings = [[0.1, 0.2, 0.3, 0.4, 0.5]]
|
||||
mock_ollama_client.embed.return_value = mock_response
|
||||
|
|
@ -88,13 +88,13 @@ class TestOllamaDynamicModelLoading(IsolatedAsyncioTestCase):
|
|||
)
|
||||
assert result == [[0.1, 0.2, 0.3, 0.4, 0.5]]
|
||||
|
||||
@patch('trustgraph.embeddings.ollama.processor.Client')
|
||||
@patch('trustgraph.embeddings.ollama.processor.AsyncClient')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__')
|
||||
async def test_multiple_model_switches(self, mock_embeddings_init, mock_async_init, mock_client_class):
|
||||
"""Test switching between multiple models"""
|
||||
# Arrange
|
||||
mock_ollama_client = Mock()
|
||||
mock_ollama_client = AsyncMock()
|
||||
mock_response = Mock()
|
||||
mock_response.embeddings = [[0.1, 0.2, 0.3, 0.4, 0.5]]
|
||||
mock_ollama_client.embed.return_value = mock_response
|
||||
|
|
@ -118,13 +118,13 @@ class TestOllamaDynamicModelLoading(IsolatedAsyncioTestCase):
|
|||
assert calls[2][1]['model'] == "model-a"
|
||||
assert calls[3][1]['model'] == "test-model" # Default
|
||||
|
||||
@patch('trustgraph.embeddings.ollama.processor.Client')
|
||||
@patch('trustgraph.embeddings.ollama.processor.AsyncClient')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__')
|
||||
async def test_none_model_uses_default(self, mock_embeddings_init, mock_async_init, mock_client_class):
|
||||
"""Test that None model parameter falls back to default"""
|
||||
# Arrange
|
||||
mock_ollama_client = Mock()
|
||||
mock_ollama_client = AsyncMock()
|
||||
mock_response = Mock()
|
||||
mock_response.embeddings = [[0.1, 0.2, 0.3, 0.4, 0.5]]
|
||||
mock_ollama_client.embed.return_value = mock_response
|
||||
|
|
@ -143,13 +143,13 @@ class TestOllamaDynamicModelLoading(IsolatedAsyncioTestCase):
|
|||
input=["test text"]
|
||||
)
|
||||
|
||||
@patch('trustgraph.embeddings.ollama.processor.Client')
|
||||
@patch('trustgraph.embeddings.ollama.processor.AsyncClient')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__')
|
||||
async def test_initialization_without_model_uses_default(self, mock_embeddings_init, mock_async_init, mock_client_class):
|
||||
"""Test initialization without model parameter uses module default"""
|
||||
# Arrange
|
||||
mock_ollama_client = Mock()
|
||||
mock_ollama_client = AsyncMock()
|
||||
mock_client_class.return_value = mock_ollama_client
|
||||
mock_async_init.return_value = None
|
||||
mock_embeddings_init.return_value = None
|
||||
|
|
|
|||
|
|
@ -277,6 +277,60 @@ class TestTripleValidation:
|
|||
is_invalid = extractor.is_valid_triple(subject, predicate, object_val, sample_ontology_subset, entity_types_invalid)
|
||||
assert not is_invalid, "Invalid range should be rejected"
|
||||
|
||||
def test_is_valid_triple_subclass_is_accepted(self, extractor, sample_ontology_subset):
|
||||
"""Domain check passes when actual type is a subclass of expected."""
|
||||
sample_ontology_subset.classes["Cake"] = {
|
||||
"uri": "http://purl.org/ontology/fo/Cake",
|
||||
"type": "owl:Class",
|
||||
"subclass_of": "Recipe",
|
||||
}
|
||||
sample_ontology_subset.object_properties["has_ingredient"] = {
|
||||
"domain": "Recipe",
|
||||
"range": "Ingredient",
|
||||
}
|
||||
|
||||
result = extractor.is_valid_triple(
|
||||
subject="cake:lemon-drizzle",
|
||||
predicate="has_ingredient",
|
||||
object_val="ingredient:lemon",
|
||||
ontology_subset=sample_ontology_subset,
|
||||
entity_types={"cake:lemon-drizzle": "Cake", "ingredient:lemon": "Ingredient"},
|
||||
)
|
||||
|
||||
assert result is True
|
||||
|
||||
def test_is_valid_triple_handles_subclass_cycle_without_infinite_loop(self, extractor, sample_ontology_subset):
|
||||
"""A cycle in subclass_of must return False instead of hanging."""
|
||||
sample_ontology_subset.classes["A"] = {"subclass_of": "B"}
|
||||
sample_ontology_subset.classes["B"] = {"subclass_of": "A"}
|
||||
sample_ontology_subset.object_properties["p"] = {"domain": "Recipe", "range": "Ingredient"}
|
||||
|
||||
result = extractor.is_valid_triple(
|
||||
subject="entity:x",
|
||||
predicate="p",
|
||||
object_val="ingredient:y",
|
||||
ontology_subset=sample_ontology_subset,
|
||||
entity_types={"entity:x": "A", "ingredient:y": "Ingredient"},
|
||||
)
|
||||
|
||||
assert result is False
|
||||
|
||||
def test_is_valid_triple_entity_types_none_default(self, extractor, sample_ontology_subset):
|
||||
"""entity_types=None should not raise; domain/range checks skip if type unknown."""
|
||||
sample_ontology_subset.object_properties["has_ingredient"] = {
|
||||
"domain": "Recipe",
|
||||
"range": "Ingredient",
|
||||
}
|
||||
|
||||
result = extractor.is_valid_triple(
|
||||
subject="recipe:x",
|
||||
predicate="has_ingredient",
|
||||
object_val="ingredient:y",
|
||||
ontology_subset=sample_ontology_subset,
|
||||
)
|
||||
|
||||
assert result is True
|
||||
|
||||
|
||||
class TestTripleParsing:
|
||||
"""Test suite for parsing triples from LLM responses."""
|
||||
|
|
@ -377,6 +431,24 @@ class TestTripleParsing:
|
|||
assert triple.p.type == IRI, "Predicate should be IRI type"
|
||||
assert triple.o.type == LITERAL, "Object literal should be LITERAL type"
|
||||
|
||||
def test_parse_and_validate_triples_collects_entity_types_from_rdf_type(self, extractor, sample_ontology_subset):
|
||||
"""entity_types should be built from rdf:type triples in the same batch."""
|
||||
sample_ontology_subset.object_properties["has_ingredient"] = {
|
||||
"domain": "Recipe",
|
||||
"range": "Ingredient",
|
||||
}
|
||||
triples_response = [
|
||||
{"subject": "recipe:cornish-pasty", "predicate": "rdf:type", "object": "Recipe"},
|
||||
{"subject": "ingredient:beef", "predicate": "rdf:type", "object": "Ingredient"},
|
||||
{"subject": "recipe:cornish-pasty", "predicate": "has_ingredient", "object": "ingredient:beef"},
|
||||
]
|
||||
|
||||
valid_triples = extractor.parse_and_validate_triples(
|
||||
triples_response, sample_ontology_subset
|
||||
)
|
||||
|
||||
assert len(valid_triples) == 3
|
||||
|
||||
|
||||
class TestURIExpansionInExtraction:
|
||||
"""Test suite for URI expansion during triple extraction."""
|
||||
|
|
|
|||
|
|
@ -1,69 +1,447 @@
|
|||
"""
|
||||
Tests for Gateway Authentication
|
||||
Tests for gateway/auth.py — IamAuth, JWT verification, API key
|
||||
resolution cache.
|
||||
|
||||
JWTs are signed with real Ed25519 keypairs generated per-test, so
|
||||
the crypto path is exercised end-to-end without mocks. API-key
|
||||
resolution is tested against a stubbed IamClient since the real
|
||||
one requires pub/sub.
|
||||
"""
|
||||
|
||||
import base64
|
||||
import json
|
||||
import time
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
from aiohttp import web
|
||||
from cryptography.hazmat.primitives import serialization
|
||||
from cryptography.hazmat.primitives.asymmetric import ed25519
|
||||
|
||||
from trustgraph.gateway.auth import Authenticator
|
||||
from trustgraph.gateway.auth import (
|
||||
IamAuth, Identity,
|
||||
_b64url_decode, _verify_jwt_eddsa,
|
||||
API_KEY_CACHE_TTL,
|
||||
)
|
||||
|
||||
|
||||
class TestAuthenticator:
|
||||
"""Test cases for Authenticator class"""
|
||||
# -- helpers ---------------------------------------------------------------
|
||||
|
||||
def test_authenticator_initialization_with_token(self):
|
||||
"""Test Authenticator initialization with valid token"""
|
||||
auth = Authenticator(token="test-token-123")
|
||||
|
||||
assert auth.token == "test-token-123"
|
||||
assert auth.allow_all is False
|
||||
|
||||
def test_authenticator_initialization_with_allow_all(self):
|
||||
"""Test Authenticator initialization with allow_all=True"""
|
||||
auth = Authenticator(allow_all=True)
|
||||
|
||||
assert auth.token is None
|
||||
assert auth.allow_all is True
|
||||
def _b64url(data: bytes) -> str:
|
||||
return base64.urlsafe_b64encode(data).rstrip(b"=").decode("ascii")
|
||||
|
||||
def test_authenticator_initialization_without_token_raises_error(self):
|
||||
"""Test Authenticator initialization without token raises RuntimeError"""
|
||||
with pytest.raises(RuntimeError, match="Need a token"):
|
||||
Authenticator()
|
||||
|
||||
def test_authenticator_initialization_with_empty_token_raises_error(self):
|
||||
"""Test Authenticator initialization with empty token raises RuntimeError"""
|
||||
with pytest.raises(RuntimeError, match="Need a token"):
|
||||
Authenticator(token="")
|
||||
def make_keypair():
|
||||
priv = ed25519.Ed25519PrivateKey.generate()
|
||||
public_pem = priv.public_key().public_bytes(
|
||||
encoding=serialization.Encoding.PEM,
|
||||
format=serialization.PublicFormat.SubjectPublicKeyInfo,
|
||||
).decode("ascii")
|
||||
return priv, public_pem
|
||||
|
||||
def test_permitted_with_allow_all_returns_true(self):
|
||||
"""Test permitted method returns True when allow_all is enabled"""
|
||||
auth = Authenticator(allow_all=True)
|
||||
|
||||
# Should return True regardless of token or roles
|
||||
assert auth.permitted("any-token", []) is True
|
||||
assert auth.permitted("different-token", ["admin"]) is True
|
||||
assert auth.permitted(None, ["user"]) is True
|
||||
|
||||
def test_permitted_with_matching_token_returns_true(self):
|
||||
"""Test permitted method returns True with matching token"""
|
||||
auth = Authenticator(token="secret-token")
|
||||
|
||||
# Should return True when tokens match
|
||||
assert auth.permitted("secret-token", []) is True
|
||||
assert auth.permitted("secret-token", ["admin", "user"]) is True
|
||||
def sign_jwt(priv, claims, alg="EdDSA"):
|
||||
header = {"alg": alg, "typ": "JWT", "kid": "kid-test"}
|
||||
h = _b64url(json.dumps(header, separators=(",", ":"), sort_keys=True).encode())
|
||||
p = _b64url(json.dumps(claims, separators=(",", ":"), sort_keys=True).encode())
|
||||
signing_input = f"{h}.{p}".encode("ascii")
|
||||
if alg == "EdDSA":
|
||||
sig = priv.sign(signing_input)
|
||||
else:
|
||||
raise ValueError(f"test helper doesn't sign {alg}")
|
||||
return f"{h}.{p}.{_b64url(sig)}"
|
||||
|
||||
def test_permitted_with_non_matching_token_returns_false(self):
|
||||
"""Test permitted method returns False with non-matching token"""
|
||||
auth = Authenticator(token="secret-token")
|
||||
|
||||
# Should return False when tokens don't match
|
||||
assert auth.permitted("wrong-token", []) is False
|
||||
assert auth.permitted("different-token", ["admin"]) is False
|
||||
assert auth.permitted(None, ["user"]) is False
|
||||
|
||||
def test_permitted_with_token_and_allow_all_returns_true(self):
|
||||
"""Test permitted method with both token and allow_all set"""
|
||||
auth = Authenticator(token="test-token", allow_all=True)
|
||||
|
||||
# allow_all should take precedence
|
||||
assert auth.permitted("any-token", []) is True
|
||||
assert auth.permitted("wrong-token", ["admin"]) is True
|
||||
def make_request(auth_header):
|
||||
"""Minimal stand-in for an aiohttp request — IamAuth only reads
|
||||
``request.headers["Authorization"]``."""
|
||||
req = Mock()
|
||||
req.headers = {}
|
||||
if auth_header is not None:
|
||||
req.headers["Authorization"] = auth_header
|
||||
return req
|
||||
|
||||
|
||||
# -- pure helpers ----------------------------------------------------------
|
||||
|
||||
|
||||
class TestB64UrlDecode:
|
||||
|
||||
def test_round_trip_without_padding(self):
|
||||
data = b"hello"
|
||||
encoded = _b64url(data)
|
||||
assert _b64url_decode(encoded) == data
|
||||
|
||||
def test_handles_various_lengths(self):
|
||||
for s in (b"a", b"ab", b"abc", b"abcd", b"abcde"):
|
||||
assert _b64url_decode(_b64url(s)) == s
|
||||
|
||||
|
||||
# -- JWT verification -----------------------------------------------------
|
||||
|
||||
|
||||
class TestVerifyJwtEddsa:
|
||||
|
||||
def test_valid_jwt_passes(self):
|
||||
priv, pub = make_keypair()
|
||||
claims = {
|
||||
"sub": "user-1", "workspace": "default",
|
||||
"iat": int(time.time()),
|
||||
"exp": int(time.time()) + 60,
|
||||
}
|
||||
token = sign_jwt(priv, claims)
|
||||
got = _verify_jwt_eddsa(token, pub)
|
||||
assert got["sub"] == "user-1"
|
||||
assert got["workspace"] == "default"
|
||||
|
||||
def test_expired_jwt_rejected(self):
|
||||
priv, pub = make_keypair()
|
||||
claims = {
|
||||
"sub": "user-1", "workspace": "default",
|
||||
"iat": int(time.time()) - 3600,
|
||||
"exp": int(time.time()) - 1,
|
||||
}
|
||||
token = sign_jwt(priv, claims)
|
||||
with pytest.raises(ValueError, match="expired"):
|
||||
_verify_jwt_eddsa(token, pub)
|
||||
|
||||
def test_bad_signature_rejected(self):
|
||||
priv_a, _ = make_keypair()
|
||||
_, pub_b = make_keypair()
|
||||
claims = {
|
||||
"sub": "user-1", "workspace": "default",
|
||||
"iat": int(time.time()),
|
||||
"exp": int(time.time()) + 60,
|
||||
}
|
||||
token = sign_jwt(priv_a, claims)
|
||||
# pub_b never signed this token.
|
||||
with pytest.raises(Exception):
|
||||
_verify_jwt_eddsa(token, pub_b)
|
||||
|
||||
def test_malformed_jwt_rejected(self):
|
||||
_, pub = make_keypair()
|
||||
with pytest.raises(ValueError, match="malformed"):
|
||||
_verify_jwt_eddsa("not-a-jwt", pub)
|
||||
|
||||
def test_unsupported_algorithm_rejected(self):
|
||||
priv, pub = make_keypair()
|
||||
# Manually build an "alg":"HS256" header — no signer needed
|
||||
# since we expect it to bail before verifying.
|
||||
header = {"alg": "HS256", "typ": "JWT", "kid": "x"}
|
||||
payload = {
|
||||
"sub": "user-1", "workspace": "default",
|
||||
"iat": int(time.time()), "exp": int(time.time()) + 60,
|
||||
}
|
||||
h = _b64url(json.dumps(header, separators=(",", ":")).encode())
|
||||
p = _b64url(json.dumps(payload, separators=(",", ":")).encode())
|
||||
sig = _b64url(b"not-a-real-sig")
|
||||
token = f"{h}.{p}.{sig}"
|
||||
with pytest.raises(ValueError, match="unsupported alg"):
|
||||
_verify_jwt_eddsa(token, pub)
|
||||
|
||||
|
||||
# -- Identity --------------------------------------------------------------
|
||||
|
||||
|
||||
class TestIdentity:
|
||||
|
||||
def test_fields(self):
|
||||
i = Identity(
|
||||
handle="u", workspace="w",
|
||||
principal_id="u", source="api-key",
|
||||
)
|
||||
assert i.handle == "u"
|
||||
assert i.workspace == "w"
|
||||
assert i.principal_id == "u"
|
||||
assert i.source == "api-key"
|
||||
|
||||
|
||||
# -- IamAuth.authenticate --------------------------------------------------
|
||||
|
||||
|
||||
class TestIamAuthDispatch:
|
||||
"""``authenticate()`` chooses between the JWT and API-key paths
|
||||
by shape of the bearer."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_authorization_header_raises_401(self):
|
||||
auth = IamAuth(backend=Mock())
|
||||
with pytest.raises(web.HTTPUnauthorized):
|
||||
await auth.authenticate(make_request(None))
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_bearer_header_raises_401(self):
|
||||
auth = IamAuth(backend=Mock())
|
||||
with pytest.raises(web.HTTPUnauthorized):
|
||||
await auth.authenticate(make_request("Basic whatever"))
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_bearer_raises_401(self):
|
||||
auth = IamAuth(backend=Mock())
|
||||
with pytest.raises(web.HTTPUnauthorized):
|
||||
await auth.authenticate(make_request("Bearer "))
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unknown_format_raises_401(self):
|
||||
# Not tg_... and not dotted-JWT shape.
|
||||
auth = IamAuth(backend=Mock())
|
||||
with pytest.raises(web.HTTPUnauthorized):
|
||||
await auth.authenticate(make_request("Bearer garbage"))
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_valid_jwt_resolves_to_identity(self):
|
||||
priv, pub = make_keypair()
|
||||
claims = {
|
||||
"sub": "user-1", "workspace": "default",
|
||||
"iat": int(time.time()),
|
||||
"exp": int(time.time()) + 60,
|
||||
}
|
||||
token = sign_jwt(priv, claims)
|
||||
|
||||
auth = IamAuth(backend=Mock())
|
||||
auth._signing_public_pem = pub
|
||||
|
||||
ident = await auth.authenticate(
|
||||
make_request(f"Bearer {token}")
|
||||
)
|
||||
assert ident.handle == "user-1"
|
||||
assert ident.workspace == "default"
|
||||
assert ident.principal_id == "user-1"
|
||||
assert ident.source == "jwt"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_jwt_without_public_key_fails(self):
|
||||
# If the gateway hasn't fetched IAM's public key yet, JWTs
|
||||
# must not validate — even ones that would otherwise pass.
|
||||
priv, _ = make_keypair()
|
||||
claims = {
|
||||
"sub": "user-1", "workspace": "default",
|
||||
"iat": int(time.time()), "exp": int(time.time()) + 60,
|
||||
}
|
||||
token = sign_jwt(priv, claims)
|
||||
auth = IamAuth(backend=Mock())
|
||||
# _signing_public_pem defaults to None
|
||||
with pytest.raises(web.HTTPUnauthorized):
|
||||
await auth.authenticate(make_request(f"Bearer {token}"))
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_api_key_path(self):
|
||||
auth = IamAuth(backend=Mock())
|
||||
|
||||
async def fake_resolve(api_key):
|
||||
assert api_key == "tg_testkey"
|
||||
# Roles are returned by the regime as a hint but the
|
||||
# gateway ignores them — kept here so the resolve
|
||||
# protocol shape is exercised.
|
||||
return ("user-xyz", "default", ["admin"])
|
||||
|
||||
async def fake_with_client(op):
|
||||
return await op(Mock(resolve_api_key=fake_resolve))
|
||||
|
||||
with patch.object(auth, "_with_client", side_effect=fake_with_client):
|
||||
ident = await auth.authenticate(
|
||||
make_request("Bearer tg_testkey")
|
||||
)
|
||||
assert ident.handle == "user-xyz"
|
||||
assert ident.workspace == "default"
|
||||
assert ident.principal_id == "user-xyz"
|
||||
assert ident.source == "api-key"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_api_key_rejection_masked_as_401(self):
|
||||
auth = IamAuth(backend=Mock())
|
||||
|
||||
async def fake_with_client(op):
|
||||
raise RuntimeError("auth-failed: unknown api key")
|
||||
|
||||
with patch.object(auth, "_with_client", side_effect=fake_with_client):
|
||||
with pytest.raises(web.HTTPUnauthorized):
|
||||
await auth.authenticate(
|
||||
make_request("Bearer tg_bogus")
|
||||
)
|
||||
|
||||
|
||||
# -- API key cache ---------------------------------------------------------
|
||||
|
||||
|
||||
class TestApiKeyCache:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_hit_skips_iam(self):
|
||||
auth = IamAuth(backend=Mock())
|
||||
calls = {"n": 0}
|
||||
|
||||
async def fake_with_client(op):
|
||||
calls["n"] += 1
|
||||
return await op(Mock(
|
||||
resolve_api_key=AsyncMock(
|
||||
return_value=("u", "default", ["reader"]),
|
||||
)
|
||||
))
|
||||
|
||||
with patch.object(auth, "_with_client", side_effect=fake_with_client):
|
||||
await auth.authenticate(make_request("Bearer tg_k1"))
|
||||
await auth.authenticate(make_request("Bearer tg_k1"))
|
||||
await auth.authenticate(make_request("Bearer tg_k1"))
|
||||
|
||||
# Only the first lookup reaches IAM; the rest are cache hits.
|
||||
assert calls["n"] == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_different_keys_are_separately_cached(self):
|
||||
auth = IamAuth(backend=Mock())
|
||||
seen = []
|
||||
|
||||
async def fake_with_client(op):
|
||||
async def resolve(plaintext):
|
||||
seen.append(plaintext)
|
||||
return ("u-" + plaintext, "default", ["reader"])
|
||||
return await op(Mock(resolve_api_key=resolve))
|
||||
|
||||
with patch.object(auth, "_with_client", side_effect=fake_with_client):
|
||||
a = await auth.authenticate(make_request("Bearer tg_a"))
|
||||
b = await auth.authenticate(make_request("Bearer tg_b"))
|
||||
|
||||
assert a.handle == "u-tg_a"
|
||||
assert b.handle == "u-tg_b"
|
||||
assert seen == ["tg_a", "tg_b"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_has_ttl_constant_set(self):
|
||||
# Not a behaviour test — just ensures we don't accidentally
|
||||
# set TTL to 0 (which would defeat the cache) or to a week.
|
||||
assert 10 <= API_KEY_CACHE_TTL <= 3600
|
||||
|
||||
|
||||
# -- IamAuth.authorise -----------------------------------------------------
|
||||
|
||||
|
||||
class TestAuthorise:
|
||||
"""``authorise()`` is the gateway's only authorisation entry
|
||||
point under the IAM contract. It calls iam-svc, caches the
|
||||
decision for the regime's TTL (clamped above), and raises 403
|
||||
on deny / 401 on regime error (fail closed)."""
|
||||
|
||||
def _make_identity(self, handle="u-1", workspace="default"):
|
||||
return Identity(
|
||||
handle=handle, workspace=workspace,
|
||||
principal_id=handle, source="api-key",
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_allow_returns_no_exception(self):
|
||||
auth = IamAuth(backend=Mock())
|
||||
|
||||
async def fake_with_client(op):
|
||||
return await op(Mock(
|
||||
authorise=AsyncMock(return_value=(True, 30)),
|
||||
))
|
||||
|
||||
with patch.object(auth, "_with_client", side_effect=fake_with_client):
|
||||
await auth.authorise(
|
||||
self._make_identity(),
|
||||
"graph:read",
|
||||
{"workspace": "default"},
|
||||
{},
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_deny_raises_403(self):
|
||||
auth = IamAuth(backend=Mock())
|
||||
|
||||
async def fake_with_client(op):
|
||||
return await op(Mock(
|
||||
authorise=AsyncMock(return_value=(False, 30)),
|
||||
))
|
||||
|
||||
with patch.object(auth, "_with_client", side_effect=fake_with_client):
|
||||
with pytest.raises(web.HTTPForbidden):
|
||||
await auth.authorise(
|
||||
self._make_identity(),
|
||||
"users:admin",
|
||||
{},
|
||||
{"workspace": "acme"},
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_regime_error_fails_closed_as_401(self):
|
||||
# If iam-svc errors, the gateway must NOT silently allow.
|
||||
auth = IamAuth(backend=Mock())
|
||||
|
||||
async def fake_with_client(op):
|
||||
raise RuntimeError("iam-svc down")
|
||||
|
||||
with patch.object(auth, "_with_client", side_effect=fake_with_client):
|
||||
with pytest.raises(web.HTTPUnauthorized):
|
||||
await auth.authorise(
|
||||
self._make_identity(),
|
||||
"graph:read",
|
||||
{"workspace": "default"},
|
||||
{},
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_allow_decision_is_cached(self):
|
||||
auth = IamAuth(backend=Mock())
|
||||
calls = {"n": 0}
|
||||
|
||||
async def fake_with_client(op):
|
||||
calls["n"] += 1
|
||||
return await op(Mock(
|
||||
authorise=AsyncMock(return_value=(True, 30)),
|
||||
))
|
||||
|
||||
with patch.object(auth, "_with_client", side_effect=fake_with_client):
|
||||
ident = self._make_identity()
|
||||
for _ in range(5):
|
||||
await auth.authorise(
|
||||
ident, "graph:read", {"workspace": "default"}, {},
|
||||
)
|
||||
|
||||
assert calls["n"] == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_deny_decision_is_cached(self):
|
||||
auth = IamAuth(backend=Mock())
|
||||
calls = {"n": 0}
|
||||
|
||||
async def fake_with_client(op):
|
||||
calls["n"] += 1
|
||||
return await op(Mock(
|
||||
authorise=AsyncMock(return_value=(False, 30)),
|
||||
))
|
||||
|
||||
with patch.object(auth, "_with_client", side_effect=fake_with_client):
|
||||
ident = self._make_identity()
|
||||
for _ in range(5):
|
||||
with pytest.raises(web.HTTPForbidden):
|
||||
await auth.authorise(
|
||||
ident, "users:admin", {}, {"workspace": "acme"},
|
||||
)
|
||||
|
||||
# Denies are cached too — repeated attempts don't re-hit IAM.
|
||||
assert calls["n"] == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_different_resources_cached_separately(self):
|
||||
auth = IamAuth(backend=Mock())
|
||||
calls = {"n": 0}
|
||||
|
||||
async def fake_with_client(op):
|
||||
calls["n"] += 1
|
||||
return await op(Mock(
|
||||
authorise=AsyncMock(return_value=(True, 30)),
|
||||
))
|
||||
|
||||
with patch.object(auth, "_with_client", side_effect=fake_with_client):
|
||||
ident = self._make_identity()
|
||||
await auth.authorise(
|
||||
ident, "graph:read", {"workspace": "a"}, {},
|
||||
)
|
||||
await auth.authorise(
|
||||
ident, "graph:read", {"workspace": "b"}, {},
|
||||
)
|
||||
|
||||
# Different resource → different cache key → two IAM calls.
|
||||
assert calls["n"] == 2
|
||||
|
|
|
|||
171
tests/unit/test_gateway/test_capabilities.py
Normal file
171
tests/unit/test_gateway/test_capabilities.py
Normal file
|
|
@ -0,0 +1,171 @@
|
|||
"""
|
||||
Tests for gateway/capabilities.py — the thin authorisation surface
|
||||
under the IAM contract.
|
||||
|
||||
The gateway no longer holds policy state (roles, capability sets,
|
||||
workspace scopes); those live in iam-svc. These tests cover only
|
||||
what the gateway shim does itself: PUBLIC / AUTHENTICATED short-
|
||||
circuiting, default-fill of workspace, and forwarding of capability
|
||||
checks to ``auth.authorise``.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from aiohttp import web
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
from trustgraph.gateway.capabilities import (
|
||||
PUBLIC, AUTHENTICATED,
|
||||
enforce, enforce_workspace,
|
||||
access_denied, auth_failure,
|
||||
)
|
||||
|
||||
|
||||
# -- test fixtures ---------------------------------------------------------
|
||||
|
||||
|
||||
class _Identity:
|
||||
"""Stand-in for auth.Identity — under the IAM contract it has
|
||||
just ``handle``, ``workspace``, ``principal_id``, ``source``."""
|
||||
|
||||
def __init__(self, handle="user-1", workspace="default"):
|
||||
self.handle = handle
|
||||
self.workspace = workspace
|
||||
self.principal_id = handle
|
||||
self.source = "api-key"
|
||||
|
||||
|
||||
def _allow_auth(identity=None):
|
||||
"""Build an Auth double that authenticates to ``identity`` and
|
||||
allows every authorise() call."""
|
||||
auth = MagicMock()
|
||||
auth.authenticate = AsyncMock(
|
||||
return_value=identity or _Identity(),
|
||||
)
|
||||
auth.authorise = AsyncMock(return_value=None)
|
||||
return auth
|
||||
|
||||
|
||||
def _deny_auth(identity=None):
|
||||
"""Build an Auth double that authenticates but denies authorise."""
|
||||
auth = MagicMock()
|
||||
auth.authenticate = AsyncMock(
|
||||
return_value=identity or _Identity(),
|
||||
)
|
||||
auth.authorise = AsyncMock(side_effect=access_denied())
|
||||
return auth
|
||||
|
||||
|
||||
# -- enforce() -------------------------------------------------------------
|
||||
|
||||
|
||||
class TestEnforce:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_public_returns_none_no_auth(self):
|
||||
auth = _allow_auth()
|
||||
result = await enforce(MagicMock(), auth, PUBLIC)
|
||||
assert result is None
|
||||
auth.authenticate.assert_not_called()
|
||||
auth.authorise.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_authenticated_skips_authorise(self):
|
||||
identity = _Identity()
|
||||
auth = _allow_auth(identity)
|
||||
result = await enforce(MagicMock(), auth, AUTHENTICATED)
|
||||
assert result is identity
|
||||
auth.authenticate.assert_awaited_once()
|
||||
auth.authorise.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_capability_calls_authorise_system_level(self):
|
||||
identity = _Identity()
|
||||
auth = _allow_auth(identity)
|
||||
result = await enforce(MagicMock(), auth, "graph:read")
|
||||
assert result is identity
|
||||
auth.authorise.assert_awaited_once_with(
|
||||
identity, "graph:read", {}, {},
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_capability_denied_raises_forbidden(self):
|
||||
auth = _deny_auth()
|
||||
with pytest.raises(web.HTTPForbidden):
|
||||
await enforce(MagicMock(), auth, "users:admin")
|
||||
|
||||
|
||||
# -- enforce_workspace() ---------------------------------------------------
|
||||
|
||||
|
||||
class TestEnforceWorkspace:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_default_fills_from_identity(self):
|
||||
data = {"operation": "x"}
|
||||
auth = _allow_auth()
|
||||
await enforce_workspace(data, _Identity(workspace="default"), auth)
|
||||
assert data["workspace"] == "default"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_caller_supplied_workspace_kept(self):
|
||||
data = {"workspace": "acme", "operation": "x"}
|
||||
auth = _allow_auth()
|
||||
await enforce_workspace(data, _Identity(workspace="default"), auth)
|
||||
assert data["workspace"] == "acme"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_capability_skips_authorise(self):
|
||||
data = {"workspace": "default"}
|
||||
auth = _allow_auth()
|
||||
await enforce_workspace(data, _Identity(), auth)
|
||||
auth.authorise.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_capability_calls_authorise_with_resource(self):
|
||||
data = {"workspace": "acme"}
|
||||
identity = _Identity()
|
||||
auth = _allow_auth(identity)
|
||||
await enforce_workspace(
|
||||
data, identity, auth, capability="graph:read",
|
||||
)
|
||||
auth.authorise.assert_awaited_once_with(
|
||||
identity, "graph:read", {"workspace": "acme"}, {},
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_capability_denied_propagates(self):
|
||||
data = {"workspace": "acme"}
|
||||
auth = _deny_auth()
|
||||
with pytest.raises(web.HTTPForbidden):
|
||||
await enforce_workspace(
|
||||
data, _Identity(), auth, capability="users:admin",
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_dict_passthrough(self):
|
||||
auth = _allow_auth()
|
||||
result = await enforce_workspace("not-a-dict", _Identity(), auth)
|
||||
assert result == "not-a-dict"
|
||||
auth.authorise.assert_not_called()
|
||||
|
||||
|
||||
# -- helpers ---------------------------------------------------------------
|
||||
|
||||
|
||||
class TestResponseHelpers:
|
||||
|
||||
def test_auth_failure_is_401(self):
|
||||
exc = auth_failure()
|
||||
assert exc.status == 401
|
||||
assert "auth failure" in exc.text
|
||||
|
||||
def test_access_denied_is_403(self):
|
||||
exc = access_denied()
|
||||
assert exc.status == 403
|
||||
assert "access denied" in exc.text
|
||||
|
||||
|
||||
class TestSentinels:
|
||||
|
||||
def test_public_and_authenticated_are_distinct(self):
|
||||
assert PUBLIC != AUTHENTICATED
|
||||
|
|
@ -42,7 +42,7 @@ class TestDispatcherManager:
|
|||
mock_backend = Mock()
|
||||
mock_config_receiver = Mock()
|
||||
|
||||
manager = DispatcherManager(mock_backend, mock_config_receiver)
|
||||
manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock())
|
||||
|
||||
assert manager.backend == mock_backend
|
||||
assert manager.config_receiver == mock_config_receiver
|
||||
|
|
@ -59,7 +59,10 @@ class TestDispatcherManager:
|
|||
mock_backend = Mock()
|
||||
mock_config_receiver = Mock()
|
||||
|
||||
manager = DispatcherManager(mock_backend, mock_config_receiver, prefix="custom-prefix")
|
||||
manager = DispatcherManager(
|
||||
mock_backend, mock_config_receiver,
|
||||
auth=Mock(), prefix="custom-prefix",
|
||||
)
|
||||
|
||||
assert manager.prefix == "custom-prefix"
|
||||
|
||||
|
|
@ -68,7 +71,7 @@ class TestDispatcherManager:
|
|||
"""Test start_flow method"""
|
||||
mock_backend = Mock()
|
||||
mock_config_receiver = Mock()
|
||||
manager = DispatcherManager(mock_backend, mock_config_receiver)
|
||||
manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock())
|
||||
|
||||
flow_data = {"name": "test_flow", "steps": []}
|
||||
|
||||
|
|
@ -82,7 +85,7 @@ class TestDispatcherManager:
|
|||
"""Test stop_flow method"""
|
||||
mock_backend = Mock()
|
||||
mock_config_receiver = Mock()
|
||||
manager = DispatcherManager(mock_backend, mock_config_receiver)
|
||||
manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock())
|
||||
|
||||
# Pre-populate with a flow
|
||||
flow_data = {"name": "test_flow", "steps": []}
|
||||
|
|
@ -96,7 +99,7 @@ class TestDispatcherManager:
|
|||
"""Test dispatch_global_service returns DispatcherWrapper"""
|
||||
mock_backend = Mock()
|
||||
mock_config_receiver = Mock()
|
||||
manager = DispatcherManager(mock_backend, mock_config_receiver)
|
||||
manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock())
|
||||
|
||||
wrapper = manager.dispatch_global_service()
|
||||
|
||||
|
|
@ -107,7 +110,7 @@ class TestDispatcherManager:
|
|||
"""Test dispatch_core_export returns DispatcherWrapper"""
|
||||
mock_backend = Mock()
|
||||
mock_config_receiver = Mock()
|
||||
manager = DispatcherManager(mock_backend, mock_config_receiver)
|
||||
manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock())
|
||||
|
||||
wrapper = manager.dispatch_core_export()
|
||||
|
||||
|
|
@ -118,7 +121,7 @@ class TestDispatcherManager:
|
|||
"""Test dispatch_core_import returns DispatcherWrapper"""
|
||||
mock_backend = Mock()
|
||||
mock_config_receiver = Mock()
|
||||
manager = DispatcherManager(mock_backend, mock_config_receiver)
|
||||
manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock())
|
||||
|
||||
wrapper = manager.dispatch_core_import()
|
||||
|
||||
|
|
@ -130,7 +133,7 @@ class TestDispatcherManager:
|
|||
"""Test process_core_import method"""
|
||||
mock_backend = Mock()
|
||||
mock_config_receiver = Mock()
|
||||
manager = DispatcherManager(mock_backend, mock_config_receiver)
|
||||
manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock())
|
||||
|
||||
with patch('trustgraph.gateway.dispatch.manager.CoreImport') as mock_core_import:
|
||||
mock_importer = Mock()
|
||||
|
|
@ -148,7 +151,7 @@ class TestDispatcherManager:
|
|||
"""Test process_core_export method"""
|
||||
mock_backend = Mock()
|
||||
mock_config_receiver = Mock()
|
||||
manager = DispatcherManager(mock_backend, mock_config_receiver)
|
||||
manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock())
|
||||
|
||||
with patch('trustgraph.gateway.dispatch.manager.CoreExport') as mock_core_export:
|
||||
mock_exporter = Mock()
|
||||
|
|
@ -166,7 +169,7 @@ class TestDispatcherManager:
|
|||
"""Test process_global_service method"""
|
||||
mock_backend = Mock()
|
||||
mock_config_receiver = Mock()
|
||||
manager = DispatcherManager(mock_backend, mock_config_receiver)
|
||||
manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock())
|
||||
|
||||
manager.invoke_global_service = AsyncMock(return_value="global_result")
|
||||
|
||||
|
|
@ -181,7 +184,7 @@ class TestDispatcherManager:
|
|||
"""Test invoke_global_service with existing dispatcher"""
|
||||
mock_backend = Mock()
|
||||
mock_config_receiver = Mock()
|
||||
manager = DispatcherManager(mock_backend, mock_config_receiver)
|
||||
manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock())
|
||||
|
||||
# Pre-populate with existing dispatcher
|
||||
mock_dispatcher = Mock()
|
||||
|
|
@ -198,7 +201,7 @@ class TestDispatcherManager:
|
|||
"""Test invoke_global_service creates new dispatcher"""
|
||||
mock_backend = Mock()
|
||||
mock_config_receiver = Mock()
|
||||
manager = DispatcherManager(mock_backend, mock_config_receiver)
|
||||
manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock())
|
||||
|
||||
with patch('trustgraph.gateway.dispatch.manager.global_dispatchers') as mock_dispatchers:
|
||||
mock_dispatcher_class = Mock()
|
||||
|
|
@ -230,7 +233,7 @@ class TestDispatcherManager:
|
|||
"""Test dispatch_flow_import returns correct method"""
|
||||
mock_backend = Mock()
|
||||
mock_config_receiver = Mock()
|
||||
manager = DispatcherManager(mock_backend, mock_config_receiver)
|
||||
manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock())
|
||||
|
||||
result = manager.dispatch_flow_import()
|
||||
|
||||
|
|
@ -240,7 +243,7 @@ class TestDispatcherManager:
|
|||
"""Test dispatch_flow_export returns correct method"""
|
||||
mock_backend = Mock()
|
||||
mock_config_receiver = Mock()
|
||||
manager = DispatcherManager(mock_backend, mock_config_receiver)
|
||||
manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock())
|
||||
|
||||
result = manager.dispatch_flow_export()
|
||||
|
||||
|
|
@ -250,7 +253,7 @@ class TestDispatcherManager:
|
|||
"""Test dispatch_socket returns correct method"""
|
||||
mock_backend = Mock()
|
||||
mock_config_receiver = Mock()
|
||||
manager = DispatcherManager(mock_backend, mock_config_receiver)
|
||||
manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock())
|
||||
|
||||
result = manager.dispatch_socket()
|
||||
|
||||
|
|
@ -260,7 +263,7 @@ class TestDispatcherManager:
|
|||
"""Test dispatch_flow_service returns DispatcherWrapper"""
|
||||
mock_backend = Mock()
|
||||
mock_config_receiver = Mock()
|
||||
manager = DispatcherManager(mock_backend, mock_config_receiver)
|
||||
manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock())
|
||||
|
||||
wrapper = manager.dispatch_flow_service()
|
||||
|
||||
|
|
@ -272,7 +275,7 @@ class TestDispatcherManager:
|
|||
"""Test process_flow_import with valid flow and kind"""
|
||||
mock_backend = Mock()
|
||||
mock_config_receiver = Mock()
|
||||
manager = DispatcherManager(mock_backend, mock_config_receiver)
|
||||
manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock())
|
||||
|
||||
# Setup test flow
|
||||
manager.flows[("default", "test_flow")] = {
|
||||
|
|
@ -308,7 +311,7 @@ class TestDispatcherManager:
|
|||
"""Test process_flow_import with invalid flow"""
|
||||
mock_backend = Mock()
|
||||
mock_config_receiver = Mock()
|
||||
manager = DispatcherManager(mock_backend, mock_config_receiver)
|
||||
manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock())
|
||||
|
||||
params = {"flow": "invalid_flow", "kind": "triples"}
|
||||
|
||||
|
|
@ -323,7 +326,7 @@ class TestDispatcherManager:
|
|||
warnings.simplefilter("ignore", RuntimeWarning)
|
||||
mock_backend = Mock()
|
||||
mock_config_receiver = Mock()
|
||||
manager = DispatcherManager(mock_backend, mock_config_receiver)
|
||||
manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock())
|
||||
|
||||
# Setup test flow
|
||||
manager.flows[("default", "test_flow")] = {
|
||||
|
|
@ -345,7 +348,7 @@ class TestDispatcherManager:
|
|||
"""Test process_flow_export with valid flow and kind"""
|
||||
mock_backend = Mock()
|
||||
mock_config_receiver = Mock()
|
||||
manager = DispatcherManager(mock_backend, mock_config_receiver)
|
||||
manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock())
|
||||
|
||||
# Setup test flow
|
||||
manager.flows[("default", "test_flow")] = {
|
||||
|
|
@ -378,26 +381,47 @@ class TestDispatcherManager:
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_socket(self):
|
||||
"""Test process_socket method"""
|
||||
"""process_socket constructs a Mux with the manager's auth
|
||||
instance passed through — this is the gateway's trust path
|
||||
for first-frame WebSocket authentication. A Mux cannot be
|
||||
built without auth (tested separately); this test pins that
|
||||
the dispatcher-manager threads the correct auth value into
|
||||
the Mux constructor call."""
|
||||
mock_backend = Mock()
|
||||
mock_config_receiver = Mock()
|
||||
manager = DispatcherManager(mock_backend, mock_config_receiver)
|
||||
|
||||
mock_auth = Mock()
|
||||
manager = DispatcherManager(
|
||||
mock_backend, mock_config_receiver, auth=mock_auth,
|
||||
)
|
||||
|
||||
with patch('trustgraph.gateway.dispatch.manager.Mux') as mock_mux:
|
||||
mock_mux_instance = Mock()
|
||||
mock_mux.return_value = mock_mux_instance
|
||||
|
||||
|
||||
result = await manager.process_socket("ws", "running", {})
|
||||
|
||||
mock_mux.assert_called_once_with(manager, "ws", "running")
|
||||
|
||||
mock_mux.assert_called_once_with(
|
||||
manager, "ws", "running", auth=mock_auth,
|
||||
)
|
||||
assert result == mock_mux_instance
|
||||
|
||||
def test_dispatcher_manager_requires_auth(self):
|
||||
"""Constructing a DispatcherManager without an auth argument
|
||||
must fail — a no-auth DispatcherManager would produce a
|
||||
Mux without authentication, silently downgrading the socket
|
||||
auth path."""
|
||||
mock_backend = Mock()
|
||||
mock_config_receiver = Mock()
|
||||
|
||||
with pytest.raises(ValueError, match="auth"):
|
||||
DispatcherManager(mock_backend, mock_config_receiver, auth=None)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_flow_service(self):
|
||||
"""Test process_flow_service method"""
|
||||
mock_backend = Mock()
|
||||
mock_config_receiver = Mock()
|
||||
manager = DispatcherManager(mock_backend, mock_config_receiver)
|
||||
manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock())
|
||||
|
||||
manager.invoke_flow_service = AsyncMock(return_value="flow_result")
|
||||
|
||||
|
|
@ -412,7 +436,7 @@ class TestDispatcherManager:
|
|||
"""Test invoke_flow_service with existing dispatcher"""
|
||||
mock_backend = Mock()
|
||||
mock_config_receiver = Mock()
|
||||
manager = DispatcherManager(mock_backend, mock_config_receiver)
|
||||
manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock())
|
||||
|
||||
# Add flow to the flows dictionary
|
||||
manager.flows[("default", "test_flow")] = {"services": {"agent": {}}}
|
||||
|
|
@ -432,7 +456,7 @@ class TestDispatcherManager:
|
|||
"""Test invoke_flow_service creates request-response dispatcher"""
|
||||
mock_backend = Mock()
|
||||
mock_config_receiver = Mock()
|
||||
manager = DispatcherManager(mock_backend, mock_config_receiver)
|
||||
manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock())
|
||||
|
||||
# Setup test flow
|
||||
manager.flows[("default", "test_flow")] = {
|
||||
|
|
@ -476,7 +500,7 @@ class TestDispatcherManager:
|
|||
"""Test invoke_flow_service creates sender dispatcher"""
|
||||
mock_backend = Mock()
|
||||
mock_config_receiver = Mock()
|
||||
manager = DispatcherManager(mock_backend, mock_config_receiver)
|
||||
manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock())
|
||||
|
||||
# Setup test flow
|
||||
manager.flows[("default", "test_flow")] = {
|
||||
|
|
@ -516,7 +540,7 @@ class TestDispatcherManager:
|
|||
"""Test invoke_flow_service with invalid flow"""
|
||||
mock_backend = Mock()
|
||||
mock_config_receiver = Mock()
|
||||
manager = DispatcherManager(mock_backend, mock_config_receiver)
|
||||
manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock())
|
||||
|
||||
with pytest.raises(RuntimeError, match="Invalid flow"):
|
||||
await manager.invoke_flow_service("data", "responder", "default", "invalid_flow", "agent")
|
||||
|
|
@ -526,7 +550,7 @@ class TestDispatcherManager:
|
|||
"""Test invoke_flow_service with kind not supported by flow"""
|
||||
mock_backend = Mock()
|
||||
mock_config_receiver = Mock()
|
||||
manager = DispatcherManager(mock_backend, mock_config_receiver)
|
||||
manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock())
|
||||
|
||||
# Setup test flow without agent interface
|
||||
manager.flows[("default", "test_flow")] = {
|
||||
|
|
@ -543,7 +567,7 @@ class TestDispatcherManager:
|
|||
"""Test invoke_flow_service with invalid kind"""
|
||||
mock_backend = Mock()
|
||||
mock_config_receiver = Mock()
|
||||
manager = DispatcherManager(mock_backend, mock_config_receiver)
|
||||
manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock())
|
||||
|
||||
# Setup test flow with interface but unsupported kind
|
||||
manager.flows[("default", "test_flow")] = {
|
||||
|
|
@ -570,7 +594,7 @@ class TestDispatcherManager:
|
|||
"""
|
||||
mock_backend = Mock()
|
||||
mock_config_receiver = Mock()
|
||||
manager = DispatcherManager(mock_backend, mock_config_receiver)
|
||||
manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock())
|
||||
|
||||
async def slow_start():
|
||||
# Yield to the event loop so other coroutines get a chance to run,
|
||||
|
|
@ -606,7 +630,7 @@ class TestDispatcherManager:
|
|||
"""
|
||||
mock_backend = Mock()
|
||||
mock_config_receiver = Mock()
|
||||
manager = DispatcherManager(mock_backend, mock_config_receiver)
|
||||
manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock())
|
||||
|
||||
manager.flows[("default", "test_flow")] = {
|
||||
"interfaces": {
|
||||
|
|
|
|||
|
|
@ -12,6 +12,19 @@ from trustgraph.gateway.dispatch.mux import Mux, MAX_QUEUE_SIZE
|
|||
class TestMux:
|
||||
"""Test cases for Mux class"""
|
||||
|
||||
def test_mux_requires_auth(self):
|
||||
"""Constructing a Mux without an ``auth`` argument must
|
||||
fail. The Mux implements the first-frame auth protocol and
|
||||
there is no no-auth mode — a no-auth Mux would silently
|
||||
accept every frame without authenticating it."""
|
||||
with pytest.raises(ValueError, match="auth"):
|
||||
Mux(
|
||||
dispatcher_manager=MagicMock(),
|
||||
ws=MagicMock(),
|
||||
running=MagicMock(),
|
||||
auth=None,
|
||||
)
|
||||
|
||||
def test_mux_initialization(self):
|
||||
"""Test Mux initialization"""
|
||||
mock_dispatcher_manager = MagicMock()
|
||||
|
|
@ -21,7 +34,8 @@ class TestMux:
|
|||
mux = Mux(
|
||||
dispatcher_manager=mock_dispatcher_manager,
|
||||
ws=mock_ws,
|
||||
running=mock_running
|
||||
running=mock_running,
|
||||
auth=MagicMock(),
|
||||
)
|
||||
|
||||
assert mux.dispatcher_manager == mock_dispatcher_manager
|
||||
|
|
@ -40,7 +54,8 @@ class TestMux:
|
|||
mux = Mux(
|
||||
dispatcher_manager=mock_dispatcher_manager,
|
||||
ws=mock_ws,
|
||||
running=mock_running
|
||||
running=mock_running,
|
||||
auth=MagicMock(),
|
||||
)
|
||||
|
||||
# Call destroy
|
||||
|
|
@ -61,7 +76,8 @@ class TestMux:
|
|||
mux = Mux(
|
||||
dispatcher_manager=mock_dispatcher_manager,
|
||||
ws=None,
|
||||
running=mock_running
|
||||
running=mock_running,
|
||||
auth=MagicMock(),
|
||||
)
|
||||
|
||||
# Call destroy
|
||||
|
|
@ -81,7 +97,8 @@ class TestMux:
|
|||
mux = Mux(
|
||||
dispatcher_manager=mock_dispatcher_manager,
|
||||
ws=mock_ws,
|
||||
running=mock_running
|
||||
running=mock_running,
|
||||
auth=MagicMock(),
|
||||
)
|
||||
|
||||
# Mock message with valid JSON
|
||||
|
|
@ -108,7 +125,8 @@ class TestMux:
|
|||
mux = Mux(
|
||||
dispatcher_manager=mock_dispatcher_manager,
|
||||
ws=mock_ws,
|
||||
running=mock_running
|
||||
running=mock_running,
|
||||
auth=MagicMock(),
|
||||
)
|
||||
|
||||
# Mock message without request field
|
||||
|
|
@ -137,7 +155,8 @@ class TestMux:
|
|||
mux = Mux(
|
||||
dispatcher_manager=mock_dispatcher_manager,
|
||||
ws=mock_ws,
|
||||
running=mock_running
|
||||
running=mock_running,
|
||||
auth=MagicMock(),
|
||||
)
|
||||
|
||||
# Mock message without id field
|
||||
|
|
@ -164,7 +183,8 @@ class TestMux:
|
|||
mux = Mux(
|
||||
dispatcher_manager=mock_dispatcher_manager,
|
||||
ws=mock_ws,
|
||||
running=mock_running
|
||||
running=mock_running,
|
||||
auth=MagicMock(),
|
||||
)
|
||||
|
||||
# Mock message with invalid JSON
|
||||
|
|
|
|||
|
|
@ -13,29 +13,36 @@ class TestConstantEndpoint:
|
|||
"""Test cases for ConstantEndpoint class"""
|
||||
|
||||
def test_constant_endpoint_initialization(self):
|
||||
"""Test ConstantEndpoint initialization"""
|
||||
"""Construction records the configured capability on the
|
||||
instance. The capability is a required argument — no
|
||||
permissive default — and the test passes an explicit
|
||||
value to demonstrate the contract."""
|
||||
mock_auth = MagicMock()
|
||||
mock_dispatcher = MagicMock()
|
||||
|
||||
|
||||
endpoint = ConstantEndpoint(
|
||||
endpoint_path="/api/test",
|
||||
auth=mock_auth,
|
||||
dispatcher=mock_dispatcher
|
||||
dispatcher=mock_dispatcher,
|
||||
capability="config:read",
|
||||
)
|
||||
|
||||
|
||||
assert endpoint.path == "/api/test"
|
||||
assert endpoint.auth == mock_auth
|
||||
assert endpoint.dispatcher == mock_dispatcher
|
||||
assert endpoint.operation == "service"
|
||||
assert endpoint.capability == "config:read"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_constant_endpoint_start_method(self):
|
||||
"""Test ConstantEndpoint start method (should be no-op)"""
|
||||
mock_auth = MagicMock()
|
||||
mock_dispatcher = MagicMock()
|
||||
|
||||
endpoint = ConstantEndpoint("/api/test", mock_auth, mock_dispatcher)
|
||||
|
||||
|
||||
endpoint = ConstantEndpoint(
|
||||
"/api/test", mock_auth, mock_dispatcher,
|
||||
capability="config:read",
|
||||
)
|
||||
|
||||
# start() should complete without error
|
||||
await endpoint.start()
|
||||
|
||||
|
|
@ -44,10 +51,13 @@ class TestConstantEndpoint:
|
|||
mock_auth = MagicMock()
|
||||
mock_dispatcher = MagicMock()
|
||||
mock_app = MagicMock()
|
||||
|
||||
endpoint = ConstantEndpoint("/api/test", mock_auth, mock_dispatcher)
|
||||
|
||||
endpoint = ConstantEndpoint(
|
||||
"/api/test", mock_auth, mock_dispatcher,
|
||||
capability="config:read",
|
||||
)
|
||||
endpoint.add_routes(mock_app)
|
||||
|
||||
|
||||
# Verify add_routes was called with POST route
|
||||
mock_app.add_routes.assert_called_once()
|
||||
# The call should include web.post with the path and handler
|
||||
|
|
|
|||
|
|
@ -1,4 +1,12 @@
|
|||
"""Tests for Gateway i18n pack endpoint."""
|
||||
"""Tests for Gateway i18n pack endpoint.
|
||||
|
||||
Production registers this endpoint with ``capability=PUBLIC``: the
|
||||
login UI needs to render its own i18n strings before any user has
|
||||
authenticated, so the endpoint is deliberately pre-auth. These
|
||||
tests exercise the PUBLIC configuration — that is the production
|
||||
contract. Behaviour of authenticated endpoints is covered by the
|
||||
IamAuth tests in ``test_auth.py``.
|
||||
"""
|
||||
|
||||
import json
|
||||
from unittest.mock import MagicMock
|
||||
|
|
@ -7,6 +15,7 @@ import pytest
|
|||
from aiohttp import web
|
||||
|
||||
from trustgraph.gateway.endpoint.i18n import I18nPackEndpoint
|
||||
from trustgraph.gateway.capabilities import PUBLIC
|
||||
|
||||
|
||||
class TestI18nPackEndpoint:
|
||||
|
|
@ -17,23 +26,28 @@ class TestI18nPackEndpoint:
|
|||
endpoint = I18nPackEndpoint(
|
||||
endpoint_path="/api/v1/i18n/packs/{lang}",
|
||||
auth=mock_auth,
|
||||
capability=PUBLIC,
|
||||
)
|
||||
|
||||
assert endpoint.path == "/api/v1/i18n/packs/{lang}"
|
||||
assert endpoint.auth == mock_auth
|
||||
assert endpoint.operation == "service"
|
||||
assert endpoint.capability == PUBLIC
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_i18n_endpoint_start_method(self):
|
||||
mock_auth = MagicMock()
|
||||
endpoint = I18nPackEndpoint("/api/v1/i18n/packs/{lang}", mock_auth)
|
||||
endpoint = I18nPackEndpoint(
|
||||
"/api/v1/i18n/packs/{lang}", mock_auth, capability=PUBLIC,
|
||||
)
|
||||
await endpoint.start()
|
||||
|
||||
def test_add_routes_registers_get_handler(self):
|
||||
mock_auth = MagicMock()
|
||||
mock_app = MagicMock()
|
||||
|
||||
endpoint = I18nPackEndpoint("/api/v1/i18n/packs/{lang}", mock_auth)
|
||||
endpoint = I18nPackEndpoint(
|
||||
"/api/v1/i18n/packs/{lang}", mock_auth, capability=PUBLIC,
|
||||
)
|
||||
endpoint.add_routes(mock_app)
|
||||
|
||||
mock_app.add_routes.assert_called_once()
|
||||
|
|
@ -41,35 +55,55 @@ class TestI18nPackEndpoint:
|
|||
assert len(call_args) == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_unauthorized_on_invalid_auth_scheme(self):
|
||||
async def test_handle_returns_pack_without_authenticating(self):
|
||||
"""The PUBLIC endpoint serves the language pack without
|
||||
invoking the auth handler at all — pre-login UI must be
|
||||
reachable. The test uses an auth mock that raises if
|
||||
touched, so any auth attempt by the endpoint is caught."""
|
||||
mock_auth = MagicMock()
|
||||
mock_auth.permitted.return_value = True
|
||||
|
||||
endpoint = I18nPackEndpoint("/api/v1/i18n/packs/{lang}", mock_auth)
|
||||
def _should_not_be_called(*args, **kwargs):
|
||||
raise AssertionError(
|
||||
"PUBLIC endpoint must not invoke auth.authenticate"
|
||||
)
|
||||
mock_auth.authenticate = _should_not_be_called
|
||||
|
||||
endpoint = I18nPackEndpoint(
|
||||
"/api/v1/i18n/packs/{lang}", mock_auth, capability=PUBLIC,
|
||||
)
|
||||
|
||||
request = MagicMock()
|
||||
request.path = "/api/v1/i18n/packs/en"
|
||||
# A caller-supplied Authorization header of any form should
|
||||
# be ignored — PUBLIC means we don't look at it.
|
||||
request.headers = {"Authorization": "Token abc"}
|
||||
request.match_info = {"lang": "en"}
|
||||
|
||||
resp = await endpoint.handle(request)
|
||||
assert isinstance(resp, web.HTTPUnauthorized)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_returns_pack_when_permitted(self):
|
||||
mock_auth = MagicMock()
|
||||
mock_auth.permitted.return_value = True
|
||||
|
||||
endpoint = I18nPackEndpoint("/api/v1/i18n/packs/{lang}", mock_auth)
|
||||
|
||||
request = MagicMock()
|
||||
request.path = "/api/v1/i18n/packs/en"
|
||||
request.headers = {}
|
||||
request.match_info = {"lang": "en"}
|
||||
|
||||
resp = await endpoint.handle(request)
|
||||
|
||||
assert resp.status == 200
|
||||
payload = json.loads(resp.body.decode("utf-8"))
|
||||
assert isinstance(payload, dict)
|
||||
assert "cli.verify_system_status.title" in payload
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_rejects_path_traversal(self):
|
||||
"""The ``lang`` path parameter is reflected through to the
|
||||
filesystem-backed pack loader. The endpoint contains an
|
||||
explicit defense against ``/`` and ``..`` in the value; this
|
||||
test pins that defense in place."""
|
||||
mock_auth = MagicMock()
|
||||
endpoint = I18nPackEndpoint(
|
||||
"/api/v1/i18n/packs/{lang}", mock_auth, capability=PUBLIC,
|
||||
)
|
||||
|
||||
for bad in ("../../etc/passwd", "en/../fr", "a/b"):
|
||||
request = MagicMock()
|
||||
request.path = f"/api/v1/i18n/packs/{bad}"
|
||||
request.headers = {}
|
||||
request.match_info = {"lang": bad}
|
||||
|
||||
resp = await endpoint.handle(request)
|
||||
assert isinstance(resp, web.HTTPBadRequest), (
|
||||
f"path-traversal defense did not reject lang={bad!r}"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -12,30 +12,24 @@ class TestEndpointManager:
|
|||
"""Test cases for EndpointManager class"""
|
||||
|
||||
def test_endpoint_manager_initialization(self):
|
||||
"""Test EndpointManager initialization creates all endpoints"""
|
||||
"""EndpointManager wires up the full endpoint set and
|
||||
records dispatcher_manager / timeout on the instance."""
|
||||
mock_dispatcher_manager = MagicMock()
|
||||
mock_auth = MagicMock()
|
||||
|
||||
# Mock dispatcher methods
|
||||
mock_dispatcher_manager.dispatch_global_service.return_value = MagicMock()
|
||||
mock_dispatcher_manager.dispatch_socket.return_value = MagicMock()
|
||||
mock_dispatcher_manager.dispatch_flow_service.return_value = MagicMock()
|
||||
mock_dispatcher_manager.dispatch_flow_import.return_value = MagicMock()
|
||||
mock_dispatcher_manager.dispatch_flow_export.return_value = MagicMock()
|
||||
mock_dispatcher_manager.dispatch_core_import.return_value = MagicMock()
|
||||
mock_dispatcher_manager.dispatch_core_export.return_value = MagicMock()
|
||||
|
||||
|
||||
# The dispatcher_manager exposes a small set of factory
|
||||
# methods — MagicMock auto-creates them, returning fresh
|
||||
# MagicMocks on each call.
|
||||
manager = EndpointManager(
|
||||
dispatcher_manager=mock_dispatcher_manager,
|
||||
auth=mock_auth,
|
||||
prometheus_url="http://prometheus:9090",
|
||||
timeout=300
|
||||
timeout=300,
|
||||
)
|
||||
|
||||
|
||||
assert manager.dispatcher_manager == mock_dispatcher_manager
|
||||
assert manager.timeout == 300
|
||||
assert manager.services == {}
|
||||
assert len(manager.endpoints) > 0 # Should have multiple endpoints
|
||||
assert len(manager.endpoints) > 0
|
||||
|
||||
def test_endpoint_manager_with_default_timeout(self):
|
||||
"""Test EndpointManager with default timeout value"""
|
||||
|
|
@ -79,9 +73,17 @@ class TestEndpointManager:
|
|||
prometheus_url="http://test:9090"
|
||||
)
|
||||
|
||||
# Verify all dispatcher methods were called during initialization
|
||||
# Each dispatcher factory is invoked once per endpoint that
|
||||
# needs a dedicated wire. dispatch_auth_iam is shared by
|
||||
# two endpoints — AuthEndpoints (login / bootstrap /
|
||||
# change-password) and IamEndpoint (registry-driven
|
||||
# /api/v1/iam) — so it's expected to be called twice.
|
||||
# Both forwarders pin the dispatcher to kind=iam and reuse
|
||||
# the same factory; they're distinct from
|
||||
# dispatch_global_service (the generic /api/v1/{kind} route).
|
||||
mock_dispatcher_manager.dispatch_global_service.assert_called_once()
|
||||
mock_dispatcher_manager.dispatch_socket.assert_called() # Called twice
|
||||
assert mock_dispatcher_manager.dispatch_auth_iam.call_count == 2
|
||||
mock_dispatcher_manager.dispatch_socket.assert_called_once()
|
||||
mock_dispatcher_manager.dispatch_flow_service.assert_called_once()
|
||||
mock_dispatcher_manager.dispatch_flow_import.assert_called_once()
|
||||
mock_dispatcher_manager.dispatch_flow_export.assert_called_once()
|
||||
|
|
|
|||
|
|
@ -12,31 +12,35 @@ class TestMetricsEndpoint:
|
|||
"""Test cases for MetricsEndpoint class"""
|
||||
|
||||
def test_metrics_endpoint_initialization(self):
|
||||
"""Test MetricsEndpoint initialization"""
|
||||
"""Construction records the configured capability on the
|
||||
instance. In production MetricsEndpoint is gated by
|
||||
'metrics:read' so that's the natural value to pass."""
|
||||
mock_auth = MagicMock()
|
||||
|
||||
|
||||
endpoint = MetricsEndpoint(
|
||||
prometheus_url="http://prometheus:9090",
|
||||
endpoint_path="/metrics",
|
||||
auth=mock_auth
|
||||
auth=mock_auth,
|
||||
capability="metrics:read",
|
||||
)
|
||||
|
||||
|
||||
assert endpoint.prometheus_url == "http://prometheus:9090"
|
||||
assert endpoint.path == "/metrics"
|
||||
assert endpoint.auth == mock_auth
|
||||
assert endpoint.operation == "service"
|
||||
assert endpoint.capability == "metrics:read"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_metrics_endpoint_start_method(self):
|
||||
"""Test MetricsEndpoint start method (should be no-op)"""
|
||||
mock_auth = MagicMock()
|
||||
|
||||
|
||||
endpoint = MetricsEndpoint(
|
||||
prometheus_url="http://localhost:9090",
|
||||
endpoint_path="/metrics",
|
||||
auth=mock_auth
|
||||
auth=mock_auth,
|
||||
capability="metrics:read",
|
||||
)
|
||||
|
||||
|
||||
# start() should complete without error
|
||||
await endpoint.start()
|
||||
|
||||
|
|
@ -44,15 +48,16 @@ class TestMetricsEndpoint:
|
|||
"""Test add_routes method registers GET route with wildcard path"""
|
||||
mock_auth = MagicMock()
|
||||
mock_app = MagicMock()
|
||||
|
||||
|
||||
endpoint = MetricsEndpoint(
|
||||
prometheus_url="http://prometheus:9090",
|
||||
endpoint_path="/metrics",
|
||||
auth=mock_auth
|
||||
auth=mock_auth,
|
||||
capability="metrics:read",
|
||||
)
|
||||
|
||||
|
||||
endpoint.add_routes(mock_app)
|
||||
|
||||
|
||||
# Verify add_routes was called with GET route
|
||||
mock_app.add_routes.assert_called_once()
|
||||
# The call should include web.get with wildcard path pattern
|
||||
|
|
|
|||
|
|
@ -1,5 +1,12 @@
|
|||
"""
|
||||
Tests for Gateway Socket Endpoint
|
||||
Tests for Gateway Socket Endpoint.
|
||||
|
||||
In production the only SocketEndpoint registered with HTTP-layer
|
||||
auth is ``/api/v1/socket`` using ``capability=AUTHENTICATED`` with
|
||||
``in_band_auth=True`` (first-frame auth over the websocket frames,
|
||||
not at the handshake). The tests below use AUTHENTICATED as the
|
||||
representative capability; construction / worker / listener
|
||||
behaviour is independent of which capability is configured.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
|
@ -7,41 +14,47 @@ from unittest.mock import MagicMock, AsyncMock
|
|||
from aiohttp import WSMsgType
|
||||
|
||||
from trustgraph.gateway.endpoint.socket import SocketEndpoint
|
||||
from trustgraph.gateway.capabilities import AUTHENTICATED
|
||||
|
||||
|
||||
class TestSocketEndpoint:
|
||||
"""Test cases for SocketEndpoint class"""
|
||||
|
||||
def test_socket_endpoint_initialization(self):
|
||||
"""Test SocketEndpoint initialization"""
|
||||
"""Construction records the configured capability on the
|
||||
instance. No permissive default is applied."""
|
||||
mock_auth = MagicMock()
|
||||
mock_dispatcher = MagicMock()
|
||||
|
||||
|
||||
endpoint = SocketEndpoint(
|
||||
endpoint_path="/api/socket",
|
||||
auth=mock_auth,
|
||||
dispatcher=mock_dispatcher
|
||||
dispatcher=mock_dispatcher,
|
||||
capability=AUTHENTICATED,
|
||||
)
|
||||
|
||||
|
||||
assert endpoint.path == "/api/socket"
|
||||
assert endpoint.auth == mock_auth
|
||||
assert endpoint.dispatcher == mock_dispatcher
|
||||
assert endpoint.operation == "socket"
|
||||
assert endpoint.capability == AUTHENTICATED
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_worker_method(self):
|
||||
"""Test SocketEndpoint worker method"""
|
||||
mock_auth = MagicMock()
|
||||
mock_dispatcher = AsyncMock()
|
||||
|
||||
endpoint = SocketEndpoint("/api/socket", mock_auth, mock_dispatcher)
|
||||
|
||||
|
||||
endpoint = SocketEndpoint(
|
||||
"/api/socket", mock_auth, mock_dispatcher,
|
||||
capability=AUTHENTICATED,
|
||||
)
|
||||
|
||||
mock_ws = MagicMock()
|
||||
mock_running = MagicMock()
|
||||
|
||||
|
||||
# Call worker method
|
||||
await endpoint.worker(mock_ws, mock_dispatcher, mock_running)
|
||||
|
||||
|
||||
# Verify dispatcher.run was called
|
||||
mock_dispatcher.run.assert_called_once()
|
||||
|
||||
|
|
@ -50,8 +63,11 @@ class TestSocketEndpoint:
|
|||
"""Test SocketEndpoint listener method with text message"""
|
||||
mock_auth = MagicMock()
|
||||
mock_dispatcher = AsyncMock()
|
||||
|
||||
endpoint = SocketEndpoint("/api/socket", mock_auth, mock_dispatcher)
|
||||
|
||||
endpoint = SocketEndpoint(
|
||||
"/api/socket", mock_auth, mock_dispatcher,
|
||||
capability=AUTHENTICATED,
|
||||
)
|
||||
|
||||
# Mock websocket with text message
|
||||
mock_msg = MagicMock()
|
||||
|
|
@ -80,8 +96,11 @@ class TestSocketEndpoint:
|
|||
"""Test SocketEndpoint listener method with binary message"""
|
||||
mock_auth = MagicMock()
|
||||
mock_dispatcher = AsyncMock()
|
||||
|
||||
endpoint = SocketEndpoint("/api/socket", mock_auth, mock_dispatcher)
|
||||
|
||||
endpoint = SocketEndpoint(
|
||||
"/api/socket", mock_auth, mock_dispatcher,
|
||||
capability=AUTHENTICATED,
|
||||
)
|
||||
|
||||
# Mock websocket with binary message
|
||||
mock_msg = MagicMock()
|
||||
|
|
@ -110,8 +129,11 @@ class TestSocketEndpoint:
|
|||
"""Test SocketEndpoint listener method with close message"""
|
||||
mock_auth = MagicMock()
|
||||
mock_dispatcher = AsyncMock()
|
||||
|
||||
endpoint = SocketEndpoint("/api/socket", mock_auth, mock_dispatcher)
|
||||
|
||||
endpoint = SocketEndpoint(
|
||||
"/api/socket", mock_auth, mock_dispatcher,
|
||||
capability=AUTHENTICATED,
|
||||
)
|
||||
|
||||
# Mock websocket with close message
|
||||
mock_msg = MagicMock()
|
||||
|
|
|
|||
|
|
@ -12,48 +12,57 @@ class TestStreamEndpoint:
|
|||
"""Test cases for StreamEndpoint class"""
|
||||
|
||||
def test_stream_endpoint_initialization_with_post(self):
|
||||
"""Test StreamEndpoint initialization with POST method"""
|
||||
"""Construction records the configured capability on the
|
||||
instance. StreamEndpoint is used in production for the
|
||||
core-import / core-export / document-stream routes; a
|
||||
document-write capability is a realistic value for a POST
|
||||
stream (e.g. core-import)."""
|
||||
mock_auth = MagicMock()
|
||||
mock_dispatcher = MagicMock()
|
||||
|
||||
|
||||
endpoint = StreamEndpoint(
|
||||
endpoint_path="/api/stream",
|
||||
auth=mock_auth,
|
||||
dispatcher=mock_dispatcher,
|
||||
method="POST"
|
||||
capability="documents:write",
|
||||
method="POST",
|
||||
)
|
||||
|
||||
|
||||
assert endpoint.path == "/api/stream"
|
||||
assert endpoint.auth == mock_auth
|
||||
assert endpoint.dispatcher == mock_dispatcher
|
||||
assert endpoint.operation == "service"
|
||||
assert endpoint.capability == "documents:write"
|
||||
assert endpoint.method == "POST"
|
||||
|
||||
def test_stream_endpoint_initialization_with_get(self):
|
||||
"""Test StreamEndpoint initialization with GET method"""
|
||||
"""GET stream — export-style endpoint, read capability."""
|
||||
mock_auth = MagicMock()
|
||||
mock_dispatcher = MagicMock()
|
||||
|
||||
|
||||
endpoint = StreamEndpoint(
|
||||
endpoint_path="/api/stream",
|
||||
auth=mock_auth,
|
||||
dispatcher=mock_dispatcher,
|
||||
method="GET"
|
||||
capability="documents:read",
|
||||
method="GET",
|
||||
)
|
||||
|
||||
|
||||
assert endpoint.method == "GET"
|
||||
|
||||
def test_stream_endpoint_initialization_default_method(self):
|
||||
"""Test StreamEndpoint initialization with default POST method"""
|
||||
"""Test StreamEndpoint initialization with default POST method.
|
||||
The method default is cosmetic; the capability is not
|
||||
defaulted — it is always required."""
|
||||
mock_auth = MagicMock()
|
||||
mock_dispatcher = MagicMock()
|
||||
|
||||
|
||||
endpoint = StreamEndpoint(
|
||||
endpoint_path="/api/stream",
|
||||
auth=mock_auth,
|
||||
dispatcher=mock_dispatcher
|
||||
dispatcher=mock_dispatcher,
|
||||
capability="documents:write",
|
||||
)
|
||||
|
||||
|
||||
assert endpoint.method == "POST" # Default value
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -61,9 +70,12 @@ class TestStreamEndpoint:
|
|||
"""Test StreamEndpoint start method (should be no-op)"""
|
||||
mock_auth = MagicMock()
|
||||
mock_dispatcher = MagicMock()
|
||||
|
||||
endpoint = StreamEndpoint("/api/stream", mock_auth, mock_dispatcher)
|
||||
|
||||
|
||||
endpoint = StreamEndpoint(
|
||||
"/api/stream", mock_auth, mock_dispatcher,
|
||||
capability="documents:write",
|
||||
)
|
||||
|
||||
# start() should complete without error
|
||||
await endpoint.start()
|
||||
|
||||
|
|
@ -72,16 +84,17 @@ class TestStreamEndpoint:
|
|||
mock_auth = MagicMock()
|
||||
mock_dispatcher = MagicMock()
|
||||
mock_app = MagicMock()
|
||||
|
||||
|
||||
endpoint = StreamEndpoint(
|
||||
endpoint_path="/api/stream",
|
||||
auth=mock_auth,
|
||||
dispatcher=mock_dispatcher,
|
||||
method="POST"
|
||||
capability="documents:write",
|
||||
method="POST",
|
||||
)
|
||||
|
||||
|
||||
endpoint.add_routes(mock_app)
|
||||
|
||||
|
||||
# Verify add_routes was called with POST route
|
||||
mock_app.add_routes.assert_called_once()
|
||||
call_args = mock_app.add_routes.call_args[0][0]
|
||||
|
|
@ -92,16 +105,17 @@ class TestStreamEndpoint:
|
|||
mock_auth = MagicMock()
|
||||
mock_dispatcher = MagicMock()
|
||||
mock_app = MagicMock()
|
||||
|
||||
|
||||
endpoint = StreamEndpoint(
|
||||
endpoint_path="/api/stream",
|
||||
auth=mock_auth,
|
||||
dispatcher=mock_dispatcher,
|
||||
method="GET"
|
||||
capability="documents:read",
|
||||
method="GET",
|
||||
)
|
||||
|
||||
|
||||
endpoint.add_routes(mock_app)
|
||||
|
||||
|
||||
# Verify add_routes was called with GET route
|
||||
mock_app.add_routes.assert_called_once()
|
||||
call_args = mock_app.add_routes.call_args[0][0]
|
||||
|
|
@ -112,13 +126,14 @@ class TestStreamEndpoint:
|
|||
mock_auth = MagicMock()
|
||||
mock_dispatcher = MagicMock()
|
||||
mock_app = MagicMock()
|
||||
|
||||
|
||||
endpoint = StreamEndpoint(
|
||||
endpoint_path="/api/stream",
|
||||
auth=mock_auth,
|
||||
dispatcher=mock_dispatcher,
|
||||
method="INVALID"
|
||||
capability="documents:write",
|
||||
method="INVALID",
|
||||
)
|
||||
|
||||
|
||||
with pytest.raises(RuntimeError, match="Bad method"):
|
||||
endpoint.add_routes(mock_app)
|
||||
|
|
@ -12,29 +12,36 @@ class TestVariableEndpoint:
|
|||
"""Test cases for VariableEndpoint class"""
|
||||
|
||||
def test_variable_endpoint_initialization(self):
|
||||
"""Test VariableEndpoint initialization"""
|
||||
"""Construction records the configured capability on the
|
||||
instance. VariableEndpoint is used in production for the
|
||||
/api/v1/{kind} admin-scoped global service routes, so a
|
||||
write-side capability is a realistic value for the test."""
|
||||
mock_auth = MagicMock()
|
||||
mock_dispatcher = MagicMock()
|
||||
|
||||
|
||||
endpoint = VariableEndpoint(
|
||||
endpoint_path="/api/variable",
|
||||
auth=mock_auth,
|
||||
dispatcher=mock_dispatcher
|
||||
dispatcher=mock_dispatcher,
|
||||
capability="config:write",
|
||||
)
|
||||
|
||||
|
||||
assert endpoint.path == "/api/variable"
|
||||
assert endpoint.auth == mock_auth
|
||||
assert endpoint.dispatcher == mock_dispatcher
|
||||
assert endpoint.operation == "service"
|
||||
assert endpoint.capability == "config:write"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_variable_endpoint_start_method(self):
|
||||
"""Test VariableEndpoint start method (should be no-op)"""
|
||||
mock_auth = MagicMock()
|
||||
mock_dispatcher = MagicMock()
|
||||
|
||||
endpoint = VariableEndpoint("/api/var", mock_auth, mock_dispatcher)
|
||||
|
||||
|
||||
endpoint = VariableEndpoint(
|
||||
"/api/var", mock_auth, mock_dispatcher,
|
||||
capability="config:write",
|
||||
)
|
||||
|
||||
# start() should complete without error
|
||||
await endpoint.start()
|
||||
|
||||
|
|
@ -43,10 +50,13 @@ class TestVariableEndpoint:
|
|||
mock_auth = MagicMock()
|
||||
mock_dispatcher = MagicMock()
|
||||
mock_app = MagicMock()
|
||||
|
||||
endpoint = VariableEndpoint("/api/variable", mock_auth, mock_dispatcher)
|
||||
|
||||
endpoint = VariableEndpoint(
|
||||
"/api/variable", mock_auth, mock_dispatcher,
|
||||
capability="config:write",
|
||||
)
|
||||
endpoint.add_routes(mock_app)
|
||||
|
||||
|
||||
# Verify add_routes was called with POST route
|
||||
mock_app.add_routes.assert_called_once()
|
||||
call_args = mock_app.add_routes.call_args[0][0]
|
||||
|
|
|
|||
|
|
@ -1,355 +1,179 @@
|
|||
"""
|
||||
Tests for Gateway Service API
|
||||
Tests for gateway/service.py — the Api class that wires together
|
||||
the pub/sub backend, IAM auth, config receiver, dispatcher manager,
|
||||
and endpoint manager.
|
||||
|
||||
The legacy ``GATEWAY_SECRET`` / ``default_api_token`` / allow-all
|
||||
surface is gone, so the tests here focus on the Api's construction
|
||||
and composition rather than the removed auth behaviour. IamAuth's
|
||||
own behaviour is covered in test_auth.py.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import asyncio
|
||||
from unittest.mock import Mock, patch, MagicMock, AsyncMock
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
from aiohttp import web
|
||||
import pulsar
|
||||
|
||||
from trustgraph.gateway.service import Api, run, default_pulsar_host, default_prometheus_url, default_timeout, default_port, default_api_token
|
||||
|
||||
# Tests for Gateway Service API
|
||||
from trustgraph.gateway.service import (
|
||||
Api,
|
||||
default_pulsar_host, default_prometheus_url,
|
||||
default_timeout, default_port,
|
||||
)
|
||||
from trustgraph.gateway.auth import IamAuth
|
||||
|
||||
|
||||
class TestApi:
|
||||
"""Test cases for Api class"""
|
||||
|
||||
# -- constants -------------------------------------------------------------
|
||||
|
||||
def test_api_initialization_with_defaults(self):
|
||||
"""Test Api initialization with default values"""
|
||||
with patch('trustgraph.gateway.service.get_pubsub') as mock_get_pubsub:
|
||||
mock_backend = Mock()
|
||||
mock_get_pubsub.return_value = mock_backend
|
||||
|
||||
api = Api()
|
||||
class TestDefaults:
|
||||
|
||||
assert api.port == default_port
|
||||
assert api.timeout == default_timeout
|
||||
assert api.pulsar_host == default_pulsar_host
|
||||
assert api.pulsar_api_key is None
|
||||
assert api.prometheus_url == default_prometheus_url + "/"
|
||||
assert api.auth.allow_all is True
|
||||
def test_exports_default_constants(self):
|
||||
# These are consumed by CLIs / tests / docs. Sanity-check
|
||||
# that they're the expected shape.
|
||||
assert default_port == 8088
|
||||
assert default_timeout == 600
|
||||
assert default_pulsar_host.startswith("pulsar://")
|
||||
assert default_prometheus_url.startswith("http")
|
||||
|
||||
# Verify get_pubsub was called
|
||||
mock_get_pubsub.assert_called_once()
|
||||
|
||||
def test_api_initialization_with_custom_config(self):
|
||||
"""Test Api initialization with custom configuration"""
|
||||
# -- Api construction ------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_backend():
|
||||
return Mock()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def api(mock_backend):
|
||||
with patch(
|
||||
"trustgraph.gateway.service.get_pubsub",
|
||||
return_value=mock_backend,
|
||||
):
|
||||
yield Api()
|
||||
|
||||
|
||||
class TestApiConstruction:
|
||||
|
||||
def test_defaults(self, api):
|
||||
assert api.port == default_port
|
||||
assert api.timeout == default_timeout
|
||||
assert api.pulsar_host == default_pulsar_host
|
||||
assert api.pulsar_api_key is None
|
||||
# prometheus_url gets normalised with a trailing slash
|
||||
assert api.prometheus_url == default_prometheus_url + "/"
|
||||
|
||||
def test_auth_is_iam_backed(self, api):
|
||||
# Any Api always gets an IamAuth. There is no "no auth" mode
|
||||
# (GATEWAY_SECRET / allow_all has been removed — see IAM spec).
|
||||
assert isinstance(api.auth, IamAuth)
|
||||
|
||||
def test_components_wired(self, api):
|
||||
assert api.config_receiver is not None
|
||||
assert api.dispatcher_manager is not None
|
||||
assert api.endpoint_manager is not None
|
||||
|
||||
def test_dispatcher_manager_has_auth(self, api):
|
||||
# The Mux uses this handle for first-frame socket auth.
|
||||
assert api.dispatcher_manager.auth is api.auth
|
||||
|
||||
def test_custom_config(self, mock_backend):
|
||||
config = {
|
||||
"port": 9000,
|
||||
"timeout": 300,
|
||||
"pulsar_host": "pulsar://custom-host:6650",
|
||||
"pulsar_api_key": "test-api-key",
|
||||
"pulsar_listener": "custom-listener",
|
||||
"pulsar_api_key": "custom-key",
|
||||
"prometheus_url": "http://custom-prometheus:9090",
|
||||
"api_token": "secret-token"
|
||||
}
|
||||
with patch(
|
||||
"trustgraph.gateway.service.get_pubsub",
|
||||
return_value=mock_backend,
|
||||
):
|
||||
a = Api(**config)
|
||||
|
||||
with patch('trustgraph.gateway.service.get_pubsub') as mock_get_pubsub:
|
||||
mock_backend = Mock()
|
||||
mock_get_pubsub.return_value = mock_backend
|
||||
assert a.port == 9000
|
||||
assert a.timeout == 300
|
||||
assert a.pulsar_host == "pulsar://custom-host:6650"
|
||||
assert a.pulsar_api_key == "custom-key"
|
||||
# Trailing slash added.
|
||||
assert a.prometheus_url == "http://custom-prometheus:9090/"
|
||||
|
||||
api = Api(**config)
|
||||
def test_prometheus_url_already_has_trailing_slash(self, mock_backend):
|
||||
with patch(
|
||||
"trustgraph.gateway.service.get_pubsub",
|
||||
return_value=mock_backend,
|
||||
):
|
||||
a = Api(prometheus_url="http://p:9090/")
|
||||
assert a.prometheus_url == "http://p:9090/"
|
||||
|
||||
assert api.port == 9000
|
||||
assert api.timeout == 300
|
||||
assert api.pulsar_host == "pulsar://custom-host:6650"
|
||||
assert api.pulsar_api_key == "test-api-key"
|
||||
assert api.prometheus_url == "http://custom-prometheus:9090/"
|
||||
assert api.auth.token == "secret-token"
|
||||
assert api.auth.allow_all is False
|
||||
def test_queue_overrides_parsed_for_config(self, mock_backend):
|
||||
with patch(
|
||||
"trustgraph.gateway.service.get_pubsub",
|
||||
return_value=mock_backend,
|
||||
):
|
||||
a = Api(
|
||||
config_request_queue="alt-config-req",
|
||||
config_response_queue="alt-config-resp",
|
||||
)
|
||||
overrides = a.dispatcher_manager.queue_overrides
|
||||
assert overrides.get("config", {}).get("request") == "alt-config-req"
|
||||
assert overrides.get("config", {}).get("response") == "alt-config-resp"
|
||||
|
||||
# Verify get_pubsub was called with config
|
||||
mock_get_pubsub.assert_called_once_with(**config)
|
||||
|
||||
def test_api_initialization_with_pulsar_api_key(self):
|
||||
"""Test Api initialization with Pulsar API key authentication"""
|
||||
with patch('trustgraph.gateway.service.get_pubsub') as mock_get_pubsub:
|
||||
mock_get_pubsub.return_value = Mock()
|
||||
# -- app_factory -----------------------------------------------------------
|
||||
|
||||
api = Api(pulsar_api_key="test-key")
|
||||
|
||||
# Verify api key was stored
|
||||
assert api.pulsar_api_key == "test-key"
|
||||
mock_get_pubsub.assert_called_once()
|
||||
|
||||
def test_api_initialization_prometheus_url_normalization(self):
|
||||
"""Test that prometheus_url gets normalized with trailing slash"""
|
||||
with patch('trustgraph.gateway.service.get_pubsub') as mock_get_pubsub:
|
||||
mock_get_pubsub.return_value = Mock()
|
||||
|
||||
# Test URL without trailing slash
|
||||
api = Api(prometheus_url="http://prometheus:9090")
|
||||
assert api.prometheus_url == "http://prometheus:9090/"
|
||||
|
||||
# Test URL with trailing slash
|
||||
api = Api(prometheus_url="http://prometheus:9090/")
|
||||
assert api.prometheus_url == "http://prometheus:9090/"
|
||||
|
||||
def test_api_initialization_empty_api_token_means_no_auth(self):
|
||||
"""Test that empty API token results in allow_all authentication"""
|
||||
with patch('trustgraph.gateway.service.get_pubsub') as mock_get_pubsub:
|
||||
mock_get_pubsub.return_value = Mock()
|
||||
|
||||
api = Api(api_token="")
|
||||
assert api.auth.allow_all is True
|
||||
|
||||
def test_api_initialization_none_api_token_means_no_auth(self):
|
||||
"""Test that None API token results in allow_all authentication"""
|
||||
with patch('trustgraph.gateway.service.get_pubsub') as mock_get_pubsub:
|
||||
mock_get_pubsub.return_value = Mock()
|
||||
|
||||
api = Api(api_token=None)
|
||||
assert api.auth.allow_all is True
|
||||
class TestAppFactory:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_app_factory_creates_application(self):
|
||||
"""Test that app_factory creates aiohttp application"""
|
||||
with patch('trustgraph.gateway.service.get_pubsub') as mock_get_pubsub:
|
||||
mock_get_pubsub.return_value = Mock()
|
||||
|
||||
api = Api()
|
||||
|
||||
# Mock the dependencies
|
||||
api.config_receiver = Mock()
|
||||
api.config_receiver.start = AsyncMock()
|
||||
api.endpoint_manager = Mock()
|
||||
api.endpoint_manager.add_routes = Mock()
|
||||
api.endpoint_manager.start = AsyncMock()
|
||||
|
||||
app = await api.app_factory()
|
||||
|
||||
assert isinstance(app, web.Application)
|
||||
assert app._client_max_size == 256 * 1024 * 1024
|
||||
|
||||
# Verify that config receiver was started
|
||||
api.config_receiver.start.assert_called_once()
|
||||
|
||||
# Verify that endpoint manager was configured
|
||||
api.endpoint_manager.add_routes.assert_called_once_with(app)
|
||||
api.endpoint_manager.start.assert_called_once()
|
||||
async def test_creates_aiohttp_app(self, api):
|
||||
# Stub out the long-tail dependencies that reach out to IAM /
|
||||
# pub/sub so we can exercise the factory in isolation.
|
||||
api.auth.start = AsyncMock()
|
||||
api.config_receiver = Mock()
|
||||
api.config_receiver.start = AsyncMock()
|
||||
api.endpoint_manager = Mock()
|
||||
api.endpoint_manager.add_routes = Mock()
|
||||
api.endpoint_manager.start = AsyncMock()
|
||||
api.endpoints = []
|
||||
|
||||
app = await api.app_factory()
|
||||
|
||||
assert isinstance(app, web.Application)
|
||||
assert app._client_max_size == 256 * 1024 * 1024
|
||||
api.auth.start.assert_called_once()
|
||||
api.config_receiver.start.assert_called_once()
|
||||
api.endpoint_manager.add_routes.assert_called_once_with(app)
|
||||
api.endpoint_manager.start.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_app_factory_with_custom_endpoints(self):
|
||||
"""Test app_factory with custom endpoints"""
|
||||
with patch('trustgraph.gateway.service.get_pubsub') as mock_get_pubsub:
|
||||
mock_get_pubsub.return_value = Mock()
|
||||
|
||||
api = Api()
|
||||
|
||||
# Mock custom endpoints
|
||||
mock_endpoint1 = Mock()
|
||||
mock_endpoint1.add_routes = Mock()
|
||||
mock_endpoint1.start = AsyncMock()
|
||||
|
||||
mock_endpoint2 = Mock()
|
||||
mock_endpoint2.add_routes = Mock()
|
||||
mock_endpoint2.start = AsyncMock()
|
||||
|
||||
api.endpoints = [mock_endpoint1, mock_endpoint2]
|
||||
|
||||
# Mock the dependencies
|
||||
api.config_receiver = Mock()
|
||||
api.config_receiver.start = AsyncMock()
|
||||
api.endpoint_manager = Mock()
|
||||
api.endpoint_manager.add_routes = Mock()
|
||||
api.endpoint_manager.start = AsyncMock()
|
||||
|
||||
app = await api.app_factory()
|
||||
|
||||
# Verify custom endpoints were configured
|
||||
mock_endpoint1.add_routes.assert_called_once_with(app)
|
||||
mock_endpoint1.start.assert_called_once()
|
||||
mock_endpoint2.add_routes.assert_called_once_with(app)
|
||||
mock_endpoint2.start.assert_called_once()
|
||||
async def test_auth_start_runs_before_accepting_traffic(self, api):
|
||||
"""``auth.start()`` fetches the IAM signing key, and must
|
||||
complete (or time out) before the gateway begins accepting
|
||||
requests. It's the first await in app_factory."""
|
||||
order = []
|
||||
|
||||
def test_run_method_calls_web_run_app(self):
|
||||
"""Test that run method calls web.run_app"""
|
||||
with patch('trustgraph.gateway.service.get_pubsub') as mock_get_pubsub, \
|
||||
patch('aiohttp.web.run_app') as mock_run_app:
|
||||
mock_get_pubsub.return_value = Mock()
|
||||
# AsyncMock.side_effect expects a sync callable (its return
|
||||
# value becomes the coroutine's return); a plain list.append
|
||||
# avoids the "coroutine was never awaited" trap of an async
|
||||
# side_effect.
|
||||
api.auth.start = AsyncMock(
|
||||
side_effect=lambda: order.append("auth"),
|
||||
)
|
||||
api.config_receiver = Mock()
|
||||
api.config_receiver.start = AsyncMock(
|
||||
side_effect=lambda: order.append("config"),
|
||||
)
|
||||
api.endpoint_manager = Mock()
|
||||
api.endpoint_manager.add_routes = Mock()
|
||||
api.endpoint_manager.start = AsyncMock(
|
||||
side_effect=lambda: order.append("endpoints"),
|
||||
)
|
||||
api.endpoints = []
|
||||
|
||||
# Api.run() passes self.app_factory() — a coroutine — to
|
||||
# web.run_app, which would normally consume it inside its own
|
||||
# event loop. Since we mock run_app, close the coroutine here
|
||||
# so it doesn't leak as an "unawaited coroutine" RuntimeWarning.
|
||||
def _consume_coro(coro, **kwargs):
|
||||
coro.close()
|
||||
mock_run_app.side_effect = _consume_coro
|
||||
await api.app_factory()
|
||||
|
||||
api = Api(port=8080)
|
||||
api.run()
|
||||
|
||||
# Verify run_app was called once with the correct port
|
||||
mock_run_app.assert_called_once()
|
||||
args, kwargs = mock_run_app.call_args
|
||||
assert len(args) == 1 # Should have one positional arg (the coroutine)
|
||||
assert kwargs == {'port': 8080} # Should have port keyword arg
|
||||
|
||||
def test_api_components_initialization(self):
|
||||
"""Test that all API components are properly initialized"""
|
||||
with patch('trustgraph.gateway.service.get_pubsub') as mock_get_pubsub:
|
||||
mock_get_pubsub.return_value = Mock()
|
||||
|
||||
api = Api()
|
||||
|
||||
# Verify all components are initialized
|
||||
assert api.config_receiver is not None
|
||||
assert api.dispatcher_manager is not None
|
||||
assert api.endpoint_manager is not None
|
||||
assert api.endpoints == []
|
||||
|
||||
# Verify component relationships
|
||||
assert api.dispatcher_manager.backend == api.pubsub_backend
|
||||
assert api.dispatcher_manager.config_receiver == api.config_receiver
|
||||
assert api.endpoint_manager.dispatcher_manager == api.dispatcher_manager
|
||||
# EndpointManager doesn't store auth directly, it passes it to individual endpoints
|
||||
|
||||
|
||||
class TestRunFunction:
|
||||
"""Test cases for the run() function"""
|
||||
|
||||
def test_run_function_with_metrics_enabled(self):
|
||||
"""Test run function with metrics enabled"""
|
||||
import warnings
|
||||
# Suppress the specific async warning with a broader pattern
|
||||
warnings.filterwarnings("ignore", message=".*Api.app_factory.*was never awaited", category=RuntimeWarning)
|
||||
|
||||
with patch('argparse.ArgumentParser.parse_args') as mock_parse_args, \
|
||||
patch('trustgraph.gateway.service.start_http_server') as mock_start_http_server:
|
||||
|
||||
# Mock command line arguments
|
||||
mock_args = Mock()
|
||||
mock_args.metrics = True
|
||||
mock_args.metrics_port = 8000
|
||||
mock_parse_args.return_value = mock_args
|
||||
|
||||
# Create a simple mock instance without any async methods
|
||||
mock_api_instance = Mock()
|
||||
mock_api_instance.run = Mock()
|
||||
|
||||
# Create a mock Api class without importing the real one
|
||||
mock_api = Mock(return_value=mock_api_instance)
|
||||
|
||||
# Patch using context manager to avoid importing the real Api class
|
||||
with patch('trustgraph.gateway.service.Api', mock_api):
|
||||
# Mock vars() to return a dict
|
||||
with patch('builtins.vars') as mock_vars:
|
||||
mock_vars.return_value = {
|
||||
'metrics': True,
|
||||
'metrics_port': 8000,
|
||||
'pulsar_host': default_pulsar_host,
|
||||
'timeout': default_timeout
|
||||
}
|
||||
|
||||
run()
|
||||
|
||||
# Verify metrics server was started
|
||||
mock_start_http_server.assert_called_once_with(8000)
|
||||
|
||||
# Verify Api was created and run was called
|
||||
mock_api.assert_called_once()
|
||||
mock_api_instance.run.assert_called_once()
|
||||
|
||||
@patch('trustgraph.gateway.service.start_http_server')
|
||||
@patch('argparse.ArgumentParser.parse_args')
|
||||
def test_run_function_with_metrics_disabled(self, mock_parse_args, mock_start_http_server):
|
||||
"""Test run function with metrics disabled"""
|
||||
# Mock command line arguments
|
||||
mock_args = Mock()
|
||||
mock_args.metrics = False
|
||||
mock_parse_args.return_value = mock_args
|
||||
|
||||
# Create a simple mock instance without any async methods
|
||||
mock_api_instance = Mock()
|
||||
mock_api_instance.run = Mock()
|
||||
|
||||
# Patch the Api class inside the test without using decorators
|
||||
with patch('trustgraph.gateway.service.Api') as mock_api:
|
||||
mock_api.return_value = mock_api_instance
|
||||
|
||||
# Mock vars() to return a dict
|
||||
with patch('builtins.vars') as mock_vars:
|
||||
mock_vars.return_value = {
|
||||
'metrics': False,
|
||||
'metrics_port': 8000,
|
||||
'pulsar_host': default_pulsar_host,
|
||||
'timeout': default_timeout
|
||||
}
|
||||
|
||||
run()
|
||||
|
||||
# Verify metrics server was NOT started
|
||||
mock_start_http_server.assert_not_called()
|
||||
|
||||
# Verify Api was created and run was called
|
||||
mock_api.assert_called_once()
|
||||
mock_api_instance.run.assert_called_once()
|
||||
|
||||
@patch('argparse.ArgumentParser.parse_args')
|
||||
def test_run_function_argument_parsing(self, mock_parse_args):
|
||||
"""Test that run function properly parses command line arguments"""
|
||||
# Mock command line arguments
|
||||
mock_args = Mock()
|
||||
mock_args.metrics = False
|
||||
mock_parse_args.return_value = mock_args
|
||||
|
||||
# Create a simple mock instance without any async methods
|
||||
mock_api_instance = Mock()
|
||||
mock_api_instance.run = Mock()
|
||||
|
||||
# Mock vars() to return a dict with all expected arguments
|
||||
expected_args = {
|
||||
'pulsar_host': 'pulsar://test:6650',
|
||||
'pulsar_api_key': 'test-key',
|
||||
'pulsar_listener': 'test-listener',
|
||||
'prometheus_url': 'http://test-prometheus:9090',
|
||||
'port': 9000,
|
||||
'timeout': 300,
|
||||
'api_token': 'secret',
|
||||
'log_level': 'INFO',
|
||||
'metrics': False,
|
||||
'metrics_port': 8001
|
||||
}
|
||||
|
||||
# Patch the Api class inside the test without using decorators
|
||||
with patch('trustgraph.gateway.service.Api') as mock_api:
|
||||
mock_api.return_value = mock_api_instance
|
||||
|
||||
with patch('builtins.vars') as mock_vars:
|
||||
mock_vars.return_value = expected_args
|
||||
|
||||
run()
|
||||
|
||||
# Verify Api was created with the parsed arguments
|
||||
mock_api.assert_called_once_with(**expected_args)
|
||||
mock_api_instance.run.assert_called_once()
|
||||
|
||||
def test_run_function_creates_argument_parser(self):
|
||||
"""Test that run function creates argument parser with correct arguments"""
|
||||
with patch('argparse.ArgumentParser') as mock_parser_class:
|
||||
mock_parser = Mock()
|
||||
mock_parser_class.return_value = mock_parser
|
||||
mock_parser.parse_args.return_value = Mock(metrics=False)
|
||||
|
||||
with patch('trustgraph.gateway.service.Api') as mock_api, \
|
||||
patch('builtins.vars') as mock_vars:
|
||||
mock_vars.return_value = {'metrics': False}
|
||||
mock_api.return_value = Mock()
|
||||
|
||||
run()
|
||||
|
||||
# Verify ArgumentParser was created
|
||||
mock_parser_class.assert_called_once()
|
||||
|
||||
# Verify add_argument was called for each expected argument
|
||||
expected_arguments = [
|
||||
'pulsar-host', 'pulsar-api-key', 'pulsar-listener',
|
||||
'prometheus-url', 'port', 'timeout', 'api-token',
|
||||
'log-level', 'metrics', 'metrics-port'
|
||||
]
|
||||
|
||||
# Check that add_argument was called multiple times (once for each arg)
|
||||
assert mock_parser.add_argument.call_count >= len(expected_arguments)
|
||||
# auth.start must be first (before config receiver, before
|
||||
# any endpoint starts).
|
||||
assert order[0] == "auth"
|
||||
# All three must have run.
|
||||
assert set(order) == {"auth", "config", "endpoints"}
|
||||
|
|
|
|||
|
|
@ -1,4 +1,15 @@
|
|||
"""Unit tests for SocketEndpoint graceful shutdown functionality."""
|
||||
"""Unit tests for SocketEndpoint graceful shutdown functionality.
|
||||
|
||||
These tests exercise SocketEndpoint in its handshake-auth
|
||||
configuration (``in_band_auth=False``) — the mode used in production
|
||||
for the flow import/export streaming endpoints. The mux socket at
|
||||
``/api/v1/socket`` uses ``in_band_auth=True`` instead, where the
|
||||
handshake always accepts and authentication runs on the first
|
||||
WebSocket frame; that path is covered by the Mux tests.
|
||||
|
||||
Every endpoint constructor here passes an explicit capability — no
|
||||
permissive default is relied upon.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import asyncio
|
||||
|
|
@ -6,13 +17,32 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
|||
from aiohttp import web, WSMsgType
|
||||
from trustgraph.gateway.endpoint.socket import SocketEndpoint
|
||||
from trustgraph.gateway.running import Running
|
||||
from trustgraph.gateway.auth import Identity
|
||||
|
||||
|
||||
# Representative capability used across these tests — corresponds to
|
||||
# the flow-import streaming endpoint pattern that uses this class.
|
||||
TEST_CAP = "graph:write"
|
||||
|
||||
|
||||
def _valid_identity():
|
||||
return Identity(
|
||||
handle="test-user",
|
||||
workspace="default",
|
||||
principal_id="test-user",
|
||||
source="api-key",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_auth():
|
||||
"""Mock authentication service."""
|
||||
"""Mock IAM-backed authenticator. Successful by default —
|
||||
``authenticate`` returns a valid identity and ``authorise``
|
||||
allows everything. Tests that need the failure paths override
|
||||
the relevant attribute locally."""
|
||||
auth = MagicMock()
|
||||
auth.permitted.return_value = True
|
||||
auth.authenticate = AsyncMock(return_value=_valid_identity())
|
||||
auth.authorise = AsyncMock(return_value=None)
|
||||
return auth
|
||||
|
||||
|
||||
|
|
@ -25,7 +55,7 @@ def mock_dispatcher_factory():
|
|||
dispatcher.receive = AsyncMock()
|
||||
dispatcher.destroy = AsyncMock()
|
||||
return dispatcher
|
||||
|
||||
|
||||
return dispatcher_factory
|
||||
|
||||
|
||||
|
|
@ -35,7 +65,8 @@ def socket_endpoint(mock_auth, mock_dispatcher_factory):
|
|||
return SocketEndpoint(
|
||||
endpoint_path="/test-socket",
|
||||
auth=mock_auth,
|
||||
dispatcher=mock_dispatcher_factory
|
||||
dispatcher=mock_dispatcher_factory,
|
||||
capability=TEST_CAP,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -61,7 +92,10 @@ def mock_request():
|
|||
@pytest.mark.asyncio
|
||||
async def test_listener_graceful_shutdown_on_close():
|
||||
"""Test listener handles websocket close gracefully."""
|
||||
socket_endpoint = SocketEndpoint("/test", MagicMock(), AsyncMock())
|
||||
socket_endpoint = SocketEndpoint(
|
||||
"/test", MagicMock(), AsyncMock(),
|
||||
capability=TEST_CAP,
|
||||
)
|
||||
|
||||
# Mock websocket that closes after one message
|
||||
ws = AsyncMock()
|
||||
|
|
@ -99,9 +133,10 @@ async def test_listener_graceful_shutdown_on_close():
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_normal_flow():
|
||||
"""Test normal websocket handling flow."""
|
||||
"""Valid bearer → handshake accepted, dispatcher created."""
|
||||
mock_auth = MagicMock()
|
||||
mock_auth.permitted.return_value = True
|
||||
mock_auth.authenticate = AsyncMock(return_value=_valid_identity())
|
||||
mock_auth.authorise = AsyncMock(return_value=None)
|
||||
|
||||
dispatcher_created = False
|
||||
async def mock_dispatcher_factory(ws, running, match_info):
|
||||
|
|
@ -111,7 +146,10 @@ async def test_handle_normal_flow():
|
|||
dispatcher.destroy = AsyncMock()
|
||||
return dispatcher
|
||||
|
||||
socket_endpoint = SocketEndpoint("/test", mock_auth, mock_dispatcher_factory)
|
||||
socket_endpoint = SocketEndpoint(
|
||||
"/test", mock_auth, mock_dispatcher_factory,
|
||||
capability=TEST_CAP,
|
||||
)
|
||||
|
||||
request = MagicMock()
|
||||
request.query = {"token": "valid-token"}
|
||||
|
|
@ -155,7 +193,8 @@ async def test_handle_normal_flow():
|
|||
async def test_handle_exception_group_cleanup():
|
||||
"""Test exception group triggers dispatcher cleanup."""
|
||||
mock_auth = MagicMock()
|
||||
mock_auth.permitted.return_value = True
|
||||
mock_auth.authenticate = AsyncMock(return_value=_valid_identity())
|
||||
mock_auth.authorise = AsyncMock(return_value=None)
|
||||
|
||||
mock_dispatcher = AsyncMock()
|
||||
mock_dispatcher.destroy = AsyncMock()
|
||||
|
|
@ -163,7 +202,10 @@ async def test_handle_exception_group_cleanup():
|
|||
async def mock_dispatcher_factory(ws, running, match_info):
|
||||
return mock_dispatcher
|
||||
|
||||
socket_endpoint = SocketEndpoint("/test", mock_auth, mock_dispatcher_factory)
|
||||
socket_endpoint = SocketEndpoint(
|
||||
"/test", mock_auth, mock_dispatcher_factory,
|
||||
capability=TEST_CAP,
|
||||
)
|
||||
|
||||
request = MagicMock()
|
||||
request.query = {"token": "valid-token"}
|
||||
|
|
@ -222,7 +264,8 @@ async def test_handle_exception_group_cleanup():
|
|||
async def test_handle_dispatcher_cleanup_timeout():
|
||||
"""Test dispatcher cleanup with timeout."""
|
||||
mock_auth = MagicMock()
|
||||
mock_auth.permitted.return_value = True
|
||||
mock_auth.authenticate = AsyncMock(return_value=_valid_identity())
|
||||
mock_auth.authorise = AsyncMock(return_value=None)
|
||||
|
||||
# Mock dispatcher that takes long to destroy
|
||||
mock_dispatcher = AsyncMock()
|
||||
|
|
@ -231,7 +274,10 @@ async def test_handle_dispatcher_cleanup_timeout():
|
|||
async def mock_dispatcher_factory(ws, running, match_info):
|
||||
return mock_dispatcher
|
||||
|
||||
socket_endpoint = SocketEndpoint("/test", mock_auth, mock_dispatcher_factory)
|
||||
socket_endpoint = SocketEndpoint(
|
||||
"/test", mock_auth, mock_dispatcher_factory,
|
||||
capability=TEST_CAP,
|
||||
)
|
||||
|
||||
request = MagicMock()
|
||||
request.query = {"token": "valid-token"}
|
||||
|
|
@ -285,49 +331,68 @@ async def test_handle_dispatcher_cleanup_timeout():
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_unauthorized_request():
|
||||
"""Test handling of unauthorized requests."""
|
||||
"""A bearer that the IAM layer rejects causes the handshake to
|
||||
fail with 401. IamAuth surfaces an HTTPUnauthorized; the
|
||||
endpoint propagates it. Note that the endpoint intentionally
|
||||
does NOT distinguish 'bad token', 'expired', 'revoked', etc. —
|
||||
that's the IAM error-masking policy."""
|
||||
mock_auth = MagicMock()
|
||||
mock_auth.permitted.return_value = False # Unauthorized
|
||||
|
||||
socket_endpoint = SocketEndpoint("/test", mock_auth, AsyncMock())
|
||||
|
||||
mock_auth.authenticate = AsyncMock(side_effect=web.HTTPUnauthorized(
|
||||
text='{"error":"auth failure"}',
|
||||
content_type="application/json",
|
||||
))
|
||||
|
||||
socket_endpoint = SocketEndpoint(
|
||||
"/test", mock_auth, AsyncMock(),
|
||||
capability=TEST_CAP,
|
||||
)
|
||||
|
||||
request = MagicMock()
|
||||
request.query = {"token": "invalid-token"}
|
||||
|
||||
|
||||
result = await socket_endpoint.handle(request)
|
||||
|
||||
# Should return HTTP 401
|
||||
|
||||
assert isinstance(result, web.HTTPUnauthorized)
|
||||
|
||||
# Should have checked permission
|
||||
mock_auth.permitted.assert_called_once_with("invalid-token", "socket")
|
||||
# authenticate must have been invoked with a synthetic request
|
||||
# carrying Bearer <the-token>. The endpoint wraps the query-
|
||||
# string token into an Authorization header for a uniform auth
|
||||
# path — the IAM layer does not look at query strings directly.
|
||||
mock_auth.authenticate.assert_called_once()
|
||||
passed_req = mock_auth.authenticate.call_args.args[0]
|
||||
assert passed_req.headers["Authorization"] == "Bearer invalid-token"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_missing_token():
|
||||
"""Test handling of requests with missing token."""
|
||||
"""Request with no ``token`` query param → 401 before any
|
||||
IAM call is made (cheap short-circuit)."""
|
||||
mock_auth = MagicMock()
|
||||
mock_auth.permitted.return_value = False
|
||||
|
||||
socket_endpoint = SocketEndpoint("/test", mock_auth, AsyncMock())
|
||||
|
||||
mock_auth.authenticate = AsyncMock(
|
||||
side_effect=AssertionError(
|
||||
"authenticate must not be invoked when no token is present"
|
||||
),
|
||||
)
|
||||
|
||||
socket_endpoint = SocketEndpoint(
|
||||
"/test", mock_auth, AsyncMock(),
|
||||
capability=TEST_CAP,
|
||||
)
|
||||
|
||||
request = MagicMock()
|
||||
request.query = {} # No token
|
||||
|
||||
|
||||
result = await socket_endpoint.handle(request)
|
||||
|
||||
# Should return HTTP 401
|
||||
|
||||
assert isinstance(result, web.HTTPUnauthorized)
|
||||
|
||||
# Should have checked permission with empty token
|
||||
mock_auth.permitted.assert_called_once_with("", "socket")
|
||||
mock_auth.authenticate.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_websocket_already_closed():
|
||||
"""Test handling when websocket is already closed."""
|
||||
mock_auth = MagicMock()
|
||||
mock_auth.permitted.return_value = True
|
||||
mock_auth.authenticate = AsyncMock(return_value=_valid_identity())
|
||||
mock_auth.authorise = AsyncMock(return_value=None)
|
||||
|
||||
mock_dispatcher = AsyncMock()
|
||||
mock_dispatcher.destroy = AsyncMock()
|
||||
|
|
@ -335,7 +400,10 @@ async def test_handle_websocket_already_closed():
|
|||
async def mock_dispatcher_factory(ws, running, match_info):
|
||||
return mock_dispatcher
|
||||
|
||||
socket_endpoint = SocketEndpoint("/test", mock_auth, mock_dispatcher_factory)
|
||||
socket_endpoint = SocketEndpoint(
|
||||
"/test", mock_auth, mock_dispatcher_factory,
|
||||
capability=TEST_CAP,
|
||||
)
|
||||
|
||||
request = MagicMock()
|
||||
request.query = {"token": "valid-token"}
|
||||
|
|
|
|||
|
|
@ -15,13 +15,13 @@ from trustgraph.base import LlmResult
|
|||
class TestOllamaProcessorSimple(IsolatedAsyncioTestCase):
|
||||
"""Test Ollama processor functionality"""
|
||||
|
||||
@patch('trustgraph.model.text_completion.ollama.llm.Client')
|
||||
@patch('trustgraph.model.text_completion.ollama.llm.AsyncClient')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_processor_initialization_basic(self, mock_llm_init, mock_async_init, mock_client_class):
|
||||
"""Test basic processor initialization"""
|
||||
# Arrange
|
||||
mock_client = MagicMock()
|
||||
mock_client = AsyncMock()
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
# Mock the parent class initialization
|
||||
|
|
@ -44,13 +44,13 @@ class TestOllamaProcessorSimple(IsolatedAsyncioTestCase):
|
|||
assert hasattr(processor, 'llm')
|
||||
mock_client_class.assert_called_once_with(host='http://localhost:11434')
|
||||
|
||||
@patch('trustgraph.model.text_completion.ollama.llm.Client')
|
||||
@patch('trustgraph.model.text_completion.ollama.llm.AsyncClient')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_success(self, mock_llm_init, mock_async_init, mock_client_class):
|
||||
"""Test successful content generation"""
|
||||
# Arrange
|
||||
mock_client = MagicMock()
|
||||
mock_client = AsyncMock()
|
||||
mock_response = {
|
||||
'response': 'Generated response from Ollama',
|
||||
'prompt_eval_count': 15,
|
||||
|
|
@ -83,13 +83,13 @@ class TestOllamaProcessorSimple(IsolatedAsyncioTestCase):
|
|||
assert result.model == 'llama2'
|
||||
mock_client.generate.assert_called_once_with('llama2', "System prompt\n\nUser prompt", options={'temperature': 0.0})
|
||||
|
||||
@patch('trustgraph.model.text_completion.ollama.llm.Client')
|
||||
@patch('trustgraph.model.text_completion.ollama.llm.AsyncClient')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_generic_exception(self, mock_llm_init, mock_async_init, mock_client_class):
|
||||
"""Test handling of generic exceptions"""
|
||||
# Arrange
|
||||
mock_client = MagicMock()
|
||||
mock_client = AsyncMock()
|
||||
mock_client.generate.side_effect = Exception("Connection error")
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
|
|
@ -110,13 +110,13 @@ class TestOllamaProcessorSimple(IsolatedAsyncioTestCase):
|
|||
with pytest.raises(Exception, match="Connection error"):
|
||||
await processor.generate_content("System prompt", "User prompt")
|
||||
|
||||
@patch('trustgraph.model.text_completion.ollama.llm.Client')
|
||||
@patch('trustgraph.model.text_completion.ollama.llm.AsyncClient')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_processor_initialization_with_custom_parameters(self, mock_llm_init, mock_async_init, mock_client_class):
|
||||
"""Test processor initialization with custom parameters"""
|
||||
# Arrange
|
||||
mock_client = MagicMock()
|
||||
mock_client = AsyncMock()
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
|
|
@ -137,13 +137,13 @@ class TestOllamaProcessorSimple(IsolatedAsyncioTestCase):
|
|||
assert processor.default_model == 'mistral'
|
||||
mock_client_class.assert_called_once_with(host='http://192.168.1.100:11434')
|
||||
|
||||
@patch('trustgraph.model.text_completion.ollama.llm.Client')
|
||||
@patch('trustgraph.model.text_completion.ollama.llm.AsyncClient')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_processor_initialization_with_defaults(self, mock_llm_init, mock_async_init, mock_client_class):
|
||||
"""Test processor initialization with default values"""
|
||||
# Arrange
|
||||
mock_client = MagicMock()
|
||||
mock_client = AsyncMock()
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
|
|
@ -164,13 +164,13 @@ class TestOllamaProcessorSimple(IsolatedAsyncioTestCase):
|
|||
# Should use default_ollama (http://localhost:11434 or from OLLAMA_HOST env)
|
||||
mock_client_class.assert_called_once()
|
||||
|
||||
@patch('trustgraph.model.text_completion.ollama.llm.Client')
|
||||
@patch('trustgraph.model.text_completion.ollama.llm.AsyncClient')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_empty_prompts(self, mock_llm_init, mock_async_init, mock_client_class):
|
||||
"""Test content generation with empty prompts"""
|
||||
# Arrange
|
||||
mock_client = MagicMock()
|
||||
mock_client = AsyncMock()
|
||||
mock_response = {
|
||||
'response': 'Default response',
|
||||
'prompt_eval_count': 2,
|
||||
|
|
@ -205,13 +205,13 @@ class TestOllamaProcessorSimple(IsolatedAsyncioTestCase):
|
|||
# The prompt should be "" + "\n\n" + "" = "\n\n"
|
||||
mock_client.generate.assert_called_once_with('llama2', "\n\n", options={'temperature': 0.0})
|
||||
|
||||
@patch('trustgraph.model.text_completion.ollama.llm.Client')
|
||||
@patch('trustgraph.model.text_completion.ollama.llm.AsyncClient')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_token_counting(self, mock_llm_init, mock_async_init, mock_client_class):
|
||||
"""Test token counting from Ollama response"""
|
||||
# Arrange
|
||||
mock_client = MagicMock()
|
||||
mock_client = AsyncMock()
|
||||
mock_response = {
|
||||
'response': 'Test response',
|
||||
'prompt_eval_count': 50,
|
||||
|
|
@ -243,13 +243,13 @@ class TestOllamaProcessorSimple(IsolatedAsyncioTestCase):
|
|||
assert result.out_token == 25
|
||||
assert result.model == 'llama2'
|
||||
|
||||
@patch('trustgraph.model.text_completion.ollama.llm.Client')
|
||||
@patch('trustgraph.model.text_completion.ollama.llm.AsyncClient')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_ollama_client_initialization(self, mock_llm_init, mock_async_init, mock_client_class):
|
||||
"""Test that Ollama client is initialized correctly"""
|
||||
# Arrange
|
||||
mock_client = MagicMock()
|
||||
mock_client = AsyncMock()
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
|
|
@ -273,13 +273,13 @@ class TestOllamaProcessorSimple(IsolatedAsyncioTestCase):
|
|||
# Verify processor has the client
|
||||
assert processor.llm == mock_client
|
||||
|
||||
@patch('trustgraph.model.text_completion.ollama.llm.Client')
|
||||
@patch('trustgraph.model.text_completion.ollama.llm.AsyncClient')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_prompt_construction(self, mock_llm_init, mock_async_init, mock_client_class):
|
||||
"""Test prompt construction with system and user prompts"""
|
||||
# Arrange
|
||||
mock_client = MagicMock()
|
||||
mock_client = AsyncMock()
|
||||
mock_response = {
|
||||
'response': 'Response with system instructions',
|
||||
'prompt_eval_count': 25,
|
||||
|
|
@ -312,13 +312,13 @@ class TestOllamaProcessorSimple(IsolatedAsyncioTestCase):
|
|||
# Verify the combined prompt
|
||||
mock_client.generate.assert_called_once_with('llama2', "You are a helpful assistant\n\nWhat is AI?", options={'temperature': 0.0})
|
||||
|
||||
@patch('trustgraph.model.text_completion.ollama.llm.Client')
|
||||
@patch('trustgraph.model.text_completion.ollama.llm.AsyncClient')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_temperature_override(self, mock_llm_init, mock_async_init, mock_client_class):
|
||||
"""Test temperature parameter override functionality"""
|
||||
# Arrange
|
||||
mock_client = MagicMock()
|
||||
mock_client = AsyncMock()
|
||||
mock_response = {
|
||||
'response': 'Response with custom temperature',
|
||||
'prompt_eval_count': 20,
|
||||
|
|
@ -360,13 +360,13 @@ class TestOllamaProcessorSimple(IsolatedAsyncioTestCase):
|
|||
options={'temperature': 0.8} # Should use runtime override
|
||||
)
|
||||
|
||||
@patch('trustgraph.model.text_completion.ollama.llm.Client')
|
||||
@patch('trustgraph.model.text_completion.ollama.llm.AsyncClient')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_model_override(self, mock_llm_init, mock_async_init, mock_client_class):
|
||||
"""Test model parameter override functionality"""
|
||||
# Arrange
|
||||
mock_client = MagicMock()
|
||||
mock_client = AsyncMock()
|
||||
mock_response = {
|
||||
'response': 'Response with custom model',
|
||||
'prompt_eval_count': 18,
|
||||
|
|
@ -408,13 +408,13 @@ class TestOllamaProcessorSimple(IsolatedAsyncioTestCase):
|
|||
options={'temperature': 0.1} # Should use processor default
|
||||
)
|
||||
|
||||
@patch('trustgraph.model.text_completion.ollama.llm.Client')
|
||||
@patch('trustgraph.model.text_completion.ollama.llm.AsyncClient')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_both_parameters_override(self, mock_llm_init, mock_async_init, mock_client_class):
|
||||
"""Test overriding both model and temperature parameters simultaneously"""
|
||||
# Arrange
|
||||
mock_client = MagicMock()
|
||||
mock_client = AsyncMock()
|
||||
mock_response = {
|
||||
'response': 'Response with both overrides',
|
||||
'prompt_eval_count': 22,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue