mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-25 08:26:21 +02:00
Fix/embeddings integration 2 (#670)
This commit is contained in:
parent
919b760c05
commit
4fa7cc7d7c
7 changed files with 90 additions and 77 deletions
|
|
@ -127,10 +127,10 @@ class TestQuery:
|
|||
mock_embeddings_client = AsyncMock()
|
||||
mock_rag.embeddings_client = mock_embeddings_client
|
||||
|
||||
# Mock the embed method to return test vectors
|
||||
# Mock the embed method to return test vectors (batch format)
|
||||
expected_vectors = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
|
||||
mock_embeddings_client.embed.return_value = expected_vectors
|
||||
|
||||
mock_embeddings_client.embed.return_value = [expected_vectors]
|
||||
|
||||
# Initialize Query
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
|
|
@ -138,15 +138,15 @@ class TestQuery:
|
|||
collection="test_collection",
|
||||
verbose=False
|
||||
)
|
||||
|
||||
|
||||
# Call get_vector
|
||||
test_query = "What is the capital of France?"
|
||||
result = await query.get_vector(test_query)
|
||||
|
||||
# Verify embeddings client was called correctly
|
||||
mock_embeddings_client.embed.assert_called_once_with(test_query)
|
||||
|
||||
# Verify result matches expected vectors
|
||||
|
||||
# Verify embeddings client was called correctly (now expects list)
|
||||
mock_embeddings_client.embed.assert_called_once_with([test_query])
|
||||
|
||||
# Verify result matches expected vectors (extracted from batch)
|
||||
assert result == expected_vectors
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -157,10 +157,10 @@ class TestQuery:
|
|||
mock_embeddings_client = AsyncMock()
|
||||
mock_rag.embeddings_client = mock_embeddings_client
|
||||
|
||||
# Mock the embed method
|
||||
# Mock the embed method (batch format)
|
||||
expected_vectors = [[0.7, 0.8, 0.9]]
|
||||
mock_embeddings_client.embed.return_value = expected_vectors
|
||||
|
||||
mock_embeddings_client.embed.return_value = [expected_vectors]
|
||||
|
||||
# Initialize Query with verbose=True
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
|
|
@ -168,15 +168,15 @@ class TestQuery:
|
|||
collection="test_collection",
|
||||
verbose=True
|
||||
)
|
||||
|
||||
|
||||
# Call get_vector
|
||||
test_query = "Test query for embeddings"
|
||||
result = await query.get_vector(test_query)
|
||||
|
||||
# Verify embeddings client was called correctly
|
||||
mock_embeddings_client.embed.assert_called_once_with(test_query)
|
||||
|
||||
# Verify result matches expected vectors
|
||||
|
||||
# Verify embeddings client was called correctly (now expects list)
|
||||
mock_embeddings_client.embed.assert_called_once_with([test_query])
|
||||
|
||||
# Verify result matches expected vectors (extracted from batch)
|
||||
assert result == expected_vectors
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -189,17 +189,17 @@ class TestQuery:
|
|||
mock_rag.embeddings_client = mock_embeddings_client
|
||||
mock_rag.graph_embeddings_client = mock_graph_embeddings_client
|
||||
|
||||
# Mock the embedding and entity query responses
|
||||
# Mock the embedding and entity query responses (batch format)
|
||||
test_vectors = [[0.1, 0.2, 0.3]]
|
||||
mock_embeddings_client.embed.return_value = test_vectors
|
||||
|
||||
mock_embeddings_client.embed.return_value = [test_vectors]
|
||||
|
||||
# Mock entity objects that have string representation
|
||||
mock_entity1 = MagicMock()
|
||||
mock_entity1.__str__ = MagicMock(return_value="entity1")
|
||||
mock_entity2 = MagicMock()
|
||||
mock_entity2.__str__ = MagicMock(return_value="entity2")
|
||||
mock_graph_embeddings_client.query.return_value = [mock_entity1, mock_entity2]
|
||||
|
||||
|
||||
# Initialize Query
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
|
|
@ -208,15 +208,15 @@ class TestQuery:
|
|||
verbose=False,
|
||||
entity_limit=25
|
||||
)
|
||||
|
||||
|
||||
# Call get_entities
|
||||
test_query = "Find related entities"
|
||||
result = await query.get_entities(test_query)
|
||||
|
||||
# Verify embeddings client was called
|
||||
mock_embeddings_client.embed.assert_called_once_with(test_query)
|
||||
|
||||
# Verify graph embeddings client was called correctly
|
||||
|
||||
# Verify embeddings client was called (now expects list)
|
||||
mock_embeddings_client.embed.assert_called_once_with([test_query])
|
||||
|
||||
# Verify graph embeddings client was called correctly (with extracted vectors)
|
||||
mock_graph_embeddings_client.query.assert_called_once_with(
|
||||
vectors=test_vectors,
|
||||
limit=25,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue