2024-07-10 23:20:06 +01:00
|
|
|
#!/usr/bin/env python3
|
|
|
|
|
|
|
|
|
|
import pulsar
|
|
|
|
|
import _pulsar
|
|
|
|
|
from pulsar.schema import JsonSchema
|
|
|
|
|
import hashlib
|
|
|
|
|
import uuid
|
|
|
|
|
|
2024-07-23 21:34:03 +01:00
|
|
|
from . schema import TextCompletionRequest, TextCompletionResponse
|
|
|
|
|
from . schema import text_completion_request_queue
|
|
|
|
|
from . schema import text_completion_response_queue
|
|
|
|
|
|
2024-07-10 23:20:06 +01:00
|
|
|
# Ugly
|
|
|
|
|
ERROR=_pulsar.LoggerLevel.Error
|
|
|
|
|
WARN=_pulsar.LoggerLevel.Warn
|
|
|
|
|
INFO=_pulsar.LoggerLevel.Info
|
|
|
|
|
DEBUG=_pulsar.LoggerLevel.Debug
|
|
|
|
|
|
|
|
|
|
class LlmClient:
|
|
|
|
|
|
|
|
|
|
def __init__(
|
2024-07-23 23:02:09 +01:00
|
|
|
self, log_level=ERROR, subscriber=None,
|
2024-07-10 23:20:06 +01:00
|
|
|
pulsar_host="pulsar://pulsar:6650",
|
|
|
|
|
):
|
|
|
|
|
|
2024-07-23 23:02:09 +01:00
|
|
|
if subscriber == None:
|
|
|
|
|
subscriber = str(uuid.uuid4())
|
2024-07-10 23:20:06 +01:00
|
|
|
|
|
|
|
|
self.client = pulsar.Client(
|
|
|
|
|
pulsar_host,
|
|
|
|
|
logger=pulsar.ConsoleLogger(log_level),
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
self.producer = self.client.create_producer(
|
2024-07-23 21:34:03 +01:00
|
|
|
topic=text_completion_request_queue,
|
2024-07-10 23:20:06 +01:00
|
|
|
schema=JsonSchema(TextCompletionRequest),
|
|
|
|
|
chunking_enabled=True,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
self.consumer = self.client.subscribe(
|
2024-07-23 23:02:09 +01:00
|
|
|
text_completion_response_queue, subscriber,
|
2024-07-10 23:20:06 +01:00
|
|
|
schema=JsonSchema(TextCompletionResponse),
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def request(self, prompt, timeout=500):
|
|
|
|
|
|
|
|
|
|
id = str(uuid.uuid4())
|
|
|
|
|
|
|
|
|
|
r = TextCompletionRequest(
|
|
|
|
|
prompt=prompt
|
|
|
|
|
)
|
2024-07-15 18:56:44 +01:00
|
|
|
|
2024-07-10 23:20:06 +01:00
|
|
|
self.producer.send(r, properties={ "id": id })
|
|
|
|
|
|
|
|
|
|
while True:
|
|
|
|
|
|
|
|
|
|
msg = self.consumer.receive(timeout_millis=timeout * 1000)
|
|
|
|
|
|
|
|
|
|
mid = msg.properties()["id"]
|
|
|
|
|
|
|
|
|
|
if mid == id:
|
|
|
|
|
resp = msg.value().response
|
|
|
|
|
self.consumer.acknowledge(msg)
|
|
|
|
|
return resp
|
|
|
|
|
|
|
|
|
|
# Ignore messages with wrong ID
|
|
|
|
|
self.consumer.acknowledge(msg)
|
|
|
|
|
|
|
|
|
|
def __del__(self):
|
|
|
|
|
|
2024-07-23 22:53:54 +01:00
|
|
|
if hasattr(self, "consumer"):
|
2024-07-23 23:02:09 +01:00
|
|
|
# self.consumer.unsubscribe()
|
2024-07-23 22:53:54 +01:00
|
|
|
self.consumer.close()
|
|
|
|
|
|
|
|
|
|
if hasattr(self, "producer"):
|
|
|
|
|
self.producer.flush()
|
|
|
|
|
self.producer.close()
|
|
|
|
|
|
2024-07-10 23:20:06 +01:00
|
|
|
self.client.close()
|
|
|
|
|
|