diff --git a/tests/contract/test_schema_field_contracts.py b/tests/contract/test_schema_field_contracts.py new file mode 100644 index 00000000..4b7c3da5 --- /dev/null +++ b/tests/contract/test_schema_field_contracts.py @@ -0,0 +1,73 @@ +""" +Contract tests for schema dataclass field sets. + +These pin the *field names* of small, widely-constructed schema dataclasses +so that any rename, removal, or accidental addition fails CI loudly instead +of waiting for a runtime TypeError on the next websocket message. + +Background: in v2.2 the `Metadata` dataclass dropped a `metadata: list[Triple]` +field but several call sites kept passing `Metadata(metadata=...)`. The bug +was only discovered when a websocket import dispatcher received its first +real message in production. A trivial structural assertion of the kind +below would have caught it at unit-test time. + +Add to this file whenever a schema rename burns you. The cost of a frozen +field set is a one-line update when you intentionally evolve the schema; the +benefit is that every call site is forced to come along for the ride. +""" + +import dataclasses +import pytest + +from trustgraph.schema import ( + Metadata, + EntityContext, + EntityEmbeddings, + ChunkEmbeddings, +) + + +def _field_names(dc): + return {f.name for f in dataclasses.fields(dc)} + + +@pytest.mark.contract +class TestSchemaFieldContracts: + """Pin the field set of dataclasses that get constructed all over the + codebase. If you intentionally change one of these, update the + expected set in the same commit — that diff will surface every call + site that needs to come along.""" + + def test_metadata_fields(self): + # NOTE: there is no `metadata` field. A previous regression + # constructed Metadata(metadata=...) and crashed at runtime. + assert _field_names(Metadata) == { + "id", + "root", + "user", + "collection", + } + + def test_entity_embeddings_fields(self): + # NOTE: the embedding field is `vector` (singular, list[float]). + # There is no `vectors` field. Several call sites historically + # passed `vectors=` and crashed at runtime. + assert _field_names(EntityEmbeddings) == { + "entity", + "vector", + "chunk_id", + } + + def test_chunk_embeddings_fields(self): + # Same `vector` (singular) convention as EntityEmbeddings. + assert _field_names(ChunkEmbeddings) == { + "chunk_id", + "vector", + } + + def test_entity_context_fields(self): + assert _field_names(EntityContext) == { + "entity", + "context", + "chunk_id", + } diff --git a/tests/unit/test_gateway/test_core_import_export_roundtrip.py b/tests/unit/test_gateway/test_core_import_export_roundtrip.py new file mode 100644 index 00000000..843a2b7b --- /dev/null +++ b/tests/unit/test_gateway/test_core_import_export_roundtrip.py @@ -0,0 +1,418 @@ +""" +Round-trip unit tests for the core msgpack import/export gateway endpoints. + +The kg-core export endpoint receives KnowledgeResponse-shaped dicts from +the responder callback and packs them into msgpack tuples. The kg-core +import endpoint takes msgpack tuples back off the wire and rebuilds +KnowledgeRequest-shaped dicts which it then hands to KnowledgeRequestor +(whose translator decodes them into real dataclasses). + +Regression coverage: the previous wire format used `"vectors"` (plural) +in the entity blobs and embedded a stale `"m"` field that referenced the +removed `Metadata.metadata` triples-list field. The export side hit a +KeyError on first message; the import side built dicts that the +KnowledgeRequestTranslator (separately fixed) couldn't decode. These +tests pin both halves of the wire protocol. +""" + +import msgpack +import pytest +from unittest.mock import AsyncMock, Mock, patch + +from trustgraph.gateway.dispatch.core_export import CoreExport +from trustgraph.gateway.dispatch.core_import import CoreImport + + +# --------------------------------------------------------------------------- +# Helpers — sample translator-shaped dicts (as KnowledgeResponseTranslator +# would emit). The vector wire key is *singular* on purpose; the export +# side previously read the wrong key and crashed. +# --------------------------------------------------------------------------- + + +def _ge_response_dict(): + return { + "graph-embeddings": { + "metadata": { + "id": "doc-1", + "root": "", + "user": "alice", + "collection": "testcoll", + }, + "entities": [ + { + "entity": {"t": "i", "i": "http://example.org/alice"}, + "vector": [0.1, 0.2, 0.3], + }, + { + "entity": {"t": "i", "i": "http://example.org/bob"}, + "vector": [0.4, 0.5, 0.6], + }, + ], + } + } + + +def _triples_response_dict(): + return { + "triples": { + "metadata": { + "id": "doc-1", + "root": "", + "user": "alice", + "collection": "testcoll", + }, + "triples": [ + { + "s": {"t": "i", "i": "http://example.org/alice"}, + "p": {"t": "i", "i": "http://example.org/knows"}, + "o": {"t": "i", "i": "http://example.org/bob"}, + }, + ], + } + } + + +def _make_request(id_="doc-1", user="alice"): + request = Mock() + request.query = {"id": id_, "user": user} + return request + + +def _make_data_reader(payload: bytes): + """Mock the aiohttp StreamReader: returns payload once, then EOF.""" + chunks = [payload, b""] + + data = Mock() + + async def fake_read(n): + return chunks.pop(0) if chunks else b"" + + data.read = fake_read + return data + + +# --------------------------------------------------------------------------- +# Export side: translator-shaped dict -> msgpack bytes +# --------------------------------------------------------------------------- + + +class TestCoreExportWireFormat: + + @pytest.mark.asyncio + @patch("trustgraph.gateway.dispatch.core_export.KnowledgeRequestor") + async def test_export_packs_graph_embeddings_with_singular_vector( + self, mock_kr_class, + ): + """The export side must read `ent["vector"]` and emit `v`. The + previous bug was reading `ent["vectors"]` which KeyErrored against + the translator output.""" + captured = [] + + async def fake_kr_process(req_dict, responder): + await responder(_ge_response_dict(), True) + + mock_kr = AsyncMock() + mock_kr.start = AsyncMock() + mock_kr.stop = AsyncMock() + mock_kr.process = fake_kr_process + mock_kr_class.return_value = mock_kr + + response = AsyncMock() + + async def fake_write(b): + captured.append(b) + + response.write = fake_write + response.write_eof = AsyncMock() + + ok = AsyncMock(return_value=response) + error = AsyncMock() + + exporter = CoreExport(backend=Mock()) + await exporter.process( + data=Mock(), + error=error, + ok=ok, + request=_make_request(), + ) + + # Did not raise, did not call error() + error.assert_not_called() + assert len(captured) == 1 + + unpacker = msgpack.Unpacker() + unpacker.feed(captured[0]) + items = list(unpacker) + + assert len(items) == 1 + msg_type, payload = items[0] + assert msg_type == "ge" + + # Metadata envelope: only id/user/collection — no stale `m["m"]`. + assert payload["m"] == { + "i": "doc-1", + "u": "alice", + "c": "testcoll", + } + + # Entities: each carries the *singular* `v` and the term envelope + assert len(payload["e"]) == 2 + assert payload["e"][0]["v"] == [0.1, 0.2, 0.3] + assert payload["e"][1]["v"] == [0.4, 0.5, 0.6] + assert payload["e"][0]["e"]["i"] == "http://example.org/alice" + + @pytest.mark.asyncio + @patch("trustgraph.gateway.dispatch.core_export.KnowledgeRequestor") + async def test_export_packs_triples(self, mock_kr_class): + captured = [] + + async def fake_kr_process(req_dict, responder): + await responder(_triples_response_dict(), True) + + mock_kr = AsyncMock() + mock_kr.start = AsyncMock() + mock_kr.stop = AsyncMock() + mock_kr.process = fake_kr_process + mock_kr_class.return_value = mock_kr + + response = AsyncMock() + + async def fake_write(b): + captured.append(b) + + response.write = fake_write + response.write_eof = AsyncMock() + + ok = AsyncMock(return_value=response) + error = AsyncMock() + + exporter = CoreExport(backend=Mock()) + await exporter.process( + data=Mock(), error=error, ok=ok, request=_make_request(), + ) + + error.assert_not_called() + assert len(captured) == 1 + + unpacker = msgpack.Unpacker() + unpacker.feed(captured[0]) + items = list(unpacker) + assert len(items) == 1 + + msg_type, payload = items[0] + assert msg_type == "t" + assert payload["m"] == { + "i": "doc-1", + "u": "alice", + "c": "testcoll", + } + assert len(payload["t"]) == 1 + + +# --------------------------------------------------------------------------- +# Import side: msgpack bytes -> translator-shaped dict +# --------------------------------------------------------------------------- + + +class TestCoreImportWireFormat: + + @pytest.mark.asyncio + @patch("trustgraph.gateway.dispatch.core_import.KnowledgeRequestor") + async def test_import_unpacks_graph_embeddings_to_singular_vector( + self, mock_kr_class, + ): + """The import side must build dicts whose entity blobs have the + singular `vector` key — that's what the KnowledgeRequestTranslator + decode side reads. Previous bug emitted `vectors`.""" + captured = [] + + async def fake_kr_process(req_dict): + captured.append(req_dict) + + mock_kr = AsyncMock() + mock_kr.start = AsyncMock() + mock_kr.stop = AsyncMock() + mock_kr.process = fake_kr_process + mock_kr_class.return_value = mock_kr + + # Build a msgpack tuple matching the new wire format + payload = msgpack.packb(( + "ge", + { + "m": {"i": "doc-1", "u": "alice", "c": "testcoll"}, + "e": [ + { + "e": {"t": "i", "i": "http://example.org/alice"}, + "v": [0.1, 0.2, 0.3], + }, + ], + }, + )) + + ok = AsyncMock(return_value=AsyncMock(write_eof=AsyncMock())) + error = AsyncMock() + + importer = CoreImport(backend=Mock()) + await importer.process( + data=_make_data_reader(payload), + error=error, + ok=ok, + request=_make_request(), + ) + + error.assert_not_called() + assert len(captured) == 1 + + req = captured[0] + assert req["operation"] == "put-kg-core" + assert req["user"] == "alice" + assert req["id"] == "doc-1" + + ge = req["graph-embeddings"] + # Metadata envelope must NOT contain a stale `metadata` key + # referencing the removed Metadata.metadata field. + assert "metadata" not in ge["metadata"] + assert ge["metadata"] == { + "id": "doc-1", + "user": "alice", + "collection": "default", + } + + # Entity blob carries the singular `vector` key + assert len(ge["entities"]) == 1 + ent = ge["entities"][0] + assert ent["vector"] == [0.1, 0.2, 0.3] + assert "vectors" not in ent + + @pytest.mark.asyncio + @patch("trustgraph.gateway.dispatch.core_import.KnowledgeRequestor") + async def test_import_unpacks_triples(self, mock_kr_class): + captured = [] + + async def fake_kr_process(req_dict): + captured.append(req_dict) + + mock_kr = AsyncMock() + mock_kr.start = AsyncMock() + mock_kr.stop = AsyncMock() + mock_kr.process = fake_kr_process + mock_kr_class.return_value = mock_kr + + payload = msgpack.packb(( + "t", + { + "m": {"i": "doc-1", "u": "alice", "c": "testcoll"}, + "t": [ + { + "s": {"t": "i", "i": "http://example.org/alice"}, + "p": {"t": "i", "i": "http://example.org/knows"}, + "o": {"t": "i", "i": "http://example.org/bob"}, + }, + ], + }, + )) + + ok = AsyncMock(return_value=AsyncMock(write_eof=AsyncMock())) + error = AsyncMock() + + importer = CoreImport(backend=Mock()) + await importer.process( + data=_make_data_reader(payload), + error=error, + ok=ok, + request=_make_request(), + ) + + error.assert_not_called() + assert len(captured) == 1 + + req = captured[0] + triples = req["triples"] + assert "metadata" not in triples["metadata"] # no stale field + assert len(triples["triples"]) == 1 + + +# --------------------------------------------------------------------------- +# Full round-trip: export bytes feed directly into import +# --------------------------------------------------------------------------- + + +class TestCoreImportExportRoundTrip: + """End-to-end: produce bytes via core_export, consume them via + core_import, and verify the dict that lands at the import-side + translator is structurally equivalent to what went in. This is the + test that catches asymmetries between the two halves.""" + + @pytest.mark.asyncio + @patch("trustgraph.gateway.dispatch.core_import.KnowledgeRequestor") + @patch("trustgraph.gateway.dispatch.core_export.KnowledgeRequestor") + async def test_graph_embeddings_round_trip( + self, mock_export_kr_class, mock_import_kr_class, + ): + # ----- export side: capture bytes ----- + export_bytes = [] + + async def fake_export_process(req_dict, responder): + await responder(_ge_response_dict(), True) + + export_kr = AsyncMock() + export_kr.start = AsyncMock() + export_kr.stop = AsyncMock() + export_kr.process = fake_export_process + mock_export_kr_class.return_value = export_kr + + response = AsyncMock() + + async def fake_write(b): + export_bytes.append(b) + + response.write = fake_write + response.write_eof = AsyncMock() + + exporter = CoreExport(backend=Mock()) + await exporter.process( + data=Mock(), + error=AsyncMock(), + ok=AsyncMock(return_value=response), + request=_make_request(), + ) + + assert len(export_bytes) == 1 + + # ----- import side: feed those bytes back in ----- + import_captured = [] + + async def fake_import_process(req_dict): + import_captured.append(req_dict) + + import_kr = AsyncMock() + import_kr.start = AsyncMock() + import_kr.stop = AsyncMock() + import_kr.process = fake_import_process + mock_import_kr_class.return_value = import_kr + + importer = CoreImport(backend=Mock()) + await importer.process( + data=_make_data_reader(export_bytes[0]), + error=AsyncMock(), + ok=AsyncMock(return_value=AsyncMock(write_eof=AsyncMock())), + request=_make_request(), + ) + + # ----- verify the dict the importer would hand to the translator ----- + assert len(import_captured) == 1 + req = import_captured[0] + + original = _ge_response_dict()["graph-embeddings"] + + ge = req["graph-embeddings"] + # The import side overrides id/user from the URL query (intentional), + # so we only round-trip the entity payload itself. + assert ge["metadata"]["id"] == original["metadata"]["id"] + assert ge["metadata"]["user"] == original["metadata"]["user"] + + assert len(ge["entities"]) == len(original["entities"]) + for got, want in zip(ge["entities"], original["entities"]): + assert got["vector"] == want["vector"] + assert got["entity"] == want["entity"] diff --git a/tests/unit/test_gateway/test_entity_contexts_import_dispatcher.py b/tests/unit/test_gateway/test_entity_contexts_import_dispatcher.py new file mode 100644 index 00000000..8eddeba9 --- /dev/null +++ b/tests/unit/test_gateway/test_entity_contexts_import_dispatcher.py @@ -0,0 +1,242 @@ +""" +Unit tests for entity contexts import dispatcher. + +Tests the business logic of EntityContextsImport while mocking the +Publisher and websocket components. + +Regression coverage: a previous version constructed Metadata(metadata=...) +which raised TypeError at runtime as soon as a message was received. These +tests exercise receive() end-to-end so any future schema/kwarg drift in +the Metadata or EntityContexts construction is caught immediately. +""" + +import pytest +from unittest.mock import Mock, AsyncMock, patch + +from trustgraph.gateway.dispatch.entity_contexts_import import EntityContextsImport +from trustgraph.schema import EntityContexts, EntityContext, Metadata + + +@pytest.fixture +def mock_backend(): + return Mock() + + +@pytest.fixture +def mock_running(): + running = Mock() + running.get.return_value = True + running.stop = Mock() + return running + + +@pytest.fixture +def mock_websocket(): + ws = Mock() + ws.close = AsyncMock() + return ws + + +@pytest.fixture +def sample_message(): + """Sample entity-contexts websocket message.""" + return { + "metadata": { + "id": "doc-123", + "user": "testuser", + "collection": "testcollection", + }, + "entities": [ + { + "entity": {"v": "http://example.org/alice", "e": True}, + "context": "Alice is a person.", + }, + { + "entity": {"v": "http://example.org/bob", "e": True}, + "context": "Bob is a person.", + }, + ], + } + + +@pytest.fixture +def empty_entities_message(): + return { + "metadata": { + "id": "doc-empty", + "user": "u", + "collection": "c", + }, + "entities": [], + } + + +class TestEntityContextsImportInitialization: + + @patch('trustgraph.gateway.dispatch.entity_contexts_import.Publisher') + def test_init_creates_publisher_with_correct_params( + self, mock_publisher_class, mock_backend, mock_websocket, mock_running + ): + instance = Mock() + mock_publisher_class.return_value = instance + + dispatcher = EntityContextsImport( + ws=mock_websocket, + running=mock_running, + backend=mock_backend, + queue="ec-queue", + ) + + mock_publisher_class.assert_called_once_with( + mock_backend, + topic="ec-queue", + schema=EntityContexts, + ) + assert dispatcher.ws is mock_websocket + assert dispatcher.running is mock_running + assert dispatcher.publisher is instance + + +class TestEntityContextsImportLifecycle: + + @patch('trustgraph.gateway.dispatch.entity_contexts_import.Publisher') + @pytest.mark.asyncio + async def test_start_calls_publisher_start( + self, mock_publisher_class, mock_backend, mock_websocket, mock_running + ): + instance = Mock() + instance.start = AsyncMock() + mock_publisher_class.return_value = instance + + dispatcher = EntityContextsImport( + ws=mock_websocket, running=mock_running, + backend=mock_backend, queue="q", + ) + await dispatcher.start() + instance.start.assert_called_once() + + @patch('trustgraph.gateway.dispatch.entity_contexts_import.Publisher') + @pytest.mark.asyncio + async def test_destroy_stops_and_closes_properly( + self, mock_publisher_class, mock_backend, mock_websocket, mock_running + ): + instance = Mock() + instance.stop = AsyncMock() + mock_publisher_class.return_value = instance + + dispatcher = EntityContextsImport( + ws=mock_websocket, running=mock_running, + backend=mock_backend, queue="q", + ) + await dispatcher.destroy() + + mock_running.stop.assert_called_once() + instance.stop.assert_called_once() + mock_websocket.close.assert_called_once() + + @patch('trustgraph.gateway.dispatch.entity_contexts_import.Publisher') + @pytest.mark.asyncio + async def test_destroy_handles_none_websocket( + self, mock_publisher_class, mock_backend, mock_running + ): + instance = Mock() + instance.stop = AsyncMock() + mock_publisher_class.return_value = instance + + dispatcher = EntityContextsImport( + ws=None, running=mock_running, + backend=mock_backend, queue="q", + ) + await dispatcher.destroy() + + mock_running.stop.assert_called_once() + instance.stop.assert_called_once() + + +class TestEntityContextsImportMessageProcessing: + """Regression coverage for receive(): catches Metadata/schema drift.""" + + @patch('trustgraph.gateway.dispatch.entity_contexts_import.Publisher') + @pytest.mark.asyncio + async def test_receive_constructs_entity_contexts_correctly( + self, mock_publisher_class, mock_backend, mock_websocket, + mock_running, sample_message, + ): + instance = Mock() + instance.send = AsyncMock() + mock_publisher_class.return_value = instance + + dispatcher = EntityContextsImport( + ws=mock_websocket, running=mock_running, + backend=mock_backend, queue="q", + ) + + mock_msg = Mock() + mock_msg.json.return_value = sample_message + + # If Metadata or EntityContexts gain/lose kwargs, this raises + # TypeError — exactly the regression we want to catch. + await dispatcher.receive(mock_msg) + + instance.send.assert_called_once() + call_args = instance.send.call_args + assert call_args[0][0] is None + + sent = call_args[0][1] + assert isinstance(sent, EntityContexts) + assert isinstance(sent.metadata, Metadata) + assert sent.metadata.id == "doc-123" + assert sent.metadata.user == "testuser" + assert sent.metadata.collection == "testcollection" + + assert len(sent.entities) == 2 + assert all(isinstance(e, EntityContext) for e in sent.entities) + assert sent.entities[0].context == "Alice is a person." + assert sent.entities[1].context == "Bob is a person." + + @patch('trustgraph.gateway.dispatch.entity_contexts_import.Publisher') + @pytest.mark.asyncio + async def test_receive_handles_empty_entities( + self, mock_publisher_class, mock_backend, mock_websocket, + mock_running, empty_entities_message, + ): + instance = Mock() + instance.send = AsyncMock() + mock_publisher_class.return_value = instance + + dispatcher = EntityContextsImport( + ws=mock_websocket, running=mock_running, + backend=mock_backend, queue="q", + ) + + mock_msg = Mock() + mock_msg.json.return_value = empty_entities_message + + await dispatcher.receive(mock_msg) + + instance.send.assert_called_once() + sent = instance.send.call_args[0][1] + assert isinstance(sent, EntityContexts) + assert sent.entities == [] + assert sent.metadata.id == "doc-empty" + + @patch('trustgraph.gateway.dispatch.entity_contexts_import.Publisher') + @pytest.mark.asyncio + async def test_receive_propagates_publisher_errors( + self, mock_publisher_class, mock_backend, mock_websocket, + mock_running, sample_message, + ): + instance = Mock() + instance.send = AsyncMock(side_effect=RuntimeError("publish failed")) + mock_publisher_class.return_value = instance + + dispatcher = EntityContextsImport( + ws=mock_websocket, running=mock_running, + backend=mock_backend, queue="q", + ) + + mock_msg = Mock() + mock_msg.json.return_value = sample_message + + with pytest.raises(RuntimeError, match="publish failed"): + await dispatcher.receive(mock_msg) diff --git a/tests/unit/test_gateway/test_graph_embeddings_import_dispatcher.py b/tests/unit/test_gateway/test_graph_embeddings_import_dispatcher.py new file mode 100644 index 00000000..fa277178 --- /dev/null +++ b/tests/unit/test_gateway/test_graph_embeddings_import_dispatcher.py @@ -0,0 +1,247 @@ +""" +Unit tests for graph embeddings import dispatcher. + +Tests the business logic of GraphEmbeddingsImport while mocking the +Publisher and websocket components. + +Regression coverage: a previous version of EntityContextsImport +constructed Metadata(metadata=...) which raised TypeError at runtime as +soon as a message was received. The same shape of bug can occur here, so +these tests exercise receive() end-to-end to catch any future schema or +kwarg drift in Metadata / GraphEmbeddings / EntityEmbeddings construction. +""" + +import pytest +from unittest.mock import Mock, AsyncMock, patch + +from trustgraph.gateway.dispatch.graph_embeddings_import import GraphEmbeddingsImport +from trustgraph.schema import GraphEmbeddings, EntityEmbeddings, Metadata + + +@pytest.fixture +def mock_backend(): + return Mock() + + +@pytest.fixture +def mock_running(): + running = Mock() + running.get.return_value = True + running.stop = Mock() + return running + + +@pytest.fixture +def mock_websocket(): + ws = Mock() + ws.close = AsyncMock() + return ws + + +@pytest.fixture +def sample_message(): + """Sample graph-embeddings websocket message.""" + return { + "metadata": { + "id": "doc-123", + "user": "testuser", + "collection": "testcollection", + }, + "entities": [ + { + "entity": {"v": "http://example.org/alice", "e": True}, + "vector": [0.1, 0.2, 0.3], + }, + { + "entity": {"v": "http://example.org/bob", "e": True}, + "vector": [0.4, 0.5, 0.6], + }, + ], + } + + +@pytest.fixture +def empty_entities_message(): + return { + "metadata": { + "id": "doc-empty", + "user": "u", + "collection": "c", + }, + "entities": [], + } + + +class TestGraphEmbeddingsImportInitialization: + + @patch('trustgraph.gateway.dispatch.graph_embeddings_import.Publisher') + def test_init_creates_publisher_with_correct_params( + self, mock_publisher_class, mock_backend, mock_websocket, mock_running + ): + instance = Mock() + mock_publisher_class.return_value = instance + + dispatcher = GraphEmbeddingsImport( + ws=mock_websocket, + running=mock_running, + backend=mock_backend, + queue="ge-queue", + ) + + mock_publisher_class.assert_called_once_with( + mock_backend, + topic="ge-queue", + schema=GraphEmbeddings, + ) + assert dispatcher.ws is mock_websocket + assert dispatcher.running is mock_running + assert dispatcher.publisher is instance + + +class TestGraphEmbeddingsImportLifecycle: + + @patch('trustgraph.gateway.dispatch.graph_embeddings_import.Publisher') + @pytest.mark.asyncio + async def test_start_calls_publisher_start( + self, mock_publisher_class, mock_backend, mock_websocket, mock_running + ): + instance = Mock() + instance.start = AsyncMock() + mock_publisher_class.return_value = instance + + dispatcher = GraphEmbeddingsImport( + ws=mock_websocket, running=mock_running, + backend=mock_backend, queue="q", + ) + await dispatcher.start() + instance.start.assert_called_once() + + @patch('trustgraph.gateway.dispatch.graph_embeddings_import.Publisher') + @pytest.mark.asyncio + async def test_destroy_stops_and_closes_properly( + self, mock_publisher_class, mock_backend, mock_websocket, mock_running + ): + instance = Mock() + instance.stop = AsyncMock() + mock_publisher_class.return_value = instance + + dispatcher = GraphEmbeddingsImport( + ws=mock_websocket, running=mock_running, + backend=mock_backend, queue="q", + ) + await dispatcher.destroy() + + mock_running.stop.assert_called_once() + instance.stop.assert_called_once() + mock_websocket.close.assert_called_once() + + @patch('trustgraph.gateway.dispatch.graph_embeddings_import.Publisher') + @pytest.mark.asyncio + async def test_destroy_handles_none_websocket( + self, mock_publisher_class, mock_backend, mock_running + ): + instance = Mock() + instance.stop = AsyncMock() + mock_publisher_class.return_value = instance + + dispatcher = GraphEmbeddingsImport( + ws=None, running=mock_running, + backend=mock_backend, queue="q", + ) + await dispatcher.destroy() + + mock_running.stop.assert_called_once() + instance.stop.assert_called_once() + + +class TestGraphEmbeddingsImportMessageProcessing: + """Regression coverage for receive(): catches Metadata/schema drift.""" + + @patch('trustgraph.gateway.dispatch.graph_embeddings_import.Publisher') + @pytest.mark.asyncio + async def test_receive_constructs_graph_embeddings_correctly( + self, mock_publisher_class, mock_backend, mock_websocket, + mock_running, sample_message, + ): + instance = Mock() + instance.send = AsyncMock() + mock_publisher_class.return_value = instance + + dispatcher = GraphEmbeddingsImport( + ws=mock_websocket, running=mock_running, + backend=mock_backend, queue="q", + ) + + mock_msg = Mock() + mock_msg.json.return_value = sample_message + + # If Metadata, GraphEmbeddings, or EntityEmbeddings gain/lose + # kwargs, this raises TypeError — exactly the regression we want + # to catch. + await dispatcher.receive(mock_msg) + + instance.send.assert_called_once() + call_args = instance.send.call_args + assert call_args[0][0] is None + + sent = call_args[0][1] + assert isinstance(sent, GraphEmbeddings) + assert isinstance(sent.metadata, Metadata) + assert sent.metadata.id == "doc-123" + assert sent.metadata.user == "testuser" + assert sent.metadata.collection == "testcollection" + + assert len(sent.entities) == 2 + assert all(isinstance(e, EntityEmbeddings) for e in sent.entities) + # Lock in the wire format: incoming "vector" key (singular, + # list[float]) maps to EntityEmbeddings.vector. This mirrors + # serialize_graph_embeddings() on the export side. + assert sent.entities[0].vector == [0.1, 0.2, 0.3] + assert sent.entities[1].vector == [0.4, 0.5, 0.6] + + @patch('trustgraph.gateway.dispatch.graph_embeddings_import.Publisher') + @pytest.mark.asyncio + async def test_receive_handles_empty_entities( + self, mock_publisher_class, mock_backend, mock_websocket, + mock_running, empty_entities_message, + ): + instance = Mock() + instance.send = AsyncMock() + mock_publisher_class.return_value = instance + + dispatcher = GraphEmbeddingsImport( + ws=mock_websocket, running=mock_running, + backend=mock_backend, queue="q", + ) + + mock_msg = Mock() + mock_msg.json.return_value = empty_entities_message + + await dispatcher.receive(mock_msg) + + instance.send.assert_called_once() + sent = instance.send.call_args[0][1] + assert isinstance(sent, GraphEmbeddings) + assert sent.entities == [] + assert sent.metadata.id == "doc-empty" + + @patch('trustgraph.gateway.dispatch.graph_embeddings_import.Publisher') + @pytest.mark.asyncio + async def test_receive_propagates_publisher_errors( + self, mock_publisher_class, mock_backend, mock_websocket, + mock_running, sample_message, + ): + instance = Mock() + instance.send = AsyncMock(side_effect=RuntimeError("publish failed")) + mock_publisher_class.return_value = instance + + dispatcher = GraphEmbeddingsImport( + ws=mock_websocket, running=mock_running, + backend=mock_backend, queue="q", + ) + + mock_msg = Mock() + mock_msg.json.return_value = sample_message + + with pytest.raises(RuntimeError, match="publish failed"): + await dispatcher.receive(mock_msg) diff --git a/tests/unit/test_gateway/test_service.py b/tests/unit/test_gateway/test_service.py index 22d9ab04..71428db4 100644 --- a/tests/unit/test_gateway/test_service.py +++ b/tests/unit/test_gateway/test_service.py @@ -171,6 +171,14 @@ class TestApi: patch('aiohttp.web.run_app') as mock_run_app: mock_get_pubsub.return_value = Mock() + # Api.run() passes self.app_factory() — a coroutine — to + # web.run_app, which would normally consume it inside its own + # event loop. Since we mock run_app, close the coroutine here + # so it doesn't leak as an "unawaited coroutine" RuntimeWarning. + def _consume_coro(coro, **kwargs): + coro.close() + mock_run_app.side_effect = _consume_coro + api = Api(port=8080) api.run() diff --git a/tests/unit/test_tables/__init__.py b/tests/unit/test_tables/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unit/test_tables/test_knowledge_table_store.py b/tests/unit/test_tables/test_knowledge_table_store.py new file mode 100644 index 00000000..5129b01e --- /dev/null +++ b/tests/unit/test_tables/test_knowledge_table_store.py @@ -0,0 +1,197 @@ +""" +Unit tests for KnowledgeTableStore row deserialization. + +Regression coverage: a previous version of get_graph_embeddings constructed +EntityEmbeddings(vectors=ent[1]) — the schema field is `vector` (singular), +so any real Cassandra row would crash on read. These tests bypass the live +Cassandra connection entirely and exercise the row -> schema conversion +with hand-built fake rows. +""" + +import pytest +from unittest.mock import Mock + +from trustgraph.tables.knowledge import KnowledgeTableStore +from trustgraph.schema import ( + EntityEmbeddings, + GraphEmbeddings, + Triples, + Triple, + Metadata, + IRI, + LITERAL, +) + + +def _make_store(): + """ + Build a KnowledgeTableStore without invoking __init__ (which connects + to Cassandra). Tests inject only the attributes the method under test + actually touches. + """ + return KnowledgeTableStore.__new__(KnowledgeTableStore) + + +class TestGetGraphEmbeddings: + + @pytest.mark.asyncio + async def test_row_converts_to_entity_embeddings_with_singular_vector(self): + """ + Cassandra rows return entities as a list of [entity_tuple, vector] + pairs in row[3]. The deserializer must construct EntityEmbeddings + with `vector=` (singular) — the schema field name. A previous + version used `vectors=` and TypeError'd at runtime. + """ + # Arrange — fake row matching the get_triples_stmt result shape: + # row[0..2] are unused by the method, row[3] is the entities blob + fake_row = ( + None, None, None, + [ + # ((value, is_uri), vector) + (("http://example.org/alice", True), [0.1, 0.2, 0.3]), + (("http://example.org/bob", True), [0.4, 0.5, 0.6]), + (("a literal entity", False), [0.7, 0.8, 0.9]), + ], + ) + + store = _make_store() + store.cassandra = Mock() + store.cassandra.execute = Mock(return_value=[fake_row]) + store.get_graph_embeddings_stmt = Mock() + + received = [] + + async def receiver(msg): + received.append(msg) + + # Act + await store.get_graph_embeddings( + user="alice", + document_id="doc-1", + receiver=receiver, + ) + + # Assert + store.cassandra.execute.assert_called_once_with( + store.get_graph_embeddings_stmt, + ("alice", "doc-1"), + ) + + assert len(received) == 1 + ge = received[0] + assert isinstance(ge, GraphEmbeddings) + assert isinstance(ge.metadata, Metadata) + assert ge.metadata.id == "doc-1" + assert ge.metadata.user == "alice" + + assert len(ge.entities) == 3 + assert all(isinstance(e, EntityEmbeddings) for e in ge.entities) + + # Vectors land in the singular `vector` field — this is the + # explicit regression assertion for the original bug. + assert ge.entities[0].vector == [0.1, 0.2, 0.3] + assert ge.entities[1].vector == [0.4, 0.5, 0.6] + assert ge.entities[2].vector == [0.7, 0.8, 0.9] + + # Term type round-trips through tuple_to_term + assert ge.entities[0].entity.type == IRI + assert ge.entities[0].entity.iri == "http://example.org/alice" + assert ge.entities[1].entity.type == IRI + assert ge.entities[1].entity.iri == "http://example.org/bob" + assert ge.entities[2].entity.type == LITERAL + assert ge.entities[2].entity.value == "a literal entity" + + @pytest.mark.asyncio + async def test_empty_entities_blob_yields_empty_list(self): + """row[3] being None / empty must produce a GraphEmbeddings with + no entities, not raise.""" + fake_row = (None, None, None, None) + + store = _make_store() + store.cassandra = Mock() + store.cassandra.execute = Mock(return_value=[fake_row]) + store.get_graph_embeddings_stmt = Mock() + + received = [] + + async def receiver(msg): + received.append(msg) + + await store.get_graph_embeddings("u", "d", receiver) + + assert len(received) == 1 + assert received[0].entities == [] + + @pytest.mark.asyncio + async def test_multiple_rows_each_emit_one_message(self): + fake_rows = [ + (None, None, None, [ + (("http://example.org/a", True), [1.0]), + ]), + (None, None, None, [ + (("http://example.org/b", True), [2.0]), + ]), + ] + + store = _make_store() + store.cassandra = Mock() + store.cassandra.execute = Mock(return_value=fake_rows) + store.get_graph_embeddings_stmt = Mock() + + received = [] + + async def receiver(msg): + received.append(msg) + + await store.get_graph_embeddings("u", "d", receiver) + + assert len(received) == 2 + assert received[0].entities[0].entity.iri == "http://example.org/a" + assert received[0].entities[0].vector == [1.0] + assert received[1].entities[0].entity.iri == "http://example.org/b" + assert received[1].entities[0].vector == [2.0] + + +class TestGetTriples: + """Bonus: the sibling get_triples path uses the same row[3] shape and + the same Metadata construction. Cover it for parity.""" + + @pytest.mark.asyncio + async def test_row_converts_to_triples(self): + # row[3] is a list of (s_val, s_uri, p_val, p_uri, o_val, o_uri) + fake_row = ( + None, None, None, + [ + ( + "http://example.org/alice", True, + "http://example.org/knows", True, + "http://example.org/bob", True, + ), + ], + ) + + store = _make_store() + store.cassandra = Mock() + store.cassandra.execute = Mock(return_value=[fake_row]) + store.get_triples_stmt = Mock() + + received = [] + + async def receiver(msg): + received.append(msg) + + await store.get_triples("alice", "doc-1", receiver) + + assert len(received) == 1 + triples_msg = received[0] + assert isinstance(triples_msg, Triples) + assert isinstance(triples_msg.metadata, Metadata) + assert triples_msg.metadata.id == "doc-1" + assert triples_msg.metadata.user == "alice" + + assert len(triples_msg.triples) == 1 + t = triples_msg.triples[0] + assert isinstance(t, Triple) + assert t.s.iri == "http://example.org/alice" + assert t.p.iri == "http://example.org/knows" + assert t.o.iri == "http://example.org/bob" diff --git a/tests/unit/test_translators/__init__.py b/tests/unit/test_translators/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unit/test_translators/test_document_embeddings_translator_roundtrip.py b/tests/unit/test_translators/test_document_embeddings_translator_roundtrip.py new file mode 100644 index 00000000..72f4796b --- /dev/null +++ b/tests/unit/test_translators/test_document_embeddings_translator_roundtrip.py @@ -0,0 +1,66 @@ +""" +Round-trip unit tests for DocumentEmbeddingsTranslator. + +Regression coverage: a previous version of the decode side constructed +ChunkEmbeddings(vectors=...) — the schema field is `vector` (singular), +so any real DocumentEmbeddings message would crash on decode. The encode +side already wrote `"vector"`, so encode→decode was asymmetric. +""" + +import pytest + +from trustgraph.messaging.translators.document_loading import ( + DocumentEmbeddingsTranslator, +) +from trustgraph.schema import ( + DocumentEmbeddings, + ChunkEmbeddings, + Metadata, +) + + +@pytest.fixture +def translator(): + return DocumentEmbeddingsTranslator() + + +@pytest.fixture +def sample(): + return DocumentEmbeddings( + metadata=Metadata( + id="doc-1", + root="", + user="alice", + collection="testcoll", + ), + chunks=[ + ChunkEmbeddings(chunk_id="c1", vector=[0.1, 0.2, 0.3]), + ChunkEmbeddings(chunk_id="c2", vector=[0.4, 0.5, 0.6]), + ], + ) + + +class TestDocumentEmbeddingsTranslator: + + def test_encode_uses_singular_vector_key(self, translator, sample): + encoded = translator.encode(sample) + chunks = encoded["chunks"] + assert all("vector" in c for c in chunks) + assert all("vectors" not in c for c in chunks) + assert chunks[0]["vector"] == [0.1, 0.2, 0.3] + + def test_roundtrip_preserves_document_embeddings(self, translator, sample): + encoded = translator.encode(sample) + decoded = translator.decode(encoded) + + assert isinstance(decoded, DocumentEmbeddings) + assert isinstance(decoded.metadata, Metadata) + assert decoded.metadata.id == "doc-1" + assert decoded.metadata.user == "alice" + assert decoded.metadata.collection == "testcoll" + + assert len(decoded.chunks) == 2 + assert decoded.chunks[0].chunk_id == "c1" + assert decoded.chunks[0].vector == [0.1, 0.2, 0.3] + assert decoded.chunks[1].chunk_id == "c2" + assert decoded.chunks[1].vector == [0.4, 0.5, 0.6] diff --git a/tests/unit/test_translators/test_knowledge_translator_roundtrip.py b/tests/unit/test_translators/test_knowledge_translator_roundtrip.py new file mode 100644 index 00000000..57e7ae17 --- /dev/null +++ b/tests/unit/test_translators/test_knowledge_translator_roundtrip.py @@ -0,0 +1,153 @@ +""" +Round-trip unit tests for KnowledgeRequestTranslator. + +Regression coverage: a previous version of the decode side constructed +EntityEmbeddings(vectors=...) — the schema field is `vector` (singular), +so any real graph-embeddings KnowledgeRequest would crash on first +message. The encode side already wrote `"vector"`, so encode→decode was +asymmetric. + +These tests build a real KnowledgeRequest with graph-embeddings, encode +it, decode the result, and assert the round-trip is lossless. They also +exercise the triples path so any future schema drift in Metadata or +Triples breaks the test. +""" + +import pytest + +from trustgraph.messaging.translators.knowledge import KnowledgeRequestTranslator +from trustgraph.schema import ( + KnowledgeRequest, + GraphEmbeddings, + EntityEmbeddings, + Triples, + Triple, + Metadata, + Term, + IRI, +) + + +def _term_iri(uri): + return Term(type=IRI, iri=uri) + + +@pytest.fixture +def translator(): + return KnowledgeRequestTranslator() + + +@pytest.fixture +def graph_embeddings_request(): + return KnowledgeRequest( + operation="put-kg-core", + user="alice", + id="doc-1", + flow="default", + collection="testcoll", + graph_embeddings=GraphEmbeddings( + metadata=Metadata( + id="doc-1", + root="", + user="alice", + collection="testcoll", + ), + entities=[ + EntityEmbeddings( + entity=_term_iri("http://example.org/alice"), + vector=[0.1, 0.2, 0.3], + ), + EntityEmbeddings( + entity=_term_iri("http://example.org/bob"), + vector=[0.4, 0.5, 0.6], + ), + ], + ), + ) + + +@pytest.fixture +def triples_request(): + return KnowledgeRequest( + operation="put-kg-core", + user="alice", + id="doc-1", + flow="default", + collection="testcoll", + triples=Triples( + metadata=Metadata( + id="doc-1", + root="", + user="alice", + collection="testcoll", + ), + triples=[ + Triple( + s=_term_iri("http://example.org/alice"), + p=_term_iri("http://example.org/knows"), + o=_term_iri("http://example.org/bob"), + ), + ], + ), + ) + + +class TestKnowledgeRequestTranslatorGraphEmbeddings: + + def test_encode_produces_singular_vector_key( + self, translator, graph_embeddings_request, + ): + """The wire key must be `vector`, never `vectors`.""" + encoded = translator.encode(graph_embeddings_request) + entities = encoded["graph-embeddings"]["entities"] + assert all("vector" in e for e in entities) + assert all("vectors" not in e for e in entities) + assert entities[0]["vector"] == [0.1, 0.2, 0.3] + + def test_roundtrip_preserves_graph_embeddings( + self, translator, graph_embeddings_request, + ): + """encode -> decode must be lossless for the GE branch.""" + encoded = translator.encode(graph_embeddings_request) + decoded = translator.decode(encoded) + + assert isinstance(decoded, KnowledgeRequest) + assert decoded.operation == "put-kg-core" + assert decoded.user == "alice" + assert decoded.id == "doc-1" + assert decoded.flow == "default" + assert decoded.collection == "testcoll" + + assert decoded.graph_embeddings is not None + ge = decoded.graph_embeddings + assert isinstance(ge, GraphEmbeddings) + assert isinstance(ge.metadata, Metadata) + assert ge.metadata.id == "doc-1" + assert ge.metadata.user == "alice" + assert ge.metadata.collection == "testcoll" + + assert len(ge.entities) == 2 + assert ge.entities[0].vector == [0.1, 0.2, 0.3] + assert ge.entities[1].vector == [0.4, 0.5, 0.6] + assert ge.entities[0].entity.iri == "http://example.org/alice" + assert ge.entities[1].entity.iri == "http://example.org/bob" + + +class TestKnowledgeRequestTranslatorTriples: + + def test_roundtrip_preserves_triples(self, translator, triples_request): + encoded = translator.encode(triples_request) + decoded = translator.decode(encoded) + + assert isinstance(decoded, KnowledgeRequest) + assert decoded.triples is not None + assert isinstance(decoded.triples.metadata, Metadata) + assert decoded.triples.metadata.id == "doc-1" + assert decoded.triples.metadata.user == "alice" + assert decoded.triples.metadata.collection == "testcoll" + + assert len(decoded.triples.triples) == 1 + t = decoded.triples.triples[0] + assert t.s.iri == "http://example.org/alice" + assert t.p.iri == "http://example.org/knows" + assert t.o.iri == "http://example.org/bob" diff --git a/trustgraph-base/trustgraph/messaging/translators/document_loading.py b/trustgraph-base/trustgraph/messaging/translators/document_loading.py index 3e7062e2..df2aa3ba 100644 --- a/trustgraph-base/trustgraph/messaging/translators/document_loading.py +++ b/trustgraph-base/trustgraph/messaging/translators/document_loading.py @@ -151,7 +151,7 @@ class DocumentEmbeddingsTranslator(SendTranslator): chunks = [ ChunkEmbeddings( chunk_id=chunk["chunk_id"], - vectors=chunk["vectors"] + vector=chunk["vector"] ) for chunk in data.get("chunks", []) ] diff --git a/trustgraph-base/trustgraph/messaging/translators/knowledge.py b/trustgraph-base/trustgraph/messaging/translators/knowledge.py index 2f11d75a..f819dc9c 100644 --- a/trustgraph-base/trustgraph/messaging/translators/knowledge.py +++ b/trustgraph-base/trustgraph/messaging/translators/knowledge.py @@ -39,7 +39,7 @@ class KnowledgeRequestTranslator(MessageTranslator): entities=[ EntityEmbeddings( entity=self.value_translator.decode(ent["entity"]), - vectors=ent["vectors"], + vector=ent["vector"], ) for ent in data["graph-embeddings"]["entities"] ] diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/core_export.py b/trustgraph-flow/trustgraph/gateway/dispatch/core_export.py index 62626046..3a37c4e3 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/core_export.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/core_export.py @@ -40,15 +40,14 @@ class CoreExport: "ge", { "m": { - "i": data["metadata"]["id"], - "m": data["metadata"]["metadata"], + "i": data["metadata"]["id"], "u": data["metadata"]["user"], "c": data["metadata"]["collection"], }, "e": [ { "e": ent["entity"], - "v": ent["vectors"], + "v": ent["vector"], } for ent in data["entities"] ] @@ -65,8 +64,7 @@ class CoreExport: "t", { "m": { - "i": data["metadata"]["id"], - "m": data["metadata"]["metadata"], + "i": data["metadata"]["id"], "u": data["metadata"]["user"], "c": data["metadata"]["collection"], }, diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/core_import.py b/trustgraph-flow/trustgraph/gateway/dispatch/core_import.py index af22a5b0..0ca07319 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/core_import.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/core_import.py @@ -48,7 +48,6 @@ class CoreImport: "triples": { "metadata": { "id": id, - "metadata": msg["m"]["m"], "user": user, "collection": "default", # Not used? }, @@ -57,7 +56,7 @@ class CoreImport: } await kr.process(msg) - + elif unpacked[0] == "ge": msg = unpacked[1] msg = { @@ -67,14 +66,13 @@ class CoreImport: "graph-embeddings": { "metadata": { "id": id, - "metadata": msg["m"]["m"], "user": user, "collection": "default", # Not used? }, "entities": [ { "entity": ent["e"], - "vectors": ent["v"], + "vector": ent["v"], } for ent in msg["e"] ] diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/entity_contexts_import.py b/trustgraph-flow/trustgraph/gateway/dispatch/entity_contexts_import.py index 6e01a5ca..de0fe52d 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/entity_contexts_import.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/entity_contexts_import.py @@ -8,7 +8,7 @@ from ... schema import Metadata from ... schema import EntityContexts, EntityContext from ... base import Publisher -from . serialize import to_subgraph, to_value +from . serialize import to_value # Module logger logger = logging.getLogger(__name__) @@ -48,7 +48,6 @@ class EntityContextsImport: elt = EntityContexts( metadata=Metadata( id=data["metadata"]["id"], - metadata=to_subgraph(data["metadata"]["metadata"]), user=data["metadata"]["user"], collection=data["metadata"]["collection"], ), diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/graph_embeddings_import.py b/trustgraph-flow/trustgraph/gateway/dispatch/graph_embeddings_import.py index 8abf5e9c..7c7dc915 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/graph_embeddings_import.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/graph_embeddings_import.py @@ -8,7 +8,7 @@ from ... schema import Metadata from ... schema import GraphEmbeddings, EntityEmbeddings from ... base import Publisher -from . serialize import to_subgraph, to_value +from . serialize import to_value # Module logger logger = logging.getLogger(__name__) @@ -48,14 +48,13 @@ class GraphEmbeddingsImport: elt = GraphEmbeddings( metadata=Metadata( id=data["metadata"]["id"], - metadata=to_subgraph(data["metadata"]["metadata"]), user=data["metadata"]["user"], collection=data["metadata"]["collection"], ), entities=[ EntityEmbeddings( entity=to_value(ent["entity"]), - vectors=ent["vectors"], + vector=ent["vector"], ) for ent in data["entities"] ] diff --git a/trustgraph-flow/trustgraph/tables/knowledge.py b/trustgraph-flow/trustgraph/tables/knowledge.py index 430dc3c9..2bdb6bd8 100644 --- a/trustgraph-flow/trustgraph/tables/knowledge.py +++ b/trustgraph-flow/trustgraph/tables/knowledge.py @@ -443,7 +443,7 @@ class KnowledgeTableStore: entities = [ EntityEmbeddings( entity = tuple_to_term(ent[0][0], ent[0][1]), - vectors = ent[1] + vector = ent[1] ) for ent in row[3] ]