mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-25 00:16:23 +02:00
Batch embeddings (#668)
Base Service (trustgraph-base/trustgraph/base/embeddings_service.py): - Changed on_request to use request.texts FastEmbed Processor (trustgraph-flow/trustgraph/embeddings/fastembed/processor.py): - on_embeddings(texts, model=None) now processes full batch efficiently - Returns [[v.tolist()] for v in vecs] - list of vector sets Ollama Processor (trustgraph-flow/trustgraph/embeddings/ollama/processor.py): - on_embeddings(texts, model=None) passes list directly to Ollama - Returns [[embedding] for embedding in embeds.embeddings] EmbeddingsClient (trustgraph-base/trustgraph/base/embeddings_client.py): - embed(texts, timeout=300) accepts list of texts Tests Updated: - test_fastembed_dynamic_model.py - 4 tests updated for new interface - test_ollama_dynamic_model.py - 4 tests updated for new interface Updated CLI, SDK and APIs
This commit is contained in:
parent
3bf8a65409
commit
0a2ce47a88
16 changed files with 785 additions and 79 deletions
667
docs/tech-specs/embeddings-batch-processing.md
Normal file
667
docs/tech-specs/embeddings-batch-processing.md
Normal file
|
|
@ -0,0 +1,667 @@
|
|||
# Embeddings Batch Processing Technical Specification
|
||||
|
||||
## Overview
|
||||
|
||||
This specification describes optimizations for the embeddings service to support batch processing of multiple texts in a single request. The current implementation processes one text at a time, missing the significant performance benefits that embedding models provide when processing batches.
|
||||
|
||||
1. **Single-Text Processing Inefficiency**: Current implementation wraps single texts in a list, underutilizing FastEmbed's batch capabilities
|
||||
2. **Request-Per-Text Overhead**: Each text requires a separate Pulsar message round-trip
|
||||
3. **Model Inference Inefficiency**: Embedding models have fixed per-batch overhead; small batches waste GPU/CPU resources
|
||||
4. **Serial Processing in Callers**: Key services loop over items and call embeddings one at a time
|
||||
|
||||
## Goals
|
||||
|
||||
- **Batch API Support**: Enable processing multiple texts in a single request
|
||||
- **Backward Compatibility**: Maintain support for single-text requests
|
||||
- **Significant Throughput Improvement**: Target 5-10x throughput improvement for bulk operations
|
||||
- **Reduced Latency per Text**: Lower amortized latency when embedding multiple texts
|
||||
- **Memory Efficiency**: Process batches without excessive memory consumption
|
||||
- **Provider Agnostic**: Support batching across FastEmbed, Ollama, and other providers
|
||||
- **Caller Migration**: Update all embedding callers to use batch API where beneficial
|
||||
|
||||
## Background
|
||||
|
||||
### Current Implementation - Embeddings Service
|
||||
|
||||
The embeddings implementation in `trustgraph-flow/trustgraph/embeddings/fastembed/processor.py` exhibits a significant performance inefficiency:
|
||||
|
||||
```python
|
||||
# fastembed/processor.py line 56
|
||||
async def on_embeddings(self, text, model=None):
|
||||
use_model = model or self.default_model
|
||||
self._load_model(use_model)
|
||||
|
||||
vecs = self.embeddings.embed([text]) # Single text wrapped in list
|
||||
|
||||
return [v.tolist() for v in vecs]
|
||||
```
|
||||
|
||||
**Problems:**
|
||||
|
||||
1. **Batch Size 1**: FastEmbed's `embed()` method is optimized for batch processing, but we always call it with `[text]` - a batch of size 1
|
||||
|
||||
2. **Per-Request Overhead**: Each embedding request incurs:
|
||||
- Pulsar message serialization/deserialization
|
||||
- Network round-trip latency
|
||||
- Model inference startup overhead
|
||||
- Python async scheduling overhead
|
||||
|
||||
3. **Schema Limitation**: The `EmbeddingsRequest` schema only supports a single text:
|
||||
```python
|
||||
@dataclass
|
||||
class EmbeddingsRequest:
|
||||
text: str = "" # Single text only
|
||||
```
|
||||
|
||||
### Current Callers - Serial Processing
|
||||
|
||||
#### 1. API Gateway
|
||||
|
||||
**File:** `trustgraph-flow/trustgraph/gateway/dispatch/embeddings.py`
|
||||
|
||||
The gateway accepts single-text embedding requests via HTTP/WebSocket and forwards them to the embeddings service. Currently no batch endpoint exists.
|
||||
|
||||
```python
|
||||
class EmbeddingsRequestor(ServiceRequestor):
|
||||
# Handles single EmbeddingsRequest -> EmbeddingsResponse
|
||||
request_schema=EmbeddingsRequest, # Single text only
|
||||
response_schema=EmbeddingsResponse,
|
||||
```
|
||||
|
||||
**Impact:** External clients (web apps, scripts) must make N HTTP requests to embed N texts.
|
||||
|
||||
#### 2. Document Embeddings Service
|
||||
|
||||
**File:** `trustgraph-flow/trustgraph/embeddings/document_embeddings/embeddings.py`
|
||||
|
||||
Processes document chunks one at a time:
|
||||
|
||||
```python
|
||||
async def on_message(self, msg, consumer, flow):
|
||||
v = msg.value()
|
||||
|
||||
# Single chunk per request
|
||||
resp = await flow("embeddings-request").request(
|
||||
EmbeddingsRequest(text=v.chunk)
|
||||
)
|
||||
vectors = resp.vectors
|
||||
```
|
||||
|
||||
**Impact:** Each document chunk requires a separate embedding call. A document with 100 chunks = 100 embedding requests.
|
||||
|
||||
#### 3. Graph Embeddings Service
|
||||
|
||||
**File:** `trustgraph-flow/trustgraph/embeddings/graph_embeddings/embeddings.py`
|
||||
|
||||
Loops over entities and embeds each one serially:
|
||||
|
||||
```python
|
||||
async def on_message(self, msg, consumer, flow):
|
||||
for entity in v.entities:
|
||||
# Serial embedding - one entity at a time
|
||||
vectors = await flow("embeddings-request").embed(
|
||||
text=entity.context
|
||||
)
|
||||
entities.append(EntityEmbeddings(
|
||||
entity=entity.entity,
|
||||
vectors=vectors,
|
||||
chunk_id=entity.chunk_id,
|
||||
))
|
||||
```
|
||||
|
||||
**Impact:** A message with 50 entities = 50 serial embedding requests. This is a major bottleneck during knowledge graph construction.
|
||||
|
||||
#### 4. Row Embeddings Service
|
||||
|
||||
**File:** `trustgraph-flow/trustgraph/embeddings/row_embeddings/embeddings.py`
|
||||
|
||||
Loops over unique texts and embeds each one serially:
|
||||
|
||||
```python
|
||||
async def on_message(self, msg, consumer, flow):
|
||||
for text, (index_name, index_value) in texts_to_embed.items():
|
||||
# Serial embedding - one text at a time
|
||||
vectors = await flow("embeddings-request").embed(text=text)
|
||||
|
||||
embeddings_list.append(RowIndexEmbedding(
|
||||
index_name=index_name,
|
||||
index_value=index_value,
|
||||
text=text,
|
||||
vectors=vectors
|
||||
))
|
||||
```
|
||||
|
||||
**Impact:** Processing a table with 100 unique indexed values = 100 serial embedding requests.
|
||||
|
||||
#### 5. EmbeddingsClient (Base Client)
|
||||
|
||||
**File:** `trustgraph-base/trustgraph/base/embeddings_client.py`
|
||||
|
||||
The client used by all flow processors only supports single-text embedding:
|
||||
|
||||
```python
|
||||
class EmbeddingsClient(RequestResponse):
|
||||
async def embed(self, text, timeout=30):
|
||||
resp = await self.request(
|
||||
EmbeddingsRequest(text=text), # Single text
|
||||
timeout=timeout
|
||||
)
|
||||
return resp.vectors
|
||||
```
|
||||
|
||||
**Impact:** All callers using this client are limited to single-text operations.
|
||||
|
||||
#### 6. Command-Line Tools
|
||||
|
||||
**File:** `trustgraph-cli/trustgraph/cli/invoke_embeddings.py`
|
||||
|
||||
CLI tool accepts single text argument:
|
||||
|
||||
```python
|
||||
def query(url, flow_id, text, token=None):
|
||||
result = flow.embeddings(text=text) # Single text
|
||||
vectors = result.get("vectors", [])
|
||||
```
|
||||
|
||||
**Impact:** Users cannot batch-embed from command line. Processing a file of texts requires N invocations.
|
||||
|
||||
#### 7. Python SDK
|
||||
|
||||
The Python SDK provides two client classes for interacting with TrustGraph services. Both only support single-text embedding.
|
||||
|
||||
**File:** `trustgraph-base/trustgraph/api/flow.py`
|
||||
|
||||
```python
|
||||
class FlowInstance:
|
||||
def embeddings(self, text):
|
||||
"""Get embeddings for a single text"""
|
||||
input = {"text": text}
|
||||
return self.request("service/embeddings", input)["vectors"]
|
||||
```
|
||||
|
||||
**File:** `trustgraph-base/trustgraph/api/socket_client.py`
|
||||
|
||||
```python
|
||||
class SocketFlowInstance:
|
||||
def embeddings(self, text: str, **kwargs: Any) -> Dict[str, Any]:
|
||||
"""Get embeddings for a single text via WebSocket"""
|
||||
request = {"text": text}
|
||||
return self.client._send_request_sync(
|
||||
"embeddings", self.flow_id, request, False
|
||||
)
|
||||
```
|
||||
|
||||
**Impact:** Python developers using the SDK must loop over texts and make N separate API calls. No batch embedding support exists for SDK users.
|
||||
|
||||
### Performance Impact
|
||||
|
||||
For typical document ingestion (1000 text chunks):
|
||||
- **Current**: 1000 separate requests, 1000 model inference calls
|
||||
- **Batched (batch_size=32)**: 32 requests, 32 model inference calls (96.8% reduction)
|
||||
|
||||
For graph embedding (message with 50 entities):
|
||||
- **Current**: 50 serial await calls, ~5-10 seconds
|
||||
- **Batched**: 1-2 batch calls, ~0.5-1 second (5-10x improvement)
|
||||
|
||||
FastEmbed and similar libraries achieve near-linear throughput scaling with batch size up to hardware limits (typically 32-128 texts per batch).
|
||||
|
||||
## Technical Design
|
||||
|
||||
### Architecture
|
||||
|
||||
The embeddings batch processing optimization requires changes to the following components:
|
||||
|
||||
#### 1. **Schema Enhancement**
|
||||
- Extend `EmbeddingsRequest` to support multiple texts
|
||||
- Extend `EmbeddingsResponse` to return multiple vector sets
|
||||
- Maintain backward compatibility with single-text requests
|
||||
|
||||
Module: `trustgraph-base/trustgraph/schema/services/llm.py`
|
||||
|
||||
#### 2. **Base Service Enhancement**
|
||||
- Update `EmbeddingsService` to handle batch requests
|
||||
- Add batch size configuration
|
||||
- Implement batch-aware request handling
|
||||
|
||||
Module: `trustgraph-base/trustgraph/base/embeddings_service.py`
|
||||
|
||||
#### 3. **Provider Processor Updates**
|
||||
- Update FastEmbed processor to pass full batch to `embed()`
|
||||
- Update Ollama processor to handle batches (if supported)
|
||||
- Add fallback sequential processing for providers without batch support
|
||||
|
||||
Modules:
|
||||
- `trustgraph-flow/trustgraph/embeddings/fastembed/processor.py`
|
||||
- `trustgraph-flow/trustgraph/embeddings/ollama/processor.py`
|
||||
|
||||
#### 4. **Client Enhancement**
|
||||
- Add batch embedding method to `EmbeddingsClient`
|
||||
- Support both single and batch APIs
|
||||
- Add automatic batching for large inputs
|
||||
|
||||
Module: `trustgraph-base/trustgraph/base/embeddings_client.py`
|
||||
|
||||
#### 5. **Caller Updates - Flow Processors**
|
||||
- Update `graph_embeddings` to batch entity contexts
|
||||
- Update `row_embeddings` to batch index texts
|
||||
- Update `document_embeddings` if message batching is feasible
|
||||
|
||||
Modules:
|
||||
- `trustgraph-flow/trustgraph/embeddings/graph_embeddings/embeddings.py`
|
||||
- `trustgraph-flow/trustgraph/embeddings/row_embeddings/embeddings.py`
|
||||
- `trustgraph-flow/trustgraph/embeddings/document_embeddings/embeddings.py`
|
||||
|
||||
#### 6. **API Gateway Enhancement**
|
||||
- Add batch embedding endpoint
|
||||
- Support array of texts in request body
|
||||
|
||||
Module: `trustgraph-flow/trustgraph/gateway/dispatch/embeddings.py`
|
||||
|
||||
#### 7. **CLI Tool Enhancement**
|
||||
- Add support for multiple texts or file input
|
||||
- Add batch size parameter
|
||||
|
||||
Module: `trustgraph-cli/trustgraph/cli/invoke_embeddings.py`
|
||||
|
||||
#### 8. **Python SDK Enhancement**
|
||||
- Add `embeddings_batch()` method to `FlowInstance`
|
||||
- Add `embeddings_batch()` method to `SocketFlowInstance`
|
||||
- Support both single and batch APIs for SDK users
|
||||
|
||||
Modules:
|
||||
- `trustgraph-base/trustgraph/api/flow.py`
|
||||
- `trustgraph-base/trustgraph/api/socket_client.py`
|
||||
|
||||
### Data Models
|
||||
|
||||
#### EmbeddingsRequest
|
||||
|
||||
```python
|
||||
@dataclass
|
||||
class EmbeddingsRequest:
|
||||
texts: list[str] = field(default_factory=list)
|
||||
```
|
||||
|
||||
Usage:
|
||||
- Single text: `EmbeddingsRequest(texts=["hello world"])`
|
||||
- Batch: `EmbeddingsRequest(texts=["text1", "text2", "text3"])`
|
||||
|
||||
#### EmbeddingsResponse
|
||||
|
||||
```python
|
||||
@dataclass
|
||||
class EmbeddingsResponse:
|
||||
error: Error | None = None
|
||||
vectors: list[list[list[float]]] = field(default_factory=list)
|
||||
```
|
||||
|
||||
Response structure:
|
||||
- `vectors[i]` contains the vector set for `texts[i]`
|
||||
- Each vector set is `list[list[float]]` (models may return multiple vectors per text)
|
||||
- Example: 3 texts → `vectors` has 3 entries, each containing that text's embeddings
|
||||
|
||||
### APIs
|
||||
|
||||
#### EmbeddingsClient
|
||||
|
||||
```python
|
||||
class EmbeddingsClient(RequestResponse):
|
||||
async def embed(
|
||||
self,
|
||||
texts: list[str],
|
||||
timeout: float = 300,
|
||||
) -> list[list[list[float]]]:
|
||||
"""
|
||||
Embed one or more texts in a single request.
|
||||
|
||||
Args:
|
||||
texts: List of texts to embed
|
||||
timeout: Timeout for the operation
|
||||
|
||||
Returns:
|
||||
List of vector sets, one per input text
|
||||
"""
|
||||
resp = await self.request(
|
||||
EmbeddingsRequest(texts=texts),
|
||||
timeout=timeout
|
||||
)
|
||||
if resp.error:
|
||||
raise RuntimeError(resp.error.message)
|
||||
return resp.vectors
|
||||
```
|
||||
|
||||
#### API Gateway Embeddings Endpoint
|
||||
|
||||
Updated endpoint supporting single or batch embedding:
|
||||
|
||||
```
|
||||
POST /api/v1/embeddings
|
||||
Content-Type: application/json
|
||||
|
||||
{
|
||||
"texts": ["text1", "text2", "text3"],
|
||||
"flow_id": "default"
|
||||
}
|
||||
|
||||
Response:
|
||||
{
|
||||
"vectors": [
|
||||
[[0.1, 0.2, ...]],
|
||||
[[0.3, 0.4, ...]],
|
||||
[[0.5, 0.6, ...]]
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
### Implementation Details
|
||||
|
||||
#### Phase 1: Schema Changes
|
||||
|
||||
**EmbeddingsRequest:**
|
||||
```python
|
||||
@dataclass
|
||||
class EmbeddingsRequest:
|
||||
texts: list[str] = field(default_factory=list)
|
||||
```
|
||||
|
||||
**EmbeddingsResponse:**
|
||||
```python
|
||||
@dataclass
|
||||
class EmbeddingsResponse:
|
||||
error: Error | None = None
|
||||
vectors: list[list[list[float]]] = field(default_factory=list)
|
||||
```
|
||||
|
||||
**Updated EmbeddingsService.on_request:**
|
||||
```python
|
||||
async def on_request(self, msg, consumer, flow):
|
||||
request = msg.value()
|
||||
id = msg.properties()["id"]
|
||||
model = flow("model")
|
||||
|
||||
vectors = await self.on_embeddings(request.texts, model=model)
|
||||
response = EmbeddingsResponse(error=None, vectors=vectors)
|
||||
|
||||
await flow("response").send(response, properties={"id": id})
|
||||
```
|
||||
|
||||
#### Phase 2: FastEmbed Processor Update
|
||||
|
||||
**Current (Inefficient):**
|
||||
```python
|
||||
async def on_embeddings(self, text, model=None):
|
||||
use_model = model or self.default_model
|
||||
self._load_model(use_model)
|
||||
vecs = self.embeddings.embed([text]) # Batch of 1
|
||||
return [v.tolist() for v in vecs]
|
||||
```
|
||||
|
||||
**Updated:**
|
||||
```python
|
||||
async def on_embeddings(self, texts: list[str], model=None):
|
||||
"""Embed texts - processes all texts in single model call"""
|
||||
if not texts:
|
||||
return []
|
||||
|
||||
use_model = model or self.default_model
|
||||
self._load_model(use_model)
|
||||
|
||||
# FastEmbed handles the full batch efficiently
|
||||
all_vecs = list(self.embeddings.embed(texts))
|
||||
|
||||
# Return list of vector sets, one per input text
|
||||
return [[v.tolist()] for v in all_vecs]
|
||||
```
|
||||
|
||||
#### Phase 3: Graph Embeddings Service Update
|
||||
|
||||
**Current (Serial):**
|
||||
```python
|
||||
async def on_message(self, msg, consumer, flow):
|
||||
entities = []
|
||||
for entity in v.entities:
|
||||
vectors = await flow("embeddings-request").embed(text=entity.context)
|
||||
entities.append(EntityEmbeddings(...))
|
||||
```
|
||||
|
||||
**Updated (Batch):**
|
||||
```python
|
||||
async def on_message(self, msg, consumer, flow):
|
||||
# Collect all contexts
|
||||
contexts = [entity.context for entity in v.entities]
|
||||
|
||||
# Single batch embedding call
|
||||
all_vectors = await flow("embeddings-request").embed(texts=contexts)
|
||||
|
||||
# Pair results with entities
|
||||
entities = [
|
||||
EntityEmbeddings(
|
||||
entity=entity.entity,
|
||||
vectors=vectors[0], # First vector from the set
|
||||
chunk_id=entity.chunk_id,
|
||||
)
|
||||
for entity, vectors in zip(v.entities, all_vectors)
|
||||
]
|
||||
```
|
||||
|
||||
#### Phase 4: Row Embeddings Service Update
|
||||
|
||||
**Current (Serial):**
|
||||
```python
|
||||
for text, (index_name, index_value) in texts_to_embed.items():
|
||||
vectors = await flow("embeddings-request").embed(text=text)
|
||||
embeddings_list.append(RowIndexEmbedding(...))
|
||||
```
|
||||
|
||||
**Updated (Batch):**
|
||||
```python
|
||||
# Collect texts and metadata
|
||||
texts = list(texts_to_embed.keys())
|
||||
metadata = list(texts_to_embed.values())
|
||||
|
||||
# Single batch embedding call
|
||||
all_vectors = await flow("embeddings-request").embed(texts=texts)
|
||||
|
||||
# Pair results
|
||||
embeddings_list = [
|
||||
RowIndexEmbedding(
|
||||
index_name=meta[0],
|
||||
index_value=meta[1],
|
||||
text=text,
|
||||
vectors=vectors[0] # First vector from the set
|
||||
)
|
||||
for text, meta, vectors in zip(texts, metadata, all_vectors)
|
||||
]
|
||||
```
|
||||
|
||||
#### Phase 5: CLI Tool Enhancement
|
||||
|
||||
**Updated CLI:**
|
||||
```python
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(...)
|
||||
|
||||
parser.add_argument(
|
||||
'text',
|
||||
nargs='*', # Zero or more texts
|
||||
help='Text(s) to convert to embedding vectors',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-f', '--file',
|
||||
help='File containing texts (one per line)',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--batch-size',
|
||||
type=int,
|
||||
default=32,
|
||||
help='Batch size for processing (default: 32)',
|
||||
)
|
||||
```
|
||||
|
||||
Usage:
|
||||
```bash
|
||||
# Single text (existing)
|
||||
tg-invoke-embeddings "hello world"
|
||||
|
||||
# Multiple texts
|
||||
tg-invoke-embeddings "text one" "text two" "text three"
|
||||
|
||||
# From file
|
||||
tg-invoke-embeddings -f texts.txt --batch-size 64
|
||||
```
|
||||
|
||||
#### Phase 6: Python SDK Enhancement
|
||||
|
||||
**FlowInstance (HTTP client):**
|
||||
|
||||
```python
|
||||
class FlowInstance:
|
||||
def embeddings(self, texts: list[str]) -> list[list[list[float]]]:
|
||||
"""
|
||||
Get embeddings for one or more texts.
|
||||
|
||||
Args:
|
||||
texts: List of texts to embed
|
||||
|
||||
Returns:
|
||||
List of vector sets, one per input text
|
||||
"""
|
||||
input = {"texts": texts}
|
||||
return self.request("service/embeddings", input)["vectors"]
|
||||
```
|
||||
|
||||
**SocketFlowInstance (WebSocket client):**
|
||||
|
||||
```python
|
||||
class SocketFlowInstance:
|
||||
def embeddings(self, texts: list[str], **kwargs: Any) -> list[list[list[float]]]:
|
||||
"""
|
||||
Get embeddings for one or more texts via WebSocket.
|
||||
|
||||
Args:
|
||||
texts: List of texts to embed
|
||||
|
||||
Returns:
|
||||
List of vector sets, one per input text
|
||||
"""
|
||||
request = {"texts": texts}
|
||||
response = self.client._send_request_sync(
|
||||
"embeddings", self.flow_id, request, False
|
||||
)
|
||||
return response["vectors"]
|
||||
```
|
||||
|
||||
**SDK Usage Examples:**
|
||||
|
||||
```python
|
||||
# Single text
|
||||
vectors = flow.embeddings(["hello world"])
|
||||
print(f"Dimensions: {len(vectors[0][0])}")
|
||||
|
||||
# Batch embedding
|
||||
texts = ["text one", "text two", "text three"]
|
||||
all_vectors = flow.embeddings(texts)
|
||||
|
||||
# Process results
|
||||
for text, vecs in zip(texts, all_vectors):
|
||||
print(f"{text}: {len(vecs[0])} dimensions")
|
||||
```
|
||||
|
||||
## Security Considerations
|
||||
|
||||
- **Request Size Limits**: Enforce maximum batch size to prevent resource exhaustion
|
||||
- **Timeout Handling**: Scale timeouts appropriately for batch size
|
||||
- **Memory Limits**: Monitor memory usage for large batches
|
||||
- **Input Validation**: Validate all texts in batch before processing
|
||||
|
||||
## Performance Considerations
|
||||
|
||||
### Expected Improvements
|
||||
|
||||
**Throughput:**
|
||||
- Single-text: ~10-50 texts/second (depending on model)
|
||||
- Batch (size 32): ~200-500 texts/second (5-10x improvement)
|
||||
|
||||
**Latency per Text:**
|
||||
- Single-text: 50-200ms per text
|
||||
- Batch (size 32): 5-20ms per text (amortized)
|
||||
|
||||
**Service-Specific Improvements:**
|
||||
|
||||
| Service | Current | Batched | Improvement |
|
||||
|---------|---------|---------|-------------|
|
||||
| Graph Embeddings (50 entities) | 5-10s | 0.5-1s | 5-10x |
|
||||
| Row Embeddings (100 texts) | 10-20s | 1-2s | 5-10x |
|
||||
| Document Ingestion (1000 chunks) | 100-200s | 10-30s | 5-10x |
|
||||
|
||||
### Configuration Parameters
|
||||
|
||||
```python
|
||||
# Recommended defaults
|
||||
DEFAULT_BATCH_SIZE = 32
|
||||
MAX_BATCH_SIZE = 128
|
||||
BATCH_TIMEOUT_MULTIPLIER = 2.0
|
||||
```
|
||||
|
||||
## Testing Strategy
|
||||
|
||||
### Unit Testing
|
||||
- Single text embedding (backward compatibility)
|
||||
- Empty batch handling
|
||||
- Maximum batch size enforcement
|
||||
- Error handling for partial batch failures
|
||||
|
||||
### Integration Testing
|
||||
- End-to-end batch embedding through Pulsar
|
||||
- Graph embeddings service batch processing
|
||||
- Row embeddings service batch processing
|
||||
- API gateway batch endpoint
|
||||
|
||||
### Performance Testing
|
||||
- Benchmark single vs batch throughput
|
||||
- Memory usage under various batch sizes
|
||||
- Latency distribution analysis
|
||||
|
||||
## Migration Plan
|
||||
|
||||
This is a breaking change release. All phases are implemented together.
|
||||
|
||||
### Phase 1: Schema Changes
|
||||
- Replace `text: str` with `texts: list[str]` in EmbeddingsRequest
|
||||
- Change `vectors` type to `list[list[list[float]]]` in EmbeddingsResponse
|
||||
|
||||
### Phase 2: Processor Updates
|
||||
- Update `on_embeddings` signature in FastEmbed and Ollama processors
|
||||
- Process full batch in single model call
|
||||
|
||||
### Phase 3: Client Updates
|
||||
- Update `EmbeddingsClient.embed()` to accept `texts: list[str]`
|
||||
|
||||
### Phase 4: Caller Updates
|
||||
- Update graph_embeddings to batch entity contexts
|
||||
- Update row_embeddings to batch index texts
|
||||
- Update document_embeddings to use new schema
|
||||
- Update CLI tool
|
||||
|
||||
### Phase 5: API Gateway
|
||||
- Update embeddings endpoint for new schema
|
||||
|
||||
### Phase 6: Python SDK
|
||||
- Update `FlowInstance.embeddings()` signature
|
||||
- Update `SocketFlowInstance.embeddings()` signature
|
||||
|
||||
## Open Questions
|
||||
|
||||
- **Streaming Large Batches**: Should we support streaming results for very large batches (>100 texts)?
|
||||
- **Provider-Specific Limits**: How should we handle providers with different maximum batch sizes?
|
||||
- **Partial Failure Handling**: If one text in a batch fails, should we fail the entire batch or return partial results?
|
||||
- **Document Embeddings Batching**: Should we batch across multiple Chunk messages or keep per-message processing?
|
||||
|
||||
## References
|
||||
|
||||
- [FastEmbed Documentation](https://github.com/qdrant/fastembed)
|
||||
- [Ollama Embeddings API](https://github.com/ollama/ollama)
|
||||
- [EmbeddingsService Implementation](trustgraph-base/trustgraph/base/embeddings_service.py)
|
||||
- [GraphRAG Performance Optimization](graphrag-performance-optimization.md)
|
||||
|
|
@ -103,12 +103,12 @@ class TestFastEmbedDynamicModelLoading(IsolatedAsyncioTestCase):
|
|||
mock_text_embedding_class.reset_mock()
|
||||
|
||||
# Act
|
||||
result = await processor.on_embeddings("test text")
|
||||
result = await processor.on_embeddings(["test text"])
|
||||
|
||||
# Assert
|
||||
mock_fastembed_instance.embed.assert_called_once_with(["test text"])
|
||||
assert processor.cached_model_name == "test-model" # Still using default
|
||||
assert result == [[0.1, 0.2, 0.3, 0.4, 0.5]]
|
||||
assert result == [[[0.1, 0.2, 0.3, 0.4, 0.5]]]
|
||||
|
||||
@patch('trustgraph.embeddings.fastembed.processor.TextEmbedding')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
|
|
@ -126,7 +126,7 @@ class TestFastEmbedDynamicModelLoading(IsolatedAsyncioTestCase):
|
|||
mock_text_embedding_class.reset_mock()
|
||||
|
||||
# Act
|
||||
result = await processor.on_embeddings("test text", model="custom-model")
|
||||
result = await processor.on_embeddings(["test text"], model="custom-model")
|
||||
|
||||
# Assert
|
||||
mock_text_embedding_class.assert_called_once_with(model_name="custom-model")
|
||||
|
|
@ -149,16 +149,16 @@ class TestFastEmbedDynamicModelLoading(IsolatedAsyncioTestCase):
|
|||
initial_call_count = mock_text_embedding_class.call_count
|
||||
|
||||
# Act - switch between models
|
||||
await processor.on_embeddings("text1", model="model-a")
|
||||
await processor.on_embeddings(["text1"], model="model-a")
|
||||
call_count_after_a = mock_text_embedding_class.call_count
|
||||
|
||||
await processor.on_embeddings("text2", model="model-a") # Same, no reload
|
||||
await processor.on_embeddings(["text2"], model="model-a") # Same, no reload
|
||||
call_count_after_a_repeat = mock_text_embedding_class.call_count
|
||||
|
||||
await processor.on_embeddings("text3", model="model-b") # Different, reload
|
||||
await processor.on_embeddings(["text3"], model="model-b") # Different, reload
|
||||
call_count_after_b = mock_text_embedding_class.call_count
|
||||
|
||||
await processor.on_embeddings("text4", model="model-a") # Back to A, reload
|
||||
await processor.on_embeddings(["text4"], model="model-a") # Back to A, reload
|
||||
call_count_after_a_again = mock_text_embedding_class.call_count
|
||||
|
||||
# Assert
|
||||
|
|
@ -183,7 +183,7 @@ class TestFastEmbedDynamicModelLoading(IsolatedAsyncioTestCase):
|
|||
initial_count = mock_text_embedding_class.call_count
|
||||
|
||||
# Act
|
||||
result = await processor.on_embeddings("test text", model=None)
|
||||
result = await processor.on_embeddings(["test text"], model=None)
|
||||
|
||||
# Assert
|
||||
# No reload, using cached default
|
||||
|
|
|
|||
|
|
@ -53,14 +53,14 @@ class TestOllamaDynamicModelLoading(IsolatedAsyncioTestCase):
|
|||
processor = Processor(id="test", concurrency=1, model="test-model", taskgroup=AsyncMock())
|
||||
|
||||
# Act
|
||||
result = await processor.on_embeddings("test text")
|
||||
result = await processor.on_embeddings(["test text"])
|
||||
|
||||
# Assert
|
||||
mock_ollama_client.embed.assert_called_once_with(
|
||||
model="test-model",
|
||||
input="test text"
|
||||
input=["test text"]
|
||||
)
|
||||
assert result == [[0.1, 0.2, 0.3, 0.4, 0.5]]
|
||||
assert result == [[[0.1, 0.2, 0.3, 0.4, 0.5]]]
|
||||
|
||||
@patch('trustgraph.embeddings.ollama.processor.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
|
|
@ -79,14 +79,14 @@ class TestOllamaDynamicModelLoading(IsolatedAsyncioTestCase):
|
|||
processor = Processor(id="test", concurrency=1, model="test-model", taskgroup=AsyncMock())
|
||||
|
||||
# Act
|
||||
result = await processor.on_embeddings("test text", model="custom-model")
|
||||
result = await processor.on_embeddings(["test text"], model="custom-model")
|
||||
|
||||
# Assert
|
||||
mock_ollama_client.embed.assert_called_once_with(
|
||||
model="custom-model",
|
||||
input="test text"
|
||||
input=["test text"]
|
||||
)
|
||||
assert result == [[0.1, 0.2, 0.3, 0.4, 0.5]]
|
||||
assert result == [[[0.1, 0.2, 0.3, 0.4, 0.5]]]
|
||||
|
||||
@patch('trustgraph.embeddings.ollama.processor.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
|
|
@ -105,10 +105,10 @@ class TestOllamaDynamicModelLoading(IsolatedAsyncioTestCase):
|
|||
processor = Processor(id="test", concurrency=1, model="test-model", taskgroup=AsyncMock())
|
||||
|
||||
# Act - switch between different models
|
||||
await processor.on_embeddings("text1", model="model-a")
|
||||
await processor.on_embeddings("text2", model="model-b")
|
||||
await processor.on_embeddings("text3", model="model-a")
|
||||
await processor.on_embeddings("text4") # Use default
|
||||
await processor.on_embeddings(["text1"], model="model-a")
|
||||
await processor.on_embeddings(["text2"], model="model-b")
|
||||
await processor.on_embeddings(["text3"], model="model-a")
|
||||
await processor.on_embeddings(["text4"]) # Use default
|
||||
|
||||
# Assert
|
||||
calls = mock_ollama_client.embed.call_args_list
|
||||
|
|
@ -135,12 +135,12 @@ class TestOllamaDynamicModelLoading(IsolatedAsyncioTestCase):
|
|||
processor = Processor(id="test", concurrency=1, model="test-model", taskgroup=AsyncMock())
|
||||
|
||||
# Act
|
||||
result = await processor.on_embeddings("test text", model=None)
|
||||
result = await processor.on_embeddings(["test text"], model=None)
|
||||
|
||||
# Assert
|
||||
mock_ollama_client.embed.assert_called_once_with(
|
||||
model="test-model",
|
||||
input="test text"
|
||||
input=["test text"]
|
||||
)
|
||||
|
||||
@patch('trustgraph.embeddings.ollama.processor.Client')
|
||||
|
|
|
|||
|
|
@ -353,7 +353,14 @@ class TestRowEmbeddingsProcessor(IsolatedAsyncioTestCase):
|
|||
|
||||
# Mock the flow
|
||||
mock_embeddings_request = AsyncMock()
|
||||
mock_embeddings_request.embed.return_value = [[0.1, 0.2, 0.3]]
|
||||
# Return batch of vector sets (one per text)
|
||||
# 4 unique texts: CUST001, John Doe, CUST002, Jane Smith
|
||||
mock_embeddings_request.embed.return_value = [
|
||||
[[0.1, 0.2, 0.3]], # vectors for text 1
|
||||
[[0.2, 0.3, 0.4]], # vectors for text 2
|
||||
[[0.3, 0.4, 0.5]], # vectors for text 3
|
||||
[[0.4, 0.5, 0.6]], # vectors for text 4
|
||||
]
|
||||
|
||||
mock_output = AsyncMock()
|
||||
|
||||
|
|
@ -368,9 +375,12 @@ class TestRowEmbeddingsProcessor(IsolatedAsyncioTestCase):
|
|||
|
||||
await processor.on_message(mock_msg, MagicMock(), mock_flow)
|
||||
|
||||
# Should have called embed for each unique text
|
||||
# 4 values: CUST001, John Doe, CUST002, Jane Smith
|
||||
assert mock_embeddings_request.embed.call_count == 4
|
||||
# Should have called embed once with all texts in a batch
|
||||
assert mock_embeddings_request.embed.call_count == 1
|
||||
# Verify it was called with a list of texts
|
||||
call_args = mock_embeddings_request.embed.call_args
|
||||
assert 'texts' in call_args.kwargs
|
||||
assert len(call_args.kwargs['texts']) == 4
|
||||
|
||||
# Should have sent output
|
||||
mock_output.send.assert_called()
|
||||
|
|
|
|||
|
|
@ -544,30 +544,29 @@ class FlowInstance:
|
|||
input
|
||||
)["response"]
|
||||
|
||||
def embeddings(self, text):
|
||||
def embeddings(self, texts):
|
||||
"""
|
||||
Generate vector embeddings for text.
|
||||
Generate vector embeddings for one or more texts.
|
||||
|
||||
Converts text into dense vector representations suitable for semantic
|
||||
Converts texts into dense vector representations suitable for semantic
|
||||
search and similarity comparison.
|
||||
|
||||
Args:
|
||||
text: Input text to embed
|
||||
texts: List of input texts to embed
|
||||
|
||||
Returns:
|
||||
list[float]: Vector embedding
|
||||
list[list[list[float]]]: Vector embeddings, one set per input text
|
||||
|
||||
Example:
|
||||
```python
|
||||
flow = api.flow().id("default")
|
||||
vectors = flow.embeddings("quantum computing")
|
||||
print(f"Embedding dimension: {len(vectors)}")
|
||||
vectors = flow.embeddings(["quantum computing"])
|
||||
print(f"Embedding dimension: {len(vectors[0][0])}")
|
||||
```
|
||||
"""
|
||||
|
||||
# The input consists of a text block
|
||||
input = {
|
||||
"text": text
|
||||
"texts": texts
|
||||
}
|
||||
|
||||
return self.request(
|
||||
|
|
|
|||
|
|
@ -712,27 +712,27 @@ class SocketFlowInstance:
|
|||
|
||||
return self.client._send_request_sync("document-embeddings", self.flow_id, request, False)
|
||||
|
||||
def embeddings(self, text: str, **kwargs: Any) -> Dict[str, Any]:
|
||||
def embeddings(self, texts: list, **kwargs: Any) -> Dict[str, Any]:
|
||||
"""
|
||||
Generate vector embeddings for text.
|
||||
Generate vector embeddings for one or more texts.
|
||||
|
||||
Args:
|
||||
text: Input text to embed
|
||||
texts: List of input texts to embed
|
||||
**kwargs: Additional parameters passed to the service
|
||||
|
||||
Returns:
|
||||
dict: Response containing vectors
|
||||
dict: Response containing vectors (one set per input text)
|
||||
|
||||
Example:
|
||||
```python
|
||||
socket = api.socket()
|
||||
flow = socket.flow("default")
|
||||
|
||||
result = flow.embeddings("quantum computing")
|
||||
result = flow.embeddings(["quantum computing"])
|
||||
vectors = result.get("vectors", [])
|
||||
```
|
||||
"""
|
||||
request = {"text": text}
|
||||
request = {"texts": texts}
|
||||
request.update(kwargs)
|
||||
|
||||
return self.client._send_request_sync("embeddings", self.flow_id, request, False)
|
||||
|
|
|
|||
|
|
@ -3,11 +3,11 @@ from . request_response_spec import RequestResponse, RequestResponseSpec
|
|||
from .. schema import EmbeddingsRequest, EmbeddingsResponse
|
||||
|
||||
class EmbeddingsClient(RequestResponse):
|
||||
async def embed(self, text, timeout=30):
|
||||
async def embed(self, texts, timeout=300):
|
||||
|
||||
resp = await self.request(
|
||||
EmbeddingsRequest(
|
||||
text = text
|
||||
texts = texts
|
||||
),
|
||||
timeout=timeout
|
||||
)
|
||||
|
|
|
|||
|
|
@ -65,7 +65,7 @@ class EmbeddingsService(FlowProcessor):
|
|||
|
||||
# Pass model from request if specified (non-empty), otherwise use default
|
||||
model = flow("model")
|
||||
vectors = await self.on_embeddings(request.text, model=model)
|
||||
vectors = await self.on_embeddings(request.texts, model=model)
|
||||
|
||||
await flow("response").send(
|
||||
EmbeddingsResponse(
|
||||
|
|
@ -94,7 +94,7 @@ class EmbeddingsService(FlowProcessor):
|
|||
type = "embeddings-error",
|
||||
message = str(e),
|
||||
),
|
||||
vectors=None,
|
||||
vectors=[],
|
||||
),
|
||||
properties={"id": id}
|
||||
)
|
||||
|
|
|
|||
|
|
@ -5,15 +5,15 @@ from .base import MessageTranslator
|
|||
|
||||
class EmbeddingsRequestTranslator(MessageTranslator):
|
||||
"""Translator for EmbeddingsRequest schema objects"""
|
||||
|
||||
|
||||
def to_pulsar(self, data: Dict[str, Any]) -> EmbeddingsRequest:
|
||||
return EmbeddingsRequest(
|
||||
text=data["text"]
|
||||
texts=data["texts"]
|
||||
)
|
||||
|
||||
|
||||
def from_pulsar(self, obj: EmbeddingsRequest) -> Dict[str, Any]:
|
||||
return {
|
||||
"text": obj.text
|
||||
"texts": obj.texts
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -29,12 +29,12 @@ class TextCompletionResponse:
|
|||
|
||||
@dataclass
|
||||
class EmbeddingsRequest:
|
||||
text: str = ""
|
||||
texts: list[str] = field(default_factory=list)
|
||||
|
||||
@dataclass
|
||||
class EmbeddingsResponse:
|
||||
error: Error | None = None
|
||||
vectors: list[list[float]] = field(default_factory=list)
|
||||
vectors: list[list[list[float]]] = field(default_factory=list)
|
||||
|
||||
############################################################################
|
||||
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ from trustgraph.api import Api
|
|||
default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/')
|
||||
default_token = os.getenv("TRUSTGRAPH_TOKEN", None)
|
||||
|
||||
def query(url, flow_id, text, token=None):
|
||||
def query(url, flow_id, texts, token=None):
|
||||
|
||||
# Create API client
|
||||
api = Api(url=url, token=token)
|
||||
|
|
@ -19,9 +19,14 @@ def query(url, flow_id, text, token=None):
|
|||
|
||||
try:
|
||||
# Call embeddings service
|
||||
result = flow.embeddings(text=text)
|
||||
result = flow.embeddings(texts=texts)
|
||||
vectors = result.get("vectors", [])
|
||||
print(vectors)
|
||||
# Print each text's vectors
|
||||
for i, vecs in enumerate(vectors):
|
||||
if len(texts) > 1:
|
||||
print(f"Text {i + 1}: {vecs}")
|
||||
else:
|
||||
print(vecs)
|
||||
|
||||
finally:
|
||||
# Clean up socket connection
|
||||
|
|
@ -53,9 +58,9 @@ def main():
|
|||
)
|
||||
|
||||
parser.add_argument(
|
||||
'text',
|
||||
nargs=1,
|
||||
help='Text to convert to embedding vector',
|
||||
'texts',
|
||||
nargs='+',
|
||||
help='Text(s) to convert to embedding vectors',
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
|
@ -65,7 +70,7 @@ def main():
|
|||
query(
|
||||
url=args.url,
|
||||
flow_id=args.flow_id,
|
||||
text=args.text[0],
|
||||
texts=args.texts,
|
||||
token=args.token,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -62,11 +62,13 @@ class Processor(FlowProcessor):
|
|||
|
||||
resp = await flow("embeddings-request").request(
|
||||
EmbeddingsRequest(
|
||||
text = v.chunk
|
||||
texts=[v.chunk]
|
||||
)
|
||||
)
|
||||
|
||||
vectors = resp.vectors
|
||||
# vectors[0] is the vector set for the first (only) text
|
||||
# vectors[0][0] is the first vector in that set
|
||||
vectors = resp.vectors[0][0] if resp.vectors else []
|
||||
|
||||
embeds = [
|
||||
ChunkEmbeddings(
|
||||
|
|
|
|||
|
|
@ -46,17 +46,22 @@ class Processor(EmbeddingsService):
|
|||
else:
|
||||
logger.debug(f"Using cached model: {model_name}")
|
||||
|
||||
async def on_embeddings(self, text, model=None):
|
||||
async def on_embeddings(self, texts, model=None):
|
||||
|
||||
if not texts:
|
||||
return []
|
||||
|
||||
use_model = model or self.default_model
|
||||
|
||||
# Reload model if it has changed
|
||||
self._load_model(use_model)
|
||||
|
||||
vecs = self.embeddings.embed([text])
|
||||
# FastEmbed processes the full batch efficiently
|
||||
vecs = list(self.embeddings.embed(texts))
|
||||
|
||||
# Return list of vector sets, one per input text
|
||||
return [
|
||||
v.tolist()
|
||||
[v.tolist()]
|
||||
for v in vecs
|
||||
]
|
||||
|
||||
|
|
|
|||
|
|
@ -58,23 +58,25 @@ class Processor(FlowProcessor):
|
|||
v = msg.value()
|
||||
logger.info(f"Indexing {v.metadata.id}...")
|
||||
|
||||
entities = []
|
||||
|
||||
try:
|
||||
|
||||
for entity in v.entities:
|
||||
# Collect all contexts for batch embedding
|
||||
contexts = [entity.context for entity in v.entities]
|
||||
|
||||
vectors = await flow("embeddings-request").embed(
|
||||
text = entity.context
|
||||
)
|
||||
# Single batch embedding call
|
||||
all_vectors = await flow("embeddings-request").embed(
|
||||
texts=contexts
|
||||
)
|
||||
|
||||
entities.append(
|
||||
EntityEmbeddings(
|
||||
entity=entity.entity,
|
||||
vectors=vectors,
|
||||
chunk_id=entity.chunk_id, # Provenance: source chunk
|
||||
)
|
||||
# Pair results with entities
|
||||
entities = [
|
||||
EntityEmbeddings(
|
||||
entity=entity.entity,
|
||||
vectors=vectors[0], # First vector from the set
|
||||
chunk_id=entity.chunk_id, # Provenance: source chunk
|
||||
)
|
||||
for entity, vectors in zip(v.entities, all_vectors)
|
||||
]
|
||||
|
||||
# Send in batches to avoid oversized messages
|
||||
for i in range(0, len(entities), self.batch_size):
|
||||
|
|
|
|||
|
|
@ -30,16 +30,24 @@ class Processor(EmbeddingsService):
|
|||
self.client = Client(host=ollama)
|
||||
self.default_model = model
|
||||
|
||||
async def on_embeddings(self, text, model=None):
|
||||
async def on_embeddings(self, texts, model=None):
|
||||
|
||||
if not texts:
|
||||
return []
|
||||
|
||||
use_model = model or self.default_model
|
||||
|
||||
# Ollama handles batch input efficiently
|
||||
embeds = self.client.embed(
|
||||
model = use_model,
|
||||
input = text
|
||||
input = texts
|
||||
)
|
||||
|
||||
return embeds.embeddings
|
||||
# Return list of vector sets, one per input text
|
||||
return [
|
||||
[embedding]
|
||||
for embedding in embeds.embeddings
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
|
|
|||
|
|
@ -200,15 +200,23 @@ class Processor(CollectionConfigHandler, FlowProcessor):
|
|||
embeddings_list = []
|
||||
|
||||
try:
|
||||
for text, (index_name, index_value) in texts_to_embed.items():
|
||||
vectors = await flow("embeddings-request").embed(text=text)
|
||||
# Collect texts and metadata for batch embedding
|
||||
texts = list(texts_to_embed.keys())
|
||||
metadata = list(texts_to_embed.values())
|
||||
|
||||
# Single batch embedding call
|
||||
all_vectors = await flow("embeddings-request").embed(texts=texts)
|
||||
|
||||
# Pair results with metadata
|
||||
for text, (index_name, index_value), vectors in zip(
|
||||
texts, metadata, all_vectors
|
||||
):
|
||||
embeddings_list.append(
|
||||
RowIndexEmbedding(
|
||||
index_name=index_name,
|
||||
index_value=index_value,
|
||||
text=text,
|
||||
vectors=vectors
|
||||
vectors=vectors[0] # First vector from the set
|
||||
)
|
||||
)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue