From 40b355305393514ff3f01de8c5bc29f1fd53bdd2 Mon Sep 17 00:00:00 2001 From: Cyber MacGeddon Date: Wed, 17 Jul 2024 15:39:16 +0100 Subject: [PATCH] producer base --- trustgraph/base/processor.py | 66 +++++++++++++++++++++++++------ trustgraph/llm/ollama_text/llm.py | 1 - 2 files changed, 54 insertions(+), 13 deletions(-) diff --git a/trustgraph/base/processor.py b/trustgraph/base/processor.py index 4e0ba6f4..90cbb132 100644 --- a/trustgraph/base/processor.py +++ b/trustgraph/base/processor.py @@ -87,7 +87,7 @@ class Consumer(BaseProcessor): log_level=LogLevel.INFO, input_queue="input", subscriber="subscriber", - request_schema=None, + input_schema=None, ): super(Consumer, self).__init__( @@ -95,12 +95,12 @@ class Consumer(BaseProcessor): log_level=log_level, ) - if request_schema == None: - raise RuntimeError("request_schema must be specified") + if input_schema == None: + raise RuntimeError("input_schema must be specified") self.consumer = self.client.subscribe( input_queue, subscriber, - schema=JsonSchema(request_schema), + schema=JsonSchema(input_schema), ) def run(self): @@ -149,8 +149,8 @@ class ConsumerProducer(BaseProcessor): input_queue="input", output_queue="output", subscriber="subscriber", - request_schema=None, - response_schema=None, + input_schema=None, + output_schema=None, ): super(ConsumerProducer, self).__init__( @@ -158,20 +158,20 @@ class ConsumerProducer(BaseProcessor): log_level=log_level, ) - if request_schema == None: - raise RuntimeError("request_schema must be specified") + if input_schema == None: + raise RuntimeError("input_schema must be specified") - if response_schema == None: - raise RuntimeError("response_schema must be specified") + if output_schema == None: + raise RuntimeError("output_schema must be specified") self.consumer = self.client.subscribe( input_queue, subscriber, - schema=JsonSchema(request_schema), + schema=JsonSchema(input_schema), ) self.producer = self.client.create_producer( topic=output_queue, - schema=JsonSchema(response_schema), + schema=JsonSchema(output_schema), ) def run(self): @@ -224,3 +224,45 @@ class ConsumerProducer(BaseProcessor): default=default_output_queue, help=f'Output queue (default: {default_output_queue})' ) + +class Producer(BaseProcessor): + + def __init__( + self, + pulsar_host=None, + log_level=LogLevel.INFO, + output_queue="output", + output_schema=None, + ): + + super(Producer, self).__init__( + pulsar_host=pulsar_host, + log_level=log_level, + ) + + if output_schema == None: + raise RuntimeError("output_schema must be specified") + + self.producer = self.client.create_producer( + topic=output_queue, + schema=JsonSchema(output_schema), + ) + + def send(self, msg, properties={}): + + print(msg) + self.producer.send(msg, properties) + + @staticmethod + def add_args( + parser, default_input_queue, default_subscriber, + default_output_queue, + ): + + BaseProcessor.add_args(parser) + + parser.add_argument( + '-o', '--output-queue', + default=default_output_queue, + help=f'Output queue (default: {default_output_queue})' + ) diff --git a/trustgraph/llm/ollama_text/llm.py b/trustgraph/llm/ollama_text/llm.py index 406d87df..b8669876 100755 --- a/trustgraph/llm/ollama_text/llm.py +++ b/trustgraph/llm/ollama_text/llm.py @@ -53,7 +53,6 @@ class Processor(ConsumerProducer): v = msg.value() # Sender-produced ID - id = msg.properties()["id"] print(f"Handling prompt {id}...", flush=True)