Release 1.4 -> master (#524)

Catch up
This commit is contained in:
cybermaggedon 2025-09-20 16:00:37 +01:00 committed by GitHub
parent a8e437fc7f
commit 6c7af8789d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
216 changed files with 31360 additions and 1611 deletions

View file

@ -85,8 +85,10 @@ class TestMilvusDocEmbeddingsQueryProcessor:
result = await processor.query_document_embeddings(query)
# Verify search was called with correct parameters
processor.vecstore.search.assert_called_once_with([0.1, 0.2, 0.3], limit=5)
# Verify search was called with correct parameters including user/collection
processor.vecstore.search.assert_called_once_with(
[0.1, 0.2, 0.3], 'test_user', 'test_collection', limit=5
)
# Verify results are document chunks
assert len(result) == 3
@ -116,10 +118,10 @@ class TestMilvusDocEmbeddingsQueryProcessor:
result = await processor.query_document_embeddings(query)
# Verify search was called twice with correct parameters
# Verify search was called twice with correct parameters including user/collection
expected_calls = [
(([0.1, 0.2, 0.3],), {"limit": 3}),
(([0.4, 0.5, 0.6],), {"limit": 3}),
(([0.1, 0.2, 0.3], 'test_user', 'test_collection'), {"limit": 3}),
(([0.4, 0.5, 0.6], 'test_user', 'test_collection'), {"limit": 3}),
]
assert processor.vecstore.search.call_count == 2
for i, (expected_args, expected_kwargs) in enumerate(expected_calls):
@ -155,7 +157,9 @@ class TestMilvusDocEmbeddingsQueryProcessor:
result = await processor.query_document_embeddings(query)
# Verify search was called with the specified limit
processor.vecstore.search.assert_called_once_with([0.1, 0.2, 0.3], limit=2)
processor.vecstore.search.assert_called_once_with(
[0.1, 0.2, 0.3], 'test_user', 'test_collection', limit=2
)
# Verify all results are returned (Milvus handles limit internally)
assert len(result) == 4
@ -194,7 +198,9 @@ class TestMilvusDocEmbeddingsQueryProcessor:
result = await processor.query_document_embeddings(query)
# Verify search was called
processor.vecstore.search.assert_called_once_with([0.1, 0.2, 0.3], limit=5)
processor.vecstore.search.assert_called_once_with(
[0.1, 0.2, 0.3], 'test_user', 'test_collection', limit=5
)
# Verify empty results
assert len(result) == 0

View file

@ -120,7 +120,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
chunks = await processor.query_document_embeddings(message)
# Verify index was accessed correctly
expected_index_name = "d-test_user-test_collection-3"
expected_index_name = "d-test_user-test_collection"
processor.pinecone.Index.assert_called_once_with(expected_index_name)
# Verify query parameters
@ -239,7 +239,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
@pytest.mark.asyncio
async def test_query_document_embeddings_different_vector_dimensions(self, processor):
"""Test querying with vectors of different dimensions"""
"""Test querying with vectors of different dimensions using same index"""
message = MagicMock()
message.vectors = [
[0.1, 0.2], # 2D vector
@ -248,37 +248,33 @@ class TestPineconeDocEmbeddingsQueryProcessor:
message.limit = 5
message.user = 'test_user'
message.collection = 'test_collection'
mock_index_2d = MagicMock()
mock_index_4d = MagicMock()
def mock_index_side_effect(name):
if name.endswith("-2"):
return mock_index_2d
elif name.endswith("-4"):
return mock_index_4d
processor.pinecone.Index.side_effect = mock_index_side_effect
# Mock results for different dimensions
# Mock single index that handles all dimensions
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
# Mock results for different vector queries
mock_results_2d = MagicMock()
mock_results_2d.matches = [MagicMock(metadata={'doc': 'Document from 2D index'})]
mock_index_2d.query.return_value = mock_results_2d
mock_results_2d.matches = [MagicMock(metadata={'doc': 'Document from 2D query'})]
mock_results_4d = MagicMock()
mock_results_4d.matches = [MagicMock(metadata={'doc': 'Document from 4D index'})]
mock_index_4d.query.return_value = mock_results_4d
mock_results_4d.matches = [MagicMock(metadata={'doc': 'Document from 4D query'})]
mock_index.query.side_effect = [mock_results_2d, mock_results_4d]
chunks = await processor.query_document_embeddings(message)
# Verify different indexes were used
# Verify same index used for both vectors
expected_index_name = "d-test_user-test_collection"
assert processor.pinecone.Index.call_count == 2
mock_index_2d.query.assert_called_once()
mock_index_4d.query.assert_called_once()
processor.pinecone.Index.assert_called_with(expected_index_name)
# Verify both queries were made
assert mock_index.query.call_count == 2
# Verify results from both dimensions
assert 'Document from 2D index' in chunks
assert 'Document from 4D index' in chunks
assert 'Document from 2D query' in chunks
assert 'Document from 4D query' in chunks
@pytest.mark.asyncio
async def test_query_document_embeddings_empty_vectors_list(self, processor):

View file

@ -104,7 +104,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
# Assert
# Verify query was called with correct parameters
expected_collection = 'd_test_user_test_collection_3'
expected_collection = 'd_test_user_test_collection'
mock_qdrant_instance.query_points.assert_called_once_with(
collection_name=expected_collection,
query=[0.1, 0.2, 0.3],
@ -166,7 +166,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
assert mock_qdrant_instance.query_points.call_count == 2
# Verify both collections were queried
expected_collection = 'd_multi_user_multi_collection_2'
expected_collection = 'd_multi_user_multi_collection'
calls = mock_qdrant_instance.query_points.call_args_list
assert calls[0][1]['collection_name'] == expected_collection
assert calls[1][1]['collection_name'] == expected_collection
@ -303,11 +303,11 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
calls = mock_qdrant_instance.query_points.call_args_list
# First call should use 2D collection
assert calls[0][1]['collection_name'] == 'd_dim_user_dim_collection_2'
assert calls[0][1]['collection_name'] == 'd_dim_user_dim_collection'
assert calls[0][1]['query'] == [0.1, 0.2]
# Second call should use 3D collection
assert calls[1][1]['collection_name'] == 'd_dim_user_dim_collection_3'
assert calls[1][1]['collection_name'] == 'd_dim_user_dim_collection'
assert calls[1][1]['query'] == [0.3, 0.4, 0.5]
# Verify results

View file

@ -133,8 +133,10 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
result = await processor.query_graph_embeddings(query)
# Verify search was called with correct parameters
processor.vecstore.search.assert_called_once_with([0.1, 0.2, 0.3], limit=10)
# Verify search was called with correct parameters including user/collection
processor.vecstore.search.assert_called_once_with(
[0.1, 0.2, 0.3], 'test_user', 'test_collection', limit=10
)
# Verify results are converted to Value objects
assert len(result) == 3
@ -171,10 +173,10 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
result = await processor.query_graph_embeddings(query)
# Verify search was called twice with correct parameters
# Verify search was called twice with correct parameters including user/collection
expected_calls = [
(([0.1, 0.2, 0.3],), {"limit": 6}),
(([0.4, 0.5, 0.6],), {"limit": 6}),
(([0.1, 0.2, 0.3], 'test_user', 'test_collection'), {"limit": 6}),
(([0.4, 0.5, 0.6], 'test_user', 'test_collection'), {"limit": 6}),
]
assert processor.vecstore.search.call_count == 2
for i, (expected_args, expected_kwargs) in enumerate(expected_calls):
@ -211,7 +213,9 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
result = await processor.query_graph_embeddings(query)
# Verify search was called with 2*limit for better deduplication
processor.vecstore.search.assert_called_once_with([0.1, 0.2, 0.3], limit=4)
processor.vecstore.search.assert_called_once_with(
[0.1, 0.2, 0.3], 'test_user', 'test_collection', limit=4
)
# Verify results are limited to the requested limit
assert len(result) == 2
@ -269,7 +273,9 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
result = await processor.query_graph_embeddings(query)
# Verify only first vector was searched (limit reached)
processor.vecstore.search.assert_called_once_with([0.1, 0.2, 0.3], limit=4)
processor.vecstore.search.assert_called_once_with(
[0.1, 0.2, 0.3], 'test_user', 'test_collection', limit=4
)
# Verify results are limited
assert len(result) == 2
@ -308,7 +314,9 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
result = await processor.query_graph_embeddings(query)
# Verify search was called
processor.vecstore.search.assert_called_once_with([0.1, 0.2, 0.3], limit=10)
processor.vecstore.search.assert_called_once_with(
[0.1, 0.2, 0.3], 'test_user', 'test_collection', limit=10
)
# Verify empty results
assert len(result) == 0

View file

@ -148,7 +148,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
entities = await processor.query_graph_embeddings(message)
# Verify index was accessed correctly
expected_index_name = "t-test_user-test_collection-3"
expected_index_name = "t-test_user-test_collection"
processor.pinecone.Index.assert_called_once_with(expected_index_name)
# Verify query parameters
@ -265,7 +265,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
@pytest.mark.asyncio
async def test_query_graph_embeddings_different_vector_dimensions(self, processor):
"""Test querying with vectors of different dimensions"""
"""Test querying with vectors of different dimensions using same index"""
message = MagicMock()
message.vectors = [
[0.1, 0.2], # 2D vector
@ -274,34 +274,30 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
message.limit = 5
message.user = 'test_user'
message.collection = 'test_collection'
mock_index_2d = MagicMock()
mock_index_4d = MagicMock()
def mock_index_side_effect(name):
if name.endswith("-2"):
return mock_index_2d
elif name.endswith("-4"):
return mock_index_4d
processor.pinecone.Index.side_effect = mock_index_side_effect
# Mock results for different dimensions
# Mock single index that handles all dimensions
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
# Mock results for different vector queries
mock_results_2d = MagicMock()
mock_results_2d.matches = [MagicMock(metadata={'entity': 'entity_2d'})]
mock_index_2d.query.return_value = mock_results_2d
mock_results_4d = MagicMock()
mock_results_4d.matches = [MagicMock(metadata={'entity': 'entity_4d'})]
mock_index_4d.query.return_value = mock_results_4d
mock_index.query.side_effect = [mock_results_2d, mock_results_4d]
entities = await processor.query_graph_embeddings(message)
# Verify different indexes were used
# Verify same index used for both vectors
expected_index_name = "t-test_user-test_collection"
assert processor.pinecone.Index.call_count == 2
mock_index_2d.query.assert_called_once()
mock_index_4d.query.assert_called_once()
processor.pinecone.Index.assert_called_with(expected_index_name)
# Verify both queries were made
assert mock_index.query.call_count == 2
# Verify results from both dimensions
entity_values = [e.value for e in entities]
assert 'entity_2d' in entity_values

View file

@ -176,7 +176,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
# Assert
# Verify query was called with correct parameters
expected_collection = 't_test_user_test_collection_3'
expected_collection = 't_test_user_test_collection'
mock_qdrant_instance.query_points.assert_called_once_with(
collection_name=expected_collection,
query=[0.1, 0.2, 0.3],
@ -236,7 +236,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
assert mock_qdrant_instance.query_points.call_count == 2
# Verify both collections were queried
expected_collection = 't_multi_user_multi_collection_2'
expected_collection = 't_multi_user_multi_collection'
calls = mock_qdrant_instance.query_points.call_args_list
assert calls[0][1]['collection_name'] == expected_collection
assert calls[1][1]['collection_name'] == expected_collection
@ -374,11 +374,11 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
calls = mock_qdrant_instance.query_points.call_args_list
# First call should use 2D collection
assert calls[0][1]['collection_name'] == 't_dim_user_dim_collection_2'
assert calls[0][1]['collection_name'] == 't_dim_user_dim_collection'
assert calls[0][1]['query'] == [0.1, 0.2]
# Second call should use 3D collection
assert calls[1][1]['collection_name'] == 't_dim_user_dim_collection_3'
assert calls[1][1]['collection_name'] == 't_dim_user_dim_collection'
assert calls[1][1]['query'] == [0.3, 0.4, 0.5]
# Verify results

View file

@ -0,0 +1,432 @@
"""
Tests for Memgraph user/collection isolation in query service
"""
import pytest
from unittest.mock import MagicMock, patch
from trustgraph.query.triples.memgraph.service import Processor
from trustgraph.schema import TriplesQueryRequest, Value
class TestMemgraphQueryUserCollectionIsolation:
"""Test cases for Memgraph query service with user/collection isolation"""
@patch('trustgraph.query.triples.memgraph.service.GraphDatabase')
@pytest.mark.asyncio
async def test_spo_query_with_user_collection(self, mock_graph_db):
"""Test SPO query pattern includes user/collection filtering"""
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
processor = Processor(taskgroup=MagicMock())
query = TriplesQueryRequest(
user="test_user",
collection="test_collection",
s=Value(value="http://example.com/s", is_uri=True),
p=Value(value="http://example.com/p", is_uri=True),
o=Value(value="test_object", is_uri=False),
limit=1000
)
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
await processor.query_triples(query)
# Verify SPO query for literal includes user/collection
expected_query = (
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
"[rel:Rel {uri: $rel, user: $user, collection: $collection}]->"
"(dest:Literal {value: $value, user: $user, collection: $collection}) "
"RETURN $src as src "
"LIMIT 1000"
)
mock_driver.execute_query.assert_any_call(
expected_query,
src="http://example.com/s",
rel="http://example.com/p",
value="test_object",
user="test_user",
collection="test_collection",
database_='memgraph'
)
@patch('trustgraph.query.triples.memgraph.service.GraphDatabase')
@pytest.mark.asyncio
async def test_sp_query_with_user_collection(self, mock_graph_db):
"""Test SP query pattern includes user/collection filtering"""
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
processor = Processor(taskgroup=MagicMock())
query = TriplesQueryRequest(
user="test_user",
collection="test_collection",
s=Value(value="http://example.com/s", is_uri=True),
p=Value(value="http://example.com/p", is_uri=True),
o=None,
limit=1000
)
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
await processor.query_triples(query)
# Verify SP query for literals includes user/collection
expected_literal_query = (
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
"[rel:Rel {uri: $rel, user: $user, collection: $collection}]->"
"(dest:Literal {user: $user, collection: $collection}) "
"RETURN dest.value as dest "
"LIMIT 1000"
)
mock_driver.execute_query.assert_any_call(
expected_literal_query,
src="http://example.com/s",
rel="http://example.com/p",
user="test_user",
collection="test_collection",
database_='memgraph'
)
@patch('trustgraph.query.triples.memgraph.service.GraphDatabase')
@pytest.mark.asyncio
async def test_so_query_with_user_collection(self, mock_graph_db):
"""Test SO query pattern includes user/collection filtering"""
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
processor = Processor(taskgroup=MagicMock())
query = TriplesQueryRequest(
user="test_user",
collection="test_collection",
s=Value(value="http://example.com/s", is_uri=True),
p=None,
o=Value(value="http://example.com/o", is_uri=True),
limit=1000
)
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
await processor.query_triples(query)
# Verify SO query for nodes includes user/collection
expected_query = (
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
"[rel:Rel {user: $user, collection: $collection}]->"
"(dest:Node {uri: $uri, user: $user, collection: $collection}) "
"RETURN rel.uri as rel "
"LIMIT 1000"
)
mock_driver.execute_query.assert_any_call(
expected_query,
src="http://example.com/s",
uri="http://example.com/o",
user="test_user",
collection="test_collection",
database_='memgraph'
)
@patch('trustgraph.query.triples.memgraph.service.GraphDatabase')
@pytest.mark.asyncio
async def test_s_only_query_with_user_collection(self, mock_graph_db):
"""Test S-only query pattern includes user/collection filtering"""
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
processor = Processor(taskgroup=MagicMock())
query = TriplesQueryRequest(
user="test_user",
collection="test_collection",
s=Value(value="http://example.com/s", is_uri=True),
p=None,
o=None,
limit=1000
)
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
await processor.query_triples(query)
# Verify S query includes user/collection
expected_query = (
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
"[rel:Rel {user: $user, collection: $collection}]->"
"(dest:Literal {user: $user, collection: $collection}) "
"RETURN rel.uri as rel, dest.value as dest "
"LIMIT 1000"
)
mock_driver.execute_query.assert_any_call(
expected_query,
src="http://example.com/s",
user="test_user",
collection="test_collection",
database_='memgraph'
)
@patch('trustgraph.query.triples.memgraph.service.GraphDatabase')
@pytest.mark.asyncio
async def test_po_query_with_user_collection(self, mock_graph_db):
"""Test PO query pattern includes user/collection filtering"""
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
processor = Processor(taskgroup=MagicMock())
query = TriplesQueryRequest(
user="test_user",
collection="test_collection",
s=None,
p=Value(value="http://example.com/p", is_uri=True),
o=Value(value="literal", is_uri=False),
limit=1000
)
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
await processor.query_triples(query)
# Verify PO query for literals includes user/collection
expected_query = (
"MATCH (src:Node {user: $user, collection: $collection})-"
"[rel:Rel {uri: $uri, user: $user, collection: $collection}]->"
"(dest:Literal {value: $value, user: $user, collection: $collection}) "
"RETURN src.uri as src "
"LIMIT 1000"
)
mock_driver.execute_query.assert_any_call(
expected_query,
uri="http://example.com/p",
value="literal",
user="test_user",
collection="test_collection",
database_='memgraph'
)
@patch('trustgraph.query.triples.memgraph.service.GraphDatabase')
@pytest.mark.asyncio
async def test_p_only_query_with_user_collection(self, mock_graph_db):
"""Test P-only query pattern includes user/collection filtering"""
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
processor = Processor(taskgroup=MagicMock())
query = TriplesQueryRequest(
user="test_user",
collection="test_collection",
s=None,
p=Value(value="http://example.com/p", is_uri=True),
o=None,
limit=1000
)
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
await processor.query_triples(query)
# Verify P query includes user/collection
expected_query = (
"MATCH (src:Node {user: $user, collection: $collection})-"
"[rel:Rel {uri: $uri, user: $user, collection: $collection}]->"
"(dest:Literal {user: $user, collection: $collection}) "
"RETURN src.uri as src, dest.value as dest "
"LIMIT 1000"
)
mock_driver.execute_query.assert_any_call(
expected_query,
uri="http://example.com/p",
user="test_user",
collection="test_collection",
database_='memgraph'
)
@patch('trustgraph.query.triples.memgraph.service.GraphDatabase')
@pytest.mark.asyncio
async def test_o_only_query_with_user_collection(self, mock_graph_db):
"""Test O-only query pattern includes user/collection filtering"""
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
processor = Processor(taskgroup=MagicMock())
query = TriplesQueryRequest(
user="test_user",
collection="test_collection",
s=None,
p=None,
o=Value(value="test_value", is_uri=False),
limit=1000
)
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
await processor.query_triples(query)
# Verify O query for literals includes user/collection
expected_query = (
"MATCH (src:Node {user: $user, collection: $collection})-"
"[rel:Rel {user: $user, collection: $collection}]->"
"(dest:Literal {value: $value, user: $user, collection: $collection}) "
"RETURN src.uri as src, rel.uri as rel "
"LIMIT 1000"
)
mock_driver.execute_query.assert_any_call(
expected_query,
value="test_value",
user="test_user",
collection="test_collection",
database_='memgraph'
)
@patch('trustgraph.query.triples.memgraph.service.GraphDatabase')
@pytest.mark.asyncio
async def test_wildcard_query_with_user_collection(self, mock_graph_db):
"""Test wildcard query (all None) includes user/collection filtering"""
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
processor = Processor(taskgroup=MagicMock())
query = TriplesQueryRequest(
user="test_user",
collection="test_collection",
s=None,
p=None,
o=None,
limit=1000
)
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
await processor.query_triples(query)
# Verify wildcard query for literals includes user/collection
expected_literal_query = (
"MATCH (src:Node {user: $user, collection: $collection})-"
"[rel:Rel {user: $user, collection: $collection}]->"
"(dest:Literal {user: $user, collection: $collection}) "
"RETURN src.uri as src, rel.uri as rel, dest.value as dest "
"LIMIT 1000"
)
mock_driver.execute_query.assert_any_call(
expected_literal_query,
user="test_user",
collection="test_collection",
database_='memgraph'
)
# Verify wildcard query for nodes includes user/collection
expected_node_query = (
"MATCH (src:Node {user: $user, collection: $collection})-"
"[rel:Rel {user: $user, collection: $collection}]->"
"(dest:Node {user: $user, collection: $collection}) "
"RETURN src.uri as src, rel.uri as rel, dest.uri as dest "
"LIMIT 1000"
)
mock_driver.execute_query.assert_any_call(
expected_node_query,
user="test_user",
collection="test_collection",
database_='memgraph'
)
@patch('trustgraph.query.triples.memgraph.service.GraphDatabase')
@pytest.mark.asyncio
async def test_query_with_defaults_when_not_provided(self, mock_graph_db):
"""Test that defaults are used when user/collection not provided"""
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
processor = Processor(taskgroup=MagicMock())
# Query without user/collection fields
query = TriplesQueryRequest(
s=Value(value="http://example.com/s", is_uri=True),
p=None,
o=None,
limit=1000
)
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
await processor.query_triples(query)
# Verify defaults were used
calls = mock_driver.execute_query.call_args_list
for call in calls:
if 'user' in call.kwargs:
assert call.kwargs['user'] == 'default'
if 'collection' in call.kwargs:
assert call.kwargs['collection'] == 'default'
@patch('trustgraph.query.triples.memgraph.service.GraphDatabase')
@pytest.mark.asyncio
async def test_results_properly_converted_to_triples(self, mock_graph_db):
"""Test that query results are properly converted to Triple objects"""
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
processor = Processor(taskgroup=MagicMock())
query = TriplesQueryRequest(
user="test_user",
collection="test_collection",
s=Value(value="http://example.com/s", is_uri=True),
p=None,
o=None,
limit=1000
)
# Mock some results
mock_record1 = MagicMock()
mock_record1.data.return_value = {
"rel": "http://example.com/p1",
"dest": "literal_value"
}
mock_record2 = MagicMock()
mock_record2.data.return_value = {
"rel": "http://example.com/p2",
"dest": "http://example.com/o"
}
# Return results for literal query, empty for node query
mock_driver.execute_query.side_effect = [
([mock_record1], MagicMock(), MagicMock()), # Literal query
([mock_record2], MagicMock(), MagicMock()) # Node query
]
result = await processor.query_triples(query)
# Verify results are proper Triple objects
assert len(result) == 2
# First triple (literal object)
assert result[0].s.value == "http://example.com/s"
assert result[0].s.is_uri == True
assert result[0].p.value == "http://example.com/p1"
assert result[0].p.is_uri == True
assert result[0].o.value == "literal_value"
assert result[0].o.is_uri == False
# Second triple (URI object)
assert result[1].s.value == "http://example.com/s"
assert result[1].s.is_uri == True
assert result[1].p.value == "http://example.com/p2"
assert result[1].p.is_uri == True
assert result[1].o.value == "http://example.com/o"
assert result[1].o.is_uri == True

View file

@ -0,0 +1,430 @@
"""
Tests for Neo4j user/collection isolation in query service
"""
import pytest
from unittest.mock import MagicMock, patch
from trustgraph.query.triples.neo4j.service import Processor
from trustgraph.schema import TriplesQueryRequest, Value
class TestNeo4jQueryUserCollectionIsolation:
"""Test cases for Neo4j query service with user/collection isolation"""
@patch('trustgraph.query.triples.neo4j.service.GraphDatabase')
@pytest.mark.asyncio
async def test_spo_query_with_user_collection(self, mock_graph_db):
"""Test SPO query pattern includes user/collection filtering"""
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
processor = Processor(taskgroup=MagicMock())
query = TriplesQueryRequest(
user="test_user",
collection="test_collection",
s=Value(value="http://example.com/s", is_uri=True),
p=Value(value="http://example.com/p", is_uri=True),
o=Value(value="test_object", is_uri=False)
)
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
await processor.query_triples(query)
# Verify SPO query for literal includes user/collection
expected_query = (
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
"[rel:Rel {uri: $rel, user: $user, collection: $collection}]->"
"(dest:Literal {value: $value, user: $user, collection: $collection}) "
"RETURN $src as src"
)
mock_driver.execute_query.assert_any_call(
expected_query,
src="http://example.com/s",
rel="http://example.com/p",
value="test_object",
user="test_user",
collection="test_collection",
database_='neo4j'
)
@patch('trustgraph.query.triples.neo4j.service.GraphDatabase')
@pytest.mark.asyncio
async def test_sp_query_with_user_collection(self, mock_graph_db):
"""Test SP query pattern includes user/collection filtering"""
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
processor = Processor(taskgroup=MagicMock())
query = TriplesQueryRequest(
user="test_user",
collection="test_collection",
s=Value(value="http://example.com/s", is_uri=True),
p=Value(value="http://example.com/p", is_uri=True),
o=None
)
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
await processor.query_triples(query)
# Verify SP query for literals includes user/collection
expected_literal_query = (
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
"[rel:Rel {uri: $rel, user: $user, collection: $collection}]->"
"(dest:Literal {user: $user, collection: $collection}) "
"RETURN dest.value as dest"
)
mock_driver.execute_query.assert_any_call(
expected_literal_query,
src="http://example.com/s",
rel="http://example.com/p",
user="test_user",
collection="test_collection",
database_='neo4j'
)
# Verify SP query for nodes includes user/collection
expected_node_query = (
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
"[rel:Rel {uri: $rel, user: $user, collection: $collection}]->"
"(dest:Node {user: $user, collection: $collection}) "
"RETURN dest.uri as dest"
)
mock_driver.execute_query.assert_any_call(
expected_node_query,
src="http://example.com/s",
rel="http://example.com/p",
user="test_user",
collection="test_collection",
database_='neo4j'
)
@patch('trustgraph.query.triples.neo4j.service.GraphDatabase')
@pytest.mark.asyncio
async def test_so_query_with_user_collection(self, mock_graph_db):
"""Test SO query pattern includes user/collection filtering"""
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
processor = Processor(taskgroup=MagicMock())
query = TriplesQueryRequest(
user="test_user",
collection="test_collection",
s=Value(value="http://example.com/s", is_uri=True),
p=None,
o=Value(value="http://example.com/o", is_uri=True)
)
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
await processor.query_triples(query)
# Verify SO query for nodes includes user/collection
expected_query = (
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
"[rel:Rel {user: $user, collection: $collection}]->"
"(dest:Node {uri: $uri, user: $user, collection: $collection}) "
"RETURN rel.uri as rel"
)
mock_driver.execute_query.assert_any_call(
expected_query,
src="http://example.com/s",
uri="http://example.com/o",
user="test_user",
collection="test_collection",
database_='neo4j'
)
@patch('trustgraph.query.triples.neo4j.service.GraphDatabase')
@pytest.mark.asyncio
async def test_s_only_query_with_user_collection(self, mock_graph_db):
"""Test S-only query pattern includes user/collection filtering"""
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
processor = Processor(taskgroup=MagicMock())
query = TriplesQueryRequest(
user="test_user",
collection="test_collection",
s=Value(value="http://example.com/s", is_uri=True),
p=None,
o=None
)
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
await processor.query_triples(query)
# Verify S query includes user/collection
expected_query = (
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
"[rel:Rel {user: $user, collection: $collection}]->"
"(dest:Literal {user: $user, collection: $collection}) "
"RETURN rel.uri as rel, dest.value as dest"
)
mock_driver.execute_query.assert_any_call(
expected_query,
src="http://example.com/s",
user="test_user",
collection="test_collection",
database_='neo4j'
)
@patch('trustgraph.query.triples.neo4j.service.GraphDatabase')
@pytest.mark.asyncio
async def test_po_query_with_user_collection(self, mock_graph_db):
"""Test PO query pattern includes user/collection filtering"""
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
processor = Processor(taskgroup=MagicMock())
query = TriplesQueryRequest(
user="test_user",
collection="test_collection",
s=None,
p=Value(value="http://example.com/p", is_uri=True),
o=Value(value="literal", is_uri=False)
)
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
await processor.query_triples(query)
# Verify PO query for literals includes user/collection
expected_query = (
"MATCH (src:Node {user: $user, collection: $collection})-"
"[rel:Rel {uri: $uri, user: $user, collection: $collection}]->"
"(dest:Literal {value: $value, user: $user, collection: $collection}) "
"RETURN src.uri as src"
)
mock_driver.execute_query.assert_any_call(
expected_query,
uri="http://example.com/p",
value="literal",
user="test_user",
collection="test_collection",
database_='neo4j'
)
@patch('trustgraph.query.triples.neo4j.service.GraphDatabase')
@pytest.mark.asyncio
async def test_p_only_query_with_user_collection(self, mock_graph_db):
"""Test P-only query pattern includes user/collection filtering"""
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
processor = Processor(taskgroup=MagicMock())
query = TriplesQueryRequest(
user="test_user",
collection="test_collection",
s=None,
p=Value(value="http://example.com/p", is_uri=True),
o=None
)
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
await processor.query_triples(query)
# Verify P query includes user/collection
expected_query = (
"MATCH (src:Node {user: $user, collection: $collection})-"
"[rel:Rel {uri: $uri, user: $user, collection: $collection}]->"
"(dest:Literal {user: $user, collection: $collection}) "
"RETURN src.uri as src, dest.value as dest"
)
mock_driver.execute_query.assert_any_call(
expected_query,
uri="http://example.com/p",
user="test_user",
collection="test_collection",
database_='neo4j'
)
@patch('trustgraph.query.triples.neo4j.service.GraphDatabase')
@pytest.mark.asyncio
async def test_o_only_query_with_user_collection(self, mock_graph_db):
"""Test O-only query pattern includes user/collection filtering"""
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
processor = Processor(taskgroup=MagicMock())
query = TriplesQueryRequest(
user="test_user",
collection="test_collection",
s=None,
p=None,
o=Value(value="test_value", is_uri=False)
)
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
await processor.query_triples(query)
# Verify O query for literals includes user/collection
expected_query = (
"MATCH (src:Node {user: $user, collection: $collection})-"
"[rel:Rel {user: $user, collection: $collection}]->"
"(dest:Literal {value: $value, user: $user, collection: $collection}) "
"RETURN src.uri as src, rel.uri as rel"
)
mock_driver.execute_query.assert_any_call(
expected_query,
value="test_value",
user="test_user",
collection="test_collection",
database_='neo4j'
)
@patch('trustgraph.query.triples.neo4j.service.GraphDatabase')
@pytest.mark.asyncio
async def test_wildcard_query_with_user_collection(self, mock_graph_db):
"""Test wildcard query (all None) includes user/collection filtering"""
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
processor = Processor(taskgroup=MagicMock())
query = TriplesQueryRequest(
user="test_user",
collection="test_collection",
s=None,
p=None,
o=None
)
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
await processor.query_triples(query)
# Verify wildcard query for literals includes user/collection
expected_literal_query = (
"MATCH (src:Node {user: $user, collection: $collection})-"
"[rel:Rel {user: $user, collection: $collection}]->"
"(dest:Literal {user: $user, collection: $collection}) "
"RETURN src.uri as src, rel.uri as rel, dest.value as dest"
)
mock_driver.execute_query.assert_any_call(
expected_literal_query,
user="test_user",
collection="test_collection",
database_='neo4j'
)
# Verify wildcard query for nodes includes user/collection
expected_node_query = (
"MATCH (src:Node {user: $user, collection: $collection})-"
"[rel:Rel {user: $user, collection: $collection}]->"
"(dest:Node {user: $user, collection: $collection}) "
"RETURN src.uri as src, rel.uri as rel, dest.uri as dest"
)
mock_driver.execute_query.assert_any_call(
expected_node_query,
user="test_user",
collection="test_collection",
database_='neo4j'
)
@patch('trustgraph.query.triples.neo4j.service.GraphDatabase')
@pytest.mark.asyncio
async def test_query_with_defaults_when_not_provided(self, mock_graph_db):
"""Test that defaults are used when user/collection not provided"""
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
processor = Processor(taskgroup=MagicMock())
# Query without user/collection fields
query = TriplesQueryRequest(
s=Value(value="http://example.com/s", is_uri=True),
p=None,
o=None
)
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
await processor.query_triples(query)
# Verify defaults were used
calls = mock_driver.execute_query.call_args_list
for call in calls:
if 'user' in call.kwargs:
assert call.kwargs['user'] == 'default'
if 'collection' in call.kwargs:
assert call.kwargs['collection'] == 'default'
@patch('trustgraph.query.triples.neo4j.service.GraphDatabase')
@pytest.mark.asyncio
async def test_results_properly_converted_to_triples(self, mock_graph_db):
"""Test that query results are properly converted to Triple objects"""
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
processor = Processor(taskgroup=MagicMock())
query = TriplesQueryRequest(
user="test_user",
collection="test_collection",
s=Value(value="http://example.com/s", is_uri=True),
p=None,
o=None
)
# Mock some results
mock_record1 = MagicMock()
mock_record1.data.return_value = {
"rel": "http://example.com/p1",
"dest": "literal_value"
}
mock_record2 = MagicMock()
mock_record2.data.return_value = {
"rel": "http://example.com/p2",
"dest": "http://example.com/o"
}
# Return results for literal query, empty for node query
mock_driver.execute_query.side_effect = [
([mock_record1], MagicMock(), MagicMock()), # Literal query
([mock_record2], MagicMock(), MagicMock()) # Node query
]
result = await processor.query_triples(query)
# Verify results are proper Triple objects
assert len(result) == 2
# First triple (literal object)
assert result[0].s.value == "http://example.com/s"
assert result[0].s.is_uri == True
assert result[0].p.value == "http://example.com/p1"
assert result[0].p.is_uri == True
assert result[0].o.value == "literal_value"
assert result[0].o.is_uri == False
# Second triple (URI object)
assert result[1].s.value == "http://example.com/s"
assert result[1].s.is_uri == True
assert result[1].p.value == "http://example.com/p2"
assert result[1].p.is_uri == True
assert result[1].o.value == "http://example.com/o"
assert result[1].o.is_uri == True

View file

@ -0,0 +1,551 @@
"""
Unit tests for Cassandra Objects GraphQL Query Processor
Tests the business logic of the GraphQL query processor including:
- GraphQL schema generation from RowSchema
- Query execution and validation
- CQL translation logic
- Message processing logic
"""
import pytest
from unittest.mock import MagicMock, AsyncMock, patch
import json
import strawberry
from strawberry import Schema
from trustgraph.query.objects.cassandra.service import Processor
from trustgraph.schema import ObjectsQueryRequest, ObjectsQueryResponse, GraphQLError
from trustgraph.schema import RowSchema, Field
class TestObjectsGraphQLQueryLogic:
"""Test business logic without external dependencies"""
def test_get_python_type_mapping(self):
"""Test schema field type conversion to Python types"""
processor = MagicMock()
processor.get_python_type = Processor.get_python_type.__get__(processor, Processor)
# Basic type mappings
assert processor.get_python_type("string") == str
assert processor.get_python_type("integer") == int
assert processor.get_python_type("float") == float
assert processor.get_python_type("boolean") == bool
assert processor.get_python_type("timestamp") == str
assert processor.get_python_type("date") == str
assert processor.get_python_type("time") == str
assert processor.get_python_type("uuid") == str
# Unknown type defaults to str
assert processor.get_python_type("unknown_type") == str
def test_create_graphql_type_basic_fields(self):
"""Test GraphQL type creation for basic field types"""
processor = MagicMock()
processor.get_python_type = Processor.get_python_type.__get__(processor, Processor)
processor.create_graphql_type = Processor.create_graphql_type.__get__(processor, Processor)
# Create test schema
schema = RowSchema(
name="test_table",
description="Test table",
fields=[
Field(
name="id",
type="string",
primary=True,
required=True,
description="Primary key"
),
Field(
name="name",
type="string",
required=True,
description="Name field"
),
Field(
name="age",
type="integer",
required=False,
description="Optional age"
),
Field(
name="active",
type="boolean",
required=False,
description="Status flag"
)
]
)
# Create GraphQL type
graphql_type = processor.create_graphql_type("test_table", schema)
# Verify type was created
assert graphql_type is not None
assert hasattr(graphql_type, '__name__')
assert "TestTable" in graphql_type.__name__ or "test_table" in graphql_type.__name__.lower()
def test_sanitize_name_cassandra_compatibility(self):
"""Test name sanitization for Cassandra field names"""
processor = MagicMock()
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
# Test field name sanitization (matches storage processor)
assert processor.sanitize_name("simple_field") == "simple_field"
assert processor.sanitize_name("Field-With-Dashes") == "field_with_dashes"
assert processor.sanitize_name("field.with.dots") == "field_with_dots"
assert processor.sanitize_name("123_field") == "o_123_field"
assert processor.sanitize_name("field with spaces") == "field_with_spaces"
assert processor.sanitize_name("special!@#chars") == "special___chars"
assert processor.sanitize_name("UPPERCASE") == "uppercase"
assert processor.sanitize_name("CamelCase") == "camelcase"
def test_sanitize_table_name(self):
"""Test table name sanitization (always gets o_ prefix)"""
processor = MagicMock()
processor.sanitize_table = Processor.sanitize_table.__get__(processor, Processor)
# Table names always get o_ prefix
assert processor.sanitize_table("simple_table") == "o_simple_table"
assert processor.sanitize_table("Table-Name") == "o_table_name"
assert processor.sanitize_table("123table") == "o_123table"
assert processor.sanitize_table("") == "o_"
@pytest.mark.asyncio
async def test_schema_config_parsing(self):
"""Test parsing of schema configuration"""
processor = MagicMock()
processor.schemas = {}
processor.graphql_types = {}
processor.graphql_schema = None
processor.config_key = "schema" # Set the config key
processor.generate_graphql_schema = AsyncMock()
processor.on_schema_config = Processor.on_schema_config.__get__(processor, Processor)
# Create test config
schema_config = {
"schema": {
"customer": json.dumps({
"name": "customer",
"description": "Customer table",
"fields": [
{
"name": "id",
"type": "string",
"primary_key": True,
"required": True,
"description": "Customer ID"
},
{
"name": "email",
"type": "string",
"indexed": True,
"required": True
},
{
"name": "status",
"type": "string",
"enum": ["active", "inactive"]
}
]
})
}
}
# Process config
await processor.on_schema_config(schema_config, version=1)
# Verify schema was loaded
assert "customer" in processor.schemas
schema = processor.schemas["customer"]
assert schema.name == "customer"
assert len(schema.fields) == 3
# Verify fields
id_field = next(f for f in schema.fields if f.name == "id")
assert id_field.primary is True
# The field should have been created correctly from JSON
# Let's test what we can verify - that the field has the right attributes
assert hasattr(id_field, 'required') # Has the required attribute
assert hasattr(id_field, 'primary') # Has the primary attribute
email_field = next(f for f in schema.fields if f.name == "email")
assert email_field.indexed is True
status_field = next(f for f in schema.fields if f.name == "status")
assert status_field.enum_values == ["active", "inactive"]
# Verify GraphQL schema regeneration was called
processor.generate_graphql_schema.assert_called_once()
def test_cql_query_building_basic(self):
"""Test basic CQL query construction"""
processor = MagicMock()
processor.session = MagicMock()
processor.connect_cassandra = MagicMock()
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
processor.sanitize_table = Processor.sanitize_table.__get__(processor, Processor)
processor.parse_filter_key = Processor.parse_filter_key.__get__(processor, Processor)
processor.query_cassandra = Processor.query_cassandra.__get__(processor, Processor)
# Mock session execute to capture the query
mock_result = []
processor.session.execute.return_value = mock_result
# Create test schema
schema = RowSchema(
name="test_table",
fields=[
Field(name="id", type="string", primary=True),
Field(name="name", type="string", indexed=True),
Field(name="status", type="string")
]
)
# Test query building
asyncio = pytest.importorskip("asyncio")
async def run_test():
await processor.query_cassandra(
user="test_user",
collection="test_collection",
schema_name="test_table",
row_schema=schema,
filters={"name": "John", "invalid_filter": "ignored"},
limit=10
)
# Run the async test
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(run_test())
finally:
loop.close()
# Verify Cassandra connection and query execution
processor.connect_cassandra.assert_called_once()
processor.session.execute.assert_called_once()
# Verify the query structure (can't easily test exact query without complex mocking)
call_args = processor.session.execute.call_args
query = call_args[0][0] # First positional argument is the query
params = call_args[0][1] # Second positional argument is parameters
# Basic query structure checks
assert "SELECT * FROM test_user.o_test_table" in query
assert "WHERE" in query
assert "collection = %s" in query
assert "LIMIT 10" in query
# Parameters should include collection and name filter
assert "test_collection" in params
assert "John" in params
@pytest.mark.asyncio
async def test_graphql_context_handling(self):
"""Test GraphQL execution context setup"""
processor = MagicMock()
processor.graphql_schema = AsyncMock()
processor.execute_graphql_query = Processor.execute_graphql_query.__get__(processor, Processor)
# Mock schema execution
mock_result = MagicMock()
mock_result.data = {"customers": [{"id": "1", "name": "Test"}]}
mock_result.errors = None
processor.graphql_schema.execute.return_value = mock_result
result = await processor.execute_graphql_query(
query='{ customers { id name } }',
variables={},
operation_name=None,
user="test_user",
collection="test_collection"
)
# Verify schema.execute was called with correct context
processor.graphql_schema.execute.assert_called_once()
call_args = processor.graphql_schema.execute.call_args
# Verify context was passed
context = call_args[1]['context_value'] # keyword argument
assert context["processor"] == processor
assert context["user"] == "test_user"
assert context["collection"] == "test_collection"
# Verify result structure
assert "data" in result
assert result["data"] == {"customers": [{"id": "1", "name": "Test"}]}
@pytest.mark.asyncio
async def test_error_handling_graphql_errors(self):
"""Test GraphQL error handling and conversion"""
processor = MagicMock()
processor.graphql_schema = AsyncMock()
processor.execute_graphql_query = Processor.execute_graphql_query.__get__(processor, Processor)
# Create a simple object to simulate GraphQL error instead of MagicMock
class MockError:
def __init__(self, message, path, extensions):
self.message = message
self.path = path
self.extensions = extensions
def __str__(self):
return self.message
mock_error = MockError(
message="Field 'invalid_field' doesn't exist",
path=["customers", "0", "invalid_field"],
extensions={"code": "FIELD_NOT_FOUND"}
)
mock_result = MagicMock()
mock_result.data = None
mock_result.errors = [mock_error]
processor.graphql_schema.execute.return_value = mock_result
result = await processor.execute_graphql_query(
query='{ customers { invalid_field } }',
variables={},
operation_name=None,
user="test_user",
collection="test_collection"
)
# Verify error handling
assert "errors" in result
assert len(result["errors"]) == 1
error = result["errors"][0]
assert error["message"] == "Field 'invalid_field' doesn't exist"
assert error["path"] == ["customers", "0", "invalid_field"] # Fixed to match string path
assert error["extensions"] == {"code": "FIELD_NOT_FOUND"}
def test_schema_generation_basic_structure(self):
"""Test basic GraphQL schema generation structure"""
processor = MagicMock()
processor.schemas = {
"customer": RowSchema(
name="customer",
fields=[
Field(name="id", type="string", primary=True),
Field(name="name", type="string")
]
)
}
processor.graphql_types = {}
processor.get_python_type = Processor.get_python_type.__get__(processor, Processor)
processor.create_graphql_type = Processor.create_graphql_type.__get__(processor, Processor)
# Test individual type creation (avoiding the full schema generation which has annotation issues)
graphql_type = processor.create_graphql_type("customer", processor.schemas["customer"])
processor.graphql_types["customer"] = graphql_type
# Verify type was created
assert len(processor.graphql_types) == 1
assert "customer" in processor.graphql_types
assert processor.graphql_types["customer"] is not None
@pytest.mark.asyncio
async def test_message_processing_success(self):
"""Test successful message processing flow"""
processor = MagicMock()
processor.execute_graphql_query = AsyncMock()
processor.on_message = Processor.on_message.__get__(processor, Processor)
# Mock successful query result
processor.execute_graphql_query.return_value = {
"data": {"customers": [{"id": "1", "name": "John"}]},
"errors": [],
"extensions": {"execution_time": "0.1"} # Extensions must be strings for Map(String())
}
# Create mock message
mock_msg = MagicMock()
mock_request = ObjectsQueryRequest(
user="test_user",
collection="test_collection",
query='{ customers { id name } }',
variables={},
operation_name=None
)
mock_msg.value.return_value = mock_request
mock_msg.properties.return_value = {"id": "test-123"}
# Mock flow
mock_flow = MagicMock()
mock_response_flow = AsyncMock()
mock_flow.return_value = mock_response_flow
# Process message
await processor.on_message(mock_msg, None, mock_flow)
# Verify query was executed
processor.execute_graphql_query.assert_called_once_with(
query='{ customers { id name } }',
variables={},
operation_name=None,
user="test_user",
collection="test_collection"
)
# Verify response was sent
mock_response_flow.send.assert_called_once()
response_call = mock_response_flow.send.call_args[0][0]
# Verify response structure
assert isinstance(response_call, ObjectsQueryResponse)
assert response_call.error is None
assert '"customers"' in response_call.data # JSON encoded
assert len(response_call.errors) == 0
@pytest.mark.asyncio
async def test_message_processing_error(self):
"""Test error handling during message processing"""
processor = MagicMock()
processor.execute_graphql_query = AsyncMock()
processor.on_message = Processor.on_message.__get__(processor, Processor)
# Mock query execution error
processor.execute_graphql_query.side_effect = RuntimeError("No schema available")
# Create mock message
mock_msg = MagicMock()
mock_request = ObjectsQueryRequest(
user="test_user",
collection="test_collection",
query='{ invalid_query }',
variables={},
operation_name=None
)
mock_msg.value.return_value = mock_request
mock_msg.properties.return_value = {"id": "test-456"}
# Mock flow
mock_flow = MagicMock()
mock_response_flow = AsyncMock()
mock_flow.return_value = mock_response_flow
# Process message
await processor.on_message(mock_msg, None, mock_flow)
# Verify error response was sent
mock_response_flow.send.assert_called_once()
response_call = mock_response_flow.send.call_args[0][0]
# Verify error response structure
assert isinstance(response_call, ObjectsQueryResponse)
assert response_call.error is not None
assert response_call.error.type == "objects-query-error"
assert "No schema available" in response_call.error.message
assert response_call.data is None
class TestCQLQueryGeneration:
"""Test CQL query generation logic in isolation"""
def test_partition_key_inclusion(self):
"""Test that collection is always included in queries"""
processor = MagicMock()
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
processor.sanitize_table = Processor.sanitize_table.__get__(processor, Processor)
# Mock the query building (simplified version)
keyspace = processor.sanitize_name("test_user")
table = processor.sanitize_table("test_table")
query = f"SELECT * FROM {keyspace}.{table}"
where_clauses = ["collection = %s"]
assert "collection = %s" in where_clauses
assert keyspace == "test_user"
assert table == "o_test_table"
def test_indexed_field_filtering(self):
"""Test that only indexed or primary key fields can be filtered"""
# Create schema with mixed field types
schema = RowSchema(
name="test",
fields=[
Field(name="id", type="string", primary=True),
Field(name="indexed_field", type="string", indexed=True),
Field(name="normal_field", type="string", indexed=False),
Field(name="another_field", type="string")
]
)
filters = {
"id": "test123", # Primary key - should be included
"indexed_field": "value", # Indexed - should be included
"normal_field": "ignored", # Not indexed - should be ignored
"another_field": "also_ignored" # Not indexed - should be ignored
}
# Simulate the filtering logic from the processor
valid_filters = []
for field_name, value in filters.items():
if value is not None:
schema_field = next((f for f in schema.fields if f.name == field_name), None)
if schema_field and (schema_field.indexed or schema_field.primary):
valid_filters.append((field_name, value))
# Only id and indexed_field should be included
assert len(valid_filters) == 2
field_names = [f[0] for f in valid_filters]
assert "id" in field_names
assert "indexed_field" in field_names
assert "normal_field" not in field_names
assert "another_field" not in field_names
class TestGraphQLSchemaGeneration:
"""Test GraphQL schema generation in detail"""
def test_field_type_annotations(self):
"""Test that GraphQL types have correct field annotations"""
processor = MagicMock()
processor.get_python_type = Processor.get_python_type.__get__(processor, Processor)
processor.create_graphql_type = Processor.create_graphql_type.__get__(processor, Processor)
# Create schema with various field types
schema = RowSchema(
name="test",
fields=[
Field(name="id", type="string", required=True, primary=True),
Field(name="count", type="integer", required=True),
Field(name="price", type="float", required=False),
Field(name="active", type="boolean", required=False),
Field(name="optional_text", type="string", required=False)
]
)
# Create GraphQL type
graphql_type = processor.create_graphql_type("test", schema)
# Verify type was created successfully
assert graphql_type is not None
def test_basic_type_creation(self):
"""Test that GraphQL types are created correctly"""
processor = MagicMock()
processor.schemas = {
"customer": RowSchema(
name="customer",
fields=[Field(name="id", type="string", primary=True)]
)
}
processor.graphql_types = {}
processor.get_python_type = Processor.get_python_type.__get__(processor, Processor)
processor.create_graphql_type = Processor.create_graphql_type.__get__(processor, Processor)
# Create GraphQL type directly
graphql_type = processor.create_graphql_type("customer", processor.schemas["customer"])
processor.graphql_types["customer"] = graphql_type
# Verify customer type was created
assert "customer" in processor.graphql_types
assert processor.graphql_types["customer"] is not None

View file

@ -70,7 +70,7 @@ class TestCassandraQueryProcessor:
assert result.is_uri is False
@pytest.mark.asyncio
@patch('trustgraph.query.triples.cassandra.service.TrustGraph')
@patch('trustgraph.query.triples.cassandra.service.KnowledgeGraph')
async def test_query_triples_spo_query(self, mock_trustgraph):
"""Test querying triples with subject, predicate, and object specified"""
from trustgraph.schema import TriplesQueryRequest, Value
@ -83,7 +83,7 @@ class TestCassandraQueryProcessor:
processor = Processor(
taskgroup=MagicMock(),
id='test-cassandra-query',
graph_host='localhost'
cassandra_host='localhost'
)
# Create query request with all SPO values
@ -98,16 +98,15 @@ class TestCassandraQueryProcessor:
result = await processor.query_triples(query)
# Verify TrustGraph was created with correct parameters
# Verify KnowledgeGraph was created with correct parameters
mock_trustgraph.assert_called_once_with(
hosts=['localhost'],
keyspace='test_user',
table='test_collection'
keyspace='test_user'
)
# Verify get_spo was called with correct parameters
mock_tg_instance.get_spo.assert_called_once_with(
'test_subject', 'test_predicate', 'test_object', limit=100
'test_collection', 'test_subject', 'test_predicate', 'test_object', limit=100
)
# Verify result contains the queried triple
@ -122,9 +121,9 @@ class TestCassandraQueryProcessor:
processor = Processor(taskgroup=taskgroup_mock)
assert processor.graph_host == ['localhost']
assert processor.username is None
assert processor.password is None
assert processor.cassandra_host == ['cassandra'] # Updated default
assert processor.cassandra_username is None
assert processor.cassandra_password is None
assert processor.table is None
def test_processor_initialization_with_custom_params(self):
@ -133,18 +132,18 @@ class TestCassandraQueryProcessor:
processor = Processor(
taskgroup=taskgroup_mock,
graph_host='cassandra.example.com',
graph_username='queryuser',
graph_password='querypass'
cassandra_host='cassandra.example.com',
cassandra_username='queryuser',
cassandra_password='querypass'
)
assert processor.graph_host == ['cassandra.example.com']
assert processor.username == 'queryuser'
assert processor.password == 'querypass'
assert processor.cassandra_host == ['cassandra.example.com']
assert processor.cassandra_username == 'queryuser'
assert processor.cassandra_password == 'querypass'
assert processor.table is None
@pytest.mark.asyncio
@patch('trustgraph.query.triples.cassandra.service.TrustGraph')
@patch('trustgraph.query.triples.cassandra.service.KnowledgeGraph')
async def test_query_triples_sp_pattern(self, mock_trustgraph):
"""Test SP query pattern (subject and predicate, no object)"""
from trustgraph.schema import TriplesQueryRequest, Value
@ -170,14 +169,14 @@ class TestCassandraQueryProcessor:
result = await processor.query_triples(query)
mock_tg_instance.get_sp.assert_called_once_with('test_subject', 'test_predicate', limit=50)
mock_tg_instance.get_sp.assert_called_once_with('test_collection', 'test_subject', 'test_predicate', limit=50)
assert len(result) == 1
assert result[0].s.value == 'test_subject'
assert result[0].p.value == 'test_predicate'
assert result[0].o.value == 'result_object'
@pytest.mark.asyncio
@patch('trustgraph.query.triples.cassandra.service.TrustGraph')
@patch('trustgraph.query.triples.cassandra.service.KnowledgeGraph')
async def test_query_triples_s_pattern(self, mock_trustgraph):
"""Test S query pattern (subject only)"""
from trustgraph.schema import TriplesQueryRequest, Value
@ -203,14 +202,14 @@ class TestCassandraQueryProcessor:
result = await processor.query_triples(query)
mock_tg_instance.get_s.assert_called_once_with('test_subject', limit=25)
mock_tg_instance.get_s.assert_called_once_with('test_collection', 'test_subject', limit=25)
assert len(result) == 1
assert result[0].s.value == 'test_subject'
assert result[0].p.value == 'result_predicate'
assert result[0].o.value == 'result_object'
@pytest.mark.asyncio
@patch('trustgraph.query.triples.cassandra.service.TrustGraph')
@patch('trustgraph.query.triples.cassandra.service.KnowledgeGraph')
async def test_query_triples_p_pattern(self, mock_trustgraph):
"""Test P query pattern (predicate only)"""
from trustgraph.schema import TriplesQueryRequest, Value
@ -236,14 +235,14 @@ class TestCassandraQueryProcessor:
result = await processor.query_triples(query)
mock_tg_instance.get_p.assert_called_once_with('test_predicate', limit=10)
mock_tg_instance.get_p.assert_called_once_with('test_collection', 'test_predicate', limit=10)
assert len(result) == 1
assert result[0].s.value == 'result_subject'
assert result[0].p.value == 'test_predicate'
assert result[0].o.value == 'result_object'
@pytest.mark.asyncio
@patch('trustgraph.query.triples.cassandra.service.TrustGraph')
@patch('trustgraph.query.triples.cassandra.service.KnowledgeGraph')
async def test_query_triples_o_pattern(self, mock_trustgraph):
"""Test O query pattern (object only)"""
from trustgraph.schema import TriplesQueryRequest, Value
@ -269,14 +268,14 @@ class TestCassandraQueryProcessor:
result = await processor.query_triples(query)
mock_tg_instance.get_o.assert_called_once_with('test_object', limit=75)
mock_tg_instance.get_o.assert_called_once_with('test_collection', 'test_object', limit=75)
assert len(result) == 1
assert result[0].s.value == 'result_subject'
assert result[0].p.value == 'result_predicate'
assert result[0].o.value == 'test_object'
@pytest.mark.asyncio
@patch('trustgraph.query.triples.cassandra.service.TrustGraph')
@patch('trustgraph.query.triples.cassandra.service.KnowledgeGraph')
async def test_query_triples_get_all_pattern(self, mock_trustgraph):
"""Test query pattern with no constraints (get all)"""
from trustgraph.schema import TriplesQueryRequest
@ -303,7 +302,7 @@ class TestCassandraQueryProcessor:
result = await processor.query_triples(query)
mock_tg_instance.get_all.assert_called_once_with(limit=1000)
mock_tg_instance.get_all.assert_called_once_with('test_collection', limit=1000)
assert len(result) == 1
assert result[0].s.value == 'all_subject'
assert result[0].p.value == 'all_predicate'
@ -325,12 +324,12 @@ class TestCassandraQueryProcessor:
# Verify our specific arguments were added
args = parser.parse_args([])
assert hasattr(args, 'graph_host')
assert args.graph_host == 'localhost'
assert hasattr(args, 'graph_username')
assert args.graph_username is None
assert hasattr(args, 'graph_password')
assert args.graph_password is None
assert hasattr(args, 'cassandra_host')
assert args.cassandra_host == 'cassandra' # Updated to new parameter name and default
assert hasattr(args, 'cassandra_username')
assert args.cassandra_username is None
assert hasattr(args, 'cassandra_password')
assert args.cassandra_password is None
def test_add_args_with_custom_values(self):
"""Test add_args with custom command line values"""
@ -341,16 +340,16 @@ class TestCassandraQueryProcessor:
with patch('trustgraph.query.triples.cassandra.service.TriplesQueryService.add_args'):
Processor.add_args(parser)
# Test parsing with custom values
# Test parsing with custom values (new cassandra_* arguments)
args = parser.parse_args([
'--graph-host', 'query.cassandra.com',
'--graph-username', 'queryuser',
'--graph-password', 'querypass'
'--cassandra-host', 'query.cassandra.com',
'--cassandra-username', 'queryuser',
'--cassandra-password', 'querypass'
])
assert args.graph_host == 'query.cassandra.com'
assert args.graph_username == 'queryuser'
assert args.graph_password == 'querypass'
assert args.cassandra_host == 'query.cassandra.com'
assert args.cassandra_username == 'queryuser'
assert args.cassandra_password == 'querypass'
def test_add_args_short_form(self):
"""Test add_args with short form arguments"""
@ -361,10 +360,10 @@ class TestCassandraQueryProcessor:
with patch('trustgraph.query.triples.cassandra.service.TriplesQueryService.add_args'):
Processor.add_args(parser)
# Test parsing with short form
args = parser.parse_args(['-g', 'short.query.com'])
# Test parsing with cassandra arguments (no short form)
args = parser.parse_args(['--cassandra-host', 'short.query.com'])
assert args.graph_host == 'short.query.com'
assert args.cassandra_host == 'short.query.com'
@patch('trustgraph.query.triples.cassandra.service.Processor.launch')
def test_run_function(self, mock_launch):
@ -376,7 +375,7 @@ class TestCassandraQueryProcessor:
mock_launch.assert_called_once_with(default_ident, '\nTriples query service. Input is a (s, p, o) triple, some values may be\nnull. Output is a list of triples.\n')
@pytest.mark.asyncio
@patch('trustgraph.query.triples.cassandra.service.TrustGraph')
@patch('trustgraph.query.triples.cassandra.service.KnowledgeGraph')
async def test_query_triples_with_authentication(self, mock_trustgraph):
"""Test querying with username and password authentication"""
from trustgraph.schema import TriplesQueryRequest, Value
@ -387,8 +386,8 @@ class TestCassandraQueryProcessor:
processor = Processor(
taskgroup=MagicMock(),
graph_username='authuser',
graph_password='authpass'
cassandra_username='authuser',
cassandra_password='authpass'
)
query = TriplesQueryRequest(
@ -402,17 +401,16 @@ class TestCassandraQueryProcessor:
await processor.query_triples(query)
# Verify TrustGraph was created with authentication
# Verify KnowledgeGraph was created with authentication
mock_trustgraph.assert_called_once_with(
hosts=['localhost'],
hosts=['cassandra'], # Updated default
keyspace='test_user',
table='test_collection',
username='authuser',
password='authpass'
)
@pytest.mark.asyncio
@patch('trustgraph.query.triples.cassandra.service.TrustGraph')
@patch('trustgraph.query.triples.cassandra.service.KnowledgeGraph')
async def test_query_triples_table_reuse(self, mock_trustgraph):
"""Test that TrustGraph is reused for same table"""
from trustgraph.schema import TriplesQueryRequest, Value
@ -441,7 +439,7 @@ class TestCassandraQueryProcessor:
assert mock_trustgraph.call_count == 1 # Should not increase
@pytest.mark.asyncio
@patch('trustgraph.query.triples.cassandra.service.TrustGraph')
@patch('trustgraph.query.triples.cassandra.service.KnowledgeGraph')
async def test_query_triples_table_switching(self, mock_trustgraph):
"""Test table switching creates new TrustGraph"""
from trustgraph.schema import TriplesQueryRequest, Value
@ -463,7 +461,7 @@ class TestCassandraQueryProcessor:
)
await processor.query_triples(query1)
assert processor.table == ('user1', 'collection1')
assert processor.table == 'user1'
# Second query with different table
query2 = TriplesQueryRequest(
@ -476,13 +474,13 @@ class TestCassandraQueryProcessor:
)
await processor.query_triples(query2)
assert processor.table == ('user2', 'collection2')
assert processor.table == 'user2'
# Verify TrustGraph was created twice
assert mock_trustgraph.call_count == 2
@pytest.mark.asyncio
@patch('trustgraph.query.triples.cassandra.service.TrustGraph')
@patch('trustgraph.query.triples.cassandra.service.KnowledgeGraph')
async def test_query_triples_exception_handling(self, mock_trustgraph):
"""Test exception handling during query execution"""
from trustgraph.schema import TriplesQueryRequest, Value
@ -506,7 +504,7 @@ class TestCassandraQueryProcessor:
await processor.query_triples(query)
@pytest.mark.asyncio
@patch('trustgraph.query.triples.cassandra.service.TrustGraph')
@patch('trustgraph.query.triples.cassandra.service.KnowledgeGraph')
async def test_query_triples_multiple_results(self, mock_trustgraph):
"""Test query returning multiple results"""
from trustgraph.schema import TriplesQueryRequest, Value
@ -536,4 +534,203 @@ class TestCassandraQueryProcessor:
assert len(result) == 2
assert result[0].o.value == 'object1'
assert result[1].o.value == 'object2'
assert result[1].o.value == 'object2'
class TestCassandraQueryPerformanceOptimizations:
"""Test cases for multi-table performance optimizations in query service"""
@pytest.mark.asyncio
@patch('trustgraph.query.triples.cassandra.service.KnowledgeGraph')
async def test_get_po_query_optimization(self, mock_trustgraph):
"""Test that get_po queries use optimized table (no ALLOW FILTERING)"""
from trustgraph.schema import TriplesQueryRequest, Value
mock_tg_instance = MagicMock()
mock_trustgraph.return_value = mock_tg_instance
mock_result = MagicMock()
mock_result.s = 'result_subject'
mock_tg_instance.get_po.return_value = [mock_result]
processor = Processor(taskgroup=MagicMock())
# PO query pattern (predicate + object, find subjects)
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=None,
p=Value(value='test_predicate', is_uri=False),
o=Value(value='test_object', is_uri=False),
limit=50
)
result = await processor.query_triples(query)
# Verify get_po was called (should use optimized po_table)
mock_tg_instance.get_po.assert_called_once_with(
'test_collection', 'test_predicate', 'test_object', limit=50
)
assert len(result) == 1
assert result[0].s.value == 'result_subject'
assert result[0].p.value == 'test_predicate'
assert result[0].o.value == 'test_object'
@pytest.mark.asyncio
@patch('trustgraph.query.triples.cassandra.service.KnowledgeGraph')
async def test_get_os_query_optimization(self, mock_trustgraph):
"""Test that get_os queries use optimized table (no ALLOW FILTERING)"""
from trustgraph.schema import TriplesQueryRequest, Value
mock_tg_instance = MagicMock()
mock_trustgraph.return_value = mock_tg_instance
mock_result = MagicMock()
mock_result.p = 'result_predicate'
mock_tg_instance.get_os.return_value = [mock_result]
processor = Processor(taskgroup=MagicMock())
# OS query pattern (object + subject, find predicates)
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=Value(value='test_subject', is_uri=False),
p=None,
o=Value(value='test_object', is_uri=False),
limit=25
)
result = await processor.query_triples(query)
# Verify get_os was called (should use optimized subject_table with clustering)
mock_tg_instance.get_os.assert_called_once_with(
'test_collection', 'test_object', 'test_subject', limit=25
)
assert len(result) == 1
assert result[0].s.value == 'test_subject'
assert result[0].p.value == 'result_predicate'
assert result[0].o.value == 'test_object'
@pytest.mark.asyncio
@patch('trustgraph.query.triples.cassandra.service.KnowledgeGraph')
async def test_all_query_patterns_use_correct_tables(self, mock_trustgraph):
"""Test that all query patterns route to their optimal tables"""
from trustgraph.schema import TriplesQueryRequest, Value
mock_tg_instance = MagicMock()
mock_trustgraph.return_value = mock_tg_instance
# Mock empty results for all queries
mock_tg_instance.get_all.return_value = []
mock_tg_instance.get_s.return_value = []
mock_tg_instance.get_p.return_value = []
mock_tg_instance.get_o.return_value = []
mock_tg_instance.get_sp.return_value = []
mock_tg_instance.get_po.return_value = []
mock_tg_instance.get_os.return_value = []
mock_tg_instance.get_spo.return_value = []
processor = Processor(taskgroup=MagicMock())
# Test each query pattern
test_patterns = [
# (s, p, o, expected_method)
(None, None, None, 'get_all'), # All triples
('s1', None, None, 'get_s'), # Subject only
(None, 'p1', None, 'get_p'), # Predicate only
(None, None, 'o1', 'get_o'), # Object only
('s1', 'p1', None, 'get_sp'), # Subject + Predicate
(None, 'p1', 'o1', 'get_po'), # Predicate + Object (CRITICAL OPTIMIZATION)
('s1', None, 'o1', 'get_os'), # Object + Subject
('s1', 'p1', 'o1', 'get_spo'), # All three
]
for s, p, o, expected_method in test_patterns:
# Reset mock call counts
mock_tg_instance.reset_mock()
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=Value(value=s, is_uri=False) if s else None,
p=Value(value=p, is_uri=False) if p else None,
o=Value(value=o, is_uri=False) if o else None,
limit=10
)
await processor.query_triples(query)
# Verify the correct method was called
method = getattr(mock_tg_instance, expected_method)
assert method.called, f"Expected {expected_method} to be called for pattern s={s}, p={p}, o={o}"
def test_legacy_vs_optimized_mode_configuration(self):
"""Test that environment variable controls query optimization mode"""
taskgroup_mock = MagicMock()
# Test optimized mode (default)
with patch.dict('os.environ', {}, clear=True):
processor = Processor(taskgroup=taskgroup_mock)
# Mode is determined in KnowledgeGraph initialization
# Test legacy mode
with patch.dict('os.environ', {'CASSANDRA_USE_LEGACY': 'true'}):
processor = Processor(taskgroup=taskgroup_mock)
# Mode is determined in KnowledgeGraph initialization
# Test explicit optimized mode
with patch.dict('os.environ', {'CASSANDRA_USE_LEGACY': 'false'}):
processor = Processor(taskgroup=taskgroup_mock)
# Mode is determined in KnowledgeGraph initialization
@pytest.mark.asyncio
@patch('trustgraph.query.triples.cassandra.service.KnowledgeGraph')
async def test_performance_critical_po_query_no_filtering(self, mock_trustgraph):
"""Test the performance-critical PO query that eliminates ALLOW FILTERING"""
from trustgraph.schema import TriplesQueryRequest, Value
mock_tg_instance = MagicMock()
mock_trustgraph.return_value = mock_tg_instance
# Mock multiple subjects for the same predicate-object pair
mock_results = []
for i in range(5):
mock_result = MagicMock()
mock_result.s = f'subject_{i}'
mock_results.append(mock_result)
mock_tg_instance.get_po.return_value = mock_results
processor = Processor(taskgroup=MagicMock())
# This is the query pattern that was slow with ALLOW FILTERING
query = TriplesQueryRequest(
user='large_dataset_user',
collection='massive_collection',
s=None,
p=Value(value='http://www.w3.org/1999/02/22-rdf-syntax-ns#type', is_uri=True),
o=Value(value='http://example.com/Person', is_uri=True),
limit=1000
)
result = await processor.query_triples(query)
# Verify optimized get_po was used (no ALLOW FILTERING needed!)
mock_tg_instance.get_po.assert_called_once_with(
'massive_collection',
'http://www.w3.org/1999/02/22-rdf-syntax-ns#type',
'http://example.com/Person',
limit=1000
)
# Verify all results were returned
assert len(result) == 5
for i, triple in enumerate(result):
assert triple.s.value == f'subject_{i}'
assert triple.p.value == 'http://www.w3.org/1999/02/22-rdf-syntax-ns#type'
assert triple.p.is_uri is True
assert triple.o.value == 'http://example.com/Person'
assert triple.o.is_uri is True