diff --git a/tests/integration/test_graph_rag_integration.py b/tests/integration/test_graph_rag_integration.py index 7a15edb8..25a572c0 100644 --- a/tests/integration/test_graph_rag_integration.py +++ b/tests/integration/test_graph_rag_integration.py @@ -131,7 +131,7 @@ class TestGraphRagIntegration: # 2. Should query graph embeddings to find relevant entities mock_graph_embeddings_client.query.assert_called_once() call_args = mock_graph_embeddings_client.query.call_args - assert call_args.kwargs['vector'] == [0.1, 0.2, 0.3, 0.4, 0.5] + assert call_args.kwargs['vector'] == [[0.1, 0.2, 0.3, 0.4, 0.5]] assert call_args.kwargs['limit'] == entity_limit assert call_args.kwargs['user'] == user assert call_args.kwargs['collection'] == collection diff --git a/tests/unit/test_clients/test_sync_document_embeddings_client.py b/tests/unit/test_clients/test_sync_document_embeddings_client.py index 2458f583..ce758f66 100644 --- a/tests/unit/test_clients/test_sync_document_embeddings_client.py +++ b/tests/unit/test_clients/test_sync_document_embeddings_client.py @@ -101,18 +101,18 @@ class TestSyncDocumentEmbeddingsClient: mock_response = MagicMock() mock_response.chunks = ["test_chunk"] client.call = MagicMock(return_value=mock_response) - - vectors = [[0.1, 0.2, 0.3]] - + + vector = [0.1, 0.2, 0.3] + # Act - result = client.request(vectors=vectors) - + result = client.request(vector=vector) + # Assert assert result == ["test_chunk"] client.call.assert_called_once_with( user="trustgraph", collection="default", - vectors=vectors, + vector=vector, limit=10, timeout=300 ) @@ -127,10 +127,10 @@ class TestSyncDocumentEmbeddingsClient: mock_response = MagicMock() mock_response.chunks = [] client.call = MagicMock(return_value=mock_response) - + # Act - result = client.request(vectors=[[0.1, 0.2, 0.3]]) - + result = client.request(vector=[0.1, 0.2, 0.3]) + # Assert assert result == [] @@ -144,10 +144,10 @@ class TestSyncDocumentEmbeddingsClient: mock_response = MagicMock() mock_response.chunks = None client.call = MagicMock(return_value=mock_response) - + # Act - result = client.request(vectors=[[0.1, 0.2, 0.3]]) - + result = client.request(vector=[0.1, 0.2, 0.3]) + # Assert assert result is None @@ -161,12 +161,12 @@ class TestSyncDocumentEmbeddingsClient: mock_response = MagicMock() mock_response.chunks = ["chunk1"] client.call = MagicMock(return_value=mock_response) - + # Act client.request( - vectors=[[0.1, 0.2, 0.3]], + vector=[0.1, 0.2, 0.3], timeout=600 ) - + # Assert assert client.call.call_args[1]["timeout"] == 600 \ No newline at end of file diff --git a/tests/unit/test_query/test_doc_embeddings_milvus_query.py b/tests/unit/test_query/test_doc_embeddings_milvus_query.py index b4c954d8..01efa146 100644 --- a/tests/unit/test_query/test_doc_embeddings_milvus_query.py +++ b/tests/unit/test_query/test_doc_embeddings_milvus_query.py @@ -33,7 +33,7 @@ class TestMilvusDocEmbeddingsQueryProcessor: query = DocumentEmbeddingsRequest( user='test_user', collection='test_collection', - vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], + vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6], limit=10 ) return query @@ -71,7 +71,7 @@ class TestMilvusDocEmbeddingsQueryProcessor: query = DocumentEmbeddingsRequest( user='test_user', collection='test_collection', - vectors=[[0.1, 0.2, 0.3]], + vector=[0.1, 0.2, 0.3], limit=5 ) @@ -97,43 +97,35 @@ class TestMilvusDocEmbeddingsQueryProcessor: assert result[2] == "Third document chunk" @pytest.mark.asyncio - async def test_query_document_embeddings_multiple_vectors(self, processor): - """Test querying document embeddings with multiple vectors""" + async def test_query_document_embeddings_longer_vector(self, processor): + """Test querying document embeddings with a longer vector""" query = DocumentEmbeddingsRequest( user='test_user', collection='test_collection', - vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], + vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6], limit=3 ) - - # Mock search results - different results for each vector - mock_results_1 = [ - {"entity": {"chunk_id": "Document from first vector"}}, - {"entity": {"chunk_id": "Another doc from first vector"}}, + + # Mock search results + mock_results = [ + {"entity": {"chunk_id": "First document"}}, + {"entity": {"chunk_id": "Second document"}}, + {"entity": {"chunk_id": "Third document"}}, ] - mock_results_2 = [ - {"entity": {"chunk_id": "Document from second vector"}}, - ] - processor.vecstore.search.side_effect = [mock_results_1, mock_results_2] - + processor.vecstore.search.return_value = mock_results + result = await processor.query_document_embeddings(query) - - # Verify search was called twice with correct parameters including user/collection - expected_calls = [ - (([0.1, 0.2, 0.3], 'test_user', 'test_collection'), {"limit": 3}), - (([0.4, 0.5, 0.6], 'test_user', 'test_collection'), {"limit": 3}), - ] - assert processor.vecstore.search.call_count == 2 - for i, (expected_args, expected_kwargs) in enumerate(expected_calls): - actual_call = processor.vecstore.search.call_args_list[i] - assert actual_call[0] == expected_args - assert actual_call[1] == expected_kwargs - - # Verify results from all vectors are combined + + # Verify search was called once with the full vector + processor.vecstore.search.assert_called_once_with( + [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], 'test_user', 'test_collection', limit=3 + ) + + # Verify results assert len(result) == 3 - assert "Document from first vector" in result - assert "Another doc from first vector" in result - assert "Document from second vector" in result + assert "First document" in result + assert "Second document" in result + assert "Third document" in result @pytest.mark.asyncio async def test_query_document_embeddings_with_limit(self, processor): @@ -141,7 +133,7 @@ class TestMilvusDocEmbeddingsQueryProcessor: query = DocumentEmbeddingsRequest( user='test_user', collection='test_collection', - vectors=[[0.1, 0.2, 0.3]], + vector=[0.1, 0.2, 0.3], limit=2 ) @@ -170,7 +162,7 @@ class TestMilvusDocEmbeddingsQueryProcessor: query = DocumentEmbeddingsRequest( user='test_user', collection='test_collection', - vectors=[], + vector=[], limit=5 ) @@ -188,7 +180,7 @@ class TestMilvusDocEmbeddingsQueryProcessor: query = DocumentEmbeddingsRequest( user='test_user', collection='test_collection', - vectors=[[0.1, 0.2, 0.3]], + vector=[0.1, 0.2, 0.3], limit=5 ) @@ -211,7 +203,7 @@ class TestMilvusDocEmbeddingsQueryProcessor: query = DocumentEmbeddingsRequest( user='test_user', collection='test_collection', - vectors=[[0.1, 0.2, 0.3]], + vector=[0.1, 0.2, 0.3], limit=5 ) @@ -237,7 +229,7 @@ class TestMilvusDocEmbeddingsQueryProcessor: query = DocumentEmbeddingsRequest( user='test_user', collection='test_collection', - vectors=[[0.1, 0.2, 0.3]], + vector=[0.1, 0.2, 0.3], limit=5 ) @@ -262,7 +254,7 @@ class TestMilvusDocEmbeddingsQueryProcessor: query = DocumentEmbeddingsRequest( user='test_user', collection='test_collection', - vectors=[[0.1, 0.2, 0.3]], + vector=[0.1, 0.2, 0.3], limit=5 ) @@ -288,7 +280,7 @@ class TestMilvusDocEmbeddingsQueryProcessor: query = DocumentEmbeddingsRequest( user='test_user', collection='test_collection', - vectors=[[0.1, 0.2, 0.3]], + vector=[0.1, 0.2, 0.3], limit=0 ) @@ -306,7 +298,7 @@ class TestMilvusDocEmbeddingsQueryProcessor: query = DocumentEmbeddingsRequest( user='test_user', collection='test_collection', - vectors=[[0.1, 0.2, 0.3]], + vector=[0.1, 0.2, 0.3], limit=-1 ) @@ -324,7 +316,7 @@ class TestMilvusDocEmbeddingsQueryProcessor: query = DocumentEmbeddingsRequest( user='test_user', collection='test_collection', - vectors=[[0.1, 0.2, 0.3]], + vector=[0.1, 0.2, 0.3], limit=5 ) @@ -341,59 +333,51 @@ class TestMilvusDocEmbeddingsQueryProcessor: query = DocumentEmbeddingsRequest( user='test_user', collection='test_collection', - vectors=[ - [0.1, 0.2], # 2D vector - [0.3, 0.4, 0.5, 0.6], # 4D vector - [0.7, 0.8, 0.9] # 3D vector - ], + vector=[0.1, 0.2, 0.3, 0.4, 0.5], # 5D vector limit=5 ) - - # Mock search results for each vector - mock_results_1 = [{"entity": {"chunk_id": "Document from 2D vector"}}] - mock_results_2 = [{"entity": {"chunk_id": "Document from 4D vector"}}] - mock_results_3 = [{"entity": {"chunk_id": "Document from 3D vector"}}] - processor.vecstore.search.side_effect = [mock_results_1, mock_results_2, mock_results_3] - + + # Mock search results + mock_results = [ + {"entity": {"chunk_id": "Document 1"}}, + {"entity": {"chunk_id": "Document 2"}}, + ] + processor.vecstore.search.return_value = mock_results + result = await processor.query_document_embeddings(query) - - # Verify all vectors were searched - assert processor.vecstore.search.call_count == 3 - - # Verify results from all dimensions - assert len(result) == 3 - assert "Document from 2D vector" in result - assert "Document from 4D vector" in result - assert "Document from 3D vector" in result + + # Verify search was called with the vector + processor.vecstore.search.assert_called_once() + + # Verify results + assert len(result) == 2 + assert "Document 1" in result + assert "Document 2" in result @pytest.mark.asyncio - async def test_query_document_embeddings_duplicate_documents(self, processor): - """Test querying document embeddings with duplicate documents in results""" + async def test_query_document_embeddings_multiple_results(self, processor): + """Test querying document embeddings with multiple results""" query = DocumentEmbeddingsRequest( user='test_user', collection='test_collection', - vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], + vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6], limit=5 ) - - # Mock search results with duplicates across vectors - mock_results_1 = [ + + # Mock search results with multiple documents + mock_results = [ {"entity": {"chunk_id": "Document A"}}, {"entity": {"chunk_id": "Document B"}}, - ] - mock_results_2 = [ - {"entity": {"chunk_id": "Document B"}}, # Duplicate {"entity": {"chunk_id": "Document C"}}, ] - processor.vecstore.search.side_effect = [mock_results_1, mock_results_2] - + processor.vecstore.search.return_value = mock_results + result = await processor.query_document_embeddings(query) - - # Note: Unlike graph embeddings, doc embeddings don't deduplicate - # This preserves ranking and allows multiple occurrences - assert len(result) == 4 - assert result.count("Document B") == 2 # Should appear twice + + # Verify results + assert len(result) == 3 assert "Document A" in result + assert "Document B" in result assert "Document C" in result def test_add_args_method(self): diff --git a/tests/unit/test_query/test_graph_embeddings_milvus_query.py b/tests/unit/test_query/test_graph_embeddings_milvus_query.py index 21b6e1bf..3c40f95e 100644 --- a/tests/unit/test_query/test_graph_embeddings_milvus_query.py +++ b/tests/unit/test_query/test_graph_embeddings_milvus_query.py @@ -33,7 +33,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor: query = GraphEmbeddingsRequest( user='test_user', collection='test_collection', - vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], + vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6], limit=10 ) return query @@ -119,7 +119,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor: query = GraphEmbeddingsRequest( user='test_user', collection='test_collection', - vectors=[[0.1, 0.2, 0.3]], + vector=[0.1, 0.2, 0.3], limit=5 ) @@ -156,7 +156,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor: query = GraphEmbeddingsRequest( user='test_user', collection='test_collection', - vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], + vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6], limit=3 ) @@ -197,7 +197,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor: query = GraphEmbeddingsRequest( user='test_user', collection='test_collection', - vectors=[[0.1, 0.2, 0.3]], + vector=[0.1, 0.2, 0.3], limit=2 ) @@ -226,7 +226,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor: query = GraphEmbeddingsRequest( user='test_user', collection='test_collection', - vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], + vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6], limit=5 ) @@ -258,7 +258,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor: query = GraphEmbeddingsRequest( user='test_user', collection='test_collection', - vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], + vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6], limit=2 ) @@ -286,7 +286,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor: query = GraphEmbeddingsRequest( user='test_user', collection='test_collection', - vectors=[], + vector=[], limit=5 ) @@ -304,7 +304,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor: query = GraphEmbeddingsRequest( user='test_user', collection='test_collection', - vectors=[[0.1, 0.2, 0.3]], + vector=[0.1, 0.2, 0.3], limit=5 ) @@ -327,7 +327,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor: query = GraphEmbeddingsRequest( user='test_user', collection='test_collection', - vectors=[[0.1, 0.2, 0.3]], + vector=[0.1, 0.2, 0.3], limit=5 ) @@ -365,7 +365,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor: query = GraphEmbeddingsRequest( user='test_user', collection='test_collection', - vectors=[[0.1, 0.2, 0.3]], + vector=[0.1, 0.2, 0.3], limit=5 ) @@ -447,7 +447,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor: query = GraphEmbeddingsRequest( user='test_user', collection='test_collection', - vectors=[[0.1, 0.2, 0.3]], + vector=[0.1, 0.2, 0.3], limit=0 ) @@ -460,33 +460,29 @@ class TestMilvusGraphEmbeddingsQueryProcessor: assert len(result) == 0 @pytest.mark.asyncio - async def test_query_graph_embeddings_different_vector_dimensions(self, processor): - """Test querying graph embeddings with different vector dimensions""" + async def test_query_graph_embeddings_longer_vector(self, processor): + """Test querying graph embeddings with a longer vector""" query = GraphEmbeddingsRequest( user='test_user', collection='test_collection', - vectors=[ - [0.1, 0.2], # 2D vector - [0.3, 0.4, 0.5, 0.6], # 4D vector - [0.7, 0.8, 0.9] # 3D vector - ], + vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9], limit=5 ) - - # Mock search results for each vector - mock_results_1 = [{"entity": {"entity": "entity_2d"}}] - mock_results_2 = [{"entity": {"entity": "entity_4d"}}] - mock_results_3 = [{"entity": {"entity": "entity_3d"}}] - processor.vecstore.search.side_effect = [mock_results_1, mock_results_2, mock_results_3] - + + # Mock search results + mock_results = [ + {"entity": {"entity": "http://example.com/entity1"}}, + {"entity": {"entity": "http://example.com/entity2"}}, + ] + processor.vecstore.search.return_value = mock_results + result = await processor.query_graph_embeddings(query) - - # Verify all vectors were searched - assert processor.vecstore.search.call_count == 3 - - # Verify results from all dimensions - assert len(result) == 3 + + # Verify search was called once with the full vector + processor.vecstore.search.assert_called_once() + + # Verify results + assert len(result) == 2 entity_values = [r.iri if r.type == IRI else r.value for r in result] - assert "entity_2d" in entity_values - assert "entity_4d" in entity_values - assert "entity_3d" in entity_values \ No newline at end of file + assert "http://example.com/entity1" in entity_values + assert "http://example.com/entity2" in entity_values \ No newline at end of file