mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-05-15 18:35:15 +02:00
Fixing tests
This commit is contained in:
parent
2cb29380fa
commit
9782421b1f
1 changed files with 12 additions and 16 deletions
|
|
@ -360,33 +360,29 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
|
||||||
assert len(set(entity_values)) == 3 # All unique
|
assert len(set(entity_values)) == 3 # All unique
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_query_graph_embeddings_early_termination_on_limit(self, processor):
|
async def test_query_graph_embeddings_respects_limit(self, processor):
|
||||||
"""Test that querying stops early when limit is reached"""
|
"""Test that query respects limit parameter"""
|
||||||
message = MagicMock()
|
message = MagicMock()
|
||||||
message.vectors = [
|
message.vector = [0.1, 0.2, 0.3]
|
||||||
[0.1, 0.2, 0.3],
|
|
||||||
[0.4, 0.5, 0.6],
|
|
||||||
[0.7, 0.8, 0.9]
|
|
||||||
]
|
|
||||||
message.limit = 2
|
message.limit = 2
|
||||||
message.user = 'test_user'
|
message.user = 'test_user'
|
||||||
message.collection = 'test_collection'
|
message.collection = 'test_collection'
|
||||||
|
|
||||||
mock_index = MagicMock()
|
mock_index = MagicMock()
|
||||||
processor.pinecone.Index.return_value = mock_index
|
processor.pinecone.Index.return_value = mock_index
|
||||||
|
|
||||||
# First query returns enough results to meet limit
|
# Query returns more results than limit
|
||||||
mock_results1 = MagicMock()
|
mock_results = MagicMock()
|
||||||
mock_results1.matches = [
|
mock_results.matches = [
|
||||||
MagicMock(metadata={'entity': 'entity1'}),
|
MagicMock(metadata={'entity': 'entity1'}),
|
||||||
MagicMock(metadata={'entity': 'entity2'}),
|
MagicMock(metadata={'entity': 'entity2'}),
|
||||||
MagicMock(metadata={'entity': 'entity3'})
|
MagicMock(metadata={'entity': 'entity3'})
|
||||||
]
|
]
|
||||||
mock_index.query.return_value = mock_results1
|
mock_index.query.return_value = mock_results
|
||||||
|
|
||||||
entities = await processor.query_graph_embeddings(message)
|
entities = await processor.query_graph_embeddings(message)
|
||||||
|
|
||||||
# Should only make one query since limit was reached
|
# Should only return 2 entities (respecting limit)
|
||||||
mock_index.query.assert_called_once()
|
mock_index.query.assert_called_once()
|
||||||
assert len(entities) == 2
|
assert len(entities) == 2
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue