mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-07-01 09:29:38 +02:00
Merge remote-tracking branch 'origin/master' into ts-port-effect-v4
This commit is contained in:
commit
92dae8c374
117 changed files with 7392 additions and 3410 deletions
296
tests/unit/test_api/test_library_api.py
Normal file
296
tests/unit/test_api/test_library_api.py
Normal file
|
|
@ -0,0 +1,296 @@
|
|||
"""
|
||||
Tests for the Library API wrapper round-trip behavior.
|
||||
Covers the get_documents → update_document path and edge cases
|
||||
from issue #893.
|
||||
"""
|
||||
|
||||
import datetime
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from trustgraph.api.library import Library, to_value, from_value
|
||||
from trustgraph.api.types import DocumentMetadata, Triple
|
||||
from trustgraph.knowledge import Uri, Literal
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _make_library(response=None):
|
||||
api = MagicMock()
|
||||
api.workspace = "default"
|
||||
api.request.return_value = response or {}
|
||||
lib = Library(api)
|
||||
return lib, api
|
||||
|
||||
|
||||
def _wire_triple(s_iri, p_iri, o_val):
|
||||
return {
|
||||
"s": {"t": "i", "i": s_iri},
|
||||
"p": {"t": "i", "i": p_iri},
|
||||
"o": {"t": "l", "v": o_val},
|
||||
}
|
||||
|
||||
|
||||
def _doc_wire(id="doc-1", time=1700000000, title="Test Doc",
|
||||
kind="text/plain", comments="", tags=None,
|
||||
metadata=None, parent_id="", document_type="source",
|
||||
include_title=True):
|
||||
doc = {
|
||||
"id": id,
|
||||
"time": time,
|
||||
"kind": kind,
|
||||
"comments": comments,
|
||||
"metadata": metadata or [],
|
||||
"tags": tags or [],
|
||||
"parent-id": parent_id,
|
||||
"document-type": document_type,
|
||||
}
|
||||
if include_title:
|
||||
doc["title"] = title
|
||||
return doc
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Bug 1: get_documents tolerates missing title
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestGetDocumentsMissingTitle:
|
||||
|
||||
def test_missing_title_defaults_to_empty(self):
|
||||
doc = _doc_wire(include_title=False)
|
||||
lib, api = _make_library({"document-metadatas": [doc]})
|
||||
|
||||
result = lib.get_documents()
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].title == ""
|
||||
|
||||
def test_present_title_preserved(self):
|
||||
doc = _doc_wire(title="My Title")
|
||||
lib, api = _make_library({"document-metadatas": [doc]})
|
||||
|
||||
result = lib.get_documents()
|
||||
|
||||
assert result[0].title == "My Title"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Bug 2: update_document handles Triple objects (attribute access)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestUpdateDocumentTripleAccess:
|
||||
|
||||
def test_triple_objects_serialized_correctly(self):
|
||||
lib, api = _make_library({})
|
||||
|
||||
metadata = DocumentMetadata(
|
||||
id="doc-1",
|
||||
time=datetime.datetime.fromtimestamp(1700000000),
|
||||
kind="text/plain",
|
||||
title="Test",
|
||||
comments="",
|
||||
metadata=[
|
||||
Triple(
|
||||
s=Uri("http://example.org/entity/alice"),
|
||||
p=Uri("http://example.org/rel/knows"),
|
||||
o=Literal("Bob"),
|
||||
),
|
||||
],
|
||||
tags=["test"],
|
||||
)
|
||||
|
||||
lib.update_document(id="doc-1", metadata=metadata)
|
||||
|
||||
call_args = api.request.call_args[0][1]
|
||||
triples = call_args["document-metadata"]["metadata"]
|
||||
|
||||
assert len(triples) == 1
|
||||
assert triples[0]["s"]["i"] == "http://example.org/entity/alice"
|
||||
assert triples[0]["p"]["i"] == "http://example.org/rel/knows"
|
||||
assert triples[0]["o"]["v"] == "Bob"
|
||||
|
||||
def test_empty_metadata_list(self):
|
||||
lib, api = _make_library({})
|
||||
|
||||
metadata = DocumentMetadata(
|
||||
id="doc-1",
|
||||
time=datetime.datetime.fromtimestamp(1700000000),
|
||||
kind="text/plain",
|
||||
title="Test",
|
||||
comments="",
|
||||
metadata=[],
|
||||
tags=[],
|
||||
)
|
||||
|
||||
lib.update_document(id="doc-1", metadata=metadata)
|
||||
|
||||
call_args = api.request.call_args[0][1]
|
||||
assert call_args["document-metadata"]["metadata"] == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Bug 3: update_document serializes datetime to int seconds
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestUpdateDocumentTimeSerialization:
|
||||
|
||||
def test_datetime_serialized_to_int(self):
|
||||
lib, api = _make_library({})
|
||||
|
||||
ts = 1700000000
|
||||
metadata = DocumentMetadata(
|
||||
id="doc-1",
|
||||
time=datetime.datetime.fromtimestamp(ts),
|
||||
kind="text/plain",
|
||||
title="Test",
|
||||
comments="",
|
||||
metadata=[],
|
||||
tags=[],
|
||||
)
|
||||
|
||||
lib.update_document(id="doc-1", metadata=metadata)
|
||||
|
||||
call_args = api.request.call_args[0][1]
|
||||
wire_time = call_args["document-metadata"]["time"]
|
||||
|
||||
assert isinstance(wire_time, int)
|
||||
assert wire_time == ts
|
||||
|
||||
def test_int_time_passed_through(self):
|
||||
lib, api = _make_library({})
|
||||
|
||||
metadata = DocumentMetadata(
|
||||
id="doc-1",
|
||||
time=1700000000,
|
||||
kind="text/plain",
|
||||
title="Test",
|
||||
comments="",
|
||||
metadata=[],
|
||||
tags=[],
|
||||
)
|
||||
|
||||
lib.update_document(id="doc-1", metadata=metadata)
|
||||
|
||||
call_args = api.request.call_args[0][1]
|
||||
assert call_args["document-metadata"]["time"] == 1700000000
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Bug 4: update_document handles empty server response
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestUpdateDocumentEmptyResponse:
|
||||
|
||||
def test_empty_response_returns_input_metadata(self):
|
||||
lib, api = _make_library({})
|
||||
|
||||
metadata = DocumentMetadata(
|
||||
id="doc-1",
|
||||
time=datetime.datetime.fromtimestamp(1700000000),
|
||||
kind="text/plain",
|
||||
title="Updated Title",
|
||||
comments="notes",
|
||||
metadata=[],
|
||||
tags=["a"],
|
||||
)
|
||||
|
||||
result = lib.update_document(id="doc-1", metadata=metadata)
|
||||
|
||||
assert result is metadata
|
||||
|
||||
def test_full_response_parsed(self):
|
||||
response_doc = _doc_wire(
|
||||
id="doc-1", title="Server Title", tags=["b"],
|
||||
)
|
||||
lib, api = _make_library({"document-metadata": response_doc})
|
||||
|
||||
metadata = DocumentMetadata(
|
||||
id="doc-1",
|
||||
time=datetime.datetime.fromtimestamp(1700000000),
|
||||
kind="text/plain",
|
||||
title="Client Title",
|
||||
comments="",
|
||||
metadata=[],
|
||||
tags=["a"],
|
||||
)
|
||||
|
||||
result = lib.update_document(id="doc-1", metadata=metadata)
|
||||
|
||||
assert result.title == "Server Title"
|
||||
assert result.tags == ["b"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Bug 5: update_document sends both id and document-id
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestUpdateDocumentIdKeys:
|
||||
|
||||
def test_both_id_keys_sent(self):
|
||||
lib, api = _make_library({})
|
||||
|
||||
metadata = DocumentMetadata(
|
||||
id="doc-1",
|
||||
time=datetime.datetime.fromtimestamp(1700000000),
|
||||
kind="text/plain",
|
||||
title="Test",
|
||||
comments="",
|
||||
metadata=[],
|
||||
tags=[],
|
||||
)
|
||||
|
||||
lib.update_document(id="doc-1", metadata=metadata)
|
||||
|
||||
call_args = api.request.call_args[0][1]
|
||||
doc_meta = call_args["document-metadata"]
|
||||
|
||||
assert doc_meta["id"] == "doc-1"
|
||||
assert doc_meta["document-id"] == "doc-1"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Round-trip: get_documents → update_document
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestGetUpdateRoundTrip:
|
||||
|
||||
def test_full_round_trip(self):
|
||||
wire_doc = _doc_wire(
|
||||
id="doc-42",
|
||||
title="Original",
|
||||
tags=["v1"],
|
||||
metadata=[_wire_triple(
|
||||
"http://example.org/e/1",
|
||||
"http://example.org/r/type",
|
||||
"report",
|
||||
)],
|
||||
)
|
||||
|
||||
lib, api = _make_library({"document-metadatas": [wire_doc]})
|
||||
|
||||
docs = lib.get_documents()
|
||||
assert len(docs) == 1
|
||||
|
||||
doc = docs[0]
|
||||
doc.title = "Updated"
|
||||
doc.tags.append("v2")
|
||||
|
||||
# Server returns empty on update
|
||||
api.request.return_value = {}
|
||||
result = lib.update_document(id=doc.id, metadata=doc)
|
||||
|
||||
# Should not raise, should return the input metadata
|
||||
assert result.title == "Updated"
|
||||
assert "v2" in result.tags
|
||||
|
||||
# Verify the wire format sent
|
||||
call_args = api.request.call_args[0][1]
|
||||
doc_meta = call_args["document-metadata"]
|
||||
|
||||
assert doc_meta["id"] == "doc-42"
|
||||
assert doc_meta["title"] == "Updated"
|
||||
assert isinstance(doc_meta["time"], int)
|
||||
assert len(doc_meta["metadata"]) == 1
|
||||
assert doc_meta["metadata"][0]["o"]["v"] == "report"
|
||||
|
|
@ -272,23 +272,22 @@ class TestMetricsIntegration:
|
|||
class TestPollTimeout:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_poll_timeout_is_100ms(self):
|
||||
"""Consumer receive timeout should be 100ms, not the original 2000ms.
|
||||
async def test_poll_timeout_is_2000ms(self):
|
||||
"""Consumer receive timeout should be 2000ms.
|
||||
|
||||
A 2000ms poll timeout means every service adds up to 2s of idle
|
||||
blocking between message bursts. With many sequential hops in a
|
||||
query pipeline, this compounds into seconds of unnecessary latency.
|
||||
100ms keeps responsiveness high without significant CPU overhead.
|
||||
receive() is a blocking call that returns immediately when a
|
||||
message arrives — the timeout only governs how often the loop
|
||||
checks the shutdown flag during idle periods. Lower values
|
||||
(e.g. 100ms) generate excessive C++ client WARN logging with
|
||||
no latency benefit.
|
||||
"""
|
||||
consumer = _make_consumer()
|
||||
|
||||
# Wire up a mock Pulsar consumer that records the receive kwargs
|
||||
mock_pulsar_consumer = MagicMock()
|
||||
received_kwargs = {}
|
||||
|
||||
def capture_receive(**kwargs):
|
||||
received_kwargs.update(kwargs)
|
||||
# Stop after one call
|
||||
consumer.running = False
|
||||
raise type('Timeout', (Exception,), {})("timeout")
|
||||
|
||||
|
|
@ -296,7 +295,7 @@ class TestPollTimeout:
|
|||
|
||||
await consumer.consume_from_queue(mock_pulsar_consumer)
|
||||
|
||||
assert received_kwargs.get("timeout_millis") == 100
|
||||
assert received_kwargs.get("timeout_millis") == 2000
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
|
|||
|
|
@ -25,16 +25,17 @@ class TestSemaphoreEnforcement:
|
|||
max_concurrent = 0
|
||||
processing_event = asyncio.Event()
|
||||
|
||||
async def slow_process(message):
|
||||
async def slow_process(message, sender):
|
||||
nonlocal concurrent_count, max_concurrent
|
||||
concurrent_count += 1
|
||||
max_concurrent = max(max_concurrent, concurrent_count)
|
||||
await asyncio.sleep(0.05)
|
||||
concurrent_count -= 1
|
||||
return {"id": message.get("id"), "response": {"ok": True}}
|
||||
|
||||
dispatcher._process_message = slow_process
|
||||
|
||||
sender = AsyncMock()
|
||||
|
||||
# Launch more tasks than max_workers
|
||||
messages = [
|
||||
{"id": f"msg-{i}", "service": "test", "request": {}}
|
||||
|
|
@ -42,7 +43,7 @@ class TestSemaphoreEnforcement:
|
|||
]
|
||||
|
||||
tasks = [
|
||||
asyncio.create_task(dispatcher.handle_message(m))
|
||||
asyncio.create_task(dispatcher.handle_message(m, sender))
|
||||
for m in messages
|
||||
]
|
||||
|
||||
|
|
@ -66,17 +67,17 @@ class TestSemaphoreEnforcement:
|
|||
|
||||
original_process = dispatcher._process_message
|
||||
|
||||
async def tracking_process(message):
|
||||
async def tracking_process(message, sender):
|
||||
nonlocal task_was_tracked
|
||||
# During processing, our task should be in active_tasks
|
||||
if len(dispatcher.active_tasks) > 0:
|
||||
task_was_tracked = True
|
||||
return {"id": message.get("id"), "response": {"ok": True}}
|
||||
|
||||
dispatcher._process_message = tracking_process
|
||||
|
||||
await dispatcher.handle_message(
|
||||
{"id": "test", "service": "test", "request": {}}
|
||||
{"id": "test", "service": "test", "request": {}},
|
||||
AsyncMock(),
|
||||
)
|
||||
|
||||
assert task_was_tracked
|
||||
|
|
@ -88,7 +89,7 @@ class TestSemaphoreEnforcement:
|
|||
"""Semaphore should be released even if processing raises."""
|
||||
dispatcher = MessageDispatcher(max_workers=2)
|
||||
|
||||
async def failing_process(message):
|
||||
async def failing_process(message, sender):
|
||||
raise RuntimeError("process failed")
|
||||
|
||||
dispatcher._process_message = failing_process
|
||||
|
|
@ -96,7 +97,8 @@ class TestSemaphoreEnforcement:
|
|||
# Should not deadlock — semaphore must be released on error
|
||||
with pytest.raises(RuntimeError):
|
||||
await dispatcher.handle_message(
|
||||
{"id": "test", "service": "test", "request": {}}
|
||||
{"id": "test", "service": "test", "request": {}},
|
||||
AsyncMock(),
|
||||
)
|
||||
|
||||
# Semaphore should be back at max
|
||||
|
|
@ -109,17 +111,18 @@ class TestSemaphoreEnforcement:
|
|||
|
||||
order = []
|
||||
|
||||
async def ordered_process(message):
|
||||
async def ordered_process(message, sender):
|
||||
msg_id = message["id"]
|
||||
order.append(f"start-{msg_id}")
|
||||
await asyncio.sleep(0.02)
|
||||
order.append(f"end-{msg_id}")
|
||||
return {"id": msg_id, "response": {"ok": True}}
|
||||
|
||||
dispatcher._process_message = ordered_process
|
||||
|
||||
sender = AsyncMock()
|
||||
|
||||
messages = [{"id": str(i), "service": "t", "request": {}} for i in range(3)]
|
||||
tasks = [asyncio.create_task(dispatcher.handle_message(m)) for m in messages]
|
||||
tasks = [asyncio.create_task(dispatcher.handle_message(m, sender)) for m in messages]
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
# With semaphore=1, each message should complete before next starts
|
||||
|
|
|
|||
|
|
@ -0,0 +1,389 @@
|
|||
"""
|
||||
Tests for TripleConverter domain/range enforcement and
|
||||
OntologySelector bypass for small ontologies.
|
||||
|
||||
Covers fixes for #908 (bypass_selector_below) and #920 (domain/range validation).
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import Mock, AsyncMock
|
||||
|
||||
from trustgraph.extract.kg.ontology.triple_converter import TripleConverter
|
||||
from trustgraph.extract.kg.ontology.ontology_selector import (
|
||||
OntologySelector,
|
||||
OntologySubset,
|
||||
)
|
||||
from trustgraph.extract.kg.ontology.ontology_loader import (
|
||||
Ontology,
|
||||
OntologyClass,
|
||||
OntologyProperty,
|
||||
)
|
||||
from trustgraph.extract.kg.ontology.simplified_parser import (
|
||||
Relationship,
|
||||
Attribute,
|
||||
)
|
||||
from trustgraph.extract.kg.ontology.text_processor import TextSegment
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.fixture
|
||||
def ontology_subset():
|
||||
"""Ontology subset with classes, hierarchy, and constrained properties."""
|
||||
return OntologySubset(
|
||||
ontology_id="test",
|
||||
classes={
|
||||
"Person": {
|
||||
"uri": "http://example.org/Person",
|
||||
"type": "owl:Class",
|
||||
"labels": [{"value": "Person"}],
|
||||
"subclass_of": None,
|
||||
},
|
||||
"Employee": {
|
||||
"uri": "http://example.org/Employee",
|
||||
"type": "owl:Class",
|
||||
"labels": [{"value": "Employee"}],
|
||||
"subclass_of": "Person",
|
||||
},
|
||||
"Manager": {
|
||||
"uri": "http://example.org/Manager",
|
||||
"type": "owl:Class",
|
||||
"labels": [{"value": "Manager"}],
|
||||
"subclass_of": "Employee",
|
||||
},
|
||||
"Company": {
|
||||
"uri": "http://example.org/Company",
|
||||
"type": "owl:Class",
|
||||
"labels": [{"value": "Company"}],
|
||||
"subclass_of": None,
|
||||
},
|
||||
"Product": {
|
||||
"uri": "http://example.org/Product",
|
||||
"type": "owl:Class",
|
||||
"labels": [{"value": "Product"}],
|
||||
"subclass_of": None,
|
||||
},
|
||||
},
|
||||
object_properties={
|
||||
"worksFor": {
|
||||
"uri": "http://example.org/worksFor",
|
||||
"type": "owl:ObjectProperty",
|
||||
"labels": [{"value": "works for"}],
|
||||
"domain": "Person",
|
||||
"range": "Company",
|
||||
},
|
||||
"manages": {
|
||||
"uri": "http://example.org/manages",
|
||||
"type": "owl:ObjectProperty",
|
||||
"labels": [{"value": "manages"}],
|
||||
"domain": "Manager",
|
||||
"range": "Employee",
|
||||
},
|
||||
"relatedTo": {
|
||||
"uri": "http://example.org/relatedTo",
|
||||
"type": "owl:ObjectProperty",
|
||||
"labels": [{"value": "related to"}],
|
||||
"domain": None,
|
||||
"range": None,
|
||||
},
|
||||
},
|
||||
datatype_properties={
|
||||
"employeeId": {
|
||||
"uri": "http://example.org/employeeId",
|
||||
"type": "owl:DatatypeProperty",
|
||||
"labels": [{"value": "employee ID"}],
|
||||
"domain": "Employee",
|
||||
},
|
||||
"description": {
|
||||
"uri": "http://example.org/description",
|
||||
"type": "owl:DatatypeProperty",
|
||||
"labels": [{"value": "description"}],
|
||||
"domain": None,
|
||||
},
|
||||
},
|
||||
metadata={"name": "Test Ontology"},
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def converter(ontology_subset):
|
||||
return TripleConverter(ontology_subset=ontology_subset, ontology_id="test")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Domain/range enforcement — relationships
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestRelationshipDomainRange:
|
||||
|
||||
def test_valid_domain_and_range(self, converter):
|
||||
rel = Relationship(
|
||||
subject="Alice", subject_type="Person",
|
||||
relation="worksFor",
|
||||
object="Acme Corp", object_type="Company",
|
||||
)
|
||||
triple = converter.convert_relationship(rel)
|
||||
assert triple is not None
|
||||
|
||||
def test_domain_violation_rejected(self, converter):
|
||||
rel = Relationship(
|
||||
subject="Widget", subject_type="Product",
|
||||
relation="worksFor",
|
||||
object="Acme Corp", object_type="Company",
|
||||
)
|
||||
assert converter.convert_relationship(rel) is None
|
||||
|
||||
def test_range_violation_rejected(self, converter):
|
||||
rel = Relationship(
|
||||
subject="Alice", subject_type="Person",
|
||||
relation="worksFor",
|
||||
object="Widget", object_type="Product",
|
||||
)
|
||||
assert converter.convert_relationship(rel) is None
|
||||
|
||||
def test_both_domain_and_range_violated(self, converter):
|
||||
rel = Relationship(
|
||||
subject="Widget", subject_type="Product",
|
||||
relation="worksFor",
|
||||
object="Gadget", object_type="Product",
|
||||
)
|
||||
assert converter.convert_relationship(rel) is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Subclass acceptance
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestSubclassAcceptance:
|
||||
|
||||
def test_direct_subclass_matches_domain(self, converter):
|
||||
"""Employee is subclass of Person; worksFor domain is Person."""
|
||||
rel = Relationship(
|
||||
subject="Bob", subject_type="Employee",
|
||||
relation="worksFor",
|
||||
object="Acme Corp", object_type="Company",
|
||||
)
|
||||
assert converter.convert_relationship(rel) is not None
|
||||
|
||||
def test_transitive_subclass_matches_domain(self, converter):
|
||||
"""Manager → Employee → Person; worksFor domain is Person."""
|
||||
rel = Relationship(
|
||||
subject="Carol", subject_type="Manager",
|
||||
relation="worksFor",
|
||||
object="Acme Corp", object_type="Company",
|
||||
)
|
||||
assert converter.convert_relationship(rel) is not None
|
||||
|
||||
def test_subclass_matches_range(self, converter):
|
||||
"""manages range is Employee; Manager is subclass of Employee."""
|
||||
rel = Relationship(
|
||||
subject="Carol", subject_type="Manager",
|
||||
relation="manages",
|
||||
object="Dave", object_type="Manager",
|
||||
)
|
||||
assert converter.convert_relationship(rel) is not None
|
||||
|
||||
def test_superclass_does_not_match_subclass_constraint(self, converter):
|
||||
"""manages domain is Manager; Person is NOT a subclass of Manager."""
|
||||
rel = Relationship(
|
||||
subject="Alice", subject_type="Person",
|
||||
relation="manages",
|
||||
object="Bob", object_type="Employee",
|
||||
)
|
||||
assert converter.convert_relationship(rel) is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Polymorphic properties (no domain/range)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestPolymorphicProperties:
|
||||
|
||||
def test_no_domain_no_range_allows_anything(self, converter):
|
||||
rel = Relationship(
|
||||
subject="Alice", subject_type="Person",
|
||||
relation="relatedTo",
|
||||
object="Acme Corp", object_type="Company",
|
||||
)
|
||||
assert converter.convert_relationship(rel) is not None
|
||||
|
||||
def test_polymorphic_with_unrelated_types(self, converter):
|
||||
rel = Relationship(
|
||||
subject="Widget", subject_type="Product",
|
||||
relation="relatedTo",
|
||||
object="Bob", object_type="Employee",
|
||||
)
|
||||
assert converter.convert_relationship(rel) is not None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Datatype property domain enforcement
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestAttributeDomainValidation:
|
||||
|
||||
def test_valid_domain(self, converter):
|
||||
attr = Attribute(
|
||||
entity="Bob", entity_type="Employee",
|
||||
attribute="employeeId", value="E-1234",
|
||||
)
|
||||
assert converter.convert_attribute(attr) is not None
|
||||
|
||||
def test_subclass_matches_domain(self, converter):
|
||||
"""Manager is subclass of Employee; employeeId domain is Employee."""
|
||||
attr = Attribute(
|
||||
entity="Carol", entity_type="Manager",
|
||||
attribute="employeeId", value="M-5678",
|
||||
)
|
||||
assert converter.convert_attribute(attr) is not None
|
||||
|
||||
def test_domain_violation_rejected(self, converter):
|
||||
attr = Attribute(
|
||||
entity="Acme Corp", entity_type="Company",
|
||||
attribute="employeeId", value="E-0000",
|
||||
)
|
||||
assert converter.convert_attribute(attr) is None
|
||||
|
||||
def test_no_domain_allows_anything(self, converter):
|
||||
attr = Attribute(
|
||||
entity="Widget", entity_type="Product",
|
||||
attribute="description", value="A useful widget",
|
||||
)
|
||||
assert converter.convert_attribute(attr) is not None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# OntologySelector bypass for small ontologies (#908)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _make_ontology(n_classes, n_obj_props=0, n_dt_props=0):
|
||||
classes = {
|
||||
f"C{i}": OntologyClass(uri=f"http://example.org/C{i}")
|
||||
for i in range(n_classes)
|
||||
}
|
||||
obj_props = {
|
||||
f"op{i}": OntologyProperty(
|
||||
uri=f"http://example.org/op{i}", type="owl:ObjectProperty"
|
||||
)
|
||||
for i in range(n_obj_props)
|
||||
}
|
||||
dt_props = {
|
||||
f"dp{i}": OntologyProperty(
|
||||
uri=f"http://example.org/dp{i}", type="owl:DatatypeProperty"
|
||||
)
|
||||
for i in range(n_dt_props)
|
||||
}
|
||||
return Ontology(
|
||||
id="tiny",
|
||||
metadata={"name": "Tiny"},
|
||||
classes=classes,
|
||||
object_properties=obj_props,
|
||||
datatype_properties=dt_props,
|
||||
)
|
||||
|
||||
|
||||
def _make_loader(ontology):
|
||||
loader = Mock()
|
||||
loader.get_ontology.return_value = ontology
|
||||
loader.get_all_ontologies.return_value = {"tiny": ontology}
|
||||
return loader
|
||||
|
||||
|
||||
class TestBypassSelectorBelow:
|
||||
|
||||
async def test_bypass_returns_full_ontology(self):
|
||||
"""With 3 elements and bypass_selector_below=5, selector is bypassed."""
|
||||
ont = _make_ontology(2, 1, 0)
|
||||
loader = _make_loader(ont)
|
||||
embedder = Mock()
|
||||
|
||||
selector = OntologySelector(
|
||||
ontology_embedder=embedder,
|
||||
ontology_loader=loader,
|
||||
bypass_selector_below=5,
|
||||
)
|
||||
|
||||
segments = [TextSegment(text="some text", type="sentence", position=0)]
|
||||
subsets = await selector.select_ontology_subset(segments)
|
||||
|
||||
assert len(subsets) == 1
|
||||
assert subsets[0].ontology_id == "tiny"
|
||||
assert len(subsets[0].classes) == 2
|
||||
assert len(subsets[0].object_properties) == 1
|
||||
assert subsets[0].relevance_score == 1.0
|
||||
# Embedder should never be called
|
||||
embedder.embed_text.assert_not_called()
|
||||
|
||||
async def test_no_bypass_when_above_threshold(self):
|
||||
"""With 10 elements and bypass_selector_below=5, selector runs normally."""
|
||||
ont = _make_ontology(6, 3, 1)
|
||||
loader = _make_loader(ont)
|
||||
|
||||
embedder = Mock()
|
||||
embedder.embed_text = AsyncMock(return_value=[0.1, 0.2])
|
||||
vector_store = Mock()
|
||||
vector_store.size.return_value = 10
|
||||
vector_store.search.return_value = []
|
||||
embedder.get_vector_store.return_value = vector_store
|
||||
|
||||
selector = OntologySelector(
|
||||
ontology_embedder=embedder,
|
||||
ontology_loader=loader,
|
||||
bypass_selector_below=5,
|
||||
)
|
||||
|
||||
segments = [TextSegment(text="some text", type="sentence", position=0)]
|
||||
subsets = await selector.select_ontology_subset(segments)
|
||||
|
||||
# Vector store was consulted (selector ran normally)
|
||||
vector_store.size.assert_called_once()
|
||||
|
||||
async def test_bypass_at_exact_threshold_not_triggered(self):
|
||||
"""With exactly 5 elements and bypass_selector_below=5, selector runs (< not <=)."""
|
||||
ont = _make_ontology(3, 1, 1) # total = 5
|
||||
loader = _make_loader(ont)
|
||||
|
||||
embedder = Mock()
|
||||
embedder.embed_text = AsyncMock(return_value=[0.1, 0.2])
|
||||
vector_store = Mock()
|
||||
vector_store.size.return_value = 5
|
||||
vector_store.search.return_value = []
|
||||
embedder.get_vector_store.return_value = vector_store
|
||||
|
||||
selector = OntologySelector(
|
||||
ontology_embedder=embedder,
|
||||
ontology_loader=loader,
|
||||
bypass_selector_below=5,
|
||||
)
|
||||
|
||||
segments = [TextSegment(text="some text", type="sentence", position=0)]
|
||||
subsets = await selector.select_ontology_subset(segments)
|
||||
|
||||
# Should NOT bypass — 5 is not < 5
|
||||
vector_store.size.assert_called_once()
|
||||
|
||||
async def test_bypass_zero_disables(self):
|
||||
"""bypass_selector_below=0 means bypass never triggers."""
|
||||
ont = _make_ontology(0, 0, 0) # empty ontology
|
||||
loader = _make_loader(ont)
|
||||
|
||||
embedder = Mock()
|
||||
embedder.embed_text = AsyncMock(return_value=[0.1])
|
||||
vector_store = Mock()
|
||||
vector_store.size.return_value = 0
|
||||
vector_store.search.return_value = []
|
||||
embedder.get_vector_store.return_value = vector_store
|
||||
|
||||
selector = OntologySelector(
|
||||
ontology_embedder=embedder,
|
||||
ontology_loader=loader,
|
||||
bypass_selector_below=0,
|
||||
)
|
||||
|
||||
segments = [TextSegment(text="some text", type="sentence", position=0)]
|
||||
subsets = await selector.select_ontology_subset(segments)
|
||||
|
||||
# 0 is not < 0, so bypass doesn't trigger
|
||||
vector_store.size.assert_called_once()
|
||||
|
|
@ -165,22 +165,37 @@ class TestIamAuthDispatch:
|
|||
by shape of the bearer."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_authorization_header_raises_401(self):
|
||||
async def test_no_authorization_header_tries_anonymous(self):
|
||||
auth = IamAuth(backend=Mock())
|
||||
with pytest.raises(web.HTTPUnauthorized):
|
||||
await auth.authenticate(make_request(None))
|
||||
|
||||
async def fake_with_client(op):
|
||||
raise RuntimeError("auth-failed: anonymous access not permitted")
|
||||
|
||||
with patch.object(auth, "_with_client", side_effect=fake_with_client):
|
||||
with pytest.raises(web.HTTPUnauthorized):
|
||||
await auth.authenticate(make_request(None))
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_bearer_header_raises_401(self):
|
||||
async def test_non_bearer_header_tries_anonymous(self):
|
||||
auth = IamAuth(backend=Mock())
|
||||
with pytest.raises(web.HTTPUnauthorized):
|
||||
await auth.authenticate(make_request("Basic whatever"))
|
||||
|
||||
async def fake_with_client(op):
|
||||
raise RuntimeError("auth-failed: anonymous access not permitted")
|
||||
|
||||
with patch.object(auth, "_with_client", side_effect=fake_with_client):
|
||||
with pytest.raises(web.HTTPUnauthorized):
|
||||
await auth.authenticate(make_request("Basic whatever"))
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_bearer_raises_401(self):
|
||||
async def test_empty_bearer_tries_anonymous(self):
|
||||
auth = IamAuth(backend=Mock())
|
||||
with pytest.raises(web.HTTPUnauthorized):
|
||||
await auth.authenticate(make_request("Bearer "))
|
||||
|
||||
async def fake_with_client(op):
|
||||
raise RuntimeError("auth-failed: anonymous access not permitted")
|
||||
|
||||
with patch.object(auth, "_with_client", side_effect=fake_with_client):
|
||||
with pytest.raises(web.HTTPUnauthorized):
|
||||
await auth.authenticate(make_request("Bearer "))
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unknown_format_raises_401(self):
|
||||
|
|
@ -445,3 +460,121 @@ class TestAuthorise:
|
|||
|
||||
# Different resource → different cache key → two IAM calls.
|
||||
assert calls["n"] == 2
|
||||
|
||||
|
||||
# -- Anonymous authentication boundary ------------------------------------
|
||||
|
||||
|
||||
class TestAnonymousAuthBoundary:
|
||||
"""The gateway must only attempt anonymous auth when no credential
|
||||
is presented. A malformed token must NOT fall through to the
|
||||
anonymous path — that would let an attacker bypass a broken token
|
||||
by simply sending garbage."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_header_attempts_anonymous(self):
|
||||
auth = IamAuth(backend=Mock())
|
||||
|
||||
async def fake_with_client(op):
|
||||
return await op(Mock(
|
||||
authenticate_anonymous=AsyncMock(
|
||||
return_value=("anon", "default", ["reader"]),
|
||||
)
|
||||
))
|
||||
|
||||
with patch.object(auth, "_with_client", side_effect=fake_with_client):
|
||||
ident = await auth.authenticate(make_request(None))
|
||||
assert ident.handle == "anon"
|
||||
assert ident.source == "anonymous"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_bearer_attempts_anonymous(self):
|
||||
auth = IamAuth(backend=Mock())
|
||||
|
||||
async def fake_with_client(op):
|
||||
return await op(Mock(
|
||||
authenticate_anonymous=AsyncMock(
|
||||
return_value=("anon", "default", ["reader"]),
|
||||
)
|
||||
))
|
||||
|
||||
with patch.object(auth, "_with_client", side_effect=fake_with_client):
|
||||
ident = await auth.authenticate(make_request("Bearer "))
|
||||
assert ident.handle == "anon"
|
||||
assert ident.source == "anonymous"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_malformed_token_does_not_fall_through_to_anonymous(self):
|
||||
auth = IamAuth(backend=Mock())
|
||||
called = {"anonymous": False}
|
||||
|
||||
original = auth._authenticate_anonymous
|
||||
|
||||
async def spy_anonymous():
|
||||
called["anonymous"] = True
|
||||
return await original()
|
||||
|
||||
auth._authenticate_anonymous = spy_anonymous
|
||||
|
||||
with pytest.raises(web.HTTPUnauthorized):
|
||||
await auth.authenticate(make_request("Bearer garbage"))
|
||||
assert not called["anonymous"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bad_api_key_does_not_fall_through_to_anonymous(self):
|
||||
auth = IamAuth(backend=Mock())
|
||||
called = {"anonymous": False}
|
||||
|
||||
async def spy_anonymous():
|
||||
called["anonymous"] = True
|
||||
|
||||
auth._authenticate_anonymous = spy_anonymous
|
||||
|
||||
async def fake_with_client(op):
|
||||
raise RuntimeError("auth-failed: unknown key")
|
||||
|
||||
with patch.object(auth, "_with_client", side_effect=fake_with_client):
|
||||
with pytest.raises(web.HTTPUnauthorized):
|
||||
await auth.authenticate(make_request("Bearer tg_bad"))
|
||||
assert not called["anonymous"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bad_jwt_does_not_fall_through_to_anonymous(self):
|
||||
auth = IamAuth(backend=Mock())
|
||||
auth._signing_public_pem = "not-a-real-pem"
|
||||
called = {"anonymous": False}
|
||||
|
||||
async def spy_anonymous():
|
||||
called["anonymous"] = True
|
||||
|
||||
auth._authenticate_anonymous = spy_anonymous
|
||||
|
||||
with pytest.raises(web.HTTPUnauthorized):
|
||||
await auth.authenticate(make_request("Bearer a.b.c"))
|
||||
assert not called["anonymous"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_anonymous_rejected_by_iam_raises_401(self):
|
||||
auth = IamAuth(backend=Mock())
|
||||
|
||||
async def fake_with_client(op):
|
||||
raise RuntimeError("auth-failed: anonymous access not permitted")
|
||||
|
||||
with patch.object(auth, "_with_client", side_effect=fake_with_client):
|
||||
with pytest.raises(web.HTTPUnauthorized):
|
||||
await auth.authenticate(make_request(None))
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_anonymous_with_empty_user_id_raises_401(self):
|
||||
auth = IamAuth(backend=Mock())
|
||||
|
||||
async def fake_with_client(op):
|
||||
return await op(Mock(
|
||||
authenticate_anonymous=AsyncMock(
|
||||
return_value=("", "default", []),
|
||||
)
|
||||
))
|
||||
|
||||
with patch.object(auth, "_with_client", side_effect=fake_with_client):
|
||||
with pytest.raises(web.HTTPUnauthorized):
|
||||
await auth.authenticate(make_request(None))
|
||||
|
|
|
|||
0
tests/unit/test_iam/__init__.py
Normal file
0
tests/unit/test_iam/__init__.py
Normal file
44
tests/unit/test_iam/test_iam_rejects_anonymous.py
Normal file
44
tests/unit/test_iam/test_iam_rejects_anonymous.py
Normal file
|
|
@ -0,0 +1,44 @@
|
|||
"""
|
||||
Contract test: the full iam-svc MUST reject authenticate-anonymous.
|
||||
|
||||
This is a safety pin — if someone accidentally adds anonymous access
|
||||
to the production IAM handler, this test catches it.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from unittest.mock import Mock, AsyncMock
|
||||
|
||||
import pytest
|
||||
|
||||
from trustgraph.iam.service.iam import IamService
|
||||
|
||||
|
||||
def _make_request(**kwargs):
|
||||
req = Mock()
|
||||
for k, v in kwargs.items():
|
||||
setattr(req, k, v)
|
||||
return req
|
||||
|
||||
|
||||
class TestIamRejectsAnonymous:
|
||||
|
||||
@pytest.fixture
|
||||
def handler(self):
|
||||
svc = object.__new__(IamService)
|
||||
svc.table_store = Mock(spec=[])
|
||||
svc.bootstrap_mode = "token"
|
||||
svc.bootstrap_token = "tok"
|
||||
svc._on_workspace_created = None
|
||||
svc._on_workspace_deleted = None
|
||||
svc._signing_key = None
|
||||
svc._signing_key_lock = asyncio.Lock()
|
||||
return svc
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_authenticate_anonymous_returns_auth_failed(self, handler):
|
||||
resp = await handler.handle(
|
||||
_make_request(operation="authenticate-anonymous")
|
||||
)
|
||||
assert resp.error is not None
|
||||
assert resp.error.type == "auth-failed"
|
||||
assert "anonymous" in resp.error.message.lower()
|
||||
138
tests/unit/test_iam/test_noauth_handler.py
Normal file
138
tests/unit/test_iam/test_noauth_handler.py
Normal file
|
|
@ -0,0 +1,138 @@
|
|||
"""
|
||||
Tests for the no-auth IAM handler.
|
||||
|
||||
Verifies that NoAuthHandler returns the expected permissive responses
|
||||
and that the always-allow authorise path returns the correct shape.
|
||||
"""
|
||||
|
||||
import json
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from trustgraph.iam.noauth.handler import NoAuthHandler
|
||||
|
||||
|
||||
def _make_request(**kwargs):
|
||||
req = Mock()
|
||||
for k, v in kwargs.items():
|
||||
setattr(req, k, v)
|
||||
return req
|
||||
|
||||
|
||||
class TestAuthenticateAnonymous:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_default_identity(self):
|
||||
h = NoAuthHandler(
|
||||
default_user_id="anon", default_workspace="ws",
|
||||
)
|
||||
resp = await h.handle(
|
||||
_make_request(operation="authenticate-anonymous")
|
||||
)
|
||||
assert resp.error is None
|
||||
assert resp.resolved_user_id == "anon"
|
||||
assert resp.resolved_workspace == "ws"
|
||||
assert "admin" in list(resp.resolved_roles)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_custom_defaults_propagate(self):
|
||||
h = NoAuthHandler(
|
||||
default_user_id="dev-user", default_workspace="dev-ws",
|
||||
)
|
||||
resp = await h.handle(
|
||||
_make_request(operation="authenticate-anonymous")
|
||||
)
|
||||
assert resp.resolved_user_id == "dev-user"
|
||||
assert resp.resolved_workspace == "dev-ws"
|
||||
|
||||
|
||||
class TestResolveApiKey:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_any_key_resolves_to_default_identity(self):
|
||||
h = NoAuthHandler()
|
||||
resp = await h.handle(
|
||||
_make_request(operation="resolve-api-key", api_key="tg_bogus")
|
||||
)
|
||||
assert resp.error is None
|
||||
assert resp.resolved_user_id == "anonymous"
|
||||
assert resp.resolved_workspace == "default"
|
||||
|
||||
|
||||
class TestAuthorise:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_always_allows(self):
|
||||
h = NoAuthHandler()
|
||||
resp = await h.handle(
|
||||
_make_request(
|
||||
operation="authorise",
|
||||
user_id="anyone",
|
||||
capability="anything",
|
||||
resource_json="{}",
|
||||
parameters_json="{}",
|
||||
)
|
||||
)
|
||||
assert resp.error is None
|
||||
assert resp.decision_allow is True
|
||||
assert resp.decision_ttl_seconds > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_authorise_many_returns_matching_count(self):
|
||||
h = NoAuthHandler()
|
||||
checks = [
|
||||
{"capability": "a", "resource": {}, "parameters": {}},
|
||||
{"capability": "b", "resource": {}, "parameters": {}},
|
||||
{"capability": "c", "resource": {}, "parameters": {}},
|
||||
]
|
||||
resp = await h.handle(
|
||||
_make_request(
|
||||
operation="authorise-many",
|
||||
user_id="u",
|
||||
authorise_checks=json.dumps(checks),
|
||||
)
|
||||
)
|
||||
assert resp.error is None
|
||||
decisions = json.loads(resp.decisions_json)
|
||||
assert len(decisions) == 3
|
||||
assert all(d["allow"] is True for d in decisions)
|
||||
|
||||
|
||||
class TestCreateWorkspaceCallback:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_workspace_calls_callback(self):
|
||||
called_with = []
|
||||
|
||||
async def on_created(ws_id):
|
||||
called_with.append(ws_id)
|
||||
|
||||
h = NoAuthHandler(on_workspace_created=on_created)
|
||||
req = _make_request(operation="create-workspace")
|
||||
req.workspace_record = Mock()
|
||||
req.workspace_record.id = "test-ws"
|
||||
resp = await h.handle(req)
|
||||
assert resp.error is None
|
||||
assert called_with == ["test-ws"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_workspace_without_callback_still_succeeds(self):
|
||||
h = NoAuthHandler()
|
||||
req = _make_request(operation="create-workspace")
|
||||
req.workspace_record = Mock()
|
||||
req.workspace_record.id = "test-ws"
|
||||
resp = await h.handle(req)
|
||||
assert resp.error is None
|
||||
|
||||
|
||||
class TestUnknownOperation:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unknown_op_returns_error(self):
|
||||
h = NoAuthHandler()
|
||||
resp = await h.handle(
|
||||
_make_request(operation="not-a-real-op")
|
||||
)
|
||||
assert resp.error is not None
|
||||
assert resp.error.type == "invalid-argument"
|
||||
|
|
@ -7,7 +7,7 @@ including template rendering, term merging, JSON validation, and error handling.
|
|||
|
||||
import pytest
|
||||
import json
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
from trustgraph.template.prompt_manager import PromptManager, PromptConfiguration, Prompt
|
||||
|
||||
|
|
@ -344,6 +344,42 @@ class TestPromptManager:
|
|||
assert pm.terms == {} # Default empty terms
|
||||
assert len(pm.prompts) == 0
|
||||
|
||||
def test_load_config_does_not_swallow_keyboard_interrupt(self, monkeypatch):
|
||||
"""KeyboardInterrupt should propagate out of config parsing."""
|
||||
pm = PromptManager()
|
||||
|
||||
def interrupt(_value):
|
||||
raise KeyboardInterrupt
|
||||
|
||||
monkeypatch.setattr("trustgraph.template.prompt_manager.json.loads", interrupt)
|
||||
|
||||
with pytest.raises(KeyboardInterrupt):
|
||||
pm.load_config({"system": json.dumps("Test")})
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_json_parse_does_not_swallow_system_exit(self):
|
||||
"""SystemExit should propagate out of JSON response parsing."""
|
||||
pm = PromptManager()
|
||||
config = {
|
||||
"system": json.dumps("Test"),
|
||||
"template-index": json.dumps(["json_response"]),
|
||||
"template.json_response": json.dumps({
|
||||
"prompt": "Generate JSON",
|
||||
"response-type": "json"
|
||||
})
|
||||
}
|
||||
pm.load_config(config)
|
||||
|
||||
def exit_parse(_text):
|
||||
raise SystemExit(2)
|
||||
|
||||
pm.parse_json = exit_parse
|
||||
mock_llm = AsyncMock()
|
||||
mock_llm.return_value = "{}"
|
||||
|
||||
with pytest.raises(SystemExit):
|
||||
await pm.invoke("json_response", {}, mock_llm)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestPromptManagerJsonl:
|
||||
|
|
@ -585,4 +621,4 @@ not json at all
|
|||
|
||||
assert len(result) == 2
|
||||
assert result[0] == {"any": "structure"}
|
||||
assert result[1] == {"completely": "different"}
|
||||
assert result[1] == {"completely": "different"}
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ import pytest
|
|||
from unittest.mock import Mock, patch, MagicMock, call
|
||||
import json
|
||||
|
||||
from trustgraph.api.socket_client import SocketClient
|
||||
from trustgraph.api import (
|
||||
Api,
|
||||
Triple,
|
||||
|
|
@ -222,6 +223,82 @@ class TestSocketClient:
|
|||
for method in expected_methods:
|
||||
assert hasattr(flow_instance, method), f"Missing method: {method}"
|
||||
|
||||
def test_socket_client_close_does_not_swallow_base_exceptions(self):
|
||||
"""Test close cleanup does not suppress process-level interrupts."""
|
||||
|
||||
class InterruptingLoop:
|
||||
def is_closed(self):
|
||||
return False
|
||||
|
||||
def run_until_complete(self, awaitable):
|
||||
if hasattr(awaitable, "close"):
|
||||
awaitable.close()
|
||||
raise SystemExit("stop")
|
||||
|
||||
socket = SocketClient(url="http://test/", timeout=60, token=None)
|
||||
socket._loop = InterruptingLoop()
|
||||
|
||||
with pytest.raises(SystemExit):
|
||||
socket.close()
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("generator_method", "async_method"),
|
||||
[
|
||||
("_streaming_generator", "_send_request_async_streaming"),
|
||||
("_streaming_generator_raw", "_send_request_async_streaming_raw"),
|
||||
],
|
||||
)
|
||||
def test_socket_client_streaming_cleanup_does_not_swallow_base_exceptions(
|
||||
self, generator_method, async_method
|
||||
):
|
||||
"""Test streaming cleanup does not suppress process-level interrupts."""
|
||||
|
||||
class FakeAsyncGenerator:
|
||||
def __anext__(self):
|
||||
return "next"
|
||||
|
||||
def aclose(self):
|
||||
return "close"
|
||||
|
||||
class InterruptingLoop:
|
||||
def run_until_complete(self, awaitable):
|
||||
if awaitable == "next":
|
||||
raise StopAsyncIteration
|
||||
if awaitable == "close":
|
||||
raise SystemExit("stop")
|
||||
raise AssertionError(f"unexpected awaitable: {awaitable!r}")
|
||||
|
||||
socket = SocketClient(url="http://test/", timeout=60, token=None)
|
||||
setattr(socket, async_method, lambda *args, **kwargs: FakeAsyncGenerator())
|
||||
generator = getattr(socket, generator_method)(
|
||||
"agent", "default", {}, InterruptingLoop()
|
||||
)
|
||||
|
||||
with pytest.raises(SystemExit):
|
||||
next(generator)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_socket_client_reader_does_not_swallow_base_exceptions(self):
|
||||
"""Test reader error fanout does not suppress process-level interrupts."""
|
||||
|
||||
class FailingSocket:
|
||||
def __aiter__(self):
|
||||
return self
|
||||
|
||||
async def __anext__(self):
|
||||
raise ValueError("reader failed")
|
||||
|
||||
class InterruptingQueue:
|
||||
async def put(self, message):
|
||||
raise SystemExit("stop")
|
||||
|
||||
socket = SocketClient(url="http://test/", timeout=60, token=None)
|
||||
socket._socket = FailingSocket()
|
||||
socket._pending = {"req-1": InterruptingQueue()}
|
||||
|
||||
with pytest.raises(SystemExit):
|
||||
await socket._reader()
|
||||
|
||||
|
||||
class TestBulkClient:
|
||||
"""Test bulk operations client"""
|
||||
|
|
|
|||
56
tests/unit/test_query/test_ontology_monitoring.py
Normal file
56
tests/unit/test_query/test_ontology_monitoring.py
Normal file
|
|
@ -0,0 +1,56 @@
|
|||
"""
|
||||
Tests for ontology monitoring metrics.
|
||||
"""
|
||||
|
||||
from trustgraph.query.ontology.monitoring import (
|
||||
PerformanceMonitor,
|
||||
_extract_metric_label,
|
||||
)
|
||||
|
||||
|
||||
def test_extract_metric_label_reads_unquoted_label_value():
|
||||
metric_name = "cache_requests_total{cache_type=entity,component=ontology}"
|
||||
|
||||
assert _extract_metric_label(metric_name, "cache_type") == "entity"
|
||||
|
||||
|
||||
def test_extract_metric_label_reads_quoted_label_value():
|
||||
metric_name = 'cache_requests_total{cache_type="entity",component="ontology"}'
|
||||
|
||||
assert _extract_metric_label(metric_name, "cache_type") == "entity"
|
||||
|
||||
|
||||
def test_extract_metric_label_returns_none_when_label_missing():
|
||||
metric_name = "cache_requests_total{component=ontology}"
|
||||
|
||||
assert _extract_metric_label(metric_name, "cache_type") is None
|
||||
|
||||
|
||||
def test_performance_report_ignores_counters_without_cache_type_label():
|
||||
monitor = PerformanceMonitor({"enabled": False})
|
||||
monitor.metrics_collector.increment(
|
||||
"cache_requests_total",
|
||||
labels={"component": "ontology"},
|
||||
)
|
||||
monitor.metrics_collector.increment(
|
||||
"cache_type=not_a_label",
|
||||
labels={"component": "ontology"},
|
||||
)
|
||||
monitor.metrics_collector.increment(
|
||||
"cache_requests_total",
|
||||
labels={"cache_type": "entity"},
|
||||
)
|
||||
monitor.metrics_collector.increment(
|
||||
"cache_hits_total",
|
||||
labels={"cache_type": "entity"},
|
||||
)
|
||||
|
||||
report = monitor.get_performance_report()
|
||||
|
||||
assert report["cache_performance"] == {
|
||||
"entity": {
|
||||
"hit_rate": 1.0,
|
||||
"total_requests": 1.0,
|
||||
"total_hits": 1.0,
|
||||
}
|
||||
}
|
||||
|
|
@ -89,12 +89,15 @@ class TestRowsGraphQLQueryLogic:
|
|||
@pytest.mark.asyncio
|
||||
async def test_schema_config_parsing(self):
|
||||
"""Test parsing of schema configuration"""
|
||||
import asyncio
|
||||
processor = MagicMock()
|
||||
processor.schemas = {}
|
||||
processor.schema_builders = {}
|
||||
processor.graphql_schemas = {}
|
||||
processor.config_key = "schema"
|
||||
processor.query_cassandra = MagicMock()
|
||||
processor._setup_lock = asyncio.Lock()
|
||||
processor._apply_schema_config = Processor._apply_schema_config.__get__(processor, Processor)
|
||||
processor.on_schema_config = Processor.on_schema_config.__get__(processor, Processor)
|
||||
|
||||
# Create test config
|
||||
|
|
@ -335,7 +338,7 @@ class TestUnifiedTableQueries:
|
|||
"""Test query execution with matching index"""
|
||||
processor = MagicMock()
|
||||
processor.session = MagicMock()
|
||||
processor.connect_cassandra = MagicMock()
|
||||
processor.connect_cassandra = AsyncMock()
|
||||
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
|
||||
processor.get_index_names = Processor.get_index_names.__get__(processor, Processor)
|
||||
processor.find_matching_index = Processor.find_matching_index.__get__(processor, Processor)
|
||||
|
|
@ -396,7 +399,7 @@ class TestUnifiedTableQueries:
|
|||
"""Test query execution without matching index (scan mode)"""
|
||||
processor = MagicMock()
|
||||
processor.session = MagicMock()
|
||||
processor.connect_cassandra = MagicMock()
|
||||
processor.connect_cassandra = AsyncMock()
|
||||
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
|
||||
processor.get_index_names = Processor.get_index_names.__get__(processor, Processor)
|
||||
processor.find_matching_index = Processor.find_matching_index.__get__(processor, Processor)
|
||||
|
|
|
|||
580
tests/unit/test_query/test_sparql_algebra.py
Normal file
580
tests/unit/test_query/test_sparql_algebra.py
Normal file
|
|
@ -0,0 +1,580 @@
|
|||
"""
|
||||
Tests for the SPARQL algebra evaluator.
|
||||
|
||||
Verifies that evaluate() and _query_pattern() call TriplesClient.query()
|
||||
with the correct arguments, and in particular that workspace is never
|
||||
passed — workspace isolation is handled by pub/sub topic routing.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, call
|
||||
|
||||
from rdflib.term import Variable, URIRef, Literal
|
||||
from rdflib.plugins.sparql.parserutils import CompValue
|
||||
|
||||
from trustgraph.schema import Term, IRI, LITERAL
|
||||
from trustgraph.query.sparql.algebra import (
|
||||
evaluate, materialise, _query_pattern, _eval_bgp,
|
||||
)
|
||||
|
||||
|
||||
# --- Helpers ---
|
||||
|
||||
def iri(v):
|
||||
return Term(type=IRI, iri=v)
|
||||
|
||||
|
||||
def lit(v):
|
||||
return Term(type=LITERAL, value=v)
|
||||
|
||||
|
||||
def make_tc(query_return=None, query_side_effect=None):
|
||||
"""Create a mock TriplesClient with both query() and query_gen() support."""
|
||||
tc = AsyncMock()
|
||||
|
||||
if query_side_effect is not None:
|
||||
tc.query.side_effect = query_side_effect
|
||||
|
||||
async def gen_side_effect(**kwargs):
|
||||
results = await query_side_effect(**kwargs)
|
||||
for r in results:
|
||||
yield r
|
||||
|
||||
tc.query_gen = gen_side_effect
|
||||
else:
|
||||
items = query_return or []
|
||||
tc.query.return_value = items
|
||||
|
||||
async def gen(**kwargs):
|
||||
for item in items:
|
||||
yield item
|
||||
|
||||
tc.query_gen = gen
|
||||
|
||||
return tc
|
||||
|
||||
|
||||
def make_triple(s, p, o):
|
||||
t = MagicMock()
|
||||
t.s = s
|
||||
t.p = p
|
||||
t.o = o
|
||||
return t
|
||||
|
||||
|
||||
def make_bgp(*patterns):
|
||||
"""Build a CompValue BGP node from (s, p, o) tuples of rdflib terms."""
|
||||
node = CompValue("BGP")
|
||||
node.triples = list(patterns)
|
||||
return node
|
||||
|
||||
|
||||
def make_project(inner, variables):
|
||||
node = CompValue("Project")
|
||||
node.p = inner
|
||||
node.PV = [Variable(v) for v in variables]
|
||||
return node
|
||||
|
||||
|
||||
def make_select(inner):
|
||||
node = CompValue("SelectQuery")
|
||||
node.p = inner
|
||||
return node
|
||||
|
||||
|
||||
def make_join(left, right):
|
||||
node = CompValue("Join")
|
||||
node.p1 = left
|
||||
node.p2 = right
|
||||
return node
|
||||
|
||||
|
||||
def make_union(left, right):
|
||||
node = CompValue("Union")
|
||||
node.p1 = left
|
||||
node.p2 = right
|
||||
return node
|
||||
|
||||
|
||||
def make_slice(inner, start, length):
|
||||
node = CompValue("Slice")
|
||||
node.p = inner
|
||||
node.start = start
|
||||
node.length = length
|
||||
return node
|
||||
|
||||
|
||||
def make_distinct(inner):
|
||||
node = CompValue("Distinct")
|
||||
node.p = inner
|
||||
return node
|
||||
|
||||
|
||||
def make_filter(inner, expr):
|
||||
node = CompValue("Filter")
|
||||
node.p = inner
|
||||
node.expr = expr
|
||||
return node
|
||||
|
||||
|
||||
def make_minus(left, right):
|
||||
node = CompValue("Minus")
|
||||
node.p1 = left
|
||||
node.p2 = right
|
||||
return node
|
||||
|
||||
|
||||
class TestQueryPattern:
|
||||
"""Tests for _query_pattern — the leaf that calls TriplesClient."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_passes_correct_args(self):
|
||||
tc = AsyncMock()
|
||||
tc.query.return_value = []
|
||||
|
||||
await _query_pattern(
|
||||
tc,
|
||||
s=iri("http://example.com/s"),
|
||||
p=iri("http://example.com/p"),
|
||||
o=None,
|
||||
collection="my-collection",
|
||||
limit=100,
|
||||
)
|
||||
|
||||
tc.query.assert_called_once_with(
|
||||
s=iri("http://example.com/s"),
|
||||
p=iri("http://example.com/p"),
|
||||
o=None,
|
||||
limit=100,
|
||||
collection="my-collection",
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_workspace_not_passed(self):
|
||||
tc = AsyncMock()
|
||||
tc.query.return_value = []
|
||||
|
||||
await _query_pattern(tc, None, None, None, "default", 10)
|
||||
|
||||
kwargs = tc.query.call_args.kwargs
|
||||
assert "workspace" not in kwargs
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_query_results(self):
|
||||
tc = AsyncMock()
|
||||
triple = make_triple(iri("http://a"), iri("http://b"), lit("c"))
|
||||
tc.query.return_value = [triple]
|
||||
|
||||
results = await _query_pattern(tc, None, None, None, "default", 10)
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0].s.iri == "http://a"
|
||||
|
||||
|
||||
class TestEvalBgp:
|
||||
"""Tests for BGP evaluation — triple pattern queries."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_single_pattern_all_variables(self):
|
||||
triple = make_triple(iri("http://s"), iri("http://p"), lit("o"))
|
||||
tc = make_tc(query_return=[triple])
|
||||
|
||||
bgp = make_bgp(
|
||||
(Variable("s"), Variable("p"), Variable("o")),
|
||||
)
|
||||
|
||||
solutions = await materialise(bgp, tc, collection="default", limit=100)
|
||||
|
||||
assert len(solutions) == 1
|
||||
assert solutions[0]["s"].iri == "http://s"
|
||||
assert solutions[0]["p"].iri == "http://p"
|
||||
assert solutions[0]["o"].value == "o"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_single_pattern_bound_subject(self):
|
||||
tc = make_tc(query_return=[
|
||||
make_triple(iri("http://s"), iri("http://p"), lit("val")),
|
||||
])
|
||||
|
||||
bgp = make_bgp(
|
||||
(URIRef("http://s"), Variable("p"), Variable("o")),
|
||||
)
|
||||
|
||||
solutions = await materialise(bgp, tc, collection="default")
|
||||
|
||||
assert len(solutions) == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_bgp_returns_empty_solution(self):
|
||||
tc = make_tc()
|
||||
|
||||
bgp = make_bgp()
|
||||
|
||||
solutions = await materialise(bgp, tc, collection="default")
|
||||
|
||||
assert solutions == [{}]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_results_returns_empty(self):
|
||||
tc = make_tc(query_return=[])
|
||||
|
||||
bgp = make_bgp(
|
||||
(Variable("s"), Variable("p"), Variable("o")),
|
||||
)
|
||||
|
||||
solutions = await materialise(bgp, tc, collection="default")
|
||||
|
||||
assert solutions == []
|
||||
|
||||
|
||||
class TestEvaluate:
|
||||
"""Tests for the top-level evaluate() dispatcher."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_select_query_node(self):
|
||||
tc = make_tc(query_return=[
|
||||
make_triple(iri("http://s"), iri("http://p"), lit("o")),
|
||||
])
|
||||
|
||||
bgp = make_bgp(
|
||||
(Variable("s"), Variable("p"), Variable("o")),
|
||||
)
|
||||
select = make_select(make_project(bgp, ["s", "p"]))
|
||||
|
||||
solutions = await materialise(select, tc, collection="default")
|
||||
|
||||
assert len(solutions) == 1
|
||||
assert "s" in solutions[0]
|
||||
assert "p" in solutions[0]
|
||||
assert "o" not in solutions[0]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_workspace_never_in_query_calls(self):
|
||||
"""Verify that no matter the algebra structure, workspace is never
|
||||
passed to TriplesClient.query()."""
|
||||
tc = make_tc(query_return=[
|
||||
make_triple(iri("http://s"), iri("http://p"), lit("o")),
|
||||
])
|
||||
|
||||
bgp1 = make_bgp((Variable("s"), Variable("p"), Variable("o")))
|
||||
bgp2 = make_bgp((Variable("a"), Variable("b"), Variable("c")))
|
||||
tree = make_select(make_project(
|
||||
make_union(bgp1, bgp2), ["s", "p", "o"]
|
||||
))
|
||||
|
||||
await materialise(tree, tc, collection="test-coll")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_join(self):
|
||||
call_count = 0
|
||||
|
||||
async def mock_query(**kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
return [make_triple(iri("http://a"), iri("http://p"), lit("v"))]
|
||||
else:
|
||||
return [make_triple(iri("http://a"), iri("http://q"), lit("w"))]
|
||||
|
||||
tc = make_tc(query_side_effect=mock_query)
|
||||
|
||||
bgp1 = make_bgp((Variable("s"), URIRef("http://p"), Variable("v1")))
|
||||
bgp2 = make_bgp((Variable("s"), URIRef("http://q"), Variable("v2")))
|
||||
tree = make_join(bgp1, bgp2)
|
||||
|
||||
solutions = await materialise(tree, tc, collection="default")
|
||||
|
||||
assert len(solutions) == 1
|
||||
assert solutions[0]["s"].iri == "http://a"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_slice(self):
|
||||
triples = [
|
||||
make_triple(iri(f"http://s{i}"), iri("http://p"), lit(f"o{i}"))
|
||||
for i in range(5)
|
||||
]
|
||||
tc = make_tc(query_return=triples)
|
||||
|
||||
bgp = make_bgp((Variable("s"), Variable("p"), Variable("o")))
|
||||
tree = make_slice(bgp, start=1, length=2)
|
||||
|
||||
solutions = await materialise(tree, tc, collection="default")
|
||||
|
||||
assert len(solutions) == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_distinct(self):
|
||||
triple = make_triple(iri("http://s"), iri("http://p"), lit("o"))
|
||||
tc = make_tc(query_return=[triple, triple])
|
||||
|
||||
bgp = make_bgp((Variable("s"), Variable("p"), Variable("o")))
|
||||
tree = make_distinct(bgp)
|
||||
|
||||
solutions = await materialise(tree, tc, collection="default")
|
||||
|
||||
assert len(solutions) == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_minus_removes_matching(self):
|
||||
alice = iri("http://example.com/alice")
|
||||
bob = iri("http://example.com/bob")
|
||||
knows = iri("http://example.com/knows")
|
||||
hates = iri("http://example.com/hates")
|
||||
charlie = iri("http://example.com/charlie")
|
||||
|
||||
left_triple = make_triple(alice, knows, bob)
|
||||
right_triple2 = make_triple(alice, hates, charlie)
|
||||
|
||||
async def mock_query(**kwargs):
|
||||
pred = kwargs.get("p")
|
||||
if pred and pred.iri == "http://example.com/knows":
|
||||
return [left_triple]
|
||||
elif pred and pred.iri == "http://example.com/hates":
|
||||
return [right_triple2]
|
||||
return []
|
||||
|
||||
tc = make_tc(query_side_effect=mock_query)
|
||||
|
||||
left_bgp = make_bgp(
|
||||
(Variable("s"), URIRef("http://example.com/knows"), Variable("o"))
|
||||
)
|
||||
right_bgp = make_bgp(
|
||||
(Variable("s"), URIRef("http://example.com/hates"), Variable("r"))
|
||||
)
|
||||
|
||||
tree = make_select(
|
||||
make_project(
|
||||
make_minus(left_bgp, right_bgp),
|
||||
["s", "o"]
|
||||
)
|
||||
)
|
||||
|
||||
solutions = await materialise(tree, tc, collection="default")
|
||||
|
||||
assert len(solutions) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_minus_no_shared_vars_preserves_all(self):
|
||||
alice = iri("http://example.com/alice")
|
||||
bob = iri("http://example.com/bob")
|
||||
|
||||
left_triple = make_triple(alice, iri("http://example.com/p"), bob)
|
||||
|
||||
async def mock_query(**kwargs):
|
||||
pred = kwargs.get("p")
|
||||
if pred and pred.iri == "http://example.com/p":
|
||||
return [left_triple]
|
||||
return []
|
||||
|
||||
tc = make_tc(query_side_effect=mock_query)
|
||||
|
||||
left_bgp = make_bgp(
|
||||
(Variable("s"), URIRef("http://example.com/p"), Variable("o"))
|
||||
)
|
||||
right_bgp = make_bgp(
|
||||
(Variable("x"), URIRef("http://example.com/q"), Variable("y"))
|
||||
)
|
||||
|
||||
tree = make_select(
|
||||
make_project(
|
||||
make_minus(left_bgp, right_bgp),
|
||||
["s", "o"]
|
||||
)
|
||||
)
|
||||
|
||||
solutions = await materialise(tree, tc, collection="default")
|
||||
|
||||
assert len(solutions) == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_filter_exists_keeps_matching(self):
|
||||
alice = iri("http://example.com/alice")
|
||||
bob = iri("http://example.com/bob")
|
||||
charlie = iri("http://example.com/charlie")
|
||||
|
||||
left_triple1 = make_triple(alice, iri("http://example.com/knows"), bob)
|
||||
left_triple2 = make_triple(alice, iri("http://example.com/knows"), charlie)
|
||||
exists_triple = make_triple(bob, iri("http://example.com/likes"), alice)
|
||||
|
||||
async def mock_query(**kwargs):
|
||||
pred = kwargs.get("p")
|
||||
if pred and pred.iri == "http://example.com/knows":
|
||||
return [left_triple1, left_triple2]
|
||||
elif pred and pred.iri == "http://example.com/likes":
|
||||
return [exists_triple]
|
||||
return []
|
||||
|
||||
tc = make_tc(query_side_effect=mock_query)
|
||||
|
||||
left_bgp = make_bgp(
|
||||
(Variable("s"), URIRef("http://example.com/knows"), Variable("o"))
|
||||
)
|
||||
exists_bgp = make_bgp(
|
||||
(Variable("o"), URIRef("http://example.com/likes"), Variable("_any"))
|
||||
)
|
||||
|
||||
exists_expr = CompValue("Builtin_EXISTS")
|
||||
exists_expr.graph = exists_bgp
|
||||
|
||||
tree = make_select(
|
||||
make_project(
|
||||
make_filter(left_bgp, exists_expr),
|
||||
["s", "o"]
|
||||
)
|
||||
)
|
||||
|
||||
solutions = await materialise(tree, tc, collection="default")
|
||||
|
||||
result_objects = [s["o"].iri for s in solutions]
|
||||
assert "http://example.com/bob" in result_objects
|
||||
assert "http://example.com/charlie" not in result_objects
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_filter_not_exists_removes_matching(self):
|
||||
alice = iri("http://example.com/alice")
|
||||
bob = iri("http://example.com/bob")
|
||||
charlie = iri("http://example.com/charlie")
|
||||
|
||||
left_triple1 = make_triple(alice, iri("http://example.com/knows"), bob)
|
||||
left_triple2 = make_triple(alice, iri("http://example.com/knows"), charlie)
|
||||
exists_triple = make_triple(bob, iri("http://example.com/likes"), alice)
|
||||
|
||||
async def mock_query(**kwargs):
|
||||
pred = kwargs.get("p")
|
||||
if pred and pred.iri == "http://example.com/knows":
|
||||
return [left_triple1, left_triple2]
|
||||
elif pred and pred.iri == "http://example.com/likes":
|
||||
return [exists_triple]
|
||||
return []
|
||||
|
||||
tc = make_tc(query_side_effect=mock_query)
|
||||
|
||||
left_bgp = make_bgp(
|
||||
(Variable("s"), URIRef("http://example.com/knows"), Variable("o"))
|
||||
)
|
||||
exists_bgp = make_bgp(
|
||||
(Variable("o"), URIRef("http://example.com/likes"), Variable("_any"))
|
||||
)
|
||||
|
||||
not_exists_expr = CompValue("Builtin_NOTEXISTS")
|
||||
not_exists_expr.graph = exists_bgp
|
||||
|
||||
tree = make_select(
|
||||
make_project(
|
||||
make_filter(left_bgp, not_exists_expr),
|
||||
["s", "o"]
|
||||
)
|
||||
)
|
||||
|
||||
solutions = await materialise(tree, tc, collection="default")
|
||||
|
||||
result_objects = [s["o"].iri for s in solutions]
|
||||
assert "http://example.com/charlie" in result_objects
|
||||
assert "http://example.com/bob" not in result_objects
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_join_values_uses_bind_join(self):
|
||||
"""When VALUES is joined with a BGP, the bind join should pass
|
||||
the VALUES bindings into the BGP evaluation so the triple store
|
||||
query is selective (not a wildcard)."""
|
||||
alice = iri("http://example.com/alice")
|
||||
bob = iri("http://example.com/bob")
|
||||
knows = iri("http://example.com/knows")
|
||||
|
||||
queries_issued = []
|
||||
|
||||
async def mock_query(**kwargs):
|
||||
queries_issued.append(kwargs)
|
||||
s, p = kwargs.get("s"), kwargs.get("p")
|
||||
if s and s.iri == "http://example.com/alice" and p and p.iri == "http://example.com/knows":
|
||||
return [make_triple(alice, knows, bob)]
|
||||
return []
|
||||
|
||||
tc = make_tc(query_side_effect=mock_query)
|
||||
|
||||
# VALUES ?s { <alice> }
|
||||
values_node = CompValue("values")
|
||||
values_node.var = [Variable("s")]
|
||||
values_node.value = [[URIRef("http://example.com/alice")]]
|
||||
values_node.res = None
|
||||
|
||||
to_multiset = CompValue("ToMultiSet")
|
||||
to_multiset.p = values_node
|
||||
|
||||
bgp = make_bgp(
|
||||
(Variable("s"), URIRef("http://example.com/knows"), Variable("o")),
|
||||
)
|
||||
|
||||
tree = make_join(to_multiset, bgp)
|
||||
solutions = await materialise(tree, tc, collection="default")
|
||||
|
||||
assert len(solutions) == 1
|
||||
assert solutions[0]["s"].iri == "http://example.com/alice"
|
||||
assert solutions[0]["o"].iri == "http://example.com/bob"
|
||||
|
||||
# The key assertion: the BGP query should have received
|
||||
# s=alice (bound from VALUES), NOT s=None (wildcard)
|
||||
assert len(queries_issued) == 1
|
||||
assert queries_issued[0]["s"] is not None
|
||||
assert queries_issued[0]["s"].iri == "http://example.com/alice"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_join_values_multiple_bindings(self):
|
||||
"""Bind join with multiple VALUES bindings."""
|
||||
alice = iri("http://example.com/alice")
|
||||
bob = iri("http://example.com/bob")
|
||||
knows = iri("http://example.com/knows")
|
||||
charlie = iri("http://example.com/charlie")
|
||||
|
||||
async def mock_query(**kwargs):
|
||||
s = kwargs.get("s")
|
||||
if s and s.iri == "http://example.com/alice":
|
||||
return [make_triple(alice, knows, bob)]
|
||||
elif s and s.iri == "http://example.com/bob":
|
||||
return [make_triple(bob, knows, charlie)]
|
||||
return []
|
||||
|
||||
tc = make_tc(query_side_effect=mock_query)
|
||||
|
||||
values_node = CompValue("values")
|
||||
values_node.var = [Variable("s")]
|
||||
values_node.value = [
|
||||
[URIRef("http://example.com/alice")],
|
||||
[URIRef("http://example.com/bob")],
|
||||
]
|
||||
values_node.res = None
|
||||
|
||||
to_multiset = CompValue("ToMultiSet")
|
||||
to_multiset.p = values_node
|
||||
|
||||
bgp = make_bgp(
|
||||
(Variable("s"), URIRef("http://example.com/knows"), Variable("o")),
|
||||
)
|
||||
|
||||
tree = make_join(to_multiset, bgp)
|
||||
solutions = await materialise(tree, tc, collection="default")
|
||||
|
||||
assert len(solutions) == 2
|
||||
subjects = {s["s"].iri for s in solutions}
|
||||
assert subjects == {
|
||||
"http://example.com/alice",
|
||||
"http://example.com/bob",
|
||||
}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unsupported_node_returns_empty_solution(self):
|
||||
tc = make_tc()
|
||||
|
||||
node = CompValue("SomethingUnknown")
|
||||
|
||||
solutions = await materialise(node, tc, collection="default")
|
||||
|
||||
assert solutions == [{}]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_compvalue_returns_empty_solution(self):
|
||||
tc = make_tc()
|
||||
|
||||
solutions = await materialise("not a node", tc, collection="default")
|
||||
|
||||
assert solutions == [{}]
|
||||
|
|
@ -300,6 +300,438 @@ class TestBuiltinFunctions:
|
|||
flags=None)
|
||||
assert evaluate_expression(expr, {"x": lit("hello")}) is False
|
||||
|
||||
def test_substr_three_args(self):
|
||||
from rdflib.term import Variable
|
||||
from rdflib import Literal
|
||||
expr = self._make_builtin("SUBSTR",
|
||||
arg=Variable("x"),
|
||||
start=Literal(1),
|
||||
length=Literal(4))
|
||||
result = evaluate_expression(expr, {"x": lit("2024-03-15")})
|
||||
assert result.type == LITERAL
|
||||
assert result.value == "2024"
|
||||
|
||||
def test_substr_two_args(self):
|
||||
from rdflib.term import Variable
|
||||
from rdflib import Literal
|
||||
expr = self._make_builtin("SUBSTR",
|
||||
arg=Variable("x"),
|
||||
start=Literal(6),
|
||||
length=None)
|
||||
result = evaluate_expression(expr, {"x": lit("2024-03-15")})
|
||||
assert result.type == LITERAL
|
||||
assert result.value == "03-15"
|
||||
|
||||
def test_substr_middle(self):
|
||||
from rdflib.term import Variable
|
||||
from rdflib import Literal
|
||||
expr = self._make_builtin("SUBSTR",
|
||||
arg=Variable("x"),
|
||||
start=Literal(6),
|
||||
length=Literal(2))
|
||||
result = evaluate_expression(expr, {"x": lit("2024-03-15")})
|
||||
assert result.type == LITERAL
|
||||
assert result.value == "03"
|
||||
|
||||
def test_substr_null_start(self):
|
||||
from rdflib.term import Variable
|
||||
from rdflib import Literal
|
||||
expr = self._make_builtin("SUBSTR",
|
||||
arg=Variable("x"),
|
||||
start=Variable("missing"),
|
||||
length=None)
|
||||
result = evaluate_expression(expr, {"x": lit("hello")})
|
||||
assert result is None
|
||||
|
||||
def test_year(self):
|
||||
from rdflib.term import Variable
|
||||
expr = self._make_builtin("YEAR", arg=Variable("x"))
|
||||
result = evaluate_expression(
|
||||
expr, {"x": lit("2024-03-15", datatype=XSD + "date")}
|
||||
)
|
||||
assert result == 2024
|
||||
|
||||
def test_month(self):
|
||||
from rdflib.term import Variable
|
||||
expr = self._make_builtin("MONTH", arg=Variable("x"))
|
||||
result = evaluate_expression(
|
||||
expr, {"x": lit("2024-03-15", datatype=XSD + "date")}
|
||||
)
|
||||
assert result == 3
|
||||
|
||||
def test_day(self):
|
||||
from rdflib.term import Variable
|
||||
expr = self._make_builtin("DAY", arg=Variable("x"))
|
||||
result = evaluate_expression(
|
||||
expr, {"x": lit("2024-03-15", datatype=XSD + "date")}
|
||||
)
|
||||
assert result == 15
|
||||
|
||||
def test_hours(self):
|
||||
from rdflib.term import Variable
|
||||
expr = self._make_builtin("HOURS", arg=Variable("x"))
|
||||
result = evaluate_expression(
|
||||
expr, {"x": lit("2024-03-15T10:30:45", datatype=XSD + "dateTime")}
|
||||
)
|
||||
assert result == 10
|
||||
|
||||
def test_minutes(self):
|
||||
from rdflib.term import Variable
|
||||
expr = self._make_builtin("MINUTES", arg=Variable("x"))
|
||||
result = evaluate_expression(
|
||||
expr, {"x": lit("2024-03-15T10:30:45", datatype=XSD + "dateTime")}
|
||||
)
|
||||
assert result == 30
|
||||
|
||||
def test_seconds(self):
|
||||
from rdflib.term import Variable
|
||||
expr = self._make_builtin("SECONDS", arg=Variable("x"))
|
||||
result = evaluate_expression(
|
||||
expr, {"x": lit("2024-03-15T10:30:45", datatype=XSD + "dateTime")}
|
||||
)
|
||||
assert result == 45
|
||||
|
||||
def test_year_from_datetime(self):
|
||||
from rdflib.term import Variable
|
||||
expr = self._make_builtin("YEAR", arg=Variable("x"))
|
||||
result = evaluate_expression(
|
||||
expr, {"x": lit("2024-03-15T10:30:45", datatype=XSD + "dateTime")}
|
||||
)
|
||||
assert result == 2024
|
||||
|
||||
def test_hours_from_date_returns_zero(self):
|
||||
from rdflib.term import Variable
|
||||
expr = self._make_builtin("HOURS", arg=Variable("x"))
|
||||
result = evaluate_expression(
|
||||
expr, {"x": lit("2024-03-15", datatype=XSD + "date")}
|
||||
)
|
||||
assert result == 0
|
||||
|
||||
def test_year_invalid_date(self):
|
||||
from rdflib.term import Variable
|
||||
expr = self._make_builtin("YEAR", arg=Variable("x"))
|
||||
result = evaluate_expression(
|
||||
expr, {"x": lit("not-a-date")}
|
||||
)
|
||||
assert result is None
|
||||
|
||||
def test_floor(self):
|
||||
from rdflib.term import Variable
|
||||
expr = self._make_builtin("FLOOR", arg=Variable("x"))
|
||||
assert evaluate_expression(expr, {"x": lit("3.7")}) == 3
|
||||
|
||||
def test_floor_negative(self):
|
||||
from rdflib.term import Variable
|
||||
expr = self._make_builtin("FLOOR", arg=Variable("x"))
|
||||
assert evaluate_expression(expr, {"x": lit("-2.3")}) == -3
|
||||
|
||||
def test_floor_none(self):
|
||||
from rdflib.term import Variable
|
||||
expr = self._make_builtin("FLOOR", arg=Variable("x"))
|
||||
assert evaluate_expression(expr, {"x": lit("abc")}) is None
|
||||
|
||||
def test_ceil(self):
|
||||
from rdflib.term import Variable
|
||||
expr = self._make_builtin("CEIL", arg=Variable("x"))
|
||||
assert evaluate_expression(expr, {"x": lit("3.2")}) == 4
|
||||
|
||||
def test_ceil_negative(self):
|
||||
from rdflib.term import Variable
|
||||
expr = self._make_builtin("CEIL", arg=Variable("x"))
|
||||
assert evaluate_expression(expr, {"x": lit("-2.7")}) == -2
|
||||
|
||||
def test_abs_positive(self):
|
||||
from rdflib.term import Variable
|
||||
expr = self._make_builtin("ABS", arg=Variable("x"))
|
||||
assert evaluate_expression(expr, {"x": lit("42")}) == 42
|
||||
|
||||
def test_abs_negative(self):
|
||||
from rdflib.term import Variable
|
||||
expr = self._make_builtin("ABS", arg=Variable("x"))
|
||||
assert evaluate_expression(expr, {"x": lit("-42")}) == 42
|
||||
|
||||
def test_abs_none(self):
|
||||
from rdflib.term import Variable
|
||||
expr = self._make_builtin("ABS", arg=Variable("x"))
|
||||
assert evaluate_expression(expr, {"x": lit("abc")}) is None
|
||||
|
||||
def test_replace_simple(self):
|
||||
from rdflib.term import Variable
|
||||
from rdflib import Literal
|
||||
expr = self._make_builtin("REPLACE",
|
||||
arg=Variable("x"),
|
||||
pattern=Literal(" BC"),
|
||||
replacement=Literal(""),
|
||||
flags=None)
|
||||
result = evaluate_expression(expr, {"x": lit("500 BC")})
|
||||
assert result.type == LITERAL
|
||||
assert result.value == "500"
|
||||
|
||||
def test_replace_regex(self):
|
||||
from rdflib.term import Variable
|
||||
from rdflib import Literal
|
||||
expr = self._make_builtin("REPLACE",
|
||||
arg=Variable("x"),
|
||||
pattern=Literal("[0-9]+"),
|
||||
replacement=Literal("X"),
|
||||
flags=None)
|
||||
result = evaluate_expression(expr, {"x": lit("abc123def456")})
|
||||
assert result.value == "abcXdefX"
|
||||
|
||||
def test_replace_case_insensitive(self):
|
||||
from rdflib.term import Variable
|
||||
from rdflib import Literal
|
||||
expr = self._make_builtin("REPLACE",
|
||||
arg=Variable("x"),
|
||||
pattern=Literal("hello"),
|
||||
replacement=Literal("world"),
|
||||
flags=Literal("i"))
|
||||
result = evaluate_expression(expr, {"x": lit("HELLO there")})
|
||||
assert result.value == "world there"
|
||||
|
||||
def test_round_up(self):
|
||||
from rdflib.term import Variable
|
||||
expr = self._make_builtin("ROUND", arg=Variable("x"))
|
||||
assert evaluate_expression(expr, {"x": lit("3.7")}) == 4
|
||||
|
||||
def test_round_down(self):
|
||||
from rdflib.term import Variable
|
||||
expr = self._make_builtin("ROUND", arg=Variable("x"))
|
||||
assert evaluate_expression(expr, {"x": lit("3.2")}) == 3
|
||||
|
||||
def test_round_none(self):
|
||||
from rdflib.term import Variable
|
||||
expr = self._make_builtin("ROUND", arg=Variable("x"))
|
||||
assert evaluate_expression(expr, {"x": lit("abc")}) is None
|
||||
|
||||
def test_strbefore(self):
|
||||
from rdflib.term import Variable
|
||||
from rdflib import Literal
|
||||
expr = self._make_builtin("STRBEFORE",
|
||||
arg1=Variable("x"), arg2=Literal("-"))
|
||||
result = evaluate_expression(expr, {"x": lit("2024-03-15")})
|
||||
assert result.value == "2024"
|
||||
|
||||
def test_strbefore_not_found(self):
|
||||
from rdflib.term import Variable
|
||||
from rdflib import Literal
|
||||
expr = self._make_builtin("STRBEFORE",
|
||||
arg1=Variable("x"), arg2=Literal("/"))
|
||||
result = evaluate_expression(expr, {"x": lit("hello")})
|
||||
assert result.value == ""
|
||||
|
||||
def test_strafter(self):
|
||||
from rdflib.term import Variable
|
||||
from rdflib import Literal
|
||||
expr = self._make_builtin("STRAFTER",
|
||||
arg1=Variable("x"), arg2=Literal("-"))
|
||||
result = evaluate_expression(expr, {"x": lit("2024-03-15")})
|
||||
assert result.value == "03-15"
|
||||
|
||||
def test_strafter_not_found(self):
|
||||
from rdflib.term import Variable
|
||||
from rdflib import Literal
|
||||
expr = self._make_builtin("STRAFTER",
|
||||
arg1=Variable("x"), arg2=Literal("/"))
|
||||
result = evaluate_expression(expr, {"x": lit("hello")})
|
||||
assert result.value == ""
|
||||
|
||||
def test_encode_for_uri(self):
|
||||
from rdflib.term import Variable
|
||||
expr = self._make_builtin("ENCODE_FOR_URI", arg=Variable("x"))
|
||||
result = evaluate_expression(expr, {"x": lit("hello world")})
|
||||
assert result.value == "hello%20world"
|
||||
|
||||
def test_encode_for_uri_special_chars(self):
|
||||
from rdflib.term import Variable
|
||||
expr = self._make_builtin("ENCODE_FOR_URI", arg=Variable("x"))
|
||||
result = evaluate_expression(expr, {"x": lit("a/b?c=d&e")})
|
||||
assert result.value == "a%2Fb%3Fc%3Dd%26e"
|
||||
|
||||
def test_langmatches_basic(self):
|
||||
from rdflib.term import Variable
|
||||
from rdflib import Literal
|
||||
expr = self._make_builtin("LANGMATCHES",
|
||||
arg1=Literal("en"), arg2=Literal("en"))
|
||||
assert evaluate_expression(expr, {}) is True
|
||||
|
||||
def test_langmatches_subtag(self):
|
||||
from rdflib.term import Variable
|
||||
from rdflib import Literal
|
||||
expr = self._make_builtin("LANGMATCHES",
|
||||
arg1=Literal("en-US"), arg2=Literal("en"))
|
||||
assert evaluate_expression(expr, {}) is True
|
||||
|
||||
def test_langmatches_wildcard(self):
|
||||
from rdflib.term import Variable
|
||||
from rdflib import Literal
|
||||
expr = self._make_builtin("LANGMATCHES",
|
||||
arg1=Literal("fr"), arg2=Literal("*"))
|
||||
assert evaluate_expression(expr, {}) is True
|
||||
|
||||
def test_langmatches_wildcard_empty(self):
|
||||
from rdflib.term import Variable
|
||||
from rdflib import Literal
|
||||
expr = self._make_builtin("LANGMATCHES",
|
||||
arg1=Literal(""), arg2=Literal("*"))
|
||||
assert evaluate_expression(expr, {}) is False
|
||||
|
||||
def test_langmatches_no_match(self):
|
||||
from rdflib.term import Variable
|
||||
from rdflib import Literal
|
||||
expr = self._make_builtin("LANGMATCHES",
|
||||
arg1=Literal("fr"), arg2=Literal("en"))
|
||||
assert evaluate_expression(expr, {}) is False
|
||||
|
||||
def test_iri_constructor(self):
|
||||
from rdflib.term import Variable
|
||||
expr = self._make_builtin("IRI", arg=Variable("x"))
|
||||
result = evaluate_expression(
|
||||
expr, {"x": lit("http://example.com/test")}
|
||||
)
|
||||
assert result.type == IRI
|
||||
assert result.iri == "http://example.com/test"
|
||||
|
||||
def test_uri_constructor(self):
|
||||
from rdflib.term import Variable
|
||||
expr = self._make_builtin("URI", arg=Variable("x"))
|
||||
result = evaluate_expression(
|
||||
expr, {"x": lit("http://example.com/test")}
|
||||
)
|
||||
assert result.type == IRI
|
||||
assert result.iri == "http://example.com/test"
|
||||
|
||||
def test_bnode_no_arg(self):
|
||||
expr = self._make_builtin("BNODE")
|
||||
result = evaluate_expression(expr, {})
|
||||
assert result.type == BLANK
|
||||
assert len(result.id) > 0
|
||||
|
||||
def test_bnode_with_label(self):
|
||||
from rdflib import Literal
|
||||
expr = self._make_builtin("BNODE", arg=Literal("mynode"))
|
||||
result = evaluate_expression(expr, {})
|
||||
assert result.type == BLANK
|
||||
assert result.id == "mynode"
|
||||
|
||||
def test_now(self):
|
||||
import re as re_mod
|
||||
expr = self._make_builtin("NOW")
|
||||
result = evaluate_expression(expr, {})
|
||||
assert result.type == LITERAL
|
||||
assert result.datatype == XSD + "dateTime"
|
||||
assert re_mod.match(r"\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}", result.value)
|
||||
|
||||
def test_tz_with_utc(self):
|
||||
from rdflib.term import Variable
|
||||
expr = self._make_builtin("TZ", arg=Variable("x"))
|
||||
result = evaluate_expression(
|
||||
expr, {"x": lit("2024-03-15T10:30:45+0000",
|
||||
datatype=XSD + "dateTime")}
|
||||
)
|
||||
assert result.type == LITERAL
|
||||
assert result.value == "+00:00"
|
||||
|
||||
def test_tz_no_timezone(self):
|
||||
from rdflib.term import Variable
|
||||
expr = self._make_builtin("TZ", arg=Variable("x"))
|
||||
result = evaluate_expression(
|
||||
expr, {"x": lit("2024-03-15T10:30:45",
|
||||
datatype=XSD + "dateTime")}
|
||||
)
|
||||
assert result.value == ""
|
||||
|
||||
def test_rand(self):
|
||||
expr = self._make_builtin("RAND")
|
||||
result = evaluate_expression(expr, {})
|
||||
assert isinstance(result, float)
|
||||
assert 0.0 <= result < 1.0
|
||||
|
||||
def test_uuid(self):
|
||||
import re as re_mod
|
||||
expr = self._make_builtin("UUID")
|
||||
result = evaluate_expression(expr, {})
|
||||
assert result.type == IRI
|
||||
assert result.iri.startswith("urn:uuid:")
|
||||
uuid_part = result.iri[len("urn:uuid:"):]
|
||||
assert re_mod.match(
|
||||
r"[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}",
|
||||
uuid_part
|
||||
)
|
||||
|
||||
def test_struuid(self):
|
||||
import re as re_mod
|
||||
expr = self._make_builtin("STRUUID")
|
||||
result = evaluate_expression(expr, {})
|
||||
assert result.type == LITERAL
|
||||
assert re_mod.match(
|
||||
r"[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}",
|
||||
result.value
|
||||
)
|
||||
|
||||
def test_md5(self):
|
||||
from rdflib.term import Variable
|
||||
expr = self._make_builtin("MD5", arg=Variable("x"))
|
||||
result = evaluate_expression(expr, {"x": lit("hello")})
|
||||
assert result.type == LITERAL
|
||||
assert result.value == "5d41402abc4b2a76b9719d911017c592"
|
||||
|
||||
def test_sha1(self):
|
||||
from rdflib.term import Variable
|
||||
expr = self._make_builtin("SHA1", arg=Variable("x"))
|
||||
result = evaluate_expression(expr, {"x": lit("hello")})
|
||||
assert result.type == LITERAL
|
||||
assert result.value == "aaf4c61ddcc5e8a2dabede0f3b482cd9aea9434d"
|
||||
|
||||
def test_sha256(self):
|
||||
from rdflib.term import Variable
|
||||
expr = self._make_builtin("SHA256", arg=Variable("x"))
|
||||
result = evaluate_expression(expr, {"x": lit("hello")})
|
||||
assert result.type == LITERAL
|
||||
assert result.value == (
|
||||
"2cf24dba5fb0a30e26e83b2ac5b9e29e"
|
||||
"1b161e5c1fa7425e73043362938b9824"
|
||||
)
|
||||
|
||||
def test_sha512(self):
|
||||
from rdflib.term import Variable
|
||||
expr = self._make_builtin("SHA512", arg=Variable("x"))
|
||||
result = evaluate_expression(expr, {"x": lit("hello")})
|
||||
assert result.type == LITERAL
|
||||
assert len(result.value) == 128
|
||||
|
||||
def test_exists_with_callback(self):
|
||||
from rdflib.plugins.sparql.parserutils import CompValue
|
||||
graph = CompValue("BGP")
|
||||
expr = self._make_builtin("EXISTS", graph=graph)
|
||||
cb = lambda g, s: True
|
||||
result = evaluate_expression(expr, {}, exists_cb=cb)
|
||||
assert result is True
|
||||
|
||||
def test_exists_callback_false(self):
|
||||
from rdflib.plugins.sparql.parserutils import CompValue
|
||||
graph = CompValue("BGP")
|
||||
expr = self._make_builtin("EXISTS", graph=graph)
|
||||
cb = lambda g, s: False
|
||||
result = evaluate_expression(expr, {}, exists_cb=cb)
|
||||
assert result is False
|
||||
|
||||
def test_notexists_with_callback(self):
|
||||
from rdflib.plugins.sparql.parserutils import CompValue
|
||||
graph = CompValue("BGP")
|
||||
expr = self._make_builtin("NOTEXISTS", graph=graph)
|
||||
cb = lambda g, s: True
|
||||
result = evaluate_expression(expr, {}, exists_cb=cb)
|
||||
assert result is False
|
||||
|
||||
def test_notexists_callback_false(self):
|
||||
from rdflib.plugins.sparql.parserutils import CompValue
|
||||
graph = CompValue("BGP")
|
||||
expr = self._make_builtin("NOTEXISTS", graph=graph)
|
||||
cb = lambda g, s: False
|
||||
result = evaluate_expression(expr, {}, exists_cb=cb)
|
||||
assert result is True
|
||||
|
||||
|
||||
class TestEffectiveBoolean:
|
||||
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ Tests for SPARQL solution sequence operations.
|
|||
import pytest
|
||||
from trustgraph.schema import Term, IRI, LITERAL
|
||||
from trustgraph.query.sparql.solutions import (
|
||||
hash_join, left_join, union, project, distinct,
|
||||
hash_join, left_join, minus, union, project, distinct,
|
||||
order_by, slice_solutions, _terms_equal, _compatible,
|
||||
)
|
||||
|
||||
|
|
@ -311,6 +311,30 @@ class TestOrderBy:
|
|||
result = order_by(solutions, [])
|
||||
assert len(result) == 1
|
||||
|
||||
def test_order_by_numeric_literals(self):
|
||||
solutions = [
|
||||
{"year": lit("1950")},
|
||||
{"year": lit("700")},
|
||||
{"year": lit("2000")},
|
||||
{"year": lit("450")},
|
||||
{"year": lit("1200")},
|
||||
]
|
||||
key_fns = [(lambda sol: sol.get("year"), True)]
|
||||
result = order_by(solutions, key_fns)
|
||||
values = [s["year"].value for s in result]
|
||||
assert values == ["450", "700", "1200", "1950", "2000"]
|
||||
|
||||
def test_order_by_numeric_descending(self):
|
||||
solutions = [
|
||||
{"year": lit("1950")},
|
||||
{"year": lit("700")},
|
||||
{"year": lit("2000")},
|
||||
]
|
||||
key_fns = [(lambda sol: sol.get("year"), False)]
|
||||
result = order_by(solutions, key_fns)
|
||||
values = [s["year"].value for s in result]
|
||||
assert values == ["2000", "1950", "700"]
|
||||
|
||||
|
||||
class TestSlice:
|
||||
|
||||
|
|
@ -343,3 +367,37 @@ class TestSlice:
|
|||
solutions = [{"s": alice}, {"s": bob}]
|
||||
result = slice_solutions(solutions)
|
||||
assert len(result) == 2
|
||||
|
||||
|
||||
class TestMinus:
|
||||
|
||||
def test_removes_compatible(self, alice, bob):
|
||||
left = [{"s": alice}, {"s": bob}]
|
||||
right = [{"s": alice}]
|
||||
result = minus(left, right)
|
||||
assert len(result) == 1
|
||||
assert result[0]["s"].iri == "http://example.com/bob"
|
||||
|
||||
def test_empty_right_preserves_all(self, alice, bob):
|
||||
left = [{"s": alice}, {"s": bob}]
|
||||
result = minus(left, [])
|
||||
assert len(result) == 2
|
||||
|
||||
def test_no_shared_variables_preserves_all(self, alice, bob):
|
||||
left = [{"s": alice}]
|
||||
right = [{"t": bob}]
|
||||
result = minus(left, right)
|
||||
assert len(result) == 1
|
||||
|
||||
def test_all_removed(self, alice):
|
||||
left = [{"s": alice}]
|
||||
right = [{"s": alice}]
|
||||
result = minus(left, right)
|
||||
assert len(result) == 0
|
||||
|
||||
def test_partial_shared_variables(self, alice, bob):
|
||||
left = [{"s": alice, "p": lit("x")}, {"s": bob, "p": lit("y")}]
|
||||
right = [{"s": alice}]
|
||||
result = minus(left, right)
|
||||
assert len(result) == 1
|
||||
assert result[0]["s"].iri == "http://example.com/bob"
|
||||
|
|
|
|||
|
|
@ -2,8 +2,10 @@
|
|||
Tests for Cassandra triples query service
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
from unittest.mock import MagicMock, patch, AsyncMock
|
||||
|
||||
from trustgraph.query.triples.cassandra.service import Processor, create_term
|
||||
from trustgraph.schema import Term, IRI, LITERAL
|
||||
|
|
@ -18,7 +20,7 @@ class TestCassandraQueryProcessor:
|
|||
return Processor(
|
||||
taskgroup=MagicMock(),
|
||||
id='test-cassandra-query',
|
||||
graph_host='localhost'
|
||||
cassandra_host='localhost'
|
||||
)
|
||||
|
||||
def test_create_term_with_http_uri(self, processor):
|
||||
|
|
@ -85,7 +87,7 @@ class TestCassandraQueryProcessor:
|
|||
mock_result.dtype = None
|
||||
mock_result.lang = None
|
||||
mock_result.o = 'test_object'
|
||||
mock_tg_instance.get_spo.return_value = [mock_result]
|
||||
mock_tg_instance.async_get_spo = AsyncMock(return_value=[mock_result])
|
||||
|
||||
processor = Processor(
|
||||
taskgroup=MagicMock(),
|
||||
|
|
@ -110,8 +112,8 @@ class TestCassandraQueryProcessor:
|
|||
keyspace='test_user'
|
||||
)
|
||||
|
||||
# Verify get_spo was called with correct parameters
|
||||
mock_tg_instance.get_spo.assert_called_once_with(
|
||||
# Verify async_get_spo was called with correct parameters
|
||||
mock_tg_instance.async_get_spo.assert_called_once_with(
|
||||
'test_collection', 'test_subject', 'test_predicate', 'test_object', g=None, limit=100
|
||||
)
|
||||
|
||||
|
|
@ -130,23 +132,25 @@ class TestCassandraQueryProcessor:
|
|||
assert processor.cassandra_host == ['cassandra'] # Updated default
|
||||
assert processor.cassandra_username is None
|
||||
assert processor.cassandra_password is None
|
||||
assert processor.table is None
|
||||
assert processor._connections == {}
|
||||
assert isinstance(processor._conn_lock, asyncio.Lock)
|
||||
|
||||
def test_processor_initialization_with_custom_params(self):
|
||||
"""Test processor initialization with custom parameters"""
|
||||
taskgroup_mock = MagicMock()
|
||||
|
||||
|
||||
processor = Processor(
|
||||
taskgroup=taskgroup_mock,
|
||||
cassandra_host='cassandra.example.com',
|
||||
cassandra_username='queryuser',
|
||||
cassandra_password='querypass'
|
||||
)
|
||||
|
||||
|
||||
assert processor.cassandra_host == ['cassandra.example.com']
|
||||
assert processor.cassandra_username == 'queryuser'
|
||||
assert processor.cassandra_password == 'querypass'
|
||||
assert processor.table is None
|
||||
assert processor._connections == {}
|
||||
assert isinstance(processor._conn_lock, asyncio.Lock)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.query.triples.cassandra.service.EntityCentricKnowledgeGraph')
|
||||
|
|
@ -164,7 +168,7 @@ class TestCassandraQueryProcessor:
|
|||
mock_result.otype = None
|
||||
mock_result.dtype = None
|
||||
mock_result.lang = None
|
||||
mock_tg_instance.get_sp.return_value = [mock_result]
|
||||
mock_tg_instance.async_get_sp = AsyncMock(return_value=[mock_result])
|
||||
|
||||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
|
|
@ -178,7 +182,7 @@ class TestCassandraQueryProcessor:
|
|||
|
||||
result = await processor.query_triples('test_user', query)
|
||||
|
||||
mock_tg_instance.get_sp.assert_called_once_with('test_collection', 'test_subject', 'test_predicate', g=None, limit=50)
|
||||
mock_tg_instance.async_get_sp.assert_called_once_with('test_collection', 'test_subject', 'test_predicate', g=None, limit=50)
|
||||
assert len(result) == 1
|
||||
assert result[0].s.iri == 'test_subject'
|
||||
assert result[0].p.iri == 'test_predicate'
|
||||
|
|
@ -200,7 +204,7 @@ class TestCassandraQueryProcessor:
|
|||
mock_result.otype = None
|
||||
mock_result.dtype = None
|
||||
mock_result.lang = None
|
||||
mock_tg_instance.get_s.return_value = [mock_result]
|
||||
mock_tg_instance.async_get_s = AsyncMock(return_value=[mock_result])
|
||||
|
||||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
|
|
@ -214,7 +218,7 @@ class TestCassandraQueryProcessor:
|
|||
|
||||
result = await processor.query_triples('test_user', query)
|
||||
|
||||
mock_tg_instance.get_s.assert_called_once_with('test_collection', 'test_subject', g=None, limit=25)
|
||||
mock_tg_instance.async_get_s.assert_called_once_with('test_collection', 'test_subject', g=None, limit=25)
|
||||
assert len(result) == 1
|
||||
assert result[0].s.iri == 'test_subject'
|
||||
assert result[0].p.iri == 'result_predicate'
|
||||
|
|
@ -236,7 +240,7 @@ class TestCassandraQueryProcessor:
|
|||
mock_result.otype = None
|
||||
mock_result.dtype = None
|
||||
mock_result.lang = None
|
||||
mock_tg_instance.get_p.return_value = [mock_result]
|
||||
mock_tg_instance.async_get_p = AsyncMock(return_value=[mock_result])
|
||||
|
||||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
|
|
@ -250,7 +254,7 @@ class TestCassandraQueryProcessor:
|
|||
|
||||
result = await processor.query_triples('test_user', query)
|
||||
|
||||
mock_tg_instance.get_p.assert_called_once_with('test_collection', 'test_predicate', g=None, limit=10)
|
||||
mock_tg_instance.async_get_p.assert_called_once_with('test_collection', 'test_predicate', g=None, limit=10)
|
||||
assert len(result) == 1
|
||||
assert result[0].s.iri == 'result_subject'
|
||||
assert result[0].p.iri == 'test_predicate'
|
||||
|
|
@ -272,7 +276,7 @@ class TestCassandraQueryProcessor:
|
|||
mock_result.otype = None
|
||||
mock_result.dtype = None
|
||||
mock_result.lang = None
|
||||
mock_tg_instance.get_o.return_value = [mock_result]
|
||||
mock_tg_instance.async_get_o = AsyncMock(return_value=[mock_result])
|
||||
|
||||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
|
|
@ -286,7 +290,7 @@ class TestCassandraQueryProcessor:
|
|||
|
||||
result = await processor.query_triples('test_user', query)
|
||||
|
||||
mock_tg_instance.get_o.assert_called_once_with('test_collection', 'test_object', g=None, limit=75)
|
||||
mock_tg_instance.async_get_o.assert_called_once_with('test_collection', 'test_object', g=None, limit=75)
|
||||
assert len(result) == 1
|
||||
assert result[0].s.iri == 'result_subject'
|
||||
assert result[0].p.iri == 'result_predicate'
|
||||
|
|
@ -305,11 +309,11 @@ class TestCassandraQueryProcessor:
|
|||
mock_result.s = 'all_subject'
|
||||
mock_result.p = 'all_predicate'
|
||||
mock_result.o = 'all_object'
|
||||
mock_result.g = ''
|
||||
mock_result.d = ''
|
||||
mock_result.otype = None
|
||||
mock_result.dtype = None
|
||||
mock_result.lang = None
|
||||
mock_tg_instance.get_all.return_value = [mock_result]
|
||||
mock_tg_instance.async_get_all = AsyncMock(return_value=[mock_result])
|
||||
|
||||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
|
|
@ -323,7 +327,7 @@ class TestCassandraQueryProcessor:
|
|||
|
||||
result = await processor.query_triples('test_user', query)
|
||||
|
||||
mock_tg_instance.get_all.assert_called_once_with('test_collection', limit=1000)
|
||||
mock_tg_instance.async_get_all.assert_called_once_with('test_collection', limit=1000)
|
||||
assert len(result) == 1
|
||||
assert result[0].s.iri == 'all_subject'
|
||||
assert result[0].p.iri == 'all_predicate'
|
||||
|
|
@ -410,7 +414,7 @@ class TestCassandraQueryProcessor:
|
|||
mock_result.dtype = None
|
||||
mock_result.lang = None
|
||||
mock_result.o = 'test_object'
|
||||
mock_tg_instance.get_spo.return_value = [mock_result]
|
||||
mock_tg_instance.async_get_spo = AsyncMock(return_value=[mock_result])
|
||||
|
||||
processor = Processor(
|
||||
taskgroup=MagicMock(),
|
||||
|
|
@ -451,7 +455,7 @@ class TestCassandraQueryProcessor:
|
|||
mock_result.dtype = None
|
||||
mock_result.lang = None
|
||||
mock_result.o = 'test_object'
|
||||
mock_tg_instance.get_spo.return_value = [mock_result]
|
||||
mock_tg_instance.async_get_spo = AsyncMock(return_value=[mock_result])
|
||||
|
||||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
|
|
@ -489,8 +493,8 @@ class TestCassandraQueryProcessor:
|
|||
mock_result.lang = None
|
||||
mock_result.p = 'p'
|
||||
mock_result.o = 'o'
|
||||
mock_tg_instance1.get_s.return_value = [mock_result]
|
||||
mock_tg_instance2.get_s.return_value = [mock_result]
|
||||
mock_tg_instance1.async_get_s = AsyncMock(return_value=[mock_result])
|
||||
mock_tg_instance2.async_get_s = AsyncMock(return_value=[mock_result])
|
||||
|
||||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
|
|
@ -504,7 +508,6 @@ class TestCassandraQueryProcessor:
|
|||
)
|
||||
|
||||
await processor.query_triples('user1', query1)
|
||||
assert processor.table == 'user1'
|
||||
|
||||
# Second query with different table
|
||||
query2 = TriplesQueryRequest(
|
||||
|
|
@ -516,10 +519,11 @@ class TestCassandraQueryProcessor:
|
|||
)
|
||||
|
||||
await processor.query_triples('user2', query2)
|
||||
assert processor.table == 'user2'
|
||||
|
||||
# Verify TrustGraph was created twice
|
||||
# Verify TrustGraph was created twice for different workspaces
|
||||
assert mock_kg_class.call_count == 2
|
||||
mock_kg_class.assert_any_call(hosts=['cassandra'], keyspace='user1')
|
||||
mock_kg_class.assert_any_call(hosts=['cassandra'], keyspace='user2')
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.query.triples.cassandra.service.EntityCentricKnowledgeGraph')
|
||||
|
|
@ -529,7 +533,7 @@ class TestCassandraQueryProcessor:
|
|||
|
||||
mock_tg_instance = MagicMock()
|
||||
mock_kg_class.return_value = mock_tg_instance
|
||||
mock_tg_instance.get_spo.side_effect = Exception("Query failed")
|
||||
mock_tg_instance.async_get_spo = AsyncMock(side_effect=Exception("Query failed"))
|
||||
|
||||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
|
|
@ -566,7 +570,7 @@ class TestCassandraQueryProcessor:
|
|||
mock_result2.otype = None
|
||||
mock_result2.dtype = None
|
||||
mock_result2.lang = None
|
||||
mock_tg_instance.get_sp.return_value = [mock_result1, mock_result2]
|
||||
mock_tg_instance.async_get_sp = AsyncMock(return_value=[mock_result1, mock_result2])
|
||||
|
||||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
|
|
@ -603,7 +607,7 @@ class TestCassandraQueryPerformanceOptimizations:
|
|||
mock_result.otype = None
|
||||
mock_result.dtype = None
|
||||
mock_result.lang = None
|
||||
mock_tg_instance.get_po.return_value = [mock_result]
|
||||
mock_tg_instance.async_get_po = AsyncMock(return_value=[mock_result])
|
||||
|
||||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
|
|
@ -618,8 +622,8 @@ class TestCassandraQueryPerformanceOptimizations:
|
|||
|
||||
result = await processor.query_triples('test_user', query)
|
||||
|
||||
# Verify get_po was called (should use optimized po_table)
|
||||
mock_tg_instance.get_po.assert_called_once_with(
|
||||
# Verify async_get_po was called (should use optimized po_table)
|
||||
mock_tg_instance.async_get_po.assert_called_once_with(
|
||||
'test_collection', 'test_predicate', 'test_object', g=None, limit=50
|
||||
)
|
||||
|
||||
|
|
@ -643,7 +647,7 @@ class TestCassandraQueryPerformanceOptimizations:
|
|||
mock_result.otype = None
|
||||
mock_result.dtype = None
|
||||
mock_result.lang = None
|
||||
mock_tg_instance.get_os.return_value = [mock_result]
|
||||
mock_tg_instance.async_get_os = AsyncMock(return_value=[mock_result])
|
||||
|
||||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
|
|
@ -658,8 +662,8 @@ class TestCassandraQueryPerformanceOptimizations:
|
|||
|
||||
result = await processor.query_triples('test_user', query)
|
||||
|
||||
# Verify get_os was called (should use optimized subject_table with clustering)
|
||||
mock_tg_instance.get_os.assert_called_once_with(
|
||||
# Verify async_get_os was called (should use optimized subject_table with clustering)
|
||||
mock_tg_instance.async_get_os.assert_called_once_with(
|
||||
'test_collection', 'test_object', 'test_subject', g=None, limit=25
|
||||
)
|
||||
|
||||
|
|
@ -678,28 +682,28 @@ class TestCassandraQueryPerformanceOptimizations:
|
|||
mock_kg_class.return_value = mock_tg_instance
|
||||
|
||||
# Mock empty results for all queries
|
||||
mock_tg_instance.get_all.return_value = []
|
||||
mock_tg_instance.get_s.return_value = []
|
||||
mock_tg_instance.get_p.return_value = []
|
||||
mock_tg_instance.get_o.return_value = []
|
||||
mock_tg_instance.get_sp.return_value = []
|
||||
mock_tg_instance.get_po.return_value = []
|
||||
mock_tg_instance.get_os.return_value = []
|
||||
mock_tg_instance.get_spo.return_value = []
|
||||
mock_tg_instance.async_get_all = AsyncMock(return_value=[])
|
||||
mock_tg_instance.async_get_s = AsyncMock(return_value=[])
|
||||
mock_tg_instance.async_get_p = AsyncMock(return_value=[])
|
||||
mock_tg_instance.async_get_o = AsyncMock(return_value=[])
|
||||
mock_tg_instance.async_get_sp = AsyncMock(return_value=[])
|
||||
mock_tg_instance.async_get_po = AsyncMock(return_value=[])
|
||||
mock_tg_instance.async_get_os = AsyncMock(return_value=[])
|
||||
mock_tg_instance.async_get_spo = AsyncMock(return_value=[])
|
||||
|
||||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
# Test each query pattern
|
||||
test_patterns = [
|
||||
# (s, p, o, expected_method)
|
||||
(None, None, None, 'get_all'), # All triples
|
||||
('s1', None, None, 'get_s'), # Subject only
|
||||
(None, 'p1', None, 'get_p'), # Predicate only
|
||||
(None, None, 'o1', 'get_o'), # Object only
|
||||
('s1', 'p1', None, 'get_sp'), # Subject + Predicate
|
||||
(None, 'p1', 'o1', 'get_po'), # Predicate + Object (CRITICAL OPTIMIZATION)
|
||||
('s1', None, 'o1', 'get_os'), # Object + Subject
|
||||
('s1', 'p1', 'o1', 'get_spo'), # All three
|
||||
(None, None, None, 'async_get_all'), # All triples
|
||||
('s1', None, None, 'async_get_s'), # Subject only
|
||||
(None, 'p1', None, 'async_get_p'), # Predicate only
|
||||
(None, None, 'o1', 'async_get_o'), # Object only
|
||||
('s1', 'p1', None, 'async_get_sp'), # Subject + Predicate
|
||||
(None, 'p1', 'o1', 'async_get_po'), # Predicate + Object (CRITICAL OPTIMIZATION)
|
||||
('s1', None, 'o1', 'async_get_os'), # Object + Subject
|
||||
('s1', 'p1', 'o1', 'async_get_spo'), # All three
|
||||
]
|
||||
|
||||
for s, p, o, expected_method in test_patterns:
|
||||
|
|
@ -759,7 +763,7 @@ class TestCassandraQueryPerformanceOptimizations:
|
|||
mock_result.lang = None
|
||||
mock_results.append(mock_result)
|
||||
|
||||
mock_tg_instance.get_po.return_value = mock_results
|
||||
mock_tg_instance.async_get_po = AsyncMock(return_value=mock_results)
|
||||
|
||||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
|
|
@ -774,8 +778,8 @@ class TestCassandraQueryPerformanceOptimizations:
|
|||
|
||||
result = await processor.query_triples('large_dataset_user', query)
|
||||
|
||||
# Verify optimized get_po was used (no ALLOW FILTERING needed!)
|
||||
mock_tg_instance.get_po.assert_called_once_with(
|
||||
# Verify optimized async_get_po was used (no ALLOW FILTERING needed!)
|
||||
mock_tg_instance.async_get_po.assert_called_once_with(
|
||||
'massive_collection',
|
||||
'http://www.w3.org/1999/02/22-rdf-syntax-ns#type',
|
||||
'http://example.com/Person',
|
||||
|
|
|
|||
|
|
@ -113,12 +113,15 @@ class TestDocEmbeddingsNullProtection:
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_valid_embedding_upserted(self):
|
||||
import asyncio
|
||||
from trustgraph.storage.doc_embeddings.qdrant.write import Processor
|
||||
|
||||
proc = Processor.__new__(Processor)
|
||||
proc.qdrant = MagicMock()
|
||||
proc.qdrant.collection_exists.return_value = True
|
||||
proc.collection_exists = MagicMock(return_value=True)
|
||||
proc._cache_lock = asyncio.Lock()
|
||||
proc._known_collections = set()
|
||||
|
||||
msg = MagicMock()
|
||||
msg.metadata.collection = "col1"
|
||||
|
|
@ -134,12 +137,15 @@ class TestDocEmbeddingsNullProtection:
|
|||
@pytest.mark.asyncio
|
||||
async def test_dimension_in_collection_name(self):
|
||||
"""Collection name should include vector dimension."""
|
||||
import asyncio
|
||||
from trustgraph.storage.doc_embeddings.qdrant.write import Processor
|
||||
|
||||
proc = Processor.__new__(Processor)
|
||||
proc.qdrant = MagicMock()
|
||||
proc.qdrant.collection_exists.return_value = True
|
||||
proc.collection_exists = MagicMock(return_value=True)
|
||||
proc._cache_lock = asyncio.Lock()
|
||||
proc._known_collections = set()
|
||||
|
||||
msg = MagicMock()
|
||||
msg.metadata.collection = "docs"
|
||||
|
|
@ -220,12 +226,15 @@ class TestGraphEmbeddingsNullProtection:
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_valid_entity_and_vector_upserted(self):
|
||||
import asyncio
|
||||
from trustgraph.storage.graph_embeddings.qdrant.write import Processor
|
||||
|
||||
proc = Processor.__new__(Processor)
|
||||
proc.qdrant = MagicMock()
|
||||
proc.qdrant.collection_exists.return_value = True
|
||||
proc.collection_exists = MagicMock(return_value=True)
|
||||
proc._cache_lock = asyncio.Lock()
|
||||
proc._known_collections = set()
|
||||
|
||||
msg = MagicMock()
|
||||
msg.metadata.collection = "col1"
|
||||
|
|
@ -241,12 +250,15 @@ class TestGraphEmbeddingsNullProtection:
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_lazy_collection_creation_on_new_dimension(self):
|
||||
import asyncio
|
||||
from trustgraph.storage.graph_embeddings.qdrant.write import Processor
|
||||
|
||||
proc = Processor.__new__(Processor)
|
||||
proc.qdrant = MagicMock()
|
||||
proc.qdrant.collection_exists.return_value = False
|
||||
proc.collection_exists = MagicMock(return_value=True)
|
||||
proc._cache_lock = asyncio.Lock()
|
||||
proc._known_collections = set()
|
||||
|
||||
msg = MagicMock()
|
||||
msg.metadata.collection = "graphs"
|
||||
|
|
|
|||
|
|
@ -337,6 +337,57 @@ class TestQuery:
|
|||
cache_key = "test_collection:unlabeled_entity"
|
||||
mock_cache.put.assert_called_once_with(cache_key, "unlabeled_entity")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_triples_query_never_passes_workspace(self):
|
||||
"""Workspace isolation is handled by pub/sub topic routing, not
|
||||
by passing workspace to TriplesClient.query(). Verify that
|
||||
GraphRAG never passes workspace as a keyword argument."""
|
||||
mock_rag = MagicMock()
|
||||
mock_cache = MagicMock()
|
||||
mock_cache.get.return_value = None
|
||||
mock_rag.label_cache = mock_cache
|
||||
mock_triples_client = AsyncMock()
|
||||
mock_rag.triples_client = mock_triples_client
|
||||
|
||||
mock_triple = MagicMock()
|
||||
mock_triple.o = "Label"
|
||||
mock_triples_client.query.return_value = [mock_triple]
|
||||
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
collection="test_collection",
|
||||
verbose=False
|
||||
)
|
||||
|
||||
await query.maybe_label("http://example.com/entity")
|
||||
|
||||
for c in mock_triples_client.query.call_args_list:
|
||||
assert "workspace" not in c.kwargs
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_follow_edges_never_passes_workspace(self):
|
||||
"""Verify follow_edges never passes workspace to query_stream."""
|
||||
mock_rag = MagicMock()
|
||||
mock_triples_client = AsyncMock()
|
||||
mock_rag.triples_client = mock_triples_client
|
||||
|
||||
mock_triple = MagicMock()
|
||||
mock_triple.s, mock_triple.p, mock_triple.o = "e1", "p1", "o1"
|
||||
mock_triples_client.query_stream.return_value = [mock_triple]
|
||||
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
collection="test_collection",
|
||||
verbose=False,
|
||||
triple_limit=10
|
||||
)
|
||||
|
||||
subgraph = set()
|
||||
await query.follow_edges("e1", subgraph, path_length=1)
|
||||
|
||||
for c in mock_triples_client.query_stream.call_args_list:
|
||||
assert "workspace" not in c.kwargs
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_follow_edges_basic_functionality(self):
|
||||
"""Test Query.follow_edges method basic triple discovery"""
|
||||
|
|
|
|||
0
tests/unit/test_rev_gateway/__init__.py
Normal file
0
tests/unit/test_rev_gateway/__init__.py
Normal file
|
|
@ -3,275 +3,279 @@ Tests for Reverse Gateway Dispatcher
|
|||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, AsyncMock, patch
|
||||
import asyncio
|
||||
from unittest.mock import MagicMock, AsyncMock, patch, ANY
|
||||
|
||||
from trustgraph.rev_gateway.dispatcher import WebSocketResponder, MessageDispatcher
|
||||
|
||||
|
||||
class TestWebSocketResponder:
|
||||
"""Test cases for WebSocketResponder class"""
|
||||
|
||||
def test_websocket_responder_initialization(self):
|
||||
"""Test WebSocketResponder initialization"""
|
||||
responder = WebSocketResponder()
|
||||
|
||||
assert responder.response is None
|
||||
assert responder.completed is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_websocket_responder_send_method(self):
|
||||
"""Test WebSocketResponder send method"""
|
||||
responder = WebSocketResponder()
|
||||
|
||||
test_response = {"data": "test response"}
|
||||
|
||||
# Call send method
|
||||
await responder.send(test_response)
|
||||
|
||||
# Verify response was stored
|
||||
assert responder.response == test_response
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_websocket_responder_call_method(self):
|
||||
"""Test WebSocketResponder __call__ method"""
|
||||
responder = WebSocketResponder()
|
||||
|
||||
test_response = {"result": "success"}
|
||||
test_completed = True
|
||||
|
||||
# Call the responder
|
||||
await responder(test_response, test_completed)
|
||||
|
||||
# Verify response and completed status were set
|
||||
assert responder.response == test_response
|
||||
assert responder.completed == test_completed
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_websocket_responder_call_method_with_false_completion(self):
|
||||
"""Test WebSocketResponder __call__ method with incomplete response"""
|
||||
responder = WebSocketResponder()
|
||||
|
||||
test_response = {"partial": "data"}
|
||||
test_completed = False
|
||||
|
||||
# Call the responder
|
||||
await responder(test_response, test_completed)
|
||||
|
||||
# Verify response was set and completed is True (since send() always sets completed=True)
|
||||
assert responder.response == test_response
|
||||
assert responder.completed is True
|
||||
from trustgraph.rev_gateway.dispatcher import MessageDispatcher
|
||||
|
||||
|
||||
class TestMessageDispatcher:
|
||||
"""Test cases for MessageDispatcher class"""
|
||||
|
||||
def test_message_dispatcher_initialization_with_defaults(self):
|
||||
"""Test MessageDispatcher initialization with default parameters"""
|
||||
dispatcher = MessageDispatcher()
|
||||
|
||||
|
||||
assert dispatcher.max_workers == 10
|
||||
assert dispatcher.semaphore._value == 10
|
||||
assert dispatcher.active_tasks == set()
|
||||
assert dispatcher.backend is None
|
||||
assert dispatcher.auth is None
|
||||
assert dispatcher.dispatcher_manager is None
|
||||
assert len(dispatcher.service_mapping) > 0
|
||||
|
||||
def test_message_dispatcher_initialization_with_custom_workers(self):
|
||||
"""Test MessageDispatcher initialization with custom max_workers"""
|
||||
dispatcher = MessageDispatcher(max_workers=5)
|
||||
|
||||
|
||||
assert dispatcher.max_workers == 5
|
||||
assert dispatcher.semaphore._value == 5
|
||||
|
||||
@patch('trustgraph.rev_gateway.dispatcher.DispatcherManager')
|
||||
def test_message_dispatcher_initialization_with_pulsar_client(self, mock_dispatcher_manager):
|
||||
"""Test MessageDispatcher initialization with pulsar_client and config_receiver"""
|
||||
def test_message_dispatcher_initialization_with_backend(
|
||||
self, mock_dispatcher_manager,
|
||||
):
|
||||
mock_backend = MagicMock()
|
||||
mock_config_receiver = MagicMock()
|
||||
mock_auth = MagicMock()
|
||||
mock_dispatcher_instance = MagicMock()
|
||||
mock_dispatcher_manager.return_value = mock_dispatcher_instance
|
||||
|
||||
|
||||
dispatcher = MessageDispatcher(
|
||||
max_workers=8,
|
||||
config_receiver=mock_config_receiver,
|
||||
backend=mock_backend
|
||||
backend=mock_backend,
|
||||
auth=mock_auth,
|
||||
timeout=300,
|
||||
)
|
||||
|
||||
|
||||
assert dispatcher.max_workers == 8
|
||||
assert dispatcher.backend == mock_backend
|
||||
assert dispatcher.auth == mock_auth
|
||||
assert dispatcher.dispatcher_manager == mock_dispatcher_instance
|
||||
mock_dispatcher_manager.assert_called_once_with(
|
||||
mock_backend, mock_config_receiver, prefix="rev-gateway"
|
||||
mock_backend, mock_config_receiver,
|
||||
auth=mock_auth, prefix="rev-gateway", timeout=300,
|
||||
)
|
||||
|
||||
def test_message_dispatcher_service_mapping(self):
|
||||
"""Test MessageDispatcher service mapping contains expected services"""
|
||||
dispatcher = MessageDispatcher()
|
||||
|
||||
|
||||
expected_services = [
|
||||
"text-completion", "graph-rag", "agent", "embeddings",
|
||||
"graph-embeddings", "triples", "document-load", "text-load",
|
||||
"flow", "knowledge", "config", "librarian", "document-rag"
|
||||
"flow", "knowledge", "config", "librarian", "document-rag",
|
||||
]
|
||||
|
||||
|
||||
for service in expected_services:
|
||||
assert service in dispatcher.service_mapping
|
||||
|
||||
# Test specific mappings
|
||||
assert dispatcher.service_mapping["text-completion"] == "text-completion"
|
||||
|
||||
assert dispatcher.service_mapping["document-load"] == "document"
|
||||
assert dispatcher.service_mapping["text-load"] == "text-document"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_message_dispatcher_handle_message_without_dispatcher_manager(self):
|
||||
"""Test MessageDispatcher handle_message without dispatcher manager"""
|
||||
async def test_handle_message_without_dispatcher_manager(self):
|
||||
dispatcher = MessageDispatcher()
|
||||
|
||||
test_message = {
|
||||
"id": "test-123",
|
||||
"service": "test-service",
|
||||
"request": {"data": "test"}
|
||||
}
|
||||
|
||||
result = await dispatcher.handle_message(test_message)
|
||||
|
||||
assert result["id"] == "test-123"
|
||||
assert "error" in result["response"]
|
||||
assert "DispatcherManager not available" in result["response"]["error"]
|
||||
dispatcher.auth = MagicMock()
|
||||
dispatcher.auth.authenticate = AsyncMock(
|
||||
return_value=MagicMock(workspace="default")
|
||||
)
|
||||
|
||||
sender = AsyncMock()
|
||||
|
||||
await dispatcher.handle_message(
|
||||
{"id": "test-1", "service": "test", "request": {}},
|
||||
sender,
|
||||
)
|
||||
|
||||
sender.assert_called_once()
|
||||
sent = sender.call_args[0][0]
|
||||
assert sent["id"] == "test-1"
|
||||
assert sent["error"]["message"] == "DispatcherManager not available"
|
||||
assert sent["error"]["type"] == "error"
|
||||
assert sent["complete"] is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_message_dispatcher_handle_message_with_exception(self):
|
||||
"""Test MessageDispatcher handle_message with exception during processing"""
|
||||
mock_dispatcher_manager = MagicMock()
|
||||
mock_dispatcher_manager.invoke_global_service = AsyncMock(side_effect=Exception("Test error"))
|
||||
|
||||
async def test_handle_message_auth_failure(self):
|
||||
dispatcher = MessageDispatcher()
|
||||
dispatcher.dispatcher_manager = mock_dispatcher_manager
|
||||
|
||||
test_message = {
|
||||
"id": "test-456",
|
||||
"service": "text-completion",
|
||||
"request": {"prompt": "test"}
|
||||
}
|
||||
|
||||
with patch('trustgraph.gateway.dispatch.manager.global_dispatchers', {"text-completion": True}):
|
||||
result = await dispatcher.handle_message(test_message)
|
||||
|
||||
assert result["id"] == "test-456"
|
||||
assert "error" in result["response"]
|
||||
assert "Test error" in result["response"]["error"]
|
||||
dispatcher.auth = MagicMock()
|
||||
dispatcher.auth.authenticate = AsyncMock(
|
||||
side_effect=Exception("auth failure")
|
||||
)
|
||||
dispatcher.dispatcher_manager = MagicMock()
|
||||
|
||||
sender = AsyncMock()
|
||||
|
||||
await dispatcher.handle_message(
|
||||
{"id": "test-2", "token": "bad", "service": "test", "request": {}},
|
||||
sender,
|
||||
)
|
||||
|
||||
sender.assert_called_once()
|
||||
sent = sender.call_args[0][0]
|
||||
assert sent["id"] == "test-2"
|
||||
assert "auth failure" in sent["error"]["message"]
|
||||
assert sent["complete"] is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_message_dispatcher_handle_message_global_service(self):
|
||||
"""Test MessageDispatcher handle_message with global service"""
|
||||
mock_dispatcher_manager = MagicMock()
|
||||
mock_dispatcher_manager.invoke_global_service = AsyncMock()
|
||||
mock_responder = MagicMock()
|
||||
mock_responder.completed = True
|
||||
mock_responder.response = {"result": "success"}
|
||||
|
||||
async def test_handle_message_global_service(self):
|
||||
mock_dm = MagicMock()
|
||||
mock_dm.invoke_global_service = AsyncMock()
|
||||
|
||||
dispatcher = MessageDispatcher()
|
||||
dispatcher.dispatcher_manager = mock_dispatcher_manager
|
||||
|
||||
test_message = {
|
||||
"id": "test-789",
|
||||
"service": "text-completion",
|
||||
"request": {"prompt": "hello"}
|
||||
}
|
||||
|
||||
with patch('trustgraph.gateway.dispatch.manager.global_dispatchers', {"text-completion": True}):
|
||||
with patch('trustgraph.rev_gateway.dispatcher.WebSocketResponder', return_value=mock_responder):
|
||||
result = await dispatcher.handle_message(test_message)
|
||||
|
||||
assert result["id"] == "test-789"
|
||||
assert result["response"] == {"result": "success"}
|
||||
mock_dispatcher_manager.invoke_global_service.assert_called_once()
|
||||
dispatcher.dispatcher_manager = mock_dm
|
||||
dispatcher.auth = MagicMock()
|
||||
dispatcher.auth.authenticate = AsyncMock(
|
||||
return_value=MagicMock(workspace="ws1")
|
||||
)
|
||||
|
||||
sender = AsyncMock()
|
||||
|
||||
with patch(
|
||||
'trustgraph.gateway.dispatch.manager.global_dispatchers',
|
||||
{"text-completion": True},
|
||||
):
|
||||
await dispatcher.handle_message(
|
||||
{
|
||||
"id": "test-3",
|
||||
"token": "tg_key",
|
||||
"service": "text-completion",
|
||||
"request": {"prompt": "hello"},
|
||||
},
|
||||
sender,
|
||||
)
|
||||
|
||||
mock_dm.invoke_global_service.assert_called_once()
|
||||
args, kwargs = mock_dm.invoke_global_service.call_args
|
||||
assert args[0] == {"prompt": "hello"}
|
||||
assert args[2] == "text-completion"
|
||||
assert kwargs["workspace"] == "ws1"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_message_dispatcher_handle_message_flow_service(self):
|
||||
"""Test MessageDispatcher handle_message with flow service"""
|
||||
mock_dispatcher_manager = MagicMock()
|
||||
mock_dispatcher_manager.invoke_flow_service = AsyncMock()
|
||||
mock_responder = MagicMock()
|
||||
mock_responder.completed = True
|
||||
mock_responder.response = {"data": "flow_result"}
|
||||
|
||||
async def test_handle_message_flow_service(self):
|
||||
mock_dm = MagicMock()
|
||||
mock_dm.invoke_flow_service = AsyncMock()
|
||||
|
||||
dispatcher = MessageDispatcher()
|
||||
dispatcher.dispatcher_manager = mock_dispatcher_manager
|
||||
|
||||
test_message = {
|
||||
"id": "test-flow-123",
|
||||
"service": "document-rag",
|
||||
"request": {"query": "test"},
|
||||
"flow": "custom-flow"
|
||||
}
|
||||
|
||||
with patch('trustgraph.gateway.dispatch.manager.global_dispatchers', {}):
|
||||
with patch('trustgraph.rev_gateway.dispatcher.WebSocketResponder', return_value=mock_responder):
|
||||
result = await dispatcher.handle_message(test_message)
|
||||
|
||||
assert result["id"] == "test-flow-123"
|
||||
assert result["response"] == {"data": "flow_result"}
|
||||
mock_dispatcher_manager.invoke_flow_service.assert_called_once_with(
|
||||
{"query": "test"}, mock_responder, "custom-flow", "document-rag"
|
||||
dispatcher.dispatcher_manager = mock_dm
|
||||
dispatcher.auth = MagicMock()
|
||||
dispatcher.auth.authenticate = AsyncMock(
|
||||
return_value=MagicMock(workspace="ws2")
|
||||
)
|
||||
|
||||
sender = AsyncMock()
|
||||
|
||||
with patch(
|
||||
'trustgraph.gateway.dispatch.manager.global_dispatchers', {},
|
||||
):
|
||||
await dispatcher.handle_message(
|
||||
{
|
||||
"id": "test-4",
|
||||
"token": "tg_key",
|
||||
"service": "document-rag",
|
||||
"request": {"query": "test"},
|
||||
"flow": "my-flow",
|
||||
},
|
||||
sender,
|
||||
)
|
||||
|
||||
mock_dm.invoke_flow_service.assert_called_once_with(
|
||||
{"query": "test"}, ANY, "ws2", "my-flow", "document-rag",
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_message_dispatcher_handle_message_incomplete_response(self):
|
||||
"""Test MessageDispatcher handle_message with incomplete response"""
|
||||
mock_dispatcher_manager = MagicMock()
|
||||
mock_dispatcher_manager.invoke_flow_service = AsyncMock()
|
||||
mock_responder = MagicMock()
|
||||
mock_responder.completed = False
|
||||
mock_responder.response = None
|
||||
|
||||
async def test_handle_message_responder_sends_frames(self):
|
||||
mock_dm = MagicMock()
|
||||
|
||||
async def fake_invoke(data, responder, svc, workspace=None):
|
||||
await responder({"partial": 1}, False)
|
||||
await responder({"partial": 2}, True)
|
||||
|
||||
mock_dm.invoke_global_service = AsyncMock(side_effect=fake_invoke)
|
||||
|
||||
dispatcher = MessageDispatcher()
|
||||
dispatcher.dispatcher_manager = mock_dispatcher_manager
|
||||
|
||||
test_message = {
|
||||
"id": "test-incomplete",
|
||||
"service": "agent",
|
||||
"request": {"input": "test"}
|
||||
dispatcher.dispatcher_manager = mock_dm
|
||||
dispatcher.auth = MagicMock()
|
||||
dispatcher.auth.authenticate = AsyncMock(
|
||||
return_value=MagicMock(workspace="ws1")
|
||||
)
|
||||
|
||||
sender = AsyncMock()
|
||||
|
||||
with patch(
|
||||
'trustgraph.gateway.dispatch.manager.global_dispatchers',
|
||||
{"text-completion": True},
|
||||
):
|
||||
await dispatcher.handle_message(
|
||||
{
|
||||
"id": "test-5",
|
||||
"token": "tg_key",
|
||||
"service": "text-completion",
|
||||
"request": {"prompt": "hi"},
|
||||
},
|
||||
sender,
|
||||
)
|
||||
|
||||
assert sender.call_count == 2
|
||||
first = sender.call_args_list[0][0][0]
|
||||
second = sender.call_args_list[1][0][0]
|
||||
|
||||
assert first == {
|
||||
"id": "test-5", "response": {"partial": 1}, "complete": False,
|
||||
}
|
||||
assert second == {
|
||||
"id": "test-5", "response": {"partial": 2}, "complete": True,
|
||||
}
|
||||
|
||||
with patch('trustgraph.gateway.dispatch.manager.global_dispatchers', {}):
|
||||
with patch('trustgraph.rev_gateway.dispatcher.WebSocketResponder', return_value=mock_responder):
|
||||
result = await dispatcher.handle_message(test_message)
|
||||
|
||||
assert result["id"] == "test-incomplete"
|
||||
assert result["response"] == {"error": "No response received"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_message_dispatcher_shutdown(self):
|
||||
"""Test MessageDispatcher shutdown method"""
|
||||
import asyncio
|
||||
|
||||
async def test_handle_message_workspace_from_identity(self):
|
||||
mock_dm = MagicMock()
|
||||
mock_dm.invoke_flow_service = AsyncMock()
|
||||
|
||||
dispatcher = MessageDispatcher()
|
||||
|
||||
# Create actual async tasks
|
||||
dispatcher.dispatcher_manager = mock_dm
|
||||
dispatcher.auth = MagicMock()
|
||||
dispatcher.auth.authenticate = AsyncMock(
|
||||
return_value=MagicMock(workspace="derived-ws")
|
||||
)
|
||||
|
||||
sender = AsyncMock()
|
||||
|
||||
with patch(
|
||||
'trustgraph.gateway.dispatch.manager.global_dispatchers', {},
|
||||
):
|
||||
await dispatcher.handle_message(
|
||||
{
|
||||
"id": "test-6",
|
||||
"token": "tg_key",
|
||||
"service": "agent",
|
||||
"request": {"question": "test"},
|
||||
"flow": "default",
|
||||
},
|
||||
sender,
|
||||
)
|
||||
|
||||
args = mock_dm.invoke_flow_service.call_args[0]
|
||||
assert args[2] == "derived-ws"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_shutdown(self):
|
||||
dispatcher = MessageDispatcher()
|
||||
|
||||
async def dummy_task():
|
||||
await asyncio.sleep(0.01)
|
||||
return "done"
|
||||
|
||||
|
||||
task1 = asyncio.create_task(dummy_task())
|
||||
task2 = asyncio.create_task(dummy_task())
|
||||
dispatcher.active_tasks = {task1, task2}
|
||||
|
||||
# Call shutdown
|
||||
|
||||
await dispatcher.shutdown()
|
||||
|
||||
# Verify tasks were completed
|
||||
|
||||
assert task1.done()
|
||||
assert task2.done()
|
||||
assert len(dispatcher.active_tasks) == 2 # Tasks remain in set but are completed
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_message_dispatcher_shutdown_with_no_tasks(self):
|
||||
"""Test MessageDispatcher shutdown with no active tasks"""
|
||||
async def test_shutdown_with_no_tasks(self):
|
||||
dispatcher = MessageDispatcher()
|
||||
|
||||
# Call shutdown with no active tasks
|
||||
|
||||
await dispatcher.shutdown()
|
||||
|
||||
# Should complete without error
|
||||
assert dispatcher.active_tasks == set()
|
||||
|
||||
assert dispatcher.active_tasks == set()
|
||||
|
|
|
|||
|
|
@ -8,22 +8,38 @@ from unittest.mock import MagicMock, AsyncMock, patch, Mock
|
|||
from aiohttp import WSMsgType, ClientWebSocketResponse
|
||||
import json
|
||||
|
||||
from trustgraph.rev_gateway.service import ReverseGateway, parse_args, run
|
||||
from trustgraph.rev_gateway.service import ReverseGateway, run
|
||||
|
||||
|
||||
MOCK_PATCHES = [
|
||||
'trustgraph.rev_gateway.service.IamAuth',
|
||||
'trustgraph.rev_gateway.service.ConfigReceiver',
|
||||
'trustgraph.rev_gateway.service.MessageDispatcher',
|
||||
'trustgraph.rev_gateway.service.get_pubsub',
|
||||
]
|
||||
|
||||
|
||||
def make_gateway(**overrides):
|
||||
config = {"websocket_uri": "ws://localhost:7650/out"}
|
||||
config.update(overrides)
|
||||
return ReverseGateway(**config)
|
||||
|
||||
|
||||
class TestReverseGateway:
|
||||
"""Test cases for ReverseGateway class"""
|
||||
|
||||
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
|
||||
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
|
||||
@patch('trustgraph.rev_gateway.service.get_pubsub')
|
||||
def test_reverse_gateway_initialization_defaults(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
|
||||
"""Test ReverseGateway initialization with default parameters"""
|
||||
mock_backend = MagicMock()
|
||||
mock_get_pubsub.return_value = mock_backend
|
||||
|
||||
gateway = ReverseGateway()
|
||||
|
||||
@patch(*MOCK_PATCHES[0:1])
|
||||
@patch(*MOCK_PATCHES[1:2])
|
||||
@patch(*MOCK_PATCHES[2:3])
|
||||
@patch(*MOCK_PATCHES[3:4])
|
||||
def test_reverse_gateway_initialization_defaults(
|
||||
self, mock_get_pubsub, mock_dispatcher,
|
||||
mock_config_receiver, mock_iam_auth,
|
||||
):
|
||||
mock_get_pubsub.return_value = MagicMock()
|
||||
|
||||
gateway = make_gateway()
|
||||
|
||||
assert gateway.websocket_uri == "ws://localhost:7650/out"
|
||||
assert gateway.host == "localhost"
|
||||
assert gateway.port == 7650
|
||||
|
|
@ -33,25 +49,22 @@ class TestReverseGateway:
|
|||
assert gateway.max_workers == 10
|
||||
assert gateway.running is False
|
||||
assert gateway.reconnect_delay == 3.0
|
||||
assert gateway.pulsar_host == "pulsar://pulsar:6650"
|
||||
assert gateway.pulsar_api_key is None
|
||||
|
||||
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
|
||||
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
|
||||
@patch('trustgraph.rev_gateway.service.get_pubsub')
|
||||
def test_reverse_gateway_initialization_custom_params(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
|
||||
"""Test ReverseGateway initialization with custom parameters"""
|
||||
mock_backend = MagicMock()
|
||||
mock_get_pubsub.return_value = mock_backend
|
||||
|
||||
gateway = ReverseGateway(
|
||||
@patch(*MOCK_PATCHES[0:1])
|
||||
@patch(*MOCK_PATCHES[1:2])
|
||||
@patch(*MOCK_PATCHES[2:3])
|
||||
@patch(*MOCK_PATCHES[3:4])
|
||||
def test_reverse_gateway_initialization_custom_params(
|
||||
self, mock_get_pubsub, mock_dispatcher,
|
||||
mock_config_receiver, mock_iam_auth,
|
||||
):
|
||||
mock_get_pubsub.return_value = MagicMock()
|
||||
|
||||
gateway = make_gateway(
|
||||
websocket_uri="wss://example.com:8080/websocket",
|
||||
max_workers=20,
|
||||
pulsar_host="pulsar://custom:6650",
|
||||
pulsar_api_key="test-key",
|
||||
pulsar_listener="test-listener"
|
||||
)
|
||||
|
||||
|
||||
assert gateway.websocket_uri == "wss://example.com:8080/websocket"
|
||||
assert gateway.host == "example.com"
|
||||
assert gateway.port == 8080
|
||||
|
|
@ -59,340 +72,360 @@ class TestReverseGateway:
|
|||
assert gateway.path == "/websocket"
|
||||
assert gateway.url == "wss://example.com:8080/websocket"
|
||||
assert gateway.max_workers == 20
|
||||
assert gateway.pulsar_host == "pulsar://custom:6650"
|
||||
assert gateway.pulsar_api_key == "test-key"
|
||||
assert gateway.pulsar_listener == "test-listener"
|
||||
|
||||
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
|
||||
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
|
||||
@patch('trustgraph.rev_gateway.service.get_pubsub')
|
||||
def test_reverse_gateway_initialization_with_missing_path(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
|
||||
"""Test ReverseGateway initialization with WebSocket URI missing path"""
|
||||
mock_backend = MagicMock()
|
||||
mock_get_pubsub.return_value = mock_backend
|
||||
|
||||
gateway = ReverseGateway(websocket_uri="ws://example.com")
|
||||
|
||||
@patch(*MOCK_PATCHES[0:1])
|
||||
@patch(*MOCK_PATCHES[1:2])
|
||||
@patch(*MOCK_PATCHES[2:3])
|
||||
@patch(*MOCK_PATCHES[3:4])
|
||||
def test_reverse_gateway_initialization_with_missing_path(
|
||||
self, mock_get_pubsub, mock_dispatcher,
|
||||
mock_config_receiver, mock_iam_auth,
|
||||
):
|
||||
mock_get_pubsub.return_value = MagicMock()
|
||||
|
||||
gateway = make_gateway(websocket_uri="ws://example.com")
|
||||
|
||||
assert gateway.path == "/ws"
|
||||
assert gateway.url == "ws://example.com/ws"
|
||||
|
||||
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
|
||||
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
|
||||
@patch('trustgraph.rev_gateway.service.get_pubsub')
|
||||
def test_reverse_gateway_initialization_invalid_scheme(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
|
||||
"""Test ReverseGateway initialization with invalid WebSocket scheme"""
|
||||
@patch(*MOCK_PATCHES[0:1])
|
||||
@patch(*MOCK_PATCHES[1:2])
|
||||
@patch(*MOCK_PATCHES[2:3])
|
||||
@patch(*MOCK_PATCHES[3:4])
|
||||
def test_reverse_gateway_initialization_invalid_scheme(
|
||||
self, mock_get_pubsub, mock_dispatcher,
|
||||
mock_config_receiver, mock_iam_auth,
|
||||
):
|
||||
with pytest.raises(ValueError, match="WebSocket URI must use ws:// or wss:// scheme"):
|
||||
ReverseGateway(websocket_uri="http://example.com")
|
||||
make_gateway(websocket_uri="http://example.com")
|
||||
|
||||
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
|
||||
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
|
||||
@patch('trustgraph.rev_gateway.service.get_pubsub')
|
||||
def test_reverse_gateway_initialization_missing_hostname(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
|
||||
"""Test ReverseGateway initialization with missing hostname"""
|
||||
@patch(*MOCK_PATCHES[0:1])
|
||||
@patch(*MOCK_PATCHES[1:2])
|
||||
@patch(*MOCK_PATCHES[2:3])
|
||||
@patch(*MOCK_PATCHES[3:4])
|
||||
def test_reverse_gateway_initialization_missing_hostname(
|
||||
self, mock_get_pubsub, mock_dispatcher,
|
||||
mock_config_receiver, mock_iam_auth,
|
||||
):
|
||||
with pytest.raises(ValueError, match="WebSocket URI must include hostname"):
|
||||
ReverseGateway(websocket_uri="ws://")
|
||||
make_gateway(websocket_uri="ws://")
|
||||
|
||||
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
|
||||
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
|
||||
@patch('trustgraph.rev_gateway.service.get_pubsub')
|
||||
def test_reverse_gateway_pulsar_client_with_auth(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
|
||||
"""Test ReverseGateway creates backend with authentication"""
|
||||
@patch(*MOCK_PATCHES[0:1])
|
||||
@patch(*MOCK_PATCHES[1:2])
|
||||
@patch(*MOCK_PATCHES[2:3])
|
||||
@patch(*MOCK_PATCHES[3:4])
|
||||
def test_reverse_gateway_iam_auth_created(
|
||||
self, mock_get_pubsub, mock_dispatcher,
|
||||
mock_config_receiver, mock_iam_auth,
|
||||
):
|
||||
mock_backend = MagicMock()
|
||||
mock_get_pubsub.return_value = mock_backend
|
||||
|
||||
gateway = ReverseGateway(
|
||||
pulsar_api_key="test-key",
|
||||
pulsar_listener="test-listener"
|
||||
gateway = make_gateway(id="test-rev-gw")
|
||||
|
||||
mock_iam_auth.assert_called_once_with(
|
||||
backend=mock_backend,
|
||||
id="test-rev-gw",
|
||||
)
|
||||
|
||||
# Verify get_pubsub was called with the correct parameters
|
||||
mock_get_pubsub.assert_called_once_with(
|
||||
pulsar_host="pulsar://pulsar:6650",
|
||||
pulsar_api_key="test-key",
|
||||
pulsar_listener="test-listener"
|
||||
@patch(*MOCK_PATCHES[0:1])
|
||||
@patch(*MOCK_PATCHES[1:2])
|
||||
@patch(*MOCK_PATCHES[2:3])
|
||||
@patch(*MOCK_PATCHES[3:4])
|
||||
def test_reverse_gateway_config_receiver_gets_auth(
|
||||
self, mock_get_pubsub, mock_dispatcher,
|
||||
mock_config_receiver, mock_iam_auth,
|
||||
):
|
||||
mock_backend = MagicMock()
|
||||
mock_get_pubsub.return_value = mock_backend
|
||||
mock_auth_instance = MagicMock()
|
||||
mock_iam_auth.return_value = mock_auth_instance
|
||||
|
||||
gateway = make_gateway()
|
||||
|
||||
mock_config_receiver.assert_called_once_with(
|
||||
mock_backend, auth=mock_auth_instance,
|
||||
)
|
||||
|
||||
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
|
||||
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
|
||||
@patch('trustgraph.rev_gateway.service.get_pubsub')
|
||||
@patch(*MOCK_PATCHES[0:1])
|
||||
@patch(*MOCK_PATCHES[1:2])
|
||||
@patch(*MOCK_PATCHES[2:3])
|
||||
@patch(*MOCK_PATCHES[3:4])
|
||||
@patch('trustgraph.rev_gateway.service.ClientSession')
|
||||
@pytest.mark.asyncio
|
||||
async def test_reverse_gateway_connect_success(self, mock_session_class, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
|
||||
"""Test ReverseGateway successful connection"""
|
||||
mock_backend = MagicMock()
|
||||
mock_get_pubsub.return_value = mock_backend
|
||||
|
||||
async def test_reverse_gateway_connect_success(
|
||||
self, mock_session_class, mock_get_pubsub,
|
||||
mock_dispatcher, mock_config_receiver, mock_iam_auth,
|
||||
):
|
||||
mock_get_pubsub.return_value = MagicMock()
|
||||
|
||||
mock_session = AsyncMock()
|
||||
mock_ws = AsyncMock()
|
||||
mock_session.ws_connect.return_value = mock_ws
|
||||
mock_session_class.return_value = mock_session
|
||||
|
||||
gateway = ReverseGateway()
|
||||
|
||||
|
||||
gateway = make_gateway()
|
||||
|
||||
result = await gateway.connect()
|
||||
|
||||
|
||||
assert result is True
|
||||
assert gateway.session == mock_session
|
||||
assert gateway.ws == mock_ws
|
||||
mock_session.ws_connect.assert_called_once_with(gateway.url)
|
||||
|
||||
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
|
||||
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
|
||||
@patch('trustgraph.rev_gateway.service.get_pubsub')
|
||||
@patch(*MOCK_PATCHES[0:1])
|
||||
@patch(*MOCK_PATCHES[1:2])
|
||||
@patch(*MOCK_PATCHES[2:3])
|
||||
@patch(*MOCK_PATCHES[3:4])
|
||||
@patch('trustgraph.rev_gateway.service.ClientSession')
|
||||
@pytest.mark.asyncio
|
||||
async def test_reverse_gateway_connect_failure(self, mock_session_class, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
|
||||
"""Test ReverseGateway connection failure"""
|
||||
mock_backend = MagicMock()
|
||||
mock_get_pubsub.return_value = mock_backend
|
||||
|
||||
async def test_reverse_gateway_connect_failure(
|
||||
self, mock_session_class, mock_get_pubsub,
|
||||
mock_dispatcher, mock_config_receiver, mock_iam_auth,
|
||||
):
|
||||
mock_get_pubsub.return_value = MagicMock()
|
||||
|
||||
mock_session = AsyncMock()
|
||||
mock_session.ws_connect.side_effect = Exception("Connection failed")
|
||||
mock_session_class.return_value = mock_session
|
||||
|
||||
gateway = ReverseGateway()
|
||||
|
||||
|
||||
gateway = make_gateway()
|
||||
|
||||
result = await gateway.connect()
|
||||
|
||||
|
||||
assert result is False
|
||||
|
||||
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
|
||||
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
|
||||
@patch('trustgraph.rev_gateway.service.get_pubsub')
|
||||
@patch(*MOCK_PATCHES[0:1])
|
||||
@patch(*MOCK_PATCHES[1:2])
|
||||
@patch(*MOCK_PATCHES[2:3])
|
||||
@patch(*MOCK_PATCHES[3:4])
|
||||
@pytest.mark.asyncio
|
||||
async def test_reverse_gateway_disconnect(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
|
||||
"""Test ReverseGateway disconnect"""
|
||||
mock_backend = MagicMock()
|
||||
mock_get_pubsub.return_value = mock_backend
|
||||
|
||||
gateway = ReverseGateway()
|
||||
|
||||
# Mock websocket and session
|
||||
async def test_reverse_gateway_disconnect(
|
||||
self, mock_get_pubsub, mock_dispatcher,
|
||||
mock_config_receiver, mock_iam_auth,
|
||||
):
|
||||
mock_get_pubsub.return_value = MagicMock()
|
||||
|
||||
gateway = make_gateway()
|
||||
|
||||
mock_ws = AsyncMock()
|
||||
mock_ws.closed = False
|
||||
mock_session = AsyncMock()
|
||||
mock_session.closed = False
|
||||
|
||||
|
||||
gateway.ws = mock_ws
|
||||
gateway.session = mock_session
|
||||
|
||||
|
||||
await gateway.disconnect()
|
||||
|
||||
|
||||
mock_ws.close.assert_called_once()
|
||||
mock_session.close.assert_called_once()
|
||||
assert gateway.ws is None
|
||||
assert gateway.session is None
|
||||
|
||||
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
|
||||
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
|
||||
@patch('trustgraph.rev_gateway.service.get_pubsub')
|
||||
@patch(*MOCK_PATCHES[0:1])
|
||||
@patch(*MOCK_PATCHES[1:2])
|
||||
@patch(*MOCK_PATCHES[2:3])
|
||||
@patch(*MOCK_PATCHES[3:4])
|
||||
@pytest.mark.asyncio
|
||||
async def test_reverse_gateway_send_message(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
|
||||
"""Test ReverseGateway send message"""
|
||||
mock_backend = MagicMock()
|
||||
mock_get_pubsub.return_value = mock_backend
|
||||
|
||||
gateway = ReverseGateway()
|
||||
|
||||
# Mock websocket
|
||||
async def test_reverse_gateway_send_message(
|
||||
self, mock_get_pubsub, mock_dispatcher,
|
||||
mock_config_receiver, mock_iam_auth,
|
||||
):
|
||||
mock_get_pubsub.return_value = MagicMock()
|
||||
|
||||
gateway = make_gateway()
|
||||
|
||||
mock_ws = AsyncMock()
|
||||
mock_ws.closed = False
|
||||
gateway.ws = mock_ws
|
||||
|
||||
|
||||
test_message = {"id": "test", "data": "hello"}
|
||||
|
||||
|
||||
await gateway.send_message(test_message)
|
||||
|
||||
|
||||
mock_ws.send_str.assert_called_once_with(json.dumps(test_message))
|
||||
|
||||
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
|
||||
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
|
||||
@patch('trustgraph.rev_gateway.service.get_pubsub')
|
||||
@patch(*MOCK_PATCHES[0:1])
|
||||
@patch(*MOCK_PATCHES[1:2])
|
||||
@patch(*MOCK_PATCHES[2:3])
|
||||
@patch(*MOCK_PATCHES[3:4])
|
||||
@pytest.mark.asyncio
|
||||
async def test_reverse_gateway_send_message_closed_connection(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
|
||||
"""Test ReverseGateway send message with closed connection"""
|
||||
mock_backend = MagicMock()
|
||||
mock_get_pubsub.return_value = mock_backend
|
||||
|
||||
gateway = ReverseGateway()
|
||||
|
||||
# Mock closed websocket
|
||||
async def test_reverse_gateway_send_message_closed_connection(
|
||||
self, mock_get_pubsub, mock_dispatcher,
|
||||
mock_config_receiver, mock_iam_auth,
|
||||
):
|
||||
mock_get_pubsub.return_value = MagicMock()
|
||||
|
||||
gateway = make_gateway()
|
||||
|
||||
mock_ws = AsyncMock()
|
||||
mock_ws.closed = True
|
||||
gateway.ws = mock_ws
|
||||
|
||||
|
||||
test_message = {"id": "test", "data": "hello"}
|
||||
|
||||
|
||||
await gateway.send_message(test_message)
|
||||
|
||||
# Should not call send_str on closed connection
|
||||
|
||||
mock_ws.send_str.assert_not_called()
|
||||
|
||||
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
|
||||
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
|
||||
@patch('trustgraph.rev_gateway.service.get_pubsub')
|
||||
@patch(*MOCK_PATCHES[0:1])
|
||||
@patch(*MOCK_PATCHES[1:2])
|
||||
@patch(*MOCK_PATCHES[2:3])
|
||||
@patch(*MOCK_PATCHES[3:4])
|
||||
@pytest.mark.asyncio
|
||||
async def test_reverse_gateway_handle_message(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
|
||||
"""Test ReverseGateway handle message"""
|
||||
mock_backend = MagicMock()
|
||||
mock_get_pubsub.return_value = mock_backend
|
||||
|
||||
mock_dispatcher_instance = AsyncMock()
|
||||
mock_dispatcher_instance.handle_message.return_value = {"response": "success"}
|
||||
mock_dispatcher.return_value = mock_dispatcher_instance
|
||||
|
||||
gateway = ReverseGateway()
|
||||
|
||||
# Mock send_message
|
||||
gateway.send_message = AsyncMock()
|
||||
|
||||
test_message = '{"id": "test", "service": "test-service", "request": {"data": "test"}}'
|
||||
|
||||
await gateway.handle_message(test_message)
|
||||
|
||||
mock_dispatcher_instance.handle_message.assert_called_once_with({
|
||||
"id": "test",
|
||||
"service": "test-service",
|
||||
"request": {"data": "test"}
|
||||
})
|
||||
gateway.send_message.assert_called_once_with({"response": "success"})
|
||||
async def test_reverse_gateway_handle_message(
|
||||
self, mock_get_pubsub, mock_dispatcher,
|
||||
mock_config_receiver, mock_iam_auth,
|
||||
):
|
||||
mock_get_pubsub.return_value = MagicMock()
|
||||
|
||||
mock_dispatcher_instance = AsyncMock()
|
||||
mock_dispatcher.return_value = mock_dispatcher_instance
|
||||
|
||||
gateway = make_gateway()
|
||||
|
||||
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
|
||||
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
|
||||
@patch('trustgraph.rev_gateway.service.get_pubsub')
|
||||
@pytest.mark.asyncio
|
||||
async def test_reverse_gateway_handle_message_invalid_json(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
|
||||
"""Test ReverseGateway handle message with invalid JSON"""
|
||||
mock_backend = MagicMock()
|
||||
mock_get_pubsub.return_value = mock_backend
|
||||
|
||||
gateway = ReverseGateway()
|
||||
|
||||
# Mock send_message
|
||||
gateway.send_message = AsyncMock()
|
||||
|
||||
test_message = 'invalid json'
|
||||
|
||||
# Should not raise exception
|
||||
|
||||
test_message = '{"id": "test", "service": "test-service", "request": {"data": "test"}}'
|
||||
|
||||
await gateway.handle_message(test_message)
|
||||
|
||||
# Should not call send_message due to error
|
||||
|
||||
mock_dispatcher_instance.handle_message.assert_called_once_with(
|
||||
{
|
||||
"id": "test",
|
||||
"service": "test-service",
|
||||
"request": {"data": "test"},
|
||||
},
|
||||
gateway.send_message,
|
||||
)
|
||||
|
||||
@patch(*MOCK_PATCHES[0:1])
|
||||
@patch(*MOCK_PATCHES[1:2])
|
||||
@patch(*MOCK_PATCHES[2:3])
|
||||
@patch(*MOCK_PATCHES[3:4])
|
||||
@pytest.mark.asyncio
|
||||
async def test_reverse_gateway_handle_message_invalid_json(
|
||||
self, mock_get_pubsub, mock_dispatcher,
|
||||
mock_config_receiver, mock_iam_auth,
|
||||
):
|
||||
mock_get_pubsub.return_value = MagicMock()
|
||||
|
||||
gateway = make_gateway()
|
||||
|
||||
gateway.send_message = AsyncMock()
|
||||
|
||||
await gateway.handle_message('invalid json')
|
||||
|
||||
gateway.send_message.assert_not_called()
|
||||
|
||||
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
|
||||
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
|
||||
@patch('trustgraph.rev_gateway.service.get_pubsub')
|
||||
@patch(*MOCK_PATCHES[0:1])
|
||||
@patch(*MOCK_PATCHES[1:2])
|
||||
@patch(*MOCK_PATCHES[2:3])
|
||||
@patch(*MOCK_PATCHES[3:4])
|
||||
@pytest.mark.asyncio
|
||||
async def test_reverse_gateway_listen_text_message(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
|
||||
"""Test ReverseGateway listen with text message"""
|
||||
mock_backend = MagicMock()
|
||||
mock_get_pubsub.return_value = mock_backend
|
||||
|
||||
gateway = ReverseGateway()
|
||||
async def test_reverse_gateway_listen_text_message(
|
||||
self, mock_get_pubsub, mock_dispatcher,
|
||||
mock_config_receiver, mock_iam_auth,
|
||||
):
|
||||
mock_get_pubsub.return_value = MagicMock()
|
||||
|
||||
gateway = make_gateway()
|
||||
gateway.running = True
|
||||
|
||||
# Mock websocket
|
||||
|
||||
mock_ws = AsyncMock()
|
||||
mock_ws.closed = False
|
||||
gateway.ws = mock_ws
|
||||
|
||||
# Mock handle_message
|
||||
|
||||
gateway.handle_message = AsyncMock()
|
||||
|
||||
# Mock message
|
||||
|
||||
mock_msg = MagicMock()
|
||||
mock_msg.type = WSMsgType.TEXT
|
||||
mock_msg.data = '{"test": "message"}'
|
||||
|
||||
# Mock receive to return message once, then raise exception to stop loop
|
||||
|
||||
mock_ws.receive.side_effect = [mock_msg, Exception("Test stop")]
|
||||
|
||||
# listen() catches exceptions and breaks, so no exception should be raised
|
||||
|
||||
await gateway.listen()
|
||||
|
||||
|
||||
gateway.handle_message.assert_called_once_with('{"test": "message"}')
|
||||
|
||||
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
|
||||
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
|
||||
@patch('trustgraph.rev_gateway.service.get_pubsub')
|
||||
@patch(*MOCK_PATCHES[0:1])
|
||||
@patch(*MOCK_PATCHES[1:2])
|
||||
@patch(*MOCK_PATCHES[2:3])
|
||||
@patch(*MOCK_PATCHES[3:4])
|
||||
@pytest.mark.asyncio
|
||||
async def test_reverse_gateway_listen_binary_message(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
|
||||
"""Test ReverseGateway listen with binary message"""
|
||||
mock_backend = MagicMock()
|
||||
mock_get_pubsub.return_value = mock_backend
|
||||
|
||||
gateway = ReverseGateway()
|
||||
async def test_reverse_gateway_listen_binary_message(
|
||||
self, mock_get_pubsub, mock_dispatcher,
|
||||
mock_config_receiver, mock_iam_auth,
|
||||
):
|
||||
mock_get_pubsub.return_value = MagicMock()
|
||||
|
||||
gateway = make_gateway()
|
||||
gateway.running = True
|
||||
|
||||
# Mock websocket
|
||||
|
||||
mock_ws = AsyncMock()
|
||||
mock_ws.closed = False
|
||||
gateway.ws = mock_ws
|
||||
|
||||
# Mock handle_message
|
||||
|
||||
gateway.handle_message = AsyncMock()
|
||||
|
||||
# Mock message
|
||||
|
||||
mock_msg = MagicMock()
|
||||
mock_msg.type = WSMsgType.BINARY
|
||||
mock_msg.data = b'{"test": "binary"}'
|
||||
|
||||
# Mock receive to return message once, then raise exception to stop loop
|
||||
|
||||
mock_ws.receive.side_effect = [mock_msg, Exception("Test stop")]
|
||||
|
||||
# listen() catches exceptions and breaks, so no exception should be raised
|
||||
|
||||
await gateway.listen()
|
||||
|
||||
|
||||
gateway.handle_message.assert_called_once_with('{"test": "binary"}')
|
||||
|
||||
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
|
||||
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
|
||||
@patch('trustgraph.rev_gateway.service.get_pubsub')
|
||||
@patch(*MOCK_PATCHES[0:1])
|
||||
@patch(*MOCK_PATCHES[1:2])
|
||||
@patch(*MOCK_PATCHES[2:3])
|
||||
@patch(*MOCK_PATCHES[3:4])
|
||||
@pytest.mark.asyncio
|
||||
async def test_reverse_gateway_listen_close_message(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
|
||||
"""Test ReverseGateway listen with close message"""
|
||||
mock_backend = MagicMock()
|
||||
mock_get_pubsub.return_value = mock_backend
|
||||
|
||||
gateway = ReverseGateway()
|
||||
async def test_reverse_gateway_listen_close_message(
|
||||
self, mock_get_pubsub, mock_dispatcher,
|
||||
mock_config_receiver, mock_iam_auth,
|
||||
):
|
||||
mock_get_pubsub.return_value = MagicMock()
|
||||
|
||||
gateway = make_gateway()
|
||||
gateway.running = True
|
||||
|
||||
# Mock websocket
|
||||
|
||||
mock_ws = AsyncMock()
|
||||
mock_ws.closed = False
|
||||
gateway.ws = mock_ws
|
||||
|
||||
# Mock handle_message
|
||||
|
||||
gateway.handle_message = AsyncMock()
|
||||
|
||||
# Mock message
|
||||
|
||||
mock_msg = MagicMock()
|
||||
mock_msg.type = WSMsgType.CLOSE
|
||||
|
||||
# Mock receive to return close message
|
||||
|
||||
mock_ws.receive.return_value = mock_msg
|
||||
|
||||
|
||||
await gateway.listen()
|
||||
|
||||
# Should not call handle_message for close message
|
||||
|
||||
gateway.handle_message.assert_not_called()
|
||||
|
||||
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
|
||||
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
|
||||
@patch('trustgraph.rev_gateway.service.get_pubsub')
|
||||
@patch(*MOCK_PATCHES[0:1])
|
||||
@patch(*MOCK_PATCHES[1:2])
|
||||
@patch(*MOCK_PATCHES[2:3])
|
||||
@patch(*MOCK_PATCHES[3:4])
|
||||
@pytest.mark.asyncio
|
||||
async def test_reverse_gateway_shutdown(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
|
||||
"""Test ReverseGateway shutdown"""
|
||||
async def test_reverse_gateway_shutdown(
|
||||
self, mock_get_pubsub, mock_dispatcher,
|
||||
mock_config_receiver, mock_iam_auth,
|
||||
):
|
||||
mock_backend = MagicMock()
|
||||
mock_get_pubsub.return_value = mock_backend
|
||||
|
||||
mock_dispatcher_instance = AsyncMock()
|
||||
mock_dispatcher.return_value = mock_dispatcher_instance
|
||||
|
||||
gateway = ReverseGateway()
|
||||
gateway = make_gateway()
|
||||
gateway.running = True
|
||||
|
||||
# Mock disconnect
|
||||
gateway.disconnect = AsyncMock()
|
||||
|
||||
await gateway.shutdown()
|
||||
|
|
@ -402,46 +435,50 @@ class TestReverseGateway:
|
|||
gateway.disconnect.assert_called_once()
|
||||
mock_backend.close.assert_called_once()
|
||||
|
||||
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
|
||||
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
|
||||
@patch('trustgraph.rev_gateway.service.get_pubsub')
|
||||
def test_reverse_gateway_stop(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
|
||||
"""Test ReverseGateway stop"""
|
||||
mock_backend = MagicMock()
|
||||
mock_get_pubsub.return_value = mock_backend
|
||||
|
||||
gateway = ReverseGateway()
|
||||
@patch(*MOCK_PATCHES[0:1])
|
||||
@patch(*MOCK_PATCHES[1:2])
|
||||
@patch(*MOCK_PATCHES[2:3])
|
||||
@patch(*MOCK_PATCHES[3:4])
|
||||
def test_reverse_gateway_stop(
|
||||
self, mock_get_pubsub, mock_dispatcher,
|
||||
mock_config_receiver, mock_iam_auth,
|
||||
):
|
||||
mock_get_pubsub.return_value = MagicMock()
|
||||
|
||||
gateway = make_gateway()
|
||||
gateway.running = True
|
||||
|
||||
|
||||
gateway.stop()
|
||||
|
||||
|
||||
assert gateway.running is False
|
||||
|
||||
|
||||
class TestReverseGatewayRun:
|
||||
"""Test cases for ReverseGateway run method"""
|
||||
|
||||
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
|
||||
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
|
||||
@patch('trustgraph.rev_gateway.service.get_pubsub')
|
||||
@patch(*MOCK_PATCHES[0:1])
|
||||
@patch(*MOCK_PATCHES[1:2])
|
||||
@patch(*MOCK_PATCHES[2:3])
|
||||
@patch(*MOCK_PATCHES[3:4])
|
||||
@pytest.mark.asyncio
|
||||
async def test_reverse_gateway_run_successful_cycle(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
|
||||
"""Test ReverseGateway run method with successful connect/listen cycle"""
|
||||
mock_backend = MagicMock()
|
||||
mock_get_pubsub.return_value = mock_backend
|
||||
|
||||
async def test_reverse_gateway_run_successful_cycle(
|
||||
self, mock_get_pubsub, mock_dispatcher,
|
||||
mock_config_receiver, mock_iam_auth,
|
||||
):
|
||||
mock_get_pubsub.return_value = MagicMock()
|
||||
|
||||
mock_auth_instance = AsyncMock()
|
||||
mock_iam_auth.return_value = mock_auth_instance
|
||||
|
||||
mock_config_receiver_instance = AsyncMock()
|
||||
mock_config_receiver.return_value = mock_config_receiver_instance
|
||||
|
||||
gateway = ReverseGateway()
|
||||
|
||||
# Mock methods
|
||||
gateway.connect = AsyncMock(return_value=True)
|
||||
|
||||
gateway = make_gateway()
|
||||
|
||||
gateway.listen = AsyncMock()
|
||||
gateway.disconnect = AsyncMock()
|
||||
gateway.shutdown = AsyncMock()
|
||||
|
||||
# Stop after one iteration
|
||||
|
||||
call_count = 0
|
||||
async def mock_connect():
|
||||
nonlocal call_count
|
||||
|
|
@ -451,91 +488,13 @@ class TestReverseGatewayRun:
|
|||
else:
|
||||
gateway.running = False
|
||||
return False
|
||||
|
||||
|
||||
gateway.connect = mock_connect
|
||||
|
||||
|
||||
await gateway.run()
|
||||
|
||||
|
||||
mock_auth_instance.start.assert_called_once()
|
||||
mock_config_receiver_instance.start.assert_called_once()
|
||||
gateway.listen.assert_called_once()
|
||||
# disconnect is called twice: once in the main loop, once in shutdown
|
||||
assert gateway.disconnect.call_count == 2
|
||||
gateway.shutdown.assert_called_once()
|
||||
|
||||
|
||||
class TestReverseGatewayArgs:
|
||||
"""Test cases for argument parsing and run function"""
|
||||
|
||||
def test_parse_args_defaults(self):
|
||||
"""Test parse_args with default values"""
|
||||
import sys
|
||||
|
||||
# Mock sys.argv
|
||||
original_argv = sys.argv
|
||||
sys.argv = ['reverse-gateway']
|
||||
|
||||
try:
|
||||
args = parse_args()
|
||||
|
||||
assert args.websocket_uri is None
|
||||
assert args.max_workers == 10
|
||||
assert args.pulsar_host is None
|
||||
assert args.pulsar_api_key is None
|
||||
assert args.pulsar_listener is None
|
||||
finally:
|
||||
sys.argv = original_argv
|
||||
|
||||
def test_parse_args_custom_values(self):
|
||||
"""Test parse_args with custom values"""
|
||||
import sys
|
||||
|
||||
# Mock sys.argv
|
||||
original_argv = sys.argv
|
||||
sys.argv = [
|
||||
'reverse-gateway',
|
||||
'--websocket-uri', 'ws://custom:8080/ws',
|
||||
'--max-workers', '20',
|
||||
'--pulsar-host', 'pulsar://custom:6650',
|
||||
'--pulsar-api-key', 'test-key',
|
||||
'--pulsar-listener', 'test-listener'
|
||||
]
|
||||
|
||||
try:
|
||||
args = parse_args()
|
||||
|
||||
assert args.websocket_uri == 'ws://custom:8080/ws'
|
||||
assert args.max_workers == 20
|
||||
assert args.pulsar_host == 'pulsar://custom:6650'
|
||||
assert args.pulsar_api_key == 'test-key'
|
||||
assert args.pulsar_listener == 'test-listener'
|
||||
finally:
|
||||
sys.argv = original_argv
|
||||
|
||||
@patch('trustgraph.rev_gateway.service.ReverseGateway')
|
||||
@patch('asyncio.run')
|
||||
def test_run_function(self, mock_asyncio_run, mock_gateway_class):
|
||||
"""Test run function"""
|
||||
import sys
|
||||
|
||||
# Mock sys.argv
|
||||
original_argv = sys.argv
|
||||
sys.argv = ['reverse-gateway', '--max-workers', '15']
|
||||
|
||||
try:
|
||||
mock_gateway_instance = MagicMock()
|
||||
mock_gateway_instance.url = "ws://localhost:7650/out"
|
||||
mock_gateway_instance.pulsar_host = "pulsar://pulsar:6650"
|
||||
mock_gateway_class.return_value = mock_gateway_instance
|
||||
|
||||
run()
|
||||
|
||||
mock_gateway_class.assert_called_once_with(
|
||||
websocket_uri=None,
|
||||
max_workers=15,
|
||||
pulsar_host=None,
|
||||
pulsar_api_key=None,
|
||||
pulsar_listener=None
|
||||
)
|
||||
mock_asyncio_run.assert_called_once_with(mock_gateway_instance.run())
|
||||
finally:
|
||||
sys.argv = original_argv
|
||||
|
|
@ -413,8 +413,8 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
# Assert
|
||||
expected_collection = 'd_cache_user_cache_collection_3' # 3 dimensions
|
||||
|
||||
# Verify collection existence is checked on each write
|
||||
mock_qdrant_instance.collection_exists.assert_called_once_with(expected_collection)
|
||||
# Second write uses cached collection state — no collection_exists check
|
||||
mock_qdrant_instance.collection_exists.assert_not_called()
|
||||
|
||||
# But upsert should still be called
|
||||
mock_qdrant_instance.upsert.assert_called_once()
|
||||
|
|
|
|||
|
|
@ -125,13 +125,13 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
|
||||
processor = Processor(**config)
|
||||
|
||||
processor.ensure_collection("test_collection", 384)
|
||||
await processor.ensure_collection("test_collection", 384)
|
||||
|
||||
mock_qdrant_instance.collection_exists.assert_called_once_with("test_collection")
|
||||
mock_qdrant_instance.create_collection.assert_called_once()
|
||||
|
||||
# Verify the collection is cached
|
||||
assert "test_collection" in processor.created_collections
|
||||
assert "test_collection" in processor._known_collections
|
||||
|
||||
@patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient')
|
||||
async def test_ensure_collection_skips_existing(self, mock_qdrant_client):
|
||||
|
|
@ -149,7 +149,7 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
|
||||
processor = Processor(**config)
|
||||
|
||||
processor.ensure_collection("existing_collection", 384)
|
||||
await processor.ensure_collection("existing_collection", 384)
|
||||
|
||||
mock_qdrant_instance.collection_exists.assert_called_once()
|
||||
mock_qdrant_instance.create_collection.assert_not_called()
|
||||
|
|
@ -168,9 +168,9 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
processor.created_collections.add("cached_collection")
|
||||
processor._known_collections.add("cached_collection")
|
||||
|
||||
processor.ensure_collection("cached_collection", 384)
|
||||
await processor.ensure_collection("cached_collection", 384)
|
||||
|
||||
# Should not check or create - just return
|
||||
mock_qdrant_instance.collection_exists.assert_not_called()
|
||||
|
|
@ -391,7 +391,7 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
processor.created_collections.add('rows_test_workspace_test_collection_schema1_384')
|
||||
processor._known_collections.add('rows_test_workspace_test_collection_schema1_384')
|
||||
|
||||
await processor.delete_collection('test_workspace', 'test_collection')
|
||||
|
||||
|
|
@ -399,7 +399,7 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
assert mock_qdrant_instance.delete_collection.call_count == 2
|
||||
|
||||
# Verify the cached collection was removed
|
||||
assert 'rows_test_workspace_test_collection_schema1_384' not in processor.created_collections
|
||||
assert 'rows_test_workspace_test_collection_schema1_384' not in processor._known_collections
|
||||
|
||||
@patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient')
|
||||
async def test_delete_collection_schema(self, mock_qdrant_client):
|
||||
|
|
|
|||
|
|
@ -121,10 +121,13 @@ class TestRowsCassandraStorageLogic:
|
|||
@pytest.mark.asyncio
|
||||
async def test_schema_config_parsing(self):
|
||||
"""Test parsing of schema configurations"""
|
||||
import asyncio
|
||||
processor = MagicMock()
|
||||
processor.schemas = {}
|
||||
processor.config_key = "schema"
|
||||
processor.registered_partitions = set()
|
||||
processor._setup_lock = asyncio.Lock()
|
||||
processor._apply_schema_config = Processor._apply_schema_config.__get__(processor, Processor)
|
||||
processor.on_schema_config = Processor.on_schema_config.__get__(processor, Processor)
|
||||
|
||||
# Create test configuration
|
||||
|
|
|
|||
|
|
@ -2,6 +2,8 @@
|
|||
Tests for Cassandra triples storage service
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch, AsyncMock
|
||||
|
||||
|
|
@ -24,12 +26,13 @@ class TestCassandraStorageProcessor:
|
|||
assert processor.cassandra_host == ['cassandra'] # Updated default
|
||||
assert processor.cassandra_username is None
|
||||
assert processor.cassandra_password is None
|
||||
assert processor.table is None
|
||||
assert processor._connections == {}
|
||||
assert isinstance(processor._conn_lock, asyncio.Lock)
|
||||
|
||||
def test_processor_initialization_with_custom_params(self):
|
||||
"""Test processor initialization with custom parameters (new cassandra_* names)"""
|
||||
taskgroup_mock = MagicMock()
|
||||
|
||||
|
||||
processor = Processor(
|
||||
taskgroup=taskgroup_mock,
|
||||
id='custom-storage',
|
||||
|
|
@ -37,11 +40,12 @@ class TestCassandraStorageProcessor:
|
|||
cassandra_username='testuser',
|
||||
cassandra_password='testpass'
|
||||
)
|
||||
|
||||
|
||||
assert processor.cassandra_host == ['cassandra.example.com']
|
||||
assert processor.cassandra_username == 'testuser'
|
||||
assert processor.cassandra_password == 'testpass'
|
||||
assert processor.table is None
|
||||
assert processor._connections == {}
|
||||
assert isinstance(processor._conn_lock, asyncio.Lock)
|
||||
|
||||
def test_processor_initialization_with_partial_auth(self):
|
||||
"""Test processor initialization with only username (no password)"""
|
||||
|
|
@ -92,6 +96,7 @@ class TestCassandraStorageProcessor:
|
|||
"""Test table switching logic when authentication is provided"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_tg_instance = MagicMock()
|
||||
mock_tg_instance.async_insert = AsyncMock()
|
||||
mock_kg_class.return_value = mock_tg_instance
|
||||
|
||||
processor = Processor(
|
||||
|
|
@ -114,7 +119,6 @@ class TestCassandraStorageProcessor:
|
|||
username='testuser',
|
||||
password='testpass'
|
||||
)
|
||||
assert processor.table == 'user1'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.storage.triples.cassandra.write.EntityCentricKnowledgeGraph')
|
||||
|
|
@ -122,6 +126,7 @@ class TestCassandraStorageProcessor:
|
|||
"""Test table switching logic when no authentication is provided"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_tg_instance = MagicMock()
|
||||
mock_tg_instance.async_insert = AsyncMock()
|
||||
mock_kg_class.return_value = mock_tg_instance
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
|
@ -138,7 +143,6 @@ class TestCassandraStorageProcessor:
|
|||
hosts=['cassandra'], # Updated default
|
||||
keyspace='user2'
|
||||
)
|
||||
assert processor.table == 'user2'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.storage.triples.cassandra.write.EntityCentricKnowledgeGraph')
|
||||
|
|
@ -146,6 +150,7 @@ class TestCassandraStorageProcessor:
|
|||
"""Test that TrustGraph is not recreated when table hasn't changed"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_tg_instance = MagicMock()
|
||||
mock_tg_instance.async_insert = AsyncMock()
|
||||
mock_kg_class.return_value = mock_tg_instance
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
|
@ -169,6 +174,7 @@ class TestCassandraStorageProcessor:
|
|||
"""Test that triples are properly inserted into Cassandra"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_tg_instance = MagicMock()
|
||||
mock_tg_instance.async_insert = AsyncMock()
|
||||
mock_kg_class.return_value = mock_tg_instance
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
|
@ -208,12 +214,12 @@ class TestCassandraStorageProcessor:
|
|||
await processor.store_triples('user1', mock_message)
|
||||
|
||||
# Verify both triples were inserted (with g=, otype=, dtype=, lang= parameters)
|
||||
assert mock_tg_instance.insert.call_count == 2
|
||||
mock_tg_instance.insert.assert_any_call(
|
||||
assert mock_tg_instance.async_insert.call_count == 2
|
||||
mock_tg_instance.async_insert.assert_any_call(
|
||||
'collection1', 'subject1', 'predicate1', 'object1',
|
||||
g=DEFAULT_GRAPH, otype='l', dtype='', lang=''
|
||||
)
|
||||
mock_tg_instance.insert.assert_any_call(
|
||||
mock_tg_instance.async_insert.assert_any_call(
|
||||
'collection1', 'subject2', 'predicate2', 'object2',
|
||||
g=DEFAULT_GRAPH, otype='l', dtype='', lang=''
|
||||
)
|
||||
|
|
@ -224,6 +230,7 @@ class TestCassandraStorageProcessor:
|
|||
"""Test behavior when message has no triples"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_tg_instance = MagicMock()
|
||||
mock_tg_instance.async_insert = AsyncMock()
|
||||
mock_kg_class.return_value = mock_tg_instance
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
|
@ -236,19 +243,17 @@ class TestCassandraStorageProcessor:
|
|||
await processor.store_triples('user1', mock_message)
|
||||
|
||||
# Verify no triples were inserted
|
||||
mock_tg_instance.insert.assert_not_called()
|
||||
mock_tg_instance.async_insert.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.storage.triples.cassandra.write.EntityCentricKnowledgeGraph')
|
||||
@patch('trustgraph.storage.triples.cassandra.write.time.sleep')
|
||||
async def test_exception_handling_with_retry(self, mock_sleep, mock_kg_class):
|
||||
async def test_exception_handling_on_connection_failure(self, mock_kg_class):
|
||||
"""Test exception handling during TrustGraph creation"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_kg_class.side_effect = Exception("Connection failed")
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
# Create mock message
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.collection = 'collection1'
|
||||
mock_message.triples = []
|
||||
|
|
@ -256,9 +261,6 @@ class TestCassandraStorageProcessor:
|
|||
with pytest.raises(Exception, match="Connection failed"):
|
||||
await processor.store_triples('user1', mock_message)
|
||||
|
||||
# Verify sleep was called before re-raising
|
||||
mock_sleep.assert_called_once_with(1)
|
||||
|
||||
def test_add_args_method(self):
|
||||
"""Test that add_args properly configures argument parser"""
|
||||
from argparse import ArgumentParser
|
||||
|
|
@ -359,8 +361,6 @@ class TestCassandraStorageProcessor:
|
|||
mock_message1.triples = []
|
||||
|
||||
await processor.store_triples('user1', mock_message1)
|
||||
assert processor.table == 'user1'
|
||||
assert processor.tg == mock_tg_instance1
|
||||
|
||||
# Second message with different table
|
||||
mock_message2 = MagicMock()
|
||||
|
|
@ -368,11 +368,11 @@ class TestCassandraStorageProcessor:
|
|||
mock_message2.triples = []
|
||||
|
||||
await processor.store_triples('user2', mock_message2)
|
||||
assert processor.table == 'user2'
|
||||
assert processor.tg == mock_tg_instance2
|
||||
|
||||
# Verify TrustGraph was created twice for different tables
|
||||
# Verify TrustGraph was created twice for different workspaces
|
||||
assert mock_kg_class.call_count == 2
|
||||
mock_kg_class.assert_any_call(hosts=['cassandra'], keyspace='user1')
|
||||
mock_kg_class.assert_any_call(hosts=['cassandra'], keyspace='user2')
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.storage.triples.cassandra.write.EntityCentricKnowledgeGraph')
|
||||
|
|
@ -380,6 +380,7 @@ class TestCassandraStorageProcessor:
|
|||
"""Test storing triples with special characters and unicode"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_tg_instance = MagicMock()
|
||||
mock_tg_instance.async_insert = AsyncMock()
|
||||
mock_kg_class.return_value = mock_tg_instance
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
|
@ -405,7 +406,7 @@ class TestCassandraStorageProcessor:
|
|||
await processor.store_triples('test_workspace', mock_message)
|
||||
|
||||
# Verify the triple was inserted with special characters preserved
|
||||
mock_tg_instance.insert.assert_called_once_with(
|
||||
mock_tg_instance.async_insert.assert_called_once_with(
|
||||
'test_collection',
|
||||
'subject with spaces & symbols',
|
||||
'predicate:with/colons',
|
||||
|
|
@ -418,29 +419,29 @@ class TestCassandraStorageProcessor:
|
|||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.storage.triples.cassandra.write.EntityCentricKnowledgeGraph')
|
||||
async def test_store_triples_preserves_old_table_on_exception(self, mock_kg_class):
|
||||
"""Test that table remains unchanged when TrustGraph creation fails"""
|
||||
async def test_connection_failure_does_not_cache_stale_state(self, mock_kg_class):
|
||||
"""Test that a failed connection doesn't leave stale cached state"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_good_instance = MagicMock()
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
# Set an initial table
|
||||
processor.table = ('old_user', 'old_collection')
|
||||
|
||||
# Mock TrustGraph to raise exception
|
||||
mock_kg_class.side_effect = Exception("Connection failed")
|
||||
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.collection = 'new_collection'
|
||||
mock_message.metadata.collection = 'collection1'
|
||||
mock_message.triples = []
|
||||
|
||||
# First call fails
|
||||
mock_kg_class.side_effect = Exception("Connection failed")
|
||||
with pytest.raises(Exception, match="Connection failed"):
|
||||
await processor.store_triples('new_user', mock_message)
|
||||
await processor.store_triples('user1', mock_message)
|
||||
|
||||
# Table should remain unchanged since self.table = table happens after try/except
|
||||
assert processor.table == ('old_user', 'old_collection')
|
||||
# TrustGraph should be set to None though
|
||||
assert processor.tg is None
|
||||
# Second call succeeds — should retry connection, not use stale state
|
||||
mock_kg_class.side_effect = None
|
||||
mock_kg_class.return_value = mock_good_instance
|
||||
await processor.store_triples('user1', mock_message)
|
||||
|
||||
# Connection was attempted twice (failed + succeeded)
|
||||
assert mock_kg_class.call_count == 2
|
||||
|
||||
|
||||
class TestCassandraPerformanceOptimizations:
|
||||
|
|
@ -452,6 +453,7 @@ class TestCassandraPerformanceOptimizations:
|
|||
"""Test that legacy mode still works with single table"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_tg_instance = MagicMock()
|
||||
mock_tg_instance.async_insert = AsyncMock()
|
||||
mock_kg_class.return_value = mock_tg_instance
|
||||
|
||||
with patch.dict('os.environ', {'CASSANDRA_USE_LEGACY': 'true'}):
|
||||
|
|
@ -472,6 +474,7 @@ class TestCassandraPerformanceOptimizations:
|
|||
"""Test that optimized mode uses multi-table schema"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_tg_instance = MagicMock()
|
||||
mock_tg_instance.async_insert = AsyncMock()
|
||||
mock_kg_class.return_value = mock_tg_instance
|
||||
|
||||
with patch.dict('os.environ', {'CASSANDRA_USE_LEGACY': 'false'}):
|
||||
|
|
@ -492,6 +495,7 @@ class TestCassandraPerformanceOptimizations:
|
|||
"""Test that all tables stay consistent during batch writes"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_tg_instance = MagicMock()
|
||||
mock_tg_instance.async_insert = AsyncMock()
|
||||
mock_kg_class.return_value = mock_tg_instance
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
|
@ -517,7 +521,7 @@ class TestCassandraPerformanceOptimizations:
|
|||
await processor.store_triples('user1', mock_message)
|
||||
|
||||
# Verify insert was called for the triple (implementation details tested in KnowledgeGraph)
|
||||
mock_tg_instance.insert.assert_called_once_with(
|
||||
mock_tg_instance.async_insert.assert_called_once_with(
|
||||
'collection1', 'test_subject', 'test_predicate', 'test_object',
|
||||
g=DEFAULT_GRAPH, otype='l', dtype='', lang=''
|
||||
)
|
||||
|
|
|
|||
|
|
@ -89,7 +89,8 @@ class TestSanitizeName:
|
|||
|
||||
class TestFindCollection:
|
||||
|
||||
def test_finds_matching_collection(self):
|
||||
@pytest.mark.asyncio
|
||||
async def test_finds_matching_collection(self):
|
||||
proc = _make_processor()
|
||||
mock_coll = MagicMock()
|
||||
mock_coll.name = "rows_test_workspace_test_col_customers_384"
|
||||
|
|
@ -98,11 +99,12 @@ class TestFindCollection:
|
|||
mock_collections.collections = [mock_coll]
|
||||
proc.qdrant.get_collections.return_value = mock_collections
|
||||
|
||||
result = proc.find_collection("test-workspace", "test-col", "customers")
|
||||
result = await proc.find_collection("test-workspace", "test-col", "customers")
|
||||
|
||||
assert result == "rows_test_workspace_test_col_customers_384"
|
||||
|
||||
def test_returns_none_when_no_match(self):
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_none_when_no_match(self):
|
||||
proc = _make_processor()
|
||||
mock_coll = MagicMock()
|
||||
mock_coll.name = "rows_other_workspace_other_col_schema_768"
|
||||
|
|
@ -111,14 +113,15 @@ class TestFindCollection:
|
|||
mock_collections.collections = [mock_coll]
|
||||
proc.qdrant.get_collections.return_value = mock_collections
|
||||
|
||||
result = proc.find_collection("test-workspace", "test-col", "customers")
|
||||
result = await proc.find_collection("test-workspace", "test-col", "customers")
|
||||
assert result is None
|
||||
|
||||
def test_returns_none_on_error(self):
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_none_on_error(self):
|
||||
proc = _make_processor()
|
||||
proc.qdrant.get_collections.side_effect = Exception("connection error")
|
||||
|
||||
result = proc.find_collection("workspace", "col", "schema")
|
||||
result = await proc.find_collection("workspace", "col", "schema")
|
||||
assert result is None
|
||||
|
||||
|
||||
|
|
@ -139,7 +142,7 @@ class TestQueryRowEmbeddings:
|
|||
@pytest.mark.asyncio
|
||||
async def test_no_collection_returns_empty(self):
|
||||
proc = _make_processor()
|
||||
proc.find_collection = MagicMock(return_value=None)
|
||||
proc.find_collection = AsyncMock(return_value=None)
|
||||
request = _make_request()
|
||||
|
||||
result = await proc.query_row_embeddings("test-workspace", request)
|
||||
|
|
@ -148,7 +151,7 @@ class TestQueryRowEmbeddings:
|
|||
@pytest.mark.asyncio
|
||||
async def test_successful_query_returns_matches(self):
|
||||
proc = _make_processor()
|
||||
proc.find_collection = MagicMock(return_value="rows_w_c_s_384")
|
||||
proc.find_collection = AsyncMock(return_value="rows_w_c_s_384")
|
||||
|
||||
points = [
|
||||
_make_search_point("name", ["Alice Smith"], "Alice Smith", 0.95),
|
||||
|
|
@ -172,7 +175,7 @@ class TestQueryRowEmbeddings:
|
|||
async def test_index_name_filter_applied(self):
|
||||
"""When index_name is specified, a Qdrant filter should be used."""
|
||||
proc = _make_processor()
|
||||
proc.find_collection = MagicMock(return_value="rows_w_c_s_384")
|
||||
proc.find_collection = AsyncMock(return_value="rows_w_c_s_384")
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.points = []
|
||||
|
|
@ -188,7 +191,7 @@ class TestQueryRowEmbeddings:
|
|||
async def test_no_index_name_no_filter(self):
|
||||
"""When index_name is empty, no filter should be applied."""
|
||||
proc = _make_processor()
|
||||
proc.find_collection = MagicMock(return_value="rows_w_c_s_384")
|
||||
proc.find_collection = AsyncMock(return_value="rows_w_c_s_384")
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.points = []
|
||||
|
|
@ -204,7 +207,7 @@ class TestQueryRowEmbeddings:
|
|||
async def test_missing_payload_fields_default(self):
|
||||
"""Points with missing payload fields should use defaults."""
|
||||
proc = _make_processor()
|
||||
proc.find_collection = MagicMock(return_value="rows_w_c_s_384")
|
||||
proc.find_collection = AsyncMock(return_value="rows_w_c_s_384")
|
||||
|
||||
point = MagicMock()
|
||||
point.payload = {} # Empty payload
|
||||
|
|
@ -225,7 +228,7 @@ class TestQueryRowEmbeddings:
|
|||
@pytest.mark.asyncio
|
||||
async def test_qdrant_error_propagates(self):
|
||||
proc = _make_processor()
|
||||
proc.find_collection = MagicMock(return_value="rows_w_c_s_384")
|
||||
proc.find_collection = AsyncMock(return_value="rows_w_c_s_384")
|
||||
proc.qdrant.query_points.side_effect = Exception("qdrant down")
|
||||
|
||||
request = _make_request()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue