diff --git a/trustgraph/embeddings/ollama/processor.py b/trustgraph/embeddings/ollama/processor.py index e5ef34a8..5c36f86a 100755 --- a/trustgraph/embeddings/ollama/processor.py +++ b/trustgraph/embeddings/ollama/processor.py @@ -3,31 +3,23 @@ Embeddings service, applies an embeddings model selected from HuggingFace. Input is text, output is embeddings vector. """ - -import pulsar -from pulsar.schema import JsonSchema -import tempfile -import base64 -import os -import argparse from langchain_community.embeddings import OllamaEmbeddings -import time from ... schema import EmbeddingsRequest, EmbeddingsResponse from ... log_level import LogLevel +from ... base import ConsumerProducer -default_pulsar_host = os.getenv("PULSAR_HOST", 'pulsar://pulsar:6650') default_input_queue = 'embeddings' default_output_queue = 'embeddings-response' default_subscriber = 'embeddings-ollama' default_model="mxbai-embed-large" default_ollama = 'http://localhost:11434' -class Processor: +class Processor(ConsumerProducer): def __init__( self, - pulsar_host=default_pulsar_host, + pulsar_host=None, input_queue=default_input_queue, output_queue=default_output_queue, subscriber=default_subscriber, @@ -36,140 +28,59 @@ class Processor: ollama=default_ollama, ): - self.client = None - - self.client = pulsar.Client( - pulsar_host, - logger=pulsar.ConsoleLogger(log_level.to_pulsar()) - ) - - self.consumer = self.client.subscribe( - input_queue, subscriber, - schema=JsonSchema(EmbeddingsRequest), - ) - - self.producer = self.client.create_producer( - topic=output_queue, - schema=JsonSchema(EmbeddingsResponse), + super(Processor, self).__init__( + pulsar_host=pulsar_host, + log_level=log_level, + input_queue=input_queue, + output_queue=output_queue, + subscriber=subscriber, + input_schema=EmbeddingsRequest, + output_schema=EmbeddingsResponse, ) self.embeddings = OllamaEmbeddings(base_url=ollama, model=model) - def run(self): + def handle(self, msg): - while True: + v = msg.value() - msg = self.consumer.receive() + # Sender-produced ID - try: + id = msg.properties()["id"] - v = msg.value() + print(f"Handling input {id}...", flush=True) - # Sender-produced ID + text = v.text + embeds = self.embeddings.embed_query([text]) - id = msg.properties()["id"] + print("Send response...", flush=True) + r = EmbeddingsResponse(vectors=[embeds]) - print(f"Handling input {id}...", flush=True) + self.producer.send(r, properties={"id": id}) - text = v.text - embeds = self.embeddings.embed_query([text]) + print("Done.", flush=True) - print("Send response...", flush=True) - r = EmbeddingsResponse(vectors=[embeds]) + @staticmethod + def add_args(parser): - self.producer.send(r, properties={"id": id}) + ConsumerProducer.add_args( + parser, default_input_queue, default_subscriber, + default_output_queue, + ) - print("Done.", flush=True) + parser.add_argument( + '-m', '--model', + default=default_model, + help=f'Embeddings model (default: {default_model})' + ) - # Acknowledge successful processing of the message - self.consumer.acknowledge(msg) - - except Exception as e: - - print("Exception:", e, flush=True) - - # Message failed to be processed - self.consumer.negative_acknowledge(msg) - - def __del__(self): - - if self.client: - self.client.close() + parser.add_argument( + '-r', '--ollama', + default=default_ollama, + help=f'ollama (default: {default_ollama})' + ) def run(): - parser = argparse.ArgumentParser( - prog='embeddings-ollama', - description=__doc__, - ) - - parser.add_argument( - '-p', '--pulsar-host', - default=default_pulsar_host, - help=f'Pulsar host (default: {default_pulsar_host})', - ) - - parser.add_argument( - '-i', '--input-queue', - default=default_input_queue, - help=f'Input queue (default: {default_input_queue})' - ) - - parser.add_argument( - '-s', '--subscriber', - default=default_subscriber, - help=f'Queue subscriber name (default: {default_subscriber})' - ) - - parser.add_argument( - '-o', '--output-queue', - default=default_output_queue, - help=f'Output queue (default: {default_output_queue})' - ) - - parser.add_argument( - '-l', '--log-level', - type=LogLevel, - default=LogLevel.INFO, - choices=list(LogLevel), - help=f'Output queue (default: info)' - ) - - parser.add_argument( - '-m', '--model', - default=default_model, - help=f'Embeddings model (default: {default_model})' - ) - - parser.add_argument( - '-r', '--ollama', - default=default_ollama, - help=f'ollama (default: {default_ollama})' - ) - - args = parser.parse_args() - - - while True: - - try: - - p = Processor( - pulsar_host=args.pulsar_host, - input_queue=args.input_queue, - output_queue=args.output_queue, - subscriber=args.subscriber, - log_level=args.log_level, - model=args.model, - ollama=args.ollama, - ) - - p.run() - - except Exception as e: - - print("Exception:", e, flush=True) - print("Will retry...", flush=True) - - time.sleep(10) + Processor.start('embeddings-ollama', __doc__)