Improve request/response handling (#18)

* Request/response error handling with common client

* Fixup error handling change
This commit is contained in:
cybermaggedon 2024-08-22 17:02:18 +01:00 committed by GitHub
parent 19c826c387
commit 1297cdb1d0
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
21 changed files with 1038 additions and 908 deletions

View file

@ -1,9 +1,11 @@
import pulsar
import _pulsar
from pulsar.schema import JsonSchema
from .. schema import EmbeddingsRequest, EmbeddingsResponse
from .. schema import embeddings_request_queue, embeddings_response_queue
from . base import BaseClient
import pulsar
import _pulsar
import hashlib
import uuid
import time
@ -14,7 +16,7 @@ WARN=_pulsar.LoggerLevel.Warn
INFO=_pulsar.LoggerLevel.Info
DEBUG=_pulsar.LoggerLevel.Debug
class EmbeddingsClient:
class EmbeddingsClient(BaseClient):
def __init__(
self, log_level=ERROR,
@ -24,72 +26,23 @@ class EmbeddingsClient:
pulsar_host="pulsar://pulsar:6650",
):
self.client = None
if input_queue == None:
input_queue=embeddings_request_queue
if output_queue == None:
output_queue=embeddings_response_queue
if subscriber == None:
subscriber = str(uuid.uuid4())
self.client = pulsar.Client(
pulsar_host,
logger=pulsar.ConsoleLogger(log_level),
super(EmbeddingsClient, self).__init__(
log_level=log_level,
subscriber=subscriber,
input_queue=input_queue,
output_queue=output_queue,
pulsar_host=pulsar_host,
input_schema=EmbeddingsRequest,
output_schema=EmbeddingsResponse,
)
self.producer = self.client.create_producer(
topic=input_queue,
schema=JsonSchema(EmbeddingsRequest),
chunking_enabled=True,
)
def request(self, text, timeout=30):
return self.call(text=text, timeout=timeout).vectors
self.consumer = self.client.subscribe(
output_queue, subscriber,
schema=JsonSchema(EmbeddingsResponse),
)
def request(self, text, timeout=10):
id = str(uuid.uuid4())
r = EmbeddingsRequest(
text=text
)
self.producer.send(r, properties={ "id": id })
end_time = time.time() + timeout
while time.time() < end_time:
try:
msg = self.consumer.receive(timeout_millis=5000)
except pulsar.exceptions.Timeout:
continue
mid = msg.properties()["id"]
if mid == id:
resp = msg.value().vectors
self.consumer.acknowledge(msg)
return resp
# Ignore messages with wrong ID
self.consumer.acknowledge(msg)
raise TimeoutError("Timed out waiting for response")
def __del__(self):
if hasattr(self, "consumer"):
# self.consumer.unsubscribe()
self.consumer.close()
if hasattr(self, "producer"):
self.producer.flush()
self.producer.close()
self.client.close()

View file

@ -9,6 +9,7 @@ import time
from .. schema import GraphEmbeddingsRequest, GraphEmbeddingsResponse
from .. schema import graph_embeddings_request_queue
from .. schema import graph_embeddings_response_queue
from . base import BaseClient
# Ugly
ERROR=_pulsar.LoggerLevel.Error
@ -16,7 +17,7 @@ WARN=_pulsar.LoggerLevel.Warn
INFO=_pulsar.LoggerLevel.Info
DEBUG=_pulsar.LoggerLevel.Debug
class GraphEmbeddingsClient:
class GraphEmbeddingsClient(BaseClient):
def __init__(
self, log_level=ERROR,
@ -32,65 +33,18 @@ class GraphEmbeddingsClient:
if output_queue == None:
output_queue = graph_embeddings_response_queue
if subscriber == None:
subscriber = str(uuid.uuid4())
self.client = pulsar.Client(
pulsar_host,
logger=pulsar.ConsoleLogger(log_level),
super(GraphEmbeddingsClient, self).__init__(
log_level=log_level,
subscriber=subscriber,
input_queue=input_queue,
output_queue=output_queue,
pulsar_host=pulsar_host,
input_schema=GraphEmbeddingsRequest,
output_schema=GraphEmbeddingsResponse,
)
self.producer = self.client.create_producer(
topic=input_queue,
schema=JsonSchema(GraphEmbeddingsRequest),
chunking_enabled=True,
)
self.consumer = self.client.subscribe(
output_queue, subscriber,
schema=JsonSchema(GraphEmbeddingsResponse),
)
def request(self, vectors, limit=10, timeout=500):
id = str(uuid.uuid4())
r = GraphEmbeddingsRequest(
vectors=vectors,
limit=limit,
)
self.producer.send(r, properties={ "id": id })
end_time = time.time() + timeout
while time.time() < end_time:
try:
msg = self.consumer.receive(timeout_millis=5000)
except pulsar.exceptions.Timeout:
continue
mid = msg.properties()["id"]
if mid == id:
resp = msg.value().entities
self.consumer.acknowledge(msg)
return resp
# Ignore messages with wrong ID
self.consumer.acknowledge(msg)
raise TimeoutError("Timed out waiting for response")
def __del__(self):
if hasattr(self, "consumer"):
self.consumer.close()
if hasattr(self, "producer"):
self.producer.flush()
self.producer.close()
self.client.close()
def request(self, vectors, limit=10, timeout=30):
return self.call(
vectors=vectors, limit=limit, timeout=timeout
).entities

View file

@ -4,6 +4,7 @@ import _pulsar
from pulsar.schema import JsonSchema
from .. schema import GraphRagQuery, GraphRagResponse
from .. schema import graph_rag_request_queue, graph_rag_response_queue
from . base import BaseClient
import hashlib
import uuid
@ -15,71 +16,36 @@ WARN=_pulsar.LoggerLevel.Warn
INFO=_pulsar.LoggerLevel.Info
DEBUG=_pulsar.LoggerLevel.Debug
class GraphRagClient:
class GraphRagClient(BaseClient):
def __init__(
self, log_level=ERROR, subscriber=None,
self,
log_level=ERROR,
subscriber=None,
input_queue=None,
output_queue=None,
pulsar_host="pulsar://pulsar:6650",
):
if subscriber == None:
subscriber = str(uuid.uuid4())
if input_queue == None:
input_queue = graph_rag_request_queue
self.client = pulsar.Client(
pulsar_host,
logger=pulsar.ConsoleLogger(log_level),
)
if output_queue == None:
output_queue = graph_rag_response_queue
self.producer = self.client.create_producer(
topic=graph_rag_request_queue,
schema=JsonSchema(GraphRagQuery),
chunking_enabled=True,
)
self.consumer = self.client.subscribe(
graph_rag_response_queue, subscriber,
schema=JsonSchema(GraphRagResponse),
super(GraphRagClient, self).__init__(
log_level=log_level,
subscriber=subscriber,
input_queue=input_queue,
output_queue=output_queue,
pulsar_host=pulsar_host,
input_schema=GraphRagQuery,
output_schema=GraphRagResponse,
)
def request(self, query, timeout=500):
id = str(uuid.uuid4())
r = GraphRagQuery(
query=query
)
self.producer.send(r, properties={ "id": id })
end_time = time.time() + timeout
while time.time() < end_time:
try:
msg = self.consumer.receive(timeout_millis=5000)
except pulsar.exceptions.Timeout:
continue
mid = msg.properties()["id"]
if mid == id:
resp = msg.value().response
self.consumer.acknowledge(msg)
return resp
# Ignore messages with wrong ID
self.consumer.acknowledge(msg)
raise TimeoutError("Timed out waiting for response")
def __del__(self):
if hasattr(self, "consumer"):
# self.consumer.unsubscribe()
self.consumer.close()
if hasattr(self, "producer"):
self.producer.flush()
self.producer.close()
self.client.close()
return self.call(
query=query, timeout=timeout
).response

View file

@ -9,6 +9,8 @@ import time
from .. schema import TextCompletionRequest, TextCompletionResponse
from .. schema import text_completion_request_queue
from .. schema import text_completion_response_queue
from .. exceptions import *
from . base import BaseClient
# Ugly
ERROR=_pulsar.LoggerLevel.Error
@ -16,7 +18,7 @@ WARN=_pulsar.LoggerLevel.Warn
INFO=_pulsar.LoggerLevel.Info
DEBUG=_pulsar.LoggerLevel.Debug
class LlmClient:
class LlmClient(BaseClient):
def __init__(
self, log_level=ERROR,
@ -26,71 +28,19 @@ class LlmClient:
pulsar_host="pulsar://pulsar:6650",
):
if input_queue == None:
input_queue = text_completion_request_queue
if input_queue is None: input_queue = text_completion_request_queue
if output_queue is None: output_queue = text_completion_response_queue
if output_queue == None:
output_queue = text_completion_response_queue
if subscriber == None:
subscriber = str(uuid.uuid4())
self.client = pulsar.Client(
pulsar_host,
logger=pulsar.ConsoleLogger(log_level),
)
self.producer = self.client.create_producer(
topic=input_queue,
schema=JsonSchema(TextCompletionRequest),
chunking_enabled=True,
)
self.consumer = self.client.subscribe(
output_queue, subscriber,
schema=JsonSchema(TextCompletionResponse),
super(LlmClient, self).__init__(
log_level=log_level,
subscriber=subscriber,
input_queue=input_queue,
output_queue=output_queue,
pulsar_host=pulsar_host,
input_schema=TextCompletionRequest,
output_schema=TextCompletionResponse,
)
def request(self, prompt, timeout=30):
id = str(uuid.uuid4())
r = TextCompletionRequest(
prompt=prompt
)
end_time = time.time() + timeout
self.producer.send(r, properties={ "id": id })
while time.time() < end_time:
try:
msg = self.consumer.receive(timeout_millis=5000)
except pulsar.exceptions.Timeout:
continue
mid = msg.properties()["id"]
if mid == id:
resp = msg.value().response
self.consumer.acknowledge(msg)
return resp
# Ignore messages with wrong ID
self.consumer.acknowledge(msg)
raise TimeoutError("Timed out waiting for response")
def __del__(self):
if hasattr(self, "consumer"):
# self.consumer.unsubscribe()
self.consumer.close()
if hasattr(self, "producer"):
self.producer.flush()
self.producer.close()
self.client.close()
return self.call(prompt=prompt, timeout=timeout).response

View file

@ -9,6 +9,7 @@ import time
from .. schema import PromptRequest, PromptResponse, Fact
from .. schema import prompt_request_queue
from .. schema import prompt_response_queue
from . base import BaseClient
# Ugly
ERROR=_pulsar.LoggerLevel.Error
@ -16,7 +17,7 @@ WARN=_pulsar.LoggerLevel.Warn
INFO=_pulsar.LoggerLevel.Info
DEBUG=_pulsar.LoggerLevel.Debug
class PromptClient:
class PromptClient(BaseClient):
def __init__(
self, log_level=ERROR,
@ -32,133 +33,35 @@ class PromptClient:
if output_queue == None:
output_queue = prompt_response_queue
if subscriber == None:
subscriber = str(uuid.uuid4())
self.client = pulsar.Client(
pulsar_host,
logger=pulsar.ConsoleLogger(log_level),
)
self.producer = self.client.create_producer(
topic=input_queue,
schema=JsonSchema(PromptRequest),
chunking_enabled=True,
)
self.consumer = self.client.subscribe(
output_queue, subscriber,
schema=JsonSchema(PromptResponse),
super(PromptClient, self).__init__(
log_level=log_level,
subscriber=subscriber,
input_queue=input_queue,
output_queue=output_queue,
pulsar_host=pulsar_host,
input_schema=PromptRequest,
output_schema=PromptResponse,
)
def request_definitions(self, chunk, timeout=30):
id = str(uuid.uuid4())
r = PromptRequest(
kind="extract-definitions",
chunk=chunk,
)
self.producer.send(r, properties={ "id": id })
end_time = time.time() + timeout
while time.time() < end_time:
try:
msg = self.consumer.receive(timeout_millis=5000)
except pulsar.exceptions.Timeout:
continue
mid = msg.properties()["id"]
if mid == id:
resp = msg.value().definitions
self.consumer.acknowledge(msg)
return resp
# Ignore messages with wrong ID
self.consumer.acknowledge(msg)
raise TimeoutError("Timed out waiting for response")
return self.call(kind="extract-definitions", chunk=chunk,
timeout=timeout).definitions
def request_relationships(self, chunk, timeout=30):
id = str(uuid.uuid4())
r = PromptRequest(
kind="extract-relationships",
chunk=chunk,
)
self.producer.send(r, properties={ "id": id })
end_time = time.time() + timeout
while time.time() < end_time:
try:
msg = self.consumer.receive(timeout_millis=5000)
except pulsar.exceptions.Timeout:
continue
mid = msg.properties()["id"]
if mid == id:
resp = msg.value().relationships
self.consumer.acknowledge(msg)
return resp
# Ignore messages with wrong ID
self.consumer.acknowledge(msg)
raise TimeoutError("Timed out waiting for response")
return self.call(kind="extract-relationships", chunk=chunk,
timeout=timeout).relationships
def request_kg_prompt(self, query, kg, timeout=30):
id = str(uuid.uuid4())
r = PromptRequest(
return self.call(
kind="kg-prompt",
query=query,
kg=[
Fact(s=v[0], p=v[1], o=v[2])
for v in kg
],
)
self.producer.send(r, properties={ "id": id })
end_time = time.time() + timeout
while time.time() < end_time:
try:
msg = self.consumer.receive(timeout_millis=5000)
except pulsar.exceptions.Timeout:
continue
mid = msg.properties()["id"]
if mid == id:
resp = msg.value().answer
self.consumer.acknowledge(msg)
return resp
# Ignore messages with wrong ID
self.consumer.acknowledge(msg)
raise TimeoutError("Timed out waiting for response")
def __del__(self):
if hasattr(self, "consumer"):
self.consumer.close()
if hasattr(self, "producer"):
self.producer.flush()
self.producer.close()
self.client.close()
timeout=timeout
).answer

View file

@ -10,6 +10,7 @@ import time
from .. schema import TriplesQueryRequest, TriplesQueryResponse, Value
from .. schema import triples_request_queue
from .. schema import triples_response_queue
from . base import BaseClient
# Ugly
ERROR=_pulsar.LoggerLevel.Error
@ -17,7 +18,7 @@ WARN=_pulsar.LoggerLevel.Warn
INFO=_pulsar.LoggerLevel.Info
DEBUG=_pulsar.LoggerLevel.Debug
class TriplesQueryClient:
class TriplesQueryClient(BaseClient):
def __init__(
self, log_level=ERROR,
@ -33,23 +34,14 @@ class TriplesQueryClient:
if output_queue == None:
output_queue = triples_response_queue
if subscriber == None:
subscriber = str(uuid.uuid4())
self.client = pulsar.Client(
pulsar_host,
logger=pulsar.ConsoleLogger(log_level),
)
self.producer = self.client.create_producer(
topic=input_queue,
schema=JsonSchema(TriplesQueryRequest),
chunking_enabled=True,
)
self.consumer = self.client.subscribe(
output_queue, subscriber,
schema=JsonSchema(TriplesQueryResponse),
super(TriplesQueryClient, self).__init__(
log_level=log_level,
subscriber=subscriber,
input_queue=input_queue,
output_queue=output_queue,
pulsar_host=pulsar_host,
input_schema=TriplesQueryRequest,
output_schema=TriplesQueryResponse,
)
def create_value(self, ent):
@ -61,48 +53,12 @@ class TriplesQueryClient:
return Value(value=ent, is_uri=False)
def request(self, s, p, o, limit=10, timeout=500):
id = str(uuid.uuid4())
r = TriplesQueryRequest(
def request(self, s, p, o, limit=10, timeout=30):
return self.call(
s=self.create_value(s),
p=self.create_value(p),
o=self.create_value(o),
limit=limit,
)
self.producer.send(r, properties={ "id": id })
end_time = time.time() + timeout
while time.time() < end_time:
try:
msg = self.consumer.receive(timeout_millis=5000)
except pulsar.exceptions.Timeout:
continue
mid = msg.properties()["id"]
if mid == id:
resp = msg.value().triples
self.consumer.acknowledge(msg)
return resp
# Ignore messages with wrong ID
self.consumer.acknowledge(msg)
raise TimeoutError("Timed out waiting for response")
def __del__(self):
if hasattr(self, "consumer"):
self.consumer.close()
if hasattr(self, "producer"):
self.producer.flush()
self.producer.close()
self.client.close()
timeout=timeout,
).triples

View file

@ -6,7 +6,7 @@ Input is text, output is embeddings vector.
from langchain_huggingface import HuggingFaceEmbeddings
from ... schema import EmbeddingsRequest, EmbeddingsResponse
from ... schema import EmbeddingsRequest, EmbeddingsResponse, Error
from ... schema import embeddings_request_queue, embeddings_response_queue
from ... log_level import LogLevel
from ... base import ConsumerProducer
@ -48,15 +48,37 @@ class Processor(ConsumerProducer):
print(f"Handling input {id}...", flush=True)
try:
text = v.text
embeds = self.embeddings.embed_documents([text])
print("Send response...", flush=True)
r = EmbeddingsResponse(vectors=embeds)
r = EmbeddingsResponse(vectors=embeds, error=None)
self.producer.send(r, properties={"id": id})
print("Done.", flush=True)
except Exception as e:
print(f"Exception: {e}")
print("Send error response...", flush=True)
r = EmbeddingsResponse(
error=Error(
type = "llm-error",
message = str(e),
),
response=None,
)
self.producer.send(r, properties={"id": id})
self.consumer.acknowledge(msg)
@staticmethod
def add_args(parser):

View file

@ -2,3 +2,13 @@
class TooManyRequests(Exception):
pass
class LlmError(Exception):
pass
class ParseError(Exception):
pass

View file

@ -6,7 +6,7 @@ Language service abstracts prompt engineering from LLM.
import json
from .... schema import Definition, Relationship, Triple
from .... schema import PromptRequest, PromptResponse
from .... schema import PromptRequest, PromptResponse, Error
from .... schema import TextCompletionRequest, TextCompletionResponse
from .... schema import text_completion_request_queue
from .... schema import text_completion_response_queue
@ -89,6 +89,8 @@ class Processor(ConsumerProducer):
def handle_extract_definitions(self, id, v):
try:
prompt = to_definitions(v.chunk)
ans = self.llm.request(prompt)
@ -118,13 +120,33 @@ class Processor(ConsumerProducer):
print("definition fields missing, ignored", flush=True)
print("Send response...", flush=True)
r = PromptResponse(definitions=output)
r = PromptResponse(definitions=output, error=None)
self.producer.send(r, properties={"id": id})
print("Done.", flush=True)
except Exception as e:
print(f"Exception: {e}")
print("Send error response...", flush=True)
r = PromptResponse(
error=Error(
type = "llm-error",
message = str(e),
),
response=None,
)
self.producer.send(r, properties={"id": id})
self.consumer.acknowledge(msg)
def handle_extract_relationships(self, id, v):
try:
prompt = to_relationships(v.chunk)
ans = self.llm.request(prompt)
@ -154,13 +176,33 @@ class Processor(ConsumerProducer):
print("relationship fields missing, ignored", flush=True)
print("Send response...", flush=True)
r = PromptResponse(relationships=output)
r = PromptResponse(relationships=output, error=None)
self.producer.send(r, properties={"id": id})
print("Done.", flush=True)
except Exception as e:
print(f"Exception: {e}")
print("Send error response...", flush=True)
r = PromptResponse(
error=Error(
type = "llm-error",
message = str(e),
),
response=None,
)
self.producer.send(r, properties={"id": id})
self.consumer.acknowledge(msg)
def handle_kg_prompt(self, id, v):
try:
prompt = to_kg_query(v.query, v.kg)
print(prompt)
@ -170,11 +212,29 @@ class Processor(ConsumerProducer):
print(ans)
print("Send response...", flush=True)
r = PromptResponse(answer=ans)
r = PromptResponse(answer=ans, error=None)
self.producer.send(r, properties={"id": id})
print("Done.", flush=True)
except Exception as e:
print(f"Exception: {e}")
print("Send error response...", flush=True)
r = PromptResponse(
error=Error(
type = "llm-error",
message = str(e),
),
response=None,
)
self.producer.send(r, properties={"id": id})
self.consumer.acknowledge(msg)
@staticmethod
def add_args(parser):

View file

@ -7,7 +7,7 @@ serverless endpoint service. Input is prompt, output is response.
import requests
import json
from .... schema import TextCompletionRequest, TextCompletionResponse
from .... schema import TextCompletionRequest, TextCompletionResponse, Error
from .... schema import text_completion_request_queue
from .... schema import text_completion_response_queue
from .... log_level import LogLevel
@ -89,6 +89,9 @@ class Processor(ConsumerProducer):
if resp.status_code == 429:
raise TooManyRequests()
if resp.status_code != 200:
raise RuntimeError("LLM failure")
result = resp.json()
message_content = result['choices'][0]['message']['content']
@ -110,6 +113,8 @@ class Processor(ConsumerProducer):
v.prompt
)
try:
response = self.call_llm(prompt)
print("Send response...", flush=True)
@ -120,6 +125,38 @@ class Processor(ConsumerProducer):
r = TextCompletionResponse(response=resp)
self.producer.send(r, properties={"id": id})
except TooManyRequests:
print("Send rate limit response...", flush=True)
r = TextCompletionResponse(
error=Error(
type = "rate-limit",
message = str(e),
)
)
self.producer.send(r, properties={"id": id})
self.consumer.acknowledge(msg)
except Exception as e:
print(f"Exception: {e}")
print("Send error response...", flush=True)
r = TextCompletionResponse(
error=Error(
type = "llm-error",
message = str(e),
)
)
self.producer.send(r, properties={"id": id})
self.consumer.acknowledge(msg)
print("Done.", flush=True)
@staticmethod

View file

@ -7,7 +7,7 @@ Input is prompt, output is response. Mistral is default.
import boto3
import json
from .... schema import TextCompletionRequest, TextCompletionResponse
from .... schema import TextCompletionRequest, TextCompletionResponse, Error
from .... schema import text_completion_request_queue
from .... schema import text_completion_response_queue
from .... log_level import LogLevel
@ -130,6 +130,8 @@ class Processor(ConsumerProducer):
accept = 'application/json'
contentType = 'application/json'
try:
# FIXME: Consider catching request limits and raise TooManyRequests
# See https://boto3.amazonaws.com/v1/documentation/api/latest/guide/retries.html
response = self.bedrock.invoke_model(body=promptbody, modelId=self.model, accept=accept, contentType=contentType)
@ -160,11 +162,50 @@ class Processor(ConsumerProducer):
resp = outputtext.replace("```", "")
print("Send response...", flush=True)
r = TextCompletionResponse(response=resp)
r = TextCompletionResponse(
error=None,
response=resp
)
self.send(r, properties={"id": id})
print("Done.", flush=True)
# FIXME: Wrong exception, don't know what Bedrock throws
# for a rate limit
except TooManyRequests:
print("Send rate limit response...", flush=True)
r = TextCompletionResponse(
error=Error(
type = "rate-limit",
message = str(e),
),
response=None,
)
self.producer.send(r, properties={"id": id})
self.consumer.acknowledge(msg)
except Exception as e:
print(f"Exception: {e}")
print("Send error response...", flush=True)
r = TextCompletionResponse(
error=Error(
type = "llm-error",
message = str(e),
),
response=None,
)
self.consumer.acknowledge(msg)
@staticmethod
def add_args(parser):

View file

@ -6,11 +6,12 @@ Input is prompt, output is response.
import anthropic
from .... schema import TextCompletionRequest, TextCompletionResponse
from .... schema import TextCompletionRequest, TextCompletionResponse, Error
from .... schema import text_completion_request_queue
from .... schema import text_completion_response_queue
from .... log_level import LogLevel
from .... base import ConsumerProducer
from .... exceptions import TooManyRequests
module = ".".join(__name__.split(".")[1:-1])
@ -65,6 +66,8 @@ class Processor(ConsumerProducer):
prompt = v.prompt
try:
# FIXME: Rate limits?
response = message = self.claude.messages.create(
model=self.model,
@ -93,6 +96,42 @@ class Processor(ConsumerProducer):
print("Done.", flush=True)
# FIXME: Wrong exception, don't know what this LLM throws
# for a rate limit
except TooManyRequests:
print("Send rate limit response...", flush=True)
r = TextCompletionResponse(
error=Error(
type = "rate-limit",
message = str(e),
),
response=None,
)
self.producer.send(r, properties={"id": id})
self.consumer.acknowledge(msg)
except Exception as e:
print(f"Exception: {e}")
print("Send error response...", flush=True)
r = TextCompletionResponse(
error=Error(
type = "llm-error",
message = str(e),
),
response=None,
)
self.producer.send(r, properties={"id": id})
self.consumer.acknowledge(msg)
@staticmethod
def add_args(parser):

View file

@ -6,11 +6,12 @@ Input is prompt, output is response.
import cohere
from .... schema import TextCompletionRequest, TextCompletionResponse
from .... schema import TextCompletionRequest, TextCompletionResponse, Error
from .... schema import text_completion_request_queue
from .... schema import text_completion_response_queue
from .... log_level import LogLevel
from .... base import ConsumerProducer
from .... exceptions import TooManyRequests
module = ".".join(__name__.split(".")[1:-1])
@ -61,7 +62,8 @@ class Processor(ConsumerProducer):
prompt = v.prompt
# FIXME: Deal with rate limits?
try:
output = self.cohere.chat(
model=self.model,
message=prompt,
@ -84,6 +86,42 @@ class Processor(ConsumerProducer):
print("Done.", flush=True)
# FIXME: Wrong exception, don't know what this LLM throws
# for a rate limit
except TooManyRequests:
print("Send rate limit response...", flush=True)
r = TextCompletionResponse(
error=Error(
type = "rate-limit",
message = str(e),
),
response=None,
)
self.producer.send(r, properties={"id": id})
self.consumer.acknowledge(msg)
except Exception as e:
print(f"Exception: {e}")
print("Send error response...", flush=True)
r = TextCompletionResponse(
error=Error(
type = "llm-error",
message = str(e),
),
response=None,
)
self.producer.send(r, properties={"id": id})
self.consumer.acknowledge(msg)
@staticmethod
def add_args(parser):

View file

@ -7,11 +7,12 @@ Input is prompt, output is response.
from langchain_community.llms import Ollama
from prometheus_client import Histogram, Info, Counter
from .... schema import TextCompletionRequest, TextCompletionResponse
from .... schema import TextCompletionRequest, TextCompletionResponse, Error
from .... schema import text_completion_request_queue
from .... schema import text_completion_response_queue
from .... log_level import LogLevel
from .... base import ConsumerProducer
from .... exceptions import TooManyRequests
module = ".".join(__name__.split(".")[1:-1])
@ -66,7 +67,8 @@ class Processor(ConsumerProducer):
prompt = v.prompt
# FIXME: Rate limits?
try:
response = self.llm.invoke(prompt)
print("Send response...", flush=True)
@ -80,6 +82,42 @@ class Processor(ConsumerProducer):
print("Done.", flush=True)
# FIXME: Wrong exception, don't know what this LLM throws
# for a rate limit
except TooManyRequests:
print("Send rate limit response...", flush=True)
r = TextCompletionResponse(
error=Error(
type = "rate-limit",
message = str(e),
),
response=None,
)
self.producer.send(r, properties={"id": id})
self.consumer.acknowledge(msg)
except Exception as e:
print(f"Exception: {e}")
print("Send error response...", flush=True)
r = TextCompletionResponse(
error=Error(
type = "llm-error",
message = str(e),
),
response=None,
)
self.producer.send(r, properties={"id": id})
self.consumer.acknowledge(msg)
@staticmethod
def add_args(parser):

View file

@ -6,11 +6,12 @@ Input is prompt, output is response.
from openai import OpenAI
from .... schema import TextCompletionRequest, TextCompletionResponse
from .... schema import TextCompletionRequest, TextCompletionResponse, Error
from .... schema import text_completion_request_queue
from .... schema import text_completion_response_queue
from .... log_level import LogLevel
from .... base import ConsumerProducer
from .... exceptions import TooManyRequests
module = ".".join(__name__.split(".")[1:-1])
@ -65,6 +66,8 @@ class Processor(ConsumerProducer):
prompt = v.prompt
try:
# FIXME: Rate limits
resp = self.openai.chat.completions.create(
model=self.model,
@ -97,6 +100,42 @@ class Processor(ConsumerProducer):
print("Done.", flush=True)
# FIXME: Wrong exception, don't know what this LLM throws
# for a rate limit
except TooManyRequests:
print("Send rate limit response...", flush=True)
r = TextCompletionResponse(
error=Error(
type = "rate-limit",
message = str(e),
),
response=None,
)
self.producer.send(r, properties={"id": id})
self.consumer.acknowledge(msg)
except Exception as e:
print(f"Exception: {e}")
print("Send error response...", flush=True)
r = TextCompletionResponse(
error=Error(
type = "llm-error",
message = str(e),
),
response=None,
)
self.producer.send(r, properties={"id": id})
self.consumer.acknowledge(msg)
@staticmethod
def add_args(parser):

View file

@ -21,7 +21,7 @@ from vertexai.preview.generative_models import (
Tool,
)
from .... schema import TextCompletionRequest, TextCompletionResponse
from .... schema import TextCompletionRequest, TextCompletionResponse, Error
from .... schema import text_completion_request_queue
from .... schema import text_completion_response_queue
from .... log_level import LogLevel
@ -136,7 +136,12 @@ class Processor(ConsumerProducer):
resp = resp.replace("```", "")
print("Send response...", flush=True)
r = TextCompletionResponse(response=resp)
r = TextCompletionResponse(
error=None,
response=resp,
)
self.producer.send(r, properties={"id": id})
print("Done.", flush=True)
@ -144,12 +149,39 @@ class Processor(ConsumerProducer):
# Acknowledge successful processing of the message
self.consumer.acknowledge(msg)
except google.api_core.exceptions.ResourceExhausted:
except google.api_core.exceptions.ResourceExhausted as e:
# 429 / rate limits case
raise TooManyRequests
print("Send rate limit response...", flush=True)
# Let other exceptions fall through
r = TextCompletionResponse(
error=Error(
type = "rate-limit",
message = str(e),
),
response=None,
)
self.producer.send(r, properties={"id": id})
self.consumer.acknowledge(msg)
except Exception as e:
print(f"Exception: {e}")
print("Send error response...", flush=True)
r = TextCompletionResponse(
error=Error(
type = "llm-error",
message = str(e),
),
response=None,
)
self.producer.send(r, properties={"id": id})
self.consumer.acknowledge(msg)
@staticmethod
def add_args(parser):

View file

@ -5,7 +5,8 @@ entities
"""
from .... direct.milvus import TripleVectors
from .... schema import GraphEmbeddingsRequest, GraphEmbeddingsResponse, Value
from .... schema import GraphEmbeddingsRequest, GraphEmbeddingsResponse
from .... schema import Error, Value
from .... schema import graph_embeddings_request_queue
from .... schema import graph_embeddings_response_queue
from .... base import ConsumerProducer
@ -47,6 +48,8 @@ class Processor(ConsumerProducer):
def handle(self, msg):
try:
v = msg.value()
# Sender-produced ID
@ -75,11 +78,29 @@ class Processor(ConsumerProducer):
entities = ents2
print("Send response...", flush=True)
r = GraphEmbeddingsResponse(entities=entities)
r = GraphEmbeddingsResponse(entities=entities, error=None)
self.producer.send(r, properties={"id": id})
print("Done.", flush=True)
except Exception as e:
print(f"Exception: {e}")
print("Send error response...", flush=True)
r = GraphEmbeddingsResponse(
error=Error(
type = "llm-error",
message = str(e),
),
response=None,
)
self.producer.send(r, properties={"id": id})
self.consumer.acknowledge(msg)
@staticmethod
def add_args(parser):

View file

@ -5,7 +5,7 @@ null. Output is a list of triples.
"""
from .... direct.cassandra import TrustGraph
from .... schema import TriplesQueryRequest, TriplesQueryResponse
from .... schema import TriplesQueryRequest, TriplesQueryResponse, Error
from .... schema import Value, Triple
from .... schema import triples_request_queue
from .... schema import triples_response_queue
@ -48,6 +48,8 @@ class Processor(ConsumerProducer):
def handle(self, msg):
try:
v = msg.value()
# Sender-produced ID
@ -128,11 +130,29 @@ class Processor(ConsumerProducer):
]
print("Send response...", flush=True)
r = TriplesQueryResponse(triples=triples)
r = TriplesQueryResponse(triples=triples, error=None)
self.producer.send(r, properties={"id": id})
print("Done.", flush=True)
except Exception as e:
print(f"Exception: {e}")
print("Send error response...", flush=True)
r = TriplesQueryResponse(
error=Error(
type = "llm-error",
message = str(e),
),
response=None,
)
self.producer.send(r, properties={"id": id})
self.consumer.acknowledge(msg)
@staticmethod
def add_args(parser):

View file

@ -6,7 +6,7 @@ null. Output is a list of triples.
from neo4j import GraphDatabase
from .... schema import TriplesQueryRequest, TriplesQueryResponse
from .... schema import TriplesQueryRequest, TriplesQueryResponse, Error
from .... schema import Value, Triple
from .... schema import triples_request_queue
from .... schema import triples_response_queue
@ -57,6 +57,8 @@ class Processor(ConsumerProducer):
def handle(self, msg):
try:
v = msg.value()
# Sender-produced ID
@ -296,6 +298,24 @@ class Processor(ConsumerProducer):
print("Done.", flush=True)
except Exception as e:
print(f"Exception: {e}")
print("Send error response...", flush=True)
r = TriplesQueryResponse(
error=Error(
type = "llm-error",
message = str(e),
),
response=None,
)
self.producer.send(r, properties={"id": id})
self.consumer.acknowledge(msg)
@staticmethod
def add_args(parser):

View file

@ -4,7 +4,7 @@ Simple RAG service, performs query using graph RAG an LLM.
Input is query, output is response.
"""
from ... schema import GraphRagQuery, GraphRagResponse
from ... schema import GraphRagQuery, GraphRagResponse, Error
from ... schema import graph_rag_request_queue, graph_rag_response_queue
from ... schema import prompt_request_queue
from ... schema import prompt_response_queue
@ -99,10 +99,11 @@ class Processor(ConsumerProducer):
def handle(self, msg):
try:
v = msg.value()
# Sender-produced ID
id = msg.properties()["id"]
print(f"Handling input {id}...", flush=True)
@ -110,11 +111,29 @@ class Processor(ConsumerProducer):
response = self.rag.query(v.query)
print("Send response...", flush=True)
r = GraphRagResponse(response = response)
r = GraphRagResponse(response = response, error=None)
self.producer.send(r, properties={"id": id})
print("Done.", flush=True)
except Exception as e:
print(f"Exception: {e}")
print("Send error response...", flush=True)
r = GraphRagResponse(
error=Error(
type = "llm-error",
message = str(e),
),
response=None,
)
self.producer.send(r, properties={"id": id})
self.consumer.acknowledge(msg)
@staticmethod
def add_args(parser):

View file

@ -8,6 +8,12 @@ def topic(topic, kind='persistent', tenant='tg', namespace='flow'):
############################################################################
class Error(Record):
type = String()
message = String()
############################################################################
class Value(Record):
value = String()
is_uri = Boolean()
@ -78,6 +84,7 @@ class GraphEmbeddingsRequest(Record):
limit = Integer()
class GraphEmbeddingsResponse(Record):
error = Error()
entities = Array(Value())
graph_embeddings_request_queue = topic(
@ -110,6 +117,7 @@ class TriplesQueryRequest(Record):
limit = Integer()
class TriplesQueryResponse(Record):
error = Error()
triples = Array(Triple())
triples_request_queue = topic(
@ -131,6 +139,7 @@ class TextCompletionRequest(Record):
prompt = String()
class TextCompletionResponse(Record):
error = Error()
response = String()
text_completion_request_queue = topic(
@ -148,6 +157,7 @@ class EmbeddingsRequest(Record):
text = String()
class EmbeddingsResponse(Record):
error = Error()
vectors = Array(Array(Double()))
embeddings_request_queue = topic(
@ -165,6 +175,7 @@ class GraphRagQuery(Record):
query = String()
class GraphRagResponse(Record):
error = Error()
response = String()
graph_rag_request_queue = topic(
@ -207,6 +218,7 @@ class PromptRequest(Record):
kg = Array(Fact())
class PromptResponse(Record):
error = Error()
answer = String()
definitions = Array(Definition())
relationships = Array(Relationship())