Fix/embeddings integration 2 (#670)

This commit is contained in:
cybermaggedon 2026-03-08 19:42:26 +00:00 committed by GitHub
parent 919b760c05
commit 4fa7cc7d7c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 90 additions and 77 deletions

View file

@ -132,9 +132,10 @@ class TestQuery:
mock_embeddings_client = AsyncMock()
mock_rag.embeddings_client = mock_embeddings_client
# Mock the embed method to return test vectors
# Mock the embed method to return test vectors in batch format
# New format: [[[vectors_for_text1]]] - returns first text's vector set
expected_vectors = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
mock_embeddings_client.embed.return_value = expected_vectors
mock_embeddings_client.embed.return_value = [expected_vectors]
# Initialize Query
query = Query(
@ -148,10 +149,10 @@ class TestQuery:
test_query = "What documents are relevant?"
result = await query.get_vector(test_query)
# Verify embeddings client was called correctly
mock_embeddings_client.embed.assert_called_once_with(test_query)
# Verify embeddings client was called correctly (now expects list)
mock_embeddings_client.embed.assert_called_once_with([test_query])
# Verify result matches expected vectors
# Verify result matches expected vectors (extracted from batch)
assert result == expected_vectors
@pytest.mark.asyncio
@ -170,8 +171,9 @@ class TestQuery:
mock_rag.fetch_chunk = mock_fetch
# Mock the embedding and document query responses
# New batch format: [[[vectors]]] - get_vector extracts [0]
test_vectors = [[0.1, 0.2, 0.3]]
mock_embeddings_client.embed.return_value = test_vectors
mock_embeddings_client.embed.return_value = [test_vectors]
# Mock document embeddings returns chunk_ids
test_chunk_ids = ["doc/c1", "doc/c2"]
@ -190,10 +192,10 @@ class TestQuery:
test_query = "Find relevant documents"
result = await query.get_docs(test_query)
# Verify embeddings client was called
mock_embeddings_client.embed.assert_called_once_with(test_query)
# Verify embeddings client was called (now expects list)
mock_embeddings_client.embed.assert_called_once_with([test_query])
# Verify doc embeddings client was called correctly
# Verify doc embeddings client was called correctly (with extracted vectors)
mock_doc_embeddings_client.query.assert_called_once_with(
test_vectors,
limit=15,
@ -214,11 +216,12 @@ class TestQuery:
mock_doc_embeddings_client = AsyncMock()
# Mock embeddings and document embeddings responses
# New batch format: [[[vectors]]] - get_vector extracts [0]
test_vectors = [[0.1, 0.2, 0.3]]
test_chunk_ids = ["doc/c3", "doc/c4"]
expected_response = "This is the document RAG response"
mock_embeddings_client.embed.return_value = test_vectors
mock_embeddings_client.embed.return_value = [test_vectors]
mock_doc_embeddings_client.query.return_value = test_chunk_ids
mock_prompt_client.document_prompt.return_value = expected_response
@ -239,10 +242,10 @@ class TestQuery:
doc_limit=10
)
# Verify embeddings client was called
mock_embeddings_client.embed.assert_called_once_with("test query")
# Verify embeddings client was called (now expects list)
mock_embeddings_client.embed.assert_called_once_with(["test query"])
# Verify doc embeddings client was called
# Verify doc embeddings client was called (with extracted vectors)
mock_doc_embeddings_client.query.assert_called_once_with(
test_vectors,
limit=10,
@ -270,8 +273,8 @@ class TestQuery:
mock_embeddings_client = AsyncMock()
mock_doc_embeddings_client = AsyncMock()
# Mock responses
mock_embeddings_client.embed.return_value = [[0.1, 0.2]]
# Mock responses (batch format)
mock_embeddings_client.embed.return_value = [[[0.1, 0.2]]]
mock_doc_embeddings_client.query.return_value = ["doc/c5"]
mock_prompt_client.document_prompt.return_value = "Default response"
@ -286,7 +289,7 @@ class TestQuery:
# Call DocumentRag.query with minimal parameters
result = await document_rag.query("simple query")
# Verify default parameters were used
# Verify default parameters were used (vectors extracted from batch)
mock_doc_embeddings_client.query.assert_called_once_with(
[[0.1, 0.2]],
limit=20, # Default doc_limit
@ -311,8 +314,8 @@ class TestQuery:
return CHUNK_CONTENT.get(chunk_id, f"Content for {chunk_id}")
mock_rag.fetch_chunk = mock_fetch
# Mock responses
mock_embeddings_client.embed.return_value = [[0.7, 0.8]]
# Mock responses (batch format)
mock_embeddings_client.embed.return_value = [[[0.7, 0.8]]]
mock_doc_embeddings_client.query.return_value = ["doc/c6"]
# Initialize Query with verbose=True
@ -327,8 +330,8 @@ class TestQuery:
# Call get_docs
result = await query.get_docs("verbose test")
# Verify calls were made
mock_embeddings_client.embed.assert_called_once_with("verbose test")
# Verify calls were made (now expects list)
mock_embeddings_client.embed.assert_called_once_with(["verbose test"])
mock_doc_embeddings_client.query.assert_called_once()
# Verify result contains fetched content
@ -342,8 +345,8 @@ class TestQuery:
mock_embeddings_client = AsyncMock()
mock_doc_embeddings_client = AsyncMock()
# Mock responses
mock_embeddings_client.embed.return_value = [[0.3, 0.4]]
# Mock responses (batch format)
mock_embeddings_client.embed.return_value = [[[0.3, 0.4]]]
mock_doc_embeddings_client.query.return_value = ["doc/c7"]
mock_prompt_client.document_prompt.return_value = "Verbose RAG response"
@ -359,8 +362,8 @@ class TestQuery:
# Call DocumentRag.query
result = await document_rag.query("verbose query test")
# Verify all clients were called
mock_embeddings_client.embed.assert_called_once_with("verbose query test")
# Verify all clients were called (now expects list)
mock_embeddings_client.embed.assert_called_once_with(["verbose query test"])
mock_doc_embeddings_client.query.assert_called_once()
# Verify prompt client was called with fetched content
@ -385,8 +388,8 @@ class TestQuery:
return f"Content for {chunk_id}"
mock_rag.fetch_chunk = mock_fetch
# Mock responses - empty chunk_id list
mock_embeddings_client.embed.return_value = [[0.1, 0.2]]
# Mock responses - empty chunk_id list (batch format)
mock_embeddings_client.embed.return_value = [[[0.1, 0.2]]]
mock_doc_embeddings_client.query.return_value = [] # No chunk_ids found
# Initialize Query
@ -400,8 +403,8 @@ class TestQuery:
# Call get_docs
result = await query.get_docs("query with no results")
# Verify calls were made
mock_embeddings_client.embed.assert_called_once_with("query with no results")
# Verify calls were made (now expects list)
mock_embeddings_client.embed.assert_called_once_with(["query with no results"])
mock_doc_embeddings_client.query.assert_called_once()
# Verify empty result is returned
@ -415,8 +418,8 @@ class TestQuery:
mock_embeddings_client = AsyncMock()
mock_doc_embeddings_client = AsyncMock()
# Mock responses - no chunk_ids found
mock_embeddings_client.embed.return_value = [[0.5, 0.6]]
# Mock responses - no chunk_ids found (batch format)
mock_embeddings_client.embed.return_value = [[[0.5, 0.6]]]
mock_doc_embeddings_client.query.return_value = [] # Empty chunk_id list
mock_prompt_client.document_prompt.return_value = "No documents found response"
@ -448,9 +451,9 @@ class TestQuery:
mock_embeddings_client = AsyncMock()
mock_rag.embeddings_client = mock_embeddings_client
# Mock the embed method
# Mock the embed method (batch format)
expected_vectors = [[0.9, 1.0, 1.1]]
mock_embeddings_client.embed.return_value = expected_vectors
mock_embeddings_client.embed.return_value = [expected_vectors]
# Initialize Query with verbose=True
query = Query(
@ -463,10 +466,10 @@ class TestQuery:
# Call get_vector
result = await query.get_vector("verbose vector test")
# Verify embeddings client was called
mock_embeddings_client.embed.assert_called_once_with("verbose vector test")
# Verify embeddings client was called (now expects list)
mock_embeddings_client.embed.assert_called_once_with(["verbose vector test"])
# Verify result
# Verify result (extracted from batch)
assert result == expected_vectors
@pytest.mark.asyncio
@ -477,13 +480,13 @@ class TestQuery:
mock_embeddings_client = AsyncMock()
mock_doc_embeddings_client = AsyncMock()
# Mock realistic responses
# Mock realistic responses (batch format)
query_text = "What is machine learning?"
query_vectors = [[0.1, 0.2, 0.3, 0.4, 0.5]]
retrieved_chunk_ids = ["doc/ml1", "doc/ml2", "doc/ml3"]
final_response = "Machine learning is a field of AI that enables computers to learn and improve from experience without being explicitly programmed."
mock_embeddings_client.embed.return_value = query_vectors
mock_embeddings_client.embed.return_value = [query_vectors]
mock_doc_embeddings_client.query.return_value = retrieved_chunk_ids
mock_prompt_client.document_prompt.return_value = final_response
@ -504,8 +507,8 @@ class TestQuery:
doc_limit=25
)
# Verify complete pipeline execution
mock_embeddings_client.embed.assert_called_once_with(query_text)
# Verify complete pipeline execution (now expects list)
mock_embeddings_client.embed.assert_called_once_with([query_text])
mock_doc_embeddings_client.query.assert_called_once_with(
query_vectors,