trustgraph/trustgraph-cli/trustgraph/cli/invoke_embeddings.py
cybermaggedon 0a2ce47a88
Batch embeddings (#668)
Base Service (trustgraph-base/trustgraph/base/embeddings_service.py):
- Changed on_request to use request.texts

FastEmbed Processor
(trustgraph-flow/trustgraph/embeddings/fastembed/processor.py):
- on_embeddings(texts, model=None) now processes full batch efficiently
- Returns [[v.tolist()] for v in vecs] - list of vector sets

Ollama Processor (trustgraph-flow/trustgraph/embeddings/ollama/processor.py):
- on_embeddings(texts, model=None) passes list directly to Ollama
- Returns [[embedding] for embedding in embeds.embeddings]

EmbeddingsClient (trustgraph-base/trustgraph/base/embeddings_client.py):
- embed(texts, timeout=300) accepts list of texts

Tests Updated:
- test_fastembed_dynamic_model.py - 4 tests updated for new interface
- test_ollama_dynamic_model.py - 4 tests updated for new interface

Updated CLI, SDK and APIs
2026-03-08 18:36:54 +00:00

82 lines
1.8 KiB
Python

"""
Invokes the embeddings service to convert text to a vector embedding.
Returns the embedding vector as a list of floats.
"""
import argparse
import os
from trustgraph.api import Api
default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/')
default_token = os.getenv("TRUSTGRAPH_TOKEN", None)
def query(url, flow_id, texts, token=None):
# Create API client
api = Api(url=url, token=token)
socket = api.socket()
flow = socket.flow(flow_id)
try:
# Call embeddings service
result = flow.embeddings(texts=texts)
vectors = result.get("vectors", [])
# Print each text's vectors
for i, vecs in enumerate(vectors):
if len(texts) > 1:
print(f"Text {i + 1}: {vecs}")
else:
print(vecs)
finally:
# Clean up socket connection
socket.close()
def main():
parser = argparse.ArgumentParser(
prog='tg-invoke-embeddings',
description=__doc__,
)
parser.add_argument(
'-u', '--url',
default=default_url,
help=f'API URL (default: {default_url})',
)
parser.add_argument(
'-t', '--token',
default=default_token,
help='Authentication token (default: $TRUSTGRAPH_TOKEN)',
)
parser.add_argument(
'-f', '--flow-id',
default="default",
help=f'Flow ID (default: default)'
)
parser.add_argument(
'texts',
nargs='+',
help='Text(s) to convert to embedding vectors',
)
args = parser.parse_args()
try:
query(
url=args.url,
flow_id=args.flow_id,
texts=args.texts,
token=args.token,
)
except Exception as e:
print("Exception:", e, flush=True)
if __name__ == "__main__":
main()