From 0a2ce47a882d02181442dc6ad4b64a36ceab2116 Mon Sep 17 00:00:00 2001 From: cybermaggedon Date: Sun, 8 Mar 2026 18:36:54 +0000 Subject: [PATCH] Batch embeddings (#668) Base Service (trustgraph-base/trustgraph/base/embeddings_service.py): - Changed on_request to use request.texts FastEmbed Processor (trustgraph-flow/trustgraph/embeddings/fastembed/processor.py): - on_embeddings(texts, model=None) now processes full batch efficiently - Returns [[v.tolist()] for v in vecs] - list of vector sets Ollama Processor (trustgraph-flow/trustgraph/embeddings/ollama/processor.py): - on_embeddings(texts, model=None) passes list directly to Ollama - Returns [[embedding] for embedding in embeds.embeddings] EmbeddingsClient (trustgraph-base/trustgraph/base/embeddings_client.py): - embed(texts, timeout=300) accepts list of texts Tests Updated: - test_fastembed_dynamic_model.py - 4 tests updated for new interface - test_ollama_dynamic_model.py - 4 tests updated for new interface Updated CLI, SDK and APIs --- .../tech-specs/embeddings-batch-processing.md | 667 ++++++++++++++++++ .../test_fastembed_dynamic_model.py | 16 +- .../test_ollama_dynamic_model.py | 24 +- .../test_row_embeddings_processor.py | 18 +- trustgraph-base/trustgraph/api/flow.py | 17 +- .../trustgraph/api/socket_client.py | 12 +- .../trustgraph/base/embeddings_client.py | 4 +- .../trustgraph/base/embeddings_service.py | 4 +- .../messaging/translators/embeddings.py | 8 +- .../trustgraph/schema/services/llm.py | 4 +- .../trustgraph/cli/invoke_embeddings.py | 19 +- .../document_embeddings/embeddings.py | 6 +- .../embeddings/fastembed/processor.py | 11 +- .../embeddings/graph_embeddings/embeddings.py | 26 +- .../trustgraph/embeddings/ollama/processor.py | 14 +- .../embeddings/row_embeddings/embeddings.py | 14 +- 16 files changed, 785 insertions(+), 79 deletions(-) create mode 100644 docs/tech-specs/embeddings-batch-processing.md diff --git a/docs/tech-specs/embeddings-batch-processing.md b/docs/tech-specs/embeddings-batch-processing.md new file mode 100644 index 00000000..59feb0ff --- /dev/null +++ b/docs/tech-specs/embeddings-batch-processing.md @@ -0,0 +1,667 @@ +# Embeddings Batch Processing Technical Specification + +## Overview + +This specification describes optimizations for the embeddings service to support batch processing of multiple texts in a single request. The current implementation processes one text at a time, missing the significant performance benefits that embedding models provide when processing batches. + +1. **Single-Text Processing Inefficiency**: Current implementation wraps single texts in a list, underutilizing FastEmbed's batch capabilities +2. **Request-Per-Text Overhead**: Each text requires a separate Pulsar message round-trip +3. **Model Inference Inefficiency**: Embedding models have fixed per-batch overhead; small batches waste GPU/CPU resources +4. **Serial Processing in Callers**: Key services loop over items and call embeddings one at a time + +## Goals + +- **Batch API Support**: Enable processing multiple texts in a single request +- **Backward Compatibility**: Maintain support for single-text requests +- **Significant Throughput Improvement**: Target 5-10x throughput improvement for bulk operations +- **Reduced Latency per Text**: Lower amortized latency when embedding multiple texts +- **Memory Efficiency**: Process batches without excessive memory consumption +- **Provider Agnostic**: Support batching across FastEmbed, Ollama, and other providers +- **Caller Migration**: Update all embedding callers to use batch API where beneficial + +## Background + +### Current Implementation - Embeddings Service + +The embeddings implementation in `trustgraph-flow/trustgraph/embeddings/fastembed/processor.py` exhibits a significant performance inefficiency: + +```python +# fastembed/processor.py line 56 +async def on_embeddings(self, text, model=None): + use_model = model or self.default_model + self._load_model(use_model) + + vecs = self.embeddings.embed([text]) # Single text wrapped in list + + return [v.tolist() for v in vecs] +``` + +**Problems:** + +1. **Batch Size 1**: FastEmbed's `embed()` method is optimized for batch processing, but we always call it with `[text]` - a batch of size 1 + +2. **Per-Request Overhead**: Each embedding request incurs: + - Pulsar message serialization/deserialization + - Network round-trip latency + - Model inference startup overhead + - Python async scheduling overhead + +3. **Schema Limitation**: The `EmbeddingsRequest` schema only supports a single text: + ```python + @dataclass + class EmbeddingsRequest: + text: str = "" # Single text only + ``` + +### Current Callers - Serial Processing + +#### 1. API Gateway + +**File:** `trustgraph-flow/trustgraph/gateway/dispatch/embeddings.py` + +The gateway accepts single-text embedding requests via HTTP/WebSocket and forwards them to the embeddings service. Currently no batch endpoint exists. + +```python +class EmbeddingsRequestor(ServiceRequestor): + # Handles single EmbeddingsRequest -> EmbeddingsResponse + request_schema=EmbeddingsRequest, # Single text only + response_schema=EmbeddingsResponse, +``` + +**Impact:** External clients (web apps, scripts) must make N HTTP requests to embed N texts. + +#### 2. Document Embeddings Service + +**File:** `trustgraph-flow/trustgraph/embeddings/document_embeddings/embeddings.py` + +Processes document chunks one at a time: + +```python +async def on_message(self, msg, consumer, flow): + v = msg.value() + + # Single chunk per request + resp = await flow("embeddings-request").request( + EmbeddingsRequest(text=v.chunk) + ) + vectors = resp.vectors +``` + +**Impact:** Each document chunk requires a separate embedding call. A document with 100 chunks = 100 embedding requests. + +#### 3. Graph Embeddings Service + +**File:** `trustgraph-flow/trustgraph/embeddings/graph_embeddings/embeddings.py` + +Loops over entities and embeds each one serially: + +```python +async def on_message(self, msg, consumer, flow): + for entity in v.entities: + # Serial embedding - one entity at a time + vectors = await flow("embeddings-request").embed( + text=entity.context + ) + entities.append(EntityEmbeddings( + entity=entity.entity, + vectors=vectors, + chunk_id=entity.chunk_id, + )) +``` + +**Impact:** A message with 50 entities = 50 serial embedding requests. This is a major bottleneck during knowledge graph construction. + +#### 4. Row Embeddings Service + +**File:** `trustgraph-flow/trustgraph/embeddings/row_embeddings/embeddings.py` + +Loops over unique texts and embeds each one serially: + +```python +async def on_message(self, msg, consumer, flow): + for text, (index_name, index_value) in texts_to_embed.items(): + # Serial embedding - one text at a time + vectors = await flow("embeddings-request").embed(text=text) + + embeddings_list.append(RowIndexEmbedding( + index_name=index_name, + index_value=index_value, + text=text, + vectors=vectors + )) +``` + +**Impact:** Processing a table with 100 unique indexed values = 100 serial embedding requests. + +#### 5. EmbeddingsClient (Base Client) + +**File:** `trustgraph-base/trustgraph/base/embeddings_client.py` + +The client used by all flow processors only supports single-text embedding: + +```python +class EmbeddingsClient(RequestResponse): + async def embed(self, text, timeout=30): + resp = await self.request( + EmbeddingsRequest(text=text), # Single text + timeout=timeout + ) + return resp.vectors +``` + +**Impact:** All callers using this client are limited to single-text operations. + +#### 6. Command-Line Tools + +**File:** `trustgraph-cli/trustgraph/cli/invoke_embeddings.py` + +CLI tool accepts single text argument: + +```python +def query(url, flow_id, text, token=None): + result = flow.embeddings(text=text) # Single text + vectors = result.get("vectors", []) +``` + +**Impact:** Users cannot batch-embed from command line. Processing a file of texts requires N invocations. + +#### 7. Python SDK + +The Python SDK provides two client classes for interacting with TrustGraph services. Both only support single-text embedding. + +**File:** `trustgraph-base/trustgraph/api/flow.py` + +```python +class FlowInstance: + def embeddings(self, text): + """Get embeddings for a single text""" + input = {"text": text} + return self.request("service/embeddings", input)["vectors"] +``` + +**File:** `trustgraph-base/trustgraph/api/socket_client.py` + +```python +class SocketFlowInstance: + def embeddings(self, text: str, **kwargs: Any) -> Dict[str, Any]: + """Get embeddings for a single text via WebSocket""" + request = {"text": text} + return self.client._send_request_sync( + "embeddings", self.flow_id, request, False + ) +``` + +**Impact:** Python developers using the SDK must loop over texts and make N separate API calls. No batch embedding support exists for SDK users. + +### Performance Impact + +For typical document ingestion (1000 text chunks): +- **Current**: 1000 separate requests, 1000 model inference calls +- **Batched (batch_size=32)**: 32 requests, 32 model inference calls (96.8% reduction) + +For graph embedding (message with 50 entities): +- **Current**: 50 serial await calls, ~5-10 seconds +- **Batched**: 1-2 batch calls, ~0.5-1 second (5-10x improvement) + +FastEmbed and similar libraries achieve near-linear throughput scaling with batch size up to hardware limits (typically 32-128 texts per batch). + +## Technical Design + +### Architecture + +The embeddings batch processing optimization requires changes to the following components: + +#### 1. **Schema Enhancement** + - Extend `EmbeddingsRequest` to support multiple texts + - Extend `EmbeddingsResponse` to return multiple vector sets + - Maintain backward compatibility with single-text requests + + Module: `trustgraph-base/trustgraph/schema/services/llm.py` + +#### 2. **Base Service Enhancement** + - Update `EmbeddingsService` to handle batch requests + - Add batch size configuration + - Implement batch-aware request handling + + Module: `trustgraph-base/trustgraph/base/embeddings_service.py` + +#### 3. **Provider Processor Updates** + - Update FastEmbed processor to pass full batch to `embed()` + - Update Ollama processor to handle batches (if supported) + - Add fallback sequential processing for providers without batch support + + Modules: + - `trustgraph-flow/trustgraph/embeddings/fastembed/processor.py` + - `trustgraph-flow/trustgraph/embeddings/ollama/processor.py` + +#### 4. **Client Enhancement** + - Add batch embedding method to `EmbeddingsClient` + - Support both single and batch APIs + - Add automatic batching for large inputs + + Module: `trustgraph-base/trustgraph/base/embeddings_client.py` + +#### 5. **Caller Updates - Flow Processors** + - Update `graph_embeddings` to batch entity contexts + - Update `row_embeddings` to batch index texts + - Update `document_embeddings` if message batching is feasible + + Modules: + - `trustgraph-flow/trustgraph/embeddings/graph_embeddings/embeddings.py` + - `trustgraph-flow/trustgraph/embeddings/row_embeddings/embeddings.py` + - `trustgraph-flow/trustgraph/embeddings/document_embeddings/embeddings.py` + +#### 6. **API Gateway Enhancement** + - Add batch embedding endpoint + - Support array of texts in request body + + Module: `trustgraph-flow/trustgraph/gateway/dispatch/embeddings.py` + +#### 7. **CLI Tool Enhancement** + - Add support for multiple texts or file input + - Add batch size parameter + + Module: `trustgraph-cli/trustgraph/cli/invoke_embeddings.py` + +#### 8. **Python SDK Enhancement** + - Add `embeddings_batch()` method to `FlowInstance` + - Add `embeddings_batch()` method to `SocketFlowInstance` + - Support both single and batch APIs for SDK users + + Modules: + - `trustgraph-base/trustgraph/api/flow.py` + - `trustgraph-base/trustgraph/api/socket_client.py` + +### Data Models + +#### EmbeddingsRequest + +```python +@dataclass +class EmbeddingsRequest: + texts: list[str] = field(default_factory=list) +``` + +Usage: +- Single text: `EmbeddingsRequest(texts=["hello world"])` +- Batch: `EmbeddingsRequest(texts=["text1", "text2", "text3"])` + +#### EmbeddingsResponse + +```python +@dataclass +class EmbeddingsResponse: + error: Error | None = None + vectors: list[list[list[float]]] = field(default_factory=list) +``` + +Response structure: +- `vectors[i]` contains the vector set for `texts[i]` +- Each vector set is `list[list[float]]` (models may return multiple vectors per text) +- Example: 3 texts → `vectors` has 3 entries, each containing that text's embeddings + +### APIs + +#### EmbeddingsClient + +```python +class EmbeddingsClient(RequestResponse): + async def embed( + self, + texts: list[str], + timeout: float = 300, + ) -> list[list[list[float]]]: + """ + Embed one or more texts in a single request. + + Args: + texts: List of texts to embed + timeout: Timeout for the operation + + Returns: + List of vector sets, one per input text + """ + resp = await self.request( + EmbeddingsRequest(texts=texts), + timeout=timeout + ) + if resp.error: + raise RuntimeError(resp.error.message) + return resp.vectors +``` + +#### API Gateway Embeddings Endpoint + +Updated endpoint supporting single or batch embedding: + +``` +POST /api/v1/embeddings +Content-Type: application/json + +{ + "texts": ["text1", "text2", "text3"], + "flow_id": "default" +} + +Response: +{ + "vectors": [ + [[0.1, 0.2, ...]], + [[0.3, 0.4, ...]], + [[0.5, 0.6, ...]] + ] +} +``` + +### Implementation Details + +#### Phase 1: Schema Changes + +**EmbeddingsRequest:** +```python +@dataclass +class EmbeddingsRequest: + texts: list[str] = field(default_factory=list) +``` + +**EmbeddingsResponse:** +```python +@dataclass +class EmbeddingsResponse: + error: Error | None = None + vectors: list[list[list[float]]] = field(default_factory=list) +``` + +**Updated EmbeddingsService.on_request:** +```python +async def on_request(self, msg, consumer, flow): + request = msg.value() + id = msg.properties()["id"] + model = flow("model") + + vectors = await self.on_embeddings(request.texts, model=model) + response = EmbeddingsResponse(error=None, vectors=vectors) + + await flow("response").send(response, properties={"id": id}) +``` + +#### Phase 2: FastEmbed Processor Update + +**Current (Inefficient):** +```python +async def on_embeddings(self, text, model=None): + use_model = model or self.default_model + self._load_model(use_model) + vecs = self.embeddings.embed([text]) # Batch of 1 + return [v.tolist() for v in vecs] +``` + +**Updated:** +```python +async def on_embeddings(self, texts: list[str], model=None): + """Embed texts - processes all texts in single model call""" + if not texts: + return [] + + use_model = model or self.default_model + self._load_model(use_model) + + # FastEmbed handles the full batch efficiently + all_vecs = list(self.embeddings.embed(texts)) + + # Return list of vector sets, one per input text + return [[v.tolist()] for v in all_vecs] +``` + +#### Phase 3: Graph Embeddings Service Update + +**Current (Serial):** +```python +async def on_message(self, msg, consumer, flow): + entities = [] + for entity in v.entities: + vectors = await flow("embeddings-request").embed(text=entity.context) + entities.append(EntityEmbeddings(...)) +``` + +**Updated (Batch):** +```python +async def on_message(self, msg, consumer, flow): + # Collect all contexts + contexts = [entity.context for entity in v.entities] + + # Single batch embedding call + all_vectors = await flow("embeddings-request").embed(texts=contexts) + + # Pair results with entities + entities = [ + EntityEmbeddings( + entity=entity.entity, + vectors=vectors[0], # First vector from the set + chunk_id=entity.chunk_id, + ) + for entity, vectors in zip(v.entities, all_vectors) + ] +``` + +#### Phase 4: Row Embeddings Service Update + +**Current (Serial):** +```python +for text, (index_name, index_value) in texts_to_embed.items(): + vectors = await flow("embeddings-request").embed(text=text) + embeddings_list.append(RowIndexEmbedding(...)) +``` + +**Updated (Batch):** +```python +# Collect texts and metadata +texts = list(texts_to_embed.keys()) +metadata = list(texts_to_embed.values()) + +# Single batch embedding call +all_vectors = await flow("embeddings-request").embed(texts=texts) + +# Pair results +embeddings_list = [ + RowIndexEmbedding( + index_name=meta[0], + index_value=meta[1], + text=text, + vectors=vectors[0] # First vector from the set + ) + for text, meta, vectors in zip(texts, metadata, all_vectors) +] +``` + +#### Phase 5: CLI Tool Enhancement + +**Updated CLI:** +```python +def main(): + parser = argparse.ArgumentParser(...) + + parser.add_argument( + 'text', + nargs='*', # Zero or more texts + help='Text(s) to convert to embedding vectors', + ) + + parser.add_argument( + '-f', '--file', + help='File containing texts (one per line)', + ) + + parser.add_argument( + '--batch-size', + type=int, + default=32, + help='Batch size for processing (default: 32)', + ) +``` + +Usage: +```bash +# Single text (existing) +tg-invoke-embeddings "hello world" + +# Multiple texts +tg-invoke-embeddings "text one" "text two" "text three" + +# From file +tg-invoke-embeddings -f texts.txt --batch-size 64 +``` + +#### Phase 6: Python SDK Enhancement + +**FlowInstance (HTTP client):** + +```python +class FlowInstance: + def embeddings(self, texts: list[str]) -> list[list[list[float]]]: + """ + Get embeddings for one or more texts. + + Args: + texts: List of texts to embed + + Returns: + List of vector sets, one per input text + """ + input = {"texts": texts} + return self.request("service/embeddings", input)["vectors"] +``` + +**SocketFlowInstance (WebSocket client):** + +```python +class SocketFlowInstance: + def embeddings(self, texts: list[str], **kwargs: Any) -> list[list[list[float]]]: + """ + Get embeddings for one or more texts via WebSocket. + + Args: + texts: List of texts to embed + + Returns: + List of vector sets, one per input text + """ + request = {"texts": texts} + response = self.client._send_request_sync( + "embeddings", self.flow_id, request, False + ) + return response["vectors"] +``` + +**SDK Usage Examples:** + +```python +# Single text +vectors = flow.embeddings(["hello world"]) +print(f"Dimensions: {len(vectors[0][0])}") + +# Batch embedding +texts = ["text one", "text two", "text three"] +all_vectors = flow.embeddings(texts) + +# Process results +for text, vecs in zip(texts, all_vectors): + print(f"{text}: {len(vecs[0])} dimensions") +``` + +## Security Considerations + +- **Request Size Limits**: Enforce maximum batch size to prevent resource exhaustion +- **Timeout Handling**: Scale timeouts appropriately for batch size +- **Memory Limits**: Monitor memory usage for large batches +- **Input Validation**: Validate all texts in batch before processing + +## Performance Considerations + +### Expected Improvements + +**Throughput:** +- Single-text: ~10-50 texts/second (depending on model) +- Batch (size 32): ~200-500 texts/second (5-10x improvement) + +**Latency per Text:** +- Single-text: 50-200ms per text +- Batch (size 32): 5-20ms per text (amortized) + +**Service-Specific Improvements:** + +| Service | Current | Batched | Improvement | +|---------|---------|---------|-------------| +| Graph Embeddings (50 entities) | 5-10s | 0.5-1s | 5-10x | +| Row Embeddings (100 texts) | 10-20s | 1-2s | 5-10x | +| Document Ingestion (1000 chunks) | 100-200s | 10-30s | 5-10x | + +### Configuration Parameters + +```python +# Recommended defaults +DEFAULT_BATCH_SIZE = 32 +MAX_BATCH_SIZE = 128 +BATCH_TIMEOUT_MULTIPLIER = 2.0 +``` + +## Testing Strategy + +### Unit Testing +- Single text embedding (backward compatibility) +- Empty batch handling +- Maximum batch size enforcement +- Error handling for partial batch failures + +### Integration Testing +- End-to-end batch embedding through Pulsar +- Graph embeddings service batch processing +- Row embeddings service batch processing +- API gateway batch endpoint + +### Performance Testing +- Benchmark single vs batch throughput +- Memory usage under various batch sizes +- Latency distribution analysis + +## Migration Plan + +This is a breaking change release. All phases are implemented together. + +### Phase 1: Schema Changes +- Replace `text: str` with `texts: list[str]` in EmbeddingsRequest +- Change `vectors` type to `list[list[list[float]]]` in EmbeddingsResponse + +### Phase 2: Processor Updates +- Update `on_embeddings` signature in FastEmbed and Ollama processors +- Process full batch in single model call + +### Phase 3: Client Updates +- Update `EmbeddingsClient.embed()` to accept `texts: list[str]` + +### Phase 4: Caller Updates +- Update graph_embeddings to batch entity contexts +- Update row_embeddings to batch index texts +- Update document_embeddings to use new schema +- Update CLI tool + +### Phase 5: API Gateway +- Update embeddings endpoint for new schema + +### Phase 6: Python SDK +- Update `FlowInstance.embeddings()` signature +- Update `SocketFlowInstance.embeddings()` signature + +## Open Questions + +- **Streaming Large Batches**: Should we support streaming results for very large batches (>100 texts)? +- **Provider-Specific Limits**: How should we handle providers with different maximum batch sizes? +- **Partial Failure Handling**: If one text in a batch fails, should we fail the entire batch or return partial results? +- **Document Embeddings Batching**: Should we batch across multiple Chunk messages or keep per-message processing? + +## References + +- [FastEmbed Documentation](https://github.com/qdrant/fastembed) +- [Ollama Embeddings API](https://github.com/ollama/ollama) +- [EmbeddingsService Implementation](trustgraph-base/trustgraph/base/embeddings_service.py) +- [GraphRAG Performance Optimization](graphrag-performance-optimization.md) diff --git a/tests/unit/test_embeddings/test_fastembed_dynamic_model.py b/tests/unit/test_embeddings/test_fastembed_dynamic_model.py index 1c1fb883..ca43bf83 100644 --- a/tests/unit/test_embeddings/test_fastembed_dynamic_model.py +++ b/tests/unit/test_embeddings/test_fastembed_dynamic_model.py @@ -103,12 +103,12 @@ class TestFastEmbedDynamicModelLoading(IsolatedAsyncioTestCase): mock_text_embedding_class.reset_mock() # Act - result = await processor.on_embeddings("test text") + result = await processor.on_embeddings(["test text"]) # Assert mock_fastembed_instance.embed.assert_called_once_with(["test text"]) assert processor.cached_model_name == "test-model" # Still using default - assert result == [[0.1, 0.2, 0.3, 0.4, 0.5]] + assert result == [[[0.1, 0.2, 0.3, 0.4, 0.5]]] @patch('trustgraph.embeddings.fastembed.processor.TextEmbedding') @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') @@ -126,7 +126,7 @@ class TestFastEmbedDynamicModelLoading(IsolatedAsyncioTestCase): mock_text_embedding_class.reset_mock() # Act - result = await processor.on_embeddings("test text", model="custom-model") + result = await processor.on_embeddings(["test text"], model="custom-model") # Assert mock_text_embedding_class.assert_called_once_with(model_name="custom-model") @@ -149,16 +149,16 @@ class TestFastEmbedDynamicModelLoading(IsolatedAsyncioTestCase): initial_call_count = mock_text_embedding_class.call_count # Act - switch between models - await processor.on_embeddings("text1", model="model-a") + await processor.on_embeddings(["text1"], model="model-a") call_count_after_a = mock_text_embedding_class.call_count - await processor.on_embeddings("text2", model="model-a") # Same, no reload + await processor.on_embeddings(["text2"], model="model-a") # Same, no reload call_count_after_a_repeat = mock_text_embedding_class.call_count - await processor.on_embeddings("text3", model="model-b") # Different, reload + await processor.on_embeddings(["text3"], model="model-b") # Different, reload call_count_after_b = mock_text_embedding_class.call_count - await processor.on_embeddings("text4", model="model-a") # Back to A, reload + await processor.on_embeddings(["text4"], model="model-a") # Back to A, reload call_count_after_a_again = mock_text_embedding_class.call_count # Assert @@ -183,7 +183,7 @@ class TestFastEmbedDynamicModelLoading(IsolatedAsyncioTestCase): initial_count = mock_text_embedding_class.call_count # Act - result = await processor.on_embeddings("test text", model=None) + result = await processor.on_embeddings(["test text"], model=None) # Assert # No reload, using cached default diff --git a/tests/unit/test_embeddings/test_ollama_dynamic_model.py b/tests/unit/test_embeddings/test_ollama_dynamic_model.py index ca0f44bf..80e1de4e 100644 --- a/tests/unit/test_embeddings/test_ollama_dynamic_model.py +++ b/tests/unit/test_embeddings/test_ollama_dynamic_model.py @@ -53,14 +53,14 @@ class TestOllamaDynamicModelLoading(IsolatedAsyncioTestCase): processor = Processor(id="test", concurrency=1, model="test-model", taskgroup=AsyncMock()) # Act - result = await processor.on_embeddings("test text") + result = await processor.on_embeddings(["test text"]) # Assert mock_ollama_client.embed.assert_called_once_with( model="test-model", - input="test text" + input=["test text"] ) - assert result == [[0.1, 0.2, 0.3, 0.4, 0.5]] + assert result == [[[0.1, 0.2, 0.3, 0.4, 0.5]]] @patch('trustgraph.embeddings.ollama.processor.Client') @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') @@ -79,14 +79,14 @@ class TestOllamaDynamicModelLoading(IsolatedAsyncioTestCase): processor = Processor(id="test", concurrency=1, model="test-model", taskgroup=AsyncMock()) # Act - result = await processor.on_embeddings("test text", model="custom-model") + result = await processor.on_embeddings(["test text"], model="custom-model") # Assert mock_ollama_client.embed.assert_called_once_with( model="custom-model", - input="test text" + input=["test text"] ) - assert result == [[0.1, 0.2, 0.3, 0.4, 0.5]] + assert result == [[[0.1, 0.2, 0.3, 0.4, 0.5]]] @patch('trustgraph.embeddings.ollama.processor.Client') @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') @@ -105,10 +105,10 @@ class TestOllamaDynamicModelLoading(IsolatedAsyncioTestCase): processor = Processor(id="test", concurrency=1, model="test-model", taskgroup=AsyncMock()) # Act - switch between different models - await processor.on_embeddings("text1", model="model-a") - await processor.on_embeddings("text2", model="model-b") - await processor.on_embeddings("text3", model="model-a") - await processor.on_embeddings("text4") # Use default + await processor.on_embeddings(["text1"], model="model-a") + await processor.on_embeddings(["text2"], model="model-b") + await processor.on_embeddings(["text3"], model="model-a") + await processor.on_embeddings(["text4"]) # Use default # Assert calls = mock_ollama_client.embed.call_args_list @@ -135,12 +135,12 @@ class TestOllamaDynamicModelLoading(IsolatedAsyncioTestCase): processor = Processor(id="test", concurrency=1, model="test-model", taskgroup=AsyncMock()) # Act - result = await processor.on_embeddings("test text", model=None) + result = await processor.on_embeddings(["test text"], model=None) # Assert mock_ollama_client.embed.assert_called_once_with( model="test-model", - input="test text" + input=["test text"] ) @patch('trustgraph.embeddings.ollama.processor.Client') diff --git a/tests/unit/test_embeddings/test_row_embeddings_processor.py b/tests/unit/test_embeddings/test_row_embeddings_processor.py index 47405431..45a22e48 100644 --- a/tests/unit/test_embeddings/test_row_embeddings_processor.py +++ b/tests/unit/test_embeddings/test_row_embeddings_processor.py @@ -353,7 +353,14 @@ class TestRowEmbeddingsProcessor(IsolatedAsyncioTestCase): # Mock the flow mock_embeddings_request = AsyncMock() - mock_embeddings_request.embed.return_value = [[0.1, 0.2, 0.3]] + # Return batch of vector sets (one per text) + # 4 unique texts: CUST001, John Doe, CUST002, Jane Smith + mock_embeddings_request.embed.return_value = [ + [[0.1, 0.2, 0.3]], # vectors for text 1 + [[0.2, 0.3, 0.4]], # vectors for text 2 + [[0.3, 0.4, 0.5]], # vectors for text 3 + [[0.4, 0.5, 0.6]], # vectors for text 4 + ] mock_output = AsyncMock() @@ -368,9 +375,12 @@ class TestRowEmbeddingsProcessor(IsolatedAsyncioTestCase): await processor.on_message(mock_msg, MagicMock(), mock_flow) - # Should have called embed for each unique text - # 4 values: CUST001, John Doe, CUST002, Jane Smith - assert mock_embeddings_request.embed.call_count == 4 + # Should have called embed once with all texts in a batch + assert mock_embeddings_request.embed.call_count == 1 + # Verify it was called with a list of texts + call_args = mock_embeddings_request.embed.call_args + assert 'texts' in call_args.kwargs + assert len(call_args.kwargs['texts']) == 4 # Should have sent output mock_output.send.assert_called() diff --git a/trustgraph-base/trustgraph/api/flow.py b/trustgraph-base/trustgraph/api/flow.py index c50bf9c4..5142aac4 100644 --- a/trustgraph-base/trustgraph/api/flow.py +++ b/trustgraph-base/trustgraph/api/flow.py @@ -544,30 +544,29 @@ class FlowInstance: input )["response"] - def embeddings(self, text): + def embeddings(self, texts): """ - Generate vector embeddings for text. + Generate vector embeddings for one or more texts. - Converts text into dense vector representations suitable for semantic + Converts texts into dense vector representations suitable for semantic search and similarity comparison. Args: - text: Input text to embed + texts: List of input texts to embed Returns: - list[float]: Vector embedding + list[list[list[float]]]: Vector embeddings, one set per input text Example: ```python flow = api.flow().id("default") - vectors = flow.embeddings("quantum computing") - print(f"Embedding dimension: {len(vectors)}") + vectors = flow.embeddings(["quantum computing"]) + print(f"Embedding dimension: {len(vectors[0][0])}") ``` """ - # The input consists of a text block input = { - "text": text + "texts": texts } return self.request( diff --git a/trustgraph-base/trustgraph/api/socket_client.py b/trustgraph-base/trustgraph/api/socket_client.py index b471b535..e5d5a356 100644 --- a/trustgraph-base/trustgraph/api/socket_client.py +++ b/trustgraph-base/trustgraph/api/socket_client.py @@ -712,27 +712,27 @@ class SocketFlowInstance: return self.client._send_request_sync("document-embeddings", self.flow_id, request, False) - def embeddings(self, text: str, **kwargs: Any) -> Dict[str, Any]: + def embeddings(self, texts: list, **kwargs: Any) -> Dict[str, Any]: """ - Generate vector embeddings for text. + Generate vector embeddings for one or more texts. Args: - text: Input text to embed + texts: List of input texts to embed **kwargs: Additional parameters passed to the service Returns: - dict: Response containing vectors + dict: Response containing vectors (one set per input text) Example: ```python socket = api.socket() flow = socket.flow("default") - result = flow.embeddings("quantum computing") + result = flow.embeddings(["quantum computing"]) vectors = result.get("vectors", []) ``` """ - request = {"text": text} + request = {"texts": texts} request.update(kwargs) return self.client._send_request_sync("embeddings", self.flow_id, request, False) diff --git a/trustgraph-base/trustgraph/base/embeddings_client.py b/trustgraph-base/trustgraph/base/embeddings_client.py index ceb08eb2..faaa192d 100644 --- a/trustgraph-base/trustgraph/base/embeddings_client.py +++ b/trustgraph-base/trustgraph/base/embeddings_client.py @@ -3,11 +3,11 @@ from . request_response_spec import RequestResponse, RequestResponseSpec from .. schema import EmbeddingsRequest, EmbeddingsResponse class EmbeddingsClient(RequestResponse): - async def embed(self, text, timeout=30): + async def embed(self, texts, timeout=300): resp = await self.request( EmbeddingsRequest( - text = text + texts = texts ), timeout=timeout ) diff --git a/trustgraph-base/trustgraph/base/embeddings_service.py b/trustgraph-base/trustgraph/base/embeddings_service.py index a1442d41..7ae63521 100644 --- a/trustgraph-base/trustgraph/base/embeddings_service.py +++ b/trustgraph-base/trustgraph/base/embeddings_service.py @@ -65,7 +65,7 @@ class EmbeddingsService(FlowProcessor): # Pass model from request if specified (non-empty), otherwise use default model = flow("model") - vectors = await self.on_embeddings(request.text, model=model) + vectors = await self.on_embeddings(request.texts, model=model) await flow("response").send( EmbeddingsResponse( @@ -94,7 +94,7 @@ class EmbeddingsService(FlowProcessor): type = "embeddings-error", message = str(e), ), - vectors=None, + vectors=[], ), properties={"id": id} ) diff --git a/trustgraph-base/trustgraph/messaging/translators/embeddings.py b/trustgraph-base/trustgraph/messaging/translators/embeddings.py index 7e6eff83..454ce733 100644 --- a/trustgraph-base/trustgraph/messaging/translators/embeddings.py +++ b/trustgraph-base/trustgraph/messaging/translators/embeddings.py @@ -5,15 +5,15 @@ from .base import MessageTranslator class EmbeddingsRequestTranslator(MessageTranslator): """Translator for EmbeddingsRequest schema objects""" - + def to_pulsar(self, data: Dict[str, Any]) -> EmbeddingsRequest: return EmbeddingsRequest( - text=data["text"] + texts=data["texts"] ) - + def from_pulsar(self, obj: EmbeddingsRequest) -> Dict[str, Any]: return { - "text": obj.text + "texts": obj.texts } diff --git a/trustgraph-base/trustgraph/schema/services/llm.py b/trustgraph-base/trustgraph/schema/services/llm.py index 1261158e..a9d19e51 100644 --- a/trustgraph-base/trustgraph/schema/services/llm.py +++ b/trustgraph-base/trustgraph/schema/services/llm.py @@ -29,12 +29,12 @@ class TextCompletionResponse: @dataclass class EmbeddingsRequest: - text: str = "" + texts: list[str] = field(default_factory=list) @dataclass class EmbeddingsResponse: error: Error | None = None - vectors: list[list[float]] = field(default_factory=list) + vectors: list[list[list[float]]] = field(default_factory=list) ############################################################################ diff --git a/trustgraph-cli/trustgraph/cli/invoke_embeddings.py b/trustgraph-cli/trustgraph/cli/invoke_embeddings.py index 71a88bd7..699a85cf 100644 --- a/trustgraph-cli/trustgraph/cli/invoke_embeddings.py +++ b/trustgraph-cli/trustgraph/cli/invoke_embeddings.py @@ -10,7 +10,7 @@ from trustgraph.api import Api default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_token = os.getenv("TRUSTGRAPH_TOKEN", None) -def query(url, flow_id, text, token=None): +def query(url, flow_id, texts, token=None): # Create API client api = Api(url=url, token=token) @@ -19,9 +19,14 @@ def query(url, flow_id, text, token=None): try: # Call embeddings service - result = flow.embeddings(text=text) + result = flow.embeddings(texts=texts) vectors = result.get("vectors", []) - print(vectors) + # Print each text's vectors + for i, vecs in enumerate(vectors): + if len(texts) > 1: + print(f"Text {i + 1}: {vecs}") + else: + print(vecs) finally: # Clean up socket connection @@ -53,9 +58,9 @@ def main(): ) parser.add_argument( - 'text', - nargs=1, - help='Text to convert to embedding vector', + 'texts', + nargs='+', + help='Text(s) to convert to embedding vectors', ) args = parser.parse_args() @@ -65,7 +70,7 @@ def main(): query( url=args.url, flow_id=args.flow_id, - text=args.text[0], + texts=args.texts, token=args.token, ) diff --git a/trustgraph-flow/trustgraph/embeddings/document_embeddings/embeddings.py b/trustgraph-flow/trustgraph/embeddings/document_embeddings/embeddings.py index 032e15c4..eb21d418 100755 --- a/trustgraph-flow/trustgraph/embeddings/document_embeddings/embeddings.py +++ b/trustgraph-flow/trustgraph/embeddings/document_embeddings/embeddings.py @@ -62,11 +62,13 @@ class Processor(FlowProcessor): resp = await flow("embeddings-request").request( EmbeddingsRequest( - text = v.chunk + texts=[v.chunk] ) ) - vectors = resp.vectors + # vectors[0] is the vector set for the first (only) text + # vectors[0][0] is the first vector in that set + vectors = resp.vectors[0][0] if resp.vectors else [] embeds = [ ChunkEmbeddings( diff --git a/trustgraph-flow/trustgraph/embeddings/fastembed/processor.py b/trustgraph-flow/trustgraph/embeddings/fastembed/processor.py index d1ce93ca..ac2c6f49 100755 --- a/trustgraph-flow/trustgraph/embeddings/fastembed/processor.py +++ b/trustgraph-flow/trustgraph/embeddings/fastembed/processor.py @@ -46,17 +46,22 @@ class Processor(EmbeddingsService): else: logger.debug(f"Using cached model: {model_name}") - async def on_embeddings(self, text, model=None): + async def on_embeddings(self, texts, model=None): + + if not texts: + return [] use_model = model or self.default_model # Reload model if it has changed self._load_model(use_model) - vecs = self.embeddings.embed([text]) + # FastEmbed processes the full batch efficiently + vecs = list(self.embeddings.embed(texts)) + # Return list of vector sets, one per input text return [ - v.tolist() + [v.tolist()] for v in vecs ] diff --git a/trustgraph-flow/trustgraph/embeddings/graph_embeddings/embeddings.py b/trustgraph-flow/trustgraph/embeddings/graph_embeddings/embeddings.py index fec528bd..c54e719d 100755 --- a/trustgraph-flow/trustgraph/embeddings/graph_embeddings/embeddings.py +++ b/trustgraph-flow/trustgraph/embeddings/graph_embeddings/embeddings.py @@ -58,23 +58,25 @@ class Processor(FlowProcessor): v = msg.value() logger.info(f"Indexing {v.metadata.id}...") - entities = [] - try: - for entity in v.entities: + # Collect all contexts for batch embedding + contexts = [entity.context for entity in v.entities] - vectors = await flow("embeddings-request").embed( - text = entity.context - ) + # Single batch embedding call + all_vectors = await flow("embeddings-request").embed( + texts=contexts + ) - entities.append( - EntityEmbeddings( - entity=entity.entity, - vectors=vectors, - chunk_id=entity.chunk_id, # Provenance: source chunk - ) + # Pair results with entities + entities = [ + EntityEmbeddings( + entity=entity.entity, + vectors=vectors[0], # First vector from the set + chunk_id=entity.chunk_id, # Provenance: source chunk ) + for entity, vectors in zip(v.entities, all_vectors) + ] # Send in batches to avoid oversized messages for i in range(0, len(entities), self.batch_size): diff --git a/trustgraph-flow/trustgraph/embeddings/ollama/processor.py b/trustgraph-flow/trustgraph/embeddings/ollama/processor.py index c951252e..c95850e2 100755 --- a/trustgraph-flow/trustgraph/embeddings/ollama/processor.py +++ b/trustgraph-flow/trustgraph/embeddings/ollama/processor.py @@ -30,16 +30,24 @@ class Processor(EmbeddingsService): self.client = Client(host=ollama) self.default_model = model - async def on_embeddings(self, text, model=None): + async def on_embeddings(self, texts, model=None): + + if not texts: + return [] use_model = model or self.default_model + # Ollama handles batch input efficiently embeds = self.client.embed( model = use_model, - input = text + input = texts ) - return embeds.embeddings + # Return list of vector sets, one per input text + return [ + [embedding] + for embedding in embeds.embeddings + ] @staticmethod def add_args(parser): diff --git a/trustgraph-flow/trustgraph/embeddings/row_embeddings/embeddings.py b/trustgraph-flow/trustgraph/embeddings/row_embeddings/embeddings.py index 84c41ff3..c1d04302 100644 --- a/trustgraph-flow/trustgraph/embeddings/row_embeddings/embeddings.py +++ b/trustgraph-flow/trustgraph/embeddings/row_embeddings/embeddings.py @@ -200,15 +200,23 @@ class Processor(CollectionConfigHandler, FlowProcessor): embeddings_list = [] try: - for text, (index_name, index_value) in texts_to_embed.items(): - vectors = await flow("embeddings-request").embed(text=text) + # Collect texts and metadata for batch embedding + texts = list(texts_to_embed.keys()) + metadata = list(texts_to_embed.values()) + # Single batch embedding call + all_vectors = await flow("embeddings-request").embed(texts=texts) + + # Pair results with metadata + for text, (index_name, index_value), vectors in zip( + texts, metadata, all_vectors + ): embeddings_list.append( RowIndexEmbedding( index_name=index_name, index_value=index_value, text=text, - vectors=vectors + vectors=vectors[0] # First vector from the set ) )