From a9197d11eeeeb17b12174c9b1a9361e9041170cb Mon Sep 17 00:00:00 2001 From: cybermaggedon Date: Tue, 22 Apr 2025 20:21:38 +0100 Subject: [PATCH] Feature/configure flows (#345) - Keeps processing in different flows separate so that data can go to different stores / collections etc. - Potentially supports different processing flows - Tidies the processing API with common base-classes for e.g. LLMs, and automatic configuration of 'clients' to use the right queue names in a flow --- Makefile | 8 + tests/test-agent | 6 +- tests/test-doc-rag | 7 +- tests/test-embeddings | 9 +- tests/test-graph-rag | 15 +- tests/test-llm | 11 +- tests/test-prompt-extraction | 13 +- tests/test-prompt-question | 9 +- tests/test-triples | 4 +- trustgraph-base/trustgraph/base/__init__.py | 27 +- .../trustgraph/base/agent_client.py | 39 +++ .../trustgraph/base/agent_service.py | 100 ++++++ .../trustgraph/base/async_processor.py | 254 ++++++++++++++ .../trustgraph/base/base_processor.py | 210 ------------ trustgraph-base/trustgraph/base/consumer.py | 224 ++++++------ .../trustgraph/base/consumer_producer.py | 62 ---- .../trustgraph/base/consumer_spec.py | 36 ++ .../base/document_embeddings_client.py | 38 ++ .../base/document_embeddings_query_service.py | 84 +++++ .../base/document_embeddings_store_service.py | 50 +++ .../trustgraph/base/embeddings_client.py | 31 ++ .../trustgraph/base/embeddings_service.py | 90 +++++ trustgraph-base/trustgraph/base/flow.py | 32 ++ .../trustgraph/base/flow_processor.py | 115 +++++++ .../base/graph_embeddings_client.py | 45 +++ .../base/graph_embeddings_query_service.py | 84 +++++ .../base/graph_embeddings_store_service.py | 50 +++ .../trustgraph/base/graph_rag_client.py | 33 ++ .../trustgraph/base/llm_service.py | 114 ++++++ trustgraph-base/trustgraph/base/metrics.py | 82 +++++ trustgraph-base/trustgraph/base/producer.py | 93 ++--- .../trustgraph/base/producer_spec.py | 25 ++ .../trustgraph/base/prompt_client.py | 93 +++++ trustgraph-base/trustgraph/base/publisher.py | 42 ++- trustgraph-base/trustgraph/base/pubsub.py | 80 +++++ .../trustgraph/base/request_response_spec.py | 136 ++++++++ .../trustgraph/base/setting_spec.py | 19 + trustgraph-base/trustgraph/base/spec.py | 4 + trustgraph-base/trustgraph/base/subscriber.py | 100 ++++-- .../trustgraph/base/subscriber_spec.py | 30 ++ .../trustgraph/base/text_completion_client.py | 30 ++ .../trustgraph/base/triples_client.py | 61 ++++ .../trustgraph/base/triples_query_service.py | 82 +++++ .../trustgraph/base/triples_store_service.py | 47 +++ trustgraph-base/trustgraph/schema/agent.py | 7 - trustgraph-base/trustgraph/schema/config.py | 2 +- .../trustgraph/schema/documents.py | 15 - trustgraph-base/trustgraph/schema/flows.py | 66 ++++ .../model/text_completion/bedrock/llm.py | 2 +- .../trustgraph/embeddings/hf/hf.py | 76 +--- .../trustgraph/agent/react/agent_manager.py | 25 +- .../trustgraph/agent/react/service.py | 188 ++++------ .../trustgraph/agent/react/tools.py | 16 +- .../trustgraph/chunking/recursive/chunker.py | 61 ++-- .../trustgraph/chunking/token/chunker.py | 61 ++-- .../trustgraph/config/service/config.py | 215 ++++++++++++ .../trustgraph/config/service/service.py | 324 +++++------------- .../decoding/mistral_ocr/processor.py | 23 +- .../trustgraph/decoding/pdf/pdf_decoder.py | 54 +-- trustgraph-flow/trustgraph/document_rag.py | 153 --------- .../document_embeddings/embeddings.py | 93 +++-- .../embeddings/fastembed/processor.py | 60 +--- .../embeddings/graph_embeddings/embeddings.py | 86 ++--- .../trustgraph/embeddings/ollama/processor.py | 2 +- .../trustgraph/external/wikipedia/service.py | 2 +- .../extract/kg/definitions/extract.py | 155 ++++----- .../extract/kg/relationships/extract.py | 122 +++---- .../trustgraph/extract/kg/topics/extract.py | 2 +- trustgraph-flow/trustgraph/gateway/agent.py | 1 - .../gateway/document_embeddings_load.py | 7 +- .../gateway/document_embeddings_stream.py | 11 +- .../trustgraph/gateway/endpoint.py | 4 - .../gateway/graph_embeddings_load.py | 7 +- .../gateway/graph_embeddings_stream.py | 11 +- trustgraph-flow/trustgraph/gateway/metrics.py | 1 - trustgraph-flow/trustgraph/gateway/mux.py | 1 - .../trustgraph/gateway/requestor.py | 22 +- trustgraph-flow/trustgraph/gateway/sender.py | 9 +- trustgraph-flow/trustgraph/gateway/service.py | 3 +- .../trustgraph/gateway/triples_load.py | 7 +- .../trustgraph/gateway/triples_stream.py | 5 +- trustgraph-flow/trustgraph/graph_rag.py | 295 ---------------- .../trustgraph/librarian/service.py | 2 +- .../trustgraph/metering/counter.py | 3 +- .../model/prompt/generic/service.py | 2 +- .../model/prompt/template/prompt_manager.py | 17 +- .../model/prompt/template/service.py | 133 ++++--- .../model/text_completion/azure/llm.py | 2 +- .../model/text_completion/azure_openai/llm.py | 2 +- .../model/text_completion/claude/llm.py | 2 +- .../model/text_completion/cohere/llm.py | 2 +- .../text_completion/googleaistudio/llm.py | 2 +- .../model/text_completion/llamafile/llm.py | 2 +- .../model/text_completion/lmstudio/llm.py | 2 +- .../model/text_completion/mistral/llm.py | 2 +- .../model/text_completion/ollama/llm.py | 2 +- .../model/text_completion/openai/llm.py | 2 +- .../query/doc_embeddings/milvus/service.py | 2 +- .../query/doc_embeddings/pinecone/service.py | 2 +- .../query/doc_embeddings/qdrant/service.py | 70 +--- .../query/graph_embeddings/milvus/service.py | 2 +- .../graph_embeddings/pinecone/service.py | 2 +- .../query/graph_embeddings/qdrant/service.py | 72 +--- .../query/triples/cassandra/service.py | 121 +++---- .../query/triples/falkordb/service.py | 2 +- .../query/triples/memgraph/service.py | 2 +- .../trustgraph/query/triples/neo4j/service.py | 2 +- .../retrieval/document_rag/document_rag.py | 94 +++++ .../trustgraph/retrieval/document_rag/rag.py | 191 ++++------- .../retrieval/graph_rag/graph_rag.py | 218 ++++++++++++ .../trustgraph/retrieval/graph_rag/rag.py | 228 +++++------- .../storage/doc_embeddings/milvus/write.py | 2 +- .../storage/doc_embeddings/pinecone/write.py | 2 +- .../storage/doc_embeddings/qdrant/write.py | 39 +-- .../storage/graph_embeddings/milvus/write.py | 2 +- .../graph_embeddings/pinecone/write.py | 2 +- .../storage/graph_embeddings/qdrant/write.py | 38 +- .../storage/object_embeddings/milvus/write.py | 2 +- .../storage/rows/cassandra/write.py | 2 +- .../storage/triples/cassandra/write.py | 41 +-- .../storage/triples/falkordb/write.py | 2 +- .../storage/triples/memgraph/write.py | 2 +- .../trustgraph/storage/triples/neo4j/write.py | 2 +- .../trustgraph/decoding/ocr/pdf_decoder.py | 2 +- .../model/text_completion/vertexai/llm.py | 138 ++------ 125 files changed, 3751 insertions(+), 2628 deletions(-) create mode 100644 trustgraph-base/trustgraph/base/agent_client.py create mode 100644 trustgraph-base/trustgraph/base/agent_service.py create mode 100644 trustgraph-base/trustgraph/base/async_processor.py delete mode 100644 trustgraph-base/trustgraph/base/base_processor.py delete mode 100644 trustgraph-base/trustgraph/base/consumer_producer.py create mode 100644 trustgraph-base/trustgraph/base/consumer_spec.py create mode 100644 trustgraph-base/trustgraph/base/document_embeddings_client.py create mode 100644 trustgraph-base/trustgraph/base/document_embeddings_query_service.py create mode 100644 trustgraph-base/trustgraph/base/document_embeddings_store_service.py create mode 100644 trustgraph-base/trustgraph/base/embeddings_client.py create mode 100644 trustgraph-base/trustgraph/base/embeddings_service.py create mode 100644 trustgraph-base/trustgraph/base/flow.py create mode 100644 trustgraph-base/trustgraph/base/flow_processor.py create mode 100644 trustgraph-base/trustgraph/base/graph_embeddings_client.py create mode 100644 trustgraph-base/trustgraph/base/graph_embeddings_query_service.py create mode 100644 trustgraph-base/trustgraph/base/graph_embeddings_store_service.py create mode 100644 trustgraph-base/trustgraph/base/graph_rag_client.py create mode 100644 trustgraph-base/trustgraph/base/llm_service.py create mode 100644 trustgraph-base/trustgraph/base/metrics.py create mode 100644 trustgraph-base/trustgraph/base/producer_spec.py create mode 100644 trustgraph-base/trustgraph/base/prompt_client.py create mode 100644 trustgraph-base/trustgraph/base/pubsub.py create mode 100644 trustgraph-base/trustgraph/base/request_response_spec.py create mode 100644 trustgraph-base/trustgraph/base/setting_spec.py create mode 100644 trustgraph-base/trustgraph/base/spec.py create mode 100644 trustgraph-base/trustgraph/base/subscriber_spec.py create mode 100644 trustgraph-base/trustgraph/base/text_completion_client.py create mode 100644 trustgraph-base/trustgraph/base/triples_client.py create mode 100644 trustgraph-base/trustgraph/base/triples_query_service.py create mode 100644 trustgraph-base/trustgraph/base/triples_store_service.py create mode 100644 trustgraph-base/trustgraph/schema/flows.py create mode 100644 trustgraph-flow/trustgraph/config/service/config.py delete mode 100644 trustgraph-flow/trustgraph/document_rag.py delete mode 100644 trustgraph-flow/trustgraph/graph_rag.py create mode 100644 trustgraph-flow/trustgraph/retrieval/document_rag/document_rag.py create mode 100644 trustgraph-flow/trustgraph/retrieval/graph_rag/graph_rag.py diff --git a/Makefile b/Makefile index 4f4de9d2..1899e602 100644 --- a/Makefile +++ b/Makefile @@ -60,6 +60,14 @@ container: update-package-versions ${DOCKER} build -f containers/Containerfile.ocr \ -t ${CONTAINER_BASE}/trustgraph-ocr:${VERSION} . +some-containers: + ${DOCKER} build -f containers/Containerfile.base \ + -t ${CONTAINER_BASE}/trustgraph-base:${VERSION} . + ${DOCKER} build -f containers/Containerfile.flow \ + -t ${CONTAINER_BASE}/trustgraph-flow:${VERSION} . + ${DOCKER} build -f containers/Containerfile.vertexai \ + -t ${CONTAINER_BASE}/trustgraph-vertexai:${VERSION} . + basic-containers: update-package-versions ${DOCKER} build -f containers/Containerfile.base \ -t ${CONTAINER_BASE}/trustgraph-base:${VERSION} . diff --git a/tests/test-agent b/tests/test-agent index 4782bbae..b1420098 100755 --- a/tests/test-agent +++ b/tests/test-agent @@ -20,7 +20,11 @@ def output(text, prefix="> ", width=78): ) print(out) -p = AgentClient(pulsar_host="pulsar://localhost:6650") +p = AgentClient( + pulsar_host="pulsar://pulsar:6650", + input_queue = "non-persistent://tg/request/agent:0000", + output_queue = "non-persistent://tg/response/agent:0000", +) q = "How many cats does Mark have? Calculate that number raised to 0.4 power. Is that number lower than the numeric part of the mission identifier of the Space Shuttle Challenger on its last mission? If so, give me an apple pie recipe, otherwise return a poem about cheese." diff --git a/tests/test-doc-rag b/tests/test-doc-rag index 718157b6..b7382bf5 100755 --- a/tests/test-doc-rag +++ b/tests/test-doc-rag @@ -3,7 +3,12 @@ import pulsar from trustgraph.clients.document_rag_client import DocumentRagClient -rag = DocumentRagClient(pulsar_host="pulsar://localhost:6650") +rag = DocumentRagClient( + pulsar_host="pulsar://localhost:6650", + subscriber="test1", + input_queue = "non-persistent://tg/request/document-rag:default", + output_queue = "non-persistent://tg/response/document-rag:default", +) query=""" What was the cause of the space shuttle disaster?""" diff --git a/tests/test-embeddings b/tests/test-embeddings index 3855fcf0..5fcd31e6 100755 --- a/tests/test-embeddings +++ b/tests/test-embeddings @@ -3,7 +3,12 @@ import pulsar from trustgraph.clients.embeddings_client import EmbeddingsClient -embed = EmbeddingsClient(pulsar_host="pulsar://localhost:6650") +embed = EmbeddingsClient( + pulsar_host="pulsar://pulsar:6650", + input_queue="non-persistent://tg/request/embeddings:default", + output_queue="non-persistent://tg/response/embeddings:default", + subscriber="test1", +) prompt="Write a funny limerick about a llama" @@ -11,5 +16,3 @@ resp = embed.request(prompt) print(resp) - - diff --git a/tests/test-graph-rag b/tests/test-graph-rag index 036f73f4..b62f890c 100755 --- a/tests/test-graph-rag +++ b/tests/test-graph-rag @@ -3,11 +3,18 @@ import pulsar from trustgraph.clients.graph_rag_client import GraphRagClient -rag = GraphRagClient(pulsar_host="pulsar://localhost:6650") +rag = GraphRagClient( + pulsar_host="pulsar://localhost:6650", + subscriber="test1", + input_queue = "non-persistent://tg/request/graph-rag:default", + output_queue = "non-persistent://tg/response/graph-rag:default", +) -query=""" -This knowledge graph describes the Space Shuttle disaster. -Present 20 facts which are present in the knowledge graph.""" +#query=""" +#This knowledge graph describes the Space Shuttle disaster. +#Present 20 facts which are present in the knowledge graph.""" + +query = "How many cats does Mark have?" resp = rag.request(query) diff --git a/tests/test-llm b/tests/test-llm index 4e86387a..aaae30a6 100755 --- a/tests/test-llm +++ b/tests/test-llm @@ -3,14 +3,17 @@ import pulsar from trustgraph.clients.llm_client import LlmClient -llm = LlmClient(pulsar_host="pulsar://localhost:6650") +llm = LlmClient( + pulsar_host="pulsar://pulsar:6650", + input_queue="non-persistent://tg/request/text-completion:default", + output_queue="non-persistent://tg/response/text-completion:default", + subscriber="test1", +) system = "You are a lovely assistant." -prompt="Write a funny limerick about a llama" +prompt="what is 2 + 2 == 5" resp = llm.request(system, prompt) print(resp) - - diff --git a/tests/test-prompt-extraction b/tests/test-prompt-extraction index c73bd2e2..20aaaf50 100755 --- a/tests/test-prompt-extraction +++ b/tests/test-prompt-extraction @@ -3,7 +3,12 @@ import json from trustgraph.clients.prompt_client import PromptClient -p = PromptClient(pulsar_host="pulsar://localhost:6650") +p = PromptClient( + pulsar_host="pulsar://localhost:6650", + input_queue="non-persistent://tg/request/prompt:default", + output_queue="non-persistent://tg/response/prompt:default", + subscriber="test1", +) chunk=""" The Space Shuttle was a reusable spacecraft that transported astronauts and cargo to and from Earth's orbit. It was designed to launch like a rocket, maneuver in orbit like a spacecraft, and land like an airplane. The Space Shuttle was NASA's space transportation system and was used for many purposes, including: @@ -31,8 +36,8 @@ The Space Shuttle's last mission was in 2011. q = "Tell me some facts in the knowledge graph" resp = p.request( - id="extract-definition", - terms = { + id="extract-definitions", + variables = { "text": chunk, } ) @@ -40,7 +45,7 @@ resp = p.request( print(resp) for fact in resp: - print(fact["term"], "::") + print(fact["entity"], "::") print(fact["definition"]) print() diff --git a/tests/test-prompt-question b/tests/test-prompt-question index 50660965..78ba72aa 100755 --- a/tests/test-prompt-question +++ b/tests/test-prompt-question @@ -3,13 +3,18 @@ import pulsar from trustgraph.clients.prompt_client import PromptClient -p = PromptClient(pulsar_host="pulsar://localhost:6650") +p = PromptClient( + pulsar_host="pulsar://localhost:6650", + input_queue="non-persistent://tg/request/prompt:default", + output_queue="non-persistent://tg/response/prompt:default", + subscriber="test1", +) question = """What is the square root of 16?""" resp = p.request( id="question", - terms = { + variables = { "question": question } ) diff --git a/tests/test-triples b/tests/test-triples index 05263d0d..e804d844 100755 --- a/tests/test-triples +++ b/tests/test-triples @@ -3,7 +3,9 @@ import pulsar from trustgraph.clients.triples_query_client import TriplesQueryClient -tq = TriplesQueryClient(pulsar_host="pulsar://localhost:6650") +tq = TriplesQueryClient( + pulsar_host="pulsar://localhost:6650", +) e = "http://trustgraph.ai/e/shuttle" diff --git a/trustgraph-base/trustgraph/base/__init__.py b/trustgraph-base/trustgraph/base/__init__.py index 3a58d51e..2accbb21 100644 --- a/trustgraph-base/trustgraph/base/__init__.py +++ b/trustgraph-base/trustgraph/base/__init__.py @@ -1,8 +1,31 @@ -from . base_processor import BaseProcessor +from . pubsub import PulsarClient +from . async_processor import AsyncProcessor from . consumer import Consumer from . producer import Producer -from . consumer_producer import ConsumerProducer from . publisher import Publisher from . subscriber import Subscriber +from . metrics import ProcessorMetrics, ConsumerMetrics, ProducerMetrics +from . flow_processor import FlowProcessor +from . consumer_spec import ConsumerSpec +from . setting_spec import SettingSpec +from . producer_spec import ProducerSpec +from . subscriber_spec import SubscriberSpec +from . request_response_spec import RequestResponseSpec +from . llm_service import LlmService, LlmResult +from . embeddings_service import EmbeddingsService +from . embeddings_client import EmbeddingsClientSpec +from . text_completion_client import TextCompletionClientSpec +from . prompt_client import PromptClientSpec +from . triples_store_service import TriplesStoreService +from . graph_embeddings_store_service import GraphEmbeddingsStoreService +from . document_embeddings_store_service import DocumentEmbeddingsStoreService +from . triples_query_service import TriplesQueryService +from . graph_embeddings_query_service import GraphEmbeddingsQueryService +from . document_embeddings_query_service import DocumentEmbeddingsQueryService +from . graph_embeddings_client import GraphEmbeddingsClientSpec +from . triples_client import TriplesClientSpec +from . document_embeddings_client import DocumentEmbeddingsClientSpec +from . agent_service import AgentService +from . graph_rag_client import GraphRagClientSpec diff --git a/trustgraph-base/trustgraph/base/agent_client.py b/trustgraph-base/trustgraph/base/agent_client.py new file mode 100644 index 00000000..76e1adff --- /dev/null +++ b/trustgraph-base/trustgraph/base/agent_client.py @@ -0,0 +1,39 @@ + +from . request_response_spec import RequestResponse, RequestResponseSpec +from .. schema import AgentRequest, AgentResponse +from .. knowledge import Uri, Literal + +class AgentClient(RequestResponse): + async def request(self, recipient, question, plan=None, state=None, + history=[], timeout=300): + + resp = await self.request( + AgentRequest( + question = question, + plan = plan, + state = state, + history = history, + ), + recipient=recipient, + timeout=timeout, + ) + + print(resp, flush=True) + + if resp.error: + raise RuntimeError(resp.error.message) + + return resp + +class GraphEmbeddingsClientSpec(RequestResponseSpec): + def __init__( + self, request_name, response_name, + ): + super(GraphEmbeddingsClientSpec, self).__init__( + request_name = request_name, + request_schema = GraphEmbeddingsRequest, + response_name = response_name, + response_schema = GraphEmbeddingsResponse, + impl = GraphEmbeddingsClient, + ) + diff --git a/trustgraph-base/trustgraph/base/agent_service.py b/trustgraph-base/trustgraph/base/agent_service.py new file mode 100644 index 00000000..0dbe728e --- /dev/null +++ b/trustgraph-base/trustgraph/base/agent_service.py @@ -0,0 +1,100 @@ + +""" +Agent manager service completion base class +""" + +import time +from prometheus_client import Histogram + +from .. schema import AgentRequest, AgentResponse, Error +from .. exceptions import TooManyRequests +from .. base import FlowProcessor, ConsumerSpec, ProducerSpec + +default_ident = "agent-manager" + +class AgentService(FlowProcessor): + + def __init__(self, **params): + + id = params.get("id") + + super(AgentService, self).__init__(**params | { "id": id }) + + self.register_specification( + ConsumerSpec( + name = "request", + schema = AgentRequest, + handler = self.on_request + ) + ) + + self.register_specification( + ProducerSpec( + name = "next", + schema = AgentRequest + ) + ) + + self.register_specification( + ProducerSpec( + name = "response", + schema = AgentResponse + ) + ) + + async def on_request(self, msg, consumer, flow): + + try: + + request = msg.value() + + # Sender-produced ID + id = msg.properties()["id"] + + async def respond(resp): + + await flow("response").send( + resp, + properties={"id": id} + ) + + async def next(resp): + + await flow("next").send( + resp, + properties={"id": id} + ) + + await self.agent_request( + request = request, respond = respond, next = next, + flow = flow + ) + + except TooManyRequests as e: + raise e + + except Exception as e: + + # Apart from rate limits, treat all exceptions as unrecoverable + print(f"on_request Exception: {e}") + + print("Send error response...", flush=True) + + await flow.producer["response"].send( + AgentResponse( + error=Error( + type = "agent-error", + message = str(e), + ), + thought = None, + observation = None, + answer = None, + ), + properties={"id": id} + ) + + @staticmethod + def add_args(parser): + + FlowProcessor.add_args(parser) + diff --git a/trustgraph-base/trustgraph/base/async_processor.py b/trustgraph-base/trustgraph/base/async_processor.py new file mode 100644 index 00000000..80440b36 --- /dev/null +++ b/trustgraph-base/trustgraph/base/async_processor.py @@ -0,0 +1,254 @@ + +# Base class for processors. Implements: +# - Pulsar client, subscribe and consume basic +# - the async startup logic +# - Initialising metrics + +import asyncio +import argparse +import _pulsar +import time +import uuid +from prometheus_client import start_http_server, Info + +from .. schema import ConfigPush, config_push_queue +from .. log_level import LogLevel +from .. exceptions import TooManyRequests +from . pubsub import PulsarClient +from . producer import Producer +from . consumer import Consumer +from . metrics import ProcessorMetrics + +default_config_queue = config_push_queue + +# Async processor +class AsyncProcessor: + + def __init__(self, **params): + + # Store the identity + self.id = params.get("id") + + # Register a pulsar client + self.pulsar_client = PulsarClient(**params) + + # Initialise metrics, records the parameters + ProcessorMetrics(id=self.id).info({ + k: str(params[k]) + for k in params + if k != "id" + }) + + # The processor runs all activity in a taskgroup, it's mandatory + # that this is provded + self.taskgroup = params.get("taskgroup") + if self.taskgroup is None: + raise RuntimeError("Essential taskgroup missing") + + # Get the configuration topic + self.config_push_queue = params.get( + "config_push_queue", default_config_queue + ) + + # This records registered configuration handlers + self.config_handlers = [] + + # Create a random ID for this subscription to the configuration + # service + config_subscriber_id = str(uuid.uuid4()) + + # Subscribe to config queue + self.config_sub_task = Consumer( + + taskgroup = self.taskgroup, + client = self.client, + subscriber = config_subscriber_id, + flow = None, + + topic = self.config_push_queue, + schema = ConfigPush, + + handler = self.on_config_change, + + # This causes new subscriptions to view the entire history of + # configuration + start_of_messages = True + ) + + self.running = True + + # This is called to start dynamic behaviour. An over-ride point for + # extra functionality + async def start(self): + await self.config_sub_task.start() + + # This is called to stop all threads. An over-ride point for extra + # functionality + def stop(self): + self.client.close() + self.running = False + + # Returns the pulsar host + @property + def pulsar_host(self): return self.client.pulsar_host + + # Returns the pulsar client + @property + def client(self): return self.pulsar_client.client + + # Register a new event handler for configuration change + def register_config_handler(self, handler): + self.config_handlers.append(handler) + + # Called when a new configuration message push occurs + async def on_config_change(self, message, consumer): + + # Get configuration data and version number + config = message.value().config + version = message.value().version + + # Acknowledge the message + consumer.acknowledge(message) + + # Invoke message handlers + print("Config change event", config, version, flush=True) + for ch in self.config_handlers: + await ch(config, version) + + # This is the 'main' body of the handler. It is a point to override + # if needed. By default does nothing. Processors are implemented + # by adding consumer/producer functionality so maybe nothing is needed + # in the run() body + async def run(self): + while self.running: + await asyncio.sleep(2) + + # Startup fabric. This runs in 'async' mode, creates a taskgroup and + # runs the producer. + @classmethod + async def launch_async(cls, args): + + try: + + # Create a taskgroup. This seems complicated, when an exception + # occurs, unhandled it looks like it cancels all threads in the + # taskgroup. Needs the exception to be caught in the right + # place. + async with asyncio.TaskGroup() as tg: + + + # Create a processor instance, and include the taskgroup + # as a paramter. A processor identity ident is used as + # - subscriber name + # - an identifier for flow configuration + p = cls(**args | { "taskgroup": tg }) + + # Start the processor + await p.start() + + # Run the processor + task = tg.create_task(p.run()) + + # The taskgroup causes everything to wait until + # all threads have stopped + + # This is here to output a debug message, shouldn't be needed. + except Exception as e: + print("Exception, closing taskgroup", flush=True) + raise e + + # Startup fabric. launch calls launch_async in async mode. + @classmethod + def launch(cls, ident, doc): + + # Start assembling CLI arguments + parser = argparse.ArgumentParser( + prog=ident, + description=doc + ) + + parser.add_argument( + '--id', + default=ident, + help=f'Configuration identity (default: {ident})', + ) + + # Invoke the class-specific add_args, which manages adding all the + # command-line arguments + cls.add_args(parser) + + # Parse arguments + args = parser.parse_args() + args = vars(args) + + # Debug + print(args, flush=True) + + # Start the Prometheus metrics service if needed + if args["metrics"]: + start_http_server(args["metrics_port"]) + + # Loop forever, exception handler + while True: + + print("Starting...", flush=True) + + try: + + # Launch the processor in an asyncio handler + asyncio.run(cls.launch_async( + args + )) + + except KeyboardInterrupt: + print("Keyboard interrupt.", flush=True) + return + + except _pulsar.Interrupted: + print("Pulsar Interrupted.", flush=True) + return + + # Exceptions from a taskgroup come in as an exception group + except ExceptionGroup as e: + + print("Exception group:", flush=True) + + for se in e.exceptions: + print(" Type:", type(se), flush=True) + print(f" Exception: {se}", flush=True) + + except Exception as e: + print("Type:", type(e), flush=True) + print("Exception:", e, flush=True) + + # Retry occurs here + print("Will retry...", flush=True) + time.sleep(4) + print("Retrying...", flush=True) + + # The command-line arguments are built using a stack of add_args + # invocations + @staticmethod + def add_args(parser): + + PulsarClient.add_args(parser) + + parser.add_argument( + '--config-push-queue', + default=default_config_queue, + help=f'Config push queue {default_config_queue}', + ) + + parser.add_argument( + '--metrics', + action=argparse.BooleanOptionalAction, + default=True, + help=f'Metrics enabled (default: true)', + ) + + parser.add_argument( + '-P', '--metrics-port', + type=int, + default=8000, + help=f'Pulsar host (default: 8000)', + ) diff --git a/trustgraph-base/trustgraph/base/base_processor.py b/trustgraph-base/trustgraph/base/base_processor.py deleted file mode 100644 index 05cdb940..00000000 --- a/trustgraph-base/trustgraph/base/base_processor.py +++ /dev/null @@ -1,210 +0,0 @@ - -import asyncio -import os -import argparse -import pulsar -from pulsar.schema import JsonSchema -import _pulsar -import time -import uuid -from prometheus_client import start_http_server, Info - -from .. schema import ConfigPush, config_push_queue -from .. log_level import LogLevel - -default_config_queue = config_push_queue -config_subscriber_id = str(uuid.uuid4()) - -class BaseProcessor: - - default_pulsar_host = os.getenv("PULSAR_HOST", 'pulsar://pulsar:6650') - default_pulsar_api_key = os.getenv("PULSAR_API_KEY", None) - - def __init__(self, **params): - - self.client = None - - if not hasattr(__class__, "params_metric"): - __class__.params_metric = Info( - 'params', 'Parameters configuration' - ) - - # FIXME: Maybe outputs information it should not - __class__.params_metric.info({ - k: str(params[k]) - for k in params - }) - - pulsar_host = params.get("pulsar_host", self.default_pulsar_host) - pulsar_listener = params.get("pulsar_listener", None) - pulsar_api_key = params.get("pulsar_api_key", None) - log_level = params.get("log_level", LogLevel.INFO) - - self.config_push_queue = params.get( - "config_push_queue", - default_config_queue - ) - - self.pulsar_host = pulsar_host - self.pulsar_api_key = pulsar_api_key - - if pulsar_api_key: - auth = pulsar.AuthenticationToken(pulsar_api_key) - self.client = pulsar.Client( - pulsar_host, - authentication=auth, - logger=pulsar.ConsoleLogger(log_level.to_pulsar()) - ) - else: - self.client = pulsar.Client( - pulsar_host, - listener_name=pulsar_listener, - logger=pulsar.ConsoleLogger(log_level.to_pulsar()) - ) - - self.pulsar_listener = pulsar_listener - - self.config_subscriber = self.client.subscribe( - self.config_push_queue, config_subscriber_id, - consumer_type=pulsar.ConsumerType.Shared, - initial_position=pulsar.InitialPosition.Earliest, - schema=JsonSchema(ConfigPush), - ) - - def __del__(self): - - if hasattr(self, "client"): - if self.client: - self.client.close() - - @staticmethod - def add_args(parser): - - parser.add_argument( - '-p', '--pulsar-host', - default=__class__.default_pulsar_host, - help=f'Pulsar host (default: {__class__.default_pulsar_host})', - ) - - parser.add_argument( - '--pulsar-api-key', - default=__class__.default_pulsar_api_key, - help=f'Pulsar API key', - ) - - parser.add_argument( - '--config-push-queue', - default=default_config_queue, - help=f'Config push queue {default_config_queue}', - ) - - parser.add_argument( - '--pulsar-listener', - help=f'Pulsar listener (default: none)', - ) - - parser.add_argument( - '-l', '--log-level', - type=LogLevel, - default=LogLevel.INFO, - choices=list(LogLevel), - help=f'Output queue (default: info)' - ) - - parser.add_argument( - '--metrics', - action=argparse.BooleanOptionalAction, - default=True, - help=f'Metrics enabled (default: true)', - ) - - parser.add_argument( - '-P', '--metrics-port', - type=int, - default=8000, - help=f'Pulsar host (default: 8000)', - ) - - async def start(self): - pass - - async def run_config_queue(self): - - if self.module == "config.service": - print("I am config-svc, not looking at config queue", flush=True) - return - - print("Config thread running", flush=True) - - while True: - - try: - msg = await asyncio.to_thread( - self.config_subscriber.receive, timeout_millis=2000 - ) - except pulsar.Timeout: - continue - - v = msg.value() - print("Got config version", v.version, flush=True) - - await self.on_config(v.version, v.config) - - async def on_config(self, version, config): - pass - - async def run(self): - raise RuntimeError("Something should have implemented the run method") - - @classmethod - async def launch_async(cls, args, prog): - p = cls(**args) - p.module = prog - await p.start() - - task1 = asyncio.create_task(p.run_config_queue()) - task2 = asyncio.create_task(p.run()) - - await asyncio.gather(task1, task2) - - @classmethod - def launch(cls, prog, doc): - - parser = argparse.ArgumentParser( - prog=prog, - description=doc - ) - - cls.add_args(parser) - - args = parser.parse_args() - args = vars(args) - - print(args) - - if args["metrics"]: - start_http_server(args["metrics_port"]) - - while True: - - try: - - asyncio.run(cls.launch_async(args, prog)) - - except KeyboardInterrupt: - print("Keyboard interrupt.") - return - - except _pulsar.Interrupted: - print("Pulsar Interrupted.") - return - - except Exception as e: - - print(type(e)) - - print("Exception:", e, flush=True) - print("Will retry...", flush=True) - - time.sleep(4) - diff --git a/trustgraph-base/trustgraph/base/consumer.py b/trustgraph-base/trustgraph/base/consumer.py index fdbe5531..57b940ac 100644 --- a/trustgraph-base/trustgraph/base/consumer.py +++ b/trustgraph-base/trustgraph/base/consumer.py @@ -1,93 +1,136 @@ -import asyncio from pulsar.schema import JsonSchema import pulsar -from prometheus_client import Histogram, Info, Counter, Enum +import _pulsar +import asyncio import time -from . base_processor import BaseProcessor from .. exceptions import TooManyRequests -default_rate_limit_retry = 10 -default_rate_limit_timeout = 7200 +class Consumer: -class Consumer(BaseProcessor): + def __init__( + self, taskgroup, flow, client, topic, subscriber, schema, + handler, + metrics = None, + start_of_messages=False, + rate_limit_retry_time = 10, rate_limit_timeout = 7200, + reconnect_time = 5, + ): - def __init__(self, **params): + self.taskgroup = taskgroup + self.flow = flow + self.client = client + self.topic = topic + self.subscriber = subscriber + self.schema = schema + self.handler = handler - if not hasattr(__class__, "state_metric"): - __class__.state_metric = Enum( - 'processor_state', 'Processor state', - states=['starting', 'running', 'stopped'] - ) - __class__.state_metric.state('starting') + self.rate_limit_retry_time = rate_limit_retry_time + self.rate_limit_timeout = rate_limit_timeout - __class__.state_metric.state('starting') + self.reconnect_time = 5 - super(Consumer, self).__init__(**params) + self.start_of_messages = start_of_messages - self.input_queue = params.get("input_queue") - self.subscriber = params.get("subscriber") - self.input_schema = params.get("input_schema") + self.running = True + self.task = None - self.rate_limit_retry = params.get( - "rate_limit_retry", default_rate_limit_retry - ) - self.rate_limit_timeout = params.get( - "rate_limit_timeout", default_rate_limit_timeout - ) + self.metrics = metrics - if self.input_schema == None: - raise RuntimeError("input_schema must be specified") + self.consumer = None - if not hasattr(__class__, "request_metric"): - __class__.request_metric = Histogram( - 'request_latency', 'Request latency (seconds)' - ) + def __del__(self): + self.running = False - if not hasattr(__class__, "pubsub_metric"): - __class__.pubsub_metric = Info( - 'pubsub', 'Pub/sub configuration' - ) + if hasattr(self, "consumer"): + if self.consumer: + self.consumer.close() - if not hasattr(__class__, "processing_metric"): - __class__.processing_metric = Counter( - 'processing_count', 'Processing count', ["status"] - ) + async def stop(self): - if not hasattr(__class__, "rate_limit_metric"): - __class__.rate_limit_metric = Counter( - 'rate_limit_count', 'Rate limit event count', - ) + self.running = False + await self.task - __class__.pubsub_metric.info({ - "input_queue": self.input_queue, - "subscriber": self.subscriber, - "input_schema": self.input_schema.__name__, - "rate_limit_retry": str(self.rate_limit_retry), - "rate_limit_timeout": str(self.rate_limit_timeout), - }) + async def start(self): - self.consumer = self.client.subscribe( - self.input_queue, self.subscriber, - consumer_type=pulsar.ConsumerType.Shared, - schema=JsonSchema(self.input_schema), - ) + self.running = True - print("Initialised consumer.", flush=True) + # Puts it in the stopped state, the run thread should set running + if self.metrics: + self.metrics.state("stopped") + + self.task = self.taskgroup.create_task(self.run()) async def run(self): - __class__.state_metric.state('running') + while self.running: - while True: + if self.metrics: + self.metrics.state("stopped") - msg = await asyncio.to_thread(self.consumer.receive) + try: + + print(self.topic, "subscribing...", flush=True) + + if self.start_of_messages: + pos = pulsar.InitialPosition.Earliest + else: + pos = pulsar.InitialPosition.Latest + + self.consumer = await asyncio.to_thread( + self.client.subscribe, + topic = self.topic, + subscription_name = self.subscriber, + schema = JsonSchema(self.schema), + initial_position = pos, + consumer_type = pulsar.ConsumerType.Shared, + ) + + except Exception as e: + + print("consumer subs Exception:", e, flush=True) + await asyncio.sleep(self.reconnect_time) + continue + + print(self.topic, "subscribed", flush=True) + + if self.metrics: + self.metrics.state("running") + + try: + + await self.consume() + + if self.metrics: + self.metrics.state("stopped") + + except Exception as e: + + print("consumer loop exception:", e, flush=True) + self.consumer.close() + self.consumer = None + await asyncio.sleep(self.reconnect_time) + continue + + async def consume(self): + + while self.running: + + try: + msg = await asyncio.to_thread( + self.consumer.receive, + timeout_millis=2000 + ) + except _pulsar.Timeout: + continue + except Exception as e: + raise e expiry = time.time() + self.rate_limit_timeout # This loop is for retry on rate-limit / resource limits - while True: + while self.running: if time.time() > expiry: @@ -97,20 +140,31 @@ class Consumer(BaseProcessor): # be retried self.consumer.negative_acknowledge(msg) - __class__.processing_metric.labels(status="error").inc() + if self.metrics: + self.metrics.process("error") # Break out of retry loop, processes next message break try: - with __class__.request_metric.time(): - await self.handle(msg) + print("Handle...", flush=True) + + if self.metrics: + + with self.metrics.record_time(): + await self.handler(msg, self, self.flow) + + else: + await self.handler(msg, self.consumer) + + print("Handled.", flush=True) # Acknowledge successful processing of the message self.consumer.acknowledge(msg) - __class__.processing_metric.labels(status="success").inc() + if self.metrics: + self.metrics.process("success") # Break out of retry loop break @@ -119,55 +173,25 @@ class Consumer(BaseProcessor): print("TooManyRequests: will retry...", flush=True) - __class__.rate_limit_metric.inc() + if self.metrics: + self.metrics.rate_limit() # Sleep - time.sleep(self.rate_limit_retry) + await asyncio.sleep(self.rate_limit_retry_time) # Contine from retry loop, just causes a reprocessing continue - + except Exception as e: - print("Exception:", e, flush=True) + print("consume exception:", e, flush=True) # Message failed to be processed, this causes it to # be retried self.consumer.negative_acknowledge(msg) - __class__.processing_metric.labels(status="error").inc() + if self.metrics: + self.metrics.process("error") # Break out of retry loop, processes next message break - - @staticmethod - def add_args(parser, default_input_queue, default_subscriber): - - BaseProcessor.add_args(parser) - - parser.add_argument( - '-i', '--input-queue', - default=default_input_queue, - help=f'Input queue (default: {default_input_queue})' - ) - - parser.add_argument( - '-s', '--subscriber', - default=default_subscriber, - help=f'Queue subscriber name (default: {default_subscriber})' - ) - - parser.add_argument( - '--rate-limit-retry', - type=int, - default=default_rate_limit_retry, - help=f'Rate limit retry (default: {default_rate_limit_retry})' - ) - - parser.add_argument( - '--rate-limit-timeout', - type=int, - default=default_rate_limit_timeout, - help=f'Rate limit timeout (default: {default_rate_limit_timeout})' - ) - diff --git a/trustgraph-base/trustgraph/base/consumer_producer.py b/trustgraph-base/trustgraph/base/consumer_producer.py deleted file mode 100644 index 1006f9b5..00000000 --- a/trustgraph-base/trustgraph/base/consumer_producer.py +++ /dev/null @@ -1,62 +0,0 @@ - -from pulsar.schema import JsonSchema -import pulsar -from prometheus_client import Histogram, Info, Counter, Enum -import time - -from . consumer import Consumer -from .. exceptions import TooManyRequests - -class ConsumerProducer(Consumer): - - def __init__(self, **params): - - super(ConsumerProducer, self).__init__(**params) - - self.output_queue = params.get("output_queue") - self.output_schema = params.get("output_schema") - - if not hasattr(__class__, "output_metric"): - __class__.output_metric = Counter( - 'output_count', 'Output items created' - ) - - __class__.pubsub_metric.info({ - "input_queue": self.input_queue, - "output_queue": self.output_queue, - "subscriber": self.subscriber, - "input_schema": self.input_schema.__name__, - "output_schema": self.output_schema.__name__, - "rate_limit_retry": str(self.rate_limit_retry), - "rate_limit_timeout": str(self.rate_limit_timeout), - }) - - if self.output_schema == None: - raise RuntimeError("output_schema must be specified") - - self.producer = self.client.create_producer( - topic=self.output_queue, - schema=JsonSchema(self.output_schema), - chunking_enabled=True, - ) - - print("Initialised consumer/producer.") - - async def send(self, msg, properties={}): - self.producer.send(msg, properties) - __class__.output_metric.inc() - - @staticmethod - def add_args( - parser, default_input_queue, default_subscriber, - default_output_queue, - ): - - Consumer.add_args(parser, default_input_queue, default_subscriber) - - parser.add_argument( - '-o', '--output-queue', - default=default_output_queue, - help=f'Output queue (default: {default_output_queue})' - ) - diff --git a/trustgraph-base/trustgraph/base/consumer_spec.py b/trustgraph-base/trustgraph/base/consumer_spec.py new file mode 100644 index 00000000..aaeca677 --- /dev/null +++ b/trustgraph-base/trustgraph/base/consumer_spec.py @@ -0,0 +1,36 @@ + +from . metrics import ConsumerMetrics +from . consumer import Consumer +from . spec import Spec + +class ConsumerSpec(Spec): + def __init__(self, name, schema, handler): + self.name = name + self.schema = schema + self.handler = handler + + def add(self, flow, processor, definition): + + consumer_metrics = ConsumerMetrics( + flow.id, f"{flow.name}-{self.name}" + ) + + consumer = Consumer( + taskgroup = processor.taskgroup, + flow = flow, + client = processor.client, + topic = definition[self.name], + subscriber = processor.id + "--" + self.name, + schema = self.schema, + handler = self.handler, + metrics = consumer_metrics, + ) + + # Consumer handle gets access to producers and other + # metadata + consumer.id = flow.id + consumer.name = self.name + consumer.flow = flow + + flow.consumer[self.name] = consumer + diff --git a/trustgraph-base/trustgraph/base/document_embeddings_client.py b/trustgraph-base/trustgraph/base/document_embeddings_client.py new file mode 100644 index 00000000..86370c52 --- /dev/null +++ b/trustgraph-base/trustgraph/base/document_embeddings_client.py @@ -0,0 +1,38 @@ + +from . request_response_spec import RequestResponse, RequestResponseSpec +from .. schema import DocumentEmbeddingsRequest, DocumentEmbeddingsResponse +from .. knowledge import Uri, Literal + +class DocumentEmbeddingsClient(RequestResponse): + async def query(self, vectors, limit=20, user="trustgraph", + collection="default", timeout=30): + + resp = await self.request( + DocumentEmbeddingsRequest( + vectors = vectors, + limit = limit, + user = user, + collection = collection + ), + timeout=timeout + ) + + print(resp, flush=True) + + if resp.error: + raise RuntimeError(resp.error.message) + + return resp.documents + +class DocumentEmbeddingsClientSpec(RequestResponseSpec): + def __init__( + self, request_name, response_name, + ): + super(DocumentEmbeddingsClientSpec, self).__init__( + request_name = request_name, + request_schema = DocumentEmbeddingsRequest, + response_name = response_name, + response_schema = DocumentEmbeddingsResponse, + impl = DocumentEmbeddingsClient, + ) + diff --git a/trustgraph-base/trustgraph/base/document_embeddings_query_service.py b/trustgraph-base/trustgraph/base/document_embeddings_query_service.py new file mode 100644 index 00000000..0dee7001 --- /dev/null +++ b/trustgraph-base/trustgraph/base/document_embeddings_query_service.py @@ -0,0 +1,84 @@ + +""" +Document embeddings query service. Input is vectors. Output is list of +embeddings. +""" + +from .. schema import DocumentEmbeddingsRequest, DocumentEmbeddingsResponse +from .. schema import Error, Value + +from . flow_processor import FlowProcessor +from . consumer_spec import ConsumerSpec +from . producer_spec import ProducerSpec + +default_ident = "ge-query" + +class DocumentEmbeddingsQueryService(FlowProcessor): + + def __init__(self, **params): + + id = params.get("id") + + super(DocumentEmbeddingsQueryService, self).__init__( + **params | { "id": id } + ) + + self.register_specification( + ConsumerSpec( + name = "request", + schema = DocumentEmbeddingsRequest, + handler = self.on_message + ) + ) + + self.register_specification( + ProducerSpec( + name = "response", + schema = DocumentEmbeddingsResponse, + ) + ) + + async def on_message(self, msg, consumer, flow): + + try: + + request = msg.value() + + # Sender-produced ID + id = msg.properties()["id"] + + print(f"Handling input {id}...", flush=True) + + docs = await self.query_document_embeddings(request) + + print("Send response...", flush=True) + r = DocumentEmbeddingsResponse(documents=docs, error=None) + await flow("response").send(r, properties={"id": id}) + + print("Done.", flush=True) + + except Exception as e: + + print(f"Exception: {e}") + + print("Send error response...", flush=True) + + r = DocumentEmbeddingsResponse( + error=Error( + type = "document-embeddings-query-error", + message = str(e), + ), + response=None, + ) + + await flow("response").send(r, properties={"id": id}) + + @staticmethod + def add_args(parser): + + FlowProcessor.add_args(parser) + +def run(): + + Processor.launch(default_ident, __doc__) + diff --git a/trustgraph-base/trustgraph/base/document_embeddings_store_service.py b/trustgraph-base/trustgraph/base/document_embeddings_store_service.py new file mode 100644 index 00000000..fbf58869 --- /dev/null +++ b/trustgraph-base/trustgraph/base/document_embeddings_store_service.py @@ -0,0 +1,50 @@ + +""" +Document embeddings store base class +""" + +from .. schema import DocumentEmbeddings +from .. base import FlowProcessor, ConsumerSpec +from .. exceptions import TooManyRequests + +default_ident = "document-embeddings-write" + +class DocumentEmbeddingsStoreService(FlowProcessor): + + def __init__(self, **params): + + id = params.get("id") + + super(DocumentEmbeddingsStoreService, self).__init__( + **params | { "id": id } + ) + + self.register_specification( + ConsumerSpec( + name = "input", + schema = DocumentEmbeddings, + handler = self.on_message + ) + ) + + async def on_message(self, msg, consumer, flow): + + try: + + request = msg.value() + + await self.store_document_embeddings(request) + + except TooManyRequests as e: + raise e + + except Exception as e: + + print(f"Exception: {e}") + raise e + + @staticmethod + def add_args(parser): + + FlowProcessor.add_args(parser) + diff --git a/trustgraph-base/trustgraph/base/embeddings_client.py b/trustgraph-base/trustgraph/base/embeddings_client.py new file mode 100644 index 00000000..ceb08eb2 --- /dev/null +++ b/trustgraph-base/trustgraph/base/embeddings_client.py @@ -0,0 +1,31 @@ + +from . request_response_spec import RequestResponse, RequestResponseSpec +from .. schema import EmbeddingsRequest, EmbeddingsResponse + +class EmbeddingsClient(RequestResponse): + async def embed(self, text, timeout=30): + + resp = await self.request( + EmbeddingsRequest( + text = text + ), + timeout=timeout + ) + + if resp.error: + raise RuntimeError(resp.error.message) + + return resp.vectors + +class EmbeddingsClientSpec(RequestResponseSpec): + def __init__( + self, request_name, response_name, + ): + super(EmbeddingsClientSpec, self).__init__( + request_name = request_name, + request_schema = EmbeddingsRequest, + response_name = response_name, + response_schema = EmbeddingsResponse, + impl = EmbeddingsClient, + ) + diff --git a/trustgraph-base/trustgraph/base/embeddings_service.py b/trustgraph-base/trustgraph/base/embeddings_service.py new file mode 100644 index 00000000..c6befdb7 --- /dev/null +++ b/trustgraph-base/trustgraph/base/embeddings_service.py @@ -0,0 +1,90 @@ + +""" +Embeddings resolution base class +""" + +import time +from prometheus_client import Histogram + +from .. schema import EmbeddingsRequest, EmbeddingsResponse, Error +from .. exceptions import TooManyRequests +from .. base import FlowProcessor, ConsumerSpec, ProducerSpec + +default_ident = "embeddings" + +class EmbeddingsService(FlowProcessor): + + def __init__(self, **params): + + id = params.get("id") + + super(EmbeddingsService, self).__init__(**params | { "id": id }) + + self.register_specification( + ConsumerSpec( + name = "request", + schema = EmbeddingsRequest, + handler = self.on_request + ) + ) + + self.register_specification( + ProducerSpec( + name = "response", + schema = EmbeddingsResponse + ) + ) + + async def on_request(self, msg, consumer, flow): + + try: + + request = msg.value() + + # Sender-produced ID + + id = msg.properties()["id"] + + print("Handling request", id, "...", flush=True) + + vectors = await self.on_embeddings(request.text) + + await flow("response").send( + EmbeddingsResponse( + error = None, + vectors = vectors, + ), + properties={"id": id} + ) + + print("Handled.", flush=True) + + except TooManyRequests as e: + raise e + + except Exception as e: + + # Apart from rate limits, treat all exceptions as unrecoverable + + print(f"Exception: {e}", flush=True) + + print("Send error response...", flush=True) + + await flow.producer["response"].send( + EmbeddingsResponse( + error=Error( + type = "embeddings-error", + message = str(e), + ), + vectors=None, + ), + properties={"id": id} + ) + + @staticmethod + def add_args(parser): + + FlowProcessor.add_args(parser) + + + diff --git a/trustgraph-base/trustgraph/base/flow.py b/trustgraph-base/trustgraph/base/flow.py new file mode 100644 index 00000000..9cda34a0 --- /dev/null +++ b/trustgraph-base/trustgraph/base/flow.py @@ -0,0 +1,32 @@ + +import asyncio + +class Flow: + def __init__(self, id, flow, processor, defn): + + self.id = id + self.name = flow + + self.producer = {} + + # Consumers and publishers. Is this a bit untidy? + self.consumer = {} + + self.setting = {} + + for spec in processor.specifications: + spec.add(self, processor, defn) + + async def start(self): + for c in self.consumer.values(): + await c.start() + + async def stop(self): + for c in self.consumer.values(): + await c.stop() + + def __call__(self, key): + if key in self.producer: return self.producer[key] + if key in self.consumer: return self.consumer[key] + if key in self.setting: return self.setting[key].value + return None diff --git a/trustgraph-base/trustgraph/base/flow_processor.py b/trustgraph-base/trustgraph/base/flow_processor.py new file mode 100644 index 00000000..e6460fe3 --- /dev/null +++ b/trustgraph-base/trustgraph/base/flow_processor.py @@ -0,0 +1,115 @@ + +# Base class for processor with management of flows in & out which are managed +# by configuration. This is probably all processor types, except for the +# configuration service which can't manage itself. + +import json + +from pulsar.schema import JsonSchema + +from .. schema import Error +from .. schema import config_request_queue, config_response_queue +from .. schema import config_push_queue +from .. log_level import LogLevel +from . async_processor import AsyncProcessor +from . flow import Flow + +# Parent class for configurable processors, configured with flows by +# the config service +class FlowProcessor(AsyncProcessor): + + def __init__(self, **params): + + # Initialise base class + super(FlowProcessor, self).__init__(**params) + + # Register configuration handler + self.register_config_handler(self.on_configure_flows) + + # Initialise flow information state + self.flows = {} + + # These can be overriden by a derived class: + + # Array of specifications: ConsumerSpec, ProducerSpec, SettingSpec + self.specifications = [] + + print("Service initialised.") + + # Register a configuration variable + def register_specification(self, spec): + self.specifications.append(spec) + + # Start processing for a new flow + async def start_flow(self, flow, defn): + self.flows[flow] = Flow(self.id, flow, self, defn) + await self.flows[flow].start() + print("Started flow: ", flow) + + # Stop processing for a new flow + async def stop_flow(self, flow): + if flow in self.flows: + await self.flows[flow].stop() + del self.flows[flow] + print("Stopped flow: ", flow, flush=True) + + # Event handler - called for a configuration change + async def on_configure_flows(self, config, version): + + print("Got config version", version, flush=True) + + # Skip over invalid data + if "flows-active" not in config: return + + # Check there's configuration information for me + if self.id in config["flows-active"]: + + # Get my flow config + flow_config = json.loads(config["flows-active"][self.id]) + + else: + + print("No configuration settings for me.", flush=True) + flow_config = {} + + # Get list of flows which should be running and are currently + # running + wanted_flows = flow_config.keys() + current_flows = self.flows.keys() + + # Start all the flows which arent currently running + for flow in wanted_flows: + if flow not in current_flows: + await self.start_flow(flow, flow_config[flow]) + + # Stop all the unwanted flows which are due to be stopped + for flow in current_flows: + if flow not in wanted_flows: + await self.stop_flow(flow) + + print("Handled config update") + + # Start threads, just call parent + async def start(self): + await super(FlowProcessor, self).start() + + @staticmethod + def add_args(parser): + + AsyncProcessor.add_args(parser) + + # parser.add_argument( + # '--rate-limit-retry', + # type=int, + # default=default_rate_limit_retry, + # help=f'Rate limit retry (default: {default_rate_limit_retry})' + # ) + + # parser.add_argument( + # '--rate-limit-timeout', + # type=int, + # default=default_rate_limit_timeout, + # help=f'Rate limit timeout (default: {default_rate_limit_timeout})' + # ) + + diff --git a/trustgraph-base/trustgraph/base/graph_embeddings_client.py b/trustgraph-base/trustgraph/base/graph_embeddings_client.py new file mode 100644 index 00000000..e89364f2 --- /dev/null +++ b/trustgraph-base/trustgraph/base/graph_embeddings_client.py @@ -0,0 +1,45 @@ + +from . request_response_spec import RequestResponse, RequestResponseSpec +from .. schema import GraphEmbeddingsRequest, GraphEmbeddingsResponse +from .. knowledge import Uri, Literal + +def to_value(x): + if x.is_uri: return Uri(x.value) + return Literal(x.value) + +class GraphEmbeddingsClient(RequestResponse): + async def query(self, vectors, limit=20, user="trustgraph", + collection="default", timeout=30): + + resp = await self.request( + GraphEmbeddingsRequest( + vectors = vectors, + limit = limit, + user = user, + collection = collection + ), + timeout=timeout + ) + + print(resp, flush=True) + + if resp.error: + raise RuntimeError(resp.error.message) + + return [ + to_value(v) + for v in resp.entities + ] + +class GraphEmbeddingsClientSpec(RequestResponseSpec): + def __init__( + self, request_name, response_name, + ): + super(GraphEmbeddingsClientSpec, self).__init__( + request_name = request_name, + request_schema = GraphEmbeddingsRequest, + response_name = response_name, + response_schema = GraphEmbeddingsResponse, + impl = GraphEmbeddingsClient, + ) + diff --git a/trustgraph-base/trustgraph/base/graph_embeddings_query_service.py b/trustgraph-base/trustgraph/base/graph_embeddings_query_service.py new file mode 100644 index 00000000..fb2e8dc5 --- /dev/null +++ b/trustgraph-base/trustgraph/base/graph_embeddings_query_service.py @@ -0,0 +1,84 @@ + +""" +Graph embeddings query service. Input is vectors. Output is list of +embeddings. +""" + +from .. schema import GraphEmbeddingsRequest, GraphEmbeddingsResponse +from .. schema import Error, Value + +from . flow_processor import FlowProcessor +from . consumer_spec import ConsumerSpec +from . producer_spec import ProducerSpec + +default_ident = "ge-query" + +class GraphEmbeddingsQueryService(FlowProcessor): + + def __init__(self, **params): + + id = params.get("id") + + super(GraphEmbeddingsQueryService, self).__init__( + **params | { "id": id } + ) + + self.register_specification( + ConsumerSpec( + name = "request", + schema = GraphEmbeddingsRequest, + handler = self.on_message + ) + ) + + self.register_specification( + ProducerSpec( + name = "response", + schema = GraphEmbeddingsResponse, + ) + ) + + async def on_message(self, msg, consumer, flow): + + try: + + request = msg.value() + + # Sender-produced ID + id = msg.properties()["id"] + + print(f"Handling input {id}...", flush=True) + + entities = await self.query_graph_embeddings(request) + + print("Send response...", flush=True) + r = GraphEmbeddingsResponse(entities=entities, error=None) + await flow("response").send(r, properties={"id": id}) + + print("Done.", flush=True) + + except Exception as e: + + print(f"Exception: {e}") + + print("Send error response...", flush=True) + + r = GraphEmbeddingsResponse( + error=Error( + type = "graph-embeddings-query-error", + message = str(e), + ), + response=None, + ) + + await flow("response").send(r, properties={"id": id}) + + @staticmethod + def add_args(parser): + + FlowProcessor.add_args(parser) + +def run(): + + Processor.launch(default_ident, __doc__) + diff --git a/trustgraph-base/trustgraph/base/graph_embeddings_store_service.py b/trustgraph-base/trustgraph/base/graph_embeddings_store_service.py new file mode 100644 index 00000000..911b90c1 --- /dev/null +++ b/trustgraph-base/trustgraph/base/graph_embeddings_store_service.py @@ -0,0 +1,50 @@ + +""" +Graph embeddings store base class +""" + +from .. schema import GraphEmbeddings +from .. base import FlowProcessor, ConsumerSpec +from .. exceptions import TooManyRequests + +default_ident = "graph-embeddings-write" + +class GraphEmbeddingsStoreService(FlowProcessor): + + def __init__(self, **params): + + id = params.get("id") + + super(GraphEmbeddingsStoreService, self).__init__( + **params | { "id": id } + ) + + self.register_specification( + ConsumerSpec( + name = "input", + schema = GraphEmbeddings, + handler = self.on_message + ) + ) + + async def on_message(self, msg, consumer, flow): + + try: + + request = msg.value() + + await self.store_graph_embeddings(request) + + except TooManyRequests as e: + raise e + + except Exception as e: + + print(f"Exception: {e}") + raise e + + @staticmethod + def add_args(parser): + + FlowProcessor.add_args(parser) + diff --git a/trustgraph-base/trustgraph/base/graph_rag_client.py b/trustgraph-base/trustgraph/base/graph_rag_client.py new file mode 100644 index 00000000..c4f3f7ab --- /dev/null +++ b/trustgraph-base/trustgraph/base/graph_rag_client.py @@ -0,0 +1,33 @@ + +from . request_response_spec import RequestResponse, RequestResponseSpec +from .. schema import GraphRagQuery, GraphRagResponse + +class GraphRagClient(RequestResponse): + async def rag(self, query, user="trustgraph", collection="default", + timeout=600): + resp = await self.request( + GraphRagQuery( + query = query, + user = user, + collection = collection, + ), + timeout=timeout + ) + + if resp.error: + raise RuntimeError(resp.error.message) + + return resp.response + +class GraphRagClientSpec(RequestResponseSpec): + def __init__( + self, request_name, response_name, + ): + super(GraphRagClientSpec, self).__init__( + request_name = request_name, + request_schema = GraphRagQuery, + response_name = response_name, + response_schema = GraphRagResponse, + impl = GraphRagClient, + ) + diff --git a/trustgraph-base/trustgraph/base/llm_service.py b/trustgraph-base/trustgraph/base/llm_service.py new file mode 100644 index 00000000..39323db7 --- /dev/null +++ b/trustgraph-base/trustgraph/base/llm_service.py @@ -0,0 +1,114 @@ + +""" +LLM text completion base class +""" + +import time +from prometheus_client import Histogram + +from .. schema import TextCompletionRequest, TextCompletionResponse, Error +from .. exceptions import TooManyRequests +from .. base import FlowProcessor, ConsumerSpec, ProducerSpec + +default_ident = "text-completion" + +class LlmResult: + __slots__ = ["text", "in_token", "out_token", "model"] + +class LlmService(FlowProcessor): + + def __init__(self, **params): + + id = params.get("id") + + super(LlmService, self).__init__(**params | { "id": id }) + + self.register_specification( + ConsumerSpec( + name = "request", + schema = TextCompletionRequest, + handler = self.on_request + ) + ) + + self.register_specification( + ProducerSpec( + name = "response", + schema = TextCompletionResponse + ) + ) + + if not hasattr(__class__, "text_completion_metric"): + __class__.text_completion_metric = Histogram( + 'text_completion_duration', + 'Text completion duration (seconds)', + ["id", "flow"], + buckets=[ + 0.25, 0.5, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, + 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, + 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, + 30.0, 35.0, 40.0, 45.0, 50.0, 60.0, 80.0, 100.0, + 120.0 + ] + ) + + async def on_request(self, msg, consumer, flow): + + try: + + request = msg.value() + + # Sender-produced ID + + id = msg.properties()["id"] + + with __class__.text_completion_metric.labels( + id=self.id, + flow=f"{flow.name}-{consumer.name}", + ).time(): + + response = await self.generate_content( + request.system, request.prompt + ) + + await flow("response").send( + TextCompletionResponse( + error=None, + response=response.text, + in_token=response.in_token, + out_token=response.out_token, + model=response.model + ), + properties={"id": id} + ) + + except TooManyRequests as e: + raise e + + except Exception as e: + + # Apart from rate limits, treat all exceptions as unrecoverable + + print(f"Exception: {e}") + + print("Send error response...", flush=True) + + await flow.producer["response"].send( + TextCompletionResponse( + error=Error( + type = "llm-error", + message = str(e), + ), + response=None, + in_token=None, + out_token=None, + model=None, + ), + properties={"id": id} + ) + + @staticmethod + def add_args(parser): + + FlowProcessor.add_args(parser) + diff --git a/trustgraph-base/trustgraph/base/metrics.py b/trustgraph-base/trustgraph/base/metrics.py new file mode 100644 index 00000000..5d87849f --- /dev/null +++ b/trustgraph-base/trustgraph/base/metrics.py @@ -0,0 +1,82 @@ + +from prometheus_client import start_http_server, Info, Enum, Histogram +from prometheus_client import Counter + +class ConsumerMetrics: + + def __init__(self, id, flow=None): + + self.id = id + self.flow = flow + + if not hasattr(__class__, "state_metric"): + __class__.state_metric = Enum( + 'consumer_state', 'Consumer state', + ["id", "flow"], + states=['stopped', 'running'] + ) + if not hasattr(__class__, "request_metric"): + __class__.request_metric = Histogram( + 'request_latency', 'Request latency (seconds)', + ["id", "flow"], + ) + if not hasattr(__class__, "processing_metric"): + __class__.processing_metric = Counter( + 'processing_count', 'Processing count', + ["id", "flow", "status"] + ) + if not hasattr(__class__, "rate_limit_metric"): + __class__.rate_limit_metric = Counter( + 'rate_limit_count', 'Rate limit event count', + ["id", "flow"] + ) + + def process(self, status): + __class__.processing_metric.labels( + id=self.id, flow=self.flow, status=status + ).inc() + + def rate_limit(self): + __class__.rate_limit_metric.labels( + id=self.id, flow=self.flow + ).inc() + + def state(self, state): + __class__.state_metric.labels( + id=self.id, flow=self.flow + ).state(state) + + def record_time(self): + return __class__.request_metric.labels( + id=self.id, flow=self.flow + ).time() + +class ProducerMetrics: + def __init__(self, id, flow=None): + + self.id = id + self.flow = flow + + if not hasattr(__class__, "output_metric"): + __class__.output_metric = Counter( + 'output_count', 'Output items created', + ["id", "flow"] + ) + + def inc(self): + __class__.output_metric.labels(id=self.id, flow=self.flow).inc() + +class ProcessorMetrics: + def __init__(self, id): + + self.id = id + + if not hasattr(__class__, "processor_metric"): + __class__.processor_metric = Info( + 'processor', 'Processor configuration', + ["id"] + ) + + def info(self, info): + __class__.processor_metric.labels(id=self.id).info(info) + diff --git a/trustgraph-base/trustgraph/base/producer.py b/trustgraph-base/trustgraph/base/producer.py index bc2d7791..bb665924 100644 --- a/trustgraph-base/trustgraph/base/producer.py +++ b/trustgraph-base/trustgraph/base/producer.py @@ -1,56 +1,69 @@ from pulsar.schema import JsonSchema -from prometheus_client import Info, Counter +import asyncio -from . base_processor import BaseProcessor +class Producer: -class Producer(BaseProcessor): + def __init__(self, client, topic, schema, metrics=None): + self.client = client + self.topic = topic + self.schema = schema - def __init__(self, **params): + self.metrics = metrics - output_queue = params.get("output_queue") - output_schema = params.get("output_schema") + self.running = True + self.producer = None - if not hasattr(__class__, "output_metric"): - __class__.output_metric = Counter( - 'output_count', 'Output items created' - ) + def __del__(self): - if not hasattr(__class__, "pubsub_metric"): - __class__.pubsub_metric = Info( - 'pubsub', 'Pub/sub configuration' - ) + self.running = False - __class__.pubsub_metric.info({ - "output_queue": output_queue, - "output_schema": output_schema.__name__, - }) + if hasattr(self, "producer"): + if self.producer: + self.producer.close() - super(Producer, self).__init__(**params) + async def start(self): + self.running = True - if output_schema == None: - raise RuntimeError("output_schema must be specified") - - self.producer = self.client.create_producer( - topic=output_queue, - schema=JsonSchema(output_schema), - chunking_enabled=True, - ) + async def stop(self): + self.running = False async def send(self, msg, properties={}): - self.producer.send(msg, properties) - __class__.output_metric.inc() - @staticmethod - def add_args( - parser, default_input_queue, default_subscriber, - default_output_queue, - ): + if not self.running: return - BaseProcessor.add_args(parser) + while self.running and self.producer is None: + + try: + print("Connect publisher to", self.topic, "...", flush=True) + self.producer = self.client.create_producer( + topic = self.topic, + schema = JsonSchema(self.schema) + ) + print("Connected to", self.topic, flush=True) + except Exception as e: + print("Exception:", e, flush=True) + await asyncio.sleep(2) + + if not self.running: break + + while self.running: + + try: + + await asyncio.to_thread( + self.producer.send, + msg, properties + ) + + if self.metrics: + self.metrics.inc() + + # Delivery success, break out of loop + break + + except Exception as e: + print("Exception:", e, flush=True) + self.producer.close() + self.producer = None - parser.add_argument( - '-o', '--output-queue', - default=default_output_queue, - help=f'Output queue (default: {default_output_queue})' - ) diff --git a/trustgraph-base/trustgraph/base/producer_spec.py b/trustgraph-base/trustgraph/base/producer_spec.py new file mode 100644 index 00000000..9007f48b --- /dev/null +++ b/trustgraph-base/trustgraph/base/producer_spec.py @@ -0,0 +1,25 @@ + +from . producer import Producer +from . metrics import ProducerMetrics +from . spec import Spec + +class ProducerSpec(Spec): + def __init__(self, name, schema): + self.name = name + self.schema = schema + + def add(self, flow, processor, definition): + + producer_metrics = ProducerMetrics( + flow.id, f"{flow.name}-{self.name}" + ) + + producer = Producer( + client = processor.client, + topic = definition[self.name], + schema = self.schema, + metrics = producer_metrics, + ) + + flow.producer[self.name] = producer + diff --git a/trustgraph-base/trustgraph/base/prompt_client.py b/trustgraph-base/trustgraph/base/prompt_client.py new file mode 100644 index 00000000..9e8ab033 --- /dev/null +++ b/trustgraph-base/trustgraph/base/prompt_client.py @@ -0,0 +1,93 @@ + +import json + +from . request_response_spec import RequestResponse, RequestResponseSpec +from .. schema import PromptRequest, PromptResponse + +class PromptClient(RequestResponse): + + async def prompt(self, id, variables, timeout=600): + + resp = await self.request( + PromptRequest( + id = id, + terms = { + k: json.dumps(v) + for k, v in variables.items() + } + ), + timeout=timeout + ) + + if resp.error: + raise RuntimeError(resp.error.message) + + if resp.text: return resp.text + + return json.loads(resp.object) + + async def extract_definitions(self, text, timeout=600): + return await self.prompt( + id = "extract-definitions", + variables = { "text": text }, + timeout = timeout, + ) + + async def extract_relationships(self, text, timeout=600): + return await self.prompt( + id = "extract-relationships", + variables = { "text": text }, + timeout = timeout, + ) + + async def kg_prompt(self, query, kg, timeout=600): + return await self.prompt( + id = "kg-prompt", + variables = { + "query": query, + "knowledge": [ + { "s": v[0], "p": v[1], "o": v[2] } + for v in kg + ] + }, + timeout = timeout, + ) + + async def document_prompt(self, query, documents, timeout=600): + return await self.prompt( + id = "document-prompt", + variables = { + "query": query, + "documents": documents, + }, + timeout = timeout, + ) + + async def agent_react(self, variables, timeout=600): + return await self.prompt( + id = "agent-react", + variables = variables, + timeout = timeout, + ) + + async def question(self, question, timeout=600): + return await self.prompt( + id = "question", + variables = { + "question": question, + }, + timeout = timeout, + ) + +class PromptClientSpec(RequestResponseSpec): + def __init__( + self, request_name, response_name, + ): + super(PromptClientSpec, self).__init__( + request_name = request_name, + request_schema = PromptRequest, + response_name = response_name, + response_schema = PromptResponse, + impl = PromptClient, + ) + diff --git a/trustgraph-base/trustgraph/base/publisher.py b/trustgraph-base/trustgraph/base/publisher.py index 2da63331..ce9e364e 100644 --- a/trustgraph-base/trustgraph/base/publisher.py +++ b/trustgraph-base/trustgraph/base/publisher.py @@ -1,47 +1,52 @@ -import queue +from pulsar.schema import JsonSchema + +import asyncio import time import pulsar -import threading class Publisher: - def __init__(self, pulsar_client, topic, schema=None, max_size=10, + def __init__(self, client, topic, schema=None, max_size=10, chunking_enabled=True): - self.client = pulsar_client + self.client = client self.topic = topic self.schema = schema - self.q = queue.Queue(maxsize=max_size) + self.q = asyncio.Queue(maxsize=max_size) self.chunking_enabled = chunking_enabled self.running = True - def start(self): - self.task = threading.Thread(target=self.run) - self.task.start() + async def start(self): + self.task = asyncio.create_task(self.run()) - def stop(self): + async def stop(self): self.running = False - def join(self): - self.stop() - self.task.join() + async def join(self): + await self.stop() + await self.task - def run(self): + async def run(self): while self.running: try: producer = self.client.create_producer( topic=self.topic, - schema=self.schema, + schema=JsonSchema(self.schema), chunking_enabled=self.chunking_enabled, ) while self.running: try: - id, item = self.q.get(timeout=0.5) - except queue.Empty: + id, item = await asyncio.wait_for( + self.q.get(), + timeout=0.5 + ) + except asyncio.TimeoutError: + continue + except asyncio.QueueEmpty: continue if id: @@ -55,7 +60,6 @@ class Publisher: # If handler drops out, sleep a retry time.sleep(2) - def send(self, id, msg): - self.q.put((id, msg)) + async def send(self, id, item): + await self.q.put((id, item)) - diff --git a/trustgraph-base/trustgraph/base/pubsub.py b/trustgraph-base/trustgraph/base/pubsub.py new file mode 100644 index 00000000..b9f233d4 --- /dev/null +++ b/trustgraph-base/trustgraph/base/pubsub.py @@ -0,0 +1,80 @@ + +import os +import pulsar +import uuid +from pulsar.schema import JsonSchema + +from .. log_level import LogLevel + +class PulsarClient: + + default_pulsar_host = os.getenv("PULSAR_HOST", 'pulsar://pulsar:6650') + default_pulsar_api_key = os.getenv("PULSAR_API_KEY", None) + + def __init__(self, **params): + + self.client = None + + pulsar_host = params.get("pulsar_host", self.default_pulsar_host) + pulsar_listener = params.get("pulsar_listener", None) + pulsar_api_key = params.get( + "pulsar_api_key", + self.default_pulsar_api_key + ) + log_level = params.get("log_level", LogLevel.INFO) + + self.pulsar_host = pulsar_host + self.pulsar_api_key = pulsar_api_key + + if pulsar_api_key: + auth = pulsar.AuthenticationToken(pulsar_api_key) + self.client = pulsar.Client( + pulsar_host, + authentication=auth, + logger=pulsar.ConsoleLogger(log_level.to_pulsar()) + ) + else: + self.client = pulsar.Client( + pulsar_host, + listener_name=pulsar_listener, + logger=pulsar.ConsoleLogger(log_level.to_pulsar()) + ) + + self.pulsar_listener = pulsar_listener + + def close(self): + self.client.close() + + def __del__(self): + + if hasattr(self, "client"): + if self.client: + self.client.close() + + @staticmethod + def add_args(parser): + + parser.add_argument( + '-p', '--pulsar-host', + default=__class__.default_pulsar_host, + help=f'Pulsar host (default: {__class__.default_pulsar_host})', + ) + + parser.add_argument( + '--pulsar-api-key', + default=__class__.default_pulsar_api_key, + help=f'Pulsar API key', + ) + + parser.add_argument( + '--pulsar-listener', + help=f'Pulsar listener (default: none)', + ) + + parser.add_argument( + '-l', '--log-level', + type=LogLevel, + default=LogLevel.INFO, + choices=list(LogLevel), + help=f'Output queue (default: info)' + ) diff --git a/trustgraph-base/trustgraph/base/request_response_spec.py b/trustgraph-base/trustgraph/base/request_response_spec.py new file mode 100644 index 00000000..dcfcbf9b --- /dev/null +++ b/trustgraph-base/trustgraph/base/request_response_spec.py @@ -0,0 +1,136 @@ + +import uuid +import asyncio + +from . subscriber import Subscriber +from . producer import Producer +from . spec import Spec +from . metrics import ConsumerMetrics, ProducerMetrics + +class RequestResponse(Subscriber): + + def __init__( + self, client, subscription, consumer_name, + request_topic, request_schema, + request_metrics, + response_topic, response_schema, + response_metrics, + ): + + super(RequestResponse, self).__init__( + client = client, + subscription = subscription, + consumer_name = consumer_name, + topic = response_topic, + schema = response_schema, + ) + + self.producer = Producer( + client = client, + topic = request_topic, + schema = request_schema, + metrics = request_metrics, + ) + + async def start(self): + await self.producer.start() + await super(RequestResponse, self).start() + + async def stop(self): + await self.producer.stop() + await super(RequestResponse, self).stop() + + async def request(self, req, timeout=300, recipient=None): + + id = str(uuid.uuid4()) + + print("Request", id, "...", flush=True) + + q = await self.subscribe(id) + + try: + + await self.producer.send( + req, + properties={"id": id} + ) + + except Exception as e: + + print("Exception:", e) + raise e + + + try: + + while True: + + resp = await asyncio.wait_for( + q.get(), + timeout=timeout + ) + + print("Got response.", flush=True) + + if recipient is None: + + # If no recipient handler, just return the first + # response we get + return resp + else: + + # Recipient handler gets to decide when we're done b + # returning a boolean + fin = await recipient(resp) + + # If done, return the last result otherwise loop round for + # next response + if fin: + return resp + else: + continue + + except Exception as e: + + print("Exception:", e) + raise e + + finally: + + await self.unsubscribe(id) + +# This deals with the request/response case. The caller needs to +# use another service in request/response mode. Uses two topics: +# - we send on the request topic as a producer +# - we receive on the response topic as a subscriber +class RequestResponseSpec(Spec): + def __init__( + self, request_name, request_schema, response_name, + response_schema, impl=RequestResponse + ): + self.request_name = request_name + self.request_schema = request_schema + self.response_name = response_name + self.response_schema = response_schema + self.impl = impl + + def add(self, flow, processor, definition): + + producer_metrics = ProducerMetrics( + flow.id, f"{flow.name}-{self.response_name}" + ) + + rr = self.impl( + client = processor.client, + subscription = flow.id, + consumer_name = flow.id, + request_topic = definition[self.request_name], + request_schema = self.request_schema, + request_metrics = producer_metrics, + response_topic = definition[self.response_name], + response_schema = self.response_schema, + response_metrics = None, + ) + + flow.consumer[self.request_name] = rr + diff --git a/trustgraph-base/trustgraph/base/setting_spec.py b/trustgraph-base/trustgraph/base/setting_spec.py new file mode 100644 index 00000000..5c5152b2 --- /dev/null +++ b/trustgraph-base/trustgraph/base/setting_spec.py @@ -0,0 +1,19 @@ + +from . spec import Spec + +class Setting: + def __init__(self, value): + self.value = value + async def start(): + pass + async def stop(): + pass + +class SettingSpec(Spec): + def __init__(self, name): + self.name = name + + def add(self, flow, processor, definition): + + flow.config[self.name] = Setting(definition[self.name]) + diff --git a/trustgraph-base/trustgraph/base/spec.py b/trustgraph-base/trustgraph/base/spec.py new file mode 100644 index 00000000..4d0d937b --- /dev/null +++ b/trustgraph-base/trustgraph/base/spec.py @@ -0,0 +1,4 @@ + +class Spec: + pass + diff --git a/trustgraph-base/trustgraph/base/subscriber.py b/trustgraph-base/trustgraph/base/subscriber.py index 30ade3ee..a8ff58f7 100644 --- a/trustgraph-base/trustgraph/base/subscriber.py +++ b/trustgraph-base/trustgraph/base/subscriber.py @@ -1,14 +1,14 @@ -import queue -import pulsar -import threading +from pulsar.schema import JsonSchema +import asyncio +import _pulsar import time class Subscriber: - def __init__(self, pulsar_client, topic, subscription, consumer_name, + def __init__(self, client, topic, subscription, consumer_name, schema=None, max_size=100): - self.client = pulsar_client + self.client = client self.topic = topic self.subscription = subscription self.consumer_name = consumer_name @@ -16,35 +16,50 @@ class Subscriber: self.q = {} self.full = {} self.max_size = max_size - self.lock = threading.Lock() + self.lock = asyncio.Lock() self.running = True - def start(self): - self.task = threading.Thread(target=self.run) - self.task.start() - - def stop(self): + async def __del__(self): self.running = False - def join(self): - self.task.join() + async def start(self): + self.task = asyncio.create_task(self.run()) - def run(self): + async def stop(self): + self.running = False + + async def join(self): + await self.stop() + await self.task + + async def run(self): while self.running: try: consumer = self.client.subscribe( - topic=self.topic, - subscription_name=self.subscription, - consumer_name=self.consumer_name, - schema=self.schema, + topic = self.topic, + subscription_name = self.subscription, + consumer_name = self.consumer_name, + schema = JsonSchema(self.schema), ) + print("Subscriber running...", flush=True) + while self.running: - msg = consumer.receive() + try: + msg = await asyncio.to_thread( + consumer.receive, + timeout_millis=2000 + ) + except _pulsar.Timeout: + continue + except Exception as e: + print("Exception:", e, flush=True) + print(type(e)) + raise e # Acknowledge successful reception of the message consumer.acknowledge(msg) @@ -56,57 +71,68 @@ class Subscriber: value = msg.value() - with self.lock: + async with self.lock: + + # FIXME: Hard-coded timeouts if id in self.q: + try: # FIXME: Timeout means data goes missing - self.q[id].put(value, timeout=0.5) - except: - pass + await asyncio.wait_for( + self.q[id].put(value), + timeout=2 + ) + except Exception as e: + print("Q Put:", e, flush=True) for q in self.full.values(): try: # FIXME: Timeout means data goes missing - q.put(value, timeout=0.5) - except: - pass + await asyncio.wait_for( + q.put(value), + timeout=2 + ) + except Exception as e: + print("Q Put:", e, flush=True) except Exception as e: - print("Exception:", e, flush=True) + print("Subscriber exception:", e, flush=True) + + consumer.close() # If handler drops out, sleep a retry time.sleep(2) - def subscribe(self, id): + async def subscribe(self, id): - with self.lock: + async with self.lock: - q = queue.Queue(maxsize=self.max_size) + q = asyncio.Queue(maxsize=self.max_size) self.q[id] = q return q - def unsubscribe(self, id): + async def unsubscribe(self, id): - with self.lock: + async with self.lock: if id in self.q: # self.q[id].shutdown(immediate=True) del self.q[id] - def subscribe_all(self, id): + async def subscribe_all(self, id): - with self.lock: + async with self.lock: - q = queue.Queue(maxsize=self.max_size) + q = asyncio.Queue(maxsize=self.max_size) self.full[id] = q return q - def unsubscribe_all(self, id): + async def unsubscribe_all(self, id): - with self.lock: + async with self.lock: if id in self.full: # self.full[id].shutdown(immediate=True) diff --git a/trustgraph-base/trustgraph/base/subscriber_spec.py b/trustgraph-base/trustgraph/base/subscriber_spec.py new file mode 100644 index 00000000..2f89290b --- /dev/null +++ b/trustgraph-base/trustgraph/base/subscriber_spec.py @@ -0,0 +1,30 @@ + +from . metrics import ConsumerMetrics +from . subscriber import Subscriber +from . spec import Spec + +class SubscriberSpec(Spec): + + def __init__(self, name, schema): + self.name = name + self.schema = schema + + def add(self, flow, processor, definition): + + # FIXME: Metrics not used + subscriber_metrics = ConsumerMetrics( + flow.id, f"{flow.name}-{self.name}" + ) + + subscriber = Subscriber( + client = processor.client, + topic = definition[self.name], + subscription = flow.id, + consumer_name = flow.id, + schema = self.schema, + ) + + # Put it in the consumer map, does that work? + # It means it gets start/stop call. + flow.consumer[self.name] = subscriber + diff --git a/trustgraph-base/trustgraph/base/text_completion_client.py b/trustgraph-base/trustgraph/base/text_completion_client.py new file mode 100644 index 00000000..aba2fada --- /dev/null +++ b/trustgraph-base/trustgraph/base/text_completion_client.py @@ -0,0 +1,30 @@ + +from . request_response_spec import RequestResponse, RequestResponseSpec +from .. schema import TextCompletionRequest, TextCompletionResponse + +class TextCompletionClient(RequestResponse): + async def text_completion(self, system, prompt, timeout=600): + resp = await self.request( + TextCompletionRequest( + system = system, prompt = prompt + ), + timeout=timeout + ) + + if resp.error: + raise RuntimeError(resp.error.message) + + return resp.response + +class TextCompletionClientSpec(RequestResponseSpec): + def __init__( + self, request_name, response_name, + ): + super(TextCompletionClientSpec, self).__init__( + request_name = request_name, + request_schema = TextCompletionRequest, + response_name = response_name, + response_schema = TextCompletionResponse, + impl = TextCompletionClient, + ) + diff --git a/trustgraph-base/trustgraph/base/triples_client.py b/trustgraph-base/trustgraph/base/triples_client.py new file mode 100644 index 00000000..c9f747b5 --- /dev/null +++ b/trustgraph-base/trustgraph/base/triples_client.py @@ -0,0 +1,61 @@ + +from . request_response_spec import RequestResponse, RequestResponseSpec +from .. schema import TriplesQueryRequest, TriplesQueryResponse, Value +from .. knowledge import Uri, Literal + +class Triple: + def __init__(self, s, p, o): + self.s = s + self.p = p + self.o = o + +def to_value(x): + if x.is_uri: return Uri(x.value) + return Literal(x.value) + +def from_value(x): + if x is None: return None + if isinstance(x, Uri): + return Value(value=str(x), is_uri=True) + else: + return Value(value=str(x), is_uri=False) + +class TriplesClient(RequestResponse): + async def query(self, s=None, p=None, o=None, limit=20, + user="trustgraph", collection="default", + timeout=30): + + resp = await self.request( + TriplesQueryRequest( + s = from_value(s), + p = from_value(p), + o = from_value(o), + limit = limit, + user = user, + collection = collection, + ), + timeout=timeout + ) + + if resp.error: + raise RuntimeError(resp.error.message) + + triples = [ + Triple(to_value(v.s), to_value(v.p), to_value(v.o)) + for v in resp.triples + ] + + return triples + +class TriplesClientSpec(RequestResponseSpec): + def __init__( + self, request_name, response_name, + ): + super(TriplesClientSpec, self).__init__( + request_name = request_name, + request_schema = TriplesQueryRequest, + response_name = response_name, + response_schema = TriplesQueryResponse, + impl = TriplesClient, + ) + diff --git a/trustgraph-base/trustgraph/base/triples_query_service.py b/trustgraph-base/trustgraph/base/triples_query_service.py new file mode 100644 index 00000000..37acc622 --- /dev/null +++ b/trustgraph-base/trustgraph/base/triples_query_service.py @@ -0,0 +1,82 @@ + +""" +Triples query service. Input is a (s, p, o) triple, some values may be +null. Output is a list of triples. +""" + +from .. schema import TriplesQueryRequest, TriplesQueryResponse, Error +from .. schema import Value, Triple + +from . flow_processor import FlowProcessor +from . consumer_spec import ConsumerSpec +from . producer_spec import ProducerSpec + +default_ident = "triples-query" + +class TriplesQueryService(FlowProcessor): + + def __init__(self, **params): + + id = params.get("id") + + super(TriplesQueryService, self).__init__(**params | { "id": id }) + + self.register_specification( + ConsumerSpec( + name = "request", + schema = TriplesQueryRequest, + handler = self.on_message + ) + ) + + self.register_specification( + ProducerSpec( + name = "response", + schema = TriplesQueryResponse, + ) + ) + + async def on_message(self, msg, consumer, flow): + + try: + + request = msg.value() + + # Sender-produced ID + id = msg.properties()["id"] + + print(f"Handling input {id}...", flush=True) + + triples = await self.query_triples(request) + + print("Send response...", flush=True) + r = TriplesQueryResponse(triples=triples, error=None) + await flow("response").send(r, properties={"id": id}) + + print("Done.", flush=True) + + except Exception as e: + + print(f"Exception: {e}") + + print("Send error response...", flush=True) + + r = TriplesQueryResponse( + error = Error( + type = "triples-query-error", + message = str(e), + ), + triples = None, + ) + + await flow("response").send(r, properties={"id": id}) + + @staticmethod + def add_args(parser): + + FlowProcessor.add_args(parser) + +def run(): + + Processor.launch(default_ident, __doc__) + diff --git a/trustgraph-base/trustgraph/base/triples_store_service.py b/trustgraph-base/trustgraph/base/triples_store_service.py new file mode 100644 index 00000000..74f95f57 --- /dev/null +++ b/trustgraph-base/trustgraph/base/triples_store_service.py @@ -0,0 +1,47 @@ + +""" +Triples store base class +""" + +from .. schema import Triples +from .. base import FlowProcessor, ConsumerSpec + +default_ident = "triples-write" + +class TriplesStoreService(FlowProcessor): + + def __init__(self, **params): + + id = params.get("id") + + super(TriplesStoreService, self).__init__(**params | { "id": id }) + + self.register_specification( + ConsumerSpec( + name = "input", + schema = Triples, + handler = self.on_message + ) + ) + + async def on_message(self, msg, consumer, flow): + + try: + + request = msg.value() + + await self.store_triples(request) + + except TooManyRequests as e: + raise e + + except Exception as e: + + print(f"Exception: {e}") + raise e + + @staticmethod + def add_args(parser): + + FlowProcessor.add_args(parser) + diff --git a/trustgraph-base/trustgraph/schema/agent.py b/trustgraph-base/trustgraph/schema/agent.py index 9bcdde51..ee20a9aa 100644 --- a/trustgraph-base/trustgraph/schema/agent.py +++ b/trustgraph-base/trustgraph/schema/agent.py @@ -26,12 +26,5 @@ class AgentResponse(Record): thought = String() observation = String() -agent_request_queue = topic( - 'agent', kind='non-persistent', namespace='request' -) -agent_response_queue = topic( - 'agent', kind='non-persistent', namespace='response' -) - ############################################################################ diff --git a/trustgraph-base/trustgraph/schema/config.py b/trustgraph-base/trustgraph/schema/config.py index efe49182..3be63aa3 100644 --- a/trustgraph-base/trustgraph/schema/config.py +++ b/trustgraph-base/trustgraph/schema/config.py @@ -2,7 +2,7 @@ from pulsar.schema import Record, Bytes, String, Boolean, Array, Map, Integer from . topic import topic -from . types import Error, RowSchema +from . types import Error ############################################################################ diff --git a/trustgraph-base/trustgraph/schema/documents.py b/trustgraph-base/trustgraph/schema/documents.py index fd0049ee..e479371d 100644 --- a/trustgraph-base/trustgraph/schema/documents.py +++ b/trustgraph-base/trustgraph/schema/documents.py @@ -11,8 +11,6 @@ class Document(Record): metadata = Metadata() data = Bytes() -document_ingest_queue = topic('document-load') - ############################################################################ # Text documents / text from PDF @@ -21,8 +19,6 @@ class TextDocument(Record): metadata = Metadata() text = Bytes() -text_ingest_queue = topic('text-document-load') - ############################################################################ # Chunks of text @@ -31,8 +27,6 @@ class Chunk(Record): metadata = Metadata() chunk = Bytes() -chunk_ingest_queue = topic('chunk-load') - ############################################################################ # Document embeddings are embeddings associated with a chunk @@ -46,8 +40,6 @@ class DocumentEmbeddings(Record): metadata = Metadata() chunks = Array(ChunkEmbeddings()) -document_embeddings_store_queue = topic('document-embeddings-store') - ############################################################################ # Doc embeddings query @@ -62,10 +54,3 @@ class DocumentEmbeddingsResponse(Record): error = Error() documents = Array(Bytes()) -document_embeddings_request_queue = topic( - 'doc-embeddings', kind='non-persistent', namespace='request' -) -document_embeddings_response_queue = topic( - 'doc-embeddings', kind='non-persistent', namespace='response', -) - diff --git a/trustgraph-base/trustgraph/schema/flows.py b/trustgraph-base/trustgraph/schema/flows.py new file mode 100644 index 00000000..5ac51a37 --- /dev/null +++ b/trustgraph-base/trustgraph/schema/flows.py @@ -0,0 +1,66 @@ + +from pulsar.schema import Record, Bytes, String, Boolean, Array, Map, Integer + +from . topic import topic +from . types import Error + +############################################################################ + +# Flow service: +# list_classes() -> (classname[]) +# get_class(classname) -> (class) +# put_class(class) -> (class) +# delete_class(classname) -> () +# +# list_flows() -> (flowid[]) +# get_flow(flowid) -> (flow) +# start_flow(flowid, classname) -> () +# stop_flow(flowid) -> () + +# Prompt services, abstract the prompt generation +class FlowRequest(Record): + + operation = String() # list_classes, get_class, put_class, delete_class + # list_flows, get_flow, start_flow, stop_flow + + # get_class, put_class, delete_class, start_flow + class_name = String() + + # put_class + class = String() + + # start_flow + description = String() + + # get_flow, start_flow, stop_flow + flow_id = String() + +class FlowResponse(Record): + + # list_classes + class_names = Array(String()) + + # list_flows + flow_ids = Array(String()) + + # get_class + class = String() + + # get_flow + flow = String() + + # get_flow + description = String() + + # Everything + error = Error() + +flow_request_queue = topic( + 'flow', kind='non-persistent', namespace='request' +) +flow_response_queue = topic( + 'flow', kind='non-persistent', namespace='response' +) + +############################################################################ + diff --git a/trustgraph-bedrock/trustgraph/model/text_completion/bedrock/llm.py b/trustgraph-bedrock/trustgraph/model/text_completion/bedrock/llm.py index 9b8818a2..572e01b7 100755 --- a/trustgraph-bedrock/trustgraph/model/text_completion/bedrock/llm.py +++ b/trustgraph-bedrock/trustgraph/model/text_completion/bedrock/llm.py @@ -17,7 +17,7 @@ from .... log_level import LogLevel from .... base import ConsumerProducer from .... exceptions import TooManyRequests -module = ".".join(__name__.split(".")[1:-1]) +module = "text-completion" default_input_queue = text_completion_request_queue default_output_queue = text_completion_response_queue diff --git a/trustgraph-embeddings-hf/trustgraph/embeddings/hf/hf.py b/trustgraph-embeddings-hf/trustgraph/embeddings/hf/hf.py index 2e44821e..0ab3cef9 100755 --- a/trustgraph-embeddings-hf/trustgraph/embeddings/hf/hf.py +++ b/trustgraph-embeddings-hf/trustgraph/embeddings/hf/hf.py @@ -4,89 +4,37 @@ Embeddings service, applies an embeddings model selected from HuggingFace. Input is text, output is embeddings vector. """ +from ... base import EmbeddingsService + from langchain_huggingface import HuggingFaceEmbeddings -from trustgraph.schema import EmbeddingsRequest, EmbeddingsResponse, Error -from trustgraph.schema import embeddings_request_queue -from trustgraph.schema import embeddings_response_queue -from trustgraph.log_level import LogLevel -from trustgraph.base import ConsumerProducer +default_ident = "embeddings" -module = ".".join(__name__.split(".")[1:-1]) - -default_input_queue = embeddings_request_queue -default_output_queue = embeddings_response_queue -default_subscriber = module default_model="all-MiniLM-L6-v2" -class Processor(ConsumerProducer): +class Processor(EmbeddingsService): def __init__(self, **params): - input_queue = params.get("input_queue", default_input_queue) - output_queue = params.get("output_queue", default_output_queue) - subscriber = params.get("subscriber", default_subscriber) model = params.get("model", default_model) super(Processor, self).__init__( - **params | { - "input_queue": input_queue, - "output_queue": output_queue, - "subscriber": subscriber, - "input_schema": EmbeddingsRequest, - "output_schema": EmbeddingsResponse, - } + **params | { "model": model } ) + print("Get model...", flush=True) self.embeddings = HuggingFaceEmbeddings(model_name=model) - async def handle(self, msg): + async def on_embeddings(self, text): - v = msg.value() - - # Sender-produced ID - id = msg.properties()["id"] - - print(f"Handling input {id}...", flush=True) - - try: - - text = v.text - embeds = self.embeddings.embed_documents([text]) - - print("Send response...", flush=True) - r = EmbeddingsResponse(vectors=embeds, error=None) - await self.send(r, properties={"id": id}) - - print("Done.", flush=True) - - - except Exception as e: - - print(f"Exception: {e}") - - print("Send error response...", flush=True) - - r = EmbeddingsResponse( - error=Error( - type = "llm-error", - message = str(e), - ), - response=None, - ) - - await self.send(r, properties={"id": id}) - - self.consumer.acknowledge(msg) - + embeds = self.embeddings.embed_documents([text]) + print("Done.", flush=True) + return embeds @staticmethod def add_args(parser): - ConsumerProducer.add_args( - parser, default_input_queue, default_subscriber, - default_output_queue, - ) + EmbeddingsService.add_args(parser) parser.add_argument( '-m', '--model', @@ -96,5 +44,5 @@ class Processor(ConsumerProducer): def run(): - Processor.launch(module, __doc__) + Processor.launch(default_ident, __doc__) diff --git a/trustgraph-flow/trustgraph/agent/react/agent_manager.py b/trustgraph-flow/trustgraph/agent/react/agent_manager.py index a195bd80..d20b86f7 100644 --- a/trustgraph-flow/trustgraph/agent/react/agent_manager.py +++ b/trustgraph-flow/trustgraph/agent/react/agent_manager.py @@ -8,12 +8,11 @@ logger = logging.getLogger(__name__) class AgentManager: - def __init__(self, context, tools, additional_context=None): - self.context = context + def __init__(self, tools, additional_context=None): self.tools = tools self.additional_context = additional_context - def reason(self, question, history): + async def reason(self, question, history, context): tools = self.tools @@ -56,10 +55,7 @@ class AgentManager: logger.info(f"prompt: {variables}") - obj = self.context.prompt.request( - "agent-react", - variables - ) + obj = await context("prompt-request").agent_react(variables) print(json.dumps(obj, indent=4), flush=True) @@ -85,9 +81,13 @@ class AgentManager: return a - async def react(self, question, history, think, observe): + async def react(self, question, history, think, observe, context): - act = self.reason(question, history) + act = await self.reason( + question = question, + history = history, + context = context, + ) logger.info(f"act: {act}") if isinstance(act, Final): @@ -104,7 +104,12 @@ class AgentManager: else: raise RuntimeError(f"No action for {act.name}!") - resp = action.implementation.invoke(**act.arguments) + print("TOOL>>>", act) + resp = await action.implementation(context).invoke( + **act.arguments + ) + + print("RSETUL", resp) resp = resp.strip() diff --git a/trustgraph-flow/trustgraph/agent/react/service.py b/trustgraph-flow/trustgraph/agent/react/service.py index 224efe3c..beb17fd4 100755 --- a/trustgraph-flow/trustgraph/agent/react/service.py +++ b/trustgraph-flow/trustgraph/agent/react/service.py @@ -6,103 +6,68 @@ import json import re import sys -from pulsar.schema import JsonSchema +from ... base import AgentService, TextCompletionClientSpec, PromptClientSpec +from ... base import GraphRagClientSpec -from ... base import ConsumerProducer -from ... schema import Error -from ... schema import AgentRequest, AgentResponse, AgentStep -from ... schema import agent_request_queue, agent_response_queue -from ... schema import prompt_request_queue as pr_request_queue -from ... schema import prompt_response_queue as pr_response_queue -from ... schema import graph_rag_request_queue as gr_request_queue -from ... schema import graph_rag_response_queue as gr_response_queue -from ... clients.prompt_client import PromptClient -from ... clients.llm_client import LlmClient -from ... clients.graph_rag_client import GraphRagClient +from ... schema import AgentRequest, AgentResponse, AgentStep, Error from . tools import KnowledgeQueryImpl, TextCompletionImpl from . agent_manager import AgentManager from . types import Final, Action, Tool, Argument -module = ".".join(__name__.split(".")[1:-1]) +default_ident = "agent-manager" +default_max_iterations = 10 -default_input_queue = agent_request_queue -default_output_queue = agent_response_queue -default_subscriber = module -default_max_iterations = 15 - -class Processor(ConsumerProducer): +class Processor(AgentService): def __init__(self, **params): + id = params.get("id") + self.max_iterations = int( params.get("max_iterations", default_max_iterations) ) - tools = {} - - input_queue = params.get("input_queue", default_input_queue) - output_queue = params.get("output_queue", default_output_queue) - subscriber = params.get("subscriber", default_subscriber) - prompt_request_queue = params.get( - "prompt_request_queue", pr_request_queue - ) - prompt_response_queue = params.get( - "prompt_response_queue", pr_response_queue - ) - graph_rag_request_queue = params.get( - "graph_rag_request_queue", gr_request_queue - ) - graph_rag_response_queue = params.get( - "graph_rag_response_queue", gr_response_queue - ) - self.config_key = params.get("config_type", "agent") super(Processor, self).__init__( **params | { - "input_queue": input_queue, - "output_queue": output_queue, - "subscriber": subscriber, - "input_schema": AgentRequest, - "output_schema": AgentResponse, - "prompt_request_queue": prompt_request_queue, - "prompt_response_queue": prompt_response_queue, - "graph_rag_request_queue": gr_request_queue, - "graph_rag_response_queue": gr_response_queue, + "id": id, + "max_iterations": self.max_iterations, + "config_type": self.config_key, } ) - self.prompt = PromptClient( - subscriber=subscriber, - input_queue=prompt_request_queue, - output_queue=prompt_response_queue, - pulsar_host = self.pulsar_host, - pulsar_api_key=self.pulsar_api_key, - ) - - self.graph_rag = GraphRagClient( - subscriber=subscriber, - input_queue=graph_rag_request_queue, - output_queue=graph_rag_response_queue, - pulsar_host = self.pulsar_host, - pulsar_api_key=self.pulsar_api_key, - ) - - # Need to be able to feed requests to myself - self.recursive_input = self.client.create_producer( - topic=input_queue, - schema=JsonSchema(AgentRequest), - ) - self.agent = AgentManager( - context=self, tools=[], additional_context="", ) - async def on_config(self, version, config): + self.config_handlers.append(self.on_tools_config) + + self.register_specification( + TextCompletionClientSpec( + request_name = "text-completion-request", + response_name = "text-completion-response", + ) + ) + + self.register_specification( + GraphRagClientSpec( + request_name = "graph-rag-request", + response_name = "graph-rag-response", + ) + ) + + self.register_specification( + PromptClientSpec( + request_name = "prompt-request", + response_name = "prompt-response", + ) + ) + + async def on_tools_config(self, config, version): print("Loading configuration version", version) @@ -138,9 +103,9 @@ class Processor(ConsumerProducer): impl_id = data.get("type") if impl_id == "knowledge-query": - impl = KnowledgeQueryImpl(self) + impl = KnowledgeQueryImpl elif impl_id == "text-completion": - impl = TextCompletionImpl(self) + impl = TextCompletionImpl else: raise RuntimeError( f"Tool-kind {impl_id} not known" @@ -155,7 +120,6 @@ class Processor(ConsumerProducer): ) self.agent = AgentManager( - context=self, tools=tools, additional_context=additional ) @@ -164,19 +128,14 @@ class Processor(ConsumerProducer): except Exception as e: - print("Exception:", e, flush=True) + print("on_tools_config Exception:", e, flush=True) print("Configuration reload failed", flush=True) - async def handle(self, msg): + async def agent_request(self, request, respond, next, flow): try: - v = msg.value() - - # Sender-produced ID - id = msg.properties()["id"] - - if v.history: + if request.history: history = [ Action( thought=h.thought, @@ -184,12 +143,12 @@ class Processor(ConsumerProducer): arguments=h.arguments, observation=h.observation ) - for h in v.history + for h in request.history ] else: history = [] - print(f"Question: {v.question}", flush=True) + print(f"Question: {request.question}", flush=True) if len(history) >= self.max_iterations: raise RuntimeError("Too many agent iterations") @@ -207,7 +166,7 @@ class Processor(ConsumerProducer): observation=None, ) - await self.send(r, properties={"id": id}) + await respond(r) async def observe(x): @@ -220,15 +179,21 @@ class Processor(ConsumerProducer): observation=x, ) - await self.send(r, properties={"id": id}) + await respond(r) - act = await self.agent.react(v.question, history, think, observe) + act = await self.agent.react( + question = request.question, + history = history, + think = think, + observe = observe, + context = flow, + ) print(f"Action: {act}", flush=True) - print("Send response...", flush=True) + if isinstance(act, Final): - if type(act) == Final: + print("Send final response...", flush=True) r = AgentResponse( answer=act.final, @@ -236,18 +201,20 @@ class Processor(ConsumerProducer): thought=None, ) - await self.send(r, properties={"id": id}) + await respond(r) print("Done.", flush=True) return + print("Send next...", flush=True) + history.append(act) r = AgentRequest( - question=v.question, - plan=v.plan, - state=v.state, + question=request.question, + plan=request.plan, + state=request.state, history=[ AgentStep( thought=h.thought, @@ -259,7 +226,7 @@ class Processor(ConsumerProducer): ] ) - self.recursive_input.send(r, properties={"id": id}) + await next(r) print("Done.", flush=True) @@ -267,7 +234,7 @@ class Processor(ConsumerProducer): except Exception as e: - print(f"Exception: {e}") + print(f"agent_request Exception: {e}") print("Send error response...", flush=True) @@ -279,39 +246,12 @@ class Processor(ConsumerProducer): response=None, ) - await self.send(r, properties={"id": id}) + await respond(r) @staticmethod def add_args(parser): - ConsumerProducer.add_args( - parser, default_input_queue, default_subscriber, - default_output_queue, - ) - - parser.add_argument( - '--prompt-request-queue', - default=pr_request_queue, - help=f'Prompt request queue (default: {pr_request_queue})', - ) - - parser.add_argument( - '--prompt-response-queue', - default=pr_response_queue, - help=f'Prompt response queue (default: {pr_response_queue})', - ) - - parser.add_argument( - '--graph-rag-request-queue', - default=gr_request_queue, - help=f'Graph RAG request queue (default: {gr_request_queue})', - ) - - parser.add_argument( - '--graph-rag-response-queue', - default=gr_response_queue, - help=f'Graph RAG response queue (default: {gr_response_queue})', - ) + AgentService.add_args(parser) parser.add_argument( '--max-iterations', @@ -327,5 +267,5 @@ class Processor(ConsumerProducer): def run(): - Processor.launch(module, __doc__) + Processor.launch(default_ident, __doc__) diff --git a/trustgraph-flow/trustgraph/agent/react/tools.py b/trustgraph-flow/trustgraph/agent/react/tools.py index 023abc02..31568b25 100644 --- a/trustgraph-flow/trustgraph/agent/react/tools.py +++ b/trustgraph-flow/trustgraph/agent/react/tools.py @@ -4,16 +4,22 @@ class KnowledgeQueryImpl: def __init__(self, context): self.context = context - def invoke(self, **arguments): - return self.context.graph_rag.request(arguments.get("question")) + async def invoke(self, **arguments): + client = self.context("graph-rag-request") + print("Graph RAG question...", flush=True) + return await client.rag( + arguments.get("question") + ) # This tool implementation knows how to do text completion. This uses # the prompt service, rather than talking to TextCompletion directly. class TextCompletionImpl: def __init__(self, context): self.context = context - def invoke(self, **arguments): - return self.context.prompt.request( - "question", { "question": arguments.get("question") } + async def invoke(self, **arguments): + client = self.context("prompt-request") + print("Prompt question...", flush=True) + return await client.question( + arguments.get("question") ) diff --git a/trustgraph-flow/trustgraph/chunking/recursive/chunker.py b/trustgraph-flow/trustgraph/chunking/recursive/chunker.py index 82f333b5..aa48cc57 100755 --- a/trustgraph-flow/trustgraph/chunking/recursive/chunker.py +++ b/trustgraph-flow/trustgraph/chunking/recursive/chunker.py @@ -7,40 +7,27 @@ as text as separate output objects. from langchain_text_splitters import RecursiveCharacterTextSplitter from prometheus_client import Histogram -from ... schema import TextDocument, Chunk, Metadata -from ... schema import text_ingest_queue, chunk_ingest_queue -from ... log_level import LogLevel -from ... base import ConsumerProducer +from ... schema import TextDocument, Chunk +from ... base import FlowProcessor, ConsumerSpec, ProducerSpec -module = ".".join(__name__.split(".")[1:-1]) +default_ident = "chunker" -default_input_queue = text_ingest_queue -default_output_queue = chunk_ingest_queue -default_subscriber = module - -class Processor(ConsumerProducer): +class Processor(FlowProcessor): def __init__(self, **params): - input_queue = params.get("input_queue", default_input_queue) - output_queue = params.get("output_queue", default_output_queue) - subscriber = params.get("subscriber", default_subscriber) + id = params.get("id", default_ident) chunk_size = params.get("chunk_size", 2000) chunk_overlap = params.get("chunk_overlap", 100) super(Processor, self).__init__( - **params | { - "input_queue": input_queue, - "output_queue": output_queue, - "subscriber": subscriber, - "input_schema": TextDocument, - "output_schema": Chunk, - } + **params | { "id": id } ) if not hasattr(__class__, "chunk_metric"): __class__.chunk_metric = Histogram( 'chunk_size', 'Chunk size', + ["id", "flow"], buckets=[100, 160, 250, 400, 650, 1000, 1600, 2500, 4000, 6400, 10000, 16000] ) @@ -52,7 +39,24 @@ class Processor(ConsumerProducer): is_separator_regex=False, ) - async def handle(self, msg): + self.register_specification( + ConsumerSpec( + name = "input", + schema = TextDocument, + handler = self.on_message, + ) + ) + + self.register_specification( + ProducerSpec( + name = "output", + schema = Chunk, + ) + ) + + print("Chunker initialised", flush=True) + + async def on_message(self, msg, consumer, flow): v = msg.value() print(f"Chunking {v.metadata.id}...", flush=True) @@ -63,24 +67,25 @@ class Processor(ConsumerProducer): for ix, chunk in enumerate(texts): + print("Chunk", len(chunk.page_content), flush=True) + r = Chunk( metadata=v.metadata, chunk=chunk.page_content.encode("utf-8"), ) - __class__.chunk_metric.observe(len(chunk.page_content)) + __class__.chunk_metric.labels( + id=consumer.id, flow=consumer.flow + ).observe(len(chunk.page_content)) - await self.send(r) + await flow("output").send(r) print("Done.", flush=True) @staticmethod def add_args(parser): - ConsumerProducer.add_args( - parser, default_input_queue, default_subscriber, - default_output_queue, - ) + FlowProcessor.add_args(parser) parser.add_argument( '-z', '--chunk-size', @@ -98,5 +103,5 @@ class Processor(ConsumerProducer): def run(): - Processor.launch(module, __doc__) + Processor.launch(default_ident, __doc__) diff --git a/trustgraph-flow/trustgraph/chunking/token/chunker.py b/trustgraph-flow/trustgraph/chunking/token/chunker.py index c625b48c..ff217350 100755 --- a/trustgraph-flow/trustgraph/chunking/token/chunker.py +++ b/trustgraph-flow/trustgraph/chunking/token/chunker.py @@ -7,40 +7,27 @@ as text as separate output objects. from langchain_text_splitters import TokenTextSplitter from prometheus_client import Histogram -from ... schema import TextDocument, Chunk, Metadata -from ... schema import text_ingest_queue, chunk_ingest_queue -from ... log_level import LogLevel -from ... base import ConsumerProducer +from ... schema import TextDocument, Chunk +from ... base import FlowProcessor -module = ".".join(__name__.split(".")[1:-1]) +default_ident = "chunker" -default_input_queue = text_ingest_queue -default_output_queue = chunk_ingest_queue -default_subscriber = module - -class Processor(ConsumerProducer): +class Processor(FlowProcessor): def __init__(self, **params): - input_queue = params.get("input_queue", default_input_queue) - output_queue = params.get("output_queue", default_output_queue) - subscriber = params.get("subscriber", default_subscriber) + id = params.get("id") chunk_size = params.get("chunk_size", 250) chunk_overlap = params.get("chunk_overlap", 15) super(Processor, self).__init__( - **params | { - "input_queue": input_queue, - "output_queue": output_queue, - "subscriber": subscriber, - "input_schema": TextDocument, - "output_schema": Chunk, - } + **params | { "id": id } ) if not hasattr(__class__, "chunk_metric"): __class__.chunk_metric = Histogram( 'chunk_size', 'Chunk size', + ["id", "flow"], buckets=[100, 160, 250, 400, 650, 1000, 1600, 2500, 4000, 6400, 10000, 16000] ) @@ -51,7 +38,24 @@ class Processor(ConsumerProducer): chunk_overlap=chunk_overlap, ) - async def handle(self, msg): + self.register_specification( + ConsumerSpec( + name = "input", + schema = TextDocument, + handler = self.on_message, + ) + ) + + self.register_specification( + ProducerSpec( + name = "output", + schema = Chunk, + ) + ) + + print("Chunker initialised", flush=True) + + async def on_message(self, msg, consumer, flow): v = msg.value() print(f"Chunking {v.metadata.id}...", flush=True) @@ -62,24 +66,25 @@ class Processor(ConsumerProducer): for ix, chunk in enumerate(texts): + print("Chunk", len(chunk.page_content), flush=True) + r = Chunk( metadata=v.metadata, chunk=chunk.page_content.encode("utf-8"), ) - __class__.chunk_metric.observe(len(chunk.page_content)) + __class__.chunk_metric.labels( + id=consumer.id, flow=consumer.flow + ).observe(len(chunk.page_content)) - await self.send(r) + await flow("output").send(r) print("Done.", flush=True) @staticmethod def add_args(parser): - ConsumerProducer.add_args( - parser, default_input_queue, default_subscriber, - default_output_queue, - ) + FlowProcessor.add_args(parser) parser.add_argument( '-z', '--chunk-size', @@ -97,5 +102,5 @@ class Processor(ConsumerProducer): def run(): - Processor.launch(module, __doc__) + Processor.launch(default_ident, __doc__) diff --git a/trustgraph-flow/trustgraph/config/service/config.py b/trustgraph-flow/trustgraph/config/service/config.py new file mode 100644 index 00000000..46ade4c3 --- /dev/null +++ b/trustgraph-flow/trustgraph/config/service/config.py @@ -0,0 +1,215 @@ + +from trustgraph.schema import ConfigResponse +from trustgraph.schema import ConfigValue, Error + +# This behaves just like a dict, should be easier to add persistent storage +# later +class ConfigurationItems(dict): + pass + +class Configuration(dict): + + # FIXME: The state is held internally. This only works if there's + # one config service. Should be more than one, and use a + # back-end state store. + + def __init__(self, push): + + # Version counter + self.version = 0 + + # External function to respond to update + self.push = push + + def __getitem__(self, key): + if key not in self: + self[key] = ConfigurationItems() + return dict.__getitem__(self, key) + + async def handle_get(self, v): + + for k in v.keys: + if k.type not in self or k.key not in self[k.type]: + return ConfigResponse( + version = None, + values = None, + directory = None, + config = None, + error = Error( + type = "key-error", + message = f"Key error" + ) + ) + + values = [ + ConfigValue( + type = k.type, + key = k.key, + value = self[k.type][k.key] + ) + for k in v.keys + ] + + return ConfigResponse( + version = self.version, + values = values, + directory = None, + config = None, + error = None, + ) + + async def handle_list(self, v): + + if v.type not in self: + + return ConfigResponse( + version = None, + values = None, + directory = None, + config = None, + error = Error( + type = "key-error", + message = "No such type", + ), + ) + + return ConfigResponse( + version = self.version, + values = None, + directory = list(self[v.type].keys()), + config = None, + error = None, + ) + + async def handle_getvalues(self, v): + + if v.type not in self: + + return ConfigResponse( + version = None, + values = None, + directory = None, + config = None, + error = Error( + type = "key-error", + message = f"Key error" + ) + ) + + values = [ + ConfigValue( + type = v.type, + key = k, + value = self[v.type][k], + ) + for k in self[v.type] + ] + + return ConfigResponse( + version = self.version, + values = values, + directory = None, + config = None, + error = None, + ) + + async def handle_delete(self, v): + + for k in v.keys: + if k.type not in self or k.key not in self[k.type]: + return ConfigResponse( + version = None, + values = None, + directory = None, + config = None, + error = Error( + type = "key-error", + message = f"Key error" + ) + ) + + for k in v.keys: + del self[k.type][k.key] + + self.version += 1 + + await self.push() + + return ConfigResponse( + version = None, + value = None, + directory = None, + values = None, + config = None, + error = None, + ) + + async def handle_put(self, v): + + for k in v.values: + self[k.type][k.key] = k.value + + self.version += 1 + + await self.push() + + return ConfigResponse( + version = None, + value = None, + directory = None, + values = None, + error = None, + ) + + async def handle_config(self, v): + + return ConfigResponse( + version = self.version, + value = None, + directory = None, + values = None, + config = self, + error = None, + ) + + async def handle(self, msg): + + print("Handle message ", msg.operation) + + if msg.operation == "get": + + resp = await self.handle_get(msg) + + elif msg.operation == "list": + + resp = await self.handle_list(msg) + + elif msg.operation == "getvalues": + + resp = await self.handle_getvalues(msg) + + elif msg.operation == "delete": + + resp = await self.handle_delete(msg) + + elif msg.operation == "put": + + resp = await self.handle_put(msg) + + elif msg.operation == "config": + + resp = await self.handle_config(msg) + + else: + + resp = ConfigResponse( + value=None, + directory=None, + values=None, + error=Error( + type = "bad-operation", + message = "Bad operation" + ) + ) + + return resp diff --git a/trustgraph-flow/trustgraph/config/service/service.py b/trustgraph-flow/trustgraph/config/service/service.py index ee0c960e..47bac828 100644 --- a/trustgraph-flow/trustgraph/config/service/service.py +++ b/trustgraph-flow/trustgraph/config/service/service.py @@ -1,287 +1,118 @@ """ -Config service. Fetchs an extract from the Wikipedia page -using the API. +Config service. Manages system global configuration state """ from pulsar.schema import JsonSchema from trustgraph.schema import ConfigRequest, ConfigResponse, ConfigPush -from trustgraph.schema import ConfigValue, Error +from trustgraph.schema import Error from trustgraph.schema import config_request_queue, config_response_queue from trustgraph.schema import config_push_queue from trustgraph.log_level import LogLevel -from trustgraph.base import ConsumerProducer +from trustgraph.base import AsyncProcessor, Consumer, Producer -module = ".".join(__name__.split(".")[1:-1]) +from . config import Configuration +from ... base import ProcessorMetrics, ConsumerMetrics, ProducerMetrics +from ... base import Consumer, Producer -default_input_queue = config_request_queue -default_output_queue = config_response_queue +default_ident = "config-svc" + +default_request_queue = config_request_queue +default_response_queue = config_response_queue default_push_queue = config_push_queue -default_subscriber = module -# This behaves just like a dict, should be easier to add persistent storage -# later - -class ConfigurationItems(dict): - pass - -class Configuration(dict): - - def __getitem__(self, key): - if key not in self: - self[key] = ConfigurationItems() - return dict.__getitem__(self, key) - -class Processor(ConsumerProducer): +class Processor(AsyncProcessor): def __init__(self, **params): - - input_queue = params.get("input_queue", default_input_queue) - output_queue = params.get("output_queue", default_output_queue) + + request_queue = params.get("request_queue", default_request_queue) + response_queue = params.get("response_queue", default_response_queue) push_queue = params.get("push_queue", default_push_queue) - subscriber = params.get("subscriber", default_subscriber) + id = params.get("id") + + request_schema = ConfigRequest + response_schema = ConfigResponse + push_schema = ConfigResponse super(Processor, self).__init__( **params | { - "input_queue": input_queue, - "output_queue": output_queue, - "push_queue": output_queue, - "subscriber": subscriber, - "input_schema": ConfigRequest, - "output_schema": ConfigResponse, - "push_schema": ConfigPush, + "request_schema": request_schema.__name__, + "response_schema": response_schema.__name__, + "push_schema": push_schema.__name__, } ) - self.push_prod = self.client.create_producer( - topic=push_queue, - schema=JsonSchema(ConfigPush), + request_metrics = ConsumerMetrics(id + "-request") + response_metrics = ProducerMetrics(id + "-response") + push_metrics = ProducerMetrics(id + "-push") + + self.push_pub = Producer( + client = self.client, + topic = push_queue, + schema = ConfigPush, + metrics = push_metrics, ) - # FIXME: The state is held internally. This only works if there's - # one config service. Should be more than one, and use a - # back-end state store. - self.config = Configuration() + self.response_pub = Producer( + client = self.client, + topic = response_queue, + schema = ConfigResponse, + metrics = response_metrics, + ) - # Version counter - self.version = 0 + self.subs = Consumer( + taskgroup = self.taskgroup, + client = self.client, + flow = None, + topic = request_queue, + subscriber = id, + schema = request_schema, + handler = self.on_message, + metrics = request_metrics, + ) + + self.config = Configuration(self.push) + + print("Service initialised.") async def start(self): + await self.push() + await self.subs.start() - async def handle_get(self, v, id): - - for k in v.keys: - if k.type not in self.config or k.key not in self.config[k.type]: - return ConfigResponse( - version = None, - values = None, - directory = None, - config = None, - error = Error( - type = "key-error", - message = f"Key error" - ) - ) - - values = [ - ConfigValue( - type = k.type, - key = k.key, - value = self.config[k.type][k.key] - ) - for k in v.keys - ] - - return ConfigResponse( - version = self.version, - values = values, - directory = None, - config = None, - error = None, - ) - - async def handle_list(self, v, id): - - if v.type not in self.config: - - return ConfigResponse( - version = None, - values = None, - directory = None, - config = None, - error = Error( - type = "key-error", - message = "No such type", - ), - ) - - return ConfigResponse( - version = self.version, - values = None, - directory = list(self.config[v.type].keys()), - config = None, - error = None, - ) - - async def handle_getvalues(self, v, id): - - if v.type not in self.config: - - return ConfigResponse( - version = None, - values = None, - directory = None, - config = None, - error = Error( - type = "key-error", - message = f"Key error" - ) - ) - - values = [ - ConfigValue( - type = v.type, - key = k, - value = self.config[v.type][k], - ) - for k in self.config[v.type] - ] - - return ConfigResponse( - version = self.version, - values = values, - directory = None, - config = None, - error = None, - ) - - async def handle_delete(self, v, id): - - for k in v.keys: - if k.type not in self.config or k.key not in self.config[k.type]: - return ConfigResponse( - version = None, - values = None, - directory = None, - config = None, - error = Error( - type = "key-error", - message = f"Key error" - ) - ) - - for k in v.keys: - del self.config[k.type][k.key] - - self.version += 1 - - await self.push() - - return ConfigResponse( - version = None, - value = None, - directory = None, - values = None, - config = None, - error = None, - ) - - async def handle_put(self, v, id): - - for k in v.values: - self.config[k.type][k.key] = k.value - - self.version += 1 - - await self.push() - - return ConfigResponse( - version = None, - value = None, - directory = None, - values = None, - error = None, - ) - - async def handle_config(self, v, id): - - return ConfigResponse( - version = self.version, - value = None, - directory = None, - values = None, - config = self.config, - error = None, - ) - async def push(self): resp = ConfigPush( - version = self.version, + version = self.config.version, value = None, directory = None, values = None, config = self.config, error = None, ) - self.push_prod.send(resp) - print("Pushed.") + + await self.push_pub.send(resp) + + print("Pushed version ", self.config.version) - async def handle(self, msg): - - v = msg.value() - - # Sender-produced ID - id = msg.properties()["id"] - - print(f"Handling {id}...", flush=True) + async def on_message(self, msg, consumer, flow): try: - if v.operation == "get": + v = msg.value() - resp = await self.handle_get(v, id) + # Sender-produced ID + id = msg.properties()["id"] - elif v.operation == "list": + print(f"Handling {id}...", flush=True) - resp = await self.handle_list(v, id) + resp = await self.config.handle(v) - elif v.operation == "getvalues": - - resp = await self.handle_getvalues(v, id) - - elif v.operation == "delete": - - resp = await self.handle_delete(v, id) - - elif v.operation == "put": - - resp = await self.handle_put(v, id) - - elif v.operation == "config": - - resp = await self.handle_config(v, id) - - else: - - resp = ConfigResponse( - value=None, - directory=None, - values=None, - error=Error( - type = "bad-operation", - message = "Bad operation" - ) - ) - - await self.send(resp, properties={"id": id}) - - self.consumer.acknowledge(msg) + await self.response_pub.send(resp, properties={"id": id}) except Exception as e: - + resp = ConfigResponse( error=Error( type = "unexpected-error", @@ -289,24 +120,33 @@ class Processor(ConsumerProducer): ), text=None, ) - await self.send(resp, properties={"id": id}) - self.consumer.acknowledge(msg) + + await self.response_pub.send(resp, properties={"id": id}) @staticmethod def add_args(parser): - ConsumerProducer.add_args( - parser, default_input_queue, default_subscriber, - default_output_queue, + AsyncProcessor.add_args(parser) + + parser.add_argument( + '-q', '--request-queue', + default=default_request_queue, + help=f'Request queue (default: {default_request_queue})' ) parser.add_argument( - '-q', '--push-queue', + '-r', '--response-queue', + default=default_response_queue, + help=f'Response queue {default_response_queue}', + ) + + parser.add_argument( + '--push-queue', default=default_push_queue, help=f'Config push queue (default: {default_push_queue})' ) def run(): - Processor.launch(module, __doc__) + Processor.launch(default_ident, __doc__) diff --git a/trustgraph-flow/trustgraph/decoding/mistral_ocr/processor.py b/trustgraph-flow/trustgraph/decoding/mistral_ocr/processor.py index f5100244..e42d1601 100755 --- a/trustgraph-flow/trustgraph/decoding/mistral_ocr/processor.py +++ b/trustgraph-flow/trustgraph/decoding/mistral_ocr/processor.py @@ -17,12 +17,10 @@ from mistralai.models import OCRResponse from ... schema import Document, TextDocument, Metadata from ... schema import document_ingest_queue, text_ingest_queue from ... log_level import LogLevel -from ... base import ConsumerProducer +from ... base import InputOutputProcessor -module = ".".join(__name__.split(".")[1:-1]) +module = "ocr" -default_input_queue = document_ingest_queue -default_output_queue = text_ingest_queue default_subscriber = module default_api_key = os.getenv("MISTRAL_TOKEN") @@ -71,19 +69,17 @@ def get_combined_markdown(ocr_response: OCRResponse) -> str: return "\n\n".join(markdowns) -class Processor(ConsumerProducer): +class Processor(InputOutputProcessor): def __init__(self, **params): - input_queue = params.get("input_queue", default_input_queue) - output_queue = params.get("output_queue", default_output_queue) + id = params.get("id") subscriber = params.get("subscriber", default_subscriber) api_key = params.get("api_key", default_api_key) super(Processor, self).__init__( **params | { - "input_queue": input_queue, - "output_queue": output_queue, + "id": id, "subscriber": subscriber, "input_schema": Document, "output_schema": TextDocument, @@ -151,7 +147,7 @@ class Processor(ConsumerProducer): return markdown - async def handle(self, msg): + async def on_message(self, msg, consumer): print("PDF message received") @@ -166,17 +162,14 @@ class Processor(ConsumerProducer): text=markdown.encode("utf-8"), ) - await self.send(r) + await consumer.q.output.send(r) print("Done.", flush=True) @staticmethod def add_args(parser): - ConsumerProducer.add_args( - parser, default_input_queue, default_subscriber, - default_output_queue, - ) + InputOutputProcessor.add_args(parser, default_subscriber) parser.add_argument( '-k', '--api-key', diff --git a/trustgraph-flow/trustgraph/decoding/pdf/pdf_decoder.py b/trustgraph-flow/trustgraph/decoding/pdf/pdf_decoder.py index 5e5e3612..d0669a59 100755 --- a/trustgraph-flow/trustgraph/decoding/pdf/pdf_decoder.py +++ b/trustgraph-flow/trustgraph/decoding/pdf/pdf_decoder.py @@ -9,39 +9,43 @@ import base64 from langchain_community.document_loaders import PyPDFLoader from ... schema import Document, TextDocument, Metadata -from ... schema import document_ingest_queue, text_ingest_queue from ... log_level import LogLevel -from ... base import ConsumerProducer +from ... base import FlowProcessor, ConsumerSpec, ProducerSpec -module = ".".join(__name__.split(".")[1:-1]) +default_ident = "pdf-decoder" -default_input_queue = document_ingest_queue -default_output_queue = text_ingest_queue -default_subscriber = module - -class Processor(ConsumerProducer): +class Processor(FlowProcessor): def __init__(self, **params): - input_queue = params.get("input_queue", default_input_queue) - output_queue = params.get("output_queue", default_output_queue) - subscriber = params.get("subscriber", default_subscriber) + id = params.get("id", default_ident) super(Processor, self).__init__( **params | { - "input_queue": input_queue, - "output_queue": output_queue, - "subscriber": subscriber, - "input_schema": Document, - "output_schema": TextDocument, + "id": id, } ) - print("PDF inited") + self.register_specification( + ConsumerSpec( + name = "input", + schema = Document, + handler = self.on_message, + ) + ) - async def handle(self, msg): + self.register_specification( + ProducerSpec( + name = "output", + schema = TextDocument, + ) + ) - print("PDF message received") + print("PDF inited", flush=True) + + async def on_message(self, msg, consumer, flow): + + print("PDF message received", flush=True) v = msg.value() @@ -59,24 +63,22 @@ class Processor(ConsumerProducer): for ix, page in enumerate(pages): + print("page", ix, flush=True) + r = TextDocument( metadata=v.metadata, text=page.page_content.encode("utf-8"), ) - await self.send(r) + await flow("output").send(r) print("Done.", flush=True) @staticmethod def add_args(parser): - - ConsumerProducer.add_args( - parser, default_input_queue, default_subscriber, - default_output_queue, - ) + FlowProcessor.add_args(parser) def run(): - Processor.launch(module, __doc__) + Processor.launch(default_ident, __doc__) diff --git a/trustgraph-flow/trustgraph/document_rag.py b/trustgraph-flow/trustgraph/document_rag.py deleted file mode 100644 index 4fc4850a..00000000 --- a/trustgraph-flow/trustgraph/document_rag.py +++ /dev/null @@ -1,153 +0,0 @@ - -from . clients.document_embeddings_client import DocumentEmbeddingsClient -from . clients.triples_query_client import TriplesQueryClient -from . clients.embeddings_client import EmbeddingsClient -from . clients.prompt_client import PromptClient - -from . schema import DocumentEmbeddingsRequest, DocumentEmbeddingsResponse -from . schema import TriplesQueryRequest, TriplesQueryResponse -from . schema import prompt_request_queue -from . schema import prompt_response_queue -from . schema import embeddings_request_queue -from . schema import embeddings_response_queue -from . schema import document_embeddings_request_queue -from . schema import document_embeddings_response_queue - -LABEL="http://www.w3.org/2000/01/rdf-schema#label" -DEFINITION="http://www.w3.org/2004/02/skos/core#definition" - -class Query: - - def __init__( - self, rag, user, collection, verbose, - doc_limit=20 - ): - self.rag = rag - self.user = user - self.collection = collection - self.verbose = verbose - self.doc_limit = doc_limit - - def get_vector(self, query): - - if self.verbose: - print("Compute embeddings...", flush=True) - - qembeds = self.rag.embeddings.request(query) - - if self.verbose: - print("Done.", flush=True) - - return qembeds - - def get_docs(self, query): - - vectors = self.get_vector(query) - - if self.verbose: - print("Get entities...", flush=True) - - docs = self.rag.de_client.request( - vectors, limit=self.doc_limit - ) - - if self.verbose: - print("Docs:", flush=True) - for doc in docs: - print(doc, flush=True) - - return docs - -class DocumentRag: - - def __init__( - self, - pulsar_host="pulsar://pulsar:6650", - pulsar_api_key=None, - pr_request_queue=None, - pr_response_queue=None, - emb_request_queue=None, - emb_response_queue=None, - de_request_queue=None, - de_response_queue=None, - verbose=False, - module="test", - ): - - self.verbose=verbose - - if pr_request_queue is None: - pr_request_queue = prompt_request_queue - - if pr_response_queue is None: - pr_response_queue = prompt_response_queue - - if emb_request_queue is None: - emb_request_queue = embeddings_request_queue - - if emb_response_queue is None: - emb_response_queue = embeddings_response_queue - - if de_request_queue is None: - de_request_queue = document_embeddings_request_queue - - if de_response_queue is None: - de_response_queue = document_embeddings_response_queue - - if self.verbose: - print("Initialising...", flush=True) - - self.de_client = DocumentEmbeddingsClient( - pulsar_host=pulsar_host, - subscriber=module + "-de", - input_queue=de_request_queue, - output_queue=de_response_queue, - pulsar_api_key=pulsar_api_key, - ) - - self.embeddings = EmbeddingsClient( - pulsar_host=pulsar_host, - input_queue=emb_request_queue, - output_queue=emb_response_queue, - subscriber=module + "-emb", - pulsar_api_key=pulsar_api_key, - ) - - self.lang = PromptClient( - pulsar_host=pulsar_host, - input_queue=pr_request_queue, - output_queue=pr_response_queue, - subscriber=module + "-de-prompt", - pulsar_api_key=pulsar_api_key, - ) - - if self.verbose: - print("Initialised", flush=True) - - def query( - self, query, user="trustgraph", collection="default", - doc_limit=20, - ): - - if self.verbose: - print("Construct prompt...", flush=True) - - q = Query( - rag=self, user=user, collection=collection, verbose=self.verbose, - doc_limit=doc_limit - ) - - docs = q.get_docs(query) - - if self.verbose: - print("Invoke LLM...", flush=True) - print(docs) - print(query) - - resp = self.lang.request_document_prompt(query, docs) - - if self.verbose: - print("Done", flush=True) - - return resp - diff --git a/trustgraph-flow/trustgraph/embeddings/document_embeddings/embeddings.py b/trustgraph-flow/trustgraph/embeddings/document_embeddings/embeddings.py index 70f53e07..95e5462d 100755 --- a/trustgraph-flow/trustgraph/embeddings/document_embeddings/embeddings.py +++ b/trustgraph-flow/trustgraph/embeddings/document_embeddings/embeddings.py @@ -6,61 +6,63 @@ Output is chunk plus embedding. """ from ... schema import Chunk, ChunkEmbeddings, DocumentEmbeddings -from ... schema import chunk_ingest_queue -from ... schema import document_embeddings_store_queue -from ... schema import embeddings_request_queue, embeddings_response_queue -from ... clients.embeddings_client import EmbeddingsClient -from ... log_level import LogLevel -from ... base import ConsumerProducer +from ... schema import EmbeddingsRequest, EmbeddingsResponse -module = ".".join(__name__.split(".")[1:-1]) +from ... base import FlowProcessor, RequestResponseSpec, ConsumerSpec +from ... base import ProducerSpec -default_input_queue = chunk_ingest_queue -default_output_queue = document_embeddings_store_queue -default_subscriber = module +default_ident = "document-embeddings" -class Processor(ConsumerProducer): +class Processor(FlowProcessor): def __init__(self, **params): - input_queue = params.get("input_queue", default_input_queue) - output_queue = params.get("output_queue", default_output_queue) - subscriber = params.get("subscriber", default_subscriber) - emb_request_queue = params.get( - "embeddings_request_queue", embeddings_request_queue - ) - emb_response_queue = params.get( - "embeddings_response_queue", embeddings_response_queue - ) + id = params.get("id") super(Processor, self).__init__( **params | { - "input_queue": input_queue, - "output_queue": output_queue, - "embeddings_request_queue": emb_request_queue, - "embeddings_response_queue": emb_response_queue, - "subscriber": subscriber, - "input_schema": Chunk, - "output_schema": DocumentEmbeddings, + "id": id, } ) - self.embeddings = EmbeddingsClient( - pulsar_host=self.pulsar_host, - pulsar_api_key=self.pulsar_api_key, - input_queue=emb_request_queue, - output_queue=emb_response_queue, - subscriber=module + "-emb", + self.register_specification( + ConsumerSpec( + name = "input", + schema = Chunk, + handler = self.on_message, + ) ) - async def handle(self, msg): + self.register_specification( + RequestResponseSpec( + request_name = "embeddings-request", + request_schema = EmbeddingsRequest, + response_name = "embeddings-response", + response_schema = EmbeddingsResponse, + ) + ) + + self.register_specification( + ProducerSpec( + name = "output", + schema = DocumentEmbeddings + ) + ) + + async def on_message(self, msg, consumer, flow): v = msg.value() print(f"Indexing {v.metadata.id}...", flush=True) try: - vectors = self.embeddings.request(v.chunk) + resp = await flow("embeddings-request").request( + EmbeddingsRequest( + text = v.chunk + ) + ) + + vectors = resp.vectors embeds = [ ChunkEmbeddings( @@ -74,7 +76,7 @@ class Processor(ConsumerProducer): chunks=embeds, ) - await self.send(r) + await flow("output").send(r) except Exception as e: print("Exception:", e, flush=True) @@ -87,24 +89,9 @@ class Processor(ConsumerProducer): @staticmethod def add_args(parser): - ConsumerProducer.add_args( - parser, default_input_queue, default_subscriber, - default_output_queue, - ) - - parser.add_argument( - '--embeddings-request-queue', - default=embeddings_request_queue, - help=f'Embeddings request queue (default: {embeddings_request_queue})', - ) - - parser.add_argument( - '--embeddings-response-queue', - default=embeddings_response_queue, - help=f'Embeddings request queue (default: {embeddings_response_queue})', - ) + FlowProcessor.add_args(parser) def run(): - Processor.launch(module, __doc__) + Processor.launch(default_ident, __doc__) diff --git a/trustgraph-flow/trustgraph/embeddings/fastembed/processor.py b/trustgraph-flow/trustgraph/embeddings/fastembed/processor.py index bc164fa0..a4ae35dc 100755 --- a/trustgraph-flow/trustgraph/embeddings/fastembed/processor.py +++ b/trustgraph-flow/trustgraph/embeddings/fastembed/processor.py @@ -1,81 +1,43 @@ """ -Embeddings service, applies an embeddings model selected from HuggingFace. +Embeddings service, applies an embeddings model using fastembed Input is text, output is embeddings vector. """ -from ... schema import EmbeddingsRequest, EmbeddingsResponse -from ... schema import embeddings_request_queue, embeddings_response_queue -from ... log_level import LogLevel -from ... base import ConsumerProducer +from ... base import EmbeddingsService + from fastembed import TextEmbedding -import os -module = ".".join(__name__.split(".")[1:-1]) +default_ident = "embeddings" -default_input_queue = embeddings_request_queue -default_output_queue = embeddings_response_queue -default_subscriber = module default_model="sentence-transformers/all-MiniLM-L6-v2" -class Processor(ConsumerProducer): +class Processor(EmbeddingsService): def __init__(self, **params): - input_queue = params.get("input_queue", default_input_queue) - output_queue = params.get("output_queue", default_output_queue) - subscriber = params.get("subscriber", default_subscriber) - model = params.get("model", default_model) super(Processor, self).__init__( - **params | { - "input_queue": input_queue, - "output_queue": output_queue, - "subscriber": subscriber, - "input_schema": EmbeddingsRequest, - "output_schema": EmbeddingsResponse, - "model": model, - } + **params | { "model": model } ) + print("Get model...", flush=True) self.embeddings = TextEmbedding(model_name = model) - async def handle(self, msg): + async def on_embeddings(self, text): - v = msg.value() - - # Sender-produced ID - - id = msg.properties()["id"] - - print(f"Handling input {id}...", flush=True) - - text = v.text vecs = self.embeddings.embed([text]) - vecs = [ + return [ v.tolist() for v in vecs ] - print("Send response...", flush=True) - r = EmbeddingsResponse( - vectors=list(vecs), - error=None, - ) - - await self.send(r, properties={"id": id}) - - print("Done.", flush=True) - @staticmethod def add_args(parser): - ConsumerProducer.add_args( - parser, default_input_queue, default_subscriber, - default_output_queue, - ) + EmbeddingsService.add_args(parser) parser.add_argument( '-m', '--model', @@ -85,5 +47,5 @@ class Processor(ConsumerProducer): def run(): - Processor.launch(module, __doc__) + Processor.launch(default_ident, __doc__) diff --git a/trustgraph-flow/trustgraph/embeddings/graph_embeddings/embeddings.py b/trustgraph-flow/trustgraph/embeddings/graph_embeddings/embeddings.py index 2cbe9907..043be3a7 100755 --- a/trustgraph-flow/trustgraph/embeddings/graph_embeddings/embeddings.py +++ b/trustgraph-flow/trustgraph/embeddings/graph_embeddings/embeddings.py @@ -6,53 +6,48 @@ Output is entity plus embedding. """ from ... schema import EntityContexts, EntityEmbeddings, GraphEmbeddings -from ... schema import entity_contexts_ingest_queue -from ... schema import graph_embeddings_store_queue -from ... schema import embeddings_request_queue, embeddings_response_queue -from ... clients.embeddings_client import EmbeddingsClient -from ... log_level import LogLevel -from ... base import ConsumerProducer +from ... schema import EmbeddingsRequest, EmbeddingsResponse -module = ".".join(__name__.split(".")[1:-1]) +from ... base import FlowProcessor, EmbeddingsClientSpec, ConsumerSpec +from ... base import ProducerSpec -default_input_queue = entity_contexts_ingest_queue -default_output_queue = graph_embeddings_store_queue -default_subscriber = module +default_ident = "graph-embeddings" -class Processor(ConsumerProducer): +class Processor(FlowProcessor): def __init__(self, **params): - input_queue = params.get("input_queue", default_input_queue) - output_queue = params.get("output_queue", default_output_queue) - subscriber = params.get("subscriber", default_subscriber) - emb_request_queue = params.get( - "embeddings_request_queue", embeddings_request_queue - ) - emb_response_queue = params.get( - "embeddings_response_queue", embeddings_response_queue - ) + id = params.get("id") super(Processor, self).__init__( **params | { - "input_queue": input_queue, - "output_queue": output_queue, - "embeddings_request_queue": emb_request_queue, - "embeddings_response_queue": emb_response_queue, - "subscriber": subscriber, - "input_schema": EntityContexts, - "output_schema": GraphEmbeddings, + "id": id, } ) - self.embeddings = EmbeddingsClient( - pulsar_host=self.pulsar_host, - input_queue=emb_request_queue, - output_queue=emb_response_queue, - subscriber=module + "-emb", + self.register_specification( + ConsumerSpec( + name = "input", + schema = EntityContexts, + handler = self.on_message, + ) ) - async def handle(self, msg): + self.register_specification( + EmbeddingsClientSpec( + request_name = "embeddings-request", + response_name = "embeddings-response", + ) + ) + + self.register_specification( + ProducerSpec( + name = "output", + schema = GraphEmbeddings + ) + ) + + async def on_message(self, msg, consumer, flow): v = msg.value() print(f"Indexing {v.metadata.id}...", flush=True) @@ -63,7 +58,9 @@ class Processor(ConsumerProducer): for entity in v.entities: - vectors = self.embeddings.request(entity.context) + vectors = await flow("embeddings-request").embed( + text = entity.context + ) entities.append( EntityEmbeddings( @@ -77,7 +74,7 @@ class Processor(ConsumerProducer): entities=entities, ) - await self.send(r) + await flow("output").send(r) except Exception as e: print("Exception:", e, flush=True) @@ -90,24 +87,9 @@ class Processor(ConsumerProducer): @staticmethod def add_args(parser): - ConsumerProducer.add_args( - parser, default_input_queue, default_subscriber, - default_output_queue, - ) - - parser.add_argument( - '--embeddings-request-queue', - default=embeddings_request_queue, - help=f'Embeddings request queue (default: {embeddings_request_queue})', - ) - - parser.add_argument( - '--embeddings-response-queue', - default=embeddings_response_queue, - help=f'Embeddings request queue (default: {embeddings_response_queue})', - ) + FlowProcessor.add_args(parser) def run(): - Processor.launch(module, __doc__) + Processor.launch(default_ident, __doc__) diff --git a/trustgraph-flow/trustgraph/embeddings/ollama/processor.py b/trustgraph-flow/trustgraph/embeddings/ollama/processor.py index c441b9c6..86787316 100755 --- a/trustgraph-flow/trustgraph/embeddings/ollama/processor.py +++ b/trustgraph-flow/trustgraph/embeddings/ollama/processor.py @@ -11,7 +11,7 @@ from ... base import ConsumerProducer from ollama import Client import os -module = ".".join(__name__.split(".")[1:-1]) +module = "embeddings" default_input_queue = embeddings_request_queue default_output_queue = embeddings_response_queue diff --git a/trustgraph-flow/trustgraph/external/wikipedia/service.py b/trustgraph-flow/trustgraph/external/wikipedia/service.py index cc002765..f7de78da 100644 --- a/trustgraph-flow/trustgraph/external/wikipedia/service.py +++ b/trustgraph-flow/trustgraph/external/wikipedia/service.py @@ -11,7 +11,7 @@ from trustgraph.log_level import LogLevel from trustgraph.base import ConsumerProducer import requests -module = ".".join(__name__.split(".")[1:-1]) +module = "wikipedia" default_input_queue = encyclopedia_lookup_request_queue default_output_queue = encyclopedia_lookup_response_queue diff --git a/trustgraph-flow/trustgraph/extract/kg/definitions/extract.py b/trustgraph-flow/trustgraph/extract/kg/definitions/extract.py index 47c99802..f95dadf9 100755 --- a/trustgraph-flow/trustgraph/extract/kg/definitions/extract.py +++ b/trustgraph-flow/trustgraph/extract/kg/definitions/extract.py @@ -5,84 +5,62 @@ get entity definitions which are output as graph edges along with entity/context definitions for embedding. """ +import json import urllib.parse -from pulsar.schema import JsonSchema from .... schema import Chunk, Triple, Triples, Metadata, Value from .... schema import EntityContext, EntityContexts -from .... schema import chunk_ingest_queue, triples_store_queue -from .... schema import entity_contexts_ingest_queue -from .... schema import prompt_request_queue -from .... schema import prompt_response_queue -from .... log_level import LogLevel -from .... clients.prompt_client import PromptClient +from .... schema import PromptRequest, PromptResponse from .... rdf import TRUSTGRAPH_ENTITIES, DEFINITION, RDF_LABEL, SUBJECT_OF -from .... base import ConsumerProducer + +from .... base import FlowProcessor, ConsumerSpec, ProducerSpec +from .... base import PromptClientSpec DEFINITION_VALUE = Value(value=DEFINITION, is_uri=True) RDF_LABEL_VALUE = Value(value=RDF_LABEL, is_uri=True) SUBJECT_OF_VALUE = Value(value=SUBJECT_OF, is_uri=True) -module = ".".join(__name__.split(".")[1:-1]) +default_ident = "kg-extract-definitions" -default_input_queue = chunk_ingest_queue -default_output_queue = triples_store_queue -default_entity_context_queue = entity_contexts_ingest_queue -default_subscriber = module - -class Processor(ConsumerProducer): +class Processor(FlowProcessor): def __init__(self, **params): - input_queue = params.get("input_queue", default_input_queue) - output_queue = params.get("output_queue", default_output_queue) - ec_queue = params.get( - "entity_context_queue", - default_entity_context_queue - ) - subscriber = params.get("subscriber", default_subscriber) - pr_request_queue = params.get( - "prompt_request_queue", prompt_request_queue - ) - pr_response_queue = params.get( - "prompt_response_queue", prompt_response_queue - ) + id = params.get("id") super(Processor, self).__init__( **params | { - "input_queue": input_queue, - "output_queue": output_queue, - "subscriber": subscriber, - "input_schema": Chunk, - "output_schema": Triples, - "prompt_request_queue": pr_request_queue, - "prompt_response_queue": pr_response_queue, + "id": id, } ) - self.ec_prod = self.client.create_producer( - topic=ec_queue, - schema=JsonSchema(EntityContexts), + self.register_specification( + ConsumerSpec( + name = "input", + schema = Chunk, + handler = self.on_message + ) ) - __class__.pubsub_metric.info({ - "input_queue": input_queue, - "output_queue": output_queue, - "entity_context_queue": ec_queue, - "prompt_request_queue": pr_request_queue, - "prompt_response_queue": pr_response_queue, - "subscriber": subscriber, - "input_schema": Chunk.__name__, - "output_schema": Triples.__name__, - "vector_schema": EntityContexts.__name__, - }) + self.register_specification( + PromptClientSpec( + request_name = "prompt-request", + response_name = "prompt-response", + ) + ) - self.prompt = PromptClient( - pulsar_host=self.pulsar_host, - pulsar_api_key=self.pulsar_api_key, - input_queue=pr_request_queue, - output_queue=pr_response_queue, - subscriber = module + "-prompt", + self.register_specification( + ProducerSpec( + name = "triples", + schema = Triples + ) + ) + + self.register_specification( + ProducerSpec( + name = "entity-contexts", + schema = EntityContexts + ) ) def to_uri(self, text): @@ -93,36 +71,47 @@ class Processor(ConsumerProducer): return uri - def get_definitions(self, chunk): - - return self.prompt.request_definitions(chunk) - - async def emit_edges(self, metadata, triples): + async def emit_triples(self, pub, metadata, triples): t = Triples( metadata=metadata, triples=triples, ) - await self.send(t) + await pub.send(t) - async def emit_ecs(self, metadata, entities): + async def emit_ecs(self, pub, metadata, entities): t = EntityContexts( metadata=metadata, entities=entities, ) - self.ec_prod.send(t) + await pub.send(t) - async def handle(self, msg): + async def on_message(self, msg, consumer, flow): v = msg.value() print(f"Indexing {v.metadata.id}...", flush=True) chunk = v.chunk.decode("utf-8") + print(chunk, flush=True) + try: - defs = self.get_definitions(chunk) + try: + + defs = await flow("prompt-request").extract_definitions( + text = chunk + ) + + print("Response", defs, flush=True) + + if type(defs) != list: + raise RuntimeError("Expecting array in prompt response") + + except Exception as e: + print("Prompt exception:", e, flush=True) + raise e triples = [] entities = [] @@ -134,8 +123,8 @@ class Processor(ConsumerProducer): for defn in defs: - s = defn.name - o = defn.definition + s = defn["entity"] + o = defn["definition"] if s == "": continue if o == "": continue @@ -166,13 +155,13 @@ class Processor(ConsumerProducer): ec = EntityContext( entity=s_value, - context=defn.definition, + context=defn["definition"], ) entities.append(ec) - - await self.emit_edges( + await self.emit_triples( + flow("triples"), Metadata( id=v.metadata.id, metadata=[], @@ -183,6 +172,7 @@ class Processor(ConsumerProducer): ) await self.emit_ecs( + flow("entity-contexts"), Metadata( id=v.metadata.id, metadata=[], @@ -200,30 +190,9 @@ class Processor(ConsumerProducer): @staticmethod def add_args(parser): - ConsumerProducer.add_args( - parser, default_input_queue, default_subscriber, - default_output_queue, - ) - - parser.add_argument( - '-e', '--entity-context-queue', - default=default_entity_context_queue, - help=f'Entity context queue (default: {default_entity_context_queue})' - ) - - parser.add_argument( - '--prompt-request-queue', - default=prompt_request_queue, - help=f'Prompt request queue (default: {prompt_request_queue})', - ) - - parser.add_argument( - '--prompt-completion-response-queue', - default=prompt_response_queue, - help=f'Prompt response queue (default: {prompt_response_queue})', - ) + FlowProcessor.add_args(parser) def run(): - Processor.launch(module, __doc__) + Processor.launch(default_ident, __doc__) diff --git a/trustgraph-flow/trustgraph/extract/kg/relationships/extract.py b/trustgraph-flow/trustgraph/extract/kg/relationships/extract.py index 2f293527..ac2929a3 100755 --- a/trustgraph-flow/trustgraph/extract/kg/relationships/extract.py +++ b/trustgraph-flow/trustgraph/extract/kg/relationships/extract.py @@ -5,59 +5,54 @@ relationship analysis to get entity relationship edges which are output as graph edges. """ +import json import urllib.parse from .... schema import Chunk, Triple, Triples from .... schema import Metadata, Value -from .... schema import chunk_ingest_queue, triples_store_queue -from .... schema import prompt_request_queue -from .... schema import prompt_response_queue -from .... log_level import LogLevel -from .... clients.prompt_client import PromptClient +from .... schema import PromptRequest, PromptResponse from .... rdf import RDF_LABEL, TRUSTGRAPH_ENTITIES, SUBJECT_OF -from .... base import ConsumerProducer + +from .... base import FlowProcessor, ConsumerSpec, ProducerSpec +from .... base import PromptClientSpec RDF_LABEL_VALUE = Value(value=RDF_LABEL, is_uri=True) SUBJECT_OF_VALUE = Value(value=SUBJECT_OF, is_uri=True) -module = ".".join(__name__.split(".")[1:-1]) +default_ident = "kg-extract-relationships" -default_input_queue = chunk_ingest_queue -default_output_queue = triples_store_queue -default_subscriber = module - -class Processor(ConsumerProducer): +class Processor(FlowProcessor): def __init__(self, **params): - input_queue = params.get("input_queue", default_input_queue) - output_queue = params.get("output_queue", default_output_queue) - subscriber = params.get("subscriber", default_subscriber) - pr_request_queue = params.get( - "prompt_request_queue", prompt_request_queue - ) - pr_response_queue = params.get( - "prompt_response_queue", prompt_response_queue - ) + id = params.get("id") super(Processor, self).__init__( **params | { - "input_queue": input_queue, - "output_queue": output_queue, - "subscriber": subscriber, - "input_schema": Chunk, - "output_schema": Triples, - "prompt_request_queue": pr_request_queue, - "prompt_response_queue": pr_response_queue, + "id": id, } ) - self.prompt = PromptClient( - pulsar_host=self.pulsar_host, - pulsar_api_key=self.pulsar_api_key, - input_queue=pr_request_queue, - output_queue=pr_response_queue, - subscriber = module + "-prompt", + self.register_specification( + ConsumerSpec( + name = "input", + schema = Chunk, + handler = self.on_message + ) + ) + + self.register_specification( + PromptClientSpec( + request_name = "prompt-request", + response_name = "prompt-response", + ) + ) + + self.register_specification( + ProducerSpec( + name = "triples", + schema = Triples + ) ) def to_uri(self, text): @@ -68,28 +63,39 @@ class Processor(ConsumerProducer): return uri - def get_relationships(self, chunk): - - return self.prompt.request_relationships(chunk) - - async def emit_edges(self, metadata, triples): + async def emit_triples(self, pub, metadata, triples): t = Triples( metadata=metadata, triples=triples, ) - await self.send(t) + await pub.send(t) - async def handle(self, msg): + async def on_message(self, msg, consumer, flow): v = msg.value() print(f"Indexing {v.metadata.id}...", flush=True) chunk = v.chunk.decode("utf-8") + print(chunk, flush=True) + try: - rels = self.get_relationships(chunk) + try: + + rels = await flow("prompt-request").extract_relationships( + text = chunk + ) + + print("Response", rels, flush=True) + + if type(rels) != list: + raise RuntimeError("Expecting array in prompt response") + + except Exception as e: + print("Prompt exception:", e, flush=True) + raise e triples = [] @@ -100,9 +106,9 @@ class Processor(ConsumerProducer): for rel in rels: - s = rel.s - p = rel.p - o = rel.o + s = rel["subject"] + p = rel["predicate"] + o = rel["object"] if s == "": continue if p == "": continue @@ -118,7 +124,7 @@ class Processor(ConsumerProducer): p_uri = self.to_uri(p) p_value = Value(value=str(p_uri), is_uri=True) - if rel.o_entity: + if rel["object-entity"]: o_uri = self.to_uri(o) o_value = Value(value=str(o_uri), is_uri=True) else: @@ -144,7 +150,7 @@ class Processor(ConsumerProducer): o=Value(value=str(p), is_uri=False) )) - if rel.o_entity: + if rel["object-entity"]: # Label for o triples.append(Triple( s=o_value, @@ -159,7 +165,7 @@ class Processor(ConsumerProducer): o=Value(value=v.metadata.id, is_uri=True) )) - if rel.o_entity: + if rel["object-entity"]: # 'Subject of' for o triples.append(Triple( s=o_value, @@ -168,6 +174,7 @@ class Processor(ConsumerProducer): )) await self.emit_edges( + flow("triples"), Metadata( id=v.metadata.id, metadata=[], @@ -185,24 +192,9 @@ class Processor(ConsumerProducer): @staticmethod def add_args(parser): - ConsumerProducer.add_args( - parser, default_input_queue, default_subscriber, - default_output_queue, - ) - - parser.add_argument( - '--prompt-request-queue', - default=prompt_request_queue, - help=f'Prompt request queue (default: {prompt_request_queue})', - ) - - parser.add_argument( - '--prompt-response-queue', - default=prompt_response_queue, - help=f'Prompt response queue (default: {prompt_response_queue})', - ) + FlowProcessor.add_args(parser) def run(): - Processor.launch(module, __doc__) + Processor.launch(default_ident, __doc__) diff --git a/trustgraph-flow/trustgraph/extract/kg/topics/extract.py b/trustgraph-flow/trustgraph/extract/kg/topics/extract.py index 7424abe2..84ab6681 100755 --- a/trustgraph-flow/trustgraph/extract/kg/topics/extract.py +++ b/trustgraph-flow/trustgraph/extract/kg/topics/extract.py @@ -18,7 +18,7 @@ from .... base import ConsumerProducer DEFINITION_VALUE = Value(value=DEFINITION, is_uri=True) -module = ".".join(__name__.split(".")[1:-1]) +module = "kg-extract-topics" default_input_queue = chunk_ingest_queue default_output_queue = triples_store_queue diff --git a/trustgraph-flow/trustgraph/gateway/agent.py b/trustgraph-flow/trustgraph/gateway/agent.py index 150b970e..5a54931b 100644 --- a/trustgraph-flow/trustgraph/gateway/agent.py +++ b/trustgraph-flow/trustgraph/gateway/agent.py @@ -39,4 +39,3 @@ class AgentRequestor(ServiceRequestor): # The 2nd boolean expression indicates whether we're done responding return resp, (message.answer is not None) - diff --git a/trustgraph-flow/trustgraph/gateway/document_embeddings_load.py b/trustgraph-flow/trustgraph/gateway/document_embeddings_load.py index 6b4b4838..bbfb51a3 100644 --- a/trustgraph-flow/trustgraph/gateway/document_embeddings_load.py +++ b/trustgraph-flow/trustgraph/gateway/document_embeddings_load.py @@ -1,6 +1,5 @@ import asyncio -from pulsar.schema import JsonSchema import uuid from aiohttp import WSMsgType @@ -26,12 +25,12 @@ class DocumentEmbeddingsLoadEndpoint(SocketEndpoint): self.publisher = Publisher( self.pulsar_client, document_embeddings_store_queue, - schema=JsonSchema(DocumentEmbeddings) + schema=DocumentEmbeddings ) async def start(self): - self.publisher.start() + await self.publisher.start() async def listener(self, ws, running): @@ -59,6 +58,6 @@ class DocumentEmbeddingsLoadEndpoint(SocketEndpoint): ], ) - self.publisher.send(None, elt) + await self.publisher.send(None, elt) running.stop() diff --git a/trustgraph-flow/trustgraph/gateway/document_embeddings_stream.py b/trustgraph-flow/trustgraph/gateway/document_embeddings_stream.py index 6d7db576..e59a0370 100644 --- a/trustgraph-flow/trustgraph/gateway/document_embeddings_stream.py +++ b/trustgraph-flow/trustgraph/gateway/document_embeddings_stream.py @@ -1,7 +1,6 @@ import asyncio import queue -from pulsar.schema import JsonSchema import uuid from .. schema import DocumentEmbeddings @@ -27,7 +26,7 @@ class DocumentEmbeddingsStreamEndpoint(SocketEndpoint): self.subscriber = Subscriber( self.pulsar_client, document_embeddings_store_queue, "api-gateway", "api-gateway", - schema=JsonSchema(DocumentEmbeddings), + schema=DocumentEmbeddings, ) async def listener(self, ws, running): @@ -44,17 +43,17 @@ class DocumentEmbeddingsStreamEndpoint(SocketEndpoint): async def start(self): - self.subscriber.start() + await self.subscriber.start() async def async_thread(self, ws, running): id = str(uuid.uuid4()) - q = self.subscriber.subscribe_all(id) + q = await self.subscriber.subscribe_all(id) while running.get(): try: - resp = await asyncio.to_thread(q.get, timeout=0.5) + resp = await asyncio.wait_for(q.get(), timeout=0.5) await ws.send_json(serialize_document_embeddings(resp)) except TimeoutError: @@ -67,7 +66,7 @@ class DocumentEmbeddingsStreamEndpoint(SocketEndpoint): print(f"Exception: {str(e)}", flush=True) break - self.subscriber.unsubscribe_all(id) + await self.subscriber.unsubscribe_all(id) running.stop() diff --git a/trustgraph-flow/trustgraph/gateway/endpoint.py b/trustgraph-flow/trustgraph/gateway/endpoint.py index 5005463c..94980e8b 100644 --- a/trustgraph-flow/trustgraph/gateway/endpoint.py +++ b/trustgraph-flow/trustgraph/gateway/endpoint.py @@ -1,13 +1,9 @@ import asyncio -from pulsar.schema import JsonSchema from aiohttp import web import uuid import logging -from .. base import Publisher -from .. base import Subscriber - logger = logging.getLogger("endpoint") logger.setLevel(logging.INFO) diff --git a/trustgraph-flow/trustgraph/gateway/graph_embeddings_load.py b/trustgraph-flow/trustgraph/gateway/graph_embeddings_load.py index c1354ce5..27e92a30 100644 --- a/trustgraph-flow/trustgraph/gateway/graph_embeddings_load.py +++ b/trustgraph-flow/trustgraph/gateway/graph_embeddings_load.py @@ -1,6 +1,5 @@ import asyncio -from pulsar.schema import JsonSchema import uuid from aiohttp import WSMsgType @@ -26,12 +25,12 @@ class GraphEmbeddingsLoadEndpoint(SocketEndpoint): self.publisher = Publisher( self.pulsar_client, graph_embeddings_store_queue, - schema=JsonSchema(GraphEmbeddings) + schema=GraphEmbeddings ) async def start(self): - self.publisher.start() + await self.publisher.start() async def listener(self, ws, running): @@ -60,6 +59,6 @@ class GraphEmbeddingsLoadEndpoint(SocketEndpoint): ] ) - self.publisher.send(None, elt) + await self.publisher.send(None, elt) running.stop() diff --git a/trustgraph-flow/trustgraph/gateway/graph_embeddings_stream.py b/trustgraph-flow/trustgraph/gateway/graph_embeddings_stream.py index 385eb9f4..37edc2bb 100644 --- a/trustgraph-flow/trustgraph/gateway/graph_embeddings_stream.py +++ b/trustgraph-flow/trustgraph/gateway/graph_embeddings_stream.py @@ -1,7 +1,6 @@ import asyncio import queue -from pulsar.schema import JsonSchema import uuid from .. schema import GraphEmbeddings @@ -26,7 +25,7 @@ class GraphEmbeddingsStreamEndpoint(SocketEndpoint): self.subscriber = Subscriber( self.pulsar_client, graph_embeddings_store_queue, "api-gateway", "api-gateway", - schema=JsonSchema(GraphEmbeddings) + schema=GraphEmbeddings ) async def listener(self, ws, running): @@ -41,17 +40,17 @@ class GraphEmbeddingsStreamEndpoint(SocketEndpoint): async def start(self): - self.subscriber.start() + await self.subscriber.start() async def async_thread(self, ws, running): id = str(uuid.uuid4()) - q = self.subscriber.subscribe_all(id) + q = await self.subscriber.subscribe_all(id) while running.get(): try: - resp = await asyncio.to_thread(q.get, timeout=0.5) + resp = await asyncio.wait_for(q.get, timeout=0.5) await ws.send_json(serialize_graph_embeddings(resp)) except TimeoutError: @@ -64,7 +63,7 @@ class GraphEmbeddingsStreamEndpoint(SocketEndpoint): print(f"Exception: {str(e)}", flush=True) break - self.subscriber.unsubscribe_all(id) + await self.subscriber.unsubscribe_all(id) running.stop() diff --git a/trustgraph-flow/trustgraph/gateway/metrics.py b/trustgraph-flow/trustgraph/gateway/metrics.py index 33c1fe3a..d8a1ef62 100644 --- a/trustgraph-flow/trustgraph/gateway/metrics.py +++ b/trustgraph-flow/trustgraph/gateway/metrics.py @@ -7,7 +7,6 @@ import aiohttp from aiohttp import web import asyncio -from pulsar.schema import JsonSchema import uuid import logging diff --git a/trustgraph-flow/trustgraph/gateway/mux.py b/trustgraph-flow/trustgraph/gateway/mux.py index 23b693ab..8195c542 100644 --- a/trustgraph-flow/trustgraph/gateway/mux.py +++ b/trustgraph-flow/trustgraph/gateway/mux.py @@ -1,7 +1,6 @@ import asyncio import queue -from pulsar.schema import JsonSchema import uuid from aiohttp import web, WSMsgType diff --git a/trustgraph-flow/trustgraph/gateway/requestor.py b/trustgraph-flow/trustgraph/gateway/requestor.py index dc74667d..63395203 100644 --- a/trustgraph-flow/trustgraph/gateway/requestor.py +++ b/trustgraph-flow/trustgraph/gateway/requestor.py @@ -1,6 +1,5 @@ import asyncio -from pulsar.schema import JsonSchema import uuid import logging @@ -23,21 +22,21 @@ class ServiceRequestor: self.pub = Publisher( pulsar_client, request_queue, - schema=JsonSchema(request_schema), + schema=request_schema, ) self.sub = Subscriber( pulsar_client, response_queue, subscription, consumer_name, - JsonSchema(response_schema) + response_schema ) self.timeout = timeout async def start(self): - self.pub.start() - self.sub.start() + await self.pub.start() + await self.sub.start() def to_request(self, request): raise RuntimeError("Not defined") @@ -51,18 +50,15 @@ class ServiceRequestor: try: - q = self.sub.subscribe(id) + q = await self.sub.subscribe(id) - await asyncio.to_thread( - self.pub.send, id, self.to_request(request) - ) + await self.pub.send(id, self.to_request(request)) while True: try: - resp = await asyncio.to_thread( - q.get, - timeout=self.timeout + resp = await asyncio.wait_for( + q.get(), timeout=self.timeout ) except Exception as e: raise RuntimeError("Timeout") @@ -99,5 +95,5 @@ class ServiceRequestor: return err finally: - self.sub.unsubscribe(id) + await self.sub.unsubscribe(id) diff --git a/trustgraph-flow/trustgraph/gateway/sender.py b/trustgraph-flow/trustgraph/gateway/sender.py index 32c586b1..81b64e6d 100644 --- a/trustgraph-flow/trustgraph/gateway/sender.py +++ b/trustgraph-flow/trustgraph/gateway/sender.py @@ -2,7 +2,6 @@ # Like ServiceRequestor, but just fire-and-forget instead of request/response import asyncio -from pulsar.schema import JsonSchema import uuid import logging @@ -21,12 +20,12 @@ class ServiceSender: self.pub = Publisher( pulsar_client, request_queue, - schema=JsonSchema(request_schema), + schema=request_schema, ) async def start(self): - self.pub.start() + await self.pub.start() def to_request(self, request): raise RuntimeError("Not defined") @@ -35,9 +34,7 @@ class ServiceSender: try: - await asyncio.to_thread( - self.pub.send, None, self.to_request(request) - ) + await self.pub.send(None, self.to_request(request)) if responder: await responder({}, True) diff --git a/trustgraph-flow/trustgraph/gateway/service.py b/trustgraph-flow/trustgraph/gateway/service.py index e997f83e..29b31483 100755 --- a/trustgraph-flow/trustgraph/gateway/service.py +++ b/trustgraph-flow/trustgraph/gateway/service.py @@ -3,7 +3,7 @@ API gateway. Offers HTTP services which are translated to interaction on the Pulsar bus. """ -module = ".".join(__name__.split(".")[1:-1]) +module = "api-gateway" # FIXME: Subscribes to Pulsar unnecessarily, should only do it when there # are active listeners @@ -19,7 +19,6 @@ import os import base64 import pulsar -from pulsar.schema import JsonSchema from prometheus_client import start_http_server from .. log_level import LogLevel diff --git a/trustgraph-flow/trustgraph/gateway/triples_load.py b/trustgraph-flow/trustgraph/gateway/triples_load.py index bc69975e..81c8ea82 100644 --- a/trustgraph-flow/trustgraph/gateway/triples_load.py +++ b/trustgraph-flow/trustgraph/gateway/triples_load.py @@ -1,6 +1,5 @@ import asyncio -from pulsar.schema import JsonSchema import uuid from aiohttp import WSMsgType @@ -24,12 +23,12 @@ class TriplesLoadEndpoint(SocketEndpoint): self.publisher = Publisher( self.pulsar_client, triples_store_queue, - schema=JsonSchema(Triples) + schema=Triples ) async def start(self): - self.publisher.start() + await self.publisher.start() async def listener(self, ws, running): @@ -51,7 +50,7 @@ class TriplesLoadEndpoint(SocketEndpoint): triples=to_subgraph(data["triples"]), ) - self.publisher.send(None, elt) + await self.publisher.send(None, elt) running.stop() diff --git a/trustgraph-flow/trustgraph/gateway/triples_stream.py b/trustgraph-flow/trustgraph/gateway/triples_stream.py index a5d5ad0a..a660591e 100644 --- a/trustgraph-flow/trustgraph/gateway/triples_stream.py +++ b/trustgraph-flow/trustgraph/gateway/triples_stream.py @@ -1,7 +1,6 @@ import asyncio import queue -from pulsar.schema import JsonSchema import uuid from .. schema import Triples @@ -24,7 +23,7 @@ class TriplesStreamEndpoint(SocketEndpoint): self.subscriber = Subscriber( self.pulsar_client, triples_store_queue, "api-gateway", "api-gateway", - schema=JsonSchema(Triples) + schema=Triples ) async def listener(self, ws, running): @@ -39,7 +38,7 @@ class TriplesStreamEndpoint(SocketEndpoint): async def start(self): - self.subscriber.start() + await self.subscriber.start() async def async_thread(self, ws, running): diff --git a/trustgraph-flow/trustgraph/graph_rag.py b/trustgraph-flow/trustgraph/graph_rag.py deleted file mode 100644 index 6a4e11c5..00000000 --- a/trustgraph-flow/trustgraph/graph_rag.py +++ /dev/null @@ -1,295 +0,0 @@ - -from . clients.graph_embeddings_client import GraphEmbeddingsClient -from . clients.triples_query_client import TriplesQueryClient -from . clients.embeddings_client import EmbeddingsClient -from . clients.prompt_client import PromptClient - -from . schema import GraphEmbeddingsRequest, GraphEmbeddingsResponse -from . schema import TriplesQueryRequest, TriplesQueryResponse -from . schema import prompt_request_queue -from . schema import prompt_response_queue -from . schema import embeddings_request_queue -from . schema import embeddings_response_queue -from . schema import graph_embeddings_request_queue -from . schema import graph_embeddings_response_queue -from . schema import triples_request_queue -from . schema import triples_response_queue - -LABEL="http://www.w3.org/2000/01/rdf-schema#label" -DEFINITION="http://www.w3.org/2004/02/skos/core#definition" - -class Query: - - def __init__( - self, rag, user, collection, verbose, - entity_limit=50, triple_limit=30, max_subgraph_size=1000, - max_path_length=2, - ): - self.rag = rag - self.user = user - self.collection = collection - self.verbose = verbose - self.entity_limit = entity_limit - self.triple_limit = triple_limit - self.max_subgraph_size = max_subgraph_size - self.max_path_length = max_path_length - - def get_vector(self, query): - - if self.verbose: - print("Compute embeddings...", flush=True) - - qembeds = self.rag.embeddings.request(query) - - if self.verbose: - print("Done.", flush=True) - - return qembeds - - def get_entities(self, query): - - vectors = self.get_vector(query) - - if self.verbose: - print("Get entities...", flush=True) - - entities = self.rag.ge_client.request( - user=self.user, collection=self.collection, - vectors=vectors, limit=self.entity_limit, - ) - - entities = [ - e.value - for e in entities - ] - - if self.verbose: - print("Entities:", flush=True) - for ent in entities: - print(" ", ent, flush=True) - - return entities - - def maybe_label(self, e): - - if e in self.rag.label_cache: - return self.rag.label_cache[e] - - res = self.rag.triples_client.request( - user=self.user, collection=self.collection, - s=e, p=LABEL, o=None, limit=1, - ) - - if len(res) == 0: - self.rag.label_cache[e] = e - return e - - self.rag.label_cache[e] = res[0].o.value - return self.rag.label_cache[e] - - def follow_edges(self, ent, subgraph, path_length): - - # Not needed? - if path_length <= 0: - return - - # Stop spanning around if the subgraph is already maxed out - if len(subgraph) >= self.max_subgraph_size: - return - - res = self.rag.triples_client.request( - user=self.user, collection=self.collection, - s=ent, p=None, o=None, - limit=self.triple_limit - ) - - for triple in res: - subgraph.add( - (triple.s.value, triple.p.value, triple.o.value) - ) - if path_length > 1: - self.follow_edges(triple.o.value, subgraph, path_length-1) - - res = self.rag.triples_client.request( - user=self.user, collection=self.collection, - s=None, p=ent, o=None, - limit=self.triple_limit - ) - - for triple in res: - subgraph.add( - (triple.s.value, triple.p.value, triple.o.value) - ) - - res = self.rag.triples_client.request( - user=self.user, collection=self.collection, - s=None, p=None, o=ent, - limit=self.triple_limit, - ) - - for triple in res: - subgraph.add( - (triple.s.value, triple.p.value, triple.o.value) - ) - if path_length > 1: - self.follow_edges(triple.s.value, subgraph, path_length-1) - - def get_subgraph(self, query): - - entities = self.get_entities(query) - - if self.verbose: - print("Get subgraph...", flush=True) - - subgraph = set() - - for ent in entities: - self.follow_edges(ent, subgraph, self.max_path_length) - - subgraph = list(subgraph) - - return subgraph - - def get_labelgraph(self, query): - - subgraph = self.get_subgraph(query) - - sg2 = [] - - for edge in subgraph: - - if edge[1] == LABEL: - continue - - s = self.maybe_label(edge[0]) - p = self.maybe_label(edge[1]) - o = self.maybe_label(edge[2]) - - sg2.append((s, p, o)) - - sg2 = sg2[0:self.max_subgraph_size] - - if self.verbose: - print("Subgraph:", flush=True) - for edge in sg2: - print(" ", str(edge), flush=True) - - if self.verbose: - print("Done.", flush=True) - - return sg2 - -class GraphRag: - - def __init__( - self, - pulsar_host="pulsar://pulsar:6650", - pulsar_api_key=None, - pr_request_queue=None, - pr_response_queue=None, - emb_request_queue=None, - emb_response_queue=None, - ge_request_queue=None, - ge_response_queue=None, - tpl_request_queue=None, - tpl_response_queue=None, - verbose=False, - module="test", - ): - - self.verbose=verbose - - if pr_request_queue is None: - pr_request_queue = prompt_request_queue - - if pr_response_queue is None: - pr_response_queue = prompt_response_queue - - if emb_request_queue is None: - emb_request_queue = embeddings_request_queue - - if emb_response_queue is None: - emb_response_queue = embeddings_response_queue - - if ge_request_queue is None: - ge_request_queue = graph_embeddings_request_queue - - if ge_response_queue is None: - ge_response_queue = graph_embeddings_response_queue - - if tpl_request_queue is None: - tpl_request_queue = triples_request_queue - - if tpl_response_queue is None: - tpl_response_queue = triples_response_queue - - if self.verbose: - print("Initialising...", flush=True) - - self.ge_client = GraphEmbeddingsClient( - pulsar_host=pulsar_host, - pulsar_api_key=pulsar_api_key, - subscriber=module + "-ge", - input_queue=ge_request_queue, - output_queue=ge_response_queue, - ) - - self.triples_client = TriplesQueryClient( - pulsar_host=pulsar_host, - pulsar_api_key=pulsar_api_key, - subscriber=module + "-tpl", - input_queue=tpl_request_queue, - output_queue=tpl_response_queue - ) - - self.embeddings = EmbeddingsClient( - pulsar_host=pulsar_host, - pulsar_api_key=pulsar_api_key, - input_queue=emb_request_queue, - output_queue=emb_response_queue, - subscriber=module + "-emb", - ) - - self.label_cache = {} - - self.prompt = PromptClient( - pulsar_host=pulsar_host, - pulsar_api_key=pulsar_api_key, - input_queue=pr_request_queue, - output_queue=pr_response_queue, - subscriber=module + "-prompt", - ) - - if self.verbose: - print("Initialised", flush=True) - - def query( - self, query, user="trustgraph", collection="default", - entity_limit=50, triple_limit=30, max_subgraph_size=1000, - max_path_length=2, - ): - - if self.verbose: - print("Construct prompt...", flush=True) - - q = Query( - rag=self, user=user, collection=collection, verbose=self.verbose, - entity_limit=entity_limit, triple_limit=triple_limit, - max_subgraph_size=max_subgraph_size, - max_path_length=max_path_length, - ) - - kg = q.get_labelgraph(query) - - if self.verbose: - print("Invoke LLM...", flush=True) - print(kg) - print(query) - - resp = self.prompt.request_kg_prompt(query, kg) - - if self.verbose: - print("Done", flush=True) - - return resp - diff --git a/trustgraph-flow/trustgraph/librarian/service.py b/trustgraph-flow/trustgraph/librarian/service.py index b42123a5..587dcbf3 100755 --- a/trustgraph-flow/trustgraph/librarian/service.py +++ b/trustgraph-flow/trustgraph/librarian/service.py @@ -35,7 +35,7 @@ from .. exceptions import RequestError from . librarian import Librarian -module = ".".join(__name__.split(".")[1:-1]) +module = "librarian" default_input_queue = librarian_request_queue default_output_queue = librarian_response_queue diff --git a/trustgraph-flow/trustgraph/metering/counter.py b/trustgraph-flow/trustgraph/metering/counter.py index 68ddf441..c721c065 100644 --- a/trustgraph-flow/trustgraph/metering/counter.py +++ b/trustgraph-flow/trustgraph/metering/counter.py @@ -10,12 +10,11 @@ from .. schema import text_completion_response_queue from .. log_level import LogLevel from .. base import Consumer -module = ".".join(__name__.split(".")[1:-1]) +module = "metering" default_input_queue = text_completion_response_queue default_subscriber = module - class Processor(Consumer): def __init__(self, **params): diff --git a/trustgraph-flow/trustgraph/model/prompt/generic/service.py b/trustgraph-flow/trustgraph/model/prompt/generic/service.py index b143b759..b10da491 100755 --- a/trustgraph-flow/trustgraph/model/prompt/generic/service.py +++ b/trustgraph-flow/trustgraph/model/prompt/generic/service.py @@ -27,7 +27,7 @@ from .... clients.llm_client import LlmClient from . prompts import to_definitions, to_relationships, to_topics from . prompts import to_kg_query, to_document_query, to_rows -module = ".".join(__name__.split(".")[1:-1]) +module = "prompt" default_input_queue = prompt_request_queue default_output_queue = prompt_response_queue diff --git a/trustgraph-flow/trustgraph/model/prompt/template/prompt_manager.py b/trustgraph-flow/trustgraph/model/prompt/template/prompt_manager.py index d8a032ca..c5c32395 100644 --- a/trustgraph-flow/trustgraph/model/prompt/template/prompt_manager.py +++ b/trustgraph-flow/trustgraph/model/prompt/template/prompt_manager.py @@ -4,8 +4,6 @@ import json from jsonschema import validate import re -from trustgraph.clients.llm_client import LlmClient - class PromptConfiguration: def __init__(self, system_template, global_terms={}, prompts={}): self.system_template = system_template @@ -21,8 +19,7 @@ class Prompt: class PromptManager: - def __init__(self, llm, config): - self.llm = llm + def __init__(self, config): self.config = config self.terms = config.global_terms @@ -54,7 +51,9 @@ class PromptManager: return json.loads(json_str) - def invoke(self, id, input): + async def invoke(self, id, input, llm): + + print("Invoke...", flush=True) if id not in self.prompts: raise RuntimeError("ID invalid") @@ -68,9 +67,7 @@ class PromptManager: "prompt": self.templates[id].render(terms) } - resp = self.llm.request(**prompt) - - print(resp, flush=True) + resp = await llm(**prompt) if resp_type == "text": return resp @@ -81,13 +78,13 @@ class PromptManager: try: obj = self.parse_json(resp) except: + print("Parse fail:", resp, flush=True) raise RuntimeError("JSON parse fail") - print(obj, flush=True) if self.prompts[id].schema: try: - print(self.prompts[id].schema) validate(instance=obj, schema=self.prompts[id].schema) + print("Validated", flush=True) except Exception as e: raise RuntimeError(f"Schema validation fail: {e}") diff --git a/trustgraph-flow/trustgraph/model/prompt/template/service.py b/trustgraph-flow/trustgraph/model/prompt/template/service.py index a1267114..67590c1c 100755 --- a/trustgraph-flow/trustgraph/model/prompt/template/service.py +++ b/trustgraph-flow/trustgraph/model/prompt/template/service.py @@ -3,6 +3,7 @@ Language service abstracts prompt engineering from LLM. """ +import asyncio import json import re @@ -10,74 +11,59 @@ from .... schema import Definition, Relationship, Triple from .... schema import Topic from .... schema import PromptRequest, PromptResponse, Error from .... schema import TextCompletionRequest, TextCompletionResponse -from .... schema import text_completion_request_queue -from .... schema import text_completion_response_queue -from .... schema import prompt_request_queue, prompt_response_queue -from .... base import ConsumerProducer -from .... clients.llm_client import LlmClient + +from .... base import FlowProcessor +from .... base import ProducerSpec, ConsumerSpec, TextCompletionClientSpec from . prompt_manager import PromptConfiguration, Prompt, PromptManager -module = ".".join(__name__.split(".")[1:-1]) +default_ident = "prompt" -default_input_queue = prompt_request_queue -default_output_queue = prompt_response_queue -default_subscriber = module - -class Processor(ConsumerProducer): +class Processor(FlowProcessor): def __init__(self, **params): - input_queue = params.get("input_queue", default_input_queue) - output_queue = params.get("output_queue", default_output_queue) - subscriber = params.get("subscriber", default_subscriber) - tc_request_queue = params.get( - "text_completion_request_queue", text_completion_request_queue - ) - tc_response_queue = params.get( - "text_completion_response_queue", text_completion_response_queue - ) + id = params.get("id") + # Config key for prompts self.config_key = params.get("config_type", "prompt") super(Processor, self).__init__( **params | { - "input_queue": input_queue, - "output_queue": output_queue, - "subscriber": subscriber, - "input_schema": PromptRequest, - "output_schema": PromptResponse, - "text_completion_request_queue": tc_request_queue, - "text_completion_response_queue": tc_response_queue, + "id": id, } ) - self.llm = LlmClient( - subscriber=subscriber, - input_queue=tc_request_queue, - output_queue=tc_response_queue, - pulsar_host = self.pulsar_host, - pulsar_api_key=self.pulsar_api_key, + self.register_specification( + ConsumerSpec( + name = "request", + schema = PromptRequest, + handler = self.on_request + ) ) - # System prompt hack - class Llm: - def __init__(self, llm): - self.llm = llm - def request(self, system, prompt): - print(system) - print(prompt, flush=True) - return self.llm.request(system, prompt) + self.register_specification( + TextCompletionClientSpec( + request_name = "text-completion-request", + response_name = "text-completion-response", + ) + ) - self.llm = Llm(self.llm) + self.register_specification( + ProducerSpec( + name = "response", + schema = PromptResponse + ) + ) + + self.register_config_handler(self.on_prompt_config) # Null configuration, should reload quickly self.manager = PromptManager( - llm = self.llm, config = PromptConfiguration("", {}, {}) ) - async def on_config(self, version, config): + async def on_prompt_config(self, config, version): print("Loading configuration version", version) @@ -111,7 +97,6 @@ class Processor(ConsumerProducer): ) self.manager = PromptManager( - self.llm, PromptConfiguration( system, {}, @@ -126,7 +111,7 @@ class Processor(ConsumerProducer): print("Exception:", e, flush=True) print("Configuration reload failed", flush=True) - async def handle(self, msg): + async def on_request(self, msg, consumer, flow): v = msg.value() @@ -138,7 +123,7 @@ class Processor(ConsumerProducer): try: - print(v.terms) + print(v.terms, flush=True) input = { k: json.loads(v) @@ -146,14 +131,33 @@ class Processor(ConsumerProducer): } print(f"Handling kind {kind}...", flush=True) - print(input, flush=True) - resp = self.manager.invoke(kind, input) + async def llm(system, prompt): + + print(system, flush=True) + print(prompt, flush=True) + + resp = await flow("text-completion-request").text_completion( + system = system, prompt = prompt, + ) + + try: + return resp + except Exception as e: + print("LLM Exception:", e, flush=True) + return None + + try: + resp = await self.manager.invoke(kind, input, llm) + except Exception as e: + print("Invocation exception:", e, flush=True) + raise e + + print(resp, flush=True) if isinstance(resp, str): print("Send text response...", flush=True) - print(resp, flush=True) r = PromptResponse( text=resp, @@ -161,7 +165,7 @@ class Processor(ConsumerProducer): error=None, ) - await self.send(r, properties={"id": id}) + await flow("response").send(r, properties={"id": id}) return @@ -176,13 +180,13 @@ class Processor(ConsumerProducer): error=None, ) - await self.send(r, properties={"id": id}) + await flow("response").send(r, properties={"id": id}) return except Exception as e: - print(f"Exception: {e}") + print(f"Exception: {e}", flush=True) print("Send error response...", flush=True) @@ -194,11 +198,11 @@ class Processor(ConsumerProducer): response=None, ) - await self.send(r, properties={"id": id}) + await flow("response").send(r, properties={"id": id}) except Exception as e: - print(f"Exception: {e}") + print(f"Exception: {e}", flush=True) print("Send error response...", flush=True) @@ -215,22 +219,7 @@ class Processor(ConsumerProducer): @staticmethod def add_args(parser): - ConsumerProducer.add_args( - parser, default_input_queue, default_subscriber, - default_output_queue, - ) - - parser.add_argument( - '--text-completion-request-queue', - default=text_completion_request_queue, - help=f'Text completion request queue (default: {text_completion_request_queue})', - ) - - parser.add_argument( - '--text-completion-response-queue', - default=text_completion_response_queue, - help=f'Text completion response queue (default: {text_completion_response_queue})', - ) + FlowProcessor.add_args(parser) parser.add_argument( '--config-type', @@ -240,5 +229,5 @@ class Processor(ConsumerProducer): def run(): - Processor.launch(module, __doc__) + Processor.launch(default_ident, __doc__) diff --git a/trustgraph-flow/trustgraph/model/text_completion/azure/llm.py b/trustgraph-flow/trustgraph/model/text_completion/azure/llm.py index 33840378..79118cc8 100755 --- a/trustgraph-flow/trustgraph/model/text_completion/azure/llm.py +++ b/trustgraph-flow/trustgraph/model/text_completion/azure/llm.py @@ -16,7 +16,7 @@ from .... log_level import LogLevel from .... base import ConsumerProducer from .... exceptions import TooManyRequests -module = ".".join(__name__.split(".")[1:-1]) +module = "text-completion" default_input_queue = text_completion_request_queue default_output_queue = text_completion_response_queue diff --git a/trustgraph-flow/trustgraph/model/text_completion/azure_openai/llm.py b/trustgraph-flow/trustgraph/model/text_completion/azure_openai/llm.py index 252d58ad..734b20c5 100755 --- a/trustgraph-flow/trustgraph/model/text_completion/azure_openai/llm.py +++ b/trustgraph-flow/trustgraph/model/text_completion/azure_openai/llm.py @@ -16,7 +16,7 @@ from .... log_level import LogLevel from .... base import ConsumerProducer from .... exceptions import TooManyRequests -module = ".".join(__name__.split(".")[1:-1]) +module = "text-completion" default_input_queue = text_completion_request_queue default_output_queue = text_completion_response_queue diff --git a/trustgraph-flow/trustgraph/model/text_completion/claude/llm.py b/trustgraph-flow/trustgraph/model/text_completion/claude/llm.py index 195a39e4..f60b70d7 100755 --- a/trustgraph-flow/trustgraph/model/text_completion/claude/llm.py +++ b/trustgraph-flow/trustgraph/model/text_completion/claude/llm.py @@ -15,7 +15,7 @@ from .... log_level import LogLevel from .... base import ConsumerProducer from .... exceptions import TooManyRequests -module = ".".join(__name__.split(".")[1:-1]) +module = "text-completion" default_input_queue = text_completion_request_queue default_output_queue = text_completion_response_queue diff --git a/trustgraph-flow/trustgraph/model/text_completion/cohere/llm.py b/trustgraph-flow/trustgraph/model/text_completion/cohere/llm.py index d5dab142..df104ada 100755 --- a/trustgraph-flow/trustgraph/model/text_completion/cohere/llm.py +++ b/trustgraph-flow/trustgraph/model/text_completion/cohere/llm.py @@ -15,7 +15,7 @@ from .... log_level import LogLevel from .... base import ConsumerProducer from .... exceptions import TooManyRequests -module = ".".join(__name__.split(".")[1:-1]) +module = "text-completion" default_input_queue = text_completion_request_queue default_output_queue = text_completion_response_queue diff --git a/trustgraph-flow/trustgraph/model/text_completion/googleaistudio/llm.py b/trustgraph-flow/trustgraph/model/text_completion/googleaistudio/llm.py index 98ecaf0e..9f382572 100644 --- a/trustgraph-flow/trustgraph/model/text_completion/googleaistudio/llm.py +++ b/trustgraph-flow/trustgraph/model/text_completion/googleaistudio/llm.py @@ -17,7 +17,7 @@ from .... log_level import LogLevel from .... base import ConsumerProducer from .... exceptions import TooManyRequests -module = ".".join(__name__.split(".")[1:-1]) +module = "text-completion" default_input_queue = text_completion_request_queue default_output_queue = text_completion_response_queue diff --git a/trustgraph-flow/trustgraph/model/text_completion/llamafile/llm.py b/trustgraph-flow/trustgraph/model/text_completion/llamafile/llm.py index 483412a2..fd473564 100755 --- a/trustgraph-flow/trustgraph/model/text_completion/llamafile/llm.py +++ b/trustgraph-flow/trustgraph/model/text_completion/llamafile/llm.py @@ -14,7 +14,7 @@ from .... log_level import LogLevel from .... base import ConsumerProducer from .... exceptions import TooManyRequests -module = ".".join(__name__.split(".")[1:-1]) +module = "text-completion" default_input_queue = text_completion_request_queue default_output_queue = text_completion_response_queue diff --git a/trustgraph-flow/trustgraph/model/text_completion/lmstudio/llm.py b/trustgraph-flow/trustgraph/model/text_completion/lmstudio/llm.py index 16ff2df4..05ff18a6 100755 --- a/trustgraph-flow/trustgraph/model/text_completion/lmstudio/llm.py +++ b/trustgraph-flow/trustgraph/model/text_completion/lmstudio/llm.py @@ -15,7 +15,7 @@ from .... log_level import LogLevel from .... base import ConsumerProducer from .... exceptions import TooManyRequests -module = ".".join(__name__.split(".")[1:-1]) +module = "text-completion" default_input_queue = text_completion_request_queue default_output_queue = text_completion_response_queue diff --git a/trustgraph-flow/trustgraph/model/text_completion/mistral/llm.py b/trustgraph-flow/trustgraph/model/text_completion/mistral/llm.py index 45f1311c..10257cdf 100755 --- a/trustgraph-flow/trustgraph/model/text_completion/mistral/llm.py +++ b/trustgraph-flow/trustgraph/model/text_completion/mistral/llm.py @@ -15,7 +15,7 @@ from .... log_level import LogLevel from .... base import ConsumerProducer from .... exceptions import TooManyRequests -module = ".".join(__name__.split(".")[1:-1]) +module = "text-completion" default_input_queue = text_completion_request_queue default_output_queue = text_completion_response_queue diff --git a/trustgraph-flow/trustgraph/model/text_completion/ollama/llm.py b/trustgraph-flow/trustgraph/model/text_completion/ollama/llm.py index 6d825bac..91e627e3 100755 --- a/trustgraph-flow/trustgraph/model/text_completion/ollama/llm.py +++ b/trustgraph-flow/trustgraph/model/text_completion/ollama/llm.py @@ -15,7 +15,7 @@ from .... log_level import LogLevel from .... base import ConsumerProducer from .... exceptions import TooManyRequests -module = ".".join(__name__.split(".")[1:-1]) +module = "text-completion" default_input_queue = text_completion_request_queue default_output_queue = text_completion_response_queue diff --git a/trustgraph-flow/trustgraph/model/text_completion/openai/llm.py b/trustgraph-flow/trustgraph/model/text_completion/openai/llm.py index 590c2e3f..2479034d 100755 --- a/trustgraph-flow/trustgraph/model/text_completion/openai/llm.py +++ b/trustgraph-flow/trustgraph/model/text_completion/openai/llm.py @@ -15,7 +15,7 @@ from .... log_level import LogLevel from .... base import ConsumerProducer from .... exceptions import TooManyRequests -module = ".".join(__name__.split(".")[1:-1]) +module = "text-completion" default_input_queue = text_completion_request_queue default_output_queue = text_completion_response_queue diff --git a/trustgraph-flow/trustgraph/query/doc_embeddings/milvus/service.py b/trustgraph-flow/trustgraph/query/doc_embeddings/milvus/service.py index b16399e9..2fb416dd 100755 --- a/trustgraph-flow/trustgraph/query/doc_embeddings/milvus/service.py +++ b/trustgraph-flow/trustgraph/query/doc_embeddings/milvus/service.py @@ -11,7 +11,7 @@ from .... schema import document_embeddings_request_queue from .... schema import document_embeddings_response_queue from .... base import ConsumerProducer -module = ".".join(__name__.split(".")[1:-1]) +module = "de-query" default_input_queue = document_embeddings_request_queue default_output_queue = document_embeddings_response_queue diff --git a/trustgraph-flow/trustgraph/query/doc_embeddings/pinecone/service.py b/trustgraph-flow/trustgraph/query/doc_embeddings/pinecone/service.py index 6a88671c..74c52055 100755 --- a/trustgraph-flow/trustgraph/query/doc_embeddings/pinecone/service.py +++ b/trustgraph-flow/trustgraph/query/doc_embeddings/pinecone/service.py @@ -16,7 +16,7 @@ from .... schema import document_embeddings_request_queue from .... schema import document_embeddings_response_queue from .... base import ConsumerProducer -module = ".".join(__name__.split(".")[1:-1]) +module = "de-query" default_input_queue = document_embeddings_request_queue default_output_queue = document_embeddings_response_queue diff --git a/trustgraph-flow/trustgraph/query/doc_embeddings/qdrant/service.py b/trustgraph-flow/trustgraph/query/doc_embeddings/qdrant/service.py index 128203ad..c5543690 100755 --- a/trustgraph-flow/trustgraph/query/doc_embeddings/qdrant/service.py +++ b/trustgraph-flow/trustgraph/query/doc_embeddings/qdrant/service.py @@ -7,71 +7,51 @@ of chunks from qdrant_client import QdrantClient from qdrant_client.models import PointStruct from qdrant_client.models import Distance, VectorParams -import uuid -from .... schema import DocumentEmbeddingsRequest, DocumentEmbeddingsResponse +from .... schema import DocumentEmbeddingsResponse from .... schema import Error, Value -from .... schema import document_embeddings_request_queue -from .... schema import document_embeddings_response_queue -from .... base import ConsumerProducer +from .... base import DocumentEmbeddingsQueryService -module = ".".join(__name__.split(".")[1:-1]) +default_ident = "de-query" -default_input_queue = document_embeddings_request_queue -default_output_queue = document_embeddings_response_queue -default_subscriber = module default_store_uri = 'http://localhost:6333' -class Processor(ConsumerProducer): +class Processor(DocumentEmbeddingsQueryService): def __init__(self, **params): - input_queue = params.get("input_queue", default_input_queue) - output_queue = params.get("output_queue", default_output_queue) - subscriber = params.get("subscriber", default_subscriber) store_uri = params.get("store_uri", default_store_uri) + #optional api key api_key = params.get("api_key", None) super(Processor, self).__init__( **params | { - "input_queue": input_queue, - "output_queue": output_queue, - "subscriber": subscriber, - "input_schema": DocumentEmbeddingsRequest, - "output_schema": DocumentEmbeddingsResponse, "store_uri": store_uri, "api_key": api_key, } ) - self.client = QdrantClient(url=store_uri, api_key=api_key) + self.qdrant = QdrantClient(url=store_uri, api_key=api_key) - async def handle(self, msg): + async def query_document_embeddings(self, msg): try: - v = msg.value() - - # Sender-produced ID - id = msg.properties()["id"] - - print(f"Handling input {id}...", flush=True) - chunks = [] - for vec in v.vectors: + for vec in msg.vectors: dim = len(vec) collection = ( - "d_" + v.user + "_" + v.collection + "_" + + "d_" + msg.user + "_" + msg.collection + "_" + str(dim) ) - search_result = self.client.query_points( + search_result = self.qdrant.query_points( collection_name=collection, query=vec, - limit=v.limit, + limit=msg.limit, with_payload=True, ).points @@ -79,37 +59,17 @@ class Processor(ConsumerProducer): ent = r.payload["doc"] chunks.append(ent) - print("Send response...", flush=True) - r = DocumentEmbeddingsResponse(documents=chunks, error=None) - await self.send(r, properties={"id": id}) - - print("Done.", flush=True) + return chunks except Exception as e: print(f"Exception: {e}") - - print("Send error response...", flush=True) - - r = DocumentEmbeddingsResponse( - error=Error( - type = "llm-error", - message = str(e), - ), - documents=None, - ) - - await self.send(r, properties={"id": id}) - - self.consumer.acknowledge(msg) + raise e @staticmethod def add_args(parser): - ConsumerProducer.add_args( - parser, default_input_queue, default_subscriber, - default_output_queue, - ) + DocumentEmbeddingsQueryService.add_args(parser) parser.add_argument( '-t', '--store-uri', @@ -125,5 +85,5 @@ class Processor(ConsumerProducer): def run(): - Processor.launch(module, __doc__) + Processor.launch(default_ident, __doc__) diff --git a/trustgraph-flow/trustgraph/query/graph_embeddings/milvus/service.py b/trustgraph-flow/trustgraph/query/graph_embeddings/milvus/service.py index 8dd8d04d..d2cec084 100755 --- a/trustgraph-flow/trustgraph/query/graph_embeddings/milvus/service.py +++ b/trustgraph-flow/trustgraph/query/graph_embeddings/milvus/service.py @@ -11,7 +11,7 @@ from .... schema import graph_embeddings_request_queue from .... schema import graph_embeddings_response_queue from .... base import ConsumerProducer -module = ".".join(__name__.split(".")[1:-1]) +module = "ge-query" default_input_queue = graph_embeddings_request_queue default_output_queue = graph_embeddings_response_queue diff --git a/trustgraph-flow/trustgraph/query/graph_embeddings/pinecone/service.py b/trustgraph-flow/trustgraph/query/graph_embeddings/pinecone/service.py index 90cfc6de..942a1e69 100755 --- a/trustgraph-flow/trustgraph/query/graph_embeddings/pinecone/service.py +++ b/trustgraph-flow/trustgraph/query/graph_embeddings/pinecone/service.py @@ -16,7 +16,7 @@ from .... schema import graph_embeddings_request_queue from .... schema import graph_embeddings_response_queue from .... base import ConsumerProducer -module = ".".join(__name__.split(".")[1:-1]) +module = "ge-query" default_input_queue = graph_embeddings_request_queue default_output_queue = graph_embeddings_response_queue diff --git a/trustgraph-flow/trustgraph/query/graph_embeddings/qdrant/service.py b/trustgraph-flow/trustgraph/query/graph_embeddings/qdrant/service.py index dc3e28f3..32da00e5 100755 --- a/trustgraph-flow/trustgraph/query/graph_embeddings/qdrant/service.py +++ b/trustgraph-flow/trustgraph/query/graph_embeddings/qdrant/service.py @@ -7,44 +7,32 @@ entities from qdrant_client import QdrantClient from qdrant_client.models import PointStruct from qdrant_client.models import Distance, VectorParams -import uuid -from .... schema import GraphEmbeddingsRequest, GraphEmbeddingsResponse +from .... schema import GraphEmbeddingsResponse from .... schema import Error, Value -from .... schema import graph_embeddings_request_queue -from .... schema import graph_embeddings_response_queue -from .... base import ConsumerProducer +from .... base import GraphEmbeddingsQueryService -module = ".".join(__name__.split(".")[1:-1]) +default_ident = "ge-query" -default_input_queue = graph_embeddings_request_queue -default_output_queue = graph_embeddings_response_queue -default_subscriber = module default_store_uri = 'http://localhost:6333' -class Processor(ConsumerProducer): +class Processor(GraphEmbeddingsQueryService): def __init__(self, **params): - input_queue = params.get("input_queue", default_input_queue) - output_queue = params.get("output_queue", default_output_queue) - subscriber = params.get("subscriber", default_subscriber) store_uri = params.get("store_uri", default_store_uri) + + #optional api key api_key = params.get("api_key", None) super(Processor, self).__init__( **params | { - "input_queue": input_queue, - "output_queue": output_queue, - "subscriber": subscriber, - "input_schema": GraphEmbeddingsRequest, - "output_schema": GraphEmbeddingsResponse, "store_uri": store_uri, "api_key": api_key, } ) - self.client = QdrantClient(url=store_uri, api_key=api_key) + self.qdrant = QdrantClient(url=store_uri, api_key=api_key) def create_value(self, ent): if ent.startswith("http://") or ent.startswith("https://"): @@ -52,34 +40,27 @@ class Processor(ConsumerProducer): else: return Value(value=ent, is_uri=False) - async def handle(self, msg): + async def query_graph_embeddings(self, msg): try: - v = msg.value() - - # Sender-produced ID - id = msg.properties()["id"] - - print(f"Handling input {id}...", flush=True) - entity_set = set() entities = [] - for vec in v.vectors: + for vec in msg.vectors: dim = len(vec) collection = ( - "t_" + v.user + "_" + v.collection + "_" + + "t_" + msg.user + "_" + msg.collection + "_" + str(dim) ) # Heuristic hack, get (2*limit), so that we have more chance # of getting (limit) entities - search_result = self.client.query_points( + search_result = self.qdrant.query_points( collection_name=collection, query=vec, - limit=v.limit * 2, + limit=msg.limit * 2, with_payload=True, ).points @@ -92,10 +73,10 @@ class Processor(ConsumerProducer): entities.append(ent) # Keep adding entities until limit - if len(entity_set) >= v.limit: break + if len(entity_set) >= msg.limit: break # Keep adding entities until limit - if len(entity_set) >= v.limit: break + if len(entity_set) >= msg.limit: break ents2 = [] @@ -105,36 +86,19 @@ class Processor(ConsumerProducer): entities = ents2 print("Send response...", flush=True) - r = GraphEmbeddingsResponse(entities=entities, error=None) - await self.send(r, properties={"id": id}) + return entities print("Done.", flush=True) except Exception as e: print(f"Exception: {e}") - - print("Send error response...", flush=True) - - r = GraphEmbeddingsResponse( - error=Error( - type = "llm-error", - message = str(e), - ), - entities=None, - ) - - await self.send(r, properties={"id": id}) - - self.consumer.acknowledge(msg) + raise e @staticmethod def add_args(parser): - ConsumerProducer.add_args( - parser, default_input_queue, default_subscriber, - default_output_queue, - ) + GraphEmbeddingsQueryService.add_args(parser) parser.add_argument( '-t', '--store-uri', @@ -150,5 +114,5 @@ class Processor(ConsumerProducer): def run(): - Processor.launch(module, __doc__) + Processor.launch(default_ident, __doc__) diff --git a/trustgraph-flow/trustgraph/query/triples/cassandra/service.py b/trustgraph-flow/trustgraph/query/triples/cassandra/service.py index e3687756..6fcf4a19 100755 --- a/trustgraph-flow/trustgraph/query/triples/cassandra/service.py +++ b/trustgraph-flow/trustgraph/query/triples/cassandra/service.py @@ -7,38 +7,24 @@ null. Output is a list of triples. from .... direct.cassandra import TrustGraph from .... schema import TriplesQueryRequest, TriplesQueryResponse, Error from .... schema import Value, Triple -from .... schema import triples_request_queue -from .... schema import triples_response_queue -from .... base import ConsumerProducer +from .... base import TriplesQueryService -module = ".".join(__name__.split(".")[1:-1]) +default_ident = "triples-query" -default_input_queue = triples_request_queue -default_output_queue = triples_response_queue -default_subscriber = module default_graph_host='localhost' -class Processor(ConsumerProducer): +class Processor(TriplesQueryService): def __init__(self, **params): - input_queue = params.get("input_queue", default_input_queue) - output_queue = params.get("output_queue", default_output_queue) - subscriber = params.get("subscriber", default_subscriber) graph_host = params.get("graph_host", default_graph_host) graph_username = params.get("graph_username", None) graph_password = params.get("graph_password", None) super(Processor, self).__init__( **params | { - "input_queue": input_queue, - "output_queue": output_queue, - "subscriber": subscriber, - "input_schema": TriplesQueryRequest, - "output_schema": TriplesQueryResponse, "graph_host": graph_host, "graph_username": graph_username, - "graph_password": graph_password, } ) @@ -53,92 +39,85 @@ class Processor(ConsumerProducer): else: return Value(value=ent, is_uri=False) - async def handle(self, msg): + async def query_triples(self, query): try: - v = msg.value() - - table = (v.user, v.collection) + table = (query.user, query.collection) if table != self.table: if self.username and self.password: self.tg = TrustGraph( hosts=self.graph_host, - keyspace=v.user, table=v.collection, + keyspace=query.user, table=query.collection, username=self.username, password=self.password ) else: self.tg = TrustGraph( hosts=self.graph_host, - keyspace=v.user, table=v.collection, + keyspace=query.user, table=query.collection, ) self.table = table - # Sender-produced ID - id = msg.properties()["id"] - - print(f"Handling input {id}...", flush=True) - triples = [] - if v.s is not None: - if v.p is not None: - if v.o is not None: + if query.s is not None: + if query.p is not None: + if query.o is not None: resp = self.tg.get_spo( - v.s.value, v.p.value, v.o.value, - limit=v.limit + query.s.value, query.p.value, query.o.value, + limit=query.limit ) - triples.append((v.s.value, v.p.value, v.o.value)) + triples.append((query.s.value, query.p.value, query.o.value)) else: resp = self.tg.get_sp( - v.s.value, v.p.value, - limit=v.limit + query.s.value, query.p.value, + limit=query.limit ) for t in resp: - triples.append((v.s.value, v.p.value, t.o)) + triples.append((query.s.value, query.p.value, t.o)) else: - if v.o is not None: + if query.o is not None: resp = self.tg.get_os( - v.o.value, v.s.value, - limit=v.limit + query.o.value, query.s.value, + limit=query.limit ) for t in resp: - triples.append((v.s.value, t.p, v.o.value)) + triples.append((query.s.value, t.p, query.o.value)) else: resp = self.tg.get_s( - v.s.value, - limit=v.limit + query.s.value, + limit=query.limit ) for t in resp: - triples.append((v.s.value, t.p, t.o)) + triples.append((query.s.value, t.p, t.o)) else: - if v.p is not None: - if v.o is not None: + if query.p is not None: + if query.o is not None: resp = self.tg.get_po( - v.p.value, v.o.value, - limit=v.limit + query.p.value, query.o.value, + limit=query.limit ) for t in resp: - triples.append((t.s, v.p.value, v.o.value)) + triples.append((t.s, query.p.value, query.o.value)) else: resp = self.tg.get_p( - v.p.value, - limit=v.limit + query.p.value, + limit=query.limit ) for t in resp: - triples.append((t.s, v.p.value, t.o)) + triples.append((t.s, query.p.value, t.o)) else: - if v.o is not None: + if query.o is not None: resp = self.tg.get_o( - v.o.value, - limit=v.limit + query.o.value, + limit=query.limit ) for t in resp: - triples.append((t.s, t.p, v.o.value)) + triples.append((t.s, t.p, query.o.value)) else: resp = self.tg.get_all( - limit=v.limit + limit=query.limit ) for t in resp: triples.append((t.s, t.p, t.o)) @@ -152,37 +131,17 @@ class Processor(ConsumerProducer): for t in triples ] - print("Send response...", flush=True) - r = TriplesQueryResponse(triples=triples, error=None) - await self.send(r, properties={"id": id}) - - print("Done.", flush=True) + return triples except Exception as e: print(f"Exception: {e}") - - print("Send error response...", flush=True) - - r = TriplesQueryResponse( - error=Error( - type = "llm-error", - message = str(e), - ), - response=None, - ) - - await self.send(r, properties={"id": id}) - - self.consumer.acknowledge(msg) + raise e @staticmethod def add_args(parser): - ConsumerProducer.add_args( - parser, default_input_queue, default_subscriber, - default_output_queue, - ) + TriplesQueryService.add_args(parser) parser.add_argument( '-g', '--graph-host', @@ -205,5 +164,5 @@ class Processor(ConsumerProducer): def run(): - Processor.launch(module, __doc__) + Processor.launch(default_ident, __doc__) diff --git a/trustgraph-flow/trustgraph/query/triples/falkordb/service.py b/trustgraph-flow/trustgraph/query/triples/falkordb/service.py index 56fed6d3..c62c28c1 100755 --- a/trustgraph-flow/trustgraph/query/triples/falkordb/service.py +++ b/trustgraph-flow/trustgraph/query/triples/falkordb/service.py @@ -13,7 +13,7 @@ from .... schema import triples_request_queue from .... schema import triples_response_queue from .... base import ConsumerProducer -module = ".".join(__name__.split(".")[1:-1]) +module = "triples-query" default_input_queue = triples_request_queue default_output_queue = triples_response_queue diff --git a/trustgraph-flow/trustgraph/query/triples/memgraph/service.py b/trustgraph-flow/trustgraph/query/triples/memgraph/service.py index f442c4ef..594c9130 100755 --- a/trustgraph-flow/trustgraph/query/triples/memgraph/service.py +++ b/trustgraph-flow/trustgraph/query/triples/memgraph/service.py @@ -13,7 +13,7 @@ from .... schema import triples_request_queue from .... schema import triples_response_queue from .... base import ConsumerProducer -module = ".".join(__name__.split(".")[1:-1]) +module = "triples-query" default_input_queue = triples_request_queue default_output_queue = triples_response_queue diff --git a/trustgraph-flow/trustgraph/query/triples/neo4j/service.py b/trustgraph-flow/trustgraph/query/triples/neo4j/service.py index 49ba0345..591361ce 100755 --- a/trustgraph-flow/trustgraph/query/triples/neo4j/service.py +++ b/trustgraph-flow/trustgraph/query/triples/neo4j/service.py @@ -13,7 +13,7 @@ from .... schema import triples_request_queue from .... schema import triples_response_queue from .... base import ConsumerProducer -module = ".".join(__name__.split(".")[1:-1]) +module = "triples-query" default_input_queue = triples_request_queue default_output_queue = triples_response_queue diff --git a/trustgraph-flow/trustgraph/retrieval/document_rag/document_rag.py b/trustgraph-flow/trustgraph/retrieval/document_rag/document_rag.py new file mode 100644 index 00000000..5e3c9b41 --- /dev/null +++ b/trustgraph-flow/trustgraph/retrieval/document_rag/document_rag.py @@ -0,0 +1,94 @@ + +import asyncio + +LABEL="http://www.w3.org/2000/01/rdf-schema#label" + +class Query: + + def __init__( + self, rag, user, collection, verbose, + doc_limit=20 + ): + self.rag = rag + self.user = user + self.collection = collection + self.verbose = verbose + self.doc_limit = doc_limit + + async def get_vector(self, query): + + if self.verbose: + print("Compute embeddings...", flush=True) + + qembeds = await self.rag.embeddings_client.embed(query) + + if self.verbose: + print("Done.", flush=True) + + return qembeds + + async def get_docs(self, query): + + vectors = await self.get_vector(query) + + if self.verbose: + print("Get docs...", flush=True) + + docs = await self.rag.doc_embeddings_client.query( + vectors, limit=self.doc_limit, + user=self.user, collection=self.collection, + ) + + if self.verbose: + print("Docs:", flush=True) + for doc in docs: + print(doc, flush=True) + + return docs + +class DocumentRag: + + def __init__( + self, prompt_client, embeddings_client, doc_embeddings_client, + verbose=False, + ): + + self.verbose = verbose + + self.prompt_client = prompt_client + self.embeddings_client = embeddings_client + self.doc_embeddings_client = doc_embeddings_client + + if self.verbose: + print("Initialised", flush=True) + + async def query( + self, query, user="trustgraph", collection="default", + doc_limit=20, + ): + + if self.verbose: + print("Construct prompt...", flush=True) + + q = Query( + rag=self, user=user, collection=collection, verbose=self.verbose, + doc_limit=doc_limit + ) + + docs = await q.get_docs(query) + + if self.verbose: + print("Invoke LLM...", flush=True) + print(docs) + print(query) + + resp = await self.prompt_client.document_prompt( + query = query, + documents = docs + ) + + if self.verbose: + print("Done", flush=True) + + return resp + diff --git a/trustgraph-flow/trustgraph/retrieval/document_rag/rag.py b/trustgraph-flow/trustgraph/retrieval/document_rag/rag.py index bb8b008e..8c478874 100755 --- a/trustgraph-flow/trustgraph/retrieval/document_rag/rag.py +++ b/trustgraph-flow/trustgraph/retrieval/document_rag/rag.py @@ -5,88 +5,77 @@ Input is query, output is response. """ from ... schema import DocumentRagQuery, DocumentRagResponse, Error -from ... schema import document_rag_request_queue, document_rag_response_queue -from ... schema import prompt_request_queue -from ... schema import prompt_response_queue -from ... schema import embeddings_request_queue -from ... schema import embeddings_response_queue -from ... schema import document_embeddings_request_queue -from ... schema import document_embeddings_response_queue -from ... log_level import LogLevel -from ... document_rag import DocumentRag -from ... base import ConsumerProducer +from . document_rag import DocumentRag +from ... base import FlowProcessor, ConsumerSpec, ProducerSpec +from ... base import PromptClientSpec, EmbeddingsClientSpec +from ... base import DocumentEmbeddingsClientSpec -module = ".".join(__name__.split(".")[1:-1]) +default_ident = "document-rag" -default_input_queue = document_rag_request_queue -default_output_queue = document_rag_response_queue -default_subscriber = module - -class Processor(ConsumerProducer): +class Processor(FlowProcessor): def __init__(self, **params): - input_queue = params.get("input_queue", default_input_queue) - output_queue = params.get("output_queue", default_output_queue) - subscriber = params.get("subscriber", default_subscriber) - pr_request_queue = params.get( - "prompt_request_queue", prompt_request_queue - ) - pr_response_queue = params.get( - "prompt_response_queue", prompt_response_queue - ) - emb_request_queue = params.get( - "embeddings_request_queue", embeddings_request_queue - ) - emb_response_queue = params.get( - "embeddings_response_queue", embeddings_response_queue - ) - de_request_queue = params.get( - "document_embeddings_request_queue", - document_embeddings_request_queue - ) - de_response_queue = params.get( - "document_embeddings_response_queue", - document_embeddings_response_queue - ) + id = params.get("id", default_ident) - doc_limit = params.get("doc_limit", 10) + doc_limit = params.get("doc_limit", 5) super(Processor, self).__init__( **params | { - "input_queue": input_queue, - "output_queue": output_queue, - "subscriber": subscriber, - "input_schema": DocumentRagQuery, - "output_schema": DocumentRagResponse, - "prompt_request_queue": pr_request_queue, - "prompt_response_queue": pr_response_queue, - "embeddings_request_queue": emb_request_queue, - "embeddings_response_queue": emb_response_queue, - "document_embeddings_request_queue": de_request_queue, - "document_embeddings_response_queue": de_response_queue, + "id": id, + "doc_limit": doc_limit, } ) - self.rag = DocumentRag( - pulsar_host=self.pulsar_host, - pulsar_api_key=self.pulsar_api_key, - pr_request_queue=pr_request_queue, - pr_response_queue=pr_response_queue, - emb_request_queue=emb_request_queue, - emb_response_queue=emb_response_queue, - de_request_queue=de_request_queue, - de_response_queue=de_response_queue, - verbose=True, - module=module, - ) - self.doc_limit = doc_limit - async def handle(self, msg): + self.register_specification( + ConsumerSpec( + name = "request", + schema = DocumentRagQuery, + handler = self.on_request, + ) + ) + + self.register_specification( + EmbeddingsClientSpec( + request_name = "embeddings-request", + response_name = "embeddings-response", + ) + ) + + self.register_specification( + DocumentEmbeddingsClientSpec( + request_name = "document-embeddings-request", + response_name = "document-embeddings-response", + ) + ) + + self.register_specification( + PromptClientSpec( + request_name = "prompt-request", + response_name = "prompt-response", + ) + ) + + self.register_specification( + ProducerSpec( + name = "response", + schema = DocumentRagResponse, + ) + ) + + async def on_request(self, msg, consumer, flow): try: + self.rag = DocumentRag( + embeddings_client = flow("embeddings-request"), + doc_embeddings_client = flow("document-embeddings-request"), + prompt_client = flow("prompt-request"), + verbose=True, + ) + v = msg.value() # Sender-produced ID @@ -99,11 +88,15 @@ class Processor(ConsumerProducer): else: doc_limit = self.doc_limit - response = self.rag.query(v.query, doc_limit=doc_limit) + response = await self.rag.query(v.query, doc_limit=doc_limit) - print("Send response...", flush=True) - r = DocumentRagResponse(response = response, error=None) - await self.send(r, properties={"id": id}) + await flow("response").send( + DocumentRagResponse( + response = response, + error = None + ), + properties = {"id": id} + ) print("Done.", flush=True) @@ -113,25 +106,21 @@ class Processor(ConsumerProducer): print("Send error response...", flush=True) - r = DocumentRagResponse( - error=Error( - type = "llm-error", - message = str(e), + await flow("response").send( + DocumentRagResponse( + response = None, + error = Error( + type = "document-rag-error", + message = str(e), + ), ), - response=None, + properties = {"id": id} ) - await self.send(r, properties={"id": id}) - - self.consumer.acknowledge(msg) - @staticmethod def add_args(parser): - ConsumerProducer.add_args( - parser, default_input_queue, default_subscriber, - default_output_queue, - ) + FlowProcessor.add_args(parser) parser.add_argument( '-d', '--doc-limit', @@ -140,43 +129,7 @@ class Processor(ConsumerProducer): help=f'Default document fetch limit (default: 10)' ) - parser.add_argument( - '--prompt-request-queue', - default=prompt_request_queue, - help=f'Prompt request queue (default: {prompt_request_queue})', - ) - - parser.add_argument( - '--prompt-response-queue', - default=prompt_response_queue, - help=f'Prompt response queue (default: {prompt_response_queue})', - ) - - parser.add_argument( - '--embeddings-request-queue', - default=embeddings_request_queue, - help=f'Embeddings request queue (default: {embeddings_request_queue})', - ) - - parser.add_argument( - '--embeddings-response-queue', - default=embeddings_response_queue, - help=f'Embeddings response queue (default: {embeddings_response_queue})', - ) - - parser.add_argument( - '--document-embeddings-request-queue', - default=document_embeddings_request_queue, - help=f'Document embeddings request queue (default: {document_embeddings_request_queue})', - ) - - parser.add_argument( - '--document-embeddings-response-queue', - default=document_embeddings_response_queue, - help=f'Document embeddings response queue (default: {document_embeddings_response_queue})', - ) - def run(): - Processor.launch(module, __doc__) + Processor.launch(default_ident, __doc__) diff --git a/trustgraph-flow/trustgraph/retrieval/graph_rag/graph_rag.py b/trustgraph-flow/trustgraph/retrieval/graph_rag/graph_rag.py new file mode 100644 index 00000000..6879023a --- /dev/null +++ b/trustgraph-flow/trustgraph/retrieval/graph_rag/graph_rag.py @@ -0,0 +1,218 @@ + +import asyncio + +LABEL="http://www.w3.org/2000/01/rdf-schema#label" + +class Query: + + def __init__( + self, rag, user, collection, verbose, + entity_limit=50, triple_limit=30, max_subgraph_size=1000, + max_path_length=2, + ): + self.rag = rag + self.user = user + self.collection = collection + self.verbose = verbose + self.entity_limit = entity_limit + self.triple_limit = triple_limit + self.max_subgraph_size = max_subgraph_size + self.max_path_length = max_path_length + + async def get_vector(self, query): + + if self.verbose: + print("Compute embeddings...", flush=True) + + qembeds = await self.rag.embeddings_client.embed(query) + + if self.verbose: + print("Done.", flush=True) + + return qembeds + + async def get_entities(self, query): + + vectors = await self.get_vector(query) + + if self.verbose: + print("Get entities...", flush=True) + + entities = await self.rag.graph_embeddings_client.query( + vectors=vectors, limit=self.entity_limit, + user=self.user, collection=self.collection, + ) + + entities = [ + str(e) + for e in entities + ] + + if self.verbose: + print("Entities:", flush=True) + for ent in entities: + print(" ", ent, flush=True) + + return entities + + async def maybe_label(self, e): + + if e in self.rag.label_cache: + return self.rag.label_cache[e] + + res = await self.rag.triples_client.query( + s=e, p=LABEL, o=None, limit=1, + user=self.user, collection=self.collection, + ) + + if len(res) == 0: + self.rag.label_cache[e] = e + return e + + self.rag.label_cache[e] = str(res[0].o) + return self.rag.label_cache[e] + + async def follow_edges(self, ent, subgraph, path_length): + + # Not needed? + if path_length <= 0: + return + + # Stop spanning around if the subgraph is already maxed out + if len(subgraph) >= self.max_subgraph_size: + return + + res = await self.rag.triples_client.query( + s=ent, p=None, o=None, + limit=self.triple_limit, + user=self.user, collection=self.collection, + ) + + for triple in res: + subgraph.add( + (str(triple.s), str(triple.p), str(triple.o)) + ) + if path_length > 1: + await self.follow_edges(str(triple.o), subgraph, path_length-1) + + res = await self.rag.triples_client.query( + s=None, p=ent, o=None, + limit=self.triple_limit, + user=self.user, collection=self.collection, + ) + + for triple in res: + subgraph.add( + (str(triple.s), str(triple.p), str(triple.o)) + ) + + res = await self.rag.triples_client.query( + s=None, p=None, o=ent, + limit=self.triple_limit, + user=self.user, collection=self.collection, + ) + + for triple in res: + subgraph.add( + (str(triple.s), str(triple.p), str(triple.o)) + ) + if path_length > 1: + await self.follow_edges( + str(triple.s), subgraph, path_length-1 + ) + + async def get_subgraph(self, query): + + entities = await self.get_entities(query) + + if self.verbose: + print("Get subgraph...", flush=True) + + subgraph = set() + + for ent in entities: + await self.follow_edges(ent, subgraph, self.max_path_length) + + subgraph = list(subgraph) + + return subgraph + + async def get_labelgraph(self, query): + + subgraph = await self.get_subgraph(query) + + sg2 = [] + + for edge in subgraph: + + if edge[1] == LABEL: + continue + + s = await self.maybe_label(edge[0]) + p = await self.maybe_label(edge[1]) + o = await self.maybe_label(edge[2]) + + sg2.append((s, p, o)) + + sg2 = sg2[0:self.max_subgraph_size] + + if self.verbose: + print("Subgraph:", flush=True) + for edge in sg2: + print(" ", str(edge), flush=True) + + if self.verbose: + print("Done.", flush=True) + + return sg2 + +class GraphRag: + + def __init__( + self, prompt_client, embeddings_client, graph_embeddings_client, + triples_client, verbose=False, + ): + + self.verbose = verbose + + self.prompt_client = prompt_client + self.embeddings_client = embeddings_client + self.graph_embeddings_client = graph_embeddings_client + self.triples_client = triples_client + + self.label_cache = {} + + if self.verbose: + print("Initialised", flush=True) + + async def query( + self, query, user = "trustgraph", collection = "default", + entity_limit = 50, triple_limit = 30, max_subgraph_size = 1000, + max_path_length = 2, + ): + + if self.verbose: + print("Construct prompt...", flush=True) + + q = Query( + rag = self, user = user, collection = collection, + verbose = self.verbose, entity_limit = entity_limit, + triple_limit = triple_limit, + max_subgraph_size = max_subgraph_size, + max_path_length = max_path_length, + ) + + kg = await q.get_labelgraph(query) + + if self.verbose: + print("Invoke LLM...", flush=True) + print(kg) + print(query) + + resp = await self.prompt_client.kg_prompt(query, kg) + + if self.verbose: + print("Done", flush=True) + + return resp + diff --git a/trustgraph-flow/trustgraph/retrieval/graph_rag/rag.py b/trustgraph-flow/trustgraph/retrieval/graph_rag/rag.py index 2c45ecd4..5d3cc2f4 100755 --- a/trustgraph-flow/trustgraph/retrieval/graph_rag/rag.py +++ b/trustgraph-flow/trustgraph/retrieval/graph_rag/rag.py @@ -5,57 +5,18 @@ Input is query, output is response. """ from ... schema import GraphRagQuery, GraphRagResponse, Error -from ... schema import graph_rag_request_queue, graph_rag_response_queue -from ... schema import prompt_request_queue -from ... schema import prompt_response_queue -from ... schema import embeddings_request_queue -from ... schema import embeddings_response_queue -from ... schema import graph_embeddings_request_queue -from ... schema import graph_embeddings_response_queue -from ... schema import triples_request_queue -from ... schema import triples_response_queue -from ... log_level import LogLevel -from ... graph_rag import GraphRag -from ... base import ConsumerProducer +from . graph_rag import GraphRag +from ... base import FlowProcessor, ConsumerSpec, ProducerSpec +from ... base import PromptClientSpec, EmbeddingsClientSpec +from ... base import GraphEmbeddingsClientSpec, TriplesClientSpec -module = ".".join(__name__.split(".")[1:-1]) +default_ident = "graph-rag" -default_input_queue = graph_rag_request_queue -default_output_queue = graph_rag_response_queue -default_subscriber = module - -class Processor(ConsumerProducer): +class Processor(FlowProcessor): def __init__(self, **params): - input_queue = params.get("input_queue", default_input_queue) - output_queue = params.get("output_queue", default_output_queue) - subscriber = params.get("subscriber", default_subscriber) - - pr_request_queue = params.get( - "prompt_request_queue", prompt_request_queue - ) - pr_response_queue = params.get( - "prompt_response_queue", prompt_response_queue - ) - emb_request_queue = params.get( - "embeddings_request_queue", embeddings_request_queue - ) - emb_response_queue = params.get( - "embeddings_response_queue", embeddings_response_queue - ) - ge_request_queue = params.get( - "graph_embeddings_request_queue", graph_embeddings_request_queue - ) - ge_response_queue = params.get( - "graph_embeddings_response_queue", graph_embeddings_response_queue - ) - tpl_request_queue = params.get( - "triples_request_queue", triples_request_queue - ) - tpl_response_queue = params.get( - "triples_response_queue", triples_response_queue - ) + id = params.get("id", default_ident) entity_limit = params.get("entity_limit", 50) triple_limit = params.get("triple_limit", 30) @@ -64,49 +25,74 @@ class Processor(ConsumerProducer): super(Processor, self).__init__( **params | { - "input_queue": input_queue, - "output_queue": output_queue, - "subscriber": subscriber, - "input_schema": GraphRagQuery, - "output_schema": GraphRagResponse, + "id": id, "entity_limit": entity_limit, "triple_limit": triple_limit, "max_subgraph_size": max_subgraph_size, - "prompt_request_queue": pr_request_queue, - "prompt_response_queue": pr_response_queue, - "embeddings_request_queue": emb_request_queue, - "embeddings_response_queue": emb_response_queue, - "graph_embeddings_request_queue": ge_request_queue, - "graph_embeddings_response_queue": ge_response_queue, - "triples_request_queue": triples_request_queue, - "triples_response_queue": triples_response_queue, + "max_path_length": max_path_length, } ) - self.rag = GraphRag( - pulsar_host=self.pulsar_host, - pulsar_api_key=self.pulsar_api_key, - pr_request_queue=pr_request_queue, - pr_response_queue=pr_response_queue, - emb_request_queue=emb_request_queue, - emb_response_queue=emb_response_queue, - ge_request_queue=ge_request_queue, - ge_response_queue=ge_response_queue, - tpl_request_queue=triples_request_queue, - tpl_response_queue=triples_response_queue, - verbose=True, - module=module, - ) - self.default_entity_limit = entity_limit self.default_triple_limit = triple_limit self.default_max_subgraph_size = max_subgraph_size self.default_max_path_length = max_path_length - async def handle(self, msg): + self.register_specification( + ConsumerSpec( + name = "request", + schema = GraphRagQuery, + handler = self.on_request, + ) + ) + + self.register_specification( + EmbeddingsClientSpec( + request_name = "embeddings-request", + response_name = "embeddings-response", + ) + ) + + self.register_specification( + GraphEmbeddingsClientSpec( + request_name = "graph-embeddings-request", + response_name = "graph-embeddings-response", + ) + ) + + self.register_specification( + TriplesClientSpec( + request_name = "triples-request", + response_name = "triples-response", + ) + ) + + self.register_specification( + PromptClientSpec( + request_name = "prompt-request", + response_name = "prompt-response", + ) + ) + + self.register_specification( + ProducerSpec( + name = "response", + schema = GraphRagResponse, + ) + ) + + async def on_request(self, msg, consumer, flow): try: + self.rag = GraphRag( + embeddings_client = flow("embeddings-request"), + graph_embeddings_client = flow("graph-embeddings-request"), + triples_client = flow("triples-request"), + prompt_client = flow("prompt-request"), + verbose=True, + ) + v = msg.value() # Sender-produced ID @@ -134,16 +120,20 @@ class Processor(ConsumerProducer): else: max_path_length = self.default_max_path_length - response = self.rag.query( - query=v.query, user=v.user, collection=v.collection, - entity_limit=entity_limit, triple_limit=triple_limit, - max_subgraph_size=max_subgraph_size, - max_path_length=max_path_length, + response = await self.rag.query( + query = v.query, user = v.user, collection = v.collection, + entity_limit = entity_limit, triple_limit = triple_limit, + max_subgraph_size = max_subgraph_size, + max_path_length = max_path_length, ) - print("Send response...", flush=True) - r = GraphRagResponse(response=response, error=None) - await self.send(r, properties={"id": id}) + await flow("response").send( + GraphRagResponse( + response = response, + error = None + ), + properties = {"id": id} + ) print("Done.", flush=True) @@ -153,25 +143,21 @@ class Processor(ConsumerProducer): print("Send error response...", flush=True) - r = GraphRagResponse( - error=Error( - type = "llm-error", - message = str(e), + await flow("response").send( + GraphRagResponse( + response = None, + error = Error( + type = "graph-rag-error", + message = str(e), + ), ), - response=None, + properties = {"id": id} ) - await self.send(r, properties={"id": id}) - - self.consumer.acknowledge(msg) - @staticmethod def add_args(parser): - ConsumerProducer.add_args( - parser, default_input_queue, default_subscriber, - default_output_queue, - ) + FlowProcessor.add_args(parser) parser.add_argument( '-e', '--entity-limit', @@ -201,55 +187,7 @@ class Processor(ConsumerProducer): help=f'Default max path length (default: 2)' ) - parser.add_argument( - '--prompt-request-queue', - default=prompt_request_queue, - help=f'Prompt request queue (default: {prompt_request_queue})', - ) - - parser.add_argument( - '--prompt-response-queue', - default=prompt_response_queue, - help=f'Prompt response queue (default: {prompt_response_queue})', - ) - - parser.add_argument( - '--embeddings-request-queue', - default=embeddings_request_queue, - help=f'Embeddings request queue (default: {embeddings_request_queue})', - ) - - parser.add_argument( - '--embeddings-response-queue', - default=embeddings_response_queue, - help=f'Embeddings response queue (default: {embeddings_response_queue})', - ) - - parser.add_argument( - '--graph-embeddings-request-queue', - default=graph_embeddings_request_queue, - help=f'Graph embeddings request queue (default: {graph_embeddings_request_queue})', - ) - - parser.add_argument( - '--graph-embeddings-response-queue', - default=graph_embeddings_response_queue, - help=f'Graph embeddings response queue (default: {graph_embeddings_response_queue})', - ) - - parser.add_argument( - '--triples-request-queue', - default=triples_request_queue, - help=f'Triples request queue (default: {triples_request_queue})', - ) - - parser.add_argument( - '--triples-response-queue', - default=triples_response_queue, - help=f'Triples response queue (default: {triples_response_queue})', - ) - def run(): - Processor.launch(module, __doc__) + Processor.launch(default_ident, __doc__) diff --git a/trustgraph-flow/trustgraph/storage/doc_embeddings/milvus/write.py b/trustgraph-flow/trustgraph/storage/doc_embeddings/milvus/write.py index b4dbc486..2949263a 100755 --- a/trustgraph-flow/trustgraph/storage/doc_embeddings/milvus/write.py +++ b/trustgraph-flow/trustgraph/storage/doc_embeddings/milvus/write.py @@ -10,7 +10,7 @@ from .... schema import document_embeddings_store_queue from .... log_level import LogLevel from .... base import Consumer -module = ".".join(__name__.split(".")[1:-1]) +module = "de-write" default_input_queue = document_embeddings_store_queue default_subscriber = module diff --git a/trustgraph-flow/trustgraph/storage/doc_embeddings/pinecone/write.py b/trustgraph-flow/trustgraph/storage/doc_embeddings/pinecone/write.py index 9e91db9a..128323aa 100644 --- a/trustgraph-flow/trustgraph/storage/doc_embeddings/pinecone/write.py +++ b/trustgraph-flow/trustgraph/storage/doc_embeddings/pinecone/write.py @@ -16,7 +16,7 @@ from .... schema import document_embeddings_store_queue from .... log_level import LogLevel from .... base import Consumer -module = ".".join(__name__.split(".")[1:-1]) +module = "de-write" default_input_queue = document_embeddings_store_queue default_subscriber = module diff --git a/trustgraph-flow/trustgraph/storage/doc_embeddings/qdrant/write.py b/trustgraph-flow/trustgraph/storage/doc_embeddings/qdrant/write.py index 810c1931..d65a75eb 100644 --- a/trustgraph-flow/trustgraph/storage/doc_embeddings/qdrant/write.py +++ b/trustgraph-flow/trustgraph/storage/doc_embeddings/qdrant/write.py @@ -8,31 +8,21 @@ from qdrant_client.models import PointStruct from qdrant_client.models import Distance, VectorParams import uuid -from .... schema import DocumentEmbeddings -from .... schema import document_embeddings_store_queue -from .... log_level import LogLevel -from .... base import Consumer +from .... base import DocumentEmbeddingsStoreService -module = ".".join(__name__.split(".")[1:-1]) +default_ident = "de-write" -default_input_queue = document_embeddings_store_queue -default_subscriber = module default_store_uri = 'http://localhost:6333' -class Processor(Consumer): +class Processor(DocumentEmbeddingsStoreService): def __init__(self, **params): - input_queue = params.get("input_queue", default_input_queue) - subscriber = params.get("subscriber", default_subscriber) store_uri = params.get("store_uri", default_store_uri) api_key = params.get("api_key", None) super(Processor, self).__init__( **params | { - "input_queue": input_queue, - "subscriber": subscriber, - "input_schema": DocumentEmbeddings, "store_uri": store_uri, "api_key": api_key, } @@ -40,13 +30,11 @@ class Processor(Consumer): self.last_collection = None - self.client = QdrantClient(url=store_uri) + self.qdrant = QdrantClient(url=store_uri, api_key=api_key) - async def handle(self, msg): + async def store_document_embeddings(self, message): - v = msg.value() - - for emb in v.chunks: + for emb in message.chunks: chunk = emb.chunk.decode("utf-8") if chunk == "": return @@ -55,16 +43,17 @@ class Processor(Consumer): dim = len(vec) collection = ( - "d_" + v.metadata.user + "_" + v.metadata.collection + "_" + + "d_" + message.metadata.user + "_" + + message.metadata.collection + "_" + str(dim) ) if collection != self.last_collection: - if not self.client.collection_exists(collection): + if not self.qdrant.collection_exists(collection): try: - self.client.create_collection( + self.qdrant.create_collection( collection_name=collection, vectors_config=VectorParams( size=dim, distance=Distance.COSINE @@ -76,7 +65,7 @@ class Processor(Consumer): self.last_collection = collection - self.client.upsert( + self.qdrant.upsert( collection_name=collection, points=[ PointStruct( @@ -92,9 +81,7 @@ class Processor(Consumer): @staticmethod def add_args(parser): - Consumer.add_args( - parser, default_input_queue, default_subscriber, - ) + DocumentEmbeddingsStoreService.add_args(parser) parser.add_argument( '-t', '--store-uri', @@ -110,5 +97,5 @@ class Processor(Consumer): def run(): - Processor.launch(module, __doc__) + Processor.launch(default_ident, __doc__) diff --git a/trustgraph-flow/trustgraph/storage/graph_embeddings/milvus/write.py b/trustgraph-flow/trustgraph/storage/graph_embeddings/milvus/write.py index b2d40306..8d8b68b0 100755 --- a/trustgraph-flow/trustgraph/storage/graph_embeddings/milvus/write.py +++ b/trustgraph-flow/trustgraph/storage/graph_embeddings/milvus/write.py @@ -9,7 +9,7 @@ from .... log_level import LogLevel from .... direct.milvus_graph_embeddings import EntityVectors from .... base import Consumer -module = ".".join(__name__.split(".")[1:-1]) +module = "ge-write" default_input_queue = graph_embeddings_store_queue default_subscriber = module diff --git a/trustgraph-flow/trustgraph/storage/graph_embeddings/pinecone/write.py b/trustgraph-flow/trustgraph/storage/graph_embeddings/pinecone/write.py index 83861b54..400acf26 100755 --- a/trustgraph-flow/trustgraph/storage/graph_embeddings/pinecone/write.py +++ b/trustgraph-flow/trustgraph/storage/graph_embeddings/pinecone/write.py @@ -15,7 +15,7 @@ from .... schema import graph_embeddings_store_queue from .... log_level import LogLevel from .... base import Consumer -module = ".".join(__name__.split(".")[1:-1]) +module = "ge-write" default_input_queue = graph_embeddings_store_queue default_subscriber = module diff --git a/trustgraph-flow/trustgraph/storage/graph_embeddings/qdrant/write.py b/trustgraph-flow/trustgraph/storage/graph_embeddings/qdrant/write.py index 6b0d7371..ecefee4f 100755 --- a/trustgraph-flow/trustgraph/storage/graph_embeddings/qdrant/write.py +++ b/trustgraph-flow/trustgraph/storage/graph_embeddings/qdrant/write.py @@ -8,31 +8,21 @@ from qdrant_client.models import PointStruct from qdrant_client.models import Distance, VectorParams import uuid -from .... schema import GraphEmbeddings -from .... schema import graph_embeddings_store_queue -from .... log_level import LogLevel -from .... base import Consumer +from .... base import GraphEmbeddingsStoreService -module = ".".join(__name__.split(".")[1:-1]) +default_ident = "ge-write" -default_input_queue = graph_embeddings_store_queue -default_subscriber = module default_store_uri = 'http://localhost:6333' -class Processor(Consumer): +class Processor(GraphEmbeddingsStoreService): def __init__(self, **params): - input_queue = params.get("input_queue", default_input_queue) - subscriber = params.get("subscriber", default_subscriber) store_uri = params.get("store_uri", default_store_uri) api_key = params.get("api_key", None) super(Processor, self).__init__( **params | { - "input_queue": input_queue, - "subscriber": subscriber, - "input_schema": GraphEmbeddings, "store_uri": store_uri, "api_key": api_key, } @@ -40,7 +30,7 @@ class Processor(Consumer): self.last_collection = None - self.client = QdrantClient(url=store_uri, api_key=api_key) + self.qdrant = QdrantClient(url=store_uri, api_key=api_key) def get_collection(self, dim, user, collection): @@ -50,10 +40,10 @@ class Processor(Consumer): if cname != self.last_collection: - if not self.client.collection_exists(cname): + if not self.qdrant.collection_exists(cname): try: - self.client.create_collection( + self.qdrant.create_collection( collection_name=cname, vectors_config=VectorParams( size=dim, distance=Distance.COSINE @@ -67,11 +57,9 @@ class Processor(Consumer): return cname - async def handle(self, msg): + async def store_graph_embeddings(self, message): - v = msg.value() - - for entity in v.entities: + for entity in message.entities: if entity.entity.value == "" or entity.entity.value is None: return @@ -80,10 +68,10 @@ class Processor(Consumer): dim = len(vec) collection = self.get_collection( - dim, v.metadata.user, v.metadata.collection + dim, message.metadata.user, message.metadata.collection ) - self.client.upsert( + self.qdrant.upsert( collection_name=collection, points=[ PointStruct( @@ -99,9 +87,7 @@ class Processor(Consumer): @staticmethod def add_args(parser): - Consumer.add_args( - parser, default_input_queue, default_subscriber, - ) + GraphEmbeddingsStoreService.add_args(parser) parser.add_argument( '-t', '--store-uri', @@ -117,5 +103,5 @@ class Processor(Consumer): def run(): - Processor.launch(module, __doc__) + Processor.launch(default_ident, __doc__) diff --git a/trustgraph-flow/trustgraph/storage/object_embeddings/milvus/write.py b/trustgraph-flow/trustgraph/storage/object_embeddings/milvus/write.py index 5490af97..d1ad139a 100755 --- a/trustgraph-flow/trustgraph/storage/object_embeddings/milvus/write.py +++ b/trustgraph-flow/trustgraph/storage/object_embeddings/milvus/write.py @@ -9,7 +9,7 @@ from .... log_level import LogLevel from .... direct.milvus_object_embeddings import ObjectVectors from .... base import Consumer -module = ".".join(__name__.split(".")[1:-1]) +module = "oe-write" default_input_queue = object_embeddings_store_queue default_subscriber = module diff --git a/trustgraph-flow/trustgraph/storage/rows/cassandra/write.py b/trustgraph-flow/trustgraph/storage/rows/cassandra/write.py index e6536e6c..a84aefde 100755 --- a/trustgraph-flow/trustgraph/storage/rows/cassandra/write.py +++ b/trustgraph-flow/trustgraph/storage/rows/cassandra/write.py @@ -17,7 +17,7 @@ from .... schema import rows_store_queue from .... log_level import LogLevel from .... base import Consumer -module = ".".join(__name__.split(".")[1:-1]) +module = "rows-write" ssl_context = SSLContext(PROTOCOL_TLSv1_2) default_input_queue = rows_store_queue diff --git a/trustgraph-flow/trustgraph/storage/triples/cassandra/write.py b/trustgraph-flow/trustgraph/storage/triples/cassandra/write.py index 17b5ae9a..f8396692 100755 --- a/trustgraph-flow/trustgraph/storage/triples/cassandra/write.py +++ b/trustgraph-flow/trustgraph/storage/triples/cassandra/write.py @@ -10,35 +10,26 @@ import argparse import time from .... direct.cassandra import TrustGraph -from .... schema import Triples -from .... schema import triples_store_queue -from .... log_level import LogLevel -from .... base import Consumer +from .... base import TriplesStoreService -module = ".".join(__name__.split(".")[1:-1]) +default_ident = "triples-write" -default_input_queue = triples_store_queue -default_subscriber = module default_graph_host='localhost' -class Processor(Consumer): +class Processor(TriplesStoreService): def __init__(self, **params): - input_queue = params.get("input_queue", default_input_queue) - subscriber = params.get("subscriber", default_subscriber) + id = params.get("id", default_ident) + graph_host = params.get("graph_host", default_graph_host) graph_username = params.get("graph_username", None) graph_password = params.get("graph_password", None) super(Processor, self).__init__( **params | { - "input_queue": input_queue, - "subscriber": subscriber, - "input_schema": Triples, "graph_host": graph_host, - "graph_username": graph_username, - "graph_password": graph_password, + "graph_username": graph_username } ) @@ -47,11 +38,9 @@ class Processor(Consumer): self.password = graph_password self.table = None - async def handle(self, msg): + async def store_triples(self, message): - v = msg.value() - - table = (v.metadata.user, v.metadata.collection) + table = (message.metadata.user, message.metadata.collection) if self.table is None or self.table != table: @@ -61,13 +50,15 @@ class Processor(Consumer): if self.username and self.password: self.tg = TrustGraph( hosts=self.graph_host, - keyspace=v.metadata.user, table=v.metadata.collection, + keyspace=message.metadata.user, + table=message.metadata.collection, username=self.username, password=self.password ) else: self.tg = TrustGraph( hosts=self.graph_host, - keyspace=v.metadata.user, table=v.metadata.collection, + keyspace=message.metadata.user, + table=message.metadata.collection, ) except Exception as e: print("Exception", e, flush=True) @@ -76,7 +67,7 @@ class Processor(Consumer): self.table = table - for t in v.triples: + for t in message.triples: self.tg.insert( t.s.value, t.p.value, @@ -86,9 +77,7 @@ class Processor(Consumer): @staticmethod def add_args(parser): - Consumer.add_args( - parser, default_input_queue, default_subscriber, - ) + TriplesStoreService.add_args(parser) parser.add_argument( '-g', '--graph-host', @@ -110,5 +99,5 @@ class Processor(Consumer): def run(): - Processor.launch(module, __doc__) + Processor.launch(default_ident, __doc__) diff --git a/trustgraph-flow/trustgraph/storage/triples/falkordb/write.py b/trustgraph-flow/trustgraph/storage/triples/falkordb/write.py index 2d0ae38a..b3996b91 100755 --- a/trustgraph-flow/trustgraph/storage/triples/falkordb/write.py +++ b/trustgraph-flow/trustgraph/storage/triples/falkordb/write.py @@ -16,7 +16,7 @@ from .... schema import triples_store_queue from .... log_level import LogLevel from .... base import Consumer -module = ".".join(__name__.split(".")[1:-1]) +module = "triples-write" default_input_queue = triples_store_queue default_subscriber = module diff --git a/trustgraph-flow/trustgraph/storage/triples/memgraph/write.py b/trustgraph-flow/trustgraph/storage/triples/memgraph/write.py index 620e669e..8c88ea8f 100755 --- a/trustgraph-flow/trustgraph/storage/triples/memgraph/write.py +++ b/trustgraph-flow/trustgraph/storage/triples/memgraph/write.py @@ -16,7 +16,7 @@ from .... schema import triples_store_queue from .... log_level import LogLevel from .... base import Consumer -module = ".".join(__name__.split(".")[1:-1]) +module = "triples-write" default_input_queue = triples_store_queue default_subscriber = module diff --git a/trustgraph-flow/trustgraph/storage/triples/neo4j/write.py b/trustgraph-flow/trustgraph/storage/triples/neo4j/write.py index 3323f912..84a4d923 100755 --- a/trustgraph-flow/trustgraph/storage/triples/neo4j/write.py +++ b/trustgraph-flow/trustgraph/storage/triples/neo4j/write.py @@ -16,7 +16,7 @@ from .... schema import triples_store_queue from .... log_level import LogLevel from .... base import Consumer -module = ".".join(__name__.split(".")[1:-1]) +module = "triples-write" default_input_queue = triples_store_queue default_subscriber = module diff --git a/trustgraph-ocr/trustgraph/decoding/ocr/pdf_decoder.py b/trustgraph-ocr/trustgraph/decoding/ocr/pdf_decoder.py index f8926589..5fa436b8 100755 --- a/trustgraph-ocr/trustgraph/decoding/ocr/pdf_decoder.py +++ b/trustgraph-ocr/trustgraph/decoding/ocr/pdf_decoder.py @@ -14,7 +14,7 @@ from ... schema import document_ingest_queue, text_ingest_queue from ... log_level import LogLevel from ... base import ConsumerProducer -module = ".".join(__name__.split(".")[1:-1]) +module = "ocr" default_input_queue = document_ingest_queue default_output_queue = text_ingest_queue diff --git a/trustgraph-vertexai/trustgraph/model/text_completion/vertexai/llm.py b/trustgraph-vertexai/trustgraph/model/text_completion/vertexai/llm.py index 4d38c8c0..3594b76d 100755 --- a/trustgraph-vertexai/trustgraph/model/text_completion/vertexai/llm.py +++ b/trustgraph-vertexai/trustgraph/model/text_completion/vertexai/llm.py @@ -4,50 +4,30 @@ Simple LLM service, performs text prompt completion using VertexAI on Google Cloud. Input is prompt, output is response. """ -import vertexai -import time -from prometheus_client import Histogram -import os - from google.oauth2 import service_account import google +import vertexai from vertexai.preview.generative_models import ( - Content, - FunctionDeclaration, - GenerativeModel, - GenerationConfig, - HarmCategory, - HarmBlockThreshold, - Part, - Tool, + Content, FunctionDeclaration, GenerativeModel, GenerationConfig, + HarmCategory, HarmBlockThreshold, Part, Tool, ) -from .... schema import TextCompletionRequest, TextCompletionResponse, Error -from .... schema import text_completion_request_queue -from .... schema import text_completion_response_queue -from .... log_level import LogLevel -from .... base import ConsumerProducer from .... exceptions import TooManyRequests +from .... base import LlmService, LlmResult -module = ".".join(__name__.split(".")[1:-1]) +default_ident = "text-completion" -default_input_queue = text_completion_request_queue -default_output_queue = text_completion_response_queue -default_subscriber = module default_model = 'gemini-1.0-pro-001' default_region = 'us-central1' default_temperature = 0.0 default_max_output = 8192 default_private_key = "private.json" -class Processor(ConsumerProducer): +class Processor(LlmService): def __init__(self, **params): - input_queue = params.get("input_queue", default_input_queue) - output_queue = params.get("output_queue", default_output_queue) - subscriber = params.get("subscriber", default_subscriber) region = params.get("region", default_region) model = params.get("model", default_model) private_key = params.get("private_key", default_private_key) @@ -57,28 +37,7 @@ class Processor(ConsumerProducer): if private_key is None: raise RuntimeError("Private key file not specified") - super(Processor, self).__init__( - **params | { - "input_queue": input_queue, - "output_queue": output_queue, - "subscriber": subscriber, - "input_schema": TextCompletionRequest, - "output_schema": TextCompletionResponse, - } - ) - - if not hasattr(__class__, "text_completion_metric"): - __class__.text_completion_metric = Histogram( - 'text_completion_duration', - 'Text completion duration (seconds)', - buckets=[ - 0.25, 0.5, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, - 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, - 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, - 30.0, 35.0, 40.0, 45.0, 50.0, 60.0, 80.0, 100.0, - 120.0 - ] - ) + super(Processor, self).__init__(**params) self.parameters = { "temperature": temperature, @@ -110,7 +69,11 @@ class Processor(ConsumerProducer): print("Initialise VertexAI...", flush=True) if private_key: - credentials = service_account.Credentials.from_service_account_file(private_key) + credentials = ( + service_account.Credentials.from_service_account_file( + private_key + ) + ) else: credentials = None @@ -131,50 +94,29 @@ class Processor(ConsumerProducer): print("Initialisation complete", flush=True) - async def handle(self, msg): + async def generate_content(self, system, prompt): try: - v = msg.value() + prompt = system + "\n\n" + prompt - # Sender-produced ID + response = self.llm.generate_content( + prompt, generation_config=self.generation_config, + safety_settings=self.safety_settings + ) - id = msg.properties()["id"] + resp = LlmResult() + resp.text = response.text + resp.in_token = response.usage_metadata.prompt_token_count + resp.out_token = response.usage_metadata.candidates_token_count + resp.model = self.model - print(f"Handling prompt {id}...", flush=True) - - prompt = v.system + "\n\n" + v.prompt - - with __class__.text_completion_metric.time(): - - response = self.llm.generate_content( - prompt, generation_config=self.generation_config, - safety_settings=self.safety_settings - ) - - resp = response.text - inputtokens = int(response.usage_metadata.prompt_token_count) - outputtokens = int(response.usage_metadata.candidates_token_count) - print(resp, flush=True) - print(f"Input Tokens: {inputtokens}", flush=True) - print(f"Output Tokens: {outputtokens}", flush=True) + print(f"Input Tokens: {resp.in_token}", flush=True) + print(f"Output Tokens: {resp.out_token}", flush=True) print("Send response...", flush=True) - r = TextCompletionResponse( - error=None, - response=resp, - in_token=inputtokens, - out_token=outputtokens, - model=self.model - ) - - await self.send(r, properties={"id": id}) - - print("Done.", flush=True) - - # Acknowledge successful processing of the message - self.consumer.acknowledge(msg) + return resp except google.api_core.exceptions.ResourceExhausted as e: @@ -186,40 +128,19 @@ class Processor(ConsumerProducer): except Exception as e: # Apart from rate limits, treat all exceptions as unrecoverable - print(f"Exception: {e}") - - print("Send error response...", flush=True) - - r = TextCompletionResponse( - error=Error( - type = "llm-error", - message = str(e), - ), - response=None, - in_token=None, - out_token=None, - model=None, - ) - - await self.send(r, properties={"id": id}) - - self.consumer.acknowledge(msg) + raise e @staticmethod def add_args(parser): - ConsumerProducer.add_args( - parser, default_input_queue, default_subscriber, - default_output_queue, - ) + LlmService.add_args(parser) parser.add_argument( '-m', '--model', default=default_model, help=f'LLM model (default: {default_model})' ) - # Also: text-bison-32k parser.add_argument( '-k', '--private-key', @@ -247,6 +168,5 @@ class Processor(ConsumerProducer): ) def run(): - - Processor.launch(module, __doc__) + Processor.launch(default_ident, __doc__)