diff --git a/tests/unit/test_query/test_graph_embeddings_pinecone_query.py b/tests/unit/test_query/test_graph_embeddings_pinecone_query.py index 1b243113..cacc8448 100644 --- a/tests/unit/test_query/test_graph_embeddings_pinecone_query.py +++ b/tests/unit/test_query/test_graph_embeddings_pinecone_query.py @@ -9,7 +9,7 @@ from unittest.mock import MagicMock, patch pytest.skip("Pinecone library missing protoc_gen_openapiv2 dependency", allow_module_level=True) from trustgraph.query.graph_embeddings.pinecone.service import Processor -from trustgraph.schema import Term, IRI, LITERAL +from trustgraph.schema import Term, IRI, LITERAL, EntityMatch class TestPineconeGraphEmbeddingsQueryProcessor: @@ -19,10 +19,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor: def mock_query_message(self): """Create a mock query message for testing""" message = MagicMock() - message.vectors = [ - [0.1, 0.2, 0.3], - [0.4, 0.5, 0.6] - ] + message.vector = [0.1, 0.2, 0.3] message.limit = 5 message.user = 'test_user' message.collection = 'test_collection' @@ -131,7 +128,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor: async def test_query_graph_embeddings_single_vector(self, processor): """Test querying graph embeddings with a single vector""" message = MagicMock() - message.vectors = [[0.1, 0.2, 0.3]] + message.vector = [0.1, 0.2, 0.3] message.limit = 3 message.user = 'test_user' message.collection = 'test_collection' @@ -162,45 +159,39 @@ class TestPineconeGraphEmbeddingsQueryProcessor: include_metadata=True ) - # Verify results + # Verify results use EntityMatch structure assert len(entities) == 3 - assert entities[0].value == 'http://example.org/entity1' - assert entities[0].type == IRI - assert entities[1].value == 'entity2' - assert entities[1].type == LITERAL - assert entities[2].value == 'http://example.org/entity3' - assert entities[2].type == IRI + assert entities[0].entity.iri == 'http://example.org/entity1' + assert entities[0].entity.type == IRI + assert entities[1].entity.value == 'entity2' + assert entities[1].entity.type == LITERAL + assert entities[2].entity.iri == 'http://example.org/entity3' + assert entities[2].entity.type == IRI @pytest.mark.asyncio - async def test_query_graph_embeddings_multiple_vectors(self, processor, mock_query_message): - """Test querying graph embeddings with multiple vectors""" + async def test_query_graph_embeddings_basic(self, processor, mock_query_message): + """Test basic graph embeddings query""" # Mock index and query results mock_index = MagicMock() processor.pinecone.Index.return_value = mock_index - - # First query results - mock_results1 = MagicMock() - mock_results1.matches = [ + + # Query results with distinct entities + mock_results = MagicMock() + mock_results.matches = [ MagicMock(metadata={'entity': 'entity1'}), - MagicMock(metadata={'entity': 'entity2'}) - ] - - # Second query results - mock_results2 = MagicMock() - mock_results2.matches = [ - MagicMock(metadata={'entity': 'entity2'}), # Duplicate + MagicMock(metadata={'entity': 'entity2'}), MagicMock(metadata={'entity': 'entity3'}) ] - - mock_index.query.side_effect = [mock_results1, mock_results2] - + + mock_index.query.return_value = mock_results + entities = await processor.query_graph_embeddings(mock_query_message) - - # Verify both queries were made - assert mock_index.query.call_count == 2 - - # Verify deduplication occurred - entity_values = [e.value for e in entities] + + # Verify query was made once + assert mock_index.query.call_count == 1 + + # Verify results with EntityMatch structure + entity_values = [e.entity.value for e in entities] assert len(entity_values) == 3 assert 'entity1' in entity_values assert 'entity2' in entity_values @@ -210,7 +201,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor: async def test_query_graph_embeddings_limit_handling(self, processor): """Test that query respects the limit parameter""" message = MagicMock() - message.vectors = [[0.1, 0.2, 0.3]] + message.vector = [0.1, 0.2, 0.3] message.limit = 2 message.user = 'test_user' message.collection = 'test_collection' @@ -234,7 +225,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor: async def test_query_graph_embeddings_zero_limit(self, processor): """Test querying with zero limit returns empty results""" message = MagicMock() - message.vectors = [[0.1, 0.2, 0.3]] + message.vector = [0.1, 0.2, 0.3] message.limit = 0 message.user = 'test_user' message.collection = 'test_collection' @@ -252,7 +243,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor: async def test_query_graph_embeddings_negative_limit(self, processor): """Test querying with negative limit returns empty results""" message = MagicMock() - message.vectors = [[0.1, 0.2, 0.3]] + message.vector = [0.1, 0.2, 0.3] message.limit = -1 message.user = 'test_user' message.collection = 'test_collection' @@ -267,52 +258,41 @@ class TestPineconeGraphEmbeddingsQueryProcessor: assert entities == [] @pytest.mark.asyncio - async def test_query_graph_embeddings_different_vector_dimensions(self, processor): - """Test querying with vectors of different dimensions using same index""" + async def test_query_graph_embeddings_2d_vector(self, processor): + """Test querying with a 2D vector""" message = MagicMock() - message.vectors = [ - [0.1, 0.2], # 2D vector - [0.3, 0.4, 0.5, 0.6] # 4D vector - ] + message.vector = [0.1, 0.2] # 2D vector message.limit = 5 message.user = 'test_user' message.collection = 'test_collection' - # Mock single index that handles all dimensions + # Mock index mock_index = MagicMock() processor.pinecone.Index.return_value = mock_index - # Mock results for different vector queries - mock_results_2d = MagicMock() - mock_results_2d.matches = [MagicMock(metadata={'entity': 'entity_2d'})] + # Mock results for 2D vector query + mock_results = MagicMock() + mock_results.matches = [MagicMock(metadata={'entity': 'entity_2d'})] - mock_results_4d = MagicMock() - mock_results_4d.matches = [MagicMock(metadata={'entity': 'entity_4d'})] - - mock_index.query.side_effect = [mock_results_2d, mock_results_4d] + mock_index.query.return_value = mock_results entities = await processor.query_graph_embeddings(message) - # Verify different indexes used for different dimensions - assert processor.pinecone.Index.call_count == 2 - index_calls = processor.pinecone.Index.call_args_list - index_names = [call[0][0] for call in index_calls] - assert "t-test_user-test_collection-2" in index_names # 2D vector - assert "t-test_user-test_collection-4" in index_names # 4D vector + # Verify correct index used for 2D vector + processor.pinecone.Index.assert_called_with("t-test_user-test_collection-2") - # Verify both queries were made - assert mock_index.query.call_count == 2 + # Verify query was made + assert mock_index.query.call_count == 1 - # Verify results from both dimensions - entity_values = [e.value for e in entities] + # Verify results with EntityMatch structure + entity_values = [e.entity.value for e in entities] assert 'entity_2d' in entity_values - assert 'entity_4d' in entity_values @pytest.mark.asyncio async def test_query_graph_embeddings_empty_vectors_list(self, processor): """Test querying with empty vectors list""" message = MagicMock() - message.vectors = [] + message.vector = [] message.limit = 5 message.user = 'test_user' message.collection = 'test_collection' @@ -331,7 +311,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor: async def test_query_graph_embeddings_no_results(self, processor): """Test querying when index returns no results""" message = MagicMock() - message.vectors = [[0.1, 0.2, 0.3]] + message.vector = [0.1, 0.2, 0.3] message.limit = 5 message.user = 'test_user' message.collection = 'test_collection' @@ -349,43 +329,34 @@ class TestPineconeGraphEmbeddingsQueryProcessor: assert entities == [] @pytest.mark.asyncio - async def test_query_graph_embeddings_deduplication_across_vectors(self, processor): - """Test that deduplication works correctly across multiple vector queries""" + async def test_query_graph_embeddings_deduplication_in_results(self, processor): + """Test that deduplication works correctly within query results""" message = MagicMock() - message.vectors = [ - [0.1, 0.2, 0.3], - [0.4, 0.5, 0.6] - ] + message.vector = [0.1, 0.2, 0.3] message.limit = 3 message.user = 'test_user' message.collection = 'test_collection' - + mock_index = MagicMock() processor.pinecone.Index.return_value = mock_index - - # Both queries return overlapping results - mock_results1 = MagicMock() - mock_results1.matches = [ + + # Query returns results with some duplicates + mock_results = MagicMock() + mock_results.matches = [ MagicMock(metadata={'entity': 'entity1'}), MagicMock(metadata={'entity': 'entity2'}), + MagicMock(metadata={'entity': 'entity1'}), # Duplicate MagicMock(metadata={'entity': 'entity3'}), - MagicMock(metadata={'entity': 'entity4'}) - ] - - mock_results2 = MagicMock() - mock_results2.matches = [ MagicMock(metadata={'entity': 'entity2'}), # Duplicate - MagicMock(metadata={'entity': 'entity3'}), # Duplicate - MagicMock(metadata={'entity': 'entity5'}) ] - - mock_index.query.side_effect = [mock_results1, mock_results2] - + + mock_index.query.return_value = mock_results + entities = await processor.query_graph_embeddings(message) - + # Should get exactly 3 unique entities (respecting limit) assert len(entities) == 3 - entity_values = [e.value for e in entities] + entity_values = [e.entity.value for e in entities] assert len(set(entity_values)) == 3 # All unique @pytest.mark.asyncio @@ -423,7 +394,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor: async def test_query_graph_embeddings_exception_handling(self, processor): """Test that exceptions are properly raised""" message = MagicMock() - message.vectors = [[0.1, 0.2, 0.3]] + message.vector = [0.1, 0.2, 0.3] message.limit = 5 message.user = 'test_user' message.collection = 'test_collection'