trustgraph/trustgraph/llm_client.py

81 lines
2 KiB
Python
Raw Normal View History

2024-07-10 23:20:06 +01:00
#!/usr/bin/env python3
import pulsar
import _pulsar
from pulsar.schema import JsonSchema
import hashlib
import uuid
from . schema import TextCompletionRequest, TextCompletionResponse
from . schema import text_completion_request_queue
from . schema import text_completion_response_queue
2024-07-10 23:20:06 +01:00
# Ugly
ERROR=_pulsar.LoggerLevel.Error
WARN=_pulsar.LoggerLevel.Warn
INFO=_pulsar.LoggerLevel.Info
DEBUG=_pulsar.LoggerLevel.Debug
class LlmClient:
def __init__(
2024-07-23 23:02:09 +01:00
self, log_level=ERROR, subscriber=None,
2024-07-10 23:20:06 +01:00
pulsar_host="pulsar://pulsar:6650",
):
2024-07-23 23:02:09 +01:00
if subscriber == None:
subscriber = str(uuid.uuid4())
2024-07-10 23:20:06 +01:00
self.client = pulsar.Client(
pulsar_host,
logger=pulsar.ConsoleLogger(log_level),
)
self.producer = self.client.create_producer(
topic=text_completion_request_queue,
2024-07-10 23:20:06 +01:00
schema=JsonSchema(TextCompletionRequest),
chunking_enabled=True,
)
self.consumer = self.client.subscribe(
2024-07-23 23:02:09 +01:00
text_completion_response_queue, subscriber,
2024-07-10 23:20:06 +01:00
schema=JsonSchema(TextCompletionResponse),
)
def request(self, prompt, timeout=500):
id = str(uuid.uuid4())
r = TextCompletionRequest(
prompt=prompt
)
2024-07-15 18:56:44 +01:00
2024-07-10 23:20:06 +01:00
self.producer.send(r, properties={ "id": id })
while True:
msg = self.consumer.receive(timeout_millis=timeout * 1000)
mid = msg.properties()["id"]
if mid == id:
resp = msg.value().response
self.consumer.acknowledge(msg)
return resp
# Ignore messages with wrong ID
self.consumer.acknowledge(msg)
def __del__(self):
2024-07-23 22:53:54 +01:00
if hasattr(self, "consumer"):
2024-07-23 23:02:09 +01:00
# self.consumer.unsubscribe()
2024-07-23 22:53:54 +01:00
self.consumer.close()
if hasattr(self, "producer"):
self.producer.flush()
self.producer.close()
2024-07-10 23:20:06 +01:00
self.client.close()