Fix/embeddings integration 2 (#670)

This commit is contained in:
cybermaggedon 2026-03-08 19:42:26 +00:00 committed by GitHub
parent 919b760c05
commit 4fa7cc7d7c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 90 additions and 77 deletions

View file

@ -127,10 +127,10 @@ class TestQuery:
mock_embeddings_client = AsyncMock()
mock_rag.embeddings_client = mock_embeddings_client
# Mock the embed method to return test vectors
# Mock the embed method to return test vectors (batch format)
expected_vectors = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
mock_embeddings_client.embed.return_value = expected_vectors
mock_embeddings_client.embed.return_value = [expected_vectors]
# Initialize Query
query = Query(
rag=mock_rag,
@ -138,15 +138,15 @@ class TestQuery:
collection="test_collection",
verbose=False
)
# Call get_vector
test_query = "What is the capital of France?"
result = await query.get_vector(test_query)
# Verify embeddings client was called correctly
mock_embeddings_client.embed.assert_called_once_with(test_query)
# Verify result matches expected vectors
# Verify embeddings client was called correctly (now expects list)
mock_embeddings_client.embed.assert_called_once_with([test_query])
# Verify result matches expected vectors (extracted from batch)
assert result == expected_vectors
@pytest.mark.asyncio
@ -157,10 +157,10 @@ class TestQuery:
mock_embeddings_client = AsyncMock()
mock_rag.embeddings_client = mock_embeddings_client
# Mock the embed method
# Mock the embed method (batch format)
expected_vectors = [[0.7, 0.8, 0.9]]
mock_embeddings_client.embed.return_value = expected_vectors
mock_embeddings_client.embed.return_value = [expected_vectors]
# Initialize Query with verbose=True
query = Query(
rag=mock_rag,
@ -168,15 +168,15 @@ class TestQuery:
collection="test_collection",
verbose=True
)
# Call get_vector
test_query = "Test query for embeddings"
result = await query.get_vector(test_query)
# Verify embeddings client was called correctly
mock_embeddings_client.embed.assert_called_once_with(test_query)
# Verify result matches expected vectors
# Verify embeddings client was called correctly (now expects list)
mock_embeddings_client.embed.assert_called_once_with([test_query])
# Verify result matches expected vectors (extracted from batch)
assert result == expected_vectors
@pytest.mark.asyncio
@ -189,17 +189,17 @@ class TestQuery:
mock_rag.embeddings_client = mock_embeddings_client
mock_rag.graph_embeddings_client = mock_graph_embeddings_client
# Mock the embedding and entity query responses
# Mock the embedding and entity query responses (batch format)
test_vectors = [[0.1, 0.2, 0.3]]
mock_embeddings_client.embed.return_value = test_vectors
mock_embeddings_client.embed.return_value = [test_vectors]
# Mock entity objects that have string representation
mock_entity1 = MagicMock()
mock_entity1.__str__ = MagicMock(return_value="entity1")
mock_entity2 = MagicMock()
mock_entity2.__str__ = MagicMock(return_value="entity2")
mock_graph_embeddings_client.query.return_value = [mock_entity1, mock_entity2]
# Initialize Query
query = Query(
rag=mock_rag,
@ -208,15 +208,15 @@ class TestQuery:
verbose=False,
entity_limit=25
)
# Call get_entities
test_query = "Find related entities"
result = await query.get_entities(test_query)
# Verify embeddings client was called
mock_embeddings_client.embed.assert_called_once_with(test_query)
# Verify graph embeddings client was called correctly
# Verify embeddings client was called (now expects list)
mock_embeddings_client.embed.assert_called_once_with([test_query])
# Verify graph embeddings client was called correctly (with extracted vectors)
mock_graph_embeddings_client.query.assert_called_once_with(
vectors=test_vectors,
limit=25,