Fixing tests

This commit is contained in:
Cyber MacGeddon 2026-03-09 10:33:24 +00:00
parent f261604991
commit 2cb29380fa

View file

@ -9,7 +9,7 @@ from unittest.mock import MagicMock, patch
pytest.skip("Pinecone library missing protoc_gen_openapiv2 dependency", allow_module_level=True) pytest.skip("Pinecone library missing protoc_gen_openapiv2 dependency", allow_module_level=True)
from trustgraph.query.graph_embeddings.pinecone.service import Processor from trustgraph.query.graph_embeddings.pinecone.service import Processor
from trustgraph.schema import Term, IRI, LITERAL from trustgraph.schema import Term, IRI, LITERAL, EntityMatch
class TestPineconeGraphEmbeddingsQueryProcessor: class TestPineconeGraphEmbeddingsQueryProcessor:
@ -19,10 +19,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
def mock_query_message(self): def mock_query_message(self):
"""Create a mock query message for testing""" """Create a mock query message for testing"""
message = MagicMock() message = MagicMock()
message.vectors = [ message.vector = [0.1, 0.2, 0.3]
[0.1, 0.2, 0.3],
[0.4, 0.5, 0.6]
]
message.limit = 5 message.limit = 5
message.user = 'test_user' message.user = 'test_user'
message.collection = 'test_collection' message.collection = 'test_collection'
@ -131,7 +128,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
async def test_query_graph_embeddings_single_vector(self, processor): async def test_query_graph_embeddings_single_vector(self, processor):
"""Test querying graph embeddings with a single vector""" """Test querying graph embeddings with a single vector"""
message = MagicMock() message = MagicMock()
message.vectors = [[0.1, 0.2, 0.3]] message.vector = [0.1, 0.2, 0.3]
message.limit = 3 message.limit = 3
message.user = 'test_user' message.user = 'test_user'
message.collection = 'test_collection' message.collection = 'test_collection'
@ -162,45 +159,39 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
include_metadata=True include_metadata=True
) )
# Verify results # Verify results use EntityMatch structure
assert len(entities) == 3 assert len(entities) == 3
assert entities[0].value == 'http://example.org/entity1' assert entities[0].entity.iri == 'http://example.org/entity1'
assert entities[0].type == IRI assert entities[0].entity.type == IRI
assert entities[1].value == 'entity2' assert entities[1].entity.value == 'entity2'
assert entities[1].type == LITERAL assert entities[1].entity.type == LITERAL
assert entities[2].value == 'http://example.org/entity3' assert entities[2].entity.iri == 'http://example.org/entity3'
assert entities[2].type == IRI assert entities[2].entity.type == IRI
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_query_graph_embeddings_multiple_vectors(self, processor, mock_query_message): async def test_query_graph_embeddings_basic(self, processor, mock_query_message):
"""Test querying graph embeddings with multiple vectors""" """Test basic graph embeddings query"""
# Mock index and query results # Mock index and query results
mock_index = MagicMock() mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index processor.pinecone.Index.return_value = mock_index
# First query results # Query results with distinct entities
mock_results1 = MagicMock() mock_results = MagicMock()
mock_results1.matches = [ mock_results.matches = [
MagicMock(metadata={'entity': 'entity1'}), MagicMock(metadata={'entity': 'entity1'}),
MagicMock(metadata={'entity': 'entity2'}) MagicMock(metadata={'entity': 'entity2'}),
]
# Second query results
mock_results2 = MagicMock()
mock_results2.matches = [
MagicMock(metadata={'entity': 'entity2'}), # Duplicate
MagicMock(metadata={'entity': 'entity3'}) MagicMock(metadata={'entity': 'entity3'})
] ]
mock_index.query.side_effect = [mock_results1, mock_results2] mock_index.query.return_value = mock_results
entities = await processor.query_graph_embeddings(mock_query_message) entities = await processor.query_graph_embeddings(mock_query_message)
# Verify both queries were made # Verify query was made once
assert mock_index.query.call_count == 2 assert mock_index.query.call_count == 1
# Verify deduplication occurred # Verify results with EntityMatch structure
entity_values = [e.value for e in entities] entity_values = [e.entity.value for e in entities]
assert len(entity_values) == 3 assert len(entity_values) == 3
assert 'entity1' in entity_values assert 'entity1' in entity_values
assert 'entity2' in entity_values assert 'entity2' in entity_values
@ -210,7 +201,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
async def test_query_graph_embeddings_limit_handling(self, processor): async def test_query_graph_embeddings_limit_handling(self, processor):
"""Test that query respects the limit parameter""" """Test that query respects the limit parameter"""
message = MagicMock() message = MagicMock()
message.vectors = [[0.1, 0.2, 0.3]] message.vector = [0.1, 0.2, 0.3]
message.limit = 2 message.limit = 2
message.user = 'test_user' message.user = 'test_user'
message.collection = 'test_collection' message.collection = 'test_collection'
@ -234,7 +225,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
async def test_query_graph_embeddings_zero_limit(self, processor): async def test_query_graph_embeddings_zero_limit(self, processor):
"""Test querying with zero limit returns empty results""" """Test querying with zero limit returns empty results"""
message = MagicMock() message = MagicMock()
message.vectors = [[0.1, 0.2, 0.3]] message.vector = [0.1, 0.2, 0.3]
message.limit = 0 message.limit = 0
message.user = 'test_user' message.user = 'test_user'
message.collection = 'test_collection' message.collection = 'test_collection'
@ -252,7 +243,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
async def test_query_graph_embeddings_negative_limit(self, processor): async def test_query_graph_embeddings_negative_limit(self, processor):
"""Test querying with negative limit returns empty results""" """Test querying with negative limit returns empty results"""
message = MagicMock() message = MagicMock()
message.vectors = [[0.1, 0.2, 0.3]] message.vector = [0.1, 0.2, 0.3]
message.limit = -1 message.limit = -1
message.user = 'test_user' message.user = 'test_user'
message.collection = 'test_collection' message.collection = 'test_collection'
@ -267,52 +258,41 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
assert entities == [] assert entities == []
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_query_graph_embeddings_different_vector_dimensions(self, processor): async def test_query_graph_embeddings_2d_vector(self, processor):
"""Test querying with vectors of different dimensions using same index""" """Test querying with a 2D vector"""
message = MagicMock() message = MagicMock()
message.vectors = [ message.vector = [0.1, 0.2] # 2D vector
[0.1, 0.2], # 2D vector
[0.3, 0.4, 0.5, 0.6] # 4D vector
]
message.limit = 5 message.limit = 5
message.user = 'test_user' message.user = 'test_user'
message.collection = 'test_collection' message.collection = 'test_collection'
# Mock single index that handles all dimensions # Mock index
mock_index = MagicMock() mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index processor.pinecone.Index.return_value = mock_index
# Mock results for different vector queries # Mock results for 2D vector query
mock_results_2d = MagicMock() mock_results = MagicMock()
mock_results_2d.matches = [MagicMock(metadata={'entity': 'entity_2d'})] mock_results.matches = [MagicMock(metadata={'entity': 'entity_2d'})]
mock_results_4d = MagicMock() mock_index.query.return_value = mock_results
mock_results_4d.matches = [MagicMock(metadata={'entity': 'entity_4d'})]
mock_index.query.side_effect = [mock_results_2d, mock_results_4d]
entities = await processor.query_graph_embeddings(message) entities = await processor.query_graph_embeddings(message)
# Verify different indexes used for different dimensions # Verify correct index used for 2D vector
assert processor.pinecone.Index.call_count == 2 processor.pinecone.Index.assert_called_with("t-test_user-test_collection-2")
index_calls = processor.pinecone.Index.call_args_list
index_names = [call[0][0] for call in index_calls]
assert "t-test_user-test_collection-2" in index_names # 2D vector
assert "t-test_user-test_collection-4" in index_names # 4D vector
# Verify both queries were made # Verify query was made
assert mock_index.query.call_count == 2 assert mock_index.query.call_count == 1
# Verify results from both dimensions # Verify results with EntityMatch structure
entity_values = [e.value for e in entities] entity_values = [e.entity.value for e in entities]
assert 'entity_2d' in entity_values assert 'entity_2d' in entity_values
assert 'entity_4d' in entity_values
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_query_graph_embeddings_empty_vectors_list(self, processor): async def test_query_graph_embeddings_empty_vectors_list(self, processor):
"""Test querying with empty vectors list""" """Test querying with empty vectors list"""
message = MagicMock() message = MagicMock()
message.vectors = [] message.vector = []
message.limit = 5 message.limit = 5
message.user = 'test_user' message.user = 'test_user'
message.collection = 'test_collection' message.collection = 'test_collection'
@ -331,7 +311,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
async def test_query_graph_embeddings_no_results(self, processor): async def test_query_graph_embeddings_no_results(self, processor):
"""Test querying when index returns no results""" """Test querying when index returns no results"""
message = MagicMock() message = MagicMock()
message.vectors = [[0.1, 0.2, 0.3]] message.vector = [0.1, 0.2, 0.3]
message.limit = 5 message.limit = 5
message.user = 'test_user' message.user = 'test_user'
message.collection = 'test_collection' message.collection = 'test_collection'
@ -349,43 +329,34 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
assert entities == [] assert entities == []
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_query_graph_embeddings_deduplication_across_vectors(self, processor): async def test_query_graph_embeddings_deduplication_in_results(self, processor):
"""Test that deduplication works correctly across multiple vector queries""" """Test that deduplication works correctly within query results"""
message = MagicMock() message = MagicMock()
message.vectors = [ message.vector = [0.1, 0.2, 0.3]
[0.1, 0.2, 0.3],
[0.4, 0.5, 0.6]
]
message.limit = 3 message.limit = 3
message.user = 'test_user' message.user = 'test_user'
message.collection = 'test_collection' message.collection = 'test_collection'
mock_index = MagicMock() mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index processor.pinecone.Index.return_value = mock_index
# Both queries return overlapping results # Query returns results with some duplicates
mock_results1 = MagicMock() mock_results = MagicMock()
mock_results1.matches = [ mock_results.matches = [
MagicMock(metadata={'entity': 'entity1'}), MagicMock(metadata={'entity': 'entity1'}),
MagicMock(metadata={'entity': 'entity2'}), MagicMock(metadata={'entity': 'entity2'}),
MagicMock(metadata={'entity': 'entity1'}), # Duplicate
MagicMock(metadata={'entity': 'entity3'}), MagicMock(metadata={'entity': 'entity3'}),
MagicMock(metadata={'entity': 'entity4'})
]
mock_results2 = MagicMock()
mock_results2.matches = [
MagicMock(metadata={'entity': 'entity2'}), # Duplicate MagicMock(metadata={'entity': 'entity2'}), # Duplicate
MagicMock(metadata={'entity': 'entity3'}), # Duplicate
MagicMock(metadata={'entity': 'entity5'})
] ]
mock_index.query.side_effect = [mock_results1, mock_results2] mock_index.query.return_value = mock_results
entities = await processor.query_graph_embeddings(message) entities = await processor.query_graph_embeddings(message)
# Should get exactly 3 unique entities (respecting limit) # Should get exactly 3 unique entities (respecting limit)
assert len(entities) == 3 assert len(entities) == 3
entity_values = [e.value for e in entities] entity_values = [e.entity.value for e in entities]
assert len(set(entity_values)) == 3 # All unique assert len(set(entity_values)) == 3 # All unique
@pytest.mark.asyncio @pytest.mark.asyncio
@ -423,7 +394,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
async def test_query_graph_embeddings_exception_handling(self, processor): async def test_query_graph_embeddings_exception_handling(self, processor):
"""Test that exceptions are properly raised""" """Test that exceptions are properly raised"""
message = MagicMock() message = MagicMock()
message.vectors = [[0.1, 0.2, 0.3]] message.vector = [0.1, 0.2, 0.3]
message.limit = 5 message.limit = 5
message.user = 'test_user' message.user = 'test_user'
message.collection = 'test_collection' message.collection = 'test_collection'