Batch embeddings (#668)

Base Service (trustgraph-base/trustgraph/base/embeddings_service.py):
- Changed on_request to use request.texts

FastEmbed Processor
(trustgraph-flow/trustgraph/embeddings/fastembed/processor.py):
- on_embeddings(texts, model=None) now processes full batch efficiently
- Returns [[v.tolist()] for v in vecs] - list of vector sets

Ollama Processor (trustgraph-flow/trustgraph/embeddings/ollama/processor.py):
- on_embeddings(texts, model=None) passes list directly to Ollama
- Returns [[embedding] for embedding in embeds.embeddings]

EmbeddingsClient (trustgraph-base/trustgraph/base/embeddings_client.py):
- embed(texts, timeout=300) accepts list of texts

Tests Updated:
- test_fastembed_dynamic_model.py - 4 tests updated for new interface
- test_ollama_dynamic_model.py - 4 tests updated for new interface

Updated CLI, SDK and APIs
This commit is contained in:
cybermaggedon 2026-03-08 18:36:54 +00:00 committed by GitHub
parent 3bf8a65409
commit 0a2ce47a88
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
16 changed files with 785 additions and 79 deletions

View file

@ -0,0 +1,667 @@
# Embeddings Batch Processing Technical Specification
## Overview
This specification describes optimizations for the embeddings service to support batch processing of multiple texts in a single request. The current implementation processes one text at a time, missing the significant performance benefits that embedding models provide when processing batches.
1. **Single-Text Processing Inefficiency**: Current implementation wraps single texts in a list, underutilizing FastEmbed's batch capabilities
2. **Request-Per-Text Overhead**: Each text requires a separate Pulsar message round-trip
3. **Model Inference Inefficiency**: Embedding models have fixed per-batch overhead; small batches waste GPU/CPU resources
4. **Serial Processing in Callers**: Key services loop over items and call embeddings one at a time
## Goals
- **Batch API Support**: Enable processing multiple texts in a single request
- **Backward Compatibility**: Maintain support for single-text requests
- **Significant Throughput Improvement**: Target 5-10x throughput improvement for bulk operations
- **Reduced Latency per Text**: Lower amortized latency when embedding multiple texts
- **Memory Efficiency**: Process batches without excessive memory consumption
- **Provider Agnostic**: Support batching across FastEmbed, Ollama, and other providers
- **Caller Migration**: Update all embedding callers to use batch API where beneficial
## Background
### Current Implementation - Embeddings Service
The embeddings implementation in `trustgraph-flow/trustgraph/embeddings/fastembed/processor.py` exhibits a significant performance inefficiency:
```python
# fastembed/processor.py line 56
async def on_embeddings(self, text, model=None):
use_model = model or self.default_model
self._load_model(use_model)
vecs = self.embeddings.embed([text]) # Single text wrapped in list
return [v.tolist() for v in vecs]
```
**Problems:**
1. **Batch Size 1**: FastEmbed's `embed()` method is optimized for batch processing, but we always call it with `[text]` - a batch of size 1
2. **Per-Request Overhead**: Each embedding request incurs:
- Pulsar message serialization/deserialization
- Network round-trip latency
- Model inference startup overhead
- Python async scheduling overhead
3. **Schema Limitation**: The `EmbeddingsRequest` schema only supports a single text:
```python
@dataclass
class EmbeddingsRequest:
text: str = "" # Single text only
```
### Current Callers - Serial Processing
#### 1. API Gateway
**File:** `trustgraph-flow/trustgraph/gateway/dispatch/embeddings.py`
The gateway accepts single-text embedding requests via HTTP/WebSocket and forwards them to the embeddings service. Currently no batch endpoint exists.
```python
class EmbeddingsRequestor(ServiceRequestor):
# Handles single EmbeddingsRequest -> EmbeddingsResponse
request_schema=EmbeddingsRequest, # Single text only
response_schema=EmbeddingsResponse,
```
**Impact:** External clients (web apps, scripts) must make N HTTP requests to embed N texts.
#### 2. Document Embeddings Service
**File:** `trustgraph-flow/trustgraph/embeddings/document_embeddings/embeddings.py`
Processes document chunks one at a time:
```python
async def on_message(self, msg, consumer, flow):
v = msg.value()
# Single chunk per request
resp = await flow("embeddings-request").request(
EmbeddingsRequest(text=v.chunk)
)
vectors = resp.vectors
```
**Impact:** Each document chunk requires a separate embedding call. A document with 100 chunks = 100 embedding requests.
#### 3. Graph Embeddings Service
**File:** `trustgraph-flow/trustgraph/embeddings/graph_embeddings/embeddings.py`
Loops over entities and embeds each one serially:
```python
async def on_message(self, msg, consumer, flow):
for entity in v.entities:
# Serial embedding - one entity at a time
vectors = await flow("embeddings-request").embed(
text=entity.context
)
entities.append(EntityEmbeddings(
entity=entity.entity,
vectors=vectors,
chunk_id=entity.chunk_id,
))
```
**Impact:** A message with 50 entities = 50 serial embedding requests. This is a major bottleneck during knowledge graph construction.
#### 4. Row Embeddings Service
**File:** `trustgraph-flow/trustgraph/embeddings/row_embeddings/embeddings.py`
Loops over unique texts and embeds each one serially:
```python
async def on_message(self, msg, consumer, flow):
for text, (index_name, index_value) in texts_to_embed.items():
# Serial embedding - one text at a time
vectors = await flow("embeddings-request").embed(text=text)
embeddings_list.append(RowIndexEmbedding(
index_name=index_name,
index_value=index_value,
text=text,
vectors=vectors
))
```
**Impact:** Processing a table with 100 unique indexed values = 100 serial embedding requests.
#### 5. EmbeddingsClient (Base Client)
**File:** `trustgraph-base/trustgraph/base/embeddings_client.py`
The client used by all flow processors only supports single-text embedding:
```python
class EmbeddingsClient(RequestResponse):
async def embed(self, text, timeout=30):
resp = await self.request(
EmbeddingsRequest(text=text), # Single text
timeout=timeout
)
return resp.vectors
```
**Impact:** All callers using this client are limited to single-text operations.
#### 6. Command-Line Tools
**File:** `trustgraph-cli/trustgraph/cli/invoke_embeddings.py`
CLI tool accepts single text argument:
```python
def query(url, flow_id, text, token=None):
result = flow.embeddings(text=text) # Single text
vectors = result.get("vectors", [])
```
**Impact:** Users cannot batch-embed from command line. Processing a file of texts requires N invocations.
#### 7. Python SDK
The Python SDK provides two client classes for interacting with TrustGraph services. Both only support single-text embedding.
**File:** `trustgraph-base/trustgraph/api/flow.py`
```python
class FlowInstance:
def embeddings(self, text):
"""Get embeddings for a single text"""
input = {"text": text}
return self.request("service/embeddings", input)["vectors"]
```
**File:** `trustgraph-base/trustgraph/api/socket_client.py`
```python
class SocketFlowInstance:
def embeddings(self, text: str, **kwargs: Any) -> Dict[str, Any]:
"""Get embeddings for a single text via WebSocket"""
request = {"text": text}
return self.client._send_request_sync(
"embeddings", self.flow_id, request, False
)
```
**Impact:** Python developers using the SDK must loop over texts and make N separate API calls. No batch embedding support exists for SDK users.
### Performance Impact
For typical document ingestion (1000 text chunks):
- **Current**: 1000 separate requests, 1000 model inference calls
- **Batched (batch_size=32)**: 32 requests, 32 model inference calls (96.8% reduction)
For graph embedding (message with 50 entities):
- **Current**: 50 serial await calls, ~5-10 seconds
- **Batched**: 1-2 batch calls, ~0.5-1 second (5-10x improvement)
FastEmbed and similar libraries achieve near-linear throughput scaling with batch size up to hardware limits (typically 32-128 texts per batch).
## Technical Design
### Architecture
The embeddings batch processing optimization requires changes to the following components:
#### 1. **Schema Enhancement**
- Extend `EmbeddingsRequest` to support multiple texts
- Extend `EmbeddingsResponse` to return multiple vector sets
- Maintain backward compatibility with single-text requests
Module: `trustgraph-base/trustgraph/schema/services/llm.py`
#### 2. **Base Service Enhancement**
- Update `EmbeddingsService` to handle batch requests
- Add batch size configuration
- Implement batch-aware request handling
Module: `trustgraph-base/trustgraph/base/embeddings_service.py`
#### 3. **Provider Processor Updates**
- Update FastEmbed processor to pass full batch to `embed()`
- Update Ollama processor to handle batches (if supported)
- Add fallback sequential processing for providers without batch support
Modules:
- `trustgraph-flow/trustgraph/embeddings/fastembed/processor.py`
- `trustgraph-flow/trustgraph/embeddings/ollama/processor.py`
#### 4. **Client Enhancement**
- Add batch embedding method to `EmbeddingsClient`
- Support both single and batch APIs
- Add automatic batching for large inputs
Module: `trustgraph-base/trustgraph/base/embeddings_client.py`
#### 5. **Caller Updates - Flow Processors**
- Update `graph_embeddings` to batch entity contexts
- Update `row_embeddings` to batch index texts
- Update `document_embeddings` if message batching is feasible
Modules:
- `trustgraph-flow/trustgraph/embeddings/graph_embeddings/embeddings.py`
- `trustgraph-flow/trustgraph/embeddings/row_embeddings/embeddings.py`
- `trustgraph-flow/trustgraph/embeddings/document_embeddings/embeddings.py`
#### 6. **API Gateway Enhancement**
- Add batch embedding endpoint
- Support array of texts in request body
Module: `trustgraph-flow/trustgraph/gateway/dispatch/embeddings.py`
#### 7. **CLI Tool Enhancement**
- Add support for multiple texts or file input
- Add batch size parameter
Module: `trustgraph-cli/trustgraph/cli/invoke_embeddings.py`
#### 8. **Python SDK Enhancement**
- Add `embeddings_batch()` method to `FlowInstance`
- Add `embeddings_batch()` method to `SocketFlowInstance`
- Support both single and batch APIs for SDK users
Modules:
- `trustgraph-base/trustgraph/api/flow.py`
- `trustgraph-base/trustgraph/api/socket_client.py`
### Data Models
#### EmbeddingsRequest
```python
@dataclass
class EmbeddingsRequest:
texts: list[str] = field(default_factory=list)
```
Usage:
- Single text: `EmbeddingsRequest(texts=["hello world"])`
- Batch: `EmbeddingsRequest(texts=["text1", "text2", "text3"])`
#### EmbeddingsResponse
```python
@dataclass
class EmbeddingsResponse:
error: Error | None = None
vectors: list[list[list[float]]] = field(default_factory=list)
```
Response structure:
- `vectors[i]` contains the vector set for `texts[i]`
- Each vector set is `list[list[float]]` (models may return multiple vectors per text)
- Example: 3 texts → `vectors` has 3 entries, each containing that text's embeddings
### APIs
#### EmbeddingsClient
```python
class EmbeddingsClient(RequestResponse):
async def embed(
self,
texts: list[str],
timeout: float = 300,
) -> list[list[list[float]]]:
"""
Embed one or more texts in a single request.
Args:
texts: List of texts to embed
timeout: Timeout for the operation
Returns:
List of vector sets, one per input text
"""
resp = await self.request(
EmbeddingsRequest(texts=texts),
timeout=timeout
)
if resp.error:
raise RuntimeError(resp.error.message)
return resp.vectors
```
#### API Gateway Embeddings Endpoint
Updated endpoint supporting single or batch embedding:
```
POST /api/v1/embeddings
Content-Type: application/json
{
"texts": ["text1", "text2", "text3"],
"flow_id": "default"
}
Response:
{
"vectors": [
[[0.1, 0.2, ...]],
[[0.3, 0.4, ...]],
[[0.5, 0.6, ...]]
]
}
```
### Implementation Details
#### Phase 1: Schema Changes
**EmbeddingsRequest:**
```python
@dataclass
class EmbeddingsRequest:
texts: list[str] = field(default_factory=list)
```
**EmbeddingsResponse:**
```python
@dataclass
class EmbeddingsResponse:
error: Error | None = None
vectors: list[list[list[float]]] = field(default_factory=list)
```
**Updated EmbeddingsService.on_request:**
```python
async def on_request(self, msg, consumer, flow):
request = msg.value()
id = msg.properties()["id"]
model = flow("model")
vectors = await self.on_embeddings(request.texts, model=model)
response = EmbeddingsResponse(error=None, vectors=vectors)
await flow("response").send(response, properties={"id": id})
```
#### Phase 2: FastEmbed Processor Update
**Current (Inefficient):**
```python
async def on_embeddings(self, text, model=None):
use_model = model or self.default_model
self._load_model(use_model)
vecs = self.embeddings.embed([text]) # Batch of 1
return [v.tolist() for v in vecs]
```
**Updated:**
```python
async def on_embeddings(self, texts: list[str], model=None):
"""Embed texts - processes all texts in single model call"""
if not texts:
return []
use_model = model or self.default_model
self._load_model(use_model)
# FastEmbed handles the full batch efficiently
all_vecs = list(self.embeddings.embed(texts))
# Return list of vector sets, one per input text
return [[v.tolist()] for v in all_vecs]
```
#### Phase 3: Graph Embeddings Service Update
**Current (Serial):**
```python
async def on_message(self, msg, consumer, flow):
entities = []
for entity in v.entities:
vectors = await flow("embeddings-request").embed(text=entity.context)
entities.append(EntityEmbeddings(...))
```
**Updated (Batch):**
```python
async def on_message(self, msg, consumer, flow):
# Collect all contexts
contexts = [entity.context for entity in v.entities]
# Single batch embedding call
all_vectors = await flow("embeddings-request").embed(texts=contexts)
# Pair results with entities
entities = [
EntityEmbeddings(
entity=entity.entity,
vectors=vectors[0], # First vector from the set
chunk_id=entity.chunk_id,
)
for entity, vectors in zip(v.entities, all_vectors)
]
```
#### Phase 4: Row Embeddings Service Update
**Current (Serial):**
```python
for text, (index_name, index_value) in texts_to_embed.items():
vectors = await flow("embeddings-request").embed(text=text)
embeddings_list.append(RowIndexEmbedding(...))
```
**Updated (Batch):**
```python
# Collect texts and metadata
texts = list(texts_to_embed.keys())
metadata = list(texts_to_embed.values())
# Single batch embedding call
all_vectors = await flow("embeddings-request").embed(texts=texts)
# Pair results
embeddings_list = [
RowIndexEmbedding(
index_name=meta[0],
index_value=meta[1],
text=text,
vectors=vectors[0] # First vector from the set
)
for text, meta, vectors in zip(texts, metadata, all_vectors)
]
```
#### Phase 5: CLI Tool Enhancement
**Updated CLI:**
```python
def main():
parser = argparse.ArgumentParser(...)
parser.add_argument(
'text',
nargs='*', # Zero or more texts
help='Text(s) to convert to embedding vectors',
)
parser.add_argument(
'-f', '--file',
help='File containing texts (one per line)',
)
parser.add_argument(
'--batch-size',
type=int,
default=32,
help='Batch size for processing (default: 32)',
)
```
Usage:
```bash
# Single text (existing)
tg-invoke-embeddings "hello world"
# Multiple texts
tg-invoke-embeddings "text one" "text two" "text three"
# From file
tg-invoke-embeddings -f texts.txt --batch-size 64
```
#### Phase 6: Python SDK Enhancement
**FlowInstance (HTTP client):**
```python
class FlowInstance:
def embeddings(self, texts: list[str]) -> list[list[list[float]]]:
"""
Get embeddings for one or more texts.
Args:
texts: List of texts to embed
Returns:
List of vector sets, one per input text
"""
input = {"texts": texts}
return self.request("service/embeddings", input)["vectors"]
```
**SocketFlowInstance (WebSocket client):**
```python
class SocketFlowInstance:
def embeddings(self, texts: list[str], **kwargs: Any) -> list[list[list[float]]]:
"""
Get embeddings for one or more texts via WebSocket.
Args:
texts: List of texts to embed
Returns:
List of vector sets, one per input text
"""
request = {"texts": texts}
response = self.client._send_request_sync(
"embeddings", self.flow_id, request, False
)
return response["vectors"]
```
**SDK Usage Examples:**
```python
# Single text
vectors = flow.embeddings(["hello world"])
print(f"Dimensions: {len(vectors[0][0])}")
# Batch embedding
texts = ["text one", "text two", "text three"]
all_vectors = flow.embeddings(texts)
# Process results
for text, vecs in zip(texts, all_vectors):
print(f"{text}: {len(vecs[0])} dimensions")
```
## Security Considerations
- **Request Size Limits**: Enforce maximum batch size to prevent resource exhaustion
- **Timeout Handling**: Scale timeouts appropriately for batch size
- **Memory Limits**: Monitor memory usage for large batches
- **Input Validation**: Validate all texts in batch before processing
## Performance Considerations
### Expected Improvements
**Throughput:**
- Single-text: ~10-50 texts/second (depending on model)
- Batch (size 32): ~200-500 texts/second (5-10x improvement)
**Latency per Text:**
- Single-text: 50-200ms per text
- Batch (size 32): 5-20ms per text (amortized)
**Service-Specific Improvements:**
| Service | Current | Batched | Improvement |
|---------|---------|---------|-------------|
| Graph Embeddings (50 entities) | 5-10s | 0.5-1s | 5-10x |
| Row Embeddings (100 texts) | 10-20s | 1-2s | 5-10x |
| Document Ingestion (1000 chunks) | 100-200s | 10-30s | 5-10x |
### Configuration Parameters
```python
# Recommended defaults
DEFAULT_BATCH_SIZE = 32
MAX_BATCH_SIZE = 128
BATCH_TIMEOUT_MULTIPLIER = 2.0
```
## Testing Strategy
### Unit Testing
- Single text embedding (backward compatibility)
- Empty batch handling
- Maximum batch size enforcement
- Error handling for partial batch failures
### Integration Testing
- End-to-end batch embedding through Pulsar
- Graph embeddings service batch processing
- Row embeddings service batch processing
- API gateway batch endpoint
### Performance Testing
- Benchmark single vs batch throughput
- Memory usage under various batch sizes
- Latency distribution analysis
## Migration Plan
This is a breaking change release. All phases are implemented together.
### Phase 1: Schema Changes
- Replace `text: str` with `texts: list[str]` in EmbeddingsRequest
- Change `vectors` type to `list[list[list[float]]]` in EmbeddingsResponse
### Phase 2: Processor Updates
- Update `on_embeddings` signature in FastEmbed and Ollama processors
- Process full batch in single model call
### Phase 3: Client Updates
- Update `EmbeddingsClient.embed()` to accept `texts: list[str]`
### Phase 4: Caller Updates
- Update graph_embeddings to batch entity contexts
- Update row_embeddings to batch index texts
- Update document_embeddings to use new schema
- Update CLI tool
### Phase 5: API Gateway
- Update embeddings endpoint for new schema
### Phase 6: Python SDK
- Update `FlowInstance.embeddings()` signature
- Update `SocketFlowInstance.embeddings()` signature
## Open Questions
- **Streaming Large Batches**: Should we support streaming results for very large batches (>100 texts)?
- **Provider-Specific Limits**: How should we handle providers with different maximum batch sizes?
- **Partial Failure Handling**: If one text in a batch fails, should we fail the entire batch or return partial results?
- **Document Embeddings Batching**: Should we batch across multiple Chunk messages or keep per-message processing?
## References
- [FastEmbed Documentation](https://github.com/qdrant/fastembed)
- [Ollama Embeddings API](https://github.com/ollama/ollama)
- [EmbeddingsService Implementation](trustgraph-base/trustgraph/base/embeddings_service.py)
- [GraphRAG Performance Optimization](graphrag-performance-optimization.md)

View file

@ -103,12 +103,12 @@ class TestFastEmbedDynamicModelLoading(IsolatedAsyncioTestCase):
mock_text_embedding_class.reset_mock() mock_text_embedding_class.reset_mock()
# Act # Act
result = await processor.on_embeddings("test text") result = await processor.on_embeddings(["test text"])
# Assert # Assert
mock_fastembed_instance.embed.assert_called_once_with(["test text"]) mock_fastembed_instance.embed.assert_called_once_with(["test text"])
assert processor.cached_model_name == "test-model" # Still using default assert processor.cached_model_name == "test-model" # Still using default
assert result == [[0.1, 0.2, 0.3, 0.4, 0.5]] assert result == [[[0.1, 0.2, 0.3, 0.4, 0.5]]]
@patch('trustgraph.embeddings.fastembed.processor.TextEmbedding') @patch('trustgraph.embeddings.fastembed.processor.TextEmbedding')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__') @patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@ -126,7 +126,7 @@ class TestFastEmbedDynamicModelLoading(IsolatedAsyncioTestCase):
mock_text_embedding_class.reset_mock() mock_text_embedding_class.reset_mock()
# Act # Act
result = await processor.on_embeddings("test text", model="custom-model") result = await processor.on_embeddings(["test text"], model="custom-model")
# Assert # Assert
mock_text_embedding_class.assert_called_once_with(model_name="custom-model") mock_text_embedding_class.assert_called_once_with(model_name="custom-model")
@ -149,16 +149,16 @@ class TestFastEmbedDynamicModelLoading(IsolatedAsyncioTestCase):
initial_call_count = mock_text_embedding_class.call_count initial_call_count = mock_text_embedding_class.call_count
# Act - switch between models # Act - switch between models
await processor.on_embeddings("text1", model="model-a") await processor.on_embeddings(["text1"], model="model-a")
call_count_after_a = mock_text_embedding_class.call_count call_count_after_a = mock_text_embedding_class.call_count
await processor.on_embeddings("text2", model="model-a") # Same, no reload await processor.on_embeddings(["text2"], model="model-a") # Same, no reload
call_count_after_a_repeat = mock_text_embedding_class.call_count call_count_after_a_repeat = mock_text_embedding_class.call_count
await processor.on_embeddings("text3", model="model-b") # Different, reload await processor.on_embeddings(["text3"], model="model-b") # Different, reload
call_count_after_b = mock_text_embedding_class.call_count call_count_after_b = mock_text_embedding_class.call_count
await processor.on_embeddings("text4", model="model-a") # Back to A, reload await processor.on_embeddings(["text4"], model="model-a") # Back to A, reload
call_count_after_a_again = mock_text_embedding_class.call_count call_count_after_a_again = mock_text_embedding_class.call_count
# Assert # Assert
@ -183,7 +183,7 @@ class TestFastEmbedDynamicModelLoading(IsolatedAsyncioTestCase):
initial_count = mock_text_embedding_class.call_count initial_count = mock_text_embedding_class.call_count
# Act # Act
result = await processor.on_embeddings("test text", model=None) result = await processor.on_embeddings(["test text"], model=None)
# Assert # Assert
# No reload, using cached default # No reload, using cached default

View file

@ -53,14 +53,14 @@ class TestOllamaDynamicModelLoading(IsolatedAsyncioTestCase):
processor = Processor(id="test", concurrency=1, model="test-model", taskgroup=AsyncMock()) processor = Processor(id="test", concurrency=1, model="test-model", taskgroup=AsyncMock())
# Act # Act
result = await processor.on_embeddings("test text") result = await processor.on_embeddings(["test text"])
# Assert # Assert
mock_ollama_client.embed.assert_called_once_with( mock_ollama_client.embed.assert_called_once_with(
model="test-model", model="test-model",
input="test text" input=["test text"]
) )
assert result == [[0.1, 0.2, 0.3, 0.4, 0.5]] assert result == [[[0.1, 0.2, 0.3, 0.4, 0.5]]]
@patch('trustgraph.embeddings.ollama.processor.Client') @patch('trustgraph.embeddings.ollama.processor.Client')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__') @patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@ -79,14 +79,14 @@ class TestOllamaDynamicModelLoading(IsolatedAsyncioTestCase):
processor = Processor(id="test", concurrency=1, model="test-model", taskgroup=AsyncMock()) processor = Processor(id="test", concurrency=1, model="test-model", taskgroup=AsyncMock())
# Act # Act
result = await processor.on_embeddings("test text", model="custom-model") result = await processor.on_embeddings(["test text"], model="custom-model")
# Assert # Assert
mock_ollama_client.embed.assert_called_once_with( mock_ollama_client.embed.assert_called_once_with(
model="custom-model", model="custom-model",
input="test text" input=["test text"]
) )
assert result == [[0.1, 0.2, 0.3, 0.4, 0.5]] assert result == [[[0.1, 0.2, 0.3, 0.4, 0.5]]]
@patch('trustgraph.embeddings.ollama.processor.Client') @patch('trustgraph.embeddings.ollama.processor.Client')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__') @patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@ -105,10 +105,10 @@ class TestOllamaDynamicModelLoading(IsolatedAsyncioTestCase):
processor = Processor(id="test", concurrency=1, model="test-model", taskgroup=AsyncMock()) processor = Processor(id="test", concurrency=1, model="test-model", taskgroup=AsyncMock())
# Act - switch between different models # Act - switch between different models
await processor.on_embeddings("text1", model="model-a") await processor.on_embeddings(["text1"], model="model-a")
await processor.on_embeddings("text2", model="model-b") await processor.on_embeddings(["text2"], model="model-b")
await processor.on_embeddings("text3", model="model-a") await processor.on_embeddings(["text3"], model="model-a")
await processor.on_embeddings("text4") # Use default await processor.on_embeddings(["text4"]) # Use default
# Assert # Assert
calls = mock_ollama_client.embed.call_args_list calls = mock_ollama_client.embed.call_args_list
@ -135,12 +135,12 @@ class TestOllamaDynamicModelLoading(IsolatedAsyncioTestCase):
processor = Processor(id="test", concurrency=1, model="test-model", taskgroup=AsyncMock()) processor = Processor(id="test", concurrency=1, model="test-model", taskgroup=AsyncMock())
# Act # Act
result = await processor.on_embeddings("test text", model=None) result = await processor.on_embeddings(["test text"], model=None)
# Assert # Assert
mock_ollama_client.embed.assert_called_once_with( mock_ollama_client.embed.assert_called_once_with(
model="test-model", model="test-model",
input="test text" input=["test text"]
) )
@patch('trustgraph.embeddings.ollama.processor.Client') @patch('trustgraph.embeddings.ollama.processor.Client')

View file

@ -353,7 +353,14 @@ class TestRowEmbeddingsProcessor(IsolatedAsyncioTestCase):
# Mock the flow # Mock the flow
mock_embeddings_request = AsyncMock() mock_embeddings_request = AsyncMock()
mock_embeddings_request.embed.return_value = [[0.1, 0.2, 0.3]] # Return batch of vector sets (one per text)
# 4 unique texts: CUST001, John Doe, CUST002, Jane Smith
mock_embeddings_request.embed.return_value = [
[[0.1, 0.2, 0.3]], # vectors for text 1
[[0.2, 0.3, 0.4]], # vectors for text 2
[[0.3, 0.4, 0.5]], # vectors for text 3
[[0.4, 0.5, 0.6]], # vectors for text 4
]
mock_output = AsyncMock() mock_output = AsyncMock()
@ -368,9 +375,12 @@ class TestRowEmbeddingsProcessor(IsolatedAsyncioTestCase):
await processor.on_message(mock_msg, MagicMock(), mock_flow) await processor.on_message(mock_msg, MagicMock(), mock_flow)
# Should have called embed for each unique text # Should have called embed once with all texts in a batch
# 4 values: CUST001, John Doe, CUST002, Jane Smith assert mock_embeddings_request.embed.call_count == 1
assert mock_embeddings_request.embed.call_count == 4 # Verify it was called with a list of texts
call_args = mock_embeddings_request.embed.call_args
assert 'texts' in call_args.kwargs
assert len(call_args.kwargs['texts']) == 4
# Should have sent output # Should have sent output
mock_output.send.assert_called() mock_output.send.assert_called()

View file

@ -544,30 +544,29 @@ class FlowInstance:
input input
)["response"] )["response"]
def embeddings(self, text): def embeddings(self, texts):
""" """
Generate vector embeddings for text. Generate vector embeddings for one or more texts.
Converts text into dense vector representations suitable for semantic Converts texts into dense vector representations suitable for semantic
search and similarity comparison. search and similarity comparison.
Args: Args:
text: Input text to embed texts: List of input texts to embed
Returns: Returns:
list[float]: Vector embedding list[list[list[float]]]: Vector embeddings, one set per input text
Example: Example:
```python ```python
flow = api.flow().id("default") flow = api.flow().id("default")
vectors = flow.embeddings("quantum computing") vectors = flow.embeddings(["quantum computing"])
print(f"Embedding dimension: {len(vectors)}") print(f"Embedding dimension: {len(vectors[0][0])}")
``` ```
""" """
# The input consists of a text block
input = { input = {
"text": text "texts": texts
} }
return self.request( return self.request(

View file

@ -712,27 +712,27 @@ class SocketFlowInstance:
return self.client._send_request_sync("document-embeddings", self.flow_id, request, False) return self.client._send_request_sync("document-embeddings", self.flow_id, request, False)
def embeddings(self, text: str, **kwargs: Any) -> Dict[str, Any]: def embeddings(self, texts: list, **kwargs: Any) -> Dict[str, Any]:
""" """
Generate vector embeddings for text. Generate vector embeddings for one or more texts.
Args: Args:
text: Input text to embed texts: List of input texts to embed
**kwargs: Additional parameters passed to the service **kwargs: Additional parameters passed to the service
Returns: Returns:
dict: Response containing vectors dict: Response containing vectors (one set per input text)
Example: Example:
```python ```python
socket = api.socket() socket = api.socket()
flow = socket.flow("default") flow = socket.flow("default")
result = flow.embeddings("quantum computing") result = flow.embeddings(["quantum computing"])
vectors = result.get("vectors", []) vectors = result.get("vectors", [])
``` ```
""" """
request = {"text": text} request = {"texts": texts}
request.update(kwargs) request.update(kwargs)
return self.client._send_request_sync("embeddings", self.flow_id, request, False) return self.client._send_request_sync("embeddings", self.flow_id, request, False)

View file

@ -3,11 +3,11 @@ from . request_response_spec import RequestResponse, RequestResponseSpec
from .. schema import EmbeddingsRequest, EmbeddingsResponse from .. schema import EmbeddingsRequest, EmbeddingsResponse
class EmbeddingsClient(RequestResponse): class EmbeddingsClient(RequestResponse):
async def embed(self, text, timeout=30): async def embed(self, texts, timeout=300):
resp = await self.request( resp = await self.request(
EmbeddingsRequest( EmbeddingsRequest(
text = text texts = texts
), ),
timeout=timeout timeout=timeout
) )

View file

@ -65,7 +65,7 @@ class EmbeddingsService(FlowProcessor):
# Pass model from request if specified (non-empty), otherwise use default # Pass model from request if specified (non-empty), otherwise use default
model = flow("model") model = flow("model")
vectors = await self.on_embeddings(request.text, model=model) vectors = await self.on_embeddings(request.texts, model=model)
await flow("response").send( await flow("response").send(
EmbeddingsResponse( EmbeddingsResponse(
@ -94,7 +94,7 @@ class EmbeddingsService(FlowProcessor):
type = "embeddings-error", type = "embeddings-error",
message = str(e), message = str(e),
), ),
vectors=None, vectors=[],
), ),
properties={"id": id} properties={"id": id}
) )

View file

@ -8,12 +8,12 @@ class EmbeddingsRequestTranslator(MessageTranslator):
def to_pulsar(self, data: Dict[str, Any]) -> EmbeddingsRequest: def to_pulsar(self, data: Dict[str, Any]) -> EmbeddingsRequest:
return EmbeddingsRequest( return EmbeddingsRequest(
text=data["text"] texts=data["texts"]
) )
def from_pulsar(self, obj: EmbeddingsRequest) -> Dict[str, Any]: def from_pulsar(self, obj: EmbeddingsRequest) -> Dict[str, Any]:
return { return {
"text": obj.text "texts": obj.texts
} }

View file

@ -29,12 +29,12 @@ class TextCompletionResponse:
@dataclass @dataclass
class EmbeddingsRequest: class EmbeddingsRequest:
text: str = "" texts: list[str] = field(default_factory=list)
@dataclass @dataclass
class EmbeddingsResponse: class EmbeddingsResponse:
error: Error | None = None error: Error | None = None
vectors: list[list[float]] = field(default_factory=list) vectors: list[list[list[float]]] = field(default_factory=list)
############################################################################ ############################################################################

View file

@ -10,7 +10,7 @@ from trustgraph.api import Api
default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/')
default_token = os.getenv("TRUSTGRAPH_TOKEN", None) default_token = os.getenv("TRUSTGRAPH_TOKEN", None)
def query(url, flow_id, text, token=None): def query(url, flow_id, texts, token=None):
# Create API client # Create API client
api = Api(url=url, token=token) api = Api(url=url, token=token)
@ -19,9 +19,14 @@ def query(url, flow_id, text, token=None):
try: try:
# Call embeddings service # Call embeddings service
result = flow.embeddings(text=text) result = flow.embeddings(texts=texts)
vectors = result.get("vectors", []) vectors = result.get("vectors", [])
print(vectors) # Print each text's vectors
for i, vecs in enumerate(vectors):
if len(texts) > 1:
print(f"Text {i + 1}: {vecs}")
else:
print(vecs)
finally: finally:
# Clean up socket connection # Clean up socket connection
@ -53,9 +58,9 @@ def main():
) )
parser.add_argument( parser.add_argument(
'text', 'texts',
nargs=1, nargs='+',
help='Text to convert to embedding vector', help='Text(s) to convert to embedding vectors',
) )
args = parser.parse_args() args = parser.parse_args()
@ -65,7 +70,7 @@ def main():
query( query(
url=args.url, url=args.url,
flow_id=args.flow_id, flow_id=args.flow_id,
text=args.text[0], texts=args.texts,
token=args.token, token=args.token,
) )

View file

@ -62,11 +62,13 @@ class Processor(FlowProcessor):
resp = await flow("embeddings-request").request( resp = await flow("embeddings-request").request(
EmbeddingsRequest( EmbeddingsRequest(
text = v.chunk texts=[v.chunk]
) )
) )
vectors = resp.vectors # vectors[0] is the vector set for the first (only) text
# vectors[0][0] is the first vector in that set
vectors = resp.vectors[0][0] if resp.vectors else []
embeds = [ embeds = [
ChunkEmbeddings( ChunkEmbeddings(

View file

@ -46,17 +46,22 @@ class Processor(EmbeddingsService):
else: else:
logger.debug(f"Using cached model: {model_name}") logger.debug(f"Using cached model: {model_name}")
async def on_embeddings(self, text, model=None): async def on_embeddings(self, texts, model=None):
if not texts:
return []
use_model = model or self.default_model use_model = model or self.default_model
# Reload model if it has changed # Reload model if it has changed
self._load_model(use_model) self._load_model(use_model)
vecs = self.embeddings.embed([text]) # FastEmbed processes the full batch efficiently
vecs = list(self.embeddings.embed(texts))
# Return list of vector sets, one per input text
return [ return [
v.tolist() [v.tolist()]
for v in vecs for v in vecs
] ]

View file

@ -58,23 +58,25 @@ class Processor(FlowProcessor):
v = msg.value() v = msg.value()
logger.info(f"Indexing {v.metadata.id}...") logger.info(f"Indexing {v.metadata.id}...")
entities = []
try: try:
for entity in v.entities: # Collect all contexts for batch embedding
contexts = [entity.context for entity in v.entities]
vectors = await flow("embeddings-request").embed( # Single batch embedding call
text = entity.context all_vectors = await flow("embeddings-request").embed(
) texts=contexts
)
entities.append( # Pair results with entities
EntityEmbeddings( entities = [
entity=entity.entity, EntityEmbeddings(
vectors=vectors, entity=entity.entity,
chunk_id=entity.chunk_id, # Provenance: source chunk vectors=vectors[0], # First vector from the set
) chunk_id=entity.chunk_id, # Provenance: source chunk
) )
for entity, vectors in zip(v.entities, all_vectors)
]
# Send in batches to avoid oversized messages # Send in batches to avoid oversized messages
for i in range(0, len(entities), self.batch_size): for i in range(0, len(entities), self.batch_size):

View file

@ -30,16 +30,24 @@ class Processor(EmbeddingsService):
self.client = Client(host=ollama) self.client = Client(host=ollama)
self.default_model = model self.default_model = model
async def on_embeddings(self, text, model=None): async def on_embeddings(self, texts, model=None):
if not texts:
return []
use_model = model or self.default_model use_model = model or self.default_model
# Ollama handles batch input efficiently
embeds = self.client.embed( embeds = self.client.embed(
model = use_model, model = use_model,
input = text input = texts
) )
return embeds.embeddings # Return list of vector sets, one per input text
return [
[embedding]
for embedding in embeds.embeddings
]
@staticmethod @staticmethod
def add_args(parser): def add_args(parser):

View file

@ -200,15 +200,23 @@ class Processor(CollectionConfigHandler, FlowProcessor):
embeddings_list = [] embeddings_list = []
try: try:
for text, (index_name, index_value) in texts_to_embed.items(): # Collect texts and metadata for batch embedding
vectors = await flow("embeddings-request").embed(text=text) texts = list(texts_to_embed.keys())
metadata = list(texts_to_embed.values())
# Single batch embedding call
all_vectors = await flow("embeddings-request").embed(texts=texts)
# Pair results with metadata
for text, (index_name, index_value), vectors in zip(
texts, metadata, all_vectors
):
embeddings_list.append( embeddings_list.append(
RowIndexEmbedding( RowIndexEmbedding(
index_name=index_name, index_name=index_name,
index_value=index_value, index_value=index_value,
text=text, text=text,
vectors=vectors vectors=vectors[0] # First vector from the set
) )
) )