Increase storage test coverage (#435)

* Fixing storage and adding tests

* PR pipeline only runs quick tests
This commit is contained in:
cybermaggedon 2025-07-15 09:33:35 +01:00 committed by GitHub
parent 4daa54abaf
commit f37decea2b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
33 changed files with 7606 additions and 754 deletions

View file

@ -0,0 +1,387 @@
"""
Tests for Milvus document embeddings storage service
"""
import pytest
from unittest.mock import MagicMock, patch
from trustgraph.storage.doc_embeddings.milvus.write import Processor
from trustgraph.schema import ChunkEmbeddings
class TestMilvusDocEmbeddingsStorageProcessor:
"""Test cases for Milvus document embeddings storage processor"""
@pytest.fixture
def mock_message(self):
"""Create a mock message for testing"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
# Create test document embeddings
chunk1 = ChunkEmbeddings(
chunk=b"This is the first document chunk",
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
)
chunk2 = ChunkEmbeddings(
chunk=b"This is the second document chunk",
vectors=[[0.7, 0.8, 0.9]]
)
message.chunks = [chunk1, chunk2]
return message
@pytest.fixture
def processor(self):
"""Create a processor instance for testing"""
with patch('trustgraph.storage.doc_embeddings.milvus.write.DocVectors') as mock_doc_vectors:
mock_vecstore = MagicMock()
mock_doc_vectors.return_value = mock_vecstore
processor = Processor(
taskgroup=MagicMock(),
id='test-milvus-de-storage',
store_uri='http://localhost:19530'
)
return processor
@patch('trustgraph.storage.doc_embeddings.milvus.write.DocVectors')
def test_processor_initialization_with_defaults(self, mock_doc_vectors):
"""Test processor initialization with default parameters"""
taskgroup_mock = MagicMock()
mock_vecstore = MagicMock()
mock_doc_vectors.return_value = mock_vecstore
processor = Processor(taskgroup=taskgroup_mock)
mock_doc_vectors.assert_called_once_with('http://localhost:19530')
assert processor.vecstore == mock_vecstore
@patch('trustgraph.storage.doc_embeddings.milvus.write.DocVectors')
def test_processor_initialization_with_custom_params(self, mock_doc_vectors):
"""Test processor initialization with custom parameters"""
taskgroup_mock = MagicMock()
mock_vecstore = MagicMock()
mock_doc_vectors.return_value = mock_vecstore
processor = Processor(
taskgroup=taskgroup_mock,
store_uri='http://custom-milvus:19530'
)
mock_doc_vectors.assert_called_once_with('http://custom-milvus:19530')
assert processor.vecstore == mock_vecstore
@pytest.mark.asyncio
async def test_store_document_embeddings_single_chunk(self, processor):
"""Test storing document embeddings for a single chunk"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
chunk = ChunkEmbeddings(
chunk=b"Test document content",
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
)
message.chunks = [chunk]
await processor.store_document_embeddings(message)
# Verify insert was called for each vector
expected_calls = [
([0.1, 0.2, 0.3], "Test document content"),
([0.4, 0.5, 0.6], "Test document content"),
]
assert processor.vecstore.insert.call_count == 2
for i, (expected_vec, expected_doc) in enumerate(expected_calls):
actual_call = processor.vecstore.insert.call_args_list[i]
assert actual_call[0][0] == expected_vec
assert actual_call[0][1] == expected_doc
@pytest.mark.asyncio
async def test_store_document_embeddings_multiple_chunks(self, processor, mock_message):
"""Test storing document embeddings for multiple chunks"""
await processor.store_document_embeddings(mock_message)
# Verify insert was called for each vector of each chunk
expected_calls = [
# Chunk 1 vectors
([0.1, 0.2, 0.3], "This is the first document chunk"),
([0.4, 0.5, 0.6], "This is the first document chunk"),
# Chunk 2 vectors
([0.7, 0.8, 0.9], "This is the second document chunk"),
]
assert processor.vecstore.insert.call_count == 3
for i, (expected_vec, expected_doc) in enumerate(expected_calls):
actual_call = processor.vecstore.insert.call_args_list[i]
assert actual_call[0][0] == expected_vec
assert actual_call[0][1] == expected_doc
@pytest.mark.asyncio
async def test_store_document_embeddings_empty_chunk(self, processor):
"""Test storing document embeddings with empty chunk (should be skipped)"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
chunk = ChunkEmbeddings(
chunk=b"",
vectors=[[0.1, 0.2, 0.3]]
)
message.chunks = [chunk]
await processor.store_document_embeddings(message)
# Verify no insert was called for empty chunk
processor.vecstore.insert.assert_not_called()
@pytest.mark.asyncio
async def test_store_document_embeddings_none_chunk(self, processor):
"""Test storing document embeddings with None chunk (should be skipped)"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
chunk = ChunkEmbeddings(
chunk=None,
vectors=[[0.1, 0.2, 0.3]]
)
message.chunks = [chunk]
await processor.store_document_embeddings(message)
# Verify no insert was called for None chunk
processor.vecstore.insert.assert_not_called()
@pytest.mark.asyncio
async def test_store_document_embeddings_mixed_valid_invalid_chunks(self, processor):
"""Test storing document embeddings with mix of valid and invalid chunks"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
valid_chunk = ChunkEmbeddings(
chunk=b"Valid document content",
vectors=[[0.1, 0.2, 0.3]]
)
empty_chunk = ChunkEmbeddings(
chunk=b"",
vectors=[[0.4, 0.5, 0.6]]
)
none_chunk = ChunkEmbeddings(
chunk=None,
vectors=[[0.7, 0.8, 0.9]]
)
message.chunks = [valid_chunk, empty_chunk, none_chunk]
await processor.store_document_embeddings(message)
# Verify only valid chunk was inserted
processor.vecstore.insert.assert_called_once_with(
[0.1, 0.2, 0.3], "Valid document content"
)
@pytest.mark.asyncio
async def test_store_document_embeddings_empty_chunks_list(self, processor):
"""Test storing document embeddings with empty chunks list"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
message.chunks = []
await processor.store_document_embeddings(message)
# Verify no insert was called
processor.vecstore.insert.assert_not_called()
@pytest.mark.asyncio
async def test_store_document_embeddings_chunk_with_no_vectors(self, processor):
"""Test storing document embeddings for chunk with no vectors"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
chunk = ChunkEmbeddings(
chunk=b"Document with no vectors",
vectors=[]
)
message.chunks = [chunk]
await processor.store_document_embeddings(message)
# Verify no insert was called (no vectors to insert)
processor.vecstore.insert.assert_not_called()
@pytest.mark.asyncio
async def test_store_document_embeddings_different_vector_dimensions(self, processor):
"""Test storing document embeddings with different vector dimensions"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
chunk = ChunkEmbeddings(
chunk=b"Document with mixed dimensions",
vectors=[
[0.1, 0.2], # 2D vector
[0.3, 0.4, 0.5, 0.6], # 4D vector
[0.7, 0.8, 0.9] # 3D vector
]
)
message.chunks = [chunk]
await processor.store_document_embeddings(message)
# Verify all vectors were inserted regardless of dimension
expected_calls = [
([0.1, 0.2], "Document with mixed dimensions"),
([0.3, 0.4, 0.5, 0.6], "Document with mixed dimensions"),
([0.7, 0.8, 0.9], "Document with mixed dimensions"),
]
assert processor.vecstore.insert.call_count == 3
for i, (expected_vec, expected_doc) in enumerate(expected_calls):
actual_call = processor.vecstore.insert.call_args_list[i]
assert actual_call[0][0] == expected_vec
assert actual_call[0][1] == expected_doc
@pytest.mark.asyncio
async def test_store_document_embeddings_unicode_content(self, processor):
"""Test storing document embeddings with Unicode content"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
chunk = ChunkEmbeddings(
chunk="Document with Unicode: éñ中文🚀".encode('utf-8'),
vectors=[[0.1, 0.2, 0.3]]
)
message.chunks = [chunk]
await processor.store_document_embeddings(message)
# Verify Unicode content was properly decoded and inserted
processor.vecstore.insert.assert_called_once_with(
[0.1, 0.2, 0.3], "Document with Unicode: éñ中文🚀"
)
@pytest.mark.asyncio
async def test_store_document_embeddings_large_chunks(self, processor):
"""Test storing document embeddings with large document chunks"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
# Create a large document chunk
large_content = "A" * 10000 # 10KB of content
chunk = ChunkEmbeddings(
chunk=large_content.encode('utf-8'),
vectors=[[0.1, 0.2, 0.3]]
)
message.chunks = [chunk]
await processor.store_document_embeddings(message)
# Verify large content was inserted
processor.vecstore.insert.assert_called_once_with(
[0.1, 0.2, 0.3], large_content
)
@pytest.mark.asyncio
async def test_store_document_embeddings_whitespace_only_chunk(self, processor):
"""Test storing document embeddings with whitespace-only chunk"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
chunk = ChunkEmbeddings(
chunk=b" \n\t ",
vectors=[[0.1, 0.2, 0.3]]
)
message.chunks = [chunk]
await processor.store_document_embeddings(message)
# Verify whitespace content was inserted (not filtered out)
processor.vecstore.insert.assert_called_once_with(
[0.1, 0.2, 0.3], " \n\t "
)
def test_add_args_method(self):
"""Test that add_args properly configures argument parser"""
from argparse import ArgumentParser
from unittest.mock import patch
parser = ArgumentParser()
# Mock the parent class add_args method
with patch('trustgraph.storage.doc_embeddings.milvus.write.DocumentEmbeddingsStoreService.add_args') as mock_parent_add_args:
Processor.add_args(parser)
# Verify parent add_args was called
mock_parent_add_args.assert_called_once()
# Verify our specific arguments were added
# Parse empty args to check defaults
args = parser.parse_args([])
assert hasattr(args, 'store_uri')
assert args.store_uri == 'http://localhost:19530'
def test_add_args_with_custom_values(self):
"""Test add_args with custom command line values"""
from argparse import ArgumentParser
from unittest.mock import patch
parser = ArgumentParser()
with patch('trustgraph.storage.doc_embeddings.milvus.write.DocumentEmbeddingsStoreService.add_args'):
Processor.add_args(parser)
# Test parsing with custom values
args = parser.parse_args([
'--store-uri', 'http://custom-milvus:19530'
])
assert args.store_uri == 'http://custom-milvus:19530'
def test_add_args_short_form(self):
"""Test add_args with short form arguments"""
from argparse import ArgumentParser
from unittest.mock import patch
parser = ArgumentParser()
with patch('trustgraph.storage.doc_embeddings.milvus.write.DocumentEmbeddingsStoreService.add_args'):
Processor.add_args(parser)
# Test parsing with short form
args = parser.parse_args(['-t', 'http://short-milvus:19530'])
assert args.store_uri == 'http://short-milvus:19530'
@patch('trustgraph.storage.doc_embeddings.milvus.write.Processor.launch')
def test_run_function(self, mock_launch):
"""Test the run function calls Processor.launch with correct parameters"""
from trustgraph.storage.doc_embeddings.milvus.write import run, default_ident
run()
mock_launch.assert_called_once_with(
default_ident,
"\nAccepts entity/vector pairs and writes them to a Milvus store.\n"
)

View file

@ -0,0 +1,536 @@
"""
Tests for Pinecone document embeddings storage service
"""
import pytest
from unittest.mock import MagicMock, patch
import uuid
from trustgraph.storage.doc_embeddings.pinecone.write import Processor
from trustgraph.schema import ChunkEmbeddings
class TestPineconeDocEmbeddingsStorageProcessor:
"""Test cases for Pinecone document embeddings storage processor"""
@pytest.fixture
def mock_message(self):
"""Create a mock message for testing"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
# Create test document embeddings
chunk1 = ChunkEmbeddings(
chunk=b"This is the first document chunk",
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
)
chunk2 = ChunkEmbeddings(
chunk=b"This is the second document chunk",
vectors=[[0.7, 0.8, 0.9]]
)
message.chunks = [chunk1, chunk2]
return message
@pytest.fixture
def processor(self):
"""Create a processor instance for testing"""
with patch('trustgraph.storage.doc_embeddings.pinecone.write.Pinecone') as mock_pinecone_class:
mock_pinecone = MagicMock()
mock_pinecone_class.return_value = mock_pinecone
processor = Processor(
taskgroup=MagicMock(),
id='test-pinecone-de-storage',
api_key='test-api-key'
)
return processor
@patch('trustgraph.storage.doc_embeddings.pinecone.write.Pinecone')
@patch('trustgraph.storage.doc_embeddings.pinecone.write.default_api_key', 'env-api-key')
def test_processor_initialization_with_defaults(self, mock_pinecone_class):
"""Test processor initialization with default parameters"""
mock_pinecone = MagicMock()
mock_pinecone_class.return_value = mock_pinecone
taskgroup_mock = MagicMock()
processor = Processor(taskgroup=taskgroup_mock)
mock_pinecone_class.assert_called_once_with(api_key='env-api-key')
assert processor.pinecone == mock_pinecone
assert processor.api_key == 'env-api-key'
assert processor.cloud == 'aws'
assert processor.region == 'us-east-1'
@patch('trustgraph.storage.doc_embeddings.pinecone.write.Pinecone')
def test_processor_initialization_with_custom_params(self, mock_pinecone_class):
"""Test processor initialization with custom parameters"""
mock_pinecone = MagicMock()
mock_pinecone_class.return_value = mock_pinecone
taskgroup_mock = MagicMock()
processor = Processor(
taskgroup=taskgroup_mock,
api_key='custom-api-key',
cloud='gcp',
region='us-west1'
)
mock_pinecone_class.assert_called_once_with(api_key='custom-api-key')
assert processor.api_key == 'custom-api-key'
assert processor.cloud == 'gcp'
assert processor.region == 'us-west1'
@patch('trustgraph.storage.doc_embeddings.pinecone.write.PineconeGRPC')
def test_processor_initialization_with_url(self, mock_pinecone_grpc_class):
"""Test processor initialization with custom URL (GRPC mode)"""
mock_pinecone = MagicMock()
mock_pinecone_grpc_class.return_value = mock_pinecone
taskgroup_mock = MagicMock()
processor = Processor(
taskgroup=taskgroup_mock,
api_key='test-api-key',
url='https://custom-host.pinecone.io'
)
mock_pinecone_grpc_class.assert_called_once_with(
api_key='test-api-key',
host='https://custom-host.pinecone.io'
)
assert processor.pinecone == mock_pinecone
assert processor.url == 'https://custom-host.pinecone.io'
@patch('trustgraph.storage.doc_embeddings.pinecone.write.default_api_key', 'not-specified')
def test_processor_initialization_missing_api_key(self):
"""Test processor initialization fails with missing API key"""
taskgroup_mock = MagicMock()
with pytest.raises(RuntimeError, match="Pinecone API key must be specified"):
Processor(taskgroup=taskgroup_mock)
@pytest.mark.asyncio
async def test_store_document_embeddings_single_chunk(self, processor):
"""Test storing document embeddings for a single chunk"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
chunk = ChunkEmbeddings(
chunk=b"Test document content",
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
)
message.chunks = [chunk]
# Mock index operations
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
processor.pinecone.has_index.return_value = True
with patch('uuid.uuid4', side_effect=['id1', 'id2']):
await processor.store_document_embeddings(message)
# Verify index name and operations
expected_index_name = "d-test_user-test_collection-3"
processor.pinecone.Index.assert_called_with(expected_index_name)
# Verify upsert was called for each vector
assert mock_index.upsert.call_count == 2
# Check first vector upsert
first_call = mock_index.upsert.call_args_list[0]
first_vectors = first_call[1]['vectors']
assert len(first_vectors) == 1
assert first_vectors[0]['id'] == 'id1'
assert first_vectors[0]['values'] == [0.1, 0.2, 0.3]
assert first_vectors[0]['metadata']['doc'] == "Test document content"
# Check second vector upsert
second_call = mock_index.upsert.call_args_list[1]
second_vectors = second_call[1]['vectors']
assert len(second_vectors) == 1
assert second_vectors[0]['id'] == 'id2'
assert second_vectors[0]['values'] == [0.4, 0.5, 0.6]
assert second_vectors[0]['metadata']['doc'] == "Test document content"
@pytest.mark.asyncio
async def test_store_document_embeddings_multiple_chunks(self, processor, mock_message):
"""Test storing document embeddings for multiple chunks"""
# Mock index operations
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
processor.pinecone.has_index.return_value = True
with patch('uuid.uuid4', side_effect=['id1', 'id2', 'id3']):
await processor.store_document_embeddings(mock_message)
# Verify upsert was called for each vector (3 total)
assert mock_index.upsert.call_count == 3
# Verify document content in metadata
calls = mock_index.upsert.call_args_list
assert calls[0][1]['vectors'][0]['metadata']['doc'] == "This is the first document chunk"
assert calls[1][1]['vectors'][0]['metadata']['doc'] == "This is the first document chunk"
assert calls[2][1]['vectors'][0]['metadata']['doc'] == "This is the second document chunk"
@pytest.mark.asyncio
async def test_store_document_embeddings_index_creation(self, processor):
"""Test automatic index creation when index doesn't exist"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
chunk = ChunkEmbeddings(
chunk=b"Test document content",
vectors=[[0.1, 0.2, 0.3]]
)
message.chunks = [chunk]
# Mock index doesn't exist initially
processor.pinecone.has_index.return_value = False
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
# Mock index creation
processor.pinecone.describe_index.return_value.status = {"ready": True}
with patch('uuid.uuid4', return_value='test-id'):
await processor.store_document_embeddings(message)
# Verify index creation was called
expected_index_name = "d-test_user-test_collection-3"
processor.pinecone.create_index.assert_called_once()
create_call = processor.pinecone.create_index.call_args
assert create_call[1]['name'] == expected_index_name
assert create_call[1]['dimension'] == 3
assert create_call[1]['metric'] == "cosine"
@pytest.mark.asyncio
async def test_store_document_embeddings_empty_chunk(self, processor):
"""Test storing document embeddings with empty chunk (should be skipped)"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
chunk = ChunkEmbeddings(
chunk=b"",
vectors=[[0.1, 0.2, 0.3]]
)
message.chunks = [chunk]
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
await processor.store_document_embeddings(message)
# Verify no upsert was called for empty chunk
mock_index.upsert.assert_not_called()
@pytest.mark.asyncio
async def test_store_document_embeddings_none_chunk(self, processor):
"""Test storing document embeddings with None chunk (should be skipped)"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
chunk = ChunkEmbeddings(
chunk=None,
vectors=[[0.1, 0.2, 0.3]]
)
message.chunks = [chunk]
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
await processor.store_document_embeddings(message)
# Verify no upsert was called for None chunk
mock_index.upsert.assert_not_called()
@pytest.mark.asyncio
async def test_store_document_embeddings_empty_decoded_chunk(self, processor):
"""Test storing document embeddings with chunk that decodes to empty string"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
chunk = ChunkEmbeddings(
chunk=b"", # Empty bytes
vectors=[[0.1, 0.2, 0.3]]
)
message.chunks = [chunk]
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
await processor.store_document_embeddings(message)
# Verify no upsert was called for empty decoded chunk
mock_index.upsert.assert_not_called()
@pytest.mark.asyncio
async def test_store_document_embeddings_different_vector_dimensions(self, processor):
"""Test storing document embeddings with different vector dimensions"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
chunk = ChunkEmbeddings(
chunk=b"Document with mixed dimensions",
vectors=[
[0.1, 0.2], # 2D vector
[0.3, 0.4, 0.5, 0.6], # 4D vector
[0.7, 0.8, 0.9] # 3D vector
]
)
message.chunks = [chunk]
mock_index_2d = MagicMock()
mock_index_4d = MagicMock()
mock_index_3d = MagicMock()
def mock_index_side_effect(name):
if name.endswith("-2"):
return mock_index_2d
elif name.endswith("-4"):
return mock_index_4d
elif name.endswith("-3"):
return mock_index_3d
processor.pinecone.Index.side_effect = mock_index_side_effect
processor.pinecone.has_index.return_value = True
with patch('uuid.uuid4', side_effect=['id1', 'id2', 'id3']):
await processor.store_document_embeddings(message)
# Verify different indexes were used for different dimensions
assert processor.pinecone.Index.call_count == 3
mock_index_2d.upsert.assert_called_once()
mock_index_4d.upsert.assert_called_once()
mock_index_3d.upsert.assert_called_once()
@pytest.mark.asyncio
async def test_store_document_embeddings_empty_chunks_list(self, processor):
"""Test storing document embeddings with empty chunks list"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
message.chunks = []
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
await processor.store_document_embeddings(message)
# Verify no operations were performed
processor.pinecone.Index.assert_not_called()
mock_index.upsert.assert_not_called()
@pytest.mark.asyncio
async def test_store_document_embeddings_chunk_with_no_vectors(self, processor):
"""Test storing document embeddings for chunk with no vectors"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
chunk = ChunkEmbeddings(
chunk=b"Document with no vectors",
vectors=[]
)
message.chunks = [chunk]
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
await processor.store_document_embeddings(message)
# Verify no upsert was called (no vectors to insert)
mock_index.upsert.assert_not_called()
@pytest.mark.asyncio
async def test_store_document_embeddings_index_creation_failure(self, processor):
"""Test handling of index creation failure"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
chunk = ChunkEmbeddings(
chunk=b"Test document content",
vectors=[[0.1, 0.2, 0.3]]
)
message.chunks = [chunk]
# Mock index doesn't exist and creation fails
processor.pinecone.has_index.return_value = False
processor.pinecone.create_index.side_effect = Exception("Index creation failed")
with pytest.raises(Exception, match="Index creation failed"):
await processor.store_document_embeddings(message)
@pytest.mark.asyncio
async def test_store_document_embeddings_index_creation_timeout(self, processor):
"""Test handling of index creation timeout"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
chunk = ChunkEmbeddings(
chunk=b"Test document content",
vectors=[[0.1, 0.2, 0.3]]
)
message.chunks = [chunk]
# Mock index doesn't exist and never becomes ready
processor.pinecone.has_index.return_value = False
processor.pinecone.describe_index.return_value.status = {"ready": False}
with patch('time.sleep'): # Speed up the test
with pytest.raises(RuntimeError, match="Gave up waiting for index creation"):
await processor.store_document_embeddings(message)
@pytest.mark.asyncio
async def test_store_document_embeddings_unicode_content(self, processor):
"""Test storing document embeddings with Unicode content"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
chunk = ChunkEmbeddings(
chunk="Document with Unicode: éñ中文🚀".encode('utf-8'),
vectors=[[0.1, 0.2, 0.3]]
)
message.chunks = [chunk]
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
processor.pinecone.has_index.return_value = True
with patch('uuid.uuid4', return_value='test-id'):
await processor.store_document_embeddings(message)
# Verify Unicode content was properly decoded and stored
call_args = mock_index.upsert.call_args
stored_doc = call_args[1]['vectors'][0]['metadata']['doc']
assert stored_doc == "Document with Unicode: éñ中文🚀"
@pytest.mark.asyncio
async def test_store_document_embeddings_large_chunks(self, processor):
"""Test storing document embeddings with large document chunks"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
# Create a large document chunk
large_content = "A" * 10000 # 10KB of content
chunk = ChunkEmbeddings(
chunk=large_content.encode('utf-8'),
vectors=[[0.1, 0.2, 0.3]]
)
message.chunks = [chunk]
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
processor.pinecone.has_index.return_value = True
with patch('uuid.uuid4', return_value='test-id'):
await processor.store_document_embeddings(message)
# Verify large content was stored
call_args = mock_index.upsert.call_args
stored_doc = call_args[1]['vectors'][0]['metadata']['doc']
assert stored_doc == large_content
def test_add_args_method(self):
"""Test that add_args properly configures argument parser"""
from argparse import ArgumentParser
from unittest.mock import patch
parser = ArgumentParser()
# Mock the parent class add_args method
with patch('trustgraph.storage.doc_embeddings.pinecone.write.DocumentEmbeddingsStoreService.add_args') as mock_parent_add_args:
Processor.add_args(parser)
# Verify parent add_args was called
mock_parent_add_args.assert_called_once()
# Verify our specific arguments were added
args = parser.parse_args([])
assert hasattr(args, 'api_key')
assert args.api_key == 'not-specified' # Default value when no env var
assert hasattr(args, 'url')
assert args.url is None
assert hasattr(args, 'cloud')
assert args.cloud == 'aws'
assert hasattr(args, 'region')
assert args.region == 'us-east-1'
def test_add_args_with_custom_values(self):
"""Test add_args with custom command line values"""
from argparse import ArgumentParser
from unittest.mock import patch
parser = ArgumentParser()
with patch('trustgraph.storage.doc_embeddings.pinecone.write.DocumentEmbeddingsStoreService.add_args'):
Processor.add_args(parser)
# Test parsing with custom values
args = parser.parse_args([
'--api-key', 'custom-api-key',
'--url', 'https://custom-host.pinecone.io',
'--cloud', 'gcp',
'--region', 'us-west1'
])
assert args.api_key == 'custom-api-key'
assert args.url == 'https://custom-host.pinecone.io'
assert args.cloud == 'gcp'
assert args.region == 'us-west1'
def test_add_args_short_form(self):
"""Test add_args with short form arguments"""
from argparse import ArgumentParser
from unittest.mock import patch
parser = ArgumentParser()
with patch('trustgraph.storage.doc_embeddings.pinecone.write.DocumentEmbeddingsStoreService.add_args'):
Processor.add_args(parser)
# Test parsing with short form
args = parser.parse_args([
'-a', 'short-api-key',
'-u', 'https://short-host.pinecone.io'
])
assert args.api_key == 'short-api-key'
assert args.url == 'https://short-host.pinecone.io'
@patch('trustgraph.storage.doc_embeddings.pinecone.write.Processor.launch')
def test_run_function(self, mock_launch):
"""Test the run function calls Processor.launch with correct parameters"""
from trustgraph.storage.doc_embeddings.pinecone.write import run, default_ident
run()
mock_launch.assert_called_once_with(
default_ident,
"\nAccepts document chunks/vector pairs and writes them to a Pinecone store.\n"
)

View file

@ -0,0 +1,354 @@
"""
Tests for Milvus graph embeddings storage service
"""
import pytest
from unittest.mock import MagicMock, patch
from trustgraph.storage.graph_embeddings.milvus.write import Processor
from trustgraph.schema import Value, EntityEmbeddings
class TestMilvusGraphEmbeddingsStorageProcessor:
"""Test cases for Milvus graph embeddings storage processor"""
@pytest.fixture
def mock_message(self):
"""Create a mock message for testing"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
# Create test entities with embeddings
entity1 = EntityEmbeddings(
entity=Value(value='http://example.com/entity1', is_uri=True),
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
)
entity2 = EntityEmbeddings(
entity=Value(value='literal entity', is_uri=False),
vectors=[[0.7, 0.8, 0.9]]
)
message.entities = [entity1, entity2]
return message
@pytest.fixture
def processor(self):
"""Create a processor instance for testing"""
with patch('trustgraph.storage.graph_embeddings.milvus.write.EntityVectors') as mock_entity_vectors:
mock_vecstore = MagicMock()
mock_entity_vectors.return_value = mock_vecstore
processor = Processor(
taskgroup=MagicMock(),
id='test-milvus-ge-storage',
store_uri='http://localhost:19530'
)
return processor
@patch('trustgraph.storage.graph_embeddings.milvus.write.EntityVectors')
def test_processor_initialization_with_defaults(self, mock_entity_vectors):
"""Test processor initialization with default parameters"""
taskgroup_mock = MagicMock()
mock_vecstore = MagicMock()
mock_entity_vectors.return_value = mock_vecstore
processor = Processor(taskgroup=taskgroup_mock)
mock_entity_vectors.assert_called_once_with('http://localhost:19530')
assert processor.vecstore == mock_vecstore
@patch('trustgraph.storage.graph_embeddings.milvus.write.EntityVectors')
def test_processor_initialization_with_custom_params(self, mock_entity_vectors):
"""Test processor initialization with custom parameters"""
taskgroup_mock = MagicMock()
mock_vecstore = MagicMock()
mock_entity_vectors.return_value = mock_vecstore
processor = Processor(
taskgroup=taskgroup_mock,
store_uri='http://custom-milvus:19530'
)
mock_entity_vectors.assert_called_once_with('http://custom-milvus:19530')
assert processor.vecstore == mock_vecstore
@pytest.mark.asyncio
async def test_store_graph_embeddings_single_entity(self, processor):
"""Test storing graph embeddings for a single entity"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
entity = EntityEmbeddings(
entity=Value(value='http://example.com/entity', is_uri=True),
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
)
message.entities = [entity]
await processor.store_graph_embeddings(message)
# Verify insert was called for each vector
expected_calls = [
([0.1, 0.2, 0.3], 'http://example.com/entity'),
([0.4, 0.5, 0.6], 'http://example.com/entity'),
]
assert processor.vecstore.insert.call_count == 2
for i, (expected_vec, expected_entity) in enumerate(expected_calls):
actual_call = processor.vecstore.insert.call_args_list[i]
assert actual_call[0][0] == expected_vec
assert actual_call[0][1] == expected_entity
@pytest.mark.asyncio
async def test_store_graph_embeddings_multiple_entities(self, processor, mock_message):
"""Test storing graph embeddings for multiple entities"""
await processor.store_graph_embeddings(mock_message)
# Verify insert was called for each vector of each entity
expected_calls = [
# Entity 1 vectors
([0.1, 0.2, 0.3], 'http://example.com/entity1'),
([0.4, 0.5, 0.6], 'http://example.com/entity1'),
# Entity 2 vectors
([0.7, 0.8, 0.9], 'literal entity'),
]
assert processor.vecstore.insert.call_count == 3
for i, (expected_vec, expected_entity) in enumerate(expected_calls):
actual_call = processor.vecstore.insert.call_args_list[i]
assert actual_call[0][0] == expected_vec
assert actual_call[0][1] == expected_entity
@pytest.mark.asyncio
async def test_store_graph_embeddings_empty_entity_value(self, processor):
"""Test storing graph embeddings with empty entity value (should be skipped)"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
entity = EntityEmbeddings(
entity=Value(value='', is_uri=False),
vectors=[[0.1, 0.2, 0.3]]
)
message.entities = [entity]
await processor.store_graph_embeddings(message)
# Verify no insert was called for empty entity
processor.vecstore.insert.assert_not_called()
@pytest.mark.asyncio
async def test_store_graph_embeddings_none_entity_value(self, processor):
"""Test storing graph embeddings with None entity value (should be skipped)"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
entity = EntityEmbeddings(
entity=Value(value=None, is_uri=False),
vectors=[[0.1, 0.2, 0.3]]
)
message.entities = [entity]
await processor.store_graph_embeddings(message)
# Verify no insert was called for None entity
processor.vecstore.insert.assert_not_called()
@pytest.mark.asyncio
async def test_store_graph_embeddings_mixed_valid_invalid_entities(self, processor):
"""Test storing graph embeddings with mix of valid and invalid entities"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
valid_entity = EntityEmbeddings(
entity=Value(value='http://example.com/valid', is_uri=True),
vectors=[[0.1, 0.2, 0.3]]
)
empty_entity = EntityEmbeddings(
entity=Value(value='', is_uri=False),
vectors=[[0.4, 0.5, 0.6]]
)
none_entity = EntityEmbeddings(
entity=Value(value=None, is_uri=False),
vectors=[[0.7, 0.8, 0.9]]
)
message.entities = [valid_entity, empty_entity, none_entity]
await processor.store_graph_embeddings(message)
# Verify only valid entity was inserted
processor.vecstore.insert.assert_called_once_with(
[0.1, 0.2, 0.3], 'http://example.com/valid'
)
@pytest.mark.asyncio
async def test_store_graph_embeddings_empty_entities_list(self, processor):
"""Test storing graph embeddings with empty entities list"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
message.entities = []
await processor.store_graph_embeddings(message)
# Verify no insert was called
processor.vecstore.insert.assert_not_called()
@pytest.mark.asyncio
async def test_store_graph_embeddings_entity_with_no_vectors(self, processor):
"""Test storing graph embeddings for entity with no vectors"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
entity = EntityEmbeddings(
entity=Value(value='http://example.com/entity', is_uri=True),
vectors=[]
)
message.entities = [entity]
await processor.store_graph_embeddings(message)
# Verify no insert was called (no vectors to insert)
processor.vecstore.insert.assert_not_called()
@pytest.mark.asyncio
async def test_store_graph_embeddings_different_vector_dimensions(self, processor):
"""Test storing graph embeddings with different vector dimensions"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
entity = EntityEmbeddings(
entity=Value(value='http://example.com/entity', is_uri=True),
vectors=[
[0.1, 0.2], # 2D vector
[0.3, 0.4, 0.5, 0.6], # 4D vector
[0.7, 0.8, 0.9] # 3D vector
]
)
message.entities = [entity]
await processor.store_graph_embeddings(message)
# Verify all vectors were inserted regardless of dimension
expected_calls = [
([0.1, 0.2], 'http://example.com/entity'),
([0.3, 0.4, 0.5, 0.6], 'http://example.com/entity'),
([0.7, 0.8, 0.9], 'http://example.com/entity'),
]
assert processor.vecstore.insert.call_count == 3
for i, (expected_vec, expected_entity) in enumerate(expected_calls):
actual_call = processor.vecstore.insert.call_args_list[i]
assert actual_call[0][0] == expected_vec
assert actual_call[0][1] == expected_entity
@pytest.mark.asyncio
async def test_store_graph_embeddings_uri_and_literal_entities(self, processor):
"""Test storing graph embeddings for both URI and literal entities"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
uri_entity = EntityEmbeddings(
entity=Value(value='http://example.com/uri_entity', is_uri=True),
vectors=[[0.1, 0.2, 0.3]]
)
literal_entity = EntityEmbeddings(
entity=Value(value='literal entity text', is_uri=False),
vectors=[[0.4, 0.5, 0.6]]
)
message.entities = [uri_entity, literal_entity]
await processor.store_graph_embeddings(message)
# Verify both entities were inserted
expected_calls = [
([0.1, 0.2, 0.3], 'http://example.com/uri_entity'),
([0.4, 0.5, 0.6], 'literal entity text'),
]
assert processor.vecstore.insert.call_count == 2
for i, (expected_vec, expected_entity) in enumerate(expected_calls):
actual_call = processor.vecstore.insert.call_args_list[i]
assert actual_call[0][0] == expected_vec
assert actual_call[0][1] == expected_entity
def test_add_args_method(self):
"""Test that add_args properly configures argument parser"""
from argparse import ArgumentParser
from unittest.mock import patch
parser = ArgumentParser()
# Mock the parent class add_args method
with patch('trustgraph.storage.graph_embeddings.milvus.write.GraphEmbeddingsStoreService.add_args') as mock_parent_add_args:
Processor.add_args(parser)
# Verify parent add_args was called
mock_parent_add_args.assert_called_once()
# Verify our specific arguments were added
# Parse empty args to check defaults
args = parser.parse_args([])
assert hasattr(args, 'store_uri')
assert args.store_uri == 'http://localhost:19530'
def test_add_args_with_custom_values(self):
"""Test add_args with custom command line values"""
from argparse import ArgumentParser
from unittest.mock import patch
parser = ArgumentParser()
with patch('trustgraph.storage.graph_embeddings.milvus.write.GraphEmbeddingsStoreService.add_args'):
Processor.add_args(parser)
# Test parsing with custom values
args = parser.parse_args([
'--store-uri', 'http://custom-milvus:19530'
])
assert args.store_uri == 'http://custom-milvus:19530'
def test_add_args_short_form(self):
"""Test add_args with short form arguments"""
from argparse import ArgumentParser
from unittest.mock import patch
parser = ArgumentParser()
with patch('trustgraph.storage.graph_embeddings.milvus.write.GraphEmbeddingsStoreService.add_args'):
Processor.add_args(parser)
# Test parsing with short form
args = parser.parse_args(['-t', 'http://short-milvus:19530'])
assert args.store_uri == 'http://short-milvus:19530'
@patch('trustgraph.storage.graph_embeddings.milvus.write.Processor.launch')
def test_run_function(self, mock_launch):
"""Test the run function calls Processor.launch with correct parameters"""
from trustgraph.storage.graph_embeddings.milvus.write import run, default_ident
run()
mock_launch.assert_called_once_with(
default_ident,
"\nAccepts entity/vector pairs and writes them to a Milvus store.\n"
)

View file

@ -0,0 +1,460 @@
"""
Tests for Pinecone graph embeddings storage service
"""
import pytest
from unittest.mock import MagicMock, patch
import uuid
from trustgraph.storage.graph_embeddings.pinecone.write import Processor
from trustgraph.schema import EntityEmbeddings, Value
class TestPineconeGraphEmbeddingsStorageProcessor:
"""Test cases for Pinecone graph embeddings storage processor"""
@pytest.fixture
def mock_message(self):
"""Create a mock message for testing"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
# Create test entity embeddings
entity1 = EntityEmbeddings(
entity=Value(value="http://example.org/entity1", is_uri=True),
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
)
entity2 = EntityEmbeddings(
entity=Value(value="entity2", is_uri=False),
vectors=[[0.7, 0.8, 0.9]]
)
message.entities = [entity1, entity2]
return message
@pytest.fixture
def processor(self):
"""Create a processor instance for testing"""
with patch('trustgraph.storage.graph_embeddings.pinecone.write.Pinecone') as mock_pinecone_class:
mock_pinecone = MagicMock()
mock_pinecone_class.return_value = mock_pinecone
processor = Processor(
taskgroup=MagicMock(),
id='test-pinecone-ge-storage',
api_key='test-api-key'
)
return processor
@patch('trustgraph.storage.graph_embeddings.pinecone.write.Pinecone')
@patch('trustgraph.storage.graph_embeddings.pinecone.write.default_api_key', 'env-api-key')
def test_processor_initialization_with_defaults(self, mock_pinecone_class):
"""Test processor initialization with default parameters"""
mock_pinecone = MagicMock()
mock_pinecone_class.return_value = mock_pinecone
taskgroup_mock = MagicMock()
processor = Processor(taskgroup=taskgroup_mock)
mock_pinecone_class.assert_called_once_with(api_key='env-api-key')
assert processor.pinecone == mock_pinecone
assert processor.api_key == 'env-api-key'
assert processor.cloud == 'aws'
assert processor.region == 'us-east-1'
@patch('trustgraph.storage.graph_embeddings.pinecone.write.Pinecone')
def test_processor_initialization_with_custom_params(self, mock_pinecone_class):
"""Test processor initialization with custom parameters"""
mock_pinecone = MagicMock()
mock_pinecone_class.return_value = mock_pinecone
taskgroup_mock = MagicMock()
processor = Processor(
taskgroup=taskgroup_mock,
api_key='custom-api-key',
cloud='gcp',
region='us-west1'
)
mock_pinecone_class.assert_called_once_with(api_key='custom-api-key')
assert processor.api_key == 'custom-api-key'
assert processor.cloud == 'gcp'
assert processor.region == 'us-west1'
@patch('trustgraph.storage.graph_embeddings.pinecone.write.PineconeGRPC')
def test_processor_initialization_with_url(self, mock_pinecone_grpc_class):
"""Test processor initialization with custom URL (GRPC mode)"""
mock_pinecone = MagicMock()
mock_pinecone_grpc_class.return_value = mock_pinecone
taskgroup_mock = MagicMock()
processor = Processor(
taskgroup=taskgroup_mock,
api_key='test-api-key',
url='https://custom-host.pinecone.io'
)
mock_pinecone_grpc_class.assert_called_once_with(
api_key='test-api-key',
host='https://custom-host.pinecone.io'
)
assert processor.pinecone == mock_pinecone
assert processor.url == 'https://custom-host.pinecone.io'
@patch('trustgraph.storage.graph_embeddings.pinecone.write.default_api_key', 'not-specified')
def test_processor_initialization_missing_api_key(self):
"""Test processor initialization fails with missing API key"""
taskgroup_mock = MagicMock()
with pytest.raises(RuntimeError, match="Pinecone API key must be specified"):
Processor(taskgroup=taskgroup_mock)
@pytest.mark.asyncio
async def test_store_graph_embeddings_single_entity(self, processor):
"""Test storing graph embeddings for a single entity"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
entity = EntityEmbeddings(
entity=Value(value="http://example.org/entity1", is_uri=True),
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
)
message.entities = [entity]
# Mock index operations
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
processor.pinecone.has_index.return_value = True
with patch('uuid.uuid4', side_effect=['id1', 'id2']):
await processor.store_graph_embeddings(message)
# Verify index name and operations
expected_index_name = "t-test_user-test_collection-3"
processor.pinecone.Index.assert_called_with(expected_index_name)
# Verify upsert was called for each vector
assert mock_index.upsert.call_count == 2
# Check first vector upsert
first_call = mock_index.upsert.call_args_list[0]
first_vectors = first_call[1]['vectors']
assert len(first_vectors) == 1
assert first_vectors[0]['id'] == 'id1'
assert first_vectors[0]['values'] == [0.1, 0.2, 0.3]
assert first_vectors[0]['metadata']['entity'] == "http://example.org/entity1"
# Check second vector upsert
second_call = mock_index.upsert.call_args_list[1]
second_vectors = second_call[1]['vectors']
assert len(second_vectors) == 1
assert second_vectors[0]['id'] == 'id2'
assert second_vectors[0]['values'] == [0.4, 0.5, 0.6]
assert second_vectors[0]['metadata']['entity'] == "http://example.org/entity1"
@pytest.mark.asyncio
async def test_store_graph_embeddings_multiple_entities(self, processor, mock_message):
"""Test storing graph embeddings for multiple entities"""
# Mock index operations
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
processor.pinecone.has_index.return_value = True
with patch('uuid.uuid4', side_effect=['id1', 'id2', 'id3']):
await processor.store_graph_embeddings(mock_message)
# Verify upsert was called for each vector (3 total)
assert mock_index.upsert.call_count == 3
# Verify entity values in metadata
calls = mock_index.upsert.call_args_list
assert calls[0][1]['vectors'][0]['metadata']['entity'] == "http://example.org/entity1"
assert calls[1][1]['vectors'][0]['metadata']['entity'] == "http://example.org/entity1"
assert calls[2][1]['vectors'][0]['metadata']['entity'] == "entity2"
@pytest.mark.asyncio
async def test_store_graph_embeddings_index_creation(self, processor):
"""Test automatic index creation when index doesn't exist"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
entity = EntityEmbeddings(
entity=Value(value="test_entity", is_uri=False),
vectors=[[0.1, 0.2, 0.3]]
)
message.entities = [entity]
# Mock index doesn't exist initially
processor.pinecone.has_index.return_value = False
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
# Mock index creation
processor.pinecone.describe_index.return_value.status = {"ready": True}
with patch('uuid.uuid4', return_value='test-id'):
await processor.store_graph_embeddings(message)
# Verify index creation was called
expected_index_name = "t-test_user-test_collection-3"
processor.pinecone.create_index.assert_called_once()
create_call = processor.pinecone.create_index.call_args
assert create_call[1]['name'] == expected_index_name
assert create_call[1]['dimension'] == 3
assert create_call[1]['metric'] == "cosine"
@pytest.mark.asyncio
async def test_store_graph_embeddings_empty_entity_value(self, processor):
"""Test storing graph embeddings with empty entity value (should be skipped)"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
entity = EntityEmbeddings(
entity=Value(value="", is_uri=False),
vectors=[[0.1, 0.2, 0.3]]
)
message.entities = [entity]
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
await processor.store_graph_embeddings(message)
# Verify no upsert was called for empty entity
mock_index.upsert.assert_not_called()
@pytest.mark.asyncio
async def test_store_graph_embeddings_none_entity_value(self, processor):
"""Test storing graph embeddings with None entity value (should be skipped)"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
entity = EntityEmbeddings(
entity=Value(value=None, is_uri=False),
vectors=[[0.1, 0.2, 0.3]]
)
message.entities = [entity]
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
await processor.store_graph_embeddings(message)
# Verify no upsert was called for None entity
mock_index.upsert.assert_not_called()
@pytest.mark.asyncio
async def test_store_graph_embeddings_different_vector_dimensions(self, processor):
"""Test storing graph embeddings with different vector dimensions"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
entity = EntityEmbeddings(
entity=Value(value="test_entity", is_uri=False),
vectors=[
[0.1, 0.2], # 2D vector
[0.3, 0.4, 0.5, 0.6], # 4D vector
[0.7, 0.8, 0.9] # 3D vector
]
)
message.entities = [entity]
mock_index_2d = MagicMock()
mock_index_4d = MagicMock()
mock_index_3d = MagicMock()
def mock_index_side_effect(name):
if name.endswith("-2"):
return mock_index_2d
elif name.endswith("-4"):
return mock_index_4d
elif name.endswith("-3"):
return mock_index_3d
processor.pinecone.Index.side_effect = mock_index_side_effect
processor.pinecone.has_index.return_value = True
with patch('uuid.uuid4', side_effect=['id1', 'id2', 'id3']):
await processor.store_graph_embeddings(message)
# Verify different indexes were used for different dimensions
assert processor.pinecone.Index.call_count == 3
mock_index_2d.upsert.assert_called_once()
mock_index_4d.upsert.assert_called_once()
mock_index_3d.upsert.assert_called_once()
@pytest.mark.asyncio
async def test_store_graph_embeddings_empty_entities_list(self, processor):
"""Test storing graph embeddings with empty entities list"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
message.entities = []
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
await processor.store_graph_embeddings(message)
# Verify no operations were performed
processor.pinecone.Index.assert_not_called()
mock_index.upsert.assert_not_called()
@pytest.mark.asyncio
async def test_store_graph_embeddings_entity_with_no_vectors(self, processor):
"""Test storing graph embeddings for entity with no vectors"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
entity = EntityEmbeddings(
entity=Value(value="test_entity", is_uri=False),
vectors=[]
)
message.entities = [entity]
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
await processor.store_graph_embeddings(message)
# Verify no upsert was called (no vectors to insert)
mock_index.upsert.assert_not_called()
@pytest.mark.asyncio
async def test_store_graph_embeddings_index_creation_failure(self, processor):
"""Test handling of index creation failure"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
entity = EntityEmbeddings(
entity=Value(value="test_entity", is_uri=False),
vectors=[[0.1, 0.2, 0.3]]
)
message.entities = [entity]
# Mock index doesn't exist and creation fails
processor.pinecone.has_index.return_value = False
processor.pinecone.create_index.side_effect = Exception("Index creation failed")
with pytest.raises(Exception, match="Index creation failed"):
await processor.store_graph_embeddings(message)
@pytest.mark.asyncio
async def test_store_graph_embeddings_index_creation_timeout(self, processor):
"""Test handling of index creation timeout"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
entity = EntityEmbeddings(
entity=Value(value="test_entity", is_uri=False),
vectors=[[0.1, 0.2, 0.3]]
)
message.entities = [entity]
# Mock index doesn't exist and never becomes ready
processor.pinecone.has_index.return_value = False
processor.pinecone.describe_index.return_value.status = {"ready": False}
with patch('time.sleep'): # Speed up the test
with pytest.raises(RuntimeError, match="Gave up waiting for index creation"):
await processor.store_graph_embeddings(message)
def test_add_args_method(self):
"""Test that add_args properly configures argument parser"""
from argparse import ArgumentParser
from unittest.mock import patch
parser = ArgumentParser()
# Mock the parent class add_args method
with patch('trustgraph.storage.graph_embeddings.pinecone.write.GraphEmbeddingsStoreService.add_args') as mock_parent_add_args:
Processor.add_args(parser)
# Verify parent add_args was called
mock_parent_add_args.assert_called_once()
# Verify our specific arguments were added by parsing empty args
args = parser.parse_args([])
assert hasattr(args, 'api_key')
assert args.api_key == 'not-specified' # Default value when no env var
assert hasattr(args, 'url')
assert args.url is None
assert hasattr(args, 'cloud')
assert args.cloud == 'aws'
assert hasattr(args, 'region')
assert args.region == 'us-east-1'
def test_add_args_with_custom_values(self):
"""Test add_args with custom command line values"""
from argparse import ArgumentParser
from unittest.mock import patch
parser = ArgumentParser()
with patch('trustgraph.storage.graph_embeddings.pinecone.write.GraphEmbeddingsStoreService.add_args'):
Processor.add_args(parser)
# Test parsing with custom values
args = parser.parse_args([
'--api-key', 'custom-api-key',
'--url', 'https://custom-host.pinecone.io',
'--cloud', 'gcp',
'--region', 'us-west1'
])
assert args.api_key == 'custom-api-key'
assert args.url == 'https://custom-host.pinecone.io'
assert args.cloud == 'gcp'
assert args.region == 'us-west1'
def test_add_args_short_form(self):
"""Test add_args with short form arguments"""
from argparse import ArgumentParser
from unittest.mock import patch
parser = ArgumentParser()
with patch('trustgraph.storage.graph_embeddings.pinecone.write.GraphEmbeddingsStoreService.add_args'):
Processor.add_args(parser)
# Test parsing with short form
args = parser.parse_args([
'-a', 'short-api-key',
'-u', 'https://short-host.pinecone.io'
])
assert args.api_key == 'short-api-key'
assert args.url == 'https://short-host.pinecone.io'
@patch('trustgraph.storage.graph_embeddings.pinecone.write.Processor.launch')
def test_run_function(self, mock_launch):
"""Test the run function calls Processor.launch with correct parameters"""
from trustgraph.storage.graph_embeddings.pinecone.write import run, default_ident
run()
mock_launch.assert_called_once_with(
default_ident,
"\nAccepts entity/vector pairs and writes them to a Pinecone store.\n"
)

View file

@ -0,0 +1,436 @@
"""
Tests for FalkorDB triples storage service
"""
import pytest
from unittest.mock import MagicMock, patch
from trustgraph.storage.triples.falkordb.write import Processor
from trustgraph.schema import Value, Triple
class TestFalkorDBStorageProcessor:
"""Test cases for FalkorDB storage processor"""
@pytest.fixture
def mock_message(self):
"""Create a mock message for testing"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
# Create a test triple
triple = Triple(
s=Value(value='http://example.com/subject', is_uri=True),
p=Value(value='http://example.com/predicate', is_uri=True),
o=Value(value='literal object', is_uri=False)
)
message.triples = [triple]
return message
@pytest.fixture
def processor(self):
"""Create a processor instance for testing"""
with patch('trustgraph.storage.triples.falkordb.write.FalkorDB') as mock_falkordb:
mock_client = MagicMock()
mock_graph = MagicMock()
mock_falkordb.from_url.return_value = mock_client
mock_client.select_graph.return_value = mock_graph
return Processor(
taskgroup=MagicMock(),
id='test-falkordb-storage',
graph_url='falkor://localhost:6379',
database='test_db'
)
@patch('trustgraph.storage.triples.falkordb.write.FalkorDB')
def test_processor_initialization_with_defaults(self, mock_falkordb):
"""Test processor initialization with default parameters"""
taskgroup_mock = MagicMock()
mock_client = MagicMock()
mock_graph = MagicMock()
mock_falkordb.from_url.return_value = mock_client
mock_client.select_graph.return_value = mock_graph
processor = Processor(taskgroup=taskgroup_mock)
assert processor.db == 'falkordb'
mock_falkordb.from_url.assert_called_once_with('falkor://falkordb:6379')
mock_client.select_graph.assert_called_once_with('falkordb')
@patch('trustgraph.storage.triples.falkordb.write.FalkorDB')
def test_processor_initialization_with_custom_params(self, mock_falkordb):
"""Test processor initialization with custom parameters"""
taskgroup_mock = MagicMock()
mock_client = MagicMock()
mock_graph = MagicMock()
mock_falkordb.from_url.return_value = mock_client
mock_client.select_graph.return_value = mock_graph
processor = Processor(
taskgroup=taskgroup_mock,
graph_url='falkor://custom:6379',
database='custom_db'
)
assert processor.db == 'custom_db'
mock_falkordb.from_url.assert_called_once_with('falkor://custom:6379')
mock_client.select_graph.assert_called_once_with('custom_db')
def test_create_node(self, processor):
"""Test node creation"""
test_uri = 'http://example.com/node'
mock_result = MagicMock()
mock_result.nodes_created = 1
mock_result.run_time_ms = 10
processor.io.query.return_value = mock_result
processor.create_node(test_uri)
processor.io.query.assert_called_once_with(
"MERGE (n:Node {uri: $uri})",
params={
"uri": test_uri,
},
)
def test_create_literal(self, processor):
"""Test literal creation"""
test_value = 'test literal value'
mock_result = MagicMock()
mock_result.nodes_created = 1
mock_result.run_time_ms = 10
processor.io.query.return_value = mock_result
processor.create_literal(test_value)
processor.io.query.assert_called_once_with(
"MERGE (n:Literal {value: $value})",
params={
"value": test_value,
},
)
def test_relate_node(self, processor):
"""Test node-to-node relationship creation"""
src_uri = 'http://example.com/src'
pred_uri = 'http://example.com/pred'
dest_uri = 'http://example.com/dest'
mock_result = MagicMock()
mock_result.nodes_created = 0
mock_result.run_time_ms = 5
processor.io.query.return_value = mock_result
processor.relate_node(src_uri, pred_uri, dest_uri)
processor.io.query.assert_called_once_with(
"MATCH (src:Node {uri: $src}) "
"MATCH (dest:Node {uri: $dest}) "
"MERGE (src)-[:Rel {uri: $uri}]->(dest)",
params={
"src": src_uri,
"dest": dest_uri,
"uri": pred_uri,
},
)
def test_relate_literal(self, processor):
"""Test node-to-literal relationship creation"""
src_uri = 'http://example.com/src'
pred_uri = 'http://example.com/pred'
literal_value = 'literal destination'
mock_result = MagicMock()
mock_result.nodes_created = 0
mock_result.run_time_ms = 5
processor.io.query.return_value = mock_result
processor.relate_literal(src_uri, pred_uri, literal_value)
processor.io.query.assert_called_once_with(
"MATCH (src:Node {uri: $src}) "
"MATCH (dest:Literal {value: $dest}) "
"MERGE (src)-[:Rel {uri: $uri}]->(dest)",
params={
"src": src_uri,
"dest": literal_value,
"uri": pred_uri,
},
)
@pytest.mark.asyncio
async def test_store_triples_with_uri_object(self, processor):
"""Test storing triple with URI object"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
triple = Triple(
s=Value(value='http://example.com/subject', is_uri=True),
p=Value(value='http://example.com/predicate', is_uri=True),
o=Value(value='http://example.com/object', is_uri=True)
)
message.triples = [triple]
mock_result = MagicMock()
mock_result.nodes_created = 1
mock_result.run_time_ms = 10
processor.io.query.return_value = mock_result
await processor.store_triples(message)
# Verify queries were called in the correct order
expected_calls = [
# Create subject node
(("MERGE (n:Node {uri: $uri})",), {"params": {"uri": "http://example.com/subject"}}),
# Create object node
(("MERGE (n:Node {uri: $uri})",), {"params": {"uri": "http://example.com/object"}}),
# Create relationship
(("MATCH (src:Node {uri: $src}) "
"MATCH (dest:Node {uri: $dest}) "
"MERGE (src)-[:Rel {uri: $uri}]->(dest)",),
{"params": {"src": "http://example.com/subject", "dest": "http://example.com/object", "uri": "http://example.com/predicate"}}),
]
assert processor.io.query.call_count == 3
for i, (expected_args, expected_kwargs) in enumerate(expected_calls):
actual_call = processor.io.query.call_args_list[i]
assert actual_call[0] == expected_args
assert actual_call[1] == expected_kwargs
@pytest.mark.asyncio
async def test_store_triples_with_literal_object(self, processor, mock_message):
"""Test storing triple with literal object"""
mock_result = MagicMock()
mock_result.nodes_created = 1
mock_result.run_time_ms = 10
processor.io.query.return_value = mock_result
await processor.store_triples(mock_message)
# Verify queries were called in the correct order
expected_calls = [
# Create subject node
(("MERGE (n:Node {uri: $uri})",), {"params": {"uri": "http://example.com/subject"}}),
# Create literal object
(("MERGE (n:Literal {value: $value})",), {"params": {"value": "literal object"}}),
# Create relationship
(("MATCH (src:Node {uri: $src}) "
"MATCH (dest:Literal {value: $dest}) "
"MERGE (src)-[:Rel {uri: $uri}]->(dest)",),
{"params": {"src": "http://example.com/subject", "dest": "literal object", "uri": "http://example.com/predicate"}}),
]
assert processor.io.query.call_count == 3
for i, (expected_args, expected_kwargs) in enumerate(expected_calls):
actual_call = processor.io.query.call_args_list[i]
assert actual_call[0] == expected_args
assert actual_call[1] == expected_kwargs
@pytest.mark.asyncio
async def test_store_triples_multiple_triples(self, processor):
"""Test storing multiple triples"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
triple1 = Triple(
s=Value(value='http://example.com/subject1', is_uri=True),
p=Value(value='http://example.com/predicate1', is_uri=True),
o=Value(value='literal object1', is_uri=False)
)
triple2 = Triple(
s=Value(value='http://example.com/subject2', is_uri=True),
p=Value(value='http://example.com/predicate2', is_uri=True),
o=Value(value='http://example.com/object2', is_uri=True)
)
message.triples = [triple1, triple2]
mock_result = MagicMock()
mock_result.nodes_created = 1
mock_result.run_time_ms = 10
processor.io.query.return_value = mock_result
await processor.store_triples(message)
# Verify total number of queries (3 per triple)
assert processor.io.query.call_count == 6
# Verify first triple operations
first_triple_calls = processor.io.query.call_args_list[0:3]
assert first_triple_calls[0][1]["params"]["uri"] == "http://example.com/subject1"
assert first_triple_calls[1][1]["params"]["value"] == "literal object1"
assert first_triple_calls[2][1]["params"]["src"] == "http://example.com/subject1"
# Verify second triple operations
second_triple_calls = processor.io.query.call_args_list[3:6]
assert second_triple_calls[0][1]["params"]["uri"] == "http://example.com/subject2"
assert second_triple_calls[1][1]["params"]["uri"] == "http://example.com/object2"
assert second_triple_calls[2][1]["params"]["src"] == "http://example.com/subject2"
@pytest.mark.asyncio
async def test_store_triples_empty_list(self, processor):
"""Test storing empty triples list"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
message.triples = []
await processor.store_triples(message)
# Verify no queries were made
processor.io.query.assert_not_called()
@pytest.mark.asyncio
async def test_store_triples_mixed_objects(self, processor):
"""Test storing triples with mixed URI and literal objects"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
triple1 = Triple(
s=Value(value='http://example.com/subject1', is_uri=True),
p=Value(value='http://example.com/predicate1', is_uri=True),
o=Value(value='literal object', is_uri=False)
)
triple2 = Triple(
s=Value(value='http://example.com/subject2', is_uri=True),
p=Value(value='http://example.com/predicate2', is_uri=True),
o=Value(value='http://example.com/object2', is_uri=True)
)
message.triples = [triple1, triple2]
mock_result = MagicMock()
mock_result.nodes_created = 1
mock_result.run_time_ms = 10
processor.io.query.return_value = mock_result
await processor.store_triples(message)
# Verify total number of queries (3 per triple)
assert processor.io.query.call_count == 6
# Verify first triple creates literal
assert "Literal" in processor.io.query.call_args_list[1][0][0]
assert processor.io.query.call_args_list[1][1]["params"]["value"] == "literal object"
# Verify second triple creates node
assert "Node" in processor.io.query.call_args_list[4][0][0]
assert processor.io.query.call_args_list[4][1]["params"]["uri"] == "http://example.com/object2"
def test_add_args_method(self):
"""Test that add_args properly configures argument parser"""
from argparse import ArgumentParser
from unittest.mock import patch
parser = ArgumentParser()
# Mock the parent class add_args method
with patch('trustgraph.storage.triples.falkordb.write.TriplesStoreService.add_args') as mock_parent_add_args:
Processor.add_args(parser)
# Verify parent add_args was called
mock_parent_add_args.assert_called_once()
# Verify our specific arguments were added
# Parse empty args to check defaults
args = parser.parse_args([])
assert hasattr(args, 'graph_url')
assert args.graph_url == 'falkor://falkordb:6379'
assert hasattr(args, 'database')
assert args.database == 'falkordb'
def test_add_args_with_custom_values(self):
"""Test add_args with custom command line values"""
from argparse import ArgumentParser
from unittest.mock import patch
parser = ArgumentParser()
with patch('trustgraph.storage.triples.falkordb.write.TriplesStoreService.add_args'):
Processor.add_args(parser)
# Test parsing with custom values
args = parser.parse_args([
'--graph-url', 'falkor://custom:6379',
'--database', 'custom_db'
])
assert args.graph_url == 'falkor://custom:6379'
assert args.database == 'custom_db'
def test_add_args_short_form(self):
"""Test add_args with short form arguments"""
from argparse import ArgumentParser
from unittest.mock import patch
parser = ArgumentParser()
with patch('trustgraph.storage.triples.falkordb.write.TriplesStoreService.add_args'):
Processor.add_args(parser)
# Test parsing with short form
args = parser.parse_args(['-g', 'falkor://short:6379'])
assert args.graph_url == 'falkor://short:6379'
@patch('trustgraph.storage.triples.falkordb.write.Processor.launch')
def test_run_function(self, mock_launch):
"""Test the run function calls Processor.launch with correct parameters"""
from trustgraph.storage.triples.falkordb.write import run, default_ident
run()
mock_launch.assert_called_once_with(
default_ident,
"\nGraph writer. Input is graph edge. Writes edges to FalkorDB graph.\n"
)
def test_create_node_with_special_characters(self, processor):
"""Test node creation with special characters in URI"""
test_uri = 'http://example.com/node with spaces & symbols'
mock_result = MagicMock()
mock_result.nodes_created = 1
mock_result.run_time_ms = 10
processor.io.query.return_value = mock_result
processor.create_node(test_uri)
processor.io.query.assert_called_once_with(
"MERGE (n:Node {uri: $uri})",
params={
"uri": test_uri,
},
)
def test_create_literal_with_special_characters(self, processor):
"""Test literal creation with special characters"""
test_value = 'literal with "quotes" and \n newlines'
mock_result = MagicMock()
mock_result.nodes_created = 1
mock_result.run_time_ms = 10
processor.io.query.return_value = mock_result
processor.create_literal(test_value)
processor.io.query.assert_called_once_with(
"MERGE (n:Literal {value: $value})",
params={
"value": test_value,
},
)

View file

@ -0,0 +1,441 @@
"""
Tests for Memgraph triples storage service
"""
import pytest
from unittest.mock import MagicMock, patch
from trustgraph.storage.triples.memgraph.write import Processor
from trustgraph.schema import Value, Triple
class TestMemgraphStorageProcessor:
"""Test cases for Memgraph storage processor"""
@pytest.fixture
def mock_message(self):
"""Create a mock message for testing"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
# Create a test triple
triple = Triple(
s=Value(value='http://example.com/subject', is_uri=True),
p=Value(value='http://example.com/predicate', is_uri=True),
o=Value(value='literal object', is_uri=False)
)
message.triples = [triple]
return message
@pytest.fixture
def processor(self):
"""Create a processor instance for testing"""
with patch('trustgraph.storage.triples.memgraph.write.GraphDatabase') as mock_graph_db:
mock_driver = MagicMock()
mock_session = MagicMock()
mock_graph_db.driver.return_value = mock_driver
mock_driver.session.return_value.__enter__.return_value = mock_session
return Processor(
taskgroup=MagicMock(),
id='test-memgraph-storage',
graph_host='bolt://localhost:7687',
username='test_user',
password='test_pass',
database='test_db'
)
@patch('trustgraph.storage.triples.memgraph.write.GraphDatabase')
def test_processor_initialization_with_defaults(self, mock_graph_db):
"""Test processor initialization with default parameters"""
taskgroup_mock = MagicMock()
mock_driver = MagicMock()
mock_session = MagicMock()
mock_graph_db.driver.return_value = mock_driver
mock_driver.session.return_value.__enter__.return_value = mock_session
processor = Processor(taskgroup=taskgroup_mock)
assert processor.db == 'memgraph'
mock_graph_db.driver.assert_called_once_with(
'bolt://memgraph:7687',
auth=('memgraph', 'password')
)
@patch('trustgraph.storage.triples.memgraph.write.GraphDatabase')
def test_processor_initialization_with_custom_params(self, mock_graph_db):
"""Test processor initialization with custom parameters"""
taskgroup_mock = MagicMock()
mock_driver = MagicMock()
mock_session = MagicMock()
mock_graph_db.driver.return_value = mock_driver
mock_driver.session.return_value.__enter__.return_value = mock_session
processor = Processor(
taskgroup=taskgroup_mock,
graph_host='bolt://custom:7687',
username='custom_user',
password='custom_pass',
database='custom_db'
)
assert processor.db == 'custom_db'
mock_graph_db.driver.assert_called_once_with(
'bolt://custom:7687',
auth=('custom_user', 'custom_pass')
)
@patch('trustgraph.storage.triples.memgraph.write.GraphDatabase')
def test_create_indexes_success(self, mock_graph_db):
"""Test successful index creation"""
taskgroup_mock = MagicMock()
mock_driver = MagicMock()
mock_session = MagicMock()
mock_graph_db.driver.return_value = mock_driver
mock_driver.session.return_value.__enter__.return_value = mock_session
processor = Processor(taskgroup=taskgroup_mock)
# Verify index creation calls
expected_calls = [
"CREATE INDEX ON :Node",
"CREATE INDEX ON :Node(uri)",
"CREATE INDEX ON :Literal",
"CREATE INDEX ON :Literal(value)"
]
assert mock_session.run.call_count == len(expected_calls)
for i, expected_call in enumerate(expected_calls):
actual_call = mock_session.run.call_args_list[i][0][0]
assert actual_call == expected_call
@patch('trustgraph.storage.triples.memgraph.write.GraphDatabase')
def test_create_indexes_with_exceptions(self, mock_graph_db):
"""Test index creation with exceptions (should be ignored)"""
taskgroup_mock = MagicMock()
mock_driver = MagicMock()
mock_session = MagicMock()
mock_graph_db.driver.return_value = mock_driver
mock_driver.session.return_value.__enter__.return_value = mock_session
# Make all index creation calls raise exceptions
mock_session.run.side_effect = Exception("Index already exists")
# Should not raise an exception
processor = Processor(taskgroup=taskgroup_mock)
# Verify all index creation calls were attempted
assert mock_session.run.call_count == 4
def test_create_node(self, processor):
"""Test node creation"""
test_uri = 'http://example.com/node'
mock_result = MagicMock()
mock_summary = MagicMock()
mock_summary.counters.nodes_created = 1
mock_summary.result_available_after = 10
mock_result.summary = mock_summary
processor.io.execute_query.return_value = mock_result
processor.create_node(test_uri)
processor.io.execute_query.assert_called_once_with(
"MERGE (n:Node {uri: $uri})",
uri=test_uri,
database_=processor.db
)
def test_create_literal(self, processor):
"""Test literal creation"""
test_value = 'test literal value'
mock_result = MagicMock()
mock_summary = MagicMock()
mock_summary.counters.nodes_created = 1
mock_summary.result_available_after = 10
mock_result.summary = mock_summary
processor.io.execute_query.return_value = mock_result
processor.create_literal(test_value)
processor.io.execute_query.assert_called_once_with(
"MERGE (n:Literal {value: $value})",
value=test_value,
database_=processor.db
)
def test_relate_node(self, processor):
"""Test node-to-node relationship creation"""
src_uri = 'http://example.com/src'
pred_uri = 'http://example.com/pred'
dest_uri = 'http://example.com/dest'
mock_result = MagicMock()
mock_summary = MagicMock()
mock_summary.counters.nodes_created = 0
mock_summary.result_available_after = 5
mock_result.summary = mock_summary
processor.io.execute_query.return_value = mock_result
processor.relate_node(src_uri, pred_uri, dest_uri)
processor.io.execute_query.assert_called_once_with(
"MATCH (src:Node {uri: $src}) "
"MATCH (dest:Node {uri: $dest}) "
"MERGE (src)-[:Rel {uri: $uri}]->(dest)",
src=src_uri, dest=dest_uri, uri=pred_uri,
database_=processor.db
)
def test_relate_literal(self, processor):
"""Test node-to-literal relationship creation"""
src_uri = 'http://example.com/src'
pred_uri = 'http://example.com/pred'
literal_value = 'literal destination'
mock_result = MagicMock()
mock_summary = MagicMock()
mock_summary.counters.nodes_created = 0
mock_summary.result_available_after = 5
mock_result.summary = mock_summary
processor.io.execute_query.return_value = mock_result
processor.relate_literal(src_uri, pred_uri, literal_value)
processor.io.execute_query.assert_called_once_with(
"MATCH (src:Node {uri: $src}) "
"MATCH (dest:Literal {value: $dest}) "
"MERGE (src)-[:Rel {uri: $uri}]->(dest)",
src=src_uri, dest=literal_value, uri=pred_uri,
database_=processor.db
)
def test_create_triple_with_uri_object(self, processor):
"""Test triple creation with URI object"""
mock_tx = MagicMock()
triple = Triple(
s=Value(value='http://example.com/subject', is_uri=True),
p=Value(value='http://example.com/predicate', is_uri=True),
o=Value(value='http://example.com/object', is_uri=True)
)
processor.create_triple(mock_tx, triple)
# Verify transaction calls
expected_calls = [
# Create subject node
("MERGE (n:Node {uri: $uri})", {'uri': 'http://example.com/subject'}),
# Create object node
("MERGE (n:Node {uri: $uri})", {'uri': 'http://example.com/object'}),
# Create relationship
("MATCH (src:Node {uri: $src}) "
"MATCH (dest:Node {uri: $dest}) "
"MERGE (src)-[:Rel {uri: $uri}]->(dest)",
{'src': 'http://example.com/subject', 'dest': 'http://example.com/object', 'uri': 'http://example.com/predicate'})
]
assert mock_tx.run.call_count == 3
for i, (expected_query, expected_params) in enumerate(expected_calls):
actual_call = mock_tx.run.call_args_list[i]
assert actual_call[0][0] == expected_query
assert actual_call[1] == expected_params
def test_create_triple_with_literal_object(self, processor):
"""Test triple creation with literal object"""
mock_tx = MagicMock()
triple = Triple(
s=Value(value='http://example.com/subject', is_uri=True),
p=Value(value='http://example.com/predicate', is_uri=True),
o=Value(value='literal object', is_uri=False)
)
processor.create_triple(mock_tx, triple)
# Verify transaction calls
expected_calls = [
# Create subject node
("MERGE (n:Node {uri: $uri})", {'uri': 'http://example.com/subject'}),
# Create literal object
("MERGE (n:Literal {value: $value})", {'value': 'literal object'}),
# Create relationship
("MATCH (src:Node {uri: $src}) "
"MATCH (dest:Literal {value: $dest}) "
"MERGE (src)-[:Rel {uri: $uri}]->(dest)",
{'src': 'http://example.com/subject', 'dest': 'literal object', 'uri': 'http://example.com/predicate'})
]
assert mock_tx.run.call_count == 3
for i, (expected_query, expected_params) in enumerate(expected_calls):
actual_call = mock_tx.run.call_args_list[i]
assert actual_call[0][0] == expected_query
assert actual_call[1] == expected_params
@pytest.mark.asyncio
async def test_store_triples_single_triple(self, processor, mock_message):
"""Test storing a single triple"""
mock_session = MagicMock()
processor.io.session.return_value.__enter__.return_value = mock_session
# Reset the mock to clear the initialization call
processor.io.session.reset_mock()
await processor.store_triples(mock_message)
# Verify session was created with correct database
processor.io.session.assert_called_once_with(database=processor.db)
# Verify execute_write was called once per triple
mock_session.execute_write.assert_called_once()
# Verify the triple was passed to create_triple
call_args = mock_session.execute_write.call_args
assert call_args[0][0] == processor.create_triple
assert call_args[0][1] == mock_message.triples[0]
@pytest.mark.asyncio
async def test_store_triples_multiple_triples(self, processor):
"""Test storing multiple triples"""
mock_session = MagicMock()
processor.io.session.return_value.__enter__.return_value = mock_session
# Reset the mock to clear the initialization call
processor.io.session.reset_mock()
# Create message with multiple triples
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
triple1 = Triple(
s=Value(value='http://example.com/subject1', is_uri=True),
p=Value(value='http://example.com/predicate1', is_uri=True),
o=Value(value='literal object1', is_uri=False)
)
triple2 = Triple(
s=Value(value='http://example.com/subject2', is_uri=True),
p=Value(value='http://example.com/predicate2', is_uri=True),
o=Value(value='http://example.com/object2', is_uri=True)
)
message.triples = [triple1, triple2]
await processor.store_triples(message)
# Verify session was called twice (once per triple)
assert processor.io.session.call_count == 2
# Verify execute_write was called once per triple
assert mock_session.execute_write.call_count == 2
# Verify each triple was processed
call_args_list = mock_session.execute_write.call_args_list
assert call_args_list[0][0][1] == triple1
assert call_args_list[1][0][1] == triple2
@pytest.mark.asyncio
async def test_store_triples_empty_list(self, processor):
"""Test storing empty triples list"""
mock_session = MagicMock()
processor.io.session.return_value.__enter__.return_value = mock_session
# Reset the mock to clear the initialization call
processor.io.session.reset_mock()
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
message.triples = []
await processor.store_triples(message)
# Verify no session calls were made (no triples to process)
processor.io.session.assert_not_called()
# Verify no execute_write calls were made
mock_session.execute_write.assert_not_called()
def test_add_args_method(self):
"""Test that add_args properly configures argument parser"""
from argparse import ArgumentParser
from unittest.mock import patch
parser = ArgumentParser()
# Mock the parent class add_args method
with patch('trustgraph.storage.triples.memgraph.write.TriplesStoreService.add_args') as mock_parent_add_args:
Processor.add_args(parser)
# Verify parent add_args was called
mock_parent_add_args.assert_called_once()
# Verify our specific arguments were added
# Parse empty args to check defaults
args = parser.parse_args([])
assert hasattr(args, 'graph_host')
assert args.graph_host == 'bolt://memgraph:7687'
assert hasattr(args, 'username')
assert args.username == 'memgraph'
assert hasattr(args, 'password')
assert args.password == 'password'
assert hasattr(args, 'database')
assert args.database == 'memgraph'
def test_add_args_with_custom_values(self):
"""Test add_args with custom command line values"""
from argparse import ArgumentParser
from unittest.mock import patch
parser = ArgumentParser()
with patch('trustgraph.storage.triples.memgraph.write.TriplesStoreService.add_args'):
Processor.add_args(parser)
# Test parsing with custom values
args = parser.parse_args([
'--graph-host', 'bolt://custom:7687',
'--username', 'custom_user',
'--password', 'custom_pass',
'--database', 'custom_db'
])
assert args.graph_host == 'bolt://custom:7687'
assert args.username == 'custom_user'
assert args.password == 'custom_pass'
assert args.database == 'custom_db'
def test_add_args_short_form(self):
"""Test add_args with short form arguments"""
from argparse import ArgumentParser
from unittest.mock import patch
parser = ArgumentParser()
with patch('trustgraph.storage.triples.memgraph.write.TriplesStoreService.add_args'):
Processor.add_args(parser)
# Test parsing with short form
args = parser.parse_args(['-g', 'bolt://short:7687'])
assert args.graph_host == 'bolt://short:7687'
@patch('trustgraph.storage.triples.memgraph.write.Processor.launch')
def test_run_function(self, mock_launch):
"""Test the run function calls Processor.launch with correct parameters"""
from trustgraph.storage.triples.memgraph.write import run, default_ident
run()
mock_launch.assert_called_once_with(
default_ident,
"\nGraph writer. Input is graph edge. Writes edges to Memgraph.\n"
)

View file

@ -0,0 +1,548 @@
"""
Tests for Neo4j triples storage service
"""
import pytest
from unittest.mock import MagicMock, patch, AsyncMock
from trustgraph.storage.triples.neo4j.write import Processor
class TestNeo4jStorageProcessor:
"""Test cases for Neo4j storage processor"""
@patch('trustgraph.storage.triples.neo4j.write.GraphDatabase')
def test_processor_initialization_with_defaults(self, mock_graph_db):
"""Test processor initialization with default parameters"""
taskgroup_mock = MagicMock()
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
mock_session = MagicMock()
mock_driver.session.return_value.__enter__.return_value = mock_session
processor = Processor(taskgroup=taskgroup_mock)
assert processor.db == 'neo4j'
mock_graph_db.driver.assert_called_once_with(
'bolt://neo4j:7687',
auth=('neo4j', 'password')
)
@patch('trustgraph.storage.triples.neo4j.write.GraphDatabase')
def test_processor_initialization_with_custom_params(self, mock_graph_db):
"""Test processor initialization with custom parameters"""
taskgroup_mock = MagicMock()
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
mock_session = MagicMock()
mock_driver.session.return_value.__enter__.return_value = mock_session
processor = Processor(
taskgroup=taskgroup_mock,
graph_host='bolt://custom:7687',
username='testuser',
password='testpass',
database='testdb'
)
assert processor.db == 'testdb'
mock_graph_db.driver.assert_called_once_with(
'bolt://custom:7687',
auth=('testuser', 'testpass')
)
@patch('trustgraph.storage.triples.neo4j.write.GraphDatabase')
def test_create_indexes_success(self, mock_graph_db):
"""Test successful index creation"""
taskgroup_mock = MagicMock()
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
mock_session = MagicMock()
mock_driver.session.return_value.__enter__.return_value = mock_session
processor = Processor(taskgroup=taskgroup_mock)
# Verify index creation queries were executed
expected_calls = [
"CREATE INDEX Node_uri FOR (n:Node) ON (n.uri)",
"CREATE INDEX Literal_value FOR (n:Literal) ON (n.value)",
"CREATE INDEX Rel_uri FOR ()-[r:Rel]-() ON (r.uri)"
]
assert mock_session.run.call_count == 3
for expected_query in expected_calls:
mock_session.run.assert_any_call(expected_query)
@patch('trustgraph.storage.triples.neo4j.write.GraphDatabase')
def test_create_indexes_with_exceptions(self, mock_graph_db):
"""Test index creation with exceptions (should be ignored)"""
taskgroup_mock = MagicMock()
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
mock_session = MagicMock()
mock_driver.session.return_value.__enter__.return_value = mock_session
# Make session.run raise exceptions
mock_session.run.side_effect = Exception("Index already exists")
# Should not raise exception - they should be caught and ignored
processor = Processor(taskgroup=taskgroup_mock)
# Should have tried to create all 3 indexes despite exceptions
assert mock_session.run.call_count == 3
@patch('trustgraph.storage.triples.neo4j.write.GraphDatabase')
def test_create_node(self, mock_graph_db):
"""Test node creation"""
taskgroup_mock = MagicMock()
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
mock_session = MagicMock()
mock_driver.session.return_value.__enter__.return_value = mock_session
# Mock execute_query response
mock_result = MagicMock()
mock_summary = MagicMock()
mock_summary.counters.nodes_created = 1
mock_summary.result_available_after = 10
mock_result.summary = mock_summary
mock_driver.execute_query.return_value = mock_result
processor = Processor(taskgroup=taskgroup_mock)
# Test create_node
processor.create_node("http://example.com/node")
mock_driver.execute_query.assert_called_with(
"MERGE (n:Node {uri: $uri})",
uri="http://example.com/node",
database_="neo4j"
)
@patch('trustgraph.storage.triples.neo4j.write.GraphDatabase')
def test_create_literal(self, mock_graph_db):
"""Test literal creation"""
taskgroup_mock = MagicMock()
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
mock_session = MagicMock()
mock_driver.session.return_value.__enter__.return_value = mock_session
# Mock execute_query response
mock_result = MagicMock()
mock_summary = MagicMock()
mock_summary.counters.nodes_created = 1
mock_summary.result_available_after = 10
mock_result.summary = mock_summary
mock_driver.execute_query.return_value = mock_result
processor = Processor(taskgroup=taskgroup_mock)
# Test create_literal
processor.create_literal("literal value")
mock_driver.execute_query.assert_called_with(
"MERGE (n:Literal {value: $value})",
value="literal value",
database_="neo4j"
)
@patch('trustgraph.storage.triples.neo4j.write.GraphDatabase')
def test_relate_node(self, mock_graph_db):
"""Test node-to-node relationship creation"""
taskgroup_mock = MagicMock()
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
mock_session = MagicMock()
mock_driver.session.return_value.__enter__.return_value = mock_session
# Mock execute_query response
mock_result = MagicMock()
mock_summary = MagicMock()
mock_summary.counters.nodes_created = 0
mock_summary.result_available_after = 10
mock_result.summary = mock_summary
mock_driver.execute_query.return_value = mock_result
processor = Processor(taskgroup=taskgroup_mock)
# Test relate_node
processor.relate_node(
"http://example.com/subject",
"http://example.com/predicate",
"http://example.com/object"
)
mock_driver.execute_query.assert_called_with(
"MATCH (src:Node {uri: $src}) "
"MATCH (dest:Node {uri: $dest}) "
"MERGE (src)-[:Rel {uri: $uri}]->(dest)",
src="http://example.com/subject",
dest="http://example.com/object",
uri="http://example.com/predicate",
database_="neo4j"
)
@patch('trustgraph.storage.triples.neo4j.write.GraphDatabase')
def test_relate_literal(self, mock_graph_db):
"""Test node-to-literal relationship creation"""
taskgroup_mock = MagicMock()
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
mock_session = MagicMock()
mock_driver.session.return_value.__enter__.return_value = mock_session
# Mock execute_query response
mock_result = MagicMock()
mock_summary = MagicMock()
mock_summary.counters.nodes_created = 0
mock_summary.result_available_after = 10
mock_result.summary = mock_summary
mock_driver.execute_query.return_value = mock_result
processor = Processor(taskgroup=taskgroup_mock)
# Test relate_literal
processor.relate_literal(
"http://example.com/subject",
"http://example.com/predicate",
"literal value"
)
mock_driver.execute_query.assert_called_with(
"MATCH (src:Node {uri: $src}) "
"MATCH (dest:Literal {value: $dest}) "
"MERGE (src)-[:Rel {uri: $uri}]->(dest)",
src="http://example.com/subject",
dest="literal value",
uri="http://example.com/predicate",
database_="neo4j"
)
@patch('trustgraph.storage.triples.neo4j.write.GraphDatabase')
@pytest.mark.asyncio
async def test_handle_triples_with_uri_object(self, mock_graph_db):
"""Test handling triples message with URI object"""
taskgroup_mock = MagicMock()
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
mock_session = MagicMock()
mock_driver.session.return_value.__enter__.return_value = mock_session
# Mock execute_query response
mock_result = MagicMock()
mock_summary = MagicMock()
mock_summary.counters.nodes_created = 1
mock_summary.result_available_after = 10
mock_result.summary = mock_summary
mock_driver.execute_query.return_value = mock_result
processor = Processor(taskgroup=taskgroup_mock)
# Create mock triple with URI object
triple = MagicMock()
triple.s.value = "http://example.com/subject"
triple.p.value = "http://example.com/predicate"
triple.o.value = "http://example.com/object"
triple.o.is_uri = True
# Create mock message
mock_message = MagicMock()
mock_message.triples = [triple]
await processor.store_triples(mock_message)
# Verify create_node was called for subject and object
# Verify relate_node was called
expected_calls = [
# Subject node creation
(
"MERGE (n:Node {uri: $uri})",
{"uri": "http://example.com/subject", "database_": "neo4j"}
),
# Object node creation
(
"MERGE (n:Node {uri: $uri})",
{"uri": "http://example.com/object", "database_": "neo4j"}
),
# Relationship creation
(
"MATCH (src:Node {uri: $src}) "
"MATCH (dest:Node {uri: $dest}) "
"MERGE (src)-[:Rel {uri: $uri}]->(dest)",
{
"src": "http://example.com/subject",
"dest": "http://example.com/object",
"uri": "http://example.com/predicate",
"database_": "neo4j"
}
)
]
assert mock_driver.execute_query.call_count == 3
for expected_query, expected_params in expected_calls:
mock_driver.execute_query.assert_any_call(expected_query, **expected_params)
@patch('trustgraph.storage.triples.neo4j.write.GraphDatabase')
@pytest.mark.asyncio
async def test_store_triples_with_literal_object(self, mock_graph_db):
"""Test handling triples message with literal object"""
taskgroup_mock = MagicMock()
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
mock_session = MagicMock()
mock_driver.session.return_value.__enter__.return_value = mock_session
# Mock execute_query response
mock_result = MagicMock()
mock_summary = MagicMock()
mock_summary.counters.nodes_created = 1
mock_summary.result_available_after = 10
mock_result.summary = mock_summary
mock_driver.execute_query.return_value = mock_result
processor = Processor(taskgroup=taskgroup_mock)
# Create mock triple with literal object
triple = MagicMock()
triple.s.value = "http://example.com/subject"
triple.p.value = "http://example.com/predicate"
triple.o.value = "literal value"
triple.o.is_uri = False
# Create mock message
mock_message = MagicMock()
mock_message.triples = [triple]
await processor.store_triples(mock_message)
# Verify create_node was called for subject
# Verify create_literal was called for object
# Verify relate_literal was called
expected_calls = [
# Subject node creation
(
"MERGE (n:Node {uri: $uri})",
{"uri": "http://example.com/subject", "database_": "neo4j"}
),
# Literal creation
(
"MERGE (n:Literal {value: $value})",
{"value": "literal value", "database_": "neo4j"}
),
# Relationship creation
(
"MATCH (src:Node {uri: $src}) "
"MATCH (dest:Literal {value: $dest}) "
"MERGE (src)-[:Rel {uri: $uri}]->(dest)",
{
"src": "http://example.com/subject",
"dest": "literal value",
"uri": "http://example.com/predicate",
"database_": "neo4j"
}
)
]
assert mock_driver.execute_query.call_count == 3
for expected_query, expected_params in expected_calls:
mock_driver.execute_query.assert_any_call(expected_query, **expected_params)
@patch('trustgraph.storage.triples.neo4j.write.GraphDatabase')
@pytest.mark.asyncio
async def test_store_multiple_triples(self, mock_graph_db):
"""Test handling message with multiple triples"""
taskgroup_mock = MagicMock()
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
mock_session = MagicMock()
mock_driver.session.return_value.__enter__.return_value = mock_session
# Mock execute_query response
mock_result = MagicMock()
mock_summary = MagicMock()
mock_summary.counters.nodes_created = 1
mock_summary.result_available_after = 10
mock_result.summary = mock_summary
mock_driver.execute_query.return_value = mock_result
processor = Processor(taskgroup=taskgroup_mock)
# Create mock triples
triple1 = MagicMock()
triple1.s.value = "http://example.com/subject1"
triple1.p.value = "http://example.com/predicate1"
triple1.o.value = "http://example.com/object1"
triple1.o.is_uri = True
triple2 = MagicMock()
triple2.s.value = "http://example.com/subject2"
triple2.p.value = "http://example.com/predicate2"
triple2.o.value = "literal value"
triple2.o.is_uri = False
# Create mock message
mock_message = MagicMock()
mock_message.triples = [triple1, triple2]
await processor.store_triples(mock_message)
# Should have processed both triples
# Triple1: 2 nodes + 1 relationship = 3 calls
# Triple2: 1 node + 1 literal + 1 relationship = 3 calls
# Total: 6 calls
assert mock_driver.execute_query.call_count == 6
@patch('trustgraph.storage.triples.neo4j.write.GraphDatabase')
@pytest.mark.asyncio
async def test_store_empty_triples(self, mock_graph_db):
"""Test handling message with no triples"""
taskgroup_mock = MagicMock()
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
mock_session = MagicMock()
mock_driver.session.return_value.__enter__.return_value = mock_session
processor = Processor(taskgroup=taskgroup_mock)
# Create mock message with empty triples
mock_message = MagicMock()
mock_message.triples = []
await processor.store_triples(mock_message)
# Should not have made any execute_query calls beyond index creation
# Only index creation calls should have been made during initialization
mock_driver.execute_query.assert_not_called()
def test_add_args_method(self):
"""Test that add_args properly configures argument parser"""
from argparse import ArgumentParser
from unittest.mock import patch
parser = ArgumentParser()
# Mock the parent class add_args method
with patch('trustgraph.storage.triples.neo4j.write.TriplesStoreService.add_args') as mock_parent_add_args:
Processor.add_args(parser)
# Verify parent add_args was called
mock_parent_add_args.assert_called_once()
# Verify our specific arguments were added
# Parse empty args to check defaults
args = parser.parse_args([])
assert hasattr(args, 'graph_host')
assert args.graph_host == 'bolt://neo4j:7687'
assert hasattr(args, 'username')
assert args.username == 'neo4j'
assert hasattr(args, 'password')
assert args.password == 'password'
assert hasattr(args, 'database')
assert args.database == 'neo4j'
def test_add_args_with_custom_values(self):
"""Test add_args with custom command line values"""
from argparse import ArgumentParser
from unittest.mock import patch
parser = ArgumentParser()
with patch('trustgraph.storage.triples.neo4j.write.TriplesStoreService.add_args'):
Processor.add_args(parser)
# Test parsing with custom values
args = parser.parse_args([
'--graph_host', 'bolt://custom:7687',
'--username', 'testuser',
'--password', 'testpass',
'--database', 'testdb'
])
assert args.graph_host == 'bolt://custom:7687'
assert args.username == 'testuser'
assert args.password == 'testpass'
assert args.database == 'testdb'
def test_add_args_short_form(self):
"""Test add_args with short form arguments"""
from argparse import ArgumentParser
from unittest.mock import patch
parser = ArgumentParser()
with patch('trustgraph.storage.triples.neo4j.write.TriplesStoreService.add_args'):
Processor.add_args(parser)
# Test parsing with short form
args = parser.parse_args(['-g', 'bolt://short:7687'])
assert args.graph_host == 'bolt://short:7687'
@patch('trustgraph.storage.triples.neo4j.write.Processor.launch')
def test_run_function(self, mock_launch):
"""Test the run function calls Processor.launch with correct parameters"""
from trustgraph.storage.triples.neo4j.write import run, default_ident
run()
mock_launch.assert_called_once_with(
default_ident,
"\nGraph writer. Input is graph edge. Writes edges to Neo4j graph.\n"
)
@patch('trustgraph.storage.triples.neo4j.write.GraphDatabase')
@pytest.mark.asyncio
async def test_store_triples_with_special_characters(self, mock_graph_db):
"""Test handling triples with special characters and unicode"""
taskgroup_mock = MagicMock()
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
mock_session = MagicMock()
mock_driver.session.return_value.__enter__.return_value = mock_session
# Mock execute_query response
mock_result = MagicMock()
mock_summary = MagicMock()
mock_summary.counters.nodes_created = 1
mock_summary.result_available_after = 10
mock_result.summary = mock_summary
mock_driver.execute_query.return_value = mock_result
processor = Processor(taskgroup=taskgroup_mock)
# Create triple with special characters
triple = MagicMock()
triple.s.value = "http://example.com/subject with spaces"
triple.p.value = "http://example.com/predicate:with/symbols"
triple.o.value = 'literal with "quotes" and unicode: ñáéíóú'
triple.o.is_uri = False
mock_message = MagicMock()
mock_message.triples = [triple]
await processor.store_triples(mock_message)
# Verify the triple was processed with special characters preserved
mock_driver.execute_query.assert_any_call(
"MERGE (n:Node {uri: $uri})",
uri="http://example.com/subject with spaces",
database_="neo4j"
)
mock_driver.execute_query.assert_any_call(
"MERGE (n:Literal {value: $value})",
value='literal with "quotes" and unicode: ñáéíóú',
database_="neo4j"
)
mock_driver.execute_query.assert_any_call(
"MATCH (src:Node {uri: $src}) "
"MATCH (dest:Literal {value: $dest}) "
"MERGE (src)-[:Rel {uri: $uri}]->(dest)",
src="http://example.com/subject with spaces",
dest='literal with "quotes" and unicode: ñáéíóú',
uri="http://example.com/predicate:with/symbols",
database_="neo4j"
)