mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-25 00:16:23 +02:00
Test suite executed from CI pipeline (#433)
* Test strategy & test cases * Unit tests * Integration tests
This commit is contained in:
parent
9c7a070681
commit
2f7fddd206
101 changed files with 17811 additions and 1 deletions
148
tests/unit/test_query/conftest.py
Normal file
148
tests/unit/test_query/conftest.py
Normal file
|
|
@ -0,0 +1,148 @@
|
|||
"""
|
||||
Shared fixtures for query tests
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def base_query_config():
|
||||
"""Base configuration for query processors"""
|
||||
return {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-query-processor'
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def qdrant_query_config(base_query_config):
|
||||
"""Configuration for Qdrant query processors"""
|
||||
return base_query_config | {
|
||||
'store_uri': 'http://localhost:6333',
|
||||
'api_key': 'test-api-key'
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_qdrant_client():
|
||||
"""Mock Qdrant client"""
|
||||
mock_client = MagicMock()
|
||||
mock_client.query_points.return_value = []
|
||||
return mock_client
|
||||
|
||||
|
||||
# Graph embeddings query fixtures
|
||||
@pytest.fixture
|
||||
def mock_graph_embeddings_request():
|
||||
"""Mock graph embeddings request message"""
|
||||
mock_message = MagicMock()
|
||||
mock_message.vectors = [[0.1, 0.2, 0.3]]
|
||||
mock_message.limit = 5
|
||||
mock_message.user = 'test_user'
|
||||
mock_message.collection = 'test_collection'
|
||||
return mock_message
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_graph_embeddings_multiple_vectors():
|
||||
"""Mock graph embeddings request with multiple vectors"""
|
||||
mock_message = MagicMock()
|
||||
mock_message.vectors = [[0.1, 0.2], [0.3, 0.4]]
|
||||
mock_message.limit = 3
|
||||
mock_message.user = 'multi_user'
|
||||
mock_message.collection = 'multi_collection'
|
||||
return mock_message
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_graph_embeddings_query_response():
|
||||
"""Mock graph embeddings query response from Qdrant"""
|
||||
mock_point1 = MagicMock()
|
||||
mock_point1.payload = {'entity': 'entity1'}
|
||||
mock_point2 = MagicMock()
|
||||
mock_point2.payload = {'entity': 'entity2'}
|
||||
return [mock_point1, mock_point2]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_graph_embeddings_uri_response():
|
||||
"""Mock graph embeddings query response with URIs"""
|
||||
mock_point1 = MagicMock()
|
||||
mock_point1.payload = {'entity': 'http://example.com/entity1'}
|
||||
mock_point2 = MagicMock()
|
||||
mock_point2.payload = {'entity': 'https://secure.example.com/entity2'}
|
||||
mock_point3 = MagicMock()
|
||||
mock_point3.payload = {'entity': 'regular entity'}
|
||||
return [mock_point1, mock_point2, mock_point3]
|
||||
|
||||
|
||||
# Document embeddings query fixtures
|
||||
@pytest.fixture
|
||||
def mock_document_embeddings_request():
|
||||
"""Mock document embeddings request message"""
|
||||
mock_message = MagicMock()
|
||||
mock_message.vectors = [[0.1, 0.2, 0.3]]
|
||||
mock_message.limit = 5
|
||||
mock_message.user = 'test_user'
|
||||
mock_message.collection = 'test_collection'
|
||||
return mock_message
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_document_embeddings_multiple_vectors():
|
||||
"""Mock document embeddings request with multiple vectors"""
|
||||
mock_message = MagicMock()
|
||||
mock_message.vectors = [[0.1, 0.2], [0.3, 0.4]]
|
||||
mock_message.limit = 3
|
||||
mock_message.user = 'multi_user'
|
||||
mock_message.collection = 'multi_collection'
|
||||
return mock_message
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_document_embeddings_query_response():
|
||||
"""Mock document embeddings query response from Qdrant"""
|
||||
mock_point1 = MagicMock()
|
||||
mock_point1.payload = {'doc': 'first document chunk'}
|
||||
mock_point2 = MagicMock()
|
||||
mock_point2.payload = {'doc': 'second document chunk'}
|
||||
return [mock_point1, mock_point2]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_document_embeddings_utf8_response():
|
||||
"""Mock document embeddings query response with UTF-8 content"""
|
||||
mock_point1 = MagicMock()
|
||||
mock_point1.payload = {'doc': 'Document with UTF-8: café, naïve, résumé'}
|
||||
mock_point2 = MagicMock()
|
||||
mock_point2.payload = {'doc': 'Chinese text: 你好世界'}
|
||||
return [mock_point1, mock_point2]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_empty_query_response():
|
||||
"""Mock empty query response"""
|
||||
return []
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_large_query_response():
|
||||
"""Mock large query response with many results"""
|
||||
mock_points = []
|
||||
for i in range(10):
|
||||
mock_point = MagicMock()
|
||||
mock_point.payload = {'doc': f'document chunk {i}'}
|
||||
mock_points.append(mock_point)
|
||||
return mock_points
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_mixed_dimension_vectors():
|
||||
"""Mock request with vectors of different dimensions"""
|
||||
mock_message = MagicMock()
|
||||
mock_message.vectors = [[0.1, 0.2], [0.3, 0.4, 0.5]] # 2D and 3D
|
||||
mock_message.limit = 5
|
||||
mock_message.user = 'dim_user'
|
||||
mock_message.collection = 'dim_collection'
|
||||
return mock_message
|
||||
542
tests/unit/test_query/test_doc_embeddings_qdrant_query.py
Normal file
542
tests/unit/test_query/test_doc_embeddings_qdrant_query.py
Normal file
|
|
@ -0,0 +1,542 @@
|
|||
"""
|
||||
Unit tests for trustgraph.query.doc_embeddings.qdrant.service
|
||||
Testing document embeddings query functionality
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from unittest import IsolatedAsyncioTestCase
|
||||
|
||||
# Import the service under test
|
||||
from trustgraph.query.doc_embeddings.qdrant.service import Processor
|
||||
|
||||
|
||||
class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
|
||||
"""Test Qdrant document embeddings query functionality"""
|
||||
|
||||
@patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient')
|
||||
@patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__')
|
||||
async def test_processor_initialization_basic(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test basic Qdrant processor initialization"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
config = {
|
||||
'store_uri': 'http://localhost:6333',
|
||||
'api_key': 'test-api-key',
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-doc-query-processor'
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
# Verify base class initialization was called
|
||||
mock_base_init.assert_called_once()
|
||||
|
||||
# Verify QdrantClient was created with correct parameters
|
||||
mock_qdrant_client.assert_called_once_with(url='http://localhost:6333', api_key='test-api-key')
|
||||
|
||||
# Verify processor attributes
|
||||
assert hasattr(processor, 'qdrant')
|
||||
assert processor.qdrant == mock_qdrant_instance
|
||||
|
||||
@patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient')
|
||||
@patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__')
|
||||
async def test_processor_initialization_with_defaults(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test processor initialization with default values"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-doc-query-processor'
|
||||
# No store_uri or api_key provided - should use defaults
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
# Verify QdrantClient was created with default URI and None API key
|
||||
mock_qdrant_client.assert_called_once_with(url='http://localhost:6333', api_key=None)
|
||||
|
||||
@patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient')
|
||||
@patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__')
|
||||
async def test_query_document_embeddings_single_vector(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test querying document embeddings with single vector"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
# Mock query response
|
||||
mock_point1 = MagicMock()
|
||||
mock_point1.payload = {'doc': 'first document chunk'}
|
||||
mock_point2 = MagicMock()
|
||||
mock_point2.payload = {'doc': 'second document chunk'}
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.points = [mock_point1, mock_point2]
|
||||
mock_qdrant_instance.query_points.return_value = mock_response
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Create mock message
|
||||
mock_message = MagicMock()
|
||||
mock_message.vectors = [[0.1, 0.2, 0.3]]
|
||||
mock_message.limit = 5
|
||||
mock_message.user = 'test_user'
|
||||
mock_message.collection = 'test_collection'
|
||||
|
||||
# Act
|
||||
result = await processor.query_document_embeddings(mock_message)
|
||||
|
||||
# Assert
|
||||
# Verify query was called with correct parameters
|
||||
expected_collection = 'd_test_user_test_collection_3'
|
||||
mock_qdrant_instance.query_points.assert_called_once_with(
|
||||
collection_name=expected_collection,
|
||||
query=[0.1, 0.2, 0.3],
|
||||
limit=5, # Direct limit, no multiplication
|
||||
with_payload=True
|
||||
)
|
||||
|
||||
# Verify result contains expected documents
|
||||
assert len(result) == 2
|
||||
# Results should be strings (document chunks)
|
||||
assert isinstance(result[0], str)
|
||||
assert isinstance(result[1], str)
|
||||
# Verify content
|
||||
assert result[0] == 'first document chunk'
|
||||
assert result[1] == 'second document chunk'
|
||||
|
||||
@patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient')
|
||||
@patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__')
|
||||
async def test_query_document_embeddings_multiple_vectors(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test querying document embeddings with multiple vectors"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
# Mock query responses for different vectors
|
||||
mock_point1 = MagicMock()
|
||||
mock_point1.payload = {'doc': 'document from vector 1'}
|
||||
mock_point2 = MagicMock()
|
||||
mock_point2.payload = {'doc': 'document from vector 2'}
|
||||
mock_point3 = MagicMock()
|
||||
mock_point3.payload = {'doc': 'another document from vector 2'}
|
||||
|
||||
mock_response1 = MagicMock()
|
||||
mock_response1.points = [mock_point1]
|
||||
mock_response2 = MagicMock()
|
||||
mock_response2.points = [mock_point2, mock_point3]
|
||||
mock_qdrant_instance.query_points.side_effect = [mock_response1, mock_response2]
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Create mock message with multiple vectors
|
||||
mock_message = MagicMock()
|
||||
mock_message.vectors = [[0.1, 0.2], [0.3, 0.4]]
|
||||
mock_message.limit = 3
|
||||
mock_message.user = 'multi_user'
|
||||
mock_message.collection = 'multi_collection'
|
||||
|
||||
# Act
|
||||
result = await processor.query_document_embeddings(mock_message)
|
||||
|
||||
# Assert
|
||||
# Verify query was called twice
|
||||
assert mock_qdrant_instance.query_points.call_count == 2
|
||||
|
||||
# Verify both collections were queried
|
||||
expected_collection = 'd_multi_user_multi_collection_2'
|
||||
calls = mock_qdrant_instance.query_points.call_args_list
|
||||
assert calls[0][1]['collection_name'] == expected_collection
|
||||
assert calls[1][1]['collection_name'] == expected_collection
|
||||
assert calls[0][1]['query'] == [0.1, 0.2]
|
||||
assert calls[1][1]['query'] == [0.3, 0.4]
|
||||
|
||||
# Verify results from both vectors are combined
|
||||
assert len(result) == 3
|
||||
assert 'document from vector 1' in result
|
||||
assert 'document from vector 2' in result
|
||||
assert 'another document from vector 2' in result
|
||||
|
||||
@patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient')
|
||||
@patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__')
|
||||
async def test_query_document_embeddings_with_limit(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test querying document embeddings respects limit parameter"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
# Mock query response with many results
|
||||
mock_points = []
|
||||
for i in range(10):
|
||||
mock_point = MagicMock()
|
||||
mock_point.payload = {'doc': f'document chunk {i}'}
|
||||
mock_points.append(mock_point)
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.points = mock_points
|
||||
mock_qdrant_instance.query_points.return_value = mock_response
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Create mock message with limit
|
||||
mock_message = MagicMock()
|
||||
mock_message.vectors = [[0.1, 0.2, 0.3]]
|
||||
mock_message.limit = 3 # Should only return 3 results
|
||||
mock_message.user = 'limit_user'
|
||||
mock_message.collection = 'limit_collection'
|
||||
|
||||
# Act
|
||||
result = await processor.query_document_embeddings(mock_message)
|
||||
|
||||
# Assert
|
||||
# Verify query was called with exact limit (no multiplication)
|
||||
mock_qdrant_instance.query_points.assert_called_once()
|
||||
call_args = mock_qdrant_instance.query_points.call_args
|
||||
assert call_args[1]['limit'] == 3 # Direct limit
|
||||
|
||||
# Verify result contains all returned documents (limit applied by Qdrant)
|
||||
assert len(result) == 10 # All results returned by mock
|
||||
|
||||
@patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient')
|
||||
@patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__')
|
||||
async def test_query_document_embeddings_empty_results(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test querying document embeddings with empty results"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
# Mock empty query response
|
||||
mock_response = MagicMock()
|
||||
mock_response.points = []
|
||||
mock_qdrant_instance.query_points.return_value = mock_response
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Create mock message
|
||||
mock_message = MagicMock()
|
||||
mock_message.vectors = [[0.1, 0.2]]
|
||||
mock_message.limit = 5
|
||||
mock_message.user = 'empty_user'
|
||||
mock_message.collection = 'empty_collection'
|
||||
|
||||
# Act
|
||||
result = await processor.query_document_embeddings(mock_message)
|
||||
|
||||
# Assert
|
||||
assert result == []
|
||||
|
||||
@patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient')
|
||||
@patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__')
|
||||
async def test_query_document_embeddings_different_dimensions(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test querying document embeddings with different vector dimensions"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
# Mock query responses
|
||||
mock_point1 = MagicMock()
|
||||
mock_point1.payload = {'doc': 'document from 2D vector'}
|
||||
mock_point2 = MagicMock()
|
||||
mock_point2.payload = {'doc': 'document from 3D vector'}
|
||||
|
||||
mock_response1 = MagicMock()
|
||||
mock_response1.points = [mock_point1]
|
||||
mock_response2 = MagicMock()
|
||||
mock_response2.points = [mock_point2]
|
||||
mock_qdrant_instance.query_points.side_effect = [mock_response1, mock_response2]
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Create mock message with different dimension vectors
|
||||
mock_message = MagicMock()
|
||||
mock_message.vectors = [[0.1, 0.2], [0.3, 0.4, 0.5]] # 2D and 3D
|
||||
mock_message.limit = 5
|
||||
mock_message.user = 'dim_user'
|
||||
mock_message.collection = 'dim_collection'
|
||||
|
||||
# Act
|
||||
result = await processor.query_document_embeddings(mock_message)
|
||||
|
||||
# Assert
|
||||
# Verify query was called twice with different collections
|
||||
assert mock_qdrant_instance.query_points.call_count == 2
|
||||
calls = mock_qdrant_instance.query_points.call_args_list
|
||||
|
||||
# First call should use 2D collection
|
||||
assert calls[0][1]['collection_name'] == 'd_dim_user_dim_collection_2'
|
||||
assert calls[0][1]['query'] == [0.1, 0.2]
|
||||
|
||||
# Second call should use 3D collection
|
||||
assert calls[1][1]['collection_name'] == 'd_dim_user_dim_collection_3'
|
||||
assert calls[1][1]['query'] == [0.3, 0.4, 0.5]
|
||||
|
||||
# Verify results
|
||||
assert len(result) == 2
|
||||
assert 'document from 2D vector' in result
|
||||
assert 'document from 3D vector' in result
|
||||
|
||||
@patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient')
|
||||
@patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__')
|
||||
async def test_query_document_embeddings_utf8_encoding(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test querying document embeddings with UTF-8 content"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
# Mock query response with UTF-8 content
|
||||
mock_point1 = MagicMock()
|
||||
mock_point1.payload = {'doc': 'Document with UTF-8: café, naïve, résumé'}
|
||||
mock_point2 = MagicMock()
|
||||
mock_point2.payload = {'doc': 'Chinese text: 你好世界'}
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.points = [mock_point1, mock_point2]
|
||||
mock_qdrant_instance.query_points.return_value = mock_response
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Create mock message
|
||||
mock_message = MagicMock()
|
||||
mock_message.vectors = [[0.1, 0.2]]
|
||||
mock_message.limit = 5
|
||||
mock_message.user = 'utf8_user'
|
||||
mock_message.collection = 'utf8_collection'
|
||||
|
||||
# Act
|
||||
result = await processor.query_document_embeddings(mock_message)
|
||||
|
||||
# Assert
|
||||
assert len(result) == 2
|
||||
|
||||
# Verify UTF-8 content works correctly
|
||||
assert 'Document with UTF-8: café, naïve, résumé' in result
|
||||
assert 'Chinese text: 你好世界' in result
|
||||
|
||||
@patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient')
|
||||
@patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__')
|
||||
async def test_query_document_embeddings_qdrant_error(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test querying document embeddings handles Qdrant errors"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
# Mock Qdrant error
|
||||
mock_qdrant_instance.query_points.side_effect = Exception("Qdrant connection failed")
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Create mock message
|
||||
mock_message = MagicMock()
|
||||
mock_message.vectors = [[0.1, 0.2]]
|
||||
mock_message.limit = 5
|
||||
mock_message.user = 'error_user'
|
||||
mock_message.collection = 'error_collection'
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(Exception, match="Qdrant connection failed"):
|
||||
await processor.query_document_embeddings(mock_message)
|
||||
|
||||
@patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient')
|
||||
@patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__')
|
||||
async def test_query_document_embeddings_zero_limit(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test querying document embeddings with zero limit"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
# Mock query response
|
||||
mock_point = MagicMock()
|
||||
mock_point.payload = {'doc': 'document chunk'}
|
||||
mock_response = MagicMock()
|
||||
mock_response.points = [mock_point]
|
||||
mock_qdrant_instance.query_points.return_value = mock_response
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Create mock message with zero limit
|
||||
mock_message = MagicMock()
|
||||
mock_message.vectors = [[0.1, 0.2]]
|
||||
mock_message.limit = 0
|
||||
mock_message.user = 'zero_user'
|
||||
mock_message.collection = 'zero_collection'
|
||||
|
||||
# Act
|
||||
result = await processor.query_document_embeddings(mock_message)
|
||||
|
||||
# Assert
|
||||
# Should still query (with limit 0)
|
||||
mock_qdrant_instance.query_points.assert_called_once()
|
||||
call_args = mock_qdrant_instance.query_points.call_args
|
||||
assert call_args[1]['limit'] == 0
|
||||
|
||||
# Result should contain all returned documents
|
||||
assert len(result) == 1
|
||||
assert result[0] == 'document chunk'
|
||||
|
||||
@patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient')
|
||||
@patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__')
|
||||
async def test_query_document_embeddings_large_limit(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test querying document embeddings with large limit"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
# Mock query response with fewer results than limit
|
||||
mock_point1 = MagicMock()
|
||||
mock_point1.payload = {'doc': 'document 1'}
|
||||
mock_point2 = MagicMock()
|
||||
mock_point2.payload = {'doc': 'document 2'}
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.points = [mock_point1, mock_point2]
|
||||
mock_qdrant_instance.query_points.return_value = mock_response
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Create mock message with large limit
|
||||
mock_message = MagicMock()
|
||||
mock_message.vectors = [[0.1, 0.2]]
|
||||
mock_message.limit = 1000 # Large limit
|
||||
mock_message.user = 'large_user'
|
||||
mock_message.collection = 'large_collection'
|
||||
|
||||
# Act
|
||||
result = await processor.query_document_embeddings(mock_message)
|
||||
|
||||
# Assert
|
||||
# Should query with full limit
|
||||
mock_qdrant_instance.query_points.assert_called_once()
|
||||
call_args = mock_qdrant_instance.query_points.call_args
|
||||
assert call_args[1]['limit'] == 1000
|
||||
|
||||
# Result should contain all available documents
|
||||
assert len(result) == 2
|
||||
assert 'document 1' in result
|
||||
assert 'document 2' in result
|
||||
|
||||
@patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient')
|
||||
@patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__')
|
||||
async def test_query_document_embeddings_missing_payload(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test querying document embeddings with missing payload data"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
# Mock query response with missing 'doc' key
|
||||
mock_point1 = MagicMock()
|
||||
mock_point1.payload = {'doc': 'valid document'}
|
||||
mock_point2 = MagicMock()
|
||||
mock_point2.payload = {} # Missing 'doc' key
|
||||
mock_point3 = MagicMock()
|
||||
mock_point3.payload = {'other_key': 'invalid'} # Wrong key
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.points = [mock_point1, mock_point2, mock_point3]
|
||||
mock_qdrant_instance.query_points.return_value = mock_response
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Create mock message
|
||||
mock_message = MagicMock()
|
||||
mock_message.vectors = [[0.1, 0.2]]
|
||||
mock_message.limit = 5
|
||||
mock_message.user = 'payload_user'
|
||||
mock_message.collection = 'payload_collection'
|
||||
|
||||
# Act & Assert
|
||||
# This should raise a KeyError when trying to access payload['doc']
|
||||
with pytest.raises(KeyError):
|
||||
await processor.query_document_embeddings(mock_message)
|
||||
|
||||
@patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient')
|
||||
@patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__')
|
||||
async def test_add_args_calls_parent(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test that add_args() calls parent add_args method"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_client.return_value = MagicMock()
|
||||
mock_parser = MagicMock()
|
||||
|
||||
# Act
|
||||
with patch('trustgraph.base.DocumentEmbeddingsQueryService.add_args') as mock_parent_add_args:
|
||||
Processor.add_args(mock_parser)
|
||||
|
||||
# Assert
|
||||
mock_parent_add_args.assert_called_once_with(mock_parser)
|
||||
|
||||
# Verify processor-specific arguments were added
|
||||
assert mock_parser.add_argument.call_count >= 2 # At least store-uri and api-key
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__])
|
||||
537
tests/unit/test_query/test_graph_embeddings_qdrant_query.py
Normal file
537
tests/unit/test_query/test_graph_embeddings_qdrant_query.py
Normal file
|
|
@ -0,0 +1,537 @@
|
|||
"""
|
||||
Unit tests for trustgraph.query.graph_embeddings.qdrant.service
|
||||
Testing graph embeddings query functionality
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from unittest import IsolatedAsyncioTestCase
|
||||
|
||||
# Import the service under test
|
||||
from trustgraph.query.graph_embeddings.qdrant.service import Processor
|
||||
|
||||
|
||||
class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
|
||||
"""Test Qdrant graph embeddings query functionality"""
|
||||
|
||||
@patch('trustgraph.query.graph_embeddings.qdrant.service.QdrantClient')
|
||||
@patch('trustgraph.base.GraphEmbeddingsQueryService.__init__')
|
||||
async def test_processor_initialization_basic(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test basic Qdrant processor initialization"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
config = {
|
||||
'store_uri': 'http://localhost:6333',
|
||||
'api_key': 'test-api-key',
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-graph-query-processor'
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
# Verify base class initialization was called
|
||||
mock_base_init.assert_called_once()
|
||||
|
||||
# Verify QdrantClient was created with correct parameters
|
||||
mock_qdrant_client.assert_called_once_with(url='http://localhost:6333', api_key='test-api-key')
|
||||
|
||||
# Verify processor attributes
|
||||
assert hasattr(processor, 'qdrant')
|
||||
assert processor.qdrant == mock_qdrant_instance
|
||||
|
||||
@patch('trustgraph.query.graph_embeddings.qdrant.service.QdrantClient')
|
||||
@patch('trustgraph.base.GraphEmbeddingsQueryService.__init__')
|
||||
async def test_processor_initialization_with_defaults(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test processor initialization with default values"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-graph-query-processor'
|
||||
# No store_uri or api_key provided - should use defaults
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
# Verify QdrantClient was created with default URI and None API key
|
||||
mock_qdrant_client.assert_called_once_with(url='http://localhost:6333', api_key=None)
|
||||
|
||||
@patch('trustgraph.query.graph_embeddings.qdrant.service.QdrantClient')
|
||||
@patch('trustgraph.base.GraphEmbeddingsQueryService.__init__')
|
||||
async def test_create_value_http_uri(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test create_value with HTTP URI"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_client.return_value = MagicMock()
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
value = processor.create_value('http://example.com/entity')
|
||||
|
||||
# Assert
|
||||
assert hasattr(value, 'value')
|
||||
assert value.value == 'http://example.com/entity'
|
||||
assert hasattr(value, 'is_uri')
|
||||
assert value.is_uri == True
|
||||
|
||||
@patch('trustgraph.query.graph_embeddings.qdrant.service.QdrantClient')
|
||||
@patch('trustgraph.base.GraphEmbeddingsQueryService.__init__')
|
||||
async def test_create_value_https_uri(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test create_value with HTTPS URI"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_client.return_value = MagicMock()
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
value = processor.create_value('https://secure.example.com/entity')
|
||||
|
||||
# Assert
|
||||
assert hasattr(value, 'value')
|
||||
assert value.value == 'https://secure.example.com/entity'
|
||||
assert hasattr(value, 'is_uri')
|
||||
assert value.is_uri == True
|
||||
|
||||
@patch('trustgraph.query.graph_embeddings.qdrant.service.QdrantClient')
|
||||
@patch('trustgraph.base.GraphEmbeddingsQueryService.__init__')
|
||||
async def test_create_value_regular_string(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test create_value with regular string (non-URI)"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_client.return_value = MagicMock()
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
value = processor.create_value('regular entity name')
|
||||
|
||||
# Assert
|
||||
assert hasattr(value, 'value')
|
||||
assert value.value == 'regular entity name'
|
||||
assert hasattr(value, 'is_uri')
|
||||
assert value.is_uri == False
|
||||
|
||||
@patch('trustgraph.query.graph_embeddings.qdrant.service.QdrantClient')
|
||||
@patch('trustgraph.base.GraphEmbeddingsQueryService.__init__')
|
||||
async def test_query_graph_embeddings_single_vector(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test querying graph embeddings with single vector"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
# Mock query response
|
||||
mock_point1 = MagicMock()
|
||||
mock_point1.payload = {'entity': 'entity1'}
|
||||
mock_point2 = MagicMock()
|
||||
mock_point2.payload = {'entity': 'entity2'}
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.points = [mock_point1, mock_point2]
|
||||
mock_qdrant_instance.query_points.return_value = mock_response
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Create mock message
|
||||
mock_message = MagicMock()
|
||||
mock_message.vectors = [[0.1, 0.2, 0.3]]
|
||||
mock_message.limit = 5
|
||||
mock_message.user = 'test_user'
|
||||
mock_message.collection = 'test_collection'
|
||||
|
||||
# Act
|
||||
result = await processor.query_graph_embeddings(mock_message)
|
||||
|
||||
# Assert
|
||||
# Verify query was called with correct parameters
|
||||
expected_collection = 't_test_user_test_collection_3'
|
||||
mock_qdrant_instance.query_points.assert_called_once_with(
|
||||
collection_name=expected_collection,
|
||||
query=[0.1, 0.2, 0.3],
|
||||
limit=10, # limit * 2 for deduplication
|
||||
with_payload=True
|
||||
)
|
||||
|
||||
# Verify result contains expected entities
|
||||
assert len(result) == 2
|
||||
assert all(hasattr(entity, 'value') for entity in result)
|
||||
entity_values = [entity.value for entity in result]
|
||||
assert 'entity1' in entity_values
|
||||
assert 'entity2' in entity_values
|
||||
|
||||
@patch('trustgraph.query.graph_embeddings.qdrant.service.QdrantClient')
|
||||
@patch('trustgraph.base.GraphEmbeddingsQueryService.__init__')
|
||||
async def test_query_graph_embeddings_multiple_vectors(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test querying graph embeddings with multiple vectors"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
# Mock query responses for different vectors
|
||||
mock_point1 = MagicMock()
|
||||
mock_point1.payload = {'entity': 'entity1'}
|
||||
mock_point2 = MagicMock()
|
||||
mock_point2.payload = {'entity': 'entity2'}
|
||||
mock_point3 = MagicMock()
|
||||
mock_point3.payload = {'entity': 'entity3'}
|
||||
|
||||
mock_response1 = MagicMock()
|
||||
mock_response1.points = [mock_point1, mock_point2]
|
||||
mock_response2 = MagicMock()
|
||||
mock_response2.points = [mock_point2, mock_point3]
|
||||
mock_qdrant_instance.query_points.side_effect = [mock_response1, mock_response2]
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Create mock message with multiple vectors
|
||||
mock_message = MagicMock()
|
||||
mock_message.vectors = [[0.1, 0.2], [0.3, 0.4]]
|
||||
mock_message.limit = 3
|
||||
mock_message.user = 'multi_user'
|
||||
mock_message.collection = 'multi_collection'
|
||||
|
||||
# Act
|
||||
result = await processor.query_graph_embeddings(mock_message)
|
||||
|
||||
# Assert
|
||||
# Verify query was called twice
|
||||
assert mock_qdrant_instance.query_points.call_count == 2
|
||||
|
||||
# Verify both collections were queried
|
||||
expected_collection = 't_multi_user_multi_collection_2'
|
||||
calls = mock_qdrant_instance.query_points.call_args_list
|
||||
assert calls[0][1]['collection_name'] == expected_collection
|
||||
assert calls[1][1]['collection_name'] == expected_collection
|
||||
assert calls[0][1]['query'] == [0.1, 0.2]
|
||||
assert calls[1][1]['query'] == [0.3, 0.4]
|
||||
|
||||
# Verify deduplication - entity2 appears in both results but should only appear once
|
||||
entity_values = [entity.value for entity in result]
|
||||
assert len(set(entity_values)) == len(entity_values) # All unique
|
||||
assert 'entity1' in entity_values
|
||||
assert 'entity2' in entity_values
|
||||
assert 'entity3' in entity_values
|
||||
|
||||
@patch('trustgraph.query.graph_embeddings.qdrant.service.QdrantClient')
|
||||
@patch('trustgraph.base.GraphEmbeddingsQueryService.__init__')
|
||||
async def test_query_graph_embeddings_with_limit(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test querying graph embeddings respects limit parameter"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
# Mock query response with more results than limit
|
||||
mock_points = []
|
||||
for i in range(10):
|
||||
mock_point = MagicMock()
|
||||
mock_point.payload = {'entity': f'entity{i}'}
|
||||
mock_points.append(mock_point)
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.points = mock_points
|
||||
mock_qdrant_instance.query_points.return_value = mock_response
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Create mock message with limit
|
||||
mock_message = MagicMock()
|
||||
mock_message.vectors = [[0.1, 0.2, 0.3]]
|
||||
mock_message.limit = 3 # Should only return 3 results
|
||||
mock_message.user = 'limit_user'
|
||||
mock_message.collection = 'limit_collection'
|
||||
|
||||
# Act
|
||||
result = await processor.query_graph_embeddings(mock_message)
|
||||
|
||||
# Assert
|
||||
# Verify query was called with limit * 2
|
||||
mock_qdrant_instance.query_points.assert_called_once()
|
||||
call_args = mock_qdrant_instance.query_points.call_args
|
||||
assert call_args[1]['limit'] == 6 # 3 * 2
|
||||
|
||||
# Verify result is limited to requested limit
|
||||
assert len(result) == 3
|
||||
|
||||
@patch('trustgraph.query.graph_embeddings.qdrant.service.QdrantClient')
|
||||
@patch('trustgraph.base.GraphEmbeddingsQueryService.__init__')
|
||||
async def test_query_graph_embeddings_empty_results(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test querying graph embeddings with empty results"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
# Mock empty query response
|
||||
mock_response = MagicMock()
|
||||
mock_response.points = []
|
||||
mock_qdrant_instance.query_points.return_value = mock_response
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Create mock message
|
||||
mock_message = MagicMock()
|
||||
mock_message.vectors = [[0.1, 0.2]]
|
||||
mock_message.limit = 5
|
||||
mock_message.user = 'empty_user'
|
||||
mock_message.collection = 'empty_collection'
|
||||
|
||||
# Act
|
||||
result = await processor.query_graph_embeddings(mock_message)
|
||||
|
||||
# Assert
|
||||
assert result == []
|
||||
|
||||
@patch('trustgraph.query.graph_embeddings.qdrant.service.QdrantClient')
|
||||
@patch('trustgraph.base.GraphEmbeddingsQueryService.__init__')
|
||||
async def test_query_graph_embeddings_different_dimensions(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test querying graph embeddings with different vector dimensions"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
# Mock query responses
|
||||
mock_point1 = MagicMock()
|
||||
mock_point1.payload = {'entity': 'entity2d'}
|
||||
mock_point2 = MagicMock()
|
||||
mock_point2.payload = {'entity': 'entity3d'}
|
||||
|
||||
mock_response1 = MagicMock()
|
||||
mock_response1.points = [mock_point1]
|
||||
mock_response2 = MagicMock()
|
||||
mock_response2.points = [mock_point2]
|
||||
mock_qdrant_instance.query_points.side_effect = [mock_response1, mock_response2]
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Create mock message with different dimension vectors
|
||||
mock_message = MagicMock()
|
||||
mock_message.vectors = [[0.1, 0.2], [0.3, 0.4, 0.5]] # 2D and 3D
|
||||
mock_message.limit = 5
|
||||
mock_message.user = 'dim_user'
|
||||
mock_message.collection = 'dim_collection'
|
||||
|
||||
# Act
|
||||
result = await processor.query_graph_embeddings(mock_message)
|
||||
|
||||
# Assert
|
||||
# Verify query was called twice with different collections
|
||||
assert mock_qdrant_instance.query_points.call_count == 2
|
||||
calls = mock_qdrant_instance.query_points.call_args_list
|
||||
|
||||
# First call should use 2D collection
|
||||
assert calls[0][1]['collection_name'] == 't_dim_user_dim_collection_2'
|
||||
assert calls[0][1]['query'] == [0.1, 0.2]
|
||||
|
||||
# Second call should use 3D collection
|
||||
assert calls[1][1]['collection_name'] == 't_dim_user_dim_collection_3'
|
||||
assert calls[1][1]['query'] == [0.3, 0.4, 0.5]
|
||||
|
||||
# Verify results
|
||||
entity_values = [entity.value for entity in result]
|
||||
assert 'entity2d' in entity_values
|
||||
assert 'entity3d' in entity_values
|
||||
|
||||
@patch('trustgraph.query.graph_embeddings.qdrant.service.QdrantClient')
|
||||
@patch('trustgraph.base.GraphEmbeddingsQueryService.__init__')
|
||||
async def test_query_graph_embeddings_uri_detection(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test querying graph embeddings with URI detection"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
# Mock query response with URIs and regular strings
|
||||
mock_point1 = MagicMock()
|
||||
mock_point1.payload = {'entity': 'http://example.com/entity1'}
|
||||
mock_point2 = MagicMock()
|
||||
mock_point2.payload = {'entity': 'https://secure.example.com/entity2'}
|
||||
mock_point3 = MagicMock()
|
||||
mock_point3.payload = {'entity': 'regular entity'}
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.points = [mock_point1, mock_point2, mock_point3]
|
||||
mock_qdrant_instance.query_points.return_value = mock_response
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Create mock message
|
||||
mock_message = MagicMock()
|
||||
mock_message.vectors = [[0.1, 0.2]]
|
||||
mock_message.limit = 5
|
||||
mock_message.user = 'uri_user'
|
||||
mock_message.collection = 'uri_collection'
|
||||
|
||||
# Act
|
||||
result = await processor.query_graph_embeddings(mock_message)
|
||||
|
||||
# Assert
|
||||
assert len(result) == 3
|
||||
|
||||
# Check URI entities
|
||||
uri_entities = [entity for entity in result if hasattr(entity, 'is_uri') and entity.is_uri]
|
||||
assert len(uri_entities) == 2
|
||||
uri_values = [entity.value for entity in uri_entities]
|
||||
assert 'http://example.com/entity1' in uri_values
|
||||
assert 'https://secure.example.com/entity2' in uri_values
|
||||
|
||||
# Check regular entities
|
||||
regular_entities = [entity for entity in result if hasattr(entity, 'is_uri') and not entity.is_uri]
|
||||
assert len(regular_entities) == 1
|
||||
assert regular_entities[0].value == 'regular entity'
|
||||
|
||||
@patch('trustgraph.query.graph_embeddings.qdrant.service.QdrantClient')
|
||||
@patch('trustgraph.base.GraphEmbeddingsQueryService.__init__')
|
||||
async def test_query_graph_embeddings_qdrant_error(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test querying graph embeddings handles Qdrant errors"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
# Mock Qdrant error
|
||||
mock_qdrant_instance.query_points.side_effect = Exception("Qdrant connection failed")
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Create mock message
|
||||
mock_message = MagicMock()
|
||||
mock_message.vectors = [[0.1, 0.2]]
|
||||
mock_message.limit = 5
|
||||
mock_message.user = 'error_user'
|
||||
mock_message.collection = 'error_collection'
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(Exception, match="Qdrant connection failed"):
|
||||
await processor.query_graph_embeddings(mock_message)
|
||||
|
||||
@patch('trustgraph.query.graph_embeddings.qdrant.service.QdrantClient')
|
||||
@patch('trustgraph.base.GraphEmbeddingsQueryService.__init__')
|
||||
async def test_query_graph_embeddings_zero_limit(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test querying graph embeddings with zero limit"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
# Mock query response - even with zero limit, Qdrant might return results
|
||||
mock_point = MagicMock()
|
||||
mock_point.payload = {'entity': 'entity1'}
|
||||
mock_response = MagicMock()
|
||||
mock_response.points = [mock_point]
|
||||
mock_qdrant_instance.query_points.return_value = mock_response
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Create mock message with zero limit
|
||||
mock_message = MagicMock()
|
||||
mock_message.vectors = [[0.1, 0.2]]
|
||||
mock_message.limit = 0
|
||||
mock_message.user = 'zero_user'
|
||||
mock_message.collection = 'zero_collection'
|
||||
|
||||
# Act
|
||||
result = await processor.query_graph_embeddings(mock_message)
|
||||
|
||||
# Assert
|
||||
# Should still query (with limit 0)
|
||||
mock_qdrant_instance.query_points.assert_called_once()
|
||||
call_args = mock_qdrant_instance.query_points.call_args
|
||||
assert call_args[1]['limit'] == 0 # 0 * 2 = 0
|
||||
|
||||
# With zero limit, the logic still adds one entity before checking the limit
|
||||
# So it returns one result (current behavior, not ideal but actual)
|
||||
assert len(result) == 1
|
||||
assert result[0].value == 'entity1'
|
||||
|
||||
@patch('trustgraph.query.graph_embeddings.qdrant.service.QdrantClient')
|
||||
@patch('trustgraph.base.GraphEmbeddingsQueryService.__init__')
|
||||
async def test_add_args_calls_parent(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test that add_args() calls parent add_args method"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_client.return_value = MagicMock()
|
||||
mock_parser = MagicMock()
|
||||
|
||||
# Act
|
||||
with patch('trustgraph.base.GraphEmbeddingsQueryService.add_args') as mock_parent_add_args:
|
||||
Processor.add_args(mock_parser)
|
||||
|
||||
# Assert
|
||||
mock_parent_add_args.assert_called_once_with(mock_parser)
|
||||
|
||||
# Verify processor-specific arguments were added
|
||||
assert mock_parser.add_argument.call_count >= 2 # At least store-uri and api-key
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__])
|
||||
539
tests/unit/test_query/test_triples_cassandra_query.py
Normal file
539
tests/unit/test_query/test_triples_cassandra_query.py
Normal file
|
|
@ -0,0 +1,539 @@
|
|||
"""
|
||||
Tests for Cassandra triples query service
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from trustgraph.query.triples.cassandra.service import Processor
|
||||
from trustgraph.schema import Value
|
||||
|
||||
|
||||
class TestCassandraQueryProcessor:
|
||||
"""Test cases for Cassandra query processor"""
|
||||
|
||||
@pytest.fixture
|
||||
def processor(self):
|
||||
"""Create a processor instance for testing"""
|
||||
return Processor(
|
||||
taskgroup=MagicMock(),
|
||||
id='test-cassandra-query',
|
||||
graph_host='localhost'
|
||||
)
|
||||
|
||||
def test_create_value_with_http_uri(self, processor):
|
||||
"""Test create_value with HTTP URI"""
|
||||
result = processor.create_value("http://example.com/resource")
|
||||
|
||||
assert isinstance(result, Value)
|
||||
assert result.value == "http://example.com/resource"
|
||||
assert result.is_uri is True
|
||||
|
||||
def test_create_value_with_https_uri(self, processor):
|
||||
"""Test create_value with HTTPS URI"""
|
||||
result = processor.create_value("https://example.com/resource")
|
||||
|
||||
assert isinstance(result, Value)
|
||||
assert result.value == "https://example.com/resource"
|
||||
assert result.is_uri is True
|
||||
|
||||
def test_create_value_with_literal(self, processor):
|
||||
"""Test create_value with literal value"""
|
||||
result = processor.create_value("just a literal string")
|
||||
|
||||
assert isinstance(result, Value)
|
||||
assert result.value == "just a literal string"
|
||||
assert result.is_uri is False
|
||||
|
||||
def test_create_value_with_empty_string(self, processor):
|
||||
"""Test create_value with empty string"""
|
||||
result = processor.create_value("")
|
||||
|
||||
assert isinstance(result, Value)
|
||||
assert result.value == ""
|
||||
assert result.is_uri is False
|
||||
|
||||
def test_create_value_with_partial_uri(self, processor):
|
||||
"""Test create_value with string that looks like URI but isn't complete"""
|
||||
result = processor.create_value("http")
|
||||
|
||||
assert isinstance(result, Value)
|
||||
assert result.value == "http"
|
||||
assert result.is_uri is False
|
||||
|
||||
def test_create_value_with_ftp_uri(self, processor):
|
||||
"""Test create_value with FTP URI (should not be detected as URI)"""
|
||||
result = processor.create_value("ftp://example.com/file")
|
||||
|
||||
assert isinstance(result, Value)
|
||||
assert result.value == "ftp://example.com/file"
|
||||
assert result.is_uri is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.query.triples.cassandra.service.TrustGraph')
|
||||
async def test_query_triples_spo_query(self, mock_trustgraph):
|
||||
"""Test querying triples with subject, predicate, and object specified"""
|
||||
from trustgraph.schema import TriplesQueryRequest, Value
|
||||
|
||||
# Setup mock TrustGraph
|
||||
mock_tg_instance = MagicMock()
|
||||
mock_trustgraph.return_value = mock_tg_instance
|
||||
mock_tg_instance.get_spo.return_value = None # SPO query returns None if found
|
||||
|
||||
processor = Processor(
|
||||
taskgroup=MagicMock(),
|
||||
id='test-cassandra-query',
|
||||
graph_host='localhost'
|
||||
)
|
||||
|
||||
# Create query request with all SPO values
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=Value(value='test_subject', is_uri=False),
|
||||
p=Value(value='test_predicate', is_uri=False),
|
||||
o=Value(value='test_object', is_uri=False),
|
||||
limit=100
|
||||
)
|
||||
|
||||
result = await processor.query_triples(query)
|
||||
|
||||
# Verify TrustGraph was created with correct parameters
|
||||
mock_trustgraph.assert_called_once_with(
|
||||
hosts=['localhost'],
|
||||
keyspace='test_user',
|
||||
table='test_collection'
|
||||
)
|
||||
|
||||
# Verify get_spo was called with correct parameters
|
||||
mock_tg_instance.get_spo.assert_called_once_with(
|
||||
'test_subject', 'test_predicate', 'test_object', limit=100
|
||||
)
|
||||
|
||||
# Verify result contains the queried triple
|
||||
assert len(result) == 1
|
||||
assert result[0].s.value == 'test_subject'
|
||||
assert result[0].p.value == 'test_predicate'
|
||||
assert result[0].o.value == 'test_object'
|
||||
|
||||
def test_processor_initialization_with_defaults(self):
|
||||
"""Test processor initialization with default parameters"""
|
||||
taskgroup_mock = MagicMock()
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
assert processor.graph_host == ['localhost']
|
||||
assert processor.username is None
|
||||
assert processor.password is None
|
||||
assert processor.table is None
|
||||
|
||||
def test_processor_initialization_with_custom_params(self):
|
||||
"""Test processor initialization with custom parameters"""
|
||||
taskgroup_mock = MagicMock()
|
||||
|
||||
processor = Processor(
|
||||
taskgroup=taskgroup_mock,
|
||||
graph_host='cassandra.example.com',
|
||||
graph_username='queryuser',
|
||||
graph_password='querypass'
|
||||
)
|
||||
|
||||
assert processor.graph_host == ['cassandra.example.com']
|
||||
assert processor.username == 'queryuser'
|
||||
assert processor.password == 'querypass'
|
||||
assert processor.table is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.query.triples.cassandra.service.TrustGraph')
|
||||
async def test_query_triples_sp_pattern(self, mock_trustgraph):
|
||||
"""Test SP query pattern (subject and predicate, no object)"""
|
||||
from trustgraph.schema import TriplesQueryRequest, Value
|
||||
|
||||
# Setup mock TrustGraph and response
|
||||
mock_tg_instance = MagicMock()
|
||||
mock_trustgraph.return_value = mock_tg_instance
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.o = 'result_object'
|
||||
mock_tg_instance.get_sp.return_value = [mock_result]
|
||||
|
||||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=Value(value='test_subject', is_uri=False),
|
||||
p=Value(value='test_predicate', is_uri=False),
|
||||
o=None,
|
||||
limit=50
|
||||
)
|
||||
|
||||
result = await processor.query_triples(query)
|
||||
|
||||
mock_tg_instance.get_sp.assert_called_once_with('test_subject', 'test_predicate', limit=50)
|
||||
assert len(result) == 1
|
||||
assert result[0].s.value == 'test_subject'
|
||||
assert result[0].p.value == 'test_predicate'
|
||||
assert result[0].o.value == 'result_object'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.query.triples.cassandra.service.TrustGraph')
|
||||
async def test_query_triples_s_pattern(self, mock_trustgraph):
|
||||
"""Test S query pattern (subject only)"""
|
||||
from trustgraph.schema import TriplesQueryRequest, Value
|
||||
|
||||
mock_tg_instance = MagicMock()
|
||||
mock_trustgraph.return_value = mock_tg_instance
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.p = 'result_predicate'
|
||||
mock_result.o = 'result_object'
|
||||
mock_tg_instance.get_s.return_value = [mock_result]
|
||||
|
||||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=Value(value='test_subject', is_uri=False),
|
||||
p=None,
|
||||
o=None,
|
||||
limit=25
|
||||
)
|
||||
|
||||
result = await processor.query_triples(query)
|
||||
|
||||
mock_tg_instance.get_s.assert_called_once_with('test_subject', limit=25)
|
||||
assert len(result) == 1
|
||||
assert result[0].s.value == 'test_subject'
|
||||
assert result[0].p.value == 'result_predicate'
|
||||
assert result[0].o.value == 'result_object'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.query.triples.cassandra.service.TrustGraph')
|
||||
async def test_query_triples_p_pattern(self, mock_trustgraph):
|
||||
"""Test P query pattern (predicate only)"""
|
||||
from trustgraph.schema import TriplesQueryRequest, Value
|
||||
|
||||
mock_tg_instance = MagicMock()
|
||||
mock_trustgraph.return_value = mock_tg_instance
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.s = 'result_subject'
|
||||
mock_result.o = 'result_object'
|
||||
mock_tg_instance.get_p.return_value = [mock_result]
|
||||
|
||||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=None,
|
||||
p=Value(value='test_predicate', is_uri=False),
|
||||
o=None,
|
||||
limit=10
|
||||
)
|
||||
|
||||
result = await processor.query_triples(query)
|
||||
|
||||
mock_tg_instance.get_p.assert_called_once_with('test_predicate', limit=10)
|
||||
assert len(result) == 1
|
||||
assert result[0].s.value == 'result_subject'
|
||||
assert result[0].p.value == 'test_predicate'
|
||||
assert result[0].o.value == 'result_object'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.query.triples.cassandra.service.TrustGraph')
|
||||
async def test_query_triples_o_pattern(self, mock_trustgraph):
|
||||
"""Test O query pattern (object only)"""
|
||||
from trustgraph.schema import TriplesQueryRequest, Value
|
||||
|
||||
mock_tg_instance = MagicMock()
|
||||
mock_trustgraph.return_value = mock_tg_instance
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.s = 'result_subject'
|
||||
mock_result.p = 'result_predicate'
|
||||
mock_tg_instance.get_o.return_value = [mock_result]
|
||||
|
||||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=None,
|
||||
p=None,
|
||||
o=Value(value='test_object', is_uri=False),
|
||||
limit=75
|
||||
)
|
||||
|
||||
result = await processor.query_triples(query)
|
||||
|
||||
mock_tg_instance.get_o.assert_called_once_with('test_object', limit=75)
|
||||
assert len(result) == 1
|
||||
assert result[0].s.value == 'result_subject'
|
||||
assert result[0].p.value == 'result_predicate'
|
||||
assert result[0].o.value == 'test_object'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.query.triples.cassandra.service.TrustGraph')
|
||||
async def test_query_triples_get_all_pattern(self, mock_trustgraph):
|
||||
"""Test query pattern with no constraints (get all)"""
|
||||
from trustgraph.schema import TriplesQueryRequest
|
||||
|
||||
mock_tg_instance = MagicMock()
|
||||
mock_trustgraph.return_value = mock_tg_instance
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.s = 'all_subject'
|
||||
mock_result.p = 'all_predicate'
|
||||
mock_result.o = 'all_object'
|
||||
mock_tg_instance.get_all.return_value = [mock_result]
|
||||
|
||||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=None,
|
||||
p=None,
|
||||
o=None,
|
||||
limit=1000
|
||||
)
|
||||
|
||||
result = await processor.query_triples(query)
|
||||
|
||||
mock_tg_instance.get_all.assert_called_once_with(limit=1000)
|
||||
assert len(result) == 1
|
||||
assert result[0].s.value == 'all_subject'
|
||||
assert result[0].p.value == 'all_predicate'
|
||||
assert result[0].o.value == 'all_object'
|
||||
|
||||
def test_add_args_method(self):
|
||||
"""Test that add_args properly configures argument parser"""
|
||||
from argparse import ArgumentParser
|
||||
|
||||
parser = ArgumentParser()
|
||||
|
||||
# Mock the parent class add_args method
|
||||
with patch('trustgraph.query.triples.cassandra.service.TriplesQueryService.add_args') as mock_parent_add_args:
|
||||
Processor.add_args(parser)
|
||||
|
||||
# Verify parent add_args was called
|
||||
mock_parent_add_args.assert_called_once_with(parser)
|
||||
|
||||
# Verify our specific arguments were added
|
||||
args = parser.parse_args([])
|
||||
|
||||
assert hasattr(args, 'graph_host')
|
||||
assert args.graph_host == 'localhost'
|
||||
assert hasattr(args, 'graph_username')
|
||||
assert args.graph_username is None
|
||||
assert hasattr(args, 'graph_password')
|
||||
assert args.graph_password is None
|
||||
|
||||
def test_add_args_with_custom_values(self):
|
||||
"""Test add_args with custom command line values"""
|
||||
from argparse import ArgumentParser
|
||||
|
||||
parser = ArgumentParser()
|
||||
|
||||
with patch('trustgraph.query.triples.cassandra.service.TriplesQueryService.add_args'):
|
||||
Processor.add_args(parser)
|
||||
|
||||
# Test parsing with custom values
|
||||
args = parser.parse_args([
|
||||
'--graph-host', 'query.cassandra.com',
|
||||
'--graph-username', 'queryuser',
|
||||
'--graph-password', 'querypass'
|
||||
])
|
||||
|
||||
assert args.graph_host == 'query.cassandra.com'
|
||||
assert args.graph_username == 'queryuser'
|
||||
assert args.graph_password == 'querypass'
|
||||
|
||||
def test_add_args_short_form(self):
|
||||
"""Test add_args with short form arguments"""
|
||||
from argparse import ArgumentParser
|
||||
|
||||
parser = ArgumentParser()
|
||||
|
||||
with patch('trustgraph.query.triples.cassandra.service.TriplesQueryService.add_args'):
|
||||
Processor.add_args(parser)
|
||||
|
||||
# Test parsing with short form
|
||||
args = parser.parse_args(['-g', 'short.query.com'])
|
||||
|
||||
assert args.graph_host == 'short.query.com'
|
||||
|
||||
@patch('trustgraph.query.triples.cassandra.service.Processor.launch')
|
||||
def test_run_function(self, mock_launch):
|
||||
"""Test the run function calls Processor.launch with correct parameters"""
|
||||
from trustgraph.query.triples.cassandra.service import run, default_ident
|
||||
|
||||
run()
|
||||
|
||||
mock_launch.assert_called_once_with(default_ident, '\nTriples query service. Input is a (s, p, o) triple, some values may be\nnull. Output is a list of triples.\n')
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.query.triples.cassandra.service.TrustGraph')
|
||||
async def test_query_triples_with_authentication(self, mock_trustgraph):
|
||||
"""Test querying with username and password authentication"""
|
||||
from trustgraph.schema import TriplesQueryRequest, Value
|
||||
|
||||
mock_tg_instance = MagicMock()
|
||||
mock_trustgraph.return_value = mock_tg_instance
|
||||
mock_tg_instance.get_spo.return_value = None
|
||||
|
||||
processor = Processor(
|
||||
taskgroup=MagicMock(),
|
||||
graph_username='authuser',
|
||||
graph_password='authpass'
|
||||
)
|
||||
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=Value(value='test_subject', is_uri=False),
|
||||
p=Value(value='test_predicate', is_uri=False),
|
||||
o=Value(value='test_object', is_uri=False),
|
||||
limit=100
|
||||
)
|
||||
|
||||
await processor.query_triples(query)
|
||||
|
||||
# Verify TrustGraph was created with authentication
|
||||
mock_trustgraph.assert_called_once_with(
|
||||
hosts=['localhost'],
|
||||
keyspace='test_user',
|
||||
table='test_collection',
|
||||
username='authuser',
|
||||
password='authpass'
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.query.triples.cassandra.service.TrustGraph')
|
||||
async def test_query_triples_table_reuse(self, mock_trustgraph):
|
||||
"""Test that TrustGraph is reused for same table"""
|
||||
from trustgraph.schema import TriplesQueryRequest, Value
|
||||
|
||||
mock_tg_instance = MagicMock()
|
||||
mock_trustgraph.return_value = mock_tg_instance
|
||||
mock_tg_instance.get_spo.return_value = None
|
||||
|
||||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=Value(value='test_subject', is_uri=False),
|
||||
p=Value(value='test_predicate', is_uri=False),
|
||||
o=Value(value='test_object', is_uri=False),
|
||||
limit=100
|
||||
)
|
||||
|
||||
# First query should create TrustGraph
|
||||
await processor.query_triples(query)
|
||||
assert mock_trustgraph.call_count == 1
|
||||
|
||||
# Second query with same table should reuse TrustGraph
|
||||
await processor.query_triples(query)
|
||||
assert mock_trustgraph.call_count == 1 # Should not increase
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.query.triples.cassandra.service.TrustGraph')
|
||||
async def test_query_triples_table_switching(self, mock_trustgraph):
|
||||
"""Test table switching creates new TrustGraph"""
|
||||
from trustgraph.schema import TriplesQueryRequest, Value
|
||||
|
||||
mock_tg_instance1 = MagicMock()
|
||||
mock_tg_instance2 = MagicMock()
|
||||
mock_trustgraph.side_effect = [mock_tg_instance1, mock_tg_instance2]
|
||||
|
||||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
# First query
|
||||
query1 = TriplesQueryRequest(
|
||||
user='user1',
|
||||
collection='collection1',
|
||||
s=Value(value='test_subject', is_uri=False),
|
||||
p=None,
|
||||
o=None,
|
||||
limit=100
|
||||
)
|
||||
|
||||
await processor.query_triples(query1)
|
||||
assert processor.table == ('user1', 'collection1')
|
||||
|
||||
# Second query with different table
|
||||
query2 = TriplesQueryRequest(
|
||||
user='user2',
|
||||
collection='collection2',
|
||||
s=Value(value='test_subject', is_uri=False),
|
||||
p=None,
|
||||
o=None,
|
||||
limit=100
|
||||
)
|
||||
|
||||
await processor.query_triples(query2)
|
||||
assert processor.table == ('user2', 'collection2')
|
||||
|
||||
# Verify TrustGraph was created twice
|
||||
assert mock_trustgraph.call_count == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.query.triples.cassandra.service.TrustGraph')
|
||||
async def test_query_triples_exception_handling(self, mock_trustgraph):
|
||||
"""Test exception handling during query execution"""
|
||||
from trustgraph.schema import TriplesQueryRequest, Value
|
||||
|
||||
mock_tg_instance = MagicMock()
|
||||
mock_trustgraph.return_value = mock_tg_instance
|
||||
mock_tg_instance.get_spo.side_effect = Exception("Query failed")
|
||||
|
||||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=Value(value='test_subject', is_uri=False),
|
||||
p=Value(value='test_predicate', is_uri=False),
|
||||
o=Value(value='test_object', is_uri=False),
|
||||
limit=100
|
||||
)
|
||||
|
||||
with pytest.raises(Exception, match="Query failed"):
|
||||
await processor.query_triples(query)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.query.triples.cassandra.service.TrustGraph')
|
||||
async def test_query_triples_multiple_results(self, mock_trustgraph):
|
||||
"""Test query returning multiple results"""
|
||||
from trustgraph.schema import TriplesQueryRequest, Value
|
||||
|
||||
mock_tg_instance = MagicMock()
|
||||
mock_trustgraph.return_value = mock_tg_instance
|
||||
|
||||
# Mock multiple results
|
||||
mock_result1 = MagicMock()
|
||||
mock_result1.o = 'object1'
|
||||
mock_result2 = MagicMock()
|
||||
mock_result2.o = 'object2'
|
||||
mock_tg_instance.get_sp.return_value = [mock_result1, mock_result2]
|
||||
|
||||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=Value(value='test_subject', is_uri=False),
|
||||
p=Value(value='test_predicate', is_uri=False),
|
||||
o=None,
|
||||
limit=100
|
||||
)
|
||||
|
||||
result = await processor.query_triples(query)
|
||||
|
||||
assert len(result) == 2
|
||||
assert result[0].o.value == 'object1'
|
||||
assert result[1].o.value == 'object2'
|
||||
Loading…
Add table
Add a link
Reference in a new issue