Improve request/response handling (#18)

* Request/response error handling with common client

* Fixup error handling change
This commit is contained in:
cybermaggedon 2024-08-22 17:02:18 +01:00 committed by GitHub
parent 19c826c387
commit 1297cdb1d0
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
21 changed files with 1038 additions and 908 deletions

View file

@ -1,9 +1,11 @@
import pulsar
import _pulsar
from pulsar.schema import JsonSchema from pulsar.schema import JsonSchema
from .. schema import EmbeddingsRequest, EmbeddingsResponse from .. schema import EmbeddingsRequest, EmbeddingsResponse
from .. schema import embeddings_request_queue, embeddings_response_queue from .. schema import embeddings_request_queue, embeddings_response_queue
from . base import BaseClient
import pulsar
import _pulsar
import hashlib import hashlib
import uuid import uuid
import time import time
@ -14,7 +16,7 @@ WARN=_pulsar.LoggerLevel.Warn
INFO=_pulsar.LoggerLevel.Info INFO=_pulsar.LoggerLevel.Info
DEBUG=_pulsar.LoggerLevel.Debug DEBUG=_pulsar.LoggerLevel.Debug
class EmbeddingsClient: class EmbeddingsClient(BaseClient):
def __init__( def __init__(
self, log_level=ERROR, self, log_level=ERROR,
@ -24,72 +26,23 @@ class EmbeddingsClient:
pulsar_host="pulsar://pulsar:6650", pulsar_host="pulsar://pulsar:6650",
): ):
self.client = None
if input_queue == None: if input_queue == None:
input_queue=embeddings_request_queue input_queue=embeddings_request_queue
if output_queue == None: if output_queue == None:
output_queue=embeddings_response_queue output_queue=embeddings_response_queue
if subscriber == None: super(EmbeddingsClient, self).__init__(
subscriber = str(uuid.uuid4()) log_level=log_level,
subscriber=subscriber,
self.client = pulsar.Client( input_queue=input_queue,
pulsar_host, output_queue=output_queue,
logger=pulsar.ConsoleLogger(log_level), pulsar_host=pulsar_host,
input_schema=EmbeddingsRequest,
output_schema=EmbeddingsResponse,
) )
self.producer = self.client.create_producer( def request(self, text, timeout=30):
topic=input_queue, return self.call(text=text, timeout=timeout).vectors
schema=JsonSchema(EmbeddingsRequest),
chunking_enabled=True,
)
self.consumer = self.client.subscribe(
output_queue, subscriber,
schema=JsonSchema(EmbeddingsResponse),
)
def request(self, text, timeout=10):
id = str(uuid.uuid4())
r = EmbeddingsRequest(
text=text
)
self.producer.send(r, properties={ "id": id })
end_time = time.time() + timeout
while time.time() < end_time:
try:
msg = self.consumer.receive(timeout_millis=5000)
except pulsar.exceptions.Timeout:
continue
mid = msg.properties()["id"]
if mid == id:
resp = msg.value().vectors
self.consumer.acknowledge(msg)
return resp
# Ignore messages with wrong ID
self.consumer.acknowledge(msg)
raise TimeoutError("Timed out waiting for response")
def __del__(self):
if hasattr(self, "consumer"):
# self.consumer.unsubscribe()
self.consumer.close()
if hasattr(self, "producer"):
self.producer.flush()
self.producer.close()
self.client.close()

View file

@ -9,6 +9,7 @@ import time
from .. schema import GraphEmbeddingsRequest, GraphEmbeddingsResponse from .. schema import GraphEmbeddingsRequest, GraphEmbeddingsResponse
from .. schema import graph_embeddings_request_queue from .. schema import graph_embeddings_request_queue
from .. schema import graph_embeddings_response_queue from .. schema import graph_embeddings_response_queue
from . base import BaseClient
# Ugly # Ugly
ERROR=_pulsar.LoggerLevel.Error ERROR=_pulsar.LoggerLevel.Error
@ -16,7 +17,7 @@ WARN=_pulsar.LoggerLevel.Warn
INFO=_pulsar.LoggerLevel.Info INFO=_pulsar.LoggerLevel.Info
DEBUG=_pulsar.LoggerLevel.Debug DEBUG=_pulsar.LoggerLevel.Debug
class GraphEmbeddingsClient: class GraphEmbeddingsClient(BaseClient):
def __init__( def __init__(
self, log_level=ERROR, self, log_level=ERROR,
@ -31,66 +32,19 @@ class GraphEmbeddingsClient:
if output_queue == None: if output_queue == None:
output_queue = graph_embeddings_response_queue output_queue = graph_embeddings_response_queue
if subscriber == None:
subscriber = str(uuid.uuid4())
self.client = pulsar.Client(
pulsar_host,
logger=pulsar.ConsoleLogger(log_level),
)
self.producer = self.client.create_producer(
topic=input_queue,
schema=JsonSchema(GraphEmbeddingsRequest),
chunking_enabled=True,
)
self.consumer = self.client.subscribe(
output_queue, subscriber,
schema=JsonSchema(GraphEmbeddingsResponse),
)
def request(self, vectors, limit=10, timeout=500):
id = str(uuid.uuid4())
r = GraphEmbeddingsRequest(
vectors=vectors,
limit=limit,
)
self.producer.send(r, properties={ "id": id })
end_time = time.time() + timeout
while time.time() < end_time:
try:
msg = self.consumer.receive(timeout_millis=5000)
except pulsar.exceptions.Timeout:
continue
mid = msg.properties()["id"]
if mid == id:
resp = msg.value().entities
self.consumer.acknowledge(msg)
return resp
# Ignore messages with wrong ID
self.consumer.acknowledge(msg)
raise TimeoutError("Timed out waiting for response")
def __del__(self):
if hasattr(self, "consumer"):
self.consumer.close()
if hasattr(self, "producer"): super(GraphEmbeddingsClient, self).__init__(
self.producer.flush() log_level=log_level,
self.producer.close() subscriber=subscriber,
input_queue=input_queue,
self.client.close() output_queue=output_queue,
pulsar_host=pulsar_host,
input_schema=GraphEmbeddingsRequest,
output_schema=GraphEmbeddingsResponse,
)
def request(self, vectors, limit=10, timeout=30):
return self.call(
vectors=vectors, limit=limit, timeout=timeout
).entities

View file

@ -4,6 +4,7 @@ import _pulsar
from pulsar.schema import JsonSchema from pulsar.schema import JsonSchema
from .. schema import GraphRagQuery, GraphRagResponse from .. schema import GraphRagQuery, GraphRagResponse
from .. schema import graph_rag_request_queue, graph_rag_response_queue from .. schema import graph_rag_request_queue, graph_rag_response_queue
from . base import BaseClient
import hashlib import hashlib
import uuid import uuid
@ -15,71 +16,36 @@ WARN=_pulsar.LoggerLevel.Warn
INFO=_pulsar.LoggerLevel.Info INFO=_pulsar.LoggerLevel.Info
DEBUG=_pulsar.LoggerLevel.Debug DEBUG=_pulsar.LoggerLevel.Debug
class GraphRagClient: class GraphRagClient(BaseClient):
def __init__( def __init__(
self, log_level=ERROR, subscriber=None, self,
log_level=ERROR,
subscriber=None,
input_queue=None,
output_queue=None,
pulsar_host="pulsar://pulsar:6650", pulsar_host="pulsar://pulsar:6650",
): ):
if subscriber == None: if input_queue == None:
subscriber = str(uuid.uuid4()) input_queue = graph_rag_request_queue
self.client = pulsar.Client( if output_queue == None:
pulsar_host, output_queue = graph_rag_response_queue
logger=pulsar.ConsoleLogger(log_level),
) super(GraphRagClient, self).__init__(
log_level=log_level,
self.producer = self.client.create_producer( subscriber=subscriber,
topic=graph_rag_request_queue, input_queue=input_queue,
schema=JsonSchema(GraphRagQuery), output_queue=output_queue,
chunking_enabled=True, pulsar_host=pulsar_host,
) input_schema=GraphRagQuery,
output_schema=GraphRagResponse,
self.consumer = self.client.subscribe(
graph_rag_response_queue, subscriber,
schema=JsonSchema(GraphRagResponse),
) )
def request(self, query, timeout=500): def request(self, query, timeout=500):
id = str(uuid.uuid4()) return self.call(
query=query, timeout=timeout
r = GraphRagQuery( ).response
query=query
)
self.producer.send(r, properties={ "id": id })
end_time = time.time() + timeout
while time.time() < end_time:
try:
msg = self.consumer.receive(timeout_millis=5000)
except pulsar.exceptions.Timeout:
continue
mid = msg.properties()["id"]
if mid == id:
resp = msg.value().response
self.consumer.acknowledge(msg)
return resp
# Ignore messages with wrong ID
self.consumer.acknowledge(msg)
raise TimeoutError("Timed out waiting for response")
def __del__(self):
if hasattr(self, "consumer"):
# self.consumer.unsubscribe()
self.consumer.close()
if hasattr(self, "producer"):
self.producer.flush()
self.producer.close()
self.client.close()

View file

@ -9,6 +9,8 @@ import time
from .. schema import TextCompletionRequest, TextCompletionResponse from .. schema import TextCompletionRequest, TextCompletionResponse
from .. schema import text_completion_request_queue from .. schema import text_completion_request_queue
from .. schema import text_completion_response_queue from .. schema import text_completion_response_queue
from .. exceptions import *
from . base import BaseClient
# Ugly # Ugly
ERROR=_pulsar.LoggerLevel.Error ERROR=_pulsar.LoggerLevel.Error
@ -16,7 +18,7 @@ WARN=_pulsar.LoggerLevel.Warn
INFO=_pulsar.LoggerLevel.Info INFO=_pulsar.LoggerLevel.Info
DEBUG=_pulsar.LoggerLevel.Debug DEBUG=_pulsar.LoggerLevel.Debug
class LlmClient: class LlmClient(BaseClient):
def __init__( def __init__(
self, log_level=ERROR, self, log_level=ERROR,
@ -26,71 +28,19 @@ class LlmClient:
pulsar_host="pulsar://pulsar:6650", pulsar_host="pulsar://pulsar:6650",
): ):
if input_queue == None: if input_queue is None: input_queue = text_completion_request_queue
input_queue = text_completion_request_queue if output_queue is None: output_queue = text_completion_response_queue
if output_queue == None: super(LlmClient, self).__init__(
output_queue = text_completion_response_queue log_level=log_level,
subscriber=subscriber,
if subscriber == None: input_queue=input_queue,
subscriber = str(uuid.uuid4()) output_queue=output_queue,
pulsar_host=pulsar_host,
self.client = pulsar.Client( input_schema=TextCompletionRequest,
pulsar_host, output_schema=TextCompletionResponse,
logger=pulsar.ConsoleLogger(log_level),
)
self.producer = self.client.create_producer(
topic=input_queue,
schema=JsonSchema(TextCompletionRequest),
chunking_enabled=True,
)
self.consumer = self.client.subscribe(
output_queue, subscriber,
schema=JsonSchema(TextCompletionResponse),
) )
def request(self, prompt, timeout=30): def request(self, prompt, timeout=30):
return self.call(prompt=prompt, timeout=timeout).response
id = str(uuid.uuid4())
r = TextCompletionRequest(
prompt=prompt
)
end_time = time.time() + timeout
self.producer.send(r, properties={ "id": id })
while time.time() < end_time:
try:
msg = self.consumer.receive(timeout_millis=5000)
except pulsar.exceptions.Timeout:
continue
mid = msg.properties()["id"]
if mid == id:
resp = msg.value().response
self.consumer.acknowledge(msg)
return resp
# Ignore messages with wrong ID
self.consumer.acknowledge(msg)
raise TimeoutError("Timed out waiting for response")
def __del__(self):
if hasattr(self, "consumer"):
# self.consumer.unsubscribe()
self.consumer.close()
if hasattr(self, "producer"):
self.producer.flush()
self.producer.close()
self.client.close()

View file

@ -9,6 +9,7 @@ import time
from .. schema import PromptRequest, PromptResponse, Fact from .. schema import PromptRequest, PromptResponse, Fact
from .. schema import prompt_request_queue from .. schema import prompt_request_queue
from .. schema import prompt_response_queue from .. schema import prompt_response_queue
from . base import BaseClient
# Ugly # Ugly
ERROR=_pulsar.LoggerLevel.Error ERROR=_pulsar.LoggerLevel.Error
@ -16,7 +17,7 @@ WARN=_pulsar.LoggerLevel.Warn
INFO=_pulsar.LoggerLevel.Info INFO=_pulsar.LoggerLevel.Info
DEBUG=_pulsar.LoggerLevel.Debug DEBUG=_pulsar.LoggerLevel.Debug
class PromptClient: class PromptClient(BaseClient):
def __init__( def __init__(
self, log_level=ERROR, self, log_level=ERROR,
@ -32,133 +33,35 @@ class PromptClient:
if output_queue == None: if output_queue == None:
output_queue = prompt_response_queue output_queue = prompt_response_queue
if subscriber == None: super(PromptClient, self).__init__(
subscriber = str(uuid.uuid4()) log_level=log_level,
subscriber=subscriber,
self.client = pulsar.Client( input_queue=input_queue,
pulsar_host, output_queue=output_queue,
logger=pulsar.ConsoleLogger(log_level), pulsar_host=pulsar_host,
) input_schema=PromptRequest,
output_schema=PromptResponse,
self.producer = self.client.create_producer(
topic=input_queue,
schema=JsonSchema(PromptRequest),
chunking_enabled=True,
)
self.consumer = self.client.subscribe(
output_queue, subscriber,
schema=JsonSchema(PromptResponse),
) )
def request_definitions(self, chunk, timeout=30): def request_definitions(self, chunk, timeout=30):
id = str(uuid.uuid4()) return self.call(kind="extract-definitions", chunk=chunk,
timeout=timeout).definitions
r = PromptRequest(
kind="extract-definitions",
chunk=chunk,
)
self.producer.send(r, properties={ "id": id })
end_time = time.time() + timeout
while time.time() < end_time:
try:
msg = self.consumer.receive(timeout_millis=5000)
except pulsar.exceptions.Timeout:
continue
mid = msg.properties()["id"]
if mid == id:
resp = msg.value().definitions
self.consumer.acknowledge(msg)
return resp
# Ignore messages with wrong ID
self.consumer.acknowledge(msg)
raise TimeoutError("Timed out waiting for response")
def request_relationships(self, chunk, timeout=30): def request_relationships(self, chunk, timeout=30):
id = str(uuid.uuid4()) return self.call(kind="extract-relationships", chunk=chunk,
timeout=timeout).relationships
r = PromptRequest(
kind="extract-relationships",
chunk=chunk,
)
self.producer.send(r, properties={ "id": id })
end_time = time.time() + timeout
while time.time() < end_time:
try:
msg = self.consumer.receive(timeout_millis=5000)
except pulsar.exceptions.Timeout:
continue
mid = msg.properties()["id"]
if mid == id:
resp = msg.value().relationships
self.consumer.acknowledge(msg)
return resp
# Ignore messages with wrong ID
self.consumer.acknowledge(msg)
raise TimeoutError("Timed out waiting for response")
def request_kg_prompt(self, query, kg, timeout=30): def request_kg_prompt(self, query, kg, timeout=30):
id = str(uuid.uuid4()) return self.call(
r = PromptRequest(
kind="kg-prompt", kind="kg-prompt",
query=query, query=query,
kg=[ kg=[
Fact(s=v[0], p=v[1], o=v[2]) Fact(s=v[0], p=v[1], o=v[2])
for v in kg for v in kg
], ],
) timeout=timeout
).answer
self.producer.send(r, properties={ "id": id })
end_time = time.time() + timeout
while time.time() < end_time:
try:
msg = self.consumer.receive(timeout_millis=5000)
except pulsar.exceptions.Timeout:
continue
mid = msg.properties()["id"]
if mid == id:
resp = msg.value().answer
self.consumer.acknowledge(msg)
return resp
# Ignore messages with wrong ID
self.consumer.acknowledge(msg)
raise TimeoutError("Timed out waiting for response")
def __del__(self):
if hasattr(self, "consumer"):
self.consumer.close()
if hasattr(self, "producer"):
self.producer.flush()
self.producer.close()
self.client.close()

View file

@ -10,6 +10,7 @@ import time
from .. schema import TriplesQueryRequest, TriplesQueryResponse, Value from .. schema import TriplesQueryRequest, TriplesQueryResponse, Value
from .. schema import triples_request_queue from .. schema import triples_request_queue
from .. schema import triples_response_queue from .. schema import triples_response_queue
from . base import BaseClient
# Ugly # Ugly
ERROR=_pulsar.LoggerLevel.Error ERROR=_pulsar.LoggerLevel.Error
@ -17,7 +18,7 @@ WARN=_pulsar.LoggerLevel.Warn
INFO=_pulsar.LoggerLevel.Info INFO=_pulsar.LoggerLevel.Info
DEBUG=_pulsar.LoggerLevel.Debug DEBUG=_pulsar.LoggerLevel.Debug
class TriplesQueryClient: class TriplesQueryClient(BaseClient):
def __init__( def __init__(
self, log_level=ERROR, self, log_level=ERROR,
@ -33,23 +34,14 @@ class TriplesQueryClient:
if output_queue == None: if output_queue == None:
output_queue = triples_response_queue output_queue = triples_response_queue
if subscriber == None: super(TriplesQueryClient, self).__init__(
subscriber = str(uuid.uuid4()) log_level=log_level,
subscriber=subscriber,
self.client = pulsar.Client( input_queue=input_queue,
pulsar_host, output_queue=output_queue,
logger=pulsar.ConsoleLogger(log_level), pulsar_host=pulsar_host,
) input_schema=TriplesQueryRequest,
output_schema=TriplesQueryResponse,
self.producer = self.client.create_producer(
topic=input_queue,
schema=JsonSchema(TriplesQueryRequest),
chunking_enabled=True,
)
self.consumer = self.client.subscribe(
output_queue, subscriber,
schema=JsonSchema(TriplesQueryResponse),
) )
def create_value(self, ent): def create_value(self, ent):
@ -61,48 +53,12 @@ class TriplesQueryClient:
return Value(value=ent, is_uri=False) return Value(value=ent, is_uri=False)
def request(self, s, p, o, limit=10, timeout=500): def request(self, s, p, o, limit=10, timeout=30):
return self.call(
id = str(uuid.uuid4())
r = TriplesQueryRequest(
s=self.create_value(s), s=self.create_value(s),
p=self.create_value(p), p=self.create_value(p),
o=self.create_value(o), o=self.create_value(o),
limit=limit, limit=limit,
) timeout=timeout,
).triples
self.producer.send(r, properties={ "id": id })
end_time = time.time() + timeout
while time.time() < end_time:
try:
msg = self.consumer.receive(timeout_millis=5000)
except pulsar.exceptions.Timeout:
continue
mid = msg.properties()["id"]
if mid == id:
resp = msg.value().triples
self.consumer.acknowledge(msg)
return resp
# Ignore messages with wrong ID
self.consumer.acknowledge(msg)
raise TimeoutError("Timed out waiting for response")
def __del__(self):
if hasattr(self, "consumer"):
self.consumer.close()
if hasattr(self, "producer"):
self.producer.flush()
self.producer.close()
self.client.close()

View file

@ -6,7 +6,7 @@ Input is text, output is embeddings vector.
from langchain_huggingface import HuggingFaceEmbeddings from langchain_huggingface import HuggingFaceEmbeddings
from ... schema import EmbeddingsRequest, EmbeddingsResponse from ... schema import EmbeddingsRequest, EmbeddingsResponse, Error
from ... schema import embeddings_request_queue, embeddings_response_queue from ... schema import embeddings_request_queue, embeddings_response_queue
from ... log_level import LogLevel from ... log_level import LogLevel
from ... base import ConsumerProducer from ... base import ConsumerProducer
@ -48,14 +48,36 @@ class Processor(ConsumerProducer):
print(f"Handling input {id}...", flush=True) print(f"Handling input {id}...", flush=True)
text = v.text try:
embeds = self.embeddings.embed_documents([text])
print("Send response...", flush=True) text = v.text
r = EmbeddingsResponse(vectors=embeds) embeds = self.embeddings.embed_documents([text])
self.producer.send(r, properties={"id": id})
print("Done.", flush=True) print("Send response...", flush=True)
r = EmbeddingsResponse(vectors=embeds, error=None)
self.producer.send(r, properties={"id": id})
print("Done.", flush=True)
except Exception as e:
print(f"Exception: {e}")
print("Send error response...", flush=True)
r = EmbeddingsResponse(
error=Error(
type = "llm-error",
message = str(e),
),
response=None,
)
self.producer.send(r, properties={"id": id})
self.consumer.acknowledge(msg)
@staticmethod @staticmethod
def add_args(parser): def add_args(parser):

View file

@ -2,3 +2,13 @@
class TooManyRequests(Exception): class TooManyRequests(Exception):
pass pass
class LlmError(Exception):
pass
class ParseError(Exception):
pass

View file

@ -6,7 +6,7 @@ Language service abstracts prompt engineering from LLM.
import json import json
from .... schema import Definition, Relationship, Triple from .... schema import Definition, Relationship, Triple
from .... schema import PromptRequest, PromptResponse from .... schema import PromptRequest, PromptResponse, Error
from .... schema import TextCompletionRequest, TextCompletionResponse from .... schema import TextCompletionRequest, TextCompletionResponse
from .... schema import text_completion_request_queue from .... schema import text_completion_request_queue
from .... schema import text_completion_response_queue from .... schema import text_completion_response_queue
@ -89,91 +89,151 @@ class Processor(ConsumerProducer):
def handle_extract_definitions(self, id, v): def handle_extract_definitions(self, id, v):
prompt = to_definitions(v.chunk)
ans = self.llm.request(prompt)
# Silently ignore JSON parse error
try: try:
defs = json.loads(ans)
except:
print("JSON parse error, ignored", flush=True)
defs = []
output = [] prompt = to_definitions(v.chunk)
for defn in defs: ans = self.llm.request(prompt)
# Silently ignore JSON parse error
try: try:
e = defn["entity"] defs = json.loads(ans)
d = defn["definition"]
output.append(
Definition(
name=e, definition=d
)
)
except: except:
print("definition fields missing, ignored", flush=True) print("JSON parse error, ignored", flush=True)
defs = []
print("Send response...", flush=True) output = []
r = PromptResponse(definitions=output)
self.producer.send(r, properties={"id": id})
print("Done.", flush=True) for defn in defs:
try:
e = defn["entity"]
d = defn["definition"]
output.append(
Definition(
name=e, definition=d
)
)
except:
print("definition fields missing, ignored", flush=True)
print("Send response...", flush=True)
r = PromptResponse(definitions=output, error=None)
self.producer.send(r, properties={"id": id})
print("Done.", flush=True)
except Exception as e:
print(f"Exception: {e}")
print("Send error response...", flush=True)
r = PromptResponse(
error=Error(
type = "llm-error",
message = str(e),
),
response=None,
)
self.producer.send(r, properties={"id": id})
self.consumer.acknowledge(msg)
def handle_extract_relationships(self, id, v): def handle_extract_relationships(self, id, v):
prompt = to_relationships(v.chunk)
ans = self.llm.request(prompt)
# Silently ignore JSON parse error
try: try:
defs = json.loads(ans)
except:
print("JSON parse error, ignored", flush=True)
defs = []
output = [] prompt = to_relationships(v.chunk)
for defn in defs: ans = self.llm.request(prompt)
# Silently ignore JSON parse error
try: try:
output.append( defs = json.loads(ans)
Relationship( except:
s = defn["subject"], print("JSON parse error, ignored", flush=True)
p = defn["predicate"], defs = []
o = defn["object"],
o_entity = defn["object-entity"], output = []
for defn in defs:
try:
output.append(
Relationship(
s = defn["subject"],
p = defn["predicate"],
o = defn["object"],
o_entity = defn["object-entity"],
)
) )
)
except Exception as e: except Exception as e:
print("relationship fields missing, ignored", flush=True) print("relationship fields missing, ignored", flush=True)
print("Send response...", flush=True) print("Send response...", flush=True)
r = PromptResponse(relationships=output) r = PromptResponse(relationships=output, error=None)
self.producer.send(r, properties={"id": id}) self.producer.send(r, properties={"id": id})
print("Done.", flush=True) print("Done.", flush=True)
except Exception as e:
print(f"Exception: {e}")
print("Send error response...", flush=True)
r = PromptResponse(
error=Error(
type = "llm-error",
message = str(e),
),
response=None,
)
self.producer.send(r, properties={"id": id})
self.consumer.acknowledge(msg)
def handle_kg_prompt(self, id, v): def handle_kg_prompt(self, id, v):
prompt = to_kg_query(v.query, v.kg) try:
print(prompt) prompt = to_kg_query(v.query, v.kg)
ans = self.llm.request(prompt) print(prompt)
print(ans) ans = self.llm.request(prompt)
print("Send response...", flush=True) print(ans)
r = PromptResponse(answer=ans)
self.producer.send(r, properties={"id": id})
print("Done.", flush=True) print("Send response...", flush=True)
r = PromptResponse(answer=ans, error=None)
self.producer.send(r, properties={"id": id})
print("Done.", flush=True)
except Exception as e:
print(f"Exception: {e}")
print("Send error response...", flush=True)
r = PromptResponse(
error=Error(
type = "llm-error",
message = str(e),
),
response=None,
)
self.producer.send(r, properties={"id": id})
self.consumer.acknowledge(msg)
@staticmethod @staticmethod
def add_args(parser): def add_args(parser):

View file

@ -7,7 +7,7 @@ serverless endpoint service. Input is prompt, output is response.
import requests import requests
import json import json
from .... schema import TextCompletionRequest, TextCompletionResponse from .... schema import TextCompletionRequest, TextCompletionResponse, Error
from .... schema import text_completion_request_queue from .... schema import text_completion_request_queue
from .... schema import text_completion_response_queue from .... schema import text_completion_response_queue
from .... log_level import LogLevel from .... log_level import LogLevel
@ -89,6 +89,9 @@ class Processor(ConsumerProducer):
if resp.status_code == 429: if resp.status_code == 429:
raise TooManyRequests() raise TooManyRequests()
if resp.status_code != 200:
raise RuntimeError("LLM failure")
result = resp.json() result = resp.json()
message_content = result['choices'][0]['message']['content'] message_content = result['choices'][0]['message']['content']
@ -110,15 +113,49 @@ class Processor(ConsumerProducer):
v.prompt v.prompt
) )
response = self.call_llm(prompt) try:
print("Send response...", flush=True) response = self.call_llm(prompt)
resp = response.replace("```json", "") print("Send response...", flush=True)
resp = response.replace("```", "")
r = TextCompletionResponse(response=resp) resp = response.replace("```json", "")
self.producer.send(r, properties={"id": id}) resp = response.replace("```", "")
r = TextCompletionResponse(response=resp)
self.producer.send(r, properties={"id": id})
except TooManyRequests:
print("Send rate limit response...", flush=True)
r = TextCompletionResponse(
error=Error(
type = "rate-limit",
message = str(e),
)
)
self.producer.send(r, properties={"id": id})
self.consumer.acknowledge(msg)
except Exception as e:
print(f"Exception: {e}")
print("Send error response...", flush=True)
r = TextCompletionResponse(
error=Error(
type = "llm-error",
message = str(e),
)
)
self.producer.send(r, properties={"id": id})
self.consumer.acknowledge(msg)
print("Done.", flush=True) print("Done.", flush=True)

View file

@ -7,7 +7,7 @@ Input is prompt, output is response. Mistral is default.
import boto3 import boto3
import json import json
from .... schema import TextCompletionRequest, TextCompletionResponse from .... schema import TextCompletionRequest, TextCompletionResponse, Error
from .... schema import text_completion_request_queue from .... schema import text_completion_request_queue
from .... schema import text_completion_response_queue from .... schema import text_completion_response_queue
from .... log_level import LogLevel from .... log_level import LogLevel
@ -130,40 +130,81 @@ class Processor(ConsumerProducer):
accept = 'application/json' accept = 'application/json'
contentType = 'application/json' contentType = 'application/json'
# FIXME: Consider catching request limits and raise TooManyRequests try:
# See https://boto3.amazonaws.com/v1/documentation/api/latest/guide/retries.html
response = self.bedrock.invoke_model(body=promptbody, modelId=self.model, accept=accept, contentType=contentType)
# Mistral Response Structure
if self.model.startswith("mistral"):
response_body = json.loads(response.get("body").read())
outputtext = response_body['outputs'][0]['text']
# Claude Response Structure # FIXME: Consider catching request limits and raise TooManyRequests
elif self.model.startswith("anthropic"): # See https://boto3.amazonaws.com/v1/documentation/api/latest/guide/retries.html
model_response = json.loads(response["body"].read()) response = self.bedrock.invoke_model(body=promptbody, modelId=self.model, accept=accept, contentType=contentType)
outputtext = model_response['content'][0]['text']
# Llama 3.1 Response Structure # Mistral Response Structure
elif self.model.startswith("meta"): if self.model.startswith("mistral"):
model_response = json.loads(response["body"].read()) response_body = json.loads(response.get("body").read())
outputtext = model_response["generation"] outputtext = response_body['outputs'][0]['text']
# Use Mistral as default # Claude Response Structure
else: elif self.model.startswith("anthropic"):
response_body = json.loads(response.get("body").read()) model_response = json.loads(response["body"].read())
outputtext = response_body['outputs'][0]['text'] outputtext = model_response['content'][0]['text']
print(outputtext, flush=True)
resp = outputtext.replace("```json", "") # Llama 3.1 Response Structure
resp = outputtext.replace("```", "") elif self.model.startswith("meta"):
model_response = json.loads(response["body"].read())
print("Send response...", flush=True) outputtext = model_response["generation"]
r = TextCompletionResponse(response=resp)
self.send(r, properties={"id": id})
print("Done.", flush=True) # Use Mistral as default
else:
response_body = json.loads(response.get("body").read())
outputtext = response_body['outputs'][0]['text']
print(outputtext, flush=True)
resp = outputtext.replace("```json", "")
resp = outputtext.replace("```", "")
print("Send response...", flush=True)
r = TextCompletionResponse(
error=None,
response=resp
)
self.send(r, properties={"id": id})
print("Done.", flush=True)
# FIXME: Wrong exception, don't know what Bedrock throws
# for a rate limit
except TooManyRequests:
print("Send rate limit response...", flush=True)
r = TextCompletionResponse(
error=Error(
type = "rate-limit",
message = str(e),
),
response=None,
)
self.producer.send(r, properties={"id": id})
self.consumer.acknowledge(msg)
except Exception as e:
print(f"Exception: {e}")
print("Send error response...", flush=True)
r = TextCompletionResponse(
error=Error(
type = "llm-error",
message = str(e),
),
response=None,
)
self.consumer.acknowledge(msg)
@staticmethod @staticmethod
def add_args(parser): def add_args(parser):

View file

@ -6,11 +6,12 @@ Input is prompt, output is response.
import anthropic import anthropic
from .... schema import TextCompletionRequest, TextCompletionResponse from .... schema import TextCompletionRequest, TextCompletionResponse, Error
from .... schema import text_completion_request_queue from .... schema import text_completion_request_queue
from .... schema import text_completion_response_queue from .... schema import text_completion_response_queue
from .... log_level import LogLevel from .... log_level import LogLevel
from .... base import ConsumerProducer from .... base import ConsumerProducer
from .... exceptions import TooManyRequests
module = ".".join(__name__.split(".")[1:-1]) module = ".".join(__name__.split(".")[1:-1])
@ -65,33 +66,71 @@ class Processor(ConsumerProducer):
prompt = v.prompt prompt = v.prompt
# FIXME: Rate limits? try:
response = message = self.claude.messages.create(
model=self.model,
max_tokens=self.max_output,
temperature=self.temperature,
system = "You are a helpful chatbot.",
messages=[
{
"role": "user",
"content": [
{
"type": "text",
"text": prompt
}
]
}
]
)
resp = response.content[0].text # FIXME: Rate limits?
print(resp, flush=True) response = message = self.claude.messages.create(
model=self.model,
max_tokens=self.max_output,
temperature=self.temperature,
system = "You are a helpful chatbot.",
messages=[
{
"role": "user",
"content": [
{
"type": "text",
"text": prompt
}
]
}
]
)
print("Send response...", flush=True) resp = response.content[0].text
r = TextCompletionResponse(response=resp) print(resp, flush=True)
self.send(r, properties={"id": id})
print("Done.", flush=True) print("Send response...", flush=True)
r = TextCompletionResponse(response=resp)
self.send(r, properties={"id": id})
print("Done.", flush=True)
# FIXME: Wrong exception, don't know what this LLM throws
# for a rate limit
except TooManyRequests:
print("Send rate limit response...", flush=True)
r = TextCompletionResponse(
error=Error(
type = "rate-limit",
message = str(e),
),
response=None,
)
self.producer.send(r, properties={"id": id})
self.consumer.acknowledge(msg)
except Exception as e:
print(f"Exception: {e}")
print("Send error response...", flush=True)
r = TextCompletionResponse(
error=Error(
type = "llm-error",
message = str(e),
),
response=None,
)
self.producer.send(r, properties={"id": id})
self.consumer.acknowledge(msg)
@staticmethod @staticmethod
def add_args(parser): def add_args(parser):

View file

@ -6,11 +6,12 @@ Input is prompt, output is response.
import cohere import cohere
from .... schema import TextCompletionRequest, TextCompletionResponse from .... schema import TextCompletionRequest, TextCompletionResponse, Error
from .... schema import text_completion_request_queue from .... schema import text_completion_request_queue
from .... schema import text_completion_response_queue from .... schema import text_completion_response_queue
from .... log_level import LogLevel from .... log_level import LogLevel
from .... base import ConsumerProducer from .... base import ConsumerProducer
from .... exceptions import TooManyRequests
module = ".".join(__name__.split(".")[1:-1]) module = ".".join(__name__.split(".")[1:-1])
@ -61,28 +62,65 @@ class Processor(ConsumerProducer):
prompt = v.prompt prompt = v.prompt
# FIXME: Deal with rate limits? try:
output = self.cohere.chat(
model=self.model,
message=prompt,
preamble = "You are a helpful AI-assistant.",
temperature=self.temperature,
chat_history=[],
prompt_truncation='auto',
connectors=[]
)
resp = output.text output = self.cohere.chat(
print(resp, flush=True) model=self.model,
message=prompt,
preamble = "You are a helpful AI-assistant.",
temperature=self.temperature,
chat_history=[],
prompt_truncation='auto',
connectors=[]
)
resp = resp.replace("```json", "") resp = output.text
resp = resp.replace("```", "") print(resp, flush=True)
print("Send response...", flush=True)
r = TextCompletionResponse(response=resp)
self.send(r, properties={"id": id})
print("Done.", flush=True) resp = resp.replace("```json", "")
resp = resp.replace("```", "")
print("Send response...", flush=True)
r = TextCompletionResponse(response=resp)
self.send(r, properties={"id": id})
print("Done.", flush=True)
# FIXME: Wrong exception, don't know what this LLM throws
# for a rate limit
except TooManyRequests:
print("Send rate limit response...", flush=True)
r = TextCompletionResponse(
error=Error(
type = "rate-limit",
message = str(e),
),
response=None,
)
self.producer.send(r, properties={"id": id})
self.consumer.acknowledge(msg)
except Exception as e:
print(f"Exception: {e}")
print("Send error response...", flush=True)
r = TextCompletionResponse(
error=Error(
type = "llm-error",
message = str(e),
),
response=None,
)
self.producer.send(r, properties={"id": id})
self.consumer.acknowledge(msg)
@staticmethod @staticmethod
def add_args(parser): def add_args(parser):

View file

@ -7,11 +7,12 @@ Input is prompt, output is response.
from langchain_community.llms import Ollama from langchain_community.llms import Ollama
from prometheus_client import Histogram, Info, Counter from prometheus_client import Histogram, Info, Counter
from .... schema import TextCompletionRequest, TextCompletionResponse from .... schema import TextCompletionRequest, TextCompletionResponse, Error
from .... schema import text_completion_request_queue from .... schema import text_completion_request_queue
from .... schema import text_completion_response_queue from .... schema import text_completion_response_queue
from .... log_level import LogLevel from .... log_level import LogLevel
from .... base import ConsumerProducer from .... base import ConsumerProducer
from .... exceptions import TooManyRequests
module = ".".join(__name__.split(".")[1:-1]) module = ".".join(__name__.split(".")[1:-1])
@ -66,19 +67,56 @@ class Processor(ConsumerProducer):
prompt = v.prompt prompt = v.prompt
# FIXME: Rate limits? try:
response = self.llm.invoke(prompt)
print("Send response...", flush=True) response = self.llm.invoke(prompt)
resp = response.replace("```json", "") print("Send response...", flush=True)
resp = response.replace("```", "")
r = TextCompletionResponse(response=resp) resp = response.replace("```json", "")
resp = response.replace("```", "")
self.send(r, properties={"id": id}) r = TextCompletionResponse(response=resp)
print("Done.", flush=True) self.send(r, properties={"id": id})
print("Done.", flush=True)
# FIXME: Wrong exception, don't know what this LLM throws
# for a rate limit
except TooManyRequests:
print("Send rate limit response...", flush=True)
r = TextCompletionResponse(
error=Error(
type = "rate-limit",
message = str(e),
),
response=None,
)
self.producer.send(r, properties={"id": id})
self.consumer.acknowledge(msg)
except Exception as e:
print(f"Exception: {e}")
print("Send error response...", flush=True)
r = TextCompletionResponse(
error=Error(
type = "llm-error",
message = str(e),
),
response=None,
)
self.producer.send(r, properties={"id": id})
self.consumer.acknowledge(msg)
@staticmethod @staticmethod
def add_args(parser): def add_args(parser):

View file

@ -6,11 +6,12 @@ Input is prompt, output is response.
from openai import OpenAI from openai import OpenAI
from .... schema import TextCompletionRequest, TextCompletionResponse from .... schema import TextCompletionRequest, TextCompletionResponse, Error
from .... schema import text_completion_request_queue from .... schema import text_completion_request_queue
from .... schema import text_completion_response_queue from .... schema import text_completion_response_queue
from .... log_level import LogLevel from .... log_level import LogLevel
from .... base import ConsumerProducer from .... base import ConsumerProducer
from .... exceptions import TooManyRequests
module = ".".join(__name__.split(".")[1:-1]) module = ".".join(__name__.split(".")[1:-1])
@ -65,37 +66,75 @@ class Processor(ConsumerProducer):
prompt = v.prompt prompt = v.prompt
# FIXME: Rate limits try:
resp = self.openai.chat.completions.create(
model=self.model, # FIXME: Rate limits
messages=[ resp = self.openai.chat.completions.create(
{ model=self.model,
"role": "user", messages=[
"content": [ {
{ "role": "user",
"type": "text", "content": [
"text": prompt {
} "type": "text",
] "text": prompt
}
]
}
],
temperature=self.temperature,
max_tokens=self.max_output,
top_p=1,
frequency_penalty=0,
presence_penalty=0,
response_format={
"type": "text"
} }
], )
temperature=self.temperature,
max_tokens=self.max_output,
top_p=1,
frequency_penalty=0,
presence_penalty=0,
response_format={
"type": "text"
}
)
print(resp.choices[0].message.content, flush=True) print(resp.choices[0].message.content, flush=True)
print("Send response...", flush=True) print("Send response...", flush=True)
r = TextCompletionResponse(response=resp.choices[0].message.content) r = TextCompletionResponse(response=resp.choices[0].message.content)
self.send(r, properties={"id": id}) self.send(r, properties={"id": id})
print("Done.", flush=True) print("Done.", flush=True)
# FIXME: Wrong exception, don't know what this LLM throws
# for a rate limit
except TooManyRequests:
print("Send rate limit response...", flush=True)
r = TextCompletionResponse(
error=Error(
type = "rate-limit",
message = str(e),
),
response=None,
)
self.producer.send(r, properties={"id": id})
self.consumer.acknowledge(msg)
except Exception as e:
print(f"Exception: {e}")
print("Send error response...", flush=True)
r = TextCompletionResponse(
error=Error(
type = "llm-error",
message = str(e),
),
response=None,
)
self.producer.send(r, properties={"id": id})
self.consumer.acknowledge(msg)
@staticmethod @staticmethod
def add_args(parser): def add_args(parser):

View file

@ -21,7 +21,7 @@ from vertexai.preview.generative_models import (
Tool, Tool,
) )
from .... schema import TextCompletionRequest, TextCompletionResponse from .... schema import TextCompletionRequest, TextCompletionResponse, Error
from .... schema import text_completion_request_queue from .... schema import text_completion_request_queue
from .... schema import text_completion_response_queue from .... schema import text_completion_response_queue
from .... log_level import LogLevel from .... log_level import LogLevel
@ -136,7 +136,12 @@ class Processor(ConsumerProducer):
resp = resp.replace("```", "") resp = resp.replace("```", "")
print("Send response...", flush=True) print("Send response...", flush=True)
r = TextCompletionResponse(response=resp)
r = TextCompletionResponse(
error=None,
response=resp,
)
self.producer.send(r, properties={"id": id}) self.producer.send(r, properties={"id": id})
print("Done.", flush=True) print("Done.", flush=True)
@ -144,12 +149,39 @@ class Processor(ConsumerProducer):
# Acknowledge successful processing of the message # Acknowledge successful processing of the message
self.consumer.acknowledge(msg) self.consumer.acknowledge(msg)
except google.api_core.exceptions.ResourceExhausted: except google.api_core.exceptions.ResourceExhausted as e:
# 429 / rate limits case print("Send rate limit response...", flush=True)
raise TooManyRequests
# Let other exceptions fall through r = TextCompletionResponse(
error=Error(
type = "rate-limit",
message = str(e),
),
response=None,
)
self.producer.send(r, properties={"id": id})
self.consumer.acknowledge(msg)
except Exception as e:
print(f"Exception: {e}")
print("Send error response...", flush=True)
r = TextCompletionResponse(
error=Error(
type = "llm-error",
message = str(e),
),
response=None,
)
self.producer.send(r, properties={"id": id})
self.consumer.acknowledge(msg)
@staticmethod @staticmethod
def add_args(parser): def add_args(parser):

View file

@ -5,7 +5,8 @@ entities
""" """
from .... direct.milvus import TripleVectors from .... direct.milvus import TripleVectors
from .... schema import GraphEmbeddingsRequest, GraphEmbeddingsResponse, Value from .... schema import GraphEmbeddingsRequest, GraphEmbeddingsResponse
from .... schema import Error, Value
from .... schema import graph_embeddings_request_queue from .... schema import graph_embeddings_request_queue
from .... schema import graph_embeddings_response_queue from .... schema import graph_embeddings_response_queue
from .... base import ConsumerProducer from .... base import ConsumerProducer
@ -47,38 +48,58 @@ class Processor(ConsumerProducer):
def handle(self, msg): def handle(self, msg):
v = msg.value() try:
# Sender-produced ID v = msg.value()
id = msg.properties()["id"]
print(f"Handling input {id}...", flush=True) # Sender-produced ID
id = msg.properties()["id"]
entities = set() print(f"Handling input {id}...", flush=True)
for vec in v.vectors: entities = set()
resp = self.vecstore.search(vec, limit=v.limit) for vec in v.vectors:
for r in resp: resp = self.vecstore.search(vec, limit=v.limit)
ent = r["entity"]["entity"]
entities.add(ent)
# Convert set to list for r in resp:
entities = list(entities) ent = r["entity"]["entity"]
entities.add(ent)
ents2 = [] # Convert set to list
entities = list(entities)
for ent in entities: ents2 = []
ents2.append(self.create_value(ent))
entities = ents2 for ent in entities:
ents2.append(self.create_value(ent))
print("Send response...", flush=True) entities = ents2
r = GraphEmbeddingsResponse(entities=entities)
self.producer.send(r, properties={"id": id})
print("Done.", flush=True) print("Send response...", flush=True)
r = GraphEmbeddingsResponse(entities=entities, error=None)
self.producer.send(r, properties={"id": id})
print("Done.", flush=True)
except Exception as e:
print(f"Exception: {e}")
print("Send error response...", flush=True)
r = GraphEmbeddingsResponse(
error=Error(
type = "llm-error",
message = str(e),
),
response=None,
)
self.producer.send(r, properties={"id": id})
self.consumer.acknowledge(msg)
@staticmethod @staticmethod
def add_args(parser): def add_args(parser):

View file

@ -5,7 +5,7 @@ null. Output is a list of triples.
""" """
from .... direct.cassandra import TrustGraph from .... direct.cassandra import TrustGraph
from .... schema import TriplesQueryRequest, TriplesQueryResponse from .... schema import TriplesQueryRequest, TriplesQueryResponse, Error
from .... schema import Value, Triple from .... schema import Value, Triple
from .... schema import triples_request_queue from .... schema import triples_request_queue
from .... schema import triples_response_queue from .... schema import triples_response_queue
@ -48,90 +48,110 @@ class Processor(ConsumerProducer):
def handle(self, msg): def handle(self, msg):
v = msg.value() try:
# Sender-produced ID v = msg.value()
id = msg.properties()["id"]
print(f"Handling input {id}...", flush=True) # Sender-produced ID
id = msg.properties()["id"]
triples = [] print(f"Handling input {id}...", flush=True)
if v.s is not None: triples = []
if v.p is not None:
if v.o is not None: if v.s is not None:
resp = self.tg.get_spo( if v.p is not None:
v.s.value, v.p.value, v.o.value, if v.o is not None:
limit=v.limit resp = self.tg.get_spo(
) v.s.value, v.p.value, v.o.value,
triples.append((v.s.value, v.p.value, v.o.value)) limit=v.limit
)
triples.append((v.s.value, v.p.value, v.o.value))
else:
resp = self.tg.get_sp(
v.s.value, v.p.value,
limit=v.limit
)
for t in resp:
triples.append((v.s.value, v.p.value, t.o))
else: else:
resp = self.tg.get_sp( if v.o is not None:
v.s.value, v.p.value, resp = self.tg.get_os(
limit=v.limit v.o.value, v.s.value,
) limit=v.limit
for t in resp: )
triples.append((v.s.value, v.p.value, t.o)) for t in resp:
triples.append((v.s.value, t.p, v.o.value))
else:
resp = self.tg.get_s(
v.s.value,
limit=v.limit
)
for t in resp:
triples.append((v.s.value, t.p, t.o))
else: else:
if v.o is not None: if v.p is not None:
resp = self.tg.get_os( if v.o is not None:
v.o.value, v.s.value, resp = self.tg.get_po(
limit=v.limit v.p.value, v.o.value,
) limit=v.limit
for t in resp: )
triples.append((v.s.value, t.p, v.o.value)) for t in resp:
triples.append((t.s, v.p.value, v.o.value))
else:
resp = self.tg.get_p(
v.p.value,
limit=v.limit
)
for t in resp:
triples.append((t.s, v.p.value, t.o))
else: else:
resp = self.tg.get_s( if v.o is not None:
v.s.value, resp = self.tg.get_o(
limit=v.limit v.o.value,
) limit=v.limit
for t in resp: )
triples.append((v.s.value, t.p, t.o)) for t in resp:
else: triples.append((t.s, t.p, v.o.value))
if v.p is not None: else:
if v.o is not None: resp = self.tg.get_all(
resp = self.tg.get_po( limit=v.limit
v.p.value, v.o.value, )
limit=v.limit for t in resp:
) triples.append((t.s, t.p, t.o))
for t in resp:
triples.append((t.s, v.p.value, v.o.value))
else:
resp = self.tg.get_p(
v.p.value,
limit=v.limit
)
for t in resp:
triples.append((t.s, v.p.value, t.o))
else:
if v.o is not None:
resp = self.tg.get_o(
v.o.value,
limit=v.limit
)
for t in resp:
triples.append((t.s, t.p, v.o.value))
else:
resp = self.tg.get_all(
limit=v.limit
)
for t in resp:
triples.append((t.s, t.p, t.o))
triples = [ triples = [
Triple( Triple(
s=self.create_value(t[0]), s=self.create_value(t[0]),
p=self.create_value(t[1]), p=self.create_value(t[1]),
o=self.create_value(t[2]) o=self.create_value(t[2])
)
for t in triples
]
print("Send response...", flush=True)
r = TriplesQueryResponse(triples=triples, error=None)
self.producer.send(r, properties={"id": id})
print("Done.", flush=True)
except Exception as e:
print(f"Exception: {e}")
print("Send error response...", flush=True)
r = TriplesQueryResponse(
error=Error(
type = "llm-error",
message = str(e),
),
response=None,
) )
for t in triples
]
print("Send response...", flush=True) self.producer.send(r, properties={"id": id})
r = TriplesQueryResponse(triples=triples)
self.producer.send(r, properties={"id": id})
print("Done.", flush=True) self.consumer.acknowledge(msg)
@staticmethod @staticmethod
def add_args(parser): def add_args(parser):

View file

@ -6,7 +6,7 @@ null. Output is a list of triples.
from neo4j import GraphDatabase from neo4j import GraphDatabase
from .... schema import TriplesQueryRequest, TriplesQueryResponse from .... schema import TriplesQueryRequest, TriplesQueryResponse, Error
from .... schema import Value, Triple from .... schema import Value, Triple
from .... schema import triples_request_queue from .... schema import triples_request_queue
from .... schema import triples_response_queue from .... schema import triples_response_queue
@ -57,245 +57,265 @@ class Processor(ConsumerProducer):
def handle(self, msg): def handle(self, msg):
v = msg.value() try:
# Sender-produced ID v = msg.value()
id = msg.properties()["id"]
print(f"Handling input {id}...", flush=True) # Sender-produced ID
id = msg.properties()["id"]
triples = [] print(f"Handling input {id}...", flush=True)
if v.s is not None: triples = []
if v.p is not None:
if v.o is not None:
# SPO if v.s is not None:
if v.p is not None:
if v.o is not None:
records, summary, keys = self.io.execute_query( # SPO
"MATCH (src:Node {uri: $src})-[rel:Rel {uri: $rel}]->(dest:Literal {value: $value}) "
"RETURN $src as src",
src=v.s.value, rel=v.p.value, value=v.o.value,
database_=self.db,
)
for rec in records: records, summary, keys = self.io.execute_query(
triples.append((v.s.value, v.p.value, v.o.value)) "MATCH (src:Node {uri: $src})-[rel:Rel {uri: $rel}]->(dest:Literal {value: $value}) "
"RETURN $src as src",
records, summary, keys = self.io.execute_query( src=v.s.value, rel=v.p.value, value=v.o.value,
"MATCH (src:Node {uri: $src})-[rel:Rel {uri: $rel}]->(dest:Node {uri: $uri}) " database_=self.db,
"RETURN $src as src", )
src=v.s.value, rel=v.p.value, uri=v.o.value,
database_=self.db,
)
for rec in records: for rec in records:
triples.append((v.s.value, v.p.value, v.o.value)) triples.append((v.s.value, v.p.value, v.o.value))
records, summary, keys = self.io.execute_query(
"MATCH (src:Node {uri: $src})-[rel:Rel {uri: $rel}]->(dest:Node {uri: $uri}) "
"RETURN $src as src",
src=v.s.value, rel=v.p.value, uri=v.o.value,
database_=self.db,
)
for rec in records:
triples.append((v.s.value, v.p.value, v.o.value))
else:
# SP
records, summary, keys = self.io.execute_query(
"MATCH (src:Node {uri: $src})-[rel:Rel {uri: $rel}]->(dest:Literal) "
"RETURN dest.value as dest",
src=v.s.value, rel=v.p.value,
database_=self.db,
)
for rec in records:
data = rec.data()
triples.append((v.s.value, v.p.value, data["dest"]))
records, summary, keys = self.io.execute_query(
"MATCH (src:Node {uri: $src})-[rel:Rel {uri: $rel}]->(dest:Node) "
"RETURN dest.uri as dest",
src=v.s.value, rel=v.p.value,
database_=self.db,
)
for rec in records:
data = rec.data()
triples.append((v.s.value, v.p.value, data["dest"]))
else: else:
# SP if v.o is not None:
records, summary, keys = self.io.execute_query(
"MATCH (src:Node {uri: $src})-[rel:Rel {uri: $rel}]->(dest:Literal) "
"RETURN dest.value as dest",
src=v.s.value, rel=v.p.value,
database_=self.db,
)
for rec in records: # SO
data = rec.data()
triples.append((v.s.value, v.p.value, data["dest"])) records, summary, keys = self.io.execute_query(
"MATCH (src:Node {uri: $src})-[rel:Rel]->(dest:Literal {value: $value}) "
records, summary, keys = self.io.execute_query( "RETURN rel.uri as rel",
"MATCH (src:Node {uri: $src})-[rel:Rel {uri: $rel}]->(dest:Node) " src=v.s.value, value=v.o.value,
"RETURN dest.uri as dest", database_=self.db,
src=v.s.value, rel=v.p.value, )
database_=self.db,
) for rec in records:
data = rec.data()
triples.append((v.s.value, data["rel"], v.o.value))
records, summary, keys = self.io.execute_query(
"MATCH (src:Node {uri: $src})-[rel:Rel]->(dest:Node {uri: $uri}) "
"RETURN rel.uri as rel",
src=v.s.value, uri=v.o.value,
database_=self.db,
)
for rec in records:
data = rec.data()
triples.append((v.s.value, data["rel"], v.o.value))
else:
# S
records, summary, keys = self.io.execute_query(
"MATCH (src:Node {uri: $src})-[rel:Rel]->(dest:Literal) "
"RETURN rel.uri as rel, dest.value as dest",
src=v.s.value,
database_=self.db,
)
for rec in records:
data = rec.data()
triples.append((v.s.value, data["rel"], data["dest"]))
records, summary, keys = self.io.execute_query(
"MATCH (src:Node {uri: $src})-[rel:Rel]->(dest:Node) "
"RETURN rel.uri as rel, dest.uri as dest",
src=v.s.value,
database_=self.db,
)
for rec in records:
data = rec.data()
triples.append((v.s.value, data["rel"], data["dest"]))
for rec in records:
data = rec.data()
triples.append((v.s.value, v.p.value, data["dest"]))
else: else:
if v.o is not None: if v.p is not None:
# SO if v.o is not None:
records, summary, keys = self.io.execute_query( # PO
"MATCH (src:Node {uri: $src})-[rel:Rel]->(dest:Literal {value: $value}) "
"RETURN rel.uri as rel",
src=v.s.value, value=v.o.value,
database_=self.db,
)
for rec in records: records, summary, keys = self.io.execute_query(
data = rec.data() "MATCH (src:Node)-[rel:Rel {uri: $uri}]->(dest:Literal {value: $value}) "
triples.append((v.s.value, data["rel"], v.o.value)) "RETURN src.uri as src",
uri=v.p.value, value=v.o.value,
records, summary, keys = self.io.execute_query( database_=self.db,
"MATCH (src:Node {uri: $src})-[rel:Rel]->(dest:Node {uri: $uri}) " )
"RETURN rel.uri as rel",
src=v.s.value, uri=v.o.value,
database_=self.db,
)
for rec in records: for rec in records:
data = rec.data() data = rec.data()
triples.append((v.s.value, data["rel"], v.o.value)) triples.append((data["src"], v.p.value, v.o.value))
records, summary, keys = self.io.execute_query(
"MATCH (src:Node)-[rel:Rel {uri: $uri}]->(dest:Node {uri: $uri}) "
"RETURN src.uri as src",
uri=v.p.value, dest=v.o.value,
database_=self.db,
)
for rec in records:
data = rec.data()
triples.append((data["src"], v.p.value, v.o.value))
else:
# P
records, summary, keys = self.io.execute_query(
"MATCH (src:Node)-[rel:Rel {uri: $uri}]->(dest:Literal) "
"RETURN src.uri as src, dest.value as dest",
uri=v.p.value,
database_=self.db,
)
for rec in records:
data = rec.data()
triples.append((data["src"], v.p.value, data["dest"]))
records, summary, keys = self.io.execute_query(
"MATCH (src:Node)-[rel:Rel {uri: $uri}]->(dest:Node) "
"RETURN src.uri as src, dest.uri as dest",
uri=v.p.value,
database_=self.db,
)
for rec in records:
data = rec.data()
triples.append((data["src"], v.p.value, data["dest"]))
else: else:
# S if v.o is not None:
records, summary, keys = self.io.execute_query( # O
"MATCH (src:Node {uri: $src})-[rel:Rel]->(dest:Literal) "
"RETURN rel.uri as rel, dest.value as dest",
src=v.s.value,
database_=self.db,
)
for rec in records: records, summary, keys = self.io.execute_query(
data = rec.data() "MATCH (src:Node)-[rel:Rel]->(dest:Literal {value: $value}) "
triples.append((v.s.value, data["rel"], data["dest"])) "RETURN src.uri as src, rel.uri as rel",
value=v.o.value,
records, summary, keys = self.io.execute_query( database_=self.db,
"MATCH (src:Node {uri: $src})-[rel:Rel]->(dest:Node) " )
"RETURN rel.uri as rel, dest.uri as dest",
src=v.s.value,
database_=self.db,
)
for rec in records: for rec in records:
data = rec.data() data = rec.data()
triples.append((v.s.value, data["rel"], data["dest"])) triples.append((data["src"], data["rel"], v.o.value))
records, summary, keys = self.io.execute_query(
"MATCH (src:Node)-[rel:Rel]->(dest:Node {uri: $uri}) "
"RETURN src.uri as src, rel.uri as rel",
uri=v.o.value,
database_=self.db,
)
else: for rec in records:
data = rec.data()
triples.append((data["src"], data["rel"], v.o.value))
if v.p is not None: else:
if v.o is not None: # *
# PO records, summary, keys = self.io.execute_query(
"MATCH (src:Node)-[rel:Rel]->(dest:Literal) "
"RETURN src.uri as src, rel.uri as rel, dest.value as dest",
database_=self.db,
)
records, summary, keys = self.io.execute_query( for rec in records:
"MATCH (src:Node)-[rel:Rel {uri: $uri}]->(dest:Literal {value: $value}) " data = rec.data()
"RETURN src.uri as src", triples.append((data["src"], data["rel"], data["dest"]))
uri=v.p.value, value=v.o.value,
database_=self.db,
)
for rec in records: records, summary, keys = self.io.execute_query(
data = rec.data() "MATCH (src:Node)-[rel:Rel]->(dest:Node) "
triples.append((data["src"], v.p.value, v.o.value)) "RETURN src.uri as src, rel.uri as rel, dest.uri as dest",
database_=self.db,
records, summary, keys = self.io.execute_query( )
"MATCH (src:Node)-[rel:Rel {uri: $uri}]->(dest:Node {uri: $uri}) "
"RETURN src.uri as src",
uri=v.p.value, dest=v.o.value,
database_=self.db,
)
for rec in records: for rec in records:
data = rec.data() data = rec.data()
triples.append((data["src"], v.p.value, v.o.value)) triples.append((data["src"], data["rel"], data["dest"]))
else: triples = [
Triple(
s=self.create_value(t[0]),
p=self.create_value(t[1]),
o=self.create_value(t[2])
)
for t in triples
]
# P print("Send response...", flush=True)
r = TriplesQueryResponse(triples=triples)
self.producer.send(r, properties={"id": id})
records, summary, keys = self.io.execute_query( print("Done.", flush=True)
"MATCH (src:Node)-[rel:Rel {uri: $uri}]->(dest:Literal) "
"RETURN src.uri as src, dest.value as dest",
uri=v.p.value,
database_=self.db,
)
for rec in records: except Exception as e:
data = rec.data()
triples.append((data["src"], v.p.value, data["dest"]))
records, summary, keys = self.io.execute_query(
"MATCH (src:Node)-[rel:Rel {uri: $uri}]->(dest:Node) "
"RETURN src.uri as src, dest.uri as dest",
uri=v.p.value,
database_=self.db,
)
for rec in records: print(f"Exception: {e}")
data = rec.data()
triples.append((data["src"], v.p.value, data["dest"]))
else: print("Send error response...", flush=True)
if v.o is not None: r = TriplesQueryResponse(
error=Error(
# O type = "llm-error",
message = str(e),
records, summary, keys = self.io.execute_query( ),
"MATCH (src:Node)-[rel:Rel]->(dest:Literal {value: $value}) " response=None,
"RETURN src.uri as src, rel.uri as rel",
value=v.o.value,
database_=self.db,
)
for rec in records:
data = rec.data()
triples.append((data["src"], data["rel"], v.o.value))
records, summary, keys = self.io.execute_query(
"MATCH (src:Node)-[rel:Rel]->(dest:Node {uri: $uri}) "
"RETURN src.uri as src, rel.uri as rel",
uri=v.o.value,
database_=self.db,
)
for rec in records:
data = rec.data()
triples.append((data["src"], data["rel"], v.o.value))
else:
# *
records, summary, keys = self.io.execute_query(
"MATCH (src:Node)-[rel:Rel]->(dest:Literal) "
"RETURN src.uri as src, rel.uri as rel, dest.value as dest",
database_=self.db,
)
for rec in records:
data = rec.data()
triples.append((data["src"], data["rel"], data["dest"]))
records, summary, keys = self.io.execute_query(
"MATCH (src:Node)-[rel:Rel]->(dest:Node) "
"RETURN src.uri as src, rel.uri as rel, dest.uri as dest",
database_=self.db,
)
for rec in records:
data = rec.data()
triples.append((data["src"], data["rel"], data["dest"]))
triples = [
Triple(
s=self.create_value(t[0]),
p=self.create_value(t[1]),
o=self.create_value(t[2])
) )
for t in triples
]
print("Send response...", flush=True) self.producer.send(r, properties={"id": id})
r = TriplesQueryResponse(triples=triples)
self.producer.send(r, properties={"id": id})
print("Done.", flush=True)
self.consumer.acknowledge(msg)
@staticmethod @staticmethod
def add_args(parser): def add_args(parser):

View file

@ -4,7 +4,7 @@ Simple RAG service, performs query using graph RAG an LLM.
Input is query, output is response. Input is query, output is response.
""" """
from ... schema import GraphRagQuery, GraphRagResponse from ... schema import GraphRagQuery, GraphRagResponse, Error
from ... schema import graph_rag_request_queue, graph_rag_response_queue from ... schema import graph_rag_request_queue, graph_rag_response_queue
from ... schema import prompt_request_queue from ... schema import prompt_request_queue
from ... schema import prompt_response_queue from ... schema import prompt_response_queue
@ -99,21 +99,40 @@ class Processor(ConsumerProducer):
def handle(self, msg): def handle(self, msg):
v = msg.value() try:
# Sender-produced ID v = msg.value()
id = msg.properties()["id"] # Sender-produced ID
id = msg.properties()["id"]
print(f"Handling input {id}...", flush=True) print(f"Handling input {id}...", flush=True)
response = self.rag.query(v.query) response = self.rag.query(v.query)
print("Send response...", flush=True) print("Send response...", flush=True)
r = GraphRagResponse(response = response) r = GraphRagResponse(response = response, error=None)
self.producer.send(r, properties={"id": id}) self.producer.send(r, properties={"id": id})
print("Done.", flush=True) print("Done.", flush=True)
except Exception as e:
print(f"Exception: {e}")
print("Send error response...", flush=True)
r = GraphRagResponse(
error=Error(
type = "llm-error",
message = str(e),
),
response=None,
)
self.producer.send(r, properties={"id": id})
self.consumer.acknowledge(msg)
@staticmethod @staticmethod
def add_args(parser): def add_args(parser):

View file

@ -8,6 +8,12 @@ def topic(topic, kind='persistent', tenant='tg', namespace='flow'):
############################################################################ ############################################################################
class Error(Record):
type = String()
message = String()
############################################################################
class Value(Record): class Value(Record):
value = String() value = String()
is_uri = Boolean() is_uri = Boolean()
@ -78,6 +84,7 @@ class GraphEmbeddingsRequest(Record):
limit = Integer() limit = Integer()
class GraphEmbeddingsResponse(Record): class GraphEmbeddingsResponse(Record):
error = Error()
entities = Array(Value()) entities = Array(Value())
graph_embeddings_request_queue = topic( graph_embeddings_request_queue = topic(
@ -110,6 +117,7 @@ class TriplesQueryRequest(Record):
limit = Integer() limit = Integer()
class TriplesQueryResponse(Record): class TriplesQueryResponse(Record):
error = Error()
triples = Array(Triple()) triples = Array(Triple())
triples_request_queue = topic( triples_request_queue = topic(
@ -131,6 +139,7 @@ class TextCompletionRequest(Record):
prompt = String() prompt = String()
class TextCompletionResponse(Record): class TextCompletionResponse(Record):
error = Error()
response = String() response = String()
text_completion_request_queue = topic( text_completion_request_queue = topic(
@ -148,6 +157,7 @@ class EmbeddingsRequest(Record):
text = String() text = String()
class EmbeddingsResponse(Record): class EmbeddingsResponse(Record):
error = Error()
vectors = Array(Array(Double())) vectors = Array(Array(Double()))
embeddings_request_queue = topic( embeddings_request_queue = topic(
@ -165,6 +175,7 @@ class GraphRagQuery(Record):
query = String() query = String()
class GraphRagResponse(Record): class GraphRagResponse(Record):
error = Error()
response = String() response = String()
graph_rag_request_queue = topic( graph_rag_request_queue = topic(
@ -207,6 +218,7 @@ class PromptRequest(Record):
kg = Array(Fact()) kg = Array(Fact())
class PromptResponse(Record): class PromptResponse(Record):
error = Error()
answer = String() answer = String()
definitions = Array(Definition()) definitions = Array(Definition())
relationships = Array(Relationship()) relationships = Array(Relationship())