mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-26 08:56:21 +02:00
Batch embeddings (#668)
Base Service (trustgraph-base/trustgraph/base/embeddings_service.py): - Changed on_request to use request.texts FastEmbed Processor (trustgraph-flow/trustgraph/embeddings/fastembed/processor.py): - on_embeddings(texts, model=None) now processes full batch efficiently - Returns [[v.tolist()] for v in vecs] - list of vector sets Ollama Processor (trustgraph-flow/trustgraph/embeddings/ollama/processor.py): - on_embeddings(texts, model=None) passes list directly to Ollama - Returns [[embedding] for embedding in embeds.embeddings] EmbeddingsClient (trustgraph-base/trustgraph/base/embeddings_client.py): - embed(texts, timeout=300) accepts list of texts Tests Updated: - test_fastembed_dynamic_model.py - 4 tests updated for new interface - test_ollama_dynamic_model.py - 4 tests updated for new interface Updated CLI, SDK and APIs
This commit is contained in:
parent
3bf8a65409
commit
0a2ce47a88
16 changed files with 785 additions and 79 deletions
|
|
@ -103,12 +103,12 @@ class TestFastEmbedDynamicModelLoading(IsolatedAsyncioTestCase):
|
|||
mock_text_embedding_class.reset_mock()
|
||||
|
||||
# Act
|
||||
result = await processor.on_embeddings("test text")
|
||||
result = await processor.on_embeddings(["test text"])
|
||||
|
||||
# Assert
|
||||
mock_fastembed_instance.embed.assert_called_once_with(["test text"])
|
||||
assert processor.cached_model_name == "test-model" # Still using default
|
||||
assert result == [[0.1, 0.2, 0.3, 0.4, 0.5]]
|
||||
assert result == [[[0.1, 0.2, 0.3, 0.4, 0.5]]]
|
||||
|
||||
@patch('trustgraph.embeddings.fastembed.processor.TextEmbedding')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
|
|
@ -126,7 +126,7 @@ class TestFastEmbedDynamicModelLoading(IsolatedAsyncioTestCase):
|
|||
mock_text_embedding_class.reset_mock()
|
||||
|
||||
# Act
|
||||
result = await processor.on_embeddings("test text", model="custom-model")
|
||||
result = await processor.on_embeddings(["test text"], model="custom-model")
|
||||
|
||||
# Assert
|
||||
mock_text_embedding_class.assert_called_once_with(model_name="custom-model")
|
||||
|
|
@ -149,16 +149,16 @@ class TestFastEmbedDynamicModelLoading(IsolatedAsyncioTestCase):
|
|||
initial_call_count = mock_text_embedding_class.call_count
|
||||
|
||||
# Act - switch between models
|
||||
await processor.on_embeddings("text1", model="model-a")
|
||||
await processor.on_embeddings(["text1"], model="model-a")
|
||||
call_count_after_a = mock_text_embedding_class.call_count
|
||||
|
||||
await processor.on_embeddings("text2", model="model-a") # Same, no reload
|
||||
await processor.on_embeddings(["text2"], model="model-a") # Same, no reload
|
||||
call_count_after_a_repeat = mock_text_embedding_class.call_count
|
||||
|
||||
await processor.on_embeddings("text3", model="model-b") # Different, reload
|
||||
await processor.on_embeddings(["text3"], model="model-b") # Different, reload
|
||||
call_count_after_b = mock_text_embedding_class.call_count
|
||||
|
||||
await processor.on_embeddings("text4", model="model-a") # Back to A, reload
|
||||
await processor.on_embeddings(["text4"], model="model-a") # Back to A, reload
|
||||
call_count_after_a_again = mock_text_embedding_class.call_count
|
||||
|
||||
# Assert
|
||||
|
|
@ -183,7 +183,7 @@ class TestFastEmbedDynamicModelLoading(IsolatedAsyncioTestCase):
|
|||
initial_count = mock_text_embedding_class.call_count
|
||||
|
||||
# Act
|
||||
result = await processor.on_embeddings("test text", model=None)
|
||||
result = await processor.on_embeddings(["test text"], model=None)
|
||||
|
||||
# Assert
|
||||
# No reload, using cached default
|
||||
|
|
|
|||
|
|
@ -53,14 +53,14 @@ class TestOllamaDynamicModelLoading(IsolatedAsyncioTestCase):
|
|||
processor = Processor(id="test", concurrency=1, model="test-model", taskgroup=AsyncMock())
|
||||
|
||||
# Act
|
||||
result = await processor.on_embeddings("test text")
|
||||
result = await processor.on_embeddings(["test text"])
|
||||
|
||||
# Assert
|
||||
mock_ollama_client.embed.assert_called_once_with(
|
||||
model="test-model",
|
||||
input="test text"
|
||||
input=["test text"]
|
||||
)
|
||||
assert result == [[0.1, 0.2, 0.3, 0.4, 0.5]]
|
||||
assert result == [[[0.1, 0.2, 0.3, 0.4, 0.5]]]
|
||||
|
||||
@patch('trustgraph.embeddings.ollama.processor.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
|
|
@ -79,14 +79,14 @@ class TestOllamaDynamicModelLoading(IsolatedAsyncioTestCase):
|
|||
processor = Processor(id="test", concurrency=1, model="test-model", taskgroup=AsyncMock())
|
||||
|
||||
# Act
|
||||
result = await processor.on_embeddings("test text", model="custom-model")
|
||||
result = await processor.on_embeddings(["test text"], model="custom-model")
|
||||
|
||||
# Assert
|
||||
mock_ollama_client.embed.assert_called_once_with(
|
||||
model="custom-model",
|
||||
input="test text"
|
||||
input=["test text"]
|
||||
)
|
||||
assert result == [[0.1, 0.2, 0.3, 0.4, 0.5]]
|
||||
assert result == [[[0.1, 0.2, 0.3, 0.4, 0.5]]]
|
||||
|
||||
@patch('trustgraph.embeddings.ollama.processor.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
|
|
@ -105,10 +105,10 @@ class TestOllamaDynamicModelLoading(IsolatedAsyncioTestCase):
|
|||
processor = Processor(id="test", concurrency=1, model="test-model", taskgroup=AsyncMock())
|
||||
|
||||
# Act - switch between different models
|
||||
await processor.on_embeddings("text1", model="model-a")
|
||||
await processor.on_embeddings("text2", model="model-b")
|
||||
await processor.on_embeddings("text3", model="model-a")
|
||||
await processor.on_embeddings("text4") # Use default
|
||||
await processor.on_embeddings(["text1"], model="model-a")
|
||||
await processor.on_embeddings(["text2"], model="model-b")
|
||||
await processor.on_embeddings(["text3"], model="model-a")
|
||||
await processor.on_embeddings(["text4"]) # Use default
|
||||
|
||||
# Assert
|
||||
calls = mock_ollama_client.embed.call_args_list
|
||||
|
|
@ -135,12 +135,12 @@ class TestOllamaDynamicModelLoading(IsolatedAsyncioTestCase):
|
|||
processor = Processor(id="test", concurrency=1, model="test-model", taskgroup=AsyncMock())
|
||||
|
||||
# Act
|
||||
result = await processor.on_embeddings("test text", model=None)
|
||||
result = await processor.on_embeddings(["test text"], model=None)
|
||||
|
||||
# Assert
|
||||
mock_ollama_client.embed.assert_called_once_with(
|
||||
model="test-model",
|
||||
input="test text"
|
||||
input=["test text"]
|
||||
)
|
||||
|
||||
@patch('trustgraph.embeddings.ollama.processor.Client')
|
||||
|
|
|
|||
|
|
@ -353,7 +353,14 @@ class TestRowEmbeddingsProcessor(IsolatedAsyncioTestCase):
|
|||
|
||||
# Mock the flow
|
||||
mock_embeddings_request = AsyncMock()
|
||||
mock_embeddings_request.embed.return_value = [[0.1, 0.2, 0.3]]
|
||||
# Return batch of vector sets (one per text)
|
||||
# 4 unique texts: CUST001, John Doe, CUST002, Jane Smith
|
||||
mock_embeddings_request.embed.return_value = [
|
||||
[[0.1, 0.2, 0.3]], # vectors for text 1
|
||||
[[0.2, 0.3, 0.4]], # vectors for text 2
|
||||
[[0.3, 0.4, 0.5]], # vectors for text 3
|
||||
[[0.4, 0.5, 0.6]], # vectors for text 4
|
||||
]
|
||||
|
||||
mock_output = AsyncMock()
|
||||
|
||||
|
|
@ -368,9 +375,12 @@ class TestRowEmbeddingsProcessor(IsolatedAsyncioTestCase):
|
|||
|
||||
await processor.on_message(mock_msg, MagicMock(), mock_flow)
|
||||
|
||||
# Should have called embed for each unique text
|
||||
# 4 values: CUST001, John Doe, CUST002, Jane Smith
|
||||
assert mock_embeddings_request.embed.call_count == 4
|
||||
# Should have called embed once with all texts in a batch
|
||||
assert mock_embeddings_request.embed.call_count == 1
|
||||
# Verify it was called with a list of texts
|
||||
call_args = mock_embeddings_request.embed.call_args
|
||||
assert 'texts' in call_args.kwargs
|
||||
assert len(call_args.kwargs['texts']) == 4
|
||||
|
||||
# Should have sent output
|
||||
mock_output.send.assert_called()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue