Improve request/response handling (#18)

* Request/response error handling with common client

* Fixup error handling change
This commit is contained in:
cybermaggedon 2024-08-22 17:02:18 +01:00 committed by GitHub
parent 19c826c387
commit 1297cdb1d0
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
21 changed files with 1038 additions and 908 deletions

View file

@ -1,9 +1,11 @@
import pulsar
import _pulsar
from pulsar.schema import JsonSchema from pulsar.schema import JsonSchema
from .. schema import EmbeddingsRequest, EmbeddingsResponse from .. schema import EmbeddingsRequest, EmbeddingsResponse
from .. schema import embeddings_request_queue, embeddings_response_queue from .. schema import embeddings_request_queue, embeddings_response_queue
from . base import BaseClient
import pulsar
import _pulsar
import hashlib import hashlib
import uuid import uuid
import time import time
@ -14,7 +16,7 @@ WARN=_pulsar.LoggerLevel.Warn
INFO=_pulsar.LoggerLevel.Info INFO=_pulsar.LoggerLevel.Info
DEBUG=_pulsar.LoggerLevel.Debug DEBUG=_pulsar.LoggerLevel.Debug
class EmbeddingsClient: class EmbeddingsClient(BaseClient):
def __init__( def __init__(
self, log_level=ERROR, self, log_level=ERROR,
@ -24,72 +26,23 @@ class EmbeddingsClient:
pulsar_host="pulsar://pulsar:6650", pulsar_host="pulsar://pulsar:6650",
): ):
self.client = None
if input_queue == None: if input_queue == None:
input_queue=embeddings_request_queue input_queue=embeddings_request_queue
if output_queue == None: if output_queue == None:
output_queue=embeddings_response_queue output_queue=embeddings_response_queue
if subscriber == None: super(EmbeddingsClient, self).__init__(
subscriber = str(uuid.uuid4()) log_level=log_level,
subscriber=subscriber,
self.client = pulsar.Client( input_queue=input_queue,
pulsar_host, output_queue=output_queue,
logger=pulsar.ConsoleLogger(log_level), pulsar_host=pulsar_host,
input_schema=EmbeddingsRequest,
output_schema=EmbeddingsResponse,
) )
self.producer = self.client.create_producer( def request(self, text, timeout=30):
topic=input_queue, return self.call(text=text, timeout=timeout).vectors
schema=JsonSchema(EmbeddingsRequest),
chunking_enabled=True,
)
self.consumer = self.client.subscribe(
output_queue, subscriber,
schema=JsonSchema(EmbeddingsResponse),
)
def request(self, text, timeout=10):
id = str(uuid.uuid4())
r = EmbeddingsRequest(
text=text
)
self.producer.send(r, properties={ "id": id })
end_time = time.time() + timeout
while time.time() < end_time:
try:
msg = self.consumer.receive(timeout_millis=5000)
except pulsar.exceptions.Timeout:
continue
mid = msg.properties()["id"]
if mid == id:
resp = msg.value().vectors
self.consumer.acknowledge(msg)
return resp
# Ignore messages with wrong ID
self.consumer.acknowledge(msg)
raise TimeoutError("Timed out waiting for response")
def __del__(self):
if hasattr(self, "consumer"):
# self.consumer.unsubscribe()
self.consumer.close()
if hasattr(self, "producer"):
self.producer.flush()
self.producer.close()
self.client.close()

View file

@ -9,6 +9,7 @@ import time
from .. schema import GraphEmbeddingsRequest, GraphEmbeddingsResponse from .. schema import GraphEmbeddingsRequest, GraphEmbeddingsResponse
from .. schema import graph_embeddings_request_queue from .. schema import graph_embeddings_request_queue
from .. schema import graph_embeddings_response_queue from .. schema import graph_embeddings_response_queue
from . base import BaseClient
# Ugly # Ugly
ERROR=_pulsar.LoggerLevel.Error ERROR=_pulsar.LoggerLevel.Error
@ -16,7 +17,7 @@ WARN=_pulsar.LoggerLevel.Warn
INFO=_pulsar.LoggerLevel.Info INFO=_pulsar.LoggerLevel.Info
DEBUG=_pulsar.LoggerLevel.Debug DEBUG=_pulsar.LoggerLevel.Debug
class GraphEmbeddingsClient: class GraphEmbeddingsClient(BaseClient):
def __init__( def __init__(
self, log_level=ERROR, self, log_level=ERROR,
@ -32,65 +33,18 @@ class GraphEmbeddingsClient:
if output_queue == None: if output_queue == None:
output_queue = graph_embeddings_response_queue output_queue = graph_embeddings_response_queue
if subscriber == None: super(GraphEmbeddingsClient, self).__init__(
subscriber = str(uuid.uuid4()) log_level=log_level,
subscriber=subscriber,
self.client = pulsar.Client( input_queue=input_queue,
pulsar_host, output_queue=output_queue,
logger=pulsar.ConsoleLogger(log_level), pulsar_host=pulsar_host,
input_schema=GraphEmbeddingsRequest,
output_schema=GraphEmbeddingsResponse,
) )
self.producer = self.client.create_producer( def request(self, vectors, limit=10, timeout=30):
topic=input_queue, return self.call(
schema=JsonSchema(GraphEmbeddingsRequest), vectors=vectors, limit=limit, timeout=timeout
chunking_enabled=True, ).entities
)
self.consumer = self.client.subscribe(
output_queue, subscriber,
schema=JsonSchema(GraphEmbeddingsResponse),
)
def request(self, vectors, limit=10, timeout=500):
id = str(uuid.uuid4())
r = GraphEmbeddingsRequest(
vectors=vectors,
limit=limit,
)
self.producer.send(r, properties={ "id": id })
end_time = time.time() + timeout
while time.time() < end_time:
try:
msg = self.consumer.receive(timeout_millis=5000)
except pulsar.exceptions.Timeout:
continue
mid = msg.properties()["id"]
if mid == id:
resp = msg.value().entities
self.consumer.acknowledge(msg)
return resp
# Ignore messages with wrong ID
self.consumer.acknowledge(msg)
raise TimeoutError("Timed out waiting for response")
def __del__(self):
if hasattr(self, "consumer"):
self.consumer.close()
if hasattr(self, "producer"):
self.producer.flush()
self.producer.close()
self.client.close()

View file

@ -4,6 +4,7 @@ import _pulsar
from pulsar.schema import JsonSchema from pulsar.schema import JsonSchema
from .. schema import GraphRagQuery, GraphRagResponse from .. schema import GraphRagQuery, GraphRagResponse
from .. schema import graph_rag_request_queue, graph_rag_response_queue from .. schema import graph_rag_request_queue, graph_rag_response_queue
from . base import BaseClient
import hashlib import hashlib
import uuid import uuid
@ -15,71 +16,36 @@ WARN=_pulsar.LoggerLevel.Warn
INFO=_pulsar.LoggerLevel.Info INFO=_pulsar.LoggerLevel.Info
DEBUG=_pulsar.LoggerLevel.Debug DEBUG=_pulsar.LoggerLevel.Debug
class GraphRagClient: class GraphRagClient(BaseClient):
def __init__( def __init__(
self, log_level=ERROR, subscriber=None, self,
log_level=ERROR,
subscriber=None,
input_queue=None,
output_queue=None,
pulsar_host="pulsar://pulsar:6650", pulsar_host="pulsar://pulsar:6650",
): ):
if subscriber == None: if input_queue == None:
subscriber = str(uuid.uuid4()) input_queue = graph_rag_request_queue
self.client = pulsar.Client( if output_queue == None:
pulsar_host, output_queue = graph_rag_response_queue
logger=pulsar.ConsoleLogger(log_level),
)
self.producer = self.client.create_producer( super(GraphRagClient, self).__init__(
topic=graph_rag_request_queue, log_level=log_level,
schema=JsonSchema(GraphRagQuery), subscriber=subscriber,
chunking_enabled=True, input_queue=input_queue,
) output_queue=output_queue,
pulsar_host=pulsar_host,
self.consumer = self.client.subscribe( input_schema=GraphRagQuery,
graph_rag_response_queue, subscriber, output_schema=GraphRagResponse,
schema=JsonSchema(GraphRagResponse),
) )
def request(self, query, timeout=500): def request(self, query, timeout=500):
id = str(uuid.uuid4()) return self.call(
query=query, timeout=timeout
r = GraphRagQuery( ).response
query=query
)
self.producer.send(r, properties={ "id": id })
end_time = time.time() + timeout
while time.time() < end_time:
try:
msg = self.consumer.receive(timeout_millis=5000)
except pulsar.exceptions.Timeout:
continue
mid = msg.properties()["id"]
if mid == id:
resp = msg.value().response
self.consumer.acknowledge(msg)
return resp
# Ignore messages with wrong ID
self.consumer.acknowledge(msg)
raise TimeoutError("Timed out waiting for response")
def __del__(self):
if hasattr(self, "consumer"):
# self.consumer.unsubscribe()
self.consumer.close()
if hasattr(self, "producer"):
self.producer.flush()
self.producer.close()
self.client.close()

View file

@ -9,6 +9,8 @@ import time
from .. schema import TextCompletionRequest, TextCompletionResponse from .. schema import TextCompletionRequest, TextCompletionResponse
from .. schema import text_completion_request_queue from .. schema import text_completion_request_queue
from .. schema import text_completion_response_queue from .. schema import text_completion_response_queue
from .. exceptions import *
from . base import BaseClient
# Ugly # Ugly
ERROR=_pulsar.LoggerLevel.Error ERROR=_pulsar.LoggerLevel.Error
@ -16,7 +18,7 @@ WARN=_pulsar.LoggerLevel.Warn
INFO=_pulsar.LoggerLevel.Info INFO=_pulsar.LoggerLevel.Info
DEBUG=_pulsar.LoggerLevel.Debug DEBUG=_pulsar.LoggerLevel.Debug
class LlmClient: class LlmClient(BaseClient):
def __init__( def __init__(
self, log_level=ERROR, self, log_level=ERROR,
@ -26,71 +28,19 @@ class LlmClient:
pulsar_host="pulsar://pulsar:6650", pulsar_host="pulsar://pulsar:6650",
): ):
if input_queue == None: if input_queue is None: input_queue = text_completion_request_queue
input_queue = text_completion_request_queue if output_queue is None: output_queue = text_completion_response_queue
if output_queue == None: super(LlmClient, self).__init__(
output_queue = text_completion_response_queue log_level=log_level,
subscriber=subscriber,
if subscriber == None: input_queue=input_queue,
subscriber = str(uuid.uuid4()) output_queue=output_queue,
pulsar_host=pulsar_host,
self.client = pulsar.Client( input_schema=TextCompletionRequest,
pulsar_host, output_schema=TextCompletionResponse,
logger=pulsar.ConsoleLogger(log_level),
)
self.producer = self.client.create_producer(
topic=input_queue,
schema=JsonSchema(TextCompletionRequest),
chunking_enabled=True,
)
self.consumer = self.client.subscribe(
output_queue, subscriber,
schema=JsonSchema(TextCompletionResponse),
) )
def request(self, prompt, timeout=30): def request(self, prompt, timeout=30):
return self.call(prompt=prompt, timeout=timeout).response
id = str(uuid.uuid4())
r = TextCompletionRequest(
prompt=prompt
)
end_time = time.time() + timeout
self.producer.send(r, properties={ "id": id })
while time.time() < end_time:
try:
msg = self.consumer.receive(timeout_millis=5000)
except pulsar.exceptions.Timeout:
continue
mid = msg.properties()["id"]
if mid == id:
resp = msg.value().response
self.consumer.acknowledge(msg)
return resp
# Ignore messages with wrong ID
self.consumer.acknowledge(msg)
raise TimeoutError("Timed out waiting for response")
def __del__(self):
if hasattr(self, "consumer"):
# self.consumer.unsubscribe()
self.consumer.close()
if hasattr(self, "producer"):
self.producer.flush()
self.producer.close()
self.client.close()

View file

@ -9,6 +9,7 @@ import time
from .. schema import PromptRequest, PromptResponse, Fact from .. schema import PromptRequest, PromptResponse, Fact
from .. schema import prompt_request_queue from .. schema import prompt_request_queue
from .. schema import prompt_response_queue from .. schema import prompt_response_queue
from . base import BaseClient
# Ugly # Ugly
ERROR=_pulsar.LoggerLevel.Error ERROR=_pulsar.LoggerLevel.Error
@ -16,7 +17,7 @@ WARN=_pulsar.LoggerLevel.Warn
INFO=_pulsar.LoggerLevel.Info INFO=_pulsar.LoggerLevel.Info
DEBUG=_pulsar.LoggerLevel.Debug DEBUG=_pulsar.LoggerLevel.Debug
class PromptClient: class PromptClient(BaseClient):
def __init__( def __init__(
self, log_level=ERROR, self, log_level=ERROR,
@ -32,133 +33,35 @@ class PromptClient:
if output_queue == None: if output_queue == None:
output_queue = prompt_response_queue output_queue = prompt_response_queue
if subscriber == None: super(PromptClient, self).__init__(
subscriber = str(uuid.uuid4()) log_level=log_level,
subscriber=subscriber,
self.client = pulsar.Client( input_queue=input_queue,
pulsar_host, output_queue=output_queue,
logger=pulsar.ConsoleLogger(log_level), pulsar_host=pulsar_host,
) input_schema=PromptRequest,
output_schema=PromptResponse,
self.producer = self.client.create_producer(
topic=input_queue,
schema=JsonSchema(PromptRequest),
chunking_enabled=True,
)
self.consumer = self.client.subscribe(
output_queue, subscriber,
schema=JsonSchema(PromptResponse),
) )
def request_definitions(self, chunk, timeout=30): def request_definitions(self, chunk, timeout=30):
id = str(uuid.uuid4()) return self.call(kind="extract-definitions", chunk=chunk,
timeout=timeout).definitions
r = PromptRequest(
kind="extract-definitions",
chunk=chunk,
)
self.producer.send(r, properties={ "id": id })
end_time = time.time() + timeout
while time.time() < end_time:
try:
msg = self.consumer.receive(timeout_millis=5000)
except pulsar.exceptions.Timeout:
continue
mid = msg.properties()["id"]
if mid == id:
resp = msg.value().definitions
self.consumer.acknowledge(msg)
return resp
# Ignore messages with wrong ID
self.consumer.acknowledge(msg)
raise TimeoutError("Timed out waiting for response")
def request_relationships(self, chunk, timeout=30): def request_relationships(self, chunk, timeout=30):
id = str(uuid.uuid4()) return self.call(kind="extract-relationships", chunk=chunk,
timeout=timeout).relationships
r = PromptRequest(
kind="extract-relationships",
chunk=chunk,
)
self.producer.send(r, properties={ "id": id })
end_time = time.time() + timeout
while time.time() < end_time:
try:
msg = self.consumer.receive(timeout_millis=5000)
except pulsar.exceptions.Timeout:
continue
mid = msg.properties()["id"]
if mid == id:
resp = msg.value().relationships
self.consumer.acknowledge(msg)
return resp
# Ignore messages with wrong ID
self.consumer.acknowledge(msg)
raise TimeoutError("Timed out waiting for response")
def request_kg_prompt(self, query, kg, timeout=30): def request_kg_prompt(self, query, kg, timeout=30):
id = str(uuid.uuid4()) return self.call(
r = PromptRequest(
kind="kg-prompt", kind="kg-prompt",
query=query, query=query,
kg=[ kg=[
Fact(s=v[0], p=v[1], o=v[2]) Fact(s=v[0], p=v[1], o=v[2])
for v in kg for v in kg
], ],
) timeout=timeout
).answer
self.producer.send(r, properties={ "id": id })
end_time = time.time() + timeout
while time.time() < end_time:
try:
msg = self.consumer.receive(timeout_millis=5000)
except pulsar.exceptions.Timeout:
continue
mid = msg.properties()["id"]
if mid == id:
resp = msg.value().answer
self.consumer.acknowledge(msg)
return resp
# Ignore messages with wrong ID
self.consumer.acknowledge(msg)
raise TimeoutError("Timed out waiting for response")
def __del__(self):
if hasattr(self, "consumer"):
self.consumer.close()
if hasattr(self, "producer"):
self.producer.flush()
self.producer.close()
self.client.close()

View file

@ -10,6 +10,7 @@ import time
from .. schema import TriplesQueryRequest, TriplesQueryResponse, Value from .. schema import TriplesQueryRequest, TriplesQueryResponse, Value
from .. schema import triples_request_queue from .. schema import triples_request_queue
from .. schema import triples_response_queue from .. schema import triples_response_queue
from . base import BaseClient
# Ugly # Ugly
ERROR=_pulsar.LoggerLevel.Error ERROR=_pulsar.LoggerLevel.Error
@ -17,7 +18,7 @@ WARN=_pulsar.LoggerLevel.Warn
INFO=_pulsar.LoggerLevel.Info INFO=_pulsar.LoggerLevel.Info
DEBUG=_pulsar.LoggerLevel.Debug DEBUG=_pulsar.LoggerLevel.Debug
class TriplesQueryClient: class TriplesQueryClient(BaseClient):
def __init__( def __init__(
self, log_level=ERROR, self, log_level=ERROR,
@ -33,23 +34,14 @@ class TriplesQueryClient:
if output_queue == None: if output_queue == None:
output_queue = triples_response_queue output_queue = triples_response_queue
if subscriber == None: super(TriplesQueryClient, self).__init__(
subscriber = str(uuid.uuid4()) log_level=log_level,
subscriber=subscriber,
self.client = pulsar.Client( input_queue=input_queue,
pulsar_host, output_queue=output_queue,
logger=pulsar.ConsoleLogger(log_level), pulsar_host=pulsar_host,
) input_schema=TriplesQueryRequest,
output_schema=TriplesQueryResponse,
self.producer = self.client.create_producer(
topic=input_queue,
schema=JsonSchema(TriplesQueryRequest),
chunking_enabled=True,
)
self.consumer = self.client.subscribe(
output_queue, subscriber,
schema=JsonSchema(TriplesQueryResponse),
) )
def create_value(self, ent): def create_value(self, ent):
@ -61,48 +53,12 @@ class TriplesQueryClient:
return Value(value=ent, is_uri=False) return Value(value=ent, is_uri=False)
def request(self, s, p, o, limit=10, timeout=500): def request(self, s, p, o, limit=10, timeout=30):
return self.call(
id = str(uuid.uuid4())
r = TriplesQueryRequest(
s=self.create_value(s), s=self.create_value(s),
p=self.create_value(p), p=self.create_value(p),
o=self.create_value(o), o=self.create_value(o),
limit=limit, limit=limit,
) timeout=timeout,
).triples
self.producer.send(r, properties={ "id": id })
end_time = time.time() + timeout
while time.time() < end_time:
try:
msg = self.consumer.receive(timeout_millis=5000)
except pulsar.exceptions.Timeout:
continue
mid = msg.properties()["id"]
if mid == id:
resp = msg.value().triples
self.consumer.acknowledge(msg)
return resp
# Ignore messages with wrong ID
self.consumer.acknowledge(msg)
raise TimeoutError("Timed out waiting for response")
def __del__(self):
if hasattr(self, "consumer"):
self.consumer.close()
if hasattr(self, "producer"):
self.producer.flush()
self.producer.close()
self.client.close()

View file

@ -6,7 +6,7 @@ Input is text, output is embeddings vector.
from langchain_huggingface import HuggingFaceEmbeddings from langchain_huggingface import HuggingFaceEmbeddings
from ... schema import EmbeddingsRequest, EmbeddingsResponse from ... schema import EmbeddingsRequest, EmbeddingsResponse, Error
from ... schema import embeddings_request_queue, embeddings_response_queue from ... schema import embeddings_request_queue, embeddings_response_queue
from ... log_level import LogLevel from ... log_level import LogLevel
from ... base import ConsumerProducer from ... base import ConsumerProducer
@ -48,15 +48,37 @@ class Processor(ConsumerProducer):
print(f"Handling input {id}...", flush=True) print(f"Handling input {id}...", flush=True)
try:
text = v.text text = v.text
embeds = self.embeddings.embed_documents([text]) embeds = self.embeddings.embed_documents([text])
print("Send response...", flush=True) print("Send response...", flush=True)
r = EmbeddingsResponse(vectors=embeds) r = EmbeddingsResponse(vectors=embeds, error=None)
self.producer.send(r, properties={"id": id}) self.producer.send(r, properties={"id": id})
print("Done.", flush=True) print("Done.", flush=True)
except Exception as e:
print(f"Exception: {e}")
print("Send error response...", flush=True)
r = EmbeddingsResponse(
error=Error(
type = "llm-error",
message = str(e),
),
response=None,
)
self.producer.send(r, properties={"id": id})
self.consumer.acknowledge(msg)
@staticmethod @staticmethod
def add_args(parser): def add_args(parser):

View file

@ -2,3 +2,13 @@
class TooManyRequests(Exception): class TooManyRequests(Exception):
pass pass
class LlmError(Exception):
pass
class ParseError(Exception):
pass

View file

@ -6,7 +6,7 @@ Language service abstracts prompt engineering from LLM.
import json import json
from .... schema import Definition, Relationship, Triple from .... schema import Definition, Relationship, Triple
from .... schema import PromptRequest, PromptResponse from .... schema import PromptRequest, PromptResponse, Error
from .... schema import TextCompletionRequest, TextCompletionResponse from .... schema import TextCompletionRequest, TextCompletionResponse
from .... schema import text_completion_request_queue from .... schema import text_completion_request_queue
from .... schema import text_completion_response_queue from .... schema import text_completion_response_queue
@ -89,6 +89,8 @@ class Processor(ConsumerProducer):
def handle_extract_definitions(self, id, v): def handle_extract_definitions(self, id, v):
try:
prompt = to_definitions(v.chunk) prompt = to_definitions(v.chunk)
ans = self.llm.request(prompt) ans = self.llm.request(prompt)
@ -118,13 +120,33 @@ class Processor(ConsumerProducer):
print("definition fields missing, ignored", flush=True) print("definition fields missing, ignored", flush=True)
print("Send response...", flush=True) print("Send response...", flush=True)
r = PromptResponse(definitions=output) r = PromptResponse(definitions=output, error=None)
self.producer.send(r, properties={"id": id}) self.producer.send(r, properties={"id": id})
print("Done.", flush=True) print("Done.", flush=True)
except Exception as e:
print(f"Exception: {e}")
print("Send error response...", flush=True)
r = PromptResponse(
error=Error(
type = "llm-error",
message = str(e),
),
response=None,
)
self.producer.send(r, properties={"id": id})
self.consumer.acknowledge(msg)
def handle_extract_relationships(self, id, v): def handle_extract_relationships(self, id, v):
try:
prompt = to_relationships(v.chunk) prompt = to_relationships(v.chunk)
ans = self.llm.request(prompt) ans = self.llm.request(prompt)
@ -154,13 +176,33 @@ class Processor(ConsumerProducer):
print("relationship fields missing, ignored", flush=True) print("relationship fields missing, ignored", flush=True)
print("Send response...", flush=True) print("Send response...", flush=True)
r = PromptResponse(relationships=output) r = PromptResponse(relationships=output, error=None)
self.producer.send(r, properties={"id": id}) self.producer.send(r, properties={"id": id})
print("Done.", flush=True) print("Done.", flush=True)
except Exception as e:
print(f"Exception: {e}")
print("Send error response...", flush=True)
r = PromptResponse(
error=Error(
type = "llm-error",
message = str(e),
),
response=None,
)
self.producer.send(r, properties={"id": id})
self.consumer.acknowledge(msg)
def handle_kg_prompt(self, id, v): def handle_kg_prompt(self, id, v):
try:
prompt = to_kg_query(v.query, v.kg) prompt = to_kg_query(v.query, v.kg)
print(prompt) print(prompt)
@ -170,11 +212,29 @@ class Processor(ConsumerProducer):
print(ans) print(ans)
print("Send response...", flush=True) print("Send response...", flush=True)
r = PromptResponse(answer=ans) r = PromptResponse(answer=ans, error=None)
self.producer.send(r, properties={"id": id}) self.producer.send(r, properties={"id": id})
print("Done.", flush=True) print("Done.", flush=True)
except Exception as e:
print(f"Exception: {e}")
print("Send error response...", flush=True)
r = PromptResponse(
error=Error(
type = "llm-error",
message = str(e),
),
response=None,
)
self.producer.send(r, properties={"id": id})
self.consumer.acknowledge(msg)
@staticmethod @staticmethod
def add_args(parser): def add_args(parser):

View file

@ -7,7 +7,7 @@ serverless endpoint service. Input is prompt, output is response.
import requests import requests
import json import json
from .... schema import TextCompletionRequest, TextCompletionResponse from .... schema import TextCompletionRequest, TextCompletionResponse, Error
from .... schema import text_completion_request_queue from .... schema import text_completion_request_queue
from .... schema import text_completion_response_queue from .... schema import text_completion_response_queue
from .... log_level import LogLevel from .... log_level import LogLevel
@ -89,6 +89,9 @@ class Processor(ConsumerProducer):
if resp.status_code == 429: if resp.status_code == 429:
raise TooManyRequests() raise TooManyRequests()
if resp.status_code != 200:
raise RuntimeError("LLM failure")
result = resp.json() result = resp.json()
message_content = result['choices'][0]['message']['content'] message_content = result['choices'][0]['message']['content']
@ -110,6 +113,8 @@ class Processor(ConsumerProducer):
v.prompt v.prompt
) )
try:
response = self.call_llm(prompt) response = self.call_llm(prompt)
print("Send response...", flush=True) print("Send response...", flush=True)
@ -120,6 +125,38 @@ class Processor(ConsumerProducer):
r = TextCompletionResponse(response=resp) r = TextCompletionResponse(response=resp)
self.producer.send(r, properties={"id": id}) self.producer.send(r, properties={"id": id})
except TooManyRequests:
print("Send rate limit response...", flush=True)
r = TextCompletionResponse(
error=Error(
type = "rate-limit",
message = str(e),
)
)
self.producer.send(r, properties={"id": id})
self.consumer.acknowledge(msg)
except Exception as e:
print(f"Exception: {e}")
print("Send error response...", flush=True)
r = TextCompletionResponse(
error=Error(
type = "llm-error",
message = str(e),
)
)
self.producer.send(r, properties={"id": id})
self.consumer.acknowledge(msg)
print("Done.", flush=True) print("Done.", flush=True)
@staticmethod @staticmethod

View file

@ -7,7 +7,7 @@ Input is prompt, output is response. Mistral is default.
import boto3 import boto3
import json import json
from .... schema import TextCompletionRequest, TextCompletionResponse from .... schema import TextCompletionRequest, TextCompletionResponse, Error
from .... schema import text_completion_request_queue from .... schema import text_completion_request_queue
from .... schema import text_completion_response_queue from .... schema import text_completion_response_queue
from .... log_level import LogLevel from .... log_level import LogLevel
@ -130,6 +130,8 @@ class Processor(ConsumerProducer):
accept = 'application/json' accept = 'application/json'
contentType = 'application/json' contentType = 'application/json'
try:
# FIXME: Consider catching request limits and raise TooManyRequests # FIXME: Consider catching request limits and raise TooManyRequests
# See https://boto3.amazonaws.com/v1/documentation/api/latest/guide/retries.html # See https://boto3.amazonaws.com/v1/documentation/api/latest/guide/retries.html
response = self.bedrock.invoke_model(body=promptbody, modelId=self.model, accept=accept, contentType=contentType) response = self.bedrock.invoke_model(body=promptbody, modelId=self.model, accept=accept, contentType=contentType)
@ -160,11 +162,50 @@ class Processor(ConsumerProducer):
resp = outputtext.replace("```", "") resp = outputtext.replace("```", "")
print("Send response...", flush=True) print("Send response...", flush=True)
r = TextCompletionResponse(response=resp) r = TextCompletionResponse(
error=None,
response=resp
)
self.send(r, properties={"id": id}) self.send(r, properties={"id": id})
print("Done.", flush=True) print("Done.", flush=True)
# FIXME: Wrong exception, don't know what Bedrock throws
# for a rate limit
except TooManyRequests:
print("Send rate limit response...", flush=True)
r = TextCompletionResponse(
error=Error(
type = "rate-limit",
message = str(e),
),
response=None,
)
self.producer.send(r, properties={"id": id})
self.consumer.acknowledge(msg)
except Exception as e:
print(f"Exception: {e}")
print("Send error response...", flush=True)
r = TextCompletionResponse(
error=Error(
type = "llm-error",
message = str(e),
),
response=None,
)
self.consumer.acknowledge(msg)
@staticmethod @staticmethod
def add_args(parser): def add_args(parser):

View file

@ -6,11 +6,12 @@ Input is prompt, output is response.
import anthropic import anthropic
from .... schema import TextCompletionRequest, TextCompletionResponse from .... schema import TextCompletionRequest, TextCompletionResponse, Error
from .... schema import text_completion_request_queue from .... schema import text_completion_request_queue
from .... schema import text_completion_response_queue from .... schema import text_completion_response_queue
from .... log_level import LogLevel from .... log_level import LogLevel
from .... base import ConsumerProducer from .... base import ConsumerProducer
from .... exceptions import TooManyRequests
module = ".".join(__name__.split(".")[1:-1]) module = ".".join(__name__.split(".")[1:-1])
@ -65,6 +66,8 @@ class Processor(ConsumerProducer):
prompt = v.prompt prompt = v.prompt
try:
# FIXME: Rate limits? # FIXME: Rate limits?
response = message = self.claude.messages.create( response = message = self.claude.messages.create(
model=self.model, model=self.model,
@ -93,6 +96,42 @@ class Processor(ConsumerProducer):
print("Done.", flush=True) print("Done.", flush=True)
# FIXME: Wrong exception, don't know what this LLM throws
# for a rate limit
except TooManyRequests:
print("Send rate limit response...", flush=True)
r = TextCompletionResponse(
error=Error(
type = "rate-limit",
message = str(e),
),
response=None,
)
self.producer.send(r, properties={"id": id})
self.consumer.acknowledge(msg)
except Exception as e:
print(f"Exception: {e}")
print("Send error response...", flush=True)
r = TextCompletionResponse(
error=Error(
type = "llm-error",
message = str(e),
),
response=None,
)
self.producer.send(r, properties={"id": id})
self.consumer.acknowledge(msg)
@staticmethod @staticmethod
def add_args(parser): def add_args(parser):

View file

@ -6,11 +6,12 @@ Input is prompt, output is response.
import cohere import cohere
from .... schema import TextCompletionRequest, TextCompletionResponse from .... schema import TextCompletionRequest, TextCompletionResponse, Error
from .... schema import text_completion_request_queue from .... schema import text_completion_request_queue
from .... schema import text_completion_response_queue from .... schema import text_completion_response_queue
from .... log_level import LogLevel from .... log_level import LogLevel
from .... base import ConsumerProducer from .... base import ConsumerProducer
from .... exceptions import TooManyRequests
module = ".".join(__name__.split(".")[1:-1]) module = ".".join(__name__.split(".")[1:-1])
@ -61,7 +62,8 @@ class Processor(ConsumerProducer):
prompt = v.prompt prompt = v.prompt
# FIXME: Deal with rate limits? try:
output = self.cohere.chat( output = self.cohere.chat(
model=self.model, model=self.model,
message=prompt, message=prompt,
@ -84,6 +86,42 @@ class Processor(ConsumerProducer):
print("Done.", flush=True) print("Done.", flush=True)
# FIXME: Wrong exception, don't know what this LLM throws
# for a rate limit
except TooManyRequests:
print("Send rate limit response...", flush=True)
r = TextCompletionResponse(
error=Error(
type = "rate-limit",
message = str(e),
),
response=None,
)
self.producer.send(r, properties={"id": id})
self.consumer.acknowledge(msg)
except Exception as e:
print(f"Exception: {e}")
print("Send error response...", flush=True)
r = TextCompletionResponse(
error=Error(
type = "llm-error",
message = str(e),
),
response=None,
)
self.producer.send(r, properties={"id": id})
self.consumer.acknowledge(msg)
@staticmethod @staticmethod
def add_args(parser): def add_args(parser):

View file

@ -7,11 +7,12 @@ Input is prompt, output is response.
from langchain_community.llms import Ollama from langchain_community.llms import Ollama
from prometheus_client import Histogram, Info, Counter from prometheus_client import Histogram, Info, Counter
from .... schema import TextCompletionRequest, TextCompletionResponse from .... schema import TextCompletionRequest, TextCompletionResponse, Error
from .... schema import text_completion_request_queue from .... schema import text_completion_request_queue
from .... schema import text_completion_response_queue from .... schema import text_completion_response_queue
from .... log_level import LogLevel from .... log_level import LogLevel
from .... base import ConsumerProducer from .... base import ConsumerProducer
from .... exceptions import TooManyRequests
module = ".".join(__name__.split(".")[1:-1]) module = ".".join(__name__.split(".")[1:-1])
@ -66,7 +67,8 @@ class Processor(ConsumerProducer):
prompt = v.prompt prompt = v.prompt
# FIXME: Rate limits? try:
response = self.llm.invoke(prompt) response = self.llm.invoke(prompt)
print("Send response...", flush=True) print("Send response...", flush=True)
@ -80,6 +82,42 @@ class Processor(ConsumerProducer):
print("Done.", flush=True) print("Done.", flush=True)
# FIXME: Wrong exception, don't know what this LLM throws
# for a rate limit
except TooManyRequests:
print("Send rate limit response...", flush=True)
r = TextCompletionResponse(
error=Error(
type = "rate-limit",
message = str(e),
),
response=None,
)
self.producer.send(r, properties={"id": id})
self.consumer.acknowledge(msg)
except Exception as e:
print(f"Exception: {e}")
print("Send error response...", flush=True)
r = TextCompletionResponse(
error=Error(
type = "llm-error",
message = str(e),
),
response=None,
)
self.producer.send(r, properties={"id": id})
self.consumer.acknowledge(msg)
@staticmethod @staticmethod
def add_args(parser): def add_args(parser):

View file

@ -6,11 +6,12 @@ Input is prompt, output is response.
from openai import OpenAI from openai import OpenAI
from .... schema import TextCompletionRequest, TextCompletionResponse from .... schema import TextCompletionRequest, TextCompletionResponse, Error
from .... schema import text_completion_request_queue from .... schema import text_completion_request_queue
from .... schema import text_completion_response_queue from .... schema import text_completion_response_queue
from .... log_level import LogLevel from .... log_level import LogLevel
from .... base import ConsumerProducer from .... base import ConsumerProducer
from .... exceptions import TooManyRequests
module = ".".join(__name__.split(".")[1:-1]) module = ".".join(__name__.split(".")[1:-1])
@ -65,6 +66,8 @@ class Processor(ConsumerProducer):
prompt = v.prompt prompt = v.prompt
try:
# FIXME: Rate limits # FIXME: Rate limits
resp = self.openai.chat.completions.create( resp = self.openai.chat.completions.create(
model=self.model, model=self.model,
@ -97,6 +100,42 @@ class Processor(ConsumerProducer):
print("Done.", flush=True) print("Done.", flush=True)
# FIXME: Wrong exception, don't know what this LLM throws
# for a rate limit
except TooManyRequests:
print("Send rate limit response...", flush=True)
r = TextCompletionResponse(
error=Error(
type = "rate-limit",
message = str(e),
),
response=None,
)
self.producer.send(r, properties={"id": id})
self.consumer.acknowledge(msg)
except Exception as e:
print(f"Exception: {e}")
print("Send error response...", flush=True)
r = TextCompletionResponse(
error=Error(
type = "llm-error",
message = str(e),
),
response=None,
)
self.producer.send(r, properties={"id": id})
self.consumer.acknowledge(msg)
@staticmethod @staticmethod
def add_args(parser): def add_args(parser):

View file

@ -21,7 +21,7 @@ from vertexai.preview.generative_models import (
Tool, Tool,
) )
from .... schema import TextCompletionRequest, TextCompletionResponse from .... schema import TextCompletionRequest, TextCompletionResponse, Error
from .... schema import text_completion_request_queue from .... schema import text_completion_request_queue
from .... schema import text_completion_response_queue from .... schema import text_completion_response_queue
from .... log_level import LogLevel from .... log_level import LogLevel
@ -136,7 +136,12 @@ class Processor(ConsumerProducer):
resp = resp.replace("```", "") resp = resp.replace("```", "")
print("Send response...", flush=True) print("Send response...", flush=True)
r = TextCompletionResponse(response=resp)
r = TextCompletionResponse(
error=None,
response=resp,
)
self.producer.send(r, properties={"id": id}) self.producer.send(r, properties={"id": id})
print("Done.", flush=True) print("Done.", flush=True)
@ -144,12 +149,39 @@ class Processor(ConsumerProducer):
# Acknowledge successful processing of the message # Acknowledge successful processing of the message
self.consumer.acknowledge(msg) self.consumer.acknowledge(msg)
except google.api_core.exceptions.ResourceExhausted: except google.api_core.exceptions.ResourceExhausted as e:
# 429 / rate limits case print("Send rate limit response...", flush=True)
raise TooManyRequests
# Let other exceptions fall through r = TextCompletionResponse(
error=Error(
type = "rate-limit",
message = str(e),
),
response=None,
)
self.producer.send(r, properties={"id": id})
self.consumer.acknowledge(msg)
except Exception as e:
print(f"Exception: {e}")
print("Send error response...", flush=True)
r = TextCompletionResponse(
error=Error(
type = "llm-error",
message = str(e),
),
response=None,
)
self.producer.send(r, properties={"id": id})
self.consumer.acknowledge(msg)
@staticmethod @staticmethod
def add_args(parser): def add_args(parser):

View file

@ -5,7 +5,8 @@ entities
""" """
from .... direct.milvus import TripleVectors from .... direct.milvus import TripleVectors
from .... schema import GraphEmbeddingsRequest, GraphEmbeddingsResponse, Value from .... schema import GraphEmbeddingsRequest, GraphEmbeddingsResponse
from .... schema import Error, Value
from .... schema import graph_embeddings_request_queue from .... schema import graph_embeddings_request_queue
from .... schema import graph_embeddings_response_queue from .... schema import graph_embeddings_response_queue
from .... base import ConsumerProducer from .... base import ConsumerProducer
@ -47,6 +48,8 @@ class Processor(ConsumerProducer):
def handle(self, msg): def handle(self, msg):
try:
v = msg.value() v = msg.value()
# Sender-produced ID # Sender-produced ID
@ -75,11 +78,29 @@ class Processor(ConsumerProducer):
entities = ents2 entities = ents2
print("Send response...", flush=True) print("Send response...", flush=True)
r = GraphEmbeddingsResponse(entities=entities) r = GraphEmbeddingsResponse(entities=entities, error=None)
self.producer.send(r, properties={"id": id}) self.producer.send(r, properties={"id": id})
print("Done.", flush=True) print("Done.", flush=True)
except Exception as e:
print(f"Exception: {e}")
print("Send error response...", flush=True)
r = GraphEmbeddingsResponse(
error=Error(
type = "llm-error",
message = str(e),
),
response=None,
)
self.producer.send(r, properties={"id": id})
self.consumer.acknowledge(msg)
@staticmethod @staticmethod
def add_args(parser): def add_args(parser):

View file

@ -5,7 +5,7 @@ null. Output is a list of triples.
""" """
from .... direct.cassandra import TrustGraph from .... direct.cassandra import TrustGraph
from .... schema import TriplesQueryRequest, TriplesQueryResponse from .... schema import TriplesQueryRequest, TriplesQueryResponse, Error
from .... schema import Value, Triple from .... schema import Value, Triple
from .... schema import triples_request_queue from .... schema import triples_request_queue
from .... schema import triples_response_queue from .... schema import triples_response_queue
@ -48,6 +48,8 @@ class Processor(ConsumerProducer):
def handle(self, msg): def handle(self, msg):
try:
v = msg.value() v = msg.value()
# Sender-produced ID # Sender-produced ID
@ -128,11 +130,29 @@ class Processor(ConsumerProducer):
] ]
print("Send response...", flush=True) print("Send response...", flush=True)
r = TriplesQueryResponse(triples=triples) r = TriplesQueryResponse(triples=triples, error=None)
self.producer.send(r, properties={"id": id}) self.producer.send(r, properties={"id": id})
print("Done.", flush=True) print("Done.", flush=True)
except Exception as e:
print(f"Exception: {e}")
print("Send error response...", flush=True)
r = TriplesQueryResponse(
error=Error(
type = "llm-error",
message = str(e),
),
response=None,
)
self.producer.send(r, properties={"id": id})
self.consumer.acknowledge(msg)
@staticmethod @staticmethod
def add_args(parser): def add_args(parser):

View file

@ -6,7 +6,7 @@ null. Output is a list of triples.
from neo4j import GraphDatabase from neo4j import GraphDatabase
from .... schema import TriplesQueryRequest, TriplesQueryResponse from .... schema import TriplesQueryRequest, TriplesQueryResponse, Error
from .... schema import Value, Triple from .... schema import Value, Triple
from .... schema import triples_request_queue from .... schema import triples_request_queue
from .... schema import triples_response_queue from .... schema import triples_response_queue
@ -57,6 +57,8 @@ class Processor(ConsumerProducer):
def handle(self, msg): def handle(self, msg):
try:
v = msg.value() v = msg.value()
# Sender-produced ID # Sender-produced ID
@ -296,6 +298,24 @@ class Processor(ConsumerProducer):
print("Done.", flush=True) print("Done.", flush=True)
except Exception as e:
print(f"Exception: {e}")
print("Send error response...", flush=True)
r = TriplesQueryResponse(
error=Error(
type = "llm-error",
message = str(e),
),
response=None,
)
self.producer.send(r, properties={"id": id})
self.consumer.acknowledge(msg)
@staticmethod @staticmethod
def add_args(parser): def add_args(parser):

View file

@ -4,7 +4,7 @@ Simple RAG service, performs query using graph RAG an LLM.
Input is query, output is response. Input is query, output is response.
""" """
from ... schema import GraphRagQuery, GraphRagResponse from ... schema import GraphRagQuery, GraphRagResponse, Error
from ... schema import graph_rag_request_queue, graph_rag_response_queue from ... schema import graph_rag_request_queue, graph_rag_response_queue
from ... schema import prompt_request_queue from ... schema import prompt_request_queue
from ... schema import prompt_response_queue from ... schema import prompt_response_queue
@ -99,10 +99,11 @@ class Processor(ConsumerProducer):
def handle(self, msg): def handle(self, msg):
try:
v = msg.value() v = msg.value()
# Sender-produced ID # Sender-produced ID
id = msg.properties()["id"] id = msg.properties()["id"]
print(f"Handling input {id}...", flush=True) print(f"Handling input {id}...", flush=True)
@ -110,11 +111,29 @@ class Processor(ConsumerProducer):
response = self.rag.query(v.query) response = self.rag.query(v.query)
print("Send response...", flush=True) print("Send response...", flush=True)
r = GraphRagResponse(response = response) r = GraphRagResponse(response = response, error=None)
self.producer.send(r, properties={"id": id}) self.producer.send(r, properties={"id": id})
print("Done.", flush=True) print("Done.", flush=True)
except Exception as e:
print(f"Exception: {e}")
print("Send error response...", flush=True)
r = GraphRagResponse(
error=Error(
type = "llm-error",
message = str(e),
),
response=None,
)
self.producer.send(r, properties={"id": id})
self.consumer.acknowledge(msg)
@staticmethod @staticmethod
def add_args(parser): def add_args(parser):

View file

@ -8,6 +8,12 @@ def topic(topic, kind='persistent', tenant='tg', namespace='flow'):
############################################################################ ############################################################################
class Error(Record):
type = String()
message = String()
############################################################################
class Value(Record): class Value(Record):
value = String() value = String()
is_uri = Boolean() is_uri = Boolean()
@ -78,6 +84,7 @@ class GraphEmbeddingsRequest(Record):
limit = Integer() limit = Integer()
class GraphEmbeddingsResponse(Record): class GraphEmbeddingsResponse(Record):
error = Error()
entities = Array(Value()) entities = Array(Value())
graph_embeddings_request_queue = topic( graph_embeddings_request_queue = topic(
@ -110,6 +117,7 @@ class TriplesQueryRequest(Record):
limit = Integer() limit = Integer()
class TriplesQueryResponse(Record): class TriplesQueryResponse(Record):
error = Error()
triples = Array(Triple()) triples = Array(Triple())
triples_request_queue = topic( triples_request_queue = topic(
@ -131,6 +139,7 @@ class TextCompletionRequest(Record):
prompt = String() prompt = String()
class TextCompletionResponse(Record): class TextCompletionResponse(Record):
error = Error()
response = String() response = String()
text_completion_request_queue = topic( text_completion_request_queue = topic(
@ -148,6 +157,7 @@ class EmbeddingsRequest(Record):
text = String() text = String()
class EmbeddingsResponse(Record): class EmbeddingsResponse(Record):
error = Error()
vectors = Array(Array(Double())) vectors = Array(Array(Double()))
embeddings_request_queue = topic( embeddings_request_queue = topic(
@ -165,6 +175,7 @@ class GraphRagQuery(Record):
query = String() query = String()
class GraphRagResponse(Record): class GraphRagResponse(Record):
error = Error()
response = String() response = String()
graph_rag_request_queue = topic( graph_rag_request_queue = topic(
@ -207,6 +218,7 @@ class PromptRequest(Record):
kg = Array(Fact()) kg = Array(Fact())
class PromptResponse(Record): class PromptResponse(Record):
error = Error()
answer = String() answer = String()
definitions = Array(Definition()) definitions = Array(Definition())
relationships = Array(Relationship()) relationships = Array(Relationship())