From 1a506d79fb56ebc6a745be23596d0159414bf7ae Mon Sep 17 00:00:00 2001 From: Cyber MacGeddon Date: Wed, 17 Jul 2024 15:33:50 +0100 Subject: [PATCH] Added producer/consumer --- trustgraph/base/processor.py | 84 +++++++++++++++++++++++++++++++ trustgraph/llm/ollama_text/llm.py | 28 +++++------ 2 files changed, 96 insertions(+), 16 deletions(-) diff --git a/trustgraph/base/processor.py b/trustgraph/base/processor.py index be3403ee..4e0ba6f4 100644 --- a/trustgraph/base/processor.py +++ b/trustgraph/base/processor.py @@ -140,3 +140,87 @@ class Consumer(BaseProcessor): help=f'Queue subscriber name (default: {default_subscriber})' ) +class ConsumerProducer(BaseProcessor): + + def __init__( + self, + pulsar_host=None, + log_level=LogLevel.INFO, + input_queue="input", + output_queue="output", + subscriber="subscriber", + request_schema=None, + response_schema=None, + ): + + super(ConsumerProducer, self).__init__( + pulsar_host=pulsar_host, + log_level=log_level, + ) + + if request_schema == None: + raise RuntimeError("request_schema must be specified") + + if response_schema == None: + raise RuntimeError("response_schema must be specified") + + self.consumer = self.client.subscribe( + input_queue, subscriber, + schema=JsonSchema(request_schema), + ) + + self.producer = self.client.create_producer( + topic=output_queue, + schema=JsonSchema(response_schema), + ) + + def run(self): + + while True: + + msg = self.consumer.receive() + + try: + + resp = self.handle(msg) + + # Acknowledge successful processing of the message + self.consumer.acknowledge(msg) + + except Exception as e: + + print("Exception:", e, flush=True) + + # Message failed to be processed + self.consumer.negative_acknowledge(msg) + + def send(self, msg, properties={}): + + print(msg) + self.producer.send(msg, properties) + + @staticmethod + def add_args( + parser, default_input_queue, default_subscriber, + default_output_queue, + ): + + BaseProcessor.add_args(parser) + + parser.add_argument( + '-i', '--input-queue', + default=default_input_queue, + help=f'Input queue (default: {default_input_queue})' + ) + + parser.add_argument( + '-s', '--subscriber', + default=default_subscriber, + help=f'Queue subscriber name (default: {default_subscriber})' + ) + + parser.add_argument( + '-o', '--output-queue', + default=default_output_queue, + help=f'Output queue (default: {default_output_queue})' + ) diff --git a/trustgraph/llm/ollama_text/llm.py b/trustgraph/llm/ollama_text/llm.py index 2646b736..406d87df 100755 --- a/trustgraph/llm/ollama_text/llm.py +++ b/trustgraph/llm/ollama_text/llm.py @@ -15,7 +15,7 @@ import time from ... schema import TextCompletionRequest, TextCompletionResponse from ... log_level import LogLevel -from ... base import Consumer +from ... base import ConsumerProducer default_input_queue = 'llm-complete-text' default_output_queue = 'llm-complete-text-response' @@ -23,7 +23,7 @@ default_subscriber = 'llm-ollama-text' default_model = 'gemma2' default_ollama = 'http://localhost:11434' -class Processor(Consumer): +class Processor(ConsumerProducer): def __init__( self, @@ -39,14 +39,11 @@ class Processor(Consumer): super(Processor, self).__init__( pulsar_host=pulsar_host, log_level=log_level, - input_queue=default_input_queue, - subscriber=default_subscriber, + input_queue=input_queue, + output_queue=output_queue, + subscriber=subscriber, request_schema=TextCompletionRequest, - ) - - self.producer = self.client.create_producer( - topic=output_queue, - schema=JsonSchema(TextCompletionResponse), + response_schema=TextCompletionResponse, ) self.llm = Ollama(base_url=ollama, model=model) @@ -65,20 +62,19 @@ class Processor(Consumer): response = self.llm.invoke(prompt) print("Send response...", flush=True) + r = TextCompletionResponse(response=response) - self.producer.send(r, properties={"id": id}) + + self.send(r, properties={"id": id}) print("Done.", flush=True) @staticmethod def add_args(parser): - Consumer.add_args(parser, default_input_queue, default_subscriber) - - parser.add_argument( - '-o', '--output-queue', - default=default_output_queue, - help=f'Output queue (default: {default_output_queue})' + ConsumerProducer.add_args( + parser, default_input_queue, default_subscriber, + default_output_queue, ) parser.add_argument(