mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-25 00:16:23 +02:00
Fix Metadata/EntityEmbeddings schema migration tail and add regression tests (#776)
The Metadata dataclass dropped its `metadata: list[Triple]` field
and EntityEmbeddings/ChunkEmbeddings settled on a singular
`vector: list[float]` field, but several call sites kept passing
`Metadata(metadata=...)` and `EntityEmbeddings(vectors=...)`. The
bugs were latent until a websocket client first hit
`/api/v1/flow/default/import/entity-contexts`, at which point the
dispatcher TypeError'd on construction.
Production fixes (5 call sites on the same migration tail):
* trustgraph-flow gateway dispatchers entity_contexts_import.py
and graph_embeddings_import.py — drop the stale
Metadata(metadata=...) kwarg; switch graph_embeddings_import
to the singular `vector` wire key.
* trustgraph-base messaging translators knowledge.py and
document_loading.py — fix decode side to read the singular
`"vector"` key, matching what their own encode sides have
always written.
* trustgraph-flow tables/knowledge.py — fix Cassandra row
deserialiser to construct EntityEmbeddings(vector=...)
instead of vectors=.
* trustgraph-flow gateway core_import/core_export — switch the
kg-core msgpack wire format to the singular `"v"`/`"vector"`
key and drop the dead `m["m"]` envelope field that referenced
the removed Metadata.metadata triples list (it was a
guaranteed KeyError on the export side).
Defense-in-depth regression coverage (32 new tests across 7 files):
* tests/contract/test_schema_field_contracts.py — pin the field
set of Metadata, EntityEmbeddings, ChunkEmbeddings,
EntityContext so any future schema rename fails CI loudly
with a clear diff.
* tests/unit/test_translators/test_knowledge_translator_roundtrip.py
and test_document_embeddings_translator_roundtrip.py -
encode→decode round-trip the affected translators end to end,
locking in the singular `"vector"` wire key.
* tests/unit/test_gateway/test_entity_contexts_import_dispatcher.py
and test_graph_embeddings_import_dispatcher.py — exercise the
websocket dispatchers' receive() path with realistic
payloads, the direct regression test for the original
production crash.
* tests/unit/test_gateway/test_core_import_export_roundtrip.py
— pack/unpack the kg-core msgpack format through the real
dispatcher classes (with KnowledgeRequestor mocked),
including a full export→import round-trip.
* tests/unit/test_tables/test_knowledge_table_store.py —
exercise the Cassandra row → schema conversion via __new__ to
bypass the live cluster connection.
Also fixes an unrelated leaked-coroutine RuntimeWarning in
test_gateway/test_service.py::test_run_method_calls_web_run_app: the
mocked aiohttp.web.run_app now closes the coroutine that Api.run() hands
it, mirroring what the real run_app would do, instead of leaving it for
the GC to complain about.
This commit is contained in:
parent
feeb92b33f
commit
7f5f2f955d
17 changed files with 1415 additions and 17 deletions
73
tests/contract/test_schema_field_contracts.py
Normal file
73
tests/contract/test_schema_field_contracts.py
Normal file
|
|
@ -0,0 +1,73 @@
|
|||
"""
|
||||
Contract tests for schema dataclass field sets.
|
||||
|
||||
These pin the *field names* of small, widely-constructed schema dataclasses
|
||||
so that any rename, removal, or accidental addition fails CI loudly instead
|
||||
of waiting for a runtime TypeError on the next websocket message.
|
||||
|
||||
Background: in v2.2 the `Metadata` dataclass dropped a `metadata: list[Triple]`
|
||||
field but several call sites kept passing `Metadata(metadata=...)`. The bug
|
||||
was only discovered when a websocket import dispatcher received its first
|
||||
real message in production. A trivial structural assertion of the kind
|
||||
below would have caught it at unit-test time.
|
||||
|
||||
Add to this file whenever a schema rename burns you. The cost of a frozen
|
||||
field set is a one-line update when you intentionally evolve the schema; the
|
||||
benefit is that every call site is forced to come along for the ride.
|
||||
"""
|
||||
|
||||
import dataclasses
|
||||
import pytest
|
||||
|
||||
from trustgraph.schema import (
|
||||
Metadata,
|
||||
EntityContext,
|
||||
EntityEmbeddings,
|
||||
ChunkEmbeddings,
|
||||
)
|
||||
|
||||
|
||||
def _field_names(dc):
|
||||
return {f.name for f in dataclasses.fields(dc)}
|
||||
|
||||
|
||||
@pytest.mark.contract
|
||||
class TestSchemaFieldContracts:
|
||||
"""Pin the field set of dataclasses that get constructed all over the
|
||||
codebase. If you intentionally change one of these, update the
|
||||
expected set in the same commit — that diff will surface every call
|
||||
site that needs to come along."""
|
||||
|
||||
def test_metadata_fields(self):
|
||||
# NOTE: there is no `metadata` field. A previous regression
|
||||
# constructed Metadata(metadata=...) and crashed at runtime.
|
||||
assert _field_names(Metadata) == {
|
||||
"id",
|
||||
"root",
|
||||
"user",
|
||||
"collection",
|
||||
}
|
||||
|
||||
def test_entity_embeddings_fields(self):
|
||||
# NOTE: the embedding field is `vector` (singular, list[float]).
|
||||
# There is no `vectors` field. Several call sites historically
|
||||
# passed `vectors=` and crashed at runtime.
|
||||
assert _field_names(EntityEmbeddings) == {
|
||||
"entity",
|
||||
"vector",
|
||||
"chunk_id",
|
||||
}
|
||||
|
||||
def test_chunk_embeddings_fields(self):
|
||||
# Same `vector` (singular) convention as EntityEmbeddings.
|
||||
assert _field_names(ChunkEmbeddings) == {
|
||||
"chunk_id",
|
||||
"vector",
|
||||
}
|
||||
|
||||
def test_entity_context_fields(self):
|
||||
assert _field_names(EntityContext) == {
|
||||
"entity",
|
||||
"context",
|
||||
"chunk_id",
|
||||
}
|
||||
418
tests/unit/test_gateway/test_core_import_export_roundtrip.py
Normal file
418
tests/unit/test_gateway/test_core_import_export_roundtrip.py
Normal file
|
|
@ -0,0 +1,418 @@
|
|||
"""
|
||||
Round-trip unit tests for the core msgpack import/export gateway endpoints.
|
||||
|
||||
The kg-core export endpoint receives KnowledgeResponse-shaped dicts from
|
||||
the responder callback and packs them into msgpack tuples. The kg-core
|
||||
import endpoint takes msgpack tuples back off the wire and rebuilds
|
||||
KnowledgeRequest-shaped dicts which it then hands to KnowledgeRequestor
|
||||
(whose translator decodes them into real dataclasses).
|
||||
|
||||
Regression coverage: the previous wire format used `"vectors"` (plural)
|
||||
in the entity blobs and embedded a stale `"m"` field that referenced the
|
||||
removed `Metadata.metadata` triples-list field. The export side hit a
|
||||
KeyError on first message; the import side built dicts that the
|
||||
KnowledgeRequestTranslator (separately fixed) couldn't decode. These
|
||||
tests pin both halves of the wire protocol.
|
||||
"""
|
||||
|
||||
import msgpack
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
from trustgraph.gateway.dispatch.core_export import CoreExport
|
||||
from trustgraph.gateway.dispatch.core_import import CoreImport
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers — sample translator-shaped dicts (as KnowledgeResponseTranslator
|
||||
# would emit). The vector wire key is *singular* on purpose; the export
|
||||
# side previously read the wrong key and crashed.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _ge_response_dict():
|
||||
return {
|
||||
"graph-embeddings": {
|
||||
"metadata": {
|
||||
"id": "doc-1",
|
||||
"root": "",
|
||||
"user": "alice",
|
||||
"collection": "testcoll",
|
||||
},
|
||||
"entities": [
|
||||
{
|
||||
"entity": {"t": "i", "i": "http://example.org/alice"},
|
||||
"vector": [0.1, 0.2, 0.3],
|
||||
},
|
||||
{
|
||||
"entity": {"t": "i", "i": "http://example.org/bob"},
|
||||
"vector": [0.4, 0.5, 0.6],
|
||||
},
|
||||
],
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def _triples_response_dict():
|
||||
return {
|
||||
"triples": {
|
||||
"metadata": {
|
||||
"id": "doc-1",
|
||||
"root": "",
|
||||
"user": "alice",
|
||||
"collection": "testcoll",
|
||||
},
|
||||
"triples": [
|
||||
{
|
||||
"s": {"t": "i", "i": "http://example.org/alice"},
|
||||
"p": {"t": "i", "i": "http://example.org/knows"},
|
||||
"o": {"t": "i", "i": "http://example.org/bob"},
|
||||
},
|
||||
],
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def _make_request(id_="doc-1", user="alice"):
|
||||
request = Mock()
|
||||
request.query = {"id": id_, "user": user}
|
||||
return request
|
||||
|
||||
|
||||
def _make_data_reader(payload: bytes):
|
||||
"""Mock the aiohttp StreamReader: returns payload once, then EOF."""
|
||||
chunks = [payload, b""]
|
||||
|
||||
data = Mock()
|
||||
|
||||
async def fake_read(n):
|
||||
return chunks.pop(0) if chunks else b""
|
||||
|
||||
data.read = fake_read
|
||||
return data
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Export side: translator-shaped dict -> msgpack bytes
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCoreExportWireFormat:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("trustgraph.gateway.dispatch.core_export.KnowledgeRequestor")
|
||||
async def test_export_packs_graph_embeddings_with_singular_vector(
|
||||
self, mock_kr_class,
|
||||
):
|
||||
"""The export side must read `ent["vector"]` and emit `v`. The
|
||||
previous bug was reading `ent["vectors"]` which KeyErrored against
|
||||
the translator output."""
|
||||
captured = []
|
||||
|
||||
async def fake_kr_process(req_dict, responder):
|
||||
await responder(_ge_response_dict(), True)
|
||||
|
||||
mock_kr = AsyncMock()
|
||||
mock_kr.start = AsyncMock()
|
||||
mock_kr.stop = AsyncMock()
|
||||
mock_kr.process = fake_kr_process
|
||||
mock_kr_class.return_value = mock_kr
|
||||
|
||||
response = AsyncMock()
|
||||
|
||||
async def fake_write(b):
|
||||
captured.append(b)
|
||||
|
||||
response.write = fake_write
|
||||
response.write_eof = AsyncMock()
|
||||
|
||||
ok = AsyncMock(return_value=response)
|
||||
error = AsyncMock()
|
||||
|
||||
exporter = CoreExport(backend=Mock())
|
||||
await exporter.process(
|
||||
data=Mock(),
|
||||
error=error,
|
||||
ok=ok,
|
||||
request=_make_request(),
|
||||
)
|
||||
|
||||
# Did not raise, did not call error()
|
||||
error.assert_not_called()
|
||||
assert len(captured) == 1
|
||||
|
||||
unpacker = msgpack.Unpacker()
|
||||
unpacker.feed(captured[0])
|
||||
items = list(unpacker)
|
||||
|
||||
assert len(items) == 1
|
||||
msg_type, payload = items[0]
|
||||
assert msg_type == "ge"
|
||||
|
||||
# Metadata envelope: only id/user/collection — no stale `m["m"]`.
|
||||
assert payload["m"] == {
|
||||
"i": "doc-1",
|
||||
"u": "alice",
|
||||
"c": "testcoll",
|
||||
}
|
||||
|
||||
# Entities: each carries the *singular* `v` and the term envelope
|
||||
assert len(payload["e"]) == 2
|
||||
assert payload["e"][0]["v"] == [0.1, 0.2, 0.3]
|
||||
assert payload["e"][1]["v"] == [0.4, 0.5, 0.6]
|
||||
assert payload["e"][0]["e"]["i"] == "http://example.org/alice"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("trustgraph.gateway.dispatch.core_export.KnowledgeRequestor")
|
||||
async def test_export_packs_triples(self, mock_kr_class):
|
||||
captured = []
|
||||
|
||||
async def fake_kr_process(req_dict, responder):
|
||||
await responder(_triples_response_dict(), True)
|
||||
|
||||
mock_kr = AsyncMock()
|
||||
mock_kr.start = AsyncMock()
|
||||
mock_kr.stop = AsyncMock()
|
||||
mock_kr.process = fake_kr_process
|
||||
mock_kr_class.return_value = mock_kr
|
||||
|
||||
response = AsyncMock()
|
||||
|
||||
async def fake_write(b):
|
||||
captured.append(b)
|
||||
|
||||
response.write = fake_write
|
||||
response.write_eof = AsyncMock()
|
||||
|
||||
ok = AsyncMock(return_value=response)
|
||||
error = AsyncMock()
|
||||
|
||||
exporter = CoreExport(backend=Mock())
|
||||
await exporter.process(
|
||||
data=Mock(), error=error, ok=ok, request=_make_request(),
|
||||
)
|
||||
|
||||
error.assert_not_called()
|
||||
assert len(captured) == 1
|
||||
|
||||
unpacker = msgpack.Unpacker()
|
||||
unpacker.feed(captured[0])
|
||||
items = list(unpacker)
|
||||
assert len(items) == 1
|
||||
|
||||
msg_type, payload = items[0]
|
||||
assert msg_type == "t"
|
||||
assert payload["m"] == {
|
||||
"i": "doc-1",
|
||||
"u": "alice",
|
||||
"c": "testcoll",
|
||||
}
|
||||
assert len(payload["t"]) == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Import side: msgpack bytes -> translator-shaped dict
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCoreImportWireFormat:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("trustgraph.gateway.dispatch.core_import.KnowledgeRequestor")
|
||||
async def test_import_unpacks_graph_embeddings_to_singular_vector(
|
||||
self, mock_kr_class,
|
||||
):
|
||||
"""The import side must build dicts whose entity blobs have the
|
||||
singular `vector` key — that's what the KnowledgeRequestTranslator
|
||||
decode side reads. Previous bug emitted `vectors`."""
|
||||
captured = []
|
||||
|
||||
async def fake_kr_process(req_dict):
|
||||
captured.append(req_dict)
|
||||
|
||||
mock_kr = AsyncMock()
|
||||
mock_kr.start = AsyncMock()
|
||||
mock_kr.stop = AsyncMock()
|
||||
mock_kr.process = fake_kr_process
|
||||
mock_kr_class.return_value = mock_kr
|
||||
|
||||
# Build a msgpack tuple matching the new wire format
|
||||
payload = msgpack.packb((
|
||||
"ge",
|
||||
{
|
||||
"m": {"i": "doc-1", "u": "alice", "c": "testcoll"},
|
||||
"e": [
|
||||
{
|
||||
"e": {"t": "i", "i": "http://example.org/alice"},
|
||||
"v": [0.1, 0.2, 0.3],
|
||||
},
|
||||
],
|
||||
},
|
||||
))
|
||||
|
||||
ok = AsyncMock(return_value=AsyncMock(write_eof=AsyncMock()))
|
||||
error = AsyncMock()
|
||||
|
||||
importer = CoreImport(backend=Mock())
|
||||
await importer.process(
|
||||
data=_make_data_reader(payload),
|
||||
error=error,
|
||||
ok=ok,
|
||||
request=_make_request(),
|
||||
)
|
||||
|
||||
error.assert_not_called()
|
||||
assert len(captured) == 1
|
||||
|
||||
req = captured[0]
|
||||
assert req["operation"] == "put-kg-core"
|
||||
assert req["user"] == "alice"
|
||||
assert req["id"] == "doc-1"
|
||||
|
||||
ge = req["graph-embeddings"]
|
||||
# Metadata envelope must NOT contain a stale `metadata` key
|
||||
# referencing the removed Metadata.metadata field.
|
||||
assert "metadata" not in ge["metadata"]
|
||||
assert ge["metadata"] == {
|
||||
"id": "doc-1",
|
||||
"user": "alice",
|
||||
"collection": "default",
|
||||
}
|
||||
|
||||
# Entity blob carries the singular `vector` key
|
||||
assert len(ge["entities"]) == 1
|
||||
ent = ge["entities"][0]
|
||||
assert ent["vector"] == [0.1, 0.2, 0.3]
|
||||
assert "vectors" not in ent
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("trustgraph.gateway.dispatch.core_import.KnowledgeRequestor")
|
||||
async def test_import_unpacks_triples(self, mock_kr_class):
|
||||
captured = []
|
||||
|
||||
async def fake_kr_process(req_dict):
|
||||
captured.append(req_dict)
|
||||
|
||||
mock_kr = AsyncMock()
|
||||
mock_kr.start = AsyncMock()
|
||||
mock_kr.stop = AsyncMock()
|
||||
mock_kr.process = fake_kr_process
|
||||
mock_kr_class.return_value = mock_kr
|
||||
|
||||
payload = msgpack.packb((
|
||||
"t",
|
||||
{
|
||||
"m": {"i": "doc-1", "u": "alice", "c": "testcoll"},
|
||||
"t": [
|
||||
{
|
||||
"s": {"t": "i", "i": "http://example.org/alice"},
|
||||
"p": {"t": "i", "i": "http://example.org/knows"},
|
||||
"o": {"t": "i", "i": "http://example.org/bob"},
|
||||
},
|
||||
],
|
||||
},
|
||||
))
|
||||
|
||||
ok = AsyncMock(return_value=AsyncMock(write_eof=AsyncMock()))
|
||||
error = AsyncMock()
|
||||
|
||||
importer = CoreImport(backend=Mock())
|
||||
await importer.process(
|
||||
data=_make_data_reader(payload),
|
||||
error=error,
|
||||
ok=ok,
|
||||
request=_make_request(),
|
||||
)
|
||||
|
||||
error.assert_not_called()
|
||||
assert len(captured) == 1
|
||||
|
||||
req = captured[0]
|
||||
triples = req["triples"]
|
||||
assert "metadata" not in triples["metadata"] # no stale field
|
||||
assert len(triples["triples"]) == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Full round-trip: export bytes feed directly into import
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCoreImportExportRoundTrip:
|
||||
"""End-to-end: produce bytes via core_export, consume them via
|
||||
core_import, and verify the dict that lands at the import-side
|
||||
translator is structurally equivalent to what went in. This is the
|
||||
test that catches asymmetries between the two halves."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("trustgraph.gateway.dispatch.core_import.KnowledgeRequestor")
|
||||
@patch("trustgraph.gateway.dispatch.core_export.KnowledgeRequestor")
|
||||
async def test_graph_embeddings_round_trip(
|
||||
self, mock_export_kr_class, mock_import_kr_class,
|
||||
):
|
||||
# ----- export side: capture bytes -----
|
||||
export_bytes = []
|
||||
|
||||
async def fake_export_process(req_dict, responder):
|
||||
await responder(_ge_response_dict(), True)
|
||||
|
||||
export_kr = AsyncMock()
|
||||
export_kr.start = AsyncMock()
|
||||
export_kr.stop = AsyncMock()
|
||||
export_kr.process = fake_export_process
|
||||
mock_export_kr_class.return_value = export_kr
|
||||
|
||||
response = AsyncMock()
|
||||
|
||||
async def fake_write(b):
|
||||
export_bytes.append(b)
|
||||
|
||||
response.write = fake_write
|
||||
response.write_eof = AsyncMock()
|
||||
|
||||
exporter = CoreExport(backend=Mock())
|
||||
await exporter.process(
|
||||
data=Mock(),
|
||||
error=AsyncMock(),
|
||||
ok=AsyncMock(return_value=response),
|
||||
request=_make_request(),
|
||||
)
|
||||
|
||||
assert len(export_bytes) == 1
|
||||
|
||||
# ----- import side: feed those bytes back in -----
|
||||
import_captured = []
|
||||
|
||||
async def fake_import_process(req_dict):
|
||||
import_captured.append(req_dict)
|
||||
|
||||
import_kr = AsyncMock()
|
||||
import_kr.start = AsyncMock()
|
||||
import_kr.stop = AsyncMock()
|
||||
import_kr.process = fake_import_process
|
||||
mock_import_kr_class.return_value = import_kr
|
||||
|
||||
importer = CoreImport(backend=Mock())
|
||||
await importer.process(
|
||||
data=_make_data_reader(export_bytes[0]),
|
||||
error=AsyncMock(),
|
||||
ok=AsyncMock(return_value=AsyncMock(write_eof=AsyncMock())),
|
||||
request=_make_request(),
|
||||
)
|
||||
|
||||
# ----- verify the dict the importer would hand to the translator -----
|
||||
assert len(import_captured) == 1
|
||||
req = import_captured[0]
|
||||
|
||||
original = _ge_response_dict()["graph-embeddings"]
|
||||
|
||||
ge = req["graph-embeddings"]
|
||||
# The import side overrides id/user from the URL query (intentional),
|
||||
# so we only round-trip the entity payload itself.
|
||||
assert ge["metadata"]["id"] == original["metadata"]["id"]
|
||||
assert ge["metadata"]["user"] == original["metadata"]["user"]
|
||||
|
||||
assert len(ge["entities"]) == len(original["entities"])
|
||||
for got, want in zip(ge["entities"], original["entities"]):
|
||||
assert got["vector"] == want["vector"]
|
||||
assert got["entity"] == want["entity"]
|
||||
|
|
@ -0,0 +1,242 @@
|
|||
"""
|
||||
Unit tests for entity contexts import dispatcher.
|
||||
|
||||
Tests the business logic of EntityContextsImport while mocking the
|
||||
Publisher and websocket components.
|
||||
|
||||
Regression coverage: a previous version constructed Metadata(metadata=...)
|
||||
which raised TypeError at runtime as soon as a message was received. These
|
||||
tests exercise receive() end-to-end so any future schema/kwarg drift in
|
||||
the Metadata or EntityContexts construction is caught immediately.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import Mock, AsyncMock, patch
|
||||
|
||||
from trustgraph.gateway.dispatch.entity_contexts_import import EntityContextsImport
|
||||
from trustgraph.schema import EntityContexts, EntityContext, Metadata
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_backend():
|
||||
return Mock()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_running():
|
||||
running = Mock()
|
||||
running.get.return_value = True
|
||||
running.stop = Mock()
|
||||
return running
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_websocket():
|
||||
ws = Mock()
|
||||
ws.close = AsyncMock()
|
||||
return ws
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_message():
|
||||
"""Sample entity-contexts websocket message."""
|
||||
return {
|
||||
"metadata": {
|
||||
"id": "doc-123",
|
||||
"user": "testuser",
|
||||
"collection": "testcollection",
|
||||
},
|
||||
"entities": [
|
||||
{
|
||||
"entity": {"v": "http://example.org/alice", "e": True},
|
||||
"context": "Alice is a person.",
|
||||
},
|
||||
{
|
||||
"entity": {"v": "http://example.org/bob", "e": True},
|
||||
"context": "Bob is a person.",
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def empty_entities_message():
|
||||
return {
|
||||
"metadata": {
|
||||
"id": "doc-empty",
|
||||
"user": "u",
|
||||
"collection": "c",
|
||||
},
|
||||
"entities": [],
|
||||
}
|
||||
|
||||
|
||||
class TestEntityContextsImportInitialization:
|
||||
|
||||
@patch('trustgraph.gateway.dispatch.entity_contexts_import.Publisher')
|
||||
def test_init_creates_publisher_with_correct_params(
|
||||
self, mock_publisher_class, mock_backend, mock_websocket, mock_running
|
||||
):
|
||||
instance = Mock()
|
||||
mock_publisher_class.return_value = instance
|
||||
|
||||
dispatcher = EntityContextsImport(
|
||||
ws=mock_websocket,
|
||||
running=mock_running,
|
||||
backend=mock_backend,
|
||||
queue="ec-queue",
|
||||
)
|
||||
|
||||
mock_publisher_class.assert_called_once_with(
|
||||
mock_backend,
|
||||
topic="ec-queue",
|
||||
schema=EntityContexts,
|
||||
)
|
||||
assert dispatcher.ws is mock_websocket
|
||||
assert dispatcher.running is mock_running
|
||||
assert dispatcher.publisher is instance
|
||||
|
||||
|
||||
class TestEntityContextsImportLifecycle:
|
||||
|
||||
@patch('trustgraph.gateway.dispatch.entity_contexts_import.Publisher')
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_calls_publisher_start(
|
||||
self, mock_publisher_class, mock_backend, mock_websocket, mock_running
|
||||
):
|
||||
instance = Mock()
|
||||
instance.start = AsyncMock()
|
||||
mock_publisher_class.return_value = instance
|
||||
|
||||
dispatcher = EntityContextsImport(
|
||||
ws=mock_websocket, running=mock_running,
|
||||
backend=mock_backend, queue="q",
|
||||
)
|
||||
await dispatcher.start()
|
||||
instance.start.assert_called_once()
|
||||
|
||||
@patch('trustgraph.gateway.dispatch.entity_contexts_import.Publisher')
|
||||
@pytest.mark.asyncio
|
||||
async def test_destroy_stops_and_closes_properly(
|
||||
self, mock_publisher_class, mock_backend, mock_websocket, mock_running
|
||||
):
|
||||
instance = Mock()
|
||||
instance.stop = AsyncMock()
|
||||
mock_publisher_class.return_value = instance
|
||||
|
||||
dispatcher = EntityContextsImport(
|
||||
ws=mock_websocket, running=mock_running,
|
||||
backend=mock_backend, queue="q",
|
||||
)
|
||||
await dispatcher.destroy()
|
||||
|
||||
mock_running.stop.assert_called_once()
|
||||
instance.stop.assert_called_once()
|
||||
mock_websocket.close.assert_called_once()
|
||||
|
||||
@patch('trustgraph.gateway.dispatch.entity_contexts_import.Publisher')
|
||||
@pytest.mark.asyncio
|
||||
async def test_destroy_handles_none_websocket(
|
||||
self, mock_publisher_class, mock_backend, mock_running
|
||||
):
|
||||
instance = Mock()
|
||||
instance.stop = AsyncMock()
|
||||
mock_publisher_class.return_value = instance
|
||||
|
||||
dispatcher = EntityContextsImport(
|
||||
ws=None, running=mock_running,
|
||||
backend=mock_backend, queue="q",
|
||||
)
|
||||
await dispatcher.destroy()
|
||||
|
||||
mock_running.stop.assert_called_once()
|
||||
instance.stop.assert_called_once()
|
||||
|
||||
|
||||
class TestEntityContextsImportMessageProcessing:
|
||||
"""Regression coverage for receive(): catches Metadata/schema drift."""
|
||||
|
||||
@patch('trustgraph.gateway.dispatch.entity_contexts_import.Publisher')
|
||||
@pytest.mark.asyncio
|
||||
async def test_receive_constructs_entity_contexts_correctly(
|
||||
self, mock_publisher_class, mock_backend, mock_websocket,
|
||||
mock_running, sample_message,
|
||||
):
|
||||
instance = Mock()
|
||||
instance.send = AsyncMock()
|
||||
mock_publisher_class.return_value = instance
|
||||
|
||||
dispatcher = EntityContextsImport(
|
||||
ws=mock_websocket, running=mock_running,
|
||||
backend=mock_backend, queue="q",
|
||||
)
|
||||
|
||||
mock_msg = Mock()
|
||||
mock_msg.json.return_value = sample_message
|
||||
|
||||
# If Metadata or EntityContexts gain/lose kwargs, this raises
|
||||
# TypeError — exactly the regression we want to catch.
|
||||
await dispatcher.receive(mock_msg)
|
||||
|
||||
instance.send.assert_called_once()
|
||||
call_args = instance.send.call_args
|
||||
assert call_args[0][0] is None
|
||||
|
||||
sent = call_args[0][1]
|
||||
assert isinstance(sent, EntityContexts)
|
||||
assert isinstance(sent.metadata, Metadata)
|
||||
assert sent.metadata.id == "doc-123"
|
||||
assert sent.metadata.user == "testuser"
|
||||
assert sent.metadata.collection == "testcollection"
|
||||
|
||||
assert len(sent.entities) == 2
|
||||
assert all(isinstance(e, EntityContext) for e in sent.entities)
|
||||
assert sent.entities[0].context == "Alice is a person."
|
||||
assert sent.entities[1].context == "Bob is a person."
|
||||
|
||||
@patch('trustgraph.gateway.dispatch.entity_contexts_import.Publisher')
|
||||
@pytest.mark.asyncio
|
||||
async def test_receive_handles_empty_entities(
|
||||
self, mock_publisher_class, mock_backend, mock_websocket,
|
||||
mock_running, empty_entities_message,
|
||||
):
|
||||
instance = Mock()
|
||||
instance.send = AsyncMock()
|
||||
mock_publisher_class.return_value = instance
|
||||
|
||||
dispatcher = EntityContextsImport(
|
||||
ws=mock_websocket, running=mock_running,
|
||||
backend=mock_backend, queue="q",
|
||||
)
|
||||
|
||||
mock_msg = Mock()
|
||||
mock_msg.json.return_value = empty_entities_message
|
||||
|
||||
await dispatcher.receive(mock_msg)
|
||||
|
||||
instance.send.assert_called_once()
|
||||
sent = instance.send.call_args[0][1]
|
||||
assert isinstance(sent, EntityContexts)
|
||||
assert sent.entities == []
|
||||
assert sent.metadata.id == "doc-empty"
|
||||
|
||||
@patch('trustgraph.gateway.dispatch.entity_contexts_import.Publisher')
|
||||
@pytest.mark.asyncio
|
||||
async def test_receive_propagates_publisher_errors(
|
||||
self, mock_publisher_class, mock_backend, mock_websocket,
|
||||
mock_running, sample_message,
|
||||
):
|
||||
instance = Mock()
|
||||
instance.send = AsyncMock(side_effect=RuntimeError("publish failed"))
|
||||
mock_publisher_class.return_value = instance
|
||||
|
||||
dispatcher = EntityContextsImport(
|
||||
ws=mock_websocket, running=mock_running,
|
||||
backend=mock_backend, queue="q",
|
||||
)
|
||||
|
||||
mock_msg = Mock()
|
||||
mock_msg.json.return_value = sample_message
|
||||
|
||||
with pytest.raises(RuntimeError, match="publish failed"):
|
||||
await dispatcher.receive(mock_msg)
|
||||
|
|
@ -0,0 +1,247 @@
|
|||
"""
|
||||
Unit tests for graph embeddings import dispatcher.
|
||||
|
||||
Tests the business logic of GraphEmbeddingsImport while mocking the
|
||||
Publisher and websocket components.
|
||||
|
||||
Regression coverage: a previous version of EntityContextsImport
|
||||
constructed Metadata(metadata=...) which raised TypeError at runtime as
|
||||
soon as a message was received. The same shape of bug can occur here, so
|
||||
these tests exercise receive() end-to-end to catch any future schema or
|
||||
kwarg drift in Metadata / GraphEmbeddings / EntityEmbeddings construction.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import Mock, AsyncMock, patch
|
||||
|
||||
from trustgraph.gateway.dispatch.graph_embeddings_import import GraphEmbeddingsImport
|
||||
from trustgraph.schema import GraphEmbeddings, EntityEmbeddings, Metadata
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_backend():
|
||||
return Mock()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_running():
|
||||
running = Mock()
|
||||
running.get.return_value = True
|
||||
running.stop = Mock()
|
||||
return running
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_websocket():
|
||||
ws = Mock()
|
||||
ws.close = AsyncMock()
|
||||
return ws
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_message():
|
||||
"""Sample graph-embeddings websocket message."""
|
||||
return {
|
||||
"metadata": {
|
||||
"id": "doc-123",
|
||||
"user": "testuser",
|
||||
"collection": "testcollection",
|
||||
},
|
||||
"entities": [
|
||||
{
|
||||
"entity": {"v": "http://example.org/alice", "e": True},
|
||||
"vector": [0.1, 0.2, 0.3],
|
||||
},
|
||||
{
|
||||
"entity": {"v": "http://example.org/bob", "e": True},
|
||||
"vector": [0.4, 0.5, 0.6],
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def empty_entities_message():
|
||||
return {
|
||||
"metadata": {
|
||||
"id": "doc-empty",
|
||||
"user": "u",
|
||||
"collection": "c",
|
||||
},
|
||||
"entities": [],
|
||||
}
|
||||
|
||||
|
||||
class TestGraphEmbeddingsImportInitialization:
|
||||
|
||||
@patch('trustgraph.gateway.dispatch.graph_embeddings_import.Publisher')
|
||||
def test_init_creates_publisher_with_correct_params(
|
||||
self, mock_publisher_class, mock_backend, mock_websocket, mock_running
|
||||
):
|
||||
instance = Mock()
|
||||
mock_publisher_class.return_value = instance
|
||||
|
||||
dispatcher = GraphEmbeddingsImport(
|
||||
ws=mock_websocket,
|
||||
running=mock_running,
|
||||
backend=mock_backend,
|
||||
queue="ge-queue",
|
||||
)
|
||||
|
||||
mock_publisher_class.assert_called_once_with(
|
||||
mock_backend,
|
||||
topic="ge-queue",
|
||||
schema=GraphEmbeddings,
|
||||
)
|
||||
assert dispatcher.ws is mock_websocket
|
||||
assert dispatcher.running is mock_running
|
||||
assert dispatcher.publisher is instance
|
||||
|
||||
|
||||
class TestGraphEmbeddingsImportLifecycle:
|
||||
|
||||
@patch('trustgraph.gateway.dispatch.graph_embeddings_import.Publisher')
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_calls_publisher_start(
|
||||
self, mock_publisher_class, mock_backend, mock_websocket, mock_running
|
||||
):
|
||||
instance = Mock()
|
||||
instance.start = AsyncMock()
|
||||
mock_publisher_class.return_value = instance
|
||||
|
||||
dispatcher = GraphEmbeddingsImport(
|
||||
ws=mock_websocket, running=mock_running,
|
||||
backend=mock_backend, queue="q",
|
||||
)
|
||||
await dispatcher.start()
|
||||
instance.start.assert_called_once()
|
||||
|
||||
@patch('trustgraph.gateway.dispatch.graph_embeddings_import.Publisher')
|
||||
@pytest.mark.asyncio
|
||||
async def test_destroy_stops_and_closes_properly(
|
||||
self, mock_publisher_class, mock_backend, mock_websocket, mock_running
|
||||
):
|
||||
instance = Mock()
|
||||
instance.stop = AsyncMock()
|
||||
mock_publisher_class.return_value = instance
|
||||
|
||||
dispatcher = GraphEmbeddingsImport(
|
||||
ws=mock_websocket, running=mock_running,
|
||||
backend=mock_backend, queue="q",
|
||||
)
|
||||
await dispatcher.destroy()
|
||||
|
||||
mock_running.stop.assert_called_once()
|
||||
instance.stop.assert_called_once()
|
||||
mock_websocket.close.assert_called_once()
|
||||
|
||||
@patch('trustgraph.gateway.dispatch.graph_embeddings_import.Publisher')
|
||||
@pytest.mark.asyncio
|
||||
async def test_destroy_handles_none_websocket(
|
||||
self, mock_publisher_class, mock_backend, mock_running
|
||||
):
|
||||
instance = Mock()
|
||||
instance.stop = AsyncMock()
|
||||
mock_publisher_class.return_value = instance
|
||||
|
||||
dispatcher = GraphEmbeddingsImport(
|
||||
ws=None, running=mock_running,
|
||||
backend=mock_backend, queue="q",
|
||||
)
|
||||
await dispatcher.destroy()
|
||||
|
||||
mock_running.stop.assert_called_once()
|
||||
instance.stop.assert_called_once()
|
||||
|
||||
|
||||
class TestGraphEmbeddingsImportMessageProcessing:
|
||||
"""Regression coverage for receive(): catches Metadata/schema drift."""
|
||||
|
||||
@patch('trustgraph.gateway.dispatch.graph_embeddings_import.Publisher')
|
||||
@pytest.mark.asyncio
|
||||
async def test_receive_constructs_graph_embeddings_correctly(
|
||||
self, mock_publisher_class, mock_backend, mock_websocket,
|
||||
mock_running, sample_message,
|
||||
):
|
||||
instance = Mock()
|
||||
instance.send = AsyncMock()
|
||||
mock_publisher_class.return_value = instance
|
||||
|
||||
dispatcher = GraphEmbeddingsImport(
|
||||
ws=mock_websocket, running=mock_running,
|
||||
backend=mock_backend, queue="q",
|
||||
)
|
||||
|
||||
mock_msg = Mock()
|
||||
mock_msg.json.return_value = sample_message
|
||||
|
||||
# If Metadata, GraphEmbeddings, or EntityEmbeddings gain/lose
|
||||
# kwargs, this raises TypeError — exactly the regression we want
|
||||
# to catch.
|
||||
await dispatcher.receive(mock_msg)
|
||||
|
||||
instance.send.assert_called_once()
|
||||
call_args = instance.send.call_args
|
||||
assert call_args[0][0] is None
|
||||
|
||||
sent = call_args[0][1]
|
||||
assert isinstance(sent, GraphEmbeddings)
|
||||
assert isinstance(sent.metadata, Metadata)
|
||||
assert sent.metadata.id == "doc-123"
|
||||
assert sent.metadata.user == "testuser"
|
||||
assert sent.metadata.collection == "testcollection"
|
||||
|
||||
assert len(sent.entities) == 2
|
||||
assert all(isinstance(e, EntityEmbeddings) for e in sent.entities)
|
||||
# Lock in the wire format: incoming "vector" key (singular,
|
||||
# list[float]) maps to EntityEmbeddings.vector. This mirrors
|
||||
# serialize_graph_embeddings() on the export side.
|
||||
assert sent.entities[0].vector == [0.1, 0.2, 0.3]
|
||||
assert sent.entities[1].vector == [0.4, 0.5, 0.6]
|
||||
|
||||
@patch('trustgraph.gateway.dispatch.graph_embeddings_import.Publisher')
|
||||
@pytest.mark.asyncio
|
||||
async def test_receive_handles_empty_entities(
|
||||
self, mock_publisher_class, mock_backend, mock_websocket,
|
||||
mock_running, empty_entities_message,
|
||||
):
|
||||
instance = Mock()
|
||||
instance.send = AsyncMock()
|
||||
mock_publisher_class.return_value = instance
|
||||
|
||||
dispatcher = GraphEmbeddingsImport(
|
||||
ws=mock_websocket, running=mock_running,
|
||||
backend=mock_backend, queue="q",
|
||||
)
|
||||
|
||||
mock_msg = Mock()
|
||||
mock_msg.json.return_value = empty_entities_message
|
||||
|
||||
await dispatcher.receive(mock_msg)
|
||||
|
||||
instance.send.assert_called_once()
|
||||
sent = instance.send.call_args[0][1]
|
||||
assert isinstance(sent, GraphEmbeddings)
|
||||
assert sent.entities == []
|
||||
assert sent.metadata.id == "doc-empty"
|
||||
|
||||
@patch('trustgraph.gateway.dispatch.graph_embeddings_import.Publisher')
|
||||
@pytest.mark.asyncio
|
||||
async def test_receive_propagates_publisher_errors(
|
||||
self, mock_publisher_class, mock_backend, mock_websocket,
|
||||
mock_running, sample_message,
|
||||
):
|
||||
instance = Mock()
|
||||
instance.send = AsyncMock(side_effect=RuntimeError("publish failed"))
|
||||
mock_publisher_class.return_value = instance
|
||||
|
||||
dispatcher = GraphEmbeddingsImport(
|
||||
ws=mock_websocket, running=mock_running,
|
||||
backend=mock_backend, queue="q",
|
||||
)
|
||||
|
||||
mock_msg = Mock()
|
||||
mock_msg.json.return_value = sample_message
|
||||
|
||||
with pytest.raises(RuntimeError, match="publish failed"):
|
||||
await dispatcher.receive(mock_msg)
|
||||
|
|
@ -171,6 +171,14 @@ class TestApi:
|
|||
patch('aiohttp.web.run_app') as mock_run_app:
|
||||
mock_get_pubsub.return_value = Mock()
|
||||
|
||||
# Api.run() passes self.app_factory() — a coroutine — to
|
||||
# web.run_app, which would normally consume it inside its own
|
||||
# event loop. Since we mock run_app, close the coroutine here
|
||||
# so it doesn't leak as an "unawaited coroutine" RuntimeWarning.
|
||||
def _consume_coro(coro, **kwargs):
|
||||
coro.close()
|
||||
mock_run_app.side_effect = _consume_coro
|
||||
|
||||
api = Api(port=8080)
|
||||
api.run()
|
||||
|
||||
|
|
|
|||
0
tests/unit/test_tables/__init__.py
Normal file
0
tests/unit/test_tables/__init__.py
Normal file
197
tests/unit/test_tables/test_knowledge_table_store.py
Normal file
197
tests/unit/test_tables/test_knowledge_table_store.py
Normal file
|
|
@ -0,0 +1,197 @@
|
|||
"""
|
||||
Unit tests for KnowledgeTableStore row deserialization.
|
||||
|
||||
Regression coverage: a previous version of get_graph_embeddings constructed
|
||||
EntityEmbeddings(vectors=ent[1]) — the schema field is `vector` (singular),
|
||||
so any real Cassandra row would crash on read. These tests bypass the live
|
||||
Cassandra connection entirely and exercise the row -> schema conversion
|
||||
with hand-built fake rows.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import Mock
|
||||
|
||||
from trustgraph.tables.knowledge import KnowledgeTableStore
|
||||
from trustgraph.schema import (
|
||||
EntityEmbeddings,
|
||||
GraphEmbeddings,
|
||||
Triples,
|
||||
Triple,
|
||||
Metadata,
|
||||
IRI,
|
||||
LITERAL,
|
||||
)
|
||||
|
||||
|
||||
def _make_store():
|
||||
"""
|
||||
Build a KnowledgeTableStore without invoking __init__ (which connects
|
||||
to Cassandra). Tests inject only the attributes the method under test
|
||||
actually touches.
|
||||
"""
|
||||
return KnowledgeTableStore.__new__(KnowledgeTableStore)
|
||||
|
||||
|
||||
class TestGetGraphEmbeddings:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_row_converts_to_entity_embeddings_with_singular_vector(self):
|
||||
"""
|
||||
Cassandra rows return entities as a list of [entity_tuple, vector]
|
||||
pairs in row[3]. The deserializer must construct EntityEmbeddings
|
||||
with `vector=` (singular) — the schema field name. A previous
|
||||
version used `vectors=` and TypeError'd at runtime.
|
||||
"""
|
||||
# Arrange — fake row matching the get_triples_stmt result shape:
|
||||
# row[0..2] are unused by the method, row[3] is the entities blob
|
||||
fake_row = (
|
||||
None, None, None,
|
||||
[
|
||||
# ((value, is_uri), vector)
|
||||
(("http://example.org/alice", True), [0.1, 0.2, 0.3]),
|
||||
(("http://example.org/bob", True), [0.4, 0.5, 0.6]),
|
||||
(("a literal entity", False), [0.7, 0.8, 0.9]),
|
||||
],
|
||||
)
|
||||
|
||||
store = _make_store()
|
||||
store.cassandra = Mock()
|
||||
store.cassandra.execute = Mock(return_value=[fake_row])
|
||||
store.get_graph_embeddings_stmt = Mock()
|
||||
|
||||
received = []
|
||||
|
||||
async def receiver(msg):
|
||||
received.append(msg)
|
||||
|
||||
# Act
|
||||
await store.get_graph_embeddings(
|
||||
user="alice",
|
||||
document_id="doc-1",
|
||||
receiver=receiver,
|
||||
)
|
||||
|
||||
# Assert
|
||||
store.cassandra.execute.assert_called_once_with(
|
||||
store.get_graph_embeddings_stmt,
|
||||
("alice", "doc-1"),
|
||||
)
|
||||
|
||||
assert len(received) == 1
|
||||
ge = received[0]
|
||||
assert isinstance(ge, GraphEmbeddings)
|
||||
assert isinstance(ge.metadata, Metadata)
|
||||
assert ge.metadata.id == "doc-1"
|
||||
assert ge.metadata.user == "alice"
|
||||
|
||||
assert len(ge.entities) == 3
|
||||
assert all(isinstance(e, EntityEmbeddings) for e in ge.entities)
|
||||
|
||||
# Vectors land in the singular `vector` field — this is the
|
||||
# explicit regression assertion for the original bug.
|
||||
assert ge.entities[0].vector == [0.1, 0.2, 0.3]
|
||||
assert ge.entities[1].vector == [0.4, 0.5, 0.6]
|
||||
assert ge.entities[2].vector == [0.7, 0.8, 0.9]
|
||||
|
||||
# Term type round-trips through tuple_to_term
|
||||
assert ge.entities[0].entity.type == IRI
|
||||
assert ge.entities[0].entity.iri == "http://example.org/alice"
|
||||
assert ge.entities[1].entity.type == IRI
|
||||
assert ge.entities[1].entity.iri == "http://example.org/bob"
|
||||
assert ge.entities[2].entity.type == LITERAL
|
||||
assert ge.entities[2].entity.value == "a literal entity"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_entities_blob_yields_empty_list(self):
|
||||
"""row[3] being None / empty must produce a GraphEmbeddings with
|
||||
no entities, not raise."""
|
||||
fake_row = (None, None, None, None)
|
||||
|
||||
store = _make_store()
|
||||
store.cassandra = Mock()
|
||||
store.cassandra.execute = Mock(return_value=[fake_row])
|
||||
store.get_graph_embeddings_stmt = Mock()
|
||||
|
||||
received = []
|
||||
|
||||
async def receiver(msg):
|
||||
received.append(msg)
|
||||
|
||||
await store.get_graph_embeddings("u", "d", receiver)
|
||||
|
||||
assert len(received) == 1
|
||||
assert received[0].entities == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_rows_each_emit_one_message(self):
|
||||
fake_rows = [
|
||||
(None, None, None, [
|
||||
(("http://example.org/a", True), [1.0]),
|
||||
]),
|
||||
(None, None, None, [
|
||||
(("http://example.org/b", True), [2.0]),
|
||||
]),
|
||||
]
|
||||
|
||||
store = _make_store()
|
||||
store.cassandra = Mock()
|
||||
store.cassandra.execute = Mock(return_value=fake_rows)
|
||||
store.get_graph_embeddings_stmt = Mock()
|
||||
|
||||
received = []
|
||||
|
||||
async def receiver(msg):
|
||||
received.append(msg)
|
||||
|
||||
await store.get_graph_embeddings("u", "d", receiver)
|
||||
|
||||
assert len(received) == 2
|
||||
assert received[0].entities[0].entity.iri == "http://example.org/a"
|
||||
assert received[0].entities[0].vector == [1.0]
|
||||
assert received[1].entities[0].entity.iri == "http://example.org/b"
|
||||
assert received[1].entities[0].vector == [2.0]
|
||||
|
||||
|
||||
class TestGetTriples:
|
||||
"""Bonus: the sibling get_triples path uses the same row[3] shape and
|
||||
the same Metadata construction. Cover it for parity."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_row_converts_to_triples(self):
|
||||
# row[3] is a list of (s_val, s_uri, p_val, p_uri, o_val, o_uri)
|
||||
fake_row = (
|
||||
None, None, None,
|
||||
[
|
||||
(
|
||||
"http://example.org/alice", True,
|
||||
"http://example.org/knows", True,
|
||||
"http://example.org/bob", True,
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
store = _make_store()
|
||||
store.cassandra = Mock()
|
||||
store.cassandra.execute = Mock(return_value=[fake_row])
|
||||
store.get_triples_stmt = Mock()
|
||||
|
||||
received = []
|
||||
|
||||
async def receiver(msg):
|
||||
received.append(msg)
|
||||
|
||||
await store.get_triples("alice", "doc-1", receiver)
|
||||
|
||||
assert len(received) == 1
|
||||
triples_msg = received[0]
|
||||
assert isinstance(triples_msg, Triples)
|
||||
assert isinstance(triples_msg.metadata, Metadata)
|
||||
assert triples_msg.metadata.id == "doc-1"
|
||||
assert triples_msg.metadata.user == "alice"
|
||||
|
||||
assert len(triples_msg.triples) == 1
|
||||
t = triples_msg.triples[0]
|
||||
assert isinstance(t, Triple)
|
||||
assert t.s.iri == "http://example.org/alice"
|
||||
assert t.p.iri == "http://example.org/knows"
|
||||
assert t.o.iri == "http://example.org/bob"
|
||||
0
tests/unit/test_translators/__init__.py
Normal file
0
tests/unit/test_translators/__init__.py
Normal file
|
|
@ -0,0 +1,66 @@
|
|||
"""
|
||||
Round-trip unit tests for DocumentEmbeddingsTranslator.
|
||||
|
||||
Regression coverage: a previous version of the decode side constructed
|
||||
ChunkEmbeddings(vectors=...) — the schema field is `vector` (singular),
|
||||
so any real DocumentEmbeddings message would crash on decode. The encode
|
||||
side already wrote `"vector"`, so encode→decode was asymmetric.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from trustgraph.messaging.translators.document_loading import (
|
||||
DocumentEmbeddingsTranslator,
|
||||
)
|
||||
from trustgraph.schema import (
|
||||
DocumentEmbeddings,
|
||||
ChunkEmbeddings,
|
||||
Metadata,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def translator():
|
||||
return DocumentEmbeddingsTranslator()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample():
|
||||
return DocumentEmbeddings(
|
||||
metadata=Metadata(
|
||||
id="doc-1",
|
||||
root="",
|
||||
user="alice",
|
||||
collection="testcoll",
|
||||
),
|
||||
chunks=[
|
||||
ChunkEmbeddings(chunk_id="c1", vector=[0.1, 0.2, 0.3]),
|
||||
ChunkEmbeddings(chunk_id="c2", vector=[0.4, 0.5, 0.6]),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
class TestDocumentEmbeddingsTranslator:
|
||||
|
||||
def test_encode_uses_singular_vector_key(self, translator, sample):
|
||||
encoded = translator.encode(sample)
|
||||
chunks = encoded["chunks"]
|
||||
assert all("vector" in c for c in chunks)
|
||||
assert all("vectors" not in c for c in chunks)
|
||||
assert chunks[0]["vector"] == [0.1, 0.2, 0.3]
|
||||
|
||||
def test_roundtrip_preserves_document_embeddings(self, translator, sample):
|
||||
encoded = translator.encode(sample)
|
||||
decoded = translator.decode(encoded)
|
||||
|
||||
assert isinstance(decoded, DocumentEmbeddings)
|
||||
assert isinstance(decoded.metadata, Metadata)
|
||||
assert decoded.metadata.id == "doc-1"
|
||||
assert decoded.metadata.user == "alice"
|
||||
assert decoded.metadata.collection == "testcoll"
|
||||
|
||||
assert len(decoded.chunks) == 2
|
||||
assert decoded.chunks[0].chunk_id == "c1"
|
||||
assert decoded.chunks[0].vector == [0.1, 0.2, 0.3]
|
||||
assert decoded.chunks[1].chunk_id == "c2"
|
||||
assert decoded.chunks[1].vector == [0.4, 0.5, 0.6]
|
||||
|
|
@ -0,0 +1,153 @@
|
|||
"""
|
||||
Round-trip unit tests for KnowledgeRequestTranslator.
|
||||
|
||||
Regression coverage: a previous version of the decode side constructed
|
||||
EntityEmbeddings(vectors=...) — the schema field is `vector` (singular),
|
||||
so any real graph-embeddings KnowledgeRequest would crash on first
|
||||
message. The encode side already wrote `"vector"`, so encode→decode was
|
||||
asymmetric.
|
||||
|
||||
These tests build a real KnowledgeRequest with graph-embeddings, encode
|
||||
it, decode the result, and assert the round-trip is lossless. They also
|
||||
exercise the triples path so any future schema drift in Metadata or
|
||||
Triples breaks the test.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from trustgraph.messaging.translators.knowledge import KnowledgeRequestTranslator
|
||||
from trustgraph.schema import (
|
||||
KnowledgeRequest,
|
||||
GraphEmbeddings,
|
||||
EntityEmbeddings,
|
||||
Triples,
|
||||
Triple,
|
||||
Metadata,
|
||||
Term,
|
||||
IRI,
|
||||
)
|
||||
|
||||
|
||||
def _term_iri(uri):
|
||||
return Term(type=IRI, iri=uri)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def translator():
|
||||
return KnowledgeRequestTranslator()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def graph_embeddings_request():
|
||||
return KnowledgeRequest(
|
||||
operation="put-kg-core",
|
||||
user="alice",
|
||||
id="doc-1",
|
||||
flow="default",
|
||||
collection="testcoll",
|
||||
graph_embeddings=GraphEmbeddings(
|
||||
metadata=Metadata(
|
||||
id="doc-1",
|
||||
root="",
|
||||
user="alice",
|
||||
collection="testcoll",
|
||||
),
|
||||
entities=[
|
||||
EntityEmbeddings(
|
||||
entity=_term_iri("http://example.org/alice"),
|
||||
vector=[0.1, 0.2, 0.3],
|
||||
),
|
||||
EntityEmbeddings(
|
||||
entity=_term_iri("http://example.org/bob"),
|
||||
vector=[0.4, 0.5, 0.6],
|
||||
),
|
||||
],
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def triples_request():
|
||||
return KnowledgeRequest(
|
||||
operation="put-kg-core",
|
||||
user="alice",
|
||||
id="doc-1",
|
||||
flow="default",
|
||||
collection="testcoll",
|
||||
triples=Triples(
|
||||
metadata=Metadata(
|
||||
id="doc-1",
|
||||
root="",
|
||||
user="alice",
|
||||
collection="testcoll",
|
||||
),
|
||||
triples=[
|
||||
Triple(
|
||||
s=_term_iri("http://example.org/alice"),
|
||||
p=_term_iri("http://example.org/knows"),
|
||||
o=_term_iri("http://example.org/bob"),
|
||||
),
|
||||
],
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class TestKnowledgeRequestTranslatorGraphEmbeddings:
|
||||
|
||||
def test_encode_produces_singular_vector_key(
|
||||
self, translator, graph_embeddings_request,
|
||||
):
|
||||
"""The wire key must be `vector`, never `vectors`."""
|
||||
encoded = translator.encode(graph_embeddings_request)
|
||||
entities = encoded["graph-embeddings"]["entities"]
|
||||
assert all("vector" in e for e in entities)
|
||||
assert all("vectors" not in e for e in entities)
|
||||
assert entities[0]["vector"] == [0.1, 0.2, 0.3]
|
||||
|
||||
def test_roundtrip_preserves_graph_embeddings(
|
||||
self, translator, graph_embeddings_request,
|
||||
):
|
||||
"""encode -> decode must be lossless for the GE branch."""
|
||||
encoded = translator.encode(graph_embeddings_request)
|
||||
decoded = translator.decode(encoded)
|
||||
|
||||
assert isinstance(decoded, KnowledgeRequest)
|
||||
assert decoded.operation == "put-kg-core"
|
||||
assert decoded.user == "alice"
|
||||
assert decoded.id == "doc-1"
|
||||
assert decoded.flow == "default"
|
||||
assert decoded.collection == "testcoll"
|
||||
|
||||
assert decoded.graph_embeddings is not None
|
||||
ge = decoded.graph_embeddings
|
||||
assert isinstance(ge, GraphEmbeddings)
|
||||
assert isinstance(ge.metadata, Metadata)
|
||||
assert ge.metadata.id == "doc-1"
|
||||
assert ge.metadata.user == "alice"
|
||||
assert ge.metadata.collection == "testcoll"
|
||||
|
||||
assert len(ge.entities) == 2
|
||||
assert ge.entities[0].vector == [0.1, 0.2, 0.3]
|
||||
assert ge.entities[1].vector == [0.4, 0.5, 0.6]
|
||||
assert ge.entities[0].entity.iri == "http://example.org/alice"
|
||||
assert ge.entities[1].entity.iri == "http://example.org/bob"
|
||||
|
||||
|
||||
class TestKnowledgeRequestTranslatorTriples:
|
||||
|
||||
def test_roundtrip_preserves_triples(self, translator, triples_request):
|
||||
encoded = translator.encode(triples_request)
|
||||
decoded = translator.decode(encoded)
|
||||
|
||||
assert isinstance(decoded, KnowledgeRequest)
|
||||
assert decoded.triples is not None
|
||||
assert isinstance(decoded.triples.metadata, Metadata)
|
||||
assert decoded.triples.metadata.id == "doc-1"
|
||||
assert decoded.triples.metadata.user == "alice"
|
||||
assert decoded.triples.metadata.collection == "testcoll"
|
||||
|
||||
assert len(decoded.triples.triples) == 1
|
||||
t = decoded.triples.triples[0]
|
||||
assert t.s.iri == "http://example.org/alice"
|
||||
assert t.p.iri == "http://example.org/knows"
|
||||
assert t.o.iri == "http://example.org/bob"
|
||||
|
|
@ -151,7 +151,7 @@ class DocumentEmbeddingsTranslator(SendTranslator):
|
|||
chunks = [
|
||||
ChunkEmbeddings(
|
||||
chunk_id=chunk["chunk_id"],
|
||||
vectors=chunk["vectors"]
|
||||
vector=chunk["vector"]
|
||||
)
|
||||
for chunk in data.get("chunks", [])
|
||||
]
|
||||
|
|
|
|||
|
|
@ -39,7 +39,7 @@ class KnowledgeRequestTranslator(MessageTranslator):
|
|||
entities=[
|
||||
EntityEmbeddings(
|
||||
entity=self.value_translator.decode(ent["entity"]),
|
||||
vectors=ent["vectors"],
|
||||
vector=ent["vector"],
|
||||
)
|
||||
for ent in data["graph-embeddings"]["entities"]
|
||||
]
|
||||
|
|
|
|||
|
|
@ -41,14 +41,13 @@ class CoreExport:
|
|||
{
|
||||
"m": {
|
||||
"i": data["metadata"]["id"],
|
||||
"m": data["metadata"]["metadata"],
|
||||
"u": data["metadata"]["user"],
|
||||
"c": data["metadata"]["collection"],
|
||||
},
|
||||
"e": [
|
||||
{
|
||||
"e": ent["entity"],
|
||||
"v": ent["vectors"],
|
||||
"v": ent["vector"],
|
||||
}
|
||||
for ent in data["entities"]
|
||||
]
|
||||
|
|
@ -66,7 +65,6 @@ class CoreExport:
|
|||
{
|
||||
"m": {
|
||||
"i": data["metadata"]["id"],
|
||||
"m": data["metadata"]["metadata"],
|
||||
"u": data["metadata"]["user"],
|
||||
"c": data["metadata"]["collection"],
|
||||
},
|
||||
|
|
|
|||
|
|
@ -48,7 +48,6 @@ class CoreImport:
|
|||
"triples": {
|
||||
"metadata": {
|
||||
"id": id,
|
||||
"metadata": msg["m"]["m"],
|
||||
"user": user,
|
||||
"collection": "default", # Not used?
|
||||
},
|
||||
|
|
@ -67,14 +66,13 @@ class CoreImport:
|
|||
"graph-embeddings": {
|
||||
"metadata": {
|
||||
"id": id,
|
||||
"metadata": msg["m"]["m"],
|
||||
"user": user,
|
||||
"collection": "default", # Not used?
|
||||
},
|
||||
"entities": [
|
||||
{
|
||||
"entity": ent["e"],
|
||||
"vectors": ent["v"],
|
||||
"vector": ent["v"],
|
||||
}
|
||||
for ent in msg["e"]
|
||||
]
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ from ... schema import Metadata
|
|||
from ... schema import EntityContexts, EntityContext
|
||||
from ... base import Publisher
|
||||
|
||||
from . serialize import to_subgraph, to_value
|
||||
from . serialize import to_value
|
||||
|
||||
# Module logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -48,7 +48,6 @@ class EntityContextsImport:
|
|||
elt = EntityContexts(
|
||||
metadata=Metadata(
|
||||
id=data["metadata"]["id"],
|
||||
metadata=to_subgraph(data["metadata"]["metadata"]),
|
||||
user=data["metadata"]["user"],
|
||||
collection=data["metadata"]["collection"],
|
||||
),
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ from ... schema import Metadata
|
|||
from ... schema import GraphEmbeddings, EntityEmbeddings
|
||||
from ... base import Publisher
|
||||
|
||||
from . serialize import to_subgraph, to_value
|
||||
from . serialize import to_value
|
||||
|
||||
# Module logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -48,14 +48,13 @@ class GraphEmbeddingsImport:
|
|||
elt = GraphEmbeddings(
|
||||
metadata=Metadata(
|
||||
id=data["metadata"]["id"],
|
||||
metadata=to_subgraph(data["metadata"]["metadata"]),
|
||||
user=data["metadata"]["user"],
|
||||
collection=data["metadata"]["collection"],
|
||||
),
|
||||
entities=[
|
||||
EntityEmbeddings(
|
||||
entity=to_value(ent["entity"]),
|
||||
vectors=ent["vectors"],
|
||||
vector=ent["vector"],
|
||||
)
|
||||
for ent in data["entities"]
|
||||
]
|
||||
|
|
|
|||
|
|
@ -443,7 +443,7 @@ class KnowledgeTableStore:
|
|||
entities = [
|
||||
EntityEmbeddings(
|
||||
entity = tuple_to_term(ent[0][0], ent[0][1]),
|
||||
vectors = ent[1]
|
||||
vector = ent[1]
|
||||
)
|
||||
for ent in row[3]
|
||||
]
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue