Merge remote-tracking branch 'origin/master' into ts-port-effect-v4

This commit is contained in:
elpresidank 2026-05-30 09:59:12 -05:00
commit 92dae8c374
117 changed files with 7392 additions and 3410 deletions

View file

@ -0,0 +1,296 @@
"""
Tests for the Library API wrapper round-trip behavior.
Covers the get_documents update_document path and edge cases
from issue #893.
"""
import datetime
import pytest
from unittest.mock import MagicMock, patch
from trustgraph.api.library import Library, to_value, from_value
from trustgraph.api.types import DocumentMetadata, Triple
from trustgraph.knowledge import Uri, Literal
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_library(response=None):
api = MagicMock()
api.workspace = "default"
api.request.return_value = response or {}
lib = Library(api)
return lib, api
def _wire_triple(s_iri, p_iri, o_val):
return {
"s": {"t": "i", "i": s_iri},
"p": {"t": "i", "i": p_iri},
"o": {"t": "l", "v": o_val},
}
def _doc_wire(id="doc-1", time=1700000000, title="Test Doc",
kind="text/plain", comments="", tags=None,
metadata=None, parent_id="", document_type="source",
include_title=True):
doc = {
"id": id,
"time": time,
"kind": kind,
"comments": comments,
"metadata": metadata or [],
"tags": tags or [],
"parent-id": parent_id,
"document-type": document_type,
}
if include_title:
doc["title"] = title
return doc
# ---------------------------------------------------------------------------
# Bug 1: get_documents tolerates missing title
# ---------------------------------------------------------------------------
class TestGetDocumentsMissingTitle:
def test_missing_title_defaults_to_empty(self):
doc = _doc_wire(include_title=False)
lib, api = _make_library({"document-metadatas": [doc]})
result = lib.get_documents()
assert len(result) == 1
assert result[0].title == ""
def test_present_title_preserved(self):
doc = _doc_wire(title="My Title")
lib, api = _make_library({"document-metadatas": [doc]})
result = lib.get_documents()
assert result[0].title == "My Title"
# ---------------------------------------------------------------------------
# Bug 2: update_document handles Triple objects (attribute access)
# ---------------------------------------------------------------------------
class TestUpdateDocumentTripleAccess:
def test_triple_objects_serialized_correctly(self):
lib, api = _make_library({})
metadata = DocumentMetadata(
id="doc-1",
time=datetime.datetime.fromtimestamp(1700000000),
kind="text/plain",
title="Test",
comments="",
metadata=[
Triple(
s=Uri("http://example.org/entity/alice"),
p=Uri("http://example.org/rel/knows"),
o=Literal("Bob"),
),
],
tags=["test"],
)
lib.update_document(id="doc-1", metadata=metadata)
call_args = api.request.call_args[0][1]
triples = call_args["document-metadata"]["metadata"]
assert len(triples) == 1
assert triples[0]["s"]["i"] == "http://example.org/entity/alice"
assert triples[0]["p"]["i"] == "http://example.org/rel/knows"
assert triples[0]["o"]["v"] == "Bob"
def test_empty_metadata_list(self):
lib, api = _make_library({})
metadata = DocumentMetadata(
id="doc-1",
time=datetime.datetime.fromtimestamp(1700000000),
kind="text/plain",
title="Test",
comments="",
metadata=[],
tags=[],
)
lib.update_document(id="doc-1", metadata=metadata)
call_args = api.request.call_args[0][1]
assert call_args["document-metadata"]["metadata"] == []
# ---------------------------------------------------------------------------
# Bug 3: update_document serializes datetime to int seconds
# ---------------------------------------------------------------------------
class TestUpdateDocumentTimeSerialization:
def test_datetime_serialized_to_int(self):
lib, api = _make_library({})
ts = 1700000000
metadata = DocumentMetadata(
id="doc-1",
time=datetime.datetime.fromtimestamp(ts),
kind="text/plain",
title="Test",
comments="",
metadata=[],
tags=[],
)
lib.update_document(id="doc-1", metadata=metadata)
call_args = api.request.call_args[0][1]
wire_time = call_args["document-metadata"]["time"]
assert isinstance(wire_time, int)
assert wire_time == ts
def test_int_time_passed_through(self):
lib, api = _make_library({})
metadata = DocumentMetadata(
id="doc-1",
time=1700000000,
kind="text/plain",
title="Test",
comments="",
metadata=[],
tags=[],
)
lib.update_document(id="doc-1", metadata=metadata)
call_args = api.request.call_args[0][1]
assert call_args["document-metadata"]["time"] == 1700000000
# ---------------------------------------------------------------------------
# Bug 4: update_document handles empty server response
# ---------------------------------------------------------------------------
class TestUpdateDocumentEmptyResponse:
def test_empty_response_returns_input_metadata(self):
lib, api = _make_library({})
metadata = DocumentMetadata(
id="doc-1",
time=datetime.datetime.fromtimestamp(1700000000),
kind="text/plain",
title="Updated Title",
comments="notes",
metadata=[],
tags=["a"],
)
result = lib.update_document(id="doc-1", metadata=metadata)
assert result is metadata
def test_full_response_parsed(self):
response_doc = _doc_wire(
id="doc-1", title="Server Title", tags=["b"],
)
lib, api = _make_library({"document-metadata": response_doc})
metadata = DocumentMetadata(
id="doc-1",
time=datetime.datetime.fromtimestamp(1700000000),
kind="text/plain",
title="Client Title",
comments="",
metadata=[],
tags=["a"],
)
result = lib.update_document(id="doc-1", metadata=metadata)
assert result.title == "Server Title"
assert result.tags == ["b"]
# ---------------------------------------------------------------------------
# Bug 5: update_document sends both id and document-id
# ---------------------------------------------------------------------------
class TestUpdateDocumentIdKeys:
def test_both_id_keys_sent(self):
lib, api = _make_library({})
metadata = DocumentMetadata(
id="doc-1",
time=datetime.datetime.fromtimestamp(1700000000),
kind="text/plain",
title="Test",
comments="",
metadata=[],
tags=[],
)
lib.update_document(id="doc-1", metadata=metadata)
call_args = api.request.call_args[0][1]
doc_meta = call_args["document-metadata"]
assert doc_meta["id"] == "doc-1"
assert doc_meta["document-id"] == "doc-1"
# ---------------------------------------------------------------------------
# Round-trip: get_documents → update_document
# ---------------------------------------------------------------------------
class TestGetUpdateRoundTrip:
def test_full_round_trip(self):
wire_doc = _doc_wire(
id="doc-42",
title="Original",
tags=["v1"],
metadata=[_wire_triple(
"http://example.org/e/1",
"http://example.org/r/type",
"report",
)],
)
lib, api = _make_library({"document-metadatas": [wire_doc]})
docs = lib.get_documents()
assert len(docs) == 1
doc = docs[0]
doc.title = "Updated"
doc.tags.append("v2")
# Server returns empty on update
api.request.return_value = {}
result = lib.update_document(id=doc.id, metadata=doc)
# Should not raise, should return the input metadata
assert result.title == "Updated"
assert "v2" in result.tags
# Verify the wire format sent
call_args = api.request.call_args[0][1]
doc_meta = call_args["document-metadata"]
assert doc_meta["id"] == "doc-42"
assert doc_meta["title"] == "Updated"
assert isinstance(doc_meta["time"], int)
assert len(doc_meta["metadata"]) == 1
assert doc_meta["metadata"][0]["o"]["v"] == "report"

View file

@ -272,23 +272,22 @@ class TestMetricsIntegration:
class TestPollTimeout:
@pytest.mark.asyncio
async def test_poll_timeout_is_100ms(self):
"""Consumer receive timeout should be 100ms, not the original 2000ms.
async def test_poll_timeout_is_2000ms(self):
"""Consumer receive timeout should be 2000ms.
A 2000ms poll timeout means every service adds up to 2s of idle
blocking between message bursts. With many sequential hops in a
query pipeline, this compounds into seconds of unnecessary latency.
100ms keeps responsiveness high without significant CPU overhead.
receive() is a blocking call that returns immediately when a
message arrives the timeout only governs how often the loop
checks the shutdown flag during idle periods. Lower values
(e.g. 100ms) generate excessive C++ client WARN logging with
no latency benefit.
"""
consumer = _make_consumer()
# Wire up a mock Pulsar consumer that records the receive kwargs
mock_pulsar_consumer = MagicMock()
received_kwargs = {}
def capture_receive(**kwargs):
received_kwargs.update(kwargs)
# Stop after one call
consumer.running = False
raise type('Timeout', (Exception,), {})("timeout")
@ -296,7 +295,7 @@ class TestPollTimeout:
await consumer.consume_from_queue(mock_pulsar_consumer)
assert received_kwargs.get("timeout_millis") == 100
assert received_kwargs.get("timeout_millis") == 2000
# ---------------------------------------------------------------------------

View file

@ -25,16 +25,17 @@ class TestSemaphoreEnforcement:
max_concurrent = 0
processing_event = asyncio.Event()
async def slow_process(message):
async def slow_process(message, sender):
nonlocal concurrent_count, max_concurrent
concurrent_count += 1
max_concurrent = max(max_concurrent, concurrent_count)
await asyncio.sleep(0.05)
concurrent_count -= 1
return {"id": message.get("id"), "response": {"ok": True}}
dispatcher._process_message = slow_process
sender = AsyncMock()
# Launch more tasks than max_workers
messages = [
{"id": f"msg-{i}", "service": "test", "request": {}}
@ -42,7 +43,7 @@ class TestSemaphoreEnforcement:
]
tasks = [
asyncio.create_task(dispatcher.handle_message(m))
asyncio.create_task(dispatcher.handle_message(m, sender))
for m in messages
]
@ -66,17 +67,17 @@ class TestSemaphoreEnforcement:
original_process = dispatcher._process_message
async def tracking_process(message):
async def tracking_process(message, sender):
nonlocal task_was_tracked
# During processing, our task should be in active_tasks
if len(dispatcher.active_tasks) > 0:
task_was_tracked = True
return {"id": message.get("id"), "response": {"ok": True}}
dispatcher._process_message = tracking_process
await dispatcher.handle_message(
{"id": "test", "service": "test", "request": {}}
{"id": "test", "service": "test", "request": {}},
AsyncMock(),
)
assert task_was_tracked
@ -88,7 +89,7 @@ class TestSemaphoreEnforcement:
"""Semaphore should be released even if processing raises."""
dispatcher = MessageDispatcher(max_workers=2)
async def failing_process(message):
async def failing_process(message, sender):
raise RuntimeError("process failed")
dispatcher._process_message = failing_process
@ -96,7 +97,8 @@ class TestSemaphoreEnforcement:
# Should not deadlock — semaphore must be released on error
with pytest.raises(RuntimeError):
await dispatcher.handle_message(
{"id": "test", "service": "test", "request": {}}
{"id": "test", "service": "test", "request": {}},
AsyncMock(),
)
# Semaphore should be back at max
@ -109,17 +111,18 @@ class TestSemaphoreEnforcement:
order = []
async def ordered_process(message):
async def ordered_process(message, sender):
msg_id = message["id"]
order.append(f"start-{msg_id}")
await asyncio.sleep(0.02)
order.append(f"end-{msg_id}")
return {"id": msg_id, "response": {"ok": True}}
dispatcher._process_message = ordered_process
sender = AsyncMock()
messages = [{"id": str(i), "service": "t", "request": {}} for i in range(3)]
tasks = [asyncio.create_task(dispatcher.handle_message(m)) for m in messages]
tasks = [asyncio.create_task(dispatcher.handle_message(m, sender)) for m in messages]
await asyncio.gather(*tasks)
# With semaphore=1, each message should complete before next starts

View file

@ -0,0 +1,389 @@
"""
Tests for TripleConverter domain/range enforcement and
OntologySelector bypass for small ontologies.
Covers fixes for #908 (bypass_selector_below) and #920 (domain/range validation).
"""
import pytest
from unittest.mock import Mock, AsyncMock
from trustgraph.extract.kg.ontology.triple_converter import TripleConverter
from trustgraph.extract.kg.ontology.ontology_selector import (
OntologySelector,
OntologySubset,
)
from trustgraph.extract.kg.ontology.ontology_loader import (
Ontology,
OntologyClass,
OntologyProperty,
)
from trustgraph.extract.kg.ontology.simplified_parser import (
Relationship,
Attribute,
)
from trustgraph.extract.kg.ontology.text_processor import TextSegment
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
@pytest.fixture
def ontology_subset():
"""Ontology subset with classes, hierarchy, and constrained properties."""
return OntologySubset(
ontology_id="test",
classes={
"Person": {
"uri": "http://example.org/Person",
"type": "owl:Class",
"labels": [{"value": "Person"}],
"subclass_of": None,
},
"Employee": {
"uri": "http://example.org/Employee",
"type": "owl:Class",
"labels": [{"value": "Employee"}],
"subclass_of": "Person",
},
"Manager": {
"uri": "http://example.org/Manager",
"type": "owl:Class",
"labels": [{"value": "Manager"}],
"subclass_of": "Employee",
},
"Company": {
"uri": "http://example.org/Company",
"type": "owl:Class",
"labels": [{"value": "Company"}],
"subclass_of": None,
},
"Product": {
"uri": "http://example.org/Product",
"type": "owl:Class",
"labels": [{"value": "Product"}],
"subclass_of": None,
},
},
object_properties={
"worksFor": {
"uri": "http://example.org/worksFor",
"type": "owl:ObjectProperty",
"labels": [{"value": "works for"}],
"domain": "Person",
"range": "Company",
},
"manages": {
"uri": "http://example.org/manages",
"type": "owl:ObjectProperty",
"labels": [{"value": "manages"}],
"domain": "Manager",
"range": "Employee",
},
"relatedTo": {
"uri": "http://example.org/relatedTo",
"type": "owl:ObjectProperty",
"labels": [{"value": "related to"}],
"domain": None,
"range": None,
},
},
datatype_properties={
"employeeId": {
"uri": "http://example.org/employeeId",
"type": "owl:DatatypeProperty",
"labels": [{"value": "employee ID"}],
"domain": "Employee",
},
"description": {
"uri": "http://example.org/description",
"type": "owl:DatatypeProperty",
"labels": [{"value": "description"}],
"domain": None,
},
},
metadata={"name": "Test Ontology"},
)
@pytest.fixture
def converter(ontology_subset):
return TripleConverter(ontology_subset=ontology_subset, ontology_id="test")
# ---------------------------------------------------------------------------
# Domain/range enforcement — relationships
# ---------------------------------------------------------------------------
class TestRelationshipDomainRange:
def test_valid_domain_and_range(self, converter):
rel = Relationship(
subject="Alice", subject_type="Person",
relation="worksFor",
object="Acme Corp", object_type="Company",
)
triple = converter.convert_relationship(rel)
assert triple is not None
def test_domain_violation_rejected(self, converter):
rel = Relationship(
subject="Widget", subject_type="Product",
relation="worksFor",
object="Acme Corp", object_type="Company",
)
assert converter.convert_relationship(rel) is None
def test_range_violation_rejected(self, converter):
rel = Relationship(
subject="Alice", subject_type="Person",
relation="worksFor",
object="Widget", object_type="Product",
)
assert converter.convert_relationship(rel) is None
def test_both_domain_and_range_violated(self, converter):
rel = Relationship(
subject="Widget", subject_type="Product",
relation="worksFor",
object="Gadget", object_type="Product",
)
assert converter.convert_relationship(rel) is None
# ---------------------------------------------------------------------------
# Subclass acceptance
# ---------------------------------------------------------------------------
class TestSubclassAcceptance:
def test_direct_subclass_matches_domain(self, converter):
"""Employee is subclass of Person; worksFor domain is Person."""
rel = Relationship(
subject="Bob", subject_type="Employee",
relation="worksFor",
object="Acme Corp", object_type="Company",
)
assert converter.convert_relationship(rel) is not None
def test_transitive_subclass_matches_domain(self, converter):
"""Manager → Employee → Person; worksFor domain is Person."""
rel = Relationship(
subject="Carol", subject_type="Manager",
relation="worksFor",
object="Acme Corp", object_type="Company",
)
assert converter.convert_relationship(rel) is not None
def test_subclass_matches_range(self, converter):
"""manages range is Employee; Manager is subclass of Employee."""
rel = Relationship(
subject="Carol", subject_type="Manager",
relation="manages",
object="Dave", object_type="Manager",
)
assert converter.convert_relationship(rel) is not None
def test_superclass_does_not_match_subclass_constraint(self, converter):
"""manages domain is Manager; Person is NOT a subclass of Manager."""
rel = Relationship(
subject="Alice", subject_type="Person",
relation="manages",
object="Bob", object_type="Employee",
)
assert converter.convert_relationship(rel) is None
# ---------------------------------------------------------------------------
# Polymorphic properties (no domain/range)
# ---------------------------------------------------------------------------
class TestPolymorphicProperties:
def test_no_domain_no_range_allows_anything(self, converter):
rel = Relationship(
subject="Alice", subject_type="Person",
relation="relatedTo",
object="Acme Corp", object_type="Company",
)
assert converter.convert_relationship(rel) is not None
def test_polymorphic_with_unrelated_types(self, converter):
rel = Relationship(
subject="Widget", subject_type="Product",
relation="relatedTo",
object="Bob", object_type="Employee",
)
assert converter.convert_relationship(rel) is not None
# ---------------------------------------------------------------------------
# Datatype property domain enforcement
# ---------------------------------------------------------------------------
class TestAttributeDomainValidation:
def test_valid_domain(self, converter):
attr = Attribute(
entity="Bob", entity_type="Employee",
attribute="employeeId", value="E-1234",
)
assert converter.convert_attribute(attr) is not None
def test_subclass_matches_domain(self, converter):
"""Manager is subclass of Employee; employeeId domain is Employee."""
attr = Attribute(
entity="Carol", entity_type="Manager",
attribute="employeeId", value="M-5678",
)
assert converter.convert_attribute(attr) is not None
def test_domain_violation_rejected(self, converter):
attr = Attribute(
entity="Acme Corp", entity_type="Company",
attribute="employeeId", value="E-0000",
)
assert converter.convert_attribute(attr) is None
def test_no_domain_allows_anything(self, converter):
attr = Attribute(
entity="Widget", entity_type="Product",
attribute="description", value="A useful widget",
)
assert converter.convert_attribute(attr) is not None
# ---------------------------------------------------------------------------
# OntologySelector bypass for small ontologies (#908)
# ---------------------------------------------------------------------------
def _make_ontology(n_classes, n_obj_props=0, n_dt_props=0):
classes = {
f"C{i}": OntologyClass(uri=f"http://example.org/C{i}")
for i in range(n_classes)
}
obj_props = {
f"op{i}": OntologyProperty(
uri=f"http://example.org/op{i}", type="owl:ObjectProperty"
)
for i in range(n_obj_props)
}
dt_props = {
f"dp{i}": OntologyProperty(
uri=f"http://example.org/dp{i}", type="owl:DatatypeProperty"
)
for i in range(n_dt_props)
}
return Ontology(
id="tiny",
metadata={"name": "Tiny"},
classes=classes,
object_properties=obj_props,
datatype_properties=dt_props,
)
def _make_loader(ontology):
loader = Mock()
loader.get_ontology.return_value = ontology
loader.get_all_ontologies.return_value = {"tiny": ontology}
return loader
class TestBypassSelectorBelow:
async def test_bypass_returns_full_ontology(self):
"""With 3 elements and bypass_selector_below=5, selector is bypassed."""
ont = _make_ontology(2, 1, 0)
loader = _make_loader(ont)
embedder = Mock()
selector = OntologySelector(
ontology_embedder=embedder,
ontology_loader=loader,
bypass_selector_below=5,
)
segments = [TextSegment(text="some text", type="sentence", position=0)]
subsets = await selector.select_ontology_subset(segments)
assert len(subsets) == 1
assert subsets[0].ontology_id == "tiny"
assert len(subsets[0].classes) == 2
assert len(subsets[0].object_properties) == 1
assert subsets[0].relevance_score == 1.0
# Embedder should never be called
embedder.embed_text.assert_not_called()
async def test_no_bypass_when_above_threshold(self):
"""With 10 elements and bypass_selector_below=5, selector runs normally."""
ont = _make_ontology(6, 3, 1)
loader = _make_loader(ont)
embedder = Mock()
embedder.embed_text = AsyncMock(return_value=[0.1, 0.2])
vector_store = Mock()
vector_store.size.return_value = 10
vector_store.search.return_value = []
embedder.get_vector_store.return_value = vector_store
selector = OntologySelector(
ontology_embedder=embedder,
ontology_loader=loader,
bypass_selector_below=5,
)
segments = [TextSegment(text="some text", type="sentence", position=0)]
subsets = await selector.select_ontology_subset(segments)
# Vector store was consulted (selector ran normally)
vector_store.size.assert_called_once()
async def test_bypass_at_exact_threshold_not_triggered(self):
"""With exactly 5 elements and bypass_selector_below=5, selector runs (< not <=)."""
ont = _make_ontology(3, 1, 1) # total = 5
loader = _make_loader(ont)
embedder = Mock()
embedder.embed_text = AsyncMock(return_value=[0.1, 0.2])
vector_store = Mock()
vector_store.size.return_value = 5
vector_store.search.return_value = []
embedder.get_vector_store.return_value = vector_store
selector = OntologySelector(
ontology_embedder=embedder,
ontology_loader=loader,
bypass_selector_below=5,
)
segments = [TextSegment(text="some text", type="sentence", position=0)]
subsets = await selector.select_ontology_subset(segments)
# Should NOT bypass — 5 is not < 5
vector_store.size.assert_called_once()
async def test_bypass_zero_disables(self):
"""bypass_selector_below=0 means bypass never triggers."""
ont = _make_ontology(0, 0, 0) # empty ontology
loader = _make_loader(ont)
embedder = Mock()
embedder.embed_text = AsyncMock(return_value=[0.1])
vector_store = Mock()
vector_store.size.return_value = 0
vector_store.search.return_value = []
embedder.get_vector_store.return_value = vector_store
selector = OntologySelector(
ontology_embedder=embedder,
ontology_loader=loader,
bypass_selector_below=0,
)
segments = [TextSegment(text="some text", type="sentence", position=0)]
subsets = await selector.select_ontology_subset(segments)
# 0 is not < 0, so bypass doesn't trigger
vector_store.size.assert_called_once()

View file

@ -165,22 +165,37 @@ class TestIamAuthDispatch:
by shape of the bearer."""
@pytest.mark.asyncio
async def test_no_authorization_header_raises_401(self):
async def test_no_authorization_header_tries_anonymous(self):
auth = IamAuth(backend=Mock())
with pytest.raises(web.HTTPUnauthorized):
await auth.authenticate(make_request(None))
async def fake_with_client(op):
raise RuntimeError("auth-failed: anonymous access not permitted")
with patch.object(auth, "_with_client", side_effect=fake_with_client):
with pytest.raises(web.HTTPUnauthorized):
await auth.authenticate(make_request(None))
@pytest.mark.asyncio
async def test_non_bearer_header_raises_401(self):
async def test_non_bearer_header_tries_anonymous(self):
auth = IamAuth(backend=Mock())
with pytest.raises(web.HTTPUnauthorized):
await auth.authenticate(make_request("Basic whatever"))
async def fake_with_client(op):
raise RuntimeError("auth-failed: anonymous access not permitted")
with patch.object(auth, "_with_client", side_effect=fake_with_client):
with pytest.raises(web.HTTPUnauthorized):
await auth.authenticate(make_request("Basic whatever"))
@pytest.mark.asyncio
async def test_empty_bearer_raises_401(self):
async def test_empty_bearer_tries_anonymous(self):
auth = IamAuth(backend=Mock())
with pytest.raises(web.HTTPUnauthorized):
await auth.authenticate(make_request("Bearer "))
async def fake_with_client(op):
raise RuntimeError("auth-failed: anonymous access not permitted")
with patch.object(auth, "_with_client", side_effect=fake_with_client):
with pytest.raises(web.HTTPUnauthorized):
await auth.authenticate(make_request("Bearer "))
@pytest.mark.asyncio
async def test_unknown_format_raises_401(self):
@ -445,3 +460,121 @@ class TestAuthorise:
# Different resource → different cache key → two IAM calls.
assert calls["n"] == 2
# -- Anonymous authentication boundary ------------------------------------
class TestAnonymousAuthBoundary:
"""The gateway must only attempt anonymous auth when no credential
is presented. A malformed token must NOT fall through to the
anonymous path that would let an attacker bypass a broken token
by simply sending garbage."""
@pytest.mark.asyncio
async def test_no_header_attempts_anonymous(self):
auth = IamAuth(backend=Mock())
async def fake_with_client(op):
return await op(Mock(
authenticate_anonymous=AsyncMock(
return_value=("anon", "default", ["reader"]),
)
))
with patch.object(auth, "_with_client", side_effect=fake_with_client):
ident = await auth.authenticate(make_request(None))
assert ident.handle == "anon"
assert ident.source == "anonymous"
@pytest.mark.asyncio
async def test_empty_bearer_attempts_anonymous(self):
auth = IamAuth(backend=Mock())
async def fake_with_client(op):
return await op(Mock(
authenticate_anonymous=AsyncMock(
return_value=("anon", "default", ["reader"]),
)
))
with patch.object(auth, "_with_client", side_effect=fake_with_client):
ident = await auth.authenticate(make_request("Bearer "))
assert ident.handle == "anon"
assert ident.source == "anonymous"
@pytest.mark.asyncio
async def test_malformed_token_does_not_fall_through_to_anonymous(self):
auth = IamAuth(backend=Mock())
called = {"anonymous": False}
original = auth._authenticate_anonymous
async def spy_anonymous():
called["anonymous"] = True
return await original()
auth._authenticate_anonymous = spy_anonymous
with pytest.raises(web.HTTPUnauthorized):
await auth.authenticate(make_request("Bearer garbage"))
assert not called["anonymous"]
@pytest.mark.asyncio
async def test_bad_api_key_does_not_fall_through_to_anonymous(self):
auth = IamAuth(backend=Mock())
called = {"anonymous": False}
async def spy_anonymous():
called["anonymous"] = True
auth._authenticate_anonymous = spy_anonymous
async def fake_with_client(op):
raise RuntimeError("auth-failed: unknown key")
with patch.object(auth, "_with_client", side_effect=fake_with_client):
with pytest.raises(web.HTTPUnauthorized):
await auth.authenticate(make_request("Bearer tg_bad"))
assert not called["anonymous"]
@pytest.mark.asyncio
async def test_bad_jwt_does_not_fall_through_to_anonymous(self):
auth = IamAuth(backend=Mock())
auth._signing_public_pem = "not-a-real-pem"
called = {"anonymous": False}
async def spy_anonymous():
called["anonymous"] = True
auth._authenticate_anonymous = spy_anonymous
with pytest.raises(web.HTTPUnauthorized):
await auth.authenticate(make_request("Bearer a.b.c"))
assert not called["anonymous"]
@pytest.mark.asyncio
async def test_anonymous_rejected_by_iam_raises_401(self):
auth = IamAuth(backend=Mock())
async def fake_with_client(op):
raise RuntimeError("auth-failed: anonymous access not permitted")
with patch.object(auth, "_with_client", side_effect=fake_with_client):
with pytest.raises(web.HTTPUnauthorized):
await auth.authenticate(make_request(None))
@pytest.mark.asyncio
async def test_anonymous_with_empty_user_id_raises_401(self):
auth = IamAuth(backend=Mock())
async def fake_with_client(op):
return await op(Mock(
authenticate_anonymous=AsyncMock(
return_value=("", "default", []),
)
))
with patch.object(auth, "_with_client", side_effect=fake_with_client):
with pytest.raises(web.HTTPUnauthorized):
await auth.authenticate(make_request(None))

View file

View file

@ -0,0 +1,44 @@
"""
Contract test: the full iam-svc MUST reject authenticate-anonymous.
This is a safety pin if someone accidentally adds anonymous access
to the production IAM handler, this test catches it.
"""
import asyncio
from unittest.mock import Mock, AsyncMock
import pytest
from trustgraph.iam.service.iam import IamService
def _make_request(**kwargs):
req = Mock()
for k, v in kwargs.items():
setattr(req, k, v)
return req
class TestIamRejectsAnonymous:
@pytest.fixture
def handler(self):
svc = object.__new__(IamService)
svc.table_store = Mock(spec=[])
svc.bootstrap_mode = "token"
svc.bootstrap_token = "tok"
svc._on_workspace_created = None
svc._on_workspace_deleted = None
svc._signing_key = None
svc._signing_key_lock = asyncio.Lock()
return svc
@pytest.mark.asyncio
async def test_authenticate_anonymous_returns_auth_failed(self, handler):
resp = await handler.handle(
_make_request(operation="authenticate-anonymous")
)
assert resp.error is not None
assert resp.error.type == "auth-failed"
assert "anonymous" in resp.error.message.lower()

View file

@ -0,0 +1,138 @@
"""
Tests for the no-auth IAM handler.
Verifies that NoAuthHandler returns the expected permissive responses
and that the always-allow authorise path returns the correct shape.
"""
import json
from unittest.mock import Mock
import pytest
from trustgraph.iam.noauth.handler import NoAuthHandler
def _make_request(**kwargs):
req = Mock()
for k, v in kwargs.items():
setattr(req, k, v)
return req
class TestAuthenticateAnonymous:
@pytest.mark.asyncio
async def test_returns_default_identity(self):
h = NoAuthHandler(
default_user_id="anon", default_workspace="ws",
)
resp = await h.handle(
_make_request(operation="authenticate-anonymous")
)
assert resp.error is None
assert resp.resolved_user_id == "anon"
assert resp.resolved_workspace == "ws"
assert "admin" in list(resp.resolved_roles)
@pytest.mark.asyncio
async def test_custom_defaults_propagate(self):
h = NoAuthHandler(
default_user_id="dev-user", default_workspace="dev-ws",
)
resp = await h.handle(
_make_request(operation="authenticate-anonymous")
)
assert resp.resolved_user_id == "dev-user"
assert resp.resolved_workspace == "dev-ws"
class TestResolveApiKey:
@pytest.mark.asyncio
async def test_any_key_resolves_to_default_identity(self):
h = NoAuthHandler()
resp = await h.handle(
_make_request(operation="resolve-api-key", api_key="tg_bogus")
)
assert resp.error is None
assert resp.resolved_user_id == "anonymous"
assert resp.resolved_workspace == "default"
class TestAuthorise:
@pytest.mark.asyncio
async def test_always_allows(self):
h = NoAuthHandler()
resp = await h.handle(
_make_request(
operation="authorise",
user_id="anyone",
capability="anything",
resource_json="{}",
parameters_json="{}",
)
)
assert resp.error is None
assert resp.decision_allow is True
assert resp.decision_ttl_seconds > 0
@pytest.mark.asyncio
async def test_authorise_many_returns_matching_count(self):
h = NoAuthHandler()
checks = [
{"capability": "a", "resource": {}, "parameters": {}},
{"capability": "b", "resource": {}, "parameters": {}},
{"capability": "c", "resource": {}, "parameters": {}},
]
resp = await h.handle(
_make_request(
operation="authorise-many",
user_id="u",
authorise_checks=json.dumps(checks),
)
)
assert resp.error is None
decisions = json.loads(resp.decisions_json)
assert len(decisions) == 3
assert all(d["allow"] is True for d in decisions)
class TestCreateWorkspaceCallback:
@pytest.mark.asyncio
async def test_create_workspace_calls_callback(self):
called_with = []
async def on_created(ws_id):
called_with.append(ws_id)
h = NoAuthHandler(on_workspace_created=on_created)
req = _make_request(operation="create-workspace")
req.workspace_record = Mock()
req.workspace_record.id = "test-ws"
resp = await h.handle(req)
assert resp.error is None
assert called_with == ["test-ws"]
@pytest.mark.asyncio
async def test_create_workspace_without_callback_still_succeeds(self):
h = NoAuthHandler()
req = _make_request(operation="create-workspace")
req.workspace_record = Mock()
req.workspace_record.id = "test-ws"
resp = await h.handle(req)
assert resp.error is None
class TestUnknownOperation:
@pytest.mark.asyncio
async def test_unknown_op_returns_error(self):
h = NoAuthHandler()
resp = await h.handle(
_make_request(operation="not-a-real-op")
)
assert resp.error is not None
assert resp.error.type == "invalid-argument"

View file

@ -7,7 +7,7 @@ including template rendering, term merging, JSON validation, and error handling.
import pytest
import json
from unittest.mock import AsyncMock, MagicMock, patch
from unittest.mock import AsyncMock
from trustgraph.template.prompt_manager import PromptManager, PromptConfiguration, Prompt
@ -344,6 +344,42 @@ class TestPromptManager:
assert pm.terms == {} # Default empty terms
assert len(pm.prompts) == 0
def test_load_config_does_not_swallow_keyboard_interrupt(self, monkeypatch):
"""KeyboardInterrupt should propagate out of config parsing."""
pm = PromptManager()
def interrupt(_value):
raise KeyboardInterrupt
monkeypatch.setattr("trustgraph.template.prompt_manager.json.loads", interrupt)
with pytest.raises(KeyboardInterrupt):
pm.load_config({"system": json.dumps("Test")})
@pytest.mark.asyncio
async def test_json_parse_does_not_swallow_system_exit(self):
"""SystemExit should propagate out of JSON response parsing."""
pm = PromptManager()
config = {
"system": json.dumps("Test"),
"template-index": json.dumps(["json_response"]),
"template.json_response": json.dumps({
"prompt": "Generate JSON",
"response-type": "json"
})
}
pm.load_config(config)
def exit_parse(_text):
raise SystemExit(2)
pm.parse_json = exit_parse
mock_llm = AsyncMock()
mock_llm.return_value = "{}"
with pytest.raises(SystemExit):
await pm.invoke("json_response", {}, mock_llm)
@pytest.mark.unit
class TestPromptManagerJsonl:
@ -585,4 +621,4 @@ not json at all
assert len(result) == 2
assert result[0] == {"any": "structure"}
assert result[1] == {"completely": "different"}
assert result[1] == {"completely": "different"}

View file

@ -8,6 +8,7 @@ import pytest
from unittest.mock import Mock, patch, MagicMock, call
import json
from trustgraph.api.socket_client import SocketClient
from trustgraph.api import (
Api,
Triple,
@ -222,6 +223,82 @@ class TestSocketClient:
for method in expected_methods:
assert hasattr(flow_instance, method), f"Missing method: {method}"
def test_socket_client_close_does_not_swallow_base_exceptions(self):
"""Test close cleanup does not suppress process-level interrupts."""
class InterruptingLoop:
def is_closed(self):
return False
def run_until_complete(self, awaitable):
if hasattr(awaitable, "close"):
awaitable.close()
raise SystemExit("stop")
socket = SocketClient(url="http://test/", timeout=60, token=None)
socket._loop = InterruptingLoop()
with pytest.raises(SystemExit):
socket.close()
@pytest.mark.parametrize(
("generator_method", "async_method"),
[
("_streaming_generator", "_send_request_async_streaming"),
("_streaming_generator_raw", "_send_request_async_streaming_raw"),
],
)
def test_socket_client_streaming_cleanup_does_not_swallow_base_exceptions(
self, generator_method, async_method
):
"""Test streaming cleanup does not suppress process-level interrupts."""
class FakeAsyncGenerator:
def __anext__(self):
return "next"
def aclose(self):
return "close"
class InterruptingLoop:
def run_until_complete(self, awaitable):
if awaitable == "next":
raise StopAsyncIteration
if awaitable == "close":
raise SystemExit("stop")
raise AssertionError(f"unexpected awaitable: {awaitable!r}")
socket = SocketClient(url="http://test/", timeout=60, token=None)
setattr(socket, async_method, lambda *args, **kwargs: FakeAsyncGenerator())
generator = getattr(socket, generator_method)(
"agent", "default", {}, InterruptingLoop()
)
with pytest.raises(SystemExit):
next(generator)
@pytest.mark.asyncio
async def test_socket_client_reader_does_not_swallow_base_exceptions(self):
"""Test reader error fanout does not suppress process-level interrupts."""
class FailingSocket:
def __aiter__(self):
return self
async def __anext__(self):
raise ValueError("reader failed")
class InterruptingQueue:
async def put(self, message):
raise SystemExit("stop")
socket = SocketClient(url="http://test/", timeout=60, token=None)
socket._socket = FailingSocket()
socket._pending = {"req-1": InterruptingQueue()}
with pytest.raises(SystemExit):
await socket._reader()
class TestBulkClient:
"""Test bulk operations client"""

View file

@ -0,0 +1,56 @@
"""
Tests for ontology monitoring metrics.
"""
from trustgraph.query.ontology.monitoring import (
PerformanceMonitor,
_extract_metric_label,
)
def test_extract_metric_label_reads_unquoted_label_value():
metric_name = "cache_requests_total{cache_type=entity,component=ontology}"
assert _extract_metric_label(metric_name, "cache_type") == "entity"
def test_extract_metric_label_reads_quoted_label_value():
metric_name = 'cache_requests_total{cache_type="entity",component="ontology"}'
assert _extract_metric_label(metric_name, "cache_type") == "entity"
def test_extract_metric_label_returns_none_when_label_missing():
metric_name = "cache_requests_total{component=ontology}"
assert _extract_metric_label(metric_name, "cache_type") is None
def test_performance_report_ignores_counters_without_cache_type_label():
monitor = PerformanceMonitor({"enabled": False})
monitor.metrics_collector.increment(
"cache_requests_total",
labels={"component": "ontology"},
)
monitor.metrics_collector.increment(
"cache_type=not_a_label",
labels={"component": "ontology"},
)
monitor.metrics_collector.increment(
"cache_requests_total",
labels={"cache_type": "entity"},
)
monitor.metrics_collector.increment(
"cache_hits_total",
labels={"cache_type": "entity"},
)
report = monitor.get_performance_report()
assert report["cache_performance"] == {
"entity": {
"hit_rate": 1.0,
"total_requests": 1.0,
"total_hits": 1.0,
}
}

View file

@ -89,12 +89,15 @@ class TestRowsGraphQLQueryLogic:
@pytest.mark.asyncio
async def test_schema_config_parsing(self):
"""Test parsing of schema configuration"""
import asyncio
processor = MagicMock()
processor.schemas = {}
processor.schema_builders = {}
processor.graphql_schemas = {}
processor.config_key = "schema"
processor.query_cassandra = MagicMock()
processor._setup_lock = asyncio.Lock()
processor._apply_schema_config = Processor._apply_schema_config.__get__(processor, Processor)
processor.on_schema_config = Processor.on_schema_config.__get__(processor, Processor)
# Create test config
@ -335,7 +338,7 @@ class TestUnifiedTableQueries:
"""Test query execution with matching index"""
processor = MagicMock()
processor.session = MagicMock()
processor.connect_cassandra = MagicMock()
processor.connect_cassandra = AsyncMock()
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
processor.get_index_names = Processor.get_index_names.__get__(processor, Processor)
processor.find_matching_index = Processor.find_matching_index.__get__(processor, Processor)
@ -396,7 +399,7 @@ class TestUnifiedTableQueries:
"""Test query execution without matching index (scan mode)"""
processor = MagicMock()
processor.session = MagicMock()
processor.connect_cassandra = MagicMock()
processor.connect_cassandra = AsyncMock()
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
processor.get_index_names = Processor.get_index_names.__get__(processor, Processor)
processor.find_matching_index = Processor.find_matching_index.__get__(processor, Processor)

View file

@ -0,0 +1,580 @@
"""
Tests for the SPARQL algebra evaluator.
Verifies that evaluate() and _query_pattern() call TriplesClient.query()
with the correct arguments, and in particular that workspace is never
passed workspace isolation is handled by pub/sub topic routing.
"""
import pytest
from unittest.mock import AsyncMock, MagicMock, call
from rdflib.term import Variable, URIRef, Literal
from rdflib.plugins.sparql.parserutils import CompValue
from trustgraph.schema import Term, IRI, LITERAL
from trustgraph.query.sparql.algebra import (
evaluate, materialise, _query_pattern, _eval_bgp,
)
# --- Helpers ---
def iri(v):
return Term(type=IRI, iri=v)
def lit(v):
return Term(type=LITERAL, value=v)
def make_tc(query_return=None, query_side_effect=None):
"""Create a mock TriplesClient with both query() and query_gen() support."""
tc = AsyncMock()
if query_side_effect is not None:
tc.query.side_effect = query_side_effect
async def gen_side_effect(**kwargs):
results = await query_side_effect(**kwargs)
for r in results:
yield r
tc.query_gen = gen_side_effect
else:
items = query_return or []
tc.query.return_value = items
async def gen(**kwargs):
for item in items:
yield item
tc.query_gen = gen
return tc
def make_triple(s, p, o):
t = MagicMock()
t.s = s
t.p = p
t.o = o
return t
def make_bgp(*patterns):
"""Build a CompValue BGP node from (s, p, o) tuples of rdflib terms."""
node = CompValue("BGP")
node.triples = list(patterns)
return node
def make_project(inner, variables):
node = CompValue("Project")
node.p = inner
node.PV = [Variable(v) for v in variables]
return node
def make_select(inner):
node = CompValue("SelectQuery")
node.p = inner
return node
def make_join(left, right):
node = CompValue("Join")
node.p1 = left
node.p2 = right
return node
def make_union(left, right):
node = CompValue("Union")
node.p1 = left
node.p2 = right
return node
def make_slice(inner, start, length):
node = CompValue("Slice")
node.p = inner
node.start = start
node.length = length
return node
def make_distinct(inner):
node = CompValue("Distinct")
node.p = inner
return node
def make_filter(inner, expr):
node = CompValue("Filter")
node.p = inner
node.expr = expr
return node
def make_minus(left, right):
node = CompValue("Minus")
node.p1 = left
node.p2 = right
return node
class TestQueryPattern:
"""Tests for _query_pattern — the leaf that calls TriplesClient."""
@pytest.mark.asyncio
async def test_passes_correct_args(self):
tc = AsyncMock()
tc.query.return_value = []
await _query_pattern(
tc,
s=iri("http://example.com/s"),
p=iri("http://example.com/p"),
o=None,
collection="my-collection",
limit=100,
)
tc.query.assert_called_once_with(
s=iri("http://example.com/s"),
p=iri("http://example.com/p"),
o=None,
limit=100,
collection="my-collection",
)
@pytest.mark.asyncio
async def test_workspace_not_passed(self):
tc = AsyncMock()
tc.query.return_value = []
await _query_pattern(tc, None, None, None, "default", 10)
kwargs = tc.query.call_args.kwargs
assert "workspace" not in kwargs
@pytest.mark.asyncio
async def test_returns_query_results(self):
tc = AsyncMock()
triple = make_triple(iri("http://a"), iri("http://b"), lit("c"))
tc.query.return_value = [triple]
results = await _query_pattern(tc, None, None, None, "default", 10)
assert len(results) == 1
assert results[0].s.iri == "http://a"
class TestEvalBgp:
"""Tests for BGP evaluation — triple pattern queries."""
@pytest.mark.asyncio
async def test_single_pattern_all_variables(self):
triple = make_triple(iri("http://s"), iri("http://p"), lit("o"))
tc = make_tc(query_return=[triple])
bgp = make_bgp(
(Variable("s"), Variable("p"), Variable("o")),
)
solutions = await materialise(bgp, tc, collection="default", limit=100)
assert len(solutions) == 1
assert solutions[0]["s"].iri == "http://s"
assert solutions[0]["p"].iri == "http://p"
assert solutions[0]["o"].value == "o"
@pytest.mark.asyncio
async def test_single_pattern_bound_subject(self):
tc = make_tc(query_return=[
make_triple(iri("http://s"), iri("http://p"), lit("val")),
])
bgp = make_bgp(
(URIRef("http://s"), Variable("p"), Variable("o")),
)
solutions = await materialise(bgp, tc, collection="default")
assert len(solutions) == 1
@pytest.mark.asyncio
async def test_empty_bgp_returns_empty_solution(self):
tc = make_tc()
bgp = make_bgp()
solutions = await materialise(bgp, tc, collection="default")
assert solutions == [{}]
@pytest.mark.asyncio
async def test_no_results_returns_empty(self):
tc = make_tc(query_return=[])
bgp = make_bgp(
(Variable("s"), Variable("p"), Variable("o")),
)
solutions = await materialise(bgp, tc, collection="default")
assert solutions == []
class TestEvaluate:
"""Tests for the top-level evaluate() dispatcher."""
@pytest.mark.asyncio
async def test_select_query_node(self):
tc = make_tc(query_return=[
make_triple(iri("http://s"), iri("http://p"), lit("o")),
])
bgp = make_bgp(
(Variable("s"), Variable("p"), Variable("o")),
)
select = make_select(make_project(bgp, ["s", "p"]))
solutions = await materialise(select, tc, collection="default")
assert len(solutions) == 1
assert "s" in solutions[0]
assert "p" in solutions[0]
assert "o" not in solutions[0]
@pytest.mark.asyncio
async def test_workspace_never_in_query_calls(self):
"""Verify that no matter the algebra structure, workspace is never
passed to TriplesClient.query()."""
tc = make_tc(query_return=[
make_triple(iri("http://s"), iri("http://p"), lit("o")),
])
bgp1 = make_bgp((Variable("s"), Variable("p"), Variable("o")))
bgp2 = make_bgp((Variable("a"), Variable("b"), Variable("c")))
tree = make_select(make_project(
make_union(bgp1, bgp2), ["s", "p", "o"]
))
await materialise(tree, tc, collection="test-coll")
@pytest.mark.asyncio
async def test_join(self):
call_count = 0
async def mock_query(**kwargs):
nonlocal call_count
call_count += 1
if call_count == 1:
return [make_triple(iri("http://a"), iri("http://p"), lit("v"))]
else:
return [make_triple(iri("http://a"), iri("http://q"), lit("w"))]
tc = make_tc(query_side_effect=mock_query)
bgp1 = make_bgp((Variable("s"), URIRef("http://p"), Variable("v1")))
bgp2 = make_bgp((Variable("s"), URIRef("http://q"), Variable("v2")))
tree = make_join(bgp1, bgp2)
solutions = await materialise(tree, tc, collection="default")
assert len(solutions) == 1
assert solutions[0]["s"].iri == "http://a"
@pytest.mark.asyncio
async def test_slice(self):
triples = [
make_triple(iri(f"http://s{i}"), iri("http://p"), lit(f"o{i}"))
for i in range(5)
]
tc = make_tc(query_return=triples)
bgp = make_bgp((Variable("s"), Variable("p"), Variable("o")))
tree = make_slice(bgp, start=1, length=2)
solutions = await materialise(tree, tc, collection="default")
assert len(solutions) == 2
@pytest.mark.asyncio
async def test_distinct(self):
triple = make_triple(iri("http://s"), iri("http://p"), lit("o"))
tc = make_tc(query_return=[triple, triple])
bgp = make_bgp((Variable("s"), Variable("p"), Variable("o")))
tree = make_distinct(bgp)
solutions = await materialise(tree, tc, collection="default")
assert len(solutions) == 1
@pytest.mark.asyncio
async def test_minus_removes_matching(self):
alice = iri("http://example.com/alice")
bob = iri("http://example.com/bob")
knows = iri("http://example.com/knows")
hates = iri("http://example.com/hates")
charlie = iri("http://example.com/charlie")
left_triple = make_triple(alice, knows, bob)
right_triple2 = make_triple(alice, hates, charlie)
async def mock_query(**kwargs):
pred = kwargs.get("p")
if pred and pred.iri == "http://example.com/knows":
return [left_triple]
elif pred and pred.iri == "http://example.com/hates":
return [right_triple2]
return []
tc = make_tc(query_side_effect=mock_query)
left_bgp = make_bgp(
(Variable("s"), URIRef("http://example.com/knows"), Variable("o"))
)
right_bgp = make_bgp(
(Variable("s"), URIRef("http://example.com/hates"), Variable("r"))
)
tree = make_select(
make_project(
make_minus(left_bgp, right_bgp),
["s", "o"]
)
)
solutions = await materialise(tree, tc, collection="default")
assert len(solutions) == 0
@pytest.mark.asyncio
async def test_minus_no_shared_vars_preserves_all(self):
alice = iri("http://example.com/alice")
bob = iri("http://example.com/bob")
left_triple = make_triple(alice, iri("http://example.com/p"), bob)
async def mock_query(**kwargs):
pred = kwargs.get("p")
if pred and pred.iri == "http://example.com/p":
return [left_triple]
return []
tc = make_tc(query_side_effect=mock_query)
left_bgp = make_bgp(
(Variable("s"), URIRef("http://example.com/p"), Variable("o"))
)
right_bgp = make_bgp(
(Variable("x"), URIRef("http://example.com/q"), Variable("y"))
)
tree = make_select(
make_project(
make_minus(left_bgp, right_bgp),
["s", "o"]
)
)
solutions = await materialise(tree, tc, collection="default")
assert len(solutions) == 1
@pytest.mark.asyncio
async def test_filter_exists_keeps_matching(self):
alice = iri("http://example.com/alice")
bob = iri("http://example.com/bob")
charlie = iri("http://example.com/charlie")
left_triple1 = make_triple(alice, iri("http://example.com/knows"), bob)
left_triple2 = make_triple(alice, iri("http://example.com/knows"), charlie)
exists_triple = make_triple(bob, iri("http://example.com/likes"), alice)
async def mock_query(**kwargs):
pred = kwargs.get("p")
if pred and pred.iri == "http://example.com/knows":
return [left_triple1, left_triple2]
elif pred and pred.iri == "http://example.com/likes":
return [exists_triple]
return []
tc = make_tc(query_side_effect=mock_query)
left_bgp = make_bgp(
(Variable("s"), URIRef("http://example.com/knows"), Variable("o"))
)
exists_bgp = make_bgp(
(Variable("o"), URIRef("http://example.com/likes"), Variable("_any"))
)
exists_expr = CompValue("Builtin_EXISTS")
exists_expr.graph = exists_bgp
tree = make_select(
make_project(
make_filter(left_bgp, exists_expr),
["s", "o"]
)
)
solutions = await materialise(tree, tc, collection="default")
result_objects = [s["o"].iri for s in solutions]
assert "http://example.com/bob" in result_objects
assert "http://example.com/charlie" not in result_objects
@pytest.mark.asyncio
async def test_filter_not_exists_removes_matching(self):
alice = iri("http://example.com/alice")
bob = iri("http://example.com/bob")
charlie = iri("http://example.com/charlie")
left_triple1 = make_triple(alice, iri("http://example.com/knows"), bob)
left_triple2 = make_triple(alice, iri("http://example.com/knows"), charlie)
exists_triple = make_triple(bob, iri("http://example.com/likes"), alice)
async def mock_query(**kwargs):
pred = kwargs.get("p")
if pred and pred.iri == "http://example.com/knows":
return [left_triple1, left_triple2]
elif pred and pred.iri == "http://example.com/likes":
return [exists_triple]
return []
tc = make_tc(query_side_effect=mock_query)
left_bgp = make_bgp(
(Variable("s"), URIRef("http://example.com/knows"), Variable("o"))
)
exists_bgp = make_bgp(
(Variable("o"), URIRef("http://example.com/likes"), Variable("_any"))
)
not_exists_expr = CompValue("Builtin_NOTEXISTS")
not_exists_expr.graph = exists_bgp
tree = make_select(
make_project(
make_filter(left_bgp, not_exists_expr),
["s", "o"]
)
)
solutions = await materialise(tree, tc, collection="default")
result_objects = [s["o"].iri for s in solutions]
assert "http://example.com/charlie" in result_objects
assert "http://example.com/bob" not in result_objects
@pytest.mark.asyncio
async def test_join_values_uses_bind_join(self):
"""When VALUES is joined with a BGP, the bind join should pass
the VALUES bindings into the BGP evaluation so the triple store
query is selective (not a wildcard)."""
alice = iri("http://example.com/alice")
bob = iri("http://example.com/bob")
knows = iri("http://example.com/knows")
queries_issued = []
async def mock_query(**kwargs):
queries_issued.append(kwargs)
s, p = kwargs.get("s"), kwargs.get("p")
if s and s.iri == "http://example.com/alice" and p and p.iri == "http://example.com/knows":
return [make_triple(alice, knows, bob)]
return []
tc = make_tc(query_side_effect=mock_query)
# VALUES ?s { <alice> }
values_node = CompValue("values")
values_node.var = [Variable("s")]
values_node.value = [[URIRef("http://example.com/alice")]]
values_node.res = None
to_multiset = CompValue("ToMultiSet")
to_multiset.p = values_node
bgp = make_bgp(
(Variable("s"), URIRef("http://example.com/knows"), Variable("o")),
)
tree = make_join(to_multiset, bgp)
solutions = await materialise(tree, tc, collection="default")
assert len(solutions) == 1
assert solutions[0]["s"].iri == "http://example.com/alice"
assert solutions[0]["o"].iri == "http://example.com/bob"
# The key assertion: the BGP query should have received
# s=alice (bound from VALUES), NOT s=None (wildcard)
assert len(queries_issued) == 1
assert queries_issued[0]["s"] is not None
assert queries_issued[0]["s"].iri == "http://example.com/alice"
@pytest.mark.asyncio
async def test_join_values_multiple_bindings(self):
"""Bind join with multiple VALUES bindings."""
alice = iri("http://example.com/alice")
bob = iri("http://example.com/bob")
knows = iri("http://example.com/knows")
charlie = iri("http://example.com/charlie")
async def mock_query(**kwargs):
s = kwargs.get("s")
if s and s.iri == "http://example.com/alice":
return [make_triple(alice, knows, bob)]
elif s and s.iri == "http://example.com/bob":
return [make_triple(bob, knows, charlie)]
return []
tc = make_tc(query_side_effect=mock_query)
values_node = CompValue("values")
values_node.var = [Variable("s")]
values_node.value = [
[URIRef("http://example.com/alice")],
[URIRef("http://example.com/bob")],
]
values_node.res = None
to_multiset = CompValue("ToMultiSet")
to_multiset.p = values_node
bgp = make_bgp(
(Variable("s"), URIRef("http://example.com/knows"), Variable("o")),
)
tree = make_join(to_multiset, bgp)
solutions = await materialise(tree, tc, collection="default")
assert len(solutions) == 2
subjects = {s["s"].iri for s in solutions}
assert subjects == {
"http://example.com/alice",
"http://example.com/bob",
}
@pytest.mark.asyncio
async def test_unsupported_node_returns_empty_solution(self):
tc = make_tc()
node = CompValue("SomethingUnknown")
solutions = await materialise(node, tc, collection="default")
assert solutions == [{}]
@pytest.mark.asyncio
async def test_non_compvalue_returns_empty_solution(self):
tc = make_tc()
solutions = await materialise("not a node", tc, collection="default")
assert solutions == [{}]

View file

@ -300,6 +300,438 @@ class TestBuiltinFunctions:
flags=None)
assert evaluate_expression(expr, {"x": lit("hello")}) is False
def test_substr_three_args(self):
from rdflib.term import Variable
from rdflib import Literal
expr = self._make_builtin("SUBSTR",
arg=Variable("x"),
start=Literal(1),
length=Literal(4))
result = evaluate_expression(expr, {"x": lit("2024-03-15")})
assert result.type == LITERAL
assert result.value == "2024"
def test_substr_two_args(self):
from rdflib.term import Variable
from rdflib import Literal
expr = self._make_builtin("SUBSTR",
arg=Variable("x"),
start=Literal(6),
length=None)
result = evaluate_expression(expr, {"x": lit("2024-03-15")})
assert result.type == LITERAL
assert result.value == "03-15"
def test_substr_middle(self):
from rdflib.term import Variable
from rdflib import Literal
expr = self._make_builtin("SUBSTR",
arg=Variable("x"),
start=Literal(6),
length=Literal(2))
result = evaluate_expression(expr, {"x": lit("2024-03-15")})
assert result.type == LITERAL
assert result.value == "03"
def test_substr_null_start(self):
from rdflib.term import Variable
from rdflib import Literal
expr = self._make_builtin("SUBSTR",
arg=Variable("x"),
start=Variable("missing"),
length=None)
result = evaluate_expression(expr, {"x": lit("hello")})
assert result is None
def test_year(self):
from rdflib.term import Variable
expr = self._make_builtin("YEAR", arg=Variable("x"))
result = evaluate_expression(
expr, {"x": lit("2024-03-15", datatype=XSD + "date")}
)
assert result == 2024
def test_month(self):
from rdflib.term import Variable
expr = self._make_builtin("MONTH", arg=Variable("x"))
result = evaluate_expression(
expr, {"x": lit("2024-03-15", datatype=XSD + "date")}
)
assert result == 3
def test_day(self):
from rdflib.term import Variable
expr = self._make_builtin("DAY", arg=Variable("x"))
result = evaluate_expression(
expr, {"x": lit("2024-03-15", datatype=XSD + "date")}
)
assert result == 15
def test_hours(self):
from rdflib.term import Variable
expr = self._make_builtin("HOURS", arg=Variable("x"))
result = evaluate_expression(
expr, {"x": lit("2024-03-15T10:30:45", datatype=XSD + "dateTime")}
)
assert result == 10
def test_minutes(self):
from rdflib.term import Variable
expr = self._make_builtin("MINUTES", arg=Variable("x"))
result = evaluate_expression(
expr, {"x": lit("2024-03-15T10:30:45", datatype=XSD + "dateTime")}
)
assert result == 30
def test_seconds(self):
from rdflib.term import Variable
expr = self._make_builtin("SECONDS", arg=Variable("x"))
result = evaluate_expression(
expr, {"x": lit("2024-03-15T10:30:45", datatype=XSD + "dateTime")}
)
assert result == 45
def test_year_from_datetime(self):
from rdflib.term import Variable
expr = self._make_builtin("YEAR", arg=Variable("x"))
result = evaluate_expression(
expr, {"x": lit("2024-03-15T10:30:45", datatype=XSD + "dateTime")}
)
assert result == 2024
def test_hours_from_date_returns_zero(self):
from rdflib.term import Variable
expr = self._make_builtin("HOURS", arg=Variable("x"))
result = evaluate_expression(
expr, {"x": lit("2024-03-15", datatype=XSD + "date")}
)
assert result == 0
def test_year_invalid_date(self):
from rdflib.term import Variable
expr = self._make_builtin("YEAR", arg=Variable("x"))
result = evaluate_expression(
expr, {"x": lit("not-a-date")}
)
assert result is None
def test_floor(self):
from rdflib.term import Variable
expr = self._make_builtin("FLOOR", arg=Variable("x"))
assert evaluate_expression(expr, {"x": lit("3.7")}) == 3
def test_floor_negative(self):
from rdflib.term import Variable
expr = self._make_builtin("FLOOR", arg=Variable("x"))
assert evaluate_expression(expr, {"x": lit("-2.3")}) == -3
def test_floor_none(self):
from rdflib.term import Variable
expr = self._make_builtin("FLOOR", arg=Variable("x"))
assert evaluate_expression(expr, {"x": lit("abc")}) is None
def test_ceil(self):
from rdflib.term import Variable
expr = self._make_builtin("CEIL", arg=Variable("x"))
assert evaluate_expression(expr, {"x": lit("3.2")}) == 4
def test_ceil_negative(self):
from rdflib.term import Variable
expr = self._make_builtin("CEIL", arg=Variable("x"))
assert evaluate_expression(expr, {"x": lit("-2.7")}) == -2
def test_abs_positive(self):
from rdflib.term import Variable
expr = self._make_builtin("ABS", arg=Variable("x"))
assert evaluate_expression(expr, {"x": lit("42")}) == 42
def test_abs_negative(self):
from rdflib.term import Variable
expr = self._make_builtin("ABS", arg=Variable("x"))
assert evaluate_expression(expr, {"x": lit("-42")}) == 42
def test_abs_none(self):
from rdflib.term import Variable
expr = self._make_builtin("ABS", arg=Variable("x"))
assert evaluate_expression(expr, {"x": lit("abc")}) is None
def test_replace_simple(self):
from rdflib.term import Variable
from rdflib import Literal
expr = self._make_builtin("REPLACE",
arg=Variable("x"),
pattern=Literal(" BC"),
replacement=Literal(""),
flags=None)
result = evaluate_expression(expr, {"x": lit("500 BC")})
assert result.type == LITERAL
assert result.value == "500"
def test_replace_regex(self):
from rdflib.term import Variable
from rdflib import Literal
expr = self._make_builtin("REPLACE",
arg=Variable("x"),
pattern=Literal("[0-9]+"),
replacement=Literal("X"),
flags=None)
result = evaluate_expression(expr, {"x": lit("abc123def456")})
assert result.value == "abcXdefX"
def test_replace_case_insensitive(self):
from rdflib.term import Variable
from rdflib import Literal
expr = self._make_builtin("REPLACE",
arg=Variable("x"),
pattern=Literal("hello"),
replacement=Literal("world"),
flags=Literal("i"))
result = evaluate_expression(expr, {"x": lit("HELLO there")})
assert result.value == "world there"
def test_round_up(self):
from rdflib.term import Variable
expr = self._make_builtin("ROUND", arg=Variable("x"))
assert evaluate_expression(expr, {"x": lit("3.7")}) == 4
def test_round_down(self):
from rdflib.term import Variable
expr = self._make_builtin("ROUND", arg=Variable("x"))
assert evaluate_expression(expr, {"x": lit("3.2")}) == 3
def test_round_none(self):
from rdflib.term import Variable
expr = self._make_builtin("ROUND", arg=Variable("x"))
assert evaluate_expression(expr, {"x": lit("abc")}) is None
def test_strbefore(self):
from rdflib.term import Variable
from rdflib import Literal
expr = self._make_builtin("STRBEFORE",
arg1=Variable("x"), arg2=Literal("-"))
result = evaluate_expression(expr, {"x": lit("2024-03-15")})
assert result.value == "2024"
def test_strbefore_not_found(self):
from rdflib.term import Variable
from rdflib import Literal
expr = self._make_builtin("STRBEFORE",
arg1=Variable("x"), arg2=Literal("/"))
result = evaluate_expression(expr, {"x": lit("hello")})
assert result.value == ""
def test_strafter(self):
from rdflib.term import Variable
from rdflib import Literal
expr = self._make_builtin("STRAFTER",
arg1=Variable("x"), arg2=Literal("-"))
result = evaluate_expression(expr, {"x": lit("2024-03-15")})
assert result.value == "03-15"
def test_strafter_not_found(self):
from rdflib.term import Variable
from rdflib import Literal
expr = self._make_builtin("STRAFTER",
arg1=Variable("x"), arg2=Literal("/"))
result = evaluate_expression(expr, {"x": lit("hello")})
assert result.value == ""
def test_encode_for_uri(self):
from rdflib.term import Variable
expr = self._make_builtin("ENCODE_FOR_URI", arg=Variable("x"))
result = evaluate_expression(expr, {"x": lit("hello world")})
assert result.value == "hello%20world"
def test_encode_for_uri_special_chars(self):
from rdflib.term import Variable
expr = self._make_builtin("ENCODE_FOR_URI", arg=Variable("x"))
result = evaluate_expression(expr, {"x": lit("a/b?c=d&e")})
assert result.value == "a%2Fb%3Fc%3Dd%26e"
def test_langmatches_basic(self):
from rdflib.term import Variable
from rdflib import Literal
expr = self._make_builtin("LANGMATCHES",
arg1=Literal("en"), arg2=Literal("en"))
assert evaluate_expression(expr, {}) is True
def test_langmatches_subtag(self):
from rdflib.term import Variable
from rdflib import Literal
expr = self._make_builtin("LANGMATCHES",
arg1=Literal("en-US"), arg2=Literal("en"))
assert evaluate_expression(expr, {}) is True
def test_langmatches_wildcard(self):
from rdflib.term import Variable
from rdflib import Literal
expr = self._make_builtin("LANGMATCHES",
arg1=Literal("fr"), arg2=Literal("*"))
assert evaluate_expression(expr, {}) is True
def test_langmatches_wildcard_empty(self):
from rdflib.term import Variable
from rdflib import Literal
expr = self._make_builtin("LANGMATCHES",
arg1=Literal(""), arg2=Literal("*"))
assert evaluate_expression(expr, {}) is False
def test_langmatches_no_match(self):
from rdflib.term import Variable
from rdflib import Literal
expr = self._make_builtin("LANGMATCHES",
arg1=Literal("fr"), arg2=Literal("en"))
assert evaluate_expression(expr, {}) is False
def test_iri_constructor(self):
from rdflib.term import Variable
expr = self._make_builtin("IRI", arg=Variable("x"))
result = evaluate_expression(
expr, {"x": lit("http://example.com/test")}
)
assert result.type == IRI
assert result.iri == "http://example.com/test"
def test_uri_constructor(self):
from rdflib.term import Variable
expr = self._make_builtin("URI", arg=Variable("x"))
result = evaluate_expression(
expr, {"x": lit("http://example.com/test")}
)
assert result.type == IRI
assert result.iri == "http://example.com/test"
def test_bnode_no_arg(self):
expr = self._make_builtin("BNODE")
result = evaluate_expression(expr, {})
assert result.type == BLANK
assert len(result.id) > 0
def test_bnode_with_label(self):
from rdflib import Literal
expr = self._make_builtin("BNODE", arg=Literal("mynode"))
result = evaluate_expression(expr, {})
assert result.type == BLANK
assert result.id == "mynode"
def test_now(self):
import re as re_mod
expr = self._make_builtin("NOW")
result = evaluate_expression(expr, {})
assert result.type == LITERAL
assert result.datatype == XSD + "dateTime"
assert re_mod.match(r"\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}", result.value)
def test_tz_with_utc(self):
from rdflib.term import Variable
expr = self._make_builtin("TZ", arg=Variable("x"))
result = evaluate_expression(
expr, {"x": lit("2024-03-15T10:30:45+0000",
datatype=XSD + "dateTime")}
)
assert result.type == LITERAL
assert result.value == "+00:00"
def test_tz_no_timezone(self):
from rdflib.term import Variable
expr = self._make_builtin("TZ", arg=Variable("x"))
result = evaluate_expression(
expr, {"x": lit("2024-03-15T10:30:45",
datatype=XSD + "dateTime")}
)
assert result.value == ""
def test_rand(self):
expr = self._make_builtin("RAND")
result = evaluate_expression(expr, {})
assert isinstance(result, float)
assert 0.0 <= result < 1.0
def test_uuid(self):
import re as re_mod
expr = self._make_builtin("UUID")
result = evaluate_expression(expr, {})
assert result.type == IRI
assert result.iri.startswith("urn:uuid:")
uuid_part = result.iri[len("urn:uuid:"):]
assert re_mod.match(
r"[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}",
uuid_part
)
def test_struuid(self):
import re as re_mod
expr = self._make_builtin("STRUUID")
result = evaluate_expression(expr, {})
assert result.type == LITERAL
assert re_mod.match(
r"[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}",
result.value
)
def test_md5(self):
from rdflib.term import Variable
expr = self._make_builtin("MD5", arg=Variable("x"))
result = evaluate_expression(expr, {"x": lit("hello")})
assert result.type == LITERAL
assert result.value == "5d41402abc4b2a76b9719d911017c592"
def test_sha1(self):
from rdflib.term import Variable
expr = self._make_builtin("SHA1", arg=Variable("x"))
result = evaluate_expression(expr, {"x": lit("hello")})
assert result.type == LITERAL
assert result.value == "aaf4c61ddcc5e8a2dabede0f3b482cd9aea9434d"
def test_sha256(self):
from rdflib.term import Variable
expr = self._make_builtin("SHA256", arg=Variable("x"))
result = evaluate_expression(expr, {"x": lit("hello")})
assert result.type == LITERAL
assert result.value == (
"2cf24dba5fb0a30e26e83b2ac5b9e29e"
"1b161e5c1fa7425e73043362938b9824"
)
def test_sha512(self):
from rdflib.term import Variable
expr = self._make_builtin("SHA512", arg=Variable("x"))
result = evaluate_expression(expr, {"x": lit("hello")})
assert result.type == LITERAL
assert len(result.value) == 128
def test_exists_with_callback(self):
from rdflib.plugins.sparql.parserutils import CompValue
graph = CompValue("BGP")
expr = self._make_builtin("EXISTS", graph=graph)
cb = lambda g, s: True
result = evaluate_expression(expr, {}, exists_cb=cb)
assert result is True
def test_exists_callback_false(self):
from rdflib.plugins.sparql.parserutils import CompValue
graph = CompValue("BGP")
expr = self._make_builtin("EXISTS", graph=graph)
cb = lambda g, s: False
result = evaluate_expression(expr, {}, exists_cb=cb)
assert result is False
def test_notexists_with_callback(self):
from rdflib.plugins.sparql.parserutils import CompValue
graph = CompValue("BGP")
expr = self._make_builtin("NOTEXISTS", graph=graph)
cb = lambda g, s: True
result = evaluate_expression(expr, {}, exists_cb=cb)
assert result is False
def test_notexists_callback_false(self):
from rdflib.plugins.sparql.parserutils import CompValue
graph = CompValue("BGP")
expr = self._make_builtin("NOTEXISTS", graph=graph)
cb = lambda g, s: False
result = evaluate_expression(expr, {}, exists_cb=cb)
assert result is True
class TestEffectiveBoolean:

View file

@ -5,7 +5,7 @@ Tests for SPARQL solution sequence operations.
import pytest
from trustgraph.schema import Term, IRI, LITERAL
from trustgraph.query.sparql.solutions import (
hash_join, left_join, union, project, distinct,
hash_join, left_join, minus, union, project, distinct,
order_by, slice_solutions, _terms_equal, _compatible,
)
@ -311,6 +311,30 @@ class TestOrderBy:
result = order_by(solutions, [])
assert len(result) == 1
def test_order_by_numeric_literals(self):
solutions = [
{"year": lit("1950")},
{"year": lit("700")},
{"year": lit("2000")},
{"year": lit("450")},
{"year": lit("1200")},
]
key_fns = [(lambda sol: sol.get("year"), True)]
result = order_by(solutions, key_fns)
values = [s["year"].value for s in result]
assert values == ["450", "700", "1200", "1950", "2000"]
def test_order_by_numeric_descending(self):
solutions = [
{"year": lit("1950")},
{"year": lit("700")},
{"year": lit("2000")},
]
key_fns = [(lambda sol: sol.get("year"), False)]
result = order_by(solutions, key_fns)
values = [s["year"].value for s in result]
assert values == ["2000", "1950", "700"]
class TestSlice:
@ -343,3 +367,37 @@ class TestSlice:
solutions = [{"s": alice}, {"s": bob}]
result = slice_solutions(solutions)
assert len(result) == 2
class TestMinus:
def test_removes_compatible(self, alice, bob):
left = [{"s": alice}, {"s": bob}]
right = [{"s": alice}]
result = minus(left, right)
assert len(result) == 1
assert result[0]["s"].iri == "http://example.com/bob"
def test_empty_right_preserves_all(self, alice, bob):
left = [{"s": alice}, {"s": bob}]
result = minus(left, [])
assert len(result) == 2
def test_no_shared_variables_preserves_all(self, alice, bob):
left = [{"s": alice}]
right = [{"t": bob}]
result = minus(left, right)
assert len(result) == 1
def test_all_removed(self, alice):
left = [{"s": alice}]
right = [{"s": alice}]
result = minus(left, right)
assert len(result) == 0
def test_partial_shared_variables(self, alice, bob):
left = [{"s": alice, "p": lit("x")}, {"s": bob, "p": lit("y")}]
right = [{"s": alice}]
result = minus(left, right)
assert len(result) == 1
assert result[0]["s"].iri == "http://example.com/bob"

View file

@ -2,8 +2,10 @@
Tests for Cassandra triples query service
"""
import asyncio
import pytest
from unittest.mock import MagicMock, patch
from unittest.mock import MagicMock, patch, AsyncMock
from trustgraph.query.triples.cassandra.service import Processor, create_term
from trustgraph.schema import Term, IRI, LITERAL
@ -18,7 +20,7 @@ class TestCassandraQueryProcessor:
return Processor(
taskgroup=MagicMock(),
id='test-cassandra-query',
graph_host='localhost'
cassandra_host='localhost'
)
def test_create_term_with_http_uri(self, processor):
@ -85,7 +87,7 @@ class TestCassandraQueryProcessor:
mock_result.dtype = None
mock_result.lang = None
mock_result.o = 'test_object'
mock_tg_instance.get_spo.return_value = [mock_result]
mock_tg_instance.async_get_spo = AsyncMock(return_value=[mock_result])
processor = Processor(
taskgroup=MagicMock(),
@ -110,8 +112,8 @@ class TestCassandraQueryProcessor:
keyspace='test_user'
)
# Verify get_spo was called with correct parameters
mock_tg_instance.get_spo.assert_called_once_with(
# Verify async_get_spo was called with correct parameters
mock_tg_instance.async_get_spo.assert_called_once_with(
'test_collection', 'test_subject', 'test_predicate', 'test_object', g=None, limit=100
)
@ -130,23 +132,25 @@ class TestCassandraQueryProcessor:
assert processor.cassandra_host == ['cassandra'] # Updated default
assert processor.cassandra_username is None
assert processor.cassandra_password is None
assert processor.table is None
assert processor._connections == {}
assert isinstance(processor._conn_lock, asyncio.Lock)
def test_processor_initialization_with_custom_params(self):
"""Test processor initialization with custom parameters"""
taskgroup_mock = MagicMock()
processor = Processor(
taskgroup=taskgroup_mock,
cassandra_host='cassandra.example.com',
cassandra_username='queryuser',
cassandra_password='querypass'
)
assert processor.cassandra_host == ['cassandra.example.com']
assert processor.cassandra_username == 'queryuser'
assert processor.cassandra_password == 'querypass'
assert processor.table is None
assert processor._connections == {}
assert isinstance(processor._conn_lock, asyncio.Lock)
@pytest.mark.asyncio
@patch('trustgraph.query.triples.cassandra.service.EntityCentricKnowledgeGraph')
@ -164,7 +168,7 @@ class TestCassandraQueryProcessor:
mock_result.otype = None
mock_result.dtype = None
mock_result.lang = None
mock_tg_instance.get_sp.return_value = [mock_result]
mock_tg_instance.async_get_sp = AsyncMock(return_value=[mock_result])
processor = Processor(taskgroup=MagicMock())
@ -178,7 +182,7 @@ class TestCassandraQueryProcessor:
result = await processor.query_triples('test_user', query)
mock_tg_instance.get_sp.assert_called_once_with('test_collection', 'test_subject', 'test_predicate', g=None, limit=50)
mock_tg_instance.async_get_sp.assert_called_once_with('test_collection', 'test_subject', 'test_predicate', g=None, limit=50)
assert len(result) == 1
assert result[0].s.iri == 'test_subject'
assert result[0].p.iri == 'test_predicate'
@ -200,7 +204,7 @@ class TestCassandraQueryProcessor:
mock_result.otype = None
mock_result.dtype = None
mock_result.lang = None
mock_tg_instance.get_s.return_value = [mock_result]
mock_tg_instance.async_get_s = AsyncMock(return_value=[mock_result])
processor = Processor(taskgroup=MagicMock())
@ -214,7 +218,7 @@ class TestCassandraQueryProcessor:
result = await processor.query_triples('test_user', query)
mock_tg_instance.get_s.assert_called_once_with('test_collection', 'test_subject', g=None, limit=25)
mock_tg_instance.async_get_s.assert_called_once_with('test_collection', 'test_subject', g=None, limit=25)
assert len(result) == 1
assert result[0].s.iri == 'test_subject'
assert result[0].p.iri == 'result_predicate'
@ -236,7 +240,7 @@ class TestCassandraQueryProcessor:
mock_result.otype = None
mock_result.dtype = None
mock_result.lang = None
mock_tg_instance.get_p.return_value = [mock_result]
mock_tg_instance.async_get_p = AsyncMock(return_value=[mock_result])
processor = Processor(taskgroup=MagicMock())
@ -250,7 +254,7 @@ class TestCassandraQueryProcessor:
result = await processor.query_triples('test_user', query)
mock_tg_instance.get_p.assert_called_once_with('test_collection', 'test_predicate', g=None, limit=10)
mock_tg_instance.async_get_p.assert_called_once_with('test_collection', 'test_predicate', g=None, limit=10)
assert len(result) == 1
assert result[0].s.iri == 'result_subject'
assert result[0].p.iri == 'test_predicate'
@ -272,7 +276,7 @@ class TestCassandraQueryProcessor:
mock_result.otype = None
mock_result.dtype = None
mock_result.lang = None
mock_tg_instance.get_o.return_value = [mock_result]
mock_tg_instance.async_get_o = AsyncMock(return_value=[mock_result])
processor = Processor(taskgroup=MagicMock())
@ -286,7 +290,7 @@ class TestCassandraQueryProcessor:
result = await processor.query_triples('test_user', query)
mock_tg_instance.get_o.assert_called_once_with('test_collection', 'test_object', g=None, limit=75)
mock_tg_instance.async_get_o.assert_called_once_with('test_collection', 'test_object', g=None, limit=75)
assert len(result) == 1
assert result[0].s.iri == 'result_subject'
assert result[0].p.iri == 'result_predicate'
@ -305,11 +309,11 @@ class TestCassandraQueryProcessor:
mock_result.s = 'all_subject'
mock_result.p = 'all_predicate'
mock_result.o = 'all_object'
mock_result.g = ''
mock_result.d = ''
mock_result.otype = None
mock_result.dtype = None
mock_result.lang = None
mock_tg_instance.get_all.return_value = [mock_result]
mock_tg_instance.async_get_all = AsyncMock(return_value=[mock_result])
processor = Processor(taskgroup=MagicMock())
@ -323,7 +327,7 @@ class TestCassandraQueryProcessor:
result = await processor.query_triples('test_user', query)
mock_tg_instance.get_all.assert_called_once_with('test_collection', limit=1000)
mock_tg_instance.async_get_all.assert_called_once_with('test_collection', limit=1000)
assert len(result) == 1
assert result[0].s.iri == 'all_subject'
assert result[0].p.iri == 'all_predicate'
@ -410,7 +414,7 @@ class TestCassandraQueryProcessor:
mock_result.dtype = None
mock_result.lang = None
mock_result.o = 'test_object'
mock_tg_instance.get_spo.return_value = [mock_result]
mock_tg_instance.async_get_spo = AsyncMock(return_value=[mock_result])
processor = Processor(
taskgroup=MagicMock(),
@ -451,7 +455,7 @@ class TestCassandraQueryProcessor:
mock_result.dtype = None
mock_result.lang = None
mock_result.o = 'test_object'
mock_tg_instance.get_spo.return_value = [mock_result]
mock_tg_instance.async_get_spo = AsyncMock(return_value=[mock_result])
processor = Processor(taskgroup=MagicMock())
@ -489,8 +493,8 @@ class TestCassandraQueryProcessor:
mock_result.lang = None
mock_result.p = 'p'
mock_result.o = 'o'
mock_tg_instance1.get_s.return_value = [mock_result]
mock_tg_instance2.get_s.return_value = [mock_result]
mock_tg_instance1.async_get_s = AsyncMock(return_value=[mock_result])
mock_tg_instance2.async_get_s = AsyncMock(return_value=[mock_result])
processor = Processor(taskgroup=MagicMock())
@ -504,7 +508,6 @@ class TestCassandraQueryProcessor:
)
await processor.query_triples('user1', query1)
assert processor.table == 'user1'
# Second query with different table
query2 = TriplesQueryRequest(
@ -516,10 +519,11 @@ class TestCassandraQueryProcessor:
)
await processor.query_triples('user2', query2)
assert processor.table == 'user2'
# Verify TrustGraph was created twice
# Verify TrustGraph was created twice for different workspaces
assert mock_kg_class.call_count == 2
mock_kg_class.assert_any_call(hosts=['cassandra'], keyspace='user1')
mock_kg_class.assert_any_call(hosts=['cassandra'], keyspace='user2')
@pytest.mark.asyncio
@patch('trustgraph.query.triples.cassandra.service.EntityCentricKnowledgeGraph')
@ -529,7 +533,7 @@ class TestCassandraQueryProcessor:
mock_tg_instance = MagicMock()
mock_kg_class.return_value = mock_tg_instance
mock_tg_instance.get_spo.side_effect = Exception("Query failed")
mock_tg_instance.async_get_spo = AsyncMock(side_effect=Exception("Query failed"))
processor = Processor(taskgroup=MagicMock())
@ -566,7 +570,7 @@ class TestCassandraQueryProcessor:
mock_result2.otype = None
mock_result2.dtype = None
mock_result2.lang = None
mock_tg_instance.get_sp.return_value = [mock_result1, mock_result2]
mock_tg_instance.async_get_sp = AsyncMock(return_value=[mock_result1, mock_result2])
processor = Processor(taskgroup=MagicMock())
@ -603,7 +607,7 @@ class TestCassandraQueryPerformanceOptimizations:
mock_result.otype = None
mock_result.dtype = None
mock_result.lang = None
mock_tg_instance.get_po.return_value = [mock_result]
mock_tg_instance.async_get_po = AsyncMock(return_value=[mock_result])
processor = Processor(taskgroup=MagicMock())
@ -618,8 +622,8 @@ class TestCassandraQueryPerformanceOptimizations:
result = await processor.query_triples('test_user', query)
# Verify get_po was called (should use optimized po_table)
mock_tg_instance.get_po.assert_called_once_with(
# Verify async_get_po was called (should use optimized po_table)
mock_tg_instance.async_get_po.assert_called_once_with(
'test_collection', 'test_predicate', 'test_object', g=None, limit=50
)
@ -643,7 +647,7 @@ class TestCassandraQueryPerformanceOptimizations:
mock_result.otype = None
mock_result.dtype = None
mock_result.lang = None
mock_tg_instance.get_os.return_value = [mock_result]
mock_tg_instance.async_get_os = AsyncMock(return_value=[mock_result])
processor = Processor(taskgroup=MagicMock())
@ -658,8 +662,8 @@ class TestCassandraQueryPerformanceOptimizations:
result = await processor.query_triples('test_user', query)
# Verify get_os was called (should use optimized subject_table with clustering)
mock_tg_instance.get_os.assert_called_once_with(
# Verify async_get_os was called (should use optimized subject_table with clustering)
mock_tg_instance.async_get_os.assert_called_once_with(
'test_collection', 'test_object', 'test_subject', g=None, limit=25
)
@ -678,28 +682,28 @@ class TestCassandraQueryPerformanceOptimizations:
mock_kg_class.return_value = mock_tg_instance
# Mock empty results for all queries
mock_tg_instance.get_all.return_value = []
mock_tg_instance.get_s.return_value = []
mock_tg_instance.get_p.return_value = []
mock_tg_instance.get_o.return_value = []
mock_tg_instance.get_sp.return_value = []
mock_tg_instance.get_po.return_value = []
mock_tg_instance.get_os.return_value = []
mock_tg_instance.get_spo.return_value = []
mock_tg_instance.async_get_all = AsyncMock(return_value=[])
mock_tg_instance.async_get_s = AsyncMock(return_value=[])
mock_tg_instance.async_get_p = AsyncMock(return_value=[])
mock_tg_instance.async_get_o = AsyncMock(return_value=[])
mock_tg_instance.async_get_sp = AsyncMock(return_value=[])
mock_tg_instance.async_get_po = AsyncMock(return_value=[])
mock_tg_instance.async_get_os = AsyncMock(return_value=[])
mock_tg_instance.async_get_spo = AsyncMock(return_value=[])
processor = Processor(taskgroup=MagicMock())
# Test each query pattern
test_patterns = [
# (s, p, o, expected_method)
(None, None, None, 'get_all'), # All triples
('s1', None, None, 'get_s'), # Subject only
(None, 'p1', None, 'get_p'), # Predicate only
(None, None, 'o1', 'get_o'), # Object only
('s1', 'p1', None, 'get_sp'), # Subject + Predicate
(None, 'p1', 'o1', 'get_po'), # Predicate + Object (CRITICAL OPTIMIZATION)
('s1', None, 'o1', 'get_os'), # Object + Subject
('s1', 'p1', 'o1', 'get_spo'), # All three
(None, None, None, 'async_get_all'), # All triples
('s1', None, None, 'async_get_s'), # Subject only
(None, 'p1', None, 'async_get_p'), # Predicate only
(None, None, 'o1', 'async_get_o'), # Object only
('s1', 'p1', None, 'async_get_sp'), # Subject + Predicate
(None, 'p1', 'o1', 'async_get_po'), # Predicate + Object (CRITICAL OPTIMIZATION)
('s1', None, 'o1', 'async_get_os'), # Object + Subject
('s1', 'p1', 'o1', 'async_get_spo'), # All three
]
for s, p, o, expected_method in test_patterns:
@ -759,7 +763,7 @@ class TestCassandraQueryPerformanceOptimizations:
mock_result.lang = None
mock_results.append(mock_result)
mock_tg_instance.get_po.return_value = mock_results
mock_tg_instance.async_get_po = AsyncMock(return_value=mock_results)
processor = Processor(taskgroup=MagicMock())
@ -774,8 +778,8 @@ class TestCassandraQueryPerformanceOptimizations:
result = await processor.query_triples('large_dataset_user', query)
# Verify optimized get_po was used (no ALLOW FILTERING needed!)
mock_tg_instance.get_po.assert_called_once_with(
# Verify optimized async_get_po was used (no ALLOW FILTERING needed!)
mock_tg_instance.async_get_po.assert_called_once_with(
'massive_collection',
'http://www.w3.org/1999/02/22-rdf-syntax-ns#type',
'http://example.com/Person',

View file

@ -113,12 +113,15 @@ class TestDocEmbeddingsNullProtection:
@pytest.mark.asyncio
async def test_valid_embedding_upserted(self):
import asyncio
from trustgraph.storage.doc_embeddings.qdrant.write import Processor
proc = Processor.__new__(Processor)
proc.qdrant = MagicMock()
proc.qdrant.collection_exists.return_value = True
proc.collection_exists = MagicMock(return_value=True)
proc._cache_lock = asyncio.Lock()
proc._known_collections = set()
msg = MagicMock()
msg.metadata.collection = "col1"
@ -134,12 +137,15 @@ class TestDocEmbeddingsNullProtection:
@pytest.mark.asyncio
async def test_dimension_in_collection_name(self):
"""Collection name should include vector dimension."""
import asyncio
from trustgraph.storage.doc_embeddings.qdrant.write import Processor
proc = Processor.__new__(Processor)
proc.qdrant = MagicMock()
proc.qdrant.collection_exists.return_value = True
proc.collection_exists = MagicMock(return_value=True)
proc._cache_lock = asyncio.Lock()
proc._known_collections = set()
msg = MagicMock()
msg.metadata.collection = "docs"
@ -220,12 +226,15 @@ class TestGraphEmbeddingsNullProtection:
@pytest.mark.asyncio
async def test_valid_entity_and_vector_upserted(self):
import asyncio
from trustgraph.storage.graph_embeddings.qdrant.write import Processor
proc = Processor.__new__(Processor)
proc.qdrant = MagicMock()
proc.qdrant.collection_exists.return_value = True
proc.collection_exists = MagicMock(return_value=True)
proc._cache_lock = asyncio.Lock()
proc._known_collections = set()
msg = MagicMock()
msg.metadata.collection = "col1"
@ -241,12 +250,15 @@ class TestGraphEmbeddingsNullProtection:
@pytest.mark.asyncio
async def test_lazy_collection_creation_on_new_dimension(self):
import asyncio
from trustgraph.storage.graph_embeddings.qdrant.write import Processor
proc = Processor.__new__(Processor)
proc.qdrant = MagicMock()
proc.qdrant.collection_exists.return_value = False
proc.collection_exists = MagicMock(return_value=True)
proc._cache_lock = asyncio.Lock()
proc._known_collections = set()
msg = MagicMock()
msg.metadata.collection = "graphs"

View file

@ -337,6 +337,57 @@ class TestQuery:
cache_key = "test_collection:unlabeled_entity"
mock_cache.put.assert_called_once_with(cache_key, "unlabeled_entity")
@pytest.mark.asyncio
async def test_triples_query_never_passes_workspace(self):
"""Workspace isolation is handled by pub/sub topic routing, not
by passing workspace to TriplesClient.query(). Verify that
GraphRAG never passes workspace as a keyword argument."""
mock_rag = MagicMock()
mock_cache = MagicMock()
mock_cache.get.return_value = None
mock_rag.label_cache = mock_cache
mock_triples_client = AsyncMock()
mock_rag.triples_client = mock_triples_client
mock_triple = MagicMock()
mock_triple.o = "Label"
mock_triples_client.query.return_value = [mock_triple]
query = Query(
rag=mock_rag,
collection="test_collection",
verbose=False
)
await query.maybe_label("http://example.com/entity")
for c in mock_triples_client.query.call_args_list:
assert "workspace" not in c.kwargs
@pytest.mark.asyncio
async def test_follow_edges_never_passes_workspace(self):
"""Verify follow_edges never passes workspace to query_stream."""
mock_rag = MagicMock()
mock_triples_client = AsyncMock()
mock_rag.triples_client = mock_triples_client
mock_triple = MagicMock()
mock_triple.s, mock_triple.p, mock_triple.o = "e1", "p1", "o1"
mock_triples_client.query_stream.return_value = [mock_triple]
query = Query(
rag=mock_rag,
collection="test_collection",
verbose=False,
triple_limit=10
)
subgraph = set()
await query.follow_edges("e1", subgraph, path_length=1)
for c in mock_triples_client.query_stream.call_args_list:
assert "workspace" not in c.kwargs
@pytest.mark.asyncio
async def test_follow_edges_basic_functionality(self):
"""Test Query.follow_edges method basic triple discovery"""

View file

View file

@ -3,275 +3,279 @@ Tests for Reverse Gateway Dispatcher
"""
import pytest
from unittest.mock import MagicMock, AsyncMock, patch
import asyncio
from unittest.mock import MagicMock, AsyncMock, patch, ANY
from trustgraph.rev_gateway.dispatcher import WebSocketResponder, MessageDispatcher
class TestWebSocketResponder:
"""Test cases for WebSocketResponder class"""
def test_websocket_responder_initialization(self):
"""Test WebSocketResponder initialization"""
responder = WebSocketResponder()
assert responder.response is None
assert responder.completed is False
@pytest.mark.asyncio
async def test_websocket_responder_send_method(self):
"""Test WebSocketResponder send method"""
responder = WebSocketResponder()
test_response = {"data": "test response"}
# Call send method
await responder.send(test_response)
# Verify response was stored
assert responder.response == test_response
@pytest.mark.asyncio
async def test_websocket_responder_call_method(self):
"""Test WebSocketResponder __call__ method"""
responder = WebSocketResponder()
test_response = {"result": "success"}
test_completed = True
# Call the responder
await responder(test_response, test_completed)
# Verify response and completed status were set
assert responder.response == test_response
assert responder.completed == test_completed
@pytest.mark.asyncio
async def test_websocket_responder_call_method_with_false_completion(self):
"""Test WebSocketResponder __call__ method with incomplete response"""
responder = WebSocketResponder()
test_response = {"partial": "data"}
test_completed = False
# Call the responder
await responder(test_response, test_completed)
# Verify response was set and completed is True (since send() always sets completed=True)
assert responder.response == test_response
assert responder.completed is True
from trustgraph.rev_gateway.dispatcher import MessageDispatcher
class TestMessageDispatcher:
"""Test cases for MessageDispatcher class"""
def test_message_dispatcher_initialization_with_defaults(self):
"""Test MessageDispatcher initialization with default parameters"""
dispatcher = MessageDispatcher()
assert dispatcher.max_workers == 10
assert dispatcher.semaphore._value == 10
assert dispatcher.active_tasks == set()
assert dispatcher.backend is None
assert dispatcher.auth is None
assert dispatcher.dispatcher_manager is None
assert len(dispatcher.service_mapping) > 0
def test_message_dispatcher_initialization_with_custom_workers(self):
"""Test MessageDispatcher initialization with custom max_workers"""
dispatcher = MessageDispatcher(max_workers=5)
assert dispatcher.max_workers == 5
assert dispatcher.semaphore._value == 5
@patch('trustgraph.rev_gateway.dispatcher.DispatcherManager')
def test_message_dispatcher_initialization_with_pulsar_client(self, mock_dispatcher_manager):
"""Test MessageDispatcher initialization with pulsar_client and config_receiver"""
def test_message_dispatcher_initialization_with_backend(
self, mock_dispatcher_manager,
):
mock_backend = MagicMock()
mock_config_receiver = MagicMock()
mock_auth = MagicMock()
mock_dispatcher_instance = MagicMock()
mock_dispatcher_manager.return_value = mock_dispatcher_instance
dispatcher = MessageDispatcher(
max_workers=8,
config_receiver=mock_config_receiver,
backend=mock_backend
backend=mock_backend,
auth=mock_auth,
timeout=300,
)
assert dispatcher.max_workers == 8
assert dispatcher.backend == mock_backend
assert dispatcher.auth == mock_auth
assert dispatcher.dispatcher_manager == mock_dispatcher_instance
mock_dispatcher_manager.assert_called_once_with(
mock_backend, mock_config_receiver, prefix="rev-gateway"
mock_backend, mock_config_receiver,
auth=mock_auth, prefix="rev-gateway", timeout=300,
)
def test_message_dispatcher_service_mapping(self):
"""Test MessageDispatcher service mapping contains expected services"""
dispatcher = MessageDispatcher()
expected_services = [
"text-completion", "graph-rag", "agent", "embeddings",
"graph-embeddings", "triples", "document-load", "text-load",
"flow", "knowledge", "config", "librarian", "document-rag"
"flow", "knowledge", "config", "librarian", "document-rag",
]
for service in expected_services:
assert service in dispatcher.service_mapping
# Test specific mappings
assert dispatcher.service_mapping["text-completion"] == "text-completion"
assert dispatcher.service_mapping["document-load"] == "document"
assert dispatcher.service_mapping["text-load"] == "text-document"
@pytest.mark.asyncio
async def test_message_dispatcher_handle_message_without_dispatcher_manager(self):
"""Test MessageDispatcher handle_message without dispatcher manager"""
async def test_handle_message_without_dispatcher_manager(self):
dispatcher = MessageDispatcher()
test_message = {
"id": "test-123",
"service": "test-service",
"request": {"data": "test"}
}
result = await dispatcher.handle_message(test_message)
assert result["id"] == "test-123"
assert "error" in result["response"]
assert "DispatcherManager not available" in result["response"]["error"]
dispatcher.auth = MagicMock()
dispatcher.auth.authenticate = AsyncMock(
return_value=MagicMock(workspace="default")
)
sender = AsyncMock()
await dispatcher.handle_message(
{"id": "test-1", "service": "test", "request": {}},
sender,
)
sender.assert_called_once()
sent = sender.call_args[0][0]
assert sent["id"] == "test-1"
assert sent["error"]["message"] == "DispatcherManager not available"
assert sent["error"]["type"] == "error"
assert sent["complete"] is True
@pytest.mark.asyncio
async def test_message_dispatcher_handle_message_with_exception(self):
"""Test MessageDispatcher handle_message with exception during processing"""
mock_dispatcher_manager = MagicMock()
mock_dispatcher_manager.invoke_global_service = AsyncMock(side_effect=Exception("Test error"))
async def test_handle_message_auth_failure(self):
dispatcher = MessageDispatcher()
dispatcher.dispatcher_manager = mock_dispatcher_manager
test_message = {
"id": "test-456",
"service": "text-completion",
"request": {"prompt": "test"}
}
with patch('trustgraph.gateway.dispatch.manager.global_dispatchers', {"text-completion": True}):
result = await dispatcher.handle_message(test_message)
assert result["id"] == "test-456"
assert "error" in result["response"]
assert "Test error" in result["response"]["error"]
dispatcher.auth = MagicMock()
dispatcher.auth.authenticate = AsyncMock(
side_effect=Exception("auth failure")
)
dispatcher.dispatcher_manager = MagicMock()
sender = AsyncMock()
await dispatcher.handle_message(
{"id": "test-2", "token": "bad", "service": "test", "request": {}},
sender,
)
sender.assert_called_once()
sent = sender.call_args[0][0]
assert sent["id"] == "test-2"
assert "auth failure" in sent["error"]["message"]
assert sent["complete"] is True
@pytest.mark.asyncio
async def test_message_dispatcher_handle_message_global_service(self):
"""Test MessageDispatcher handle_message with global service"""
mock_dispatcher_manager = MagicMock()
mock_dispatcher_manager.invoke_global_service = AsyncMock()
mock_responder = MagicMock()
mock_responder.completed = True
mock_responder.response = {"result": "success"}
async def test_handle_message_global_service(self):
mock_dm = MagicMock()
mock_dm.invoke_global_service = AsyncMock()
dispatcher = MessageDispatcher()
dispatcher.dispatcher_manager = mock_dispatcher_manager
test_message = {
"id": "test-789",
"service": "text-completion",
"request": {"prompt": "hello"}
}
with patch('trustgraph.gateway.dispatch.manager.global_dispatchers', {"text-completion": True}):
with patch('trustgraph.rev_gateway.dispatcher.WebSocketResponder', return_value=mock_responder):
result = await dispatcher.handle_message(test_message)
assert result["id"] == "test-789"
assert result["response"] == {"result": "success"}
mock_dispatcher_manager.invoke_global_service.assert_called_once()
dispatcher.dispatcher_manager = mock_dm
dispatcher.auth = MagicMock()
dispatcher.auth.authenticate = AsyncMock(
return_value=MagicMock(workspace="ws1")
)
sender = AsyncMock()
with patch(
'trustgraph.gateway.dispatch.manager.global_dispatchers',
{"text-completion": True},
):
await dispatcher.handle_message(
{
"id": "test-3",
"token": "tg_key",
"service": "text-completion",
"request": {"prompt": "hello"},
},
sender,
)
mock_dm.invoke_global_service.assert_called_once()
args, kwargs = mock_dm.invoke_global_service.call_args
assert args[0] == {"prompt": "hello"}
assert args[2] == "text-completion"
assert kwargs["workspace"] == "ws1"
@pytest.mark.asyncio
async def test_message_dispatcher_handle_message_flow_service(self):
"""Test MessageDispatcher handle_message with flow service"""
mock_dispatcher_manager = MagicMock()
mock_dispatcher_manager.invoke_flow_service = AsyncMock()
mock_responder = MagicMock()
mock_responder.completed = True
mock_responder.response = {"data": "flow_result"}
async def test_handle_message_flow_service(self):
mock_dm = MagicMock()
mock_dm.invoke_flow_service = AsyncMock()
dispatcher = MessageDispatcher()
dispatcher.dispatcher_manager = mock_dispatcher_manager
test_message = {
"id": "test-flow-123",
"service": "document-rag",
"request": {"query": "test"},
"flow": "custom-flow"
}
with patch('trustgraph.gateway.dispatch.manager.global_dispatchers', {}):
with patch('trustgraph.rev_gateway.dispatcher.WebSocketResponder', return_value=mock_responder):
result = await dispatcher.handle_message(test_message)
assert result["id"] == "test-flow-123"
assert result["response"] == {"data": "flow_result"}
mock_dispatcher_manager.invoke_flow_service.assert_called_once_with(
{"query": "test"}, mock_responder, "custom-flow", "document-rag"
dispatcher.dispatcher_manager = mock_dm
dispatcher.auth = MagicMock()
dispatcher.auth.authenticate = AsyncMock(
return_value=MagicMock(workspace="ws2")
)
sender = AsyncMock()
with patch(
'trustgraph.gateway.dispatch.manager.global_dispatchers', {},
):
await dispatcher.handle_message(
{
"id": "test-4",
"token": "tg_key",
"service": "document-rag",
"request": {"query": "test"},
"flow": "my-flow",
},
sender,
)
mock_dm.invoke_flow_service.assert_called_once_with(
{"query": "test"}, ANY, "ws2", "my-flow", "document-rag",
)
@pytest.mark.asyncio
async def test_message_dispatcher_handle_message_incomplete_response(self):
"""Test MessageDispatcher handle_message with incomplete response"""
mock_dispatcher_manager = MagicMock()
mock_dispatcher_manager.invoke_flow_service = AsyncMock()
mock_responder = MagicMock()
mock_responder.completed = False
mock_responder.response = None
async def test_handle_message_responder_sends_frames(self):
mock_dm = MagicMock()
async def fake_invoke(data, responder, svc, workspace=None):
await responder({"partial": 1}, False)
await responder({"partial": 2}, True)
mock_dm.invoke_global_service = AsyncMock(side_effect=fake_invoke)
dispatcher = MessageDispatcher()
dispatcher.dispatcher_manager = mock_dispatcher_manager
test_message = {
"id": "test-incomplete",
"service": "agent",
"request": {"input": "test"}
dispatcher.dispatcher_manager = mock_dm
dispatcher.auth = MagicMock()
dispatcher.auth.authenticate = AsyncMock(
return_value=MagicMock(workspace="ws1")
)
sender = AsyncMock()
with patch(
'trustgraph.gateway.dispatch.manager.global_dispatchers',
{"text-completion": True},
):
await dispatcher.handle_message(
{
"id": "test-5",
"token": "tg_key",
"service": "text-completion",
"request": {"prompt": "hi"},
},
sender,
)
assert sender.call_count == 2
first = sender.call_args_list[0][0][0]
second = sender.call_args_list[1][0][0]
assert first == {
"id": "test-5", "response": {"partial": 1}, "complete": False,
}
assert second == {
"id": "test-5", "response": {"partial": 2}, "complete": True,
}
with patch('trustgraph.gateway.dispatch.manager.global_dispatchers', {}):
with patch('trustgraph.rev_gateway.dispatcher.WebSocketResponder', return_value=mock_responder):
result = await dispatcher.handle_message(test_message)
assert result["id"] == "test-incomplete"
assert result["response"] == {"error": "No response received"}
@pytest.mark.asyncio
async def test_message_dispatcher_shutdown(self):
"""Test MessageDispatcher shutdown method"""
import asyncio
async def test_handle_message_workspace_from_identity(self):
mock_dm = MagicMock()
mock_dm.invoke_flow_service = AsyncMock()
dispatcher = MessageDispatcher()
# Create actual async tasks
dispatcher.dispatcher_manager = mock_dm
dispatcher.auth = MagicMock()
dispatcher.auth.authenticate = AsyncMock(
return_value=MagicMock(workspace="derived-ws")
)
sender = AsyncMock()
with patch(
'trustgraph.gateway.dispatch.manager.global_dispatchers', {},
):
await dispatcher.handle_message(
{
"id": "test-6",
"token": "tg_key",
"service": "agent",
"request": {"question": "test"},
"flow": "default",
},
sender,
)
args = mock_dm.invoke_flow_service.call_args[0]
assert args[2] == "derived-ws"
@pytest.mark.asyncio
async def test_shutdown(self):
dispatcher = MessageDispatcher()
async def dummy_task():
await asyncio.sleep(0.01)
return "done"
task1 = asyncio.create_task(dummy_task())
task2 = asyncio.create_task(dummy_task())
dispatcher.active_tasks = {task1, task2}
# Call shutdown
await dispatcher.shutdown()
# Verify tasks were completed
assert task1.done()
assert task2.done()
assert len(dispatcher.active_tasks) == 2 # Tasks remain in set but are completed
@pytest.mark.asyncio
async def test_message_dispatcher_shutdown_with_no_tasks(self):
"""Test MessageDispatcher shutdown with no active tasks"""
async def test_shutdown_with_no_tasks(self):
dispatcher = MessageDispatcher()
# Call shutdown with no active tasks
await dispatcher.shutdown()
# Should complete without error
assert dispatcher.active_tasks == set()
assert dispatcher.active_tasks == set()

View file

@ -8,22 +8,38 @@ from unittest.mock import MagicMock, AsyncMock, patch, Mock
from aiohttp import WSMsgType, ClientWebSocketResponse
import json
from trustgraph.rev_gateway.service import ReverseGateway, parse_args, run
from trustgraph.rev_gateway.service import ReverseGateway, run
MOCK_PATCHES = [
'trustgraph.rev_gateway.service.IamAuth',
'trustgraph.rev_gateway.service.ConfigReceiver',
'trustgraph.rev_gateway.service.MessageDispatcher',
'trustgraph.rev_gateway.service.get_pubsub',
]
def make_gateway(**overrides):
config = {"websocket_uri": "ws://localhost:7650/out"}
config.update(overrides)
return ReverseGateway(**config)
class TestReverseGateway:
"""Test cases for ReverseGateway class"""
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
@patch('trustgraph.rev_gateway.service.get_pubsub')
def test_reverse_gateway_initialization_defaults(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
"""Test ReverseGateway initialization with default parameters"""
mock_backend = MagicMock()
mock_get_pubsub.return_value = mock_backend
gateway = ReverseGateway()
@patch(*MOCK_PATCHES[0:1])
@patch(*MOCK_PATCHES[1:2])
@patch(*MOCK_PATCHES[2:3])
@patch(*MOCK_PATCHES[3:4])
def test_reverse_gateway_initialization_defaults(
self, mock_get_pubsub, mock_dispatcher,
mock_config_receiver, mock_iam_auth,
):
mock_get_pubsub.return_value = MagicMock()
gateway = make_gateway()
assert gateway.websocket_uri == "ws://localhost:7650/out"
assert gateway.host == "localhost"
assert gateway.port == 7650
@ -33,25 +49,22 @@ class TestReverseGateway:
assert gateway.max_workers == 10
assert gateway.running is False
assert gateway.reconnect_delay == 3.0
assert gateway.pulsar_host == "pulsar://pulsar:6650"
assert gateway.pulsar_api_key is None
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
@patch('trustgraph.rev_gateway.service.get_pubsub')
def test_reverse_gateway_initialization_custom_params(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
"""Test ReverseGateway initialization with custom parameters"""
mock_backend = MagicMock()
mock_get_pubsub.return_value = mock_backend
gateway = ReverseGateway(
@patch(*MOCK_PATCHES[0:1])
@patch(*MOCK_PATCHES[1:2])
@patch(*MOCK_PATCHES[2:3])
@patch(*MOCK_PATCHES[3:4])
def test_reverse_gateway_initialization_custom_params(
self, mock_get_pubsub, mock_dispatcher,
mock_config_receiver, mock_iam_auth,
):
mock_get_pubsub.return_value = MagicMock()
gateway = make_gateway(
websocket_uri="wss://example.com:8080/websocket",
max_workers=20,
pulsar_host="pulsar://custom:6650",
pulsar_api_key="test-key",
pulsar_listener="test-listener"
)
assert gateway.websocket_uri == "wss://example.com:8080/websocket"
assert gateway.host == "example.com"
assert gateway.port == 8080
@ -59,340 +72,360 @@ class TestReverseGateway:
assert gateway.path == "/websocket"
assert gateway.url == "wss://example.com:8080/websocket"
assert gateway.max_workers == 20
assert gateway.pulsar_host == "pulsar://custom:6650"
assert gateway.pulsar_api_key == "test-key"
assert gateway.pulsar_listener == "test-listener"
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
@patch('trustgraph.rev_gateway.service.get_pubsub')
def test_reverse_gateway_initialization_with_missing_path(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
"""Test ReverseGateway initialization with WebSocket URI missing path"""
mock_backend = MagicMock()
mock_get_pubsub.return_value = mock_backend
gateway = ReverseGateway(websocket_uri="ws://example.com")
@patch(*MOCK_PATCHES[0:1])
@patch(*MOCK_PATCHES[1:2])
@patch(*MOCK_PATCHES[2:3])
@patch(*MOCK_PATCHES[3:4])
def test_reverse_gateway_initialization_with_missing_path(
self, mock_get_pubsub, mock_dispatcher,
mock_config_receiver, mock_iam_auth,
):
mock_get_pubsub.return_value = MagicMock()
gateway = make_gateway(websocket_uri="ws://example.com")
assert gateway.path == "/ws"
assert gateway.url == "ws://example.com/ws"
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
@patch('trustgraph.rev_gateway.service.get_pubsub')
def test_reverse_gateway_initialization_invalid_scheme(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
"""Test ReverseGateway initialization with invalid WebSocket scheme"""
@patch(*MOCK_PATCHES[0:1])
@patch(*MOCK_PATCHES[1:2])
@patch(*MOCK_PATCHES[2:3])
@patch(*MOCK_PATCHES[3:4])
def test_reverse_gateway_initialization_invalid_scheme(
self, mock_get_pubsub, mock_dispatcher,
mock_config_receiver, mock_iam_auth,
):
with pytest.raises(ValueError, match="WebSocket URI must use ws:// or wss:// scheme"):
ReverseGateway(websocket_uri="http://example.com")
make_gateway(websocket_uri="http://example.com")
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
@patch('trustgraph.rev_gateway.service.get_pubsub')
def test_reverse_gateway_initialization_missing_hostname(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
"""Test ReverseGateway initialization with missing hostname"""
@patch(*MOCK_PATCHES[0:1])
@patch(*MOCK_PATCHES[1:2])
@patch(*MOCK_PATCHES[2:3])
@patch(*MOCK_PATCHES[3:4])
def test_reverse_gateway_initialization_missing_hostname(
self, mock_get_pubsub, mock_dispatcher,
mock_config_receiver, mock_iam_auth,
):
with pytest.raises(ValueError, match="WebSocket URI must include hostname"):
ReverseGateway(websocket_uri="ws://")
make_gateway(websocket_uri="ws://")
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
@patch('trustgraph.rev_gateway.service.get_pubsub')
def test_reverse_gateway_pulsar_client_with_auth(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
"""Test ReverseGateway creates backend with authentication"""
@patch(*MOCK_PATCHES[0:1])
@patch(*MOCK_PATCHES[1:2])
@patch(*MOCK_PATCHES[2:3])
@patch(*MOCK_PATCHES[3:4])
def test_reverse_gateway_iam_auth_created(
self, mock_get_pubsub, mock_dispatcher,
mock_config_receiver, mock_iam_auth,
):
mock_backend = MagicMock()
mock_get_pubsub.return_value = mock_backend
gateway = ReverseGateway(
pulsar_api_key="test-key",
pulsar_listener="test-listener"
gateway = make_gateway(id="test-rev-gw")
mock_iam_auth.assert_called_once_with(
backend=mock_backend,
id="test-rev-gw",
)
# Verify get_pubsub was called with the correct parameters
mock_get_pubsub.assert_called_once_with(
pulsar_host="pulsar://pulsar:6650",
pulsar_api_key="test-key",
pulsar_listener="test-listener"
@patch(*MOCK_PATCHES[0:1])
@patch(*MOCK_PATCHES[1:2])
@patch(*MOCK_PATCHES[2:3])
@patch(*MOCK_PATCHES[3:4])
def test_reverse_gateway_config_receiver_gets_auth(
self, mock_get_pubsub, mock_dispatcher,
mock_config_receiver, mock_iam_auth,
):
mock_backend = MagicMock()
mock_get_pubsub.return_value = mock_backend
mock_auth_instance = MagicMock()
mock_iam_auth.return_value = mock_auth_instance
gateway = make_gateway()
mock_config_receiver.assert_called_once_with(
mock_backend, auth=mock_auth_instance,
)
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
@patch('trustgraph.rev_gateway.service.get_pubsub')
@patch(*MOCK_PATCHES[0:1])
@patch(*MOCK_PATCHES[1:2])
@patch(*MOCK_PATCHES[2:3])
@patch(*MOCK_PATCHES[3:4])
@patch('trustgraph.rev_gateway.service.ClientSession')
@pytest.mark.asyncio
async def test_reverse_gateway_connect_success(self, mock_session_class, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
"""Test ReverseGateway successful connection"""
mock_backend = MagicMock()
mock_get_pubsub.return_value = mock_backend
async def test_reverse_gateway_connect_success(
self, mock_session_class, mock_get_pubsub,
mock_dispatcher, mock_config_receiver, mock_iam_auth,
):
mock_get_pubsub.return_value = MagicMock()
mock_session = AsyncMock()
mock_ws = AsyncMock()
mock_session.ws_connect.return_value = mock_ws
mock_session_class.return_value = mock_session
gateway = ReverseGateway()
gateway = make_gateway()
result = await gateway.connect()
assert result is True
assert gateway.session == mock_session
assert gateway.ws == mock_ws
mock_session.ws_connect.assert_called_once_with(gateway.url)
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
@patch('trustgraph.rev_gateway.service.get_pubsub')
@patch(*MOCK_PATCHES[0:1])
@patch(*MOCK_PATCHES[1:2])
@patch(*MOCK_PATCHES[2:3])
@patch(*MOCK_PATCHES[3:4])
@patch('trustgraph.rev_gateway.service.ClientSession')
@pytest.mark.asyncio
async def test_reverse_gateway_connect_failure(self, mock_session_class, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
"""Test ReverseGateway connection failure"""
mock_backend = MagicMock()
mock_get_pubsub.return_value = mock_backend
async def test_reverse_gateway_connect_failure(
self, mock_session_class, mock_get_pubsub,
mock_dispatcher, mock_config_receiver, mock_iam_auth,
):
mock_get_pubsub.return_value = MagicMock()
mock_session = AsyncMock()
mock_session.ws_connect.side_effect = Exception("Connection failed")
mock_session_class.return_value = mock_session
gateway = ReverseGateway()
gateway = make_gateway()
result = await gateway.connect()
assert result is False
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
@patch('trustgraph.rev_gateway.service.get_pubsub')
@patch(*MOCK_PATCHES[0:1])
@patch(*MOCK_PATCHES[1:2])
@patch(*MOCK_PATCHES[2:3])
@patch(*MOCK_PATCHES[3:4])
@pytest.mark.asyncio
async def test_reverse_gateway_disconnect(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
"""Test ReverseGateway disconnect"""
mock_backend = MagicMock()
mock_get_pubsub.return_value = mock_backend
gateway = ReverseGateway()
# Mock websocket and session
async def test_reverse_gateway_disconnect(
self, mock_get_pubsub, mock_dispatcher,
mock_config_receiver, mock_iam_auth,
):
mock_get_pubsub.return_value = MagicMock()
gateway = make_gateway()
mock_ws = AsyncMock()
mock_ws.closed = False
mock_session = AsyncMock()
mock_session.closed = False
gateway.ws = mock_ws
gateway.session = mock_session
await gateway.disconnect()
mock_ws.close.assert_called_once()
mock_session.close.assert_called_once()
assert gateway.ws is None
assert gateway.session is None
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
@patch('trustgraph.rev_gateway.service.get_pubsub')
@patch(*MOCK_PATCHES[0:1])
@patch(*MOCK_PATCHES[1:2])
@patch(*MOCK_PATCHES[2:3])
@patch(*MOCK_PATCHES[3:4])
@pytest.mark.asyncio
async def test_reverse_gateway_send_message(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
"""Test ReverseGateway send message"""
mock_backend = MagicMock()
mock_get_pubsub.return_value = mock_backend
gateway = ReverseGateway()
# Mock websocket
async def test_reverse_gateway_send_message(
self, mock_get_pubsub, mock_dispatcher,
mock_config_receiver, mock_iam_auth,
):
mock_get_pubsub.return_value = MagicMock()
gateway = make_gateway()
mock_ws = AsyncMock()
mock_ws.closed = False
gateway.ws = mock_ws
test_message = {"id": "test", "data": "hello"}
await gateway.send_message(test_message)
mock_ws.send_str.assert_called_once_with(json.dumps(test_message))
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
@patch('trustgraph.rev_gateway.service.get_pubsub')
@patch(*MOCK_PATCHES[0:1])
@patch(*MOCK_PATCHES[1:2])
@patch(*MOCK_PATCHES[2:3])
@patch(*MOCK_PATCHES[3:4])
@pytest.mark.asyncio
async def test_reverse_gateway_send_message_closed_connection(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
"""Test ReverseGateway send message with closed connection"""
mock_backend = MagicMock()
mock_get_pubsub.return_value = mock_backend
gateway = ReverseGateway()
# Mock closed websocket
async def test_reverse_gateway_send_message_closed_connection(
self, mock_get_pubsub, mock_dispatcher,
mock_config_receiver, mock_iam_auth,
):
mock_get_pubsub.return_value = MagicMock()
gateway = make_gateway()
mock_ws = AsyncMock()
mock_ws.closed = True
gateway.ws = mock_ws
test_message = {"id": "test", "data": "hello"}
await gateway.send_message(test_message)
# Should not call send_str on closed connection
mock_ws.send_str.assert_not_called()
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
@patch('trustgraph.rev_gateway.service.get_pubsub')
@patch(*MOCK_PATCHES[0:1])
@patch(*MOCK_PATCHES[1:2])
@patch(*MOCK_PATCHES[2:3])
@patch(*MOCK_PATCHES[3:4])
@pytest.mark.asyncio
async def test_reverse_gateway_handle_message(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
"""Test ReverseGateway handle message"""
mock_backend = MagicMock()
mock_get_pubsub.return_value = mock_backend
mock_dispatcher_instance = AsyncMock()
mock_dispatcher_instance.handle_message.return_value = {"response": "success"}
mock_dispatcher.return_value = mock_dispatcher_instance
gateway = ReverseGateway()
# Mock send_message
gateway.send_message = AsyncMock()
test_message = '{"id": "test", "service": "test-service", "request": {"data": "test"}}'
await gateway.handle_message(test_message)
mock_dispatcher_instance.handle_message.assert_called_once_with({
"id": "test",
"service": "test-service",
"request": {"data": "test"}
})
gateway.send_message.assert_called_once_with({"response": "success"})
async def test_reverse_gateway_handle_message(
self, mock_get_pubsub, mock_dispatcher,
mock_config_receiver, mock_iam_auth,
):
mock_get_pubsub.return_value = MagicMock()
mock_dispatcher_instance = AsyncMock()
mock_dispatcher.return_value = mock_dispatcher_instance
gateway = make_gateway()
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
@patch('trustgraph.rev_gateway.service.get_pubsub')
@pytest.mark.asyncio
async def test_reverse_gateway_handle_message_invalid_json(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
"""Test ReverseGateway handle message with invalid JSON"""
mock_backend = MagicMock()
mock_get_pubsub.return_value = mock_backend
gateway = ReverseGateway()
# Mock send_message
gateway.send_message = AsyncMock()
test_message = 'invalid json'
# Should not raise exception
test_message = '{"id": "test", "service": "test-service", "request": {"data": "test"}}'
await gateway.handle_message(test_message)
# Should not call send_message due to error
mock_dispatcher_instance.handle_message.assert_called_once_with(
{
"id": "test",
"service": "test-service",
"request": {"data": "test"},
},
gateway.send_message,
)
@patch(*MOCK_PATCHES[0:1])
@patch(*MOCK_PATCHES[1:2])
@patch(*MOCK_PATCHES[2:3])
@patch(*MOCK_PATCHES[3:4])
@pytest.mark.asyncio
async def test_reverse_gateway_handle_message_invalid_json(
self, mock_get_pubsub, mock_dispatcher,
mock_config_receiver, mock_iam_auth,
):
mock_get_pubsub.return_value = MagicMock()
gateway = make_gateway()
gateway.send_message = AsyncMock()
await gateway.handle_message('invalid json')
gateway.send_message.assert_not_called()
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
@patch('trustgraph.rev_gateway.service.get_pubsub')
@patch(*MOCK_PATCHES[0:1])
@patch(*MOCK_PATCHES[1:2])
@patch(*MOCK_PATCHES[2:3])
@patch(*MOCK_PATCHES[3:4])
@pytest.mark.asyncio
async def test_reverse_gateway_listen_text_message(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
"""Test ReverseGateway listen with text message"""
mock_backend = MagicMock()
mock_get_pubsub.return_value = mock_backend
gateway = ReverseGateway()
async def test_reverse_gateway_listen_text_message(
self, mock_get_pubsub, mock_dispatcher,
mock_config_receiver, mock_iam_auth,
):
mock_get_pubsub.return_value = MagicMock()
gateway = make_gateway()
gateway.running = True
# Mock websocket
mock_ws = AsyncMock()
mock_ws.closed = False
gateway.ws = mock_ws
# Mock handle_message
gateway.handle_message = AsyncMock()
# Mock message
mock_msg = MagicMock()
mock_msg.type = WSMsgType.TEXT
mock_msg.data = '{"test": "message"}'
# Mock receive to return message once, then raise exception to stop loop
mock_ws.receive.side_effect = [mock_msg, Exception("Test stop")]
# listen() catches exceptions and breaks, so no exception should be raised
await gateway.listen()
gateway.handle_message.assert_called_once_with('{"test": "message"}')
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
@patch('trustgraph.rev_gateway.service.get_pubsub')
@patch(*MOCK_PATCHES[0:1])
@patch(*MOCK_PATCHES[1:2])
@patch(*MOCK_PATCHES[2:3])
@patch(*MOCK_PATCHES[3:4])
@pytest.mark.asyncio
async def test_reverse_gateway_listen_binary_message(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
"""Test ReverseGateway listen with binary message"""
mock_backend = MagicMock()
mock_get_pubsub.return_value = mock_backend
gateway = ReverseGateway()
async def test_reverse_gateway_listen_binary_message(
self, mock_get_pubsub, mock_dispatcher,
mock_config_receiver, mock_iam_auth,
):
mock_get_pubsub.return_value = MagicMock()
gateway = make_gateway()
gateway.running = True
# Mock websocket
mock_ws = AsyncMock()
mock_ws.closed = False
gateway.ws = mock_ws
# Mock handle_message
gateway.handle_message = AsyncMock()
# Mock message
mock_msg = MagicMock()
mock_msg.type = WSMsgType.BINARY
mock_msg.data = b'{"test": "binary"}'
# Mock receive to return message once, then raise exception to stop loop
mock_ws.receive.side_effect = [mock_msg, Exception("Test stop")]
# listen() catches exceptions and breaks, so no exception should be raised
await gateway.listen()
gateway.handle_message.assert_called_once_with('{"test": "binary"}')
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
@patch('trustgraph.rev_gateway.service.get_pubsub')
@patch(*MOCK_PATCHES[0:1])
@patch(*MOCK_PATCHES[1:2])
@patch(*MOCK_PATCHES[2:3])
@patch(*MOCK_PATCHES[3:4])
@pytest.mark.asyncio
async def test_reverse_gateway_listen_close_message(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
"""Test ReverseGateway listen with close message"""
mock_backend = MagicMock()
mock_get_pubsub.return_value = mock_backend
gateway = ReverseGateway()
async def test_reverse_gateway_listen_close_message(
self, mock_get_pubsub, mock_dispatcher,
mock_config_receiver, mock_iam_auth,
):
mock_get_pubsub.return_value = MagicMock()
gateway = make_gateway()
gateway.running = True
# Mock websocket
mock_ws = AsyncMock()
mock_ws.closed = False
gateway.ws = mock_ws
# Mock handle_message
gateway.handle_message = AsyncMock()
# Mock message
mock_msg = MagicMock()
mock_msg.type = WSMsgType.CLOSE
# Mock receive to return close message
mock_ws.receive.return_value = mock_msg
await gateway.listen()
# Should not call handle_message for close message
gateway.handle_message.assert_not_called()
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
@patch('trustgraph.rev_gateway.service.get_pubsub')
@patch(*MOCK_PATCHES[0:1])
@patch(*MOCK_PATCHES[1:2])
@patch(*MOCK_PATCHES[2:3])
@patch(*MOCK_PATCHES[3:4])
@pytest.mark.asyncio
async def test_reverse_gateway_shutdown(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
"""Test ReverseGateway shutdown"""
async def test_reverse_gateway_shutdown(
self, mock_get_pubsub, mock_dispatcher,
mock_config_receiver, mock_iam_auth,
):
mock_backend = MagicMock()
mock_get_pubsub.return_value = mock_backend
mock_dispatcher_instance = AsyncMock()
mock_dispatcher.return_value = mock_dispatcher_instance
gateway = ReverseGateway()
gateway = make_gateway()
gateway.running = True
# Mock disconnect
gateway.disconnect = AsyncMock()
await gateway.shutdown()
@ -402,46 +435,50 @@ class TestReverseGateway:
gateway.disconnect.assert_called_once()
mock_backend.close.assert_called_once()
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
@patch('trustgraph.rev_gateway.service.get_pubsub')
def test_reverse_gateway_stop(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
"""Test ReverseGateway stop"""
mock_backend = MagicMock()
mock_get_pubsub.return_value = mock_backend
gateway = ReverseGateway()
@patch(*MOCK_PATCHES[0:1])
@patch(*MOCK_PATCHES[1:2])
@patch(*MOCK_PATCHES[2:3])
@patch(*MOCK_PATCHES[3:4])
def test_reverse_gateway_stop(
self, mock_get_pubsub, mock_dispatcher,
mock_config_receiver, mock_iam_auth,
):
mock_get_pubsub.return_value = MagicMock()
gateway = make_gateway()
gateway.running = True
gateway.stop()
assert gateway.running is False
class TestReverseGatewayRun:
"""Test cases for ReverseGateway run method"""
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
@patch('trustgraph.rev_gateway.service.get_pubsub')
@patch(*MOCK_PATCHES[0:1])
@patch(*MOCK_PATCHES[1:2])
@patch(*MOCK_PATCHES[2:3])
@patch(*MOCK_PATCHES[3:4])
@pytest.mark.asyncio
async def test_reverse_gateway_run_successful_cycle(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
"""Test ReverseGateway run method with successful connect/listen cycle"""
mock_backend = MagicMock()
mock_get_pubsub.return_value = mock_backend
async def test_reverse_gateway_run_successful_cycle(
self, mock_get_pubsub, mock_dispatcher,
mock_config_receiver, mock_iam_auth,
):
mock_get_pubsub.return_value = MagicMock()
mock_auth_instance = AsyncMock()
mock_iam_auth.return_value = mock_auth_instance
mock_config_receiver_instance = AsyncMock()
mock_config_receiver.return_value = mock_config_receiver_instance
gateway = ReverseGateway()
# Mock methods
gateway.connect = AsyncMock(return_value=True)
gateway = make_gateway()
gateway.listen = AsyncMock()
gateway.disconnect = AsyncMock()
gateway.shutdown = AsyncMock()
# Stop after one iteration
call_count = 0
async def mock_connect():
nonlocal call_count
@ -451,91 +488,13 @@ class TestReverseGatewayRun:
else:
gateway.running = False
return False
gateway.connect = mock_connect
await gateway.run()
mock_auth_instance.start.assert_called_once()
mock_config_receiver_instance.start.assert_called_once()
gateway.listen.assert_called_once()
# disconnect is called twice: once in the main loop, once in shutdown
assert gateway.disconnect.call_count == 2
gateway.shutdown.assert_called_once()
class TestReverseGatewayArgs:
"""Test cases for argument parsing and run function"""
def test_parse_args_defaults(self):
"""Test parse_args with default values"""
import sys
# Mock sys.argv
original_argv = sys.argv
sys.argv = ['reverse-gateway']
try:
args = parse_args()
assert args.websocket_uri is None
assert args.max_workers == 10
assert args.pulsar_host is None
assert args.pulsar_api_key is None
assert args.pulsar_listener is None
finally:
sys.argv = original_argv
def test_parse_args_custom_values(self):
"""Test parse_args with custom values"""
import sys
# Mock sys.argv
original_argv = sys.argv
sys.argv = [
'reverse-gateway',
'--websocket-uri', 'ws://custom:8080/ws',
'--max-workers', '20',
'--pulsar-host', 'pulsar://custom:6650',
'--pulsar-api-key', 'test-key',
'--pulsar-listener', 'test-listener'
]
try:
args = parse_args()
assert args.websocket_uri == 'ws://custom:8080/ws'
assert args.max_workers == 20
assert args.pulsar_host == 'pulsar://custom:6650'
assert args.pulsar_api_key == 'test-key'
assert args.pulsar_listener == 'test-listener'
finally:
sys.argv = original_argv
@patch('trustgraph.rev_gateway.service.ReverseGateway')
@patch('asyncio.run')
def test_run_function(self, mock_asyncio_run, mock_gateway_class):
"""Test run function"""
import sys
# Mock sys.argv
original_argv = sys.argv
sys.argv = ['reverse-gateway', '--max-workers', '15']
try:
mock_gateway_instance = MagicMock()
mock_gateway_instance.url = "ws://localhost:7650/out"
mock_gateway_instance.pulsar_host = "pulsar://pulsar:6650"
mock_gateway_class.return_value = mock_gateway_instance
run()
mock_gateway_class.assert_called_once_with(
websocket_uri=None,
max_workers=15,
pulsar_host=None,
pulsar_api_key=None,
pulsar_listener=None
)
mock_asyncio_run.assert_called_once_with(mock_gateway_instance.run())
finally:
sys.argv = original_argv

View file

@ -413,8 +413,8 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
# Assert
expected_collection = 'd_cache_user_cache_collection_3' # 3 dimensions
# Verify collection existence is checked on each write
mock_qdrant_instance.collection_exists.assert_called_once_with(expected_collection)
# Second write uses cached collection state — no collection_exists check
mock_qdrant_instance.collection_exists.assert_not_called()
# But upsert should still be called
mock_qdrant_instance.upsert.assert_called_once()

View file

@ -125,13 +125,13 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase):
processor = Processor(**config)
processor.ensure_collection("test_collection", 384)
await processor.ensure_collection("test_collection", 384)
mock_qdrant_instance.collection_exists.assert_called_once_with("test_collection")
mock_qdrant_instance.create_collection.assert_called_once()
# Verify the collection is cached
assert "test_collection" in processor.created_collections
assert "test_collection" in processor._known_collections
@patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient')
async def test_ensure_collection_skips_existing(self, mock_qdrant_client):
@ -149,7 +149,7 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase):
processor = Processor(**config)
processor.ensure_collection("existing_collection", 384)
await processor.ensure_collection("existing_collection", 384)
mock_qdrant_instance.collection_exists.assert_called_once()
mock_qdrant_instance.create_collection.assert_not_called()
@ -168,9 +168,9 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase):
}
processor = Processor(**config)
processor.created_collections.add("cached_collection")
processor._known_collections.add("cached_collection")
processor.ensure_collection("cached_collection", 384)
await processor.ensure_collection("cached_collection", 384)
# Should not check or create - just return
mock_qdrant_instance.collection_exists.assert_not_called()
@ -391,7 +391,7 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase):
}
processor = Processor(**config)
processor.created_collections.add('rows_test_workspace_test_collection_schema1_384')
processor._known_collections.add('rows_test_workspace_test_collection_schema1_384')
await processor.delete_collection('test_workspace', 'test_collection')
@ -399,7 +399,7 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase):
assert mock_qdrant_instance.delete_collection.call_count == 2
# Verify the cached collection was removed
assert 'rows_test_workspace_test_collection_schema1_384' not in processor.created_collections
assert 'rows_test_workspace_test_collection_schema1_384' not in processor._known_collections
@patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient')
async def test_delete_collection_schema(self, mock_qdrant_client):

View file

@ -121,10 +121,13 @@ class TestRowsCassandraStorageLogic:
@pytest.mark.asyncio
async def test_schema_config_parsing(self):
"""Test parsing of schema configurations"""
import asyncio
processor = MagicMock()
processor.schemas = {}
processor.config_key = "schema"
processor.registered_partitions = set()
processor._setup_lock = asyncio.Lock()
processor._apply_schema_config = Processor._apply_schema_config.__get__(processor, Processor)
processor.on_schema_config = Processor.on_schema_config.__get__(processor, Processor)
# Create test configuration

View file

@ -2,6 +2,8 @@
Tests for Cassandra triples storage service
"""
import asyncio
import pytest
from unittest.mock import MagicMock, patch, AsyncMock
@ -24,12 +26,13 @@ class TestCassandraStorageProcessor:
assert processor.cassandra_host == ['cassandra'] # Updated default
assert processor.cassandra_username is None
assert processor.cassandra_password is None
assert processor.table is None
assert processor._connections == {}
assert isinstance(processor._conn_lock, asyncio.Lock)
def test_processor_initialization_with_custom_params(self):
"""Test processor initialization with custom parameters (new cassandra_* names)"""
taskgroup_mock = MagicMock()
processor = Processor(
taskgroup=taskgroup_mock,
id='custom-storage',
@ -37,11 +40,12 @@ class TestCassandraStorageProcessor:
cassandra_username='testuser',
cassandra_password='testpass'
)
assert processor.cassandra_host == ['cassandra.example.com']
assert processor.cassandra_username == 'testuser'
assert processor.cassandra_password == 'testpass'
assert processor.table is None
assert processor._connections == {}
assert isinstance(processor._conn_lock, asyncio.Lock)
def test_processor_initialization_with_partial_auth(self):
"""Test processor initialization with only username (no password)"""
@ -92,6 +96,7 @@ class TestCassandraStorageProcessor:
"""Test table switching logic when authentication is provided"""
taskgroup_mock = MagicMock()
mock_tg_instance = MagicMock()
mock_tg_instance.async_insert = AsyncMock()
mock_kg_class.return_value = mock_tg_instance
processor = Processor(
@ -114,7 +119,6 @@ class TestCassandraStorageProcessor:
username='testuser',
password='testpass'
)
assert processor.table == 'user1'
@pytest.mark.asyncio
@patch('trustgraph.storage.triples.cassandra.write.EntityCentricKnowledgeGraph')
@ -122,6 +126,7 @@ class TestCassandraStorageProcessor:
"""Test table switching logic when no authentication is provided"""
taskgroup_mock = MagicMock()
mock_tg_instance = MagicMock()
mock_tg_instance.async_insert = AsyncMock()
mock_kg_class.return_value = mock_tg_instance
processor = Processor(taskgroup=taskgroup_mock)
@ -138,7 +143,6 @@ class TestCassandraStorageProcessor:
hosts=['cassandra'], # Updated default
keyspace='user2'
)
assert processor.table == 'user2'
@pytest.mark.asyncio
@patch('trustgraph.storage.triples.cassandra.write.EntityCentricKnowledgeGraph')
@ -146,6 +150,7 @@ class TestCassandraStorageProcessor:
"""Test that TrustGraph is not recreated when table hasn't changed"""
taskgroup_mock = MagicMock()
mock_tg_instance = MagicMock()
mock_tg_instance.async_insert = AsyncMock()
mock_kg_class.return_value = mock_tg_instance
processor = Processor(taskgroup=taskgroup_mock)
@ -169,6 +174,7 @@ class TestCassandraStorageProcessor:
"""Test that triples are properly inserted into Cassandra"""
taskgroup_mock = MagicMock()
mock_tg_instance = MagicMock()
mock_tg_instance.async_insert = AsyncMock()
mock_kg_class.return_value = mock_tg_instance
processor = Processor(taskgroup=taskgroup_mock)
@ -208,12 +214,12 @@ class TestCassandraStorageProcessor:
await processor.store_triples('user1', mock_message)
# Verify both triples were inserted (with g=, otype=, dtype=, lang= parameters)
assert mock_tg_instance.insert.call_count == 2
mock_tg_instance.insert.assert_any_call(
assert mock_tg_instance.async_insert.call_count == 2
mock_tg_instance.async_insert.assert_any_call(
'collection1', 'subject1', 'predicate1', 'object1',
g=DEFAULT_GRAPH, otype='l', dtype='', lang=''
)
mock_tg_instance.insert.assert_any_call(
mock_tg_instance.async_insert.assert_any_call(
'collection1', 'subject2', 'predicate2', 'object2',
g=DEFAULT_GRAPH, otype='l', dtype='', lang=''
)
@ -224,6 +230,7 @@ class TestCassandraStorageProcessor:
"""Test behavior when message has no triples"""
taskgroup_mock = MagicMock()
mock_tg_instance = MagicMock()
mock_tg_instance.async_insert = AsyncMock()
mock_kg_class.return_value = mock_tg_instance
processor = Processor(taskgroup=taskgroup_mock)
@ -236,19 +243,17 @@ class TestCassandraStorageProcessor:
await processor.store_triples('user1', mock_message)
# Verify no triples were inserted
mock_tg_instance.insert.assert_not_called()
mock_tg_instance.async_insert.assert_not_called()
@pytest.mark.asyncio
@patch('trustgraph.storage.triples.cassandra.write.EntityCentricKnowledgeGraph')
@patch('trustgraph.storage.triples.cassandra.write.time.sleep')
async def test_exception_handling_with_retry(self, mock_sleep, mock_kg_class):
async def test_exception_handling_on_connection_failure(self, mock_kg_class):
"""Test exception handling during TrustGraph creation"""
taskgroup_mock = MagicMock()
mock_kg_class.side_effect = Exception("Connection failed")
processor = Processor(taskgroup=taskgroup_mock)
# Create mock message
mock_message = MagicMock()
mock_message.metadata.collection = 'collection1'
mock_message.triples = []
@ -256,9 +261,6 @@ class TestCassandraStorageProcessor:
with pytest.raises(Exception, match="Connection failed"):
await processor.store_triples('user1', mock_message)
# Verify sleep was called before re-raising
mock_sleep.assert_called_once_with(1)
def test_add_args_method(self):
"""Test that add_args properly configures argument parser"""
from argparse import ArgumentParser
@ -359,8 +361,6 @@ class TestCassandraStorageProcessor:
mock_message1.triples = []
await processor.store_triples('user1', mock_message1)
assert processor.table == 'user1'
assert processor.tg == mock_tg_instance1
# Second message with different table
mock_message2 = MagicMock()
@ -368,11 +368,11 @@ class TestCassandraStorageProcessor:
mock_message2.triples = []
await processor.store_triples('user2', mock_message2)
assert processor.table == 'user2'
assert processor.tg == mock_tg_instance2
# Verify TrustGraph was created twice for different tables
# Verify TrustGraph was created twice for different workspaces
assert mock_kg_class.call_count == 2
mock_kg_class.assert_any_call(hosts=['cassandra'], keyspace='user1')
mock_kg_class.assert_any_call(hosts=['cassandra'], keyspace='user2')
@pytest.mark.asyncio
@patch('trustgraph.storage.triples.cassandra.write.EntityCentricKnowledgeGraph')
@ -380,6 +380,7 @@ class TestCassandraStorageProcessor:
"""Test storing triples with special characters and unicode"""
taskgroup_mock = MagicMock()
mock_tg_instance = MagicMock()
mock_tg_instance.async_insert = AsyncMock()
mock_kg_class.return_value = mock_tg_instance
processor = Processor(taskgroup=taskgroup_mock)
@ -405,7 +406,7 @@ class TestCassandraStorageProcessor:
await processor.store_triples('test_workspace', mock_message)
# Verify the triple was inserted with special characters preserved
mock_tg_instance.insert.assert_called_once_with(
mock_tg_instance.async_insert.assert_called_once_with(
'test_collection',
'subject with spaces & symbols',
'predicate:with/colons',
@ -418,29 +419,29 @@ class TestCassandraStorageProcessor:
@pytest.mark.asyncio
@patch('trustgraph.storage.triples.cassandra.write.EntityCentricKnowledgeGraph')
async def test_store_triples_preserves_old_table_on_exception(self, mock_kg_class):
"""Test that table remains unchanged when TrustGraph creation fails"""
async def test_connection_failure_does_not_cache_stale_state(self, mock_kg_class):
"""Test that a failed connection doesn't leave stale cached state"""
taskgroup_mock = MagicMock()
mock_good_instance = MagicMock()
processor = Processor(taskgroup=taskgroup_mock)
# Set an initial table
processor.table = ('old_user', 'old_collection')
# Mock TrustGraph to raise exception
mock_kg_class.side_effect = Exception("Connection failed")
mock_message = MagicMock()
mock_message.metadata.collection = 'new_collection'
mock_message.metadata.collection = 'collection1'
mock_message.triples = []
# First call fails
mock_kg_class.side_effect = Exception("Connection failed")
with pytest.raises(Exception, match="Connection failed"):
await processor.store_triples('new_user', mock_message)
await processor.store_triples('user1', mock_message)
# Table should remain unchanged since self.table = table happens after try/except
assert processor.table == ('old_user', 'old_collection')
# TrustGraph should be set to None though
assert processor.tg is None
# Second call succeeds — should retry connection, not use stale state
mock_kg_class.side_effect = None
mock_kg_class.return_value = mock_good_instance
await processor.store_triples('user1', mock_message)
# Connection was attempted twice (failed + succeeded)
assert mock_kg_class.call_count == 2
class TestCassandraPerformanceOptimizations:
@ -452,6 +453,7 @@ class TestCassandraPerformanceOptimizations:
"""Test that legacy mode still works with single table"""
taskgroup_mock = MagicMock()
mock_tg_instance = MagicMock()
mock_tg_instance.async_insert = AsyncMock()
mock_kg_class.return_value = mock_tg_instance
with patch.dict('os.environ', {'CASSANDRA_USE_LEGACY': 'true'}):
@ -472,6 +474,7 @@ class TestCassandraPerformanceOptimizations:
"""Test that optimized mode uses multi-table schema"""
taskgroup_mock = MagicMock()
mock_tg_instance = MagicMock()
mock_tg_instance.async_insert = AsyncMock()
mock_kg_class.return_value = mock_tg_instance
with patch.dict('os.environ', {'CASSANDRA_USE_LEGACY': 'false'}):
@ -492,6 +495,7 @@ class TestCassandraPerformanceOptimizations:
"""Test that all tables stay consistent during batch writes"""
taskgroup_mock = MagicMock()
mock_tg_instance = MagicMock()
mock_tg_instance.async_insert = AsyncMock()
mock_kg_class.return_value = mock_tg_instance
processor = Processor(taskgroup=taskgroup_mock)
@ -517,7 +521,7 @@ class TestCassandraPerformanceOptimizations:
await processor.store_triples('user1', mock_message)
# Verify insert was called for the triple (implementation details tested in KnowledgeGraph)
mock_tg_instance.insert.assert_called_once_with(
mock_tg_instance.async_insert.assert_called_once_with(
'collection1', 'test_subject', 'test_predicate', 'test_object',
g=DEFAULT_GRAPH, otype='l', dtype='', lang=''
)

View file

@ -89,7 +89,8 @@ class TestSanitizeName:
class TestFindCollection:
def test_finds_matching_collection(self):
@pytest.mark.asyncio
async def test_finds_matching_collection(self):
proc = _make_processor()
mock_coll = MagicMock()
mock_coll.name = "rows_test_workspace_test_col_customers_384"
@ -98,11 +99,12 @@ class TestFindCollection:
mock_collections.collections = [mock_coll]
proc.qdrant.get_collections.return_value = mock_collections
result = proc.find_collection("test-workspace", "test-col", "customers")
result = await proc.find_collection("test-workspace", "test-col", "customers")
assert result == "rows_test_workspace_test_col_customers_384"
def test_returns_none_when_no_match(self):
@pytest.mark.asyncio
async def test_returns_none_when_no_match(self):
proc = _make_processor()
mock_coll = MagicMock()
mock_coll.name = "rows_other_workspace_other_col_schema_768"
@ -111,14 +113,15 @@ class TestFindCollection:
mock_collections.collections = [mock_coll]
proc.qdrant.get_collections.return_value = mock_collections
result = proc.find_collection("test-workspace", "test-col", "customers")
result = await proc.find_collection("test-workspace", "test-col", "customers")
assert result is None
def test_returns_none_on_error(self):
@pytest.mark.asyncio
async def test_returns_none_on_error(self):
proc = _make_processor()
proc.qdrant.get_collections.side_effect = Exception("connection error")
result = proc.find_collection("workspace", "col", "schema")
result = await proc.find_collection("workspace", "col", "schema")
assert result is None
@ -139,7 +142,7 @@ class TestQueryRowEmbeddings:
@pytest.mark.asyncio
async def test_no_collection_returns_empty(self):
proc = _make_processor()
proc.find_collection = MagicMock(return_value=None)
proc.find_collection = AsyncMock(return_value=None)
request = _make_request()
result = await proc.query_row_embeddings("test-workspace", request)
@ -148,7 +151,7 @@ class TestQueryRowEmbeddings:
@pytest.mark.asyncio
async def test_successful_query_returns_matches(self):
proc = _make_processor()
proc.find_collection = MagicMock(return_value="rows_w_c_s_384")
proc.find_collection = AsyncMock(return_value="rows_w_c_s_384")
points = [
_make_search_point("name", ["Alice Smith"], "Alice Smith", 0.95),
@ -172,7 +175,7 @@ class TestQueryRowEmbeddings:
async def test_index_name_filter_applied(self):
"""When index_name is specified, a Qdrant filter should be used."""
proc = _make_processor()
proc.find_collection = MagicMock(return_value="rows_w_c_s_384")
proc.find_collection = AsyncMock(return_value="rows_w_c_s_384")
mock_result = MagicMock()
mock_result.points = []
@ -188,7 +191,7 @@ class TestQueryRowEmbeddings:
async def test_no_index_name_no_filter(self):
"""When index_name is empty, no filter should be applied."""
proc = _make_processor()
proc.find_collection = MagicMock(return_value="rows_w_c_s_384")
proc.find_collection = AsyncMock(return_value="rows_w_c_s_384")
mock_result = MagicMock()
mock_result.points = []
@ -204,7 +207,7 @@ class TestQueryRowEmbeddings:
async def test_missing_payload_fields_default(self):
"""Points with missing payload fields should use defaults."""
proc = _make_processor()
proc.find_collection = MagicMock(return_value="rows_w_c_s_384")
proc.find_collection = AsyncMock(return_value="rows_w_c_s_384")
point = MagicMock()
point.payload = {} # Empty payload
@ -225,7 +228,7 @@ class TestQueryRowEmbeddings:
@pytest.mark.asyncio
async def test_qdrant_error_propagates(self):
proc = _make_processor()
proc.find_collection = MagicMock(return_value="rows_w_c_s_384")
proc.find_collection = AsyncMock(return_value="rows_w_c_s_384")
proc.qdrant.query_points.side_effect = Exception("qdrant down")
request = _make_request()