mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-25 08:26:21 +02:00
Batch embeddings (#668)
Base Service (trustgraph-base/trustgraph/base/embeddings_service.py): - Changed on_request to use request.texts FastEmbed Processor (trustgraph-flow/trustgraph/embeddings/fastembed/processor.py): - on_embeddings(texts, model=None) now processes full batch efficiently - Returns [[v.tolist()] for v in vecs] - list of vector sets Ollama Processor (trustgraph-flow/trustgraph/embeddings/ollama/processor.py): - on_embeddings(texts, model=None) passes list directly to Ollama - Returns [[embedding] for embedding in embeds.embeddings] EmbeddingsClient (trustgraph-base/trustgraph/base/embeddings_client.py): - embed(texts, timeout=300) accepts list of texts Tests Updated: - test_fastembed_dynamic_model.py - 4 tests updated for new interface - test_ollama_dynamic_model.py - 4 tests updated for new interface Updated CLI, SDK and APIs
This commit is contained in:
parent
3bf8a65409
commit
0a2ce47a88
16 changed files with 785 additions and 79 deletions
|
|
@ -10,7 +10,7 @@ from trustgraph.api import Api
|
|||
default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/')
|
||||
default_token = os.getenv("TRUSTGRAPH_TOKEN", None)
|
||||
|
||||
def query(url, flow_id, text, token=None):
|
||||
def query(url, flow_id, texts, token=None):
|
||||
|
||||
# Create API client
|
||||
api = Api(url=url, token=token)
|
||||
|
|
@ -19,9 +19,14 @@ def query(url, flow_id, text, token=None):
|
|||
|
||||
try:
|
||||
# Call embeddings service
|
||||
result = flow.embeddings(text=text)
|
||||
result = flow.embeddings(texts=texts)
|
||||
vectors = result.get("vectors", [])
|
||||
print(vectors)
|
||||
# Print each text's vectors
|
||||
for i, vecs in enumerate(vectors):
|
||||
if len(texts) > 1:
|
||||
print(f"Text {i + 1}: {vecs}")
|
||||
else:
|
||||
print(vecs)
|
||||
|
||||
finally:
|
||||
# Clean up socket connection
|
||||
|
|
@ -53,9 +58,9 @@ def main():
|
|||
)
|
||||
|
||||
parser.add_argument(
|
||||
'text',
|
||||
nargs=1,
|
||||
help='Text to convert to embedding vector',
|
||||
'texts',
|
||||
nargs='+',
|
||||
help='Text(s) to convert to embedding vectors',
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
|
@ -65,7 +70,7 @@ def main():
|
|||
query(
|
||||
url=args.url,
|
||||
flow_id=args.flow_id,
|
||||
text=args.text[0],
|
||||
texts=args.texts,
|
||||
token=args.token,
|
||||
)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue