producer base

This commit is contained in:
Cyber MacGeddon 2024-07-17 15:39:16 +01:00
parent 1a506d79fb
commit 40b3553053
2 changed files with 54 additions and 13 deletions

View file

@ -87,7 +87,7 @@ class Consumer(BaseProcessor):
log_level=LogLevel.INFO, log_level=LogLevel.INFO,
input_queue="input", input_queue="input",
subscriber="subscriber", subscriber="subscriber",
request_schema=None, input_schema=None,
): ):
super(Consumer, self).__init__( super(Consumer, self).__init__(
@ -95,12 +95,12 @@ class Consumer(BaseProcessor):
log_level=log_level, log_level=log_level,
) )
if request_schema == None: if input_schema == None:
raise RuntimeError("request_schema must be specified") raise RuntimeError("input_schema must be specified")
self.consumer = self.client.subscribe( self.consumer = self.client.subscribe(
input_queue, subscriber, input_queue, subscriber,
schema=JsonSchema(request_schema), schema=JsonSchema(input_schema),
) )
def run(self): def run(self):
@ -149,8 +149,8 @@ class ConsumerProducer(BaseProcessor):
input_queue="input", input_queue="input",
output_queue="output", output_queue="output",
subscriber="subscriber", subscriber="subscriber",
request_schema=None, input_schema=None,
response_schema=None, output_schema=None,
): ):
super(ConsumerProducer, self).__init__( super(ConsumerProducer, self).__init__(
@ -158,20 +158,20 @@ class ConsumerProducer(BaseProcessor):
log_level=log_level, log_level=log_level,
) )
if request_schema == None: if input_schema == None:
raise RuntimeError("request_schema must be specified") raise RuntimeError("input_schema must be specified")
if response_schema == None: if output_schema == None:
raise RuntimeError("response_schema must be specified") raise RuntimeError("output_schema must be specified")
self.consumer = self.client.subscribe( self.consumer = self.client.subscribe(
input_queue, subscriber, input_queue, subscriber,
schema=JsonSchema(request_schema), schema=JsonSchema(input_schema),
) )
self.producer = self.client.create_producer( self.producer = self.client.create_producer(
topic=output_queue, topic=output_queue,
schema=JsonSchema(response_schema), schema=JsonSchema(output_schema),
) )
def run(self): def run(self):
@ -224,3 +224,45 @@ class ConsumerProducer(BaseProcessor):
default=default_output_queue, default=default_output_queue,
help=f'Output queue (default: {default_output_queue})' help=f'Output queue (default: {default_output_queue})'
) )
class Producer(BaseProcessor):
def __init__(
self,
pulsar_host=None,
log_level=LogLevel.INFO,
output_queue="output",
output_schema=None,
):
super(Producer, self).__init__(
pulsar_host=pulsar_host,
log_level=log_level,
)
if output_schema == None:
raise RuntimeError("output_schema must be specified")
self.producer = self.client.create_producer(
topic=output_queue,
schema=JsonSchema(output_schema),
)
def send(self, msg, properties={}):
print(msg)
self.producer.send(msg, properties)
@staticmethod
def add_args(
parser, default_input_queue, default_subscriber,
default_output_queue,
):
BaseProcessor.add_args(parser)
parser.add_argument(
'-o', '--output-queue',
default=default_output_queue,
help=f'Output queue (default: {default_output_queue})'
)

View file

@ -53,7 +53,6 @@ class Processor(ConsumerProducer):
v = msg.value() v = msg.value()
# Sender-produced ID # Sender-produced ID
id = msg.properties()["id"] id = msg.properties()["id"]
print(f"Handling prompt {id}...", flush=True) print(f"Handling prompt {id}...", flush=True)