mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-28 18:06:21 +02:00
Batch embeddings (#668)
Base Service (trustgraph-base/trustgraph/base/embeddings_service.py): - Changed on_request to use request.texts FastEmbed Processor (trustgraph-flow/trustgraph/embeddings/fastembed/processor.py): - on_embeddings(texts, model=None) now processes full batch efficiently - Returns [[v.tolist()] for v in vecs] - list of vector sets Ollama Processor (trustgraph-flow/trustgraph/embeddings/ollama/processor.py): - on_embeddings(texts, model=None) passes list directly to Ollama - Returns [[embedding] for embedding in embeds.embeddings] EmbeddingsClient (trustgraph-base/trustgraph/base/embeddings_client.py): - embed(texts, timeout=300) accepts list of texts Tests Updated: - test_fastembed_dynamic_model.py - 4 tests updated for new interface - test_ollama_dynamic_model.py - 4 tests updated for new interface Updated CLI, SDK and APIs
This commit is contained in:
parent
3bf8a65409
commit
0a2ce47a88
16 changed files with 785 additions and 79 deletions
|
|
@ -62,11 +62,13 @@ class Processor(FlowProcessor):
|
|||
|
||||
resp = await flow("embeddings-request").request(
|
||||
EmbeddingsRequest(
|
||||
text = v.chunk
|
||||
texts=[v.chunk]
|
||||
)
|
||||
)
|
||||
|
||||
vectors = resp.vectors
|
||||
# vectors[0] is the vector set for the first (only) text
|
||||
# vectors[0][0] is the first vector in that set
|
||||
vectors = resp.vectors[0][0] if resp.vectors else []
|
||||
|
||||
embeds = [
|
||||
ChunkEmbeddings(
|
||||
|
|
|
|||
|
|
@ -46,17 +46,22 @@ class Processor(EmbeddingsService):
|
|||
else:
|
||||
logger.debug(f"Using cached model: {model_name}")
|
||||
|
||||
async def on_embeddings(self, text, model=None):
|
||||
async def on_embeddings(self, texts, model=None):
|
||||
|
||||
if not texts:
|
||||
return []
|
||||
|
||||
use_model = model or self.default_model
|
||||
|
||||
# Reload model if it has changed
|
||||
self._load_model(use_model)
|
||||
|
||||
vecs = self.embeddings.embed([text])
|
||||
# FastEmbed processes the full batch efficiently
|
||||
vecs = list(self.embeddings.embed(texts))
|
||||
|
||||
# Return list of vector sets, one per input text
|
||||
return [
|
||||
v.tolist()
|
||||
[v.tolist()]
|
||||
for v in vecs
|
||||
]
|
||||
|
||||
|
|
|
|||
|
|
@ -58,23 +58,25 @@ class Processor(FlowProcessor):
|
|||
v = msg.value()
|
||||
logger.info(f"Indexing {v.metadata.id}...")
|
||||
|
||||
entities = []
|
||||
|
||||
try:
|
||||
|
||||
for entity in v.entities:
|
||||
# Collect all contexts for batch embedding
|
||||
contexts = [entity.context for entity in v.entities]
|
||||
|
||||
vectors = await flow("embeddings-request").embed(
|
||||
text = entity.context
|
||||
)
|
||||
# Single batch embedding call
|
||||
all_vectors = await flow("embeddings-request").embed(
|
||||
texts=contexts
|
||||
)
|
||||
|
||||
entities.append(
|
||||
EntityEmbeddings(
|
||||
entity=entity.entity,
|
||||
vectors=vectors,
|
||||
chunk_id=entity.chunk_id, # Provenance: source chunk
|
||||
)
|
||||
# Pair results with entities
|
||||
entities = [
|
||||
EntityEmbeddings(
|
||||
entity=entity.entity,
|
||||
vectors=vectors[0], # First vector from the set
|
||||
chunk_id=entity.chunk_id, # Provenance: source chunk
|
||||
)
|
||||
for entity, vectors in zip(v.entities, all_vectors)
|
||||
]
|
||||
|
||||
# Send in batches to avoid oversized messages
|
||||
for i in range(0, len(entities), self.batch_size):
|
||||
|
|
|
|||
|
|
@ -30,16 +30,24 @@ class Processor(EmbeddingsService):
|
|||
self.client = Client(host=ollama)
|
||||
self.default_model = model
|
||||
|
||||
async def on_embeddings(self, text, model=None):
|
||||
async def on_embeddings(self, texts, model=None):
|
||||
|
||||
if not texts:
|
||||
return []
|
||||
|
||||
use_model = model or self.default_model
|
||||
|
||||
# Ollama handles batch input efficiently
|
||||
embeds = self.client.embed(
|
||||
model = use_model,
|
||||
input = text
|
||||
input = texts
|
||||
)
|
||||
|
||||
return embeds.embeddings
|
||||
# Return list of vector sets, one per input text
|
||||
return [
|
||||
[embedding]
|
||||
for embedding in embeds.embeddings
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
|
|
|||
|
|
@ -200,15 +200,23 @@ class Processor(CollectionConfigHandler, FlowProcessor):
|
|||
embeddings_list = []
|
||||
|
||||
try:
|
||||
for text, (index_name, index_value) in texts_to_embed.items():
|
||||
vectors = await flow("embeddings-request").embed(text=text)
|
||||
# Collect texts and metadata for batch embedding
|
||||
texts = list(texts_to_embed.keys())
|
||||
metadata = list(texts_to_embed.values())
|
||||
|
||||
# Single batch embedding call
|
||||
all_vectors = await flow("embeddings-request").embed(texts=texts)
|
||||
|
||||
# Pair results with metadata
|
||||
for text, (index_name, index_value), vectors in zip(
|
||||
texts, metadata, all_vectors
|
||||
):
|
||||
embeddings_list.append(
|
||||
RowIndexEmbedding(
|
||||
index_name=index_name,
|
||||
index_value=index_value,
|
||||
text=text,
|
||||
vectors=vectors
|
||||
vectors=vectors[0] # First vector from the set
|
||||
)
|
||||
)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue