mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-05-03 12:22:37 +02:00
Fixing tests
This commit is contained in:
parent
f261604991
commit
2cb29380fa
1 changed files with 60 additions and 89 deletions
|
|
@ -9,7 +9,7 @@ from unittest.mock import MagicMock, patch
|
|||
pytest.skip("Pinecone library missing protoc_gen_openapiv2 dependency", allow_module_level=True)
|
||||
|
||||
from trustgraph.query.graph_embeddings.pinecone.service import Processor
|
||||
from trustgraph.schema import Term, IRI, LITERAL
|
||||
from trustgraph.schema import Term, IRI, LITERAL, EntityMatch
|
||||
|
||||
|
||||
class TestPineconeGraphEmbeddingsQueryProcessor:
|
||||
|
|
@ -19,10 +19,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
|
|||
def mock_query_message(self):
|
||||
"""Create a mock query message for testing"""
|
||||
message = MagicMock()
|
||||
message.vectors = [
|
||||
[0.1, 0.2, 0.3],
|
||||
[0.4, 0.5, 0.6]
|
||||
]
|
||||
message.vector = [0.1, 0.2, 0.3]
|
||||
message.limit = 5
|
||||
message.user = 'test_user'
|
||||
message.collection = 'test_collection'
|
||||
|
|
@ -131,7 +128,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
|
|||
async def test_query_graph_embeddings_single_vector(self, processor):
|
||||
"""Test querying graph embeddings with a single vector"""
|
||||
message = MagicMock()
|
||||
message.vectors = [[0.1, 0.2, 0.3]]
|
||||
message.vector = [0.1, 0.2, 0.3]
|
||||
message.limit = 3
|
||||
message.user = 'test_user'
|
||||
message.collection = 'test_collection'
|
||||
|
|
@ -162,45 +159,39 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
|
|||
include_metadata=True
|
||||
)
|
||||
|
||||
# Verify results
|
||||
# Verify results use EntityMatch structure
|
||||
assert len(entities) == 3
|
||||
assert entities[0].value == 'http://example.org/entity1'
|
||||
assert entities[0].type == IRI
|
||||
assert entities[1].value == 'entity2'
|
||||
assert entities[1].type == LITERAL
|
||||
assert entities[2].value == 'http://example.org/entity3'
|
||||
assert entities[2].type == IRI
|
||||
assert entities[0].entity.iri == 'http://example.org/entity1'
|
||||
assert entities[0].entity.type == IRI
|
||||
assert entities[1].entity.value == 'entity2'
|
||||
assert entities[1].entity.type == LITERAL
|
||||
assert entities[2].entity.iri == 'http://example.org/entity3'
|
||||
assert entities[2].entity.type == IRI
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_graph_embeddings_multiple_vectors(self, processor, mock_query_message):
|
||||
"""Test querying graph embeddings with multiple vectors"""
|
||||
async def test_query_graph_embeddings_basic(self, processor, mock_query_message):
|
||||
"""Test basic graph embeddings query"""
|
||||
# Mock index and query results
|
||||
mock_index = MagicMock()
|
||||
processor.pinecone.Index.return_value = mock_index
|
||||
|
||||
# First query results
|
||||
mock_results1 = MagicMock()
|
||||
mock_results1.matches = [
|
||||
|
||||
# Query results with distinct entities
|
||||
mock_results = MagicMock()
|
||||
mock_results.matches = [
|
||||
MagicMock(metadata={'entity': 'entity1'}),
|
||||
MagicMock(metadata={'entity': 'entity2'})
|
||||
]
|
||||
|
||||
# Second query results
|
||||
mock_results2 = MagicMock()
|
||||
mock_results2.matches = [
|
||||
MagicMock(metadata={'entity': 'entity2'}), # Duplicate
|
||||
MagicMock(metadata={'entity': 'entity2'}),
|
||||
MagicMock(metadata={'entity': 'entity3'})
|
||||
]
|
||||
|
||||
mock_index.query.side_effect = [mock_results1, mock_results2]
|
||||
|
||||
|
||||
mock_index.query.return_value = mock_results
|
||||
|
||||
entities = await processor.query_graph_embeddings(mock_query_message)
|
||||
|
||||
# Verify both queries were made
|
||||
assert mock_index.query.call_count == 2
|
||||
|
||||
# Verify deduplication occurred
|
||||
entity_values = [e.value for e in entities]
|
||||
|
||||
# Verify query was made once
|
||||
assert mock_index.query.call_count == 1
|
||||
|
||||
# Verify results with EntityMatch structure
|
||||
entity_values = [e.entity.value for e in entities]
|
||||
assert len(entity_values) == 3
|
||||
assert 'entity1' in entity_values
|
||||
assert 'entity2' in entity_values
|
||||
|
|
@ -210,7 +201,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
|
|||
async def test_query_graph_embeddings_limit_handling(self, processor):
|
||||
"""Test that query respects the limit parameter"""
|
||||
message = MagicMock()
|
||||
message.vectors = [[0.1, 0.2, 0.3]]
|
||||
message.vector = [0.1, 0.2, 0.3]
|
||||
message.limit = 2
|
||||
message.user = 'test_user'
|
||||
message.collection = 'test_collection'
|
||||
|
|
@ -234,7 +225,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
|
|||
async def test_query_graph_embeddings_zero_limit(self, processor):
|
||||
"""Test querying with zero limit returns empty results"""
|
||||
message = MagicMock()
|
||||
message.vectors = [[0.1, 0.2, 0.3]]
|
||||
message.vector = [0.1, 0.2, 0.3]
|
||||
message.limit = 0
|
||||
message.user = 'test_user'
|
||||
message.collection = 'test_collection'
|
||||
|
|
@ -252,7 +243,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
|
|||
async def test_query_graph_embeddings_negative_limit(self, processor):
|
||||
"""Test querying with negative limit returns empty results"""
|
||||
message = MagicMock()
|
||||
message.vectors = [[0.1, 0.2, 0.3]]
|
||||
message.vector = [0.1, 0.2, 0.3]
|
||||
message.limit = -1
|
||||
message.user = 'test_user'
|
||||
message.collection = 'test_collection'
|
||||
|
|
@ -267,52 +258,41 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
|
|||
assert entities == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_graph_embeddings_different_vector_dimensions(self, processor):
|
||||
"""Test querying with vectors of different dimensions using same index"""
|
||||
async def test_query_graph_embeddings_2d_vector(self, processor):
|
||||
"""Test querying with a 2D vector"""
|
||||
message = MagicMock()
|
||||
message.vectors = [
|
||||
[0.1, 0.2], # 2D vector
|
||||
[0.3, 0.4, 0.5, 0.6] # 4D vector
|
||||
]
|
||||
message.vector = [0.1, 0.2] # 2D vector
|
||||
message.limit = 5
|
||||
message.user = 'test_user'
|
||||
message.collection = 'test_collection'
|
||||
|
||||
# Mock single index that handles all dimensions
|
||||
# Mock index
|
||||
mock_index = MagicMock()
|
||||
processor.pinecone.Index.return_value = mock_index
|
||||
|
||||
# Mock results for different vector queries
|
||||
mock_results_2d = MagicMock()
|
||||
mock_results_2d.matches = [MagicMock(metadata={'entity': 'entity_2d'})]
|
||||
# Mock results for 2D vector query
|
||||
mock_results = MagicMock()
|
||||
mock_results.matches = [MagicMock(metadata={'entity': 'entity_2d'})]
|
||||
|
||||
mock_results_4d = MagicMock()
|
||||
mock_results_4d.matches = [MagicMock(metadata={'entity': 'entity_4d'})]
|
||||
|
||||
mock_index.query.side_effect = [mock_results_2d, mock_results_4d]
|
||||
mock_index.query.return_value = mock_results
|
||||
|
||||
entities = await processor.query_graph_embeddings(message)
|
||||
|
||||
# Verify different indexes used for different dimensions
|
||||
assert processor.pinecone.Index.call_count == 2
|
||||
index_calls = processor.pinecone.Index.call_args_list
|
||||
index_names = [call[0][0] for call in index_calls]
|
||||
assert "t-test_user-test_collection-2" in index_names # 2D vector
|
||||
assert "t-test_user-test_collection-4" in index_names # 4D vector
|
||||
# Verify correct index used for 2D vector
|
||||
processor.pinecone.Index.assert_called_with("t-test_user-test_collection-2")
|
||||
|
||||
# Verify both queries were made
|
||||
assert mock_index.query.call_count == 2
|
||||
# Verify query was made
|
||||
assert mock_index.query.call_count == 1
|
||||
|
||||
# Verify results from both dimensions
|
||||
entity_values = [e.value for e in entities]
|
||||
# Verify results with EntityMatch structure
|
||||
entity_values = [e.entity.value for e in entities]
|
||||
assert 'entity_2d' in entity_values
|
||||
assert 'entity_4d' in entity_values
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_graph_embeddings_empty_vectors_list(self, processor):
|
||||
"""Test querying with empty vectors list"""
|
||||
message = MagicMock()
|
||||
message.vectors = []
|
||||
message.vector = []
|
||||
message.limit = 5
|
||||
message.user = 'test_user'
|
||||
message.collection = 'test_collection'
|
||||
|
|
@ -331,7 +311,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
|
|||
async def test_query_graph_embeddings_no_results(self, processor):
|
||||
"""Test querying when index returns no results"""
|
||||
message = MagicMock()
|
||||
message.vectors = [[0.1, 0.2, 0.3]]
|
||||
message.vector = [0.1, 0.2, 0.3]
|
||||
message.limit = 5
|
||||
message.user = 'test_user'
|
||||
message.collection = 'test_collection'
|
||||
|
|
@ -349,43 +329,34 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
|
|||
assert entities == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_graph_embeddings_deduplication_across_vectors(self, processor):
|
||||
"""Test that deduplication works correctly across multiple vector queries"""
|
||||
async def test_query_graph_embeddings_deduplication_in_results(self, processor):
|
||||
"""Test that deduplication works correctly within query results"""
|
||||
message = MagicMock()
|
||||
message.vectors = [
|
||||
[0.1, 0.2, 0.3],
|
||||
[0.4, 0.5, 0.6]
|
||||
]
|
||||
message.vector = [0.1, 0.2, 0.3]
|
||||
message.limit = 3
|
||||
message.user = 'test_user'
|
||||
message.collection = 'test_collection'
|
||||
|
||||
|
||||
mock_index = MagicMock()
|
||||
processor.pinecone.Index.return_value = mock_index
|
||||
|
||||
# Both queries return overlapping results
|
||||
mock_results1 = MagicMock()
|
||||
mock_results1.matches = [
|
||||
|
||||
# Query returns results with some duplicates
|
||||
mock_results = MagicMock()
|
||||
mock_results.matches = [
|
||||
MagicMock(metadata={'entity': 'entity1'}),
|
||||
MagicMock(metadata={'entity': 'entity2'}),
|
||||
MagicMock(metadata={'entity': 'entity1'}), # Duplicate
|
||||
MagicMock(metadata={'entity': 'entity3'}),
|
||||
MagicMock(metadata={'entity': 'entity4'})
|
||||
]
|
||||
|
||||
mock_results2 = MagicMock()
|
||||
mock_results2.matches = [
|
||||
MagicMock(metadata={'entity': 'entity2'}), # Duplicate
|
||||
MagicMock(metadata={'entity': 'entity3'}), # Duplicate
|
||||
MagicMock(metadata={'entity': 'entity5'})
|
||||
]
|
||||
|
||||
mock_index.query.side_effect = [mock_results1, mock_results2]
|
||||
|
||||
|
||||
mock_index.query.return_value = mock_results
|
||||
|
||||
entities = await processor.query_graph_embeddings(message)
|
||||
|
||||
|
||||
# Should get exactly 3 unique entities (respecting limit)
|
||||
assert len(entities) == 3
|
||||
entity_values = [e.value for e in entities]
|
||||
entity_values = [e.entity.value for e in entities]
|
||||
assert len(set(entity_values)) == 3 # All unique
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -423,7 +394,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
|
|||
async def test_query_graph_embeddings_exception_handling(self, processor):
|
||||
"""Test that exceptions are properly raised"""
|
||||
message = MagicMock()
|
||||
message.vectors = [[0.1, 0.2, 0.3]]
|
||||
message.vector = [0.1, 0.2, 0.3]
|
||||
message.limit = 5
|
||||
message.user = 'test_user'
|
||||
message.collection = 'test_collection'
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue