Fixing tests

This commit is contained in:
Cyber MacGeddon 2026-03-09 10:46:03 +00:00
parent d8a4a8c57f
commit 2670c5f80e
7 changed files with 177 additions and 206 deletions

View file

@ -6,7 +6,7 @@ import pytest
from unittest.mock import MagicMock, patch
from trustgraph.query.doc_embeddings.milvus.service import Processor
from trustgraph.schema import DocumentEmbeddingsRequest
from trustgraph.schema import DocumentEmbeddingsRequest, ChunkMatch
class TestMilvusDocEmbeddingsQueryProcessor:
@ -90,11 +90,12 @@ class TestMilvusDocEmbeddingsQueryProcessor:
[0.1, 0.2, 0.3], 'test_user', 'test_collection', limit=5
)
# Verify results are document chunks
# Verify results are ChunkMatch objects
assert len(result) == 3
assert result[0] == "First document chunk"
assert result[1] == "Second document chunk"
assert result[2] == "Third document chunk"
assert isinstance(result[0], ChunkMatch)
assert result[0].chunk_id == "First document chunk"
assert result[1].chunk_id == "Second document chunk"
assert result[2].chunk_id == "Third document chunk"
@pytest.mark.asyncio
async def test_query_document_embeddings_longer_vector(self, processor):
@ -121,11 +122,12 @@ class TestMilvusDocEmbeddingsQueryProcessor:
[0.1, 0.2, 0.3, 0.4, 0.5, 0.6], 'test_user', 'test_collection', limit=3
)
# Verify results
# Verify results are ChunkMatch objects
assert len(result) == 3
assert "First document" in result
assert "Second document" in result
assert "Third document" in result
chunk_ids = [r.chunk_id for r in result]
assert "First document" in chunk_ids
assert "Second document" in chunk_ids
assert "Third document" in chunk_ids
@pytest.mark.asyncio
async def test_query_document_embeddings_with_limit(self, processor):
@ -217,11 +219,12 @@ class TestMilvusDocEmbeddingsQueryProcessor:
result = await processor.query_document_embeddings(query)
# Verify Unicode content is preserved
# Verify Unicode content is preserved in ChunkMatch objects
assert len(result) == 3
assert "Document with Unicode: éñ中文🚀" in result
assert "Regular ASCII document" in result
assert "Document with émojis: 😀🎉" in result
chunk_ids = [r.chunk_id for r in result]
assert "Document with Unicode: éñ中文🚀" in chunk_ids
assert "Regular ASCII document" in chunk_ids
assert "Document with émojis: 😀🎉" in chunk_ids
@pytest.mark.asyncio
async def test_query_document_embeddings_large_documents(self, processor):
@ -243,10 +246,11 @@ class TestMilvusDocEmbeddingsQueryProcessor:
result = await processor.query_document_embeddings(query)
# Verify large content is preserved
# Verify large content is preserved in ChunkMatch objects
assert len(result) == 2
assert large_doc in result
assert "Small document" in result
chunk_ids = [r.chunk_id for r in result]
assert large_doc in chunk_ids
assert "Small document" in chunk_ids
@pytest.mark.asyncio
async def test_query_document_embeddings_special_characters(self, processor):
@ -268,11 +272,12 @@ class TestMilvusDocEmbeddingsQueryProcessor:
result = await processor.query_document_embeddings(query)
# Verify special characters are preserved
# Verify special characters are preserved in ChunkMatch objects
assert len(result) == 3
assert "Document with \"quotes\" and 'apostrophes'" in result
assert "Document with\nnewlines\tand\ttabs" in result
assert "Document with special chars: @#$%^&*()" in result
chunk_ids = [r.chunk_id for r in result]
assert "Document with \"quotes\" and 'apostrophes'" in chunk_ids
assert "Document with\nnewlines\tand\ttabs" in chunk_ids
assert "Document with special chars: @#$%^&*()" in chunk_ids
@pytest.mark.asyncio
async def test_query_document_embeddings_zero_limit(self, processor):
@ -349,10 +354,11 @@ class TestMilvusDocEmbeddingsQueryProcessor:
# Verify search was called with the vector
processor.vecstore.search.assert_called_once()
# Verify results
# Verify results are ChunkMatch objects
assert len(result) == 2
assert "Document 1" in result
assert "Document 2" in result
chunk_ids = [r.chunk_id for r in result]
assert "Document 1" in chunk_ids
assert "Document 2" in chunk_ids
@pytest.mark.asyncio
async def test_query_document_embeddings_multiple_results(self, processor):
@ -374,11 +380,12 @@ class TestMilvusDocEmbeddingsQueryProcessor:
result = await processor.query_document_embeddings(query)
# Verify results
# Verify results are ChunkMatch objects
assert len(result) == 3
assert "Document A" in result
assert "Document B" in result
assert "Document C" in result
chunk_ids = [r.chunk_id for r in result]
assert "Document A" in chunk_ids
assert "Document B" in chunk_ids
assert "Document C" in chunk_ids
def test_add_args_method(self):
"""Test that add_args properly configures argument parser"""

View file

@ -9,6 +9,7 @@ from unittest import IsolatedAsyncioTestCase
# Import the service under test
from trustgraph.query.doc_embeddings.qdrant.service import Processor
from trustgraph.schema import ChunkMatch
class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
@ -112,72 +113,69 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
with_payload=True
)
# Verify result contains expected documents
# Verify result contains expected ChunkMatch objects
assert len(result) == 2
# Results should be strings (document chunks)
assert isinstance(result[0], str)
assert isinstance(result[1], str)
# Results should be ChunkMatch objects
assert isinstance(result[0], ChunkMatch)
assert isinstance(result[1], ChunkMatch)
# Verify content
assert result[0] == 'first document chunk'
assert result[1] == 'second document chunk'
assert result[0].chunk_id == 'first document chunk'
assert result[1].chunk_id == 'second document chunk'
@patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient')
@patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__')
async def test_query_document_embeddings_multiple_vectors(self, mock_base_init, mock_qdrant_client):
"""Test querying document embeddings with multiple vectors"""
async def test_query_document_embeddings_multiple_results(self, mock_base_init, mock_qdrant_client):
"""Test querying document embeddings returns multiple results"""
# Arrange
mock_base_init.return_value = None
mock_qdrant_instance = MagicMock()
mock_qdrant_client.return_value = mock_qdrant_instance
# Mock query responses for different vectors
# Mock query response with multiple results
mock_point1 = MagicMock()
mock_point1.payload = {'chunk_id': 'document from vector 1'}
mock_point1.payload = {'chunk_id': 'document chunk 1'}
mock_point2 = MagicMock()
mock_point2.payload = {'chunk_id': 'document from vector 2'}
mock_point2.payload = {'chunk_id': 'document chunk 2'}
mock_point3 = MagicMock()
mock_point3.payload = {'chunk_id': 'another document from vector 2'}
mock_response1 = MagicMock()
mock_response1.points = [mock_point1]
mock_response2 = MagicMock()
mock_response2.points = [mock_point2, mock_point3]
mock_qdrant_instance.query_points.side_effect = [mock_response1, mock_response2]
mock_point3.payload = {'chunk_id': 'document chunk 3'}
mock_response = MagicMock()
mock_response.points = [mock_point1, mock_point2, mock_point3]
mock_qdrant_instance.query_points.return_value = mock_response
config = {
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Create mock message with multiple vectors
# Create mock message with single vector
mock_message = MagicMock()
mock_message.vectors = [[0.1, 0.2], [0.3, 0.4]]
mock_message.vector = [0.1, 0.2]
mock_message.limit = 3
mock_message.user = 'multi_user'
mock_message.collection = 'multi_collection'
# Act
result = await processor.query_document_embeddings(mock_message)
# Assert
# Verify query was called twice
assert mock_qdrant_instance.query_points.call_count == 2
# Verify query was called once
assert mock_qdrant_instance.query_points.call_count == 1
# Verify both collections were queried (both 2-dimensional vectors)
# Verify collection was queried correctly
expected_collection = 'd_multi_user_multi_collection_2' # 2 dimensions
calls = mock_qdrant_instance.query_points.call_args_list
assert calls[0][1]['collection_name'] == expected_collection
assert calls[1][1]['collection_name'] == expected_collection
assert calls[0][1]['query'] == [0.1, 0.2]
assert calls[1][1]['query'] == [0.3, 0.4]
# Verify results from both vectors are combined
# Verify results are ChunkMatch objects
assert len(result) == 3
assert 'document from vector 1' in result
assert 'document from vector 2' in result
assert 'another document from vector 2' in result
chunk_ids = [r.chunk_id for r in result]
assert 'document chunk 1' in chunk_ids
assert 'document chunk 2' in chunk_ids
assert 'document chunk 3' in chunk_ids
@patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient')
@patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__')
@ -262,58 +260,53 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
@patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient')
@patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__')
async def test_query_document_embeddings_different_dimensions(self, mock_base_init, mock_qdrant_client):
"""Test querying document embeddings with different vector dimensions"""
"""Test querying document embeddings with a higher dimension vector"""
# Arrange
mock_base_init.return_value = None
mock_qdrant_instance = MagicMock()
mock_qdrant_client.return_value = mock_qdrant_instance
# Mock query responses
# Mock query response
mock_point1 = MagicMock()
mock_point1.payload = {'chunk_id': 'document from 2D vector'}
mock_point1.payload = {'chunk_id': 'document from 5D vector'}
mock_point2 = MagicMock()
mock_point2.payload = {'chunk_id': 'document from 3D vector'}
mock_response1 = MagicMock()
mock_response1.points = [mock_point1]
mock_response2 = MagicMock()
mock_response2.points = [mock_point2]
mock_qdrant_instance.query_points.side_effect = [mock_response1, mock_response2]
mock_point2.payload = {'chunk_id': 'another 5D document'}
mock_response = MagicMock()
mock_response.points = [mock_point1, mock_point2]
mock_qdrant_instance.query_points.return_value = mock_response
config = {
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Create mock message with different dimension vectors
# Create mock message with 5D vector
mock_message = MagicMock()
mock_message.vectors = [[0.1, 0.2], [0.3, 0.4, 0.5]] # 2D and 3D
mock_message.vector = [0.1, 0.2, 0.3, 0.4, 0.5] # 5D vector
mock_message.limit = 5
mock_message.user = 'dim_user'
mock_message.collection = 'dim_collection'
# Act
result = await processor.query_document_embeddings(mock_message)
# Assert
# Verify query was called twice with different collections
assert mock_qdrant_instance.query_points.call_count == 2
# Verify query was called once with correct collection
assert mock_qdrant_instance.query_points.call_count == 1
calls = mock_qdrant_instance.query_points.call_args_list
# First call should use 2D collection
assert calls[0][1]['collection_name'] == 'd_dim_user_dim_collection_2' # 2 dimensions
assert calls[0][1]['query'] == [0.1, 0.2]
# Call should use 5D collection
assert calls[0][1]['collection_name'] == 'd_dim_user_dim_collection_5' # 5 dimensions
assert calls[0][1]['query'] == [0.1, 0.2, 0.3, 0.4, 0.5]
# Second call should use 3D collection
assert calls[1][1]['collection_name'] == 'd_dim_user_dim_collection_3' # 3 dimensions
assert calls[1][1]['query'] == [0.3, 0.4, 0.5]
# Verify results
# Verify results are ChunkMatch objects
assert len(result) == 2
assert 'document from 2D vector' in result
assert 'document from 3D vector' in result
chunk_ids = [r.chunk_id for r in result]
assert 'document from 5D vector' in chunk_ids
assert 'another 5D document' in chunk_ids
@patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient')
@patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__')
@ -353,10 +346,11 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
# Assert
assert len(result) == 2
# Verify UTF-8 content works correctly
assert 'Document with UTF-8: café, naïve, résumé' in result
assert 'Chinese text: 你好世界' in result
# Verify UTF-8 content works correctly in ChunkMatch objects
chunk_ids = [r.chunk_id for r in result]
assert 'Document with UTF-8: café, naïve, résumé' in chunk_ids
assert 'Chinese text: 你好世界' in chunk_ids
@patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient')
@patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__')
@ -426,10 +420,11 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
mock_qdrant_instance.query_points.assert_called_once()
call_args = mock_qdrant_instance.query_points.call_args
assert call_args[1]['limit'] == 0
# Result should contain all returned documents
# Result should contain all returned documents as ChunkMatch objects
assert len(result) == 1
assert result[0] == 'document chunk'
assert isinstance(result[0], ChunkMatch)
assert result[0].chunk_id == 'document chunk'
@patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient')
@patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__')
@ -472,11 +467,12 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
mock_qdrant_instance.query_points.assert_called_once()
call_args = mock_qdrant_instance.query_points.call_args
assert call_args[1]['limit'] == 1000
# Result should contain all available documents
# Result should contain all available documents as ChunkMatch objects
assert len(result) == 2
assert 'document 1' in result
assert 'document 2' in result
chunk_ids = [r.chunk_id for r in result]
assert 'document 1' in chunk_ids
assert 'document 2' in chunk_ids
@patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient')
@patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__')

View file

@ -151,40 +151,31 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
assert result[2].entity.type == LITERAL
@pytest.mark.asyncio
async def test_query_graph_embeddings_multiple_vectors(self, processor):
"""Test querying graph embeddings with multiple vectors"""
async def test_query_graph_embeddings_multiple_results(self, processor):
"""Test querying graph embeddings returns multiple results"""
query = GraphEmbeddingsRequest(
user='test_user',
collection='test_collection',
vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
limit=3
limit=5
)
# Mock search results - different results for each vector
mock_results_1 = [
# Mock search results with multiple entities
mock_results = [
{"entity": {"entity": "http://example.com/entity1"}},
{"entity": {"entity": "http://example.com/entity2"}},
]
mock_results_2 = [
{"entity": {"entity": "http://example.com/entity2"}}, # Duplicate
{"entity": {"entity": "http://example.com/entity3"}},
]
processor.vecstore.search.side_effect = [mock_results_1, mock_results_2]
processor.vecstore.search.return_value = mock_results
result = await processor.query_graph_embeddings(query)
# Verify search was called twice with correct parameters including user/collection
expected_calls = [
(([0.1, 0.2, 0.3], 'test_user', 'test_collection'), {"limit": 6}),
(([0.4, 0.5, 0.6], 'test_user', 'test_collection'), {"limit": 6}),
]
assert processor.vecstore.search.call_count == 2
for i, (expected_args, expected_kwargs) in enumerate(expected_calls):
actual_call = processor.vecstore.search.call_args_list[i]
assert actual_call[0] == expected_args
assert actual_call[1] == expected_kwargs
# Verify results are deduplicated and limited
# Verify search was called once with the full vector
processor.vecstore.search.assert_called_once_with(
[0.1, 0.2, 0.3, 0.4, 0.5, 0.6], 'test_user', 'test_collection', limit=10
)
# Verify results are EntityMatch objects
assert len(result) == 3
entity_values = [r.entity.iri if r.entity.type == IRI else r.entity.value for r in result]
assert "http://example.com/entity1" in entity_values
@ -221,63 +212,57 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
assert len(result) == 2
@pytest.mark.asyncio
async def test_query_graph_embeddings_deduplication(self, processor):
"""Test that duplicate entities are properly deduplicated"""
async def test_query_graph_embeddings_preserves_order(self, processor):
"""Test that query results preserve order from the vector store"""
query = GraphEmbeddingsRequest(
user='test_user',
collection='test_collection',
vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
limit=5
)
# Mock search results with duplicates
mock_results_1 = [
# Mock search results in specific order
mock_results = [
{"entity": {"entity": "http://example.com/entity1"}},
{"entity": {"entity": "http://example.com/entity2"}},
{"entity": {"entity": "http://example.com/entity3"}},
]
mock_results_2 = [
{"entity": {"entity": "http://example.com/entity2"}}, # Duplicate
{"entity": {"entity": "http://example.com/entity1"}}, # Duplicate
{"entity": {"entity": "http://example.com/entity3"}}, # New
]
processor.vecstore.search.side_effect = [mock_results_1, mock_results_2]
processor.vecstore.search.return_value = mock_results
result = await processor.query_graph_embeddings(query)
# Verify duplicates are removed
# Verify results are in the same order as returned by the store
assert len(result) == 3
entity_values = [r.entity.iri if r.entity.type == IRI else r.entity.value for r in result]
assert len(set(entity_values)) == 3 # All unique
assert "http://example.com/entity1" in entity_values
assert "http://example.com/entity2" in entity_values
assert "http://example.com/entity3" in entity_values
assert result[0].entity.iri == "http://example.com/entity1"
assert result[1].entity.iri == "http://example.com/entity2"
assert result[2].entity.iri == "http://example.com/entity3"
@pytest.mark.asyncio
async def test_query_graph_embeddings_early_termination_on_limit(self, processor):
"""Test that querying stops early when limit is reached"""
async def test_query_graph_embeddings_results_limited(self, processor):
"""Test that results are properly limited when store returns more than requested"""
query = GraphEmbeddingsRequest(
user='test_user',
collection='test_collection',
vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
limit=2
)
# Mock search results - first vector returns enough results
mock_results_1 = [
# Mock search results - returns more results than limit
mock_results = [
{"entity": {"entity": "http://example.com/entity1"}},
{"entity": {"entity": "http://example.com/entity2"}},
{"entity": {"entity": "http://example.com/entity3"}},
]
processor.vecstore.search.return_value = mock_results_1
processor.vecstore.search.return_value = mock_results
result = await processor.query_graph_embeddings(query)
# Verify only first vector was searched (limit reached)
# Verify search was called with the full vector
processor.vecstore.search.assert_called_once_with(
[0.1, 0.2, 0.3], 'test_user', 'test_collection', limit=4
[0.1, 0.2, 0.3, 0.4, 0.5, 0.6], 'test_user', 'test_collection', limit=4
)
# Verify results are limited
# Verify results are limited to requested amount
assert len(result) == 2
@pytest.mark.asyncio