From 4fa7cc7d7c14af1d4183d3926c2f3e19b5e5b2fe Mon Sep 17 00:00:00 2001 From: cybermaggedon Date: Sun, 8 Mar 2026 19:42:26 +0000 Subject: [PATCH] Fix/embeddings integration 2 (#670) --- .../test_document_rag_integration.py | 10 ++- ...test_document_rag_streaming_integration.py | 3 +- .../integration/test_graph_rag_integration.py | 10 ++- .../test_graph_rag_streaming_integration.py | 3 +- .../test_rag_streaming_protocol.py | 4 +- .../unit/test_retrieval/test_document_rag.py | 81 ++++++++++--------- tests/unit/test_retrieval/test_graph_rag.py | 56 ++++++------- 7 files changed, 90 insertions(+), 77 deletions(-) diff --git a/tests/integration/test_document_rag_integration.py b/tests/integration/test_document_rag_integration.py index 75f148f3..99e25ed5 100644 --- a/tests/integration/test_document_rag_integration.py +++ b/tests/integration/test_document_rag_integration.py @@ -27,9 +27,13 @@ class TestDocumentRagIntegration: def mock_embeddings_client(self): """Mock embeddings client that returns realistic vector embeddings""" client = AsyncMock() + # New batch format: [[[vectors_for_text1], ...]] + # One text input returns one vector set containing two vectors client.embed.return_value = [ - [0.1, 0.2, 0.3, 0.4, 0.5], # Realistic 5-dimensional embedding - [0.6, 0.7, 0.8, 0.9, 1.0] # Second embedding for testing + [ + [0.1, 0.2, 0.3, 0.4, 0.5], # First vector for text + [0.6, 0.7, 0.8, 0.9, 1.0] # Second vector for text + ] ] return client @@ -90,7 +94,7 @@ class TestDocumentRagIntegration: ) # Assert - Verify service coordination - mock_embeddings_client.embed.assert_called_once_with(query) + mock_embeddings_client.embed.assert_called_once_with([query]) mock_doc_embeddings_client.query.assert_called_once_with( [[0.1, 0.2, 0.3, 0.4, 0.5], [0.6, 0.7, 0.8, 0.9, 1.0]], diff --git a/tests/integration/test_document_rag_streaming_integration.py b/tests/integration/test_document_rag_streaming_integration.py index 52b86caf..db79ebac 100644 --- a/tests/integration/test_document_rag_streaming_integration.py +++ b/tests/integration/test_document_rag_streaming_integration.py @@ -30,7 +30,8 @@ class TestDocumentRagStreaming: def mock_embeddings_client(self): """Mock embeddings client""" client = AsyncMock() - client.embed.return_value = [[0.1, 0.2, 0.3, 0.4, 0.5]] + # New batch format: [[[vectors_for_text1]]] + client.embed.return_value = [[[0.1, 0.2, 0.3, 0.4, 0.5]]] return client @pytest.fixture diff --git a/tests/integration/test_graph_rag_integration.py b/tests/integration/test_graph_rag_integration.py index a0608819..94e8cf08 100644 --- a/tests/integration/test_graph_rag_integration.py +++ b/tests/integration/test_graph_rag_integration.py @@ -21,8 +21,12 @@ class TestGraphRagIntegration: def mock_embeddings_client(self): """Mock embeddings client that returns realistic vector embeddings""" client = AsyncMock() + # New batch format: [[[vectors_for_text1], ...]] + # One text input returns one vector set containing one vector client.embed.return_value = [ - [0.1, 0.2, 0.3, 0.4, 0.5], # Realistic 5-dimensional embedding + [ + [0.1, 0.2, 0.3, 0.4, 0.5], # Vector for text + ] ] return client @@ -120,8 +124,8 @@ class TestGraphRagIntegration: # Assert - Verify service coordination - # 1. Should compute embeddings for query - mock_embeddings_client.embed.assert_called_once_with(query) + # 1. Should compute embeddings for query (now expects list of texts) + mock_embeddings_client.embed.assert_called_once_with([query]) # 2. Should query graph embeddings to find relevant entities mock_graph_embeddings_client.query.assert_called_once() diff --git a/tests/integration/test_graph_rag_streaming_integration.py b/tests/integration/test_graph_rag_streaming_integration.py index 47dd84b6..f4d8ce8b 100644 --- a/tests/integration/test_graph_rag_streaming_integration.py +++ b/tests/integration/test_graph_rag_streaming_integration.py @@ -24,7 +24,8 @@ class TestGraphRagStreaming: def mock_embeddings_client(self): """Mock embeddings client""" client = AsyncMock() - client.embed.return_value = [[0.1, 0.2, 0.3, 0.4, 0.5]] + # New batch format: [[[vectors_for_text1]]] + client.embed.return_value = [[[0.1, 0.2, 0.3, 0.4, 0.5]]] return client @pytest.fixture diff --git a/tests/integration/test_rag_streaming_protocol.py b/tests/integration/test_rag_streaming_protocol.py index ba492687..19f2cf35 100644 --- a/tests/integration/test_rag_streaming_protocol.py +++ b/tests/integration/test_rag_streaming_protocol.py @@ -18,7 +18,7 @@ class TestGraphRagStreamingProtocol: def mock_embeddings_client(self): """Mock embeddings client""" client = AsyncMock() - client.embed.return_value = [[0.1, 0.2, 0.3]] + client.embed.return_value = [[[0.1, 0.2, 0.3]]] return client @pytest.fixture @@ -197,7 +197,7 @@ class TestDocumentRagStreamingProtocol: def mock_embeddings_client(self): """Mock embeddings client""" client = AsyncMock() - client.embed.return_value = [[0.1, 0.2, 0.3]] + client.embed.return_value = [[[0.1, 0.2, 0.3]]] return client @pytest.fixture diff --git a/tests/unit/test_retrieval/test_document_rag.py b/tests/unit/test_retrieval/test_document_rag.py index c9f5e8e1..51e41524 100644 --- a/tests/unit/test_retrieval/test_document_rag.py +++ b/tests/unit/test_retrieval/test_document_rag.py @@ -132,9 +132,10 @@ class TestQuery: mock_embeddings_client = AsyncMock() mock_rag.embeddings_client = mock_embeddings_client - # Mock the embed method to return test vectors + # Mock the embed method to return test vectors in batch format + # New format: [[[vectors_for_text1]]] - returns first text's vector set expected_vectors = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]] - mock_embeddings_client.embed.return_value = expected_vectors + mock_embeddings_client.embed.return_value = [expected_vectors] # Initialize Query query = Query( @@ -148,10 +149,10 @@ class TestQuery: test_query = "What documents are relevant?" result = await query.get_vector(test_query) - # Verify embeddings client was called correctly - mock_embeddings_client.embed.assert_called_once_with(test_query) + # Verify embeddings client was called correctly (now expects list) + mock_embeddings_client.embed.assert_called_once_with([test_query]) - # Verify result matches expected vectors + # Verify result matches expected vectors (extracted from batch) assert result == expected_vectors @pytest.mark.asyncio @@ -170,8 +171,9 @@ class TestQuery: mock_rag.fetch_chunk = mock_fetch # Mock the embedding and document query responses + # New batch format: [[[vectors]]] - get_vector extracts [0] test_vectors = [[0.1, 0.2, 0.3]] - mock_embeddings_client.embed.return_value = test_vectors + mock_embeddings_client.embed.return_value = [test_vectors] # Mock document embeddings returns chunk_ids test_chunk_ids = ["doc/c1", "doc/c2"] @@ -190,10 +192,10 @@ class TestQuery: test_query = "Find relevant documents" result = await query.get_docs(test_query) - # Verify embeddings client was called - mock_embeddings_client.embed.assert_called_once_with(test_query) + # Verify embeddings client was called (now expects list) + mock_embeddings_client.embed.assert_called_once_with([test_query]) - # Verify doc embeddings client was called correctly + # Verify doc embeddings client was called correctly (with extracted vectors) mock_doc_embeddings_client.query.assert_called_once_with( test_vectors, limit=15, @@ -214,11 +216,12 @@ class TestQuery: mock_doc_embeddings_client = AsyncMock() # Mock embeddings and document embeddings responses + # New batch format: [[[vectors]]] - get_vector extracts [0] test_vectors = [[0.1, 0.2, 0.3]] test_chunk_ids = ["doc/c3", "doc/c4"] expected_response = "This is the document RAG response" - mock_embeddings_client.embed.return_value = test_vectors + mock_embeddings_client.embed.return_value = [test_vectors] mock_doc_embeddings_client.query.return_value = test_chunk_ids mock_prompt_client.document_prompt.return_value = expected_response @@ -239,10 +242,10 @@ class TestQuery: doc_limit=10 ) - # Verify embeddings client was called - mock_embeddings_client.embed.assert_called_once_with("test query") + # Verify embeddings client was called (now expects list) + mock_embeddings_client.embed.assert_called_once_with(["test query"]) - # Verify doc embeddings client was called + # Verify doc embeddings client was called (with extracted vectors) mock_doc_embeddings_client.query.assert_called_once_with( test_vectors, limit=10, @@ -270,8 +273,8 @@ class TestQuery: mock_embeddings_client = AsyncMock() mock_doc_embeddings_client = AsyncMock() - # Mock responses - mock_embeddings_client.embed.return_value = [[0.1, 0.2]] + # Mock responses (batch format) + mock_embeddings_client.embed.return_value = [[[0.1, 0.2]]] mock_doc_embeddings_client.query.return_value = ["doc/c5"] mock_prompt_client.document_prompt.return_value = "Default response" @@ -286,7 +289,7 @@ class TestQuery: # Call DocumentRag.query with minimal parameters result = await document_rag.query("simple query") - # Verify default parameters were used + # Verify default parameters were used (vectors extracted from batch) mock_doc_embeddings_client.query.assert_called_once_with( [[0.1, 0.2]], limit=20, # Default doc_limit @@ -311,8 +314,8 @@ class TestQuery: return CHUNK_CONTENT.get(chunk_id, f"Content for {chunk_id}") mock_rag.fetch_chunk = mock_fetch - # Mock responses - mock_embeddings_client.embed.return_value = [[0.7, 0.8]] + # Mock responses (batch format) + mock_embeddings_client.embed.return_value = [[[0.7, 0.8]]] mock_doc_embeddings_client.query.return_value = ["doc/c6"] # Initialize Query with verbose=True @@ -327,8 +330,8 @@ class TestQuery: # Call get_docs result = await query.get_docs("verbose test") - # Verify calls were made - mock_embeddings_client.embed.assert_called_once_with("verbose test") + # Verify calls were made (now expects list) + mock_embeddings_client.embed.assert_called_once_with(["verbose test"]) mock_doc_embeddings_client.query.assert_called_once() # Verify result contains fetched content @@ -342,8 +345,8 @@ class TestQuery: mock_embeddings_client = AsyncMock() mock_doc_embeddings_client = AsyncMock() - # Mock responses - mock_embeddings_client.embed.return_value = [[0.3, 0.4]] + # Mock responses (batch format) + mock_embeddings_client.embed.return_value = [[[0.3, 0.4]]] mock_doc_embeddings_client.query.return_value = ["doc/c7"] mock_prompt_client.document_prompt.return_value = "Verbose RAG response" @@ -359,8 +362,8 @@ class TestQuery: # Call DocumentRag.query result = await document_rag.query("verbose query test") - # Verify all clients were called - mock_embeddings_client.embed.assert_called_once_with("verbose query test") + # Verify all clients were called (now expects list) + mock_embeddings_client.embed.assert_called_once_with(["verbose query test"]) mock_doc_embeddings_client.query.assert_called_once() # Verify prompt client was called with fetched content @@ -385,8 +388,8 @@ class TestQuery: return f"Content for {chunk_id}" mock_rag.fetch_chunk = mock_fetch - # Mock responses - empty chunk_id list - mock_embeddings_client.embed.return_value = [[0.1, 0.2]] + # Mock responses - empty chunk_id list (batch format) + mock_embeddings_client.embed.return_value = [[[0.1, 0.2]]] mock_doc_embeddings_client.query.return_value = [] # No chunk_ids found # Initialize Query @@ -400,8 +403,8 @@ class TestQuery: # Call get_docs result = await query.get_docs("query with no results") - # Verify calls were made - mock_embeddings_client.embed.assert_called_once_with("query with no results") + # Verify calls were made (now expects list) + mock_embeddings_client.embed.assert_called_once_with(["query with no results"]) mock_doc_embeddings_client.query.assert_called_once() # Verify empty result is returned @@ -415,8 +418,8 @@ class TestQuery: mock_embeddings_client = AsyncMock() mock_doc_embeddings_client = AsyncMock() - # Mock responses - no chunk_ids found - mock_embeddings_client.embed.return_value = [[0.5, 0.6]] + # Mock responses - no chunk_ids found (batch format) + mock_embeddings_client.embed.return_value = [[[0.5, 0.6]]] mock_doc_embeddings_client.query.return_value = [] # Empty chunk_id list mock_prompt_client.document_prompt.return_value = "No documents found response" @@ -448,9 +451,9 @@ class TestQuery: mock_embeddings_client = AsyncMock() mock_rag.embeddings_client = mock_embeddings_client - # Mock the embed method + # Mock the embed method (batch format) expected_vectors = [[0.9, 1.0, 1.1]] - mock_embeddings_client.embed.return_value = expected_vectors + mock_embeddings_client.embed.return_value = [expected_vectors] # Initialize Query with verbose=True query = Query( @@ -463,10 +466,10 @@ class TestQuery: # Call get_vector result = await query.get_vector("verbose vector test") - # Verify embeddings client was called - mock_embeddings_client.embed.assert_called_once_with("verbose vector test") + # Verify embeddings client was called (now expects list) + mock_embeddings_client.embed.assert_called_once_with(["verbose vector test"]) - # Verify result + # Verify result (extracted from batch) assert result == expected_vectors @pytest.mark.asyncio @@ -477,13 +480,13 @@ class TestQuery: mock_embeddings_client = AsyncMock() mock_doc_embeddings_client = AsyncMock() - # Mock realistic responses + # Mock realistic responses (batch format) query_text = "What is machine learning?" query_vectors = [[0.1, 0.2, 0.3, 0.4, 0.5]] retrieved_chunk_ids = ["doc/ml1", "doc/ml2", "doc/ml3"] final_response = "Machine learning is a field of AI that enables computers to learn and improve from experience without being explicitly programmed." - mock_embeddings_client.embed.return_value = query_vectors + mock_embeddings_client.embed.return_value = [query_vectors] mock_doc_embeddings_client.query.return_value = retrieved_chunk_ids mock_prompt_client.document_prompt.return_value = final_response @@ -504,8 +507,8 @@ class TestQuery: doc_limit=25 ) - # Verify complete pipeline execution - mock_embeddings_client.embed.assert_called_once_with(query_text) + # Verify complete pipeline execution (now expects list) + mock_embeddings_client.embed.assert_called_once_with([query_text]) mock_doc_embeddings_client.query.assert_called_once_with( query_vectors, diff --git a/tests/unit/test_retrieval/test_graph_rag.py b/tests/unit/test_retrieval/test_graph_rag.py index 5f54e28a..15b0c82d 100644 --- a/tests/unit/test_retrieval/test_graph_rag.py +++ b/tests/unit/test_retrieval/test_graph_rag.py @@ -127,10 +127,10 @@ class TestQuery: mock_embeddings_client = AsyncMock() mock_rag.embeddings_client = mock_embeddings_client - # Mock the embed method to return test vectors + # Mock the embed method to return test vectors (batch format) expected_vectors = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]] - mock_embeddings_client.embed.return_value = expected_vectors - + mock_embeddings_client.embed.return_value = [expected_vectors] + # Initialize Query query = Query( rag=mock_rag, @@ -138,15 +138,15 @@ class TestQuery: collection="test_collection", verbose=False ) - + # Call get_vector test_query = "What is the capital of France?" result = await query.get_vector(test_query) - - # Verify embeddings client was called correctly - mock_embeddings_client.embed.assert_called_once_with(test_query) - - # Verify result matches expected vectors + + # Verify embeddings client was called correctly (now expects list) + mock_embeddings_client.embed.assert_called_once_with([test_query]) + + # Verify result matches expected vectors (extracted from batch) assert result == expected_vectors @pytest.mark.asyncio @@ -157,10 +157,10 @@ class TestQuery: mock_embeddings_client = AsyncMock() mock_rag.embeddings_client = mock_embeddings_client - # Mock the embed method + # Mock the embed method (batch format) expected_vectors = [[0.7, 0.8, 0.9]] - mock_embeddings_client.embed.return_value = expected_vectors - + mock_embeddings_client.embed.return_value = [expected_vectors] + # Initialize Query with verbose=True query = Query( rag=mock_rag, @@ -168,15 +168,15 @@ class TestQuery: collection="test_collection", verbose=True ) - + # Call get_vector test_query = "Test query for embeddings" result = await query.get_vector(test_query) - - # Verify embeddings client was called correctly - mock_embeddings_client.embed.assert_called_once_with(test_query) - - # Verify result matches expected vectors + + # Verify embeddings client was called correctly (now expects list) + mock_embeddings_client.embed.assert_called_once_with([test_query]) + + # Verify result matches expected vectors (extracted from batch) assert result == expected_vectors @pytest.mark.asyncio @@ -189,17 +189,17 @@ class TestQuery: mock_rag.embeddings_client = mock_embeddings_client mock_rag.graph_embeddings_client = mock_graph_embeddings_client - # Mock the embedding and entity query responses + # Mock the embedding and entity query responses (batch format) test_vectors = [[0.1, 0.2, 0.3]] - mock_embeddings_client.embed.return_value = test_vectors - + mock_embeddings_client.embed.return_value = [test_vectors] + # Mock entity objects that have string representation mock_entity1 = MagicMock() mock_entity1.__str__ = MagicMock(return_value="entity1") mock_entity2 = MagicMock() mock_entity2.__str__ = MagicMock(return_value="entity2") mock_graph_embeddings_client.query.return_value = [mock_entity1, mock_entity2] - + # Initialize Query query = Query( rag=mock_rag, @@ -208,15 +208,15 @@ class TestQuery: verbose=False, entity_limit=25 ) - + # Call get_entities test_query = "Find related entities" result = await query.get_entities(test_query) - - # Verify embeddings client was called - mock_embeddings_client.embed.assert_called_once_with(test_query) - - # Verify graph embeddings client was called correctly + + # Verify embeddings client was called (now expects list) + mock_embeddings_client.embed.assert_called_once_with([test_query]) + + # Verify graph embeddings client was called correctly (with extracted vectors) mock_graph_embeddings_client.query.assert_called_once_with( vectors=test_vectors, limit=25,