mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-05-16 19:05:14 +02:00
release/v2.4 -> master (#924)
* CLI auth migration, document embeddings core lifecycle (#913) Migrate get_kg_core and put_kg_core CLI tools to use Api/SocketClient with first-frame auth (fixes broken raw websocket path). Fix wire format field names (root/vector). Remove ~600 lines of dead raw websocket code from invoke_graph_rag.py. Add document embeddings core lifecycle to the knowledge service: list/get/put/delete/load operations across schema, translator, Cassandra table store, knowledge manager, gateway registry, REST API, socket client, and CLI (tg-get-de-core, tg-put-de-core). Fix delete_kg_core to also clean up document embeddings rows. * Remove spurious workspace parameter from SPARQL algebra evaluator (#915) Fix threading of workspace paramater: - The SPARQL algebra evaluator was threading a workspace parameter through every function and passing it to TriplesClient.query(), which doesn't accept it. Workspace isolation is handled by pub/sub topic routing — the TriplesClient is already scoped to a workspace-specific flow, same as GraphRAG. Passing workspace explicitly was both incorrect and unnecessary. Update tests: - tests/unit/test_query/test_sparql_algebra.py (new) — Tests _query_pattern, _eval_bgp, and evaluate() with various algebra nodes. Key tests assert workspace is never in tc.query() kwargs, plus correctness tests for BGP, JOIN, UNION, SLICE, DISTINCT, and edge cases. - tests/unit/test_retrieval/test_graph_rag.py — Added test_triples_query_never_passes_workspace (checks query()) and test_follow_edges_never_passes_workspace (checks query_stream()). * Make all Cassandra and Qdrant I/O async-safe with proper concurrency controls (#916) Cassandra triples services were using syncronous EntityCentricKnowledgeGraph methods from async contexts, and connection state was managed with threading.local which is wrong for asyncio coroutines sharing a single thread. Qdrant services had no async wrapping at all, blocking the event loop on every network call. Rows services had unprotected shared state mutations across concurrent coroutines. - Add async methods to EntityCentricKnowledgeGraph (async_insert, async_get_s/p/o/sp/po/os/spo/all, async_collection_exists, async_create_collection, async_delete_collection) using the existing cassandra_async.async_execute bridge - Rewrite triples write + query services: replace threading.local with asyncio.Lock + dict cache for per-workspace connections, use async ECKG methods for all data operations, keep asyncio.to_thread only for one-time blocking ECKG construction - Wrap all Qdrant calls in asyncio.to_thread across all 6 services (doc/graph/row embeddings write + query), add asyncio.Lock + set cache for collection existence checks - Add asyncio.Lock to rows write + query services to protect shared state (schemas, sessions, config caches) from concurrent mutation - Update all affected tests to match new async patterns * Fixed error only returning a page of results (#921) The root cause: async_execute only materialises the first result page (by design — it says so in its docstring). The streaming query set fetch_size=20 and expected to iterate all results, but only got the first 20 rows back. The fix uses asyncio.to_thread(lambda: list(tg.session.execute(...))) which lets the sync driver iterate all pages in a worker thread — exactly what the pre-async code did. * Optional test warning suppression (#923) * Fix test collection module errors & silence upstream Pytest warnings (#823) * chore: add virtual environment and .env directories to gitignore * test: filter upstream DeprecationWarning and UserWarning messages * fix(namespace): remove empty __init__.py files to fix PEP 420 implicit namespace routing for trustgraph sub-packages * Revert __init__.py deletions * Add .ini changes but commented out, will be useful at times --------- Co-authored-by: Salil M <d2kyt@protonmail.com>
This commit is contained in:
parent
159b1e2824
commit
142dd0231c
42 changed files with 1910 additions and 1492 deletions
|
|
@ -89,12 +89,15 @@ class TestRowsGraphQLQueryLogic:
|
|||
@pytest.mark.asyncio
|
||||
async def test_schema_config_parsing(self):
|
||||
"""Test parsing of schema configuration"""
|
||||
import asyncio
|
||||
processor = MagicMock()
|
||||
processor.schemas = {}
|
||||
processor.schema_builders = {}
|
||||
processor.graphql_schemas = {}
|
||||
processor.config_key = "schema"
|
||||
processor.query_cassandra = MagicMock()
|
||||
processor._setup_lock = asyncio.Lock()
|
||||
processor._apply_schema_config = Processor._apply_schema_config.__get__(processor, Processor)
|
||||
processor.on_schema_config = Processor.on_schema_config.__get__(processor, Processor)
|
||||
|
||||
# Create test config
|
||||
|
|
@ -335,7 +338,7 @@ class TestUnifiedTableQueries:
|
|||
"""Test query execution with matching index"""
|
||||
processor = MagicMock()
|
||||
processor.session = MagicMock()
|
||||
processor.connect_cassandra = MagicMock()
|
||||
processor.connect_cassandra = AsyncMock()
|
||||
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
|
||||
processor.get_index_names = Processor.get_index_names.__get__(processor, Processor)
|
||||
processor.find_matching_index = Processor.find_matching_index.__get__(processor, Processor)
|
||||
|
|
@ -396,7 +399,7 @@ class TestUnifiedTableQueries:
|
|||
"""Test query execution without matching index (scan mode)"""
|
||||
processor = MagicMock()
|
||||
processor.session = MagicMock()
|
||||
processor.connect_cassandra = MagicMock()
|
||||
processor.connect_cassandra = AsyncMock()
|
||||
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
|
||||
processor.get_index_names = Processor.get_index_names.__get__(processor, Processor)
|
||||
processor.find_matching_index = Processor.find_matching_index.__get__(processor, Processor)
|
||||
|
|
|
|||
302
tests/unit/test_query/test_sparql_algebra.py
Normal file
302
tests/unit/test_query/test_sparql_algebra.py
Normal file
|
|
@ -0,0 +1,302 @@
|
|||
"""
|
||||
Tests for the SPARQL algebra evaluator.
|
||||
|
||||
Verifies that evaluate() and _query_pattern() call TriplesClient.query()
|
||||
with the correct arguments, and in particular that workspace is never
|
||||
passed — workspace isolation is handled by pub/sub topic routing.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, call
|
||||
|
||||
from rdflib.term import Variable, URIRef, Literal
|
||||
from rdflib.plugins.sparql.parserutils import CompValue
|
||||
|
||||
from trustgraph.schema import Term, IRI, LITERAL
|
||||
from trustgraph.query.sparql.algebra import (
|
||||
evaluate, _query_pattern, _eval_bgp,
|
||||
)
|
||||
|
||||
|
||||
# --- Helpers ---
|
||||
|
||||
def iri(v):
|
||||
return Term(type=IRI, iri=v)
|
||||
|
||||
|
||||
def lit(v):
|
||||
return Term(type=LITERAL, value=v)
|
||||
|
||||
|
||||
def make_triple(s, p, o):
|
||||
t = MagicMock()
|
||||
t.s = s
|
||||
t.p = p
|
||||
t.o = o
|
||||
return t
|
||||
|
||||
|
||||
def make_bgp(*patterns):
|
||||
"""Build a CompValue BGP node from (s, p, o) tuples of rdflib terms."""
|
||||
node = CompValue("BGP")
|
||||
node.triples = list(patterns)
|
||||
return node
|
||||
|
||||
|
||||
def make_project(inner, variables):
|
||||
node = CompValue("Project")
|
||||
node.p = inner
|
||||
node.PV = [Variable(v) for v in variables]
|
||||
return node
|
||||
|
||||
|
||||
def make_select(inner):
|
||||
node = CompValue("SelectQuery")
|
||||
node.p = inner
|
||||
return node
|
||||
|
||||
|
||||
def make_join(left, right):
|
||||
node = CompValue("Join")
|
||||
node.p1 = left
|
||||
node.p2 = right
|
||||
return node
|
||||
|
||||
|
||||
def make_union(left, right):
|
||||
node = CompValue("Union")
|
||||
node.p1 = left
|
||||
node.p2 = right
|
||||
return node
|
||||
|
||||
|
||||
def make_slice(inner, start, length):
|
||||
node = CompValue("Slice")
|
||||
node.p = inner
|
||||
node.start = start
|
||||
node.length = length
|
||||
return node
|
||||
|
||||
|
||||
def make_distinct(inner):
|
||||
node = CompValue("Distinct")
|
||||
node.p = inner
|
||||
return node
|
||||
|
||||
|
||||
class TestQueryPattern:
|
||||
"""Tests for _query_pattern — the leaf that calls TriplesClient."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_passes_correct_args(self):
|
||||
tc = AsyncMock()
|
||||
tc.query.return_value = []
|
||||
|
||||
await _query_pattern(
|
||||
tc,
|
||||
s=iri("http://example.com/s"),
|
||||
p=iri("http://example.com/p"),
|
||||
o=None,
|
||||
collection="my-collection",
|
||||
limit=100,
|
||||
)
|
||||
|
||||
tc.query.assert_called_once_with(
|
||||
s=iri("http://example.com/s"),
|
||||
p=iri("http://example.com/p"),
|
||||
o=None,
|
||||
limit=100,
|
||||
collection="my-collection",
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_workspace_not_passed(self):
|
||||
tc = AsyncMock()
|
||||
tc.query.return_value = []
|
||||
|
||||
await _query_pattern(tc, None, None, None, "default", 10)
|
||||
|
||||
kwargs = tc.query.call_args.kwargs
|
||||
assert "workspace" not in kwargs
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_query_results(self):
|
||||
tc = AsyncMock()
|
||||
triple = make_triple(iri("http://a"), iri("http://b"), lit("c"))
|
||||
tc.query.return_value = [triple]
|
||||
|
||||
results = await _query_pattern(tc, None, None, None, "default", 10)
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0].s.iri == "http://a"
|
||||
|
||||
|
||||
class TestEvalBgp:
|
||||
"""Tests for BGP evaluation — triple pattern queries."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_single_pattern_all_variables(self):
|
||||
tc = AsyncMock()
|
||||
triple = make_triple(iri("http://s"), iri("http://p"), lit("o"))
|
||||
tc.query.return_value = [triple]
|
||||
|
||||
bgp = make_bgp(
|
||||
(Variable("s"), Variable("p"), Variable("o")),
|
||||
)
|
||||
|
||||
solutions = await evaluate(bgp, tc, collection="default", limit=100)
|
||||
|
||||
assert len(solutions) == 1
|
||||
assert solutions[0]["s"].iri == "http://s"
|
||||
assert solutions[0]["p"].iri == "http://p"
|
||||
assert solutions[0]["o"].value == "o"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_single_pattern_bound_subject(self):
|
||||
tc = AsyncMock()
|
||||
tc.query.return_value = [
|
||||
make_triple(iri("http://s"), iri("http://p"), lit("val")),
|
||||
]
|
||||
|
||||
bgp = make_bgp(
|
||||
(URIRef("http://s"), Variable("p"), Variable("o")),
|
||||
)
|
||||
|
||||
solutions = await evaluate(bgp, tc, collection="default")
|
||||
|
||||
tc.query.assert_called_once()
|
||||
kwargs = tc.query.call_args.kwargs
|
||||
assert "workspace" not in kwargs
|
||||
assert kwargs["collection"] == "default"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_bgp_returns_empty_solution(self):
|
||||
tc = AsyncMock()
|
||||
|
||||
bgp = make_bgp()
|
||||
|
||||
solutions = await evaluate(bgp, tc, collection="default")
|
||||
|
||||
assert solutions == [{}]
|
||||
tc.query.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_results_returns_empty(self):
|
||||
tc = AsyncMock()
|
||||
tc.query.return_value = []
|
||||
|
||||
bgp = make_bgp(
|
||||
(Variable("s"), Variable("p"), Variable("o")),
|
||||
)
|
||||
|
||||
solutions = await evaluate(bgp, tc, collection="default")
|
||||
|
||||
assert solutions == []
|
||||
|
||||
|
||||
class TestEvaluate:
|
||||
"""Tests for the top-level evaluate() dispatcher."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_select_query_node(self):
|
||||
tc = AsyncMock()
|
||||
tc.query.return_value = [
|
||||
make_triple(iri("http://s"), iri("http://p"), lit("o")),
|
||||
]
|
||||
|
||||
bgp = make_bgp(
|
||||
(Variable("s"), Variable("p"), Variable("o")),
|
||||
)
|
||||
select = make_select(make_project(bgp, ["s", "p"]))
|
||||
|
||||
solutions = await evaluate(select, tc, collection="default")
|
||||
|
||||
assert len(solutions) == 1
|
||||
assert "s" in solutions[0]
|
||||
assert "p" in solutions[0]
|
||||
assert "o" not in solutions[0]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_workspace_never_in_query_calls(self):
|
||||
"""Verify that no matter the algebra structure, workspace is never
|
||||
passed to TriplesClient.query()."""
|
||||
tc = AsyncMock()
|
||||
tc.query.return_value = [
|
||||
make_triple(iri("http://s"), iri("http://p"), lit("o")),
|
||||
]
|
||||
|
||||
bgp1 = make_bgp((Variable("s"), Variable("p"), Variable("o")))
|
||||
bgp2 = make_bgp((Variable("a"), Variable("b"), Variable("c")))
|
||||
tree = make_select(make_project(
|
||||
make_union(bgp1, bgp2), ["s", "p", "o"]
|
||||
))
|
||||
|
||||
await evaluate(tree, tc, collection="test-coll")
|
||||
|
||||
for c in tc.query.call_args_list:
|
||||
assert "workspace" not in c.kwargs
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_join(self):
|
||||
tc = AsyncMock()
|
||||
tc.query.side_effect = [
|
||||
[make_triple(iri("http://a"), iri("http://p"), lit("v"))],
|
||||
[make_triple(iri("http://a"), iri("http://q"), lit("w"))],
|
||||
]
|
||||
|
||||
bgp1 = make_bgp((Variable("s"), URIRef("http://p"), Variable("v1")))
|
||||
bgp2 = make_bgp((Variable("s"), URIRef("http://q"), Variable("v2")))
|
||||
tree = make_join(bgp1, bgp2)
|
||||
|
||||
solutions = await evaluate(tree, tc, collection="default")
|
||||
|
||||
assert len(solutions) == 1
|
||||
assert solutions[0]["s"].iri == "http://a"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_slice(self):
|
||||
tc = AsyncMock()
|
||||
triples = [
|
||||
make_triple(iri(f"http://s{i}"), iri("http://p"), lit(f"o{i}"))
|
||||
for i in range(5)
|
||||
]
|
||||
tc.query.return_value = triples
|
||||
|
||||
bgp = make_bgp((Variable("s"), Variable("p"), Variable("o")))
|
||||
tree = make_slice(bgp, start=1, length=2)
|
||||
|
||||
solutions = await evaluate(tree, tc, collection="default")
|
||||
|
||||
assert len(solutions) == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_distinct(self):
|
||||
tc = AsyncMock()
|
||||
triple = make_triple(iri("http://s"), iri("http://p"), lit("o"))
|
||||
tc.query.return_value = [triple, triple]
|
||||
|
||||
bgp = make_bgp((Variable("s"), Variable("p"), Variable("o")))
|
||||
tree = make_distinct(bgp)
|
||||
|
||||
solutions = await evaluate(tree, tc, collection="default")
|
||||
|
||||
assert len(solutions) == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unsupported_node_returns_empty_solution(self):
|
||||
tc = AsyncMock()
|
||||
|
||||
node = CompValue("SomethingUnknown")
|
||||
|
||||
solutions = await evaluate(node, tc, collection="default")
|
||||
|
||||
assert solutions == [{}]
|
||||
tc.query.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_compvalue_returns_empty_solution(self):
|
||||
tc = AsyncMock()
|
||||
|
||||
solutions = await evaluate("not a node", tc, collection="default")
|
||||
|
||||
assert solutions == [{}]
|
||||
|
|
@ -2,8 +2,10 @@
|
|||
Tests for Cassandra triples query service
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
from unittest.mock import MagicMock, patch, AsyncMock
|
||||
|
||||
from trustgraph.query.triples.cassandra.service import Processor, create_term
|
||||
from trustgraph.schema import Term, IRI, LITERAL
|
||||
|
|
@ -18,7 +20,7 @@ class TestCassandraQueryProcessor:
|
|||
return Processor(
|
||||
taskgroup=MagicMock(),
|
||||
id='test-cassandra-query',
|
||||
graph_host='localhost'
|
||||
cassandra_host='localhost'
|
||||
)
|
||||
|
||||
def test_create_term_with_http_uri(self, processor):
|
||||
|
|
@ -85,7 +87,7 @@ class TestCassandraQueryProcessor:
|
|||
mock_result.dtype = None
|
||||
mock_result.lang = None
|
||||
mock_result.o = 'test_object'
|
||||
mock_tg_instance.get_spo.return_value = [mock_result]
|
||||
mock_tg_instance.async_get_spo = AsyncMock(return_value=[mock_result])
|
||||
|
||||
processor = Processor(
|
||||
taskgroup=MagicMock(),
|
||||
|
|
@ -110,8 +112,8 @@ class TestCassandraQueryProcessor:
|
|||
keyspace='test_user'
|
||||
)
|
||||
|
||||
# Verify get_spo was called with correct parameters
|
||||
mock_tg_instance.get_spo.assert_called_once_with(
|
||||
# Verify async_get_spo was called with correct parameters
|
||||
mock_tg_instance.async_get_spo.assert_called_once_with(
|
||||
'test_collection', 'test_subject', 'test_predicate', 'test_object', g=None, limit=100
|
||||
)
|
||||
|
||||
|
|
@ -130,23 +132,25 @@ class TestCassandraQueryProcessor:
|
|||
assert processor.cassandra_host == ['cassandra'] # Updated default
|
||||
assert processor.cassandra_username is None
|
||||
assert processor.cassandra_password is None
|
||||
assert processor.table is None
|
||||
assert processor._connections == {}
|
||||
assert isinstance(processor._conn_lock, asyncio.Lock)
|
||||
|
||||
def test_processor_initialization_with_custom_params(self):
|
||||
"""Test processor initialization with custom parameters"""
|
||||
taskgroup_mock = MagicMock()
|
||||
|
||||
|
||||
processor = Processor(
|
||||
taskgroup=taskgroup_mock,
|
||||
cassandra_host='cassandra.example.com',
|
||||
cassandra_username='queryuser',
|
||||
cassandra_password='querypass'
|
||||
)
|
||||
|
||||
|
||||
assert processor.cassandra_host == ['cassandra.example.com']
|
||||
assert processor.cassandra_username == 'queryuser'
|
||||
assert processor.cassandra_password == 'querypass'
|
||||
assert processor.table is None
|
||||
assert processor._connections == {}
|
||||
assert isinstance(processor._conn_lock, asyncio.Lock)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.query.triples.cassandra.service.EntityCentricKnowledgeGraph')
|
||||
|
|
@ -164,7 +168,7 @@ class TestCassandraQueryProcessor:
|
|||
mock_result.otype = None
|
||||
mock_result.dtype = None
|
||||
mock_result.lang = None
|
||||
mock_tg_instance.get_sp.return_value = [mock_result]
|
||||
mock_tg_instance.async_get_sp = AsyncMock(return_value=[mock_result])
|
||||
|
||||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
|
|
@ -178,7 +182,7 @@ class TestCassandraQueryProcessor:
|
|||
|
||||
result = await processor.query_triples('test_user', query)
|
||||
|
||||
mock_tg_instance.get_sp.assert_called_once_with('test_collection', 'test_subject', 'test_predicate', g=None, limit=50)
|
||||
mock_tg_instance.async_get_sp.assert_called_once_with('test_collection', 'test_subject', 'test_predicate', g=None, limit=50)
|
||||
assert len(result) == 1
|
||||
assert result[0].s.iri == 'test_subject'
|
||||
assert result[0].p.iri == 'test_predicate'
|
||||
|
|
@ -200,7 +204,7 @@ class TestCassandraQueryProcessor:
|
|||
mock_result.otype = None
|
||||
mock_result.dtype = None
|
||||
mock_result.lang = None
|
||||
mock_tg_instance.get_s.return_value = [mock_result]
|
||||
mock_tg_instance.async_get_s = AsyncMock(return_value=[mock_result])
|
||||
|
||||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
|
|
@ -214,7 +218,7 @@ class TestCassandraQueryProcessor:
|
|||
|
||||
result = await processor.query_triples('test_user', query)
|
||||
|
||||
mock_tg_instance.get_s.assert_called_once_with('test_collection', 'test_subject', g=None, limit=25)
|
||||
mock_tg_instance.async_get_s.assert_called_once_with('test_collection', 'test_subject', g=None, limit=25)
|
||||
assert len(result) == 1
|
||||
assert result[0].s.iri == 'test_subject'
|
||||
assert result[0].p.iri == 'result_predicate'
|
||||
|
|
@ -236,7 +240,7 @@ class TestCassandraQueryProcessor:
|
|||
mock_result.otype = None
|
||||
mock_result.dtype = None
|
||||
mock_result.lang = None
|
||||
mock_tg_instance.get_p.return_value = [mock_result]
|
||||
mock_tg_instance.async_get_p = AsyncMock(return_value=[mock_result])
|
||||
|
||||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
|
|
@ -250,7 +254,7 @@ class TestCassandraQueryProcessor:
|
|||
|
||||
result = await processor.query_triples('test_user', query)
|
||||
|
||||
mock_tg_instance.get_p.assert_called_once_with('test_collection', 'test_predicate', g=None, limit=10)
|
||||
mock_tg_instance.async_get_p.assert_called_once_with('test_collection', 'test_predicate', g=None, limit=10)
|
||||
assert len(result) == 1
|
||||
assert result[0].s.iri == 'result_subject'
|
||||
assert result[0].p.iri == 'test_predicate'
|
||||
|
|
@ -272,7 +276,7 @@ class TestCassandraQueryProcessor:
|
|||
mock_result.otype = None
|
||||
mock_result.dtype = None
|
||||
mock_result.lang = None
|
||||
mock_tg_instance.get_o.return_value = [mock_result]
|
||||
mock_tg_instance.async_get_o = AsyncMock(return_value=[mock_result])
|
||||
|
||||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
|
|
@ -286,7 +290,7 @@ class TestCassandraQueryProcessor:
|
|||
|
||||
result = await processor.query_triples('test_user', query)
|
||||
|
||||
mock_tg_instance.get_o.assert_called_once_with('test_collection', 'test_object', g=None, limit=75)
|
||||
mock_tg_instance.async_get_o.assert_called_once_with('test_collection', 'test_object', g=None, limit=75)
|
||||
assert len(result) == 1
|
||||
assert result[0].s.iri == 'result_subject'
|
||||
assert result[0].p.iri == 'result_predicate'
|
||||
|
|
@ -305,11 +309,11 @@ class TestCassandraQueryProcessor:
|
|||
mock_result.s = 'all_subject'
|
||||
mock_result.p = 'all_predicate'
|
||||
mock_result.o = 'all_object'
|
||||
mock_result.g = ''
|
||||
mock_result.d = ''
|
||||
mock_result.otype = None
|
||||
mock_result.dtype = None
|
||||
mock_result.lang = None
|
||||
mock_tg_instance.get_all.return_value = [mock_result]
|
||||
mock_tg_instance.async_get_all = AsyncMock(return_value=[mock_result])
|
||||
|
||||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
|
|
@ -323,7 +327,7 @@ class TestCassandraQueryProcessor:
|
|||
|
||||
result = await processor.query_triples('test_user', query)
|
||||
|
||||
mock_tg_instance.get_all.assert_called_once_with('test_collection', limit=1000)
|
||||
mock_tg_instance.async_get_all.assert_called_once_with('test_collection', limit=1000)
|
||||
assert len(result) == 1
|
||||
assert result[0].s.iri == 'all_subject'
|
||||
assert result[0].p.iri == 'all_predicate'
|
||||
|
|
@ -410,7 +414,7 @@ class TestCassandraQueryProcessor:
|
|||
mock_result.dtype = None
|
||||
mock_result.lang = None
|
||||
mock_result.o = 'test_object'
|
||||
mock_tg_instance.get_spo.return_value = [mock_result]
|
||||
mock_tg_instance.async_get_spo = AsyncMock(return_value=[mock_result])
|
||||
|
||||
processor = Processor(
|
||||
taskgroup=MagicMock(),
|
||||
|
|
@ -451,7 +455,7 @@ class TestCassandraQueryProcessor:
|
|||
mock_result.dtype = None
|
||||
mock_result.lang = None
|
||||
mock_result.o = 'test_object'
|
||||
mock_tg_instance.get_spo.return_value = [mock_result]
|
||||
mock_tg_instance.async_get_spo = AsyncMock(return_value=[mock_result])
|
||||
|
||||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
|
|
@ -489,8 +493,8 @@ class TestCassandraQueryProcessor:
|
|||
mock_result.lang = None
|
||||
mock_result.p = 'p'
|
||||
mock_result.o = 'o'
|
||||
mock_tg_instance1.get_s.return_value = [mock_result]
|
||||
mock_tg_instance2.get_s.return_value = [mock_result]
|
||||
mock_tg_instance1.async_get_s = AsyncMock(return_value=[mock_result])
|
||||
mock_tg_instance2.async_get_s = AsyncMock(return_value=[mock_result])
|
||||
|
||||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
|
|
@ -504,7 +508,6 @@ class TestCassandraQueryProcessor:
|
|||
)
|
||||
|
||||
await processor.query_triples('user1', query1)
|
||||
assert processor.table == 'user1'
|
||||
|
||||
# Second query with different table
|
||||
query2 = TriplesQueryRequest(
|
||||
|
|
@ -516,10 +519,11 @@ class TestCassandraQueryProcessor:
|
|||
)
|
||||
|
||||
await processor.query_triples('user2', query2)
|
||||
assert processor.table == 'user2'
|
||||
|
||||
# Verify TrustGraph was created twice
|
||||
# Verify TrustGraph was created twice for different workspaces
|
||||
assert mock_kg_class.call_count == 2
|
||||
mock_kg_class.assert_any_call(hosts=['cassandra'], keyspace='user1')
|
||||
mock_kg_class.assert_any_call(hosts=['cassandra'], keyspace='user2')
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.query.triples.cassandra.service.EntityCentricKnowledgeGraph')
|
||||
|
|
@ -529,7 +533,7 @@ class TestCassandraQueryProcessor:
|
|||
|
||||
mock_tg_instance = MagicMock()
|
||||
mock_kg_class.return_value = mock_tg_instance
|
||||
mock_tg_instance.get_spo.side_effect = Exception("Query failed")
|
||||
mock_tg_instance.async_get_spo = AsyncMock(side_effect=Exception("Query failed"))
|
||||
|
||||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
|
|
@ -566,7 +570,7 @@ class TestCassandraQueryProcessor:
|
|||
mock_result2.otype = None
|
||||
mock_result2.dtype = None
|
||||
mock_result2.lang = None
|
||||
mock_tg_instance.get_sp.return_value = [mock_result1, mock_result2]
|
||||
mock_tg_instance.async_get_sp = AsyncMock(return_value=[mock_result1, mock_result2])
|
||||
|
||||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
|
|
@ -603,7 +607,7 @@ class TestCassandraQueryPerformanceOptimizations:
|
|||
mock_result.otype = None
|
||||
mock_result.dtype = None
|
||||
mock_result.lang = None
|
||||
mock_tg_instance.get_po.return_value = [mock_result]
|
||||
mock_tg_instance.async_get_po = AsyncMock(return_value=[mock_result])
|
||||
|
||||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
|
|
@ -618,8 +622,8 @@ class TestCassandraQueryPerformanceOptimizations:
|
|||
|
||||
result = await processor.query_triples('test_user', query)
|
||||
|
||||
# Verify get_po was called (should use optimized po_table)
|
||||
mock_tg_instance.get_po.assert_called_once_with(
|
||||
# Verify async_get_po was called (should use optimized po_table)
|
||||
mock_tg_instance.async_get_po.assert_called_once_with(
|
||||
'test_collection', 'test_predicate', 'test_object', g=None, limit=50
|
||||
)
|
||||
|
||||
|
|
@ -643,7 +647,7 @@ class TestCassandraQueryPerformanceOptimizations:
|
|||
mock_result.otype = None
|
||||
mock_result.dtype = None
|
||||
mock_result.lang = None
|
||||
mock_tg_instance.get_os.return_value = [mock_result]
|
||||
mock_tg_instance.async_get_os = AsyncMock(return_value=[mock_result])
|
||||
|
||||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
|
|
@ -658,8 +662,8 @@ class TestCassandraQueryPerformanceOptimizations:
|
|||
|
||||
result = await processor.query_triples('test_user', query)
|
||||
|
||||
# Verify get_os was called (should use optimized subject_table with clustering)
|
||||
mock_tg_instance.get_os.assert_called_once_with(
|
||||
# Verify async_get_os was called (should use optimized subject_table with clustering)
|
||||
mock_tg_instance.async_get_os.assert_called_once_with(
|
||||
'test_collection', 'test_object', 'test_subject', g=None, limit=25
|
||||
)
|
||||
|
||||
|
|
@ -678,28 +682,28 @@ class TestCassandraQueryPerformanceOptimizations:
|
|||
mock_kg_class.return_value = mock_tg_instance
|
||||
|
||||
# Mock empty results for all queries
|
||||
mock_tg_instance.get_all.return_value = []
|
||||
mock_tg_instance.get_s.return_value = []
|
||||
mock_tg_instance.get_p.return_value = []
|
||||
mock_tg_instance.get_o.return_value = []
|
||||
mock_tg_instance.get_sp.return_value = []
|
||||
mock_tg_instance.get_po.return_value = []
|
||||
mock_tg_instance.get_os.return_value = []
|
||||
mock_tg_instance.get_spo.return_value = []
|
||||
mock_tg_instance.async_get_all = AsyncMock(return_value=[])
|
||||
mock_tg_instance.async_get_s = AsyncMock(return_value=[])
|
||||
mock_tg_instance.async_get_p = AsyncMock(return_value=[])
|
||||
mock_tg_instance.async_get_o = AsyncMock(return_value=[])
|
||||
mock_tg_instance.async_get_sp = AsyncMock(return_value=[])
|
||||
mock_tg_instance.async_get_po = AsyncMock(return_value=[])
|
||||
mock_tg_instance.async_get_os = AsyncMock(return_value=[])
|
||||
mock_tg_instance.async_get_spo = AsyncMock(return_value=[])
|
||||
|
||||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
# Test each query pattern
|
||||
test_patterns = [
|
||||
# (s, p, o, expected_method)
|
||||
(None, None, None, 'get_all'), # All triples
|
||||
('s1', None, None, 'get_s'), # Subject only
|
||||
(None, 'p1', None, 'get_p'), # Predicate only
|
||||
(None, None, 'o1', 'get_o'), # Object only
|
||||
('s1', 'p1', None, 'get_sp'), # Subject + Predicate
|
||||
(None, 'p1', 'o1', 'get_po'), # Predicate + Object (CRITICAL OPTIMIZATION)
|
||||
('s1', None, 'o1', 'get_os'), # Object + Subject
|
||||
('s1', 'p1', 'o1', 'get_spo'), # All three
|
||||
(None, None, None, 'async_get_all'), # All triples
|
||||
('s1', None, None, 'async_get_s'), # Subject only
|
||||
(None, 'p1', None, 'async_get_p'), # Predicate only
|
||||
(None, None, 'o1', 'async_get_o'), # Object only
|
||||
('s1', 'p1', None, 'async_get_sp'), # Subject + Predicate
|
||||
(None, 'p1', 'o1', 'async_get_po'), # Predicate + Object (CRITICAL OPTIMIZATION)
|
||||
('s1', None, 'o1', 'async_get_os'), # Object + Subject
|
||||
('s1', 'p1', 'o1', 'async_get_spo'), # All three
|
||||
]
|
||||
|
||||
for s, p, o, expected_method in test_patterns:
|
||||
|
|
@ -759,7 +763,7 @@ class TestCassandraQueryPerformanceOptimizations:
|
|||
mock_result.lang = None
|
||||
mock_results.append(mock_result)
|
||||
|
||||
mock_tg_instance.get_po.return_value = mock_results
|
||||
mock_tg_instance.async_get_po = AsyncMock(return_value=mock_results)
|
||||
|
||||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
|
|
@ -774,8 +778,8 @@ class TestCassandraQueryPerformanceOptimizations:
|
|||
|
||||
result = await processor.query_triples('large_dataset_user', query)
|
||||
|
||||
# Verify optimized get_po was used (no ALLOW FILTERING needed!)
|
||||
mock_tg_instance.get_po.assert_called_once_with(
|
||||
# Verify optimized async_get_po was used (no ALLOW FILTERING needed!)
|
||||
mock_tg_instance.async_get_po.assert_called_once_with(
|
||||
'massive_collection',
|
||||
'http://www.w3.org/1999/02/22-rdf-syntax-ns#type',
|
||||
'http://example.com/Person',
|
||||
|
|
|
|||
|
|
@ -113,12 +113,15 @@ class TestDocEmbeddingsNullProtection:
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_valid_embedding_upserted(self):
|
||||
import asyncio
|
||||
from trustgraph.storage.doc_embeddings.qdrant.write import Processor
|
||||
|
||||
proc = Processor.__new__(Processor)
|
||||
proc.qdrant = MagicMock()
|
||||
proc.qdrant.collection_exists.return_value = True
|
||||
proc.collection_exists = MagicMock(return_value=True)
|
||||
proc._cache_lock = asyncio.Lock()
|
||||
proc._known_collections = set()
|
||||
|
||||
msg = MagicMock()
|
||||
msg.metadata.collection = "col1"
|
||||
|
|
@ -134,12 +137,15 @@ class TestDocEmbeddingsNullProtection:
|
|||
@pytest.mark.asyncio
|
||||
async def test_dimension_in_collection_name(self):
|
||||
"""Collection name should include vector dimension."""
|
||||
import asyncio
|
||||
from trustgraph.storage.doc_embeddings.qdrant.write import Processor
|
||||
|
||||
proc = Processor.__new__(Processor)
|
||||
proc.qdrant = MagicMock()
|
||||
proc.qdrant.collection_exists.return_value = True
|
||||
proc.collection_exists = MagicMock(return_value=True)
|
||||
proc._cache_lock = asyncio.Lock()
|
||||
proc._known_collections = set()
|
||||
|
||||
msg = MagicMock()
|
||||
msg.metadata.collection = "docs"
|
||||
|
|
@ -220,12 +226,15 @@ class TestGraphEmbeddingsNullProtection:
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_valid_entity_and_vector_upserted(self):
|
||||
import asyncio
|
||||
from trustgraph.storage.graph_embeddings.qdrant.write import Processor
|
||||
|
||||
proc = Processor.__new__(Processor)
|
||||
proc.qdrant = MagicMock()
|
||||
proc.qdrant.collection_exists.return_value = True
|
||||
proc.collection_exists = MagicMock(return_value=True)
|
||||
proc._cache_lock = asyncio.Lock()
|
||||
proc._known_collections = set()
|
||||
|
||||
msg = MagicMock()
|
||||
msg.metadata.collection = "col1"
|
||||
|
|
@ -241,12 +250,15 @@ class TestGraphEmbeddingsNullProtection:
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_lazy_collection_creation_on_new_dimension(self):
|
||||
import asyncio
|
||||
from trustgraph.storage.graph_embeddings.qdrant.write import Processor
|
||||
|
||||
proc = Processor.__new__(Processor)
|
||||
proc.qdrant = MagicMock()
|
||||
proc.qdrant.collection_exists.return_value = False
|
||||
proc.collection_exists = MagicMock(return_value=True)
|
||||
proc._cache_lock = asyncio.Lock()
|
||||
proc._known_collections = set()
|
||||
|
||||
msg = MagicMock()
|
||||
msg.metadata.collection = "graphs"
|
||||
|
|
|
|||
|
|
@ -337,6 +337,57 @@ class TestQuery:
|
|||
cache_key = "test_collection:unlabeled_entity"
|
||||
mock_cache.put.assert_called_once_with(cache_key, "unlabeled_entity")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_triples_query_never_passes_workspace(self):
|
||||
"""Workspace isolation is handled by pub/sub topic routing, not
|
||||
by passing workspace to TriplesClient.query(). Verify that
|
||||
GraphRAG never passes workspace as a keyword argument."""
|
||||
mock_rag = MagicMock()
|
||||
mock_cache = MagicMock()
|
||||
mock_cache.get.return_value = None
|
||||
mock_rag.label_cache = mock_cache
|
||||
mock_triples_client = AsyncMock()
|
||||
mock_rag.triples_client = mock_triples_client
|
||||
|
||||
mock_triple = MagicMock()
|
||||
mock_triple.o = "Label"
|
||||
mock_triples_client.query.return_value = [mock_triple]
|
||||
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
collection="test_collection",
|
||||
verbose=False
|
||||
)
|
||||
|
||||
await query.maybe_label("http://example.com/entity")
|
||||
|
||||
for c in mock_triples_client.query.call_args_list:
|
||||
assert "workspace" not in c.kwargs
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_follow_edges_never_passes_workspace(self):
|
||||
"""Verify follow_edges never passes workspace to query_stream."""
|
||||
mock_rag = MagicMock()
|
||||
mock_triples_client = AsyncMock()
|
||||
mock_rag.triples_client = mock_triples_client
|
||||
|
||||
mock_triple = MagicMock()
|
||||
mock_triple.s, mock_triple.p, mock_triple.o = "e1", "p1", "o1"
|
||||
mock_triples_client.query_stream.return_value = [mock_triple]
|
||||
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
collection="test_collection",
|
||||
verbose=False,
|
||||
triple_limit=10
|
||||
)
|
||||
|
||||
subgraph = set()
|
||||
await query.follow_edges("e1", subgraph, path_length=1)
|
||||
|
||||
for c in mock_triples_client.query_stream.call_args_list:
|
||||
assert "workspace" not in c.kwargs
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_follow_edges_basic_functionality(self):
|
||||
"""Test Query.follow_edges method basic triple discovery"""
|
||||
|
|
|
|||
|
|
@ -413,8 +413,8 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
# Assert
|
||||
expected_collection = 'd_cache_user_cache_collection_3' # 3 dimensions
|
||||
|
||||
# Verify collection existence is checked on each write
|
||||
mock_qdrant_instance.collection_exists.assert_called_once_with(expected_collection)
|
||||
# Second write uses cached collection state — no collection_exists check
|
||||
mock_qdrant_instance.collection_exists.assert_not_called()
|
||||
|
||||
# But upsert should still be called
|
||||
mock_qdrant_instance.upsert.assert_called_once()
|
||||
|
|
|
|||
|
|
@ -125,13 +125,13 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
|
||||
processor = Processor(**config)
|
||||
|
||||
processor.ensure_collection("test_collection", 384)
|
||||
await processor.ensure_collection("test_collection", 384)
|
||||
|
||||
mock_qdrant_instance.collection_exists.assert_called_once_with("test_collection")
|
||||
mock_qdrant_instance.create_collection.assert_called_once()
|
||||
|
||||
# Verify the collection is cached
|
||||
assert "test_collection" in processor.created_collections
|
||||
assert "test_collection" in processor._known_collections
|
||||
|
||||
@patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient')
|
||||
async def test_ensure_collection_skips_existing(self, mock_qdrant_client):
|
||||
|
|
@ -149,7 +149,7 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
|
||||
processor = Processor(**config)
|
||||
|
||||
processor.ensure_collection("existing_collection", 384)
|
||||
await processor.ensure_collection("existing_collection", 384)
|
||||
|
||||
mock_qdrant_instance.collection_exists.assert_called_once()
|
||||
mock_qdrant_instance.create_collection.assert_not_called()
|
||||
|
|
@ -168,9 +168,9 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
processor.created_collections.add("cached_collection")
|
||||
processor._known_collections.add("cached_collection")
|
||||
|
||||
processor.ensure_collection("cached_collection", 384)
|
||||
await processor.ensure_collection("cached_collection", 384)
|
||||
|
||||
# Should not check or create - just return
|
||||
mock_qdrant_instance.collection_exists.assert_not_called()
|
||||
|
|
@ -391,7 +391,7 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
processor.created_collections.add('rows_test_workspace_test_collection_schema1_384')
|
||||
processor._known_collections.add('rows_test_workspace_test_collection_schema1_384')
|
||||
|
||||
await processor.delete_collection('test_workspace', 'test_collection')
|
||||
|
||||
|
|
@ -399,7 +399,7 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
assert mock_qdrant_instance.delete_collection.call_count == 2
|
||||
|
||||
# Verify the cached collection was removed
|
||||
assert 'rows_test_workspace_test_collection_schema1_384' not in processor.created_collections
|
||||
assert 'rows_test_workspace_test_collection_schema1_384' not in processor._known_collections
|
||||
|
||||
@patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient')
|
||||
async def test_delete_collection_schema(self, mock_qdrant_client):
|
||||
|
|
|
|||
|
|
@ -121,10 +121,13 @@ class TestRowsCassandraStorageLogic:
|
|||
@pytest.mark.asyncio
|
||||
async def test_schema_config_parsing(self):
|
||||
"""Test parsing of schema configurations"""
|
||||
import asyncio
|
||||
processor = MagicMock()
|
||||
processor.schemas = {}
|
||||
processor.config_key = "schema"
|
||||
processor.registered_partitions = set()
|
||||
processor._setup_lock = asyncio.Lock()
|
||||
processor._apply_schema_config = Processor._apply_schema_config.__get__(processor, Processor)
|
||||
processor.on_schema_config = Processor.on_schema_config.__get__(processor, Processor)
|
||||
|
||||
# Create test configuration
|
||||
|
|
|
|||
|
|
@ -2,6 +2,8 @@
|
|||
Tests for Cassandra triples storage service
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch, AsyncMock
|
||||
|
||||
|
|
@ -24,12 +26,13 @@ class TestCassandraStorageProcessor:
|
|||
assert processor.cassandra_host == ['cassandra'] # Updated default
|
||||
assert processor.cassandra_username is None
|
||||
assert processor.cassandra_password is None
|
||||
assert processor.table is None
|
||||
assert processor._connections == {}
|
||||
assert isinstance(processor._conn_lock, asyncio.Lock)
|
||||
|
||||
def test_processor_initialization_with_custom_params(self):
|
||||
"""Test processor initialization with custom parameters (new cassandra_* names)"""
|
||||
taskgroup_mock = MagicMock()
|
||||
|
||||
|
||||
processor = Processor(
|
||||
taskgroup=taskgroup_mock,
|
||||
id='custom-storage',
|
||||
|
|
@ -37,11 +40,12 @@ class TestCassandraStorageProcessor:
|
|||
cassandra_username='testuser',
|
||||
cassandra_password='testpass'
|
||||
)
|
||||
|
||||
|
||||
assert processor.cassandra_host == ['cassandra.example.com']
|
||||
assert processor.cassandra_username == 'testuser'
|
||||
assert processor.cassandra_password == 'testpass'
|
||||
assert processor.table is None
|
||||
assert processor._connections == {}
|
||||
assert isinstance(processor._conn_lock, asyncio.Lock)
|
||||
|
||||
def test_processor_initialization_with_partial_auth(self):
|
||||
"""Test processor initialization with only username (no password)"""
|
||||
|
|
@ -92,6 +96,7 @@ class TestCassandraStorageProcessor:
|
|||
"""Test table switching logic when authentication is provided"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_tg_instance = MagicMock()
|
||||
mock_tg_instance.async_insert = AsyncMock()
|
||||
mock_kg_class.return_value = mock_tg_instance
|
||||
|
||||
processor = Processor(
|
||||
|
|
@ -114,7 +119,6 @@ class TestCassandraStorageProcessor:
|
|||
username='testuser',
|
||||
password='testpass'
|
||||
)
|
||||
assert processor.table == 'user1'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.storage.triples.cassandra.write.EntityCentricKnowledgeGraph')
|
||||
|
|
@ -122,6 +126,7 @@ class TestCassandraStorageProcessor:
|
|||
"""Test table switching logic when no authentication is provided"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_tg_instance = MagicMock()
|
||||
mock_tg_instance.async_insert = AsyncMock()
|
||||
mock_kg_class.return_value = mock_tg_instance
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
|
@ -138,7 +143,6 @@ class TestCassandraStorageProcessor:
|
|||
hosts=['cassandra'], # Updated default
|
||||
keyspace='user2'
|
||||
)
|
||||
assert processor.table == 'user2'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.storage.triples.cassandra.write.EntityCentricKnowledgeGraph')
|
||||
|
|
@ -146,6 +150,7 @@ class TestCassandraStorageProcessor:
|
|||
"""Test that TrustGraph is not recreated when table hasn't changed"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_tg_instance = MagicMock()
|
||||
mock_tg_instance.async_insert = AsyncMock()
|
||||
mock_kg_class.return_value = mock_tg_instance
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
|
@ -169,6 +174,7 @@ class TestCassandraStorageProcessor:
|
|||
"""Test that triples are properly inserted into Cassandra"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_tg_instance = MagicMock()
|
||||
mock_tg_instance.async_insert = AsyncMock()
|
||||
mock_kg_class.return_value = mock_tg_instance
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
|
@ -208,12 +214,12 @@ class TestCassandraStorageProcessor:
|
|||
await processor.store_triples('user1', mock_message)
|
||||
|
||||
# Verify both triples were inserted (with g=, otype=, dtype=, lang= parameters)
|
||||
assert mock_tg_instance.insert.call_count == 2
|
||||
mock_tg_instance.insert.assert_any_call(
|
||||
assert mock_tg_instance.async_insert.call_count == 2
|
||||
mock_tg_instance.async_insert.assert_any_call(
|
||||
'collection1', 'subject1', 'predicate1', 'object1',
|
||||
g=DEFAULT_GRAPH, otype='l', dtype='', lang=''
|
||||
)
|
||||
mock_tg_instance.insert.assert_any_call(
|
||||
mock_tg_instance.async_insert.assert_any_call(
|
||||
'collection1', 'subject2', 'predicate2', 'object2',
|
||||
g=DEFAULT_GRAPH, otype='l', dtype='', lang=''
|
||||
)
|
||||
|
|
@ -224,6 +230,7 @@ class TestCassandraStorageProcessor:
|
|||
"""Test behavior when message has no triples"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_tg_instance = MagicMock()
|
||||
mock_tg_instance.async_insert = AsyncMock()
|
||||
mock_kg_class.return_value = mock_tg_instance
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
|
@ -236,19 +243,17 @@ class TestCassandraStorageProcessor:
|
|||
await processor.store_triples('user1', mock_message)
|
||||
|
||||
# Verify no triples were inserted
|
||||
mock_tg_instance.insert.assert_not_called()
|
||||
mock_tg_instance.async_insert.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.storage.triples.cassandra.write.EntityCentricKnowledgeGraph')
|
||||
@patch('trustgraph.storage.triples.cassandra.write.time.sleep')
|
||||
async def test_exception_handling_with_retry(self, mock_sleep, mock_kg_class):
|
||||
async def test_exception_handling_on_connection_failure(self, mock_kg_class):
|
||||
"""Test exception handling during TrustGraph creation"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_kg_class.side_effect = Exception("Connection failed")
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
# Create mock message
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.collection = 'collection1'
|
||||
mock_message.triples = []
|
||||
|
|
@ -256,9 +261,6 @@ class TestCassandraStorageProcessor:
|
|||
with pytest.raises(Exception, match="Connection failed"):
|
||||
await processor.store_triples('user1', mock_message)
|
||||
|
||||
# Verify sleep was called before re-raising
|
||||
mock_sleep.assert_called_once_with(1)
|
||||
|
||||
def test_add_args_method(self):
|
||||
"""Test that add_args properly configures argument parser"""
|
||||
from argparse import ArgumentParser
|
||||
|
|
@ -359,8 +361,6 @@ class TestCassandraStorageProcessor:
|
|||
mock_message1.triples = []
|
||||
|
||||
await processor.store_triples('user1', mock_message1)
|
||||
assert processor.table == 'user1'
|
||||
assert processor.tg == mock_tg_instance1
|
||||
|
||||
# Second message with different table
|
||||
mock_message2 = MagicMock()
|
||||
|
|
@ -368,11 +368,11 @@ class TestCassandraStorageProcessor:
|
|||
mock_message2.triples = []
|
||||
|
||||
await processor.store_triples('user2', mock_message2)
|
||||
assert processor.table == 'user2'
|
||||
assert processor.tg == mock_tg_instance2
|
||||
|
||||
# Verify TrustGraph was created twice for different tables
|
||||
# Verify TrustGraph was created twice for different workspaces
|
||||
assert mock_kg_class.call_count == 2
|
||||
mock_kg_class.assert_any_call(hosts=['cassandra'], keyspace='user1')
|
||||
mock_kg_class.assert_any_call(hosts=['cassandra'], keyspace='user2')
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.storage.triples.cassandra.write.EntityCentricKnowledgeGraph')
|
||||
|
|
@ -380,6 +380,7 @@ class TestCassandraStorageProcessor:
|
|||
"""Test storing triples with special characters and unicode"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_tg_instance = MagicMock()
|
||||
mock_tg_instance.async_insert = AsyncMock()
|
||||
mock_kg_class.return_value = mock_tg_instance
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
|
@ -405,7 +406,7 @@ class TestCassandraStorageProcessor:
|
|||
await processor.store_triples('test_workspace', mock_message)
|
||||
|
||||
# Verify the triple was inserted with special characters preserved
|
||||
mock_tg_instance.insert.assert_called_once_with(
|
||||
mock_tg_instance.async_insert.assert_called_once_with(
|
||||
'test_collection',
|
||||
'subject with spaces & symbols',
|
||||
'predicate:with/colons',
|
||||
|
|
@ -418,29 +419,29 @@ class TestCassandraStorageProcessor:
|
|||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.storage.triples.cassandra.write.EntityCentricKnowledgeGraph')
|
||||
async def test_store_triples_preserves_old_table_on_exception(self, mock_kg_class):
|
||||
"""Test that table remains unchanged when TrustGraph creation fails"""
|
||||
async def test_connection_failure_does_not_cache_stale_state(self, mock_kg_class):
|
||||
"""Test that a failed connection doesn't leave stale cached state"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_good_instance = MagicMock()
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
# Set an initial table
|
||||
processor.table = ('old_user', 'old_collection')
|
||||
|
||||
# Mock TrustGraph to raise exception
|
||||
mock_kg_class.side_effect = Exception("Connection failed")
|
||||
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.collection = 'new_collection'
|
||||
mock_message.metadata.collection = 'collection1'
|
||||
mock_message.triples = []
|
||||
|
||||
# First call fails
|
||||
mock_kg_class.side_effect = Exception("Connection failed")
|
||||
with pytest.raises(Exception, match="Connection failed"):
|
||||
await processor.store_triples('new_user', mock_message)
|
||||
await processor.store_triples('user1', mock_message)
|
||||
|
||||
# Table should remain unchanged since self.table = table happens after try/except
|
||||
assert processor.table == ('old_user', 'old_collection')
|
||||
# TrustGraph should be set to None though
|
||||
assert processor.tg is None
|
||||
# Second call succeeds — should retry connection, not use stale state
|
||||
mock_kg_class.side_effect = None
|
||||
mock_kg_class.return_value = mock_good_instance
|
||||
await processor.store_triples('user1', mock_message)
|
||||
|
||||
# Connection was attempted twice (failed + succeeded)
|
||||
assert mock_kg_class.call_count == 2
|
||||
|
||||
|
||||
class TestCassandraPerformanceOptimizations:
|
||||
|
|
@ -452,6 +453,7 @@ class TestCassandraPerformanceOptimizations:
|
|||
"""Test that legacy mode still works with single table"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_tg_instance = MagicMock()
|
||||
mock_tg_instance.async_insert = AsyncMock()
|
||||
mock_kg_class.return_value = mock_tg_instance
|
||||
|
||||
with patch.dict('os.environ', {'CASSANDRA_USE_LEGACY': 'true'}):
|
||||
|
|
@ -472,6 +474,7 @@ class TestCassandraPerformanceOptimizations:
|
|||
"""Test that optimized mode uses multi-table schema"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_tg_instance = MagicMock()
|
||||
mock_tg_instance.async_insert = AsyncMock()
|
||||
mock_kg_class.return_value = mock_tg_instance
|
||||
|
||||
with patch.dict('os.environ', {'CASSANDRA_USE_LEGACY': 'false'}):
|
||||
|
|
@ -492,6 +495,7 @@ class TestCassandraPerformanceOptimizations:
|
|||
"""Test that all tables stay consistent during batch writes"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_tg_instance = MagicMock()
|
||||
mock_tg_instance.async_insert = AsyncMock()
|
||||
mock_kg_class.return_value = mock_tg_instance
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
|
@ -517,7 +521,7 @@ class TestCassandraPerformanceOptimizations:
|
|||
await processor.store_triples('user1', mock_message)
|
||||
|
||||
# Verify insert was called for the triple (implementation details tested in KnowledgeGraph)
|
||||
mock_tg_instance.insert.assert_called_once_with(
|
||||
mock_tg_instance.async_insert.assert_called_once_with(
|
||||
'collection1', 'test_subject', 'test_predicate', 'test_object',
|
||||
g=DEFAULT_GRAPH, otype='l', dtype='', lang=''
|
||||
)
|
||||
|
|
|
|||
|
|
@ -89,7 +89,8 @@ class TestSanitizeName:
|
|||
|
||||
class TestFindCollection:
|
||||
|
||||
def test_finds_matching_collection(self):
|
||||
@pytest.mark.asyncio
|
||||
async def test_finds_matching_collection(self):
|
||||
proc = _make_processor()
|
||||
mock_coll = MagicMock()
|
||||
mock_coll.name = "rows_test_workspace_test_col_customers_384"
|
||||
|
|
@ -98,11 +99,12 @@ class TestFindCollection:
|
|||
mock_collections.collections = [mock_coll]
|
||||
proc.qdrant.get_collections.return_value = mock_collections
|
||||
|
||||
result = proc.find_collection("test-workspace", "test-col", "customers")
|
||||
result = await proc.find_collection("test-workspace", "test-col", "customers")
|
||||
|
||||
assert result == "rows_test_workspace_test_col_customers_384"
|
||||
|
||||
def test_returns_none_when_no_match(self):
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_none_when_no_match(self):
|
||||
proc = _make_processor()
|
||||
mock_coll = MagicMock()
|
||||
mock_coll.name = "rows_other_workspace_other_col_schema_768"
|
||||
|
|
@ -111,14 +113,15 @@ class TestFindCollection:
|
|||
mock_collections.collections = [mock_coll]
|
||||
proc.qdrant.get_collections.return_value = mock_collections
|
||||
|
||||
result = proc.find_collection("test-workspace", "test-col", "customers")
|
||||
result = await proc.find_collection("test-workspace", "test-col", "customers")
|
||||
assert result is None
|
||||
|
||||
def test_returns_none_on_error(self):
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_none_on_error(self):
|
||||
proc = _make_processor()
|
||||
proc.qdrant.get_collections.side_effect = Exception("connection error")
|
||||
|
||||
result = proc.find_collection("workspace", "col", "schema")
|
||||
result = await proc.find_collection("workspace", "col", "schema")
|
||||
assert result is None
|
||||
|
||||
|
||||
|
|
@ -139,7 +142,7 @@ class TestQueryRowEmbeddings:
|
|||
@pytest.mark.asyncio
|
||||
async def test_no_collection_returns_empty(self):
|
||||
proc = _make_processor()
|
||||
proc.find_collection = MagicMock(return_value=None)
|
||||
proc.find_collection = AsyncMock(return_value=None)
|
||||
request = _make_request()
|
||||
|
||||
result = await proc.query_row_embeddings("test-workspace", request)
|
||||
|
|
@ -148,7 +151,7 @@ class TestQueryRowEmbeddings:
|
|||
@pytest.mark.asyncio
|
||||
async def test_successful_query_returns_matches(self):
|
||||
proc = _make_processor()
|
||||
proc.find_collection = MagicMock(return_value="rows_w_c_s_384")
|
||||
proc.find_collection = AsyncMock(return_value="rows_w_c_s_384")
|
||||
|
||||
points = [
|
||||
_make_search_point("name", ["Alice Smith"], "Alice Smith", 0.95),
|
||||
|
|
@ -172,7 +175,7 @@ class TestQueryRowEmbeddings:
|
|||
async def test_index_name_filter_applied(self):
|
||||
"""When index_name is specified, a Qdrant filter should be used."""
|
||||
proc = _make_processor()
|
||||
proc.find_collection = MagicMock(return_value="rows_w_c_s_384")
|
||||
proc.find_collection = AsyncMock(return_value="rows_w_c_s_384")
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.points = []
|
||||
|
|
@ -188,7 +191,7 @@ class TestQueryRowEmbeddings:
|
|||
async def test_no_index_name_no_filter(self):
|
||||
"""When index_name is empty, no filter should be applied."""
|
||||
proc = _make_processor()
|
||||
proc.find_collection = MagicMock(return_value="rows_w_c_s_384")
|
||||
proc.find_collection = AsyncMock(return_value="rows_w_c_s_384")
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.points = []
|
||||
|
|
@ -204,7 +207,7 @@ class TestQueryRowEmbeddings:
|
|||
async def test_missing_payload_fields_default(self):
|
||||
"""Points with missing payload fields should use defaults."""
|
||||
proc = _make_processor()
|
||||
proc.find_collection = MagicMock(return_value="rows_w_c_s_384")
|
||||
proc.find_collection = AsyncMock(return_value="rows_w_c_s_384")
|
||||
|
||||
point = MagicMock()
|
||||
point.payload = {} # Empty payload
|
||||
|
|
@ -225,7 +228,7 @@ class TestQueryRowEmbeddings:
|
|||
@pytest.mark.asyncio
|
||||
async def test_qdrant_error_propagates(self):
|
||||
proc = _make_processor()
|
||||
proc.find_collection = MagicMock(return_value="rows_w_c_s_384")
|
||||
proc.find_collection = AsyncMock(return_value="rows_w_c_s_384")
|
||||
proc.qdrant.query_points.side_effect = Exception("qdrant down")
|
||||
|
||||
request = _make_request()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue