Batch embeddings (#668)

Base Service (trustgraph-base/trustgraph/base/embeddings_service.py):
- Changed on_request to use request.texts

FastEmbed Processor
(trustgraph-flow/trustgraph/embeddings/fastembed/processor.py):
- on_embeddings(texts, model=None) now processes full batch efficiently
- Returns [[v.tolist()] for v in vecs] - list of vector sets

Ollama Processor (trustgraph-flow/trustgraph/embeddings/ollama/processor.py):
- on_embeddings(texts, model=None) passes list directly to Ollama
- Returns [[embedding] for embedding in embeds.embeddings]

EmbeddingsClient (trustgraph-base/trustgraph/base/embeddings_client.py):
- embed(texts, timeout=300) accepts list of texts

Tests Updated:
- test_fastembed_dynamic_model.py - 4 tests updated for new interface
- test_ollama_dynamic_model.py - 4 tests updated for new interface

Updated CLI, SDK and APIs
This commit is contained in:
cybermaggedon 2026-03-08 18:36:54 +00:00 committed by GitHub
parent 3bf8a65409
commit 0a2ce47a88
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
16 changed files with 785 additions and 79 deletions

View file

@ -62,11 +62,13 @@ class Processor(FlowProcessor):
resp = await flow("embeddings-request").request(
EmbeddingsRequest(
text = v.chunk
texts=[v.chunk]
)
)
vectors = resp.vectors
# vectors[0] is the vector set for the first (only) text
# vectors[0][0] is the first vector in that set
vectors = resp.vectors[0][0] if resp.vectors else []
embeds = [
ChunkEmbeddings(

View file

@ -46,17 +46,22 @@ class Processor(EmbeddingsService):
else:
logger.debug(f"Using cached model: {model_name}")
async def on_embeddings(self, text, model=None):
async def on_embeddings(self, texts, model=None):
if not texts:
return []
use_model = model or self.default_model
# Reload model if it has changed
self._load_model(use_model)
vecs = self.embeddings.embed([text])
# FastEmbed processes the full batch efficiently
vecs = list(self.embeddings.embed(texts))
# Return list of vector sets, one per input text
return [
v.tolist()
[v.tolist()]
for v in vecs
]

View file

@ -58,23 +58,25 @@ class Processor(FlowProcessor):
v = msg.value()
logger.info(f"Indexing {v.metadata.id}...")
entities = []
try:
for entity in v.entities:
# Collect all contexts for batch embedding
contexts = [entity.context for entity in v.entities]
vectors = await flow("embeddings-request").embed(
text = entity.context
)
# Single batch embedding call
all_vectors = await flow("embeddings-request").embed(
texts=contexts
)
entities.append(
EntityEmbeddings(
entity=entity.entity,
vectors=vectors,
chunk_id=entity.chunk_id, # Provenance: source chunk
)
# Pair results with entities
entities = [
EntityEmbeddings(
entity=entity.entity,
vectors=vectors[0], # First vector from the set
chunk_id=entity.chunk_id, # Provenance: source chunk
)
for entity, vectors in zip(v.entities, all_vectors)
]
# Send in batches to avoid oversized messages
for i in range(0, len(entities), self.batch_size):

View file

@ -30,16 +30,24 @@ class Processor(EmbeddingsService):
self.client = Client(host=ollama)
self.default_model = model
async def on_embeddings(self, text, model=None):
async def on_embeddings(self, texts, model=None):
if not texts:
return []
use_model = model or self.default_model
# Ollama handles batch input efficiently
embeds = self.client.embed(
model = use_model,
input = text
input = texts
)
return embeds.embeddings
# Return list of vector sets, one per input text
return [
[embedding]
for embedding in embeds.embeddings
]
@staticmethod
def add_args(parser):

View file

@ -200,15 +200,23 @@ class Processor(CollectionConfigHandler, FlowProcessor):
embeddings_list = []
try:
for text, (index_name, index_value) in texts_to_embed.items():
vectors = await flow("embeddings-request").embed(text=text)
# Collect texts and metadata for batch embedding
texts = list(texts_to_embed.keys())
metadata = list(texts_to_embed.values())
# Single batch embedding call
all_vectors = await flow("embeddings-request").embed(texts=texts)
# Pair results with metadata
for text, (index_name, index_value), vectors in zip(
texts, metadata, all_vectors
):
embeddings_list.append(
RowIndexEmbedding(
index_name=index_name,
index_value=index_value,
text=text,
vectors=vectors
vectors=vectors[0] # First vector from the set
)
)