diff --git a/trustgraph/clients/embeddings_client.py b/trustgraph/clients/embeddings_client.py index 6724acc2..9b90b758 100644 --- a/trustgraph/clients/embeddings_client.py +++ b/trustgraph/clients/embeddings_client.py @@ -1,9 +1,11 @@ -import pulsar -import _pulsar from pulsar.schema import JsonSchema from .. schema import EmbeddingsRequest, EmbeddingsResponse from .. schema import embeddings_request_queue, embeddings_response_queue +from . base import BaseClient + +import pulsar +import _pulsar import hashlib import uuid import time @@ -14,7 +16,7 @@ WARN=_pulsar.LoggerLevel.Warn INFO=_pulsar.LoggerLevel.Info DEBUG=_pulsar.LoggerLevel.Debug -class EmbeddingsClient: +class EmbeddingsClient(BaseClient): def __init__( self, log_level=ERROR, @@ -24,72 +26,23 @@ class EmbeddingsClient: pulsar_host="pulsar://pulsar:6650", ): - self.client = None - if input_queue == None: input_queue=embeddings_request_queue if output_queue == None: output_queue=embeddings_response_queue - if subscriber == None: - subscriber = str(uuid.uuid4()) - - self.client = pulsar.Client( - pulsar_host, - logger=pulsar.ConsoleLogger(log_level), + super(EmbeddingsClient, self).__init__( + log_level=log_level, + subscriber=subscriber, + input_queue=input_queue, + output_queue=output_queue, + pulsar_host=pulsar_host, + input_schema=EmbeddingsRequest, + output_schema=EmbeddingsResponse, ) - self.producer = self.client.create_producer( - topic=input_queue, - schema=JsonSchema(EmbeddingsRequest), - chunking_enabled=True, - ) + def request(self, text, timeout=30): + return self.call(text=text, timeout=timeout).vectors - self.consumer = self.client.subscribe( - output_queue, subscriber, - schema=JsonSchema(EmbeddingsResponse), - ) - - def request(self, text, timeout=10): - - id = str(uuid.uuid4()) - - r = EmbeddingsRequest( - text=text - ) - self.producer.send(r, properties={ "id": id }) - - end_time = time.time() + timeout - - while time.time() < end_time: - - try: - msg = self.consumer.receive(timeout_millis=5000) - except pulsar.exceptions.Timeout: - continue - - mid = msg.properties()["id"] - - if mid == id: - resp = msg.value().vectors - self.consumer.acknowledge(msg) - return resp - - # Ignore messages with wrong ID - self.consumer.acknowledge(msg) - - raise TimeoutError("Timed out waiting for response") - - def __del__(self): - - if hasattr(self, "consumer"): -# self.consumer.unsubscribe() - self.consumer.close() - - if hasattr(self, "producer"): - self.producer.flush() - self.producer.close() - - self.client.close() diff --git a/trustgraph/clients/graph_embeddings_client.py b/trustgraph/clients/graph_embeddings_client.py index 10e2dee3..3f361ecd 100644 --- a/trustgraph/clients/graph_embeddings_client.py +++ b/trustgraph/clients/graph_embeddings_client.py @@ -9,6 +9,7 @@ import time from .. schema import GraphEmbeddingsRequest, GraphEmbeddingsResponse from .. schema import graph_embeddings_request_queue from .. schema import graph_embeddings_response_queue +from . base import BaseClient # Ugly ERROR=_pulsar.LoggerLevel.Error @@ -16,7 +17,7 @@ WARN=_pulsar.LoggerLevel.Warn INFO=_pulsar.LoggerLevel.Info DEBUG=_pulsar.LoggerLevel.Debug -class GraphEmbeddingsClient: +class GraphEmbeddingsClient(BaseClient): def __init__( self, log_level=ERROR, @@ -31,66 +32,19 @@ class GraphEmbeddingsClient: if output_queue == None: output_queue = graph_embeddings_response_queue - - if subscriber == None: - subscriber = str(uuid.uuid4()) - - self.client = pulsar.Client( - pulsar_host, - logger=pulsar.ConsoleLogger(log_level), - ) - - self.producer = self.client.create_producer( - topic=input_queue, - schema=JsonSchema(GraphEmbeddingsRequest), - chunking_enabled=True, - ) - - self.consumer = self.client.subscribe( - output_queue, subscriber, - schema=JsonSchema(GraphEmbeddingsResponse), - ) - - def request(self, vectors, limit=10, timeout=500): - - id = str(uuid.uuid4()) - - r = GraphEmbeddingsRequest( - vectors=vectors, - limit=limit, - ) - - self.producer.send(r, properties={ "id": id }) - - end_time = time.time() + timeout - - while time.time() < end_time: - - try: - msg = self.consumer.receive(timeout_millis=5000) - except pulsar.exceptions.Timeout: - continue - - mid = msg.properties()["id"] - - if mid == id: - resp = msg.value().entities - self.consumer.acknowledge(msg) - return resp - - # Ignore messages with wrong ID - self.consumer.acknowledge(msg) - - raise TimeoutError("Timed out waiting for response") - - def __del__(self): - - if hasattr(self, "consumer"): - self.consumer.close() - if hasattr(self, "producer"): - self.producer.flush() - self.producer.close() - - self.client.close() + super(GraphEmbeddingsClient, self).__init__( + log_level=log_level, + subscriber=subscriber, + input_queue=input_queue, + output_queue=output_queue, + pulsar_host=pulsar_host, + input_schema=GraphEmbeddingsRequest, + output_schema=GraphEmbeddingsResponse, + ) + + def request(self, vectors, limit=10, timeout=30): + return self.call( + vectors=vectors, limit=limit, timeout=timeout + ).entities diff --git a/trustgraph/clients/graph_rag_client.py b/trustgraph/clients/graph_rag_client.py index 9ecf1695..60f8d47c 100644 --- a/trustgraph/clients/graph_rag_client.py +++ b/trustgraph/clients/graph_rag_client.py @@ -4,6 +4,7 @@ import _pulsar from pulsar.schema import JsonSchema from .. schema import GraphRagQuery, GraphRagResponse from .. schema import graph_rag_request_queue, graph_rag_response_queue +from . base import BaseClient import hashlib import uuid @@ -15,71 +16,36 @@ WARN=_pulsar.LoggerLevel.Warn INFO=_pulsar.LoggerLevel.Info DEBUG=_pulsar.LoggerLevel.Debug -class GraphRagClient: +class GraphRagClient(BaseClient): def __init__( - self, log_level=ERROR, subscriber=None, + self, + log_level=ERROR, + subscriber=None, + input_queue=None, + output_queue=None, pulsar_host="pulsar://pulsar:6650", ): - if subscriber == None: - subscriber = str(uuid.uuid4()) + if input_queue == None: + input_queue = graph_rag_request_queue - self.client = pulsar.Client( - pulsar_host, - logger=pulsar.ConsoleLogger(log_level), - ) - - self.producer = self.client.create_producer( - topic=graph_rag_request_queue, - schema=JsonSchema(GraphRagQuery), - chunking_enabled=True, - ) - - self.consumer = self.client.subscribe( - graph_rag_response_queue, subscriber, - schema=JsonSchema(GraphRagResponse), + if output_queue == None: + output_queue = graph_rag_response_queue + + super(GraphRagClient, self).__init__( + log_level=log_level, + subscriber=subscriber, + input_queue=input_queue, + output_queue=output_queue, + pulsar_host=pulsar_host, + input_schema=GraphRagQuery, + output_schema=GraphRagResponse, ) def request(self, query, timeout=500): - id = str(uuid.uuid4()) - - r = GraphRagQuery( - query=query - ) - self.producer.send(r, properties={ "id": id }) - - end_time = time.time() + timeout - - while time.time() < end_time: - - try: - msg = self.consumer.receive(timeout_millis=5000) - except pulsar.exceptions.Timeout: - continue - - mid = msg.properties()["id"] - - if mid == id: - resp = msg.value().response - self.consumer.acknowledge(msg) - return resp - - # Ignore messages with wrong ID - self.consumer.acknowledge(msg) - - raise TimeoutError("Timed out waiting for response") - - def __del__(self): - - if hasattr(self, "consumer"): -# self.consumer.unsubscribe() - self.consumer.close() - - if hasattr(self, "producer"): - self.producer.flush() - self.producer.close() - - self.client.close() + return self.call( + query=query, timeout=timeout + ).response diff --git a/trustgraph/clients/llm_client.py b/trustgraph/clients/llm_client.py index 23af764f..064b286d 100644 --- a/trustgraph/clients/llm_client.py +++ b/trustgraph/clients/llm_client.py @@ -9,6 +9,8 @@ import time from .. schema import TextCompletionRequest, TextCompletionResponse from .. schema import text_completion_request_queue from .. schema import text_completion_response_queue +from .. exceptions import * +from . base import BaseClient # Ugly ERROR=_pulsar.LoggerLevel.Error @@ -16,7 +18,7 @@ WARN=_pulsar.LoggerLevel.Warn INFO=_pulsar.LoggerLevel.Info DEBUG=_pulsar.LoggerLevel.Debug -class LlmClient: +class LlmClient(BaseClient): def __init__( self, log_level=ERROR, @@ -26,71 +28,19 @@ class LlmClient: pulsar_host="pulsar://pulsar:6650", ): - if input_queue == None: - input_queue = text_completion_request_queue + if input_queue is None: input_queue = text_completion_request_queue + if output_queue is None: output_queue = text_completion_response_queue - if output_queue == None: - output_queue = text_completion_response_queue - - if subscriber == None: - subscriber = str(uuid.uuid4()) - - self.client = pulsar.Client( - pulsar_host, - logger=pulsar.ConsoleLogger(log_level), - ) - - self.producer = self.client.create_producer( - topic=input_queue, - schema=JsonSchema(TextCompletionRequest), - chunking_enabled=True, - ) - - self.consumer = self.client.subscribe( - output_queue, subscriber, - schema=JsonSchema(TextCompletionResponse), + super(LlmClient, self).__init__( + log_level=log_level, + subscriber=subscriber, + input_queue=input_queue, + output_queue=output_queue, + pulsar_host=pulsar_host, + input_schema=TextCompletionRequest, + output_schema=TextCompletionResponse, ) def request(self, prompt, timeout=30): - - id = str(uuid.uuid4()) - - r = TextCompletionRequest( - prompt=prompt - ) - - end_time = time.time() + timeout - - self.producer.send(r, properties={ "id": id }) - - while time.time() < end_time: - - try: - msg = self.consumer.receive(timeout_millis=5000) - except pulsar.exceptions.Timeout: - continue - - mid = msg.properties()["id"] - - if mid == id: - resp = msg.value().response - self.consumer.acknowledge(msg) - return resp - - # Ignore messages with wrong ID - self.consumer.acknowledge(msg) - - raise TimeoutError("Timed out waiting for response") - - def __del__(self): - - if hasattr(self, "consumer"): -# self.consumer.unsubscribe() - self.consumer.close() - - if hasattr(self, "producer"): - self.producer.flush() - self.producer.close() - - self.client.close() + return self.call(prompt=prompt, timeout=timeout).response diff --git a/trustgraph/clients/prompt_client.py b/trustgraph/clients/prompt_client.py index 3b5ad6e6..6bec8839 100644 --- a/trustgraph/clients/prompt_client.py +++ b/trustgraph/clients/prompt_client.py @@ -9,6 +9,7 @@ import time from .. schema import PromptRequest, PromptResponse, Fact from .. schema import prompt_request_queue from .. schema import prompt_response_queue +from . base import BaseClient # Ugly ERROR=_pulsar.LoggerLevel.Error @@ -16,7 +17,7 @@ WARN=_pulsar.LoggerLevel.Warn INFO=_pulsar.LoggerLevel.Info DEBUG=_pulsar.LoggerLevel.Debug -class PromptClient: +class PromptClient(BaseClient): def __init__( self, log_level=ERROR, @@ -32,133 +33,35 @@ class PromptClient: if output_queue == None: output_queue = prompt_response_queue - if subscriber == None: - subscriber = str(uuid.uuid4()) - - self.client = pulsar.Client( - pulsar_host, - logger=pulsar.ConsoleLogger(log_level), - ) - - self.producer = self.client.create_producer( - topic=input_queue, - schema=JsonSchema(PromptRequest), - chunking_enabled=True, - ) - - self.consumer = self.client.subscribe( - output_queue, subscriber, - schema=JsonSchema(PromptResponse), + super(PromptClient, self).__init__( + log_level=log_level, + subscriber=subscriber, + input_queue=input_queue, + output_queue=output_queue, + pulsar_host=pulsar_host, + input_schema=PromptRequest, + output_schema=PromptResponse, ) def request_definitions(self, chunk, timeout=30): - id = str(uuid.uuid4()) - - r = PromptRequest( - kind="extract-definitions", - chunk=chunk, - ) - - self.producer.send(r, properties={ "id": id }) - - end_time = time.time() + timeout - - while time.time() < end_time: - - try: - msg = self.consumer.receive(timeout_millis=5000) - except pulsar.exceptions.Timeout: - continue - - mid = msg.properties()["id"] - - if mid == id: - resp = msg.value().definitions - self.consumer.acknowledge(msg) - return resp - - # Ignore messages with wrong ID - self.consumer.acknowledge(msg) - - raise TimeoutError("Timed out waiting for response") + return self.call(kind="extract-definitions", chunk=chunk, + timeout=timeout).definitions def request_relationships(self, chunk, timeout=30): - id = str(uuid.uuid4()) - - r = PromptRequest( - kind="extract-relationships", - chunk=chunk, - ) - - self.producer.send(r, properties={ "id": id }) - - end_time = time.time() + timeout - - while time.time() < end_time: - - try: - msg = self.consumer.receive(timeout_millis=5000) - except pulsar.exceptions.Timeout: - continue - - mid = msg.properties()["id"] - - if mid == id: - resp = msg.value().relationships - self.consumer.acknowledge(msg) - return resp - - # Ignore messages with wrong ID - self.consumer.acknowledge(msg) - - raise TimeoutError("Timed out waiting for response") + return self.call(kind="extract-relationships", chunk=chunk, + timeout=timeout).relationships def request_kg_prompt(self, query, kg, timeout=30): - id = str(uuid.uuid4()) - - r = PromptRequest( + return self.call( kind="kg-prompt", query=query, kg=[ Fact(s=v[0], p=v[1], o=v[2]) for v in kg ], - ) - - self.producer.send(r, properties={ "id": id }) - - end_time = time.time() + timeout - - while time.time() < end_time: - - try: - msg = self.consumer.receive(timeout_millis=5000) - except pulsar.exceptions.Timeout: - continue - - mid = msg.properties()["id"] - - if mid == id: - resp = msg.value().answer - self.consumer.acknowledge(msg) - return resp - - # Ignore messages with wrong ID - self.consumer.acknowledge(msg) - - raise TimeoutError("Timed out waiting for response") - - def __del__(self): - - if hasattr(self, "consumer"): - self.consumer.close() - - if hasattr(self, "producer"): - self.producer.flush() - self.producer.close() - - self.client.close() + timeout=timeout + ).answer diff --git a/trustgraph/clients/triples_query_client.py b/trustgraph/clients/triples_query_client.py index ab5f788c..a3246189 100644 --- a/trustgraph/clients/triples_query_client.py +++ b/trustgraph/clients/triples_query_client.py @@ -10,6 +10,7 @@ import time from .. schema import TriplesQueryRequest, TriplesQueryResponse, Value from .. schema import triples_request_queue from .. schema import triples_response_queue +from . base import BaseClient # Ugly ERROR=_pulsar.LoggerLevel.Error @@ -17,7 +18,7 @@ WARN=_pulsar.LoggerLevel.Warn INFO=_pulsar.LoggerLevel.Info DEBUG=_pulsar.LoggerLevel.Debug -class TriplesQueryClient: +class TriplesQueryClient(BaseClient): def __init__( self, log_level=ERROR, @@ -33,23 +34,14 @@ class TriplesQueryClient: if output_queue == None: output_queue = triples_response_queue - if subscriber == None: - subscriber = str(uuid.uuid4()) - - self.client = pulsar.Client( - pulsar_host, - logger=pulsar.ConsoleLogger(log_level), - ) - - self.producer = self.client.create_producer( - topic=input_queue, - schema=JsonSchema(TriplesQueryRequest), - chunking_enabled=True, - ) - - self.consumer = self.client.subscribe( - output_queue, subscriber, - schema=JsonSchema(TriplesQueryResponse), + super(TriplesQueryClient, self).__init__( + log_level=log_level, + subscriber=subscriber, + input_queue=input_queue, + output_queue=output_queue, + pulsar_host=pulsar_host, + input_schema=TriplesQueryRequest, + output_schema=TriplesQueryResponse, ) def create_value(self, ent): @@ -61,48 +53,12 @@ class TriplesQueryClient: return Value(value=ent, is_uri=False) - def request(self, s, p, o, limit=10, timeout=500): - - id = str(uuid.uuid4()) - - r = TriplesQueryRequest( + def request(self, s, p, o, limit=10, timeout=30): + return self.call( s=self.create_value(s), p=self.create_value(p), o=self.create_value(o), limit=limit, - ) - - self.producer.send(r, properties={ "id": id }) - - end_time = time.time() + timeout - - while time.time() < end_time: - - try: - msg = self.consumer.receive(timeout_millis=5000) - except pulsar.exceptions.Timeout: - continue - - mid = msg.properties()["id"] - - if mid == id: - resp = msg.value().triples - self.consumer.acknowledge(msg) - return resp - - # Ignore messages with wrong ID - self.consumer.acknowledge(msg) - - raise TimeoutError("Timed out waiting for response") - - def __del__(self): - - if hasattr(self, "consumer"): - self.consumer.close() - - if hasattr(self, "producer"): - self.producer.flush() - self.producer.close() - - self.client.close() + timeout=timeout, + ).triples diff --git a/trustgraph/embeddings/hf/hf.py b/trustgraph/embeddings/hf/hf.py index 20e57d53..8b9dcaab 100755 --- a/trustgraph/embeddings/hf/hf.py +++ b/trustgraph/embeddings/hf/hf.py @@ -6,7 +6,7 @@ Input is text, output is embeddings vector. from langchain_huggingface import HuggingFaceEmbeddings -from ... schema import EmbeddingsRequest, EmbeddingsResponse +from ... schema import EmbeddingsRequest, EmbeddingsResponse, Error from ... schema import embeddings_request_queue, embeddings_response_queue from ... log_level import LogLevel from ... base import ConsumerProducer @@ -48,14 +48,36 @@ class Processor(ConsumerProducer): print(f"Handling input {id}...", flush=True) - text = v.text - embeds = self.embeddings.embed_documents([text]) + try: - print("Send response...", flush=True) - r = EmbeddingsResponse(vectors=embeds) - self.producer.send(r, properties={"id": id}) + text = v.text + embeds = self.embeddings.embed_documents([text]) - print("Done.", flush=True) + print("Send response...", flush=True) + r = EmbeddingsResponse(vectors=embeds, error=None) + self.producer.send(r, properties={"id": id}) + + print("Done.", flush=True) + + + except Exception as e: + + print(f"Exception: {e}") + + print("Send error response...", flush=True) + + r = EmbeddingsResponse( + error=Error( + type = "llm-error", + message = str(e), + ), + response=None, + ) + + self.producer.send(r, properties={"id": id}) + + self.consumer.acknowledge(msg) + @staticmethod def add_args(parser): diff --git a/trustgraph/exceptions.py b/trustgraph/exceptions.py index e5647e3e..16f9956c 100644 --- a/trustgraph/exceptions.py +++ b/trustgraph/exceptions.py @@ -2,3 +2,13 @@ class TooManyRequests(Exception): pass +class LlmError(Exception): + pass + +class ParseError(Exception): + pass + + + + + diff --git a/trustgraph/model/prompt/generic/service.py b/trustgraph/model/prompt/generic/service.py index 90b28af2..c005c296 100755 --- a/trustgraph/model/prompt/generic/service.py +++ b/trustgraph/model/prompt/generic/service.py @@ -6,7 +6,7 @@ Language service abstracts prompt engineering from LLM. import json from .... schema import Definition, Relationship, Triple -from .... schema import PromptRequest, PromptResponse +from .... schema import PromptRequest, PromptResponse, Error from .... schema import TextCompletionRequest, TextCompletionResponse from .... schema import text_completion_request_queue from .... schema import text_completion_response_queue @@ -89,91 +89,151 @@ class Processor(ConsumerProducer): def handle_extract_definitions(self, id, v): - prompt = to_definitions(v.chunk) - - ans = self.llm.request(prompt) - - # Silently ignore JSON parse error try: - defs = json.loads(ans) - except: - print("JSON parse error, ignored", flush=True) - defs = [] - output = [] + prompt = to_definitions(v.chunk) - for defn in defs: + ans = self.llm.request(prompt) + # Silently ignore JSON parse error try: - e = defn["entity"] - d = defn["definition"] - - output.append( - Definition( - name=e, definition=d - ) - ) - + defs = json.loads(ans) except: - print("definition fields missing, ignored", flush=True) + print("JSON parse error, ignored", flush=True) + defs = [] - print("Send response...", flush=True) - r = PromptResponse(definitions=output) - self.producer.send(r, properties={"id": id}) + output = [] - print("Done.", flush=True) + for defn in defs: + + try: + e = defn["entity"] + d = defn["definition"] + + output.append( + Definition( + name=e, definition=d + ) + ) + + except: + print("definition fields missing, ignored", flush=True) + + print("Send response...", flush=True) + r = PromptResponse(definitions=output, error=None) + self.producer.send(r, properties={"id": id}) + + print("Done.", flush=True) + except Exception as e: + + print(f"Exception: {e}") + + print("Send error response...", flush=True) + + r = PromptResponse( + error=Error( + type = "llm-error", + message = str(e), + ), + response=None, + ) + + self.producer.send(r, properties={"id": id}) + + self.consumer.acknowledge(msg) + def handle_extract_relationships(self, id, v): - prompt = to_relationships(v.chunk) - - ans = self.llm.request(prompt) - - # Silently ignore JSON parse error try: - defs = json.loads(ans) - except: - print("JSON parse error, ignored", flush=True) - defs = [] - output = [] + prompt = to_relationships(v.chunk) - for defn in defs: + ans = self.llm.request(prompt) + # Silently ignore JSON parse error try: - output.append( - Relationship( - s = defn["subject"], - p = defn["predicate"], - o = defn["object"], - o_entity = defn["object-entity"], + defs = json.loads(ans) + except: + print("JSON parse error, ignored", flush=True) + defs = [] + + output = [] + + for defn in defs: + + try: + output.append( + Relationship( + s = defn["subject"], + p = defn["predicate"], + o = defn["object"], + o_entity = defn["object-entity"], + ) ) - ) - except Exception as e: - print("relationship fields missing, ignored", flush=True) + except Exception as e: + print("relationship fields missing, ignored", flush=True) - print("Send response...", flush=True) - r = PromptResponse(relationships=output) - self.producer.send(r, properties={"id": id}) + print("Send response...", flush=True) + r = PromptResponse(relationships=output, error=None) + self.producer.send(r, properties={"id": id}) - print("Done.", flush=True) + print("Done.", flush=True) + + except Exception as e: + + print(f"Exception: {e}") + + print("Send error response...", flush=True) + + r = PromptResponse( + error=Error( + type = "llm-error", + message = str(e), + ), + response=None, + ) + + self.producer.send(r, properties={"id": id}) + + self.consumer.acknowledge(msg) def handle_kg_prompt(self, id, v): - prompt = to_kg_query(v.query, v.kg) + try: - print(prompt) + prompt = to_kg_query(v.query, v.kg) - ans = self.llm.request(prompt) + print(prompt) - print(ans) + ans = self.llm.request(prompt) - print("Send response...", flush=True) - r = PromptResponse(answer=ans) - self.producer.send(r, properties={"id": id}) + print(ans) - print("Done.", flush=True) + print("Send response...", flush=True) + r = PromptResponse(answer=ans, error=None) + self.producer.send(r, properties={"id": id}) + + print("Done.", flush=True) + + except Exception as e: + + print(f"Exception: {e}") + + print("Send error response...", flush=True) + + r = PromptResponse( + error=Error( + type = "llm-error", + message = str(e), + ), + response=None, + ) + + self.producer.send(r, properties={"id": id}) + + self.consumer.acknowledge(msg) @staticmethod def add_args(parser): diff --git a/trustgraph/model/text_completion/azure/llm.py b/trustgraph/model/text_completion/azure/llm.py index 38484820..b0ee1592 100755 --- a/trustgraph/model/text_completion/azure/llm.py +++ b/trustgraph/model/text_completion/azure/llm.py @@ -7,7 +7,7 @@ serverless endpoint service. Input is prompt, output is response. import requests import json -from .... schema import TextCompletionRequest, TextCompletionResponse +from .... schema import TextCompletionRequest, TextCompletionResponse, Error from .... schema import text_completion_request_queue from .... schema import text_completion_response_queue from .... log_level import LogLevel @@ -89,6 +89,9 @@ class Processor(ConsumerProducer): if resp.status_code == 429: raise TooManyRequests() + if resp.status_code != 200: + raise RuntimeError("LLM failure") + result = resp.json() message_content = result['choices'][0]['message']['content'] @@ -110,15 +113,49 @@ class Processor(ConsumerProducer): v.prompt ) - response = self.call_llm(prompt) + try: - print("Send response...", flush=True) + response = self.call_llm(prompt) - resp = response.replace("```json", "") - resp = response.replace("```", "") + print("Send response...", flush=True) - r = TextCompletionResponse(response=resp) - self.producer.send(r, properties={"id": id}) + resp = response.replace("```json", "") + resp = response.replace("```", "") + + r = TextCompletionResponse(response=resp) + self.producer.send(r, properties={"id": id}) + + except TooManyRequests: + + print("Send rate limit response...", flush=True) + + r = TextCompletionResponse( + error=Error( + type = "rate-limit", + message = str(e), + ) + ) + + self.producer.send(r, properties={"id": id}) + + self.consumer.acknowledge(msg) + + except Exception as e: + + print(f"Exception: {e}") + + print("Send error response...", flush=True) + + r = TextCompletionResponse( + error=Error( + type = "llm-error", + message = str(e), + ) + ) + + self.producer.send(r, properties={"id": id}) + + self.consumer.acknowledge(msg) print("Done.", flush=True) diff --git a/trustgraph/model/text_completion/bedrock/llm.py b/trustgraph/model/text_completion/bedrock/llm.py index cd504a16..4da0ef14 100755 --- a/trustgraph/model/text_completion/bedrock/llm.py +++ b/trustgraph/model/text_completion/bedrock/llm.py @@ -7,7 +7,7 @@ Input is prompt, output is response. Mistral is default. import boto3 import json -from .... schema import TextCompletionRequest, TextCompletionResponse +from .... schema import TextCompletionRequest, TextCompletionResponse, Error from .... schema import text_completion_request_queue from .... schema import text_completion_response_queue from .... log_level import LogLevel @@ -130,40 +130,81 @@ class Processor(ConsumerProducer): accept = 'application/json' contentType = 'application/json' - # FIXME: Consider catching request limits and raise TooManyRequests - # See https://boto3.amazonaws.com/v1/documentation/api/latest/guide/retries.html - response = self.bedrock.invoke_model(body=promptbody, modelId=self.model, accept=accept, contentType=contentType) - - # Mistral Response Structure - if self.model.startswith("mistral"): - response_body = json.loads(response.get("body").read()) - outputtext = response_body['outputs'][0]['text'] + try: - # Claude Response Structure - elif self.model.startswith("anthropic"): - model_response = json.loads(response["body"].read()) - outputtext = model_response['content'][0]['text'] + # FIXME: Consider catching request limits and raise TooManyRequests + # See https://boto3.amazonaws.com/v1/documentation/api/latest/guide/retries.html + response = self.bedrock.invoke_model(body=promptbody, modelId=self.model, accept=accept, contentType=contentType) - # Llama 3.1 Response Structure - elif self.model.startswith("meta"): - model_response = json.loads(response["body"].read()) - outputtext = model_response["generation"] + # Mistral Response Structure + if self.model.startswith("mistral"): + response_body = json.loads(response.get("body").read()) + outputtext = response_body['outputs'][0]['text'] - # Use Mistral as default - else: - response_body = json.loads(response.get("body").read()) - outputtext = response_body['outputs'][0]['text'] - - print(outputtext, flush=True) + # Claude Response Structure + elif self.model.startswith("anthropic"): + model_response = json.loads(response["body"].read()) + outputtext = model_response['content'][0]['text'] - resp = outputtext.replace("```json", "") - resp = outputtext.replace("```", "") - - print("Send response...", flush=True) - r = TextCompletionResponse(response=resp) - self.send(r, properties={"id": id}) + # Llama 3.1 Response Structure + elif self.model.startswith("meta"): + model_response = json.loads(response["body"].read()) + outputtext = model_response["generation"] - print("Done.", flush=True) + # Use Mistral as default + else: + response_body = json.loads(response.get("body").read()) + outputtext = response_body['outputs'][0]['text'] + + print(outputtext, flush=True) + + resp = outputtext.replace("```json", "") + resp = outputtext.replace("```", "") + + print("Send response...", flush=True) + r = TextCompletionResponse( + error=None, + response=resp + ) + + self.send(r, properties={"id": id}) + + print("Done.", flush=True) + + + # FIXME: Wrong exception, don't know what Bedrock throws + # for a rate limit + except TooManyRequests: + + print("Send rate limit response...", flush=True) + + r = TextCompletionResponse( + error=Error( + type = "rate-limit", + message = str(e), + ), + response=None, + ) + + self.producer.send(r, properties={"id": id}) + + self.consumer.acknowledge(msg) + + except Exception as e: + + print(f"Exception: {e}") + + print("Send error response...", flush=True) + + r = TextCompletionResponse( + error=Error( + type = "llm-error", + message = str(e), + ), + response=None, + ) + + self.consumer.acknowledge(msg) @staticmethod def add_args(parser): diff --git a/trustgraph/model/text_completion/claude/llm.py b/trustgraph/model/text_completion/claude/llm.py index 89c8607d..141a4050 100755 --- a/trustgraph/model/text_completion/claude/llm.py +++ b/trustgraph/model/text_completion/claude/llm.py @@ -6,11 +6,12 @@ Input is prompt, output is response. import anthropic -from .... schema import TextCompletionRequest, TextCompletionResponse +from .... schema import TextCompletionRequest, TextCompletionResponse, Error from .... schema import text_completion_request_queue from .... schema import text_completion_response_queue from .... log_level import LogLevel from .... base import ConsumerProducer +from .... exceptions import TooManyRequests module = ".".join(__name__.split(".")[1:-1]) @@ -65,33 +66,71 @@ class Processor(ConsumerProducer): prompt = v.prompt - # FIXME: Rate limits? - response = message = self.claude.messages.create( - model=self.model, - max_tokens=self.max_output, - temperature=self.temperature, - system = "You are a helpful chatbot.", - messages=[ - { - "role": "user", - "content": [ - { - "type": "text", - "text": prompt - } - ] - } - ] - ) + try: - resp = response.content[0].text - print(resp, flush=True) + # FIXME: Rate limits? + response = message = self.claude.messages.create( + model=self.model, + max_tokens=self.max_output, + temperature=self.temperature, + system = "You are a helpful chatbot.", + messages=[ + { + "role": "user", + "content": [ + { + "type": "text", + "text": prompt + } + ] + } + ] + ) - print("Send response...", flush=True) - r = TextCompletionResponse(response=resp) - self.send(r, properties={"id": id}) + resp = response.content[0].text + print(resp, flush=True) - print("Done.", flush=True) + print("Send response...", flush=True) + r = TextCompletionResponse(response=resp) + self.send(r, properties={"id": id}) + + print("Done.", flush=True) + + # FIXME: Wrong exception, don't know what this LLM throws + # for a rate limit + except TooManyRequests: + + print("Send rate limit response...", flush=True) + + r = TextCompletionResponse( + error=Error( + type = "rate-limit", + message = str(e), + ), + response=None, + ) + + self.producer.send(r, properties={"id": id}) + + self.consumer.acknowledge(msg) + + except Exception as e: + + print(f"Exception: {e}") + + print("Send error response...", flush=True) + + r = TextCompletionResponse( + error=Error( + type = "llm-error", + message = str(e), + ), + response=None, + ) + + self.producer.send(r, properties={"id": id}) + + self.consumer.acknowledge(msg) @staticmethod def add_args(parser): diff --git a/trustgraph/model/text_completion/cohere/llm.py b/trustgraph/model/text_completion/cohere/llm.py index 018a272b..42b86ec2 100755 --- a/trustgraph/model/text_completion/cohere/llm.py +++ b/trustgraph/model/text_completion/cohere/llm.py @@ -6,11 +6,12 @@ Input is prompt, output is response. import cohere -from .... schema import TextCompletionRequest, TextCompletionResponse +from .... schema import TextCompletionRequest, TextCompletionResponse, Error from .... schema import text_completion_request_queue from .... schema import text_completion_response_queue from .... log_level import LogLevel from .... base import ConsumerProducer +from .... exceptions import TooManyRequests module = ".".join(__name__.split(".")[1:-1]) @@ -61,28 +62,65 @@ class Processor(ConsumerProducer): prompt = v.prompt - # FIXME: Deal with rate limits? - output = self.cohere.chat( - model=self.model, - message=prompt, - preamble = "You are a helpful AI-assistant.", - temperature=self.temperature, - chat_history=[], - prompt_truncation='auto', - connectors=[] - ) + try: - resp = output.text - print(resp, flush=True) + output = self.cohere.chat( + model=self.model, + message=prompt, + preamble = "You are a helpful AI-assistant.", + temperature=self.temperature, + chat_history=[], + prompt_truncation='auto', + connectors=[] + ) - resp = resp.replace("```json", "") - resp = resp.replace("```", "") - - print("Send response...", flush=True) - r = TextCompletionResponse(response=resp) - self.send(r, properties={"id": id}) + resp = output.text + print(resp, flush=True) - print("Done.", flush=True) + resp = resp.replace("```json", "") + resp = resp.replace("```", "") + + print("Send response...", flush=True) + r = TextCompletionResponse(response=resp) + self.send(r, properties={"id": id}) + + print("Done.", flush=True) + + # FIXME: Wrong exception, don't know what this LLM throws + # for a rate limit + except TooManyRequests: + + print("Send rate limit response...", flush=True) + + r = TextCompletionResponse( + error=Error( + type = "rate-limit", + message = str(e), + ), + response=None, + ) + + self.producer.send(r, properties={"id": id}) + + self.consumer.acknowledge(msg) + + except Exception as e: + + print(f"Exception: {e}") + + print("Send error response...", flush=True) + + r = TextCompletionResponse( + error=Error( + type = "llm-error", + message = str(e), + ), + response=None, + ) + + self.producer.send(r, properties={"id": id}) + + self.consumer.acknowledge(msg) @staticmethod def add_args(parser): diff --git a/trustgraph/model/text_completion/ollama/llm.py b/trustgraph/model/text_completion/ollama/llm.py index a5b3d873..78a6a1af 100755 --- a/trustgraph/model/text_completion/ollama/llm.py +++ b/trustgraph/model/text_completion/ollama/llm.py @@ -7,11 +7,12 @@ Input is prompt, output is response. from langchain_community.llms import Ollama from prometheus_client import Histogram, Info, Counter -from .... schema import TextCompletionRequest, TextCompletionResponse +from .... schema import TextCompletionRequest, TextCompletionResponse, Error from .... schema import text_completion_request_queue from .... schema import text_completion_response_queue from .... log_level import LogLevel from .... base import ConsumerProducer +from .... exceptions import TooManyRequests module = ".".join(__name__.split(".")[1:-1]) @@ -66,19 +67,56 @@ class Processor(ConsumerProducer): prompt = v.prompt - # FIXME: Rate limits? - response = self.llm.invoke(prompt) + try: - print("Send response...", flush=True) + response = self.llm.invoke(prompt) - resp = response.replace("```json", "") - resp = response.replace("```", "") + print("Send response...", flush=True) - r = TextCompletionResponse(response=resp) + resp = response.replace("```json", "") + resp = response.replace("```", "") - self.send(r, properties={"id": id}) + r = TextCompletionResponse(response=resp) - print("Done.", flush=True) + self.send(r, properties={"id": id}) + + print("Done.", flush=True) + + # FIXME: Wrong exception, don't know what this LLM throws + # for a rate limit + except TooManyRequests: + + print("Send rate limit response...", flush=True) + + r = TextCompletionResponse( + error=Error( + type = "rate-limit", + message = str(e), + ), + response=None, + ) + + self.producer.send(r, properties={"id": id}) + + self.consumer.acknowledge(msg) + + except Exception as e: + + print(f"Exception: {e}") + + print("Send error response...", flush=True) + + r = TextCompletionResponse( + error=Error( + type = "llm-error", + message = str(e), + ), + response=None, + ) + + self.producer.send(r, properties={"id": id}) + + self.consumer.acknowledge(msg) @staticmethod def add_args(parser): diff --git a/trustgraph/model/text_completion/openai/llm.py b/trustgraph/model/text_completion/openai/llm.py index aba92e96..129bdb47 100755 --- a/trustgraph/model/text_completion/openai/llm.py +++ b/trustgraph/model/text_completion/openai/llm.py @@ -6,11 +6,12 @@ Input is prompt, output is response. from openai import OpenAI -from .... schema import TextCompletionRequest, TextCompletionResponse +from .... schema import TextCompletionRequest, TextCompletionResponse, Error from .... schema import text_completion_request_queue from .... schema import text_completion_response_queue from .... log_level import LogLevel from .... base import ConsumerProducer +from .... exceptions import TooManyRequests module = ".".join(__name__.split(".")[1:-1]) @@ -65,37 +66,75 @@ class Processor(ConsumerProducer): prompt = v.prompt - # FIXME: Rate limits - resp = self.openai.chat.completions.create( - model=self.model, - messages=[ - { - "role": "user", - "content": [ - { - "type": "text", - "text": prompt - } - ] + try: + + # FIXME: Rate limits + resp = self.openai.chat.completions.create( + model=self.model, + messages=[ + { + "role": "user", + "content": [ + { + "type": "text", + "text": prompt + } + ] + } + ], + temperature=self.temperature, + max_tokens=self.max_output, + top_p=1, + frequency_penalty=0, + presence_penalty=0, + response_format={ + "type": "text" } - ], - temperature=self.temperature, - max_tokens=self.max_output, - top_p=1, - frequency_penalty=0, - presence_penalty=0, - response_format={ - "type": "text" - } - ) + ) - print(resp.choices[0].message.content, flush=True) + print(resp.choices[0].message.content, flush=True) - print("Send response...", flush=True) - r = TextCompletionResponse(response=resp.choices[0].message.content) - self.send(r, properties={"id": id}) + print("Send response...", flush=True) + r = TextCompletionResponse(response=resp.choices[0].message.content) + self.send(r, properties={"id": id}) - print("Done.", flush=True) + print("Done.", flush=True) + + # FIXME: Wrong exception, don't know what this LLM throws + # for a rate limit + except TooManyRequests: + + print("Send rate limit response...", flush=True) + + r = TextCompletionResponse( + error=Error( + type = "rate-limit", + message = str(e), + ), + response=None, + ) + + self.producer.send(r, properties={"id": id}) + + self.consumer.acknowledge(msg) + + except Exception as e: + + print(f"Exception: {e}") + + print("Send error response...", flush=True) + + r = TextCompletionResponse( + error=Error( + type = "llm-error", + message = str(e), + ), + response=None, + ) + + self.producer.send(r, properties={"id": id}) + + self.consumer.acknowledge(msg) @staticmethod def add_args(parser): diff --git a/trustgraph/model/text_completion/vertexai/llm.py b/trustgraph/model/text_completion/vertexai/llm.py index f9b2de22..da870eaf 100755 --- a/trustgraph/model/text_completion/vertexai/llm.py +++ b/trustgraph/model/text_completion/vertexai/llm.py @@ -21,7 +21,7 @@ from vertexai.preview.generative_models import ( Tool, ) -from .... schema import TextCompletionRequest, TextCompletionResponse +from .... schema import TextCompletionRequest, TextCompletionResponse, Error from .... schema import text_completion_request_queue from .... schema import text_completion_response_queue from .... log_level import LogLevel @@ -136,7 +136,12 @@ class Processor(ConsumerProducer): resp = resp.replace("```", "") print("Send response...", flush=True) - r = TextCompletionResponse(response=resp) + + r = TextCompletionResponse( + error=None, + response=resp, + ) + self.producer.send(r, properties={"id": id}) print("Done.", flush=True) @@ -144,12 +149,39 @@ class Processor(ConsumerProducer): # Acknowledge successful processing of the message self.consumer.acknowledge(msg) - except google.api_core.exceptions.ResourceExhausted: + except google.api_core.exceptions.ResourceExhausted as e: - # 429 / rate limits case - raise TooManyRequests + print("Send rate limit response...", flush=True) - # Let other exceptions fall through + r = TextCompletionResponse( + error=Error( + type = "rate-limit", + message = str(e), + ), + response=None, + ) + + self.producer.send(r, properties={"id": id}) + + self.consumer.acknowledge(msg) + + except Exception as e: + + print(f"Exception: {e}") + + print("Send error response...", flush=True) + + r = TextCompletionResponse( + error=Error( + type = "llm-error", + message = str(e), + ), + response=None, + ) + + self.producer.send(r, properties={"id": id}) + + self.consumer.acknowledge(msg) @staticmethod def add_args(parser): diff --git a/trustgraph/query/graph_embeddings/milvus/service.py b/trustgraph/query/graph_embeddings/milvus/service.py index 74e0bee8..bb2eaa0d 100755 --- a/trustgraph/query/graph_embeddings/milvus/service.py +++ b/trustgraph/query/graph_embeddings/milvus/service.py @@ -5,7 +5,8 @@ entities """ from .... direct.milvus import TripleVectors -from .... schema import GraphEmbeddingsRequest, GraphEmbeddingsResponse, Value +from .... schema import GraphEmbeddingsRequest, GraphEmbeddingsResponse +from .... schema import Error, Value from .... schema import graph_embeddings_request_queue from .... schema import graph_embeddings_response_queue from .... base import ConsumerProducer @@ -47,38 +48,58 @@ class Processor(ConsumerProducer): def handle(self, msg): - v = msg.value() + try: - # Sender-produced ID - id = msg.properties()["id"] + v = msg.value() - print(f"Handling input {id}...", flush=True) + # Sender-produced ID + id = msg.properties()["id"] - entities = set() + print(f"Handling input {id}...", flush=True) - for vec in v.vectors: + entities = set() - resp = self.vecstore.search(vec, limit=v.limit) + for vec in v.vectors: - for r in resp: - ent = r["entity"]["entity"] - entities.add(ent) + resp = self.vecstore.search(vec, limit=v.limit) - # Convert set to list - entities = list(entities) + for r in resp: + ent = r["entity"]["entity"] + entities.add(ent) - ents2 = [] + # Convert set to list + entities = list(entities) - for ent in entities: - ents2.append(self.create_value(ent)) + ents2 = [] - entities = ents2 + for ent in entities: + ents2.append(self.create_value(ent)) - print("Send response...", flush=True) - r = GraphEmbeddingsResponse(entities=entities) - self.producer.send(r, properties={"id": id}) + entities = ents2 - print("Done.", flush=True) + print("Send response...", flush=True) + r = GraphEmbeddingsResponse(entities=entities, error=None) + self.producer.send(r, properties={"id": id}) + + print("Done.", flush=True) + + except Exception as e: + + print(f"Exception: {e}") + + print("Send error response...", flush=True) + + r = GraphEmbeddingsResponse( + error=Error( + type = "llm-error", + message = str(e), + ), + response=None, + ) + + self.producer.send(r, properties={"id": id}) + + self.consumer.acknowledge(msg) @staticmethod def add_args(parser): diff --git a/trustgraph/query/triples/cassandra/service.py b/trustgraph/query/triples/cassandra/service.py index c40a50fc..5e1e0e3e 100755 --- a/trustgraph/query/triples/cassandra/service.py +++ b/trustgraph/query/triples/cassandra/service.py @@ -5,7 +5,7 @@ null. Output is a list of triples. """ from .... direct.cassandra import TrustGraph -from .... schema import TriplesQueryRequest, TriplesQueryResponse +from .... schema import TriplesQueryRequest, TriplesQueryResponse, Error from .... schema import Value, Triple from .... schema import triples_request_queue from .... schema import triples_response_queue @@ -48,90 +48,110 @@ class Processor(ConsumerProducer): def handle(self, msg): - v = msg.value() + try: - # Sender-produced ID - id = msg.properties()["id"] + v = msg.value() - print(f"Handling input {id}...", flush=True) + # Sender-produced ID + id = msg.properties()["id"] - triples = [] + print(f"Handling input {id}...", flush=True) - if v.s is not None: - if v.p is not None: - if v.o is not None: - resp = self.tg.get_spo( - v.s.value, v.p.value, v.o.value, - limit=v.limit - ) - triples.append((v.s.value, v.p.value, v.o.value)) + triples = [] + + if v.s is not None: + if v.p is not None: + if v.o is not None: + resp = self.tg.get_spo( + v.s.value, v.p.value, v.o.value, + limit=v.limit + ) + triples.append((v.s.value, v.p.value, v.o.value)) + else: + resp = self.tg.get_sp( + v.s.value, v.p.value, + limit=v.limit + ) + for t in resp: + triples.append((v.s.value, v.p.value, t.o)) else: - resp = self.tg.get_sp( - v.s.value, v.p.value, - limit=v.limit - ) - for t in resp: - triples.append((v.s.value, v.p.value, t.o)) + if v.o is not None: + resp = self.tg.get_os( + v.o.value, v.s.value, + limit=v.limit + ) + for t in resp: + triples.append((v.s.value, t.p, v.o.value)) + else: + resp = self.tg.get_s( + v.s.value, + limit=v.limit + ) + for t in resp: + triples.append((v.s.value, t.p, t.o)) else: - if v.o is not None: - resp = self.tg.get_os( - v.o.value, v.s.value, - limit=v.limit - ) - for t in resp: - triples.append((v.s.value, t.p, v.o.value)) + if v.p is not None: + if v.o is not None: + resp = self.tg.get_po( + v.p.value, v.o.value, + limit=v.limit + ) + for t in resp: + triples.append((t.s, v.p.value, v.o.value)) + else: + resp = self.tg.get_p( + v.p.value, + limit=v.limit + ) + for t in resp: + triples.append((t.s, v.p.value, t.o)) else: - resp = self.tg.get_s( - v.s.value, - limit=v.limit - ) - for t in resp: - triples.append((v.s.value, t.p, t.o)) - else: - if v.p is not None: - if v.o is not None: - resp = self.tg.get_po( - v.p.value, v.o.value, - limit=v.limit - ) - for t in resp: - triples.append((t.s, v.p.value, v.o.value)) - else: - resp = self.tg.get_p( - v.p.value, - limit=v.limit - ) - for t in resp: - triples.append((t.s, v.p.value, t.o)) - else: - if v.o is not None: - resp = self.tg.get_o( - v.o.value, - limit=v.limit - ) - for t in resp: - triples.append((t.s, t.p, v.o.value)) - else: - resp = self.tg.get_all( - limit=v.limit - ) - for t in resp: - triples.append((t.s, t.p, t.o)) + if v.o is not None: + resp = self.tg.get_o( + v.o.value, + limit=v.limit + ) + for t in resp: + triples.append((t.s, t.p, v.o.value)) + else: + resp = self.tg.get_all( + limit=v.limit + ) + for t in resp: + triples.append((t.s, t.p, t.o)) - triples = [ - Triple( - s=self.create_value(t[0]), - p=self.create_value(t[1]), - o=self.create_value(t[2]) + triples = [ + Triple( + s=self.create_value(t[0]), + p=self.create_value(t[1]), + o=self.create_value(t[2]) + ) + for t in triples + ] + + print("Send response...", flush=True) + r = TriplesQueryResponse(triples=triples, error=None) + self.producer.send(r, properties={"id": id}) + + print("Done.", flush=True) + + except Exception as e: + + print(f"Exception: {e}") + + print("Send error response...", flush=True) + + r = TriplesQueryResponse( + error=Error( + type = "llm-error", + message = str(e), + ), + response=None, ) - for t in triples - ] - print("Send response...", flush=True) - r = TriplesQueryResponse(triples=triples) - self.producer.send(r, properties={"id": id}) + self.producer.send(r, properties={"id": id}) - print("Done.", flush=True) + self.consumer.acknowledge(msg) @staticmethod def add_args(parser): diff --git a/trustgraph/query/triples/neo4j/service.py b/trustgraph/query/triples/neo4j/service.py index 5fca61c3..03e574cd 100755 --- a/trustgraph/query/triples/neo4j/service.py +++ b/trustgraph/query/triples/neo4j/service.py @@ -6,7 +6,7 @@ null. Output is a list of triples. from neo4j import GraphDatabase -from .... schema import TriplesQueryRequest, TriplesQueryResponse +from .... schema import TriplesQueryRequest, TriplesQueryResponse, Error from .... schema import Value, Triple from .... schema import triples_request_queue from .... schema import triples_response_queue @@ -57,245 +57,265 @@ class Processor(ConsumerProducer): def handle(self, msg): - v = msg.value() + try: - # Sender-produced ID - id = msg.properties()["id"] + v = msg.value() - print(f"Handling input {id}...", flush=True) + # Sender-produced ID + id = msg.properties()["id"] - triples = [] + print(f"Handling input {id}...", flush=True) - if v.s is not None: - if v.p is not None: - if v.o is not None: + triples = [] - # SPO + if v.s is not None: + if v.p is not None: + if v.o is not None: - records, summary, keys = self.io.execute_query( - "MATCH (src:Node {uri: $src})-[rel:Rel {uri: $rel}]->(dest:Literal {value: $value}) " - "RETURN $src as src", - src=v.s.value, rel=v.p.value, value=v.o.value, - database_=self.db, - ) + # SPO - for rec in records: - triples.append((v.s.value, v.p.value, v.o.value)) - - records, summary, keys = self.io.execute_query( - "MATCH (src:Node {uri: $src})-[rel:Rel {uri: $rel}]->(dest:Node {uri: $uri}) " - "RETURN $src as src", - src=v.s.value, rel=v.p.value, uri=v.o.value, - database_=self.db, - ) + records, summary, keys = self.io.execute_query( + "MATCH (src:Node {uri: $src})-[rel:Rel {uri: $rel}]->(dest:Literal {value: $value}) " + "RETURN $src as src", + src=v.s.value, rel=v.p.value, value=v.o.value, + database_=self.db, + ) - for rec in records: - triples.append((v.s.value, v.p.value, v.o.value)) + for rec in records: + triples.append((v.s.value, v.p.value, v.o.value)) + + records, summary, keys = self.io.execute_query( + "MATCH (src:Node {uri: $src})-[rel:Rel {uri: $rel}]->(dest:Node {uri: $uri}) " + "RETURN $src as src", + src=v.s.value, rel=v.p.value, uri=v.o.value, + database_=self.db, + ) + + for rec in records: + triples.append((v.s.value, v.p.value, v.o.value)) + + else: + + # SP + + records, summary, keys = self.io.execute_query( + "MATCH (src:Node {uri: $src})-[rel:Rel {uri: $rel}]->(dest:Literal) " + "RETURN dest.value as dest", + src=v.s.value, rel=v.p.value, + database_=self.db, + ) + + for rec in records: + data = rec.data() + triples.append((v.s.value, v.p.value, data["dest"])) + + records, summary, keys = self.io.execute_query( + "MATCH (src:Node {uri: $src})-[rel:Rel {uri: $rel}]->(dest:Node) " + "RETURN dest.uri as dest", + src=v.s.value, rel=v.p.value, + database_=self.db, + ) + + for rec in records: + data = rec.data() + triples.append((v.s.value, v.p.value, data["dest"])) else: - # SP - - records, summary, keys = self.io.execute_query( - "MATCH (src:Node {uri: $src})-[rel:Rel {uri: $rel}]->(dest:Literal) " - "RETURN dest.value as dest", - src=v.s.value, rel=v.p.value, - database_=self.db, - ) + if v.o is not None: - for rec in records: - data = rec.data() - triples.append((v.s.value, v.p.value, data["dest"])) - - records, summary, keys = self.io.execute_query( - "MATCH (src:Node {uri: $src})-[rel:Rel {uri: $rel}]->(dest:Node) " - "RETURN dest.uri as dest", - src=v.s.value, rel=v.p.value, - database_=self.db, - ) + # SO + + records, summary, keys = self.io.execute_query( + "MATCH (src:Node {uri: $src})-[rel:Rel]->(dest:Literal {value: $value}) " + "RETURN rel.uri as rel", + src=v.s.value, value=v.o.value, + database_=self.db, + ) + + for rec in records: + data = rec.data() + triples.append((v.s.value, data["rel"], v.o.value)) + + records, summary, keys = self.io.execute_query( + "MATCH (src:Node {uri: $src})-[rel:Rel]->(dest:Node {uri: $uri}) " + "RETURN rel.uri as rel", + src=v.s.value, uri=v.o.value, + database_=self.db, + ) + + for rec in records: + data = rec.data() + triples.append((v.s.value, data["rel"], v.o.value)) + + else: + + # S + + records, summary, keys = self.io.execute_query( + "MATCH (src:Node {uri: $src})-[rel:Rel]->(dest:Literal) " + "RETURN rel.uri as rel, dest.value as dest", + src=v.s.value, + database_=self.db, + ) + + for rec in records: + data = rec.data() + triples.append((v.s.value, data["rel"], data["dest"])) + + records, summary, keys = self.io.execute_query( + "MATCH (src:Node {uri: $src})-[rel:Rel]->(dest:Node) " + "RETURN rel.uri as rel, dest.uri as dest", + src=v.s.value, + database_=self.db, + ) + + for rec in records: + data = rec.data() + triples.append((v.s.value, data["rel"], data["dest"])) - for rec in records: - data = rec.data() - triples.append((v.s.value, v.p.value, data["dest"])) else: - if v.o is not None: + if v.p is not None: - # SO + if v.o is not None: - records, summary, keys = self.io.execute_query( - "MATCH (src:Node {uri: $src})-[rel:Rel]->(dest:Literal {value: $value}) " - "RETURN rel.uri as rel", - src=v.s.value, value=v.o.value, - database_=self.db, - ) + # PO - for rec in records: - data = rec.data() - triples.append((v.s.value, data["rel"], v.o.value)) - - records, summary, keys = self.io.execute_query( - "MATCH (src:Node {uri: $src})-[rel:Rel]->(dest:Node {uri: $uri}) " - "RETURN rel.uri as rel", - src=v.s.value, uri=v.o.value, - database_=self.db, - ) + records, summary, keys = self.io.execute_query( + "MATCH (src:Node)-[rel:Rel {uri: $uri}]->(dest:Literal {value: $value}) " + "RETURN src.uri as src", + uri=v.p.value, value=v.o.value, + database_=self.db, + ) - for rec in records: - data = rec.data() - triples.append((v.s.value, data["rel"], v.o.value)) + for rec in records: + data = rec.data() + triples.append((data["src"], v.p.value, v.o.value)) + + records, summary, keys = self.io.execute_query( + "MATCH (src:Node)-[rel:Rel {uri: $uri}]->(dest:Node {uri: $uri}) " + "RETURN src.uri as src", + uri=v.p.value, dest=v.o.value, + database_=self.db, + ) + + for rec in records: + data = rec.data() + triples.append((data["src"], v.p.value, v.o.value)) + + else: + + # P + + records, summary, keys = self.io.execute_query( + "MATCH (src:Node)-[rel:Rel {uri: $uri}]->(dest:Literal) " + "RETURN src.uri as src, dest.value as dest", + uri=v.p.value, + database_=self.db, + ) + + for rec in records: + data = rec.data() + triples.append((data["src"], v.p.value, data["dest"])) + + records, summary, keys = self.io.execute_query( + "MATCH (src:Node)-[rel:Rel {uri: $uri}]->(dest:Node) " + "RETURN src.uri as src, dest.uri as dest", + uri=v.p.value, + database_=self.db, + ) + + for rec in records: + data = rec.data() + triples.append((data["src"], v.p.value, data["dest"])) else: - # S + if v.o is not None: - records, summary, keys = self.io.execute_query( - "MATCH (src:Node {uri: $src})-[rel:Rel]->(dest:Literal) " - "RETURN rel.uri as rel, dest.value as dest", - src=v.s.value, - database_=self.db, - ) + # O - for rec in records: - data = rec.data() - triples.append((v.s.value, data["rel"], data["dest"])) - - records, summary, keys = self.io.execute_query( - "MATCH (src:Node {uri: $src})-[rel:Rel]->(dest:Node) " - "RETURN rel.uri as rel, dest.uri as dest", - src=v.s.value, - database_=self.db, - ) + records, summary, keys = self.io.execute_query( + "MATCH (src:Node)-[rel:Rel]->(dest:Literal {value: $value}) " + "RETURN src.uri as src, rel.uri as rel", + value=v.o.value, + database_=self.db, + ) - for rec in records: - data = rec.data() - triples.append((v.s.value, data["rel"], data["dest"])) + for rec in records: + data = rec.data() + triples.append((data["src"], data["rel"], v.o.value)) + records, summary, keys = self.io.execute_query( + "MATCH (src:Node)-[rel:Rel]->(dest:Node {uri: $uri}) " + "RETURN src.uri as src, rel.uri as rel", + uri=v.o.value, + database_=self.db, + ) - else: + for rec in records: + data = rec.data() + triples.append((data["src"], data["rel"], v.o.value)) - if v.p is not None: + else: - if v.o is not None: + # * - # PO + records, summary, keys = self.io.execute_query( + "MATCH (src:Node)-[rel:Rel]->(dest:Literal) " + "RETURN src.uri as src, rel.uri as rel, dest.value as dest", + database_=self.db, + ) - records, summary, keys = self.io.execute_query( - "MATCH (src:Node)-[rel:Rel {uri: $uri}]->(dest:Literal {value: $value}) " - "RETURN src.uri as src", - uri=v.p.value, value=v.o.value, - database_=self.db, - ) + for rec in records: + data = rec.data() + triples.append((data["src"], data["rel"], data["dest"])) - for rec in records: - data = rec.data() - triples.append((data["src"], v.p.value, v.o.value)) - - records, summary, keys = self.io.execute_query( - "MATCH (src:Node)-[rel:Rel {uri: $uri}]->(dest:Node {uri: $uri}) " - "RETURN src.uri as src", - uri=v.p.value, dest=v.o.value, - database_=self.db, - ) + records, summary, keys = self.io.execute_query( + "MATCH (src:Node)-[rel:Rel]->(dest:Node) " + "RETURN src.uri as src, rel.uri as rel, dest.uri as dest", + database_=self.db, + ) - for rec in records: - data = rec.data() - triples.append((data["src"], v.p.value, v.o.value)) + for rec in records: + data = rec.data() + triples.append((data["src"], data["rel"], data["dest"])) - else: + triples = [ + Triple( + s=self.create_value(t[0]), + p=self.create_value(t[1]), + o=self.create_value(t[2]) + ) + for t in triples + ] - # P + print("Send response...", flush=True) + r = TriplesQueryResponse(triples=triples) + self.producer.send(r, properties={"id": id}) - records, summary, keys = self.io.execute_query( - "MATCH (src:Node)-[rel:Rel {uri: $uri}]->(dest:Literal) " - "RETURN src.uri as src, dest.value as dest", - uri=v.p.value, - database_=self.db, - ) + print("Done.", flush=True) - for rec in records: - data = rec.data() - triples.append((data["src"], v.p.value, data["dest"])) - - records, summary, keys = self.io.execute_query( - "MATCH (src:Node)-[rel:Rel {uri: $uri}]->(dest:Node) " - "RETURN src.uri as src, dest.uri as dest", - uri=v.p.value, - database_=self.db, - ) + except Exception as e: - for rec in records: - data = rec.data() - triples.append((data["src"], v.p.value, data["dest"])) + print(f"Exception: {e}") - else: + print("Send error response...", flush=True) - if v.o is not None: - - # O - - records, summary, keys = self.io.execute_query( - "MATCH (src:Node)-[rel:Rel]->(dest:Literal {value: $value}) " - "RETURN src.uri as src, rel.uri as rel", - value=v.o.value, - database_=self.db, - ) - - for rec in records: - data = rec.data() - triples.append((data["src"], data["rel"], v.o.value)) - - records, summary, keys = self.io.execute_query( - "MATCH (src:Node)-[rel:Rel]->(dest:Node {uri: $uri}) " - "RETURN src.uri as src, rel.uri as rel", - uri=v.o.value, - database_=self.db, - ) - - for rec in records: - data = rec.data() - triples.append((data["src"], data["rel"], v.o.value)) - - else: - - # * - - records, summary, keys = self.io.execute_query( - "MATCH (src:Node)-[rel:Rel]->(dest:Literal) " - "RETURN src.uri as src, rel.uri as rel, dest.value as dest", - database_=self.db, - ) - - for rec in records: - data = rec.data() - triples.append((data["src"], data["rel"], data["dest"])) - - records, summary, keys = self.io.execute_query( - "MATCH (src:Node)-[rel:Rel]->(dest:Node) " - "RETURN src.uri as src, rel.uri as rel, dest.uri as dest", - database_=self.db, - ) - - for rec in records: - data = rec.data() - triples.append((data["src"], data["rel"], data["dest"])) - - triples = [ - Triple( - s=self.create_value(t[0]), - p=self.create_value(t[1]), - o=self.create_value(t[2]) + r = TriplesQueryResponse( + error=Error( + type = "llm-error", + message = str(e), + ), + response=None, ) - for t in triples - ] - print("Send response...", flush=True) - r = TriplesQueryResponse(triples=triples) - self.producer.send(r, properties={"id": id}) - - print("Done.", flush=True) + self.producer.send(r, properties={"id": id}) + self.consumer.acknowledge(msg) + @staticmethod def add_args(parser): diff --git a/trustgraph/retrieval/graph_rag/rag.py b/trustgraph/retrieval/graph_rag/rag.py index 6ccd00a6..98697d4f 100755 --- a/trustgraph/retrieval/graph_rag/rag.py +++ b/trustgraph/retrieval/graph_rag/rag.py @@ -4,7 +4,7 @@ Simple RAG service, performs query using graph RAG an LLM. Input is query, output is response. """ -from ... schema import GraphRagQuery, GraphRagResponse +from ... schema import GraphRagQuery, GraphRagResponse, Error from ... schema import graph_rag_request_queue, graph_rag_response_queue from ... schema import prompt_request_queue from ... schema import prompt_response_queue @@ -99,21 +99,40 @@ class Processor(ConsumerProducer): def handle(self, msg): - v = msg.value() + try: - # Sender-produced ID + v = msg.value() - id = msg.properties()["id"] + # Sender-produced ID + id = msg.properties()["id"] - print(f"Handling input {id}...", flush=True) + print(f"Handling input {id}...", flush=True) - response = self.rag.query(v.query) + response = self.rag.query(v.query) - print("Send response...", flush=True) - r = GraphRagResponse(response = response) - self.producer.send(r, properties={"id": id}) + print("Send response...", flush=True) + r = GraphRagResponse(response = response, error=None) + self.producer.send(r, properties={"id": id}) - print("Done.", flush=True) + print("Done.", flush=True) + + except Exception as e: + + print(f"Exception: {e}") + + print("Send error response...", flush=True) + + r = GraphRagResponse( + error=Error( + type = "llm-error", + message = str(e), + ), + response=None, + ) + + self.producer.send(r, properties={"id": id}) + + self.consumer.acknowledge(msg) @staticmethod def add_args(parser): diff --git a/trustgraph/schema.py b/trustgraph/schema.py index 5911064a..6b711677 100644 --- a/trustgraph/schema.py +++ b/trustgraph/schema.py @@ -8,6 +8,12 @@ def topic(topic, kind='persistent', tenant='tg', namespace='flow'): ############################################################################ +class Error(Record): + type = String() + message = String() + +############################################################################ + class Value(Record): value = String() is_uri = Boolean() @@ -78,6 +84,7 @@ class GraphEmbeddingsRequest(Record): limit = Integer() class GraphEmbeddingsResponse(Record): + error = Error() entities = Array(Value()) graph_embeddings_request_queue = topic( @@ -110,6 +117,7 @@ class TriplesQueryRequest(Record): limit = Integer() class TriplesQueryResponse(Record): + error = Error() triples = Array(Triple()) triples_request_queue = topic( @@ -131,6 +139,7 @@ class TextCompletionRequest(Record): prompt = String() class TextCompletionResponse(Record): + error = Error() response = String() text_completion_request_queue = topic( @@ -148,6 +157,7 @@ class EmbeddingsRequest(Record): text = String() class EmbeddingsResponse(Record): + error = Error() vectors = Array(Array(Double())) embeddings_request_queue = topic( @@ -165,6 +175,7 @@ class GraphRagQuery(Record): query = String() class GraphRagResponse(Record): + error = Error() response = String() graph_rag_request_queue = topic( @@ -207,6 +218,7 @@ class PromptRequest(Record): kg = Array(Fact()) class PromptResponse(Record): + error = Error() answer = String() definitions = Array(Definition()) relationships = Array(Relationship())