diff --git a/tests/unit/test_embeddings/test_fastembed_dynamic_model.py b/tests/unit/test_embeddings/test_fastembed_dynamic_model.py index ca43bf83..f4e456cb 100644 --- a/tests/unit/test_embeddings/test_fastembed_dynamic_model.py +++ b/tests/unit/test_embeddings/test_fastembed_dynamic_model.py @@ -108,7 +108,7 @@ class TestFastEmbedDynamicModelLoading(IsolatedAsyncioTestCase): # Assert mock_fastembed_instance.embed.assert_called_once_with(["test text"]) assert processor.cached_model_name == "test-model" # Still using default - assert result == [[[0.1, 0.2, 0.3, 0.4, 0.5]]] + assert result == [[0.1, 0.2, 0.3, 0.4, 0.5]] @patch('trustgraph.embeddings.fastembed.processor.TextEmbedding') @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') diff --git a/tests/unit/test_embeddings/test_ollama_dynamic_model.py b/tests/unit/test_embeddings/test_ollama_dynamic_model.py index 80e1de4e..d52a58c6 100644 --- a/tests/unit/test_embeddings/test_ollama_dynamic_model.py +++ b/tests/unit/test_embeddings/test_ollama_dynamic_model.py @@ -60,7 +60,7 @@ class TestOllamaDynamicModelLoading(IsolatedAsyncioTestCase): model="test-model", input=["test text"] ) - assert result == [[[0.1, 0.2, 0.3, 0.4, 0.5]]] + assert result == [[0.1, 0.2, 0.3, 0.4, 0.5]] @patch('trustgraph.embeddings.ollama.processor.Client') @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') @@ -86,7 +86,7 @@ class TestOllamaDynamicModelLoading(IsolatedAsyncioTestCase): model="custom-model", input=["test text"] ) - assert result == [[[0.1, 0.2, 0.3, 0.4, 0.5]]] + assert result == [[0.1, 0.2, 0.3, 0.4, 0.5]] @patch('trustgraph.embeddings.ollama.processor.Client') @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') diff --git a/tests/unit/test_query/test_doc_embeddings_milvus_query.py b/tests/unit/test_query/test_doc_embeddings_milvus_query.py index 01efa146..1cddce97 100644 --- a/tests/unit/test_query/test_doc_embeddings_milvus_query.py +++ b/tests/unit/test_query/test_doc_embeddings_milvus_query.py @@ -6,7 +6,7 @@ import pytest from unittest.mock import MagicMock, patch from trustgraph.query.doc_embeddings.milvus.service import Processor -from trustgraph.schema import DocumentEmbeddingsRequest +from trustgraph.schema import DocumentEmbeddingsRequest, ChunkMatch class TestMilvusDocEmbeddingsQueryProcessor: @@ -90,11 +90,12 @@ class TestMilvusDocEmbeddingsQueryProcessor: [0.1, 0.2, 0.3], 'test_user', 'test_collection', limit=5 ) - # Verify results are document chunks + # Verify results are ChunkMatch objects assert len(result) == 3 - assert result[0] == "First document chunk" - assert result[1] == "Second document chunk" - assert result[2] == "Third document chunk" + assert isinstance(result[0], ChunkMatch) + assert result[0].chunk_id == "First document chunk" + assert result[1].chunk_id == "Second document chunk" + assert result[2].chunk_id == "Third document chunk" @pytest.mark.asyncio async def test_query_document_embeddings_longer_vector(self, processor): @@ -121,11 +122,12 @@ class TestMilvusDocEmbeddingsQueryProcessor: [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], 'test_user', 'test_collection', limit=3 ) - # Verify results + # Verify results are ChunkMatch objects assert len(result) == 3 - assert "First document" in result - assert "Second document" in result - assert "Third document" in result + chunk_ids = [r.chunk_id for r in result] + assert "First document" in chunk_ids + assert "Second document" in chunk_ids + assert "Third document" in chunk_ids @pytest.mark.asyncio async def test_query_document_embeddings_with_limit(self, processor): @@ -217,11 +219,12 @@ class TestMilvusDocEmbeddingsQueryProcessor: result = await processor.query_document_embeddings(query) - # Verify Unicode content is preserved + # Verify Unicode content is preserved in ChunkMatch objects assert len(result) == 3 - assert "Document with Unicode: éñ中文🚀" in result - assert "Regular ASCII document" in result - assert "Document with émojis: 😀🎉" in result + chunk_ids = [r.chunk_id for r in result] + assert "Document with Unicode: éñ中文🚀" in chunk_ids + assert "Regular ASCII document" in chunk_ids + assert "Document with émojis: 😀🎉" in chunk_ids @pytest.mark.asyncio async def test_query_document_embeddings_large_documents(self, processor): @@ -243,10 +246,11 @@ class TestMilvusDocEmbeddingsQueryProcessor: result = await processor.query_document_embeddings(query) - # Verify large content is preserved + # Verify large content is preserved in ChunkMatch objects assert len(result) == 2 - assert large_doc in result - assert "Small document" in result + chunk_ids = [r.chunk_id for r in result] + assert large_doc in chunk_ids + assert "Small document" in chunk_ids @pytest.mark.asyncio async def test_query_document_embeddings_special_characters(self, processor): @@ -268,11 +272,12 @@ class TestMilvusDocEmbeddingsQueryProcessor: result = await processor.query_document_embeddings(query) - # Verify special characters are preserved + # Verify special characters are preserved in ChunkMatch objects assert len(result) == 3 - assert "Document with \"quotes\" and 'apostrophes'" in result - assert "Document with\nnewlines\tand\ttabs" in result - assert "Document with special chars: @#$%^&*()" in result + chunk_ids = [r.chunk_id for r in result] + assert "Document with \"quotes\" and 'apostrophes'" in chunk_ids + assert "Document with\nnewlines\tand\ttabs" in chunk_ids + assert "Document with special chars: @#$%^&*()" in chunk_ids @pytest.mark.asyncio async def test_query_document_embeddings_zero_limit(self, processor): @@ -349,10 +354,11 @@ class TestMilvusDocEmbeddingsQueryProcessor: # Verify search was called with the vector processor.vecstore.search.assert_called_once() - # Verify results + # Verify results are ChunkMatch objects assert len(result) == 2 - assert "Document 1" in result - assert "Document 2" in result + chunk_ids = [r.chunk_id for r in result] + assert "Document 1" in chunk_ids + assert "Document 2" in chunk_ids @pytest.mark.asyncio async def test_query_document_embeddings_multiple_results(self, processor): @@ -374,11 +380,12 @@ class TestMilvusDocEmbeddingsQueryProcessor: result = await processor.query_document_embeddings(query) - # Verify results + # Verify results are ChunkMatch objects assert len(result) == 3 - assert "Document A" in result - assert "Document B" in result - assert "Document C" in result + chunk_ids = [r.chunk_id for r in result] + assert "Document A" in chunk_ids + assert "Document B" in chunk_ids + assert "Document C" in chunk_ids def test_add_args_method(self): """Test that add_args properly configures argument parser""" diff --git a/tests/unit/test_query/test_doc_embeddings_qdrant_query.py b/tests/unit/test_query/test_doc_embeddings_qdrant_query.py index e6670c10..1d2f0e6d 100644 --- a/tests/unit/test_query/test_doc_embeddings_qdrant_query.py +++ b/tests/unit/test_query/test_doc_embeddings_qdrant_query.py @@ -9,6 +9,7 @@ from unittest import IsolatedAsyncioTestCase # Import the service under test from trustgraph.query.doc_embeddings.qdrant.service import Processor +from trustgraph.schema import ChunkMatch class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase): @@ -112,72 +113,69 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase): with_payload=True ) - # Verify result contains expected documents + # Verify result contains expected ChunkMatch objects assert len(result) == 2 - # Results should be strings (document chunks) - assert isinstance(result[0], str) - assert isinstance(result[1], str) + # Results should be ChunkMatch objects + assert isinstance(result[0], ChunkMatch) + assert isinstance(result[1], ChunkMatch) # Verify content - assert result[0] == 'first document chunk' - assert result[1] == 'second document chunk' + assert result[0].chunk_id == 'first document chunk' + assert result[1].chunk_id == 'second document chunk' @patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient') @patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__') - async def test_query_document_embeddings_multiple_vectors(self, mock_base_init, mock_qdrant_client): - """Test querying document embeddings with multiple vectors""" + async def test_query_document_embeddings_multiple_results(self, mock_base_init, mock_qdrant_client): + """Test querying document embeddings returns multiple results""" # Arrange mock_base_init.return_value = None mock_qdrant_instance = MagicMock() mock_qdrant_client.return_value = mock_qdrant_instance - - # Mock query responses for different vectors + + # Mock query response with multiple results mock_point1 = MagicMock() - mock_point1.payload = {'chunk_id': 'document from vector 1'} + mock_point1.payload = {'chunk_id': 'document chunk 1'} mock_point2 = MagicMock() - mock_point2.payload = {'chunk_id': 'document from vector 2'} + mock_point2.payload = {'chunk_id': 'document chunk 2'} mock_point3 = MagicMock() - mock_point3.payload = {'chunk_id': 'another document from vector 2'} - - mock_response1 = MagicMock() - mock_response1.points = [mock_point1] - mock_response2 = MagicMock() - mock_response2.points = [mock_point2, mock_point3] - mock_qdrant_instance.query_points.side_effect = [mock_response1, mock_response2] - + mock_point3.payload = {'chunk_id': 'document chunk 3'} + + mock_response = MagicMock() + mock_response.points = [mock_point1, mock_point2, mock_point3] + mock_qdrant_instance.query_points.return_value = mock_response + config = { 'taskgroup': AsyncMock(), 'id': 'test-processor' } processor = Processor(**config) - - # Create mock message with multiple vectors + + # Create mock message with single vector mock_message = MagicMock() - mock_message.vectors = [[0.1, 0.2], [0.3, 0.4]] + mock_message.vector = [0.1, 0.2] mock_message.limit = 3 mock_message.user = 'multi_user' mock_message.collection = 'multi_collection' - + # Act result = await processor.query_document_embeddings(mock_message) # Assert - # Verify query was called twice - assert mock_qdrant_instance.query_points.call_count == 2 + # Verify query was called once + assert mock_qdrant_instance.query_points.call_count == 1 - # Verify both collections were queried (both 2-dimensional vectors) + # Verify collection was queried correctly expected_collection = 'd_multi_user_multi_collection_2' # 2 dimensions calls = mock_qdrant_instance.query_points.call_args_list assert calls[0][1]['collection_name'] == expected_collection - assert calls[1][1]['collection_name'] == expected_collection assert calls[0][1]['query'] == [0.1, 0.2] - assert calls[1][1]['query'] == [0.3, 0.4] - - # Verify results from both vectors are combined + + # Verify results are ChunkMatch objects assert len(result) == 3 - assert 'document from vector 1' in result - assert 'document from vector 2' in result - assert 'another document from vector 2' in result + chunk_ids = [r.chunk_id for r in result] + assert 'document chunk 1' in chunk_ids + assert 'document chunk 2' in chunk_ids + assert 'document chunk 3' in chunk_ids @patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient') @patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__') @@ -262,58 +260,53 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase): @patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient') @patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__') async def test_query_document_embeddings_different_dimensions(self, mock_base_init, mock_qdrant_client): - """Test querying document embeddings with different vector dimensions""" + """Test querying document embeddings with a higher dimension vector""" # Arrange mock_base_init.return_value = None mock_qdrant_instance = MagicMock() mock_qdrant_client.return_value = mock_qdrant_instance - - # Mock query responses + + # Mock query response mock_point1 = MagicMock() - mock_point1.payload = {'chunk_id': 'document from 2D vector'} + mock_point1.payload = {'chunk_id': 'document from 5D vector'} mock_point2 = MagicMock() - mock_point2.payload = {'chunk_id': 'document from 3D vector'} - - mock_response1 = MagicMock() - mock_response1.points = [mock_point1] - mock_response2 = MagicMock() - mock_response2.points = [mock_point2] - mock_qdrant_instance.query_points.side_effect = [mock_response1, mock_response2] - + mock_point2.payload = {'chunk_id': 'another 5D document'} + + mock_response = MagicMock() + mock_response.points = [mock_point1, mock_point2] + mock_qdrant_instance.query_points.return_value = mock_response + config = { 'taskgroup': AsyncMock(), 'id': 'test-processor' } processor = Processor(**config) - - # Create mock message with different dimension vectors + + # Create mock message with 5D vector mock_message = MagicMock() - mock_message.vectors = [[0.1, 0.2], [0.3, 0.4, 0.5]] # 2D and 3D + mock_message.vector = [0.1, 0.2, 0.3, 0.4, 0.5] # 5D vector mock_message.limit = 5 mock_message.user = 'dim_user' mock_message.collection = 'dim_collection' - + # Act result = await processor.query_document_embeddings(mock_message) # Assert - # Verify query was called twice with different collections - assert mock_qdrant_instance.query_points.call_count == 2 + # Verify query was called once with correct collection + assert mock_qdrant_instance.query_points.call_count == 1 calls = mock_qdrant_instance.query_points.call_args_list - # First call should use 2D collection - assert calls[0][1]['collection_name'] == 'd_dim_user_dim_collection_2' # 2 dimensions - assert calls[0][1]['query'] == [0.1, 0.2] + # Call should use 5D collection + assert calls[0][1]['collection_name'] == 'd_dim_user_dim_collection_5' # 5 dimensions + assert calls[0][1]['query'] == [0.1, 0.2, 0.3, 0.4, 0.5] - # Second call should use 3D collection - assert calls[1][1]['collection_name'] == 'd_dim_user_dim_collection_3' # 3 dimensions - assert calls[1][1]['query'] == [0.3, 0.4, 0.5] - - # Verify results + # Verify results are ChunkMatch objects assert len(result) == 2 - assert 'document from 2D vector' in result - assert 'document from 3D vector' in result + chunk_ids = [r.chunk_id for r in result] + assert 'document from 5D vector' in chunk_ids + assert 'another 5D document' in chunk_ids @patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient') @patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__') @@ -353,10 +346,11 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase): # Assert assert len(result) == 2 - - # Verify UTF-8 content works correctly - assert 'Document with UTF-8: café, naïve, résumé' in result - assert 'Chinese text: 你好世界' in result + + # Verify UTF-8 content works correctly in ChunkMatch objects + chunk_ids = [r.chunk_id for r in result] + assert 'Document with UTF-8: café, naïve, résumé' in chunk_ids + assert 'Chinese text: 你好世界' in chunk_ids @patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient') @patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__') @@ -426,10 +420,11 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase): mock_qdrant_instance.query_points.assert_called_once() call_args = mock_qdrant_instance.query_points.call_args assert call_args[1]['limit'] == 0 - - # Result should contain all returned documents + + # Result should contain all returned documents as ChunkMatch objects assert len(result) == 1 - assert result[0] == 'document chunk' + assert isinstance(result[0], ChunkMatch) + assert result[0].chunk_id == 'document chunk' @patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient') @patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__') @@ -472,11 +467,12 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase): mock_qdrant_instance.query_points.assert_called_once() call_args = mock_qdrant_instance.query_points.call_args assert call_args[1]['limit'] == 1000 - - # Result should contain all available documents + + # Result should contain all available documents as ChunkMatch objects assert len(result) == 2 - assert 'document 1' in result - assert 'document 2' in result + chunk_ids = [r.chunk_id for r in result] + assert 'document 1' in chunk_ids + assert 'document 2' in chunk_ids @patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient') @patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__') diff --git a/tests/unit/test_query/test_graph_embeddings_milvus_query.py b/tests/unit/test_query/test_graph_embeddings_milvus_query.py index 458d613d..f2b8be7e 100644 --- a/tests/unit/test_query/test_graph_embeddings_milvus_query.py +++ b/tests/unit/test_query/test_graph_embeddings_milvus_query.py @@ -151,40 +151,31 @@ class TestMilvusGraphEmbeddingsQueryProcessor: assert result[2].entity.type == LITERAL @pytest.mark.asyncio - async def test_query_graph_embeddings_multiple_vectors(self, processor): - """Test querying graph embeddings with multiple vectors""" + async def test_query_graph_embeddings_multiple_results(self, processor): + """Test querying graph embeddings returns multiple results""" query = GraphEmbeddingsRequest( user='test_user', collection='test_collection', vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6], - limit=3 + limit=5 ) - - # Mock search results - different results for each vector - mock_results_1 = [ + + # Mock search results with multiple entities + mock_results = [ {"entity": {"entity": "http://example.com/entity1"}}, {"entity": {"entity": "http://example.com/entity2"}}, - ] - mock_results_2 = [ - {"entity": {"entity": "http://example.com/entity2"}}, # Duplicate {"entity": {"entity": "http://example.com/entity3"}}, ] - processor.vecstore.search.side_effect = [mock_results_1, mock_results_2] - + processor.vecstore.search.return_value = mock_results + result = await processor.query_graph_embeddings(query) - - # Verify search was called twice with correct parameters including user/collection - expected_calls = [ - (([0.1, 0.2, 0.3], 'test_user', 'test_collection'), {"limit": 6}), - (([0.4, 0.5, 0.6], 'test_user', 'test_collection'), {"limit": 6}), - ] - assert processor.vecstore.search.call_count == 2 - for i, (expected_args, expected_kwargs) in enumerate(expected_calls): - actual_call = processor.vecstore.search.call_args_list[i] - assert actual_call[0] == expected_args - assert actual_call[1] == expected_kwargs - - # Verify results are deduplicated and limited + + # Verify search was called once with the full vector + processor.vecstore.search.assert_called_once_with( + [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], 'test_user', 'test_collection', limit=10 + ) + + # Verify results are EntityMatch objects assert len(result) == 3 entity_values = [r.entity.iri if r.entity.type == IRI else r.entity.value for r in result] assert "http://example.com/entity1" in entity_values @@ -221,63 +212,57 @@ class TestMilvusGraphEmbeddingsQueryProcessor: assert len(result) == 2 @pytest.mark.asyncio - async def test_query_graph_embeddings_deduplication(self, processor): - """Test that duplicate entities are properly deduplicated""" + async def test_query_graph_embeddings_preserves_order(self, processor): + """Test that query results preserve order from the vector store""" query = GraphEmbeddingsRequest( user='test_user', collection='test_collection', vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6], limit=5 ) - - # Mock search results with duplicates - mock_results_1 = [ + + # Mock search results in specific order + mock_results = [ {"entity": {"entity": "http://example.com/entity1"}}, {"entity": {"entity": "http://example.com/entity2"}}, + {"entity": {"entity": "http://example.com/entity3"}}, ] - mock_results_2 = [ - {"entity": {"entity": "http://example.com/entity2"}}, # Duplicate - {"entity": {"entity": "http://example.com/entity1"}}, # Duplicate - {"entity": {"entity": "http://example.com/entity3"}}, # New - ] - processor.vecstore.search.side_effect = [mock_results_1, mock_results_2] - + processor.vecstore.search.return_value = mock_results + result = await processor.query_graph_embeddings(query) - - # Verify duplicates are removed + + # Verify results are in the same order as returned by the store assert len(result) == 3 - entity_values = [r.entity.iri if r.entity.type == IRI else r.entity.value for r in result] - assert len(set(entity_values)) == 3 # All unique - assert "http://example.com/entity1" in entity_values - assert "http://example.com/entity2" in entity_values - assert "http://example.com/entity3" in entity_values + assert result[0].entity.iri == "http://example.com/entity1" + assert result[1].entity.iri == "http://example.com/entity2" + assert result[2].entity.iri == "http://example.com/entity3" @pytest.mark.asyncio - async def test_query_graph_embeddings_early_termination_on_limit(self, processor): - """Test that querying stops early when limit is reached""" + async def test_query_graph_embeddings_results_limited(self, processor): + """Test that results are properly limited when store returns more than requested""" query = GraphEmbeddingsRequest( user='test_user', collection='test_collection', vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6], limit=2 ) - - # Mock search results - first vector returns enough results - mock_results_1 = [ + + # Mock search results - returns more results than limit + mock_results = [ {"entity": {"entity": "http://example.com/entity1"}}, {"entity": {"entity": "http://example.com/entity2"}}, {"entity": {"entity": "http://example.com/entity3"}}, ] - processor.vecstore.search.return_value = mock_results_1 - + processor.vecstore.search.return_value = mock_results + result = await processor.query_graph_embeddings(query) - - # Verify only first vector was searched (limit reached) + + # Verify search was called with the full vector processor.vecstore.search.assert_called_once_with( - [0.1, 0.2, 0.3], 'test_user', 'test_collection', limit=4 + [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], 'test_user', 'test_collection', limit=4 ) - - # Verify results are limited + + # Verify results are limited to requested amount assert len(result) == 2 @pytest.mark.asyncio diff --git a/tests/unit/test_storage/test_doc_embeddings_milvus_storage.py b/tests/unit/test_storage/test_doc_embeddings_milvus_storage.py index 4229e6a0..f9d60541 100644 --- a/tests/unit/test_storage/test_doc_embeddings_milvus_storage.py +++ b/tests/unit/test_storage/test_doc_embeddings_milvus_storage.py @@ -82,44 +82,34 @@ class TestMilvusDocEmbeddingsStorageProcessor: message.metadata = MagicMock() message.metadata.user = 'test_user' message.metadata.collection = 'test_collection' - + chunk = ChunkEmbeddings( chunk_id="Test document content", vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6] ) message.chunks = [chunk] - + await processor.store_document_embeddings(message) - - # Verify insert was called for each vector with user/collection parameters - expected_calls = [ - ([0.1, 0.2, 0.3], "Test document content", 'test_user', 'test_collection'), - ([0.4, 0.5, 0.6], "Test document content", 'test_user', 'test_collection'), - ] - - assert processor.vecstore.insert.call_count == 2 - for i, (expected_vec, expected_doc, expected_user, expected_collection) in enumerate(expected_calls): - actual_call = processor.vecstore.insert.call_args_list[i] - assert actual_call[0][0] == expected_vec - assert actual_call[0][1] == expected_doc - assert actual_call[0][2] == expected_user - assert actual_call[0][3] == expected_collection + + # Verify insert was called once for the single chunk with its vector + processor.vecstore.insert.assert_called_once_with( + [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], "Test document content", 'test_user', 'test_collection' + ) @pytest.mark.asyncio async def test_store_document_embeddings_multiple_chunks(self, processor, mock_message): """Test storing document embeddings for multiple chunks""" await processor.store_document_embeddings(mock_message) - - # Verify insert was called for each vector of each chunk with user/collection parameters + + # Verify insert was called once per chunk with user/collection parameters expected_calls = [ - # Chunk 1 vectors - ([0.1, 0.2, 0.3], "This is the first document chunk", 'test_user', 'test_collection'), - ([0.4, 0.5, 0.6], "This is the first document chunk", 'test_user', 'test_collection'), - # Chunk 2 vectors + # Chunk 1 - single vector + ([0.1, 0.2, 0.3, 0.4, 0.5, 0.6], "This is the first document chunk", 'test_user', 'test_collection'), + # Chunk 2 - single vector ([0.7, 0.8, 0.9], "This is the second document chunk", 'test_user', 'test_collection'), ] - - assert processor.vecstore.insert.call_count == 3 + + assert processor.vecstore.insert.call_count == 2 for i, (expected_vec, expected_doc, expected_user, expected_collection) in enumerate(expected_calls): actual_call = processor.vecstore.insert.call_args_list[i] assert actual_call[0][0] == expected_vec diff --git a/tests/unit/test_storage/test_graph_embeddings_milvus_storage.py b/tests/unit/test_storage/test_graph_embeddings_milvus_storage.py index 925cdca3..e4d60adf 100644 --- a/tests/unit/test_storage/test_graph_embeddings_milvus_storage.py +++ b/tests/unit/test_storage/test_graph_embeddings_milvus_storage.py @@ -82,44 +82,37 @@ class TestMilvusGraphEmbeddingsStorageProcessor: message.metadata = MagicMock() message.metadata.user = 'test_user' message.metadata.collection = 'test_collection' - + entity = EntityEmbeddings( entity=Term(type=IRI, iri='http://example.com/entity'), vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6] ) message.entities = [entity] - + await processor.store_graph_embeddings(message) - - # Verify insert was called for each vector with user/collection parameters - expected_calls = [ - ([0.1, 0.2, 0.3], 'http://example.com/entity', 'test_user', 'test_collection'), - ([0.4, 0.5, 0.6], 'http://example.com/entity', 'test_user', 'test_collection'), - ] - - assert processor.vecstore.insert.call_count == 2 - for i, (expected_vec, expected_entity, expected_user, expected_collection) in enumerate(expected_calls): - actual_call = processor.vecstore.insert.call_args_list[i] - assert actual_call[0][0] == expected_vec - assert actual_call[0][1] == expected_entity - assert actual_call[0][2] == expected_user - assert actual_call[0][3] == expected_collection + + # Verify insert was called once with the full vector + processor.vecstore.insert.assert_called_once() + actual_call = processor.vecstore.insert.call_args_list[0] + assert actual_call[0][0] == [0.1, 0.2, 0.3, 0.4, 0.5, 0.6] + assert actual_call[0][1] == 'http://example.com/entity' + assert actual_call[0][2] == 'test_user' + assert actual_call[0][3] == 'test_collection' @pytest.mark.asyncio async def test_store_graph_embeddings_multiple_entities(self, processor, mock_message): """Test storing graph embeddings for multiple entities""" await processor.store_graph_embeddings(mock_message) - - # Verify insert was called for each vector of each entity with user/collection parameters + + # Verify insert was called once per entity with user/collection parameters expected_calls = [ - # Entity 1 vectors - ([0.1, 0.2, 0.3], 'http://example.com/entity1', 'test_user', 'test_collection'), - ([0.4, 0.5, 0.6], 'http://example.com/entity1', 'test_user', 'test_collection'), - # Entity 2 vectors + # Entity 1 - single vector + ([0.1, 0.2, 0.3, 0.4, 0.5, 0.6], 'http://example.com/entity1', 'test_user', 'test_collection'), + # Entity 2 - single vector ([0.7, 0.8, 0.9], 'literal entity', 'test_user', 'test_collection'), ] - - assert processor.vecstore.insert.call_count == 3 + + assert processor.vecstore.insert.call_count == 2 for i, (expected_vec, expected_entity, expected_user, expected_collection) in enumerate(expected_calls): actual_call = processor.vecstore.insert.call_args_list[i] assert actual_call[0][0] == expected_vec