Fix Metadata/EntityEmbeddings schema migration tail and add regression tests (#777)

The Metadata dataclass dropped its `metadata: list[Triple]` field
and EntityEmbeddings/ChunkEmbeddings settled on a singular
`vector: list[float]` field, but several call sites kept passing
`Metadata(metadata=...)` and `EntityEmbeddings(vectors=...)`. The
bugs were latent until a websocket client first hit
`/api/v1/flow/default/import/entity-contexts`, at which point the
dispatcher TypeError'd on construction.

Production fixes (5 call sites on the same migration tail):

  * trustgraph-flow gateway dispatchers entity_contexts_import.py
    and graph_embeddings_import.py — drop the stale
    Metadata(metadata=...)  kwarg; switch graph_embeddings_import
    to the singular `vector` wire key.
  * trustgraph-base messaging translators knowledge.py and
    document_loading.py — fix decode side to read the singular
    `"vector"` key, matching what their own encode sides have
    always written.
  * trustgraph-flow tables/knowledge.py — fix Cassandra row
    deserialiser to construct EntityEmbeddings(vector=...)
    instead of vectors=.
  * trustgraph-flow gateway core_import/core_export — switch the
    kg-core msgpack wire format to the singular `"v"`/`"vector"`
    key and drop the dead `m["m"]` envelope field that referenced
    the removed Metadata.metadata triples list (it was a
    guaranteed KeyError on the export side).

Defense-in-depth regression coverage (32 new tests across 7 files):

  * tests/contract/test_schema_field_contracts.py — pin the field
    set of Metadata, EntityEmbeddings, ChunkEmbeddings,
    EntityContext so any future schema rename fails CI loudly
    with a clear diff.
  * tests/unit/test_translators/test_knowledge_translator_roundtrip.py
    and test_document_embeddings_translator_roundtrip.py -
    encode→decode round-trip the affected translators end to end,
    locking in the singular `"vector"` wire key.
  * tests/unit/test_gateway/test_entity_contexts_import_dispatcher.py
    and test_graph_embeddings_import_dispatcher.py — exercise the
    websocket dispatchers' receive() path with realistic
    payloads, the direct regression test for the original
    production crash.
  * tests/unit/test_gateway/test_core_import_export_roundtrip.py
    — pack/unpack the kg-core msgpack format through the real
    dispatcher classes (with KnowledgeRequestor mocked),
    including a full export→import round-trip.
  * tests/unit/test_tables/test_knowledge_table_store.py —
    exercise the Cassandra row → schema conversion via __new__ to
    bypass the live cluster connection.

Also fixes an unrelated leaked-coroutine RuntimeWarning in
test_gateway/test_service.py::test_run_method_calls_web_run_app: the
mocked aiohttp.web.run_app now closes the coroutine that Api.run() hands
it, mirroring what the real run_app would do, instead of leaving it for
the GC to complain about.
This commit is contained in:
cybermaggedon 2026-04-10 20:43:45 +01:00 committed by GitHub
parent 0994d4b05f
commit c23e28aa66
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
17 changed files with 1415 additions and 17 deletions

View file

@ -0,0 +1,73 @@
"""
Contract tests for schema dataclass field sets.
These pin the *field names* of small, widely-constructed schema dataclasses
so that any rename, removal, or accidental addition fails CI loudly instead
of waiting for a runtime TypeError on the next websocket message.
Background: in v2.2 the `Metadata` dataclass dropped a `metadata: list[Triple]`
field but several call sites kept passing `Metadata(metadata=...)`. The bug
was only discovered when a websocket import dispatcher received its first
real message in production. A trivial structural assertion of the kind
below would have caught it at unit-test time.
Add to this file whenever a schema rename burns you. The cost of a frozen
field set is a one-line update when you intentionally evolve the schema; the
benefit is that every call site is forced to come along for the ride.
"""
import dataclasses
import pytest
from trustgraph.schema import (
Metadata,
EntityContext,
EntityEmbeddings,
ChunkEmbeddings,
)
def _field_names(dc):
return {f.name for f in dataclasses.fields(dc)}
@pytest.mark.contract
class TestSchemaFieldContracts:
"""Pin the field set of dataclasses that get constructed all over the
codebase. If you intentionally change one of these, update the
expected set in the same commit that diff will surface every call
site that needs to come along."""
def test_metadata_fields(self):
# NOTE: there is no `metadata` field. A previous regression
# constructed Metadata(metadata=...) and crashed at runtime.
assert _field_names(Metadata) == {
"id",
"root",
"user",
"collection",
}
def test_entity_embeddings_fields(self):
# NOTE: the embedding field is `vector` (singular, list[float]).
# There is no `vectors` field. Several call sites historically
# passed `vectors=` and crashed at runtime.
assert _field_names(EntityEmbeddings) == {
"entity",
"vector",
"chunk_id",
}
def test_chunk_embeddings_fields(self):
# Same `vector` (singular) convention as EntityEmbeddings.
assert _field_names(ChunkEmbeddings) == {
"chunk_id",
"vector",
}
def test_entity_context_fields(self):
assert _field_names(EntityContext) == {
"entity",
"context",
"chunk_id",
}

View file

@ -0,0 +1,418 @@
"""
Round-trip unit tests for the core msgpack import/export gateway endpoints.
The kg-core export endpoint receives KnowledgeResponse-shaped dicts from
the responder callback and packs them into msgpack tuples. The kg-core
import endpoint takes msgpack tuples back off the wire and rebuilds
KnowledgeRequest-shaped dicts which it then hands to KnowledgeRequestor
(whose translator decodes them into real dataclasses).
Regression coverage: the previous wire format used `"vectors"` (plural)
in the entity blobs and embedded a stale `"m"` field that referenced the
removed `Metadata.metadata` triples-list field. The export side hit a
KeyError on first message; the import side built dicts that the
KnowledgeRequestTranslator (separately fixed) couldn't decode. These
tests pin both halves of the wire protocol.
"""
import msgpack
import pytest
from unittest.mock import AsyncMock, Mock, patch
from trustgraph.gateway.dispatch.core_export import CoreExport
from trustgraph.gateway.dispatch.core_import import CoreImport
# ---------------------------------------------------------------------------
# Helpers — sample translator-shaped dicts (as KnowledgeResponseTranslator
# would emit). The vector wire key is *singular* on purpose; the export
# side previously read the wrong key and crashed.
# ---------------------------------------------------------------------------
def _ge_response_dict():
return {
"graph-embeddings": {
"metadata": {
"id": "doc-1",
"root": "",
"user": "alice",
"collection": "testcoll",
},
"entities": [
{
"entity": {"t": "i", "i": "http://example.org/alice"},
"vector": [0.1, 0.2, 0.3],
},
{
"entity": {"t": "i", "i": "http://example.org/bob"},
"vector": [0.4, 0.5, 0.6],
},
],
}
}
def _triples_response_dict():
return {
"triples": {
"metadata": {
"id": "doc-1",
"root": "",
"user": "alice",
"collection": "testcoll",
},
"triples": [
{
"s": {"t": "i", "i": "http://example.org/alice"},
"p": {"t": "i", "i": "http://example.org/knows"},
"o": {"t": "i", "i": "http://example.org/bob"},
},
],
}
}
def _make_request(id_="doc-1", user="alice"):
request = Mock()
request.query = {"id": id_, "user": user}
return request
def _make_data_reader(payload: bytes):
"""Mock the aiohttp StreamReader: returns payload once, then EOF."""
chunks = [payload, b""]
data = Mock()
async def fake_read(n):
return chunks.pop(0) if chunks else b""
data.read = fake_read
return data
# ---------------------------------------------------------------------------
# Export side: translator-shaped dict -> msgpack bytes
# ---------------------------------------------------------------------------
class TestCoreExportWireFormat:
@pytest.mark.asyncio
@patch("trustgraph.gateway.dispatch.core_export.KnowledgeRequestor")
async def test_export_packs_graph_embeddings_with_singular_vector(
self, mock_kr_class,
):
"""The export side must read `ent["vector"]` and emit `v`. The
previous bug was reading `ent["vectors"]` which KeyErrored against
the translator output."""
captured = []
async def fake_kr_process(req_dict, responder):
await responder(_ge_response_dict(), True)
mock_kr = AsyncMock()
mock_kr.start = AsyncMock()
mock_kr.stop = AsyncMock()
mock_kr.process = fake_kr_process
mock_kr_class.return_value = mock_kr
response = AsyncMock()
async def fake_write(b):
captured.append(b)
response.write = fake_write
response.write_eof = AsyncMock()
ok = AsyncMock(return_value=response)
error = AsyncMock()
exporter = CoreExport(backend=Mock())
await exporter.process(
data=Mock(),
error=error,
ok=ok,
request=_make_request(),
)
# Did not raise, did not call error()
error.assert_not_called()
assert len(captured) == 1
unpacker = msgpack.Unpacker()
unpacker.feed(captured[0])
items = list(unpacker)
assert len(items) == 1
msg_type, payload = items[0]
assert msg_type == "ge"
# Metadata envelope: only id/user/collection — no stale `m["m"]`.
assert payload["m"] == {
"i": "doc-1",
"u": "alice",
"c": "testcoll",
}
# Entities: each carries the *singular* `v` and the term envelope
assert len(payload["e"]) == 2
assert payload["e"][0]["v"] == [0.1, 0.2, 0.3]
assert payload["e"][1]["v"] == [0.4, 0.5, 0.6]
assert payload["e"][0]["e"]["i"] == "http://example.org/alice"
@pytest.mark.asyncio
@patch("trustgraph.gateway.dispatch.core_export.KnowledgeRequestor")
async def test_export_packs_triples(self, mock_kr_class):
captured = []
async def fake_kr_process(req_dict, responder):
await responder(_triples_response_dict(), True)
mock_kr = AsyncMock()
mock_kr.start = AsyncMock()
mock_kr.stop = AsyncMock()
mock_kr.process = fake_kr_process
mock_kr_class.return_value = mock_kr
response = AsyncMock()
async def fake_write(b):
captured.append(b)
response.write = fake_write
response.write_eof = AsyncMock()
ok = AsyncMock(return_value=response)
error = AsyncMock()
exporter = CoreExport(backend=Mock())
await exporter.process(
data=Mock(), error=error, ok=ok, request=_make_request(),
)
error.assert_not_called()
assert len(captured) == 1
unpacker = msgpack.Unpacker()
unpacker.feed(captured[0])
items = list(unpacker)
assert len(items) == 1
msg_type, payload = items[0]
assert msg_type == "t"
assert payload["m"] == {
"i": "doc-1",
"u": "alice",
"c": "testcoll",
}
assert len(payload["t"]) == 1
# ---------------------------------------------------------------------------
# Import side: msgpack bytes -> translator-shaped dict
# ---------------------------------------------------------------------------
class TestCoreImportWireFormat:
@pytest.mark.asyncio
@patch("trustgraph.gateway.dispatch.core_import.KnowledgeRequestor")
async def test_import_unpacks_graph_embeddings_to_singular_vector(
self, mock_kr_class,
):
"""The import side must build dicts whose entity blobs have the
singular `vector` key that's what the KnowledgeRequestTranslator
decode side reads. Previous bug emitted `vectors`."""
captured = []
async def fake_kr_process(req_dict):
captured.append(req_dict)
mock_kr = AsyncMock()
mock_kr.start = AsyncMock()
mock_kr.stop = AsyncMock()
mock_kr.process = fake_kr_process
mock_kr_class.return_value = mock_kr
# Build a msgpack tuple matching the new wire format
payload = msgpack.packb((
"ge",
{
"m": {"i": "doc-1", "u": "alice", "c": "testcoll"},
"e": [
{
"e": {"t": "i", "i": "http://example.org/alice"},
"v": [0.1, 0.2, 0.3],
},
],
},
))
ok = AsyncMock(return_value=AsyncMock(write_eof=AsyncMock()))
error = AsyncMock()
importer = CoreImport(backend=Mock())
await importer.process(
data=_make_data_reader(payload),
error=error,
ok=ok,
request=_make_request(),
)
error.assert_not_called()
assert len(captured) == 1
req = captured[0]
assert req["operation"] == "put-kg-core"
assert req["user"] == "alice"
assert req["id"] == "doc-1"
ge = req["graph-embeddings"]
# Metadata envelope must NOT contain a stale `metadata` key
# referencing the removed Metadata.metadata field.
assert "metadata" not in ge["metadata"]
assert ge["metadata"] == {
"id": "doc-1",
"user": "alice",
"collection": "default",
}
# Entity blob carries the singular `vector` key
assert len(ge["entities"]) == 1
ent = ge["entities"][0]
assert ent["vector"] == [0.1, 0.2, 0.3]
assert "vectors" not in ent
@pytest.mark.asyncio
@patch("trustgraph.gateway.dispatch.core_import.KnowledgeRequestor")
async def test_import_unpacks_triples(self, mock_kr_class):
captured = []
async def fake_kr_process(req_dict):
captured.append(req_dict)
mock_kr = AsyncMock()
mock_kr.start = AsyncMock()
mock_kr.stop = AsyncMock()
mock_kr.process = fake_kr_process
mock_kr_class.return_value = mock_kr
payload = msgpack.packb((
"t",
{
"m": {"i": "doc-1", "u": "alice", "c": "testcoll"},
"t": [
{
"s": {"t": "i", "i": "http://example.org/alice"},
"p": {"t": "i", "i": "http://example.org/knows"},
"o": {"t": "i", "i": "http://example.org/bob"},
},
],
},
))
ok = AsyncMock(return_value=AsyncMock(write_eof=AsyncMock()))
error = AsyncMock()
importer = CoreImport(backend=Mock())
await importer.process(
data=_make_data_reader(payload),
error=error,
ok=ok,
request=_make_request(),
)
error.assert_not_called()
assert len(captured) == 1
req = captured[0]
triples = req["triples"]
assert "metadata" not in triples["metadata"] # no stale field
assert len(triples["triples"]) == 1
# ---------------------------------------------------------------------------
# Full round-trip: export bytes feed directly into import
# ---------------------------------------------------------------------------
class TestCoreImportExportRoundTrip:
"""End-to-end: produce bytes via core_export, consume them via
core_import, and verify the dict that lands at the import-side
translator is structurally equivalent to what went in. This is the
test that catches asymmetries between the two halves."""
@pytest.mark.asyncio
@patch("trustgraph.gateway.dispatch.core_import.KnowledgeRequestor")
@patch("trustgraph.gateway.dispatch.core_export.KnowledgeRequestor")
async def test_graph_embeddings_round_trip(
self, mock_export_kr_class, mock_import_kr_class,
):
# ----- export side: capture bytes -----
export_bytes = []
async def fake_export_process(req_dict, responder):
await responder(_ge_response_dict(), True)
export_kr = AsyncMock()
export_kr.start = AsyncMock()
export_kr.stop = AsyncMock()
export_kr.process = fake_export_process
mock_export_kr_class.return_value = export_kr
response = AsyncMock()
async def fake_write(b):
export_bytes.append(b)
response.write = fake_write
response.write_eof = AsyncMock()
exporter = CoreExport(backend=Mock())
await exporter.process(
data=Mock(),
error=AsyncMock(),
ok=AsyncMock(return_value=response),
request=_make_request(),
)
assert len(export_bytes) == 1
# ----- import side: feed those bytes back in -----
import_captured = []
async def fake_import_process(req_dict):
import_captured.append(req_dict)
import_kr = AsyncMock()
import_kr.start = AsyncMock()
import_kr.stop = AsyncMock()
import_kr.process = fake_import_process
mock_import_kr_class.return_value = import_kr
importer = CoreImport(backend=Mock())
await importer.process(
data=_make_data_reader(export_bytes[0]),
error=AsyncMock(),
ok=AsyncMock(return_value=AsyncMock(write_eof=AsyncMock())),
request=_make_request(),
)
# ----- verify the dict the importer would hand to the translator -----
assert len(import_captured) == 1
req = import_captured[0]
original = _ge_response_dict()["graph-embeddings"]
ge = req["graph-embeddings"]
# The import side overrides id/user from the URL query (intentional),
# so we only round-trip the entity payload itself.
assert ge["metadata"]["id"] == original["metadata"]["id"]
assert ge["metadata"]["user"] == original["metadata"]["user"]
assert len(ge["entities"]) == len(original["entities"])
for got, want in zip(ge["entities"], original["entities"]):
assert got["vector"] == want["vector"]
assert got["entity"] == want["entity"]

View file

@ -0,0 +1,242 @@
"""
Unit tests for entity contexts import dispatcher.
Tests the business logic of EntityContextsImport while mocking the
Publisher and websocket components.
Regression coverage: a previous version constructed Metadata(metadata=...)
which raised TypeError at runtime as soon as a message was received. These
tests exercise receive() end-to-end so any future schema/kwarg drift in
the Metadata or EntityContexts construction is caught immediately.
"""
import pytest
from unittest.mock import Mock, AsyncMock, patch
from trustgraph.gateway.dispatch.entity_contexts_import import EntityContextsImport
from trustgraph.schema import EntityContexts, EntityContext, Metadata
@pytest.fixture
def mock_backend():
return Mock()
@pytest.fixture
def mock_running():
running = Mock()
running.get.return_value = True
running.stop = Mock()
return running
@pytest.fixture
def mock_websocket():
ws = Mock()
ws.close = AsyncMock()
return ws
@pytest.fixture
def sample_message():
"""Sample entity-contexts websocket message."""
return {
"metadata": {
"id": "doc-123",
"user": "testuser",
"collection": "testcollection",
},
"entities": [
{
"entity": {"v": "http://example.org/alice", "e": True},
"context": "Alice is a person.",
},
{
"entity": {"v": "http://example.org/bob", "e": True},
"context": "Bob is a person.",
},
],
}
@pytest.fixture
def empty_entities_message():
return {
"metadata": {
"id": "doc-empty",
"user": "u",
"collection": "c",
},
"entities": [],
}
class TestEntityContextsImportInitialization:
@patch('trustgraph.gateway.dispatch.entity_contexts_import.Publisher')
def test_init_creates_publisher_with_correct_params(
self, mock_publisher_class, mock_backend, mock_websocket, mock_running
):
instance = Mock()
mock_publisher_class.return_value = instance
dispatcher = EntityContextsImport(
ws=mock_websocket,
running=mock_running,
backend=mock_backend,
queue="ec-queue",
)
mock_publisher_class.assert_called_once_with(
mock_backend,
topic="ec-queue",
schema=EntityContexts,
)
assert dispatcher.ws is mock_websocket
assert dispatcher.running is mock_running
assert dispatcher.publisher is instance
class TestEntityContextsImportLifecycle:
@patch('trustgraph.gateway.dispatch.entity_contexts_import.Publisher')
@pytest.mark.asyncio
async def test_start_calls_publisher_start(
self, mock_publisher_class, mock_backend, mock_websocket, mock_running
):
instance = Mock()
instance.start = AsyncMock()
mock_publisher_class.return_value = instance
dispatcher = EntityContextsImport(
ws=mock_websocket, running=mock_running,
backend=mock_backend, queue="q",
)
await dispatcher.start()
instance.start.assert_called_once()
@patch('trustgraph.gateway.dispatch.entity_contexts_import.Publisher')
@pytest.mark.asyncio
async def test_destroy_stops_and_closes_properly(
self, mock_publisher_class, mock_backend, mock_websocket, mock_running
):
instance = Mock()
instance.stop = AsyncMock()
mock_publisher_class.return_value = instance
dispatcher = EntityContextsImport(
ws=mock_websocket, running=mock_running,
backend=mock_backend, queue="q",
)
await dispatcher.destroy()
mock_running.stop.assert_called_once()
instance.stop.assert_called_once()
mock_websocket.close.assert_called_once()
@patch('trustgraph.gateway.dispatch.entity_contexts_import.Publisher')
@pytest.mark.asyncio
async def test_destroy_handles_none_websocket(
self, mock_publisher_class, mock_backend, mock_running
):
instance = Mock()
instance.stop = AsyncMock()
mock_publisher_class.return_value = instance
dispatcher = EntityContextsImport(
ws=None, running=mock_running,
backend=mock_backend, queue="q",
)
await dispatcher.destroy()
mock_running.stop.assert_called_once()
instance.stop.assert_called_once()
class TestEntityContextsImportMessageProcessing:
"""Regression coverage for receive(): catches Metadata/schema drift."""
@patch('trustgraph.gateway.dispatch.entity_contexts_import.Publisher')
@pytest.mark.asyncio
async def test_receive_constructs_entity_contexts_correctly(
self, mock_publisher_class, mock_backend, mock_websocket,
mock_running, sample_message,
):
instance = Mock()
instance.send = AsyncMock()
mock_publisher_class.return_value = instance
dispatcher = EntityContextsImport(
ws=mock_websocket, running=mock_running,
backend=mock_backend, queue="q",
)
mock_msg = Mock()
mock_msg.json.return_value = sample_message
# If Metadata or EntityContexts gain/lose kwargs, this raises
# TypeError — exactly the regression we want to catch.
await dispatcher.receive(mock_msg)
instance.send.assert_called_once()
call_args = instance.send.call_args
assert call_args[0][0] is None
sent = call_args[0][1]
assert isinstance(sent, EntityContexts)
assert isinstance(sent.metadata, Metadata)
assert sent.metadata.id == "doc-123"
assert sent.metadata.user == "testuser"
assert sent.metadata.collection == "testcollection"
assert len(sent.entities) == 2
assert all(isinstance(e, EntityContext) for e in sent.entities)
assert sent.entities[0].context == "Alice is a person."
assert sent.entities[1].context == "Bob is a person."
@patch('trustgraph.gateway.dispatch.entity_contexts_import.Publisher')
@pytest.mark.asyncio
async def test_receive_handles_empty_entities(
self, mock_publisher_class, mock_backend, mock_websocket,
mock_running, empty_entities_message,
):
instance = Mock()
instance.send = AsyncMock()
mock_publisher_class.return_value = instance
dispatcher = EntityContextsImport(
ws=mock_websocket, running=mock_running,
backend=mock_backend, queue="q",
)
mock_msg = Mock()
mock_msg.json.return_value = empty_entities_message
await dispatcher.receive(mock_msg)
instance.send.assert_called_once()
sent = instance.send.call_args[0][1]
assert isinstance(sent, EntityContexts)
assert sent.entities == []
assert sent.metadata.id == "doc-empty"
@patch('trustgraph.gateway.dispatch.entity_contexts_import.Publisher')
@pytest.mark.asyncio
async def test_receive_propagates_publisher_errors(
self, mock_publisher_class, mock_backend, mock_websocket,
mock_running, sample_message,
):
instance = Mock()
instance.send = AsyncMock(side_effect=RuntimeError("publish failed"))
mock_publisher_class.return_value = instance
dispatcher = EntityContextsImport(
ws=mock_websocket, running=mock_running,
backend=mock_backend, queue="q",
)
mock_msg = Mock()
mock_msg.json.return_value = sample_message
with pytest.raises(RuntimeError, match="publish failed"):
await dispatcher.receive(mock_msg)

View file

@ -0,0 +1,247 @@
"""
Unit tests for graph embeddings import dispatcher.
Tests the business logic of GraphEmbeddingsImport while mocking the
Publisher and websocket components.
Regression coverage: a previous version of EntityContextsImport
constructed Metadata(metadata=...) which raised TypeError at runtime as
soon as a message was received. The same shape of bug can occur here, so
these tests exercise receive() end-to-end to catch any future schema or
kwarg drift in Metadata / GraphEmbeddings / EntityEmbeddings construction.
"""
import pytest
from unittest.mock import Mock, AsyncMock, patch
from trustgraph.gateway.dispatch.graph_embeddings_import import GraphEmbeddingsImport
from trustgraph.schema import GraphEmbeddings, EntityEmbeddings, Metadata
@pytest.fixture
def mock_backend():
return Mock()
@pytest.fixture
def mock_running():
running = Mock()
running.get.return_value = True
running.stop = Mock()
return running
@pytest.fixture
def mock_websocket():
ws = Mock()
ws.close = AsyncMock()
return ws
@pytest.fixture
def sample_message():
"""Sample graph-embeddings websocket message."""
return {
"metadata": {
"id": "doc-123",
"user": "testuser",
"collection": "testcollection",
},
"entities": [
{
"entity": {"v": "http://example.org/alice", "e": True},
"vector": [0.1, 0.2, 0.3],
},
{
"entity": {"v": "http://example.org/bob", "e": True},
"vector": [0.4, 0.5, 0.6],
},
],
}
@pytest.fixture
def empty_entities_message():
return {
"metadata": {
"id": "doc-empty",
"user": "u",
"collection": "c",
},
"entities": [],
}
class TestGraphEmbeddingsImportInitialization:
@patch('trustgraph.gateway.dispatch.graph_embeddings_import.Publisher')
def test_init_creates_publisher_with_correct_params(
self, mock_publisher_class, mock_backend, mock_websocket, mock_running
):
instance = Mock()
mock_publisher_class.return_value = instance
dispatcher = GraphEmbeddingsImport(
ws=mock_websocket,
running=mock_running,
backend=mock_backend,
queue="ge-queue",
)
mock_publisher_class.assert_called_once_with(
mock_backend,
topic="ge-queue",
schema=GraphEmbeddings,
)
assert dispatcher.ws is mock_websocket
assert dispatcher.running is mock_running
assert dispatcher.publisher is instance
class TestGraphEmbeddingsImportLifecycle:
@patch('trustgraph.gateway.dispatch.graph_embeddings_import.Publisher')
@pytest.mark.asyncio
async def test_start_calls_publisher_start(
self, mock_publisher_class, mock_backend, mock_websocket, mock_running
):
instance = Mock()
instance.start = AsyncMock()
mock_publisher_class.return_value = instance
dispatcher = GraphEmbeddingsImport(
ws=mock_websocket, running=mock_running,
backend=mock_backend, queue="q",
)
await dispatcher.start()
instance.start.assert_called_once()
@patch('trustgraph.gateway.dispatch.graph_embeddings_import.Publisher')
@pytest.mark.asyncio
async def test_destroy_stops_and_closes_properly(
self, mock_publisher_class, mock_backend, mock_websocket, mock_running
):
instance = Mock()
instance.stop = AsyncMock()
mock_publisher_class.return_value = instance
dispatcher = GraphEmbeddingsImport(
ws=mock_websocket, running=mock_running,
backend=mock_backend, queue="q",
)
await dispatcher.destroy()
mock_running.stop.assert_called_once()
instance.stop.assert_called_once()
mock_websocket.close.assert_called_once()
@patch('trustgraph.gateway.dispatch.graph_embeddings_import.Publisher')
@pytest.mark.asyncio
async def test_destroy_handles_none_websocket(
self, mock_publisher_class, mock_backend, mock_running
):
instance = Mock()
instance.stop = AsyncMock()
mock_publisher_class.return_value = instance
dispatcher = GraphEmbeddingsImport(
ws=None, running=mock_running,
backend=mock_backend, queue="q",
)
await dispatcher.destroy()
mock_running.stop.assert_called_once()
instance.stop.assert_called_once()
class TestGraphEmbeddingsImportMessageProcessing:
"""Regression coverage for receive(): catches Metadata/schema drift."""
@patch('trustgraph.gateway.dispatch.graph_embeddings_import.Publisher')
@pytest.mark.asyncio
async def test_receive_constructs_graph_embeddings_correctly(
self, mock_publisher_class, mock_backend, mock_websocket,
mock_running, sample_message,
):
instance = Mock()
instance.send = AsyncMock()
mock_publisher_class.return_value = instance
dispatcher = GraphEmbeddingsImport(
ws=mock_websocket, running=mock_running,
backend=mock_backend, queue="q",
)
mock_msg = Mock()
mock_msg.json.return_value = sample_message
# If Metadata, GraphEmbeddings, or EntityEmbeddings gain/lose
# kwargs, this raises TypeError — exactly the regression we want
# to catch.
await dispatcher.receive(mock_msg)
instance.send.assert_called_once()
call_args = instance.send.call_args
assert call_args[0][0] is None
sent = call_args[0][1]
assert isinstance(sent, GraphEmbeddings)
assert isinstance(sent.metadata, Metadata)
assert sent.metadata.id == "doc-123"
assert sent.metadata.user == "testuser"
assert sent.metadata.collection == "testcollection"
assert len(sent.entities) == 2
assert all(isinstance(e, EntityEmbeddings) for e in sent.entities)
# Lock in the wire format: incoming "vector" key (singular,
# list[float]) maps to EntityEmbeddings.vector. This mirrors
# serialize_graph_embeddings() on the export side.
assert sent.entities[0].vector == [0.1, 0.2, 0.3]
assert sent.entities[1].vector == [0.4, 0.5, 0.6]
@patch('trustgraph.gateway.dispatch.graph_embeddings_import.Publisher')
@pytest.mark.asyncio
async def test_receive_handles_empty_entities(
self, mock_publisher_class, mock_backend, mock_websocket,
mock_running, empty_entities_message,
):
instance = Mock()
instance.send = AsyncMock()
mock_publisher_class.return_value = instance
dispatcher = GraphEmbeddingsImport(
ws=mock_websocket, running=mock_running,
backend=mock_backend, queue="q",
)
mock_msg = Mock()
mock_msg.json.return_value = empty_entities_message
await dispatcher.receive(mock_msg)
instance.send.assert_called_once()
sent = instance.send.call_args[0][1]
assert isinstance(sent, GraphEmbeddings)
assert sent.entities == []
assert sent.metadata.id == "doc-empty"
@patch('trustgraph.gateway.dispatch.graph_embeddings_import.Publisher')
@pytest.mark.asyncio
async def test_receive_propagates_publisher_errors(
self, mock_publisher_class, mock_backend, mock_websocket,
mock_running, sample_message,
):
instance = Mock()
instance.send = AsyncMock(side_effect=RuntimeError("publish failed"))
mock_publisher_class.return_value = instance
dispatcher = GraphEmbeddingsImport(
ws=mock_websocket, running=mock_running,
backend=mock_backend, queue="q",
)
mock_msg = Mock()
mock_msg.json.return_value = sample_message
with pytest.raises(RuntimeError, match="publish failed"):
await dispatcher.receive(mock_msg)

View file

@ -171,6 +171,14 @@ class TestApi:
patch('aiohttp.web.run_app') as mock_run_app: patch('aiohttp.web.run_app') as mock_run_app:
mock_get_pubsub.return_value = Mock() mock_get_pubsub.return_value = Mock()
# Api.run() passes self.app_factory() — a coroutine — to
# web.run_app, which would normally consume it inside its own
# event loop. Since we mock run_app, close the coroutine here
# so it doesn't leak as an "unawaited coroutine" RuntimeWarning.
def _consume_coro(coro, **kwargs):
coro.close()
mock_run_app.side_effect = _consume_coro
api = Api(port=8080) api = Api(port=8080)
api.run() api.run()

View file

View file

@ -0,0 +1,197 @@
"""
Unit tests for KnowledgeTableStore row deserialization.
Regression coverage: a previous version of get_graph_embeddings constructed
EntityEmbeddings(vectors=ent[1]) the schema field is `vector` (singular),
so any real Cassandra row would crash on read. These tests bypass the live
Cassandra connection entirely and exercise the row -> schema conversion
with hand-built fake rows.
"""
import pytest
from unittest.mock import Mock
from trustgraph.tables.knowledge import KnowledgeTableStore
from trustgraph.schema import (
EntityEmbeddings,
GraphEmbeddings,
Triples,
Triple,
Metadata,
IRI,
LITERAL,
)
def _make_store():
"""
Build a KnowledgeTableStore without invoking __init__ (which connects
to Cassandra). Tests inject only the attributes the method under test
actually touches.
"""
return KnowledgeTableStore.__new__(KnowledgeTableStore)
class TestGetGraphEmbeddings:
@pytest.mark.asyncio
async def test_row_converts_to_entity_embeddings_with_singular_vector(self):
"""
Cassandra rows return entities as a list of [entity_tuple, vector]
pairs in row[3]. The deserializer must construct EntityEmbeddings
with `vector=` (singular) the schema field name. A previous
version used `vectors=` and TypeError'd at runtime.
"""
# Arrange — fake row matching the get_triples_stmt result shape:
# row[0..2] are unused by the method, row[3] is the entities blob
fake_row = (
None, None, None,
[
# ((value, is_uri), vector)
(("http://example.org/alice", True), [0.1, 0.2, 0.3]),
(("http://example.org/bob", True), [0.4, 0.5, 0.6]),
(("a literal entity", False), [0.7, 0.8, 0.9]),
],
)
store = _make_store()
store.cassandra = Mock()
store.cassandra.execute = Mock(return_value=[fake_row])
store.get_graph_embeddings_stmt = Mock()
received = []
async def receiver(msg):
received.append(msg)
# Act
await store.get_graph_embeddings(
user="alice",
document_id="doc-1",
receiver=receiver,
)
# Assert
store.cassandra.execute.assert_called_once_with(
store.get_graph_embeddings_stmt,
("alice", "doc-1"),
)
assert len(received) == 1
ge = received[0]
assert isinstance(ge, GraphEmbeddings)
assert isinstance(ge.metadata, Metadata)
assert ge.metadata.id == "doc-1"
assert ge.metadata.user == "alice"
assert len(ge.entities) == 3
assert all(isinstance(e, EntityEmbeddings) for e in ge.entities)
# Vectors land in the singular `vector` field — this is the
# explicit regression assertion for the original bug.
assert ge.entities[0].vector == [0.1, 0.2, 0.3]
assert ge.entities[1].vector == [0.4, 0.5, 0.6]
assert ge.entities[2].vector == [0.7, 0.8, 0.9]
# Term type round-trips through tuple_to_term
assert ge.entities[0].entity.type == IRI
assert ge.entities[0].entity.iri == "http://example.org/alice"
assert ge.entities[1].entity.type == IRI
assert ge.entities[1].entity.iri == "http://example.org/bob"
assert ge.entities[2].entity.type == LITERAL
assert ge.entities[2].entity.value == "a literal entity"
@pytest.mark.asyncio
async def test_empty_entities_blob_yields_empty_list(self):
"""row[3] being None / empty must produce a GraphEmbeddings with
no entities, not raise."""
fake_row = (None, None, None, None)
store = _make_store()
store.cassandra = Mock()
store.cassandra.execute = Mock(return_value=[fake_row])
store.get_graph_embeddings_stmt = Mock()
received = []
async def receiver(msg):
received.append(msg)
await store.get_graph_embeddings("u", "d", receiver)
assert len(received) == 1
assert received[0].entities == []
@pytest.mark.asyncio
async def test_multiple_rows_each_emit_one_message(self):
fake_rows = [
(None, None, None, [
(("http://example.org/a", True), [1.0]),
]),
(None, None, None, [
(("http://example.org/b", True), [2.0]),
]),
]
store = _make_store()
store.cassandra = Mock()
store.cassandra.execute = Mock(return_value=fake_rows)
store.get_graph_embeddings_stmt = Mock()
received = []
async def receiver(msg):
received.append(msg)
await store.get_graph_embeddings("u", "d", receiver)
assert len(received) == 2
assert received[0].entities[0].entity.iri == "http://example.org/a"
assert received[0].entities[0].vector == [1.0]
assert received[1].entities[0].entity.iri == "http://example.org/b"
assert received[1].entities[0].vector == [2.0]
class TestGetTriples:
"""Bonus: the sibling get_triples path uses the same row[3] shape and
the same Metadata construction. Cover it for parity."""
@pytest.mark.asyncio
async def test_row_converts_to_triples(self):
# row[3] is a list of (s_val, s_uri, p_val, p_uri, o_val, o_uri)
fake_row = (
None, None, None,
[
(
"http://example.org/alice", True,
"http://example.org/knows", True,
"http://example.org/bob", True,
),
],
)
store = _make_store()
store.cassandra = Mock()
store.cassandra.execute = Mock(return_value=[fake_row])
store.get_triples_stmt = Mock()
received = []
async def receiver(msg):
received.append(msg)
await store.get_triples("alice", "doc-1", receiver)
assert len(received) == 1
triples_msg = received[0]
assert isinstance(triples_msg, Triples)
assert isinstance(triples_msg.metadata, Metadata)
assert triples_msg.metadata.id == "doc-1"
assert triples_msg.metadata.user == "alice"
assert len(triples_msg.triples) == 1
t = triples_msg.triples[0]
assert isinstance(t, Triple)
assert t.s.iri == "http://example.org/alice"
assert t.p.iri == "http://example.org/knows"
assert t.o.iri == "http://example.org/bob"

View file

View file

@ -0,0 +1,66 @@
"""
Round-trip unit tests for DocumentEmbeddingsTranslator.
Regression coverage: a previous version of the decode side constructed
ChunkEmbeddings(vectors=...) the schema field is `vector` (singular),
so any real DocumentEmbeddings message would crash on decode. The encode
side already wrote `"vector"`, so encodedecode was asymmetric.
"""
import pytest
from trustgraph.messaging.translators.document_loading import (
DocumentEmbeddingsTranslator,
)
from trustgraph.schema import (
DocumentEmbeddings,
ChunkEmbeddings,
Metadata,
)
@pytest.fixture
def translator():
return DocumentEmbeddingsTranslator()
@pytest.fixture
def sample():
return DocumentEmbeddings(
metadata=Metadata(
id="doc-1",
root="",
user="alice",
collection="testcoll",
),
chunks=[
ChunkEmbeddings(chunk_id="c1", vector=[0.1, 0.2, 0.3]),
ChunkEmbeddings(chunk_id="c2", vector=[0.4, 0.5, 0.6]),
],
)
class TestDocumentEmbeddingsTranslator:
def test_encode_uses_singular_vector_key(self, translator, sample):
encoded = translator.encode(sample)
chunks = encoded["chunks"]
assert all("vector" in c for c in chunks)
assert all("vectors" not in c for c in chunks)
assert chunks[0]["vector"] == [0.1, 0.2, 0.3]
def test_roundtrip_preserves_document_embeddings(self, translator, sample):
encoded = translator.encode(sample)
decoded = translator.decode(encoded)
assert isinstance(decoded, DocumentEmbeddings)
assert isinstance(decoded.metadata, Metadata)
assert decoded.metadata.id == "doc-1"
assert decoded.metadata.user == "alice"
assert decoded.metadata.collection == "testcoll"
assert len(decoded.chunks) == 2
assert decoded.chunks[0].chunk_id == "c1"
assert decoded.chunks[0].vector == [0.1, 0.2, 0.3]
assert decoded.chunks[1].chunk_id == "c2"
assert decoded.chunks[1].vector == [0.4, 0.5, 0.6]

View file

@ -0,0 +1,153 @@
"""
Round-trip unit tests for KnowledgeRequestTranslator.
Regression coverage: a previous version of the decode side constructed
EntityEmbeddings(vectors=...) the schema field is `vector` (singular),
so any real graph-embeddings KnowledgeRequest would crash on first
message. The encode side already wrote `"vector"`, so encodedecode was
asymmetric.
These tests build a real KnowledgeRequest with graph-embeddings, encode
it, decode the result, and assert the round-trip is lossless. They also
exercise the triples path so any future schema drift in Metadata or
Triples breaks the test.
"""
import pytest
from trustgraph.messaging.translators.knowledge import KnowledgeRequestTranslator
from trustgraph.schema import (
KnowledgeRequest,
GraphEmbeddings,
EntityEmbeddings,
Triples,
Triple,
Metadata,
Term,
IRI,
)
def _term_iri(uri):
return Term(type=IRI, iri=uri)
@pytest.fixture
def translator():
return KnowledgeRequestTranslator()
@pytest.fixture
def graph_embeddings_request():
return KnowledgeRequest(
operation="put-kg-core",
user="alice",
id="doc-1",
flow="default",
collection="testcoll",
graph_embeddings=GraphEmbeddings(
metadata=Metadata(
id="doc-1",
root="",
user="alice",
collection="testcoll",
),
entities=[
EntityEmbeddings(
entity=_term_iri("http://example.org/alice"),
vector=[0.1, 0.2, 0.3],
),
EntityEmbeddings(
entity=_term_iri("http://example.org/bob"),
vector=[0.4, 0.5, 0.6],
),
],
),
)
@pytest.fixture
def triples_request():
return KnowledgeRequest(
operation="put-kg-core",
user="alice",
id="doc-1",
flow="default",
collection="testcoll",
triples=Triples(
metadata=Metadata(
id="doc-1",
root="",
user="alice",
collection="testcoll",
),
triples=[
Triple(
s=_term_iri("http://example.org/alice"),
p=_term_iri("http://example.org/knows"),
o=_term_iri("http://example.org/bob"),
),
],
),
)
class TestKnowledgeRequestTranslatorGraphEmbeddings:
def test_encode_produces_singular_vector_key(
self, translator, graph_embeddings_request,
):
"""The wire key must be `vector`, never `vectors`."""
encoded = translator.encode(graph_embeddings_request)
entities = encoded["graph-embeddings"]["entities"]
assert all("vector" in e for e in entities)
assert all("vectors" not in e for e in entities)
assert entities[0]["vector"] == [0.1, 0.2, 0.3]
def test_roundtrip_preserves_graph_embeddings(
self, translator, graph_embeddings_request,
):
"""encode -> decode must be lossless for the GE branch."""
encoded = translator.encode(graph_embeddings_request)
decoded = translator.decode(encoded)
assert isinstance(decoded, KnowledgeRequest)
assert decoded.operation == "put-kg-core"
assert decoded.user == "alice"
assert decoded.id == "doc-1"
assert decoded.flow == "default"
assert decoded.collection == "testcoll"
assert decoded.graph_embeddings is not None
ge = decoded.graph_embeddings
assert isinstance(ge, GraphEmbeddings)
assert isinstance(ge.metadata, Metadata)
assert ge.metadata.id == "doc-1"
assert ge.metadata.user == "alice"
assert ge.metadata.collection == "testcoll"
assert len(ge.entities) == 2
assert ge.entities[0].vector == [0.1, 0.2, 0.3]
assert ge.entities[1].vector == [0.4, 0.5, 0.6]
assert ge.entities[0].entity.iri == "http://example.org/alice"
assert ge.entities[1].entity.iri == "http://example.org/bob"
class TestKnowledgeRequestTranslatorTriples:
def test_roundtrip_preserves_triples(self, translator, triples_request):
encoded = translator.encode(triples_request)
decoded = translator.decode(encoded)
assert isinstance(decoded, KnowledgeRequest)
assert decoded.triples is not None
assert isinstance(decoded.triples.metadata, Metadata)
assert decoded.triples.metadata.id == "doc-1"
assert decoded.triples.metadata.user == "alice"
assert decoded.triples.metadata.collection == "testcoll"
assert len(decoded.triples.triples) == 1
t = decoded.triples.triples[0]
assert t.s.iri == "http://example.org/alice"
assert t.p.iri == "http://example.org/knows"
assert t.o.iri == "http://example.org/bob"

View file

@ -151,7 +151,7 @@ class DocumentEmbeddingsTranslator(SendTranslator):
chunks = [ chunks = [
ChunkEmbeddings( ChunkEmbeddings(
chunk_id=chunk["chunk_id"], chunk_id=chunk["chunk_id"],
vectors=chunk["vectors"] vector=chunk["vector"]
) )
for chunk in data.get("chunks", []) for chunk in data.get("chunks", [])
] ]

View file

@ -39,7 +39,7 @@ class KnowledgeRequestTranslator(MessageTranslator):
entities=[ entities=[
EntityEmbeddings( EntityEmbeddings(
entity=self.value_translator.decode(ent["entity"]), entity=self.value_translator.decode(ent["entity"]),
vectors=ent["vectors"], vector=ent["vector"],
) )
for ent in data["graph-embeddings"]["entities"] for ent in data["graph-embeddings"]["entities"]
] ]

View file

@ -41,14 +41,13 @@ class CoreExport:
{ {
"m": { "m": {
"i": data["metadata"]["id"], "i": data["metadata"]["id"],
"m": data["metadata"]["metadata"],
"u": data["metadata"]["user"], "u": data["metadata"]["user"],
"c": data["metadata"]["collection"], "c": data["metadata"]["collection"],
}, },
"e": [ "e": [
{ {
"e": ent["entity"], "e": ent["entity"],
"v": ent["vectors"], "v": ent["vector"],
} }
for ent in data["entities"] for ent in data["entities"]
] ]
@ -66,7 +65,6 @@ class CoreExport:
{ {
"m": { "m": {
"i": data["metadata"]["id"], "i": data["metadata"]["id"],
"m": data["metadata"]["metadata"],
"u": data["metadata"]["user"], "u": data["metadata"]["user"],
"c": data["metadata"]["collection"], "c": data["metadata"]["collection"],
}, },

View file

@ -48,7 +48,6 @@ class CoreImport:
"triples": { "triples": {
"metadata": { "metadata": {
"id": id, "id": id,
"metadata": msg["m"]["m"],
"user": user, "user": user,
"collection": "default", # Not used? "collection": "default", # Not used?
}, },
@ -67,14 +66,13 @@ class CoreImport:
"graph-embeddings": { "graph-embeddings": {
"metadata": { "metadata": {
"id": id, "id": id,
"metadata": msg["m"]["m"],
"user": user, "user": user,
"collection": "default", # Not used? "collection": "default", # Not used?
}, },
"entities": [ "entities": [
{ {
"entity": ent["e"], "entity": ent["e"],
"vectors": ent["v"], "vector": ent["v"],
} }
for ent in msg["e"] for ent in msg["e"]
] ]

View file

@ -8,7 +8,7 @@ from ... schema import Metadata
from ... schema import EntityContexts, EntityContext from ... schema import EntityContexts, EntityContext
from ... base import Publisher from ... base import Publisher
from . serialize import to_subgraph, to_value from . serialize import to_value
# Module logger # Module logger
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -48,7 +48,6 @@ class EntityContextsImport:
elt = EntityContexts( elt = EntityContexts(
metadata=Metadata( metadata=Metadata(
id=data["metadata"]["id"], id=data["metadata"]["id"],
metadata=to_subgraph(data["metadata"]["metadata"]),
user=data["metadata"]["user"], user=data["metadata"]["user"],
collection=data["metadata"]["collection"], collection=data["metadata"]["collection"],
), ),

View file

@ -8,7 +8,7 @@ from ... schema import Metadata
from ... schema import GraphEmbeddings, EntityEmbeddings from ... schema import GraphEmbeddings, EntityEmbeddings
from ... base import Publisher from ... base import Publisher
from . serialize import to_subgraph, to_value from . serialize import to_value
# Module logger # Module logger
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -48,14 +48,13 @@ class GraphEmbeddingsImport:
elt = GraphEmbeddings( elt = GraphEmbeddings(
metadata=Metadata( metadata=Metadata(
id=data["metadata"]["id"], id=data["metadata"]["id"],
metadata=to_subgraph(data["metadata"]["metadata"]),
user=data["metadata"]["user"], user=data["metadata"]["user"],
collection=data["metadata"]["collection"], collection=data["metadata"]["collection"],
), ),
entities=[ entities=[
EntityEmbeddings( EntityEmbeddings(
entity=to_value(ent["entity"]), entity=to_value(ent["entity"]),
vectors=ent["vectors"], vector=ent["vector"],
) )
for ent in data["entities"] for ent in data["entities"]
] ]

View file

@ -443,7 +443,7 @@ class KnowledgeTableStore:
entities = [ entities = [
EntityEmbeddings( EntityEmbeddings(
entity = tuple_to_term(ent[0][0], ent[0][1]), entity = tuple_to_term(ent[0][0], ent[0][1]),
vectors = ent[1] vector = ent[1]
) )
for ent in row[3] for ent in row[3]
] ]