diff --git a/.github/workflows/pull-request.yaml b/.github/workflows/pull-request.yaml index feb4e52f..63732269 100644 --- a/.github/workflows/pull-request.yaml +++ b/.github/workflows/pull-request.yaml @@ -48,8 +48,8 @@ jobs: - name: Unit tests run: pytest tests/unit - - name: Integration tests - run: pytest tests/integration + - name: Integration tests (cut the out the long-running tests) + run: pytest tests/integration -m 'not slow' - name: Contract tests run: pytest tests/contract diff --git a/tests/integration/cassandra_test_helper.py b/tests/integration/cassandra_test_helper.py new file mode 100644 index 00000000..17cc6df6 --- /dev/null +++ b/tests/integration/cassandra_test_helper.py @@ -0,0 +1,112 @@ +""" +Helper for managing Cassandra containers in integration tests +Alternative to testcontainers for Fedora/Podman compatibility +""" + +import subprocess +import time +import socket +from contextlib import contextmanager +from cassandra.cluster import Cluster +from cassandra.policies import RetryPolicy + + +class CassandraTestContainer: + """Simple Cassandra container manager using Podman""" + + def __init__(self, image="docker.io/library/cassandra:4.1", port=9042): + self.image = image + self.port = port + self.container_name = f"test-cassandra-{int(time.time())}" + self.container_id = None + + def start(self): + """Start Cassandra container""" + # Remove any existing container with same name + subprocess.run([ + "podman", "rm", "-f", self.container_name + ], capture_output=True) + + # Start new container with faster startup options + result = subprocess.run([ + "podman", "run", "-d", + "--name", self.container_name, + "-p", f"{self.port}:9042", + "-e", "JVM_OPTS=-Dcassandra.skip_wait_for_gossip_to_settle=0", + self.image + ], capture_output=True, text=True) + + if result.returncode != 0: + raise RuntimeError(f"Failed to start container: {result.stderr}") + + self.container_id = result.stdout.strip() + + # Wait for Cassandra to be ready + self._wait_for_ready() + return self + + def stop(self): + """Stop and remove container""" + import time + if self.container_name: + # Small delay before stopping to ensure connections are closed + time.sleep(0.5) + subprocess.run([ + "podman", "rm", "-f", self.container_name + ], capture_output=True) + + def get_connection_host_port(self): + """Get host and port for connection""" + return "localhost", self.port + + def _wait_for_ready(self, timeout=120): + """Wait for Cassandra to be ready for CQL queries""" + start_time = time.time() + + print(f"Waiting for Cassandra to be ready on port {self.port}...") + + while time.time() - start_time < timeout: + try: + # First check if port is open + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.settimeout(1) + result = sock.connect_ex(("localhost", self.port)) + sock.close() + + if result == 0: + # Port is open, now try to connect with Cassandra driver + try: + cluster = Cluster(['localhost'], port=self.port) + cluster.connect_timeout = 5 + session = cluster.connect() + + # Try a simple query to verify Cassandra is ready + session.execute("SELECT release_version FROM system.local") + session.shutdown() + cluster.shutdown() + + print("Cassandra is ready!") + return + + except Exception as e: + print(f"Cassandra not ready yet: {e}") + pass + + except Exception as e: + print(f"Connection check failed: {e}") + pass + + time.sleep(3) + + raise RuntimeError(f"Cassandra not ready after {timeout} seconds") + + +@contextmanager +def cassandra_container(image="docker.io/library/cassandra:4.1", port=9042): + """Context manager for Cassandra container""" + container = CassandraTestContainer(image, port) + try: + container.start() + yield container + finally: + container.stop() diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index 61b9b1a8..0f47077c 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -383,4 +383,22 @@ def sample_kg_triples(): # Test markers for integration tests -pytestmark = pytest.mark.integration \ No newline at end of file +pytestmark = pytest.mark.integration + + +def pytest_sessionfinish(session, exitstatus): + """ + Called after whole test run finished, right before returning the exit status. + + This hook is used to ensure Cassandra driver threads have time to shut down + properly before pytest exits, preventing "cannot schedule new futures after + shutdown" errors. + """ + import time + import gc + + # Force garbage collection to clean up any remaining objects + gc.collect() + + # Give Cassandra driver threads more time to clean up + time.sleep(2) \ No newline at end of file diff --git a/tests/integration/test_cassandra_integration.py b/tests/integration/test_cassandra_integration.py new file mode 100644 index 00000000..ce9d7fd3 --- /dev/null +++ b/tests/integration/test_cassandra_integration.py @@ -0,0 +1,411 @@ +""" +Cassandra integration tests using Podman containers + +These tests verify end-to-end functionality of Cassandra storage and query processors +with real database instances. Compatible with Fedora Linux and Podman. + +Uses a single container for all tests to minimize startup time. +""" + +import pytest +import asyncio +import time +from unittest.mock import MagicMock + +from .cassandra_test_helper import cassandra_container +from trustgraph.direct.cassandra import TrustGraph +from trustgraph.storage.triples.cassandra.write import Processor as StorageProcessor +from trustgraph.query.triples.cassandra.service import Processor as QueryProcessor +from trustgraph.schema import Triple, Value, Metadata, Triples, TriplesQueryRequest + + +@pytest.mark.integration +@pytest.mark.slow +class TestCassandraIntegration: + """Integration tests for Cassandra using a single shared container""" + + @pytest.fixture(scope="class") + def cassandra_shared_container(self): + """Class-level fixture: single Cassandra container for all tests""" + with cassandra_container() as container: + yield container + + def setup_method(self): + """Track all created clients for cleanup""" + self.clients_to_close = [] + + def teardown_method(self): + """Clean up all Cassandra connections""" + import gc + + for client in self.clients_to_close: + try: + client.close() + except Exception: + pass # Ignore errors during cleanup + + # Clear the list and force garbage collection + self.clients_to_close.clear() + gc.collect() + + # Small delay to let threads finish + time.sleep(0.5) + + @pytest.mark.asyncio + async def test_complete_cassandra_integration(self, cassandra_shared_container): + """Complete integration test covering all Cassandra functionality""" + container = cassandra_shared_container + host, port = container.get_connection_host_port() + + print("=" * 60) + print("RUNNING COMPLETE CASSANDRA INTEGRATION TEST") + print("=" * 60) + + # ===================================================== + # Test 1: Basic TrustGraph Operations + # ===================================================== + print("\n1. Testing basic TrustGraph operations...") + + client = TrustGraph( + hosts=[host], + keyspace="test_basic", + table="test_table" + ) + self.clients_to_close.append(client) + + # Insert test data + client.insert("http://example.org/alice", "knows", "http://example.org/bob") + client.insert("http://example.org/alice", "age", "25") + client.insert("http://example.org/bob", "age", "30") + + # Test get_all + all_results = list(client.get_all(limit=10)) + assert len(all_results) == 3 + print(f"✓ Stored and retrieved {len(all_results)} triples") + + # Test get_s (subject query) + alice_results = list(client.get_s("http://example.org/alice", limit=10)) + assert len(alice_results) == 2 + alice_predicates = [r.p for r in alice_results] + assert "knows" in alice_predicates + assert "age" in alice_predicates + print("✓ Subject queries working") + + # Test get_p (predicate query) + age_results = list(client.get_p("age", limit=10)) + assert len(age_results) == 2 + age_subjects = [r.s for r in age_results] + assert "http://example.org/alice" in age_subjects + assert "http://example.org/bob" in age_subjects + print("✓ Predicate queries working") + + # ===================================================== + # Test 2: Storage Processor Integration + # ===================================================== + print("\n2. Testing storage processor integration...") + + storage_processor = StorageProcessor( + taskgroup=MagicMock(), + hosts=[host], + keyspace="test_storage", + table="test_triples" + ) + # Track the TrustGraph instance that will be created + self.storage_processor = storage_processor + + # Create test message + storage_message = Triples( + metadata=Metadata(user="testuser", collection="testcol"), + triples=[ + Triple( + s=Value(value="http://example.org/person1", is_uri=True), + p=Value(value="http://example.org/name", is_uri=True), + o=Value(value="Alice Smith", is_uri=False) + ), + Triple( + s=Value(value="http://example.org/person1", is_uri=True), + p=Value(value="http://example.org/age", is_uri=True), + o=Value(value="25", is_uri=False) + ), + Triple( + s=Value(value="http://example.org/person1", is_uri=True), + p=Value(value="http://example.org/department", is_uri=True), + o=Value(value="Engineering", is_uri=False) + ) + ] + ) + + # Store triples via processor + await storage_processor.store_triples(storage_message) + # Track the created TrustGraph instance + if hasattr(storage_processor, 'tg'): + self.clients_to_close.append(storage_processor.tg) + + # Verify data was stored + storage_results = list(storage_processor.tg.get_s("http://example.org/person1", limit=10)) + assert len(storage_results) == 3 + + predicates = [row.p for row in storage_results] + objects = [row.o for row in storage_results] + + assert "http://example.org/name" in predicates + assert "http://example.org/age" in predicates + assert "http://example.org/department" in predicates + assert "Alice Smith" in objects + assert "25" in objects + assert "Engineering" in objects + print("✓ Storage processor working") + + # ===================================================== + # Test 3: Query Processor Integration + # ===================================================== + print("\n3. Testing query processor integration...") + + query_processor = QueryProcessor( + taskgroup=MagicMock(), + hosts=[host], + keyspace="test_query", + table="test_triples" + ) + + # Use same storage processor for the query keyspace + query_storage_processor = StorageProcessor( + taskgroup=MagicMock(), + hosts=[host], + keyspace="test_query", + table="test_triples" + ) + + # Store test data for querying + query_test_message = Triples( + metadata=Metadata(user="testuser", collection="testcol"), + triples=[ + Triple( + s=Value(value="http://example.org/alice", is_uri=True), + p=Value(value="http://example.org/knows", is_uri=True), + o=Value(value="http://example.org/bob", is_uri=True) + ), + Triple( + s=Value(value="http://example.org/alice", is_uri=True), + p=Value(value="http://example.org/age", is_uri=True), + o=Value(value="30", is_uri=False) + ), + Triple( + s=Value(value="http://example.org/bob", is_uri=True), + p=Value(value="http://example.org/knows", is_uri=True), + o=Value(value="http://example.org/charlie", is_uri=True) + ) + ] + ) + await query_storage_processor.store_triples(query_test_message) + + # Debug: Check what was actually stored + print("Debug: Checking what was stored for Alice...") + direct_results = list(query_storage_processor.tg.get_s("http://example.org/alice", limit=10)) + print(f"Direct TrustGraph results: {len(direct_results)}") + for result in direct_results: + print(f" S=http://example.org/alice, P={result.p}, O={result.o}") + + # Test S query (find all relationships for Alice) + s_query = TriplesQueryRequest( + s=Value(value="http://example.org/alice", is_uri=True), + p=None, # None for wildcard + o=None, # None for wildcard + limit=10, + user="testuser", + collection="testcol" + ) + s_results = await query_processor.query_triples(s_query) + print(f"Query processor results: {len(s_results)}") + for result in s_results: + print(f" S={result.s.value}, P={result.p.value}, O={result.o.value}") + assert len(s_results) == 2 + + s_predicates = [t.p.value for t in s_results] + assert "http://example.org/knows" in s_predicates + assert "http://example.org/age" in s_predicates + print("✓ Subject queries via processor working") + + # Test P query (find all "knows" relationships) + p_query = TriplesQueryRequest( + s=None, # None for wildcard + p=Value(value="http://example.org/knows", is_uri=True), + o=None, # None for wildcard + limit=10, + user="testuser", + collection="testcol" + ) + p_results = await query_processor.query_triples(p_query) + print(p_results) + assert len(p_results) == 2 # Alice knows Bob, Bob knows Charlie + + p_subjects = [t.s.value for t in p_results] + assert "http://example.org/alice" in p_subjects + assert "http://example.org/bob" in p_subjects + print("✓ Predicate queries via processor working") + + # ===================================================== + # Test 4: Concurrent Operations + # ===================================================== + print("\n4. Testing concurrent operations...") + + concurrent_processor = StorageProcessor( + taskgroup=MagicMock(), + hosts=[host], + keyspace="test_concurrent", + table="test_triples" + ) + + # Create multiple coroutines for concurrent storage + async def store_person_data(person_id, name, age, department): + message = Triples( + metadata=Metadata(user="concurrent_test", collection="people"), + triples=[ + Triple( + s=Value(value=f"http://example.org/{person_id}", is_uri=True), + p=Value(value="http://example.org/name", is_uri=True), + o=Value(value=name, is_uri=False) + ), + Triple( + s=Value(value=f"http://example.org/{person_id}", is_uri=True), + p=Value(value="http://example.org/age", is_uri=True), + o=Value(value=str(age), is_uri=False) + ), + Triple( + s=Value(value=f"http://example.org/{person_id}", is_uri=True), + p=Value(value="http://example.org/department", is_uri=True), + o=Value(value=department, is_uri=False) + ) + ] + ) + await concurrent_processor.store_triples(message) + + # Store data for multiple people concurrently + people_data = [ + ("person1", "John Doe", 25, "Engineering"), + ("person2", "Jane Smith", 30, "Marketing"), + ("person3", "Bob Wilson", 35, "Engineering"), + ("person4", "Alice Brown", 28, "Sales"), + ] + + # Run storage operations concurrently + store_tasks = [store_person_data(pid, name, age, dept) for pid, name, age, dept in people_data] + await asyncio.gather(*store_tasks) + # Track the created TrustGraph instance + if hasattr(concurrent_processor, 'tg'): + self.clients_to_close.append(concurrent_processor.tg) + + # Verify all names were stored + name_results = list(concurrent_processor.tg.get_p("http://example.org/name", limit=10)) + assert len(name_results) == 4 + + stored_names = [r.o for r in name_results] + expected_names = ["John Doe", "Jane Smith", "Bob Wilson", "Alice Brown"] + + for name in expected_names: + assert name in stored_names + + # Verify department data + dept_results = list(concurrent_processor.tg.get_p("http://example.org/department", limit=10)) + assert len(dept_results) == 4 + + stored_depts = [r.o for r in dept_results] + assert "Engineering" in stored_depts + assert "Marketing" in stored_depts + assert "Sales" in stored_depts + print("✓ Concurrent operations working") + + # ===================================================== + # Test 5: Complex Queries and Data Integrity + # ===================================================== + print("\n5. Testing complex queries and data integrity...") + + complex_processor = StorageProcessor( + taskgroup=MagicMock(), + hosts=[host], + keyspace="test_complex", + table="test_triples" + ) + + # Create a knowledge graph about a company + company_graph = Triples( + metadata=Metadata(user="integration_test", collection="company"), + triples=[ + # People and their types + Triple( + s=Value(value="http://company.org/alice", is_uri=True), + p=Value(value="http://www.w3.org/1999/02/22-rdf-syntax-ns#type", is_uri=True), + o=Value(value="http://company.org/Employee", is_uri=True) + ), + Triple( + s=Value(value="http://company.org/bob", is_uri=True), + p=Value(value="http://www.w3.org/1999/02/22-rdf-syntax-ns#type", is_uri=True), + o=Value(value="http://company.org/Employee", is_uri=True) + ), + # Relationships + Triple( + s=Value(value="http://company.org/alice", is_uri=True), + p=Value(value="http://company.org/reportsTo", is_uri=True), + o=Value(value="http://company.org/bob", is_uri=True) + ), + Triple( + s=Value(value="http://company.org/alice", is_uri=True), + p=Value(value="http://company.org/worksIn", is_uri=True), + o=Value(value="http://company.org/engineering", is_uri=True) + ), + # Personal info + Triple( + s=Value(value="http://company.org/alice", is_uri=True), + p=Value(value="http://company.org/fullName", is_uri=True), + o=Value(value="Alice Johnson", is_uri=False) + ), + Triple( + s=Value(value="http://company.org/alice", is_uri=True), + p=Value(value="http://company.org/email", is_uri=True), + o=Value(value="alice@company.org", is_uri=False) + ), + ] + ) + + # Store the company knowledge graph + await complex_processor.store_triples(company_graph) + # Track the created TrustGraph instance + if hasattr(complex_processor, 'tg'): + self.clients_to_close.append(complex_processor.tg) + + # Verify all Alice's data + alice_data = list(complex_processor.tg.get_s("http://company.org/alice", limit=20)) + assert len(alice_data) == 5 + + alice_predicates = [r.p for r in alice_data] + expected_predicates = [ + "http://www.w3.org/1999/02/22-rdf-syntax-ns#type", + "http://company.org/reportsTo", + "http://company.org/worksIn", + "http://company.org/fullName", + "http://company.org/email" + ] + for pred in expected_predicates: + assert pred in alice_predicates + + # Test type-based queries + employee_results = list(complex_processor.tg.get_p("http://www.w3.org/1999/02/22-rdf-syntax-ns#type", limit=10)) + print(employee_results) + assert len(employee_results) == 2 + + employees = [r.s for r in employee_results] + assert "http://company.org/alice" in employees + assert "http://company.org/bob" in employees + print("✓ Complex queries and data integrity working") + + # ===================================================== + # Summary + # ===================================================== + print("\n" + "=" * 60) + print("✅ ALL CASSANDRA INTEGRATION TESTS PASSED!") + print("✅ Basic operations: PASSED") + print("✅ Storage processor: PASSED") + print("✅ Query processor: PASSED") + print("✅ Concurrent operations: PASSED") + print("✅ Complex queries: PASSED") + print("=" * 60) diff --git a/tests/unit/test_query/test_doc_embeddings_milvus_query.py b/tests/unit/test_query/test_doc_embeddings_milvus_query.py new file mode 100644 index 00000000..10ea54d2 --- /dev/null +++ b/tests/unit/test_query/test_doc_embeddings_milvus_query.py @@ -0,0 +1,456 @@ +""" +Tests for Milvus document embeddings query service +""" + +import pytest +from unittest.mock import MagicMock, patch + +from trustgraph.query.doc_embeddings.milvus.service import Processor +from trustgraph.schema import DocumentEmbeddingsRequest + + +class TestMilvusDocEmbeddingsQueryProcessor: + """Test cases for Milvus document embeddings query processor""" + + @pytest.fixture + def processor(self): + """Create a processor instance for testing""" + with patch('trustgraph.query.doc_embeddings.milvus.service.DocVectors') as mock_doc_vectors: + mock_vecstore = MagicMock() + mock_doc_vectors.return_value = mock_vecstore + + processor = Processor( + taskgroup=MagicMock(), + id='test-milvus-de-query', + store_uri='http://localhost:19530' + ) + + return processor + + @pytest.fixture + def mock_query_request(self): + """Create a mock query request for testing""" + query = DocumentEmbeddingsRequest( + user='test_user', + collection='test_collection', + vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], + limit=10 + ) + return query + + @patch('trustgraph.query.doc_embeddings.milvus.service.DocVectors') + def test_processor_initialization_with_defaults(self, mock_doc_vectors): + """Test processor initialization with default parameters""" + taskgroup_mock = MagicMock() + mock_vecstore = MagicMock() + mock_doc_vectors.return_value = mock_vecstore + + processor = Processor(taskgroup=taskgroup_mock) + + mock_doc_vectors.assert_called_once_with('http://localhost:19530') + assert processor.vecstore == mock_vecstore + + @patch('trustgraph.query.doc_embeddings.milvus.service.DocVectors') + def test_processor_initialization_with_custom_params(self, mock_doc_vectors): + """Test processor initialization with custom parameters""" + taskgroup_mock = MagicMock() + mock_vecstore = MagicMock() + mock_doc_vectors.return_value = mock_vecstore + + processor = Processor( + taskgroup=taskgroup_mock, + store_uri='http://custom-milvus:19530' + ) + + mock_doc_vectors.assert_called_once_with('http://custom-milvus:19530') + assert processor.vecstore == mock_vecstore + + @pytest.mark.asyncio + async def test_query_document_embeddings_single_vector(self, processor): + """Test querying document embeddings with a single vector""" + query = DocumentEmbeddingsRequest( + user='test_user', + collection='test_collection', + vectors=[[0.1, 0.2, 0.3]], + limit=5 + ) + + # Mock search results + mock_results = [ + {"entity": {"doc": "First document chunk"}}, + {"entity": {"doc": "Second document chunk"}}, + {"entity": {"doc": "Third document chunk"}}, + ] + processor.vecstore.search.return_value = mock_results + + result = await processor.query_document_embeddings(query) + + # Verify search was called with correct parameters + processor.vecstore.search.assert_called_once_with([0.1, 0.2, 0.3], limit=5) + + # Verify results are document chunks + assert len(result) == 3 + assert result[0] == "First document chunk" + assert result[1] == "Second document chunk" + assert result[2] == "Third document chunk" + + @pytest.mark.asyncio + async def test_query_document_embeddings_multiple_vectors(self, processor): + """Test querying document embeddings with multiple vectors""" + query = DocumentEmbeddingsRequest( + user='test_user', + collection='test_collection', + vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], + limit=3 + ) + + # Mock search results - different results for each vector + mock_results_1 = [ + {"entity": {"doc": "Document from first vector"}}, + {"entity": {"doc": "Another doc from first vector"}}, + ] + mock_results_2 = [ + {"entity": {"doc": "Document from second vector"}}, + ] + processor.vecstore.search.side_effect = [mock_results_1, mock_results_2] + + result = await processor.query_document_embeddings(query) + + # Verify search was called twice with correct parameters + expected_calls = [ + (([0.1, 0.2, 0.3],), {"limit": 3}), + (([0.4, 0.5, 0.6],), {"limit": 3}), + ] + assert processor.vecstore.search.call_count == 2 + for i, (expected_args, expected_kwargs) in enumerate(expected_calls): + actual_call = processor.vecstore.search.call_args_list[i] + assert actual_call[0] == expected_args + assert actual_call[1] == expected_kwargs + + # Verify results from all vectors are combined + assert len(result) == 3 + assert "Document from first vector" in result + assert "Another doc from first vector" in result + assert "Document from second vector" in result + + @pytest.mark.asyncio + async def test_query_document_embeddings_with_limit(self, processor): + """Test querying document embeddings respects limit parameter""" + query = DocumentEmbeddingsRequest( + user='test_user', + collection='test_collection', + vectors=[[0.1, 0.2, 0.3]], + limit=2 + ) + + # Mock search results - more results than limit + mock_results = [ + {"entity": {"doc": "Document 1"}}, + {"entity": {"doc": "Document 2"}}, + {"entity": {"doc": "Document 3"}}, + {"entity": {"doc": "Document 4"}}, + ] + processor.vecstore.search.return_value = mock_results + + result = await processor.query_document_embeddings(query) + + # Verify search was called with the specified limit + processor.vecstore.search.assert_called_once_with([0.1, 0.2, 0.3], limit=2) + + # Verify all results are returned (Milvus handles limit internally) + assert len(result) == 4 + + @pytest.mark.asyncio + async def test_query_document_embeddings_empty_vectors(self, processor): + """Test querying document embeddings with empty vectors list""" + query = DocumentEmbeddingsRequest( + user='test_user', + collection='test_collection', + vectors=[], + limit=5 + ) + + result = await processor.query_document_embeddings(query) + + # Verify no search was called + processor.vecstore.search.assert_not_called() + + # Verify empty results + assert len(result) == 0 + + @pytest.mark.asyncio + async def test_query_document_embeddings_empty_search_results(self, processor): + """Test querying document embeddings with empty search results""" + query = DocumentEmbeddingsRequest( + user='test_user', + collection='test_collection', + vectors=[[0.1, 0.2, 0.3]], + limit=5 + ) + + # Mock empty search results + processor.vecstore.search.return_value = [] + + result = await processor.query_document_embeddings(query) + + # Verify search was called + processor.vecstore.search.assert_called_once_with([0.1, 0.2, 0.3], limit=5) + + # Verify empty results + assert len(result) == 0 + + @pytest.mark.asyncio + async def test_query_document_embeddings_unicode_documents(self, processor): + """Test querying document embeddings with Unicode document content""" + query = DocumentEmbeddingsRequest( + user='test_user', + collection='test_collection', + vectors=[[0.1, 0.2, 0.3]], + limit=5 + ) + + # Mock search results with Unicode content + mock_results = [ + {"entity": {"doc": "Document with Unicode: éñ中文🚀"}}, + {"entity": {"doc": "Regular ASCII document"}}, + {"entity": {"doc": "Document with émojis: 😀🎉"}}, + ] + processor.vecstore.search.return_value = mock_results + + result = await processor.query_document_embeddings(query) + + # Verify Unicode content is preserved + assert len(result) == 3 + assert "Document with Unicode: éñ中文🚀" in result + assert "Regular ASCII document" in result + assert "Document with émojis: 😀🎉" in result + + @pytest.mark.asyncio + async def test_query_document_embeddings_large_documents(self, processor): + """Test querying document embeddings with large document content""" + query = DocumentEmbeddingsRequest( + user='test_user', + collection='test_collection', + vectors=[[0.1, 0.2, 0.3]], + limit=5 + ) + + # Mock search results with large content + large_doc = "A" * 10000 # 10KB of content + mock_results = [ + {"entity": {"doc": large_doc}}, + {"entity": {"doc": "Small document"}}, + ] + processor.vecstore.search.return_value = mock_results + + result = await processor.query_document_embeddings(query) + + # Verify large content is preserved + assert len(result) == 2 + assert large_doc in result + assert "Small document" in result + + @pytest.mark.asyncio + async def test_query_document_embeddings_special_characters(self, processor): + """Test querying document embeddings with special characters in documents""" + query = DocumentEmbeddingsRequest( + user='test_user', + collection='test_collection', + vectors=[[0.1, 0.2, 0.3]], + limit=5 + ) + + # Mock search results with special characters + mock_results = [ + {"entity": {"doc": "Document with \"quotes\" and 'apostrophes'"}}, + {"entity": {"doc": "Document with\nnewlines\tand\ttabs"}}, + {"entity": {"doc": "Document with special chars: @#$%^&*()"}}, + ] + processor.vecstore.search.return_value = mock_results + + result = await processor.query_document_embeddings(query) + + # Verify special characters are preserved + assert len(result) == 3 + assert "Document with \"quotes\" and 'apostrophes'" in result + assert "Document with\nnewlines\tand\ttabs" in result + assert "Document with special chars: @#$%^&*()" in result + + @pytest.mark.asyncio + async def test_query_document_embeddings_zero_limit(self, processor): + """Test querying document embeddings with zero limit""" + query = DocumentEmbeddingsRequest( + user='test_user', + collection='test_collection', + vectors=[[0.1, 0.2, 0.3]], + limit=0 + ) + + result = await processor.query_document_embeddings(query) + + # Verify no search was called (optimization for zero limit) + processor.vecstore.search.assert_not_called() + + # Verify empty results due to zero limit + assert len(result) == 0 + + @pytest.mark.asyncio + async def test_query_document_embeddings_negative_limit(self, processor): + """Test querying document embeddings with negative limit""" + query = DocumentEmbeddingsRequest( + user='test_user', + collection='test_collection', + vectors=[[0.1, 0.2, 0.3]], + limit=-1 + ) + + result = await processor.query_document_embeddings(query) + + # Verify no search was called (optimization for negative limit) + processor.vecstore.search.assert_not_called() + + # Verify empty results due to negative limit + assert len(result) == 0 + + @pytest.mark.asyncio + async def test_query_document_embeddings_exception_handling(self, processor): + """Test exception handling during query processing""" + query = DocumentEmbeddingsRequest( + user='test_user', + collection='test_collection', + vectors=[[0.1, 0.2, 0.3]], + limit=5 + ) + + # Mock search to raise exception + processor.vecstore.search.side_effect = Exception("Milvus connection failed") + + # Should raise the exception + with pytest.raises(Exception, match="Milvus connection failed"): + await processor.query_document_embeddings(query) + + @pytest.mark.asyncio + async def test_query_document_embeddings_different_vector_dimensions(self, processor): + """Test querying document embeddings with different vector dimensions""" + query = DocumentEmbeddingsRequest( + user='test_user', + collection='test_collection', + vectors=[ + [0.1, 0.2], # 2D vector + [0.3, 0.4, 0.5, 0.6], # 4D vector + [0.7, 0.8, 0.9] # 3D vector + ], + limit=5 + ) + + # Mock search results for each vector + mock_results_1 = [{"entity": {"doc": "Document from 2D vector"}}] + mock_results_2 = [{"entity": {"doc": "Document from 4D vector"}}] + mock_results_3 = [{"entity": {"doc": "Document from 3D vector"}}] + processor.vecstore.search.side_effect = [mock_results_1, mock_results_2, mock_results_3] + + result = await processor.query_document_embeddings(query) + + # Verify all vectors were searched + assert processor.vecstore.search.call_count == 3 + + # Verify results from all dimensions + assert len(result) == 3 + assert "Document from 2D vector" in result + assert "Document from 4D vector" in result + assert "Document from 3D vector" in result + + @pytest.mark.asyncio + async def test_query_document_embeddings_duplicate_documents(self, processor): + """Test querying document embeddings with duplicate documents in results""" + query = DocumentEmbeddingsRequest( + user='test_user', + collection='test_collection', + vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], + limit=5 + ) + + # Mock search results with duplicates across vectors + mock_results_1 = [ + {"entity": {"doc": "Document A"}}, + {"entity": {"doc": "Document B"}}, + ] + mock_results_2 = [ + {"entity": {"doc": "Document B"}}, # Duplicate + {"entity": {"doc": "Document C"}}, + ] + processor.vecstore.search.side_effect = [mock_results_1, mock_results_2] + + result = await processor.query_document_embeddings(query) + + # Note: Unlike graph embeddings, doc embeddings don't deduplicate + # This preserves ranking and allows multiple occurrences + assert len(result) == 4 + assert result.count("Document B") == 2 # Should appear twice + assert "Document A" in result + assert "Document C" in result + + def test_add_args_method(self): + """Test that add_args properly configures argument parser""" + from argparse import ArgumentParser + from unittest.mock import patch + + parser = ArgumentParser() + + # Mock the parent class add_args method + with patch('trustgraph.query.doc_embeddings.milvus.service.DocumentEmbeddingsQueryService.add_args') as mock_parent_add_args: + Processor.add_args(parser) + + # Verify parent add_args was called + mock_parent_add_args.assert_called_once() + + # Verify our specific arguments were added + # Parse empty args to check defaults + args = parser.parse_args([]) + + assert hasattr(args, 'store_uri') + assert args.store_uri == 'http://localhost:19530' + + def test_add_args_with_custom_values(self): + """Test add_args with custom command line values""" + from argparse import ArgumentParser + from unittest.mock import patch + + parser = ArgumentParser() + + with patch('trustgraph.query.doc_embeddings.milvus.service.DocumentEmbeddingsQueryService.add_args'): + Processor.add_args(parser) + + # Test parsing with custom values + args = parser.parse_args([ + '--store-uri', 'http://custom-milvus:19530' + ]) + + assert args.store_uri == 'http://custom-milvus:19530' + + def test_add_args_short_form(self): + """Test add_args with short form arguments""" + from argparse import ArgumentParser + from unittest.mock import patch + + parser = ArgumentParser() + + with patch('trustgraph.query.doc_embeddings.milvus.service.DocumentEmbeddingsQueryService.add_args'): + Processor.add_args(parser) + + # Test parsing with short form + args = parser.parse_args(['-t', 'http://short-milvus:19530']) + + assert args.store_uri == 'http://short-milvus:19530' + + @patch('trustgraph.query.doc_embeddings.milvus.service.Processor.launch') + def test_run_function(self, mock_launch): + """Test the run function calls Processor.launch with correct parameters""" + from trustgraph.query.doc_embeddings.milvus.service import run, default_ident + + run() + + mock_launch.assert_called_once_with( + default_ident, + "\nDocument embeddings query service. Input is vector, output is an array\nof chunks\n" + ) \ No newline at end of file diff --git a/tests/unit/test_query/test_doc_embeddings_pinecone_query.py b/tests/unit/test_query/test_doc_embeddings_pinecone_query.py new file mode 100644 index 00000000..92551587 --- /dev/null +++ b/tests/unit/test_query/test_doc_embeddings_pinecone_query.py @@ -0,0 +1,558 @@ +""" +Tests for Pinecone document embeddings query service +""" + +import pytest +from unittest.mock import MagicMock, patch + +from trustgraph.query.doc_embeddings.pinecone.service import Processor + + +class TestPineconeDocEmbeddingsQueryProcessor: + """Test cases for Pinecone document embeddings query processor""" + + @pytest.fixture + def mock_query_message(self): + """Create a mock query message for testing""" + message = MagicMock() + message.vectors = [ + [0.1, 0.2, 0.3], + [0.4, 0.5, 0.6] + ] + message.limit = 5 + message.user = 'test_user' + message.collection = 'test_collection' + return message + + @pytest.fixture + def processor(self): + """Create a processor instance for testing""" + with patch('trustgraph.query.doc_embeddings.pinecone.service.Pinecone') as mock_pinecone_class: + mock_pinecone = MagicMock() + mock_pinecone_class.return_value = mock_pinecone + + processor = Processor( + taskgroup=MagicMock(), + id='test-pinecone-de-query', + api_key='test-api-key' + ) + + return processor + + @patch('trustgraph.query.doc_embeddings.pinecone.service.Pinecone') + @patch('trustgraph.query.doc_embeddings.pinecone.service.default_api_key', 'env-api-key') + def test_processor_initialization_with_defaults(self, mock_pinecone_class): + """Test processor initialization with default parameters""" + mock_pinecone = MagicMock() + mock_pinecone_class.return_value = mock_pinecone + taskgroup_mock = MagicMock() + + processor = Processor(taskgroup=taskgroup_mock) + + mock_pinecone_class.assert_called_once_with(api_key='env-api-key') + assert processor.pinecone == mock_pinecone + assert processor.api_key == 'env-api-key' + + @patch('trustgraph.query.doc_embeddings.pinecone.service.Pinecone') + def test_processor_initialization_with_custom_params(self, mock_pinecone_class): + """Test processor initialization with custom parameters""" + mock_pinecone = MagicMock() + mock_pinecone_class.return_value = mock_pinecone + taskgroup_mock = MagicMock() + + processor = Processor( + taskgroup=taskgroup_mock, + api_key='custom-api-key' + ) + + mock_pinecone_class.assert_called_once_with(api_key='custom-api-key') + assert processor.api_key == 'custom-api-key' + + @patch('trustgraph.query.doc_embeddings.pinecone.service.PineconeGRPC') + def test_processor_initialization_with_url(self, mock_pinecone_grpc_class): + """Test processor initialization with custom URL (GRPC mode)""" + mock_pinecone = MagicMock() + mock_pinecone_grpc_class.return_value = mock_pinecone + taskgroup_mock = MagicMock() + + processor = Processor( + taskgroup=taskgroup_mock, + api_key='test-api-key', + url='https://custom-host.pinecone.io' + ) + + mock_pinecone_grpc_class.assert_called_once_with( + api_key='test-api-key', + host='https://custom-host.pinecone.io' + ) + assert processor.pinecone == mock_pinecone + assert processor.url == 'https://custom-host.pinecone.io' + + @patch('trustgraph.query.doc_embeddings.pinecone.service.default_api_key', 'not-specified') + def test_processor_initialization_missing_api_key(self): + """Test processor initialization fails with missing API key""" + taskgroup_mock = MagicMock() + + with pytest.raises(RuntimeError, match="Pinecone API key must be specified"): + Processor(taskgroup=taskgroup_mock) + + @pytest.mark.asyncio + async def test_query_document_embeddings_single_vector(self, processor): + """Test querying document embeddings with a single vector""" + message = MagicMock() + message.vectors = [[0.1, 0.2, 0.3]] + message.limit = 3 + message.user = 'test_user' + message.collection = 'test_collection' + + # Mock index and query results + mock_index = MagicMock() + processor.pinecone.Index.return_value = mock_index + + mock_results = MagicMock() + mock_results.matches = [ + MagicMock(metadata={'doc': 'First document chunk'}), + MagicMock(metadata={'doc': 'Second document chunk'}), + MagicMock(metadata={'doc': 'Third document chunk'}) + ] + mock_index.query.return_value = mock_results + + chunks = await processor.query_document_embeddings(message) + + # Verify index was accessed correctly + expected_index_name = "d-test_user-test_collection-3" + processor.pinecone.Index.assert_called_once_with(expected_index_name) + + # Verify query parameters + mock_index.query.assert_called_once_with( + vector=[0.1, 0.2, 0.3], + top_k=3, + include_values=False, + include_metadata=True + ) + + # Verify results + assert len(chunks) == 3 + assert chunks[0] == 'First document chunk' + assert chunks[1] == 'Second document chunk' + assert chunks[2] == 'Third document chunk' + + @pytest.mark.asyncio + async def test_query_document_embeddings_multiple_vectors(self, processor, mock_query_message): + """Test querying document embeddings with multiple vectors""" + # Mock index and query results + mock_index = MagicMock() + processor.pinecone.Index.return_value = mock_index + + # First query results + mock_results1 = MagicMock() + mock_results1.matches = [ + MagicMock(metadata={'doc': 'Document chunk 1'}), + MagicMock(metadata={'doc': 'Document chunk 2'}) + ] + + # Second query results + mock_results2 = MagicMock() + mock_results2.matches = [ + MagicMock(metadata={'doc': 'Document chunk 3'}), + MagicMock(metadata={'doc': 'Document chunk 4'}) + ] + + mock_index.query.side_effect = [mock_results1, mock_results2] + + chunks = await processor.query_document_embeddings(mock_query_message) + + # Verify both queries were made + assert mock_index.query.call_count == 2 + + # Verify results from both queries + assert len(chunks) == 4 + assert 'Document chunk 1' in chunks + assert 'Document chunk 2' in chunks + assert 'Document chunk 3' in chunks + assert 'Document chunk 4' in chunks + + @pytest.mark.asyncio + async def test_query_document_embeddings_limit_handling(self, processor): + """Test that query respects the limit parameter""" + message = MagicMock() + message.vectors = [[0.1, 0.2, 0.3]] + message.limit = 2 + message.user = 'test_user' + message.collection = 'test_collection' + + # Mock index with many results + mock_index = MagicMock() + processor.pinecone.Index.return_value = mock_index + + mock_results = MagicMock() + mock_results.matches = [ + MagicMock(metadata={'doc': f'Document chunk {i}'}) for i in range(10) + ] + mock_index.query.return_value = mock_results + + chunks = await processor.query_document_embeddings(message) + + # Verify limit is passed to query + mock_index.query.assert_called_once() + call_args = mock_index.query.call_args + assert call_args[1]['top_k'] == 2 + + # Results should contain all returned chunks (limit is applied by Pinecone) + assert len(chunks) == 10 + + @pytest.mark.asyncio + async def test_query_document_embeddings_zero_limit(self, processor): + """Test querying with zero limit returns empty results""" + message = MagicMock() + message.vectors = [[0.1, 0.2, 0.3]] + message.limit = 0 + message.user = 'test_user' + message.collection = 'test_collection' + + mock_index = MagicMock() + processor.pinecone.Index.return_value = mock_index + + chunks = await processor.query_document_embeddings(message) + + # Verify no query was made and empty result returned + mock_index.query.assert_not_called() + assert chunks == [] + + @pytest.mark.asyncio + async def test_query_document_embeddings_negative_limit(self, processor): + """Test querying with negative limit returns empty results""" + message = MagicMock() + message.vectors = [[0.1, 0.2, 0.3]] + message.limit = -1 + message.user = 'test_user' + message.collection = 'test_collection' + + mock_index = MagicMock() + processor.pinecone.Index.return_value = mock_index + + chunks = await processor.query_document_embeddings(message) + + # Verify no query was made and empty result returned + mock_index.query.assert_not_called() + assert chunks == [] + + @pytest.mark.asyncio + async def test_query_document_embeddings_different_vector_dimensions(self, processor): + """Test querying with vectors of different dimensions""" + message = MagicMock() + message.vectors = [ + [0.1, 0.2], # 2D vector + [0.3, 0.4, 0.5, 0.6] # 4D vector + ] + message.limit = 5 + message.user = 'test_user' + message.collection = 'test_collection' + + mock_index_2d = MagicMock() + mock_index_4d = MagicMock() + + def mock_index_side_effect(name): + if name.endswith("-2"): + return mock_index_2d + elif name.endswith("-4"): + return mock_index_4d + + processor.pinecone.Index.side_effect = mock_index_side_effect + + # Mock results for different dimensions + mock_results_2d = MagicMock() + mock_results_2d.matches = [MagicMock(metadata={'doc': 'Document from 2D index'})] + mock_index_2d.query.return_value = mock_results_2d + + mock_results_4d = MagicMock() + mock_results_4d.matches = [MagicMock(metadata={'doc': 'Document from 4D index'})] + mock_index_4d.query.return_value = mock_results_4d + + chunks = await processor.query_document_embeddings(message) + + # Verify different indexes were used + assert processor.pinecone.Index.call_count == 2 + mock_index_2d.query.assert_called_once() + mock_index_4d.query.assert_called_once() + + # Verify results from both dimensions + assert 'Document from 2D index' in chunks + assert 'Document from 4D index' in chunks + + @pytest.mark.asyncio + async def test_query_document_embeddings_empty_vectors_list(self, processor): + """Test querying with empty vectors list""" + message = MagicMock() + message.vectors = [] + message.limit = 5 + message.user = 'test_user' + message.collection = 'test_collection' + + mock_index = MagicMock() + processor.pinecone.Index.return_value = mock_index + + chunks = await processor.query_document_embeddings(message) + + # Verify no queries were made and empty result returned + processor.pinecone.Index.assert_not_called() + mock_index.query.assert_not_called() + assert chunks == [] + + @pytest.mark.asyncio + async def test_query_document_embeddings_no_results(self, processor): + """Test querying when index returns no results""" + message = MagicMock() + message.vectors = [[0.1, 0.2, 0.3]] + message.limit = 5 + message.user = 'test_user' + message.collection = 'test_collection' + + mock_index = MagicMock() + processor.pinecone.Index.return_value = mock_index + + mock_results = MagicMock() + mock_results.matches = [] + mock_index.query.return_value = mock_results + + chunks = await processor.query_document_embeddings(message) + + # Verify empty results + assert chunks == [] + + @pytest.mark.asyncio + async def test_query_document_embeddings_unicode_content(self, processor): + """Test querying document embeddings with Unicode content results""" + message = MagicMock() + message.vectors = [[0.1, 0.2, 0.3]] + message.limit = 2 + message.user = 'test_user' + message.collection = 'test_collection' + + mock_index = MagicMock() + processor.pinecone.Index.return_value = mock_index + + mock_results = MagicMock() + mock_results.matches = [ + MagicMock(metadata={'doc': 'Document with Unicode: éñ中文🚀'}), + MagicMock(metadata={'doc': 'Regular ASCII document'}) + ] + mock_index.query.return_value = mock_results + + chunks = await processor.query_document_embeddings(message) + + # Verify Unicode content is properly handled + assert len(chunks) == 2 + assert 'Document with Unicode: éñ中文🚀' in chunks + assert 'Regular ASCII document' in chunks + + @pytest.mark.asyncio + async def test_query_document_embeddings_large_content(self, processor): + """Test querying document embeddings with large content results""" + message = MagicMock() + message.vectors = [[0.1, 0.2, 0.3]] + message.limit = 1 + message.user = 'test_user' + message.collection = 'test_collection' + + mock_index = MagicMock() + processor.pinecone.Index.return_value = mock_index + + # Create a large document content + large_content = "A" * 10000 # 10KB of content + mock_results = MagicMock() + mock_results.matches = [ + MagicMock(metadata={'doc': large_content}) + ] + mock_index.query.return_value = mock_results + + chunks = await processor.query_document_embeddings(message) + + # Verify large content is properly handled + assert len(chunks) == 1 + assert chunks[0] == large_content + + @pytest.mark.asyncio + async def test_query_document_embeddings_mixed_content_types(self, processor): + """Test querying document embeddings with mixed content types""" + message = MagicMock() + message.vectors = [[0.1, 0.2, 0.3]] + message.limit = 5 + message.user = 'test_user' + message.collection = 'test_collection' + + mock_index = MagicMock() + processor.pinecone.Index.return_value = mock_index + + mock_results = MagicMock() + mock_results.matches = [ + MagicMock(metadata={'doc': 'Short text'}), + MagicMock(metadata={'doc': 'A' * 1000}), # Long text + MagicMock(metadata={'doc': 'Text with numbers: 123 and symbols: @#$'}), + MagicMock(metadata={'doc': ' Whitespace text '}), + MagicMock(metadata={'doc': ''}) # Empty string + ] + mock_index.query.return_value = mock_results + + chunks = await processor.query_document_embeddings(message) + + # Verify all content types are properly handled + assert len(chunks) == 5 + assert 'Short text' in chunks + assert 'A' * 1000 in chunks + assert 'Text with numbers: 123 and symbols: @#$' in chunks + assert ' Whitespace text ' in chunks + assert '' in chunks + + @pytest.mark.asyncio + async def test_query_document_embeddings_exception_handling(self, processor): + """Test that exceptions are properly raised""" + message = MagicMock() + message.vectors = [[0.1, 0.2, 0.3]] + message.limit = 5 + message.user = 'test_user' + message.collection = 'test_collection' + + mock_index = MagicMock() + processor.pinecone.Index.return_value = mock_index + mock_index.query.side_effect = Exception("Query failed") + + with pytest.raises(Exception, match="Query failed"): + await processor.query_document_embeddings(message) + + @pytest.mark.asyncio + async def test_query_document_embeddings_index_access_failure(self, processor): + """Test handling of index access failure""" + message = MagicMock() + message.vectors = [[0.1, 0.2, 0.3]] + message.limit = 5 + message.user = 'test_user' + message.collection = 'test_collection' + + processor.pinecone.Index.side_effect = Exception("Index access failed") + + with pytest.raises(Exception, match="Index access failed"): + await processor.query_document_embeddings(message) + + @pytest.mark.asyncio + async def test_query_document_embeddings_vector_accumulation(self, processor): + """Test that results from multiple vectors are properly accumulated""" + message = MagicMock() + message.vectors = [ + [0.1, 0.2, 0.3], + [0.4, 0.5, 0.6], + [0.7, 0.8, 0.9] + ] + message.limit = 2 + message.user = 'test_user' + message.collection = 'test_collection' + + mock_index = MagicMock() + processor.pinecone.Index.return_value = mock_index + + # Each query returns different results + mock_results1 = MagicMock() + mock_results1.matches = [ + MagicMock(metadata={'doc': 'Doc from vector 1.1'}), + MagicMock(metadata={'doc': 'Doc from vector 1.2'}) + ] + + mock_results2 = MagicMock() + mock_results2.matches = [ + MagicMock(metadata={'doc': 'Doc from vector 2.1'}) + ] + + mock_results3 = MagicMock() + mock_results3.matches = [ + MagicMock(metadata={'doc': 'Doc from vector 3.1'}), + MagicMock(metadata={'doc': 'Doc from vector 3.2'}), + MagicMock(metadata={'doc': 'Doc from vector 3.3'}) + ] + + mock_index.query.side_effect = [mock_results1, mock_results2, mock_results3] + + chunks = await processor.query_document_embeddings(message) + + # Verify all queries were made + assert mock_index.query.call_count == 3 + + # Verify all results are accumulated + assert len(chunks) == 6 + assert 'Doc from vector 1.1' in chunks + assert 'Doc from vector 1.2' in chunks + assert 'Doc from vector 2.1' in chunks + assert 'Doc from vector 3.1' in chunks + assert 'Doc from vector 3.2' in chunks + assert 'Doc from vector 3.3' in chunks + + def test_add_args_method(self): + """Test that add_args properly configures argument parser""" + from argparse import ArgumentParser + from unittest.mock import patch + + parser = ArgumentParser() + + # Mock the parent class add_args method + with patch('trustgraph.query.doc_embeddings.pinecone.service.DocumentEmbeddingsQueryService.add_args') as mock_parent_add_args: + Processor.add_args(parser) + + # Verify parent add_args was called + mock_parent_add_args.assert_called_once() + + # Verify our specific arguments were added + args = parser.parse_args([]) + + assert hasattr(args, 'api_key') + assert args.api_key == 'not-specified' # Default value when no env var + assert hasattr(args, 'url') + assert args.url is None + + def test_add_args_with_custom_values(self): + """Test add_args with custom command line values""" + from argparse import ArgumentParser + from unittest.mock import patch + + parser = ArgumentParser() + + with patch('trustgraph.query.doc_embeddings.pinecone.service.DocumentEmbeddingsQueryService.add_args'): + Processor.add_args(parser) + + # Test parsing with custom values + args = parser.parse_args([ + '--api-key', 'custom-api-key', + '--url', 'https://custom-host.pinecone.io' + ]) + + assert args.api_key == 'custom-api-key' + assert args.url == 'https://custom-host.pinecone.io' + + def test_add_args_short_form(self): + """Test add_args with short form arguments""" + from argparse import ArgumentParser + from unittest.mock import patch + + parser = ArgumentParser() + + with patch('trustgraph.query.doc_embeddings.pinecone.service.DocumentEmbeddingsQueryService.add_args'): + Processor.add_args(parser) + + # Test parsing with short form + args = parser.parse_args([ + '-a', 'short-api-key', + '-u', 'https://short-host.pinecone.io' + ]) + + assert args.api_key == 'short-api-key' + assert args.url == 'https://short-host.pinecone.io' + + @patch('trustgraph.query.doc_embeddings.pinecone.service.Processor.launch') + def test_run_function(self, mock_launch): + """Test the run function calls Processor.launch with correct parameters""" + from trustgraph.query.doc_embeddings.pinecone.service import run, default_ident + + run() + + mock_launch.assert_called_once_with( + default_ident, + "\nDocument embeddings query service. Input is vector, output is an array\nof chunks. Pinecone implementation.\n" + ) \ No newline at end of file diff --git a/tests/unit/test_query/test_graph_embeddings_milvus_query.py b/tests/unit/test_query/test_graph_embeddings_milvus_query.py new file mode 100644 index 00000000..5fbb74d5 --- /dev/null +++ b/tests/unit/test_query/test_graph_embeddings_milvus_query.py @@ -0,0 +1,484 @@ +""" +Tests for Milvus graph embeddings query service +""" + +import pytest +from unittest.mock import MagicMock, patch + +from trustgraph.query.graph_embeddings.milvus.service import Processor +from trustgraph.schema import Value, GraphEmbeddingsRequest + + +class TestMilvusGraphEmbeddingsQueryProcessor: + """Test cases for Milvus graph embeddings query processor""" + + @pytest.fixture + def processor(self): + """Create a processor instance for testing""" + with patch('trustgraph.query.graph_embeddings.milvus.service.EntityVectors') as mock_entity_vectors: + mock_vecstore = MagicMock() + mock_entity_vectors.return_value = mock_vecstore + + processor = Processor( + taskgroup=MagicMock(), + id='test-milvus-ge-query', + store_uri='http://localhost:19530' + ) + + return processor + + @pytest.fixture + def mock_query_request(self): + """Create a mock query request for testing""" + query = GraphEmbeddingsRequest( + user='test_user', + collection='test_collection', + vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], + limit=10 + ) + return query + + @patch('trustgraph.query.graph_embeddings.milvus.service.EntityVectors') + def test_processor_initialization_with_defaults(self, mock_entity_vectors): + """Test processor initialization with default parameters""" + taskgroup_mock = MagicMock() + mock_vecstore = MagicMock() + mock_entity_vectors.return_value = mock_vecstore + + processor = Processor(taskgroup=taskgroup_mock) + + mock_entity_vectors.assert_called_once_with('http://localhost:19530') + assert processor.vecstore == mock_vecstore + + @patch('trustgraph.query.graph_embeddings.milvus.service.EntityVectors') + def test_processor_initialization_with_custom_params(self, mock_entity_vectors): + """Test processor initialization with custom parameters""" + taskgroup_mock = MagicMock() + mock_vecstore = MagicMock() + mock_entity_vectors.return_value = mock_vecstore + + processor = Processor( + taskgroup=taskgroup_mock, + store_uri='http://custom-milvus:19530' + ) + + mock_entity_vectors.assert_called_once_with('http://custom-milvus:19530') + assert processor.vecstore == mock_vecstore + + def test_create_value_with_http_uri(self, processor): + """Test create_value with HTTP URI""" + result = processor.create_value("http://example.com/resource") + + assert isinstance(result, Value) + assert result.value == "http://example.com/resource" + assert result.is_uri is True + + def test_create_value_with_https_uri(self, processor): + """Test create_value with HTTPS URI""" + result = processor.create_value("https://example.com/resource") + + assert isinstance(result, Value) + assert result.value == "https://example.com/resource" + assert result.is_uri is True + + def test_create_value_with_literal(self, processor): + """Test create_value with literal value""" + result = processor.create_value("just a literal string") + + assert isinstance(result, Value) + assert result.value == "just a literal string" + assert result.is_uri is False + + def test_create_value_with_empty_string(self, processor): + """Test create_value with empty string""" + result = processor.create_value("") + + assert isinstance(result, Value) + assert result.value == "" + assert result.is_uri is False + + def test_create_value_with_partial_uri(self, processor): + """Test create_value with string that looks like URI but isn't complete""" + result = processor.create_value("http") + + assert isinstance(result, Value) + assert result.value == "http" + assert result.is_uri is False + + def test_create_value_with_ftp_uri(self, processor): + """Test create_value with FTP URI (should not be detected as URI)""" + result = processor.create_value("ftp://example.com/file") + + assert isinstance(result, Value) + assert result.value == "ftp://example.com/file" + assert result.is_uri is False + + @pytest.mark.asyncio + async def test_query_graph_embeddings_single_vector(self, processor): + """Test querying graph embeddings with a single vector""" + query = GraphEmbeddingsRequest( + user='test_user', + collection='test_collection', + vectors=[[0.1, 0.2, 0.3]], + limit=5 + ) + + # Mock search results + mock_results = [ + {"entity": {"entity": "http://example.com/entity1"}}, + {"entity": {"entity": "http://example.com/entity2"}}, + {"entity": {"entity": "literal entity"}}, + ] + processor.vecstore.search.return_value = mock_results + + result = await processor.query_graph_embeddings(query) + + # Verify search was called with correct parameters + processor.vecstore.search.assert_called_once_with([0.1, 0.2, 0.3], limit=10) + + # Verify results are converted to Value objects + assert len(result) == 3 + assert isinstance(result[0], Value) + assert result[0].value == "http://example.com/entity1" + assert result[0].is_uri is True + assert isinstance(result[1], Value) + assert result[1].value == "http://example.com/entity2" + assert result[1].is_uri is True + assert isinstance(result[2], Value) + assert result[2].value == "literal entity" + assert result[2].is_uri is False + + @pytest.mark.asyncio + async def test_query_graph_embeddings_multiple_vectors(self, processor): + """Test querying graph embeddings with multiple vectors""" + query = GraphEmbeddingsRequest( + user='test_user', + collection='test_collection', + vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], + limit=3 + ) + + # Mock search results - different results for each vector + mock_results_1 = [ + {"entity": {"entity": "http://example.com/entity1"}}, + {"entity": {"entity": "http://example.com/entity2"}}, + ] + mock_results_2 = [ + {"entity": {"entity": "http://example.com/entity2"}}, # Duplicate + {"entity": {"entity": "http://example.com/entity3"}}, + ] + processor.vecstore.search.side_effect = [mock_results_1, mock_results_2] + + result = await processor.query_graph_embeddings(query) + + # Verify search was called twice with correct parameters + expected_calls = [ + (([0.1, 0.2, 0.3],), {"limit": 6}), + (([0.4, 0.5, 0.6],), {"limit": 6}), + ] + assert processor.vecstore.search.call_count == 2 + for i, (expected_args, expected_kwargs) in enumerate(expected_calls): + actual_call = processor.vecstore.search.call_args_list[i] + assert actual_call[0] == expected_args + assert actual_call[1] == expected_kwargs + + # Verify results are deduplicated and limited + assert len(result) == 3 + entity_values = [r.value for r in result] + assert "http://example.com/entity1" in entity_values + assert "http://example.com/entity2" in entity_values + assert "http://example.com/entity3" in entity_values + + @pytest.mark.asyncio + async def test_query_graph_embeddings_with_limit(self, processor): + """Test querying graph embeddings respects limit parameter""" + query = GraphEmbeddingsRequest( + user='test_user', + collection='test_collection', + vectors=[[0.1, 0.2, 0.3]], + limit=2 + ) + + # Mock search results - more results than limit + mock_results = [ + {"entity": {"entity": "http://example.com/entity1"}}, + {"entity": {"entity": "http://example.com/entity2"}}, + {"entity": {"entity": "http://example.com/entity3"}}, + {"entity": {"entity": "http://example.com/entity4"}}, + ] + processor.vecstore.search.return_value = mock_results + + result = await processor.query_graph_embeddings(query) + + # Verify search was called with 2*limit for better deduplication + processor.vecstore.search.assert_called_once_with([0.1, 0.2, 0.3], limit=4) + + # Verify results are limited to the requested limit + assert len(result) == 2 + + @pytest.mark.asyncio + async def test_query_graph_embeddings_deduplication(self, processor): + """Test that duplicate entities are properly deduplicated""" + query = GraphEmbeddingsRequest( + user='test_user', + collection='test_collection', + vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], + limit=5 + ) + + # Mock search results with duplicates + mock_results_1 = [ + {"entity": {"entity": "http://example.com/entity1"}}, + {"entity": {"entity": "http://example.com/entity2"}}, + ] + mock_results_2 = [ + {"entity": {"entity": "http://example.com/entity2"}}, # Duplicate + {"entity": {"entity": "http://example.com/entity1"}}, # Duplicate + {"entity": {"entity": "http://example.com/entity3"}}, # New + ] + processor.vecstore.search.side_effect = [mock_results_1, mock_results_2] + + result = await processor.query_graph_embeddings(query) + + # Verify duplicates are removed + assert len(result) == 3 + entity_values = [r.value for r in result] + assert len(set(entity_values)) == 3 # All unique + assert "http://example.com/entity1" in entity_values + assert "http://example.com/entity2" in entity_values + assert "http://example.com/entity3" in entity_values + + @pytest.mark.asyncio + async def test_query_graph_embeddings_early_termination_on_limit(self, processor): + """Test that querying stops early when limit is reached""" + query = GraphEmbeddingsRequest( + user='test_user', + collection='test_collection', + vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], + limit=2 + ) + + # Mock search results - first vector returns enough results + mock_results_1 = [ + {"entity": {"entity": "http://example.com/entity1"}}, + {"entity": {"entity": "http://example.com/entity2"}}, + {"entity": {"entity": "http://example.com/entity3"}}, + ] + processor.vecstore.search.return_value = mock_results_1 + + result = await processor.query_graph_embeddings(query) + + # Verify only first vector was searched (limit reached) + processor.vecstore.search.assert_called_once_with([0.1, 0.2, 0.3], limit=4) + + # Verify results are limited + assert len(result) == 2 + + @pytest.mark.asyncio + async def test_query_graph_embeddings_empty_vectors(self, processor): + """Test querying graph embeddings with empty vectors list""" + query = GraphEmbeddingsRequest( + user='test_user', + collection='test_collection', + vectors=[], + limit=5 + ) + + result = await processor.query_graph_embeddings(query) + + # Verify no search was called + processor.vecstore.search.assert_not_called() + + # Verify empty results + assert len(result) == 0 + + @pytest.mark.asyncio + async def test_query_graph_embeddings_empty_search_results(self, processor): + """Test querying graph embeddings with empty search results""" + query = GraphEmbeddingsRequest( + user='test_user', + collection='test_collection', + vectors=[[0.1, 0.2, 0.3]], + limit=5 + ) + + # Mock empty search results + processor.vecstore.search.return_value = [] + + result = await processor.query_graph_embeddings(query) + + # Verify search was called + processor.vecstore.search.assert_called_once_with([0.1, 0.2, 0.3], limit=10) + + # Verify empty results + assert len(result) == 0 + + @pytest.mark.asyncio + async def test_query_graph_embeddings_mixed_uri_literal_results(self, processor): + """Test querying graph embeddings with mixed URI and literal results""" + query = GraphEmbeddingsRequest( + user='test_user', + collection='test_collection', + vectors=[[0.1, 0.2, 0.3]], + limit=5 + ) + + # Mock search results with mixed types + mock_results = [ + {"entity": {"entity": "http://example.com/uri_entity"}}, + {"entity": {"entity": "literal entity text"}}, + {"entity": {"entity": "https://example.com/another_uri"}}, + {"entity": {"entity": "another literal"}}, + ] + processor.vecstore.search.return_value = mock_results + + result = await processor.query_graph_embeddings(query) + + # Verify all results are properly typed + assert len(result) == 4 + + # Check URI entities + uri_results = [r for r in result if r.is_uri] + assert len(uri_results) == 2 + uri_values = [r.value for r in uri_results] + assert "http://example.com/uri_entity" in uri_values + assert "https://example.com/another_uri" in uri_values + + # Check literal entities + literal_results = [r for r in result if not r.is_uri] + assert len(literal_results) == 2 + literal_values = [r.value for r in literal_results] + assert "literal entity text" in literal_values + assert "another literal" in literal_values + + @pytest.mark.asyncio + async def test_query_graph_embeddings_exception_handling(self, processor): + """Test exception handling during query processing""" + query = GraphEmbeddingsRequest( + user='test_user', + collection='test_collection', + vectors=[[0.1, 0.2, 0.3]], + limit=5 + ) + + # Mock search to raise exception + processor.vecstore.search.side_effect = Exception("Milvus connection failed") + + # Should raise the exception + with pytest.raises(Exception, match="Milvus connection failed"): + await processor.query_graph_embeddings(query) + + def test_add_args_method(self): + """Test that add_args properly configures argument parser""" + from argparse import ArgumentParser + from unittest.mock import patch + + parser = ArgumentParser() + + # Mock the parent class add_args method + with patch('trustgraph.query.graph_embeddings.milvus.service.GraphEmbeddingsQueryService.add_args') as mock_parent_add_args: + Processor.add_args(parser) + + # Verify parent add_args was called + mock_parent_add_args.assert_called_once() + + # Verify our specific arguments were added + # Parse empty args to check defaults + args = parser.parse_args([]) + + assert hasattr(args, 'store_uri') + assert args.store_uri == 'http://localhost:19530' + + def test_add_args_with_custom_values(self): + """Test add_args with custom command line values""" + from argparse import ArgumentParser + from unittest.mock import patch + + parser = ArgumentParser() + + with patch('trustgraph.query.graph_embeddings.milvus.service.GraphEmbeddingsQueryService.add_args'): + Processor.add_args(parser) + + # Test parsing with custom values + args = parser.parse_args([ + '--store-uri', 'http://custom-milvus:19530' + ]) + + assert args.store_uri == 'http://custom-milvus:19530' + + def test_add_args_short_form(self): + """Test add_args with short form arguments""" + from argparse import ArgumentParser + from unittest.mock import patch + + parser = ArgumentParser() + + with patch('trustgraph.query.graph_embeddings.milvus.service.GraphEmbeddingsQueryService.add_args'): + Processor.add_args(parser) + + # Test parsing with short form + args = parser.parse_args(['-t', 'http://short-milvus:19530']) + + assert args.store_uri == 'http://short-milvus:19530' + + @patch('trustgraph.query.graph_embeddings.milvus.service.Processor.launch') + def test_run_function(self, mock_launch): + """Test the run function calls Processor.launch with correct parameters""" + from trustgraph.query.graph_embeddings.milvus.service import run, default_ident + + run() + + mock_launch.assert_called_once_with( + default_ident, + "\nGraph embeddings query service. Input is vector, output is list of\nentities\n" + ) + + @pytest.mark.asyncio + async def test_query_graph_embeddings_zero_limit(self, processor): + """Test querying graph embeddings with zero limit""" + query = GraphEmbeddingsRequest( + user='test_user', + collection='test_collection', + vectors=[[0.1, 0.2, 0.3]], + limit=0 + ) + + result = await processor.query_graph_embeddings(query) + + # Verify no search was called (optimization for zero limit) + processor.vecstore.search.assert_not_called() + + # Verify empty results due to zero limit + assert len(result) == 0 + + @pytest.mark.asyncio + async def test_query_graph_embeddings_different_vector_dimensions(self, processor): + """Test querying graph embeddings with different vector dimensions""" + query = GraphEmbeddingsRequest( + user='test_user', + collection='test_collection', + vectors=[ + [0.1, 0.2], # 2D vector + [0.3, 0.4, 0.5, 0.6], # 4D vector + [0.7, 0.8, 0.9] # 3D vector + ], + limit=5 + ) + + # Mock search results for each vector + mock_results_1 = [{"entity": {"entity": "entity_2d"}}] + mock_results_2 = [{"entity": {"entity": "entity_4d"}}] + mock_results_3 = [{"entity": {"entity": "entity_3d"}}] + processor.vecstore.search.side_effect = [mock_results_1, mock_results_2, mock_results_3] + + result = await processor.query_graph_embeddings(query) + + # Verify all vectors were searched + assert processor.vecstore.search.call_count == 3 + + # Verify results from all dimensions + assert len(result) == 3 + entity_values = [r.value for r in result] + assert "entity_2d" in entity_values + assert "entity_4d" in entity_values + assert "entity_3d" in entity_values \ No newline at end of file diff --git a/tests/unit/test_query/test_graph_embeddings_pinecone_query.py b/tests/unit/test_query/test_graph_embeddings_pinecone_query.py new file mode 100644 index 00000000..5352e002 --- /dev/null +++ b/tests/unit/test_query/test_graph_embeddings_pinecone_query.py @@ -0,0 +1,507 @@ +""" +Tests for Pinecone graph embeddings query service +""" + +import pytest +from unittest.mock import MagicMock, patch + +from trustgraph.query.graph_embeddings.pinecone.service import Processor +from trustgraph.schema import Value + + +class TestPineconeGraphEmbeddingsQueryProcessor: + """Test cases for Pinecone graph embeddings query processor""" + + @pytest.fixture + def mock_query_message(self): + """Create a mock query message for testing""" + message = MagicMock() + message.vectors = [ + [0.1, 0.2, 0.3], + [0.4, 0.5, 0.6] + ] + message.limit = 5 + message.user = 'test_user' + message.collection = 'test_collection' + return message + + @pytest.fixture + def processor(self): + """Create a processor instance for testing""" + with patch('trustgraph.query.graph_embeddings.pinecone.service.Pinecone') as mock_pinecone_class: + mock_pinecone = MagicMock() + mock_pinecone_class.return_value = mock_pinecone + + processor = Processor( + taskgroup=MagicMock(), + id='test-pinecone-ge-query', + api_key='test-api-key' + ) + + return processor + + @patch('trustgraph.query.graph_embeddings.pinecone.service.Pinecone') + @patch('trustgraph.query.graph_embeddings.pinecone.service.default_api_key', 'env-api-key') + def test_processor_initialization_with_defaults(self, mock_pinecone_class): + """Test processor initialization with default parameters""" + mock_pinecone = MagicMock() + mock_pinecone_class.return_value = mock_pinecone + taskgroup_mock = MagicMock() + + processor = Processor(taskgroup=taskgroup_mock) + + mock_pinecone_class.assert_called_once_with(api_key='env-api-key') + assert processor.pinecone == mock_pinecone + assert processor.api_key == 'env-api-key' + + @patch('trustgraph.query.graph_embeddings.pinecone.service.Pinecone') + def test_processor_initialization_with_custom_params(self, mock_pinecone_class): + """Test processor initialization with custom parameters""" + mock_pinecone = MagicMock() + mock_pinecone_class.return_value = mock_pinecone + taskgroup_mock = MagicMock() + + processor = Processor( + taskgroup=taskgroup_mock, + api_key='custom-api-key' + ) + + mock_pinecone_class.assert_called_once_with(api_key='custom-api-key') + assert processor.api_key == 'custom-api-key' + + @patch('trustgraph.query.graph_embeddings.pinecone.service.PineconeGRPC') + def test_processor_initialization_with_url(self, mock_pinecone_grpc_class): + """Test processor initialization with custom URL (GRPC mode)""" + mock_pinecone = MagicMock() + mock_pinecone_grpc_class.return_value = mock_pinecone + taskgroup_mock = MagicMock() + + processor = Processor( + taskgroup=taskgroup_mock, + api_key='test-api-key', + url='https://custom-host.pinecone.io' + ) + + mock_pinecone_grpc_class.assert_called_once_with( + api_key='test-api-key', + host='https://custom-host.pinecone.io' + ) + assert processor.pinecone == mock_pinecone + assert processor.url == 'https://custom-host.pinecone.io' + + @patch('trustgraph.query.graph_embeddings.pinecone.service.default_api_key', 'not-specified') + def test_processor_initialization_missing_api_key(self): + """Test processor initialization fails with missing API key""" + taskgroup_mock = MagicMock() + + with pytest.raises(RuntimeError, match="Pinecone API key must be specified"): + Processor(taskgroup=taskgroup_mock) + + def test_create_value_uri(self, processor): + """Test create_value method for URI entities""" + uri_entity = "http://example.org/entity" + value = processor.create_value(uri_entity) + + assert isinstance(value, Value) + assert value.value == uri_entity + assert value.is_uri == True + + def test_create_value_https_uri(self, processor): + """Test create_value method for HTTPS URI entities""" + uri_entity = "https://example.org/entity" + value = processor.create_value(uri_entity) + + assert isinstance(value, Value) + assert value.value == uri_entity + assert value.is_uri == True + + def test_create_value_literal(self, processor): + """Test create_value method for literal entities""" + literal_entity = "literal_entity" + value = processor.create_value(literal_entity) + + assert isinstance(value, Value) + assert value.value == literal_entity + assert value.is_uri == False + + @pytest.mark.asyncio + async def test_query_graph_embeddings_single_vector(self, processor): + """Test querying graph embeddings with a single vector""" + message = MagicMock() + message.vectors = [[0.1, 0.2, 0.3]] + message.limit = 3 + message.user = 'test_user' + message.collection = 'test_collection' + + # Mock index and query results + mock_index = MagicMock() + processor.pinecone.Index.return_value = mock_index + + mock_results = MagicMock() + mock_results.matches = [ + MagicMock(metadata={'entity': 'http://example.org/entity1'}), + MagicMock(metadata={'entity': 'entity2'}), + MagicMock(metadata={'entity': 'http://example.org/entity3'}) + ] + mock_index.query.return_value = mock_results + + entities = await processor.query_graph_embeddings(message) + + # Verify index was accessed correctly + expected_index_name = "t-test_user-test_collection-3" + processor.pinecone.Index.assert_called_once_with(expected_index_name) + + # Verify query parameters + mock_index.query.assert_called_once_with( + vector=[0.1, 0.2, 0.3], + top_k=6, # 2 * limit + include_values=False, + include_metadata=True + ) + + # Verify results + assert len(entities) == 3 + assert entities[0].value == 'http://example.org/entity1' + assert entities[0].is_uri == True + assert entities[1].value == 'entity2' + assert entities[1].is_uri == False + assert entities[2].value == 'http://example.org/entity3' + assert entities[2].is_uri == True + + @pytest.mark.asyncio + async def test_query_graph_embeddings_multiple_vectors(self, processor, mock_query_message): + """Test querying graph embeddings with multiple vectors""" + # Mock index and query results + mock_index = MagicMock() + processor.pinecone.Index.return_value = mock_index + + # First query results + mock_results1 = MagicMock() + mock_results1.matches = [ + MagicMock(metadata={'entity': 'entity1'}), + MagicMock(metadata={'entity': 'entity2'}) + ] + + # Second query results + mock_results2 = MagicMock() + mock_results2.matches = [ + MagicMock(metadata={'entity': 'entity2'}), # Duplicate + MagicMock(metadata={'entity': 'entity3'}) + ] + + mock_index.query.side_effect = [mock_results1, mock_results2] + + entities = await processor.query_graph_embeddings(mock_query_message) + + # Verify both queries were made + assert mock_index.query.call_count == 2 + + # Verify deduplication occurred + entity_values = [e.value for e in entities] + assert len(entity_values) == 3 + assert 'entity1' in entity_values + assert 'entity2' in entity_values + assert 'entity3' in entity_values + + @pytest.mark.asyncio + async def test_query_graph_embeddings_limit_handling(self, processor): + """Test that query respects the limit parameter""" + message = MagicMock() + message.vectors = [[0.1, 0.2, 0.3]] + message.limit = 2 + message.user = 'test_user' + message.collection = 'test_collection' + + # Mock index with many results + mock_index = MagicMock() + processor.pinecone.Index.return_value = mock_index + + mock_results = MagicMock() + mock_results.matches = [ + MagicMock(metadata={'entity': f'entity{i}'}) for i in range(10) + ] + mock_index.query.return_value = mock_results + + entities = await processor.query_graph_embeddings(message) + + # Verify limit is respected + assert len(entities) == 2 + + @pytest.mark.asyncio + async def test_query_graph_embeddings_zero_limit(self, processor): + """Test querying with zero limit returns empty results""" + message = MagicMock() + message.vectors = [[0.1, 0.2, 0.3]] + message.limit = 0 + message.user = 'test_user' + message.collection = 'test_collection' + + mock_index = MagicMock() + processor.pinecone.Index.return_value = mock_index + + entities = await processor.query_graph_embeddings(message) + + # Verify no query was made and empty result returned + mock_index.query.assert_not_called() + assert entities == [] + + @pytest.mark.asyncio + async def test_query_graph_embeddings_negative_limit(self, processor): + """Test querying with negative limit returns empty results""" + message = MagicMock() + message.vectors = [[0.1, 0.2, 0.3]] + message.limit = -1 + message.user = 'test_user' + message.collection = 'test_collection' + + mock_index = MagicMock() + processor.pinecone.Index.return_value = mock_index + + entities = await processor.query_graph_embeddings(message) + + # Verify no query was made and empty result returned + mock_index.query.assert_not_called() + assert entities == [] + + @pytest.mark.asyncio + async def test_query_graph_embeddings_different_vector_dimensions(self, processor): + """Test querying with vectors of different dimensions""" + message = MagicMock() + message.vectors = [ + [0.1, 0.2], # 2D vector + [0.3, 0.4, 0.5, 0.6] # 4D vector + ] + message.limit = 5 + message.user = 'test_user' + message.collection = 'test_collection' + + mock_index_2d = MagicMock() + mock_index_4d = MagicMock() + + def mock_index_side_effect(name): + if name.endswith("-2"): + return mock_index_2d + elif name.endswith("-4"): + return mock_index_4d + + processor.pinecone.Index.side_effect = mock_index_side_effect + + # Mock results for different dimensions + mock_results_2d = MagicMock() + mock_results_2d.matches = [MagicMock(metadata={'entity': 'entity_2d'})] + mock_index_2d.query.return_value = mock_results_2d + + mock_results_4d = MagicMock() + mock_results_4d.matches = [MagicMock(metadata={'entity': 'entity_4d'})] + mock_index_4d.query.return_value = mock_results_4d + + entities = await processor.query_graph_embeddings(message) + + # Verify different indexes were used + assert processor.pinecone.Index.call_count == 2 + mock_index_2d.query.assert_called_once() + mock_index_4d.query.assert_called_once() + + # Verify results from both dimensions + entity_values = [e.value for e in entities] + assert 'entity_2d' in entity_values + assert 'entity_4d' in entity_values + + @pytest.mark.asyncio + async def test_query_graph_embeddings_empty_vectors_list(self, processor): + """Test querying with empty vectors list""" + message = MagicMock() + message.vectors = [] + message.limit = 5 + message.user = 'test_user' + message.collection = 'test_collection' + + mock_index = MagicMock() + processor.pinecone.Index.return_value = mock_index + + entities = await processor.query_graph_embeddings(message) + + # Verify no queries were made and empty result returned + processor.pinecone.Index.assert_not_called() + mock_index.query.assert_not_called() + assert entities == [] + + @pytest.mark.asyncio + async def test_query_graph_embeddings_no_results(self, processor): + """Test querying when index returns no results""" + message = MagicMock() + message.vectors = [[0.1, 0.2, 0.3]] + message.limit = 5 + message.user = 'test_user' + message.collection = 'test_collection' + + mock_index = MagicMock() + processor.pinecone.Index.return_value = mock_index + + mock_results = MagicMock() + mock_results.matches = [] + mock_index.query.return_value = mock_results + + entities = await processor.query_graph_embeddings(message) + + # Verify empty results + assert entities == [] + + @pytest.mark.asyncio + async def test_query_graph_embeddings_deduplication_across_vectors(self, processor): + """Test that deduplication works correctly across multiple vector queries""" + message = MagicMock() + message.vectors = [ + [0.1, 0.2, 0.3], + [0.4, 0.5, 0.6] + ] + message.limit = 3 + message.user = 'test_user' + message.collection = 'test_collection' + + mock_index = MagicMock() + processor.pinecone.Index.return_value = mock_index + + # Both queries return overlapping results + mock_results1 = MagicMock() + mock_results1.matches = [ + MagicMock(metadata={'entity': 'entity1'}), + MagicMock(metadata={'entity': 'entity2'}), + MagicMock(metadata={'entity': 'entity3'}), + MagicMock(metadata={'entity': 'entity4'}) + ] + + mock_results2 = MagicMock() + mock_results2.matches = [ + MagicMock(metadata={'entity': 'entity2'}), # Duplicate + MagicMock(metadata={'entity': 'entity3'}), # Duplicate + MagicMock(metadata={'entity': 'entity5'}) + ] + + mock_index.query.side_effect = [mock_results1, mock_results2] + + entities = await processor.query_graph_embeddings(message) + + # Should get exactly 3 unique entities (respecting limit) + assert len(entities) == 3 + entity_values = [e.value for e in entities] + assert len(set(entity_values)) == 3 # All unique + + @pytest.mark.asyncio + async def test_query_graph_embeddings_early_termination_on_limit(self, processor): + """Test that querying stops early when limit is reached""" + message = MagicMock() + message.vectors = [ + [0.1, 0.2, 0.3], + [0.4, 0.5, 0.6], + [0.7, 0.8, 0.9] + ] + message.limit = 2 + message.user = 'test_user' + message.collection = 'test_collection' + + mock_index = MagicMock() + processor.pinecone.Index.return_value = mock_index + + # First query returns enough results to meet limit + mock_results1 = MagicMock() + mock_results1.matches = [ + MagicMock(metadata={'entity': 'entity1'}), + MagicMock(metadata={'entity': 'entity2'}), + MagicMock(metadata={'entity': 'entity3'}) + ] + mock_index.query.return_value = mock_results1 + + entities = await processor.query_graph_embeddings(message) + + # Should only make one query since limit was reached + mock_index.query.assert_called_once() + assert len(entities) == 2 + + @pytest.mark.asyncio + async def test_query_graph_embeddings_exception_handling(self, processor): + """Test that exceptions are properly raised""" + message = MagicMock() + message.vectors = [[0.1, 0.2, 0.3]] + message.limit = 5 + message.user = 'test_user' + message.collection = 'test_collection' + + mock_index = MagicMock() + processor.pinecone.Index.return_value = mock_index + mock_index.query.side_effect = Exception("Query failed") + + with pytest.raises(Exception, match="Query failed"): + await processor.query_graph_embeddings(message) + + def test_add_args_method(self): + """Test that add_args properly configures argument parser""" + from argparse import ArgumentParser + from unittest.mock import patch + + parser = ArgumentParser() + + # Mock the parent class add_args method + with patch('trustgraph.query.graph_embeddings.pinecone.service.GraphEmbeddingsQueryService.add_args') as mock_parent_add_args: + Processor.add_args(parser) + + # Verify parent add_args was called + mock_parent_add_args.assert_called_once() + + # Verify our specific arguments were added + args = parser.parse_args([]) + + assert hasattr(args, 'api_key') + assert args.api_key == 'not-specified' # Default value when no env var + assert hasattr(args, 'url') + assert args.url is None + + def test_add_args_with_custom_values(self): + """Test add_args with custom command line values""" + from argparse import ArgumentParser + from unittest.mock import patch + + parser = ArgumentParser() + + with patch('trustgraph.query.graph_embeddings.pinecone.service.GraphEmbeddingsQueryService.add_args'): + Processor.add_args(parser) + + # Test parsing with custom values + args = parser.parse_args([ + '--api-key', 'custom-api-key', + '--url', 'https://custom-host.pinecone.io' + ]) + + assert args.api_key == 'custom-api-key' + assert args.url == 'https://custom-host.pinecone.io' + + def test_add_args_short_form(self): + """Test add_args with short form arguments""" + from argparse import ArgumentParser + from unittest.mock import patch + + parser = ArgumentParser() + + with patch('trustgraph.query.graph_embeddings.pinecone.service.GraphEmbeddingsQueryService.add_args'): + Processor.add_args(parser) + + # Test parsing with short form + args = parser.parse_args([ + '-a', 'short-api-key', + '-u', 'https://short-host.pinecone.io' + ]) + + assert args.api_key == 'short-api-key' + assert args.url == 'https://short-host.pinecone.io' + + @patch('trustgraph.query.graph_embeddings.pinecone.service.Processor.launch') + def test_run_function(self, mock_launch): + """Test the run function calls Processor.launch with correct parameters""" + from trustgraph.query.graph_embeddings.pinecone.service import run, default_ident + + run() + + mock_launch.assert_called_once_with( + default_ident, + "\nGraph embeddings query service. Input is vector, output is list of\nentities. Pinecone implementation.\n" + ) \ No newline at end of file diff --git a/tests/unit/test_query/test_triples_falkordb_query.py b/tests/unit/test_query/test_triples_falkordb_query.py new file mode 100644 index 00000000..3e7d07db --- /dev/null +++ b/tests/unit/test_query/test_triples_falkordb_query.py @@ -0,0 +1,556 @@ +""" +Tests for FalkorDB triples query service +""" + +import pytest +from unittest.mock import MagicMock, patch + +from trustgraph.query.triples.falkordb.service import Processor +from trustgraph.schema import Value, TriplesQueryRequest + + +class TestFalkorDBQueryProcessor: + """Test cases for FalkorDB query processor""" + + @pytest.fixture + def processor(self): + """Create a processor instance for testing""" + with patch('trustgraph.query.triples.falkordb.service.FalkorDB'): + return Processor( + taskgroup=MagicMock(), + id='test-falkordb-query', + graph_url='falkor://localhost:6379' + ) + + def test_create_value_with_http_uri(self, processor): + """Test create_value with HTTP URI""" + result = processor.create_value("http://example.com/resource") + + assert isinstance(result, Value) + assert result.value == "http://example.com/resource" + assert result.is_uri is True + + def test_create_value_with_https_uri(self, processor): + """Test create_value with HTTPS URI""" + result = processor.create_value("https://example.com/resource") + + assert isinstance(result, Value) + assert result.value == "https://example.com/resource" + assert result.is_uri is True + + def test_create_value_with_literal(self, processor): + """Test create_value with literal value""" + result = processor.create_value("just a literal string") + + assert isinstance(result, Value) + assert result.value == "just a literal string" + assert result.is_uri is False + + def test_create_value_with_empty_string(self, processor): + """Test create_value with empty string""" + result = processor.create_value("") + + assert isinstance(result, Value) + assert result.value == "" + assert result.is_uri is False + + def test_create_value_with_partial_uri(self, processor): + """Test create_value with string that looks like URI but isn't complete""" + result = processor.create_value("http") + + assert isinstance(result, Value) + assert result.value == "http" + assert result.is_uri is False + + def test_create_value_with_ftp_uri(self, processor): + """Test create_value with FTP URI (should not be detected as URI)""" + result = processor.create_value("ftp://example.com/file") + + assert isinstance(result, Value) + assert result.value == "ftp://example.com/file" + assert result.is_uri is False + + @patch('trustgraph.query.triples.falkordb.service.FalkorDB') + def test_processor_initialization_with_defaults(self, mock_falkordb): + """Test processor initialization with default parameters""" + taskgroup_mock = MagicMock() + mock_client = MagicMock() + mock_graph = MagicMock() + mock_falkordb.from_url.return_value = mock_client + mock_client.select_graph.return_value = mock_graph + + processor = Processor(taskgroup=taskgroup_mock) + + assert processor.db == 'falkordb' + mock_falkordb.from_url.assert_called_once_with('falkor://falkordb:6379') + mock_client.select_graph.assert_called_once_with('falkordb') + + @patch('trustgraph.query.triples.falkordb.service.FalkorDB') + def test_processor_initialization_with_custom_params(self, mock_falkordb): + """Test processor initialization with custom parameters""" + taskgroup_mock = MagicMock() + mock_client = MagicMock() + mock_graph = MagicMock() + mock_falkordb.from_url.return_value = mock_client + mock_client.select_graph.return_value = mock_graph + + processor = Processor( + taskgroup=taskgroup_mock, + graph_url='falkor://custom:6379', + database='customdb' + ) + + assert processor.db == 'customdb' + mock_falkordb.from_url.assert_called_once_with('falkor://custom:6379') + mock_client.select_graph.assert_called_once_with('customdb') + + @patch('trustgraph.query.triples.falkordb.service.FalkorDB') + @pytest.mark.asyncio + async def test_query_triples_spo_query(self, mock_falkordb): + """Test SPO query (all values specified)""" + taskgroup_mock = MagicMock() + mock_client = MagicMock() + mock_graph = MagicMock() + mock_falkordb.from_url.return_value = mock_client + mock_client.select_graph.return_value = mock_graph + + # Mock query results - both queries return one record each + mock_result = MagicMock() + mock_result.result_set = [["record1"]] + mock_graph.query.return_value = mock_result + + processor = Processor(taskgroup=taskgroup_mock) + + # Create query request + query = TriplesQueryRequest( + user='test_user', + collection='test_collection', + s=Value(value="http://example.com/subject", is_uri=True), + p=Value(value="http://example.com/predicate", is_uri=True), + o=Value(value="literal object", is_uri=False), + limit=100 + ) + + result = await processor.query_triples(query) + + # Verify both literal and URI queries were executed + assert mock_graph.query.call_count == 2 + + # Verify result contains the queried triple (appears twice - once from each query) + assert len(result) == 2 + assert result[0].s.value == "http://example.com/subject" + assert result[0].p.value == "http://example.com/predicate" + assert result[0].o.value == "literal object" + + @patch('trustgraph.query.triples.falkordb.service.FalkorDB') + @pytest.mark.asyncio + async def test_query_triples_sp_query(self, mock_falkordb): + """Test SP query (subject and predicate specified)""" + taskgroup_mock = MagicMock() + mock_client = MagicMock() + mock_graph = MagicMock() + mock_falkordb.from_url.return_value = mock_client + mock_client.select_graph.return_value = mock_graph + + # Mock query results with different objects + mock_result1 = MagicMock() + mock_result1.result_set = [["literal result"]] + mock_result2 = MagicMock() + mock_result2.result_set = [["http://example.com/uri_result"]] + + mock_graph.query.side_effect = [mock_result1, mock_result2] + + processor = Processor(taskgroup=taskgroup_mock) + + # Create query request + query = TriplesQueryRequest( + user='test_user', + collection='test_collection', + s=Value(value="http://example.com/subject", is_uri=True), + p=Value(value="http://example.com/predicate", is_uri=True), + o=None, + limit=100 + ) + + result = await processor.query_triples(query) + + # Verify both literal and URI queries were executed + assert mock_graph.query.call_count == 2 + + # Verify results contain different objects + assert len(result) == 2 + assert result[0].s.value == "http://example.com/subject" + assert result[0].p.value == "http://example.com/predicate" + assert result[0].o.value == "literal result" + + assert result[1].s.value == "http://example.com/subject" + assert result[1].p.value == "http://example.com/predicate" + assert result[1].o.value == "http://example.com/uri_result" + + @patch('trustgraph.query.triples.falkordb.service.FalkorDB') + @pytest.mark.asyncio + async def test_query_triples_so_query(self, mock_falkordb): + """Test SO query (subject and object specified)""" + taskgroup_mock = MagicMock() + mock_client = MagicMock() + mock_graph = MagicMock() + mock_falkordb.from_url.return_value = mock_client + mock_client.select_graph.return_value = mock_graph + + # Mock query results with different predicates + mock_result1 = MagicMock() + mock_result1.result_set = [["http://example.com/pred1"]] + mock_result2 = MagicMock() + mock_result2.result_set = [["http://example.com/pred2"]] + + mock_graph.query.side_effect = [mock_result1, mock_result2] + + processor = Processor(taskgroup=taskgroup_mock) + + # Create query request + query = TriplesQueryRequest( + user='test_user', + collection='test_collection', + s=Value(value="http://example.com/subject", is_uri=True), + p=None, + o=Value(value="literal object", is_uri=False), + limit=100 + ) + + result = await processor.query_triples(query) + + # Verify both literal and URI queries were executed + assert mock_graph.query.call_count == 2 + + # Verify results contain different predicates + assert len(result) == 2 + assert result[0].s.value == "http://example.com/subject" + assert result[0].p.value == "http://example.com/pred1" + assert result[0].o.value == "literal object" + + assert result[1].s.value == "http://example.com/subject" + assert result[1].p.value == "http://example.com/pred2" + assert result[1].o.value == "literal object" + + @patch('trustgraph.query.triples.falkordb.service.FalkorDB') + @pytest.mark.asyncio + async def test_query_triples_s_query(self, mock_falkordb): + """Test S query (subject only)""" + taskgroup_mock = MagicMock() + mock_client = MagicMock() + mock_graph = MagicMock() + mock_falkordb.from_url.return_value = mock_client + mock_client.select_graph.return_value = mock_graph + + # Mock query results with different predicate-object pairs + mock_result1 = MagicMock() + mock_result1.result_set = [["http://example.com/pred1", "literal1"]] + mock_result2 = MagicMock() + mock_result2.result_set = [["http://example.com/pred2", "http://example.com/uri2"]] + + mock_graph.query.side_effect = [mock_result1, mock_result2] + + processor = Processor(taskgroup=taskgroup_mock) + + # Create query request + query = TriplesQueryRequest( + user='test_user', + collection='test_collection', + s=Value(value="http://example.com/subject", is_uri=True), + p=None, + o=None, + limit=100 + ) + + result = await processor.query_triples(query) + + # Verify both literal and URI queries were executed + assert mock_graph.query.call_count == 2 + + # Verify results contain different predicate-object pairs + assert len(result) == 2 + assert result[0].s.value == "http://example.com/subject" + assert result[0].p.value == "http://example.com/pred1" + assert result[0].o.value == "literal1" + + assert result[1].s.value == "http://example.com/subject" + assert result[1].p.value == "http://example.com/pred2" + assert result[1].o.value == "http://example.com/uri2" + + @patch('trustgraph.query.triples.falkordb.service.FalkorDB') + @pytest.mark.asyncio + async def test_query_triples_po_query(self, mock_falkordb): + """Test PO query (predicate and object specified)""" + taskgroup_mock = MagicMock() + mock_client = MagicMock() + mock_graph = MagicMock() + mock_falkordb.from_url.return_value = mock_client + mock_client.select_graph.return_value = mock_graph + + # Mock query results with different subjects + mock_result1 = MagicMock() + mock_result1.result_set = [["http://example.com/subj1"]] + mock_result2 = MagicMock() + mock_result2.result_set = [["http://example.com/subj2"]] + + mock_graph.query.side_effect = [mock_result1, mock_result2] + + processor = Processor(taskgroup=taskgroup_mock) + + # Create query request + query = TriplesQueryRequest( + user='test_user', + collection='test_collection', + s=None, + p=Value(value="http://example.com/predicate", is_uri=True), + o=Value(value="literal object", is_uri=False), + limit=100 + ) + + result = await processor.query_triples(query) + + # Verify both literal and URI queries were executed + assert mock_graph.query.call_count == 2 + + # Verify results contain different subjects + assert len(result) == 2 + assert result[0].s.value == "http://example.com/subj1" + assert result[0].p.value == "http://example.com/predicate" + assert result[0].o.value == "literal object" + + assert result[1].s.value == "http://example.com/subj2" + assert result[1].p.value == "http://example.com/predicate" + assert result[1].o.value == "literal object" + + @patch('trustgraph.query.triples.falkordb.service.FalkorDB') + @pytest.mark.asyncio + async def test_query_triples_p_query(self, mock_falkordb): + """Test P query (predicate only)""" + taskgroup_mock = MagicMock() + mock_client = MagicMock() + mock_graph = MagicMock() + mock_falkordb.from_url.return_value = mock_client + mock_client.select_graph.return_value = mock_graph + + # Mock query results with different subject-object pairs + mock_result1 = MagicMock() + mock_result1.result_set = [["http://example.com/subj1", "literal1"]] + mock_result2 = MagicMock() + mock_result2.result_set = [["http://example.com/subj2", "http://example.com/uri2"]] + + mock_graph.query.side_effect = [mock_result1, mock_result2] + + processor = Processor(taskgroup=taskgroup_mock) + + # Create query request + query = TriplesQueryRequest( + user='test_user', + collection='test_collection', + s=None, + p=Value(value="http://example.com/predicate", is_uri=True), + o=None, + limit=100 + ) + + result = await processor.query_triples(query) + + # Verify both literal and URI queries were executed + assert mock_graph.query.call_count == 2 + + # Verify results contain different subject-object pairs + assert len(result) == 2 + assert result[0].s.value == "http://example.com/subj1" + assert result[0].p.value == "http://example.com/predicate" + assert result[0].o.value == "literal1" + + assert result[1].s.value == "http://example.com/subj2" + assert result[1].p.value == "http://example.com/predicate" + assert result[1].o.value == "http://example.com/uri2" + + @patch('trustgraph.query.triples.falkordb.service.FalkorDB') + @pytest.mark.asyncio + async def test_query_triples_o_query(self, mock_falkordb): + """Test O query (object only)""" + taskgroup_mock = MagicMock() + mock_client = MagicMock() + mock_graph = MagicMock() + mock_falkordb.from_url.return_value = mock_client + mock_client.select_graph.return_value = mock_graph + + # Mock query results with different subject-predicate pairs + mock_result1 = MagicMock() + mock_result1.result_set = [["http://example.com/subj1", "http://example.com/pred1"]] + mock_result2 = MagicMock() + mock_result2.result_set = [["http://example.com/subj2", "http://example.com/pred2"]] + + mock_graph.query.side_effect = [mock_result1, mock_result2] + + processor = Processor(taskgroup=taskgroup_mock) + + # Create query request + query = TriplesQueryRequest( + user='test_user', + collection='test_collection', + s=None, + p=None, + o=Value(value="literal object", is_uri=False), + limit=100 + ) + + result = await processor.query_triples(query) + + # Verify both literal and URI queries were executed + assert mock_graph.query.call_count == 2 + + # Verify results contain different subject-predicate pairs + assert len(result) == 2 + assert result[0].s.value == "http://example.com/subj1" + assert result[0].p.value == "http://example.com/pred1" + assert result[0].o.value == "literal object" + + assert result[1].s.value == "http://example.com/subj2" + assert result[1].p.value == "http://example.com/pred2" + assert result[1].o.value == "literal object" + + @patch('trustgraph.query.triples.falkordb.service.FalkorDB') + @pytest.mark.asyncio + async def test_query_triples_wildcard_query(self, mock_falkordb): + """Test wildcard query (no constraints)""" + taskgroup_mock = MagicMock() + mock_client = MagicMock() + mock_graph = MagicMock() + mock_falkordb.from_url.return_value = mock_client + mock_client.select_graph.return_value = mock_graph + + # Mock query results + mock_result1 = MagicMock() + mock_result1.result_set = [["http://example.com/s1", "http://example.com/p1", "literal1"]] + mock_result2 = MagicMock() + mock_result2.result_set = [["http://example.com/s2", "http://example.com/p2", "http://example.com/o2"]] + + mock_graph.query.side_effect = [mock_result1, mock_result2] + + processor = Processor(taskgroup=taskgroup_mock) + + # Create query request + query = TriplesQueryRequest( + user='test_user', + collection='test_collection', + s=None, + p=None, + o=None, + limit=100 + ) + + result = await processor.query_triples(query) + + # Verify both literal and URI queries were executed + assert mock_graph.query.call_count == 2 + + # Verify results contain different triples + assert len(result) == 2 + assert result[0].s.value == "http://example.com/s1" + assert result[0].p.value == "http://example.com/p1" + assert result[0].o.value == "literal1" + + assert result[1].s.value == "http://example.com/s2" + assert result[1].p.value == "http://example.com/p2" + assert result[1].o.value == "http://example.com/o2" + + @patch('trustgraph.query.triples.falkordb.service.FalkorDB') + @pytest.mark.asyncio + async def test_query_triples_exception_handling(self, mock_falkordb): + """Test exception handling during query processing""" + taskgroup_mock = MagicMock() + mock_client = MagicMock() + mock_graph = MagicMock() + mock_falkordb.from_url.return_value = mock_client + mock_client.select_graph.return_value = mock_graph + + # Mock query to raise exception + mock_graph.query.side_effect = Exception("Database connection failed") + + processor = Processor(taskgroup=taskgroup_mock) + + # Create query request + query = TriplesQueryRequest( + user='test_user', + collection='test_collection', + s=Value(value="http://example.com/subject", is_uri=True), + p=None, + o=None, + limit=100 + ) + + # Should raise the exception + with pytest.raises(Exception, match="Database connection failed"): + await processor.query_triples(query) + + def test_add_args_method(self): + """Test that add_args properly configures argument parser""" + from argparse import ArgumentParser + from unittest.mock import patch + + parser = ArgumentParser() + + # Mock the parent class add_args method + with patch('trustgraph.query.triples.falkordb.service.TriplesQueryService.add_args') as mock_parent_add_args: + Processor.add_args(parser) + + # Verify parent add_args was called + mock_parent_add_args.assert_called_once() + + # Verify our specific arguments were added + # Parse empty args to check defaults + args = parser.parse_args([]) + + assert hasattr(args, 'graph_url') + assert args.graph_url == 'falkor://falkordb:6379' + assert hasattr(args, 'database') + assert args.database == 'falkordb' + + def test_add_args_with_custom_values(self): + """Test add_args with custom command line values""" + from argparse import ArgumentParser + from unittest.mock import patch + + parser = ArgumentParser() + + with patch('trustgraph.query.triples.falkordb.service.TriplesQueryService.add_args'): + Processor.add_args(parser) + + # Test parsing with custom values + args = parser.parse_args([ + '--graph-url', 'falkor://custom:6379', + '--database', 'querydb' + ]) + + assert args.graph_url == 'falkor://custom:6379' + assert args.database == 'querydb' + + def test_add_args_short_form(self): + """Test add_args with short form arguments""" + from argparse import ArgumentParser + from unittest.mock import patch + + parser = ArgumentParser() + + with patch('trustgraph.query.triples.falkordb.service.TriplesQueryService.add_args'): + Processor.add_args(parser) + + # Test parsing with short form + args = parser.parse_args(['-g', 'falkor://short:6379']) + + assert args.graph_url == 'falkor://short:6379' + + @patch('trustgraph.query.triples.falkordb.service.Processor.launch') + def test_run_function(self, mock_launch): + """Test the run function calls Processor.launch with correct parameters""" + from trustgraph.query.triples.falkordb.service import run, default_ident + + run() + + mock_launch.assert_called_once_with( + default_ident, + "\nTriples query service for FalkorDB.\nInput is a (s, p, o) triple, some values may be null. Output is a list of\ntriples.\n" + ) \ No newline at end of file diff --git a/tests/unit/test_query/test_triples_memgraph_query.py b/tests/unit/test_query/test_triples_memgraph_query.py new file mode 100644 index 00000000..bd394ae4 --- /dev/null +++ b/tests/unit/test_query/test_triples_memgraph_query.py @@ -0,0 +1,568 @@ +""" +Tests for Memgraph triples query service +""" + +import pytest +from unittest.mock import MagicMock, patch + +from trustgraph.query.triples.memgraph.service import Processor +from trustgraph.schema import Value, TriplesQueryRequest + + +class TestMemgraphQueryProcessor: + """Test cases for Memgraph query processor""" + + @pytest.fixture + def processor(self): + """Create a processor instance for testing""" + with patch('trustgraph.query.triples.memgraph.service.GraphDatabase'): + return Processor( + taskgroup=MagicMock(), + id='test-memgraph-query', + graph_host='bolt://localhost:7687' + ) + + def test_create_value_with_http_uri(self, processor): + """Test create_value with HTTP URI""" + result = processor.create_value("http://example.com/resource") + + assert isinstance(result, Value) + assert result.value == "http://example.com/resource" + assert result.is_uri is True + + def test_create_value_with_https_uri(self, processor): + """Test create_value with HTTPS URI""" + result = processor.create_value("https://example.com/resource") + + assert isinstance(result, Value) + assert result.value == "https://example.com/resource" + assert result.is_uri is True + + def test_create_value_with_literal(self, processor): + """Test create_value with literal value""" + result = processor.create_value("just a literal string") + + assert isinstance(result, Value) + assert result.value == "just a literal string" + assert result.is_uri is False + + def test_create_value_with_empty_string(self, processor): + """Test create_value with empty string""" + result = processor.create_value("") + + assert isinstance(result, Value) + assert result.value == "" + assert result.is_uri is False + + def test_create_value_with_partial_uri(self, processor): + """Test create_value with string that looks like URI but isn't complete""" + result = processor.create_value("http") + + assert isinstance(result, Value) + assert result.value == "http" + assert result.is_uri is False + + def test_create_value_with_ftp_uri(self, processor): + """Test create_value with FTP URI (should not be detected as URI)""" + result = processor.create_value("ftp://example.com/file") + + assert isinstance(result, Value) + assert result.value == "ftp://example.com/file" + assert result.is_uri is False + + @patch('trustgraph.query.triples.memgraph.service.GraphDatabase') + def test_processor_initialization_with_defaults(self, mock_graph_db): + """Test processor initialization with default parameters""" + taskgroup_mock = MagicMock() + mock_driver = MagicMock() + mock_graph_db.driver.return_value = mock_driver + + processor = Processor(taskgroup=taskgroup_mock) + + assert processor.db == 'memgraph' + mock_graph_db.driver.assert_called_once_with( + 'bolt://memgraph:7687', + auth=('memgraph', 'password') + ) + + @patch('trustgraph.query.triples.memgraph.service.GraphDatabase') + def test_processor_initialization_with_custom_params(self, mock_graph_db): + """Test processor initialization with custom parameters""" + taskgroup_mock = MagicMock() + mock_driver = MagicMock() + mock_graph_db.driver.return_value = mock_driver + + processor = Processor( + taskgroup=taskgroup_mock, + graph_host='bolt://custom:7687', + username='queryuser', + password='querypass', + database='customdb' + ) + + assert processor.db == 'customdb' + mock_graph_db.driver.assert_called_once_with( + 'bolt://custom:7687', + auth=('queryuser', 'querypass') + ) + + @patch('trustgraph.query.triples.memgraph.service.GraphDatabase') + @pytest.mark.asyncio + async def test_query_triples_spo_query(self, mock_graph_db): + """Test SPO query (all values specified)""" + taskgroup_mock = MagicMock() + mock_driver = MagicMock() + mock_graph_db.driver.return_value = mock_driver + + # Mock query results - both queries return one record each + mock_records = [MagicMock()] + mock_driver.execute_query.return_value = (mock_records, None, None) + + processor = Processor(taskgroup=taskgroup_mock) + + # Create query request + query = TriplesQueryRequest( + user='test_user', + collection='test_collection', + s=Value(value="http://example.com/subject", is_uri=True), + p=Value(value="http://example.com/predicate", is_uri=True), + o=Value(value="literal object", is_uri=False), + limit=100 + ) + + result = await processor.query_triples(query) + + # Verify both literal and URI queries were executed + assert mock_driver.execute_query.call_count == 2 + + # Verify result contains the queried triple (appears twice - once from each query) + assert len(result) == 2 + assert result[0].s.value == "http://example.com/subject" + assert result[0].p.value == "http://example.com/predicate" + assert result[0].o.value == "literal object" + + @patch('trustgraph.query.triples.memgraph.service.GraphDatabase') + @pytest.mark.asyncio + async def test_query_triples_sp_query(self, mock_graph_db): + """Test SP query (subject and predicate specified)""" + taskgroup_mock = MagicMock() + mock_driver = MagicMock() + mock_graph_db.driver.return_value = mock_driver + + # Mock query results with different objects + mock_record1 = MagicMock() + mock_record1.data.return_value = {"dest": "literal result"} + mock_record2 = MagicMock() + mock_record2.data.return_value = {"dest": "http://example.com/uri_result"} + + mock_driver.execute_query.side_effect = [ + ([mock_record1], None, None), # Literal query + ([mock_record2], None, None) # URI query + ] + + processor = Processor(taskgroup=taskgroup_mock) + + # Create query request + query = TriplesQueryRequest( + user='test_user', + collection='test_collection', + s=Value(value="http://example.com/subject", is_uri=True), + p=Value(value="http://example.com/predicate", is_uri=True), + o=None, + limit=100 + ) + + result = await processor.query_triples(query) + + # Verify both literal and URI queries were executed + assert mock_driver.execute_query.call_count == 2 + + # Verify results contain different objects + assert len(result) == 2 + assert result[0].s.value == "http://example.com/subject" + assert result[0].p.value == "http://example.com/predicate" + assert result[0].o.value == "literal result" + + assert result[1].s.value == "http://example.com/subject" + assert result[1].p.value == "http://example.com/predicate" + assert result[1].o.value == "http://example.com/uri_result" + + @patch('trustgraph.query.triples.memgraph.service.GraphDatabase') + @pytest.mark.asyncio + async def test_query_triples_so_query(self, mock_graph_db): + """Test SO query (subject and object specified)""" + taskgroup_mock = MagicMock() + mock_driver = MagicMock() + mock_graph_db.driver.return_value = mock_driver + + # Mock query results with different predicates + mock_record1 = MagicMock() + mock_record1.data.return_value = {"rel": "http://example.com/pred1"} + mock_record2 = MagicMock() + mock_record2.data.return_value = {"rel": "http://example.com/pred2"} + + mock_driver.execute_query.side_effect = [ + ([mock_record1], None, None), # Literal query + ([mock_record2], None, None) # URI query + ] + + processor = Processor(taskgroup=taskgroup_mock) + + # Create query request + query = TriplesQueryRequest( + user='test_user', + collection='test_collection', + s=Value(value="http://example.com/subject", is_uri=True), + p=None, + o=Value(value="literal object", is_uri=False), + limit=100 + ) + + result = await processor.query_triples(query) + + # Verify both literal and URI queries were executed + assert mock_driver.execute_query.call_count == 2 + + # Verify results contain different predicates + assert len(result) == 2 + assert result[0].s.value == "http://example.com/subject" + assert result[0].p.value == "http://example.com/pred1" + assert result[0].o.value == "literal object" + + assert result[1].s.value == "http://example.com/subject" + assert result[1].p.value == "http://example.com/pred2" + assert result[1].o.value == "literal object" + + @patch('trustgraph.query.triples.memgraph.service.GraphDatabase') + @pytest.mark.asyncio + async def test_query_triples_s_query(self, mock_graph_db): + """Test S query (subject only)""" + taskgroup_mock = MagicMock() + mock_driver = MagicMock() + mock_graph_db.driver.return_value = mock_driver + + # Mock query results with different predicate-object pairs + mock_record1 = MagicMock() + mock_record1.data.return_value = {"rel": "http://example.com/pred1", "dest": "literal1"} + mock_record2 = MagicMock() + mock_record2.data.return_value = {"rel": "http://example.com/pred2", "dest": "http://example.com/uri2"} + + mock_driver.execute_query.side_effect = [ + ([mock_record1], None, None), # Literal query + ([mock_record2], None, None) # URI query + ] + + processor = Processor(taskgroup=taskgroup_mock) + + # Create query request + query = TriplesQueryRequest( + user='test_user', + collection='test_collection', + s=Value(value="http://example.com/subject", is_uri=True), + p=None, + o=None, + limit=100 + ) + + result = await processor.query_triples(query) + + # Verify both literal and URI queries were executed + assert mock_driver.execute_query.call_count == 2 + + # Verify results contain different predicate-object pairs + assert len(result) == 2 + assert result[0].s.value == "http://example.com/subject" + assert result[0].p.value == "http://example.com/pred1" + assert result[0].o.value == "literal1" + + assert result[1].s.value == "http://example.com/subject" + assert result[1].p.value == "http://example.com/pred2" + assert result[1].o.value == "http://example.com/uri2" + + @patch('trustgraph.query.triples.memgraph.service.GraphDatabase') + @pytest.mark.asyncio + async def test_query_triples_po_query(self, mock_graph_db): + """Test PO query (predicate and object specified)""" + taskgroup_mock = MagicMock() + mock_driver = MagicMock() + mock_graph_db.driver.return_value = mock_driver + + # Mock query results with different subjects + mock_record1 = MagicMock() + mock_record1.data.return_value = {"src": "http://example.com/subj1"} + mock_record2 = MagicMock() + mock_record2.data.return_value = {"src": "http://example.com/subj2"} + + mock_driver.execute_query.side_effect = [ + ([mock_record1], None, None), # Literal query + ([mock_record2], None, None) # URI query + ] + + processor = Processor(taskgroup=taskgroup_mock) + + # Create query request + query = TriplesQueryRequest( + user='test_user', + collection='test_collection', + s=None, + p=Value(value="http://example.com/predicate", is_uri=True), + o=Value(value="literal object", is_uri=False), + limit=100 + ) + + result = await processor.query_triples(query) + + # Verify both literal and URI queries were executed + assert mock_driver.execute_query.call_count == 2 + + # Verify results contain different subjects + assert len(result) == 2 + assert result[0].s.value == "http://example.com/subj1" + assert result[0].p.value == "http://example.com/predicate" + assert result[0].o.value == "literal object" + + assert result[1].s.value == "http://example.com/subj2" + assert result[1].p.value == "http://example.com/predicate" + assert result[1].o.value == "literal object" + + @patch('trustgraph.query.triples.memgraph.service.GraphDatabase') + @pytest.mark.asyncio + async def test_query_triples_p_query(self, mock_graph_db): + """Test P query (predicate only)""" + taskgroup_mock = MagicMock() + mock_driver = MagicMock() + mock_graph_db.driver.return_value = mock_driver + + # Mock query results with different subject-object pairs + mock_record1 = MagicMock() + mock_record1.data.return_value = {"src": "http://example.com/subj1", "dest": "literal1"} + mock_record2 = MagicMock() + mock_record2.data.return_value = {"src": "http://example.com/subj2", "dest": "http://example.com/uri2"} + + mock_driver.execute_query.side_effect = [ + ([mock_record1], None, None), # Literal query + ([mock_record2], None, None) # URI query + ] + + processor = Processor(taskgroup=taskgroup_mock) + + # Create query request + query = TriplesQueryRequest( + user='test_user', + collection='test_collection', + s=None, + p=Value(value="http://example.com/predicate", is_uri=True), + o=None, + limit=100 + ) + + result = await processor.query_triples(query) + + # Verify both literal and URI queries were executed + assert mock_driver.execute_query.call_count == 2 + + # Verify results contain different subject-object pairs + assert len(result) == 2 + assert result[0].s.value == "http://example.com/subj1" + assert result[0].p.value == "http://example.com/predicate" + assert result[0].o.value == "literal1" + + assert result[1].s.value == "http://example.com/subj2" + assert result[1].p.value == "http://example.com/predicate" + assert result[1].o.value == "http://example.com/uri2" + + @patch('trustgraph.query.triples.memgraph.service.GraphDatabase') + @pytest.mark.asyncio + async def test_query_triples_o_query(self, mock_graph_db): + """Test O query (object only)""" + taskgroup_mock = MagicMock() + mock_driver = MagicMock() + mock_graph_db.driver.return_value = mock_driver + + # Mock query results with different subject-predicate pairs + mock_record1 = MagicMock() + mock_record1.data.return_value = {"src": "http://example.com/subj1", "rel": "http://example.com/pred1"} + mock_record2 = MagicMock() + mock_record2.data.return_value = {"src": "http://example.com/subj2", "rel": "http://example.com/pred2"} + + mock_driver.execute_query.side_effect = [ + ([mock_record1], None, None), # Literal query + ([mock_record2], None, None) # URI query + ] + + processor = Processor(taskgroup=taskgroup_mock) + + # Create query request + query = TriplesQueryRequest( + user='test_user', + collection='test_collection', + s=None, + p=None, + o=Value(value="literal object", is_uri=False), + limit=100 + ) + + result = await processor.query_triples(query) + + # Verify both literal and URI queries were executed + assert mock_driver.execute_query.call_count == 2 + + # Verify results contain different subject-predicate pairs + assert len(result) == 2 + assert result[0].s.value == "http://example.com/subj1" + assert result[0].p.value == "http://example.com/pred1" + assert result[0].o.value == "literal object" + + assert result[1].s.value == "http://example.com/subj2" + assert result[1].p.value == "http://example.com/pred2" + assert result[1].o.value == "literal object" + + @patch('trustgraph.query.triples.memgraph.service.GraphDatabase') + @pytest.mark.asyncio + async def test_query_triples_wildcard_query(self, mock_graph_db): + """Test wildcard query (no constraints)""" + taskgroup_mock = MagicMock() + mock_driver = MagicMock() + mock_graph_db.driver.return_value = mock_driver + + # Mock query results + mock_record1 = MagicMock() + mock_record1.data.return_value = {"src": "http://example.com/s1", "rel": "http://example.com/p1", "dest": "literal1"} + mock_record2 = MagicMock() + mock_record2.data.return_value = {"src": "http://example.com/s2", "rel": "http://example.com/p2", "dest": "http://example.com/o2"} + + mock_driver.execute_query.side_effect = [ + ([mock_record1], None, None), # Literal query + ([mock_record2], None, None) # URI query + ] + + processor = Processor(taskgroup=taskgroup_mock) + + # Create query request + query = TriplesQueryRequest( + user='test_user', + collection='test_collection', + s=None, + p=None, + o=None, + limit=100 + ) + + result = await processor.query_triples(query) + + # Verify both literal and URI queries were executed + assert mock_driver.execute_query.call_count == 2 + + # Verify results contain different triples + assert len(result) == 2 + assert result[0].s.value == "http://example.com/s1" + assert result[0].p.value == "http://example.com/p1" + assert result[0].o.value == "literal1" + + assert result[1].s.value == "http://example.com/s2" + assert result[1].p.value == "http://example.com/p2" + assert result[1].o.value == "http://example.com/o2" + + @patch('trustgraph.query.triples.memgraph.service.GraphDatabase') + @pytest.mark.asyncio + async def test_query_triples_exception_handling(self, mock_graph_db): + """Test exception handling during query processing""" + taskgroup_mock = MagicMock() + mock_driver = MagicMock() + mock_graph_db.driver.return_value = mock_driver + + # Mock execute_query to raise exception + mock_driver.execute_query.side_effect = Exception("Database connection failed") + + processor = Processor(taskgroup=taskgroup_mock) + + # Create query request + query = TriplesQueryRequest( + user='test_user', + collection='test_collection', + s=Value(value="http://example.com/subject", is_uri=True), + p=None, + o=None, + limit=100 + ) + + # Should raise the exception + with pytest.raises(Exception, match="Database connection failed"): + await processor.query_triples(query) + + def test_add_args_method(self): + """Test that add_args properly configures argument parser""" + from argparse import ArgumentParser + from unittest.mock import patch + + parser = ArgumentParser() + + # Mock the parent class add_args method + with patch('trustgraph.query.triples.memgraph.service.TriplesQueryService.add_args') as mock_parent_add_args: + Processor.add_args(parser) + + # Verify parent add_args was called + mock_parent_add_args.assert_called_once() + + # Verify our specific arguments were added + # Parse empty args to check defaults + args = parser.parse_args([]) + + assert hasattr(args, 'graph_host') + assert args.graph_host == 'bolt://memgraph:7687' + assert hasattr(args, 'username') + assert args.username == 'memgraph' + assert hasattr(args, 'password') + assert args.password == 'password' + assert hasattr(args, 'database') + assert args.database == 'memgraph' + + def test_add_args_with_custom_values(self): + """Test add_args with custom command line values""" + from argparse import ArgumentParser + from unittest.mock import patch + + parser = ArgumentParser() + + with patch('trustgraph.query.triples.memgraph.service.TriplesQueryService.add_args'): + Processor.add_args(parser) + + # Test parsing with custom values + args = parser.parse_args([ + '--graph-host', 'bolt://custom:7687', + '--username', 'queryuser', + '--password', 'querypass', + '--database', 'querydb' + ]) + + assert args.graph_host == 'bolt://custom:7687' + assert args.username == 'queryuser' + assert args.password == 'querypass' + assert args.database == 'querydb' + + def test_add_args_short_form(self): + """Test add_args with short form arguments""" + from argparse import ArgumentParser + from unittest.mock import patch + + parser = ArgumentParser() + + with patch('trustgraph.query.triples.memgraph.service.TriplesQueryService.add_args'): + Processor.add_args(parser) + + # Test parsing with short form + args = parser.parse_args(['-g', 'bolt://short:7687']) + + assert args.graph_host == 'bolt://short:7687' + + @patch('trustgraph.query.triples.memgraph.service.Processor.launch') + def test_run_function(self, mock_launch): + """Test the run function calls Processor.launch with correct parameters""" + from trustgraph.query.triples.memgraph.service import run, default_ident + + run() + + mock_launch.assert_called_once_with( + default_ident, + "\nTriples query service for memgraph.\nInput is a (s, p, o) triple, some values may be null. Output is a list of\ntriples.\n" + ) \ No newline at end of file diff --git a/tests/unit/test_query/test_triples_neo4j_query.py b/tests/unit/test_query/test_triples_neo4j_query.py new file mode 100644 index 00000000..320aed54 --- /dev/null +++ b/tests/unit/test_query/test_triples_neo4j_query.py @@ -0,0 +1,338 @@ +""" +Tests for Neo4j triples query service +""" + +import pytest +from unittest.mock import MagicMock, patch + +from trustgraph.query.triples.neo4j.service import Processor +from trustgraph.schema import Value, TriplesQueryRequest + + +class TestNeo4jQueryProcessor: + """Test cases for Neo4j query processor""" + + @pytest.fixture + def processor(self): + """Create a processor instance for testing""" + with patch('trustgraph.query.triples.neo4j.service.GraphDatabase'): + return Processor( + taskgroup=MagicMock(), + id='test-neo4j-query', + graph_host='bolt://localhost:7687' + ) + + def test_create_value_with_http_uri(self, processor): + """Test create_value with HTTP URI""" + result = processor.create_value("http://example.com/resource") + + assert isinstance(result, Value) + assert result.value == "http://example.com/resource" + assert result.is_uri is True + + def test_create_value_with_https_uri(self, processor): + """Test create_value with HTTPS URI""" + result = processor.create_value("https://example.com/resource") + + assert isinstance(result, Value) + assert result.value == "https://example.com/resource" + assert result.is_uri is True + + def test_create_value_with_literal(self, processor): + """Test create_value with literal value""" + result = processor.create_value("just a literal string") + + assert isinstance(result, Value) + assert result.value == "just a literal string" + assert result.is_uri is False + + def test_create_value_with_empty_string(self, processor): + """Test create_value with empty string""" + result = processor.create_value("") + + assert isinstance(result, Value) + assert result.value == "" + assert result.is_uri is False + + def test_create_value_with_partial_uri(self, processor): + """Test create_value with string that looks like URI but isn't complete""" + result = processor.create_value("http") + + assert isinstance(result, Value) + assert result.value == "http" + assert result.is_uri is False + + def test_create_value_with_ftp_uri(self, processor): + """Test create_value with FTP URI (should not be detected as URI)""" + result = processor.create_value("ftp://example.com/file") + + assert isinstance(result, Value) + assert result.value == "ftp://example.com/file" + assert result.is_uri is False + + @patch('trustgraph.query.triples.neo4j.service.GraphDatabase') + def test_processor_initialization_with_defaults(self, mock_graph_db): + """Test processor initialization with default parameters""" + taskgroup_mock = MagicMock() + mock_driver = MagicMock() + mock_graph_db.driver.return_value = mock_driver + + processor = Processor(taskgroup=taskgroup_mock) + + assert processor.db == 'neo4j' + mock_graph_db.driver.assert_called_once_with( + 'bolt://neo4j:7687', + auth=('neo4j', 'password') + ) + + @patch('trustgraph.query.triples.neo4j.service.GraphDatabase') + def test_processor_initialization_with_custom_params(self, mock_graph_db): + """Test processor initialization with custom parameters""" + taskgroup_mock = MagicMock() + mock_driver = MagicMock() + mock_graph_db.driver.return_value = mock_driver + + processor = Processor( + taskgroup=taskgroup_mock, + graph_host='bolt://custom:7687', + username='queryuser', + password='querypass', + database='customdb' + ) + + assert processor.db == 'customdb' + mock_graph_db.driver.assert_called_once_with( + 'bolt://custom:7687', + auth=('queryuser', 'querypass') + ) + + @patch('trustgraph.query.triples.neo4j.service.GraphDatabase') + @pytest.mark.asyncio + async def test_query_triples_spo_query(self, mock_graph_db): + """Test SPO query (all values specified)""" + taskgroup_mock = MagicMock() + mock_driver = MagicMock() + mock_graph_db.driver.return_value = mock_driver + + # Mock query results - both queries return one record each + mock_records = [MagicMock()] + mock_driver.execute_query.return_value = (mock_records, None, None) + + processor = Processor(taskgroup=taskgroup_mock) + + # Create query request + query = TriplesQueryRequest( + user='test_user', + collection='test_collection', + s=Value(value="http://example.com/subject", is_uri=True), + p=Value(value="http://example.com/predicate", is_uri=True), + o=Value(value="literal object", is_uri=False), + limit=100 + ) + + result = await processor.query_triples(query) + + # Verify both literal and URI queries were executed + assert mock_driver.execute_query.call_count == 2 + + # Verify result contains the queried triple (appears twice - once from each query) + assert len(result) == 2 + assert result[0].s.value == "http://example.com/subject" + assert result[0].p.value == "http://example.com/predicate" + assert result[0].o.value == "literal object" + + @patch('trustgraph.query.triples.neo4j.service.GraphDatabase') + @pytest.mark.asyncio + async def test_query_triples_sp_query(self, mock_graph_db): + """Test SP query (subject and predicate specified)""" + taskgroup_mock = MagicMock() + mock_driver = MagicMock() + mock_graph_db.driver.return_value = mock_driver + + # Mock query results with different objects + mock_record1 = MagicMock() + mock_record1.data.return_value = {"dest": "literal result"} + mock_record2 = MagicMock() + mock_record2.data.return_value = {"dest": "http://example.com/uri_result"} + + mock_driver.execute_query.side_effect = [ + ([mock_record1], None, None), # Literal query + ([mock_record2], None, None) # URI query + ] + + processor = Processor(taskgroup=taskgroup_mock) + + # Create query request + query = TriplesQueryRequest( + user='test_user', + collection='test_collection', + s=Value(value="http://example.com/subject", is_uri=True), + p=Value(value="http://example.com/predicate", is_uri=True), + o=None, + limit=100 + ) + + result = await processor.query_triples(query) + + # Verify both literal and URI queries were executed + assert mock_driver.execute_query.call_count == 2 + + # Verify results contain different objects + assert len(result) == 2 + assert result[0].s.value == "http://example.com/subject" + assert result[0].p.value == "http://example.com/predicate" + assert result[0].o.value == "literal result" + + assert result[1].s.value == "http://example.com/subject" + assert result[1].p.value == "http://example.com/predicate" + assert result[1].o.value == "http://example.com/uri_result" + + @patch('trustgraph.query.triples.neo4j.service.GraphDatabase') + @pytest.mark.asyncio + async def test_query_triples_wildcard_query(self, mock_graph_db): + """Test wildcard query (no constraints)""" + taskgroup_mock = MagicMock() + mock_driver = MagicMock() + mock_graph_db.driver.return_value = mock_driver + + # Mock query results + mock_record1 = MagicMock() + mock_record1.data.return_value = {"src": "http://example.com/s1", "rel": "http://example.com/p1", "dest": "literal1"} + mock_record2 = MagicMock() + mock_record2.data.return_value = {"src": "http://example.com/s2", "rel": "http://example.com/p2", "dest": "http://example.com/o2"} + + mock_driver.execute_query.side_effect = [ + ([mock_record1], None, None), # Literal query + ([mock_record2], None, None) # URI query + ] + + processor = Processor(taskgroup=taskgroup_mock) + + # Create query request + query = TriplesQueryRequest( + user='test_user', + collection='test_collection', + s=None, + p=None, + o=None, + limit=100 + ) + + result = await processor.query_triples(query) + + # Verify both literal and URI queries were executed + assert mock_driver.execute_query.call_count == 2 + + # Verify results contain different triples + assert len(result) == 2 + assert result[0].s.value == "http://example.com/s1" + assert result[0].p.value == "http://example.com/p1" + assert result[0].o.value == "literal1" + + assert result[1].s.value == "http://example.com/s2" + assert result[1].p.value == "http://example.com/p2" + assert result[1].o.value == "http://example.com/o2" + + @patch('trustgraph.query.triples.neo4j.service.GraphDatabase') + @pytest.mark.asyncio + async def test_query_triples_exception_handling(self, mock_graph_db): + """Test exception handling during query processing""" + taskgroup_mock = MagicMock() + mock_driver = MagicMock() + mock_graph_db.driver.return_value = mock_driver + + # Mock execute_query to raise exception + mock_driver.execute_query.side_effect = Exception("Database connection failed") + + processor = Processor(taskgroup=taskgroup_mock) + + # Create query request + query = TriplesQueryRequest( + user='test_user', + collection='test_collection', + s=Value(value="http://example.com/subject", is_uri=True), + p=None, + o=None, + limit=100 + ) + + # Should raise the exception + with pytest.raises(Exception, match="Database connection failed"): + await processor.query_triples(query) + + def test_add_args_method(self): + """Test that add_args properly configures argument parser""" + from argparse import ArgumentParser + from unittest.mock import patch + + parser = ArgumentParser() + + # Mock the parent class add_args method + with patch('trustgraph.query.triples.neo4j.service.TriplesQueryService.add_args') as mock_parent_add_args: + Processor.add_args(parser) + + # Verify parent add_args was called + mock_parent_add_args.assert_called_once() + + # Verify our specific arguments were added + # Parse empty args to check defaults + args = parser.parse_args([]) + + assert hasattr(args, 'graph_host') + assert args.graph_host == 'bolt://neo4j:7687' + assert hasattr(args, 'username') + assert args.username == 'neo4j' + assert hasattr(args, 'password') + assert args.password == 'password' + assert hasattr(args, 'database') + assert args.database == 'neo4j' + + def test_add_args_with_custom_values(self): + """Test add_args with custom command line values""" + from argparse import ArgumentParser + from unittest.mock import patch + + parser = ArgumentParser() + + with patch('trustgraph.query.triples.neo4j.service.TriplesQueryService.add_args'): + Processor.add_args(parser) + + # Test parsing with custom values + args = parser.parse_args([ + '--graph-host', 'bolt://custom:7687', + '--username', 'queryuser', + '--password', 'querypass', + '--database', 'querydb' + ]) + + assert args.graph_host == 'bolt://custom:7687' + assert args.username == 'queryuser' + assert args.password == 'querypass' + assert args.database == 'querydb' + + def test_add_args_short_form(self): + """Test add_args with short form arguments""" + from argparse import ArgumentParser + from unittest.mock import patch + + parser = ArgumentParser() + + with patch('trustgraph.query.triples.neo4j.service.TriplesQueryService.add_args'): + Processor.add_args(parser) + + # Test parsing with short form + args = parser.parse_args(['-g', 'bolt://short:7687']) + + assert args.graph_host == 'bolt://short:7687' + + @patch('trustgraph.query.triples.neo4j.service.Processor.launch') + def test_run_function(self, mock_launch): + """Test the run function calls Processor.launch with correct parameters""" + from trustgraph.query.triples.neo4j.service import run, default_ident + + run() + + mock_launch.assert_called_once_with( + default_ident, + "\nTriples query service for neo4j.\nInput is a (s, p, o) triple, some values may be null. Output is a list of\ntriples.\n" + ) \ No newline at end of file diff --git a/tests/unit/test_storage/test_doc_embeddings_milvus_storage.py b/tests/unit/test_storage/test_doc_embeddings_milvus_storage.py new file mode 100644 index 00000000..5e6bcfb9 --- /dev/null +++ b/tests/unit/test_storage/test_doc_embeddings_milvus_storage.py @@ -0,0 +1,387 @@ +""" +Tests for Milvus document embeddings storage service +""" + +import pytest +from unittest.mock import MagicMock, patch + +from trustgraph.storage.doc_embeddings.milvus.write import Processor +from trustgraph.schema import ChunkEmbeddings + + +class TestMilvusDocEmbeddingsStorageProcessor: + """Test cases for Milvus document embeddings storage processor""" + + @pytest.fixture + def mock_message(self): + """Create a mock message for testing""" + message = MagicMock() + message.metadata = MagicMock() + message.metadata.user = 'test_user' + message.metadata.collection = 'test_collection' + + # Create test document embeddings + chunk1 = ChunkEmbeddings( + chunk=b"This is the first document chunk", + vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]] + ) + chunk2 = ChunkEmbeddings( + chunk=b"This is the second document chunk", + vectors=[[0.7, 0.8, 0.9]] + ) + message.chunks = [chunk1, chunk2] + + return message + + @pytest.fixture + def processor(self): + """Create a processor instance for testing""" + with patch('trustgraph.storage.doc_embeddings.milvus.write.DocVectors') as mock_doc_vectors: + mock_vecstore = MagicMock() + mock_doc_vectors.return_value = mock_vecstore + + processor = Processor( + taskgroup=MagicMock(), + id='test-milvus-de-storage', + store_uri='http://localhost:19530' + ) + + return processor + + @patch('trustgraph.storage.doc_embeddings.milvus.write.DocVectors') + def test_processor_initialization_with_defaults(self, mock_doc_vectors): + """Test processor initialization with default parameters""" + taskgroup_mock = MagicMock() + mock_vecstore = MagicMock() + mock_doc_vectors.return_value = mock_vecstore + + processor = Processor(taskgroup=taskgroup_mock) + + mock_doc_vectors.assert_called_once_with('http://localhost:19530') + assert processor.vecstore == mock_vecstore + + @patch('trustgraph.storage.doc_embeddings.milvus.write.DocVectors') + def test_processor_initialization_with_custom_params(self, mock_doc_vectors): + """Test processor initialization with custom parameters""" + taskgroup_mock = MagicMock() + mock_vecstore = MagicMock() + mock_doc_vectors.return_value = mock_vecstore + + processor = Processor( + taskgroup=taskgroup_mock, + store_uri='http://custom-milvus:19530' + ) + + mock_doc_vectors.assert_called_once_with('http://custom-milvus:19530') + assert processor.vecstore == mock_vecstore + + @pytest.mark.asyncio + async def test_store_document_embeddings_single_chunk(self, processor): + """Test storing document embeddings for a single chunk""" + message = MagicMock() + message.metadata = MagicMock() + message.metadata.user = 'test_user' + message.metadata.collection = 'test_collection' + + chunk = ChunkEmbeddings( + chunk=b"Test document content", + vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]] + ) + message.chunks = [chunk] + + await processor.store_document_embeddings(message) + + # Verify insert was called for each vector + expected_calls = [ + ([0.1, 0.2, 0.3], "Test document content"), + ([0.4, 0.5, 0.6], "Test document content"), + ] + + assert processor.vecstore.insert.call_count == 2 + for i, (expected_vec, expected_doc) in enumerate(expected_calls): + actual_call = processor.vecstore.insert.call_args_list[i] + assert actual_call[0][0] == expected_vec + assert actual_call[0][1] == expected_doc + + @pytest.mark.asyncio + async def test_store_document_embeddings_multiple_chunks(self, processor, mock_message): + """Test storing document embeddings for multiple chunks""" + await processor.store_document_embeddings(mock_message) + + # Verify insert was called for each vector of each chunk + expected_calls = [ + # Chunk 1 vectors + ([0.1, 0.2, 0.3], "This is the first document chunk"), + ([0.4, 0.5, 0.6], "This is the first document chunk"), + # Chunk 2 vectors + ([0.7, 0.8, 0.9], "This is the second document chunk"), + ] + + assert processor.vecstore.insert.call_count == 3 + for i, (expected_vec, expected_doc) in enumerate(expected_calls): + actual_call = processor.vecstore.insert.call_args_list[i] + assert actual_call[0][0] == expected_vec + assert actual_call[0][1] == expected_doc + + @pytest.mark.asyncio + async def test_store_document_embeddings_empty_chunk(self, processor): + """Test storing document embeddings with empty chunk (should be skipped)""" + message = MagicMock() + message.metadata = MagicMock() + message.metadata.user = 'test_user' + message.metadata.collection = 'test_collection' + + chunk = ChunkEmbeddings( + chunk=b"", + vectors=[[0.1, 0.2, 0.3]] + ) + message.chunks = [chunk] + + await processor.store_document_embeddings(message) + + # Verify no insert was called for empty chunk + processor.vecstore.insert.assert_not_called() + + @pytest.mark.asyncio + async def test_store_document_embeddings_none_chunk(self, processor): + """Test storing document embeddings with None chunk (should be skipped)""" + message = MagicMock() + message.metadata = MagicMock() + message.metadata.user = 'test_user' + message.metadata.collection = 'test_collection' + + chunk = ChunkEmbeddings( + chunk=None, + vectors=[[0.1, 0.2, 0.3]] + ) + message.chunks = [chunk] + + await processor.store_document_embeddings(message) + + # Verify no insert was called for None chunk + processor.vecstore.insert.assert_not_called() + + @pytest.mark.asyncio + async def test_store_document_embeddings_mixed_valid_invalid_chunks(self, processor): + """Test storing document embeddings with mix of valid and invalid chunks""" + message = MagicMock() + message.metadata = MagicMock() + message.metadata.user = 'test_user' + message.metadata.collection = 'test_collection' + + valid_chunk = ChunkEmbeddings( + chunk=b"Valid document content", + vectors=[[0.1, 0.2, 0.3]] + ) + empty_chunk = ChunkEmbeddings( + chunk=b"", + vectors=[[0.4, 0.5, 0.6]] + ) + none_chunk = ChunkEmbeddings( + chunk=None, + vectors=[[0.7, 0.8, 0.9]] + ) + message.chunks = [valid_chunk, empty_chunk, none_chunk] + + await processor.store_document_embeddings(message) + + # Verify only valid chunk was inserted + processor.vecstore.insert.assert_called_once_with( + [0.1, 0.2, 0.3], "Valid document content" + ) + + @pytest.mark.asyncio + async def test_store_document_embeddings_empty_chunks_list(self, processor): + """Test storing document embeddings with empty chunks list""" + message = MagicMock() + message.metadata = MagicMock() + message.metadata.user = 'test_user' + message.metadata.collection = 'test_collection' + message.chunks = [] + + await processor.store_document_embeddings(message) + + # Verify no insert was called + processor.vecstore.insert.assert_not_called() + + @pytest.mark.asyncio + async def test_store_document_embeddings_chunk_with_no_vectors(self, processor): + """Test storing document embeddings for chunk with no vectors""" + message = MagicMock() + message.metadata = MagicMock() + message.metadata.user = 'test_user' + message.metadata.collection = 'test_collection' + + chunk = ChunkEmbeddings( + chunk=b"Document with no vectors", + vectors=[] + ) + message.chunks = [chunk] + + await processor.store_document_embeddings(message) + + # Verify no insert was called (no vectors to insert) + processor.vecstore.insert.assert_not_called() + + @pytest.mark.asyncio + async def test_store_document_embeddings_different_vector_dimensions(self, processor): + """Test storing document embeddings with different vector dimensions""" + message = MagicMock() + message.metadata = MagicMock() + message.metadata.user = 'test_user' + message.metadata.collection = 'test_collection' + + chunk = ChunkEmbeddings( + chunk=b"Document with mixed dimensions", + vectors=[ + [0.1, 0.2], # 2D vector + [0.3, 0.4, 0.5, 0.6], # 4D vector + [0.7, 0.8, 0.9] # 3D vector + ] + ) + message.chunks = [chunk] + + await processor.store_document_embeddings(message) + + # Verify all vectors were inserted regardless of dimension + expected_calls = [ + ([0.1, 0.2], "Document with mixed dimensions"), + ([0.3, 0.4, 0.5, 0.6], "Document with mixed dimensions"), + ([0.7, 0.8, 0.9], "Document with mixed dimensions"), + ] + + assert processor.vecstore.insert.call_count == 3 + for i, (expected_vec, expected_doc) in enumerate(expected_calls): + actual_call = processor.vecstore.insert.call_args_list[i] + assert actual_call[0][0] == expected_vec + assert actual_call[0][1] == expected_doc + + @pytest.mark.asyncio + async def test_store_document_embeddings_unicode_content(self, processor): + """Test storing document embeddings with Unicode content""" + message = MagicMock() + message.metadata = MagicMock() + message.metadata.user = 'test_user' + message.metadata.collection = 'test_collection' + + chunk = ChunkEmbeddings( + chunk="Document with Unicode: éñ中文🚀".encode('utf-8'), + vectors=[[0.1, 0.2, 0.3]] + ) + message.chunks = [chunk] + + await processor.store_document_embeddings(message) + + # Verify Unicode content was properly decoded and inserted + processor.vecstore.insert.assert_called_once_with( + [0.1, 0.2, 0.3], "Document with Unicode: éñ中文🚀" + ) + + @pytest.mark.asyncio + async def test_store_document_embeddings_large_chunks(self, processor): + """Test storing document embeddings with large document chunks""" + message = MagicMock() + message.metadata = MagicMock() + message.metadata.user = 'test_user' + message.metadata.collection = 'test_collection' + + # Create a large document chunk + large_content = "A" * 10000 # 10KB of content + chunk = ChunkEmbeddings( + chunk=large_content.encode('utf-8'), + vectors=[[0.1, 0.2, 0.3]] + ) + message.chunks = [chunk] + + await processor.store_document_embeddings(message) + + # Verify large content was inserted + processor.vecstore.insert.assert_called_once_with( + [0.1, 0.2, 0.3], large_content + ) + + @pytest.mark.asyncio + async def test_store_document_embeddings_whitespace_only_chunk(self, processor): + """Test storing document embeddings with whitespace-only chunk""" + message = MagicMock() + message.metadata = MagicMock() + message.metadata.user = 'test_user' + message.metadata.collection = 'test_collection' + + chunk = ChunkEmbeddings( + chunk=b" \n\t ", + vectors=[[0.1, 0.2, 0.3]] + ) + message.chunks = [chunk] + + await processor.store_document_embeddings(message) + + # Verify whitespace content was inserted (not filtered out) + processor.vecstore.insert.assert_called_once_with( + [0.1, 0.2, 0.3], " \n\t " + ) + + def test_add_args_method(self): + """Test that add_args properly configures argument parser""" + from argparse import ArgumentParser + from unittest.mock import patch + + parser = ArgumentParser() + + # Mock the parent class add_args method + with patch('trustgraph.storage.doc_embeddings.milvus.write.DocumentEmbeddingsStoreService.add_args') as mock_parent_add_args: + Processor.add_args(parser) + + # Verify parent add_args was called + mock_parent_add_args.assert_called_once() + + # Verify our specific arguments were added + # Parse empty args to check defaults + args = parser.parse_args([]) + + assert hasattr(args, 'store_uri') + assert args.store_uri == 'http://localhost:19530' + + def test_add_args_with_custom_values(self): + """Test add_args with custom command line values""" + from argparse import ArgumentParser + from unittest.mock import patch + + parser = ArgumentParser() + + with patch('trustgraph.storage.doc_embeddings.milvus.write.DocumentEmbeddingsStoreService.add_args'): + Processor.add_args(parser) + + # Test parsing with custom values + args = parser.parse_args([ + '--store-uri', 'http://custom-milvus:19530' + ]) + + assert args.store_uri == 'http://custom-milvus:19530' + + def test_add_args_short_form(self): + """Test add_args with short form arguments""" + from argparse import ArgumentParser + from unittest.mock import patch + + parser = ArgumentParser() + + with patch('trustgraph.storage.doc_embeddings.milvus.write.DocumentEmbeddingsStoreService.add_args'): + Processor.add_args(parser) + + # Test parsing with short form + args = parser.parse_args(['-t', 'http://short-milvus:19530']) + + assert args.store_uri == 'http://short-milvus:19530' + + @patch('trustgraph.storage.doc_embeddings.milvus.write.Processor.launch') + def test_run_function(self, mock_launch): + """Test the run function calls Processor.launch with correct parameters""" + from trustgraph.storage.doc_embeddings.milvus.write import run, default_ident + + run() + + mock_launch.assert_called_once_with( + default_ident, + "\nAccepts entity/vector pairs and writes them to a Milvus store.\n" + ) \ No newline at end of file diff --git a/tests/unit/test_storage/test_doc_embeddings_pinecone_storage.py b/tests/unit/test_storage/test_doc_embeddings_pinecone_storage.py new file mode 100644 index 00000000..6c4ddb6b --- /dev/null +++ b/tests/unit/test_storage/test_doc_embeddings_pinecone_storage.py @@ -0,0 +1,536 @@ +""" +Tests for Pinecone document embeddings storage service +""" + +import pytest +from unittest.mock import MagicMock, patch +import uuid + +from trustgraph.storage.doc_embeddings.pinecone.write import Processor +from trustgraph.schema import ChunkEmbeddings + + +class TestPineconeDocEmbeddingsStorageProcessor: + """Test cases for Pinecone document embeddings storage processor""" + + @pytest.fixture + def mock_message(self): + """Create a mock message for testing""" + message = MagicMock() + message.metadata = MagicMock() + message.metadata.user = 'test_user' + message.metadata.collection = 'test_collection' + + # Create test document embeddings + chunk1 = ChunkEmbeddings( + chunk=b"This is the first document chunk", + vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]] + ) + chunk2 = ChunkEmbeddings( + chunk=b"This is the second document chunk", + vectors=[[0.7, 0.8, 0.9]] + ) + message.chunks = [chunk1, chunk2] + + return message + + @pytest.fixture + def processor(self): + """Create a processor instance for testing""" + with patch('trustgraph.storage.doc_embeddings.pinecone.write.Pinecone') as mock_pinecone_class: + mock_pinecone = MagicMock() + mock_pinecone_class.return_value = mock_pinecone + + processor = Processor( + taskgroup=MagicMock(), + id='test-pinecone-de-storage', + api_key='test-api-key' + ) + + return processor + + @patch('trustgraph.storage.doc_embeddings.pinecone.write.Pinecone') + @patch('trustgraph.storage.doc_embeddings.pinecone.write.default_api_key', 'env-api-key') + def test_processor_initialization_with_defaults(self, mock_pinecone_class): + """Test processor initialization with default parameters""" + mock_pinecone = MagicMock() + mock_pinecone_class.return_value = mock_pinecone + taskgroup_mock = MagicMock() + + processor = Processor(taskgroup=taskgroup_mock) + + mock_pinecone_class.assert_called_once_with(api_key='env-api-key') + assert processor.pinecone == mock_pinecone + assert processor.api_key == 'env-api-key' + assert processor.cloud == 'aws' + assert processor.region == 'us-east-1' + + @patch('trustgraph.storage.doc_embeddings.pinecone.write.Pinecone') + def test_processor_initialization_with_custom_params(self, mock_pinecone_class): + """Test processor initialization with custom parameters""" + mock_pinecone = MagicMock() + mock_pinecone_class.return_value = mock_pinecone + taskgroup_mock = MagicMock() + + processor = Processor( + taskgroup=taskgroup_mock, + api_key='custom-api-key', + cloud='gcp', + region='us-west1' + ) + + mock_pinecone_class.assert_called_once_with(api_key='custom-api-key') + assert processor.api_key == 'custom-api-key' + assert processor.cloud == 'gcp' + assert processor.region == 'us-west1' + + @patch('trustgraph.storage.doc_embeddings.pinecone.write.PineconeGRPC') + def test_processor_initialization_with_url(self, mock_pinecone_grpc_class): + """Test processor initialization with custom URL (GRPC mode)""" + mock_pinecone = MagicMock() + mock_pinecone_grpc_class.return_value = mock_pinecone + taskgroup_mock = MagicMock() + + processor = Processor( + taskgroup=taskgroup_mock, + api_key='test-api-key', + url='https://custom-host.pinecone.io' + ) + + mock_pinecone_grpc_class.assert_called_once_with( + api_key='test-api-key', + host='https://custom-host.pinecone.io' + ) + assert processor.pinecone == mock_pinecone + assert processor.url == 'https://custom-host.pinecone.io' + + @patch('trustgraph.storage.doc_embeddings.pinecone.write.default_api_key', 'not-specified') + def test_processor_initialization_missing_api_key(self): + """Test processor initialization fails with missing API key""" + taskgroup_mock = MagicMock() + + with pytest.raises(RuntimeError, match="Pinecone API key must be specified"): + Processor(taskgroup=taskgroup_mock) + + @pytest.mark.asyncio + async def test_store_document_embeddings_single_chunk(self, processor): + """Test storing document embeddings for a single chunk""" + message = MagicMock() + message.metadata = MagicMock() + message.metadata.user = 'test_user' + message.metadata.collection = 'test_collection' + + chunk = ChunkEmbeddings( + chunk=b"Test document content", + vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]] + ) + message.chunks = [chunk] + + # Mock index operations + mock_index = MagicMock() + processor.pinecone.Index.return_value = mock_index + processor.pinecone.has_index.return_value = True + + with patch('uuid.uuid4', side_effect=['id1', 'id2']): + await processor.store_document_embeddings(message) + + # Verify index name and operations + expected_index_name = "d-test_user-test_collection-3" + processor.pinecone.Index.assert_called_with(expected_index_name) + + # Verify upsert was called for each vector + assert mock_index.upsert.call_count == 2 + + # Check first vector upsert + first_call = mock_index.upsert.call_args_list[0] + first_vectors = first_call[1]['vectors'] + assert len(first_vectors) == 1 + assert first_vectors[0]['id'] == 'id1' + assert first_vectors[0]['values'] == [0.1, 0.2, 0.3] + assert first_vectors[0]['metadata']['doc'] == "Test document content" + + # Check second vector upsert + second_call = mock_index.upsert.call_args_list[1] + second_vectors = second_call[1]['vectors'] + assert len(second_vectors) == 1 + assert second_vectors[0]['id'] == 'id2' + assert second_vectors[0]['values'] == [0.4, 0.5, 0.6] + assert second_vectors[0]['metadata']['doc'] == "Test document content" + + @pytest.mark.asyncio + async def test_store_document_embeddings_multiple_chunks(self, processor, mock_message): + """Test storing document embeddings for multiple chunks""" + # Mock index operations + mock_index = MagicMock() + processor.pinecone.Index.return_value = mock_index + processor.pinecone.has_index.return_value = True + + with patch('uuid.uuid4', side_effect=['id1', 'id2', 'id3']): + await processor.store_document_embeddings(mock_message) + + # Verify upsert was called for each vector (3 total) + assert mock_index.upsert.call_count == 3 + + # Verify document content in metadata + calls = mock_index.upsert.call_args_list + assert calls[0][1]['vectors'][0]['metadata']['doc'] == "This is the first document chunk" + assert calls[1][1]['vectors'][0]['metadata']['doc'] == "This is the first document chunk" + assert calls[2][1]['vectors'][0]['metadata']['doc'] == "This is the second document chunk" + + @pytest.mark.asyncio + async def test_store_document_embeddings_index_creation(self, processor): + """Test automatic index creation when index doesn't exist""" + message = MagicMock() + message.metadata = MagicMock() + message.metadata.user = 'test_user' + message.metadata.collection = 'test_collection' + + chunk = ChunkEmbeddings( + chunk=b"Test document content", + vectors=[[0.1, 0.2, 0.3]] + ) + message.chunks = [chunk] + + # Mock index doesn't exist initially + processor.pinecone.has_index.return_value = False + mock_index = MagicMock() + processor.pinecone.Index.return_value = mock_index + + # Mock index creation + processor.pinecone.describe_index.return_value.status = {"ready": True} + + with patch('uuid.uuid4', return_value='test-id'): + await processor.store_document_embeddings(message) + + # Verify index creation was called + expected_index_name = "d-test_user-test_collection-3" + processor.pinecone.create_index.assert_called_once() + create_call = processor.pinecone.create_index.call_args + assert create_call[1]['name'] == expected_index_name + assert create_call[1]['dimension'] == 3 + assert create_call[1]['metric'] == "cosine" + + @pytest.mark.asyncio + async def test_store_document_embeddings_empty_chunk(self, processor): + """Test storing document embeddings with empty chunk (should be skipped)""" + message = MagicMock() + message.metadata = MagicMock() + message.metadata.user = 'test_user' + message.metadata.collection = 'test_collection' + + chunk = ChunkEmbeddings( + chunk=b"", + vectors=[[0.1, 0.2, 0.3]] + ) + message.chunks = [chunk] + + mock_index = MagicMock() + processor.pinecone.Index.return_value = mock_index + + await processor.store_document_embeddings(message) + + # Verify no upsert was called for empty chunk + mock_index.upsert.assert_not_called() + + @pytest.mark.asyncio + async def test_store_document_embeddings_none_chunk(self, processor): + """Test storing document embeddings with None chunk (should be skipped)""" + message = MagicMock() + message.metadata = MagicMock() + message.metadata.user = 'test_user' + message.metadata.collection = 'test_collection' + + chunk = ChunkEmbeddings( + chunk=None, + vectors=[[0.1, 0.2, 0.3]] + ) + message.chunks = [chunk] + + mock_index = MagicMock() + processor.pinecone.Index.return_value = mock_index + + await processor.store_document_embeddings(message) + + # Verify no upsert was called for None chunk + mock_index.upsert.assert_not_called() + + @pytest.mark.asyncio + async def test_store_document_embeddings_empty_decoded_chunk(self, processor): + """Test storing document embeddings with chunk that decodes to empty string""" + message = MagicMock() + message.metadata = MagicMock() + message.metadata.user = 'test_user' + message.metadata.collection = 'test_collection' + + chunk = ChunkEmbeddings( + chunk=b"", # Empty bytes + vectors=[[0.1, 0.2, 0.3]] + ) + message.chunks = [chunk] + + mock_index = MagicMock() + processor.pinecone.Index.return_value = mock_index + + await processor.store_document_embeddings(message) + + # Verify no upsert was called for empty decoded chunk + mock_index.upsert.assert_not_called() + + @pytest.mark.asyncio + async def test_store_document_embeddings_different_vector_dimensions(self, processor): + """Test storing document embeddings with different vector dimensions""" + message = MagicMock() + message.metadata = MagicMock() + message.metadata.user = 'test_user' + message.metadata.collection = 'test_collection' + + chunk = ChunkEmbeddings( + chunk=b"Document with mixed dimensions", + vectors=[ + [0.1, 0.2], # 2D vector + [0.3, 0.4, 0.5, 0.6], # 4D vector + [0.7, 0.8, 0.9] # 3D vector + ] + ) + message.chunks = [chunk] + + mock_index_2d = MagicMock() + mock_index_4d = MagicMock() + mock_index_3d = MagicMock() + + def mock_index_side_effect(name): + if name.endswith("-2"): + return mock_index_2d + elif name.endswith("-4"): + return mock_index_4d + elif name.endswith("-3"): + return mock_index_3d + + processor.pinecone.Index.side_effect = mock_index_side_effect + processor.pinecone.has_index.return_value = True + + with patch('uuid.uuid4', side_effect=['id1', 'id2', 'id3']): + await processor.store_document_embeddings(message) + + # Verify different indexes were used for different dimensions + assert processor.pinecone.Index.call_count == 3 + mock_index_2d.upsert.assert_called_once() + mock_index_4d.upsert.assert_called_once() + mock_index_3d.upsert.assert_called_once() + + @pytest.mark.asyncio + async def test_store_document_embeddings_empty_chunks_list(self, processor): + """Test storing document embeddings with empty chunks list""" + message = MagicMock() + message.metadata = MagicMock() + message.metadata.user = 'test_user' + message.metadata.collection = 'test_collection' + message.chunks = [] + + mock_index = MagicMock() + processor.pinecone.Index.return_value = mock_index + + await processor.store_document_embeddings(message) + + # Verify no operations were performed + processor.pinecone.Index.assert_not_called() + mock_index.upsert.assert_not_called() + + @pytest.mark.asyncio + async def test_store_document_embeddings_chunk_with_no_vectors(self, processor): + """Test storing document embeddings for chunk with no vectors""" + message = MagicMock() + message.metadata = MagicMock() + message.metadata.user = 'test_user' + message.metadata.collection = 'test_collection' + + chunk = ChunkEmbeddings( + chunk=b"Document with no vectors", + vectors=[] + ) + message.chunks = [chunk] + + mock_index = MagicMock() + processor.pinecone.Index.return_value = mock_index + + await processor.store_document_embeddings(message) + + # Verify no upsert was called (no vectors to insert) + mock_index.upsert.assert_not_called() + + @pytest.mark.asyncio + async def test_store_document_embeddings_index_creation_failure(self, processor): + """Test handling of index creation failure""" + message = MagicMock() + message.metadata = MagicMock() + message.metadata.user = 'test_user' + message.metadata.collection = 'test_collection' + + chunk = ChunkEmbeddings( + chunk=b"Test document content", + vectors=[[0.1, 0.2, 0.3]] + ) + message.chunks = [chunk] + + # Mock index doesn't exist and creation fails + processor.pinecone.has_index.return_value = False + processor.pinecone.create_index.side_effect = Exception("Index creation failed") + + with pytest.raises(Exception, match="Index creation failed"): + await processor.store_document_embeddings(message) + + @pytest.mark.asyncio + async def test_store_document_embeddings_index_creation_timeout(self, processor): + """Test handling of index creation timeout""" + message = MagicMock() + message.metadata = MagicMock() + message.metadata.user = 'test_user' + message.metadata.collection = 'test_collection' + + chunk = ChunkEmbeddings( + chunk=b"Test document content", + vectors=[[0.1, 0.2, 0.3]] + ) + message.chunks = [chunk] + + # Mock index doesn't exist and never becomes ready + processor.pinecone.has_index.return_value = False + processor.pinecone.describe_index.return_value.status = {"ready": False} + + with patch('time.sleep'): # Speed up the test + with pytest.raises(RuntimeError, match="Gave up waiting for index creation"): + await processor.store_document_embeddings(message) + + @pytest.mark.asyncio + async def test_store_document_embeddings_unicode_content(self, processor): + """Test storing document embeddings with Unicode content""" + message = MagicMock() + message.metadata = MagicMock() + message.metadata.user = 'test_user' + message.metadata.collection = 'test_collection' + + chunk = ChunkEmbeddings( + chunk="Document with Unicode: éñ中文🚀".encode('utf-8'), + vectors=[[0.1, 0.2, 0.3]] + ) + message.chunks = [chunk] + + mock_index = MagicMock() + processor.pinecone.Index.return_value = mock_index + processor.pinecone.has_index.return_value = True + + with patch('uuid.uuid4', return_value='test-id'): + await processor.store_document_embeddings(message) + + # Verify Unicode content was properly decoded and stored + call_args = mock_index.upsert.call_args + stored_doc = call_args[1]['vectors'][0]['metadata']['doc'] + assert stored_doc == "Document with Unicode: éñ中文🚀" + + @pytest.mark.asyncio + async def test_store_document_embeddings_large_chunks(self, processor): + """Test storing document embeddings with large document chunks""" + message = MagicMock() + message.metadata = MagicMock() + message.metadata.user = 'test_user' + message.metadata.collection = 'test_collection' + + # Create a large document chunk + large_content = "A" * 10000 # 10KB of content + chunk = ChunkEmbeddings( + chunk=large_content.encode('utf-8'), + vectors=[[0.1, 0.2, 0.3]] + ) + message.chunks = [chunk] + + mock_index = MagicMock() + processor.pinecone.Index.return_value = mock_index + processor.pinecone.has_index.return_value = True + + with patch('uuid.uuid4', return_value='test-id'): + await processor.store_document_embeddings(message) + + # Verify large content was stored + call_args = mock_index.upsert.call_args + stored_doc = call_args[1]['vectors'][0]['metadata']['doc'] + assert stored_doc == large_content + + def test_add_args_method(self): + """Test that add_args properly configures argument parser""" + from argparse import ArgumentParser + from unittest.mock import patch + + parser = ArgumentParser() + + # Mock the parent class add_args method + with patch('trustgraph.storage.doc_embeddings.pinecone.write.DocumentEmbeddingsStoreService.add_args') as mock_parent_add_args: + Processor.add_args(parser) + + # Verify parent add_args was called + mock_parent_add_args.assert_called_once() + + # Verify our specific arguments were added + args = parser.parse_args([]) + + assert hasattr(args, 'api_key') + assert args.api_key == 'not-specified' # Default value when no env var + assert hasattr(args, 'url') + assert args.url is None + assert hasattr(args, 'cloud') + assert args.cloud == 'aws' + assert hasattr(args, 'region') + assert args.region == 'us-east-1' + + def test_add_args_with_custom_values(self): + """Test add_args with custom command line values""" + from argparse import ArgumentParser + from unittest.mock import patch + + parser = ArgumentParser() + + with patch('trustgraph.storage.doc_embeddings.pinecone.write.DocumentEmbeddingsStoreService.add_args'): + Processor.add_args(parser) + + # Test parsing with custom values + args = parser.parse_args([ + '--api-key', 'custom-api-key', + '--url', 'https://custom-host.pinecone.io', + '--cloud', 'gcp', + '--region', 'us-west1' + ]) + + assert args.api_key == 'custom-api-key' + assert args.url == 'https://custom-host.pinecone.io' + assert args.cloud == 'gcp' + assert args.region == 'us-west1' + + def test_add_args_short_form(self): + """Test add_args with short form arguments""" + from argparse import ArgumentParser + from unittest.mock import patch + + parser = ArgumentParser() + + with patch('trustgraph.storage.doc_embeddings.pinecone.write.DocumentEmbeddingsStoreService.add_args'): + Processor.add_args(parser) + + # Test parsing with short form + args = parser.parse_args([ + '-a', 'short-api-key', + '-u', 'https://short-host.pinecone.io' + ]) + + assert args.api_key == 'short-api-key' + assert args.url == 'https://short-host.pinecone.io' + + @patch('trustgraph.storage.doc_embeddings.pinecone.write.Processor.launch') + def test_run_function(self, mock_launch): + """Test the run function calls Processor.launch with correct parameters""" + from trustgraph.storage.doc_embeddings.pinecone.write import run, default_ident + + run() + + mock_launch.assert_called_once_with( + default_ident, + "\nAccepts document chunks/vector pairs and writes them to a Pinecone store.\n" + ) \ No newline at end of file diff --git a/tests/unit/test_storage/test_graph_embeddings_milvus_storage.py b/tests/unit/test_storage/test_graph_embeddings_milvus_storage.py new file mode 100644 index 00000000..ae300574 --- /dev/null +++ b/tests/unit/test_storage/test_graph_embeddings_milvus_storage.py @@ -0,0 +1,354 @@ +""" +Tests for Milvus graph embeddings storage service +""" + +import pytest +from unittest.mock import MagicMock, patch + +from trustgraph.storage.graph_embeddings.milvus.write import Processor +from trustgraph.schema import Value, EntityEmbeddings + + +class TestMilvusGraphEmbeddingsStorageProcessor: + """Test cases for Milvus graph embeddings storage processor""" + + @pytest.fixture + def mock_message(self): + """Create a mock message for testing""" + message = MagicMock() + message.metadata = MagicMock() + message.metadata.user = 'test_user' + message.metadata.collection = 'test_collection' + + # Create test entities with embeddings + entity1 = EntityEmbeddings( + entity=Value(value='http://example.com/entity1', is_uri=True), + vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]] + ) + entity2 = EntityEmbeddings( + entity=Value(value='literal entity', is_uri=False), + vectors=[[0.7, 0.8, 0.9]] + ) + message.entities = [entity1, entity2] + + return message + + @pytest.fixture + def processor(self): + """Create a processor instance for testing""" + with patch('trustgraph.storage.graph_embeddings.milvus.write.EntityVectors') as mock_entity_vectors: + mock_vecstore = MagicMock() + mock_entity_vectors.return_value = mock_vecstore + + processor = Processor( + taskgroup=MagicMock(), + id='test-milvus-ge-storage', + store_uri='http://localhost:19530' + ) + + return processor + + @patch('trustgraph.storage.graph_embeddings.milvus.write.EntityVectors') + def test_processor_initialization_with_defaults(self, mock_entity_vectors): + """Test processor initialization with default parameters""" + taskgroup_mock = MagicMock() + mock_vecstore = MagicMock() + mock_entity_vectors.return_value = mock_vecstore + + processor = Processor(taskgroup=taskgroup_mock) + + mock_entity_vectors.assert_called_once_with('http://localhost:19530') + assert processor.vecstore == mock_vecstore + + @patch('trustgraph.storage.graph_embeddings.milvus.write.EntityVectors') + def test_processor_initialization_with_custom_params(self, mock_entity_vectors): + """Test processor initialization with custom parameters""" + taskgroup_mock = MagicMock() + mock_vecstore = MagicMock() + mock_entity_vectors.return_value = mock_vecstore + + processor = Processor( + taskgroup=taskgroup_mock, + store_uri='http://custom-milvus:19530' + ) + + mock_entity_vectors.assert_called_once_with('http://custom-milvus:19530') + assert processor.vecstore == mock_vecstore + + @pytest.mark.asyncio + async def test_store_graph_embeddings_single_entity(self, processor): + """Test storing graph embeddings for a single entity""" + message = MagicMock() + message.metadata = MagicMock() + message.metadata.user = 'test_user' + message.metadata.collection = 'test_collection' + + entity = EntityEmbeddings( + entity=Value(value='http://example.com/entity', is_uri=True), + vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]] + ) + message.entities = [entity] + + await processor.store_graph_embeddings(message) + + # Verify insert was called for each vector + expected_calls = [ + ([0.1, 0.2, 0.3], 'http://example.com/entity'), + ([0.4, 0.5, 0.6], 'http://example.com/entity'), + ] + + assert processor.vecstore.insert.call_count == 2 + for i, (expected_vec, expected_entity) in enumerate(expected_calls): + actual_call = processor.vecstore.insert.call_args_list[i] + assert actual_call[0][0] == expected_vec + assert actual_call[0][1] == expected_entity + + @pytest.mark.asyncio + async def test_store_graph_embeddings_multiple_entities(self, processor, mock_message): + """Test storing graph embeddings for multiple entities""" + await processor.store_graph_embeddings(mock_message) + + # Verify insert was called for each vector of each entity + expected_calls = [ + # Entity 1 vectors + ([0.1, 0.2, 0.3], 'http://example.com/entity1'), + ([0.4, 0.5, 0.6], 'http://example.com/entity1'), + # Entity 2 vectors + ([0.7, 0.8, 0.9], 'literal entity'), + ] + + assert processor.vecstore.insert.call_count == 3 + for i, (expected_vec, expected_entity) in enumerate(expected_calls): + actual_call = processor.vecstore.insert.call_args_list[i] + assert actual_call[0][0] == expected_vec + assert actual_call[0][1] == expected_entity + + @pytest.mark.asyncio + async def test_store_graph_embeddings_empty_entity_value(self, processor): + """Test storing graph embeddings with empty entity value (should be skipped)""" + message = MagicMock() + message.metadata = MagicMock() + message.metadata.user = 'test_user' + message.metadata.collection = 'test_collection' + + entity = EntityEmbeddings( + entity=Value(value='', is_uri=False), + vectors=[[0.1, 0.2, 0.3]] + ) + message.entities = [entity] + + await processor.store_graph_embeddings(message) + + # Verify no insert was called for empty entity + processor.vecstore.insert.assert_not_called() + + @pytest.mark.asyncio + async def test_store_graph_embeddings_none_entity_value(self, processor): + """Test storing graph embeddings with None entity value (should be skipped)""" + message = MagicMock() + message.metadata = MagicMock() + message.metadata.user = 'test_user' + message.metadata.collection = 'test_collection' + + entity = EntityEmbeddings( + entity=Value(value=None, is_uri=False), + vectors=[[0.1, 0.2, 0.3]] + ) + message.entities = [entity] + + await processor.store_graph_embeddings(message) + + # Verify no insert was called for None entity + processor.vecstore.insert.assert_not_called() + + @pytest.mark.asyncio + async def test_store_graph_embeddings_mixed_valid_invalid_entities(self, processor): + """Test storing graph embeddings with mix of valid and invalid entities""" + message = MagicMock() + message.metadata = MagicMock() + message.metadata.user = 'test_user' + message.metadata.collection = 'test_collection' + + valid_entity = EntityEmbeddings( + entity=Value(value='http://example.com/valid', is_uri=True), + vectors=[[0.1, 0.2, 0.3]] + ) + empty_entity = EntityEmbeddings( + entity=Value(value='', is_uri=False), + vectors=[[0.4, 0.5, 0.6]] + ) + none_entity = EntityEmbeddings( + entity=Value(value=None, is_uri=False), + vectors=[[0.7, 0.8, 0.9]] + ) + message.entities = [valid_entity, empty_entity, none_entity] + + await processor.store_graph_embeddings(message) + + # Verify only valid entity was inserted + processor.vecstore.insert.assert_called_once_with( + [0.1, 0.2, 0.3], 'http://example.com/valid' + ) + + @pytest.mark.asyncio + async def test_store_graph_embeddings_empty_entities_list(self, processor): + """Test storing graph embeddings with empty entities list""" + message = MagicMock() + message.metadata = MagicMock() + message.metadata.user = 'test_user' + message.metadata.collection = 'test_collection' + message.entities = [] + + await processor.store_graph_embeddings(message) + + # Verify no insert was called + processor.vecstore.insert.assert_not_called() + + @pytest.mark.asyncio + async def test_store_graph_embeddings_entity_with_no_vectors(self, processor): + """Test storing graph embeddings for entity with no vectors""" + message = MagicMock() + message.metadata = MagicMock() + message.metadata.user = 'test_user' + message.metadata.collection = 'test_collection' + + entity = EntityEmbeddings( + entity=Value(value='http://example.com/entity', is_uri=True), + vectors=[] + ) + message.entities = [entity] + + await processor.store_graph_embeddings(message) + + # Verify no insert was called (no vectors to insert) + processor.vecstore.insert.assert_not_called() + + @pytest.mark.asyncio + async def test_store_graph_embeddings_different_vector_dimensions(self, processor): + """Test storing graph embeddings with different vector dimensions""" + message = MagicMock() + message.metadata = MagicMock() + message.metadata.user = 'test_user' + message.metadata.collection = 'test_collection' + + entity = EntityEmbeddings( + entity=Value(value='http://example.com/entity', is_uri=True), + vectors=[ + [0.1, 0.2], # 2D vector + [0.3, 0.4, 0.5, 0.6], # 4D vector + [0.7, 0.8, 0.9] # 3D vector + ] + ) + message.entities = [entity] + + await processor.store_graph_embeddings(message) + + # Verify all vectors were inserted regardless of dimension + expected_calls = [ + ([0.1, 0.2], 'http://example.com/entity'), + ([0.3, 0.4, 0.5, 0.6], 'http://example.com/entity'), + ([0.7, 0.8, 0.9], 'http://example.com/entity'), + ] + + assert processor.vecstore.insert.call_count == 3 + for i, (expected_vec, expected_entity) in enumerate(expected_calls): + actual_call = processor.vecstore.insert.call_args_list[i] + assert actual_call[0][0] == expected_vec + assert actual_call[0][1] == expected_entity + + @pytest.mark.asyncio + async def test_store_graph_embeddings_uri_and_literal_entities(self, processor): + """Test storing graph embeddings for both URI and literal entities""" + message = MagicMock() + message.metadata = MagicMock() + message.metadata.user = 'test_user' + message.metadata.collection = 'test_collection' + + uri_entity = EntityEmbeddings( + entity=Value(value='http://example.com/uri_entity', is_uri=True), + vectors=[[0.1, 0.2, 0.3]] + ) + literal_entity = EntityEmbeddings( + entity=Value(value='literal entity text', is_uri=False), + vectors=[[0.4, 0.5, 0.6]] + ) + message.entities = [uri_entity, literal_entity] + + await processor.store_graph_embeddings(message) + + # Verify both entities were inserted + expected_calls = [ + ([0.1, 0.2, 0.3], 'http://example.com/uri_entity'), + ([0.4, 0.5, 0.6], 'literal entity text'), + ] + + assert processor.vecstore.insert.call_count == 2 + for i, (expected_vec, expected_entity) in enumerate(expected_calls): + actual_call = processor.vecstore.insert.call_args_list[i] + assert actual_call[0][0] == expected_vec + assert actual_call[0][1] == expected_entity + + def test_add_args_method(self): + """Test that add_args properly configures argument parser""" + from argparse import ArgumentParser + from unittest.mock import patch + + parser = ArgumentParser() + + # Mock the parent class add_args method + with patch('trustgraph.storage.graph_embeddings.milvus.write.GraphEmbeddingsStoreService.add_args') as mock_parent_add_args: + Processor.add_args(parser) + + # Verify parent add_args was called + mock_parent_add_args.assert_called_once() + + # Verify our specific arguments were added + # Parse empty args to check defaults + args = parser.parse_args([]) + + assert hasattr(args, 'store_uri') + assert args.store_uri == 'http://localhost:19530' + + def test_add_args_with_custom_values(self): + """Test add_args with custom command line values""" + from argparse import ArgumentParser + from unittest.mock import patch + + parser = ArgumentParser() + + with patch('trustgraph.storage.graph_embeddings.milvus.write.GraphEmbeddingsStoreService.add_args'): + Processor.add_args(parser) + + # Test parsing with custom values + args = parser.parse_args([ + '--store-uri', 'http://custom-milvus:19530' + ]) + + assert args.store_uri == 'http://custom-milvus:19530' + + def test_add_args_short_form(self): + """Test add_args with short form arguments""" + from argparse import ArgumentParser + from unittest.mock import patch + + parser = ArgumentParser() + + with patch('trustgraph.storage.graph_embeddings.milvus.write.GraphEmbeddingsStoreService.add_args'): + Processor.add_args(parser) + + # Test parsing with short form + args = parser.parse_args(['-t', 'http://short-milvus:19530']) + + assert args.store_uri == 'http://short-milvus:19530' + + @patch('trustgraph.storage.graph_embeddings.milvus.write.Processor.launch') + def test_run_function(self, mock_launch): + """Test the run function calls Processor.launch with correct parameters""" + from trustgraph.storage.graph_embeddings.milvus.write import run, default_ident + + run() + + mock_launch.assert_called_once_with( + default_ident, + "\nAccepts entity/vector pairs and writes them to a Milvus store.\n" + ) \ No newline at end of file diff --git a/tests/unit/test_storage/test_graph_embeddings_pinecone_storage.py b/tests/unit/test_storage/test_graph_embeddings_pinecone_storage.py new file mode 100644 index 00000000..91e60057 --- /dev/null +++ b/tests/unit/test_storage/test_graph_embeddings_pinecone_storage.py @@ -0,0 +1,460 @@ +""" +Tests for Pinecone graph embeddings storage service +""" + +import pytest +from unittest.mock import MagicMock, patch +import uuid + +from trustgraph.storage.graph_embeddings.pinecone.write import Processor +from trustgraph.schema import EntityEmbeddings, Value + + +class TestPineconeGraphEmbeddingsStorageProcessor: + """Test cases for Pinecone graph embeddings storage processor""" + + @pytest.fixture + def mock_message(self): + """Create a mock message for testing""" + message = MagicMock() + message.metadata = MagicMock() + message.metadata.user = 'test_user' + message.metadata.collection = 'test_collection' + + # Create test entity embeddings + entity1 = EntityEmbeddings( + entity=Value(value="http://example.org/entity1", is_uri=True), + vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]] + ) + entity2 = EntityEmbeddings( + entity=Value(value="entity2", is_uri=False), + vectors=[[0.7, 0.8, 0.9]] + ) + message.entities = [entity1, entity2] + + return message + + @pytest.fixture + def processor(self): + """Create a processor instance for testing""" + with patch('trustgraph.storage.graph_embeddings.pinecone.write.Pinecone') as mock_pinecone_class: + mock_pinecone = MagicMock() + mock_pinecone_class.return_value = mock_pinecone + + processor = Processor( + taskgroup=MagicMock(), + id='test-pinecone-ge-storage', + api_key='test-api-key' + ) + + return processor + + @patch('trustgraph.storage.graph_embeddings.pinecone.write.Pinecone') + @patch('trustgraph.storage.graph_embeddings.pinecone.write.default_api_key', 'env-api-key') + def test_processor_initialization_with_defaults(self, mock_pinecone_class): + """Test processor initialization with default parameters""" + mock_pinecone = MagicMock() + mock_pinecone_class.return_value = mock_pinecone + taskgroup_mock = MagicMock() + + processor = Processor(taskgroup=taskgroup_mock) + + mock_pinecone_class.assert_called_once_with(api_key='env-api-key') + assert processor.pinecone == mock_pinecone + assert processor.api_key == 'env-api-key' + assert processor.cloud == 'aws' + assert processor.region == 'us-east-1' + + @patch('trustgraph.storage.graph_embeddings.pinecone.write.Pinecone') + def test_processor_initialization_with_custom_params(self, mock_pinecone_class): + """Test processor initialization with custom parameters""" + mock_pinecone = MagicMock() + mock_pinecone_class.return_value = mock_pinecone + taskgroup_mock = MagicMock() + + processor = Processor( + taskgroup=taskgroup_mock, + api_key='custom-api-key', + cloud='gcp', + region='us-west1' + ) + + mock_pinecone_class.assert_called_once_with(api_key='custom-api-key') + assert processor.api_key == 'custom-api-key' + assert processor.cloud == 'gcp' + assert processor.region == 'us-west1' + + @patch('trustgraph.storage.graph_embeddings.pinecone.write.PineconeGRPC') + def test_processor_initialization_with_url(self, mock_pinecone_grpc_class): + """Test processor initialization with custom URL (GRPC mode)""" + mock_pinecone = MagicMock() + mock_pinecone_grpc_class.return_value = mock_pinecone + taskgroup_mock = MagicMock() + + processor = Processor( + taskgroup=taskgroup_mock, + api_key='test-api-key', + url='https://custom-host.pinecone.io' + ) + + mock_pinecone_grpc_class.assert_called_once_with( + api_key='test-api-key', + host='https://custom-host.pinecone.io' + ) + assert processor.pinecone == mock_pinecone + assert processor.url == 'https://custom-host.pinecone.io' + + @patch('trustgraph.storage.graph_embeddings.pinecone.write.default_api_key', 'not-specified') + def test_processor_initialization_missing_api_key(self): + """Test processor initialization fails with missing API key""" + taskgroup_mock = MagicMock() + + with pytest.raises(RuntimeError, match="Pinecone API key must be specified"): + Processor(taskgroup=taskgroup_mock) + + @pytest.mark.asyncio + async def test_store_graph_embeddings_single_entity(self, processor): + """Test storing graph embeddings for a single entity""" + message = MagicMock() + message.metadata = MagicMock() + message.metadata.user = 'test_user' + message.metadata.collection = 'test_collection' + + entity = EntityEmbeddings( + entity=Value(value="http://example.org/entity1", is_uri=True), + vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]] + ) + message.entities = [entity] + + # Mock index operations + mock_index = MagicMock() + processor.pinecone.Index.return_value = mock_index + processor.pinecone.has_index.return_value = True + + with patch('uuid.uuid4', side_effect=['id1', 'id2']): + await processor.store_graph_embeddings(message) + + # Verify index name and operations + expected_index_name = "t-test_user-test_collection-3" + processor.pinecone.Index.assert_called_with(expected_index_name) + + # Verify upsert was called for each vector + assert mock_index.upsert.call_count == 2 + + # Check first vector upsert + first_call = mock_index.upsert.call_args_list[0] + first_vectors = first_call[1]['vectors'] + assert len(first_vectors) == 1 + assert first_vectors[0]['id'] == 'id1' + assert first_vectors[0]['values'] == [0.1, 0.2, 0.3] + assert first_vectors[0]['metadata']['entity'] == "http://example.org/entity1" + + # Check second vector upsert + second_call = mock_index.upsert.call_args_list[1] + second_vectors = second_call[1]['vectors'] + assert len(second_vectors) == 1 + assert second_vectors[0]['id'] == 'id2' + assert second_vectors[0]['values'] == [0.4, 0.5, 0.6] + assert second_vectors[0]['metadata']['entity'] == "http://example.org/entity1" + + @pytest.mark.asyncio + async def test_store_graph_embeddings_multiple_entities(self, processor, mock_message): + """Test storing graph embeddings for multiple entities""" + # Mock index operations + mock_index = MagicMock() + processor.pinecone.Index.return_value = mock_index + processor.pinecone.has_index.return_value = True + + with patch('uuid.uuid4', side_effect=['id1', 'id2', 'id3']): + await processor.store_graph_embeddings(mock_message) + + # Verify upsert was called for each vector (3 total) + assert mock_index.upsert.call_count == 3 + + # Verify entity values in metadata + calls = mock_index.upsert.call_args_list + assert calls[0][1]['vectors'][0]['metadata']['entity'] == "http://example.org/entity1" + assert calls[1][1]['vectors'][0]['metadata']['entity'] == "http://example.org/entity1" + assert calls[2][1]['vectors'][0]['metadata']['entity'] == "entity2" + + @pytest.mark.asyncio + async def test_store_graph_embeddings_index_creation(self, processor): + """Test automatic index creation when index doesn't exist""" + message = MagicMock() + message.metadata = MagicMock() + message.metadata.user = 'test_user' + message.metadata.collection = 'test_collection' + + entity = EntityEmbeddings( + entity=Value(value="test_entity", is_uri=False), + vectors=[[0.1, 0.2, 0.3]] + ) + message.entities = [entity] + + # Mock index doesn't exist initially + processor.pinecone.has_index.return_value = False + mock_index = MagicMock() + processor.pinecone.Index.return_value = mock_index + + # Mock index creation + processor.pinecone.describe_index.return_value.status = {"ready": True} + + with patch('uuid.uuid4', return_value='test-id'): + await processor.store_graph_embeddings(message) + + # Verify index creation was called + expected_index_name = "t-test_user-test_collection-3" + processor.pinecone.create_index.assert_called_once() + create_call = processor.pinecone.create_index.call_args + assert create_call[1]['name'] == expected_index_name + assert create_call[1]['dimension'] == 3 + assert create_call[1]['metric'] == "cosine" + + @pytest.mark.asyncio + async def test_store_graph_embeddings_empty_entity_value(self, processor): + """Test storing graph embeddings with empty entity value (should be skipped)""" + message = MagicMock() + message.metadata = MagicMock() + message.metadata.user = 'test_user' + message.metadata.collection = 'test_collection' + + entity = EntityEmbeddings( + entity=Value(value="", is_uri=False), + vectors=[[0.1, 0.2, 0.3]] + ) + message.entities = [entity] + + mock_index = MagicMock() + processor.pinecone.Index.return_value = mock_index + + await processor.store_graph_embeddings(message) + + # Verify no upsert was called for empty entity + mock_index.upsert.assert_not_called() + + @pytest.mark.asyncio + async def test_store_graph_embeddings_none_entity_value(self, processor): + """Test storing graph embeddings with None entity value (should be skipped)""" + message = MagicMock() + message.metadata = MagicMock() + message.metadata.user = 'test_user' + message.metadata.collection = 'test_collection' + + entity = EntityEmbeddings( + entity=Value(value=None, is_uri=False), + vectors=[[0.1, 0.2, 0.3]] + ) + message.entities = [entity] + + mock_index = MagicMock() + processor.pinecone.Index.return_value = mock_index + + await processor.store_graph_embeddings(message) + + # Verify no upsert was called for None entity + mock_index.upsert.assert_not_called() + + @pytest.mark.asyncio + async def test_store_graph_embeddings_different_vector_dimensions(self, processor): + """Test storing graph embeddings with different vector dimensions""" + message = MagicMock() + message.metadata = MagicMock() + message.metadata.user = 'test_user' + message.metadata.collection = 'test_collection' + + entity = EntityEmbeddings( + entity=Value(value="test_entity", is_uri=False), + vectors=[ + [0.1, 0.2], # 2D vector + [0.3, 0.4, 0.5, 0.6], # 4D vector + [0.7, 0.8, 0.9] # 3D vector + ] + ) + message.entities = [entity] + + mock_index_2d = MagicMock() + mock_index_4d = MagicMock() + mock_index_3d = MagicMock() + + def mock_index_side_effect(name): + if name.endswith("-2"): + return mock_index_2d + elif name.endswith("-4"): + return mock_index_4d + elif name.endswith("-3"): + return mock_index_3d + + processor.pinecone.Index.side_effect = mock_index_side_effect + processor.pinecone.has_index.return_value = True + + with patch('uuid.uuid4', side_effect=['id1', 'id2', 'id3']): + await processor.store_graph_embeddings(message) + + # Verify different indexes were used for different dimensions + assert processor.pinecone.Index.call_count == 3 + mock_index_2d.upsert.assert_called_once() + mock_index_4d.upsert.assert_called_once() + mock_index_3d.upsert.assert_called_once() + + @pytest.mark.asyncio + async def test_store_graph_embeddings_empty_entities_list(self, processor): + """Test storing graph embeddings with empty entities list""" + message = MagicMock() + message.metadata = MagicMock() + message.metadata.user = 'test_user' + message.metadata.collection = 'test_collection' + message.entities = [] + + mock_index = MagicMock() + processor.pinecone.Index.return_value = mock_index + + await processor.store_graph_embeddings(message) + + # Verify no operations were performed + processor.pinecone.Index.assert_not_called() + mock_index.upsert.assert_not_called() + + @pytest.mark.asyncio + async def test_store_graph_embeddings_entity_with_no_vectors(self, processor): + """Test storing graph embeddings for entity with no vectors""" + message = MagicMock() + message.metadata = MagicMock() + message.metadata.user = 'test_user' + message.metadata.collection = 'test_collection' + + entity = EntityEmbeddings( + entity=Value(value="test_entity", is_uri=False), + vectors=[] + ) + message.entities = [entity] + + mock_index = MagicMock() + processor.pinecone.Index.return_value = mock_index + + await processor.store_graph_embeddings(message) + + # Verify no upsert was called (no vectors to insert) + mock_index.upsert.assert_not_called() + + @pytest.mark.asyncio + async def test_store_graph_embeddings_index_creation_failure(self, processor): + """Test handling of index creation failure""" + message = MagicMock() + message.metadata = MagicMock() + message.metadata.user = 'test_user' + message.metadata.collection = 'test_collection' + + entity = EntityEmbeddings( + entity=Value(value="test_entity", is_uri=False), + vectors=[[0.1, 0.2, 0.3]] + ) + message.entities = [entity] + + # Mock index doesn't exist and creation fails + processor.pinecone.has_index.return_value = False + processor.pinecone.create_index.side_effect = Exception("Index creation failed") + + with pytest.raises(Exception, match="Index creation failed"): + await processor.store_graph_embeddings(message) + + @pytest.mark.asyncio + async def test_store_graph_embeddings_index_creation_timeout(self, processor): + """Test handling of index creation timeout""" + message = MagicMock() + message.metadata = MagicMock() + message.metadata.user = 'test_user' + message.metadata.collection = 'test_collection' + + entity = EntityEmbeddings( + entity=Value(value="test_entity", is_uri=False), + vectors=[[0.1, 0.2, 0.3]] + ) + message.entities = [entity] + + # Mock index doesn't exist and never becomes ready + processor.pinecone.has_index.return_value = False + processor.pinecone.describe_index.return_value.status = {"ready": False} + + with patch('time.sleep'): # Speed up the test + with pytest.raises(RuntimeError, match="Gave up waiting for index creation"): + await processor.store_graph_embeddings(message) + + def test_add_args_method(self): + """Test that add_args properly configures argument parser""" + from argparse import ArgumentParser + from unittest.mock import patch + + parser = ArgumentParser() + + # Mock the parent class add_args method + with patch('trustgraph.storage.graph_embeddings.pinecone.write.GraphEmbeddingsStoreService.add_args') as mock_parent_add_args: + Processor.add_args(parser) + + # Verify parent add_args was called + mock_parent_add_args.assert_called_once() + + # Verify our specific arguments were added by parsing empty args + args = parser.parse_args([]) + + assert hasattr(args, 'api_key') + assert args.api_key == 'not-specified' # Default value when no env var + assert hasattr(args, 'url') + assert args.url is None + assert hasattr(args, 'cloud') + assert args.cloud == 'aws' + assert hasattr(args, 'region') + assert args.region == 'us-east-1' + + def test_add_args_with_custom_values(self): + """Test add_args with custom command line values""" + from argparse import ArgumentParser + from unittest.mock import patch + + parser = ArgumentParser() + + with patch('trustgraph.storage.graph_embeddings.pinecone.write.GraphEmbeddingsStoreService.add_args'): + Processor.add_args(parser) + + # Test parsing with custom values + args = parser.parse_args([ + '--api-key', 'custom-api-key', + '--url', 'https://custom-host.pinecone.io', + '--cloud', 'gcp', + '--region', 'us-west1' + ]) + + assert args.api_key == 'custom-api-key' + assert args.url == 'https://custom-host.pinecone.io' + assert args.cloud == 'gcp' + assert args.region == 'us-west1' + + def test_add_args_short_form(self): + """Test add_args with short form arguments""" + from argparse import ArgumentParser + from unittest.mock import patch + + parser = ArgumentParser() + + with patch('trustgraph.storage.graph_embeddings.pinecone.write.GraphEmbeddingsStoreService.add_args'): + Processor.add_args(parser) + + # Test parsing with short form + args = parser.parse_args([ + '-a', 'short-api-key', + '-u', 'https://short-host.pinecone.io' + ]) + + assert args.api_key == 'short-api-key' + assert args.url == 'https://short-host.pinecone.io' + + @patch('trustgraph.storage.graph_embeddings.pinecone.write.Processor.launch') + def test_run_function(self, mock_launch): + """Test the run function calls Processor.launch with correct parameters""" + from trustgraph.storage.graph_embeddings.pinecone.write import run, default_ident + + run() + + mock_launch.assert_called_once_with( + default_ident, + "\nAccepts entity/vector pairs and writes them to a Pinecone store.\n" + ) \ No newline at end of file diff --git a/tests/unit/test_storage/test_triples_falkordb_storage.py b/tests/unit/test_storage/test_triples_falkordb_storage.py new file mode 100644 index 00000000..7d602b6f --- /dev/null +++ b/tests/unit/test_storage/test_triples_falkordb_storage.py @@ -0,0 +1,436 @@ +""" +Tests for FalkorDB triples storage service +""" + +import pytest +from unittest.mock import MagicMock, patch + +from trustgraph.storage.triples.falkordb.write import Processor +from trustgraph.schema import Value, Triple + + +class TestFalkorDBStorageProcessor: + """Test cases for FalkorDB storage processor""" + + @pytest.fixture + def mock_message(self): + """Create a mock message for testing""" + message = MagicMock() + message.metadata = MagicMock() + message.metadata.user = 'test_user' + message.metadata.collection = 'test_collection' + + # Create a test triple + triple = Triple( + s=Value(value='http://example.com/subject', is_uri=True), + p=Value(value='http://example.com/predicate', is_uri=True), + o=Value(value='literal object', is_uri=False) + ) + message.triples = [triple] + + return message + + @pytest.fixture + def processor(self): + """Create a processor instance for testing""" + with patch('trustgraph.storage.triples.falkordb.write.FalkorDB') as mock_falkordb: + mock_client = MagicMock() + mock_graph = MagicMock() + mock_falkordb.from_url.return_value = mock_client + mock_client.select_graph.return_value = mock_graph + + return Processor( + taskgroup=MagicMock(), + id='test-falkordb-storage', + graph_url='falkor://localhost:6379', + database='test_db' + ) + + @patch('trustgraph.storage.triples.falkordb.write.FalkorDB') + def test_processor_initialization_with_defaults(self, mock_falkordb): + """Test processor initialization with default parameters""" + taskgroup_mock = MagicMock() + mock_client = MagicMock() + mock_graph = MagicMock() + mock_falkordb.from_url.return_value = mock_client + mock_client.select_graph.return_value = mock_graph + + processor = Processor(taskgroup=taskgroup_mock) + + assert processor.db == 'falkordb' + mock_falkordb.from_url.assert_called_once_with('falkor://falkordb:6379') + mock_client.select_graph.assert_called_once_with('falkordb') + + @patch('trustgraph.storage.triples.falkordb.write.FalkorDB') + def test_processor_initialization_with_custom_params(self, mock_falkordb): + """Test processor initialization with custom parameters""" + taskgroup_mock = MagicMock() + mock_client = MagicMock() + mock_graph = MagicMock() + mock_falkordb.from_url.return_value = mock_client + mock_client.select_graph.return_value = mock_graph + + processor = Processor( + taskgroup=taskgroup_mock, + graph_url='falkor://custom:6379', + database='custom_db' + ) + + assert processor.db == 'custom_db' + mock_falkordb.from_url.assert_called_once_with('falkor://custom:6379') + mock_client.select_graph.assert_called_once_with('custom_db') + + def test_create_node(self, processor): + """Test node creation""" + test_uri = 'http://example.com/node' + mock_result = MagicMock() + mock_result.nodes_created = 1 + mock_result.run_time_ms = 10 + + processor.io.query.return_value = mock_result + + processor.create_node(test_uri) + + processor.io.query.assert_called_once_with( + "MERGE (n:Node {uri: $uri})", + params={ + "uri": test_uri, + }, + ) + + def test_create_literal(self, processor): + """Test literal creation""" + test_value = 'test literal value' + mock_result = MagicMock() + mock_result.nodes_created = 1 + mock_result.run_time_ms = 10 + + processor.io.query.return_value = mock_result + + processor.create_literal(test_value) + + processor.io.query.assert_called_once_with( + "MERGE (n:Literal {value: $value})", + params={ + "value": test_value, + }, + ) + + def test_relate_node(self, processor): + """Test node-to-node relationship creation""" + src_uri = 'http://example.com/src' + pred_uri = 'http://example.com/pred' + dest_uri = 'http://example.com/dest' + + mock_result = MagicMock() + mock_result.nodes_created = 0 + mock_result.run_time_ms = 5 + + processor.io.query.return_value = mock_result + + processor.relate_node(src_uri, pred_uri, dest_uri) + + processor.io.query.assert_called_once_with( + "MATCH (src:Node {uri: $src}) " + "MATCH (dest:Node {uri: $dest}) " + "MERGE (src)-[:Rel {uri: $uri}]->(dest)", + params={ + "src": src_uri, + "dest": dest_uri, + "uri": pred_uri, + }, + ) + + def test_relate_literal(self, processor): + """Test node-to-literal relationship creation""" + src_uri = 'http://example.com/src' + pred_uri = 'http://example.com/pred' + literal_value = 'literal destination' + + mock_result = MagicMock() + mock_result.nodes_created = 0 + mock_result.run_time_ms = 5 + + processor.io.query.return_value = mock_result + + processor.relate_literal(src_uri, pred_uri, literal_value) + + processor.io.query.assert_called_once_with( + "MATCH (src:Node {uri: $src}) " + "MATCH (dest:Literal {value: $dest}) " + "MERGE (src)-[:Rel {uri: $uri}]->(dest)", + params={ + "src": src_uri, + "dest": literal_value, + "uri": pred_uri, + }, + ) + + @pytest.mark.asyncio + async def test_store_triples_with_uri_object(self, processor): + """Test storing triple with URI object""" + message = MagicMock() + message.metadata = MagicMock() + message.metadata.user = 'test_user' + message.metadata.collection = 'test_collection' + + triple = Triple( + s=Value(value='http://example.com/subject', is_uri=True), + p=Value(value='http://example.com/predicate', is_uri=True), + o=Value(value='http://example.com/object', is_uri=True) + ) + message.triples = [triple] + + mock_result = MagicMock() + mock_result.nodes_created = 1 + mock_result.run_time_ms = 10 + processor.io.query.return_value = mock_result + + await processor.store_triples(message) + + # Verify queries were called in the correct order + expected_calls = [ + # Create subject node + (("MERGE (n:Node {uri: $uri})",), {"params": {"uri": "http://example.com/subject"}}), + # Create object node + (("MERGE (n:Node {uri: $uri})",), {"params": {"uri": "http://example.com/object"}}), + # Create relationship + (("MATCH (src:Node {uri: $src}) " + "MATCH (dest:Node {uri: $dest}) " + "MERGE (src)-[:Rel {uri: $uri}]->(dest)",), + {"params": {"src": "http://example.com/subject", "dest": "http://example.com/object", "uri": "http://example.com/predicate"}}), + ] + + assert processor.io.query.call_count == 3 + for i, (expected_args, expected_kwargs) in enumerate(expected_calls): + actual_call = processor.io.query.call_args_list[i] + assert actual_call[0] == expected_args + assert actual_call[1] == expected_kwargs + + @pytest.mark.asyncio + async def test_store_triples_with_literal_object(self, processor, mock_message): + """Test storing triple with literal object""" + mock_result = MagicMock() + mock_result.nodes_created = 1 + mock_result.run_time_ms = 10 + processor.io.query.return_value = mock_result + + await processor.store_triples(mock_message) + + # Verify queries were called in the correct order + expected_calls = [ + # Create subject node + (("MERGE (n:Node {uri: $uri})",), {"params": {"uri": "http://example.com/subject"}}), + # Create literal object + (("MERGE (n:Literal {value: $value})",), {"params": {"value": "literal object"}}), + # Create relationship + (("MATCH (src:Node {uri: $src}) " + "MATCH (dest:Literal {value: $dest}) " + "MERGE (src)-[:Rel {uri: $uri}]->(dest)",), + {"params": {"src": "http://example.com/subject", "dest": "literal object", "uri": "http://example.com/predicate"}}), + ] + + assert processor.io.query.call_count == 3 + for i, (expected_args, expected_kwargs) in enumerate(expected_calls): + actual_call = processor.io.query.call_args_list[i] + assert actual_call[0] == expected_args + assert actual_call[1] == expected_kwargs + + @pytest.mark.asyncio + async def test_store_triples_multiple_triples(self, processor): + """Test storing multiple triples""" + message = MagicMock() + message.metadata = MagicMock() + message.metadata.user = 'test_user' + message.metadata.collection = 'test_collection' + + triple1 = Triple( + s=Value(value='http://example.com/subject1', is_uri=True), + p=Value(value='http://example.com/predicate1', is_uri=True), + o=Value(value='literal object1', is_uri=False) + ) + triple2 = Triple( + s=Value(value='http://example.com/subject2', is_uri=True), + p=Value(value='http://example.com/predicate2', is_uri=True), + o=Value(value='http://example.com/object2', is_uri=True) + ) + message.triples = [triple1, triple2] + + mock_result = MagicMock() + mock_result.nodes_created = 1 + mock_result.run_time_ms = 10 + processor.io.query.return_value = mock_result + + await processor.store_triples(message) + + # Verify total number of queries (3 per triple) + assert processor.io.query.call_count == 6 + + # Verify first triple operations + first_triple_calls = processor.io.query.call_args_list[0:3] + assert first_triple_calls[0][1]["params"]["uri"] == "http://example.com/subject1" + assert first_triple_calls[1][1]["params"]["value"] == "literal object1" + assert first_triple_calls[2][1]["params"]["src"] == "http://example.com/subject1" + + # Verify second triple operations + second_triple_calls = processor.io.query.call_args_list[3:6] + assert second_triple_calls[0][1]["params"]["uri"] == "http://example.com/subject2" + assert second_triple_calls[1][1]["params"]["uri"] == "http://example.com/object2" + assert second_triple_calls[2][1]["params"]["src"] == "http://example.com/subject2" + + @pytest.mark.asyncio + async def test_store_triples_empty_list(self, processor): + """Test storing empty triples list""" + message = MagicMock() + message.metadata = MagicMock() + message.metadata.user = 'test_user' + message.metadata.collection = 'test_collection' + message.triples = [] + + await processor.store_triples(message) + + # Verify no queries were made + processor.io.query.assert_not_called() + + @pytest.mark.asyncio + async def test_store_triples_mixed_objects(self, processor): + """Test storing triples with mixed URI and literal objects""" + message = MagicMock() + message.metadata = MagicMock() + message.metadata.user = 'test_user' + message.metadata.collection = 'test_collection' + + triple1 = Triple( + s=Value(value='http://example.com/subject1', is_uri=True), + p=Value(value='http://example.com/predicate1', is_uri=True), + o=Value(value='literal object', is_uri=False) + ) + triple2 = Triple( + s=Value(value='http://example.com/subject2', is_uri=True), + p=Value(value='http://example.com/predicate2', is_uri=True), + o=Value(value='http://example.com/object2', is_uri=True) + ) + message.triples = [triple1, triple2] + + mock_result = MagicMock() + mock_result.nodes_created = 1 + mock_result.run_time_ms = 10 + processor.io.query.return_value = mock_result + + await processor.store_triples(message) + + # Verify total number of queries (3 per triple) + assert processor.io.query.call_count == 6 + + # Verify first triple creates literal + assert "Literal" in processor.io.query.call_args_list[1][0][0] + assert processor.io.query.call_args_list[1][1]["params"]["value"] == "literal object" + + # Verify second triple creates node + assert "Node" in processor.io.query.call_args_list[4][0][0] + assert processor.io.query.call_args_list[4][1]["params"]["uri"] == "http://example.com/object2" + + def test_add_args_method(self): + """Test that add_args properly configures argument parser""" + from argparse import ArgumentParser + from unittest.mock import patch + + parser = ArgumentParser() + + # Mock the parent class add_args method + with patch('trustgraph.storage.triples.falkordb.write.TriplesStoreService.add_args') as mock_parent_add_args: + Processor.add_args(parser) + + # Verify parent add_args was called + mock_parent_add_args.assert_called_once() + + # Verify our specific arguments were added + # Parse empty args to check defaults + args = parser.parse_args([]) + + assert hasattr(args, 'graph_url') + assert args.graph_url == 'falkor://falkordb:6379' + assert hasattr(args, 'database') + assert args.database == 'falkordb' + + def test_add_args_with_custom_values(self): + """Test add_args with custom command line values""" + from argparse import ArgumentParser + from unittest.mock import patch + + parser = ArgumentParser() + + with patch('trustgraph.storage.triples.falkordb.write.TriplesStoreService.add_args'): + Processor.add_args(parser) + + # Test parsing with custom values + args = parser.parse_args([ + '--graph-url', 'falkor://custom:6379', + '--database', 'custom_db' + ]) + + assert args.graph_url == 'falkor://custom:6379' + assert args.database == 'custom_db' + + def test_add_args_short_form(self): + """Test add_args with short form arguments""" + from argparse import ArgumentParser + from unittest.mock import patch + + parser = ArgumentParser() + + with patch('trustgraph.storage.triples.falkordb.write.TriplesStoreService.add_args'): + Processor.add_args(parser) + + # Test parsing with short form + args = parser.parse_args(['-g', 'falkor://short:6379']) + + assert args.graph_url == 'falkor://short:6379' + + @patch('trustgraph.storage.triples.falkordb.write.Processor.launch') + def test_run_function(self, mock_launch): + """Test the run function calls Processor.launch with correct parameters""" + from trustgraph.storage.triples.falkordb.write import run, default_ident + + run() + + mock_launch.assert_called_once_with( + default_ident, + "\nGraph writer. Input is graph edge. Writes edges to FalkorDB graph.\n" + ) + + def test_create_node_with_special_characters(self, processor): + """Test node creation with special characters in URI""" + test_uri = 'http://example.com/node with spaces & symbols' + mock_result = MagicMock() + mock_result.nodes_created = 1 + mock_result.run_time_ms = 10 + + processor.io.query.return_value = mock_result + + processor.create_node(test_uri) + + processor.io.query.assert_called_once_with( + "MERGE (n:Node {uri: $uri})", + params={ + "uri": test_uri, + }, + ) + + def test_create_literal_with_special_characters(self, processor): + """Test literal creation with special characters""" + test_value = 'literal with "quotes" and \n newlines' + mock_result = MagicMock() + mock_result.nodes_created = 1 + mock_result.run_time_ms = 10 + + processor.io.query.return_value = mock_result + + processor.create_literal(test_value) + + processor.io.query.assert_called_once_with( + "MERGE (n:Literal {value: $value})", + params={ + "value": test_value, + }, + ) \ No newline at end of file diff --git a/tests/unit/test_storage/test_triples_memgraph_storage.py b/tests/unit/test_storage/test_triples_memgraph_storage.py new file mode 100644 index 00000000..83dfdbc4 --- /dev/null +++ b/tests/unit/test_storage/test_triples_memgraph_storage.py @@ -0,0 +1,441 @@ +""" +Tests for Memgraph triples storage service +""" + +import pytest +from unittest.mock import MagicMock, patch + +from trustgraph.storage.triples.memgraph.write import Processor +from trustgraph.schema import Value, Triple + + +class TestMemgraphStorageProcessor: + """Test cases for Memgraph storage processor""" + + @pytest.fixture + def mock_message(self): + """Create a mock message for testing""" + message = MagicMock() + message.metadata = MagicMock() + message.metadata.user = 'test_user' + message.metadata.collection = 'test_collection' + + # Create a test triple + triple = Triple( + s=Value(value='http://example.com/subject', is_uri=True), + p=Value(value='http://example.com/predicate', is_uri=True), + o=Value(value='literal object', is_uri=False) + ) + message.triples = [triple] + + return message + + @pytest.fixture + def processor(self): + """Create a processor instance for testing""" + with patch('trustgraph.storage.triples.memgraph.write.GraphDatabase') as mock_graph_db: + mock_driver = MagicMock() + mock_session = MagicMock() + mock_graph_db.driver.return_value = mock_driver + mock_driver.session.return_value.__enter__.return_value = mock_session + + return Processor( + taskgroup=MagicMock(), + id='test-memgraph-storage', + graph_host='bolt://localhost:7687', + username='test_user', + password='test_pass', + database='test_db' + ) + + @patch('trustgraph.storage.triples.memgraph.write.GraphDatabase') + def test_processor_initialization_with_defaults(self, mock_graph_db): + """Test processor initialization with default parameters""" + taskgroup_mock = MagicMock() + mock_driver = MagicMock() + mock_session = MagicMock() + mock_graph_db.driver.return_value = mock_driver + mock_driver.session.return_value.__enter__.return_value = mock_session + + processor = Processor(taskgroup=taskgroup_mock) + + assert processor.db == 'memgraph' + mock_graph_db.driver.assert_called_once_with( + 'bolt://memgraph:7687', + auth=('memgraph', 'password') + ) + + @patch('trustgraph.storage.triples.memgraph.write.GraphDatabase') + def test_processor_initialization_with_custom_params(self, mock_graph_db): + """Test processor initialization with custom parameters""" + taskgroup_mock = MagicMock() + mock_driver = MagicMock() + mock_session = MagicMock() + mock_graph_db.driver.return_value = mock_driver + mock_driver.session.return_value.__enter__.return_value = mock_session + + processor = Processor( + taskgroup=taskgroup_mock, + graph_host='bolt://custom:7687', + username='custom_user', + password='custom_pass', + database='custom_db' + ) + + assert processor.db == 'custom_db' + mock_graph_db.driver.assert_called_once_with( + 'bolt://custom:7687', + auth=('custom_user', 'custom_pass') + ) + + @patch('trustgraph.storage.triples.memgraph.write.GraphDatabase') + def test_create_indexes_success(self, mock_graph_db): + """Test successful index creation""" + taskgroup_mock = MagicMock() + mock_driver = MagicMock() + mock_session = MagicMock() + mock_graph_db.driver.return_value = mock_driver + mock_driver.session.return_value.__enter__.return_value = mock_session + + processor = Processor(taskgroup=taskgroup_mock) + + # Verify index creation calls + expected_calls = [ + "CREATE INDEX ON :Node", + "CREATE INDEX ON :Node(uri)", + "CREATE INDEX ON :Literal", + "CREATE INDEX ON :Literal(value)" + ] + + assert mock_session.run.call_count == len(expected_calls) + for i, expected_call in enumerate(expected_calls): + actual_call = mock_session.run.call_args_list[i][0][0] + assert actual_call == expected_call + + @patch('trustgraph.storage.triples.memgraph.write.GraphDatabase') + def test_create_indexes_with_exceptions(self, mock_graph_db): + """Test index creation with exceptions (should be ignored)""" + taskgroup_mock = MagicMock() + mock_driver = MagicMock() + mock_session = MagicMock() + mock_graph_db.driver.return_value = mock_driver + mock_driver.session.return_value.__enter__.return_value = mock_session + + # Make all index creation calls raise exceptions + mock_session.run.side_effect = Exception("Index already exists") + + # Should not raise an exception + processor = Processor(taskgroup=taskgroup_mock) + + # Verify all index creation calls were attempted + assert mock_session.run.call_count == 4 + + def test_create_node(self, processor): + """Test node creation""" + test_uri = 'http://example.com/node' + mock_result = MagicMock() + mock_summary = MagicMock() + mock_summary.counters.nodes_created = 1 + mock_summary.result_available_after = 10 + mock_result.summary = mock_summary + + processor.io.execute_query.return_value = mock_result + + processor.create_node(test_uri) + + processor.io.execute_query.assert_called_once_with( + "MERGE (n:Node {uri: $uri})", + uri=test_uri, + database_=processor.db + ) + + def test_create_literal(self, processor): + """Test literal creation""" + test_value = 'test literal value' + mock_result = MagicMock() + mock_summary = MagicMock() + mock_summary.counters.nodes_created = 1 + mock_summary.result_available_after = 10 + mock_result.summary = mock_summary + + processor.io.execute_query.return_value = mock_result + + processor.create_literal(test_value) + + processor.io.execute_query.assert_called_once_with( + "MERGE (n:Literal {value: $value})", + value=test_value, + database_=processor.db + ) + + def test_relate_node(self, processor): + """Test node-to-node relationship creation""" + src_uri = 'http://example.com/src' + pred_uri = 'http://example.com/pred' + dest_uri = 'http://example.com/dest' + + mock_result = MagicMock() + mock_summary = MagicMock() + mock_summary.counters.nodes_created = 0 + mock_summary.result_available_after = 5 + mock_result.summary = mock_summary + + processor.io.execute_query.return_value = mock_result + + processor.relate_node(src_uri, pred_uri, dest_uri) + + processor.io.execute_query.assert_called_once_with( + "MATCH (src:Node {uri: $src}) " + "MATCH (dest:Node {uri: $dest}) " + "MERGE (src)-[:Rel {uri: $uri}]->(dest)", + src=src_uri, dest=dest_uri, uri=pred_uri, + database_=processor.db + ) + + def test_relate_literal(self, processor): + """Test node-to-literal relationship creation""" + src_uri = 'http://example.com/src' + pred_uri = 'http://example.com/pred' + literal_value = 'literal destination' + + mock_result = MagicMock() + mock_summary = MagicMock() + mock_summary.counters.nodes_created = 0 + mock_summary.result_available_after = 5 + mock_result.summary = mock_summary + + processor.io.execute_query.return_value = mock_result + + processor.relate_literal(src_uri, pred_uri, literal_value) + + processor.io.execute_query.assert_called_once_with( + "MATCH (src:Node {uri: $src}) " + "MATCH (dest:Literal {value: $dest}) " + "MERGE (src)-[:Rel {uri: $uri}]->(dest)", + src=src_uri, dest=literal_value, uri=pred_uri, + database_=processor.db + ) + + def test_create_triple_with_uri_object(self, processor): + """Test triple creation with URI object""" + mock_tx = MagicMock() + + triple = Triple( + s=Value(value='http://example.com/subject', is_uri=True), + p=Value(value='http://example.com/predicate', is_uri=True), + o=Value(value='http://example.com/object', is_uri=True) + ) + + processor.create_triple(mock_tx, triple) + + # Verify transaction calls + expected_calls = [ + # Create subject node + ("MERGE (n:Node {uri: $uri})", {'uri': 'http://example.com/subject'}), + # Create object node + ("MERGE (n:Node {uri: $uri})", {'uri': 'http://example.com/object'}), + # Create relationship + ("MATCH (src:Node {uri: $src}) " + "MATCH (dest:Node {uri: $dest}) " + "MERGE (src)-[:Rel {uri: $uri}]->(dest)", + {'src': 'http://example.com/subject', 'dest': 'http://example.com/object', 'uri': 'http://example.com/predicate'}) + ] + + assert mock_tx.run.call_count == 3 + for i, (expected_query, expected_params) in enumerate(expected_calls): + actual_call = mock_tx.run.call_args_list[i] + assert actual_call[0][0] == expected_query + assert actual_call[1] == expected_params + + def test_create_triple_with_literal_object(self, processor): + """Test triple creation with literal object""" + mock_tx = MagicMock() + + triple = Triple( + s=Value(value='http://example.com/subject', is_uri=True), + p=Value(value='http://example.com/predicate', is_uri=True), + o=Value(value='literal object', is_uri=False) + ) + + processor.create_triple(mock_tx, triple) + + # Verify transaction calls + expected_calls = [ + # Create subject node + ("MERGE (n:Node {uri: $uri})", {'uri': 'http://example.com/subject'}), + # Create literal object + ("MERGE (n:Literal {value: $value})", {'value': 'literal object'}), + # Create relationship + ("MATCH (src:Node {uri: $src}) " + "MATCH (dest:Literal {value: $dest}) " + "MERGE (src)-[:Rel {uri: $uri}]->(dest)", + {'src': 'http://example.com/subject', 'dest': 'literal object', 'uri': 'http://example.com/predicate'}) + ] + + assert mock_tx.run.call_count == 3 + for i, (expected_query, expected_params) in enumerate(expected_calls): + actual_call = mock_tx.run.call_args_list[i] + assert actual_call[0][0] == expected_query + assert actual_call[1] == expected_params + + @pytest.mark.asyncio + async def test_store_triples_single_triple(self, processor, mock_message): + """Test storing a single triple""" + mock_session = MagicMock() + processor.io.session.return_value.__enter__.return_value = mock_session + + # Reset the mock to clear the initialization call + processor.io.session.reset_mock() + + await processor.store_triples(mock_message) + + # Verify session was created with correct database + processor.io.session.assert_called_once_with(database=processor.db) + + # Verify execute_write was called once per triple + mock_session.execute_write.assert_called_once() + + # Verify the triple was passed to create_triple + call_args = mock_session.execute_write.call_args + assert call_args[0][0] == processor.create_triple + assert call_args[0][1] == mock_message.triples[0] + + @pytest.mark.asyncio + async def test_store_triples_multiple_triples(self, processor): + """Test storing multiple triples""" + mock_session = MagicMock() + processor.io.session.return_value.__enter__.return_value = mock_session + + # Reset the mock to clear the initialization call + processor.io.session.reset_mock() + + # Create message with multiple triples + message = MagicMock() + message.metadata = MagicMock() + message.metadata.user = 'test_user' + message.metadata.collection = 'test_collection' + + triple1 = Triple( + s=Value(value='http://example.com/subject1', is_uri=True), + p=Value(value='http://example.com/predicate1', is_uri=True), + o=Value(value='literal object1', is_uri=False) + ) + triple2 = Triple( + s=Value(value='http://example.com/subject2', is_uri=True), + p=Value(value='http://example.com/predicate2', is_uri=True), + o=Value(value='http://example.com/object2', is_uri=True) + ) + message.triples = [triple1, triple2] + + await processor.store_triples(message) + + # Verify session was called twice (once per triple) + assert processor.io.session.call_count == 2 + + # Verify execute_write was called once per triple + assert mock_session.execute_write.call_count == 2 + + # Verify each triple was processed + call_args_list = mock_session.execute_write.call_args_list + assert call_args_list[0][0][1] == triple1 + assert call_args_list[1][0][1] == triple2 + + @pytest.mark.asyncio + async def test_store_triples_empty_list(self, processor): + """Test storing empty triples list""" + mock_session = MagicMock() + processor.io.session.return_value.__enter__.return_value = mock_session + + # Reset the mock to clear the initialization call + processor.io.session.reset_mock() + + message = MagicMock() + message.metadata = MagicMock() + message.metadata.user = 'test_user' + message.metadata.collection = 'test_collection' + message.triples = [] + + await processor.store_triples(message) + + # Verify no session calls were made (no triples to process) + processor.io.session.assert_not_called() + + # Verify no execute_write calls were made + mock_session.execute_write.assert_not_called() + + def test_add_args_method(self): + """Test that add_args properly configures argument parser""" + from argparse import ArgumentParser + from unittest.mock import patch + + parser = ArgumentParser() + + # Mock the parent class add_args method + with patch('trustgraph.storage.triples.memgraph.write.TriplesStoreService.add_args') as mock_parent_add_args: + Processor.add_args(parser) + + # Verify parent add_args was called + mock_parent_add_args.assert_called_once() + + # Verify our specific arguments were added + # Parse empty args to check defaults + args = parser.parse_args([]) + + assert hasattr(args, 'graph_host') + assert args.graph_host == 'bolt://memgraph:7687' + assert hasattr(args, 'username') + assert args.username == 'memgraph' + assert hasattr(args, 'password') + assert args.password == 'password' + assert hasattr(args, 'database') + assert args.database == 'memgraph' + + def test_add_args_with_custom_values(self): + """Test add_args with custom command line values""" + from argparse import ArgumentParser + from unittest.mock import patch + + parser = ArgumentParser() + + with patch('trustgraph.storage.triples.memgraph.write.TriplesStoreService.add_args'): + Processor.add_args(parser) + + # Test parsing with custom values + args = parser.parse_args([ + '--graph-host', 'bolt://custom:7687', + '--username', 'custom_user', + '--password', 'custom_pass', + '--database', 'custom_db' + ]) + + assert args.graph_host == 'bolt://custom:7687' + assert args.username == 'custom_user' + assert args.password == 'custom_pass' + assert args.database == 'custom_db' + + def test_add_args_short_form(self): + """Test add_args with short form arguments""" + from argparse import ArgumentParser + from unittest.mock import patch + + parser = ArgumentParser() + + with patch('trustgraph.storage.triples.memgraph.write.TriplesStoreService.add_args'): + Processor.add_args(parser) + + # Test parsing with short form + args = parser.parse_args(['-g', 'bolt://short:7687']) + + assert args.graph_host == 'bolt://short:7687' + + @patch('trustgraph.storage.triples.memgraph.write.Processor.launch') + def test_run_function(self, mock_launch): + """Test the run function calls Processor.launch with correct parameters""" + from trustgraph.storage.triples.memgraph.write import run, default_ident + + run() + + mock_launch.assert_called_once_with( + default_ident, + "\nGraph writer. Input is graph edge. Writes edges to Memgraph.\n" + ) \ No newline at end of file diff --git a/tests/unit/test_storage/test_triples_neo4j_storage.py b/tests/unit/test_storage/test_triples_neo4j_storage.py new file mode 100644 index 00000000..a84706ee --- /dev/null +++ b/tests/unit/test_storage/test_triples_neo4j_storage.py @@ -0,0 +1,548 @@ +""" +Tests for Neo4j triples storage service +""" + +import pytest +from unittest.mock import MagicMock, patch, AsyncMock + +from trustgraph.storage.triples.neo4j.write import Processor + + +class TestNeo4jStorageProcessor: + """Test cases for Neo4j storage processor""" + + @patch('trustgraph.storage.triples.neo4j.write.GraphDatabase') + def test_processor_initialization_with_defaults(self, mock_graph_db): + """Test processor initialization with default parameters""" + taskgroup_mock = MagicMock() + mock_driver = MagicMock() + mock_graph_db.driver.return_value = mock_driver + mock_session = MagicMock() + mock_driver.session.return_value.__enter__.return_value = mock_session + + processor = Processor(taskgroup=taskgroup_mock) + + assert processor.db == 'neo4j' + mock_graph_db.driver.assert_called_once_with( + 'bolt://neo4j:7687', + auth=('neo4j', 'password') + ) + + @patch('trustgraph.storage.triples.neo4j.write.GraphDatabase') + def test_processor_initialization_with_custom_params(self, mock_graph_db): + """Test processor initialization with custom parameters""" + taskgroup_mock = MagicMock() + mock_driver = MagicMock() + mock_graph_db.driver.return_value = mock_driver + mock_session = MagicMock() + mock_driver.session.return_value.__enter__.return_value = mock_session + + processor = Processor( + taskgroup=taskgroup_mock, + graph_host='bolt://custom:7687', + username='testuser', + password='testpass', + database='testdb' + ) + + assert processor.db == 'testdb' + mock_graph_db.driver.assert_called_once_with( + 'bolt://custom:7687', + auth=('testuser', 'testpass') + ) + + @patch('trustgraph.storage.triples.neo4j.write.GraphDatabase') + def test_create_indexes_success(self, mock_graph_db): + """Test successful index creation""" + taskgroup_mock = MagicMock() + mock_driver = MagicMock() + mock_graph_db.driver.return_value = mock_driver + mock_session = MagicMock() + mock_driver.session.return_value.__enter__.return_value = mock_session + + processor = Processor(taskgroup=taskgroup_mock) + + # Verify index creation queries were executed + expected_calls = [ + "CREATE INDEX Node_uri FOR (n:Node) ON (n.uri)", + "CREATE INDEX Literal_value FOR (n:Literal) ON (n.value)", + "CREATE INDEX Rel_uri FOR ()-[r:Rel]-() ON (r.uri)" + ] + + assert mock_session.run.call_count == 3 + for expected_query in expected_calls: + mock_session.run.assert_any_call(expected_query) + + @patch('trustgraph.storage.triples.neo4j.write.GraphDatabase') + def test_create_indexes_with_exceptions(self, mock_graph_db): + """Test index creation with exceptions (should be ignored)""" + taskgroup_mock = MagicMock() + mock_driver = MagicMock() + mock_graph_db.driver.return_value = mock_driver + mock_session = MagicMock() + mock_driver.session.return_value.__enter__.return_value = mock_session + + # Make session.run raise exceptions + mock_session.run.side_effect = Exception("Index already exists") + + # Should not raise exception - they should be caught and ignored + processor = Processor(taskgroup=taskgroup_mock) + + # Should have tried to create all 3 indexes despite exceptions + assert mock_session.run.call_count == 3 + + @patch('trustgraph.storage.triples.neo4j.write.GraphDatabase') + def test_create_node(self, mock_graph_db): + """Test node creation""" + taskgroup_mock = MagicMock() + mock_driver = MagicMock() + mock_graph_db.driver.return_value = mock_driver + mock_session = MagicMock() + mock_driver.session.return_value.__enter__.return_value = mock_session + + # Mock execute_query response + mock_result = MagicMock() + mock_summary = MagicMock() + mock_summary.counters.nodes_created = 1 + mock_summary.result_available_after = 10 + mock_result.summary = mock_summary + mock_driver.execute_query.return_value = mock_result + + processor = Processor(taskgroup=taskgroup_mock) + + # Test create_node + processor.create_node("http://example.com/node") + + mock_driver.execute_query.assert_called_with( + "MERGE (n:Node {uri: $uri})", + uri="http://example.com/node", + database_="neo4j" + ) + + @patch('trustgraph.storage.triples.neo4j.write.GraphDatabase') + def test_create_literal(self, mock_graph_db): + """Test literal creation""" + taskgroup_mock = MagicMock() + mock_driver = MagicMock() + mock_graph_db.driver.return_value = mock_driver + mock_session = MagicMock() + mock_driver.session.return_value.__enter__.return_value = mock_session + + # Mock execute_query response + mock_result = MagicMock() + mock_summary = MagicMock() + mock_summary.counters.nodes_created = 1 + mock_summary.result_available_after = 10 + mock_result.summary = mock_summary + mock_driver.execute_query.return_value = mock_result + + processor = Processor(taskgroup=taskgroup_mock) + + # Test create_literal + processor.create_literal("literal value") + + mock_driver.execute_query.assert_called_with( + "MERGE (n:Literal {value: $value})", + value="literal value", + database_="neo4j" + ) + + @patch('trustgraph.storage.triples.neo4j.write.GraphDatabase') + def test_relate_node(self, mock_graph_db): + """Test node-to-node relationship creation""" + taskgroup_mock = MagicMock() + mock_driver = MagicMock() + mock_graph_db.driver.return_value = mock_driver + mock_session = MagicMock() + mock_driver.session.return_value.__enter__.return_value = mock_session + + # Mock execute_query response + mock_result = MagicMock() + mock_summary = MagicMock() + mock_summary.counters.nodes_created = 0 + mock_summary.result_available_after = 10 + mock_result.summary = mock_summary + mock_driver.execute_query.return_value = mock_result + + processor = Processor(taskgroup=taskgroup_mock) + + # Test relate_node + processor.relate_node( + "http://example.com/subject", + "http://example.com/predicate", + "http://example.com/object" + ) + + mock_driver.execute_query.assert_called_with( + "MATCH (src:Node {uri: $src}) " + "MATCH (dest:Node {uri: $dest}) " + "MERGE (src)-[:Rel {uri: $uri}]->(dest)", + src="http://example.com/subject", + dest="http://example.com/object", + uri="http://example.com/predicate", + database_="neo4j" + ) + + @patch('trustgraph.storage.triples.neo4j.write.GraphDatabase') + def test_relate_literal(self, mock_graph_db): + """Test node-to-literal relationship creation""" + taskgroup_mock = MagicMock() + mock_driver = MagicMock() + mock_graph_db.driver.return_value = mock_driver + mock_session = MagicMock() + mock_driver.session.return_value.__enter__.return_value = mock_session + + # Mock execute_query response + mock_result = MagicMock() + mock_summary = MagicMock() + mock_summary.counters.nodes_created = 0 + mock_summary.result_available_after = 10 + mock_result.summary = mock_summary + mock_driver.execute_query.return_value = mock_result + + processor = Processor(taskgroup=taskgroup_mock) + + # Test relate_literal + processor.relate_literal( + "http://example.com/subject", + "http://example.com/predicate", + "literal value" + ) + + mock_driver.execute_query.assert_called_with( + "MATCH (src:Node {uri: $src}) " + "MATCH (dest:Literal {value: $dest}) " + "MERGE (src)-[:Rel {uri: $uri}]->(dest)", + src="http://example.com/subject", + dest="literal value", + uri="http://example.com/predicate", + database_="neo4j" + ) + + @patch('trustgraph.storage.triples.neo4j.write.GraphDatabase') + @pytest.mark.asyncio + async def test_handle_triples_with_uri_object(self, mock_graph_db): + """Test handling triples message with URI object""" + taskgroup_mock = MagicMock() + mock_driver = MagicMock() + mock_graph_db.driver.return_value = mock_driver + mock_session = MagicMock() + mock_driver.session.return_value.__enter__.return_value = mock_session + + # Mock execute_query response + mock_result = MagicMock() + mock_summary = MagicMock() + mock_summary.counters.nodes_created = 1 + mock_summary.result_available_after = 10 + mock_result.summary = mock_summary + mock_driver.execute_query.return_value = mock_result + + processor = Processor(taskgroup=taskgroup_mock) + + # Create mock triple with URI object + triple = MagicMock() + triple.s.value = "http://example.com/subject" + triple.p.value = "http://example.com/predicate" + triple.o.value = "http://example.com/object" + triple.o.is_uri = True + + # Create mock message + mock_message = MagicMock() + mock_message.triples = [triple] + + await processor.store_triples(mock_message) + + # Verify create_node was called for subject and object + # Verify relate_node was called + expected_calls = [ + # Subject node creation + ( + "MERGE (n:Node {uri: $uri})", + {"uri": "http://example.com/subject", "database_": "neo4j"} + ), + # Object node creation + ( + "MERGE (n:Node {uri: $uri})", + {"uri": "http://example.com/object", "database_": "neo4j"} + ), + # Relationship creation + ( + "MATCH (src:Node {uri: $src}) " + "MATCH (dest:Node {uri: $dest}) " + "MERGE (src)-[:Rel {uri: $uri}]->(dest)", + { + "src": "http://example.com/subject", + "dest": "http://example.com/object", + "uri": "http://example.com/predicate", + "database_": "neo4j" + } + ) + ] + + assert mock_driver.execute_query.call_count == 3 + for expected_query, expected_params in expected_calls: + mock_driver.execute_query.assert_any_call(expected_query, **expected_params) + + @patch('trustgraph.storage.triples.neo4j.write.GraphDatabase') + @pytest.mark.asyncio + async def test_store_triples_with_literal_object(self, mock_graph_db): + """Test handling triples message with literal object""" + taskgroup_mock = MagicMock() + mock_driver = MagicMock() + mock_graph_db.driver.return_value = mock_driver + mock_session = MagicMock() + mock_driver.session.return_value.__enter__.return_value = mock_session + + # Mock execute_query response + mock_result = MagicMock() + mock_summary = MagicMock() + mock_summary.counters.nodes_created = 1 + mock_summary.result_available_after = 10 + mock_result.summary = mock_summary + mock_driver.execute_query.return_value = mock_result + + processor = Processor(taskgroup=taskgroup_mock) + + # Create mock triple with literal object + triple = MagicMock() + triple.s.value = "http://example.com/subject" + triple.p.value = "http://example.com/predicate" + triple.o.value = "literal value" + triple.o.is_uri = False + + # Create mock message + mock_message = MagicMock() + mock_message.triples = [triple] + + await processor.store_triples(mock_message) + + # Verify create_node was called for subject + # Verify create_literal was called for object + # Verify relate_literal was called + expected_calls = [ + # Subject node creation + ( + "MERGE (n:Node {uri: $uri})", + {"uri": "http://example.com/subject", "database_": "neo4j"} + ), + # Literal creation + ( + "MERGE (n:Literal {value: $value})", + {"value": "literal value", "database_": "neo4j"} + ), + # Relationship creation + ( + "MATCH (src:Node {uri: $src}) " + "MATCH (dest:Literal {value: $dest}) " + "MERGE (src)-[:Rel {uri: $uri}]->(dest)", + { + "src": "http://example.com/subject", + "dest": "literal value", + "uri": "http://example.com/predicate", + "database_": "neo4j" + } + ) + ] + + assert mock_driver.execute_query.call_count == 3 + for expected_query, expected_params in expected_calls: + mock_driver.execute_query.assert_any_call(expected_query, **expected_params) + + @patch('trustgraph.storage.triples.neo4j.write.GraphDatabase') + @pytest.mark.asyncio + async def test_store_multiple_triples(self, mock_graph_db): + """Test handling message with multiple triples""" + taskgroup_mock = MagicMock() + mock_driver = MagicMock() + mock_graph_db.driver.return_value = mock_driver + mock_session = MagicMock() + mock_driver.session.return_value.__enter__.return_value = mock_session + + # Mock execute_query response + mock_result = MagicMock() + mock_summary = MagicMock() + mock_summary.counters.nodes_created = 1 + mock_summary.result_available_after = 10 + mock_result.summary = mock_summary + mock_driver.execute_query.return_value = mock_result + + processor = Processor(taskgroup=taskgroup_mock) + + # Create mock triples + triple1 = MagicMock() + triple1.s.value = "http://example.com/subject1" + triple1.p.value = "http://example.com/predicate1" + triple1.o.value = "http://example.com/object1" + triple1.o.is_uri = True + + triple2 = MagicMock() + triple2.s.value = "http://example.com/subject2" + triple2.p.value = "http://example.com/predicate2" + triple2.o.value = "literal value" + triple2.o.is_uri = False + + # Create mock message + mock_message = MagicMock() + mock_message.triples = [triple1, triple2] + + await processor.store_triples(mock_message) + + # Should have processed both triples + # Triple1: 2 nodes + 1 relationship = 3 calls + # Triple2: 1 node + 1 literal + 1 relationship = 3 calls + # Total: 6 calls + assert mock_driver.execute_query.call_count == 6 + + @patch('trustgraph.storage.triples.neo4j.write.GraphDatabase') + @pytest.mark.asyncio + async def test_store_empty_triples(self, mock_graph_db): + """Test handling message with no triples""" + taskgroup_mock = MagicMock() + mock_driver = MagicMock() + mock_graph_db.driver.return_value = mock_driver + mock_session = MagicMock() + mock_driver.session.return_value.__enter__.return_value = mock_session + + processor = Processor(taskgroup=taskgroup_mock) + + # Create mock message with empty triples + mock_message = MagicMock() + mock_message.triples = [] + + await processor.store_triples(mock_message) + + # Should not have made any execute_query calls beyond index creation + # Only index creation calls should have been made during initialization + mock_driver.execute_query.assert_not_called() + + def test_add_args_method(self): + """Test that add_args properly configures argument parser""" + from argparse import ArgumentParser + from unittest.mock import patch + + parser = ArgumentParser() + + # Mock the parent class add_args method + with patch('trustgraph.storage.triples.neo4j.write.TriplesStoreService.add_args') as mock_parent_add_args: + Processor.add_args(parser) + + # Verify parent add_args was called + mock_parent_add_args.assert_called_once() + + # Verify our specific arguments were added + # Parse empty args to check defaults + args = parser.parse_args([]) + + assert hasattr(args, 'graph_host') + assert args.graph_host == 'bolt://neo4j:7687' + assert hasattr(args, 'username') + assert args.username == 'neo4j' + assert hasattr(args, 'password') + assert args.password == 'password' + assert hasattr(args, 'database') + assert args.database == 'neo4j' + + def test_add_args_with_custom_values(self): + """Test add_args with custom command line values""" + from argparse import ArgumentParser + from unittest.mock import patch + + parser = ArgumentParser() + + with patch('trustgraph.storage.triples.neo4j.write.TriplesStoreService.add_args'): + Processor.add_args(parser) + + # Test parsing with custom values + args = parser.parse_args([ + '--graph_host', 'bolt://custom:7687', + '--username', 'testuser', + '--password', 'testpass', + '--database', 'testdb' + ]) + + assert args.graph_host == 'bolt://custom:7687' + assert args.username == 'testuser' + assert args.password == 'testpass' + assert args.database == 'testdb' + + def test_add_args_short_form(self): + """Test add_args with short form arguments""" + from argparse import ArgumentParser + from unittest.mock import patch + + parser = ArgumentParser() + + with patch('trustgraph.storage.triples.neo4j.write.TriplesStoreService.add_args'): + Processor.add_args(parser) + + # Test parsing with short form + args = parser.parse_args(['-g', 'bolt://short:7687']) + + assert args.graph_host == 'bolt://short:7687' + + @patch('trustgraph.storage.triples.neo4j.write.Processor.launch') + def test_run_function(self, mock_launch): + """Test the run function calls Processor.launch with correct parameters""" + from trustgraph.storage.triples.neo4j.write import run, default_ident + + run() + + mock_launch.assert_called_once_with( + default_ident, + "\nGraph writer. Input is graph edge. Writes edges to Neo4j graph.\n" + ) + + @patch('trustgraph.storage.triples.neo4j.write.GraphDatabase') + @pytest.mark.asyncio + async def test_store_triples_with_special_characters(self, mock_graph_db): + """Test handling triples with special characters and unicode""" + taskgroup_mock = MagicMock() + mock_driver = MagicMock() + mock_graph_db.driver.return_value = mock_driver + mock_session = MagicMock() + mock_driver.session.return_value.__enter__.return_value = mock_session + + # Mock execute_query response + mock_result = MagicMock() + mock_summary = MagicMock() + mock_summary.counters.nodes_created = 1 + mock_summary.result_available_after = 10 + mock_result.summary = mock_summary + mock_driver.execute_query.return_value = mock_result + + processor = Processor(taskgroup=taskgroup_mock) + + # Create triple with special characters + triple = MagicMock() + triple.s.value = "http://example.com/subject with spaces" + triple.p.value = "http://example.com/predicate:with/symbols" + triple.o.value = 'literal with "quotes" and unicode: ñáéíóú' + triple.o.is_uri = False + + mock_message = MagicMock() + mock_message.triples = [triple] + + await processor.store_triples(mock_message) + + # Verify the triple was processed with special characters preserved + mock_driver.execute_query.assert_any_call( + "MERGE (n:Node {uri: $uri})", + uri="http://example.com/subject with spaces", + database_="neo4j" + ) + + mock_driver.execute_query.assert_any_call( + "MERGE (n:Literal {value: $value})", + value='literal with "quotes" and unicode: ñáéíóú', + database_="neo4j" + ) + + mock_driver.execute_query.assert_any_call( + "MATCH (src:Node {uri: $src}) " + "MATCH (dest:Literal {value: $dest}) " + "MERGE (src)-[:Rel {uri: $uri}]->(dest)", + src="http://example.com/subject with spaces", + dest='literal with "quotes" and unicode: ñáéíóú', + uri="http://example.com/predicate:with/symbols", + database_="neo4j" + ) diff --git a/trustgraph-flow/trustgraph/direct/cassandra.py b/trustgraph-flow/trustgraph/direct/cassandra.py index 73f1f33a..f7ca7e5e 100644 --- a/trustgraph-flow/trustgraph/direct/cassandra.py +++ b/trustgraph-flow/trustgraph/direct/cassandra.py @@ -3,6 +3,9 @@ from cassandra.cluster import Cluster from cassandra.auth import PlainTextAuthProvider from ssl import SSLContext, PROTOCOL_TLSv1_2 +# Global list to track clusters for cleanup +_active_clusters = [] + class TrustGraph: def __init__( @@ -24,6 +27,9 @@ class TrustGraph: else: self.cluster = Cluster(hosts) self.session = self.cluster.connect() + + # Track this cluster globally + _active_clusters.append(self.cluster) self.init() @@ -119,3 +125,13 @@ class TrustGraph: f"""select s as x from {self.table} where s = %s and p = %s and o = %s limit {limit}""", (s, p, o) ) + + def close(self): + """Close the Cassandra session and cluster connections properly""" + if hasattr(self, 'session') and self.session: + self.session.shutdown() + if hasattr(self, 'cluster') and self.cluster: + self.cluster.shutdown() + # Remove from global tracking + if self.cluster in _active_clusters: + _active_clusters.remove(self.cluster) diff --git a/trustgraph-flow/trustgraph/query/doc_embeddings/milvus/service.py b/trustgraph-flow/trustgraph/query/doc_embeddings/milvus/service.py index 2fb416dd..0148a98d 100755 --- a/trustgraph-flow/trustgraph/query/doc_embeddings/milvus/service.py +++ b/trustgraph-flow/trustgraph/query/doc_embeddings/milvus/service.py @@ -5,94 +5,56 @@ of chunks """ from .... direct.milvus_doc_embeddings import DocVectors -from .... schema import DocumentEmbeddingsRequest, DocumentEmbeddingsResponse +from .... schema import DocumentEmbeddingsResponse from .... schema import Error, Value -from .... schema import document_embeddings_request_queue -from .... schema import document_embeddings_response_queue -from .... base import ConsumerProducer +from .... base import DocumentEmbeddingsQueryService -module = "de-query" - -default_input_queue = document_embeddings_request_queue -default_output_queue = document_embeddings_response_queue -default_subscriber = module +default_ident = "de-query" default_store_uri = 'http://localhost:19530' -class Processor(ConsumerProducer): +class Processor(DocumentEmbeddingsQueryService): def __init__(self, **params): - input_queue = params.get("input_queue", default_input_queue) - output_queue = params.get("output_queue", default_output_queue) - subscriber = params.get("subscriber", default_subscriber) store_uri = params.get("store_uri", default_store_uri) super(Processor, self).__init__( **params | { - "input_queue": input_queue, - "output_queue": output_queue, - "subscriber": subscriber, - "input_schema": DocumentEmbeddingsRequest, - "output_schema": DocumentEmbeddingsResponse, "store_uri": store_uri, } ) self.vecstore = DocVectors(store_uri) - async def handle(self, msg): + async def query_document_embeddings(self, msg): try: - v = msg.value() - - # Sender-produced ID - id = msg.properties()["id"] - - print(f"Handling input {id}...", flush=True) + # Handle zero limit case + if msg.limit <= 0: + return [] chunks = [] - for vec in v.vectors: + for vec in msg.vectors: - resp = self.vecstore.search(vec, limit=v.limit) + resp = self.vecstore.search(vec, limit=msg.limit) for r in resp: chunk = r["entity"]["doc"] - chunk = chunk.encode("utf-8") chunks.append(chunk) - print("Send response...", flush=True) - r = DocumentEmbeddingsResponse(documents=chunks, error=None) - await self.send(r, properties={"id": id}) - - print("Done.", flush=True) + return chunks except Exception as e: print(f"Exception: {e}") - - print("Send error response...", flush=True) - - r = DocumentEmbeddingsResponse( - error=Error( - type = "llm-error", - message = str(e), - ), - documents=None, - ) - - await self.send(r, properties={"id": id}) - - self.consumer.acknowledge(msg) + raise e @staticmethod def add_args(parser): - ConsumerProducer.add_args( - parser, default_input_queue, default_subscriber, - default_output_queue, - ) + DocumentEmbeddingsQueryService.add_args(parser) parser.add_argument( '-t', '--store-uri', @@ -102,5 +64,5 @@ class Processor(ConsumerProducer): def run(): - Processor.launch(module, __doc__) + Processor.launch(default_ident, __doc__) diff --git a/trustgraph-flow/trustgraph/query/doc_embeddings/pinecone/service.py b/trustgraph-flow/trustgraph/query/doc_embeddings/pinecone/service.py index 74c52055..8388a8ca 100755 --- a/trustgraph-flow/trustgraph/query/doc_embeddings/pinecone/service.py +++ b/trustgraph-flow/trustgraph/query/doc_embeddings/pinecone/service.py @@ -10,30 +10,21 @@ from pinecone.grpc import PineconeGRPC, GRPCClientConfig import uuid import os -from .... schema import DocumentEmbeddingsRequest, DocumentEmbeddingsResponse -from .... schema import Error, Value -from .... schema import document_embeddings_request_queue -from .... schema import document_embeddings_response_queue -from .... base import ConsumerProducer +from .... base import DocumentEmbeddingsQueryService -module = "de-query" - -default_input_queue = document_embeddings_request_queue -default_output_queue = document_embeddings_response_queue -default_subscriber = module +default_ident = "de-query" default_api_key = os.getenv("PINECONE_API_KEY", "not-specified") -class Processor(ConsumerProducer): +class Processor(DocumentEmbeddingsQueryService): def __init__(self, **params): - input_queue = params.get("input_queue", default_input_queue) - output_queue = params.get("output_queue", default_output_queue) - subscriber = params.get("subscriber", default_subscriber) - self.url = params.get("url", None) self.api_key = params.get("api_key", default_api_key) + if self.api_key is None or self.api_key == "not-specified": + raise RuntimeError("Pinecone API key must be specified") + if self.url: self.pinecone = PineconeGRPC( @@ -47,88 +38,53 @@ class Processor(ConsumerProducer): super(Processor, self).__init__( **params | { - "input_queue": input_queue, - "output_queue": output_queue, - "subscriber": subscriber, - "input_schema": DocumentEmbeddingsRequest, - "output_schema": DocumentEmbeddingsResponse, "url": self.url, + "api_key": self.api_key, } ) - async def handle(self, msg): + async def query_document_embeddings(self, msg): try: - v = msg.value() - - # Sender-produced ID - id = msg.properties()["id"] - - print(f"Handling input {id}...", flush=True) + # Handle zero limit case + if msg.limit <= 0: + return [] chunks = [] - for vec in v.vectors: + for vec in msg.vectors: dim = len(vec) index_name = ( - "d-" + v.user + "-" + str(dim) + "d-" + msg.user + "-" + msg.collection + "-" + str(dim) ) index = self.pinecone.Index(index_name) results = index.query( - namespace=v.collection, vector=vec, - top_k=v.limit, + top_k=msg.limit, include_values=False, include_metadata=True ) - search_result = self.client.query_points( - collection_name=collection, - query=vec, - limit=v.limit, - with_payload=True, - ).points - for r in results.matches: doc = r.metadata["doc"] - chunks.add(doc) + chunks.append(doc) - print("Send response...", flush=True) - r = DocumentEmbeddingsResponse(documents=chunks, error=None) - await self.send(r, properties={"id": id}) - - print("Done.", flush=True) + return chunks except Exception as e: print(f"Exception: {e}") - - print("Send error response...", flush=True) - - r = DocumentEmbeddingsResponse( - error=Error( - type = "llm-error", - message = str(e), - ), - documents=None, - ) - - await self.send(r, properties={"id": id}) - - self.consumer.acknowledge(msg) + raise e @staticmethod def add_args(parser): - ConsumerProducer.add_args( - parser, default_input_queue, default_subscriber, - default_output_queue, - ) + DocumentEmbeddingsQueryService.add_args(parser) parser.add_argument( '-a', '--api-key', @@ -143,5 +99,5 @@ class Processor(ConsumerProducer): def run(): - Processor.launch(module, __doc__) + Processor.launch(default_ident, __doc__) diff --git a/trustgraph-flow/trustgraph/query/graph_embeddings/milvus/service.py b/trustgraph-flow/trustgraph/query/graph_embeddings/milvus/service.py index d2cec084..7603f4d6 100755 --- a/trustgraph-flow/trustgraph/query/graph_embeddings/milvus/service.py +++ b/trustgraph-flow/trustgraph/query/graph_embeddings/milvus/service.py @@ -5,35 +5,21 @@ entities """ from .... direct.milvus_graph_embeddings import EntityVectors -from .... schema import GraphEmbeddingsRequest, GraphEmbeddingsResponse +from .... schema import GraphEmbeddingsResponse from .... schema import Error, Value -from .... schema import graph_embeddings_request_queue -from .... schema import graph_embeddings_response_queue -from .... base import ConsumerProducer +from .... base import GraphEmbeddingsQueryService -module = "ge-query" - -default_input_queue = graph_embeddings_request_queue -default_output_queue = graph_embeddings_response_queue -default_subscriber = module +default_ident = "ge-query" default_store_uri = 'http://localhost:19530' -class Processor(ConsumerProducer): +class Processor(GraphEmbeddingsQueryService): def __init__(self, **params): - input_queue = params.get("input_queue", default_input_queue) - output_queue = params.get("output_queue", default_output_queue) - subscriber = params.get("subscriber", default_subscriber) store_uri = params.get("store_uri", default_store_uri) super(Processor, self).__init__( **params | { - "input_queue": input_queue, - "output_queue": output_queue, - "subscriber": subscriber, - "input_schema": GraphEmbeddingsRequest, - "output_schema": GraphEmbeddingsResponse, "store_uri": store_uri, } ) @@ -46,29 +32,34 @@ class Processor(ConsumerProducer): else: return Value(value=ent, is_uri=False) - async def handle(self, msg): + async def query_graph_embeddings(self, msg): try: - v = msg.value() + entity_set = set() + entities = [] - # Sender-produced ID - id = msg.properties()["id"] + # Handle zero limit case + if msg.limit <= 0: + return [] - print(f"Handling input {id}...", flush=True) + for vec in msg.vectors: - entities = set() - - for vec in v.vectors: - - resp = self.vecstore.search(vec, limit=v.limit) + resp = self.vecstore.search(vec, limit=msg.limit * 2) for r in resp: ent = r["entity"]["entity"] - entities.add(ent) + + # De-dupe entities + if ent not in entity_set: + entity_set.add(ent) + entities.append(ent) - # Convert set to list - entities = list(entities) + # Keep adding entities until limit + if len(entity_set) >= msg.limit: break + + # Keep adding entities until limit + if len(entity_set) >= msg.limit: break ents2 = [] @@ -78,36 +69,19 @@ class Processor(ConsumerProducer): entities = ents2 print("Send response...", flush=True) - r = GraphEmbeddingsResponse(entities=entities, error=None) - await self.send(r, properties={"id": id}) + return entities print("Done.", flush=True) except Exception as e: print(f"Exception: {e}") - - print("Send error response...", flush=True) - - r = GraphEmbeddingsResponse( - error=Error( - type = "llm-error", - message = str(e), - ), - entities=None, - ) - - await self.send(r, properties={"id": id}) - - self.consumer.acknowledge(msg) + raise e @staticmethod def add_args(parser): - ConsumerProducer.add_args( - parser, default_input_queue, default_subscriber, - default_output_queue, - ) + GraphEmbeddingsQueryService.add_args(parser) parser.add_argument( '-t', '--store-uri', @@ -117,5 +91,5 @@ class Processor(ConsumerProducer): def run(): - Processor.launch(module, __doc__) + Processor.launch(default_ident, __doc__) diff --git a/trustgraph-flow/trustgraph/query/graph_embeddings/pinecone/service.py b/trustgraph-flow/trustgraph/query/graph_embeddings/pinecone/service.py index 942a1e69..94781fc1 100755 --- a/trustgraph-flow/trustgraph/query/graph_embeddings/pinecone/service.py +++ b/trustgraph-flow/trustgraph/query/graph_embeddings/pinecone/service.py @@ -10,30 +10,23 @@ from pinecone.grpc import PineconeGRPC, GRPCClientConfig import uuid import os -from .... schema import GraphEmbeddingsRequest, GraphEmbeddingsResponse +from .... schema import GraphEmbeddingsResponse from .... schema import Error, Value -from .... schema import graph_embeddings_request_queue -from .... schema import graph_embeddings_response_queue -from .... base import ConsumerProducer +from .... base import GraphEmbeddingsQueryService -module = "ge-query" - -default_input_queue = graph_embeddings_request_queue -default_output_queue = graph_embeddings_response_queue -default_subscriber = module +default_ident = "ge-query" default_api_key = os.getenv("PINECONE_API_KEY", "not-specified") -class Processor(ConsumerProducer): +class Processor(GraphEmbeddingsQueryService): def __init__(self, **params): - input_queue = params.get("input_queue", default_input_queue) - output_queue = params.get("output_queue", default_output_queue) - subscriber = params.get("subscriber", default_subscriber) - self.url = params.get("url", None) self.api_key = params.get("api_key", default_api_key) + if self.api_key is None or self.api_key == "not-specified": + raise RuntimeError("Pinecone API key must be specified") + if self.url: self.pinecone = PineconeGRPC( @@ -47,12 +40,8 @@ class Processor(ConsumerProducer): super(Processor, self).__init__( **params | { - "input_queue": input_queue, - "output_queue": output_queue, - "subscriber": subscriber, - "input_schema": GraphEmbeddingsRequest, - "output_schema": GraphEmbeddingsResponse, "url": self.url, + "api_key": self.api_key, } ) @@ -62,26 +51,23 @@ class Processor(ConsumerProducer): else: return Value(value=ent, is_uri=False) - async def handle(self, msg): + async def query_graph_embeddings(self, msg): try: - v = msg.value() - - # Sender-produced ID - id = msg.properties()["id"] - - print(f"Handling input {id}...", flush=True) + # Handle zero limit case + if msg.limit <= 0: + return [] entity_set = set() entities = [] - for vec in v.vectors: + for vec in msg.vectors: dim = len(vec) index_name = ( - "t-" + v.user + "-" + str(dim) + "t-" + msg.user + "-" + msg.collection + "-" + str(dim) ) index = self.pinecone.Index(index_name) @@ -89,9 +75,8 @@ class Processor(ConsumerProducer): # Heuristic hack, get (2*limit), so that we have more chance # of getting (limit) entities results = index.query( - namespace=v.collection, vector=vec, - top_k=v.limit * 2, + top_k=msg.limit * 2, include_values=False, include_metadata=True ) @@ -106,10 +91,10 @@ class Processor(ConsumerProducer): entities.append(ent) # Keep adding entities until limit - if len(entity_set) >= v.limit: break + if len(entity_set) >= msg.limit: break # Keep adding entities until limit - if len(entity_set) >= v.limit: break + if len(entity_set) >= msg.limit: break ents2 = [] @@ -118,37 +103,17 @@ class Processor(ConsumerProducer): entities = ents2 - print("Send response...", flush=True) - r = GraphEmbeddingsResponse(entities=entities, error=None) - await self.send(r, properties={"id": id}) - - print("Done.", flush=True) + return entities except Exception as e: print(f"Exception: {e}") - - print("Send error response...", flush=True) - - r = GraphEmbeddingsResponse( - error=Error( - type = "llm-error", - message = str(e), - ), - entities=None, - ) - - await self.send(r, properties={"id": id}) - - self.consumer.acknowledge(msg) + raise e @staticmethod def add_args(parser): - ConsumerProducer.add_args( - parser, default_input_queue, default_subscriber, - default_output_queue, - ) + GraphEmbeddingsQueryService.add_args(parser) parser.add_argument( '-a', '--api-key', @@ -163,5 +128,5 @@ class Processor(ConsumerProducer): def run(): - Processor.launch(module, __doc__) + Processor.launch(default_ident, __doc__) diff --git a/trustgraph-flow/trustgraph/query/triples/falkordb/service.py b/trustgraph-flow/trustgraph/query/triples/falkordb/service.py index c62c28c1..2bbe5e2f 100755 --- a/trustgraph-flow/trustgraph/query/triples/falkordb/service.py +++ b/trustgraph-flow/trustgraph/query/triples/falkordb/service.py @@ -9,37 +9,24 @@ from falkordb import FalkorDB from .... schema import TriplesQueryRequest, TriplesQueryResponse, Error from .... schema import Value, Triple -from .... schema import triples_request_queue -from .... schema import triples_response_queue -from .... base import ConsumerProducer +from .... base import TriplesQueryService -module = "triples-query" - -default_input_queue = triples_request_queue -default_output_queue = triples_response_queue -default_subscriber = module +default_ident = "triples-query" default_graph_url = 'falkor://falkordb:6379' default_database = 'falkordb' -class Processor(ConsumerProducer): +class Processor(TriplesQueryService): def __init__(self, **params): - input_queue = params.get("input_queue", default_input_queue) - output_queue = params.get("output_queue", default_output_queue) - subscriber = params.get("subscriber", default_subscriber) - graph_url = params.get("graph_host", default_graph_url) + graph_url = params.get("graph_url", default_graph_url) database = params.get("database", default_database) super(Processor, self).__init__( **params | { - "input_queue": input_queue, - "output_queue": output_queue, - "subscriber": subscriber, - "input_schema": TriplesQueryRequest, - "output_schema": TriplesQueryResponse, "graph_url": graph_url, + "database": database, } ) @@ -54,50 +41,45 @@ class Processor(ConsumerProducer): else: return Value(value=ent, is_uri=False) - async def handle(self, msg): + async def query_triples(self, query): try: - v = msg.value() - - # Sender-produced ID - id = msg.properties()["id"] - - print(f"Handling input {id}...", flush=True) - triples = [] - if v.s is not None: - if v.p is not None: - if v.o is not None: + if query.s is not None: + if query.p is not None: + if query.o is not None: # SPO records = self.io.query( "MATCH (src:Node {uri: $src})-[rel:Rel {uri: $rel}]->(dest:Literal {value: $value}) " - "RETURN $src as src", + "RETURN $src as src " + "LIMIT " + str(query.limit), params={ - "src": v.s.value, - "rel": v.p.value, - "value": v.o.value, + "src": query.s.value, + "rel": query.p.value, + "value": query.o.value, }, ).result_set for rec in records: - triples.append((v.s.value, v.p.value, v.o.value)) + triples.append((query.s.value, query.p.value, query.o.value)) records = self.io.query( "MATCH (src:Node {uri: $src})-[rel:Rel {uri: $rel}]->(dest:Node {uri: $uri}) " - "RETURN $src as src", + "RETURN $src as src " + "LIMIT " + str(query.limit), params={ - "src": v.s.value, - "rel": v.p.value, - "uri": v.o.value, + "src": query.s.value, + "rel": query.p.value, + "uri": query.o.value, }, ).result_set for rec in records: - triples.append((v.s.value, v.p.value, v.o.value)) + triples.append((query.s.value, query.p.value, query.o.value)) else: @@ -105,116 +87,124 @@ class Processor(ConsumerProducer): records = self.io.query( "MATCH (src:Node {uri: $src})-[rel:Rel {uri: $rel}]->(dest:Literal) " - "RETURN dest.value as dest", + "RETURN dest.value as dest " + "LIMIT " + str(query.limit), params={ - "src": v.s.value, - "rel": v.p.value, + "src": query.s.value, + "rel": query.p.value, }, ).result_set for rec in records: - triples.append((v.s.value, v.p.value, rec[0])) + triples.append((query.s.value, query.p.value, rec[0])) records = self.io.query( "MATCH (src:Node {uri: $src})-[rel:Rel {uri: $rel}]->(dest:Node) " - "RETURN dest.uri as dest", + "RETURN dest.uri as dest " + "LIMIT " + str(query.limit), params={ - "src": v.s.value, - "rel": v.p.value, + "src": query.s.value, + "rel": query.p.value, }, ).result_set for rec in records: - triples.append((v.s.value, v.p.value, rec[0])) + triples.append((query.s.value, query.p.value, rec[0])) else: - if v.o is not None: + if query.o is not None: # SO records = self.io.query( "MATCH (src:Node {uri: $src})-[rel:Rel]->(dest:Literal {value: $value}) " - "RETURN rel.uri as rel", + "RETURN rel.uri as rel " + "LIMIT " + str(query.limit), params={ - "src": v.s.value, - "value": v.o.value, + "src": query.s.value, + "value": query.o.value, }, ).result_set for rec in records: - triples.append((v.s.value, rec[0], v.o.value)) + triples.append((query.s.value, rec[0], query.o.value)) records = self.io.query( "MATCH (src:Node {uri: $src})-[rel:Rel]->(dest:Node {uri: $uri}) " - "RETURN rel.uri as rel", + "RETURN rel.uri as rel " + "LIMIT " + str(query.limit), params={ - "src": v.s.value, - "uri": v.o.value, + "src": query.s.value, + "uri": query.o.value, }, ).result_set for rec in records: - triples.append((v.s.value, rec[0], v.o.value)) + triples.append((query.s.value, rec[0], query.o.value)) else: # s records = self.io.query( - "match (src:node {uri: $src})-[rel:rel]->(dest:literal) " - "return rel.uri as rel, dest.value as dest", + "MATCH (src:Node {uri: $src})-[rel:Rel]->(dest:Literal) " + "RETURN rel.uri as rel, dest.value as dest " + "LIMIT " + str(query.limit), params={ - "src": v.s.value, + "src": query.s.value, }, ).result_set for rec in records: - triples.append((v.s.value, rec[0], rec[1])) + triples.append((query.s.value, rec[0], rec[1])) records = self.io.query( "MATCH (src:Node {uri: $src})-[rel:Rel]->(dest:Node) " - "RETURN rel.uri as rel, dest.uri as dest", + "RETURN rel.uri as rel, dest.uri as dest " + "LIMIT " + str(query.limit), params={ - "src": v.s.value, + "src": query.s.value, }, ).result_set for rec in records: - triples.append((v.s.value, rec[0], rec[1])) + triples.append((query.s.value, rec[0], rec[1])) else: - if v.p is not None: + if query.p is not None: - if v.o is not None: + if query.o is not None: # PO records = self.io.query( "MATCH (src:Node)-[rel:Rel {uri: $uri}]->(dest:Literal {value: $value}) " - "RETURN src.uri as src", + "RETURN src.uri as src " + "LIMIT " + str(query.limit), params={ - "uri": v.p.value, - "value": v.o.value, + "uri": query.p.value, + "value": query.o.value, }, ).result_set for rec in records: - triples.append((rec[0], v.p.value, v.o.value)) + triples.append((rec[0], query.p.value, query.o.value)) records = self.io.query( - "MATCH (src:Node)-[rel:Rel {uri: $uri}]->(dest:Node {uri: $uri}) " - "RETURN src.uri as src", + "MATCH (src:Node)-[rel:Rel {uri: $uri}]->(dest:Node {uri: $dest}) " + "RETURN src.uri as src " + "LIMIT " + str(query.limit), params={ - "uri": v.p.value, - "dest": v.o.value, + "uri": query.p.value, + "dest": query.o.value, }, ).result_set for rec in records: - triples.append((rec[0], v.p.value, v.o.value)) + triples.append((rec[0], query.p.value, query.o.value)) else: @@ -222,53 +212,57 @@ class Processor(ConsumerProducer): records = self.io.query( "MATCH (src:Node)-[rel:Rel {uri: $uri}]->(dest:Literal) " - "RETURN src.uri as src, dest.value as dest", + "RETURN src.uri as src, dest.value as dest " + "LIMIT " + str(query.limit), params={ - "uri": v.p.value, + "uri": query.p.value, }, ).result_set for rec in records: - triples.append((rec[0], v.p.value, rec[1])) + triples.append((rec[0], query.p.value, rec[1])) records = self.io.query( "MATCH (src:Node)-[rel:Rel {uri: $uri}]->(dest:Node) " - "RETURN src.uri as src, dest.uri as dest", + "RETURN src.uri as src, dest.uri as dest " + "LIMIT " + str(query.limit), params={ - "uri": v.p.value, + "uri": query.p.value, }, ).result_set for rec in records: - triples.append((rec[0], v.p.value, rec[1])) + triples.append((rec[0], query.p.value, rec[1])) else: - if v.o is not None: + if query.o is not None: # O records = self.io.query( "MATCH (src:Node)-[rel:Rel]->(dest:Literal {value: $value}) " - "RETURN src.uri as src, rel.uri as rel", + "RETURN src.uri as src, rel.uri as rel " + "LIMIT " + str(query.limit), params={ - "value": v.o.value, + "value": query.o.value, }, ).result_set for rec in records: - triples.append((rec[0], rec[1], v.o.value)) + triples.append((rec[0], rec[1], query.o.value)) records = self.io.query( "MATCH (src:Node)-[rel:Rel]->(dest:Node {uri: $uri}) " - "RETURN src.uri as src, rel.uri as rel", + "RETURN src.uri as src, rel.uri as rel " + "LIMIT " + str(query.limit), params={ - "uri": v.o.value, + "uri": query.o.value, }, ).result_set for rec in records: - triples.append((rec[0], rec[1], v.o.value)) + triples.append((rec[0], rec[1], query.o.value)) else: @@ -276,7 +270,8 @@ class Processor(ConsumerProducer): records = self.io.query( "MATCH (src:Node)-[rel:Rel]->(dest:Literal) " - "RETURN src.uri as src, rel.uri as rel, dest.value as dest", + "RETURN src.uri as src, rel.uri as rel, dest.value as dest " + "LIMIT " + str(query.limit), ).result_set for rec in records: @@ -284,7 +279,8 @@ class Processor(ConsumerProducer): records = self.io.query( "MATCH (src:Node)-[rel:Rel]->(dest:Node) " - "RETURN src.uri as src, rel.uri as rel, dest.uri as dest", + "RETURN src.uri as src, rel.uri as rel, dest.uri as dest " + "LIMIT " + str(query.limit), ).result_set for rec in records: @@ -296,40 +292,20 @@ class Processor(ConsumerProducer): p=self.create_value(t[1]), o=self.create_value(t[2]) ) - for t in triples + for t in triples[:query.limit] ] - print("Send response...", flush=True) - r = TriplesQueryResponse(triples=triples, error=None) - await self.send(r, properties={"id": id}) - - print("Done.", flush=True) + return triples except Exception as e: print(f"Exception: {e}") - - print("Send error response...", flush=True) - - r = TriplesQueryResponse( - error=Error( - type = "llm-error", - message = str(e), - ), - response=None, - ) - - await self.send(r, properties={"id": id}) - - self.consumer.acknowledge(msg) + raise e @staticmethod def add_args(parser): - ConsumerProducer.add_args( - parser, default_input_queue, default_subscriber, - default_output_queue, - ) + TriplesQueryService.add_args(parser) parser.add_argument( '-g', '--graph-url', @@ -345,5 +321,5 @@ class Processor(ConsumerProducer): def run(): - Processor.launch(module, __doc__) + Processor.launch(default_ident, __doc__) diff --git a/trustgraph-flow/trustgraph/query/triples/memgraph/service.py b/trustgraph-flow/trustgraph/query/triples/memgraph/service.py index 594c9130..bc75dd16 100755 --- a/trustgraph-flow/trustgraph/query/triples/memgraph/service.py +++ b/trustgraph-flow/trustgraph/query/triples/memgraph/service.py @@ -9,28 +9,19 @@ from neo4j import GraphDatabase from .... schema import TriplesQueryRequest, TriplesQueryResponse, Error from .... schema import Value, Triple -from .... schema import triples_request_queue -from .... schema import triples_response_queue -from .... base import ConsumerProducer +from .... base import TriplesQueryService -module = "triples-query" - -default_input_queue = triples_request_queue -default_output_queue = triples_response_queue -default_subscriber = module +default_ident = "triples-query" default_graph_host = 'bolt://memgraph:7687' default_username = 'memgraph' default_password = 'password' default_database = 'memgraph' -class Processor(ConsumerProducer): +class Processor(TriplesQueryService): def __init__(self, **params): - input_queue = params.get("input_queue", default_input_queue) - output_queue = params.get("output_queue", default_output_queue) - subscriber = params.get("subscriber", default_subscriber) graph_host = params.get("graph_host", default_graph_host) username = params.get("username", default_username) password = params.get("password", default_password) @@ -38,12 +29,9 @@ class Processor(ConsumerProducer): super(Processor, self).__init__( **params | { - "input_queue": input_queue, - "output_queue": output_queue, - "subscriber": subscriber, - "input_schema": TriplesQueryRequest, - "output_schema": TriplesQueryResponse, "graph_host": graph_host, + "username": username, + "database": database, } ) @@ -58,46 +46,39 @@ class Processor(ConsumerProducer): else: return Value(value=ent, is_uri=False) - async def handle(self, msg): + async def query_triples(self, query): try: - v = msg.value() - - # Sender-produced ID - id = msg.properties()["id"] - - print(f"Handling input {id}...", flush=True) - triples = [] - if v.s is not None: - if v.p is not None: - if v.o is not None: + if query.s is not None: + if query.p is not None: + if query.o is not None: # SPO records, summary, keys = self.io.execute_query( "MATCH (src:Node {uri: $src})-[rel:Rel {uri: $rel}]->(dest:Literal {value: $value}) " "RETURN $src as src " - "LIMIT " + str(v.limit), - src=v.s.value, rel=v.p.value, value=v.o.value, + "LIMIT " + str(query.limit), + src=query.s.value, rel=query.p.value, value=query.o.value, database_=self.db, ) for rec in records: - triples.append((v.s.value, v.p.value, v.o.value)) + triples.append((query.s.value, query.p.value, query.o.value)) records, summary, keys = self.io.execute_query( "MATCH (src:Node {uri: $src})-[rel:Rel {uri: $rel}]->(dest:Node {uri: $uri}) " "RETURN $src as src " - "LIMIT " + str(v.limit), - src=v.s.value, rel=v.p.value, uri=v.o.value, + "LIMIT " + str(query.limit), + src=query.s.value, rel=query.p.value, uri=query.o.value, database_=self.db, ) for rec in records: - triples.append((v.s.value, v.p.value, v.o.value)) + triples.append((query.s.value, query.p.value, query.o.value)) else: @@ -106,56 +87,56 @@ class Processor(ConsumerProducer): records, summary, keys = self.io.execute_query( "MATCH (src:Node {uri: $src})-[rel:Rel {uri: $rel}]->(dest:Literal) " "RETURN dest.value as dest " - "LIMIT " + str(v.limit), - src=v.s.value, rel=v.p.value, + "LIMIT " + str(query.limit), + src=query.s.value, rel=query.p.value, database_=self.db, ) for rec in records: data = rec.data() - triples.append((v.s.value, v.p.value, data["dest"])) + triples.append((query.s.value, query.p.value, data["dest"])) records, summary, keys = self.io.execute_query( "MATCH (src:Node {uri: $src})-[rel:Rel {uri: $rel}]->(dest:Node) " "RETURN dest.uri as dest " - "LIMIT " + str(v.limit), - src=v.s.value, rel=v.p.value, + "LIMIT " + str(query.limit), + src=query.s.value, rel=query.p.value, database_=self.db, ) for rec in records: data = rec.data() - triples.append((v.s.value, v.p.value, data["dest"])) + triples.append((query.s.value, query.p.value, data["dest"])) else: - if v.o is not None: + if query.o is not None: # SO records, summary, keys = self.io.execute_query( "MATCH (src:Node {uri: $src})-[rel:Rel]->(dest:Literal {value: $value}) " "RETURN rel.uri as rel " - "LIMIT " + str(v.limit), - src=v.s.value, value=v.o.value, + "LIMIT " + str(query.limit), + src=query.s.value, value=query.o.value, database_=self.db, ) for rec in records: data = rec.data() - triples.append((v.s.value, data["rel"], v.o.value)) + triples.append((query.s.value, data["rel"], query.o.value)) records, summary, keys = self.io.execute_query( "MATCH (src:Node {uri: $src})-[rel:Rel]->(dest:Node {uri: $uri}) " "RETURN rel.uri as rel " - "LIMIT " + str(v.limit), - src=v.s.value, uri=v.o.value, + "LIMIT " + str(query.limit), + src=query.s.value, uri=query.o.value, database_=self.db, ) for rec in records: data = rec.data() - triples.append((v.s.value, data["rel"], v.o.value)) + triples.append((query.s.value, data["rel"], query.o.value)) else: @@ -164,59 +145,59 @@ class Processor(ConsumerProducer): records, summary, keys = self.io.execute_query( "MATCH (src:Node {uri: $src})-[rel:Rel]->(dest:Literal) " "RETURN rel.uri as rel, dest.value as dest " - "LIMIT " + str(v.limit), - src=v.s.value, + "LIMIT " + str(query.limit), + src=query.s.value, database_=self.db, ) for rec in records: data = rec.data() - triples.append((v.s.value, data["rel"], data["dest"])) + triples.append((query.s.value, data["rel"], data["dest"])) records, summary, keys = self.io.execute_query( "MATCH (src:Node {uri: $src})-[rel:Rel]->(dest:Node) " "RETURN rel.uri as rel, dest.uri as dest " - "LIMIT " + str(v.limit), - src=v.s.value, + "LIMIT " + str(query.limit), + src=query.s.value, database_=self.db, ) for rec in records: data = rec.data() - triples.append((v.s.value, data["rel"], data["dest"])) + triples.append((query.s.value, data["rel"], data["dest"])) else: - if v.p is not None: + if query.p is not None: - if v.o is not None: + if query.o is not None: # PO records, summary, keys = self.io.execute_query( "MATCH (src:Node)-[rel:Rel {uri: $uri}]->(dest:Literal {value: $value}) " "RETURN src.uri as src " - "LIMIT " + str(v.limit), - uri=v.p.value, value=v.o.value, + "LIMIT " + str(query.limit), + uri=query.p.value, value=query.o.value, database_=self.db, ) for rec in records: data = rec.data() - triples.append((data["src"], v.p.value, v.o.value)) + triples.append((data["src"], query.p.value, query.o.value)) records, summary, keys = self.io.execute_query( - "MATCH (src:Node)-[rel:Rel {uri: $uri}]->(dest:Node {uri: $uri}) " + "MATCH (src:Node)-[rel:Rel {uri: $uri}]->(dest:Node {uri: $dest}) " "RETURN src.uri as src " - "LIMIT " + str(v.limit), - uri=v.p.value, dest=v.o.value, + "LIMIT " + str(query.limit), + uri=query.p.value, dest=query.o.value, database_=self.db, ) for rec in records: data = rec.data() - triples.append((data["src"], v.p.value, v.o.value)) + triples.append((data["src"], query.p.value, query.o.value)) else: @@ -225,56 +206,56 @@ class Processor(ConsumerProducer): records, summary, keys = self.io.execute_query( "MATCH (src:Node)-[rel:Rel {uri: $uri}]->(dest:Literal) " "RETURN src.uri as src, dest.value as dest " - "LIMIT " + str(v.limit), - uri=v.p.value, + "LIMIT " + str(query.limit), + uri=query.p.value, database_=self.db, ) for rec in records: data = rec.data() - triples.append((data["src"], v.p.value, data["dest"])) + triples.append((data["src"], query.p.value, data["dest"])) records, summary, keys = self.io.execute_query( "MATCH (src:Node)-[rel:Rel {uri: $uri}]->(dest:Node) " "RETURN src.uri as src, dest.uri as dest " - "LIMIT " + str(v.limit), - uri=v.p.value, + "LIMIT " + str(query.limit), + uri=query.p.value, database_=self.db, ) for rec in records: data = rec.data() - triples.append((data["src"], v.p.value, data["dest"])) + triples.append((data["src"], query.p.value, data["dest"])) else: - if v.o is not None: + if query.o is not None: # O records, summary, keys = self.io.execute_query( "MATCH (src:Node)-[rel:Rel]->(dest:Literal {value: $value}) " "RETURN src.uri as src, rel.uri as rel " - "LIMIT " + str(v.limit), - value=v.o.value, + "LIMIT " + str(query.limit), + value=query.o.value, database_=self.db, ) for rec in records: data = rec.data() - triples.append((data["src"], data["rel"], v.o.value)) + triples.append((data["src"], data["rel"], query.o.value)) records, summary, keys = self.io.execute_query( "MATCH (src:Node)-[rel:Rel]->(dest:Node {uri: $uri}) " "RETURN src.uri as src, rel.uri as rel " - "LIMIT " + str(v.limit), - uri=v.o.value, + "LIMIT " + str(query.limit), + uri=query.o.value, database_=self.db, ) for rec in records: data = rec.data() - triples.append((data["src"], data["rel"], v.o.value)) + triples.append((data["src"], data["rel"], query.o.value)) else: @@ -283,7 +264,7 @@ class Processor(ConsumerProducer): records, summary, keys = self.io.execute_query( "MATCH (src:Node)-[rel:Rel]->(dest:Literal) " "RETURN src.uri as src, rel.uri as rel, dest.value as dest " - "LIMIT " + str(v.limit), + "LIMIT " + str(query.limit), database_=self.db, ) @@ -294,7 +275,7 @@ class Processor(ConsumerProducer): records, summary, keys = self.io.execute_query( "MATCH (src:Node)-[rel:Rel]->(dest:Node) " "RETURN src.uri as src, rel.uri as rel, dest.uri as dest " - "LIMIT " + str(v.limit), + "LIMIT " + str(query.limit), database_=self.db, ) @@ -308,40 +289,22 @@ class Processor(ConsumerProducer): p=self.create_value(t[1]), o=self.create_value(t[2]) ) - for t in triples[:v.limit] + for t in triples[:query.limit] ] - print("Send response...", flush=True) - r = TriplesQueryResponse(triples=triples, error=None) - await self.send(r, properties={"id": id}) - - print("Done.", flush=True) + return triples except Exception as e: print(f"Exception: {e}") - print("Send error response...", flush=True) - - r = TriplesQueryResponse( - error=Error( - type = "llm-error", - message = str(e), - ), - response=None, - ) - - await self.send(r, properties={"id": id}) - - self.consumer.acknowledge(msg) + print(f"Exception: {e}") + raise e @staticmethod def add_args(parser): - ConsumerProducer.add_args( - parser, default_input_queue, default_subscriber, - default_output_queue, - ) + TriplesQueryService.add_args(parser) parser.add_argument( '-g', '--graph-host', @@ -369,5 +332,5 @@ class Processor(ConsumerProducer): def run(): - Processor.launch(module, __doc__) + Processor.launch(default_ident, __doc__) diff --git a/trustgraph-flow/trustgraph/query/triples/neo4j/service.py b/trustgraph-flow/trustgraph/query/triples/neo4j/service.py index 591361ce..f65c0f56 100755 --- a/trustgraph-flow/trustgraph/query/triples/neo4j/service.py +++ b/trustgraph-flow/trustgraph/query/triples/neo4j/service.py @@ -9,28 +9,19 @@ from neo4j import GraphDatabase from .... schema import TriplesQueryRequest, TriplesQueryResponse, Error from .... schema import Value, Triple -from .... schema import triples_request_queue -from .... schema import triples_response_queue -from .... base import ConsumerProducer +from .... base import TriplesQueryService -module = "triples-query" - -default_input_queue = triples_request_queue -default_output_queue = triples_response_queue -default_subscriber = module +default_ident = "triples-query" default_graph_host = 'bolt://neo4j:7687' default_username = 'neo4j' default_password = 'password' default_database = 'neo4j' -class Processor(ConsumerProducer): +class Processor(TriplesQueryService): def __init__(self, **params): - input_queue = params.get("input_queue", default_input_queue) - output_queue = params.get("output_queue", default_output_queue) - subscriber = params.get("subscriber", default_subscriber) graph_host = params.get("graph_host", default_graph_host) username = params.get("username", default_username) password = params.get("password", default_password) @@ -38,12 +29,9 @@ class Processor(ConsumerProducer): super(Processor, self).__init__( **params | { - "input_queue": input_queue, - "output_queue": output_queue, - "subscriber": subscriber, - "input_schema": TriplesQueryRequest, - "output_schema": TriplesQueryResponse, "graph_host": graph_host, + "username": username, + "database": database, } ) @@ -58,44 +46,37 @@ class Processor(ConsumerProducer): else: return Value(value=ent, is_uri=False) - async def handle(self, msg): + async def query_triples(self, query): try: - v = msg.value() - - # Sender-produced ID - id = msg.properties()["id"] - - print(f"Handling input {id}...", flush=True) - triples = [] - if v.s is not None: - if v.p is not None: - if v.o is not None: + if query.s is not None: + if query.p is not None: + if query.o is not None: # SPO records, summary, keys = self.io.execute_query( "MATCH (src:Node {uri: $src})-[rel:Rel {uri: $rel}]->(dest:Literal {value: $value}) " "RETURN $src as src", - src=v.s.value, rel=v.p.value, value=v.o.value, + src=query.s.value, rel=query.p.value, value=query.o.value, database_=self.db, ) for rec in records: - triples.append((v.s.value, v.p.value, v.o.value)) + triples.append((query.s.value, query.p.value, query.o.value)) records, summary, keys = self.io.execute_query( "MATCH (src:Node {uri: $src})-[rel:Rel {uri: $rel}]->(dest:Node {uri: $uri}) " "RETURN $src as src", - src=v.s.value, rel=v.p.value, uri=v.o.value, + src=query.s.value, rel=query.p.value, uri=query.o.value, database_=self.db, ) for rec in records: - triples.append((v.s.value, v.p.value, v.o.value)) + triples.append((query.s.value, query.p.value, query.o.value)) else: @@ -104,52 +85,52 @@ class Processor(ConsumerProducer): records, summary, keys = self.io.execute_query( "MATCH (src:Node {uri: $src})-[rel:Rel {uri: $rel}]->(dest:Literal) " "RETURN dest.value as dest", - src=v.s.value, rel=v.p.value, + src=query.s.value, rel=query.p.value, database_=self.db, ) for rec in records: data = rec.data() - triples.append((v.s.value, v.p.value, data["dest"])) + triples.append((query.s.value, query.p.value, data["dest"])) records, summary, keys = self.io.execute_query( "MATCH (src:Node {uri: $src})-[rel:Rel {uri: $rel}]->(dest:Node) " "RETURN dest.uri as dest", - src=v.s.value, rel=v.p.value, + src=query.s.value, rel=query.p.value, database_=self.db, ) for rec in records: data = rec.data() - triples.append((v.s.value, v.p.value, data["dest"])) + triples.append((query.s.value, query.p.value, data["dest"])) else: - if v.o is not None: + if query.o is not None: # SO records, summary, keys = self.io.execute_query( "MATCH (src:Node {uri: $src})-[rel:Rel]->(dest:Literal {value: $value}) " "RETURN rel.uri as rel", - src=v.s.value, value=v.o.value, + src=query.s.value, value=query.o.value, database_=self.db, ) for rec in records: data = rec.data() - triples.append((v.s.value, data["rel"], v.o.value)) + triples.append((query.s.value, data["rel"], query.o.value)) records, summary, keys = self.io.execute_query( "MATCH (src:Node {uri: $src})-[rel:Rel]->(dest:Node {uri: $uri}) " "RETURN rel.uri as rel", - src=v.s.value, uri=v.o.value, + src=query.s.value, uri=query.o.value, database_=self.db, ) for rec in records: data = rec.data() - triples.append((v.s.value, data["rel"], v.o.value)) + triples.append((query.s.value, data["rel"], query.o.value)) else: @@ -158,55 +139,55 @@ class Processor(ConsumerProducer): records, summary, keys = self.io.execute_query( "MATCH (src:Node {uri: $src})-[rel:Rel]->(dest:Literal) " "RETURN rel.uri as rel, dest.value as dest", - src=v.s.value, + src=query.s.value, database_=self.db, ) for rec in records: data = rec.data() - triples.append((v.s.value, data["rel"], data["dest"])) + triples.append((query.s.value, data["rel"], data["dest"])) records, summary, keys = self.io.execute_query( "MATCH (src:Node {uri: $src})-[rel:Rel]->(dest:Node) " "RETURN rel.uri as rel, dest.uri as dest", - src=v.s.value, + src=query.s.value, database_=self.db, ) for rec in records: data = rec.data() - triples.append((v.s.value, data["rel"], data["dest"])) + triples.append((query.s.value, data["rel"], data["dest"])) else: - if v.p is not None: + if query.p is not None: - if v.o is not None: + if query.o is not None: # PO records, summary, keys = self.io.execute_query( "MATCH (src:Node)-[rel:Rel {uri: $uri}]->(dest:Literal {value: $value}) " "RETURN src.uri as src", - uri=v.p.value, value=v.o.value, + uri=query.p.value, value=query.o.value, database_=self.db, ) for rec in records: data = rec.data() - triples.append((data["src"], v.p.value, v.o.value)) + triples.append((data["src"], query.p.value, query.o.value)) records, summary, keys = self.io.execute_query( - "MATCH (src:Node)-[rel:Rel {uri: $uri}]->(dest:Node {uri: $uri}) " + "MATCH (src:Node)-[rel:Rel {uri: $uri}]->(dest:Node {uri: $dest}) " "RETURN src.uri as src", - uri=v.p.value, dest=v.o.value, + uri=query.p.value, dest=query.o.value, database_=self.db, ) for rec in records: data = rec.data() - triples.append((data["src"], v.p.value, v.o.value)) + triples.append((data["src"], query.p.value, query.o.value)) else: @@ -215,52 +196,52 @@ class Processor(ConsumerProducer): records, summary, keys = self.io.execute_query( "MATCH (src:Node)-[rel:Rel {uri: $uri}]->(dest:Literal) " "RETURN src.uri as src, dest.value as dest", - uri=v.p.value, + uri=query.p.value, database_=self.db, ) for rec in records: data = rec.data() - triples.append((data["src"], v.p.value, data["dest"])) + triples.append((data["src"], query.p.value, data["dest"])) records, summary, keys = self.io.execute_query( "MATCH (src:Node)-[rel:Rel {uri: $uri}]->(dest:Node) " "RETURN src.uri as src, dest.uri as dest", - uri=v.p.value, + uri=query.p.value, database_=self.db, ) for rec in records: data = rec.data() - triples.append((data["src"], v.p.value, data["dest"])) + triples.append((data["src"], query.p.value, data["dest"])) else: - if v.o is not None: + if query.o is not None: # O records, summary, keys = self.io.execute_query( "MATCH (src:Node)-[rel:Rel]->(dest:Literal {value: $value}) " "RETURN src.uri as src, rel.uri as rel", - value=v.o.value, + value=query.o.value, database_=self.db, ) for rec in records: data = rec.data() - triples.append((data["src"], data["rel"], v.o.value)) + triples.append((data["src"], data["rel"], query.o.value)) records, summary, keys = self.io.execute_query( "MATCH (src:Node)-[rel:Rel]->(dest:Node {uri: $uri}) " "RETURN src.uri as src, rel.uri as rel", - uri=v.o.value, + uri=query.o.value, database_=self.db, ) for rec in records: data = rec.data() - triples.append((data["src"], data["rel"], v.o.value)) + triples.append((data["src"], data["rel"], query.o.value)) else: @@ -295,37 +276,17 @@ class Processor(ConsumerProducer): for t in triples ] - print("Send response...", flush=True) - r = TriplesQueryResponse(triples=triples, error=None) - await self.send(r, properties={"id": id}) - - print("Done.", flush=True) + return triples except Exception as e: print(f"Exception: {e}") - - print("Send error response...", flush=True) - - r = TriplesQueryResponse( - error=Error( - type = "llm-error", - message = str(e), - ), - response=None, - ) - - await self.send(r, properties={"id": id}) - - self.consumer.acknowledge(msg) + raise e @staticmethod def add_args(parser): - ConsumerProducer.add_args( - parser, default_input_queue, default_subscriber, - default_output_queue, - ) + TriplesQueryService.add_args(parser) parser.add_argument( '-g', '--graph-host', @@ -353,5 +314,5 @@ class Processor(ConsumerProducer): def run(): - Processor.launch(module, __doc__) + Processor.launch(default_ident, __doc__) diff --git a/trustgraph-flow/trustgraph/storage/doc_embeddings/milvus/write.py b/trustgraph-flow/trustgraph/storage/doc_embeddings/milvus/write.py index 2949263a..05027d75 100755 --- a/trustgraph-flow/trustgraph/storage/doc_embeddings/milvus/write.py +++ b/trustgraph-flow/trustgraph/storage/doc_embeddings/milvus/write.py @@ -4,58 +4,41 @@ Accepts entity/vector pairs and writes them to a Milvus store. """ from .... direct.milvus_doc_embeddings import DocVectors +from .... base import DocumentEmbeddingsStoreService -from .... schema import DocumentEmbeddings -from .... schema import document_embeddings_store_queue -from .... log_level import LogLevel -from .... base import Consumer - -module = "de-write" - -default_input_queue = document_embeddings_store_queue -default_subscriber = module +default_ident = "de-write" default_store_uri = 'http://localhost:19530' -class Processor(Consumer): +class Processor(DocumentEmbeddingsStoreService): def __init__(self, **params): - input_queue = params.get("input_queue", default_input_queue) - subscriber = params.get("subscriber", default_subscriber) store_uri = params.get("store_uri", default_store_uri) super(Processor, self).__init__( **params | { - "input_queue": input_queue, - "subscriber": subscriber, - "input_schema": DocumentEmbeddings, "store_uri": store_uri, } ) self.vecstore = DocVectors(store_uri) - async def handle(self, msg): + async def store_document_embeddings(self, message): - v = msg.value() - - for emb in v.chunks: + for emb in message.chunks: + if emb.chunk is None or emb.chunk == b"": continue + chunk = emb.chunk.decode("utf-8") - if chunk == "" or chunk is None: continue + if chunk == "": continue for vec in emb.vectors: - - if chunk != "" and v.chunk is not None: - for vec in v.vectors: - self.vecstore.insert(vec, chunk) + self.vecstore.insert(vec, chunk) @staticmethod def add_args(parser): - Consumer.add_args( - parser, default_input_queue, default_subscriber, - ) + DocumentEmbeddingsStoreService.add_args(parser) parser.add_argument( '-t', '--store-uri', @@ -65,5 +48,5 @@ class Processor(Consumer): def run(): - Processor.launch(module, __doc__) + Processor.launch(default_ident, __doc__) diff --git a/trustgraph-flow/trustgraph/storage/doc_embeddings/pinecone/write.py b/trustgraph-flow/trustgraph/storage/doc_embeddings/pinecone/write.py index 128323aa..0d8bac83 100644 --- a/trustgraph-flow/trustgraph/storage/doc_embeddings/pinecone/write.py +++ b/trustgraph-flow/trustgraph/storage/doc_embeddings/pinecone/write.py @@ -1,42 +1,32 @@ """ -Accepts entity/vector pairs and writes them to a Qdrant store. +Accepts document chunks/vector pairs and writes them to a Pinecone store. """ -from qdrant_client import QdrantClient -from qdrant_client.models import PointStruct -from qdrant_client.models import Distance, VectorParams +from pinecone import Pinecone, ServerlessSpec +from pinecone.grpc import PineconeGRPC, GRPCClientConfig import time import uuid import os -from .... schema import DocumentEmbeddings -from .... schema import document_embeddings_store_queue -from .... log_level import LogLevel -from .... base import Consumer +from .... base import DocumentEmbeddingsStoreService -module = "de-write" - -default_input_queue = document_embeddings_store_queue -default_subscriber = module +default_ident = "de-write" default_api_key = os.getenv("PINECONE_API_KEY", "not-specified") default_cloud = "aws" default_region = "us-east-1" -class Processor(Consumer): +class Processor(DocumentEmbeddingsStoreService): def __init__(self, **params): - input_queue = params.get("input_queue", default_input_queue) - subscriber = params.get("subscriber", default_subscriber) - self.url = params.get("url", None) self.cloud = params.get("cloud", default_cloud) self.region = params.get("region", default_region) self.api_key = params.get("api_key", default_api_key) - if self.api_key is None: + if self.api_key is None or self.api_key == "not-specified": raise RuntimeError("Pinecone API key must be specified") if self.url: @@ -52,94 +42,96 @@ class Processor(Consumer): super(Processor, self).__init__( **params | { - "input_queue": input_queue, - "subscriber": subscriber, - "input_schema": DocumentEmbeddings, "url": self.url, + "cloud": self.cloud, + "region": self.region, + "api_key": self.api_key, } ) self.last_index_name = None - async def handle(self, msg): + def create_index(self, index_name, dim): - v = msg.value() + self.pinecone.create_index( + name = index_name, + dimension = dim, + metric = "cosine", + spec = ServerlessSpec( + cloud = self.cloud, + region = self.region, + ) + ) - for emb in v.chunks: + for i in range(0, 1000): + if self.pinecone.describe_index( + index_name + ).status["ready"]: + break + + time.sleep(1) + + if not self.pinecone.describe_index( + index_name + ).status["ready"]: + raise RuntimeError( + "Gave up waiting for index creation" + ) + + async def store_document_embeddings(self, message): + + for emb in message.chunks: + + if emb.chunk is None or emb.chunk == b"": continue + chunk = emb.chunk.decode("utf-8") - if chunk == "" or chunk is None: continue + if chunk == "": continue for vec in emb.vectors: - for vec in v.vectors: + dim = len(vec) + index_name = ( + "d-" + message.metadata.user + "-" + message.metadata.collection + "-" + str(dim) + ) - dim = len(vec) - collection = ( - "d-" + v.metadata.user + "-" + str(dim) - ) + if index_name != self.last_index_name: - if index_name != self.last_index_name: + if not self.pinecone.has_index(index_name): - if not self.pinecone.has_index(index_name): + try: - try: + self.create_index(index_name, dim) - self.pinecone.create_index( - name = index_name, - dimension = dim, - metric = "cosine", - spec = ServerlessSpec( - cloud = self.cloud, - region = self.region, - ) - ) + except Exception as e: + print("Pinecone index creation failed") + raise e - for i in range(0, 1000): + print(f"Index {index_name} created", flush=True) - if self.pinecone.describe_index( - index_name - ).status["ready"]: - break + self.last_index_name = index_name - time.sleep(1) + index = self.pinecone.Index(index_name) - if not self.pinecone.describe_index( - index_name - ).status["ready"]: - raise RuntimeError( - "Gave up waiting for index creation" - ) + # Generate unique ID for each vector + vector_id = str(uuid.uuid4()) - except Exception as e: - print("Pinecone index creation failed") - raise e + records = [ + { + "id": vector_id, + "values": vec, + "metadata": { "doc": chunk }, + } + ] - print(f"Index {index_name} created", flush=True) - - self.last_index_name = index_name - - index = self.pinecone.Index(index_name) - - records = [ - { - "id": id, - "values": vec, - "metadata": { "doc": chunk }, - } - ] - - index.upsert( - vectors = records, - namespace = v.metadata.collection, - ) + index.upsert( + vectors = records, + ) @staticmethod def add_args(parser): - Consumer.add_args( - parser, default_input_queue, default_subscriber, - ) + DocumentEmbeddingsStoreService.add_args(parser) parser.add_argument( '-a', '--api-key', @@ -166,5 +158,5 @@ class Processor(Consumer): def run(): - Processor.launch(module, __doc__) + Processor.launch(default_ident, __doc__) diff --git a/trustgraph-flow/trustgraph/storage/graph_embeddings/milvus/write.py b/trustgraph-flow/trustgraph/storage/graph_embeddings/milvus/write.py index 8d8b68b0..f140ab76 100755 --- a/trustgraph-flow/trustgraph/storage/graph_embeddings/milvus/write.py +++ b/trustgraph-flow/trustgraph/storage/graph_embeddings/milvus/write.py @@ -3,42 +3,29 @@ Accepts entity/vector pairs and writes them to a Milvus store. """ -from .... schema import GraphEmbeddings -from .... schema import graph_embeddings_store_queue -from .... log_level import LogLevel from .... direct.milvus_graph_embeddings import EntityVectors -from .... base import Consumer +from .... base import GraphEmbeddingsStoreService -module = "ge-write" - -default_input_queue = graph_embeddings_store_queue -default_subscriber = module +default_ident = "ge-write" default_store_uri = 'http://localhost:19530' -class Processor(Consumer): +class Processor(GraphEmbeddingsStoreService): def __init__(self, **params): - input_queue = params.get("input_queue", default_input_queue) - subscriber = params.get("subscriber", default_subscriber) store_uri = params.get("store_uri", default_store_uri) super(Processor, self).__init__( **params | { - "input_queue": input_queue, - "subscriber": subscriber, - "input_schema": GraphEmbeddings, "store_uri": store_uri, } ) self.vecstore = EntityVectors(store_uri) - async def handle(self, msg): + async def store_graph_embeddings(self, message): - v = msg.value() - - for entity in v.entities: + for entity in message.entities: if entity.entity.value != "" and entity.entity.value is not None: for vec in entity.vectors: @@ -47,9 +34,7 @@ class Processor(Consumer): @staticmethod def add_args(parser): - Consumer.add_args( - parser, default_input_queue, default_subscriber, - ) + GraphEmbeddingsStoreService.add_args(parser) parser.add_argument( '-t', '--store-uri', @@ -59,5 +44,5 @@ class Processor(Consumer): def run(): - Processor.launch(module, __doc__) + Processor.launch(default_ident, __doc__) diff --git a/trustgraph-flow/trustgraph/storage/graph_embeddings/pinecone/write.py b/trustgraph-flow/trustgraph/storage/graph_embeddings/pinecone/write.py index 400acf26..e575d12a 100755 --- a/trustgraph-flow/trustgraph/storage/graph_embeddings/pinecone/write.py +++ b/trustgraph-flow/trustgraph/storage/graph_embeddings/pinecone/write.py @@ -10,32 +10,23 @@ import time import uuid import os -from .... schema import GraphEmbeddings -from .... schema import graph_embeddings_store_queue -from .... log_level import LogLevel -from .... base import Consumer +from .... base import GraphEmbeddingsStoreService -module = "ge-write" - -default_input_queue = graph_embeddings_store_queue -default_subscriber = module +default_ident = "ge-write" default_api_key = os.getenv("PINECONE_API_KEY", "not-specified") default_cloud = "aws" default_region = "us-east-1" -class Processor(Consumer): +class Processor(GraphEmbeddingsStoreService): def __init__(self, **params): - input_queue = params.get("input_queue", default_input_queue) - subscriber = params.get("subscriber", default_subscriber) - self.url = params.get("url", None) self.cloud = params.get("cloud", default_cloud) self.region = params.get("region", default_region) self.api_key = params.get("api_key", default_api_key) - if self.api_key is None: + if self.api_key is None or self.api_key == "not-specified": raise RuntimeError("Pinecone API key must be specified") if self.url: @@ -51,10 +42,10 @@ class Processor(Consumer): super(Processor, self).__init__( **params | { - "input_queue": input_queue, - "subscriber": subscriber, - "input_schema": GraphEmbeddings, "url": self.url, + "cloud": self.cloud, + "region": self.region, + "api_key": self.api_key, } ) @@ -88,13 +79,9 @@ class Processor(Consumer): "Gave up waiting for index creation" ) - async def handle(self, msg): + async def store_graph_embeddings(self, message): - v = msg.value() - - id = str(uuid.uuid4()) - - for entity in v.entities: + for entity in message.entities: if entity.entity.value == "" or entity.entity.value is None: continue @@ -104,7 +91,7 @@ class Processor(Consumer): dim = len(vec) index_name = ( - "t-" + v.metadata.user + "-" + str(dim) + "t-" + message.metadata.user + "-" + message.metadata.collection + "-" + str(dim) ) if index_name != self.last_index_name: @@ -125,9 +112,12 @@ class Processor(Consumer): index = self.pinecone.Index(index_name) + # Generate unique ID for each vector + vector_id = str(uuid.uuid4()) + records = [ { - "id": id, + "id": vector_id, "values": vec, "metadata": { "entity": entity.entity.value }, } @@ -135,15 +125,12 @@ class Processor(Consumer): index.upsert( vectors = records, - namespace = v.metadata.collection, ) @staticmethod def add_args(parser): - Consumer.add_args( - parser, default_input_queue, default_subscriber, - ) + GraphEmbeddingsStoreService.add_args(parser) parser.add_argument( '-a', '--api-key', @@ -170,5 +157,5 @@ class Processor(Consumer): def run(): - Processor.launch(module, __doc__) + Processor.launch(default_ident, __doc__) diff --git a/trustgraph-flow/trustgraph/storage/triples/falkordb/write.py b/trustgraph-flow/trustgraph/storage/triples/falkordb/write.py index b3996b91..defb7d69 100755 --- a/trustgraph-flow/trustgraph/storage/triples/falkordb/write.py +++ b/trustgraph-flow/trustgraph/storage/triples/falkordb/write.py @@ -11,34 +11,24 @@ import time from falkordb import FalkorDB -from .... schema import Triples -from .... schema import triples_store_queue -from .... log_level import LogLevel -from .... base import Consumer +from .... base import TriplesStoreService -module = "triples-write" - -default_input_queue = triples_store_queue -default_subscriber = module +default_ident = "triples-write" default_graph_url = 'falkor://falkordb:6379' default_database = 'falkordb' -class Processor(Consumer): +class Processor(TriplesStoreService): def __init__(self, **params): - input_queue = params.get("input_queue", default_input_queue) - subscriber = params.get("subscriber", default_subscriber) - graph_url = params.get("graph_host", default_graph_url) + graph_url = params.get("graph_url", default_graph_url) database = params.get("database", default_database) super(Processor, self).__init__( **params | { - "input_queue": input_queue, - "subscriber": subscriber, - "input_schema": Triples, "graph_url": graph_url, + "database": database, } ) @@ -118,11 +108,9 @@ class Processor(Consumer): time=res.run_time_ms )) - async def handle(self, msg): + async def store_triples(self, message): - v = msg.value() - - for t in v.triples: + for t in message.triples: self.create_node(t.s.value) @@ -136,14 +124,12 @@ class Processor(Consumer): @staticmethod def add_args(parser): - Consumer.add_args( - parser, default_input_queue, default_subscriber, - ) + TriplesStoreService.add_args(parser) parser.add_argument( - '-g', '--graph_host', + '-g', '--graph-url', default=default_graph_url, - help=f'Graph host (default: {default_graph_url})' + help=f'Graph URL (default: {default_graph_url})' ) parser.add_argument( @@ -154,5 +140,5 @@ class Processor(Consumer): def run(): - Processor.launch(module, __doc__) + Processor.launch(default_ident, __doc__) diff --git a/trustgraph-flow/trustgraph/storage/triples/memgraph/write.py b/trustgraph-flow/trustgraph/storage/triples/memgraph/write.py index 8c88ea8f..9079923e 100755 --- a/trustgraph-flow/trustgraph/storage/triples/memgraph/write.py +++ b/trustgraph-flow/trustgraph/storage/triples/memgraph/write.py @@ -11,27 +11,19 @@ import time from neo4j import GraphDatabase -from .... schema import Triples -from .... schema import triples_store_queue -from .... log_level import LogLevel -from .... base import Consumer +from .... base import TriplesStoreService -module = "triples-write" - -default_input_queue = triples_store_queue -default_subscriber = module +default_ident = "triples-write" default_graph_host = 'bolt://memgraph:7687' default_username = 'memgraph' default_password = 'password' default_database = 'memgraph' -class Processor(Consumer): +class Processor(TriplesStoreService): def __init__(self, **params): - input_queue = params.get("input_queue", default_input_queue) - subscriber = params.get("subscriber", default_subscriber) graph_host = params.get("graph_host", default_graph_host) username = params.get("username", default_username) password = params.get("password", default_password) @@ -39,10 +31,10 @@ class Processor(Consumer): super(Processor, self).__init__( **params | { - "input_queue": input_queue, - "subscriber": subscriber, - "input_schema": Triples, "graph_host": graph_host, + "username": username, + "password": password, + "database": database, } ) @@ -205,11 +197,9 @@ class Processor(Consumer): src=t.s.value, dest=t.o.value, uri=t.p.value, ) - async def handle(self, msg): + async def store_triples(self, message): - v = msg.value() - - for t in v.triples: + for t in message.triples: # self.create_node(t.s.value) @@ -226,12 +216,10 @@ class Processor(Consumer): @staticmethod def add_args(parser): - Consumer.add_args( - parser, default_input_queue, default_subscriber, - ) + TriplesStoreService.add_args(parser) parser.add_argument( - '-g', '--graph_host', + '-g', '--graph-host', default=default_graph_host, help=f'Graph host (default: {default_graph_host})' ) @@ -256,5 +244,5 @@ class Processor(Consumer): def run(): - Processor.launch(module, __doc__) + Processor.launch(default_ident, __doc__) diff --git a/trustgraph-flow/trustgraph/storage/triples/neo4j/write.py b/trustgraph-flow/trustgraph/storage/triples/neo4j/write.py index 84a4d923..5293ee1e 100755 --- a/trustgraph-flow/trustgraph/storage/triples/neo4j/write.py +++ b/trustgraph-flow/trustgraph/storage/triples/neo4j/write.py @@ -10,28 +10,21 @@ import argparse import time from neo4j import GraphDatabase +from .... base import TriplesStoreService -from .... schema import Triples -from .... schema import triples_store_queue -from .... log_level import LogLevel -from .... base import Consumer - -module = "triples-write" - -default_input_queue = triples_store_queue -default_subscriber = module +default_ident = "triples-write" default_graph_host = 'bolt://neo4j:7687' default_username = 'neo4j' default_password = 'password' default_database = 'neo4j' -class Processor(Consumer): +class Processor(TriplesStoreService): def __init__(self, **params): - input_queue = params.get("input_queue", default_input_queue) - subscriber = params.get("subscriber", default_subscriber) + id = params.get("id", default_ident) + graph_host = params.get("graph_host", default_graph_host) username = params.get("username", default_username) password = params.get("password", default_password) @@ -39,10 +32,9 @@ class Processor(Consumer): super(Processor, self).__init__( **params | { - "input_queue": input_queue, - "subscriber": subscriber, - "input_schema": Triples, "graph_host": graph_host, + "username": username, + "database": database, } ) @@ -158,11 +150,9 @@ class Processor(Consumer): time=summary.result_available_after )) - async def handle(self, msg): + async def store_triples(self, message): - v = msg.value() - - for t in v.triples: + for t in message.triples: self.create_node(t.s.value) @@ -176,9 +166,7 @@ class Processor(Consumer): @staticmethod def add_args(parser): - Consumer.add_args( - parser, default_input_queue, default_subscriber, - ) + TriplesStoreService.add_args(parser) parser.add_argument( '-g', '--graph_host', @@ -206,5 +194,5 @@ class Processor(Consumer): def run(): - Processor.launch(module, __doc__) + Processor.launch(default_ident, __doc__)