mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-26 17:06:22 +02:00
Fix Metadata/EntityEmbeddings schema migration tail and add regression tests (#777)
The Metadata dataclass dropped its `metadata: list[Triple]` field
and EntityEmbeddings/ChunkEmbeddings settled on a singular
`vector: list[float]` field, but several call sites kept passing
`Metadata(metadata=...)` and `EntityEmbeddings(vectors=...)`. The
bugs were latent until a websocket client first hit
`/api/v1/flow/default/import/entity-contexts`, at which point the
dispatcher TypeError'd on construction.
Production fixes (5 call sites on the same migration tail):
* trustgraph-flow gateway dispatchers entity_contexts_import.py
and graph_embeddings_import.py — drop the stale
Metadata(metadata=...) kwarg; switch graph_embeddings_import
to the singular `vector` wire key.
* trustgraph-base messaging translators knowledge.py and
document_loading.py — fix decode side to read the singular
`"vector"` key, matching what their own encode sides have
always written.
* trustgraph-flow tables/knowledge.py — fix Cassandra row
deserialiser to construct EntityEmbeddings(vector=...)
instead of vectors=.
* trustgraph-flow gateway core_import/core_export — switch the
kg-core msgpack wire format to the singular `"v"`/`"vector"`
key and drop the dead `m["m"]` envelope field that referenced
the removed Metadata.metadata triples list (it was a
guaranteed KeyError on the export side).
Defense-in-depth regression coverage (32 new tests across 7 files):
* tests/contract/test_schema_field_contracts.py — pin the field
set of Metadata, EntityEmbeddings, ChunkEmbeddings,
EntityContext so any future schema rename fails CI loudly
with a clear diff.
* tests/unit/test_translators/test_knowledge_translator_roundtrip.py
and test_document_embeddings_translator_roundtrip.py -
encode→decode round-trip the affected translators end to end,
locking in the singular `"vector"` wire key.
* tests/unit/test_gateway/test_entity_contexts_import_dispatcher.py
and test_graph_embeddings_import_dispatcher.py — exercise the
websocket dispatchers' receive() path with realistic
payloads, the direct regression test for the original
production crash.
* tests/unit/test_gateway/test_core_import_export_roundtrip.py
— pack/unpack the kg-core msgpack format through the real
dispatcher classes (with KnowledgeRequestor mocked),
including a full export→import round-trip.
* tests/unit/test_tables/test_knowledge_table_store.py —
exercise the Cassandra row → schema conversion via __new__ to
bypass the live cluster connection.
Also fixes an unrelated leaked-coroutine RuntimeWarning in
test_gateway/test_service.py::test_run_method_calls_web_run_app: the
mocked aiohttp.web.run_app now closes the coroutine that Api.run() hands
it, mirroring what the real run_app would do, instead of leaving it for
the GC to complain about.
This commit is contained in:
parent
0994d4b05f
commit
c23e28aa66
17 changed files with 1415 additions and 17 deletions
418
tests/unit/test_gateway/test_core_import_export_roundtrip.py
Normal file
418
tests/unit/test_gateway/test_core_import_export_roundtrip.py
Normal file
|
|
@ -0,0 +1,418 @@
|
|||
"""
|
||||
Round-trip unit tests for the core msgpack import/export gateway endpoints.
|
||||
|
||||
The kg-core export endpoint receives KnowledgeResponse-shaped dicts from
|
||||
the responder callback and packs them into msgpack tuples. The kg-core
|
||||
import endpoint takes msgpack tuples back off the wire and rebuilds
|
||||
KnowledgeRequest-shaped dicts which it then hands to KnowledgeRequestor
|
||||
(whose translator decodes them into real dataclasses).
|
||||
|
||||
Regression coverage: the previous wire format used `"vectors"` (plural)
|
||||
in the entity blobs and embedded a stale `"m"` field that referenced the
|
||||
removed `Metadata.metadata` triples-list field. The export side hit a
|
||||
KeyError on first message; the import side built dicts that the
|
||||
KnowledgeRequestTranslator (separately fixed) couldn't decode. These
|
||||
tests pin both halves of the wire protocol.
|
||||
"""
|
||||
|
||||
import msgpack
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
from trustgraph.gateway.dispatch.core_export import CoreExport
|
||||
from trustgraph.gateway.dispatch.core_import import CoreImport
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers — sample translator-shaped dicts (as KnowledgeResponseTranslator
|
||||
# would emit). The vector wire key is *singular* on purpose; the export
|
||||
# side previously read the wrong key and crashed.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _ge_response_dict():
|
||||
return {
|
||||
"graph-embeddings": {
|
||||
"metadata": {
|
||||
"id": "doc-1",
|
||||
"root": "",
|
||||
"user": "alice",
|
||||
"collection": "testcoll",
|
||||
},
|
||||
"entities": [
|
||||
{
|
||||
"entity": {"t": "i", "i": "http://example.org/alice"},
|
||||
"vector": [0.1, 0.2, 0.3],
|
||||
},
|
||||
{
|
||||
"entity": {"t": "i", "i": "http://example.org/bob"},
|
||||
"vector": [0.4, 0.5, 0.6],
|
||||
},
|
||||
],
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def _triples_response_dict():
|
||||
return {
|
||||
"triples": {
|
||||
"metadata": {
|
||||
"id": "doc-1",
|
||||
"root": "",
|
||||
"user": "alice",
|
||||
"collection": "testcoll",
|
||||
},
|
||||
"triples": [
|
||||
{
|
||||
"s": {"t": "i", "i": "http://example.org/alice"},
|
||||
"p": {"t": "i", "i": "http://example.org/knows"},
|
||||
"o": {"t": "i", "i": "http://example.org/bob"},
|
||||
},
|
||||
],
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def _make_request(id_="doc-1", user="alice"):
|
||||
request = Mock()
|
||||
request.query = {"id": id_, "user": user}
|
||||
return request
|
||||
|
||||
|
||||
def _make_data_reader(payload: bytes):
|
||||
"""Mock the aiohttp StreamReader: returns payload once, then EOF."""
|
||||
chunks = [payload, b""]
|
||||
|
||||
data = Mock()
|
||||
|
||||
async def fake_read(n):
|
||||
return chunks.pop(0) if chunks else b""
|
||||
|
||||
data.read = fake_read
|
||||
return data
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Export side: translator-shaped dict -> msgpack bytes
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCoreExportWireFormat:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("trustgraph.gateway.dispatch.core_export.KnowledgeRequestor")
|
||||
async def test_export_packs_graph_embeddings_with_singular_vector(
|
||||
self, mock_kr_class,
|
||||
):
|
||||
"""The export side must read `ent["vector"]` and emit `v`. The
|
||||
previous bug was reading `ent["vectors"]` which KeyErrored against
|
||||
the translator output."""
|
||||
captured = []
|
||||
|
||||
async def fake_kr_process(req_dict, responder):
|
||||
await responder(_ge_response_dict(), True)
|
||||
|
||||
mock_kr = AsyncMock()
|
||||
mock_kr.start = AsyncMock()
|
||||
mock_kr.stop = AsyncMock()
|
||||
mock_kr.process = fake_kr_process
|
||||
mock_kr_class.return_value = mock_kr
|
||||
|
||||
response = AsyncMock()
|
||||
|
||||
async def fake_write(b):
|
||||
captured.append(b)
|
||||
|
||||
response.write = fake_write
|
||||
response.write_eof = AsyncMock()
|
||||
|
||||
ok = AsyncMock(return_value=response)
|
||||
error = AsyncMock()
|
||||
|
||||
exporter = CoreExport(backend=Mock())
|
||||
await exporter.process(
|
||||
data=Mock(),
|
||||
error=error,
|
||||
ok=ok,
|
||||
request=_make_request(),
|
||||
)
|
||||
|
||||
# Did not raise, did not call error()
|
||||
error.assert_not_called()
|
||||
assert len(captured) == 1
|
||||
|
||||
unpacker = msgpack.Unpacker()
|
||||
unpacker.feed(captured[0])
|
||||
items = list(unpacker)
|
||||
|
||||
assert len(items) == 1
|
||||
msg_type, payload = items[0]
|
||||
assert msg_type == "ge"
|
||||
|
||||
# Metadata envelope: only id/user/collection — no stale `m["m"]`.
|
||||
assert payload["m"] == {
|
||||
"i": "doc-1",
|
||||
"u": "alice",
|
||||
"c": "testcoll",
|
||||
}
|
||||
|
||||
# Entities: each carries the *singular* `v` and the term envelope
|
||||
assert len(payload["e"]) == 2
|
||||
assert payload["e"][0]["v"] == [0.1, 0.2, 0.3]
|
||||
assert payload["e"][1]["v"] == [0.4, 0.5, 0.6]
|
||||
assert payload["e"][0]["e"]["i"] == "http://example.org/alice"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("trustgraph.gateway.dispatch.core_export.KnowledgeRequestor")
|
||||
async def test_export_packs_triples(self, mock_kr_class):
|
||||
captured = []
|
||||
|
||||
async def fake_kr_process(req_dict, responder):
|
||||
await responder(_triples_response_dict(), True)
|
||||
|
||||
mock_kr = AsyncMock()
|
||||
mock_kr.start = AsyncMock()
|
||||
mock_kr.stop = AsyncMock()
|
||||
mock_kr.process = fake_kr_process
|
||||
mock_kr_class.return_value = mock_kr
|
||||
|
||||
response = AsyncMock()
|
||||
|
||||
async def fake_write(b):
|
||||
captured.append(b)
|
||||
|
||||
response.write = fake_write
|
||||
response.write_eof = AsyncMock()
|
||||
|
||||
ok = AsyncMock(return_value=response)
|
||||
error = AsyncMock()
|
||||
|
||||
exporter = CoreExport(backend=Mock())
|
||||
await exporter.process(
|
||||
data=Mock(), error=error, ok=ok, request=_make_request(),
|
||||
)
|
||||
|
||||
error.assert_not_called()
|
||||
assert len(captured) == 1
|
||||
|
||||
unpacker = msgpack.Unpacker()
|
||||
unpacker.feed(captured[0])
|
||||
items = list(unpacker)
|
||||
assert len(items) == 1
|
||||
|
||||
msg_type, payload = items[0]
|
||||
assert msg_type == "t"
|
||||
assert payload["m"] == {
|
||||
"i": "doc-1",
|
||||
"u": "alice",
|
||||
"c": "testcoll",
|
||||
}
|
||||
assert len(payload["t"]) == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Import side: msgpack bytes -> translator-shaped dict
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCoreImportWireFormat:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("trustgraph.gateway.dispatch.core_import.KnowledgeRequestor")
|
||||
async def test_import_unpacks_graph_embeddings_to_singular_vector(
|
||||
self, mock_kr_class,
|
||||
):
|
||||
"""The import side must build dicts whose entity blobs have the
|
||||
singular `vector` key — that's what the KnowledgeRequestTranslator
|
||||
decode side reads. Previous bug emitted `vectors`."""
|
||||
captured = []
|
||||
|
||||
async def fake_kr_process(req_dict):
|
||||
captured.append(req_dict)
|
||||
|
||||
mock_kr = AsyncMock()
|
||||
mock_kr.start = AsyncMock()
|
||||
mock_kr.stop = AsyncMock()
|
||||
mock_kr.process = fake_kr_process
|
||||
mock_kr_class.return_value = mock_kr
|
||||
|
||||
# Build a msgpack tuple matching the new wire format
|
||||
payload = msgpack.packb((
|
||||
"ge",
|
||||
{
|
||||
"m": {"i": "doc-1", "u": "alice", "c": "testcoll"},
|
||||
"e": [
|
||||
{
|
||||
"e": {"t": "i", "i": "http://example.org/alice"},
|
||||
"v": [0.1, 0.2, 0.3],
|
||||
},
|
||||
],
|
||||
},
|
||||
))
|
||||
|
||||
ok = AsyncMock(return_value=AsyncMock(write_eof=AsyncMock()))
|
||||
error = AsyncMock()
|
||||
|
||||
importer = CoreImport(backend=Mock())
|
||||
await importer.process(
|
||||
data=_make_data_reader(payload),
|
||||
error=error,
|
||||
ok=ok,
|
||||
request=_make_request(),
|
||||
)
|
||||
|
||||
error.assert_not_called()
|
||||
assert len(captured) == 1
|
||||
|
||||
req = captured[0]
|
||||
assert req["operation"] == "put-kg-core"
|
||||
assert req["user"] == "alice"
|
||||
assert req["id"] == "doc-1"
|
||||
|
||||
ge = req["graph-embeddings"]
|
||||
# Metadata envelope must NOT contain a stale `metadata` key
|
||||
# referencing the removed Metadata.metadata field.
|
||||
assert "metadata" not in ge["metadata"]
|
||||
assert ge["metadata"] == {
|
||||
"id": "doc-1",
|
||||
"user": "alice",
|
||||
"collection": "default",
|
||||
}
|
||||
|
||||
# Entity blob carries the singular `vector` key
|
||||
assert len(ge["entities"]) == 1
|
||||
ent = ge["entities"][0]
|
||||
assert ent["vector"] == [0.1, 0.2, 0.3]
|
||||
assert "vectors" not in ent
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("trustgraph.gateway.dispatch.core_import.KnowledgeRequestor")
|
||||
async def test_import_unpacks_triples(self, mock_kr_class):
|
||||
captured = []
|
||||
|
||||
async def fake_kr_process(req_dict):
|
||||
captured.append(req_dict)
|
||||
|
||||
mock_kr = AsyncMock()
|
||||
mock_kr.start = AsyncMock()
|
||||
mock_kr.stop = AsyncMock()
|
||||
mock_kr.process = fake_kr_process
|
||||
mock_kr_class.return_value = mock_kr
|
||||
|
||||
payload = msgpack.packb((
|
||||
"t",
|
||||
{
|
||||
"m": {"i": "doc-1", "u": "alice", "c": "testcoll"},
|
||||
"t": [
|
||||
{
|
||||
"s": {"t": "i", "i": "http://example.org/alice"},
|
||||
"p": {"t": "i", "i": "http://example.org/knows"},
|
||||
"o": {"t": "i", "i": "http://example.org/bob"},
|
||||
},
|
||||
],
|
||||
},
|
||||
))
|
||||
|
||||
ok = AsyncMock(return_value=AsyncMock(write_eof=AsyncMock()))
|
||||
error = AsyncMock()
|
||||
|
||||
importer = CoreImport(backend=Mock())
|
||||
await importer.process(
|
||||
data=_make_data_reader(payload),
|
||||
error=error,
|
||||
ok=ok,
|
||||
request=_make_request(),
|
||||
)
|
||||
|
||||
error.assert_not_called()
|
||||
assert len(captured) == 1
|
||||
|
||||
req = captured[0]
|
||||
triples = req["triples"]
|
||||
assert "metadata" not in triples["metadata"] # no stale field
|
||||
assert len(triples["triples"]) == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Full round-trip: export bytes feed directly into import
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCoreImportExportRoundTrip:
|
||||
"""End-to-end: produce bytes via core_export, consume them via
|
||||
core_import, and verify the dict that lands at the import-side
|
||||
translator is structurally equivalent to what went in. This is the
|
||||
test that catches asymmetries between the two halves."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("trustgraph.gateway.dispatch.core_import.KnowledgeRequestor")
|
||||
@patch("trustgraph.gateway.dispatch.core_export.KnowledgeRequestor")
|
||||
async def test_graph_embeddings_round_trip(
|
||||
self, mock_export_kr_class, mock_import_kr_class,
|
||||
):
|
||||
# ----- export side: capture bytes -----
|
||||
export_bytes = []
|
||||
|
||||
async def fake_export_process(req_dict, responder):
|
||||
await responder(_ge_response_dict(), True)
|
||||
|
||||
export_kr = AsyncMock()
|
||||
export_kr.start = AsyncMock()
|
||||
export_kr.stop = AsyncMock()
|
||||
export_kr.process = fake_export_process
|
||||
mock_export_kr_class.return_value = export_kr
|
||||
|
||||
response = AsyncMock()
|
||||
|
||||
async def fake_write(b):
|
||||
export_bytes.append(b)
|
||||
|
||||
response.write = fake_write
|
||||
response.write_eof = AsyncMock()
|
||||
|
||||
exporter = CoreExport(backend=Mock())
|
||||
await exporter.process(
|
||||
data=Mock(),
|
||||
error=AsyncMock(),
|
||||
ok=AsyncMock(return_value=response),
|
||||
request=_make_request(),
|
||||
)
|
||||
|
||||
assert len(export_bytes) == 1
|
||||
|
||||
# ----- import side: feed those bytes back in -----
|
||||
import_captured = []
|
||||
|
||||
async def fake_import_process(req_dict):
|
||||
import_captured.append(req_dict)
|
||||
|
||||
import_kr = AsyncMock()
|
||||
import_kr.start = AsyncMock()
|
||||
import_kr.stop = AsyncMock()
|
||||
import_kr.process = fake_import_process
|
||||
mock_import_kr_class.return_value = import_kr
|
||||
|
||||
importer = CoreImport(backend=Mock())
|
||||
await importer.process(
|
||||
data=_make_data_reader(export_bytes[0]),
|
||||
error=AsyncMock(),
|
||||
ok=AsyncMock(return_value=AsyncMock(write_eof=AsyncMock())),
|
||||
request=_make_request(),
|
||||
)
|
||||
|
||||
# ----- verify the dict the importer would hand to the translator -----
|
||||
assert len(import_captured) == 1
|
||||
req = import_captured[0]
|
||||
|
||||
original = _ge_response_dict()["graph-embeddings"]
|
||||
|
||||
ge = req["graph-embeddings"]
|
||||
# The import side overrides id/user from the URL query (intentional),
|
||||
# so we only round-trip the entity payload itself.
|
||||
assert ge["metadata"]["id"] == original["metadata"]["id"]
|
||||
assert ge["metadata"]["user"] == original["metadata"]["user"]
|
||||
|
||||
assert len(ge["entities"]) == len(original["entities"])
|
||||
for got, want in zip(ge["entities"], original["entities"]):
|
||||
assert got["vector"] == want["vector"]
|
||||
assert got["entity"] == want["entity"]
|
||||
|
|
@ -0,0 +1,242 @@
|
|||
"""
|
||||
Unit tests for entity contexts import dispatcher.
|
||||
|
||||
Tests the business logic of EntityContextsImport while mocking the
|
||||
Publisher and websocket components.
|
||||
|
||||
Regression coverage: a previous version constructed Metadata(metadata=...)
|
||||
which raised TypeError at runtime as soon as a message was received. These
|
||||
tests exercise receive() end-to-end so any future schema/kwarg drift in
|
||||
the Metadata or EntityContexts construction is caught immediately.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import Mock, AsyncMock, patch
|
||||
|
||||
from trustgraph.gateway.dispatch.entity_contexts_import import EntityContextsImport
|
||||
from trustgraph.schema import EntityContexts, EntityContext, Metadata
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_backend():
|
||||
return Mock()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_running():
|
||||
running = Mock()
|
||||
running.get.return_value = True
|
||||
running.stop = Mock()
|
||||
return running
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_websocket():
|
||||
ws = Mock()
|
||||
ws.close = AsyncMock()
|
||||
return ws
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_message():
|
||||
"""Sample entity-contexts websocket message."""
|
||||
return {
|
||||
"metadata": {
|
||||
"id": "doc-123",
|
||||
"user": "testuser",
|
||||
"collection": "testcollection",
|
||||
},
|
||||
"entities": [
|
||||
{
|
||||
"entity": {"v": "http://example.org/alice", "e": True},
|
||||
"context": "Alice is a person.",
|
||||
},
|
||||
{
|
||||
"entity": {"v": "http://example.org/bob", "e": True},
|
||||
"context": "Bob is a person.",
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def empty_entities_message():
|
||||
return {
|
||||
"metadata": {
|
||||
"id": "doc-empty",
|
||||
"user": "u",
|
||||
"collection": "c",
|
||||
},
|
||||
"entities": [],
|
||||
}
|
||||
|
||||
|
||||
class TestEntityContextsImportInitialization:
|
||||
|
||||
@patch('trustgraph.gateway.dispatch.entity_contexts_import.Publisher')
|
||||
def test_init_creates_publisher_with_correct_params(
|
||||
self, mock_publisher_class, mock_backend, mock_websocket, mock_running
|
||||
):
|
||||
instance = Mock()
|
||||
mock_publisher_class.return_value = instance
|
||||
|
||||
dispatcher = EntityContextsImport(
|
||||
ws=mock_websocket,
|
||||
running=mock_running,
|
||||
backend=mock_backend,
|
||||
queue="ec-queue",
|
||||
)
|
||||
|
||||
mock_publisher_class.assert_called_once_with(
|
||||
mock_backend,
|
||||
topic="ec-queue",
|
||||
schema=EntityContexts,
|
||||
)
|
||||
assert dispatcher.ws is mock_websocket
|
||||
assert dispatcher.running is mock_running
|
||||
assert dispatcher.publisher is instance
|
||||
|
||||
|
||||
class TestEntityContextsImportLifecycle:
|
||||
|
||||
@patch('trustgraph.gateway.dispatch.entity_contexts_import.Publisher')
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_calls_publisher_start(
|
||||
self, mock_publisher_class, mock_backend, mock_websocket, mock_running
|
||||
):
|
||||
instance = Mock()
|
||||
instance.start = AsyncMock()
|
||||
mock_publisher_class.return_value = instance
|
||||
|
||||
dispatcher = EntityContextsImport(
|
||||
ws=mock_websocket, running=mock_running,
|
||||
backend=mock_backend, queue="q",
|
||||
)
|
||||
await dispatcher.start()
|
||||
instance.start.assert_called_once()
|
||||
|
||||
@patch('trustgraph.gateway.dispatch.entity_contexts_import.Publisher')
|
||||
@pytest.mark.asyncio
|
||||
async def test_destroy_stops_and_closes_properly(
|
||||
self, mock_publisher_class, mock_backend, mock_websocket, mock_running
|
||||
):
|
||||
instance = Mock()
|
||||
instance.stop = AsyncMock()
|
||||
mock_publisher_class.return_value = instance
|
||||
|
||||
dispatcher = EntityContextsImport(
|
||||
ws=mock_websocket, running=mock_running,
|
||||
backend=mock_backend, queue="q",
|
||||
)
|
||||
await dispatcher.destroy()
|
||||
|
||||
mock_running.stop.assert_called_once()
|
||||
instance.stop.assert_called_once()
|
||||
mock_websocket.close.assert_called_once()
|
||||
|
||||
@patch('trustgraph.gateway.dispatch.entity_contexts_import.Publisher')
|
||||
@pytest.mark.asyncio
|
||||
async def test_destroy_handles_none_websocket(
|
||||
self, mock_publisher_class, mock_backend, mock_running
|
||||
):
|
||||
instance = Mock()
|
||||
instance.stop = AsyncMock()
|
||||
mock_publisher_class.return_value = instance
|
||||
|
||||
dispatcher = EntityContextsImport(
|
||||
ws=None, running=mock_running,
|
||||
backend=mock_backend, queue="q",
|
||||
)
|
||||
await dispatcher.destroy()
|
||||
|
||||
mock_running.stop.assert_called_once()
|
||||
instance.stop.assert_called_once()
|
||||
|
||||
|
||||
class TestEntityContextsImportMessageProcessing:
|
||||
"""Regression coverage for receive(): catches Metadata/schema drift."""
|
||||
|
||||
@patch('trustgraph.gateway.dispatch.entity_contexts_import.Publisher')
|
||||
@pytest.mark.asyncio
|
||||
async def test_receive_constructs_entity_contexts_correctly(
|
||||
self, mock_publisher_class, mock_backend, mock_websocket,
|
||||
mock_running, sample_message,
|
||||
):
|
||||
instance = Mock()
|
||||
instance.send = AsyncMock()
|
||||
mock_publisher_class.return_value = instance
|
||||
|
||||
dispatcher = EntityContextsImport(
|
||||
ws=mock_websocket, running=mock_running,
|
||||
backend=mock_backend, queue="q",
|
||||
)
|
||||
|
||||
mock_msg = Mock()
|
||||
mock_msg.json.return_value = sample_message
|
||||
|
||||
# If Metadata or EntityContexts gain/lose kwargs, this raises
|
||||
# TypeError — exactly the regression we want to catch.
|
||||
await dispatcher.receive(mock_msg)
|
||||
|
||||
instance.send.assert_called_once()
|
||||
call_args = instance.send.call_args
|
||||
assert call_args[0][0] is None
|
||||
|
||||
sent = call_args[0][1]
|
||||
assert isinstance(sent, EntityContexts)
|
||||
assert isinstance(sent.metadata, Metadata)
|
||||
assert sent.metadata.id == "doc-123"
|
||||
assert sent.metadata.user == "testuser"
|
||||
assert sent.metadata.collection == "testcollection"
|
||||
|
||||
assert len(sent.entities) == 2
|
||||
assert all(isinstance(e, EntityContext) for e in sent.entities)
|
||||
assert sent.entities[0].context == "Alice is a person."
|
||||
assert sent.entities[1].context == "Bob is a person."
|
||||
|
||||
@patch('trustgraph.gateway.dispatch.entity_contexts_import.Publisher')
|
||||
@pytest.mark.asyncio
|
||||
async def test_receive_handles_empty_entities(
|
||||
self, mock_publisher_class, mock_backend, mock_websocket,
|
||||
mock_running, empty_entities_message,
|
||||
):
|
||||
instance = Mock()
|
||||
instance.send = AsyncMock()
|
||||
mock_publisher_class.return_value = instance
|
||||
|
||||
dispatcher = EntityContextsImport(
|
||||
ws=mock_websocket, running=mock_running,
|
||||
backend=mock_backend, queue="q",
|
||||
)
|
||||
|
||||
mock_msg = Mock()
|
||||
mock_msg.json.return_value = empty_entities_message
|
||||
|
||||
await dispatcher.receive(mock_msg)
|
||||
|
||||
instance.send.assert_called_once()
|
||||
sent = instance.send.call_args[0][1]
|
||||
assert isinstance(sent, EntityContexts)
|
||||
assert sent.entities == []
|
||||
assert sent.metadata.id == "doc-empty"
|
||||
|
||||
@patch('trustgraph.gateway.dispatch.entity_contexts_import.Publisher')
|
||||
@pytest.mark.asyncio
|
||||
async def test_receive_propagates_publisher_errors(
|
||||
self, mock_publisher_class, mock_backend, mock_websocket,
|
||||
mock_running, sample_message,
|
||||
):
|
||||
instance = Mock()
|
||||
instance.send = AsyncMock(side_effect=RuntimeError("publish failed"))
|
||||
mock_publisher_class.return_value = instance
|
||||
|
||||
dispatcher = EntityContextsImport(
|
||||
ws=mock_websocket, running=mock_running,
|
||||
backend=mock_backend, queue="q",
|
||||
)
|
||||
|
||||
mock_msg = Mock()
|
||||
mock_msg.json.return_value = sample_message
|
||||
|
||||
with pytest.raises(RuntimeError, match="publish failed"):
|
||||
await dispatcher.receive(mock_msg)
|
||||
|
|
@ -0,0 +1,247 @@
|
|||
"""
|
||||
Unit tests for graph embeddings import dispatcher.
|
||||
|
||||
Tests the business logic of GraphEmbeddingsImport while mocking the
|
||||
Publisher and websocket components.
|
||||
|
||||
Regression coverage: a previous version of EntityContextsImport
|
||||
constructed Metadata(metadata=...) which raised TypeError at runtime as
|
||||
soon as a message was received. The same shape of bug can occur here, so
|
||||
these tests exercise receive() end-to-end to catch any future schema or
|
||||
kwarg drift in Metadata / GraphEmbeddings / EntityEmbeddings construction.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import Mock, AsyncMock, patch
|
||||
|
||||
from trustgraph.gateway.dispatch.graph_embeddings_import import GraphEmbeddingsImport
|
||||
from trustgraph.schema import GraphEmbeddings, EntityEmbeddings, Metadata
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_backend():
|
||||
return Mock()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_running():
|
||||
running = Mock()
|
||||
running.get.return_value = True
|
||||
running.stop = Mock()
|
||||
return running
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_websocket():
|
||||
ws = Mock()
|
||||
ws.close = AsyncMock()
|
||||
return ws
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_message():
|
||||
"""Sample graph-embeddings websocket message."""
|
||||
return {
|
||||
"metadata": {
|
||||
"id": "doc-123",
|
||||
"user": "testuser",
|
||||
"collection": "testcollection",
|
||||
},
|
||||
"entities": [
|
||||
{
|
||||
"entity": {"v": "http://example.org/alice", "e": True},
|
||||
"vector": [0.1, 0.2, 0.3],
|
||||
},
|
||||
{
|
||||
"entity": {"v": "http://example.org/bob", "e": True},
|
||||
"vector": [0.4, 0.5, 0.6],
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def empty_entities_message():
|
||||
return {
|
||||
"metadata": {
|
||||
"id": "doc-empty",
|
||||
"user": "u",
|
||||
"collection": "c",
|
||||
},
|
||||
"entities": [],
|
||||
}
|
||||
|
||||
|
||||
class TestGraphEmbeddingsImportInitialization:
|
||||
|
||||
@patch('trustgraph.gateway.dispatch.graph_embeddings_import.Publisher')
|
||||
def test_init_creates_publisher_with_correct_params(
|
||||
self, mock_publisher_class, mock_backend, mock_websocket, mock_running
|
||||
):
|
||||
instance = Mock()
|
||||
mock_publisher_class.return_value = instance
|
||||
|
||||
dispatcher = GraphEmbeddingsImport(
|
||||
ws=mock_websocket,
|
||||
running=mock_running,
|
||||
backend=mock_backend,
|
||||
queue="ge-queue",
|
||||
)
|
||||
|
||||
mock_publisher_class.assert_called_once_with(
|
||||
mock_backend,
|
||||
topic="ge-queue",
|
||||
schema=GraphEmbeddings,
|
||||
)
|
||||
assert dispatcher.ws is mock_websocket
|
||||
assert dispatcher.running is mock_running
|
||||
assert dispatcher.publisher is instance
|
||||
|
||||
|
||||
class TestGraphEmbeddingsImportLifecycle:
|
||||
|
||||
@patch('trustgraph.gateway.dispatch.graph_embeddings_import.Publisher')
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_calls_publisher_start(
|
||||
self, mock_publisher_class, mock_backend, mock_websocket, mock_running
|
||||
):
|
||||
instance = Mock()
|
||||
instance.start = AsyncMock()
|
||||
mock_publisher_class.return_value = instance
|
||||
|
||||
dispatcher = GraphEmbeddingsImport(
|
||||
ws=mock_websocket, running=mock_running,
|
||||
backend=mock_backend, queue="q",
|
||||
)
|
||||
await dispatcher.start()
|
||||
instance.start.assert_called_once()
|
||||
|
||||
@patch('trustgraph.gateway.dispatch.graph_embeddings_import.Publisher')
|
||||
@pytest.mark.asyncio
|
||||
async def test_destroy_stops_and_closes_properly(
|
||||
self, mock_publisher_class, mock_backend, mock_websocket, mock_running
|
||||
):
|
||||
instance = Mock()
|
||||
instance.stop = AsyncMock()
|
||||
mock_publisher_class.return_value = instance
|
||||
|
||||
dispatcher = GraphEmbeddingsImport(
|
||||
ws=mock_websocket, running=mock_running,
|
||||
backend=mock_backend, queue="q",
|
||||
)
|
||||
await dispatcher.destroy()
|
||||
|
||||
mock_running.stop.assert_called_once()
|
||||
instance.stop.assert_called_once()
|
||||
mock_websocket.close.assert_called_once()
|
||||
|
||||
@patch('trustgraph.gateway.dispatch.graph_embeddings_import.Publisher')
|
||||
@pytest.mark.asyncio
|
||||
async def test_destroy_handles_none_websocket(
|
||||
self, mock_publisher_class, mock_backend, mock_running
|
||||
):
|
||||
instance = Mock()
|
||||
instance.stop = AsyncMock()
|
||||
mock_publisher_class.return_value = instance
|
||||
|
||||
dispatcher = GraphEmbeddingsImport(
|
||||
ws=None, running=mock_running,
|
||||
backend=mock_backend, queue="q",
|
||||
)
|
||||
await dispatcher.destroy()
|
||||
|
||||
mock_running.stop.assert_called_once()
|
||||
instance.stop.assert_called_once()
|
||||
|
||||
|
||||
class TestGraphEmbeddingsImportMessageProcessing:
|
||||
"""Regression coverage for receive(): catches Metadata/schema drift."""
|
||||
|
||||
@patch('trustgraph.gateway.dispatch.graph_embeddings_import.Publisher')
|
||||
@pytest.mark.asyncio
|
||||
async def test_receive_constructs_graph_embeddings_correctly(
|
||||
self, mock_publisher_class, mock_backend, mock_websocket,
|
||||
mock_running, sample_message,
|
||||
):
|
||||
instance = Mock()
|
||||
instance.send = AsyncMock()
|
||||
mock_publisher_class.return_value = instance
|
||||
|
||||
dispatcher = GraphEmbeddingsImport(
|
||||
ws=mock_websocket, running=mock_running,
|
||||
backend=mock_backend, queue="q",
|
||||
)
|
||||
|
||||
mock_msg = Mock()
|
||||
mock_msg.json.return_value = sample_message
|
||||
|
||||
# If Metadata, GraphEmbeddings, or EntityEmbeddings gain/lose
|
||||
# kwargs, this raises TypeError — exactly the regression we want
|
||||
# to catch.
|
||||
await dispatcher.receive(mock_msg)
|
||||
|
||||
instance.send.assert_called_once()
|
||||
call_args = instance.send.call_args
|
||||
assert call_args[0][0] is None
|
||||
|
||||
sent = call_args[0][1]
|
||||
assert isinstance(sent, GraphEmbeddings)
|
||||
assert isinstance(sent.metadata, Metadata)
|
||||
assert sent.metadata.id == "doc-123"
|
||||
assert sent.metadata.user == "testuser"
|
||||
assert sent.metadata.collection == "testcollection"
|
||||
|
||||
assert len(sent.entities) == 2
|
||||
assert all(isinstance(e, EntityEmbeddings) for e in sent.entities)
|
||||
# Lock in the wire format: incoming "vector" key (singular,
|
||||
# list[float]) maps to EntityEmbeddings.vector. This mirrors
|
||||
# serialize_graph_embeddings() on the export side.
|
||||
assert sent.entities[0].vector == [0.1, 0.2, 0.3]
|
||||
assert sent.entities[1].vector == [0.4, 0.5, 0.6]
|
||||
|
||||
@patch('trustgraph.gateway.dispatch.graph_embeddings_import.Publisher')
|
||||
@pytest.mark.asyncio
|
||||
async def test_receive_handles_empty_entities(
|
||||
self, mock_publisher_class, mock_backend, mock_websocket,
|
||||
mock_running, empty_entities_message,
|
||||
):
|
||||
instance = Mock()
|
||||
instance.send = AsyncMock()
|
||||
mock_publisher_class.return_value = instance
|
||||
|
||||
dispatcher = GraphEmbeddingsImport(
|
||||
ws=mock_websocket, running=mock_running,
|
||||
backend=mock_backend, queue="q",
|
||||
)
|
||||
|
||||
mock_msg = Mock()
|
||||
mock_msg.json.return_value = empty_entities_message
|
||||
|
||||
await dispatcher.receive(mock_msg)
|
||||
|
||||
instance.send.assert_called_once()
|
||||
sent = instance.send.call_args[0][1]
|
||||
assert isinstance(sent, GraphEmbeddings)
|
||||
assert sent.entities == []
|
||||
assert sent.metadata.id == "doc-empty"
|
||||
|
||||
@patch('trustgraph.gateway.dispatch.graph_embeddings_import.Publisher')
|
||||
@pytest.mark.asyncio
|
||||
async def test_receive_propagates_publisher_errors(
|
||||
self, mock_publisher_class, mock_backend, mock_websocket,
|
||||
mock_running, sample_message,
|
||||
):
|
||||
instance = Mock()
|
||||
instance.send = AsyncMock(side_effect=RuntimeError("publish failed"))
|
||||
mock_publisher_class.return_value = instance
|
||||
|
||||
dispatcher = GraphEmbeddingsImport(
|
||||
ws=mock_websocket, running=mock_running,
|
||||
backend=mock_backend, queue="q",
|
||||
)
|
||||
|
||||
mock_msg = Mock()
|
||||
mock_msg.json.return_value = sample_message
|
||||
|
||||
with pytest.raises(RuntimeError, match="publish failed"):
|
||||
await dispatcher.receive(mock_msg)
|
||||
|
|
@ -171,6 +171,14 @@ class TestApi:
|
|||
patch('aiohttp.web.run_app') as mock_run_app:
|
||||
mock_get_pubsub.return_value = Mock()
|
||||
|
||||
# Api.run() passes self.app_factory() — a coroutine — to
|
||||
# web.run_app, which would normally consume it inside its own
|
||||
# event loop. Since we mock run_app, close the coroutine here
|
||||
# so it doesn't leak as an "unawaited coroutine" RuntimeWarning.
|
||||
def _consume_coro(coro, **kwargs):
|
||||
coro.close()
|
||||
mock_run_app.side_effect = _consume_coro
|
||||
|
||||
api = Api(port=8080)
|
||||
api.run()
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue