From de283f25823180e207f3411e96358e5aa046380b Mon Sep 17 00:00:00 2001 From: Cyber MacGeddon Date: Wed, 17 Jul 2024 16:03:49 +0100 Subject: [PATCH] Ported embeddings-hf --- trustgraph/embeddings/hf/hf.py | 156 ++++++------------------ trustgraph/vector/milvus_write/write.py | 9 -- 2 files changed, 34 insertions(+), 131 deletions(-) diff --git a/trustgraph/embeddings/hf/hf.py b/trustgraph/embeddings/hf/hf.py index 96ee8e8b..82273310 100755 --- a/trustgraph/embeddings/hf/hf.py +++ b/trustgraph/embeddings/hf/hf.py @@ -4,29 +4,22 @@ Embeddings service, applies an embeddings model selected from HuggingFace. Input is text, output is embeddings vector. """ -import pulsar -from pulsar.schema import JsonSchema -import tempfile -import base64 -import os -import argparse from langchain_huggingface import HuggingFaceEmbeddings -import time from ... schema import EmbeddingsRequest, EmbeddingsResponse from ... log_level import LogLevel +from ... base import ConsumerProducer -default_pulsar_host = os.getenv("PULSAR_HOST", 'pulsar://pulsar:6650') default_input_queue = 'embeddings' default_output_queue = 'embeddings-response' default_subscriber = 'embeddings-hf' default_model="all-MiniLM-L6-v2" -class Processor: +class Processor(ConsumerProducer): def __init__( self, - pulsar_host=default_pulsar_host, + pulsar_host=None, input_queue=default_input_queue, output_queue=default_output_queue, subscriber=default_subscriber, @@ -34,132 +27,51 @@ class Processor: model=default_model, ): - self.client = None - - self.client = pulsar.Client( - pulsar_host, - logger=pulsar.ConsoleLogger(log_level.to_pulsar()) - ) - - self.consumer = self.client.subscribe( - input_queue, subscriber, - schema=JsonSchema(EmbeddingsRequest), - ) - - self.producer = self.client.create_producer( - topic=output_queue, - schema=JsonSchema(EmbeddingsResponse), + super(Processor, self).__init__( + pulsar_host=pulsar_host, + log_level=log_level, + input_queue=input_queue, + output_queue=output_queue, + subscriber=subscriber, + input_schema=EmbeddingsRequest, + output_schema=EmbeddingsResponse, ) self.embeddings = HuggingFaceEmbeddings(model_name=model) - def run(self): + def handle(self, msg): - while True: + v = msg.value() - msg = self.consumer.receive() + # Sender-produced ID + id = msg.properties()["id"] - try: + print(f"Handling input {id}...", flush=True) - v = msg.value() + text = v.text + embeds = self.embeddings.embed_documents([text]) - # Sender-produced ID + print("Send response...", flush=True) + r = EmbeddingsResponse(vectors=embeds) + self.producer.send(r, properties={"id": id}) - id = msg.properties()["id"] + print("Done.", flush=True) - print(f"Handling input {id}...", flush=True) + @staticmethod + def add_args(parser): - text = v.text - embeds = self.embeddings.embed_documents([text]) + ConsumerProducer.add_args( + parser, default_input_queue, default_subscriber, + default_output_queue, + ) - print("Send response...", flush=True) - r = EmbeddingsResponse(vectors=embeds) - self.producer.send(r, properties={"id": id}) - - print("Done.", flush=True) - - # Acknowledge successful processing of the message - self.consumer.acknowledge(msg) - - except Exception as e: - - print("Exception:", e, flush=True) - - # Message failed to be processed - self.consumer.negative_acknowledge(msg) - - def __del__(self): - - if self.client: - self.client.close() + parser.add_argument( + '-m', '--model', + default="all-MiniLM-L6-v2", + help=f'LLM model (default: all-MiniLM-L6-v2)' + ) def run(): - parser = argparse.ArgumentParser( - prog='llm-ollama-text', - description=__doc__, - ) - - parser.add_argument( - '-p', '--pulsar-host', - default=default_pulsar_host, - help=f'Pulsar host (default: {default_pulsar_host})', - ) - - parser.add_argument( - '-i', '--input-queue', - default=default_input_queue, - help=f'Input queue (default: {default_input_queue})' - ) - - parser.add_argument( - '-s', '--subscriber', - default=default_subscriber, - help=f'Queue subscriber name (default: {default_subscriber})' - ) - - parser.add_argument( - '-o', '--output-queue', - default=default_output_queue, - help=f'Output queue (default: {default_output_queue})' - ) - - parser.add_argument( - '-l', '--log-level', - type=LogLevel, - default=LogLevel.INFO, - choices=list(LogLevel), - help=f'Output queue (default: info)' - ) - - parser.add_argument( - '-m', '--model', - default="all-MiniLM-L6-v2", - help=f'LLM model (default: all-MiniLM-L6-v2)' - ) - - args = parser.parse_args() - - - while True: - - try: - - p = Processor( - pulsar_host=args.pulsar_host, - input_queue=args.input_queue, - output_queue=args.output_queue, - subscriber=args.subscriber, - log_level=args.log_level, - model=args.model, - ) - - p.run() - - except Exception as e: - - print("Exception:", e, flush=True) - print("Will retry...", flush=True) - - time.sleep(10) + Processor.start("embeddings-hf", __doc__) diff --git a/trustgraph/vector/milvus_write/write.py b/trustgraph/vector/milvus_write/write.py index be40ae32..cf2a5f76 100755 --- a/trustgraph/vector/milvus_write/write.py +++ b/trustgraph/vector/milvus_write/write.py @@ -3,15 +3,6 @@ Accepts entity/vector pairs and writes them to a Milvus store. """ -import pulsar -from pulsar.schema import JsonSchema -from langchain_community.document_loaders import PyPDFLoader -import tempfile -import base64 -import os -import argparse -import time - from ... schema import VectorsAssociation from ... log_level import LogLevel from ... triple_vectors import TripleVectors