diff --git a/Makefile b/Makefile
index 1899e602..4f4de9d2 100644
--- a/Makefile
+++ b/Makefile
@@ -60,14 +60,6 @@ container: update-package-versions
${DOCKER} build -f containers/Containerfile.ocr \
-t ${CONTAINER_BASE}/trustgraph-ocr:${VERSION} .
-some-containers:
- ${DOCKER} build -f containers/Containerfile.base \
- -t ${CONTAINER_BASE}/trustgraph-base:${VERSION} .
- ${DOCKER} build -f containers/Containerfile.flow \
- -t ${CONTAINER_BASE}/trustgraph-flow:${VERSION} .
- ${DOCKER} build -f containers/Containerfile.vertexai \
- -t ${CONTAINER_BASE}/trustgraph-vertexai:${VERSION} .
-
basic-containers: update-package-versions
${DOCKER} build -f containers/Containerfile.base \
-t ${CONTAINER_BASE}/trustgraph-base:${VERSION} .
diff --git a/README.md b/README.md
index ed3118c1..f5c6e186 100644
--- a/README.md
+++ b/README.md
@@ -2,41 +2,40 @@
-## Autonomous Operations Platform
+## Data-to-AI, Simplified.
[](https://pypi.org/project/trustgraph/) [](https://discord.gg/sQMwkRz5GX)
-ð [Docs](https://trustgraph.ai/docs/getstarted) ðš [YouTube](https://www.youtube.com/@TrustGraphAI?sub_confirmation=1) ð§ [Knowledge Cores](https://github.com/trustgraph-ai/catalog/tree/master/v3) âïļ [API Docs](docs/apis/README.md) ð§âðŧ [CLI Docs](https://trustgraph.ai/docs/running/cli) ðŽ [Discord](https://discord.gg/sQMwkRz5GX) ð [Blog](https://blog.trustgraph.ai/subscribe)
+ð [Getting Started](https://trustgraph.ai/docs/getstarted) ðš [YouTube](https://www.youtube.com/@TrustGraphAI?sub_confirmation=1) ð§ [Knowledge Cores](https://github.com/trustgraph-ai/catalog/tree/master/v3) âïļ [API Docs](docs/apis/README.md) ð§âðŧ [CLI Docs](https://trustgraph.ai/docs/running/cli) ðŽ [Discord](https://discord.gg/sQMwkRz5GX) ð [Blog](https://blog.trustgraph.ai/subscribe)
-**Transform AI agents from experimental concepts into a new paradigm of continuous operations.**
+## The AI App Problem: Everything in Between
-The **TrustGraph** platform provides a robust, scalable, and reliable AI infrastructure designed for complex environments, complete with a full observability and telemetry stack. **TrustGraph** automates the deployment of state-of-the-art RAG pipelines using both Knowledge Graphs and Vector Databases in local and cloud environments with a unified interface to all major LLM providers.
+Building enterprise AI applications is *hard*. You're not just connecting APIs with a protocol - you're wrangling a complex ecosystem:
----
+* **Data Silos:** Connecting to and managing data from various sources (databases, APIs, files) is a nightmare.
+* **LLM Integration:** Choosing, integrating, and managing different LLMs adds another layer of complexity.
+* **Deployment Headaches:** Deploying, scaling, and monitoring your AI application is a constant challenge.
+* **Knowledge Graph Construction:** Taking raw knowledge and structuring it so it can be efficiently retrieved.
+* **Vector Database Juggling:** Setting up and optimizing a vector database for efficient data retrieval is crucial but complex.
+* **Data Pipelines:** Building robust ETL pipelines to prepare and transform your data is time-consuming.
+* **Data Management:** As your app grows, so does the data meaning storage and retreival becomes much more complex.
+* **Prompt Engineering:** Building, testing, and deploying prompts for specific use cases.
+* **Reliability:** With every new connection, the complexity ramps up meaning any simple error can bring the entire system crashing down.
-- âĻ [**Key Features**](#-key-features)
-- ðŊ [**Why TrustGraph?**](#-why-trustgraph)
-- ð [**Getting Started**](#-getting-started)
-- ð§ [**Configuration Builder**](#-configuration-builder)
-- ð§ [**Knowledge Cores**](#-knowledge-cores)
-- ð [**Architecture**](#-architecture)
-- ð§Đ [**Integrations**](#-integrations)
-- ð [**Observability & Telemetry**](#-observability--telemetry)
-- ðĪ [**Contributing**](#-contributing)
-- ð [**License**](#-license)
-- ð [**Support & Community**](#-support--community)
+## What is TrustGraph?
----
+**TrustGraph removes the biggest headache of building an AI app: connecting and managing all the data, deployments, and models.** As a full-stack platform, TrustGraph simplifies the development and deployment of data-driven AI applications. TrustGraph is a complete solution, handling everything from data ingestion to deployment, so you can focus on building innovative AI experiences.
-## âĻ Key Features
+
+
+## The Stack Layers
- ð **Data Ingest**: Bulk ingest documents such as `.pdf`,`.txt`, and `.md`
-- ð **OCR Pipelines**: OCR documents with PDF decode, Tesseract, or Mistral OCR services
- ðŠ **Adjustable Chunking**: Choose your chunking algorithm and parameters
-- ð **No-code LLM Integration**: **Anthropic**, **AWS Bedrock**, **AzureAI**, **AzureOpenAI**, **Cohere**, **Google AI Studio**, **Google VertexAI**, **Llamafiles**, **LM Studio**, **Mistral**, **Ollama**, and **OpenAI**
+- ð **No-code LLM Integration**: **Anthropic**, **AWS Bedrock**, **AzureAI**, **AzureOpenAI**, **Cohere**, **Google AI Studio**, **Google VertexAI**, **Llamafiles**, **Ollama**, and **OpenAI**
- ð **Automated Knowledge Graph Building**: No need for complex ontologies and manual graph building
- ðĒ **Knowledge Graph to Vector Embeddings Mappings**: Connect knowledge graph enhanced data directly to vector embeddings
- â**Natural Language Data Retrieval**: Automatically perform a semantic similiarity search and subgraph extraction for the context of LLM generative responses
@@ -45,38 +44,31 @@ The **TrustGraph** platform provides a robust, scalable, and reliable AI infrast
- ð **Multiple Knowledge Graph Options**: Full integration with **Memgraph**, **FalkorDB**, **Neo4j**, or **Cassandra**
- ð§Ū **Multiple VectorDB Options**: Full integration with **Qdrant**, **Pinecone**, or **Milvus**
- ðïļ **Production-Grade** Reliability, scalability, and accuracy
-- ð **Observability and Telemetry**: Get insights into system performance with **Prometheus** and **Grafana**
+- ð **Observability and Telemetry**: Get insights into system performance with **Prometheus** and **Grafana**
- ðŧ **Orchestration**: Fully containerized with **Docker** or **Kubernetes**
- ðĨ **Stack Manager**: Control and scale the stack with confidence with **Apache Pulsar**
-- âïļ **Cloud Deployments**: **AWS**, **Azure**, **Google Cloud**, and **Scaleway**
+- âïļ **Cloud Deployments**: **AWS** and **Google Cloud**
- ðŠī **Customizable and Extensible**: Tailor for your data and use cases
- ðĨïļ **Configuration Builder**: Build the `YAML` configuration with drop down menus and selectable parameters
- ðĩïļ **Test Suite**: A simple UI to fully test TrustGraph performance
-## ðŊ Why TrustGraph?
+## Why Use TrustGraph?
-Traditional operations involve manual intervention, siloed tools, and reactive problem-solving. While AI agents show promise, integrating them into reliable, continuous operations presents significant challenges:
+* **Accelerate Development:** TrustGraph instantly connects your data and app, keeping you laser focused on your users.
+* **Reduce Complexity:** Eliminate the pain of integrating disparate tools and technologies.
+* **Focus on Innovation:** Spend your time building your core AI logic, not managing infrastructure.
+* **Improve Data Relevance:** Ensure your LLM has access to the *right* data, at the *right* time.
+* **Scale with Confidence:** Deploy and scale your AI applications reliably and efficiently.
+* **Full RAG Solution:** Focus on optimizing your respones not building RAG pipelines.
-1. **Scalability & Reliability:** Standalone agents don't scale or offer the robustness required for business-critical operations.
-2. **Contextual Understanding:** Agents need deep, relevant context (often locked in sensitive and protectec data) to perform complex tasks effectively. RAG is powerful but complex to deploy and manage.
-3. **Integration Nightmare:** Connecting agents to diverse systems, data sources, and various LLMs is difficult and time-consuming.
-4. **Lack of Oversight:** Monitoring, debugging, and understanding the behavior of multiple autonomous agents in production is critical but often overlooked.
-
-**TrustGraph addresses these challenges by providing:**
-
-* A **platform**, not just a library, for managing the lifecycle of autonomous operations.
-* **Automated, best-practice RAG deployments** that combine the strengths of semantic vector search and structured knowledge graph traversal.
-* A **standardized layer** for LLM interaction and enterprise system integration.
-* **Built-in observability** to ensure you can trust and manage your autonomous systems.
-
-## ð Getting Started
+## Quickstart Guide ð
- [Install the CLI](#install-the-trustgraph-cli)
-- [Configuration Builder](#-configuration-builder)
-- [Platform Restarts](#platform-restarts)
+- [Configuration Builder](#configuration-builder)
+- [System Restarts](#system-restarts)
- [Test Suite](#test-suite)
- [Example Notebooks](#example-trustgraph-notebooks)
-### Developer APIs and CLI
+## Developer APIs and CLI
- [**REST API**](docs/apis/README.md#rest-apis)
- [**Websocket API**](docs/apis/README.md#websocket-api)
@@ -87,7 +79,7 @@ See the [API Developer's Guide](#api-documentation) for more information.
For users, **TrustGraph** has the following interfaces:
-- [**Configuration Builder**](#-configuration-builder)
+- [**Configuration Builder**](#configuration-builder)
- [**Test Suite**](#test-suite)
The `TrustGraph CLI` installs the commands for interacting with TrustGraph while running along with the Python SDK. The `Configuration Builder` enables customization of TrustGraph deployments prior to launching. The **REST API** can be accessed through port `8088` of the TrustGraph host machine with JSON request and response bodies.
@@ -95,18 +87,18 @@ The `TrustGraph CLI` installs the commands for interacting with TrustGraph while
### Install the TrustGraph CLI
```
-pip3 install trustgraph-cli==0.21.17
+pip3 install trustgraph-cli==0.20.9
```
> [!NOTE]
> The `TrustGraph CLI` version must match the desired `TrustGraph` release version.
-## ð§ Configuration Builder
+## Configuration Builder
-TrustGraph is endlessly customizable by editing the `YAML` launch files. The `Configuration Builder` provides a quick and intuitive tool for building a custom configuration that deploys with Docker, Podman, Minikube, AWS, Azure, Google Cloud, or Scaleway. There is a `Configuration Builder` for the both the lastest and stable `TrustGraph` releases.
+TrustGraph is endlessly customizable by editing the `YAML` launch files. The `Configuration Builder` provides a quick and intuitive tool for building a custom configuration that deploys with Docker, Podman, Minikube, or Google Cloud. There is a `Configuration Builder` for the both the lastest and stable `TrustGraph` releases.
-- [**Configuration Builder** (Stable 0.21.17) ð](https://config-ui.demo.trustgraph.ai/)
-- [**Configuration Builder** (Latest 0.22.5) ð](https://dev.config-ui.demo.trustgraph.ai/)
+- [**Configuration Builder** (Stable 0.20.9) ð](https://config-ui.demo.trustgraph.ai/)
+- [**Configuration Builder** (Latest 0.20.11) ð](https://dev.config-ui.demo.trustgraph.ai/)
The `Configuration Builder` has 4 important sections:
@@ -129,7 +121,7 @@ When finished, shutting down TrustGraph is as simple as:
docker compose down -v
```
-### Platform Restarts
+## System Restarts
The `-v` flag will destroy all data on shut down. To restart the system, it's necessary to keep the volumes. To keep the volumes, shut down without the `-v` flag:
```
@@ -143,7 +135,7 @@ docker compose up -d
All data previously in TrustGraph will be saved and usable on restart.
-### Test Suite
+## Test Suite
If added to the build in the `Configuration Builder`, the `Test Suite` will be available at port `8888`. The `Test Suite` has the following capabilities:
@@ -153,11 +145,20 @@ If added to the build in the `Configuration Builder`, the `Test Suite` will be a
- **Graph Visualizer** ð: Visualize semantic relationships in **3D**
- **Data Loader** ð: Directly load `.pdf`, `.txt`, or `.md` into the system with document metadata
-### Example TrustGraph Notebooks
+## Example TrustGraph Notebooks
- [**REST API Notebooks**](https://github.com/trustgraph-ai/example-notebooks/tree/master/api-examples)
- [**Python SDK Notebooks**](https://github.com/trustgraph-ai/example-notebooks/tree/master/api-library)
+## Prebuilt Configuration Files
+
+TrustGraph `YAML` files are available [here](https://github.com/trustgraph-ai/trustgraph/releases). Download `deploy.zip` for the desired release version.
+
+| Release Type | Release Version |
+| ------------ | --------------- |
+| Latest | [0.20.11](https://github.com/trustgraph-ai/trustgraph/releases/download/v0.20.11/deploy.zip) |
+| Stable | [0.20.9](https://github.com/trustgraph-ai/trustgraph/releases/download/v0.20.9/deploy.zip) |
+
TrustGraph is fully containerized and is launched with a `YAML` configuration file. Unzipping the `deploy.zip` will add the `deploy` directory with the following subdirectories:
- `docker-compose`
@@ -179,39 +180,12 @@ kubectl apply -f
TrustGraph is designed to be modular to support as many LLMs and environments as possible. A natural fit for a modular architecture is to decompose functions into a set of modules connected through a pub/sub backbone. [Apache Pulsar](https://github.com/apache/pulsar/) serves as this pub/sub backbone. Pulsar acts as the data broker managing data processing queues connected to procesing modules.
-## ð§ Knowledge Cores
+### Pulsar Workflows
-One of the biggest challenges currently facing RAG architectures is the ability to quickly reuse and integrate knowledge sets. **TrustGraph** solves this problem by storing the results of the document ingestion process in reusable Knowledge Cores. Being able to store and reuse the Knowledge Cores means the process has to be run only once for a set of documents. These reusable Knowledge Cores can be loaded back into **TrustGraph** and used for RAG.
-
-A Knowledge Core has two components:
-
-- Set of Graph Edges
-- Set of mapped Vector Embeddings
-
-When a Knowledge Core is loaded into TrustGraph, the corresponding graph edges and vector embeddings are queued and loaded into the chosen graph and vector stores.
-
-## ð Architecture
-
-As a full-stack platform, TrustGraph provides all the stack layers needed to connect the data layer to the app layer for autonomous operations.
-
-
-
-## ð§Đ Integrations
-TrustGraph seamlessly integrates API services, data stores, observability, telemetry, and control flow for a unified platform experience.
-
-- LLM Providers: **Anthropic**, **AWS Bedrock**, **AzureAI**, **AzureOpenAI**, **Cohere**, **Google AI Studio**, **Google VertexAI**, **Llamafiles**, **LM Studio**, **Mistral**, **Ollama**, and **OpenAI**
-- Vector Databases: **Qdrant**, **Pinecone**, and **Milvus**
-- Knowledge Graphs: **Memgraph**, **Neo4j**, and **FalkorDB**
-- Data Stores: **Apache Cassandra**
-- Observability: **Prometheus** and **Grafana**
-- Control Flow: **Apache Pulsar**
-
-### Pulsar Control Flows
-
-- For control flows, Pulsar accepts the output of a processing module and queues it for input to the next subscribed module.
+- For processing flows, Pulsar accepts the output of a processing module and queues it for input to the next subscribed module.
- For services such as LLMs and embeddings, Pulsar provides a client/server model. A Pulsar queue is used as the input to the service. When processed, the output is then delivered to a separate queue where a client subscriber can request that output.
-### Document Extraction Agents
+## Data Extraction Agents
TrustGraph extracts knowledge documents to an ultra-dense knowledge graph using 3 automonous data extraction agents. These agents focus on individual elements needed to build the knowledge graph. The agents are:
@@ -231,7 +205,7 @@ Text or Markdown file:
tg-load-text
```
-### Graph RAG Queries
+## Graph RAG Queries
Once the knowledge graph and embeddings have been built or a cognitive core has been loaded, RAG queries are launched with a single line:
@@ -239,7 +213,7 @@ Once the knowledge graph and embeddings have been built or a cognitive core has
tg-invoke-graph-rag -q "What are the top 3 takeaways from the document?"
```
-### Agent Flow
+## Agent Flow
Invoking the Agent Flow will use a ReAct style approach the combines Graph RAG and text completion requests to think through a problem solution.
@@ -250,44 +224,14 @@ tg-invoke-agent -v -q "Write a blog post on the top 3 takeaways from the documen
> [!TIP]
> Adding `-v` to the agent request will return all of the agent manager's thoughts and observations that led to the final response.
-## ð Observability & Telemetry
+## API Documentation
-Once the platform is running, access the Grafana dashboard at:
+[Developing on TrustGraph using APIs](docs/apis/README.md)
-```
-http://localhost:3000
-```
+## Deploy and Manage TrustGraph
-Default credentials are:
+[ðð Full Deployment Guide ðð](https://trustgraph.ai/docs/getstarted)
-```
-user: admin
-password: admin
-```
-
-The default Grafana dashboard tracks the following:
-
-- LLM Latency
-- Error Rate
-- Service Request Rates
-- Queue Backlogs
-- Chunking Histogram
-- Error Source by Service
-- Rate Limit Events
-- CPU usage by Service
-- Memory usage by Service
-- Models Deployed
-- Token Throughput (Tokens/second)
-- Cost Throughput (Cost/second)
-
-## ðĪ Contributing
+## TrustGraph Developer's Guide
[Developing for TrustGraph](docs/README.development.md)
-
-## ð License
-**TrustGraph** is licensed under [AGPL-3.0](https://www.gnu.org/licenses/agpl-3.0.en.html).
-
-## ð Support & Community
-- Bug Reports & Feature Requests: [Discord](https://discord.gg/sQMwkRz5GX)
-- Discussions & Questions: [Discord](https://discord.gg/sQMwkRz5GX)
-- Documentation: [Docs](https://trustgraph.ai/docs/getstarted)
diff --git a/test-api/test-library-add-doc b/test-api/test-library-add-doc
deleted file mode 100755
index bd927367..00000000
--- a/test-api/test-library-add-doc
+++ /dev/null
@@ -1,78 +0,0 @@
-#!/usr/bin/env python3
-
-import requests
-import json
-import sys
-import base64
-
-url = "http://localhost:8088/api/v1/"
-
-############################################################################
-
-id = "http://trustgraph.ai/doc/12345678"
-
-with open("docs/README.cats") as f:
- doc = base64.b64encode(f.read().encode("utf-8")).decode("utf-8")
-
-input = {
- "operation": "add",
- "document": {
- "id": id,
- "metadata": [
- {
- "s": {
- "v": id,
- "e": True,
- },
- "p": {
- "v": "http://www.w3.org/2000/01/rdf-schema#label",
- "e": True,
- },
- "o": {
- "v": "Mark's pets", "e": False,
- },
- },
- {
- "s": {
- "v": id,
- "e": True,
- },
- "p": {
- "v": 'https://schema.org/keywords',
- "e": True,
- },
- "o": {
- "v": "cats", "e": False,
- },
- },
- ],
- "document": doc,
- "kind": "text/plain",
- "user": "trustgraph",
- "collection": "default",
- "title": "Mark's cats",
- "comments": "Test doc taken from the TrustGraph repo",
- }
-}
-
-resp = requests.post(
- f"{url}librarian",
- json=input,
-)
-
-print(resp.text)
-resp = resp.json()
-
-print(resp)
-
-if "error" in resp:
- print(f"Error: {resp['error']}")
- sys.exit(1)
-
-# print(resp["response"])
-print(resp)
-
-sys.exit(0)
-
-############################################################################
-
diff --git a/test-api/test-library-add-doc2 b/test-api/test-library-add-doc2
deleted file mode 100755
index 0c0856f9..00000000
--- a/test-api/test-library-add-doc2
+++ /dev/null
@@ -1,90 +0,0 @@
-#!/usr/bin/env python3
-
-import requests
-import json
-import sys
-import base64
-
-url = "http://localhost:8088/api/v1/"
-
-############################################################################
-
-id = "http://trustgraph.ai/doc/12345678"
-
-source = "../sources/20160001634.pdf"
-
-with open(source, "rb") as f:
- doc = base64.b64encode(f.read()).decode("utf-8")
-
-input = {
- "operation": "add",
- "id": id,
- "document": {
- "metadata": [
- {
- "s": {
- "v": id,
- "e": True,
- },
- "p": {
- "v": "http://www.w3.org/2000/01/rdf-schema#label",
- "e": True,
- },
- "o": {
- "v": "Challenger report volume 1", "e": False,
- },
- },
- {
- "s": {
- "v": id,
- "e": True,
- },
- "p": {
- "v": 'https://schema.org/keywords',
- "e": True,
- },
- "o": {
- "v": "space shuttle", "e": False,
- },
- },
- {
- "s": {
- "v": id,
- "e": True,
- },
- "p": {
- "v": 'https://schema.org/keywords',
- "e": True,
- },
- "o": {
- "v": "nasa", "e": False,
- },
- },
- ],
- "document": doc,
- "kind": "application/pdf",
- "user": "trustgraph",
- "collection": "default",
- }
-}
-
-resp = requests.post(
- f"{url}librarian",
- json=input,
-)
-
-print(resp.text)
-resp = resp.json()
-
-print(resp)
-
-if "error" in resp:
- print(f"Error: {resp['error']}")
- sys.exit(1)
-
-print(resp)
-
-sys.exit(0)
-
-############################################################################
-
diff --git a/test-api/test-library-list b/test-api/test-library-list
deleted file mode 100755
index 72ea4478..00000000
--- a/test-api/test-library-list
+++ /dev/null
@@ -1,39 +0,0 @@
-#!/usr/bin/env python3
-
-import requests
-import json
-import sys
-import base64
-
-url = "http://localhost:8088/api/v1/"
-
-############################################################################
-
-user = "trustgraph"
-
-input = {
- "operation": "list",
- "user": user,
-}
-
-resp = requests.post(
- f"{url}librarian",
- json=input,
-)
-
-print(resp.text)
-resp = resp.json()
-
-print(resp)
-
-if "error" in resp:
- print(f"Error: {resp['error']}")
- sys.exit(1)
-
-# print(resp["response"])
-print(resp)
-
-sys.exit(0)
-
-############################################################################
-
diff --git a/tests/test-agent b/tests/test-agent
index b1420098..4782bbae 100755
--- a/tests/test-agent
+++ b/tests/test-agent
@@ -20,11 +20,7 @@ def output(text, prefix="> ", width=78):
)
print(out)
-p = AgentClient(
- pulsar_host="pulsar://pulsar:6650",
- input_queue = "non-persistent://tg/request/agent:0000",
- output_queue = "non-persistent://tg/response/agent:0000",
-)
+p = AgentClient(pulsar_host="pulsar://localhost:6650")
q = "How many cats does Mark have? Calculate that number raised to 0.4 power. Is that number lower than the numeric part of the mission identifier of the Space Shuttle Challenger on its last mission? If so, give me an apple pie recipe, otherwise return a poem about cheese."
diff --git a/tests/test-config b/tests/test-config
deleted file mode 100644
index 63f77b6b..00000000
--- a/tests/test-config
+++ /dev/null
@@ -1,2 +0,0 @@
-#!/usr/bin/env python3
-
diff --git a/tests/test-doc-rag b/tests/test-doc-rag
index b7382bf5..718157b6 100755
--- a/tests/test-doc-rag
+++ b/tests/test-doc-rag
@@ -3,12 +3,7 @@
import pulsar
from trustgraph.clients.document_rag_client import DocumentRagClient
-rag = DocumentRagClient(
- pulsar_host="pulsar://localhost:6650",
- subscriber="test1",
- input_queue = "non-persistent://tg/request/document-rag:default",
- output_queue = "non-persistent://tg/response/document-rag:default",
-)
+rag = DocumentRagClient(pulsar_host="pulsar://localhost:6650")
query="""
What was the cause of the space shuttle disaster?"""
diff --git a/tests/test-embeddings b/tests/test-embeddings
index 5fcd31e6..3855fcf0 100755
--- a/tests/test-embeddings
+++ b/tests/test-embeddings
@@ -3,12 +3,7 @@
import pulsar
from trustgraph.clients.embeddings_client import EmbeddingsClient
-embed = EmbeddingsClient(
- pulsar_host="pulsar://pulsar:6650",
- input_queue="non-persistent://tg/request/embeddings:default",
- output_queue="non-persistent://tg/response/embeddings:default",
- subscriber="test1",
-)
+embed = EmbeddingsClient(pulsar_host="pulsar://localhost:6650")
prompt="Write a funny limerick about a llama"
@@ -16,3 +11,5 @@ resp = embed.request(prompt)
print(resp)
+
+
diff --git a/tests/test-flow b/tests/test-flow
deleted file mode 100755
index 87a349af..00000000
--- a/tests/test-flow
+++ /dev/null
@@ -1,92 +0,0 @@
-#!/usr/bin/env python3
-
-import requests
-
-url = "http://localhost:8088/"
-
-resp = requests.post(
- f"{url}/api/v1/flow",
- json={
- "operation": "list-classes",
- }
-)
-
-print(resp)
-print(resp.text)
-
-resp = requests.post(
- f"{url}/api/v1/flow",
- json={
- "operation": "get-class",
- "class-name": "default",
- }
-)
-
-print(resp)
-print(resp.text)
-
-resp = requests.post(
- f"{url}/api/v1/flow",
- json={
- "operation": "put-class",
- "class-name": "bunch",
- "class-definition": "{}",
- }
-)
-
-print(resp)
-print(resp.text)
-
-resp = requests.post(
- f"{url}/api/v1/flow",
- json={
- "operation": "get-class",
- "class-name": "bunch",
- }
-)
-
-print(resp)
-print(resp.text)
-
-resp = requests.post(
- f"{url}/api/v1/flow",
- json={
- "operation": "list-classes",
- }
-)
-
-print(resp)
-print(resp.text)
-
-
-resp = requests.post(
- f"{url}/api/v1/flow",
- json={
- "operation": "delete-class",
- "class-name": "bunch",
- }
-)
-
-print(resp)
-print(resp.text)
-
-
-resp = requests.post(
- f"{url}/api/v1/flow",
- json={
- "operation": "list-classes",
- }
-)
-
-print(resp)
-print(resp.text)
-
-resp = requests.post(
- f"{url}/api/v1/flow",
- json={
- "operation": "list-flows",
- }
-)
-
-print(resp)
-print(resp.text)
diff --git a/tests/test-flow-get-class b/tests/test-flow-get-class
deleted file mode 100755
index 20707b51..00000000
--- a/tests/test-flow-get-class
+++ /dev/null
@@ -1,19 +0,0 @@
-#!/usr/bin/env python3
-
-import requests
-
-url = "http://localhost:8088/"
-
-resp = requests.post(
- f"{url}/api/v1/flow",
- json={
- "operation": "get-class",
- "class-name": "default",
- }
-)
-
-resp = resp.json()
-
-print(resp["class-definition"])
-
-
diff --git a/tests/test-flow-put-class b/tests/test-flow-put-class
deleted file mode 100755
index 8fd4d9f2..00000000
--- a/tests/test-flow-put-class
+++ /dev/null
@@ -1,22 +0,0 @@
-#!/usr/bin/env python3
-
-import requests
-import json
-
-url = "http://localhost:8088/"
-
-defn = {"class": {"de-query:{class}": {"request": "non-persistent://tg/request/document-embeddings:{class}", "response": "non-persistent://tg/response/document-embeddings:{class}"}, "document-rag:{class}": {"document-embeddings-request": "non-persistent://tg/request/document-embeddings:{class}", "document-embeddings-response": "non-persistent://tg/response/document-embeddings:{class}", "embeddings-request": "non-persistent://tg/request/embeddings:{class}", "embeddings-response": "non-persistent://tg/response/embeddings:{class}", "prompt-request": "non-persistent://tg/request/prompt-rag:{class}", "prompt-response": "non-persistent://tg/response/prompt-rag:{class}", "request": "non-persistent://tg/request/document-rag:{class}", "response": "non-persistent://tg/response/document-rag:{class}"}, "embeddings:{class}": {"request": "non-persistent://tg/request/embeddings:{class}", "response": "non-persistent://tg/response/embeddings:{class}"}, "ge-query:{class}": {"request": "non-persistent://tg/request/graph-embeddings:{class}", "response": "non-persistent://tg/response/graph-embeddings:{class}"}, "graph-rag:{class}": {"embeddings-request": "non-persistent://tg/request/embeddings:{class}", "embeddings-response": "non-persistent://tg/response/embeddings:{class}", "graph-embeddings-request": "non-persistent://tg/request/graph-embeddings:{class}", "graph-embeddings-response": "non-persistent://tg/response/graph-embeddings:{class}", "prompt-request": "non-persistent://tg/request/prompt-rag:{class}", "prompt-response": "non-persistent://tg/response/prompt-rag:{class}", "request": "non-persistent://tg/request/graph-rag:{class}", "response": "non-persistent://tg/response/graph-rag:{class}", "triples-request": "non-persistent://tg/request/triples:{class}", "triples-response": "non-persistent://tg/response/triples:{class}"}, "metering-rag:{class}": {"input": "non-persistent://tg/response/text-completion-rag:{class}"}, "metering:{class}": {"input": "non-persistent://tg/response/text-completion:{class}"}, "prompt-rag:{class}": {"request": "non-persistent://tg/request/prompt-rag:{class}", "response": "non-persistent://tg/response/prompt-rag:{class}", "text-completion-request": "non-persistent://tg/request/text-completion-rag:{class}", "text-completion-response": "non-persistent://tg/response/text-completion-rag:{class}"}, "prompt:{class}": {"request": "non-persistent://tg/request/prompt:{class}", "response": "non-persistent://tg/response/prompt:{class}", "text-completion-request": "non-persistent://tg/request/text-completion:{class}", "text-completion-response": "non-persistent://tg/response/text-completion:{class}"}, "text-completion-rag:{class}": {"request": "non-persistent://tg/request/text-completion-rag:{class}", "response": "non-persistent://tg/response/text-completion-rag:{class}"}, "text-completion:{class}": {"request": "non-persistent://tg/request/text-completion:{class}", "response": "non-persistent://tg/response/text-completion:{class}"}, "triples-query:{class}": {"request": "non-persistent://tg/request/triples:{class}", "response": "non-persistent://tg/response/triples:{class}"}}, "description": "Default flow class, supports GraphRAG and document RAG", "flow": {"agent-manager:{id}": {"graph-rag-request": "non-persistent://tg/request/graph-rag:{class}", "graph-rag-response": "non-persistent://tg/response/graph-rag:{class}", "next": "non-persistent://tg/request/agent:{id}", "prompt-request": "non-persistent://tg/request/prompt:{class}", "prompt-response": "non-persistent://tg/response/prompt:{class}", "request": "non-persistent://tg/request/agent:{id}", "response": "non-persistent://tg/response/agent:{id}", "text-completion-request": "non-persistent://tg/request/text-completion:{class}", "text-completion-response": "non-persistent://tg/response/text-completion:{class}"}, "chunker:{id}": {"input": "persistent://tg/flow/text-document-load:{id}", "output": "persistent://tg/flow/chunk-load:{id}"}, "de-write:{id}": {"input": "persistent://tg/flow/document-embeddings-store:{id}"}, "document-embeddings:{id}": {"embeddings-request": "non-persistent://tg/request/embeddings:{class}", "embeddings-response": "non-persistent://tg/response/embeddings:{class}", "input": "persistent://tg/flow/chunk-load:{id}", "output": "persistent://tg/flow/document-embeddings-store:{id}"}, "ge-write:{id}": {"input": "persistent://tg/flow/graph-embeddings-store:{id}"}, "graph-embeddings:{id}": {"embeddings-request": "non-persistent://tg/request/embeddings:{class}", "embeddings-response": "non-persistent://tg/response/embeddings:{class}", "input": "persistent://tg/flow/entity-contexts-load:{id}", "output": "persistent://tg/flow/graph-embeddings-store:{id}"}, "kg-extract-definitions:{id}": {"entity-contexts": "persistent://tg/flow/entity-contexts-load:{id}", "input": "persistent://tg/flow/chunk-load:{id}", "prompt-request": "non-persistent://tg/request/prompt:{class}", "prompt-response": "non-persistent://tg/response/prompt:{class}", "triples": "persistent://tg/flow/triples-store:{id}"}, "kg-extract-relationships:{id}": {"input": "persistent://tg/flow/chunk-load:{id}", "prompt-request": "non-persistent://tg/request/prompt:{class}", "prompt-response": "non-persistent://tg/response/prompt:{class}", "triples": "persistent://tg/flow/triples-store:{id}"}, "pdf-decoder:{id}": {"input": "persistent://tg/flow/document-load:{id}", "output": "persistent://tg/flow/text-document-load:{id}"}, "triples-write:{id}": {"input": "persistent://tg/flow/triples-store:{id}"}}, "tags": ["document-rag", "graph-rag", "knowledge-extraction"]}
-
-resp = requests.post(
- f"{url}/api/v1/flow",
- json={
- "operation": "put-class",
- "class-name": "default",
- "class-definition": json.dumps(defn),
- }
-)
-
-resp = resp.json()
-
-print(resp)
-
diff --git a/tests/test-flow-start-flow b/tests/test-flow-start-flow
deleted file mode 100755
index 15a3c0cc..00000000
--- a/tests/test-flow-start-flow
+++ /dev/null
@@ -1,23 +0,0 @@
-#!/usr/bin/env python3
-
-import requests
-import json
-
-url = "http://localhost:8088/"
-
-resp = requests.post(
- f"{url}/api/v1/flow",
- json={
- "operation": "start-flow",
- "flow-id": "0003",
- "class-name": "default",
- }
-)
-
-print(resp)
-print(resp.text)
-resp = resp.json()
-
-
-print(resp)
-
diff --git a/tests/test-flow-stop-flow b/tests/test-flow-stop-flow
deleted file mode 100755
index 62ea1aa9..00000000
--- a/tests/test-flow-stop-flow
+++ /dev/null
@@ -1,22 +0,0 @@
-#!/usr/bin/env python3
-
-import requests
-import json
-
-url = "http://localhost:8088/"
-
-resp = requests.post(
- f"{url}/api/v1/flow",
- json={
- "operation": "stop-flow",
- "flow-id": "0003",
- }
-)
-
-print(resp)
-print(resp.text)
-resp = resp.json()
-
-
-print(resp)
-
diff --git a/tests/test-graph-rag b/tests/test-graph-rag
index b62f890c..036f73f4 100755
--- a/tests/test-graph-rag
+++ b/tests/test-graph-rag
@@ -3,18 +3,11 @@
import pulsar
from trustgraph.clients.graph_rag_client import GraphRagClient
-rag = GraphRagClient(
- pulsar_host="pulsar://localhost:6650",
- subscriber="test1",
- input_queue = "non-persistent://tg/request/graph-rag:default",
- output_queue = "non-persistent://tg/response/graph-rag:default",
-)
+rag = GraphRagClient(pulsar_host="pulsar://localhost:6650")
-#query="""
-#This knowledge graph describes the Space Shuttle disaster.
-#Present 20 facts which are present in the knowledge graph."""
-
-query = "How many cats does Mark have?"
+query="""
+This knowledge graph describes the Space Shuttle disaster.
+Present 20 facts which are present in the knowledge graph."""
resp = rag.request(query)
diff --git a/tests/test-llm b/tests/test-llm
index aaae30a6..4e86387a 100755
--- a/tests/test-llm
+++ b/tests/test-llm
@@ -3,17 +3,14 @@
import pulsar
from trustgraph.clients.llm_client import LlmClient
-llm = LlmClient(
- pulsar_host="pulsar://pulsar:6650",
- input_queue="non-persistent://tg/request/text-completion:default",
- output_queue="non-persistent://tg/response/text-completion:default",
- subscriber="test1",
-)
+llm = LlmClient(pulsar_host="pulsar://localhost:6650")
system = "You are a lovely assistant."
-prompt="what is 2 + 2 == 5"
+prompt="Write a funny limerick about a llama"
resp = llm.request(system, prompt)
print(resp)
+
+
diff --git a/tests/test-load-pdf b/tests/test-load-pdf
deleted file mode 100755
index 838a57ce..00000000
--- a/tests/test-load-pdf
+++ /dev/null
@@ -1,36 +0,0 @@
-#!/usr/bin/env python3
-
-import pulsar
-from pulsar.schema import JsonSchema
-import base64
-
-from trustgraph.schema import Document, Metadata
-
-client = pulsar.Client("pulsar://localhost:6650", listener_name="localhost")
-
-prod = client.create_producer(
- topic="persistent://tg/flow/document-load:0000",
- schema=JsonSchema(Document),
- chunking_enabled=True,
-)
-
-path = "../sources/Challenger-Report-Vol1.pdf"
-
-with open(path, "rb") as f:
- blob = base64.b64encode(f.read()).decode("utf-8")
-
-message = Document(
- metadata = Metadata(
- id = "00001",
- metadata = [],
- user="trustgraph",
- collection="default",
- ),
- data=blob
-)
-
-prod.send(message)
-
-prod.close()
-client.close()
-
diff --git a/tests/test-load-text b/tests/test-load-text
deleted file mode 100755
index 754458aa..00000000
--- a/tests/test-load-text
+++ /dev/null
@@ -1,37 +0,0 @@
-#!/usr/bin/env python3
-
-import pulsar
-from pulsar.schema import JsonSchema
-import base64
-
-from trustgraph.schema import TextDocument, Metadata
-
-client = pulsar.Client("pulsar://localhost:6650", listener_name="localhost")
-
-prod = client.create_producer(
- topic="persistent://tg/flow/text-document-load:0000",
- schema=JsonSchema(TextDocument),
- chunking_enabled=True,
-)
-
-path = "docs/README.cats"
-
-with open(path, "r") as f:
-# blob = base64.b64encode(f.read()).decode("utf-8")
- blob = f.read()
-
-message = TextDocument(
- metadata = Metadata(
- id = "00001",
- metadata = [],
- user="trustgraph",
- collection="default",
- ),
- text=blob
-)
-
-prod.send(message)
-
-prod.close()
-client.close()
-
diff --git a/tests/test-prompt-extraction b/tests/test-prompt-extraction
index 20aaaf50..c73bd2e2 100755
--- a/tests/test-prompt-extraction
+++ b/tests/test-prompt-extraction
@@ -3,12 +3,7 @@
import json
from trustgraph.clients.prompt_client import PromptClient
-p = PromptClient(
- pulsar_host="pulsar://localhost:6650",
- input_queue="non-persistent://tg/request/prompt:default",
- output_queue="non-persistent://tg/response/prompt:default",
- subscriber="test1",
-)
+p = PromptClient(pulsar_host="pulsar://localhost:6650")
chunk="""
The Space Shuttle was a reusable spacecraft that transported astronauts and cargo to and from Earth's orbit. It was designed to launch like a rocket, maneuver in orbit like a spacecraft, and land like an airplane. The Space Shuttle was NASA's space transportation system and was used for many purposes, including:
@@ -36,8 +31,8 @@ The Space Shuttle's last mission was in 2011.
q = "Tell me some facts in the knowledge graph"
resp = p.request(
- id="extract-definitions",
- variables = {
+ id="extract-definition",
+ terms = {
"text": chunk,
}
)
@@ -45,7 +40,7 @@ resp = p.request(
print(resp)
for fact in resp:
- print(fact["entity"], "::")
+ print(fact["term"], "::")
print(fact["definition"])
print()
diff --git a/tests/test-prompt-question b/tests/test-prompt-question
index 78ba72aa..50660965 100755
--- a/tests/test-prompt-question
+++ b/tests/test-prompt-question
@@ -3,18 +3,13 @@
import pulsar
from trustgraph.clients.prompt_client import PromptClient
-p = PromptClient(
- pulsar_host="pulsar://localhost:6650",
- input_queue="non-persistent://tg/request/prompt:default",
- output_queue="non-persistent://tg/response/prompt:default",
- subscriber="test1",
-)
+p = PromptClient(pulsar_host="pulsar://localhost:6650")
question = """What is the square root of 16?"""
resp = p.request(
id="question",
- variables = {
+ terms = {
"question": question
}
)
diff --git a/tests/test-triples b/tests/test-triples
index e804d844..05263d0d 100755
--- a/tests/test-triples
+++ b/tests/test-triples
@@ -3,9 +3,7 @@
import pulsar
from trustgraph.clients.triples_query_client import TriplesQueryClient
-tq = TriplesQueryClient(
- pulsar_host="pulsar://localhost:6650",
-)
+tq = TriplesQueryClient(pulsar_host="pulsar://localhost:6650")
e = "http://trustgraph.ai/e/shuttle"
diff --git a/trustgraph-base/trustgraph/api/api.py b/trustgraph-base/trustgraph/api/api.py
index fbff8d34..ddc6b2c3 100644
--- a/trustgraph-base/trustgraph/api/api.py
+++ b/trustgraph-base/trustgraph/api/api.py
@@ -562,233 +562,3 @@ class Api:
except:
raise ProtocolException(f"Response not formatted correctly")
- def flow_list_classes(self):
-
- # The input consists of system and prompt strings
- input = {
- "operation": "list-classes",
- }
-
- url = f"{self.url}flow"
-
- # Invoke the API, input is passed as JSON
- resp = requests.post(url, json=input)
-
- # Should be a 200 status code
- if resp.status_code != 200:
- raise ProtocolException(f"Status code {resp.status_code}")
-
- try:
- # Parse the response as JSON
- object = resp.json()
- except:
- raise ProtocolException(f"Expected JSON response")
-
- self.check_error(object)
-
- try:
- return object["class-names"]
- except:
- raise ProtocolException(f"Response not formatted correctly")
-
- def flow_get_class(self, class_name):
-
- # The input consists of system and prompt strings
- input = {
- "operation": "get-class",
- "class-name": class_name,
- }
-
- url = f"{self.url}flow"
-
- # Invoke the API, input is passed as JSON
- resp = requests.post(url, json=input)
-
- # Should be a 200 status code
- if resp.status_code != 200:
- raise ProtocolException(f"Status code {resp.status_code}")
-
- try:
- # Parse the response as JSON
- object = resp.json()
- except:
- raise ProtocolException(f"Expected JSON response")
-
- self.check_error(object)
-
- try:
- return json.loads(object["class-definition"])
- except Exception as e:
- print(e)
- raise ProtocolException(f"Response not formatted correctly")
-
- def flow_put_class(self, class_name, definition):
-
- # The input consists of system and prompt strings
- input = {
- "operation": "put-class",
- "class-name": class_name,
- "class-definition": json.dumps(definition),
- }
-
- url = f"{self.url}flow"
-
- # Invoke the API, input is passed as JSON
- resp = requests.post(url, json=input)
-
- # Should be a 200 status code
- if resp.status_code != 200:
- raise ProtocolException(f"Status code {resp.status_code}")
-
- try:
- # Parse the response as JSON
- object = resp.json()
- except:
- raise ProtocolException(f"Expected JSON response")
-
- self.check_error(object)
-
- return
-
- def flow_delete_class(self, class_name):
-
- # The input consists of system and prompt strings
- input = {
- "operation": "delete-class",
- "class-name": class_name,
- }
-
- url = f"{self.url}flow"
-
- # Invoke the API, input is passed as JSON
- resp = requests.post(url, json=input)
-
- # Should be a 200 status code
- if resp.status_code != 200:
- raise ProtocolException(f"Status code {resp.status_code}")
-
- try:
- # Parse the response as JSON
- object = resp.json()
- except:
- raise ProtocolException(f"Expected JSON response")
-
- self.check_error(object)
-
- return
-
- def flow_list(self):
-
- # The input consists of system and prompt strings
- input = {
- "operation": "list-flows",
- }
-
- url = f"{self.url}flow"
-
- # Invoke the API, input is passed as JSON
- resp = requests.post(url, json=input)
-
- # Should be a 200 status code
- if resp.status_code != 200:
- raise ProtocolException(f"Status code {resp.status_code}")
-
- try:
- # Parse the response as JSON
- object = resp.json()
- except:
- raise ProtocolException(f"Expected JSON response")
-
- self.check_error(object)
-
- try:
- return object["flow-ids"]
- except:
- raise ProtocolException(f"Response not formatted correctly")
-
- def flow_get(self, id):
-
- # The input consists of system and prompt strings
- input = {
- "operation": "get-flow",
- "flow-id": id,
- }
-
- url = f"{self.url}flow"
-
- # Invoke the API, input is passed as JSON
- resp = requests.post(url, json=input)
-
- # Should be a 200 status code
- if resp.status_code != 200:
- raise ProtocolException(f"Status code {resp.status_code}")
-
- try:
- # Parse the response as JSON
- object = resp.json()
- except:
- raise ProtocolException(f"Expected JSON response")
-
- self.check_error(object)
-
- try:
- return json.loads(object["flow"])
- except:
- raise ProtocolException(f"Response not formatted correctly")
-
- def flow_start(self, class_name, id, description):
-
- # The input consists of system and prompt strings
- input = {
- "operation": "start-flow",
- "flow-id": id,
- "class-name": class_name,
- "description": description,
- }
-
- url = f"{self.url}flow"
-
- # Invoke the API, input is passed as JSON
- resp = requests.post(url, json=input)
-
- # Should be a 200 status code
- if resp.status_code != 200:
- raise ProtocolException(f"Status code {resp.status_code}")
-
- try:
- # Parse the response as JSON
- object = resp.json()
- except:
- raise ProtocolException(f"Expected JSON response")
-
- self.check_error(object)
-
- return
-
- def flow_stop(self, id):
-
- # The input consists of system and prompt strings
- input = {
- "operation": "stop-flow",
- "flow-id": id,
- }
-
- url = f"{self.url}flow"
-
- # Invoke the API, input is passed as JSON
- resp = requests.post(url, json=input)
-
- # Should be a 200 status code
- if resp.status_code != 200:
- raise ProtocolException(f"Status code {resp.status_code}")
-
- try:
- # Parse the response as JSON
- object = resp.json()
- except:
- raise ProtocolException(f"Expected JSON response")
-
- self.check_error(object)
-
- return
-
diff --git a/trustgraph-base/trustgraph/base/__init__.py b/trustgraph-base/trustgraph/base/__init__.py
index 2accbb21..3a58d51e 100644
--- a/trustgraph-base/trustgraph/base/__init__.py
+++ b/trustgraph-base/trustgraph/base/__init__.py
@@ -1,31 +1,8 @@
-from . pubsub import PulsarClient
-from . async_processor import AsyncProcessor
+from . base_processor import BaseProcessor
from . consumer import Consumer
from . producer import Producer
+from . consumer_producer import ConsumerProducer
from . publisher import Publisher
from . subscriber import Subscriber
-from . metrics import ProcessorMetrics, ConsumerMetrics, ProducerMetrics
-from . flow_processor import FlowProcessor
-from . consumer_spec import ConsumerSpec
-from . setting_spec import SettingSpec
-from . producer_spec import ProducerSpec
-from . subscriber_spec import SubscriberSpec
-from . request_response_spec import RequestResponseSpec
-from . llm_service import LlmService, LlmResult
-from . embeddings_service import EmbeddingsService
-from . embeddings_client import EmbeddingsClientSpec
-from . text_completion_client import TextCompletionClientSpec
-from . prompt_client import PromptClientSpec
-from . triples_store_service import TriplesStoreService
-from . graph_embeddings_store_service import GraphEmbeddingsStoreService
-from . document_embeddings_store_service import DocumentEmbeddingsStoreService
-from . triples_query_service import TriplesQueryService
-from . graph_embeddings_query_service import GraphEmbeddingsQueryService
-from . document_embeddings_query_service import DocumentEmbeddingsQueryService
-from . graph_embeddings_client import GraphEmbeddingsClientSpec
-from . triples_client import TriplesClientSpec
-from . document_embeddings_client import DocumentEmbeddingsClientSpec
-from . agent_service import AgentService
-from . graph_rag_client import GraphRagClientSpec
diff --git a/trustgraph-base/trustgraph/base/agent_client.py b/trustgraph-base/trustgraph/base/agent_client.py
deleted file mode 100644
index 76e1adff..00000000
--- a/trustgraph-base/trustgraph/base/agent_client.py
+++ /dev/null
@@ -1,39 +0,0 @@
-
-from . request_response_spec import RequestResponse, RequestResponseSpec
-from .. schema import AgentRequest, AgentResponse
-from .. knowledge import Uri, Literal
-
-class AgentClient(RequestResponse):
- async def request(self, recipient, question, plan=None, state=None,
- history=[], timeout=300):
-
- resp = await self.request(
- AgentRequest(
- question = question,
- plan = plan,
- state = state,
- history = history,
- ),
- recipient=recipient,
- timeout=timeout,
- )
-
- print(resp, flush=True)
-
- if resp.error:
- raise RuntimeError(resp.error.message)
-
- return resp
-
-class GraphEmbeddingsClientSpec(RequestResponseSpec):
- def __init__(
- self, request_name, response_name,
- ):
- super(GraphEmbeddingsClientSpec, self).__init__(
- request_name = request_name,
- request_schema = GraphEmbeddingsRequest,
- response_name = response_name,
- response_schema = GraphEmbeddingsResponse,
- impl = GraphEmbeddingsClient,
- )
-
diff --git a/trustgraph-base/trustgraph/base/agent_service.py b/trustgraph-base/trustgraph/base/agent_service.py
deleted file mode 100644
index 0dbe728e..00000000
--- a/trustgraph-base/trustgraph/base/agent_service.py
+++ /dev/null
@@ -1,100 +0,0 @@
-
-"""
-Agent manager service completion base class
-"""
-
-import time
-from prometheus_client import Histogram
-
-from .. schema import AgentRequest, AgentResponse, Error
-from .. exceptions import TooManyRequests
-from .. base import FlowProcessor, ConsumerSpec, ProducerSpec
-
-default_ident = "agent-manager"
-
-class AgentService(FlowProcessor):
-
- def __init__(self, **params):
-
- id = params.get("id")
-
- super(AgentService, self).__init__(**params | { "id": id })
-
- self.register_specification(
- ConsumerSpec(
- name = "request",
- schema = AgentRequest,
- handler = self.on_request
- )
- )
-
- self.register_specification(
- ProducerSpec(
- name = "next",
- schema = AgentRequest
- )
- )
-
- self.register_specification(
- ProducerSpec(
- name = "response",
- schema = AgentResponse
- )
- )
-
- async def on_request(self, msg, consumer, flow):
-
- try:
-
- request = msg.value()
-
- # Sender-produced ID
- id = msg.properties()["id"]
-
- async def respond(resp):
-
- await flow("response").send(
- resp,
- properties={"id": id}
- )
-
- async def next(resp):
-
- await flow("next").send(
- resp,
- properties={"id": id}
- )
-
- await self.agent_request(
- request = request, respond = respond, next = next,
- flow = flow
- )
-
- except TooManyRequests as e:
- raise e
-
- except Exception as e:
-
- # Apart from rate limits, treat all exceptions as unrecoverable
- print(f"on_request Exception: {e}")
-
- print("Send error response...", flush=True)
-
- await flow.producer["response"].send(
- AgentResponse(
- error=Error(
- type = "agent-error",
- message = str(e),
- ),
- thought = None,
- observation = None,
- answer = None,
- ),
- properties={"id": id}
- )
-
- @staticmethod
- def add_args(parser):
-
- FlowProcessor.add_args(parser)
-
diff --git a/trustgraph-base/trustgraph/base/async_processor.py b/trustgraph-base/trustgraph/base/async_processor.py
deleted file mode 100644
index bdf9a0bb..00000000
--- a/trustgraph-base/trustgraph/base/async_processor.py
+++ /dev/null
@@ -1,257 +0,0 @@
-
-# Base class for processors. Implements:
-# - Pulsar client, subscribe and consume basic
-# - the async startup logic
-# - Initialising metrics
-
-import asyncio
-import argparse
-import _pulsar
-import time
-import uuid
-from prometheus_client import start_http_server, Info
-
-from .. schema import ConfigPush, config_push_queue
-from .. log_level import LogLevel
-from .. exceptions import TooManyRequests
-from . pubsub import PulsarClient
-from . producer import Producer
-from . consumer import Consumer
-from . metrics import ProcessorMetrics, ConsumerMetrics
-
-default_config_queue = config_push_queue
-
-# Async processor
-class AsyncProcessor:
-
- def __init__(self, **params):
-
- # Store the identity
- self.id = params.get("id")
-
- # Register a pulsar client
- self.pulsar_client_object = PulsarClient(**params)
-
- # Initialise metrics, records the parameters
- ProcessorMetrics(processor = self.id).info({
- k: str(params[k])
- for k in params
- if k != "id"
- })
-
- # The processor runs all activity in a taskgroup, it's mandatory
- # that this is provded
- self.taskgroup = params.get("taskgroup")
- if self.taskgroup is None:
- raise RuntimeError("Essential taskgroup missing")
-
- # Get the configuration topic
- self.config_push_queue = params.get(
- "config_push_queue", default_config_queue
- )
-
- # This records registered configuration handlers
- self.config_handlers = []
-
- # Create a random ID for this subscription to the configuration
- # service
- config_subscriber_id = str(uuid.uuid4())
-
- config_consumer_metrics = ConsumerMetrics(
- processor = self.id, flow = None, name = "config",
- )
-
- # Subscribe to config queue
- self.config_sub_task = Consumer(
-
- taskgroup = self.taskgroup,
- client = self.pulsar_client,
- subscriber = config_subscriber_id,
- flow = None,
-
- topic = self.config_push_queue,
- schema = ConfigPush,
-
- handler = self.on_config_change,
-
- metrics = config_consumer_metrics,
-
- # This causes new subscriptions to view the entire history of
- # configuration
- start_of_messages = True
- )
-
- self.running = True
-
- # This is called to start dynamic behaviour. An over-ride point for
- # extra functionality
- async def start(self):
- await self.config_sub_task.start()
-
- # This is called to stop all threads. An over-ride point for extra
- # functionality
- def stop(self):
- self.pulsar_client.close()
- self.running = False
-
- # Returns the pulsar host
- @property
- def pulsar_host(self): return self.pulsar_client_object.pulsar_host
-
- # Returns the pulsar client
- @property
- def pulsar_client(self): return self.pulsar_client_object.client
-
- # Register a new event handler for configuration change
- def register_config_handler(self, handler):
- self.config_handlers.append(handler)
-
- # Called when a new configuration message push occurs
- async def on_config_change(self, message, consumer, flow):
-
- # Get configuration data and version number
- config = message.value().config
- version = message.value().version
-
- # Invoke message handlers
- print("Config change event", config, version, flush=True)
- for ch in self.config_handlers:
- await ch(config, version)
-
- # This is the 'main' body of the handler. It is a point to override
- # if needed. By default does nothing. Processors are implemented
- # by adding consumer/producer functionality so maybe nothing is needed
- # in the run() body
- async def run(self):
- while self.running:
- await asyncio.sleep(2)
-
- # Startup fabric. This runs in 'async' mode, creates a taskgroup and
- # runs the producer.
- @classmethod
- async def launch_async(cls, args):
-
- try:
-
- # Create a taskgroup. This seems complicated, when an exception
- # occurs, unhandled it looks like it cancels all threads in the
- # taskgroup. Needs the exception to be caught in the right
- # place.
- async with asyncio.TaskGroup() as tg:
-
-
- # Create a processor instance, and include the taskgroup
- # as a paramter. A processor identity ident is used as
- # - subscriber name
- # - an identifier for flow configuration
- p = cls(**args | { "taskgroup": tg })
-
- # Start the processor
- await p.start()
-
- # Run the processor
- task = tg.create_task(p.run())
-
- # The taskgroup causes everything to wait until
- # all threads have stopped
-
- # This is here to output a debug message, shouldn't be needed.
- except Exception as e:
- print("Exception, closing taskgroup", flush=True)
- raise e
-
- # Startup fabric. launch calls launch_async in async mode.
- @classmethod
- def launch(cls, ident, doc):
-
- # Start assembling CLI arguments
- parser = argparse.ArgumentParser(
- prog=ident,
- description=doc
- )
-
- parser.add_argument(
- '--id',
- default=ident,
- help=f'Configuration identity (default: {ident})',
- )
-
- # Invoke the class-specific add_args, which manages adding all the
- # command-line arguments
- cls.add_args(parser)
-
- # Parse arguments
- args = parser.parse_args()
- args = vars(args)
-
- # Debug
- print(args, flush=True)
-
- # Start the Prometheus metrics service if needed
- if args["metrics"]:
- start_http_server(args["metrics_port"])
-
- # Loop forever, exception handler
- while True:
-
- print("Starting...", flush=True)
-
- try:
-
- # Launch the processor in an asyncio handler
- asyncio.run(cls.launch_async(
- args
- ))
-
- except KeyboardInterrupt:
- print("Keyboard interrupt.", flush=True)
- return
-
- except _pulsar.Interrupted:
- print("Pulsar Interrupted.", flush=True)
- return
-
- # Exceptions from a taskgroup come in as an exception group
- except ExceptionGroup as e:
-
- print("Exception group:", flush=True)
-
- for se in e.exceptions:
- print(" Type:", type(se), flush=True)
- print(f" Exception: {se}", flush=True)
-
- except Exception as e:
- print("Type:", type(e), flush=True)
- print("Exception:", e, flush=True)
-
- # Retry occurs here
- print("Will retry...", flush=True)
- time.sleep(4)
- print("Retrying...", flush=True)
-
- # The command-line arguments are built using a stack of add_args
- # invocations
- @staticmethod
- def add_args(parser):
-
- PulsarClient.add_args(parser)
-
- parser.add_argument(
- '--config-queue',
- default=default_config_queue,
- help=f'Config push queue {default_config_queue}',
- )
-
- parser.add_argument(
- '--metrics',
- action=argparse.BooleanOptionalAction,
- default=True,
- help=f'Metrics enabled (default: true)',
- )
-
- parser.add_argument(
- '-P', '--metrics-port',
- type=int,
- default=8000,
- help=f'Pulsar host (default: 8000)',
- )
diff --git a/trustgraph-base/trustgraph/base/base_processor.py b/trustgraph-base/trustgraph/base/base_processor.py
new file mode 100644
index 00000000..05cdb940
--- /dev/null
+++ b/trustgraph-base/trustgraph/base/base_processor.py
@@ -0,0 +1,210 @@
+
+import asyncio
+import os
+import argparse
+import pulsar
+from pulsar.schema import JsonSchema
+import _pulsar
+import time
+import uuid
+from prometheus_client import start_http_server, Info
+
+from .. schema import ConfigPush, config_push_queue
+from .. log_level import LogLevel
+
+default_config_queue = config_push_queue
+config_subscriber_id = str(uuid.uuid4())
+
+class BaseProcessor:
+
+ default_pulsar_host = os.getenv("PULSAR_HOST", 'pulsar://pulsar:6650')
+ default_pulsar_api_key = os.getenv("PULSAR_API_KEY", None)
+
+ def __init__(self, **params):
+
+ self.client = None
+
+ if not hasattr(__class__, "params_metric"):
+ __class__.params_metric = Info(
+ 'params', 'Parameters configuration'
+ )
+
+ # FIXME: Maybe outputs information it should not
+ __class__.params_metric.info({
+ k: str(params[k])
+ for k in params
+ })
+
+ pulsar_host = params.get("pulsar_host", self.default_pulsar_host)
+ pulsar_listener = params.get("pulsar_listener", None)
+ pulsar_api_key = params.get("pulsar_api_key", None)
+ log_level = params.get("log_level", LogLevel.INFO)
+
+ self.config_push_queue = params.get(
+ "config_push_queue",
+ default_config_queue
+ )
+
+ self.pulsar_host = pulsar_host
+ self.pulsar_api_key = pulsar_api_key
+
+ if pulsar_api_key:
+ auth = pulsar.AuthenticationToken(pulsar_api_key)
+ self.client = pulsar.Client(
+ pulsar_host,
+ authentication=auth,
+ logger=pulsar.ConsoleLogger(log_level.to_pulsar())
+ )
+ else:
+ self.client = pulsar.Client(
+ pulsar_host,
+ listener_name=pulsar_listener,
+ logger=pulsar.ConsoleLogger(log_level.to_pulsar())
+ )
+
+ self.pulsar_listener = pulsar_listener
+
+ self.config_subscriber = self.client.subscribe(
+ self.config_push_queue, config_subscriber_id,
+ consumer_type=pulsar.ConsumerType.Shared,
+ initial_position=pulsar.InitialPosition.Earliest,
+ schema=JsonSchema(ConfigPush),
+ )
+
+ def __del__(self):
+
+ if hasattr(self, "client"):
+ if self.client:
+ self.client.close()
+
+ @staticmethod
+ def add_args(parser):
+
+ parser.add_argument(
+ '-p', '--pulsar-host',
+ default=__class__.default_pulsar_host,
+ help=f'Pulsar host (default: {__class__.default_pulsar_host})',
+ )
+
+ parser.add_argument(
+ '--pulsar-api-key',
+ default=__class__.default_pulsar_api_key,
+ help=f'Pulsar API key',
+ )
+
+ parser.add_argument(
+ '--config-push-queue',
+ default=default_config_queue,
+ help=f'Config push queue {default_config_queue}',
+ )
+
+ parser.add_argument(
+ '--pulsar-listener',
+ help=f'Pulsar listener (default: none)',
+ )
+
+ parser.add_argument(
+ '-l', '--log-level',
+ type=LogLevel,
+ default=LogLevel.INFO,
+ choices=list(LogLevel),
+ help=f'Output queue (default: info)'
+ )
+
+ parser.add_argument(
+ '--metrics',
+ action=argparse.BooleanOptionalAction,
+ default=True,
+ help=f'Metrics enabled (default: true)',
+ )
+
+ parser.add_argument(
+ '-P', '--metrics-port',
+ type=int,
+ default=8000,
+ help=f'Pulsar host (default: 8000)',
+ )
+
+ async def start(self):
+ pass
+
+ async def run_config_queue(self):
+
+ if self.module == "config.service":
+ print("I am config-svc, not looking at config queue", flush=True)
+ return
+
+ print("Config thread running", flush=True)
+
+ while True:
+
+ try:
+ msg = await asyncio.to_thread(
+ self.config_subscriber.receive, timeout_millis=2000
+ )
+ except pulsar.Timeout:
+ continue
+
+ v = msg.value()
+ print("Got config version", v.version, flush=True)
+
+ await self.on_config(v.version, v.config)
+
+ async def on_config(self, version, config):
+ pass
+
+ async def run(self):
+ raise RuntimeError("Something should have implemented the run method")
+
+ @classmethod
+ async def launch_async(cls, args, prog):
+ p = cls(**args)
+ p.module = prog
+ await p.start()
+
+ task1 = asyncio.create_task(p.run_config_queue())
+ task2 = asyncio.create_task(p.run())
+
+ await asyncio.gather(task1, task2)
+
+ @classmethod
+ def launch(cls, prog, doc):
+
+ parser = argparse.ArgumentParser(
+ prog=prog,
+ description=doc
+ )
+
+ cls.add_args(parser)
+
+ args = parser.parse_args()
+ args = vars(args)
+
+ print(args)
+
+ if args["metrics"]:
+ start_http_server(args["metrics_port"])
+
+ while True:
+
+ try:
+
+ asyncio.run(cls.launch_async(args, prog))
+
+ except KeyboardInterrupt:
+ print("Keyboard interrupt.")
+ return
+
+ except _pulsar.Interrupted:
+ print("Pulsar Interrupted.")
+ return
+
+ except Exception as e:
+
+ print(type(e))
+
+ print("Exception:", e, flush=True)
+ print("Will retry...", flush=True)
+
+ time.sleep(4)
+
diff --git a/trustgraph-base/trustgraph/base/consumer.py b/trustgraph-base/trustgraph/base/consumer.py
index 8f262b83..fdbe5531 100644
--- a/trustgraph-base/trustgraph/base/consumer.py
+++ b/trustgraph-base/trustgraph/base/consumer.py
@@ -1,136 +1,93 @@
+import asyncio
from pulsar.schema import JsonSchema
import pulsar
-import _pulsar
-import asyncio
+from prometheus_client import Histogram, Info, Counter, Enum
import time
+from . base_processor import BaseProcessor
from .. exceptions import TooManyRequests
-class Consumer:
+default_rate_limit_retry = 10
+default_rate_limit_timeout = 7200
- def __init__(
- self, taskgroup, flow, client, topic, subscriber, schema,
- handler,
- metrics = None,
- start_of_messages=False,
- rate_limit_retry_time = 10, rate_limit_timeout = 7200,
- reconnect_time = 5,
- ):
+class Consumer(BaseProcessor):
- self.taskgroup = taskgroup
- self.flow = flow
- self.client = client
- self.topic = topic
- self.subscriber = subscriber
- self.schema = schema
- self.handler = handler
+ def __init__(self, **params):
- self.rate_limit_retry_time = rate_limit_retry_time
- self.rate_limit_timeout = rate_limit_timeout
+ if not hasattr(__class__, "state_metric"):
+ __class__.state_metric = Enum(
+ 'processor_state', 'Processor state',
+ states=['starting', 'running', 'stopped']
+ )
+ __class__.state_metric.state('starting')
- self.reconnect_time = 5
+ __class__.state_metric.state('starting')
- self.start_of_messages = start_of_messages
+ super(Consumer, self).__init__(**params)
- self.running = True
- self.task = None
+ self.input_queue = params.get("input_queue")
+ self.subscriber = params.get("subscriber")
+ self.input_schema = params.get("input_schema")
- self.metrics = metrics
+ self.rate_limit_retry = params.get(
+ "rate_limit_retry", default_rate_limit_retry
+ )
+ self.rate_limit_timeout = params.get(
+ "rate_limit_timeout", default_rate_limit_timeout
+ )
- self.consumer = None
+ if self.input_schema == None:
+ raise RuntimeError("input_schema must be specified")
- def __del__(self):
- self.running = False
+ if not hasattr(__class__, "request_metric"):
+ __class__.request_metric = Histogram(
+ 'request_latency', 'Request latency (seconds)'
+ )
- if hasattr(self, "consumer"):
- if self.consumer:
- self.consumer.close()
+ if not hasattr(__class__, "pubsub_metric"):
+ __class__.pubsub_metric = Info(
+ 'pubsub', 'Pub/sub configuration'
+ )
- async def stop(self):
+ if not hasattr(__class__, "processing_metric"):
+ __class__.processing_metric = Counter(
+ 'processing_count', 'Processing count', ["status"]
+ )
- self.running = False
- await self.task
+ if not hasattr(__class__, "rate_limit_metric"):
+ __class__.rate_limit_metric = Counter(
+ 'rate_limit_count', 'Rate limit event count',
+ )
- async def start(self):
+ __class__.pubsub_metric.info({
+ "input_queue": self.input_queue,
+ "subscriber": self.subscriber,
+ "input_schema": self.input_schema.__name__,
+ "rate_limit_retry": str(self.rate_limit_retry),
+ "rate_limit_timeout": str(self.rate_limit_timeout),
+ })
- self.running = True
+ self.consumer = self.client.subscribe(
+ self.input_queue, self.subscriber,
+ consumer_type=pulsar.ConsumerType.Shared,
+ schema=JsonSchema(self.input_schema),
+ )
- # Puts it in the stopped state, the run thread should set running
- if self.metrics:
- self.metrics.state("stopped")
-
- self.task = self.taskgroup.create_task(self.run())
+ print("Initialised consumer.", flush=True)
async def run(self):
- while self.running:
+ __class__.state_metric.state('running')
- if self.metrics:
- self.metrics.state("stopped")
+ while True:
- try:
-
- print(self.topic, "subscribing...", flush=True)
-
- if self.start_of_messages:
- pos = pulsar.InitialPosition.Earliest
- else:
- pos = pulsar.InitialPosition.Latest
-
- self.consumer = await asyncio.to_thread(
- self.client.subscribe,
- topic = self.topic,
- subscription_name = self.subscriber,
- schema = JsonSchema(self.schema),
- initial_position = pos,
- consumer_type = pulsar.ConsumerType.Shared,
- )
-
- except Exception as e:
-
- print("consumer subs Exception:", e, flush=True)
- await asyncio.sleep(self.reconnect_time)
- continue
-
- print(self.topic, "subscribed", flush=True)
-
- if self.metrics:
- self.metrics.state("running")
-
- try:
-
- await self.consume()
-
- if self.metrics:
- self.metrics.state("stopped")
-
- except Exception as e:
-
- print("consumer loop exception:", e, flush=True)
- self.consumer.close()
- self.consumer = None
- await asyncio.sleep(self.reconnect_time)
- continue
-
- async def consume(self):
-
- while self.running:
-
- try:
- msg = await asyncio.to_thread(
- self.consumer.receive,
- timeout_millis=2000
- )
- except _pulsar.Timeout:
- continue
- except Exception as e:
- raise e
+ msg = await asyncio.to_thread(self.consumer.receive)
expiry = time.time() + self.rate_limit_timeout
# This loop is for retry on rate-limit / resource limits
- while self.running:
+ while True:
if time.time() > expiry:
@@ -140,31 +97,20 @@ class Consumer:
# be retried
self.consumer.negative_acknowledge(msg)
- if self.metrics:
- self.metrics.process("error")
+ __class__.processing_metric.labels(status="error").inc()
# Break out of retry loop, processes next message
break
try:
- print("Handle...", flush=True)
-
- if self.metrics:
-
- with self.metrics.record_time():
- await self.handler(msg, self, self.flow)
-
- else:
- await self.handler(msg, self, self.flow)
-
- print("Handled.", flush=True)
+ with __class__.request_metric.time():
+ await self.handle(msg)
# Acknowledge successful processing of the message
self.consumer.acknowledge(msg)
- if self.metrics:
- self.metrics.process("success")
+ __class__.processing_metric.labels(status="success").inc()
# Break out of retry loop
break
@@ -173,25 +119,55 @@ class Consumer:
print("TooManyRequests: will retry...", flush=True)
- if self.metrics:
- self.metrics.rate_limit()
+ __class__.rate_limit_metric.inc()
# Sleep
- await asyncio.sleep(self.rate_limit_retry_time)
+ time.sleep(self.rate_limit_retry)
# Contine from retry loop, just causes a reprocessing
continue
-
+
except Exception as e:
- print("consume exception:", e, flush=True)
+ print("Exception:", e, flush=True)
# Message failed to be processed, this causes it to
# be retried
self.consumer.negative_acknowledge(msg)
- if self.metrics:
- self.metrics.process("error")
+ __class__.processing_metric.labels(status="error").inc()
# Break out of retry loop, processes next message
break
+
+ @staticmethod
+ def add_args(parser, default_input_queue, default_subscriber):
+
+ BaseProcessor.add_args(parser)
+
+ parser.add_argument(
+ '-i', '--input-queue',
+ default=default_input_queue,
+ help=f'Input queue (default: {default_input_queue})'
+ )
+
+ parser.add_argument(
+ '-s', '--subscriber',
+ default=default_subscriber,
+ help=f'Queue subscriber name (default: {default_subscriber})'
+ )
+
+ parser.add_argument(
+ '--rate-limit-retry',
+ type=int,
+ default=default_rate_limit_retry,
+ help=f'Rate limit retry (default: {default_rate_limit_retry})'
+ )
+
+ parser.add_argument(
+ '--rate-limit-timeout',
+ type=int,
+ default=default_rate_limit_timeout,
+ help=f'Rate limit timeout (default: {default_rate_limit_timeout})'
+ )
+
diff --git a/trustgraph-base/trustgraph/base/consumer_producer.py b/trustgraph-base/trustgraph/base/consumer_producer.py
new file mode 100644
index 00000000..1006f9b5
--- /dev/null
+++ b/trustgraph-base/trustgraph/base/consumer_producer.py
@@ -0,0 +1,62 @@
+
+from pulsar.schema import JsonSchema
+import pulsar
+from prometheus_client import Histogram, Info, Counter, Enum
+import time
+
+from . consumer import Consumer
+from .. exceptions import TooManyRequests
+
+class ConsumerProducer(Consumer):
+
+ def __init__(self, **params):
+
+ super(ConsumerProducer, self).__init__(**params)
+
+ self.output_queue = params.get("output_queue")
+ self.output_schema = params.get("output_schema")
+
+ if not hasattr(__class__, "output_metric"):
+ __class__.output_metric = Counter(
+ 'output_count', 'Output items created'
+ )
+
+ __class__.pubsub_metric.info({
+ "input_queue": self.input_queue,
+ "output_queue": self.output_queue,
+ "subscriber": self.subscriber,
+ "input_schema": self.input_schema.__name__,
+ "output_schema": self.output_schema.__name__,
+ "rate_limit_retry": str(self.rate_limit_retry),
+ "rate_limit_timeout": str(self.rate_limit_timeout),
+ })
+
+ if self.output_schema == None:
+ raise RuntimeError("output_schema must be specified")
+
+ self.producer = self.client.create_producer(
+ topic=self.output_queue,
+ schema=JsonSchema(self.output_schema),
+ chunking_enabled=True,
+ )
+
+ print("Initialised consumer/producer.")
+
+ async def send(self, msg, properties={}):
+ self.producer.send(msg, properties)
+ __class__.output_metric.inc()
+
+ @staticmethod
+ def add_args(
+ parser, default_input_queue, default_subscriber,
+ default_output_queue,
+ ):
+
+ Consumer.add_args(parser, default_input_queue, default_subscriber)
+
+ parser.add_argument(
+ '-o', '--output-queue',
+ default=default_output_queue,
+ help=f'Output queue (default: {default_output_queue})'
+ )
+
diff --git a/trustgraph-base/trustgraph/base/consumer_spec.py b/trustgraph-base/trustgraph/base/consumer_spec.py
deleted file mode 100644
index 21497dc5..00000000
--- a/trustgraph-base/trustgraph/base/consumer_spec.py
+++ /dev/null
@@ -1,36 +0,0 @@
-
-from . metrics import ConsumerMetrics
-from . consumer import Consumer
-from . spec import Spec
-
-class ConsumerSpec(Spec):
- def __init__(self, name, schema, handler):
- self.name = name
- self.schema = schema
- self.handler = handler
-
- def add(self, flow, processor, definition):
-
- consumer_metrics = ConsumerMetrics(
- processor = flow.id, flow = flow.name, name = self.name,
- )
-
- consumer = Consumer(
- taskgroup = processor.taskgroup,
- flow = flow,
- client = processor.pulsar_client,
- topic = definition[self.name],
- subscriber = processor.id + "--" + self.name,
- schema = self.schema,
- handler = self.handler,
- metrics = consumer_metrics,
- )
-
- # Consumer handle gets access to producers and other
- # metadata
- consumer.id = flow.id
- consumer.name = self.name
- consumer.flow = flow
-
- flow.consumer[self.name] = consumer
-
diff --git a/trustgraph-base/trustgraph/base/document_embeddings_client.py b/trustgraph-base/trustgraph/base/document_embeddings_client.py
deleted file mode 100644
index 86370c52..00000000
--- a/trustgraph-base/trustgraph/base/document_embeddings_client.py
+++ /dev/null
@@ -1,38 +0,0 @@
-
-from . request_response_spec import RequestResponse, RequestResponseSpec
-from .. schema import DocumentEmbeddingsRequest, DocumentEmbeddingsResponse
-from .. knowledge import Uri, Literal
-
-class DocumentEmbeddingsClient(RequestResponse):
- async def query(self, vectors, limit=20, user="trustgraph",
- collection="default", timeout=30):
-
- resp = await self.request(
- DocumentEmbeddingsRequest(
- vectors = vectors,
- limit = limit,
- user = user,
- collection = collection
- ),
- timeout=timeout
- )
-
- print(resp, flush=True)
-
- if resp.error:
- raise RuntimeError(resp.error.message)
-
- return resp.documents
-
-class DocumentEmbeddingsClientSpec(RequestResponseSpec):
- def __init__(
- self, request_name, response_name,
- ):
- super(DocumentEmbeddingsClientSpec, self).__init__(
- request_name = request_name,
- request_schema = DocumentEmbeddingsRequest,
- response_name = response_name,
- response_schema = DocumentEmbeddingsResponse,
- impl = DocumentEmbeddingsClient,
- )
-
diff --git a/trustgraph-base/trustgraph/base/document_embeddings_query_service.py b/trustgraph-base/trustgraph/base/document_embeddings_query_service.py
deleted file mode 100644
index 0dee7001..00000000
--- a/trustgraph-base/trustgraph/base/document_embeddings_query_service.py
+++ /dev/null
@@ -1,84 +0,0 @@
-
-"""
-Document embeddings query service. Input is vectors. Output is list of
-embeddings.
-"""
-
-from .. schema import DocumentEmbeddingsRequest, DocumentEmbeddingsResponse
-from .. schema import Error, Value
-
-from . flow_processor import FlowProcessor
-from . consumer_spec import ConsumerSpec
-from . producer_spec import ProducerSpec
-
-default_ident = "ge-query"
-
-class DocumentEmbeddingsQueryService(FlowProcessor):
-
- def __init__(self, **params):
-
- id = params.get("id")
-
- super(DocumentEmbeddingsQueryService, self).__init__(
- **params | { "id": id }
- )
-
- self.register_specification(
- ConsumerSpec(
- name = "request",
- schema = DocumentEmbeddingsRequest,
- handler = self.on_message
- )
- )
-
- self.register_specification(
- ProducerSpec(
- name = "response",
- schema = DocumentEmbeddingsResponse,
- )
- )
-
- async def on_message(self, msg, consumer, flow):
-
- try:
-
- request = msg.value()
-
- # Sender-produced ID
- id = msg.properties()["id"]
-
- print(f"Handling input {id}...", flush=True)
-
- docs = await self.query_document_embeddings(request)
-
- print("Send response...", flush=True)
- r = DocumentEmbeddingsResponse(documents=docs, error=None)
- await flow("response").send(r, properties={"id": id})
-
- print("Done.", flush=True)
-
- except Exception as e:
-
- print(f"Exception: {e}")
-
- print("Send error response...", flush=True)
-
- r = DocumentEmbeddingsResponse(
- error=Error(
- type = "document-embeddings-query-error",
- message = str(e),
- ),
- response=None,
- )
-
- await flow("response").send(r, properties={"id": id})
-
- @staticmethod
- def add_args(parser):
-
- FlowProcessor.add_args(parser)
-
-def run():
-
- Processor.launch(default_ident, __doc__)
-
diff --git a/trustgraph-base/trustgraph/base/document_embeddings_store_service.py b/trustgraph-base/trustgraph/base/document_embeddings_store_service.py
deleted file mode 100644
index fbf58869..00000000
--- a/trustgraph-base/trustgraph/base/document_embeddings_store_service.py
+++ /dev/null
@@ -1,50 +0,0 @@
-
-"""
-Document embeddings store base class
-"""
-
-from .. schema import DocumentEmbeddings
-from .. base import FlowProcessor, ConsumerSpec
-from .. exceptions import TooManyRequests
-
-default_ident = "document-embeddings-write"
-
-class DocumentEmbeddingsStoreService(FlowProcessor):
-
- def __init__(self, **params):
-
- id = params.get("id")
-
- super(DocumentEmbeddingsStoreService, self).__init__(
- **params | { "id": id }
- )
-
- self.register_specification(
- ConsumerSpec(
- name = "input",
- schema = DocumentEmbeddings,
- handler = self.on_message
- )
- )
-
- async def on_message(self, msg, consumer, flow):
-
- try:
-
- request = msg.value()
-
- await self.store_document_embeddings(request)
-
- except TooManyRequests as e:
- raise e
-
- except Exception as e:
-
- print(f"Exception: {e}")
- raise e
-
- @staticmethod
- def add_args(parser):
-
- FlowProcessor.add_args(parser)
-
diff --git a/trustgraph-base/trustgraph/base/embeddings_client.py b/trustgraph-base/trustgraph/base/embeddings_client.py
deleted file mode 100644
index ceb08eb2..00000000
--- a/trustgraph-base/trustgraph/base/embeddings_client.py
+++ /dev/null
@@ -1,31 +0,0 @@
-
-from . request_response_spec import RequestResponse, RequestResponseSpec
-from .. schema import EmbeddingsRequest, EmbeddingsResponse
-
-class EmbeddingsClient(RequestResponse):
- async def embed(self, text, timeout=30):
-
- resp = await self.request(
- EmbeddingsRequest(
- text = text
- ),
- timeout=timeout
- )
-
- if resp.error:
- raise RuntimeError(resp.error.message)
-
- return resp.vectors
-
-class EmbeddingsClientSpec(RequestResponseSpec):
- def __init__(
- self, request_name, response_name,
- ):
- super(EmbeddingsClientSpec, self).__init__(
- request_name = request_name,
- request_schema = EmbeddingsRequest,
- response_name = response_name,
- response_schema = EmbeddingsResponse,
- impl = EmbeddingsClient,
- )
-
diff --git a/trustgraph-base/trustgraph/base/embeddings_service.py b/trustgraph-base/trustgraph/base/embeddings_service.py
deleted file mode 100644
index c6befdb7..00000000
--- a/trustgraph-base/trustgraph/base/embeddings_service.py
+++ /dev/null
@@ -1,90 +0,0 @@
-
-"""
-Embeddings resolution base class
-"""
-
-import time
-from prometheus_client import Histogram
-
-from .. schema import EmbeddingsRequest, EmbeddingsResponse, Error
-from .. exceptions import TooManyRequests
-from .. base import FlowProcessor, ConsumerSpec, ProducerSpec
-
-default_ident = "embeddings"
-
-class EmbeddingsService(FlowProcessor):
-
- def __init__(self, **params):
-
- id = params.get("id")
-
- super(EmbeddingsService, self).__init__(**params | { "id": id })
-
- self.register_specification(
- ConsumerSpec(
- name = "request",
- schema = EmbeddingsRequest,
- handler = self.on_request
- )
- )
-
- self.register_specification(
- ProducerSpec(
- name = "response",
- schema = EmbeddingsResponse
- )
- )
-
- async def on_request(self, msg, consumer, flow):
-
- try:
-
- request = msg.value()
-
- # Sender-produced ID
-
- id = msg.properties()["id"]
-
- print("Handling request", id, "...", flush=True)
-
- vectors = await self.on_embeddings(request.text)
-
- await flow("response").send(
- EmbeddingsResponse(
- error = None,
- vectors = vectors,
- ),
- properties={"id": id}
- )
-
- print("Handled.", flush=True)
-
- except TooManyRequests as e:
- raise e
-
- except Exception as e:
-
- # Apart from rate limits, treat all exceptions as unrecoverable
-
- print(f"Exception: {e}", flush=True)
-
- print("Send error response...", flush=True)
-
- await flow.producer["response"].send(
- EmbeddingsResponse(
- error=Error(
- type = "embeddings-error",
- message = str(e),
- ),
- vectors=None,
- ),
- properties={"id": id}
- )
-
- @staticmethod
- def add_args(parser):
-
- FlowProcessor.add_args(parser)
-
-
-
diff --git a/trustgraph-base/trustgraph/base/flow.py b/trustgraph-base/trustgraph/base/flow.py
deleted file mode 100644
index 9cda34a0..00000000
--- a/trustgraph-base/trustgraph/base/flow.py
+++ /dev/null
@@ -1,32 +0,0 @@
-
-import asyncio
-
-class Flow:
- def __init__(self, id, flow, processor, defn):
-
- self.id = id
- self.name = flow
-
- self.producer = {}
-
- # Consumers and publishers. Is this a bit untidy?
- self.consumer = {}
-
- self.setting = {}
-
- for spec in processor.specifications:
- spec.add(self, processor, defn)
-
- async def start(self):
- for c in self.consumer.values():
- await c.start()
-
- async def stop(self):
- for c in self.consumer.values():
- await c.stop()
-
- def __call__(self, key):
- if key in self.producer: return self.producer[key]
- if key in self.consumer: return self.consumer[key]
- if key in self.setting: return self.setting[key].value
- return None
diff --git a/trustgraph-base/trustgraph/base/flow_processor.py b/trustgraph-base/trustgraph/base/flow_processor.py
deleted file mode 100644
index e6460fe3..00000000
--- a/trustgraph-base/trustgraph/base/flow_processor.py
+++ /dev/null
@@ -1,115 +0,0 @@
-
-# Base class for processor with management of flows in & out which are managed
-# by configuration. This is probably all processor types, except for the
-# configuration service which can't manage itself.
-
-import json
-
-from pulsar.schema import JsonSchema
-
-from .. schema import Error
-from .. schema import config_request_queue, config_response_queue
-from .. schema import config_push_queue
-from .. log_level import LogLevel
-from . async_processor import AsyncProcessor
-from . flow import Flow
-
-# Parent class for configurable processors, configured with flows by
-# the config service
-class FlowProcessor(AsyncProcessor):
-
- def __init__(self, **params):
-
- # Initialise base class
- super(FlowProcessor, self).__init__(**params)
-
- # Register configuration handler
- self.register_config_handler(self.on_configure_flows)
-
- # Initialise flow information state
- self.flows = {}
-
- # These can be overriden by a derived class:
-
- # Array of specifications: ConsumerSpec, ProducerSpec, SettingSpec
- self.specifications = []
-
- print("Service initialised.")
-
- # Register a configuration variable
- def register_specification(self, spec):
- self.specifications.append(spec)
-
- # Start processing for a new flow
- async def start_flow(self, flow, defn):
- self.flows[flow] = Flow(self.id, flow, self, defn)
- await self.flows[flow].start()
- print("Started flow: ", flow)
-
- # Stop processing for a new flow
- async def stop_flow(self, flow):
- if flow in self.flows:
- await self.flows[flow].stop()
- del self.flows[flow]
- print("Stopped flow: ", flow, flush=True)
-
- # Event handler - called for a configuration change
- async def on_configure_flows(self, config, version):
-
- print("Got config version", version, flush=True)
-
- # Skip over invalid data
- if "flows-active" not in config: return
-
- # Check there's configuration information for me
- if self.id in config["flows-active"]:
-
- # Get my flow config
- flow_config = json.loads(config["flows-active"][self.id])
-
- else:
-
- print("No configuration settings for me.", flush=True)
- flow_config = {}
-
- # Get list of flows which should be running and are currently
- # running
- wanted_flows = flow_config.keys()
- current_flows = self.flows.keys()
-
- # Start all the flows which arent currently running
- for flow in wanted_flows:
- if flow not in current_flows:
- await self.start_flow(flow, flow_config[flow])
-
- # Stop all the unwanted flows which are due to be stopped
- for flow in current_flows:
- if flow not in wanted_flows:
- await self.stop_flow(flow)
-
- print("Handled config update")
-
- # Start threads, just call parent
- async def start(self):
- await super(FlowProcessor, self).start()
-
- @staticmethod
- def add_args(parser):
-
- AsyncProcessor.add_args(parser)
-
- # parser.add_argument(
- # '--rate-limit-retry',
- # type=int,
- # default=default_rate_limit_retry,
- # help=f'Rate limit retry (default: {default_rate_limit_retry})'
- # )
-
- # parser.add_argument(
- # '--rate-limit-timeout',
- # type=int,
- # default=default_rate_limit_timeout,
- # help=f'Rate limit timeout (default: {default_rate_limit_timeout})'
- # )
-
-
diff --git a/trustgraph-base/trustgraph/base/graph_embeddings_client.py b/trustgraph-base/trustgraph/base/graph_embeddings_client.py
deleted file mode 100644
index e89364f2..00000000
--- a/trustgraph-base/trustgraph/base/graph_embeddings_client.py
+++ /dev/null
@@ -1,45 +0,0 @@
-
-from . request_response_spec import RequestResponse, RequestResponseSpec
-from .. schema import GraphEmbeddingsRequest, GraphEmbeddingsResponse
-from .. knowledge import Uri, Literal
-
-def to_value(x):
- if x.is_uri: return Uri(x.value)
- return Literal(x.value)
-
-class GraphEmbeddingsClient(RequestResponse):
- async def query(self, vectors, limit=20, user="trustgraph",
- collection="default", timeout=30):
-
- resp = await self.request(
- GraphEmbeddingsRequest(
- vectors = vectors,
- limit = limit,
- user = user,
- collection = collection
- ),
- timeout=timeout
- )
-
- print(resp, flush=True)
-
- if resp.error:
- raise RuntimeError(resp.error.message)
-
- return [
- to_value(v)
- for v in resp.entities
- ]
-
-class GraphEmbeddingsClientSpec(RequestResponseSpec):
- def __init__(
- self, request_name, response_name,
- ):
- super(GraphEmbeddingsClientSpec, self).__init__(
- request_name = request_name,
- request_schema = GraphEmbeddingsRequest,
- response_name = response_name,
- response_schema = GraphEmbeddingsResponse,
- impl = GraphEmbeddingsClient,
- )
-
diff --git a/trustgraph-base/trustgraph/base/graph_embeddings_query_service.py b/trustgraph-base/trustgraph/base/graph_embeddings_query_service.py
deleted file mode 100644
index fb2e8dc5..00000000
--- a/trustgraph-base/trustgraph/base/graph_embeddings_query_service.py
+++ /dev/null
@@ -1,84 +0,0 @@
-
-"""
-Graph embeddings query service. Input is vectors. Output is list of
-embeddings.
-"""
-
-from .. schema import GraphEmbeddingsRequest, GraphEmbeddingsResponse
-from .. schema import Error, Value
-
-from . flow_processor import FlowProcessor
-from . consumer_spec import ConsumerSpec
-from . producer_spec import ProducerSpec
-
-default_ident = "ge-query"
-
-class GraphEmbeddingsQueryService(FlowProcessor):
-
- def __init__(self, **params):
-
- id = params.get("id")
-
- super(GraphEmbeddingsQueryService, self).__init__(
- **params | { "id": id }
- )
-
- self.register_specification(
- ConsumerSpec(
- name = "request",
- schema = GraphEmbeddingsRequest,
- handler = self.on_message
- )
- )
-
- self.register_specification(
- ProducerSpec(
- name = "response",
- schema = GraphEmbeddingsResponse,
- )
- )
-
- async def on_message(self, msg, consumer, flow):
-
- try:
-
- request = msg.value()
-
- # Sender-produced ID
- id = msg.properties()["id"]
-
- print(f"Handling input {id}...", flush=True)
-
- entities = await self.query_graph_embeddings(request)
-
- print("Send response...", flush=True)
- r = GraphEmbeddingsResponse(entities=entities, error=None)
- await flow("response").send(r, properties={"id": id})
-
- print("Done.", flush=True)
-
- except Exception as e:
-
- print(f"Exception: {e}")
-
- print("Send error response...", flush=True)
-
- r = GraphEmbeddingsResponse(
- error=Error(
- type = "graph-embeddings-query-error",
- message = str(e),
- ),
- response=None,
- )
-
- await flow("response").send(r, properties={"id": id})
-
- @staticmethod
- def add_args(parser):
-
- FlowProcessor.add_args(parser)
-
-def run():
-
- Processor.launch(default_ident, __doc__)
-
diff --git a/trustgraph-base/trustgraph/base/graph_embeddings_store_service.py b/trustgraph-base/trustgraph/base/graph_embeddings_store_service.py
deleted file mode 100644
index 911b90c1..00000000
--- a/trustgraph-base/trustgraph/base/graph_embeddings_store_service.py
+++ /dev/null
@@ -1,50 +0,0 @@
-
-"""
-Graph embeddings store base class
-"""
-
-from .. schema import GraphEmbeddings
-from .. base import FlowProcessor, ConsumerSpec
-from .. exceptions import TooManyRequests
-
-default_ident = "graph-embeddings-write"
-
-class GraphEmbeddingsStoreService(FlowProcessor):
-
- def __init__(self, **params):
-
- id = params.get("id")
-
- super(GraphEmbeddingsStoreService, self).__init__(
- **params | { "id": id }
- )
-
- self.register_specification(
- ConsumerSpec(
- name = "input",
- schema = GraphEmbeddings,
- handler = self.on_message
- )
- )
-
- async def on_message(self, msg, consumer, flow):
-
- try:
-
- request = msg.value()
-
- await self.store_graph_embeddings(request)
-
- except TooManyRequests as e:
- raise e
-
- except Exception as e:
-
- print(f"Exception: {e}")
- raise e
-
- @staticmethod
- def add_args(parser):
-
- FlowProcessor.add_args(parser)
-
diff --git a/trustgraph-base/trustgraph/base/graph_rag_client.py b/trustgraph-base/trustgraph/base/graph_rag_client.py
deleted file mode 100644
index c4f3f7ab..00000000
--- a/trustgraph-base/trustgraph/base/graph_rag_client.py
+++ /dev/null
@@ -1,33 +0,0 @@
-
-from . request_response_spec import RequestResponse, RequestResponseSpec
-from .. schema import GraphRagQuery, GraphRagResponse
-
-class GraphRagClient(RequestResponse):
- async def rag(self, query, user="trustgraph", collection="default",
- timeout=600):
- resp = await self.request(
- GraphRagQuery(
- query = query,
- user = user,
- collection = collection,
- ),
- timeout=timeout
- )
-
- if resp.error:
- raise RuntimeError(resp.error.message)
-
- return resp.response
-
-class GraphRagClientSpec(RequestResponseSpec):
- def __init__(
- self, request_name, response_name,
- ):
- super(GraphRagClientSpec, self).__init__(
- request_name = request_name,
- request_schema = GraphRagQuery,
- response_name = response_name,
- response_schema = GraphRagResponse,
- impl = GraphRagClient,
- )
-
diff --git a/trustgraph-base/trustgraph/base/llm_service.py b/trustgraph-base/trustgraph/base/llm_service.py
deleted file mode 100644
index 39323db7..00000000
--- a/trustgraph-base/trustgraph/base/llm_service.py
+++ /dev/null
@@ -1,114 +0,0 @@
-
-"""
-LLM text completion base class
-"""
-
-import time
-from prometheus_client import Histogram
-
-from .. schema import TextCompletionRequest, TextCompletionResponse, Error
-from .. exceptions import TooManyRequests
-from .. base import FlowProcessor, ConsumerSpec, ProducerSpec
-
-default_ident = "text-completion"
-
-class LlmResult:
- __slots__ = ["text", "in_token", "out_token", "model"]
-
-class LlmService(FlowProcessor):
-
- def __init__(self, **params):
-
- id = params.get("id")
-
- super(LlmService, self).__init__(**params | { "id": id })
-
- self.register_specification(
- ConsumerSpec(
- name = "request",
- schema = TextCompletionRequest,
- handler = self.on_request
- )
- )
-
- self.register_specification(
- ProducerSpec(
- name = "response",
- schema = TextCompletionResponse
- )
- )
-
- if not hasattr(__class__, "text_completion_metric"):
- __class__.text_completion_metric = Histogram(
- 'text_completion_duration',
- 'Text completion duration (seconds)',
- ["id", "flow"],
- buckets=[
- 0.25, 0.5, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0,
- 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
- 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0,
- 30.0, 35.0, 40.0, 45.0, 50.0, 60.0, 80.0, 100.0,
- 120.0
- ]
- )
-
- async def on_request(self, msg, consumer, flow):
-
- try:
-
- request = msg.value()
-
- # Sender-produced ID
-
- id = msg.properties()["id"]
-
- with __class__.text_completion_metric.labels(
- id=self.id,
- flow=f"{flow.name}-{consumer.name}",
- ).time():
-
- response = await self.generate_content(
- request.system, request.prompt
- )
-
- await flow("response").send(
- TextCompletionResponse(
- error=None,
- response=response.text,
- in_token=response.in_token,
- out_token=response.out_token,
- model=response.model
- ),
- properties={"id": id}
- )
-
- except TooManyRequests as e:
- raise e
-
- except Exception as e:
-
- # Apart from rate limits, treat all exceptions as unrecoverable
-
- print(f"Exception: {e}")
-
- print("Send error response...", flush=True)
-
- await flow.producer["response"].send(
- TextCompletionResponse(
- error=Error(
- type = "llm-error",
- message = str(e),
- ),
- response=None,
- in_token=None,
- out_token=None,
- model=None,
- ),
- properties={"id": id}
- )
-
- @staticmethod
- def add_args(parser):
-
- FlowProcessor.add_args(parser)
-
diff --git a/trustgraph-base/trustgraph/base/metrics.py b/trustgraph-base/trustgraph/base/metrics.py
deleted file mode 100644
index 4ffbac9c..00000000
--- a/trustgraph-base/trustgraph/base/metrics.py
+++ /dev/null
@@ -1,136 +0,0 @@
-
-from prometheus_client import start_http_server, Info, Enum, Histogram
-from prometheus_client import Counter
-
-class ConsumerMetrics:
-
- def __init__(self, processor, flow, name):
-
- self.processor = processor
- self.flow = flow
- self.name = name
-
- if not hasattr(__class__, "state_metric"):
- __class__.state_metric = Enum(
- 'consumer_state', 'Consumer state',
- ["processor", "flow", "name"],
- states=['stopped', 'running']
- )
-
- if not hasattr(__class__, "request_metric"):
- __class__.request_metric = Histogram(
- 'request_latency', 'Request latency (seconds)',
- ["processor", "flow", "name"],
- )
-
- if not hasattr(__class__, "processing_metric"):
- __class__.processing_metric = Counter(
- 'processing_count', 'Processing count',
- ["processor", "flow", "name", "status"],
- )
-
- if not hasattr(__class__, "rate_limit_metric"):
- __class__.rate_limit_metric = Counter(
- 'rate_limit_count', 'Rate limit event count',
- ["processor", "flow", "name"],
- )
-
- def process(self, status):
- __class__.processing_metric.labels(
- processor = self.processor, flow = self.flow, name = self.name,
- status=status
- ).inc()
-
- def rate_limit(self):
- __class__.rate_limit_metric.labels(
- processor = self.processor, flow = self.flow, name = self.name,
- ).inc()
-
- def state(self, state):
- __class__.state_metric.labels(
- processor = self.processor, flow = self.flow, name = self.name,
- ).state(state)
-
- def record_time(self):
- return __class__.request_metric.labels(
- processor = self.processor, flow = self.flow, name = self.name,
- ).time()
-
-class ProducerMetrics:
-
- def __init__(self, processor, flow, name):
-
- self.processor = processor
- self.flow = flow
- self.name = name
-
- if not hasattr(__class__, "producer_metric"):
- __class__.producer_metric = Counter(
- 'producer_count', 'Output items produced',
- ["processor", "flow", "name"],
- )
-
- def inc(self):
- __class__.producer_metric.labels(
- processor = self.processor, flow = self.flow, name = self.name
- ).inc()
-
-class ProcessorMetrics:
- def __init__(self, processor):
-
- self.processor = processor
-
- if not hasattr(__class__, "processor_metric"):
- __class__.processor_metric = Info(
- 'processor', 'Processor configuration',
- ["processor"]
- )
-
- def info(self, info):
- __class__.processor_metric.labels(
- processor = self.processor
- ).info(info)
-
-class SubscriberMetrics:
-
- def __init__(self, processor, flow, name):
-
- self.processor = processor
- self.flow = flow
- self.name = name
-
- if not hasattr(__class__, "state_metric"):
- __class__.state_metric = Enum(
- 'subscriber_state', 'Subscriber state',
- ["processor", "flow", "name"],
- states=['stopped', 'running']
- )
-
- if not hasattr(__class__, "received_metric"):
- __class__.received_metric = Counter(
- 'received_count', 'Received count',
- ["processor", "flow", "name"],
- )
-
- if not hasattr(__class__, "dropped_metric"):
- __class__.dropped_metric = Counter(
- 'dropped_count', 'Dropped messages count',
- ["processor", "flow", "name"],
- )
-
- def received(self):
- __class__.received_metric.labels(
- processor = self.processor, flow = self.flow, name = self.name,
- ).inc()
-
- def state(self, state):
-
- __class__.state_metric.labels(
- processor = self.processor, flow = self.flow, name = self.name,
- ).state(state)
-
- def dropped(self, state):
- __class__.dropped_metric.labels(
- processor = self.processor, flow = self.flow, name = self.name,
- ).inc()
-
diff --git a/trustgraph-base/trustgraph/base/producer.py b/trustgraph-base/trustgraph/base/producer.py
index bb665924..bc2d7791 100644
--- a/trustgraph-base/trustgraph/base/producer.py
+++ b/trustgraph-base/trustgraph/base/producer.py
@@ -1,69 +1,56 @@
from pulsar.schema import JsonSchema
-import asyncio
+from prometheus_client import Info, Counter
-class Producer:
+from . base_processor import BaseProcessor
- def __init__(self, client, topic, schema, metrics=None):
- self.client = client
- self.topic = topic
- self.schema = schema
+class Producer(BaseProcessor):
- self.metrics = metrics
+ def __init__(self, **params):
- self.running = True
- self.producer = None
+ output_queue = params.get("output_queue")
+ output_schema = params.get("output_schema")
- def __del__(self):
+ if not hasattr(__class__, "output_metric"):
+ __class__.output_metric = Counter(
+ 'output_count', 'Output items created'
+ )
- self.running = False
+ if not hasattr(__class__, "pubsub_metric"):
+ __class__.pubsub_metric = Info(
+ 'pubsub', 'Pub/sub configuration'
+ )
- if hasattr(self, "producer"):
- if self.producer:
- self.producer.close()
+ __class__.pubsub_metric.info({
+ "output_queue": output_queue,
+ "output_schema": output_schema.__name__,
+ })
- async def start(self):
- self.running = True
+ super(Producer, self).__init__(**params)
- async def stop(self):
- self.running = False
+ if output_schema == None:
+ raise RuntimeError("output_schema must be specified")
+
+ self.producer = self.client.create_producer(
+ topic=output_queue,
+ schema=JsonSchema(output_schema),
+ chunking_enabled=True,
+ )
async def send(self, msg, properties={}):
+ self.producer.send(msg, properties)
+ __class__.output_metric.inc()
- if not self.running: return
+ @staticmethod
+ def add_args(
+ parser, default_input_queue, default_subscriber,
+ default_output_queue,
+ ):
- while self.running and self.producer is None:
-
- try:
- print("Connect publisher to", self.topic, "...", flush=True)
- self.producer = self.client.create_producer(
- topic = self.topic,
- schema = JsonSchema(self.schema)
- )
- print("Connected to", self.topic, flush=True)
- except Exception as e:
- print("Exception:", e, flush=True)
- await asyncio.sleep(2)
-
- if not self.running: break
-
- while self.running:
-
- try:
-
- await asyncio.to_thread(
- self.producer.send,
- msg, properties
- )
-
- if self.metrics:
- self.metrics.inc()
-
- # Delivery success, break out of loop
- break
-
- except Exception as e:
- print("Exception:", e, flush=True)
- self.producer.close()
- self.producer = None
+ BaseProcessor.add_args(parser)
+ parser.add_argument(
+ '-o', '--output-queue',
+ default=default_output_queue,
+ help=f'Output queue (default: {default_output_queue})'
+ )
diff --git a/trustgraph-base/trustgraph/base/producer_spec.py b/trustgraph-base/trustgraph/base/producer_spec.py
deleted file mode 100644
index 9c8bbc6a..00000000
--- a/trustgraph-base/trustgraph/base/producer_spec.py
+++ /dev/null
@@ -1,25 +0,0 @@
-
-from . producer import Producer
-from . metrics import ProducerMetrics
-from . spec import Spec
-
-class ProducerSpec(Spec):
- def __init__(self, name, schema):
- self.name = name
- self.schema = schema
-
- def add(self, flow, processor, definition):
-
- producer_metrics = ProducerMetrics(
- processor = flow.id, flow = flow.name, name = self.name
- )
-
- producer = Producer(
- client = processor.pulsar_client,
- topic = definition[self.name],
- schema = self.schema,
- metrics = producer_metrics,
- )
-
- flow.producer[self.name] = producer
-
diff --git a/trustgraph-base/trustgraph/base/prompt_client.py b/trustgraph-base/trustgraph/base/prompt_client.py
deleted file mode 100644
index 9e8ab033..00000000
--- a/trustgraph-base/trustgraph/base/prompt_client.py
+++ /dev/null
@@ -1,93 +0,0 @@
-
-import json
-
-from . request_response_spec import RequestResponse, RequestResponseSpec
-from .. schema import PromptRequest, PromptResponse
-
-class PromptClient(RequestResponse):
-
- async def prompt(self, id, variables, timeout=600):
-
- resp = await self.request(
- PromptRequest(
- id = id,
- terms = {
- k: json.dumps(v)
- for k, v in variables.items()
- }
- ),
- timeout=timeout
- )
-
- if resp.error:
- raise RuntimeError(resp.error.message)
-
- if resp.text: return resp.text
-
- return json.loads(resp.object)
-
- async def extract_definitions(self, text, timeout=600):
- return await self.prompt(
- id = "extract-definitions",
- variables = { "text": text },
- timeout = timeout,
- )
-
- async def extract_relationships(self, text, timeout=600):
- return await self.prompt(
- id = "extract-relationships",
- variables = { "text": text },
- timeout = timeout,
- )
-
- async def kg_prompt(self, query, kg, timeout=600):
- return await self.prompt(
- id = "kg-prompt",
- variables = {
- "query": query,
- "knowledge": [
- { "s": v[0], "p": v[1], "o": v[2] }
- for v in kg
- ]
- },
- timeout = timeout,
- )
-
- async def document_prompt(self, query, documents, timeout=600):
- return await self.prompt(
- id = "document-prompt",
- variables = {
- "query": query,
- "documents": documents,
- },
- timeout = timeout,
- )
-
- async def agent_react(self, variables, timeout=600):
- return await self.prompt(
- id = "agent-react",
- variables = variables,
- timeout = timeout,
- )
-
- async def question(self, question, timeout=600):
- return await self.prompt(
- id = "question",
- variables = {
- "question": question,
- },
- timeout = timeout,
- )
-
-class PromptClientSpec(RequestResponseSpec):
- def __init__(
- self, request_name, response_name,
- ):
- super(PromptClientSpec, self).__init__(
- request_name = request_name,
- request_schema = PromptRequest,
- response_name = response_name,
- response_schema = PromptResponse,
- impl = PromptClient,
- )
-
diff --git a/trustgraph-base/trustgraph/base/publisher.py b/trustgraph-base/trustgraph/base/publisher.py
index ce9e364e..2da63331 100644
--- a/trustgraph-base/trustgraph/base/publisher.py
+++ b/trustgraph-base/trustgraph/base/publisher.py
@@ -1,52 +1,47 @@
-from pulsar.schema import JsonSchema
-
-import asyncio
+import queue
import time
import pulsar
+import threading
class Publisher:
- def __init__(self, client, topic, schema=None, max_size=10,
+ def __init__(self, pulsar_client, topic, schema=None, max_size=10,
chunking_enabled=True):
- self.client = client
+ self.client = pulsar_client
self.topic = topic
self.schema = schema
- self.q = asyncio.Queue(maxsize=max_size)
+ self.q = queue.Queue(maxsize=max_size)
self.chunking_enabled = chunking_enabled
self.running = True
- async def start(self):
- self.task = asyncio.create_task(self.run())
+ def start(self):
+ self.task = threading.Thread(target=self.run)
+ self.task.start()
- async def stop(self):
+ def stop(self):
self.running = False
- async def join(self):
- await self.stop()
- await self.task
+ def join(self):
+ self.stop()
+ self.task.join()
- async def run(self):
+ def run(self):
while self.running:
try:
producer = self.client.create_producer(
topic=self.topic,
- schema=JsonSchema(self.schema),
+ schema=self.schema,
chunking_enabled=self.chunking_enabled,
)
while self.running:
try:
- id, item = await asyncio.wait_for(
- self.q.get(),
- timeout=0.5
- )
- except asyncio.TimeoutError:
- continue
- except asyncio.QueueEmpty:
+ id, item = self.q.get(timeout=0.5)
+ except queue.Empty:
continue
if id:
@@ -60,6 +55,7 @@ class Publisher:
# If handler drops out, sleep a retry
time.sleep(2)
- async def send(self, id, item):
- await self.q.put((id, item))
+ def send(self, id, msg):
+ self.q.put((id, msg))
+
diff --git a/trustgraph-base/trustgraph/base/pubsub.py b/trustgraph-base/trustgraph/base/pubsub.py
deleted file mode 100644
index b9f233d4..00000000
--- a/trustgraph-base/trustgraph/base/pubsub.py
+++ /dev/null
@@ -1,80 +0,0 @@
-
-import os
-import pulsar
-import uuid
-from pulsar.schema import JsonSchema
-
-from .. log_level import LogLevel
-
-class PulsarClient:
-
- default_pulsar_host = os.getenv("PULSAR_HOST", 'pulsar://pulsar:6650')
- default_pulsar_api_key = os.getenv("PULSAR_API_KEY", None)
-
- def __init__(self, **params):
-
- self.client = None
-
- pulsar_host = params.get("pulsar_host", self.default_pulsar_host)
- pulsar_listener = params.get("pulsar_listener", None)
- pulsar_api_key = params.get(
- "pulsar_api_key",
- self.default_pulsar_api_key
- )
- log_level = params.get("log_level", LogLevel.INFO)
-
- self.pulsar_host = pulsar_host
- self.pulsar_api_key = pulsar_api_key
-
- if pulsar_api_key:
- auth = pulsar.AuthenticationToken(pulsar_api_key)
- self.client = pulsar.Client(
- pulsar_host,
- authentication=auth,
- logger=pulsar.ConsoleLogger(log_level.to_pulsar())
- )
- else:
- self.client = pulsar.Client(
- pulsar_host,
- listener_name=pulsar_listener,
- logger=pulsar.ConsoleLogger(log_level.to_pulsar())
- )
-
- self.pulsar_listener = pulsar_listener
-
- def close(self):
- self.client.close()
-
- def __del__(self):
-
- if hasattr(self, "client"):
- if self.client:
- self.client.close()
-
- @staticmethod
- def add_args(parser):
-
- parser.add_argument(
- '-p', '--pulsar-host',
- default=__class__.default_pulsar_host,
- help=f'Pulsar host (default: {__class__.default_pulsar_host})',
- )
-
- parser.add_argument(
- '--pulsar-api-key',
- default=__class__.default_pulsar_api_key,
- help=f'Pulsar API key',
- )
-
- parser.add_argument(
- '--pulsar-listener',
- help=f'Pulsar listener (default: none)',
- )
-
- parser.add_argument(
- '-l', '--log-level',
- type=LogLevel,
- default=LogLevel.INFO,
- choices=list(LogLevel),
- help=f'Output queue (default: info)'
- )
diff --git a/trustgraph-base/trustgraph/base/request_response_spec.py b/trustgraph-base/trustgraph/base/request_response_spec.py
deleted file mode 100644
index 88ee4563..00000000
--- a/trustgraph-base/trustgraph/base/request_response_spec.py
+++ /dev/null
@@ -1,141 +0,0 @@
-
-import uuid
-import asyncio
-
-from . subscriber import Subscriber
-from . producer import Producer
-from . spec import Spec
-from . metrics import ConsumerMetrics, ProducerMetrics, SubscriberMetrics
-
-class RequestResponse(Subscriber):
-
- def __init__(
- self, client, subscription, consumer_name,
- request_topic, request_schema,
- request_metrics,
- response_topic, response_schema,
- response_metrics,
- ):
-
- super(RequestResponse, self).__init__(
- client = client,
- subscription = subscription,
- consumer_name = consumer_name,
- topic = response_topic,
- schema = response_schema,
- metrics = response_metrics,
- )
-
- self.producer = Producer(
- client = client,
- topic = request_topic,
- schema = request_schema,
- metrics = request_metrics,
- )
-
- async def start(self):
- await self.producer.start()
- await super(RequestResponse, self).start()
-
- async def stop(self):
- await self.producer.stop()
- await super(RequestResponse, self).stop()
-
- async def request(self, req, timeout=300, recipient=None):
-
- id = str(uuid.uuid4())
-
- print("Request", id, "...", flush=True)
-
- q = await self.subscribe(id)
-
- try:
-
- await self.producer.send(
- req,
- properties={"id": id}
- )
-
- except Exception as e:
-
- print("Exception:", e)
- raise e
-
-
- try:
-
- while True:
-
- resp = await asyncio.wait_for(
- q.get(),
- timeout=timeout
- )
-
- print("Got response.", flush=True)
-
- if recipient is None:
-
- # If no recipient handler, just return the first
- # response we get
- return resp
- else:
-
- # Recipient handler gets to decide when we're done b
- # returning a boolean
- fin = await recipient(resp)
-
- # If done, return the last result otherwise loop round for
- # next response
- if fin:
- return resp
- else:
- continue
-
- except Exception as e:
-
- print("Exception:", e)
- raise e
-
- finally:
-
- await self.unsubscribe(id)
-
-# This deals with the request/response case. The caller needs to
-# use another service in request/response mode. Uses two topics:
-# - we send on the request topic as a producer
-# - we receive on the response topic as a subscriber
-class RequestResponseSpec(Spec):
- def __init__(
- self, request_name, request_schema, response_name,
- response_schema, impl=RequestResponse
- ):
- self.request_name = request_name
- self.request_schema = request_schema
- self.response_name = response_name
- self.response_schema = response_schema
- self.impl = impl
-
- def add(self, flow, processor, definition):
-
- request_metrics = ProducerMetrics(
- processor = flow.id, flow = flow.name, name = self.request_name
- )
-
- response_metrics = SubscriberMetrics(
- processor = flow.id, flow = flow.name, name = self.request_name
- )
-
- rr = self.impl(
- client = processor.pulsar_client,
- subscription = flow.id,
- consumer_name = flow.id,
- request_topic = definition[self.request_name],
- request_schema = self.request_schema,
- request_metrics = request_metrics,
- response_topic = definition[self.response_name],
- response_schema = self.response_schema,
- response_metrics = response_metrics,
- )
-
- flow.consumer[self.request_name] = rr
-
diff --git a/trustgraph-base/trustgraph/base/setting_spec.py b/trustgraph-base/trustgraph/base/setting_spec.py
deleted file mode 100644
index 5c5152b2..00000000
--- a/trustgraph-base/trustgraph/base/setting_spec.py
+++ /dev/null
@@ -1,19 +0,0 @@
-
-from . spec import Spec
-
-class Setting:
- def __init__(self, value):
- self.value = value
- async def start():
- pass
- async def stop():
- pass
-
-class SettingSpec(Spec):
- def __init__(self, name):
- self.name = name
-
- def add(self, flow, processor, definition):
-
- flow.config[self.name] = Setting(definition[self.name])
-
diff --git a/trustgraph-base/trustgraph/base/spec.py b/trustgraph-base/trustgraph/base/spec.py
deleted file mode 100644
index 4d0d937b..00000000
--- a/trustgraph-base/trustgraph/base/spec.py
+++ /dev/null
@@ -1,4 +0,0 @@
-
-class Spec:
- pass
-
diff --git a/trustgraph-base/trustgraph/base/subscriber.py b/trustgraph-base/trustgraph/base/subscriber.py
index 1cf263d4..30ade3ee 100644
--- a/trustgraph-base/trustgraph/base/subscriber.py
+++ b/trustgraph-base/trustgraph/base/subscriber.py
@@ -1,14 +1,14 @@
-from pulsar.schema import JsonSchema
-import asyncio
-import _pulsar
+import queue
+import pulsar
+import threading
import time
class Subscriber:
- def __init__(self, client, topic, subscription, consumer_name,
- schema=None, max_size=100, metrics=None):
- self.client = client
+ def __init__(self, pulsar_client, topic, subscription, consumer_name,
+ schema=None, max_size=100):
+ self.client = pulsar_client
self.topic = topic
self.subscription = subscription
self.consumer_name = consumer_name
@@ -16,60 +16,35 @@ class Subscriber:
self.q = {}
self.full = {}
self.max_size = max_size
- self.lock = asyncio.Lock()
+ self.lock = threading.Lock()
self.running = True
- self.metrics = metrics
- async def __del__(self):
+ def start(self):
+ self.task = threading.Thread(target=self.run)
+ self.task.start()
+
+ def stop(self):
self.running = False
- async def start(self):
- self.task = asyncio.create_task(self.run())
+ def join(self):
+ self.task.join()
- async def stop(self):
- self.running = False
-
- async def join(self):
- await self.stop()
- await self.task
-
- async def run(self):
+ def run(self):
while self.running:
- if self.metrics:
- self.metrics.state("stopped")
-
try:
consumer = self.client.subscribe(
- topic = self.topic,
- subscription_name = self.subscription,
- consumer_name = self.consumer_name,
- schema = JsonSchema(self.schema),
+ topic=self.topic,
+ subscription_name=self.subscription,
+ consumer_name=self.consumer_name,
+ schema=self.schema,
)
- if self.metrics:
- self.metrics.state("running")
-
- print("Subscriber running...", flush=True)
-
while self.running:
- try:
- msg = await asyncio.to_thread(
- consumer.receive,
- timeout_millis=2000
- )
- except _pulsar.Timeout:
- continue
- except Exception as e:
- print("Exception:", e, flush=True)
- print(type(e))
- raise e
-
- if self.metrics:
- self.metrics.received()
+ msg = consumer.receive()
# Acknowledge successful reception of the message
consumer.acknowledge(msg)
@@ -81,74 +56,57 @@ class Subscriber:
value = msg.value()
- async with self.lock:
-
- # FIXME: Hard-coded timeouts
+ with self.lock:
if id in self.q:
-
try:
# FIXME: Timeout means data goes missing
- await asyncio.wait_for(
- self.q[id].put(value),
- timeout=2
- )
-
- except Exception as e:
- self.metrics.dropped()
- print("Q Put:", e, flush=True)
+ self.q[id].put(value, timeout=0.5)
+ except:
+ pass
for q in self.full.values():
try:
# FIXME: Timeout means data goes missing
- await asyncio.wait_for(
- q.put(value),
- timeout=2
- )
- except Exception as e:
- self.metrics.dropped()
- print("Q Put:", e, flush=True)
+ q.put(value, timeout=0.5)
+ except:
+ pass
except Exception as e:
- print("Subscriber exception:", e, flush=True)
-
- consumer.close()
+ print("Exception:", e, flush=True)
- if self.metrics:
- self.metrics.state("stopped")
-
# If handler drops out, sleep a retry
time.sleep(2)
- async def subscribe(self, id):
+ def subscribe(self, id):
- async with self.lock:
+ with self.lock:
- q = asyncio.Queue(maxsize=self.max_size)
+ q = queue.Queue(maxsize=self.max_size)
self.q[id] = q
return q
- async def unsubscribe(self, id):
+ def unsubscribe(self, id):
- async with self.lock:
+ with self.lock:
if id in self.q:
# self.q[id].shutdown(immediate=True)
del self.q[id]
- async def subscribe_all(self, id):
+ def subscribe_all(self, id):
- async with self.lock:
+ with self.lock:
- q = asyncio.Queue(maxsize=self.max_size)
+ q = queue.Queue(maxsize=self.max_size)
self.full[id] = q
return q
- async def unsubscribe_all(self, id):
+ def unsubscribe_all(self, id):
- async with self.lock:
+ with self.lock:
if id in self.full:
# self.full[id].shutdown(immediate=True)
diff --git a/trustgraph-base/trustgraph/base/subscriber_spec.py b/trustgraph-base/trustgraph/base/subscriber_spec.py
deleted file mode 100644
index 7dca09db..00000000
--- a/trustgraph-base/trustgraph/base/subscriber_spec.py
+++ /dev/null
@@ -1,30 +0,0 @@
-
-from . metrics import SubscriberMetrics
-from . subscriber import Subscriber
-from . spec import Spec
-
-class SubscriberSpec(Spec):
-
- def __init__(self, name, schema):
- self.name = name
- self.schema = schema
-
- def add(self, flow, processor, definition):
-
- subscriber_metrics = SubscriberMetrics(
- processor = flow.id, flow = flow.name, name = self.name
- )
-
- subscriber = Subscriber(
- client = processor.pulsar_client,
- topic = definition[self.name],
- subscription = flow.id,
- consumer_name = flow.id,
- schema = self.schema,
- metrics = subscriber_metrics,
- )
-
- # Put it in the consumer map, does that work?
- # It means it gets start/stop call.
- flow.consumer[self.name] = subscriber
-
diff --git a/trustgraph-base/trustgraph/base/text_completion_client.py b/trustgraph-base/trustgraph/base/text_completion_client.py
deleted file mode 100644
index aba2fada..00000000
--- a/trustgraph-base/trustgraph/base/text_completion_client.py
+++ /dev/null
@@ -1,30 +0,0 @@
-
-from . request_response_spec import RequestResponse, RequestResponseSpec
-from .. schema import TextCompletionRequest, TextCompletionResponse
-
-class TextCompletionClient(RequestResponse):
- async def text_completion(self, system, prompt, timeout=600):
- resp = await self.request(
- TextCompletionRequest(
- system = system, prompt = prompt
- ),
- timeout=timeout
- )
-
- if resp.error:
- raise RuntimeError(resp.error.message)
-
- return resp.response
-
-class TextCompletionClientSpec(RequestResponseSpec):
- def __init__(
- self, request_name, response_name,
- ):
- super(TextCompletionClientSpec, self).__init__(
- request_name = request_name,
- request_schema = TextCompletionRequest,
- response_name = response_name,
- response_schema = TextCompletionResponse,
- impl = TextCompletionClient,
- )
-
diff --git a/trustgraph-base/trustgraph/base/triples_client.py b/trustgraph-base/trustgraph/base/triples_client.py
deleted file mode 100644
index c9f747b5..00000000
--- a/trustgraph-base/trustgraph/base/triples_client.py
+++ /dev/null
@@ -1,61 +0,0 @@
-
-from . request_response_spec import RequestResponse, RequestResponseSpec
-from .. schema import TriplesQueryRequest, TriplesQueryResponse, Value
-from .. knowledge import Uri, Literal
-
-class Triple:
- def __init__(self, s, p, o):
- self.s = s
- self.p = p
- self.o = o
-
-def to_value(x):
- if x.is_uri: return Uri(x.value)
- return Literal(x.value)
-
-def from_value(x):
- if x is None: return None
- if isinstance(x, Uri):
- return Value(value=str(x), is_uri=True)
- else:
- return Value(value=str(x), is_uri=False)
-
-class TriplesClient(RequestResponse):
- async def query(self, s=None, p=None, o=None, limit=20,
- user="trustgraph", collection="default",
- timeout=30):
-
- resp = await self.request(
- TriplesQueryRequest(
- s = from_value(s),
- p = from_value(p),
- o = from_value(o),
- limit = limit,
- user = user,
- collection = collection,
- ),
- timeout=timeout
- )
-
- if resp.error:
- raise RuntimeError(resp.error.message)
-
- triples = [
- Triple(to_value(v.s), to_value(v.p), to_value(v.o))
- for v in resp.triples
- ]
-
- return triples
-
-class TriplesClientSpec(RequestResponseSpec):
- def __init__(
- self, request_name, response_name,
- ):
- super(TriplesClientSpec, self).__init__(
- request_name = request_name,
- request_schema = TriplesQueryRequest,
- response_name = response_name,
- response_schema = TriplesQueryResponse,
- impl = TriplesClient,
- )
-
diff --git a/trustgraph-base/trustgraph/base/triples_query_service.py b/trustgraph-base/trustgraph/base/triples_query_service.py
deleted file mode 100644
index 37acc622..00000000
--- a/trustgraph-base/trustgraph/base/triples_query_service.py
+++ /dev/null
@@ -1,82 +0,0 @@
-
-"""
-Triples query service. Input is a (s, p, o) triple, some values may be
-null. Output is a list of triples.
-"""
-
-from .. schema import TriplesQueryRequest, TriplesQueryResponse, Error
-from .. schema import Value, Triple
-
-from . flow_processor import FlowProcessor
-from . consumer_spec import ConsumerSpec
-from . producer_spec import ProducerSpec
-
-default_ident = "triples-query"
-
-class TriplesQueryService(FlowProcessor):
-
- def __init__(self, **params):
-
- id = params.get("id")
-
- super(TriplesQueryService, self).__init__(**params | { "id": id })
-
- self.register_specification(
- ConsumerSpec(
- name = "request",
- schema = TriplesQueryRequest,
- handler = self.on_message
- )
- )
-
- self.register_specification(
- ProducerSpec(
- name = "response",
- schema = TriplesQueryResponse,
- )
- )
-
- async def on_message(self, msg, consumer, flow):
-
- try:
-
- request = msg.value()
-
- # Sender-produced ID
- id = msg.properties()["id"]
-
- print(f"Handling input {id}...", flush=True)
-
- triples = await self.query_triples(request)
-
- print("Send response...", flush=True)
- r = TriplesQueryResponse(triples=triples, error=None)
- await flow("response").send(r, properties={"id": id})
-
- print("Done.", flush=True)
-
- except Exception as e:
-
- print(f"Exception: {e}")
-
- print("Send error response...", flush=True)
-
- r = TriplesQueryResponse(
- error = Error(
- type = "triples-query-error",
- message = str(e),
- ),
- triples = None,
- )
-
- await flow("response").send(r, properties={"id": id})
-
- @staticmethod
- def add_args(parser):
-
- FlowProcessor.add_args(parser)
-
-def run():
-
- Processor.launch(default_ident, __doc__)
-
diff --git a/trustgraph-base/trustgraph/base/triples_store_service.py b/trustgraph-base/trustgraph/base/triples_store_service.py
deleted file mode 100644
index 74f95f57..00000000
--- a/trustgraph-base/trustgraph/base/triples_store_service.py
+++ /dev/null
@@ -1,47 +0,0 @@
-
-"""
-Triples store base class
-"""
-
-from .. schema import Triples
-from .. base import FlowProcessor, ConsumerSpec
-
-default_ident = "triples-write"
-
-class TriplesStoreService(FlowProcessor):
-
- def __init__(self, **params):
-
- id = params.get("id")
-
- super(TriplesStoreService, self).__init__(**params | { "id": id })
-
- self.register_specification(
- ConsumerSpec(
- name = "input",
- schema = Triples,
- handler = self.on_message
- )
- )
-
- async def on_message(self, msg, consumer, flow):
-
- try:
-
- request = msg.value()
-
- await self.store_triples(request)
-
- except TooManyRequests as e:
- raise e
-
- except Exception as e:
-
- print(f"Exception: {e}")
- raise e
-
- @staticmethod
- def add_args(parser):
-
- FlowProcessor.add_args(parser)
-
diff --git a/trustgraph-base/trustgraph/schema/__init__.py b/trustgraph-base/trustgraph/schema/__init__.py
index a9bb30a6..28e1a879 100644
--- a/trustgraph-base/trustgraph/schema/__init__.py
+++ b/trustgraph-base/trustgraph/schema/__init__.py
@@ -12,5 +12,5 @@ from . agent import *
from . lookup import *
from . library import *
from . config import *
-from . flows import *
+
diff --git a/trustgraph-base/trustgraph/schema/agent.py b/trustgraph-base/trustgraph/schema/agent.py
index ee20a9aa..9bcdde51 100644
--- a/trustgraph-base/trustgraph/schema/agent.py
+++ b/trustgraph-base/trustgraph/schema/agent.py
@@ -26,5 +26,12 @@ class AgentResponse(Record):
thought = String()
observation = String()
+agent_request_queue = topic(
+ 'agent', kind='non-persistent', namespace='request'
+)
+agent_response_queue = topic(
+ 'agent', kind='non-persistent', namespace='response'
+)
+
############################################################################
diff --git a/trustgraph-base/trustgraph/schema/config.py b/trustgraph-base/trustgraph/schema/config.py
index 3be63aa3..efe49182 100644
--- a/trustgraph-base/trustgraph/schema/config.py
+++ b/trustgraph-base/trustgraph/schema/config.py
@@ -2,7 +2,7 @@
from pulsar.schema import Record, Bytes, String, Boolean, Array, Map, Integer
from . topic import topic
-from . types import Error
+from . types import Error, RowSchema
############################################################################
diff --git a/trustgraph-base/trustgraph/schema/documents.py b/trustgraph-base/trustgraph/schema/documents.py
index e479371d..fd0049ee 100644
--- a/trustgraph-base/trustgraph/schema/documents.py
+++ b/trustgraph-base/trustgraph/schema/documents.py
@@ -11,6 +11,8 @@ class Document(Record):
metadata = Metadata()
data = Bytes()
+document_ingest_queue = topic('document-load')
+
############################################################################
# Text documents / text from PDF
@@ -19,6 +21,8 @@ class TextDocument(Record):
metadata = Metadata()
text = Bytes()
+text_ingest_queue = topic('text-document-load')
+
############################################################################
# Chunks of text
@@ -27,6 +31,8 @@ class Chunk(Record):
metadata = Metadata()
chunk = Bytes()
+chunk_ingest_queue = topic('chunk-load')
+
############################################################################
# Document embeddings are embeddings associated with a chunk
@@ -40,6 +46,8 @@ class DocumentEmbeddings(Record):
metadata = Metadata()
chunks = Array(ChunkEmbeddings())
+document_embeddings_store_queue = topic('document-embeddings-store')
+
############################################################################
# Doc embeddings query
@@ -54,3 +62,10 @@ class DocumentEmbeddingsResponse(Record):
error = Error()
documents = Array(Bytes())
+document_embeddings_request_queue = topic(
+ 'doc-embeddings', kind='non-persistent', namespace='request'
+)
+document_embeddings_response_queue = topic(
+ 'doc-embeddings', kind='non-persistent', namespace='response',
+)
+
diff --git a/trustgraph-base/trustgraph/schema/flows.py b/trustgraph-base/trustgraph/schema/flows.py
deleted file mode 100644
index 28b90f5d..00000000
--- a/trustgraph-base/trustgraph/schema/flows.py
+++ /dev/null
@@ -1,66 +0,0 @@
-
-from pulsar.schema import Record, Bytes, String, Boolean, Array, Map, Integer
-
-from . topic import topic
-from . types import Error
-
-############################################################################
-
-# Flow service:
-# list_classes() -> (classname[])
-# get_class(classname) -> (class)
-# put_class(class) -> (class)
-# delete_class(classname) -> ()
-#
-# list_flows() -> (flowid[])
-# get_flow(flowid) -> (flow)
-# start_flow(flowid, classname) -> ()
-# stop_flow(flowid) -> ()
-
-# Prompt services, abstract the prompt generation
-class FlowRequest(Record):
-
- operation = String() # list-classes, get-class, put-class, delete-class
- # list-flows, get-flow, start-flow, stop-flow
-
- # get_class, put_class, delete_class, start_flow
- class_name = String()
-
- # put_class
- class_definition = String()
-
- # start_flow
- description = String()
-
- # get_flow, start_flow, stop_flow
- flow_id = String()
-
-class FlowResponse(Record):
-
- # list_classes
- class_names = Array(String())
-
- # list_flows
- flow_ids = Array(String())
-
- # get_class
- class_definition = String()
-
- # get_flow
- flow = String()
-
- # get_flow
- description = String()
-
- # Everything
- error = Error()
-
-flow_request_queue = topic(
- 'flow', kind='non-persistent', namespace='request'
-)
-flow_response_queue = topic(
- 'flow', kind='non-persistent', namespace='response'
-)
-
-############################################################################
-
diff --git a/trustgraph-base/trustgraph/schema/graph.py b/trustgraph-base/trustgraph/schema/graph.py
index 97a99fbd..7c304e1d 100644
--- a/trustgraph-base/trustgraph/schema/graph.py
+++ b/trustgraph-base/trustgraph/schema/graph.py
@@ -18,6 +18,8 @@ class EntityContexts(Record):
metadata = Metadata()
entities = Array(EntityContext())
+entity_contexts_ingest_queue = topic('entity-contexts-load')
+
############################################################################
# Graph embeddings are embeddings associated with a graph entity
@@ -31,6 +33,8 @@ class GraphEmbeddings(Record):
metadata = Metadata()
entities = Array(EntityEmbeddings())
+graph_embeddings_store_queue = topic('graph-embeddings-store')
+
############################################################################
# Graph embeddings query
@@ -45,6 +49,13 @@ class GraphEmbeddingsResponse(Record):
error = Error()
entities = Array(Value())
+graph_embeddings_request_queue = topic(
+ 'graph-embeddings', kind='non-persistent', namespace='request'
+)
+graph_embeddings_response_queue = topic(
+ 'graph-embeddings', kind='non-persistent', namespace='response'
+)
+
############################################################################
# Graph triples
@@ -53,6 +64,8 @@ class Triples(Record):
metadata = Metadata()
triples = Array(Triple())
+triples_store_queue = topic('triples-store')
+
############################################################################
# Triples query
@@ -69,3 +82,9 @@ class TriplesQueryResponse(Record):
error = Error()
triples = Array(Triple())
+triples_request_queue = topic(
+ 'triples', kind='non-persistent', namespace='request'
+)
+triples_response_queue = topic(
+ 'triples', kind='non-persistent', namespace='response'
+)
diff --git a/trustgraph-base/trustgraph/schema/lookup.py b/trustgraph-base/trustgraph/schema/lookup.py
index a88d188e..d0a0517c 100644
--- a/trustgraph-base/trustgraph/schema/lookup.py
+++ b/trustgraph-base/trustgraph/schema/lookup.py
@@ -17,5 +17,26 @@ class LookupResponse(Record):
text = String()
error = Error()
+encyclopedia_lookup_request_queue = topic(
+ 'encyclopedia', kind='non-persistent', namespace='request'
+)
+encyclopedia_lookup_response_queue = topic(
+ 'encyclopedia', kind='non-persistent', namespace='response',
+)
+
+dbpedia_lookup_request_queue = topic(
+ 'dbpedia', kind='non-persistent', namespace='request'
+)
+dbpedia_lookup_response_queue = topic(
+ 'dbpedia', kind='non-persistent', namespace='response',
+)
+
+internet_search_request_queue = topic(
+ 'internet-search', kind='non-persistent', namespace='request'
+)
+internet_search_response_queue = topic(
+ 'internet-search', kind='non-persistent', namespace='response',
+)
+
############################################################################
diff --git a/trustgraph-base/trustgraph/schema/models.py b/trustgraph-base/trustgraph/schema/models.py
index ea3b9128..a634e1c4 100644
--- a/trustgraph-base/trustgraph/schema/models.py
+++ b/trustgraph-base/trustgraph/schema/models.py
@@ -19,6 +19,13 @@ class TextCompletionResponse(Record):
out_token = Integer()
model = String()
+text_completion_request_queue = topic(
+ 'text-completion', kind='non-persistent', namespace='request'
+)
+text_completion_response_queue = topic(
+ 'text-completion', kind='non-persistent', namespace='response'
+)
+
############################################################################
# Embeddings
@@ -30,3 +37,9 @@ class EmbeddingsResponse(Record):
error = Error()
vectors = Array(Array(Double()))
+embeddings_request_queue = topic(
+ 'embeddings', kind='non-persistent', namespace='request'
+)
+embeddings_response_queue = topic(
+ 'embeddings', kind='non-persistent', namespace='response'
+)
diff --git a/trustgraph-base/trustgraph/schema/object.py b/trustgraph-base/trustgraph/schema/object.py
index 6667fdf3..60c2bdc3 100644
--- a/trustgraph-base/trustgraph/schema/object.py
+++ b/trustgraph-base/trustgraph/schema/object.py
@@ -18,6 +18,8 @@ class ObjectEmbeddings(Record):
key_name = String()
id = String()
+object_embeddings_store_queue = topic('object-embeddings-store')
+
############################################################################
# Stores rows of information
@@ -27,5 +29,5 @@ class Rows(Record):
row_schema = RowSchema()
rows = Array(Map(String()))
-
+rows_store_queue = topic('rows-store')
diff --git a/trustgraph-base/trustgraph/schema/prompt.py b/trustgraph-base/trustgraph/schema/prompt.py
index 369ace53..15eddea8 100644
--- a/trustgraph-base/trustgraph/schema/prompt.py
+++ b/trustgraph-base/trustgraph/schema/prompt.py
@@ -55,5 +55,12 @@ class PromptResponse(Record):
# JSON encoded
object = String()
+prompt_request_queue = topic(
+ 'prompt', kind='non-persistent', namespace='request'
+)
+prompt_response_queue = topic(
+ 'prompt', kind='non-persistent', namespace='response'
+)
+
############################################################################
diff --git a/trustgraph-base/trustgraph/schema/retrieval.py b/trustgraph-base/trustgraph/schema/retrieval.py
index 1077e4f9..caeb8e67 100644
--- a/trustgraph-base/trustgraph/schema/retrieval.py
+++ b/trustgraph-base/trustgraph/schema/retrieval.py
@@ -20,6 +20,13 @@ class GraphRagResponse(Record):
error = Error()
response = String()
+graph_rag_request_queue = topic(
+ 'graph-rag', kind='non-persistent', namespace='request'
+)
+graph_rag_response_queue = topic(
+ 'graph-rag', kind='non-persistent', namespace='response'
+)
+
############################################################################
# Document RAG text retrieval
@@ -34,3 +41,9 @@ class DocumentRagResponse(Record):
error = Error()
response = String()
+document_rag_request_queue = topic(
+ 'doc-rag', kind='non-persistent', namespace='request'
+)
+document_rag_response_queue = topic(
+ 'doc-rag', kind='non-persistent', namespace='response'
+)
diff --git a/trustgraph-bedrock/trustgraph/model/text_completion/bedrock/llm.py b/trustgraph-bedrock/trustgraph/model/text_completion/bedrock/llm.py
index 572e01b7..9b8818a2 100755
--- a/trustgraph-bedrock/trustgraph/model/text_completion/bedrock/llm.py
+++ b/trustgraph-bedrock/trustgraph/model/text_completion/bedrock/llm.py
@@ -17,7 +17,7 @@ from .... log_level import LogLevel
from .... base import ConsumerProducer
from .... exceptions import TooManyRequests
-module = "text-completion"
+module = ".".join(__name__.split(".")[1:-1])
default_input_queue = text_completion_request_queue
default_output_queue = text_completion_response_queue
diff --git a/trustgraph-cli/scripts/tg-delete-flow-class b/trustgraph-cli/scripts/tg-delete-flow-class
deleted file mode 100755
index 345fe00f..00000000
--- a/trustgraph-cli/scripts/tg-delete-flow-class
+++ /dev/null
@@ -1,52 +0,0 @@
-#!/usr/bin/env python3
-
-"""
-"""
-
-import argparse
-import os
-import tabulate
-from trustgraph.api import Api
-import json
-
-default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/')
-
-def delete_flow_class(url, class_name):
-
- api = Api(url)
-
- class_names = api.flow_delete_class(class_name)
-
-def main():
-
- parser = argparse.ArgumentParser(
- prog='tg-delete-flow-class',
- description=__doc__,
- )
-
- parser.add_argument(
- '-u', '--api-url',
- default=default_url,
- help=f'API URL (default: {default_url})',
- )
-
- parser.add_argument(
- '-n', '--class-name',
- help=f'Flow class name',
- )
-
- args = parser.parse_args()
-
- try:
-
- delete_flow_class(
- url=args.api_url,
- class_name=args.class_name,
- )
-
- except Exception as e:
-
- print("Exception:", e, flush=True)
-
-main()
-
diff --git a/trustgraph-cli/scripts/tg-get-flow-class b/trustgraph-cli/scripts/tg-get-flow-class
deleted file mode 100755
index 450f1df7..00000000
--- a/trustgraph-cli/scripts/tg-get-flow-class
+++ /dev/null
@@ -1,56 +0,0 @@
-#!/usr/bin/env python3
-
-"""
-Dumps out the current configuration
-"""
-
-import argparse
-import os
-import tabulate
-from trustgraph.api import Api
-import json
-
-default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/')
-
-def get_flow_class(url, class_name):
-
- api = Api(url)
-
- cls = api.flow_get_class(class_name)
-
- print(json.dumps(cls, indent=4))
-
-def main():
-
- parser = argparse.ArgumentParser(
- prog='tg-get-flow-class',
- description=__doc__,
- )
-
- parser.add_argument(
- '-u', '--api-url',
- default=default_url,
- help=f'API URL (default: {default_url})',
- )
-
- parser.add_argument(
- '-n', '--class-name',
- required=True,
- help=f'Flow class name',
- )
-
- args = parser.parse_args()
-
- try:
-
- get_flow_class(
- url=args.api_url,
- class_name=args.class_name,
- )
-
- except Exception as e:
-
- print("Exception:", e, flush=True)
-
-main()
-
diff --git a/trustgraph-cli/scripts/tg-put-flow-class b/trustgraph-cli/scripts/tg-put-flow-class
deleted file mode 100755
index ca048e1f..00000000
--- a/trustgraph-cli/scripts/tg-put-flow-class
+++ /dev/null
@@ -1,59 +0,0 @@
-#!/usr/bin/env python3
-
-"""
-Dumps out the current configuration
-"""
-
-import argparse
-import os
-import tabulate
-from trustgraph.api import Api
-import json
-
-default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/')
-
-def put_flow_class(url, class_name, config):
-
- api = Api(url)
-
- class_names = api.flow_put_class(class_name, config)
-
-def main():
-
- parser = argparse.ArgumentParser(
- prog='tg-put-flow-class',
- description=__doc__,
- )
-
- parser.add_argument(
- '-u', '--api-url',
- default=default_url,
- help=f'API URL (default: {default_url})',
- )
-
- parser.add_argument(
- '-n', '--class-name',
- help=f'Flow class name',
- )
-
- parser.add_argument(
- '-c', '--config',
- help=f'Initial configuration to load',
- )
-
- args = parser.parse_args()
-
- try:
-
- put_flow_class(
- url=args.api_url,
- class_name=args.class_name,
- config=json.loads(args.config),
- )
-
- except Exception as e:
-
- print("Exception:", e, flush=True)
-
-main()
-
diff --git a/trustgraph-cli/scripts/tg-show-flow-classes b/trustgraph-cli/scripts/tg-show-flow-classes
deleted file mode 100755
index a3671184..00000000
--- a/trustgraph-cli/scripts/tg-show-flow-classes
+++ /dev/null
@@ -1,68 +0,0 @@
-#!/usr/bin/env python3
-
-"""
-"""
-
-import argparse
-import os
-import tabulate
-from trustgraph.api import Api
-import json
-
-default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/')
-
-def show_flow_classes(url):
-
- api = Api(url)
-
- class_names = api.flow_list_classes()
-
- if len(class_names) == 0:
- print("No flows.")
- return
-
- classes = []
-
- for class_name in class_names:
- cls = api.flow_get_class(class_name)
- classes.append((
- class_name,
- cls.get("description", ""),
- ", ".join(cls.get("tags", [])),
- ))
-
- print(tabulate.tabulate(
- classes,
- tablefmt="pretty",
- maxcolwidths=[None, 40, 20],
- stralign="left",
- headers = ["flow class", "description", "tags"],
- ))
-
-def main():
-
- parser = argparse.ArgumentParser(
- prog='tg-show-flow-classes',
- description=__doc__,
- )
-
- parser.add_argument(
- '-u', '--api-url',
- default=default_url,
- help=f'API URL (default: {default_url})',
- )
-
- args = parser.parse_args()
-
- try:
-
- show_flow_classes(
- url=args.api_url,
- )
-
- except Exception as e:
-
- print("Exception:", e, flush=True)
-
-main()
-
diff --git a/trustgraph-cli/scripts/tg-show-flows b/trustgraph-cli/scripts/tg-show-flows
deleted file mode 100755
index 2a090013..00000000
--- a/trustgraph-cli/scripts/tg-show-flows
+++ /dev/null
@@ -1,111 +0,0 @@
-#!/usr/bin/env python3
-
-"""
-"""
-
-import argparse
-import os
-import tabulate
-from trustgraph.api import Api, ConfigKey
-import json
-
-default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/')
-
-def get_interface(api, i):
-
- key = ConfigKey("interface-descriptions", i)
-
- value = api.config_get([key])[0].value
-
- return json.loads(value)
-
-def describe_interfaces(intdefs, flow):
-
- intfs = flow.get("interfaces", {})
-
- lst = []
-
- for k, v in intdefs.items():
-
- if intdefs[k].get("visible", False):
-
- label = intdefs[k].get("description", k)
- kind = intdefs[k].get("kind", None)
-
- if kind == "request-response":
- req = intfs[k]["request"]
- resp = intfs[k]["request"]
-
- lst.append(f"{k} request: {req}")
- lst.append(f"{k} response: {resp}")
-
- if kind == "send":
- q = intfs[k]
-
- lst.append(f"{k}: {q}")
-
- return "\n".join(lst)
-
-def show_flows(url):
-
- api = Api(url)
-
- interface_names = api.config_list("interface-descriptions")
-
- interface_defs = {
- i: get_interface(api, i)
- for i in interface_names
- }
-
- flow_ids = api.flow_list()
-
- if len(flow_ids) == 0:
- print("No flows.")
- return
-
- flows = []
-
- for id in flow_ids:
-
- flow = api.flow_get(id)
-
- table = []
- table.append(("id", id))
- table.append(("class", flow.get("class-name", "")))
- table.append(("desc", flow.get("description", "")))
- table.append(("queue", describe_interfaces(interface_defs, flow)))
-
- print(tabulate.tabulate(
- table,
- tablefmt="pretty",
- stralign="left",
- ))
- print()
-
-def main():
-
- parser = argparse.ArgumentParser(
- prog='tg-show-flows',
- description=__doc__,
- )
-
- parser.add_argument(
- '-u', '--api-url',
- default=default_url,
- help=f'API URL (default: {default_url})',
- )
-
- args = parser.parse_args()
-
- try:
-
- show_flows(
- url=args.api_url,
- )
-
- except Exception as e:
-
- print("Exception:", e, flush=True)
-
-main()
-
diff --git a/trustgraph-cli/scripts/tg-start-flow b/trustgraph-cli/scripts/tg-start-flow
deleted file mode 100755
index 377b7963..00000000
--- a/trustgraph-cli/scripts/tg-start-flow
+++ /dev/null
@@ -1,71 +0,0 @@
-#!/usr/bin/env python3
-
-"""
-"""
-
-import argparse
-import os
-import tabulate
-from trustgraph.api import Api
-import json
-
-default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/')
-
-def start_flow(url, class_name, flow_id, description):
-
- api = Api(url)
-
- api.flow_start(
- class_name = class_name,
- id = flow_id,
- description = description,
- )
-
-def main():
-
- parser = argparse.ArgumentParser(
- prog='tg-start-flow',
- description=__doc__,
- )
-
- parser.add_argument(
- '-u', '--api-url',
- default=default_url,
- help=f'API URL (default: {default_url})',
- )
-
- parser.add_argument(
- '-n', '--class-name',
- required=True,
- help=f'Flow class name',
- )
-
- parser.add_argument(
- '-i', '--flow-id',
- required=True,
- help=f'Flow ID',
- )
-
- parser.add_argument(
- '-d', '--description',
- required=True,
- help=f'Flow description',
- )
-
- args = parser.parse_args()
-
- try:
-
- start_flow(
- url = args.api_url,
- class_name = args.class_name,
- flow_id = args.flow_id,
- description = args.description,
- )
-
- except Exception as e:
-
- print("Exception:", e, flush=True)
-
-main()
-
diff --git a/trustgraph-cli/scripts/tg-stop-flow b/trustgraph-cli/scripts/tg-stop-flow
deleted file mode 100755
index cdbaf6ee..00000000
--- a/trustgraph-cli/scripts/tg-stop-flow
+++ /dev/null
@@ -1,53 +0,0 @@
-#!/usr/bin/env python3
-
-"""
-"""
-
-import argparse
-import os
-import tabulate
-from trustgraph.api import Api
-import json
-
-default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/')
-
-def stop_flow(url, flow_id):
-
- api = Api(url)
-
- api.flow_stop(id = flow_id)
-
-def main():
-
- parser = argparse.ArgumentParser(
- prog='tg-stop-flow',
- description=__doc__,
- )
-
- parser.add_argument(
- '-u', '--api-url',
- default=default_url,
- help=f'API URL (default: {default_url})',
- )
-
- parser.add_argument(
- '-i', '--flow-id',
- required=True,
- help=f'Flow ID',
- )
-
- args = parser.parse_args()
-
- try:
-
- stop_flow(
- url=args.api_url,
- flow_id=args.flow_id,
- )
-
- except Exception as e:
-
- print("Exception:", e, flush=True)
-
-main()
-
diff --git a/trustgraph-cli/setup.py b/trustgraph-cli/setup.py
index a16bd732..dda401ec 100644
--- a/trustgraph-cli/setup.py
+++ b/trustgraph-cli/setup.py
@@ -63,13 +63,6 @@ setuptools.setup(
"scripts/tg-save-kg-core",
"scripts/tg-save-doc-embeds",
"scripts/tg-show-config",
- "scripts/tg-show-flows",
- "scripts/tg-show-flow-classes",
- "scripts/tg-get-flow-class",
- "scripts/tg-start-flow",
- "scripts/tg-stop-flow",
- "scripts/tg-delete-flow-class",
- "scripts/tg-put-flow-class",
"scripts/tg-set-prompt",
"scripts/tg-show-tools",
"scripts/tg-show-prompts",
diff --git a/trustgraph-embeddings-hf/trustgraph/embeddings/hf/hf.py b/trustgraph-embeddings-hf/trustgraph/embeddings/hf/hf.py
index 0ab3cef9..2e44821e 100755
--- a/trustgraph-embeddings-hf/trustgraph/embeddings/hf/hf.py
+++ b/trustgraph-embeddings-hf/trustgraph/embeddings/hf/hf.py
@@ -4,37 +4,89 @@ Embeddings service, applies an embeddings model selected from HuggingFace.
Input is text, output is embeddings vector.
"""
-from ... base import EmbeddingsService
-
from langchain_huggingface import HuggingFaceEmbeddings
-default_ident = "embeddings"
+from trustgraph.schema import EmbeddingsRequest, EmbeddingsResponse, Error
+from trustgraph.schema import embeddings_request_queue
+from trustgraph.schema import embeddings_response_queue
+from trustgraph.log_level import LogLevel
+from trustgraph.base import ConsumerProducer
+module = ".".join(__name__.split(".")[1:-1])
+
+default_input_queue = embeddings_request_queue
+default_output_queue = embeddings_response_queue
+default_subscriber = module
default_model="all-MiniLM-L6-v2"
-class Processor(EmbeddingsService):
+class Processor(ConsumerProducer):
def __init__(self, **params):
+ input_queue = params.get("input_queue", default_input_queue)
+ output_queue = params.get("output_queue", default_output_queue)
+ subscriber = params.get("subscriber", default_subscriber)
model = params.get("model", default_model)
super(Processor, self).__init__(
- **params | { "model": model }
+ **params | {
+ "input_queue": input_queue,
+ "output_queue": output_queue,
+ "subscriber": subscriber,
+ "input_schema": EmbeddingsRequest,
+ "output_schema": EmbeddingsResponse,
+ }
)
- print("Get model...", flush=True)
self.embeddings = HuggingFaceEmbeddings(model_name=model)
- async def on_embeddings(self, text):
+ async def handle(self, msg):
- embeds = self.embeddings.embed_documents([text])
- print("Done.", flush=True)
- return embeds
+ v = msg.value()
+
+ # Sender-produced ID
+ id = msg.properties()["id"]
+
+ print(f"Handling input {id}...", flush=True)
+
+ try:
+
+ text = v.text
+ embeds = self.embeddings.embed_documents([text])
+
+ print("Send response...", flush=True)
+ r = EmbeddingsResponse(vectors=embeds, error=None)
+ await self.send(r, properties={"id": id})
+
+ print("Done.", flush=True)
+
+
+ except Exception as e:
+
+ print(f"Exception: {e}")
+
+ print("Send error response...", flush=True)
+
+ r = EmbeddingsResponse(
+ error=Error(
+ type = "llm-error",
+ message = str(e),
+ ),
+ response=None,
+ )
+
+ await self.send(r, properties={"id": id})
+
+ self.consumer.acknowledge(msg)
+
@staticmethod
def add_args(parser):
- EmbeddingsService.add_args(parser)
+ ConsumerProducer.add_args(
+ parser, default_input_queue, default_subscriber,
+ default_output_queue,
+ )
parser.add_argument(
'-m', '--model',
@@ -44,5 +96,5 @@ class Processor(EmbeddingsService):
def run():
- Processor.launch(default_ident, __doc__)
+ Processor.launch(module, __doc__)
diff --git a/trustgraph-flow/trustgraph/agent/react/agent_manager.py b/trustgraph-flow/trustgraph/agent/react/agent_manager.py
index d20b86f7..a195bd80 100644
--- a/trustgraph-flow/trustgraph/agent/react/agent_manager.py
+++ b/trustgraph-flow/trustgraph/agent/react/agent_manager.py
@@ -8,11 +8,12 @@ logger = logging.getLogger(__name__)
class AgentManager:
- def __init__(self, tools, additional_context=None):
+ def __init__(self, context, tools, additional_context=None):
+ self.context = context
self.tools = tools
self.additional_context = additional_context
- async def reason(self, question, history, context):
+ def reason(self, question, history):
tools = self.tools
@@ -55,7 +56,10 @@ class AgentManager:
logger.info(f"prompt: {variables}")
- obj = await context("prompt-request").agent_react(variables)
+ obj = self.context.prompt.request(
+ "agent-react",
+ variables
+ )
print(json.dumps(obj, indent=4), flush=True)
@@ -81,13 +85,9 @@ class AgentManager:
return a
- async def react(self, question, history, think, observe, context):
+ async def react(self, question, history, think, observe):
- act = await self.reason(
- question = question,
- history = history,
- context = context,
- )
+ act = self.reason(question, history)
logger.info(f"act: {act}")
if isinstance(act, Final):
@@ -104,12 +104,7 @@ class AgentManager:
else:
raise RuntimeError(f"No action for {act.name}!")
- print("TOOL>>>", act)
- resp = await action.implementation(context).invoke(
- **act.arguments
- )
-
- print("RSETUL", resp)
+ resp = action.implementation.invoke(**act.arguments)
resp = resp.strip()
diff --git a/trustgraph-flow/trustgraph/agent/react/service.py b/trustgraph-flow/trustgraph/agent/react/service.py
index beb17fd4..224efe3c 100755
--- a/trustgraph-flow/trustgraph/agent/react/service.py
+++ b/trustgraph-flow/trustgraph/agent/react/service.py
@@ -6,68 +6,103 @@ import json
import re
import sys
-from ... base import AgentService, TextCompletionClientSpec, PromptClientSpec
-from ... base import GraphRagClientSpec
+from pulsar.schema import JsonSchema
-from ... schema import AgentRequest, AgentResponse, AgentStep, Error
+from ... base import ConsumerProducer
+from ... schema import Error
+from ... schema import AgentRequest, AgentResponse, AgentStep
+from ... schema import agent_request_queue, agent_response_queue
+from ... schema import prompt_request_queue as pr_request_queue
+from ... schema import prompt_response_queue as pr_response_queue
+from ... schema import graph_rag_request_queue as gr_request_queue
+from ... schema import graph_rag_response_queue as gr_response_queue
+from ... clients.prompt_client import PromptClient
+from ... clients.llm_client import LlmClient
+from ... clients.graph_rag_client import GraphRagClient
from . tools import KnowledgeQueryImpl, TextCompletionImpl
from . agent_manager import AgentManager
from . types import Final, Action, Tool, Argument
-default_ident = "agent-manager"
-default_max_iterations = 10
+module = ".".join(__name__.split(".")[1:-1])
-class Processor(AgentService):
+default_input_queue = agent_request_queue
+default_output_queue = agent_response_queue
+default_subscriber = module
+default_max_iterations = 15
+
+class Processor(ConsumerProducer):
def __init__(self, **params):
- id = params.get("id")
-
self.max_iterations = int(
params.get("max_iterations", default_max_iterations)
)
+ tools = {}
+
+ input_queue = params.get("input_queue", default_input_queue)
+ output_queue = params.get("output_queue", default_output_queue)
+ subscriber = params.get("subscriber", default_subscriber)
+ prompt_request_queue = params.get(
+ "prompt_request_queue", pr_request_queue
+ )
+ prompt_response_queue = params.get(
+ "prompt_response_queue", pr_response_queue
+ )
+ graph_rag_request_queue = params.get(
+ "graph_rag_request_queue", gr_request_queue
+ )
+ graph_rag_response_queue = params.get(
+ "graph_rag_response_queue", gr_response_queue
+ )
+
self.config_key = params.get("config_type", "agent")
super(Processor, self).__init__(
**params | {
- "id": id,
- "max_iterations": self.max_iterations,
- "config_type": self.config_key,
+ "input_queue": input_queue,
+ "output_queue": output_queue,
+ "subscriber": subscriber,
+ "input_schema": AgentRequest,
+ "output_schema": AgentResponse,
+ "prompt_request_queue": prompt_request_queue,
+ "prompt_response_queue": prompt_response_queue,
+ "graph_rag_request_queue": gr_request_queue,
+ "graph_rag_response_queue": gr_response_queue,
}
)
+ self.prompt = PromptClient(
+ subscriber=subscriber,
+ input_queue=prompt_request_queue,
+ output_queue=prompt_response_queue,
+ pulsar_host = self.pulsar_host,
+ pulsar_api_key=self.pulsar_api_key,
+ )
+
+ self.graph_rag = GraphRagClient(
+ subscriber=subscriber,
+ input_queue=graph_rag_request_queue,
+ output_queue=graph_rag_response_queue,
+ pulsar_host = self.pulsar_host,
+ pulsar_api_key=self.pulsar_api_key,
+ )
+
+ # Need to be able to feed requests to myself
+ self.recursive_input = self.client.create_producer(
+ topic=input_queue,
+ schema=JsonSchema(AgentRequest),
+ )
+
self.agent = AgentManager(
+ context=self,
tools=[],
additional_context="",
)
- self.config_handlers.append(self.on_tools_config)
-
- self.register_specification(
- TextCompletionClientSpec(
- request_name = "text-completion-request",
- response_name = "text-completion-response",
- )
- )
-
- self.register_specification(
- GraphRagClientSpec(
- request_name = "graph-rag-request",
- response_name = "graph-rag-response",
- )
- )
-
- self.register_specification(
- PromptClientSpec(
- request_name = "prompt-request",
- response_name = "prompt-response",
- )
- )
-
- async def on_tools_config(self, config, version):
+ async def on_config(self, version, config):
print("Loading configuration version", version)
@@ -103,9 +138,9 @@ class Processor(AgentService):
impl_id = data.get("type")
if impl_id == "knowledge-query":
- impl = KnowledgeQueryImpl
+ impl = KnowledgeQueryImpl(self)
elif impl_id == "text-completion":
- impl = TextCompletionImpl
+ impl = TextCompletionImpl(self)
else:
raise RuntimeError(
f"Tool-kind {impl_id} not known"
@@ -120,6 +155,7 @@ class Processor(AgentService):
)
self.agent = AgentManager(
+ context=self,
tools=tools,
additional_context=additional
)
@@ -128,14 +164,19 @@ class Processor(AgentService):
except Exception as e:
- print("on_tools_config Exception:", e, flush=True)
+ print("Exception:", e, flush=True)
print("Configuration reload failed", flush=True)
- async def agent_request(self, request, respond, next, flow):
+ async def handle(self, msg):
try:
- if request.history:
+ v = msg.value()
+
+ # Sender-produced ID
+ id = msg.properties()["id"]
+
+ if v.history:
history = [
Action(
thought=h.thought,
@@ -143,12 +184,12 @@ class Processor(AgentService):
arguments=h.arguments,
observation=h.observation
)
- for h in request.history
+ for h in v.history
]
else:
history = []
- print(f"Question: {request.question}", flush=True)
+ print(f"Question: {v.question}", flush=True)
if len(history) >= self.max_iterations:
raise RuntimeError("Too many agent iterations")
@@ -166,7 +207,7 @@ class Processor(AgentService):
observation=None,
)
- await respond(r)
+ await self.send(r, properties={"id": id})
async def observe(x):
@@ -179,21 +220,15 @@ class Processor(AgentService):
observation=x,
)
- await respond(r)
+ await self.send(r, properties={"id": id})
- act = await self.agent.react(
- question = request.question,
- history = history,
- think = think,
- observe = observe,
- context = flow,
- )
+ act = await self.agent.react(v.question, history, think, observe)
print(f"Action: {act}", flush=True)
- if isinstance(act, Final):
+ print("Send response...", flush=True)
- print("Send final response...", flush=True)
+ if type(act) == Final:
r = AgentResponse(
answer=act.final,
@@ -201,20 +236,18 @@ class Processor(AgentService):
thought=None,
)
- await respond(r)
+ await self.send(r, properties={"id": id})
print("Done.", flush=True)
return
- print("Send next...", flush=True)
-
history.append(act)
r = AgentRequest(
- question=request.question,
- plan=request.plan,
- state=request.state,
+ question=v.question,
+ plan=v.plan,
+ state=v.state,
history=[
AgentStep(
thought=h.thought,
@@ -226,7 +259,7 @@ class Processor(AgentService):
]
)
- await next(r)
+ self.recursive_input.send(r, properties={"id": id})
print("Done.", flush=True)
@@ -234,7 +267,7 @@ class Processor(AgentService):
except Exception as e:
- print(f"agent_request Exception: {e}")
+ print(f"Exception: {e}")
print("Send error response...", flush=True)
@@ -246,12 +279,39 @@ class Processor(AgentService):
response=None,
)
- await respond(r)
+ await self.send(r, properties={"id": id})
@staticmethod
def add_args(parser):
- AgentService.add_args(parser)
+ ConsumerProducer.add_args(
+ parser, default_input_queue, default_subscriber,
+ default_output_queue,
+ )
+
+ parser.add_argument(
+ '--prompt-request-queue',
+ default=pr_request_queue,
+ help=f'Prompt request queue (default: {pr_request_queue})',
+ )
+
+ parser.add_argument(
+ '--prompt-response-queue',
+ default=pr_response_queue,
+ help=f'Prompt response queue (default: {pr_response_queue})',
+ )
+
+ parser.add_argument(
+ '--graph-rag-request-queue',
+ default=gr_request_queue,
+ help=f'Graph RAG request queue (default: {gr_request_queue})',
+ )
+
+ parser.add_argument(
+ '--graph-rag-response-queue',
+ default=gr_response_queue,
+ help=f'Graph RAG response queue (default: {gr_response_queue})',
+ )
parser.add_argument(
'--max-iterations',
@@ -267,5 +327,5 @@ class Processor(AgentService):
def run():
- Processor.launch(default_ident, __doc__)
+ Processor.launch(module, __doc__)
diff --git a/trustgraph-flow/trustgraph/agent/react/tools.py b/trustgraph-flow/trustgraph/agent/react/tools.py
index 31568b25..023abc02 100644
--- a/trustgraph-flow/trustgraph/agent/react/tools.py
+++ b/trustgraph-flow/trustgraph/agent/react/tools.py
@@ -4,22 +4,16 @@
class KnowledgeQueryImpl:
def __init__(self, context):
self.context = context
- async def invoke(self, **arguments):
- client = self.context("graph-rag-request")
- print("Graph RAG question...", flush=True)
- return await client.rag(
- arguments.get("question")
- )
+ def invoke(self, **arguments):
+ return self.context.graph_rag.request(arguments.get("question"))
# This tool implementation knows how to do text completion. This uses
# the prompt service, rather than talking to TextCompletion directly.
class TextCompletionImpl:
def __init__(self, context):
self.context = context
- async def invoke(self, **arguments):
- client = self.context("prompt-request")
- print("Prompt question...", flush=True)
- return await client.question(
- arguments.get("question")
+ def invoke(self, **arguments):
+ return self.context.prompt.request(
+ "question", { "question": arguments.get("question") }
)
diff --git a/trustgraph-flow/trustgraph/chunking/recursive/chunker.py b/trustgraph-flow/trustgraph/chunking/recursive/chunker.py
index aa48cc57..82f333b5 100755
--- a/trustgraph-flow/trustgraph/chunking/recursive/chunker.py
+++ b/trustgraph-flow/trustgraph/chunking/recursive/chunker.py
@@ -7,27 +7,40 @@ as text as separate output objects.
from langchain_text_splitters import RecursiveCharacterTextSplitter
from prometheus_client import Histogram
-from ... schema import TextDocument, Chunk
-from ... base import FlowProcessor, ConsumerSpec, ProducerSpec
+from ... schema import TextDocument, Chunk, Metadata
+from ... schema import text_ingest_queue, chunk_ingest_queue
+from ... log_level import LogLevel
+from ... base import ConsumerProducer
-default_ident = "chunker"
+module = ".".join(__name__.split(".")[1:-1])
-class Processor(FlowProcessor):
+default_input_queue = text_ingest_queue
+default_output_queue = chunk_ingest_queue
+default_subscriber = module
+
+class Processor(ConsumerProducer):
def __init__(self, **params):
- id = params.get("id", default_ident)
+ input_queue = params.get("input_queue", default_input_queue)
+ output_queue = params.get("output_queue", default_output_queue)
+ subscriber = params.get("subscriber", default_subscriber)
chunk_size = params.get("chunk_size", 2000)
chunk_overlap = params.get("chunk_overlap", 100)
super(Processor, self).__init__(
- **params | { "id": id }
+ **params | {
+ "input_queue": input_queue,
+ "output_queue": output_queue,
+ "subscriber": subscriber,
+ "input_schema": TextDocument,
+ "output_schema": Chunk,
+ }
)
if not hasattr(__class__, "chunk_metric"):
__class__.chunk_metric = Histogram(
'chunk_size', 'Chunk size',
- ["id", "flow"],
buckets=[100, 160, 250, 400, 650, 1000, 1600,
2500, 4000, 6400, 10000, 16000]
)
@@ -39,24 +52,7 @@ class Processor(FlowProcessor):
is_separator_regex=False,
)
- self.register_specification(
- ConsumerSpec(
- name = "input",
- schema = TextDocument,
- handler = self.on_message,
- )
- )
-
- self.register_specification(
- ProducerSpec(
- name = "output",
- schema = Chunk,
- )
- )
-
- print("Chunker initialised", flush=True)
-
- async def on_message(self, msg, consumer, flow):
+ async def handle(self, msg):
v = msg.value()
print(f"Chunking {v.metadata.id}...", flush=True)
@@ -67,25 +63,24 @@ class Processor(FlowProcessor):
for ix, chunk in enumerate(texts):
- print("Chunk", len(chunk.page_content), flush=True)
-
r = Chunk(
metadata=v.metadata,
chunk=chunk.page_content.encode("utf-8"),
)
- __class__.chunk_metric.labels(
- id=consumer.id, flow=consumer.flow
- ).observe(len(chunk.page_content))
+ __class__.chunk_metric.observe(len(chunk.page_content))
- await flow("output").send(r)
+ await self.send(r)
print("Done.", flush=True)
@staticmethod
def add_args(parser):
- FlowProcessor.add_args(parser)
+ ConsumerProducer.add_args(
+ parser, default_input_queue, default_subscriber,
+ default_output_queue,
+ )
parser.add_argument(
'-z', '--chunk-size',
@@ -103,5 +98,5 @@ class Processor(FlowProcessor):
def run():
- Processor.launch(default_ident, __doc__)
+ Processor.launch(module, __doc__)
diff --git a/trustgraph-flow/trustgraph/chunking/token/chunker.py b/trustgraph-flow/trustgraph/chunking/token/chunker.py
index ff217350..c625b48c 100755
--- a/trustgraph-flow/trustgraph/chunking/token/chunker.py
+++ b/trustgraph-flow/trustgraph/chunking/token/chunker.py
@@ -7,27 +7,40 @@ as text as separate output objects.
from langchain_text_splitters import TokenTextSplitter
from prometheus_client import Histogram
-from ... schema import TextDocument, Chunk
-from ... base import FlowProcessor
+from ... schema import TextDocument, Chunk, Metadata
+from ... schema import text_ingest_queue, chunk_ingest_queue
+from ... log_level import LogLevel
+from ... base import ConsumerProducer
-default_ident = "chunker"
+module = ".".join(__name__.split(".")[1:-1])
-class Processor(FlowProcessor):
+default_input_queue = text_ingest_queue
+default_output_queue = chunk_ingest_queue
+default_subscriber = module
+
+class Processor(ConsumerProducer):
def __init__(self, **params):
- id = params.get("id")
+ input_queue = params.get("input_queue", default_input_queue)
+ output_queue = params.get("output_queue", default_output_queue)
+ subscriber = params.get("subscriber", default_subscriber)
chunk_size = params.get("chunk_size", 250)
chunk_overlap = params.get("chunk_overlap", 15)
super(Processor, self).__init__(
- **params | { "id": id }
+ **params | {
+ "input_queue": input_queue,
+ "output_queue": output_queue,
+ "subscriber": subscriber,
+ "input_schema": TextDocument,
+ "output_schema": Chunk,
+ }
)
if not hasattr(__class__, "chunk_metric"):
__class__.chunk_metric = Histogram(
'chunk_size', 'Chunk size',
- ["id", "flow"],
buckets=[100, 160, 250, 400, 650, 1000, 1600,
2500, 4000, 6400, 10000, 16000]
)
@@ -38,24 +51,7 @@ class Processor(FlowProcessor):
chunk_overlap=chunk_overlap,
)
- self.register_specification(
- ConsumerSpec(
- name = "input",
- schema = TextDocument,
- handler = self.on_message,
- )
- )
-
- self.register_specification(
- ProducerSpec(
- name = "output",
- schema = Chunk,
- )
- )
-
- print("Chunker initialised", flush=True)
-
- async def on_message(self, msg, consumer, flow):
+ async def handle(self, msg):
v = msg.value()
print(f"Chunking {v.metadata.id}...", flush=True)
@@ -66,25 +62,24 @@ class Processor(FlowProcessor):
for ix, chunk in enumerate(texts):
- print("Chunk", len(chunk.page_content), flush=True)
-
r = Chunk(
metadata=v.metadata,
chunk=chunk.page_content.encode("utf-8"),
)
- __class__.chunk_metric.labels(
- id=consumer.id, flow=consumer.flow
- ).observe(len(chunk.page_content))
+ __class__.chunk_metric.observe(len(chunk.page_content))
- await flow("output").send(r)
+ await self.send(r)
print("Done.", flush=True)
@staticmethod
def add_args(parser):
- FlowProcessor.add_args(parser)
+ ConsumerProducer.add_args(
+ parser, default_input_queue, default_subscriber,
+ default_output_queue,
+ )
parser.add_argument(
'-z', '--chunk-size',
@@ -102,5 +97,5 @@ class Processor(FlowProcessor):
def run():
- Processor.launch(default_ident, __doc__)
+ Processor.launch(module, __doc__)
diff --git a/trustgraph-flow/trustgraph/config/service/config.py b/trustgraph-flow/trustgraph/config/service/config.py
deleted file mode 100644
index 46ade4c3..00000000
--- a/trustgraph-flow/trustgraph/config/service/config.py
+++ /dev/null
@@ -1,215 +0,0 @@
-
-from trustgraph.schema import ConfigResponse
-from trustgraph.schema import ConfigValue, Error
-
-# This behaves just like a dict, should be easier to add persistent storage
-# later
-class ConfigurationItems(dict):
- pass
-
-class Configuration(dict):
-
- # FIXME: The state is held internally. This only works if there's
- # one config service. Should be more than one, and use a
- # back-end state store.
-
- def __init__(self, push):
-
- # Version counter
- self.version = 0
-
- # External function to respond to update
- self.push = push
-
- def __getitem__(self, key):
- if key not in self:
- self[key] = ConfigurationItems()
- return dict.__getitem__(self, key)
-
- async def handle_get(self, v):
-
- for k in v.keys:
- if k.type not in self or k.key not in self[k.type]:
- return ConfigResponse(
- version = None,
- values = None,
- directory = None,
- config = None,
- error = Error(
- type = "key-error",
- message = f"Key error"
- )
- )
-
- values = [
- ConfigValue(
- type = k.type,
- key = k.key,
- value = self[k.type][k.key]
- )
- for k in v.keys
- ]
-
- return ConfigResponse(
- version = self.version,
- values = values,
- directory = None,
- config = None,
- error = None,
- )
-
- async def handle_list(self, v):
-
- if v.type not in self:
-
- return ConfigResponse(
- version = None,
- values = None,
- directory = None,
- config = None,
- error = Error(
- type = "key-error",
- message = "No such type",
- ),
- )
-
- return ConfigResponse(
- version = self.version,
- values = None,
- directory = list(self[v.type].keys()),
- config = None,
- error = None,
- )
-
- async def handle_getvalues(self, v):
-
- if v.type not in self:
-
- return ConfigResponse(
- version = None,
- values = None,
- directory = None,
- config = None,
- error = Error(
- type = "key-error",
- message = f"Key error"
- )
- )
-
- values = [
- ConfigValue(
- type = v.type,
- key = k,
- value = self[v.type][k],
- )
- for k in self[v.type]
- ]
-
- return ConfigResponse(
- version = self.version,
- values = values,
- directory = None,
- config = None,
- error = None,
- )
-
- async def handle_delete(self, v):
-
- for k in v.keys:
- if k.type not in self or k.key not in self[k.type]:
- return ConfigResponse(
- version = None,
- values = None,
- directory = None,
- config = None,
- error = Error(
- type = "key-error",
- message = f"Key error"
- )
- )
-
- for k in v.keys:
- del self[k.type][k.key]
-
- self.version += 1
-
- await self.push()
-
- return ConfigResponse(
- version = None,
- value = None,
- directory = None,
- values = None,
- config = None,
- error = None,
- )
-
- async def handle_put(self, v):
-
- for k in v.values:
- self[k.type][k.key] = k.value
-
- self.version += 1
-
- await self.push()
-
- return ConfigResponse(
- version = None,
- value = None,
- directory = None,
- values = None,
- error = None,
- )
-
- async def handle_config(self, v):
-
- return ConfigResponse(
- version = self.version,
- value = None,
- directory = None,
- values = None,
- config = self,
- error = None,
- )
-
- async def handle(self, msg):
-
- print("Handle message ", msg.operation)
-
- if msg.operation == "get":
-
- resp = await self.handle_get(msg)
-
- elif msg.operation == "list":
-
- resp = await self.handle_list(msg)
-
- elif msg.operation == "getvalues":
-
- resp = await self.handle_getvalues(msg)
-
- elif msg.operation == "delete":
-
- resp = await self.handle_delete(msg)
-
- elif msg.operation == "put":
-
- resp = await self.handle_put(msg)
-
- elif msg.operation == "config":
-
- resp = await self.handle_config(msg)
-
- else:
-
- resp = ConfigResponse(
- value=None,
- directory=None,
- values=None,
- error=Error(
- type = "bad-operation",
- message = "Bad operation"
- )
- )
-
- return resp
diff --git a/trustgraph-flow/trustgraph/config/service/flow.py b/trustgraph-flow/trustgraph/config/service/flow.py
deleted file mode 100644
index 3933e4aa..00000000
--- a/trustgraph-flow/trustgraph/config/service/flow.py
+++ /dev/null
@@ -1,228 +0,0 @@
-
-from trustgraph.schema import FlowResponse, Error
-import json
-
-class FlowConfig:
- def __init__(self, config):
-
- self.config = config
-
- async def handle_list_classes(self, msg):
-
- names = list(self.config["flow-classes"].keys())
-
- return FlowResponse(
- error = None,
- class_names = names,
- )
-
- async def handle_get_class(self, msg):
-
- return FlowResponse(
- error = None,
- class_definition = self.config["flow-classes"][msg.class_name],
- )
-
- async def handle_put_class(self, msg):
-
- self.config["flow-classes"][msg.class_name] = msg.class_definition
-
- await self.config.push()
-
- return FlowResponse(
- error = None,
- )
-
- async def handle_delete_class(self, msg):
-
- print(msg)
-
- del self.config["flow-classes"][msg.class_name]
-
- await self.config.push()
-
- return FlowResponse(
- error = None,
- )
-
- async def handle_list_flows(self, msg):
-
- names = list(self.config["flows"].keys())
-
- return FlowResponse(
- error = None,
- flow_ids = names,
- )
-
- async def handle_get_flow(self, msg):
-
- flow = self.config["flows"][msg.flow_id]
-
- return FlowResponse(
- error = None,
- flow = flow,
- )
-
- async def handle_start_flow(self, msg):
-
- if msg.class_name is None:
- raise RuntimeError("No class name")
-
- if msg.flow_id is None:
- raise RuntimeError("No flow ID")
-
- if msg.flow_id in self.config["flows"]:
- raise RuntimeError("Flow already exists")
-
- if msg.description is None:
- raise RuntimeError("No description")
-
- if msg.class_name not in self.config["flow-classes"]:
- raise RuntimeError("Class does not exist")
-
- def repl_template(tmp):
- return tmp.replace(
- "{class}", msg.class_name
- ).replace(
- "{id}", msg.flow_id
- )
-
- cls = json.loads(self.config["flow-classes"][msg.class_name])
-
- for kind in ("class", "flow"):
-
- for k, v in cls[kind].items():
-
- processor, variant = k.split(":", 1)
-
- variant = repl_template(variant)
-
- v = {
- repl_template(k2): repl_template(v2)
- for k2, v2 in v.items()
- }
-
- if processor in self.config["flows-active"]:
- target = json.loads(self.config["flows-active"][processor])
- else:
- target = {}
-
- if variant not in target:
- target[variant] = v
-
- self.config["flows-active"][processor] = json.dumps(target)
-
- def repl_interface(i):
- if isinstance(i, str):
- return repl_template(i)
- else:
- return {
- k: repl_template(v)
- for k, v in i.items()
- }
-
- if "interfaces" in cls:
- interfaces = {
- k: repl_interface(v)
- for k, v in cls["interfaces"].items()
- }
- else:
- interfaces = {}
-
- self.config["flows"][msg.flow_id] = json.dumps({
- "description": msg.description,
- "class-name": msg.class_name,
- "interfaces": interfaces,
- })
-
- await self.config.push()
-
- return FlowResponse(
- error = None,
- )
-
- async def handle_stop_flow(self, msg):
-
- if msg.flow_id is None:
- raise RuntimeError("No flow ID")
-
- if msg.flow_id not in self.config["flows"]:
- raise RuntimeError("Flow ID invalid")
-
- flow = json.loads(self.config["flows"][msg.flow_id])
-
- if "class-name" not in flow:
- raise RuntimeError("Internal error: flow has no flow class")
-
- class_name = flow["class-name"]
-
- cls = json.loads(self.config["flow-classes"][class_name])
-
- def repl_template(tmp):
- return tmp.replace(
- "{class}", class_name
- ).replace(
- "{id}", msg.flow_id
- )
-
- for kind in ("flow",):
-
- for k, v in cls[kind].items():
-
- processor, variant = k.split(":", 1)
-
- variant = repl_template(variant)
-
- if processor in self.config["flows-active"]:
- target = json.loads(self.config["flows-active"][processor])
- else:
- target = {}
-
- if variant in target:
- del target[variant]
-
- self.config["flows-active"][processor] = json.dumps(target)
-
- if msg.flow_id in self.config["flows"]:
- del self.config["flows"][msg.flow_id]
-
- await self.config.push()
-
- return FlowResponse(
- error = None,
- )
-
- async def handle(self, msg):
-
- print("Handle message ", msg.operation)
-
- if msg.operation == "list-classes":
- resp = await self.handle_list_classes(msg)
- elif msg.operation == "get-class":
- resp = await self.handle_get_class(msg)
- elif msg.operation == "put-class":
- resp = await self.handle_put_class(msg)
- elif msg.operation == "delete-class":
- resp = await self.handle_delete_class(msg)
- elif msg.operation == "list-flows":
- resp = await self.handle_list_flows(msg)
- elif msg.operation == "get-flow":
- resp = await self.handle_get_flow(msg)
- elif msg.operation == "start-flow":
- resp = await self.handle_start_flow(msg)
- elif msg.operation == "stop-flow":
- resp = await self.handle_stop_flow(msg)
- else:
-
- resp = FlowResponse(
- value=None,
- directory=None,
- values=None,
- error=Error(
- type = "bad-operation",
- message = "Bad operation"
- )
- )
-
- return resp
-
diff --git a/trustgraph-flow/trustgraph/config/service/service.py b/trustgraph-flow/trustgraph/config/service/service.py
index c0268389..ee0c960e 100644
--- a/trustgraph-flow/trustgraph/config/service/service.py
+++ b/trustgraph-flow/trustgraph/config/service/service.py
@@ -1,148 +1,214 @@
"""
-Config service. Manages system global configuration state
+Config service. Fetchs an extract from the Wikipedia page
+using the API.
"""
from pulsar.schema import JsonSchema
-from trustgraph.schema import Error
-
from trustgraph.schema import ConfigRequest, ConfigResponse, ConfigPush
+from trustgraph.schema import ConfigValue, Error
from trustgraph.schema import config_request_queue, config_response_queue
from trustgraph.schema import config_push_queue
-
-from trustgraph.schema import FlowRequest, FlowResponse
-from trustgraph.schema import flow_request_queue, flow_response_queue
-
from trustgraph.log_level import LogLevel
-from trustgraph.base import AsyncProcessor, Consumer, Producer
+from trustgraph.base import ConsumerProducer
-from . config import Configuration
-from . flow import FlowConfig
+module = ".".join(__name__.split(".")[1:-1])
-from ... base import ProcessorMetrics, ConsumerMetrics, ProducerMetrics
-from ... base import Consumer, Producer
+default_input_queue = config_request_queue
+default_output_queue = config_response_queue
+default_push_queue = config_push_queue
+default_subscriber = module
-default_ident = "config-svc"
+# This behaves just like a dict, should be easier to add persistent storage
+# later
-default_config_request_queue = config_request_queue
-default_config_response_queue = config_response_queue
-default_config_push_queue = config_push_queue
+class ConfigurationItems(dict):
+ pass
-default_flow_request_queue = flow_request_queue
-default_flow_response_queue = flow_response_queue
+class Configuration(dict):
-class Processor(AsyncProcessor):
+ def __getitem__(self, key):
+ if key not in self:
+ self[key] = ConfigurationItems()
+ return dict.__getitem__(self, key)
+
+class Processor(ConsumerProducer):
def __init__(self, **params):
-
- config_request_queue = params.get(
- "config_request_queue", default_config_request_queue
- )
- config_response_queue = params.get(
- "config_response_queue", default_config_response_queue
- )
- config_push_queue = params.get(
- "config_push_queue", default_config_push_queue
- )
- flow_request_queue = params.get(
- "flow_request_queue", default_flow_request_queue
- )
- flow_response_queue = params.get(
- "flow_response_queue", default_flow_response_queue
- )
-
- id = params.get("id")
-
- flow_request_schema = FlowRequest
- flow_response_schema = FlowResponse
+ input_queue = params.get("input_queue", default_input_queue)
+ output_queue = params.get("output_queue", default_output_queue)
+ push_queue = params.get("push_queue", default_push_queue)
+ subscriber = params.get("subscriber", default_subscriber)
super(Processor, self).__init__(
**params | {
- "config_request_schema": ConfigRequest.__name__,
- "config_response_schema": ConfigResponse.__name__,
- "config_push_schema": ConfigPush.__name__,
- "flow_request_schema": FlowRequest.__name__,
- "flow_response_schema": FlowResponse.__name__,
+ "input_queue": input_queue,
+ "output_queue": output_queue,
+ "push_queue": output_queue,
+ "subscriber": subscriber,
+ "input_schema": ConfigRequest,
+ "output_schema": ConfigResponse,
+ "push_schema": ConfigPush,
}
)
- config_request_metrics = ConsumerMetrics(
- processor = self.id, flow = None, name = "config-request"
- )
- config_response_metrics = ProducerMetrics(
- processor = self.id, flow = None, name = "config-response"
- )
- config_push_metrics = ProducerMetrics(
- processor = self.id, flow = None, name = "config-push"
+ self.push_prod = self.client.create_producer(
+ topic=push_queue,
+ schema=JsonSchema(ConfigPush),
)
- flow_request_metrics = ConsumerMetrics(
- processor = self.id, flow = None, name = "flow-request"
- )
- flow_response_metrics = ProducerMetrics(
- processor = self.id, flow = None, name = "flow-response"
- )
+ # FIXME: The state is held internally. This only works if there's
+ # one config service. Should be more than one, and use a
+ # back-end state store.
+ self.config = Configuration()
- self.config_request_consumer = Consumer(
- taskgroup = self.taskgroup,
- client = self.pulsar_client,
- flow = None,
- topic = config_request_queue,
- subscriber = id,
- schema = ConfigRequest,
- handler = self.on_config_request,
- metrics = config_request_metrics,
- )
-
- self.config_response_producer = Producer(
- client = self.pulsar_client,
- topic = config_response_queue,
- schema = ConfigResponse,
- metrics = config_response_metrics,
- )
-
- self.config_push_producer = Producer(
- client = self.pulsar_client,
- topic = config_push_queue,
- schema = ConfigPush,
- metrics = config_push_metrics,
- )
-
- self.flow_request_consumer = Consumer(
- taskgroup = self.taskgroup,
- client = self.pulsar_client,
- flow = None,
- topic = flow_request_queue,
- subscriber = id,
- schema = FlowRequest,
- handler = self.on_flow_request,
- metrics = flow_request_metrics,
- )
-
- self.flow_response_producer = Producer(
- client = self.pulsar_client,
- topic = flow_response_queue,
- schema = FlowResponse,
- metrics = flow_response_metrics,
- )
-
- self.config = Configuration(self.push)
- self.flow = FlowConfig(self.config)
-
- print("Service initialised.")
+ # Version counter
+ self.version = 0
async def start(self):
+ await self.push()
+
+ async def handle_get(self, v, id):
+
+ for k in v.keys:
+ if k.type not in self.config or k.key not in self.config[k.type]:
+ return ConfigResponse(
+ version = None,
+ values = None,
+ directory = None,
+ config = None,
+ error = Error(
+ type = "key-error",
+ message = f"Key error"
+ )
+ )
+
+ values = [
+ ConfigValue(
+ type = k.type,
+ key = k.key,
+ value = self.config[k.type][k.key]
+ )
+ for k in v.keys
+ ]
+
+ return ConfigResponse(
+ version = self.version,
+ values = values,
+ directory = None,
+ config = None,
+ error = None,
+ )
+
+ async def handle_list(self, v, id):
+
+ if v.type not in self.config:
+
+ return ConfigResponse(
+ version = None,
+ values = None,
+ directory = None,
+ config = None,
+ error = Error(
+ type = "key-error",
+ message = "No such type",
+ ),
+ )
+
+ return ConfigResponse(
+ version = self.version,
+ values = None,
+ directory = list(self.config[v.type].keys()),
+ config = None,
+ error = None,
+ )
+
+ async def handle_getvalues(self, v, id):
+
+ if v.type not in self.config:
+
+ return ConfigResponse(
+ version = None,
+ values = None,
+ directory = None,
+ config = None,
+ error = Error(
+ type = "key-error",
+ message = f"Key error"
+ )
+ )
+
+ values = [
+ ConfigValue(
+ type = v.type,
+ key = k,
+ value = self.config[v.type][k],
+ )
+ for k in self.config[v.type]
+ ]
+
+ return ConfigResponse(
+ version = self.version,
+ values = values,
+ directory = None,
+ config = None,
+ error = None,
+ )
+
+ async def handle_delete(self, v, id):
+
+ for k in v.keys:
+ if k.type not in self.config or k.key not in self.config[k.type]:
+ return ConfigResponse(
+ version = None,
+ values = None,
+ directory = None,
+ config = None,
+ error = Error(
+ type = "key-error",
+ message = f"Key error"
+ )
+ )
+
+ for k in v.keys:
+ del self.config[k.type][k.key]
+
+ self.version += 1
await self.push()
- await self.config_request_consumer.start()
- await self.flow_request_consumer.start()
-
- async def push(self):
- resp = ConfigPush(
- version = self.config.version,
+ return ConfigResponse(
+ version = None,
+ value = None,
+ directory = None,
+ values = None,
+ config = None,
+ error = None,
+ )
+
+ async def handle_put(self, v, id):
+
+ for k in v.values:
+ self.config[k.type][k.key] = k.value
+
+ self.version += 1
+
+ await self.push()
+
+ return ConfigResponse(
+ version = None,
+ value = None,
+ directory = None,
+ values = None,
+ error = None,
+ )
+
+ async def handle_config(self, v, id):
+
+ return ConfigResponse(
+ version = self.version,
value = None,
directory = None,
values = None,
@@ -150,108 +216,97 @@ class Processor(AsyncProcessor):
error = None,
)
- await self.config_push_producer.send(resp)
+ async def push(self):
- print("Pushed version ", self.config.version)
+ resp = ConfigPush(
+ version = self.version,
+ value = None,
+ directory = None,
+ values = None,
+ config = self.config,
+ error = None,
+ )
+ self.push_prod.send(resp)
+ print("Pushed.")
- async def on_config_request(self, msg, consumer, flow):
+ async def handle(self, msg):
+
+ v = msg.value()
+
+ # Sender-produced ID
+ id = msg.properties()["id"]
+
+ print(f"Handling {id}...", flush=True)
try:
- v = msg.value()
+ if v.operation == "get":
- # Sender-produced ID
- id = msg.properties()["id"]
+ resp = await self.handle_get(v, id)
- print(f"Handling {id}...", flush=True)
+ elif v.operation == "list":
- resp = await self.config.handle(v)
+ resp = await self.handle_list(v, id)
- await self.config_response_producer.send(
- resp, properties={"id": id}
- )
+ elif v.operation == "getvalues":
+
+ resp = await self.handle_getvalues(v, id)
+
+ elif v.operation == "delete":
+
+ resp = await self.handle_delete(v, id)
+
+ elif v.operation == "put":
+
+ resp = await self.handle_put(v, id)
+
+ elif v.operation == "config":
+
+ resp = await self.handle_config(v, id)
+
+ else:
+
+ resp = ConfigResponse(
+ value=None,
+ directory=None,
+ values=None,
+ error=Error(
+ type = "bad-operation",
+ message = "Bad operation"
+ )
+ )
+
+ await self.send(resp, properties={"id": id})
+
+ self.consumer.acknowledge(msg)
except Exception as e:
-
+
resp = ConfigResponse(
error=Error(
- type = "config-error",
+ type = "unexpected-error",
message = str(e),
),
text=None,
)
-
- await self.config_response_producer.send(
- resp, properties={"id": id}
- )
-
- async def on_flow_request(self, msg, consumer, flow):
-
- try:
-
- v = msg.value()
-
- # Sender-produced ID
- id = msg.properties()["id"]
-
- print(f"Handling {id}...", flush=True)
-
- resp = await self.flow.handle(v)
-
- await self.flow_response_producer.send(
- resp, properties={"id": id}
- )
-
- except Exception as e:
-
- resp = FlowResponse(
- error=Error(
- type = "flow-error",
- message = str(e),
- ),
- text=None,
- )
-
- await self.flow_response_producer.send(
- resp, properties={"id": id}
- )
+ await self.send(resp, properties={"id": id})
+ self.consumer.acknowledge(msg)
@staticmethod
def add_args(parser):
- AsyncProcessor.add_args(parser)
-
- parser.add_argument(
- '--config-request-queue',
- default=default_config_request_queue,
- help=f'Config request queue (default: {default_config_request_queue})'
+ ConsumerProducer.add_args(
+ parser, default_input_queue, default_subscriber,
+ default_output_queue,
)
parser.add_argument(
- '--config-response-queue',
- default=default_config_response_queue,
- help=f'Config response queue {default_config_response_queue}',
- )
-
- parser.add_argument(
- '--push-queue',
- default=default_config_push_queue,
- help=f'Config push queue (default: {default_config_push_queue})'
- )
-
- parser.add_argument(
- '--flow-request-queue',
- default=default_flow_request_queue,
- help=f'Flow request queue (default: {default_flow_request_queue})'
- )
-
- parser.add_argument(
- '--flow-response-queue',
- default=default_flow_response_queue,
- help=f'Flow response queue {default_flow_response_queue}',
+ '-q', '--push-queue',
+ default=default_push_queue,
+ help=f'Config push queue (default: {default_push_queue})'
)
def run():
- Processor.launch(default_ident, __doc__)
+ Processor.launch(module, __doc__)
diff --git a/trustgraph-flow/trustgraph/decoding/mistral_ocr/processor.py b/trustgraph-flow/trustgraph/decoding/mistral_ocr/processor.py
index e42d1601..f5100244 100755
--- a/trustgraph-flow/trustgraph/decoding/mistral_ocr/processor.py
+++ b/trustgraph-flow/trustgraph/decoding/mistral_ocr/processor.py
@@ -17,10 +17,12 @@ from mistralai.models import OCRResponse
from ... schema import Document, TextDocument, Metadata
from ... schema import document_ingest_queue, text_ingest_queue
from ... log_level import LogLevel
-from ... base import InputOutputProcessor
+from ... base import ConsumerProducer
-module = "ocr"
+module = ".".join(__name__.split(".")[1:-1])
+default_input_queue = document_ingest_queue
+default_output_queue = text_ingest_queue
default_subscriber = module
default_api_key = os.getenv("MISTRAL_TOKEN")
@@ -69,17 +71,19 @@ def get_combined_markdown(ocr_response: OCRResponse) -> str:
return "\n\n".join(markdowns)
-class Processor(InputOutputProcessor):
+class Processor(ConsumerProducer):
def __init__(self, **params):
- id = params.get("id")
+ input_queue = params.get("input_queue", default_input_queue)
+ output_queue = params.get("output_queue", default_output_queue)
subscriber = params.get("subscriber", default_subscriber)
api_key = params.get("api_key", default_api_key)
super(Processor, self).__init__(
**params | {
- "id": id,
+ "input_queue": input_queue,
+ "output_queue": output_queue,
"subscriber": subscriber,
"input_schema": Document,
"output_schema": TextDocument,
@@ -147,7 +151,7 @@ class Processor(InputOutputProcessor):
return markdown
- async def on_message(self, msg, consumer):
+ async def handle(self, msg):
print("PDF message received")
@@ -162,14 +166,17 @@ class Processor(InputOutputProcessor):
text=markdown.encode("utf-8"),
)
- await consumer.q.output.send(r)
+ await self.send(r)
print("Done.", flush=True)
@staticmethod
def add_args(parser):
- InputOutputProcessor.add_args(parser, default_subscriber)
+ ConsumerProducer.add_args(
+ parser, default_input_queue, default_subscriber,
+ default_output_queue,
+ )
parser.add_argument(
'-k', '--api-key',
diff --git a/trustgraph-flow/trustgraph/decoding/pdf/pdf_decoder.py b/trustgraph-flow/trustgraph/decoding/pdf/pdf_decoder.py
index d0669a59..5e5e3612 100755
--- a/trustgraph-flow/trustgraph/decoding/pdf/pdf_decoder.py
+++ b/trustgraph-flow/trustgraph/decoding/pdf/pdf_decoder.py
@@ -9,43 +9,39 @@ import base64
from langchain_community.document_loaders import PyPDFLoader
from ... schema import Document, TextDocument, Metadata
+from ... schema import document_ingest_queue, text_ingest_queue
from ... log_level import LogLevel
-from ... base import FlowProcessor, ConsumerSpec, ProducerSpec
+from ... base import ConsumerProducer
-default_ident = "pdf-decoder"
+module = ".".join(__name__.split(".")[1:-1])
-class Processor(FlowProcessor):
+default_input_queue = document_ingest_queue
+default_output_queue = text_ingest_queue
+default_subscriber = module
+
+class Processor(ConsumerProducer):
def __init__(self, **params):
- id = params.get("id", default_ident)
+ input_queue = params.get("input_queue", default_input_queue)
+ output_queue = params.get("output_queue", default_output_queue)
+ subscriber = params.get("subscriber", default_subscriber)
super(Processor, self).__init__(
**params | {
- "id": id,
+ "input_queue": input_queue,
+ "output_queue": output_queue,
+ "subscriber": subscriber,
+ "input_schema": Document,
+ "output_schema": TextDocument,
}
)
- self.register_specification(
- ConsumerSpec(
- name = "input",
- schema = Document,
- handler = self.on_message,
- )
- )
+ print("PDF inited")
- self.register_specification(
- ProducerSpec(
- name = "output",
- schema = TextDocument,
- )
- )
+ async def handle(self, msg):
- print("PDF inited", flush=True)
-
- async def on_message(self, msg, consumer, flow):
-
- print("PDF message received", flush=True)
+ print("PDF message received")
v = msg.value()
@@ -63,22 +59,24 @@ class Processor(FlowProcessor):
for ix, page in enumerate(pages):
- print("page", ix, flush=True)
-
r = TextDocument(
metadata=v.metadata,
text=page.page_content.encode("utf-8"),
)
- await flow("output").send(r)
+ await self.send(r)
print("Done.", flush=True)
@staticmethod
def add_args(parser):
- FlowProcessor.add_args(parser)
+
+ ConsumerProducer.add_args(
+ parser, default_input_queue, default_subscriber,
+ default_output_queue,
+ )
def run():
- Processor.launch(default_ident, __doc__)
+ Processor.launch(module, __doc__)
diff --git a/trustgraph-flow/trustgraph/document_rag.py b/trustgraph-flow/trustgraph/document_rag.py
new file mode 100644
index 00000000..4fc4850a
--- /dev/null
+++ b/trustgraph-flow/trustgraph/document_rag.py
@@ -0,0 +1,153 @@
+
+from . clients.document_embeddings_client import DocumentEmbeddingsClient
+from . clients.triples_query_client import TriplesQueryClient
+from . clients.embeddings_client import EmbeddingsClient
+from . clients.prompt_client import PromptClient
+
+from . schema import DocumentEmbeddingsRequest, DocumentEmbeddingsResponse
+from . schema import TriplesQueryRequest, TriplesQueryResponse
+from . schema import prompt_request_queue
+from . schema import prompt_response_queue
+from . schema import embeddings_request_queue
+from . schema import embeddings_response_queue
+from . schema import document_embeddings_request_queue
+from . schema import document_embeddings_response_queue
+
+LABEL="http://www.w3.org/2000/01/rdf-schema#label"
+DEFINITION="http://www.w3.org/2004/02/skos/core#definition"
+
+class Query:
+
+ def __init__(
+ self, rag, user, collection, verbose,
+ doc_limit=20
+ ):
+ self.rag = rag
+ self.user = user
+ self.collection = collection
+ self.verbose = verbose
+ self.doc_limit = doc_limit
+
+ def get_vector(self, query):
+
+ if self.verbose:
+ print("Compute embeddings...", flush=True)
+
+ qembeds = self.rag.embeddings.request(query)
+
+ if self.verbose:
+ print("Done.", flush=True)
+
+ return qembeds
+
+ def get_docs(self, query):
+
+ vectors = self.get_vector(query)
+
+ if self.verbose:
+ print("Get entities...", flush=True)
+
+ docs = self.rag.de_client.request(
+ vectors, limit=self.doc_limit
+ )
+
+ if self.verbose:
+ print("Docs:", flush=True)
+ for doc in docs:
+ print(doc, flush=True)
+
+ return docs
+
+class DocumentRag:
+
+ def __init__(
+ self,
+ pulsar_host="pulsar://pulsar:6650",
+ pulsar_api_key=None,
+ pr_request_queue=None,
+ pr_response_queue=None,
+ emb_request_queue=None,
+ emb_response_queue=None,
+ de_request_queue=None,
+ de_response_queue=None,
+ verbose=False,
+ module="test",
+ ):
+
+ self.verbose=verbose
+
+ if pr_request_queue is None:
+ pr_request_queue = prompt_request_queue
+
+ if pr_response_queue is None:
+ pr_response_queue = prompt_response_queue
+
+ if emb_request_queue is None:
+ emb_request_queue = embeddings_request_queue
+
+ if emb_response_queue is None:
+ emb_response_queue = embeddings_response_queue
+
+ if de_request_queue is None:
+ de_request_queue = document_embeddings_request_queue
+
+ if de_response_queue is None:
+ de_response_queue = document_embeddings_response_queue
+
+ if self.verbose:
+ print("Initialising...", flush=True)
+
+ self.de_client = DocumentEmbeddingsClient(
+ pulsar_host=pulsar_host,
+ subscriber=module + "-de",
+ input_queue=de_request_queue,
+ output_queue=de_response_queue,
+ pulsar_api_key=pulsar_api_key,
+ )
+
+ self.embeddings = EmbeddingsClient(
+ pulsar_host=pulsar_host,
+ input_queue=emb_request_queue,
+ output_queue=emb_response_queue,
+ subscriber=module + "-emb",
+ pulsar_api_key=pulsar_api_key,
+ )
+
+ self.lang = PromptClient(
+ pulsar_host=pulsar_host,
+ input_queue=pr_request_queue,
+ output_queue=pr_response_queue,
+ subscriber=module + "-de-prompt",
+ pulsar_api_key=pulsar_api_key,
+ )
+
+ if self.verbose:
+ print("Initialised", flush=True)
+
+ def query(
+ self, query, user="trustgraph", collection="default",
+ doc_limit=20,
+ ):
+
+ if self.verbose:
+ print("Construct prompt...", flush=True)
+
+ q = Query(
+ rag=self, user=user, collection=collection, verbose=self.verbose,
+ doc_limit=doc_limit
+ )
+
+ docs = q.get_docs(query)
+
+ if self.verbose:
+ print("Invoke LLM...", flush=True)
+ print(docs)
+ print(query)
+
+ resp = self.lang.request_document_prompt(query, docs)
+
+ if self.verbose:
+ print("Done", flush=True)
+
+ return resp
+
diff --git a/trustgraph-flow/trustgraph/embeddings/document_embeddings/embeddings.py b/trustgraph-flow/trustgraph/embeddings/document_embeddings/embeddings.py
index 95e5462d..70f53e07 100755
--- a/trustgraph-flow/trustgraph/embeddings/document_embeddings/embeddings.py
+++ b/trustgraph-flow/trustgraph/embeddings/document_embeddings/embeddings.py
@@ -6,63 +6,61 @@ Output is chunk plus embedding.
"""
from ... schema import Chunk, ChunkEmbeddings, DocumentEmbeddings
-from ... schema import EmbeddingsRequest, EmbeddingsResponse
+from ... schema import chunk_ingest_queue
+from ... schema import document_embeddings_store_queue
+from ... schema import embeddings_request_queue, embeddings_response_queue
+from ... clients.embeddings_client import EmbeddingsClient
+from ... log_level import LogLevel
+from ... base import ConsumerProducer
-from ... base import FlowProcessor, RequestResponseSpec, ConsumerSpec
-from ... base import ProducerSpec
+module = ".".join(__name__.split(".")[1:-1])
-default_ident = "document-embeddings"
+default_input_queue = chunk_ingest_queue
+default_output_queue = document_embeddings_store_queue
+default_subscriber = module
-class Processor(FlowProcessor):
+class Processor(ConsumerProducer):
def __init__(self, **params):
- id = params.get("id")
+ input_queue = params.get("input_queue", default_input_queue)
+ output_queue = params.get("output_queue", default_output_queue)
+ subscriber = params.get("subscriber", default_subscriber)
+ emb_request_queue = params.get(
+ "embeddings_request_queue", embeddings_request_queue
+ )
+ emb_response_queue = params.get(
+ "embeddings_response_queue", embeddings_response_queue
+ )
super(Processor, self).__init__(
**params | {
- "id": id,
+ "input_queue": input_queue,
+ "output_queue": output_queue,
+ "embeddings_request_queue": emb_request_queue,
+ "embeddings_response_queue": emb_response_queue,
+ "subscriber": subscriber,
+ "input_schema": Chunk,
+ "output_schema": DocumentEmbeddings,
}
)
- self.register_specification(
- ConsumerSpec(
- name = "input",
- schema = Chunk,
- handler = self.on_message,
- )
+ self.embeddings = EmbeddingsClient(
+ pulsar_host=self.pulsar_host,
+ pulsar_api_key=self.pulsar_api_key,
+ input_queue=emb_request_queue,
+ output_queue=emb_response_queue,
+ subscriber=module + "-emb",
)
- self.register_specification(
- RequestResponseSpec(
- request_name = "embeddings-request",
- request_schema = EmbeddingsRequest,
- response_name = "embeddings-response",
- response_schema = EmbeddingsResponse,
- )
- )
-
- self.register_specification(
- ProducerSpec(
- name = "output",
- schema = DocumentEmbeddings
- )
- )
-
- async def on_message(self, msg, consumer, flow):
+ async def handle(self, msg):
v = msg.value()
print(f"Indexing {v.metadata.id}...", flush=True)
try:
- resp = await flow("embeddings-request").request(
- EmbeddingsRequest(
- text = v.chunk
- )
- )
-
- vectors = resp.vectors
+ vectors = self.embeddings.request(v.chunk)
embeds = [
ChunkEmbeddings(
@@ -76,7 +74,7 @@ class Processor(FlowProcessor):
chunks=embeds,
)
- await flow("output").send(r)
+ await self.send(r)
except Exception as e:
print("Exception:", e, flush=True)
@@ -89,9 +87,24 @@ class Processor(FlowProcessor):
@staticmethod
def add_args(parser):
- FlowProcessor.add_args(parser)
+ ConsumerProducer.add_args(
+ parser, default_input_queue, default_subscriber,
+ default_output_queue,
+ )
+
+ parser.add_argument(
+ '--embeddings-request-queue',
+ default=embeddings_request_queue,
+ help=f'Embeddings request queue (default: {embeddings_request_queue})',
+ )
+
+ parser.add_argument(
+ '--embeddings-response-queue',
+ default=embeddings_response_queue,
+ help=f'Embeddings request queue (default: {embeddings_response_queue})',
+ )
def run():
- Processor.launch(default_ident, __doc__)
+ Processor.launch(module, __doc__)
diff --git a/trustgraph-flow/trustgraph/embeddings/fastembed/processor.py b/trustgraph-flow/trustgraph/embeddings/fastembed/processor.py
index a4ae35dc..bc164fa0 100755
--- a/trustgraph-flow/trustgraph/embeddings/fastembed/processor.py
+++ b/trustgraph-flow/trustgraph/embeddings/fastembed/processor.py
@@ -1,43 +1,81 @@
"""
-Embeddings service, applies an embeddings model using fastembed
+Embeddings service, applies an embeddings model selected from HuggingFace.
Input is text, output is embeddings vector.
"""
-from ... base import EmbeddingsService
-
+from ... schema import EmbeddingsRequest, EmbeddingsResponse
+from ... schema import embeddings_request_queue, embeddings_response_queue
+from ... log_level import LogLevel
+from ... base import ConsumerProducer
from fastembed import TextEmbedding
+import os
-default_ident = "embeddings"
+module = ".".join(__name__.split(".")[1:-1])
+default_input_queue = embeddings_request_queue
+default_output_queue = embeddings_response_queue
+default_subscriber = module
default_model="sentence-transformers/all-MiniLM-L6-v2"
-class Processor(EmbeddingsService):
+class Processor(ConsumerProducer):
def __init__(self, **params):
+ input_queue = params.get("input_queue", default_input_queue)
+ output_queue = params.get("output_queue", default_output_queue)
+ subscriber = params.get("subscriber", default_subscriber)
+
model = params.get("model", default_model)
super(Processor, self).__init__(
- **params | { "model": model }
+ **params | {
+ "input_queue": input_queue,
+ "output_queue": output_queue,
+ "subscriber": subscriber,
+ "input_schema": EmbeddingsRequest,
+ "output_schema": EmbeddingsResponse,
+ "model": model,
+ }
)
- print("Get model...", flush=True)
self.embeddings = TextEmbedding(model_name = model)
- async def on_embeddings(self, text):
+ async def handle(self, msg):
+ v = msg.value()
+
+ # Sender-produced ID
+
+ id = msg.properties()["id"]
+
+ print(f"Handling input {id}...", flush=True)
+
+ text = v.text
vecs = self.embeddings.embed([text])
- return [
+ vecs = [
v.tolist()
for v in vecs
]
+ print("Send response...", flush=True)
+ r = EmbeddingsResponse(
+ vectors=list(vecs),
+ error=None,
+ )
+
+ await self.send(r, properties={"id": id})
+
+ print("Done.", flush=True)
+
@staticmethod
def add_args(parser):
- EmbeddingsService.add_args(parser)
+ ConsumerProducer.add_args(
+ parser, default_input_queue, default_subscriber,
+ default_output_queue,
+ )
parser.add_argument(
'-m', '--model',
@@ -47,5 +85,5 @@ class Processor(EmbeddingsService):
def run():
- Processor.launch(default_ident, __doc__)
+ Processor.launch(module, __doc__)
diff --git a/trustgraph-flow/trustgraph/embeddings/graph_embeddings/embeddings.py b/trustgraph-flow/trustgraph/embeddings/graph_embeddings/embeddings.py
index 043be3a7..2cbe9907 100755
--- a/trustgraph-flow/trustgraph/embeddings/graph_embeddings/embeddings.py
+++ b/trustgraph-flow/trustgraph/embeddings/graph_embeddings/embeddings.py
@@ -6,48 +6,53 @@ Output is entity plus embedding.
"""
from ... schema import EntityContexts, EntityEmbeddings, GraphEmbeddings
-from ... schema import EmbeddingsRequest, EmbeddingsResponse
+from ... schema import entity_contexts_ingest_queue
+from ... schema import graph_embeddings_store_queue
+from ... schema import embeddings_request_queue, embeddings_response_queue
+from ... clients.embeddings_client import EmbeddingsClient
+from ... log_level import LogLevel
+from ... base import ConsumerProducer
-from ... base import FlowProcessor, EmbeddingsClientSpec, ConsumerSpec
-from ... base import ProducerSpec
+module = ".".join(__name__.split(".")[1:-1])
-default_ident = "graph-embeddings"
+default_input_queue = entity_contexts_ingest_queue
+default_output_queue = graph_embeddings_store_queue
+default_subscriber = module
-class Processor(FlowProcessor):
+class Processor(ConsumerProducer):
def __init__(self, **params):
- id = params.get("id")
+ input_queue = params.get("input_queue", default_input_queue)
+ output_queue = params.get("output_queue", default_output_queue)
+ subscriber = params.get("subscriber", default_subscriber)
+ emb_request_queue = params.get(
+ "embeddings_request_queue", embeddings_request_queue
+ )
+ emb_response_queue = params.get(
+ "embeddings_response_queue", embeddings_response_queue
+ )
super(Processor, self).__init__(
**params | {
- "id": id,
+ "input_queue": input_queue,
+ "output_queue": output_queue,
+ "embeddings_request_queue": emb_request_queue,
+ "embeddings_response_queue": emb_response_queue,
+ "subscriber": subscriber,
+ "input_schema": EntityContexts,
+ "output_schema": GraphEmbeddings,
}
)
- self.register_specification(
- ConsumerSpec(
- name = "input",
- schema = EntityContexts,
- handler = self.on_message,
- )
+ self.embeddings = EmbeddingsClient(
+ pulsar_host=self.pulsar_host,
+ input_queue=emb_request_queue,
+ output_queue=emb_response_queue,
+ subscriber=module + "-emb",
)
- self.register_specification(
- EmbeddingsClientSpec(
- request_name = "embeddings-request",
- response_name = "embeddings-response",
- )
- )
-
- self.register_specification(
- ProducerSpec(
- name = "output",
- schema = GraphEmbeddings
- )
- )
-
- async def on_message(self, msg, consumer, flow):
+ async def handle(self, msg):
v = msg.value()
print(f"Indexing {v.metadata.id}...", flush=True)
@@ -58,9 +63,7 @@ class Processor(FlowProcessor):
for entity in v.entities:
- vectors = await flow("embeddings-request").embed(
- text = entity.context
- )
+ vectors = self.embeddings.request(entity.context)
entities.append(
EntityEmbeddings(
@@ -74,7 +77,7 @@ class Processor(FlowProcessor):
entities=entities,
)
- await flow("output").send(r)
+ await self.send(r)
except Exception as e:
print("Exception:", e, flush=True)
@@ -87,9 +90,24 @@ class Processor(FlowProcessor):
@staticmethod
def add_args(parser):
- FlowProcessor.add_args(parser)
+ ConsumerProducer.add_args(
+ parser, default_input_queue, default_subscriber,
+ default_output_queue,
+ )
+
+ parser.add_argument(
+ '--embeddings-request-queue',
+ default=embeddings_request_queue,
+ help=f'Embeddings request queue (default: {embeddings_request_queue})',
+ )
+
+ parser.add_argument(
+ '--embeddings-response-queue',
+ default=embeddings_response_queue,
+ help=f'Embeddings request queue (default: {embeddings_response_queue})',
+ )
def run():
- Processor.launch(default_ident, __doc__)
+ Processor.launch(module, __doc__)
diff --git a/trustgraph-flow/trustgraph/embeddings/ollama/processor.py b/trustgraph-flow/trustgraph/embeddings/ollama/processor.py
index 86787316..c441b9c6 100755
--- a/trustgraph-flow/trustgraph/embeddings/ollama/processor.py
+++ b/trustgraph-flow/trustgraph/embeddings/ollama/processor.py
@@ -11,7 +11,7 @@ from ... base import ConsumerProducer
from ollama import Client
import os
-module = "embeddings"
+module = ".".join(__name__.split(".")[1:-1])
default_input_queue = embeddings_request_queue
default_output_queue = embeddings_response_queue
diff --git a/trustgraph-flow/trustgraph/external/wikipedia/service.py b/trustgraph-flow/trustgraph/external/wikipedia/service.py
index f7de78da..cc002765 100644
--- a/trustgraph-flow/trustgraph/external/wikipedia/service.py
+++ b/trustgraph-flow/trustgraph/external/wikipedia/service.py
@@ -11,7 +11,7 @@ from trustgraph.log_level import LogLevel
from trustgraph.base import ConsumerProducer
import requests
-module = "wikipedia"
+module = ".".join(__name__.split(".")[1:-1])
default_input_queue = encyclopedia_lookup_request_queue
default_output_queue = encyclopedia_lookup_response_queue
diff --git a/trustgraph-flow/trustgraph/extract/kg/definitions/extract.py b/trustgraph-flow/trustgraph/extract/kg/definitions/extract.py
index f95dadf9..47c99802 100755
--- a/trustgraph-flow/trustgraph/extract/kg/definitions/extract.py
+++ b/trustgraph-flow/trustgraph/extract/kg/definitions/extract.py
@@ -5,62 +5,84 @@ get entity definitions which are output as graph edges along with
entity/context definitions for embedding.
"""
-import json
import urllib.parse
+from pulsar.schema import JsonSchema
from .... schema import Chunk, Triple, Triples, Metadata, Value
from .... schema import EntityContext, EntityContexts
-from .... schema import PromptRequest, PromptResponse
+from .... schema import chunk_ingest_queue, triples_store_queue
+from .... schema import entity_contexts_ingest_queue
+from .... schema import prompt_request_queue
+from .... schema import prompt_response_queue
+from .... log_level import LogLevel
+from .... clients.prompt_client import PromptClient
from .... rdf import TRUSTGRAPH_ENTITIES, DEFINITION, RDF_LABEL, SUBJECT_OF
-
-from .... base import FlowProcessor, ConsumerSpec, ProducerSpec
-from .... base import PromptClientSpec
+from .... base import ConsumerProducer
DEFINITION_VALUE = Value(value=DEFINITION, is_uri=True)
RDF_LABEL_VALUE = Value(value=RDF_LABEL, is_uri=True)
SUBJECT_OF_VALUE = Value(value=SUBJECT_OF, is_uri=True)
-default_ident = "kg-extract-definitions"
+module = ".".join(__name__.split(".")[1:-1])
-class Processor(FlowProcessor):
+default_input_queue = chunk_ingest_queue
+default_output_queue = triples_store_queue
+default_entity_context_queue = entity_contexts_ingest_queue
+default_subscriber = module
+
+class Processor(ConsumerProducer):
def __init__(self, **params):
- id = params.get("id")
+ input_queue = params.get("input_queue", default_input_queue)
+ output_queue = params.get("output_queue", default_output_queue)
+ ec_queue = params.get(
+ "entity_context_queue",
+ default_entity_context_queue
+ )
+ subscriber = params.get("subscriber", default_subscriber)
+ pr_request_queue = params.get(
+ "prompt_request_queue", prompt_request_queue
+ )
+ pr_response_queue = params.get(
+ "prompt_response_queue", prompt_response_queue
+ )
super(Processor, self).__init__(
**params | {
- "id": id,
+ "input_queue": input_queue,
+ "output_queue": output_queue,
+ "subscriber": subscriber,
+ "input_schema": Chunk,
+ "output_schema": Triples,
+ "prompt_request_queue": pr_request_queue,
+ "prompt_response_queue": pr_response_queue,
}
)
- self.register_specification(
- ConsumerSpec(
- name = "input",
- schema = Chunk,
- handler = self.on_message
- )
+ self.ec_prod = self.client.create_producer(
+ topic=ec_queue,
+ schema=JsonSchema(EntityContexts),
)
- self.register_specification(
- PromptClientSpec(
- request_name = "prompt-request",
- response_name = "prompt-response",
- )
- )
+ __class__.pubsub_metric.info({
+ "input_queue": input_queue,
+ "output_queue": output_queue,
+ "entity_context_queue": ec_queue,
+ "prompt_request_queue": pr_request_queue,
+ "prompt_response_queue": pr_response_queue,
+ "subscriber": subscriber,
+ "input_schema": Chunk.__name__,
+ "output_schema": Triples.__name__,
+ "vector_schema": EntityContexts.__name__,
+ })
- self.register_specification(
- ProducerSpec(
- name = "triples",
- schema = Triples
- )
- )
-
- self.register_specification(
- ProducerSpec(
- name = "entity-contexts",
- schema = EntityContexts
- )
+ self.prompt = PromptClient(
+ pulsar_host=self.pulsar_host,
+ pulsar_api_key=self.pulsar_api_key,
+ input_queue=pr_request_queue,
+ output_queue=pr_response_queue,
+ subscriber = module + "-prompt",
)
def to_uri(self, text):
@@ -71,47 +93,36 @@ class Processor(FlowProcessor):
return uri
- async def emit_triples(self, pub, metadata, triples):
+ def get_definitions(self, chunk):
+
+ return self.prompt.request_definitions(chunk)
+
+ async def emit_edges(self, metadata, triples):
t = Triples(
metadata=metadata,
triples=triples,
)
- await pub.send(t)
+ await self.send(t)
- async def emit_ecs(self, pub, metadata, entities):
+ async def emit_ecs(self, metadata, entities):
t = EntityContexts(
metadata=metadata,
entities=entities,
)
- await pub.send(t)
+ self.ec_prod.send(t)
- async def on_message(self, msg, consumer, flow):
+ async def handle(self, msg):
v = msg.value()
print(f"Indexing {v.metadata.id}...", flush=True)
chunk = v.chunk.decode("utf-8")
- print(chunk, flush=True)
-
try:
- try:
-
- defs = await flow("prompt-request").extract_definitions(
- text = chunk
- )
-
- print("Response", defs, flush=True)
-
- if type(defs) != list:
- raise RuntimeError("Expecting array in prompt response")
-
- except Exception as e:
- print("Prompt exception:", e, flush=True)
- raise e
+ defs = self.get_definitions(chunk)
triples = []
entities = []
@@ -123,8 +134,8 @@ class Processor(FlowProcessor):
for defn in defs:
- s = defn["entity"]
- o = defn["definition"]
+ s = defn.name
+ o = defn.definition
if s == "": continue
if o == "": continue
@@ -155,13 +166,13 @@ class Processor(FlowProcessor):
ec = EntityContext(
entity=s_value,
- context=defn["definition"],
+ context=defn.definition,
)
entities.append(ec)
+
- await self.emit_triples(
- flow("triples"),
+ await self.emit_edges(
Metadata(
id=v.metadata.id,
metadata=[],
@@ -172,7 +183,6 @@ class Processor(FlowProcessor):
)
await self.emit_ecs(
- flow("entity-contexts"),
Metadata(
id=v.metadata.id,
metadata=[],
@@ -190,9 +200,30 @@ class Processor(FlowProcessor):
@staticmethod
def add_args(parser):
- FlowProcessor.add_args(parser)
+ ConsumerProducer.add_args(
+ parser, default_input_queue, default_subscriber,
+ default_output_queue,
+ )
+
+ parser.add_argument(
+ '-e', '--entity-context-queue',
+ default=default_entity_context_queue,
+ help=f'Entity context queue (default: {default_entity_context_queue})'
+ )
+
+ parser.add_argument(
+ '--prompt-request-queue',
+ default=prompt_request_queue,
+ help=f'Prompt request queue (default: {prompt_request_queue})',
+ )
+
+ parser.add_argument(
+ '--prompt-completion-response-queue',
+ default=prompt_response_queue,
+ help=f'Prompt response queue (default: {prompt_response_queue})',
+ )
def run():
- Processor.launch(default_ident, __doc__)
+ Processor.launch(module, __doc__)
diff --git a/trustgraph-flow/trustgraph/extract/kg/relationships/extract.py b/trustgraph-flow/trustgraph/extract/kg/relationships/extract.py
index ac2929a3..2f293527 100755
--- a/trustgraph-flow/trustgraph/extract/kg/relationships/extract.py
+++ b/trustgraph-flow/trustgraph/extract/kg/relationships/extract.py
@@ -5,54 +5,59 @@ relationship analysis to get entity relationship edges which are output as
graph edges.
"""
-import json
import urllib.parse
from .... schema import Chunk, Triple, Triples
from .... schema import Metadata, Value
-from .... schema import PromptRequest, PromptResponse
+from .... schema import chunk_ingest_queue, triples_store_queue
+from .... schema import prompt_request_queue
+from .... schema import prompt_response_queue
+from .... log_level import LogLevel
+from .... clients.prompt_client import PromptClient
from .... rdf import RDF_LABEL, TRUSTGRAPH_ENTITIES, SUBJECT_OF
-
-from .... base import FlowProcessor, ConsumerSpec, ProducerSpec
-from .... base import PromptClientSpec
+from .... base import ConsumerProducer
RDF_LABEL_VALUE = Value(value=RDF_LABEL, is_uri=True)
SUBJECT_OF_VALUE = Value(value=SUBJECT_OF, is_uri=True)
-default_ident = "kg-extract-relationships"
+module = ".".join(__name__.split(".")[1:-1])
-class Processor(FlowProcessor):
+default_input_queue = chunk_ingest_queue
+default_output_queue = triples_store_queue
+default_subscriber = module
+
+class Processor(ConsumerProducer):
def __init__(self, **params):
- id = params.get("id")
+ input_queue = params.get("input_queue", default_input_queue)
+ output_queue = params.get("output_queue", default_output_queue)
+ subscriber = params.get("subscriber", default_subscriber)
+ pr_request_queue = params.get(
+ "prompt_request_queue", prompt_request_queue
+ )
+ pr_response_queue = params.get(
+ "prompt_response_queue", prompt_response_queue
+ )
super(Processor, self).__init__(
**params | {
- "id": id,
+ "input_queue": input_queue,
+ "output_queue": output_queue,
+ "subscriber": subscriber,
+ "input_schema": Chunk,
+ "output_schema": Triples,
+ "prompt_request_queue": pr_request_queue,
+ "prompt_response_queue": pr_response_queue,
}
)
- self.register_specification(
- ConsumerSpec(
- name = "input",
- schema = Chunk,
- handler = self.on_message
- )
- )
-
- self.register_specification(
- PromptClientSpec(
- request_name = "prompt-request",
- response_name = "prompt-response",
- )
- )
-
- self.register_specification(
- ProducerSpec(
- name = "triples",
- schema = Triples
- )
+ self.prompt = PromptClient(
+ pulsar_host=self.pulsar_host,
+ pulsar_api_key=self.pulsar_api_key,
+ input_queue=pr_request_queue,
+ output_queue=pr_response_queue,
+ subscriber = module + "-prompt",
)
def to_uri(self, text):
@@ -63,39 +68,28 @@ class Processor(FlowProcessor):
return uri
- async def emit_triples(self, pub, metadata, triples):
+ def get_relationships(self, chunk):
+
+ return self.prompt.request_relationships(chunk)
+
+ async def emit_edges(self, metadata, triples):
t = Triples(
metadata=metadata,
triples=triples,
)
- await pub.send(t)
+ await self.send(t)
- async def on_message(self, msg, consumer, flow):
+ async def handle(self, msg):
v = msg.value()
print(f"Indexing {v.metadata.id}...", flush=True)
chunk = v.chunk.decode("utf-8")
- print(chunk, flush=True)
-
try:
- try:
-
- rels = await flow("prompt-request").extract_relationships(
- text = chunk
- )
-
- print("Response", rels, flush=True)
-
- if type(rels) != list:
- raise RuntimeError("Expecting array in prompt response")
-
- except Exception as e:
- print("Prompt exception:", e, flush=True)
- raise e
+ rels = self.get_relationships(chunk)
triples = []
@@ -106,9 +100,9 @@ class Processor(FlowProcessor):
for rel in rels:
- s = rel["subject"]
- p = rel["predicate"]
- o = rel["object"]
+ s = rel.s
+ p = rel.p
+ o = rel.o
if s == "": continue
if p == "": continue
@@ -124,7 +118,7 @@ class Processor(FlowProcessor):
p_uri = self.to_uri(p)
p_value = Value(value=str(p_uri), is_uri=True)
- if rel["object-entity"]:
+ if rel.o_entity:
o_uri = self.to_uri(o)
o_value = Value(value=str(o_uri), is_uri=True)
else:
@@ -150,7 +144,7 @@ class Processor(FlowProcessor):
o=Value(value=str(p), is_uri=False)
))
- if rel["object-entity"]:
+ if rel.o_entity:
# Label for o
triples.append(Triple(
s=o_value,
@@ -165,7 +159,7 @@ class Processor(FlowProcessor):
o=Value(value=v.metadata.id, is_uri=True)
))
- if rel["object-entity"]:
+ if rel.o_entity:
# 'Subject of' for o
triples.append(Triple(
s=o_value,
@@ -174,7 +168,6 @@ class Processor(FlowProcessor):
))
await self.emit_edges(
- flow("triples"),
Metadata(
id=v.metadata.id,
metadata=[],
@@ -192,9 +185,24 @@ class Processor(FlowProcessor):
@staticmethod
def add_args(parser):
- FlowProcessor.add_args(parser)
+ ConsumerProducer.add_args(
+ parser, default_input_queue, default_subscriber,
+ default_output_queue,
+ )
+
+ parser.add_argument(
+ '--prompt-request-queue',
+ default=prompt_request_queue,
+ help=f'Prompt request queue (default: {prompt_request_queue})',
+ )
+
+ parser.add_argument(
+ '--prompt-response-queue',
+ default=prompt_response_queue,
+ help=f'Prompt response queue (default: {prompt_response_queue})',
+ )
def run():
- Processor.launch(default_ident, __doc__)
+ Processor.launch(module, __doc__)
diff --git a/trustgraph-flow/trustgraph/extract/kg/topics/extract.py b/trustgraph-flow/trustgraph/extract/kg/topics/extract.py
index 84ab6681..7424abe2 100755
--- a/trustgraph-flow/trustgraph/extract/kg/topics/extract.py
+++ b/trustgraph-flow/trustgraph/extract/kg/topics/extract.py
@@ -18,7 +18,7 @@ from .... base import ConsumerProducer
DEFINITION_VALUE = Value(value=DEFINITION, is_uri=True)
-module = "kg-extract-topics"
+module = ".".join(__name__.split(".")[1:-1])
default_input_queue = chunk_ingest_queue
default_output_queue = triples_store_queue
diff --git a/trustgraph-flow/trustgraph/gateway/agent.py b/trustgraph-flow/trustgraph/gateway/agent.py
index 5a54931b..150b970e 100644
--- a/trustgraph-flow/trustgraph/gateway/agent.py
+++ b/trustgraph-flow/trustgraph/gateway/agent.py
@@ -39,3 +39,4 @@ class AgentRequestor(ServiceRequestor):
# The 2nd boolean expression indicates whether we're done responding
return resp, (message.answer is not None)
+
diff --git a/trustgraph-flow/trustgraph/gateway/document_embeddings_load.py b/trustgraph-flow/trustgraph/gateway/document_embeddings_load.py
index bbfb51a3..6b4b4838 100644
--- a/trustgraph-flow/trustgraph/gateway/document_embeddings_load.py
+++ b/trustgraph-flow/trustgraph/gateway/document_embeddings_load.py
@@ -1,5 +1,6 @@
import asyncio
+from pulsar.schema import JsonSchema
import uuid
from aiohttp import WSMsgType
@@ -25,12 +26,12 @@ class DocumentEmbeddingsLoadEndpoint(SocketEndpoint):
self.publisher = Publisher(
self.pulsar_client, document_embeddings_store_queue,
- schema=DocumentEmbeddings
+ schema=JsonSchema(DocumentEmbeddings)
)
async def start(self):
- await self.publisher.start()
+ self.publisher.start()
async def listener(self, ws, running):
@@ -58,6 +59,6 @@ class DocumentEmbeddingsLoadEndpoint(SocketEndpoint):
],
)
- await self.publisher.send(None, elt)
+ self.publisher.send(None, elt)
running.stop()
diff --git a/trustgraph-flow/trustgraph/gateway/document_embeddings_stream.py b/trustgraph-flow/trustgraph/gateway/document_embeddings_stream.py
index e59a0370..6d7db576 100644
--- a/trustgraph-flow/trustgraph/gateway/document_embeddings_stream.py
+++ b/trustgraph-flow/trustgraph/gateway/document_embeddings_stream.py
@@ -1,6 +1,7 @@
import asyncio
import queue
+from pulsar.schema import JsonSchema
import uuid
from .. schema import DocumentEmbeddings
@@ -26,7 +27,7 @@ class DocumentEmbeddingsStreamEndpoint(SocketEndpoint):
self.subscriber = Subscriber(
self.pulsar_client, document_embeddings_store_queue,
"api-gateway", "api-gateway",
- schema=DocumentEmbeddings,
+ schema=JsonSchema(DocumentEmbeddings),
)
async def listener(self, ws, running):
@@ -43,17 +44,17 @@ class DocumentEmbeddingsStreamEndpoint(SocketEndpoint):
async def start(self):
- await self.subscriber.start()
+ self.subscriber.start()
async def async_thread(self, ws, running):
id = str(uuid.uuid4())
- q = await self.subscriber.subscribe_all(id)
+ q = self.subscriber.subscribe_all(id)
while running.get():
try:
- resp = await asyncio.wait_for(q.get(), timeout=0.5)
+ resp = await asyncio.to_thread(q.get, timeout=0.5)
await ws.send_json(serialize_document_embeddings(resp))
except TimeoutError:
@@ -66,7 +67,7 @@ class DocumentEmbeddingsStreamEndpoint(SocketEndpoint):
print(f"Exception: {str(e)}", flush=True)
break
- await self.subscriber.unsubscribe_all(id)
+ self.subscriber.unsubscribe_all(id)
running.stop()
diff --git a/trustgraph-flow/trustgraph/gateway/endpoint.py b/trustgraph-flow/trustgraph/gateway/endpoint.py
index 94980e8b..5005463c 100644
--- a/trustgraph-flow/trustgraph/gateway/endpoint.py
+++ b/trustgraph-flow/trustgraph/gateway/endpoint.py
@@ -1,9 +1,13 @@
import asyncio
+from pulsar.schema import JsonSchema
from aiohttp import web
import uuid
import logging
+from .. base import Publisher
+from .. base import Subscriber
+
logger = logging.getLogger("endpoint")
logger.setLevel(logging.INFO)
diff --git a/trustgraph-flow/trustgraph/gateway/flow.py b/trustgraph-flow/trustgraph/gateway/flow.py
deleted file mode 100644
index c666d99c..00000000
--- a/trustgraph-flow/trustgraph/gateway/flow.py
+++ /dev/null
@@ -1,51 +0,0 @@
-
-from .. schema import FlowRequest, FlowResponse, ConfigKey, ConfigValue
-from .. schema import flow_request_queue
-from .. schema import flow_response_queue
-
-from . endpoint import ServiceEndpoint
-from . requestor import ServiceRequestor
-
-class FlowRequestor(ServiceRequestor):
- def __init__(self, pulsar_client, timeout, auth):
-
- super(FlowRequestor, self).__init__(
- pulsar_client=pulsar_client,
- request_queue=flow_request_queue,
- response_queue=flow_response_queue,
- request_schema=FlowRequest,
- response_schema=FlowResponse,
- timeout=timeout,
- )
-
- def to_request(self, body):
-
- return FlowRequest(
- operation = body.get("operation", None),
- class_name = body.get("class-name", None),
- class_definition = body.get("class-definition", None),
- description = body.get("description", None),
- flow_id = body.get("flow-id", None),
- )
-
- def from_response(self, message):
-
- response = { }
-
- if message.class_names is not None:
- response["class-names"] = message.class_names
-
- if message.flow_ids is not None:
- response["flow-ids"] = message.flow_ids
-
- if message.class_definition is not None:
- response["class-definition"] = message.class_definition
-
- if message.flow is not None:
- response["flow"] = message.flow
-
- if message.description is not None:
- response["description"] = message.description
-
- return response, True
-
diff --git a/trustgraph-flow/trustgraph/gateway/graph_embeddings_load.py b/trustgraph-flow/trustgraph/gateway/graph_embeddings_load.py
index 27e92a30..c1354ce5 100644
--- a/trustgraph-flow/trustgraph/gateway/graph_embeddings_load.py
+++ b/trustgraph-flow/trustgraph/gateway/graph_embeddings_load.py
@@ -1,5 +1,6 @@
import asyncio
+from pulsar.schema import JsonSchema
import uuid
from aiohttp import WSMsgType
@@ -25,12 +26,12 @@ class GraphEmbeddingsLoadEndpoint(SocketEndpoint):
self.publisher = Publisher(
self.pulsar_client, graph_embeddings_store_queue,
- schema=GraphEmbeddings
+ schema=JsonSchema(GraphEmbeddings)
)
async def start(self):
- await self.publisher.start()
+ self.publisher.start()
async def listener(self, ws, running):
@@ -59,6 +60,6 @@ class GraphEmbeddingsLoadEndpoint(SocketEndpoint):
]
)
- await self.publisher.send(None, elt)
+ self.publisher.send(None, elt)
running.stop()
diff --git a/trustgraph-flow/trustgraph/gateway/graph_embeddings_stream.py b/trustgraph-flow/trustgraph/gateway/graph_embeddings_stream.py
index 37edc2bb..385eb9f4 100644
--- a/trustgraph-flow/trustgraph/gateway/graph_embeddings_stream.py
+++ b/trustgraph-flow/trustgraph/gateway/graph_embeddings_stream.py
@@ -1,6 +1,7 @@
import asyncio
import queue
+from pulsar.schema import JsonSchema
import uuid
from .. schema import GraphEmbeddings
@@ -25,7 +26,7 @@ class GraphEmbeddingsStreamEndpoint(SocketEndpoint):
self.subscriber = Subscriber(
self.pulsar_client, graph_embeddings_store_queue,
"api-gateway", "api-gateway",
- schema=GraphEmbeddings
+ schema=JsonSchema(GraphEmbeddings)
)
async def listener(self, ws, running):
@@ -40,17 +41,17 @@ class GraphEmbeddingsStreamEndpoint(SocketEndpoint):
async def start(self):
- await self.subscriber.start()
+ self.subscriber.start()
async def async_thread(self, ws, running):
id = str(uuid.uuid4())
- q = await self.subscriber.subscribe_all(id)
+ q = self.subscriber.subscribe_all(id)
while running.get():
try:
- resp = await asyncio.wait_for(q.get, timeout=0.5)
+ resp = await asyncio.to_thread(q.get, timeout=0.5)
await ws.send_json(serialize_graph_embeddings(resp))
except TimeoutError:
@@ -63,7 +64,7 @@ class GraphEmbeddingsStreamEndpoint(SocketEndpoint):
print(f"Exception: {str(e)}", flush=True)
break
- await self.subscriber.unsubscribe_all(id)
+ self.subscriber.unsubscribe_all(id)
running.stop()
diff --git a/trustgraph-flow/trustgraph/gateway/metrics.py b/trustgraph-flow/trustgraph/gateway/metrics.py
index d8a1ef62..33c1fe3a 100644
--- a/trustgraph-flow/trustgraph/gateway/metrics.py
+++ b/trustgraph-flow/trustgraph/gateway/metrics.py
@@ -7,6 +7,7 @@
import aiohttp
from aiohttp import web
import asyncio
+from pulsar.schema import JsonSchema
import uuid
import logging
diff --git a/trustgraph-flow/trustgraph/gateway/mux.py b/trustgraph-flow/trustgraph/gateway/mux.py
index 1afc3225..23b693ab 100644
--- a/trustgraph-flow/trustgraph/gateway/mux.py
+++ b/trustgraph-flow/trustgraph/gateway/mux.py
@@ -1,10 +1,12 @@
import asyncio
import queue
+from pulsar.schema import JsonSchema
import uuid
from aiohttp import web, WSMsgType
from . socket import SocketEndpoint
+from . text_completion import TextCompletionRequestor
MAX_OUTSTANDING_REQUESTS = 15
WORKER_CLOSE_WAIT = 0.01
diff --git a/trustgraph-flow/trustgraph/gateway/requestor.py b/trustgraph-flow/trustgraph/gateway/requestor.py
index 63395203..dc74667d 100644
--- a/trustgraph-flow/trustgraph/gateway/requestor.py
+++ b/trustgraph-flow/trustgraph/gateway/requestor.py
@@ -1,5 +1,6 @@
import asyncio
+from pulsar.schema import JsonSchema
import uuid
import logging
@@ -22,21 +23,21 @@ class ServiceRequestor:
self.pub = Publisher(
pulsar_client, request_queue,
- schema=request_schema,
+ schema=JsonSchema(request_schema),
)
self.sub = Subscriber(
pulsar_client, response_queue,
subscription, consumer_name,
- response_schema
+ JsonSchema(response_schema)
)
self.timeout = timeout
async def start(self):
- await self.pub.start()
- await self.sub.start()
+ self.pub.start()
+ self.sub.start()
def to_request(self, request):
raise RuntimeError("Not defined")
@@ -50,15 +51,18 @@ class ServiceRequestor:
try:
- q = await self.sub.subscribe(id)
+ q = self.sub.subscribe(id)
- await self.pub.send(id, self.to_request(request))
+ await asyncio.to_thread(
+ self.pub.send, id, self.to_request(request)
+ )
while True:
try:
- resp = await asyncio.wait_for(
- q.get(), timeout=self.timeout
+ resp = await asyncio.to_thread(
+ q.get,
+ timeout=self.timeout
)
except Exception as e:
raise RuntimeError("Timeout")
@@ -95,5 +99,5 @@ class ServiceRequestor:
return err
finally:
- await self.sub.unsubscribe(id)
+ self.sub.unsubscribe(id)
diff --git a/trustgraph-flow/trustgraph/gateway/sender.py b/trustgraph-flow/trustgraph/gateway/sender.py
index 81b64e6d..32c586b1 100644
--- a/trustgraph-flow/trustgraph/gateway/sender.py
+++ b/trustgraph-flow/trustgraph/gateway/sender.py
@@ -2,6 +2,7 @@
# Like ServiceRequestor, but just fire-and-forget instead of request/response
import asyncio
+from pulsar.schema import JsonSchema
import uuid
import logging
@@ -20,12 +21,12 @@ class ServiceSender:
self.pub = Publisher(
pulsar_client, request_queue,
- schema=request_schema,
+ schema=JsonSchema(request_schema),
)
async def start(self):
- await self.pub.start()
+ self.pub.start()
def to_request(self, request):
raise RuntimeError("Not defined")
@@ -34,7 +35,9 @@ class ServiceSender:
try:
- await self.pub.send(None, self.to_request(request))
+ await asyncio.to_thread(
+ self.pub.send, None, self.to_request(request)
+ )
if responder:
await responder({}, True)
diff --git a/trustgraph-flow/trustgraph/gateway/service.py b/trustgraph-flow/trustgraph/gateway/service.py
index d7df3240..e997f83e 100755
--- a/trustgraph-flow/trustgraph/gateway/service.py
+++ b/trustgraph-flow/trustgraph/gateway/service.py
@@ -3,7 +3,7 @@ API gateway. Offers HTTP services which are translated to interaction on the
Pulsar bus.
"""
-module = "api-gateway"
+module = ".".join(__name__.split(".")[1:-1])
# FIXME: Subscribes to Pulsar unnecessarily, should only do it when there
# are active listeners
@@ -19,36 +19,35 @@ import os
import base64
import pulsar
+from pulsar.schema import JsonSchema
from prometheus_client import start_http_server
from .. log_level import LogLevel
from . serialize import to_subgraph
from . running import Running
-
-#from . text_completion import TextCompletionRequestor
-#from . prompt import PromptRequestor
-#from . graph_rag import GraphRagRequestor
-#from . document_rag import DocumentRagRequestor
-#from . triples_query import TriplesQueryRequestor
-#from . graph_embeddings_query import GraphEmbeddingsQueryRequestor
-#from . embeddings import EmbeddingsRequestor
-#from . encyclopedia import EncyclopediaRequestor
-#from . agent import AgentRequestor
-#from . dbpedia import DbpediaRequestor
-#from . internet_search import InternetSearchRequestor
-#from . librarian import LibrarianRequestor
+from . text_completion import TextCompletionRequestor
+from . prompt import PromptRequestor
+from . graph_rag import GraphRagRequestor
+from . document_rag import DocumentRagRequestor
+from . triples_query import TriplesQueryRequestor
+from . graph_embeddings_query import GraphEmbeddingsQueryRequestor
+from . embeddings import EmbeddingsRequestor
+from . encyclopedia import EncyclopediaRequestor
+from . agent import AgentRequestor
+from . dbpedia import DbpediaRequestor
+from . internet_search import InternetSearchRequestor
+from . librarian import LibrarianRequestor
from . config import ConfigRequestor
-from . flow import FlowRequestor
-#from . triples_stream import TriplesStreamEndpoint
-#from . graph_embeddings_stream import GraphEmbeddingsStreamEndpoint
-#from . document_embeddings_stream import DocumentEmbeddingsStreamEndpoint
-#from . triples_load import TriplesLoadEndpoint
-#from . graph_embeddings_load import GraphEmbeddingsLoadEndpoint
-#from . document_embeddings_load import DocumentEmbeddingsLoadEndpoint
+from . triples_stream import TriplesStreamEndpoint
+from . graph_embeddings_stream import GraphEmbeddingsStreamEndpoint
+from . document_embeddings_stream import DocumentEmbeddingsStreamEndpoint
+from . triples_load import TriplesLoadEndpoint
+from . graph_embeddings_load import GraphEmbeddingsLoadEndpoint
+from . document_embeddings_load import DocumentEmbeddingsLoadEndpoint
from . mux import MuxEndpoint
-#from . document_load import DocumentLoadSender
-#from . text_load import TextLoadSender
+from . document_load import DocumentLoadSender
+from . text_load import TextLoadSender
from . metrics import MetricsEndpoint
from . endpoint import ServiceEndpoint
@@ -107,165 +106,157 @@ class Api:
self.auth = Authenticator(allow_all=True)
self.services = {
- # "text-completion": TextCompletionRequestor(
- # pulsar_client=self.pulsar_client, timeout=self.timeout,
- # auth = self.auth,
- # ),
- # "prompt": PromptRequestor(
- # pulsar_client=self.pulsar_client, timeout=self.timeout,
- # auth = self.auth,
- # ),
- # "graph-rag": GraphRagRequestor(
- # pulsar_client=self.pulsar_client, timeout=self.timeout,
- # auth = self.auth,
- # ),
- # "document-rag": DocumentRagRequestor(
- # pulsar_client=self.pulsar_client, timeout=self.timeout,
- # auth = self.auth,
- # ),
- # "triples-query": TriplesQueryRequestor(
- # pulsar_client=self.pulsar_client, timeout=self.timeout,
- # auth = self.auth,
- # ),
- # "graph-embeddings-query": GraphEmbeddingsQueryRequestor(
- # pulsar_client=self.pulsar_client, timeout=self.timeout,
- # auth = self.auth,
- # ),
- # "embeddings": EmbeddingsRequestor(
- # pulsar_client=self.pulsar_client, timeout=self.timeout,
- # auth = self.auth,
- # ),
- # "agent": AgentRequestor(
- # pulsar_client=self.pulsar_client, timeout=self.timeout,
- # auth = self.auth,
- # ),
- # "librarian": LibrarianRequestor(
- # pulsar_client=self.pulsar_client, timeout=self.timeout,
- # auth = self.auth,
- # ),
+ "text-completion": TextCompletionRequestor(
+ pulsar_client=self.pulsar_client, timeout=self.timeout,
+ auth = self.auth,
+ ),
+ "prompt": PromptRequestor(
+ pulsar_client=self.pulsar_client, timeout=self.timeout,
+ auth = self.auth,
+ ),
+ "graph-rag": GraphRagRequestor(
+ pulsar_client=self.pulsar_client, timeout=self.timeout,
+ auth = self.auth,
+ ),
+ "document-rag": DocumentRagRequestor(
+ pulsar_client=self.pulsar_client, timeout=self.timeout,
+ auth = self.auth,
+ ),
+ "triples-query": TriplesQueryRequestor(
+ pulsar_client=self.pulsar_client, timeout=self.timeout,
+ auth = self.auth,
+ ),
+ "graph-embeddings-query": GraphEmbeddingsQueryRequestor(
+ pulsar_client=self.pulsar_client, timeout=self.timeout,
+ auth = self.auth,
+ ),
+ "embeddings": EmbeddingsRequestor(
+ pulsar_client=self.pulsar_client, timeout=self.timeout,
+ auth = self.auth,
+ ),
+ "agent": AgentRequestor(
+ pulsar_client=self.pulsar_client, timeout=self.timeout,
+ auth = self.auth,
+ ),
+ "librarian": LibrarianRequestor(
+ pulsar_client=self.pulsar_client, timeout=self.timeout,
+ auth = self.auth,
+ ),
"config": ConfigRequestor(
pulsar_client=self.pulsar_client, timeout=self.timeout,
auth = self.auth,
),
- "flow": FlowRequestor(
+ "encyclopedia": EncyclopediaRequestor(
pulsar_client=self.pulsar_client, timeout=self.timeout,
auth = self.auth,
),
- # "encyclopedia": EncyclopediaRequestor(
- # pulsar_client=self.pulsar_client, timeout=self.timeout,
- # auth = self.auth,
- # ),
- # "dbpedia": DbpediaRequestor(
- # pulsar_client=self.pulsar_client, timeout=self.timeout,
- # auth = self.auth,
- # ),
- # "internet-search": InternetSearchRequestor(
- # pulsar_client=self.pulsar_client, timeout=self.timeout,
- # auth = self.auth,
- # ),
- # "document-load": DocumentLoadSender(
- # pulsar_client=self.pulsar_client,
- # ),
- # "text-load": TextLoadSender(
- # pulsar_client=self.pulsar_client,
- # ),
+ "dbpedia": DbpediaRequestor(
+ pulsar_client=self.pulsar_client, timeout=self.timeout,
+ auth = self.auth,
+ ),
+ "internet-search": InternetSearchRequestor(
+ pulsar_client=self.pulsar_client, timeout=self.timeout,
+ auth = self.auth,
+ ),
+ "document-load": DocumentLoadSender(
+ pulsar_client=self.pulsar_client,
+ ),
+ "text-load": TextLoadSender(
+ pulsar_client=self.pulsar_client,
+ ),
}
self.endpoints = [
- # ServiceEndpoint(
- # endpoint_path = "/api/v1/text-completion", auth=self.auth,
- # requestor = self.services["text-completion"],
- # ),
- # ServiceEndpoint(
- # endpoint_path = "/api/v1/prompt", auth=self.auth,
- # requestor = self.services["prompt"],
- # ),
- # ServiceEndpoint(
- # endpoint_path = "/api/v1/graph-rag", auth=self.auth,
- # requestor = self.services["graph-rag"],
- # ),
- # ServiceEndpoint(
- # endpoint_path = "/api/v1/document-rag", auth=self.auth,
- # requestor = self.services["document-rag"],
- # ),
- # ServiceEndpoint(
- # endpoint_path = "/api/v1/triples-query", auth=self.auth,
- # requestor = self.services["triples-query"],
- # ),
- # ServiceEndpoint(
- # endpoint_path = "/api/v1/graph-embeddings-query",
- # auth=self.auth,
- # requestor = self.services["graph-embeddings-query"],
- # ),
- # ServiceEndpoint(
- # endpoint_path = "/api/v1/embeddings", auth=self.auth,
- # requestor = self.services["embeddings"],
- # ),
- # ServiceEndpoint(
- # endpoint_path = "/api/v1/agent", auth=self.auth,
- # requestor = self.services["agent"],
- # ),
- # ServiceEndpoint(
- # endpoint_path = "/api/v1/librarian", auth=self.auth,
- # requestor = self.services["librarian"],
- # ),
+ ServiceEndpoint(
+ endpoint_path = "/api/v1/text-completion", auth=self.auth,
+ requestor = self.services["text-completion"],
+ ),
+ ServiceEndpoint(
+ endpoint_path = "/api/v1/prompt", auth=self.auth,
+ requestor = self.services["prompt"],
+ ),
+ ServiceEndpoint(
+ endpoint_path = "/api/v1/graph-rag", auth=self.auth,
+ requestor = self.services["graph-rag"],
+ ),
+ ServiceEndpoint(
+ endpoint_path = "/api/v1/document-rag", auth=self.auth,
+ requestor = self.services["document-rag"],
+ ),
+ ServiceEndpoint(
+ endpoint_path = "/api/v1/triples-query", auth=self.auth,
+ requestor = self.services["triples-query"],
+ ),
+ ServiceEndpoint(
+ endpoint_path = "/api/v1/graph-embeddings-query",
+ auth=self.auth,
+ requestor = self.services["graph-embeddings-query"],
+ ),
+ ServiceEndpoint(
+ endpoint_path = "/api/v1/embeddings", auth=self.auth,
+ requestor = self.services["embeddings"],
+ ),
+ ServiceEndpoint(
+ endpoint_path = "/api/v1/agent", auth=self.auth,
+ requestor = self.services["agent"],
+ ),
+ ServiceEndpoint(
+ endpoint_path = "/api/v1/librarian", auth=self.auth,
+ requestor = self.services["librarian"],
+ ),
ServiceEndpoint(
endpoint_path = "/api/v1/config", auth=self.auth,
requestor = self.services["config"],
),
ServiceEndpoint(
- endpoint_path = "/api/v1/flow", auth=self.auth,
- requestor = self.services["flow"],
+ endpoint_path = "/api/v1/encyclopedia", auth=self.auth,
+ requestor = self.services["encyclopedia"],
+ ),
+ ServiceEndpoint(
+ endpoint_path = "/api/v1/dbpedia", auth=self.auth,
+ requestor = self.services["dbpedia"],
+ ),
+ ServiceEndpoint(
+ endpoint_path = "/api/v1/internet-search", auth=self.auth,
+ requestor = self.services["internet-search"],
+ ),
+ ServiceEndpoint(
+ endpoint_path = "/api/v1/load/document", auth=self.auth,
+ requestor = self.services["document-load"],
+ ),
+ ServiceEndpoint(
+ endpoint_path = "/api/v1/load/text", auth=self.auth,
+ requestor = self.services["text-load"],
+ ),
+ TriplesStreamEndpoint(
+ pulsar_client=self.pulsar_client,
+ auth = self.auth,
+ ),
+ GraphEmbeddingsStreamEndpoint(
+ pulsar_client=self.pulsar_client,
+ auth = self.auth,
+ ),
+ DocumentEmbeddingsStreamEndpoint(
+ pulsar_client=self.pulsar_client,
+ auth = self.auth,
+ ),
+ TriplesLoadEndpoint(
+ pulsar_client=self.pulsar_client,
+ auth = self.auth,
+ ),
+ GraphEmbeddingsLoadEndpoint(
+ pulsar_client=self.pulsar_client,
+ auth = self.auth,
+ ),
+ DocumentEmbeddingsLoadEndpoint(
+ pulsar_client=self.pulsar_client,
+ auth = self.auth,
+ ),
+ MuxEndpoint(
+ pulsar_client=self.pulsar_client,
+ auth = self.auth,
+ services = self.services,
),
- # ServiceEndpoint(
- # endpoint_path = "/api/v1/encyclopedia", auth=self.auth,
- # requestor = self.services["encyclopedia"],
- # ),
- # ServiceEndpoint(
- # endpoint_path = "/api/v1/dbpedia", auth=self.auth,
- # requestor = self.services["dbpedia"],
- # ),
- # ServiceEndpoint(
- # endpoint_path = "/api/v1/internet-search", auth=self.auth,
- # requestor = self.services["internet-search"],
- # ),
- # ServiceEndpoint(
- # endpoint_path = "/api/v1/load/document", auth=self.auth,
- # requestor = self.services["document-load"],
- # ),
- # ServiceEndpoint(
- # endpoint_path = "/api/v1/load/text", auth=self.auth,
- # requestor = self.services["text-load"],
- # ),
- # TriplesStreamEndpoint(
- # pulsar_client=self.pulsar_client,
- # auth = self.auth,
- # ),
- # GraphEmbeddingsStreamEndpoint(
- # pulsar_client=self.pulsar_client,
- # auth = self.auth,
- # ),
- # DocumentEmbeddingsStreamEndpoint(
- # pulsar_client=self.pulsar_client,
- # auth = self.auth,
- # ),
- # TriplesLoadEndpoint(
- # pulsar_client=self.pulsar_client,
- # auth = self.auth,
- # ),
- # GraphEmbeddingsLoadEndpoint(
- # pulsar_client=self.pulsar_client,
- # auth = self.auth,
- # ),
- # DocumentEmbeddingsLoadEndpoint(
- # pulsar_client=self.pulsar_client,
- # auth = self.auth,
- # ),
- # MuxEndpoint(
- # pulsar_client=self.pulsar_client,
- # auth = self.auth,
- # services = self.services,
- # ),
MetricsEndpoint(
endpoint_path = "/api/v1/metrics",
prometheus_url = self.prometheus_url,
diff --git a/trustgraph-flow/trustgraph/gateway/triples_load.py b/trustgraph-flow/trustgraph/gateway/triples_load.py
index 81c8ea82..bc69975e 100644
--- a/trustgraph-flow/trustgraph/gateway/triples_load.py
+++ b/trustgraph-flow/trustgraph/gateway/triples_load.py
@@ -1,5 +1,6 @@
import asyncio
+from pulsar.schema import JsonSchema
import uuid
from aiohttp import WSMsgType
@@ -23,12 +24,12 @@ class TriplesLoadEndpoint(SocketEndpoint):
self.publisher = Publisher(
self.pulsar_client, triples_store_queue,
- schema=Triples
+ schema=JsonSchema(Triples)
)
async def start(self):
- await self.publisher.start()
+ self.publisher.start()
async def listener(self, ws, running):
@@ -50,7 +51,7 @@ class TriplesLoadEndpoint(SocketEndpoint):
triples=to_subgraph(data["triples"]),
)
- await self.publisher.send(None, elt)
+ self.publisher.send(None, elt)
running.stop()
diff --git a/trustgraph-flow/trustgraph/gateway/triples_stream.py b/trustgraph-flow/trustgraph/gateway/triples_stream.py
index a660591e..a5d5ad0a 100644
--- a/trustgraph-flow/trustgraph/gateway/triples_stream.py
+++ b/trustgraph-flow/trustgraph/gateway/triples_stream.py
@@ -1,6 +1,7 @@
import asyncio
import queue
+from pulsar.schema import JsonSchema
import uuid
from .. schema import Triples
@@ -23,7 +24,7 @@ class TriplesStreamEndpoint(SocketEndpoint):
self.subscriber = Subscriber(
self.pulsar_client, triples_store_queue,
"api-gateway", "api-gateway",
- schema=Triples
+ schema=JsonSchema(Triples)
)
async def listener(self, ws, running):
@@ -38,7 +39,7 @@ class TriplesStreamEndpoint(SocketEndpoint):
async def start(self):
- await self.subscriber.start()
+ self.subscriber.start()
async def async_thread(self, ws, running):
diff --git a/trustgraph-flow/trustgraph/graph_rag.py b/trustgraph-flow/trustgraph/graph_rag.py
new file mode 100644
index 00000000..6a4e11c5
--- /dev/null
+++ b/trustgraph-flow/trustgraph/graph_rag.py
@@ -0,0 +1,295 @@
+
+from . clients.graph_embeddings_client import GraphEmbeddingsClient
+from . clients.triples_query_client import TriplesQueryClient
+from . clients.embeddings_client import EmbeddingsClient
+from . clients.prompt_client import PromptClient
+
+from . schema import GraphEmbeddingsRequest, GraphEmbeddingsResponse
+from . schema import TriplesQueryRequest, TriplesQueryResponse
+from . schema import prompt_request_queue
+from . schema import prompt_response_queue
+from . schema import embeddings_request_queue
+from . schema import embeddings_response_queue
+from . schema import graph_embeddings_request_queue
+from . schema import graph_embeddings_response_queue
+from . schema import triples_request_queue
+from . schema import triples_response_queue
+
+LABEL="http://www.w3.org/2000/01/rdf-schema#label"
+DEFINITION="http://www.w3.org/2004/02/skos/core#definition"
+
+class Query:
+
+ def __init__(
+ self, rag, user, collection, verbose,
+ entity_limit=50, triple_limit=30, max_subgraph_size=1000,
+ max_path_length=2,
+ ):
+ self.rag = rag
+ self.user = user
+ self.collection = collection
+ self.verbose = verbose
+ self.entity_limit = entity_limit
+ self.triple_limit = triple_limit
+ self.max_subgraph_size = max_subgraph_size
+ self.max_path_length = max_path_length
+
+ def get_vector(self, query):
+
+ if self.verbose:
+ print("Compute embeddings...", flush=True)
+
+ qembeds = self.rag.embeddings.request(query)
+
+ if self.verbose:
+ print("Done.", flush=True)
+
+ return qembeds
+
+ def get_entities(self, query):
+
+ vectors = self.get_vector(query)
+
+ if self.verbose:
+ print("Get entities...", flush=True)
+
+ entities = self.rag.ge_client.request(
+ user=self.user, collection=self.collection,
+ vectors=vectors, limit=self.entity_limit,
+ )
+
+ entities = [
+ e.value
+ for e in entities
+ ]
+
+ if self.verbose:
+ print("Entities:", flush=True)
+ for ent in entities:
+ print(" ", ent, flush=True)
+
+ return entities
+
+ def maybe_label(self, e):
+
+ if e in self.rag.label_cache:
+ return self.rag.label_cache[e]
+
+ res = self.rag.triples_client.request(
+ user=self.user, collection=self.collection,
+ s=e, p=LABEL, o=None, limit=1,
+ )
+
+ if len(res) == 0:
+ self.rag.label_cache[e] = e
+ return e
+
+ self.rag.label_cache[e] = res[0].o.value
+ return self.rag.label_cache[e]
+
+ def follow_edges(self, ent, subgraph, path_length):
+
+ # Not needed?
+ if path_length <= 0:
+ return
+
+ # Stop spanning around if the subgraph is already maxed out
+ if len(subgraph) >= self.max_subgraph_size:
+ return
+
+ res = self.rag.triples_client.request(
+ user=self.user, collection=self.collection,
+ s=ent, p=None, o=None,
+ limit=self.triple_limit
+ )
+
+ for triple in res:
+ subgraph.add(
+ (triple.s.value, triple.p.value, triple.o.value)
+ )
+ if path_length > 1:
+ self.follow_edges(triple.o.value, subgraph, path_length-1)
+
+ res = self.rag.triples_client.request(
+ user=self.user, collection=self.collection,
+ s=None, p=ent, o=None,
+ limit=self.triple_limit
+ )
+
+ for triple in res:
+ subgraph.add(
+ (triple.s.value, triple.p.value, triple.o.value)
+ )
+
+ res = self.rag.triples_client.request(
+ user=self.user, collection=self.collection,
+ s=None, p=None, o=ent,
+ limit=self.triple_limit,
+ )
+
+ for triple in res:
+ subgraph.add(
+ (triple.s.value, triple.p.value, triple.o.value)
+ )
+ if path_length > 1:
+ self.follow_edges(triple.s.value, subgraph, path_length-1)
+
+ def get_subgraph(self, query):
+
+ entities = self.get_entities(query)
+
+ if self.verbose:
+ print("Get subgraph...", flush=True)
+
+ subgraph = set()
+
+ for ent in entities:
+ self.follow_edges(ent, subgraph, self.max_path_length)
+
+ subgraph = list(subgraph)
+
+ return subgraph
+
+ def get_labelgraph(self, query):
+
+ subgraph = self.get_subgraph(query)
+
+ sg2 = []
+
+ for edge in subgraph:
+
+ if edge[1] == LABEL:
+ continue
+
+ s = self.maybe_label(edge[0])
+ p = self.maybe_label(edge[1])
+ o = self.maybe_label(edge[2])
+
+ sg2.append((s, p, o))
+
+ sg2 = sg2[0:self.max_subgraph_size]
+
+ if self.verbose:
+ print("Subgraph:", flush=True)
+ for edge in sg2:
+ print(" ", str(edge), flush=True)
+
+ if self.verbose:
+ print("Done.", flush=True)
+
+ return sg2
+
+class GraphRag:
+
+ def __init__(
+ self,
+ pulsar_host="pulsar://pulsar:6650",
+ pulsar_api_key=None,
+ pr_request_queue=None,
+ pr_response_queue=None,
+ emb_request_queue=None,
+ emb_response_queue=None,
+ ge_request_queue=None,
+ ge_response_queue=None,
+ tpl_request_queue=None,
+ tpl_response_queue=None,
+ verbose=False,
+ module="test",
+ ):
+
+ self.verbose=verbose
+
+ if pr_request_queue is None:
+ pr_request_queue = prompt_request_queue
+
+ if pr_response_queue is None:
+ pr_response_queue = prompt_response_queue
+
+ if emb_request_queue is None:
+ emb_request_queue = embeddings_request_queue
+
+ if emb_response_queue is None:
+ emb_response_queue = embeddings_response_queue
+
+ if ge_request_queue is None:
+ ge_request_queue = graph_embeddings_request_queue
+
+ if ge_response_queue is None:
+ ge_response_queue = graph_embeddings_response_queue
+
+ if tpl_request_queue is None:
+ tpl_request_queue = triples_request_queue
+
+ if tpl_response_queue is None:
+ tpl_response_queue = triples_response_queue
+
+ if self.verbose:
+ print("Initialising...", flush=True)
+
+ self.ge_client = GraphEmbeddingsClient(
+ pulsar_host=pulsar_host,
+ pulsar_api_key=pulsar_api_key,
+ subscriber=module + "-ge",
+ input_queue=ge_request_queue,
+ output_queue=ge_response_queue,
+ )
+
+ self.triples_client = TriplesQueryClient(
+ pulsar_host=pulsar_host,
+ pulsar_api_key=pulsar_api_key,
+ subscriber=module + "-tpl",
+ input_queue=tpl_request_queue,
+ output_queue=tpl_response_queue
+ )
+
+ self.embeddings = EmbeddingsClient(
+ pulsar_host=pulsar_host,
+ pulsar_api_key=pulsar_api_key,
+ input_queue=emb_request_queue,
+ output_queue=emb_response_queue,
+ subscriber=module + "-emb",
+ )
+
+ self.label_cache = {}
+
+ self.prompt = PromptClient(
+ pulsar_host=pulsar_host,
+ pulsar_api_key=pulsar_api_key,
+ input_queue=pr_request_queue,
+ output_queue=pr_response_queue,
+ subscriber=module + "-prompt",
+ )
+
+ if self.verbose:
+ print("Initialised", flush=True)
+
+ def query(
+ self, query, user="trustgraph", collection="default",
+ entity_limit=50, triple_limit=30, max_subgraph_size=1000,
+ max_path_length=2,
+ ):
+
+ if self.verbose:
+ print("Construct prompt...", flush=True)
+
+ q = Query(
+ rag=self, user=user, collection=collection, verbose=self.verbose,
+ entity_limit=entity_limit, triple_limit=triple_limit,
+ max_subgraph_size=max_subgraph_size,
+ max_path_length=max_path_length,
+ )
+
+ kg = q.get_labelgraph(query)
+
+ if self.verbose:
+ print("Invoke LLM...", flush=True)
+ print(kg)
+ print(query)
+
+ resp = self.prompt.request_kg_prompt(query, kg)
+
+ if self.verbose:
+ print("Done", flush=True)
+
+ return resp
+
diff --git a/trustgraph-flow/trustgraph/librarian/service.py b/trustgraph-flow/trustgraph/librarian/service.py
index 587dcbf3..b42123a5 100755
--- a/trustgraph-flow/trustgraph/librarian/service.py
+++ b/trustgraph-flow/trustgraph/librarian/service.py
@@ -35,7 +35,7 @@ from .. exceptions import RequestError
from . librarian import Librarian
-module = "librarian"
+module = ".".join(__name__.split(".")[1:-1])
default_input_queue = librarian_request_queue
default_output_queue = librarian_response_queue
diff --git a/trustgraph-flow/trustgraph/metering/counter.py b/trustgraph-flow/trustgraph/metering/counter.py
index c721c065..68ddf441 100644
--- a/trustgraph-flow/trustgraph/metering/counter.py
+++ b/trustgraph-flow/trustgraph/metering/counter.py
@@ -10,11 +10,12 @@ from .. schema import text_completion_response_queue
from .. log_level import LogLevel
from .. base import Consumer
-module = "metering"
+module = ".".join(__name__.split(".")[1:-1])
default_input_queue = text_completion_response_queue
default_subscriber = module
+
class Processor(Consumer):
def __init__(self, **params):
diff --git a/trustgraph-flow/trustgraph/model/prompt/generic/service.py b/trustgraph-flow/trustgraph/model/prompt/generic/service.py
index b10da491..b143b759 100755
--- a/trustgraph-flow/trustgraph/model/prompt/generic/service.py
+++ b/trustgraph-flow/trustgraph/model/prompt/generic/service.py
@@ -27,7 +27,7 @@ from .... clients.llm_client import LlmClient
from . prompts import to_definitions, to_relationships, to_topics
from . prompts import to_kg_query, to_document_query, to_rows
-module = "prompt"
+module = ".".join(__name__.split(".")[1:-1])
default_input_queue = prompt_request_queue
default_output_queue = prompt_response_queue
diff --git a/trustgraph-flow/trustgraph/model/prompt/template/prompt_manager.py b/trustgraph-flow/trustgraph/model/prompt/template/prompt_manager.py
index c5c32395..d8a032ca 100644
--- a/trustgraph-flow/trustgraph/model/prompt/template/prompt_manager.py
+++ b/trustgraph-flow/trustgraph/model/prompt/template/prompt_manager.py
@@ -4,6 +4,8 @@ import json
from jsonschema import validate
import re
+from trustgraph.clients.llm_client import LlmClient
+
class PromptConfiguration:
def __init__(self, system_template, global_terms={}, prompts={}):
self.system_template = system_template
@@ -19,7 +21,8 @@ class Prompt:
class PromptManager:
- def __init__(self, config):
+ def __init__(self, llm, config):
+ self.llm = llm
self.config = config
self.terms = config.global_terms
@@ -51,9 +54,7 @@ class PromptManager:
return json.loads(json_str)
- async def invoke(self, id, input, llm):
-
- print("Invoke...", flush=True)
+ def invoke(self, id, input):
if id not in self.prompts:
raise RuntimeError("ID invalid")
@@ -67,7 +68,9 @@ class PromptManager:
"prompt": self.templates[id].render(terms)
}
- resp = await llm(**prompt)
+ resp = self.llm.request(**prompt)
+
+ print(resp, flush=True)
if resp_type == "text":
return resp
@@ -78,13 +81,13 @@ class PromptManager:
try:
obj = self.parse_json(resp)
except:
- print("Parse fail:", resp, flush=True)
raise RuntimeError("JSON parse fail")
+ print(obj, flush=True)
if self.prompts[id].schema:
try:
+ print(self.prompts[id].schema)
validate(instance=obj, schema=self.prompts[id].schema)
- print("Validated", flush=True)
except Exception as e:
raise RuntimeError(f"Schema validation fail: {e}")
diff --git a/trustgraph-flow/trustgraph/model/prompt/template/service.py b/trustgraph-flow/trustgraph/model/prompt/template/service.py
index 67590c1c..a1267114 100755
--- a/trustgraph-flow/trustgraph/model/prompt/template/service.py
+++ b/trustgraph-flow/trustgraph/model/prompt/template/service.py
@@ -3,7 +3,6 @@
Language service abstracts prompt engineering from LLM.
"""
-import asyncio
import json
import re
@@ -11,59 +10,74 @@ from .... schema import Definition, Relationship, Triple
from .... schema import Topic
from .... schema import PromptRequest, PromptResponse, Error
from .... schema import TextCompletionRequest, TextCompletionResponse
-
-from .... base import FlowProcessor
-from .... base import ProducerSpec, ConsumerSpec, TextCompletionClientSpec
+from .... schema import text_completion_request_queue
+from .... schema import text_completion_response_queue
+from .... schema import prompt_request_queue, prompt_response_queue
+from .... base import ConsumerProducer
+from .... clients.llm_client import LlmClient
from . prompt_manager import PromptConfiguration, Prompt, PromptManager
-default_ident = "prompt"
+module = ".".join(__name__.split(".")[1:-1])
-class Processor(FlowProcessor):
+default_input_queue = prompt_request_queue
+default_output_queue = prompt_response_queue
+default_subscriber = module
+
+class Processor(ConsumerProducer):
def __init__(self, **params):
- id = params.get("id")
+ input_queue = params.get("input_queue", default_input_queue)
+ output_queue = params.get("output_queue", default_output_queue)
+ subscriber = params.get("subscriber", default_subscriber)
+ tc_request_queue = params.get(
+ "text_completion_request_queue", text_completion_request_queue
+ )
+ tc_response_queue = params.get(
+ "text_completion_response_queue", text_completion_response_queue
+ )
- # Config key for prompts
self.config_key = params.get("config_type", "prompt")
super(Processor, self).__init__(
**params | {
- "id": id,
+ "input_queue": input_queue,
+ "output_queue": output_queue,
+ "subscriber": subscriber,
+ "input_schema": PromptRequest,
+ "output_schema": PromptResponse,
+ "text_completion_request_queue": tc_request_queue,
+ "text_completion_response_queue": tc_response_queue,
}
)
- self.register_specification(
- ConsumerSpec(
- name = "request",
- schema = PromptRequest,
- handler = self.on_request
- )
+ self.llm = LlmClient(
+ subscriber=subscriber,
+ input_queue=tc_request_queue,
+ output_queue=tc_response_queue,
+ pulsar_host = self.pulsar_host,
+ pulsar_api_key=self.pulsar_api_key,
)
- self.register_specification(
- TextCompletionClientSpec(
- request_name = "text-completion-request",
- response_name = "text-completion-response",
- )
- )
+ # System prompt hack
+ class Llm:
+ def __init__(self, llm):
+ self.llm = llm
+ def request(self, system, prompt):
+ print(system)
+ print(prompt, flush=True)
+ return self.llm.request(system, prompt)
- self.register_specification(
- ProducerSpec(
- name = "response",
- schema = PromptResponse
- )
- )
-
- self.register_config_handler(self.on_prompt_config)
+ self.llm = Llm(self.llm)
# Null configuration, should reload quickly
self.manager = PromptManager(
+ llm = self.llm,
config = PromptConfiguration("", {}, {})
)
- async def on_prompt_config(self, config, version):
+ async def on_config(self, version, config):
print("Loading configuration version", version)
@@ -97,6 +111,7 @@ class Processor(FlowProcessor):
)
self.manager = PromptManager(
+ self.llm,
PromptConfiguration(
system,
{},
@@ -111,7 +126,7 @@ class Processor(FlowProcessor):
print("Exception:", e, flush=True)
print("Configuration reload failed", flush=True)
- async def on_request(self, msg, consumer, flow):
+ async def handle(self, msg):
v = msg.value()
@@ -123,7 +138,7 @@ class Processor(FlowProcessor):
try:
- print(v.terms, flush=True)
+ print(v.terms)
input = {
k: json.loads(v)
@@ -131,33 +146,14 @@ class Processor(FlowProcessor):
}
print(f"Handling kind {kind}...", flush=True)
+ print(input, flush=True)
- async def llm(system, prompt):
-
- print(system, flush=True)
- print(prompt, flush=True)
-
- resp = await flow("text-completion-request").text_completion(
- system = system, prompt = prompt,
- )
-
- try:
- return resp
- except Exception as e:
- print("LLM Exception:", e, flush=True)
- return None
-
- try:
- resp = await self.manager.invoke(kind, input, llm)
- except Exception as e:
- print("Invocation exception:", e, flush=True)
- raise e
-
- print(resp, flush=True)
+ resp = self.manager.invoke(kind, input)
if isinstance(resp, str):
print("Send text response...", flush=True)
+ print(resp, flush=True)
r = PromptResponse(
text=resp,
@@ -165,7 +161,7 @@ class Processor(FlowProcessor):
error=None,
)
- await flow("response").send(r, properties={"id": id})
+ await self.send(r, properties={"id": id})
return
@@ -180,13 +176,13 @@ class Processor(FlowProcessor):
error=None,
)
- await flow("response").send(r, properties={"id": id})
+ await self.send(r, properties={"id": id})
return
except Exception as e:
- print(f"Exception: {e}", flush=True)
+ print(f"Exception: {e}")
print("Send error response...", flush=True)
@@ -198,11 +194,11 @@ class Processor(FlowProcessor):
response=None,
)
- await flow("response").send(r, properties={"id": id})
+ await self.send(r, properties={"id": id})
except Exception as e:
- print(f"Exception: {e}", flush=True)
+ print(f"Exception: {e}")
print("Send error response...", flush=True)
@@ -219,7 +215,22 @@ class Processor(FlowProcessor):
@staticmethod
def add_args(parser):
- FlowProcessor.add_args(parser)
+ ConsumerProducer.add_args(
+ parser, default_input_queue, default_subscriber,
+ default_output_queue,
+ )
+
+ parser.add_argument(
+ '--text-completion-request-queue',
+ default=text_completion_request_queue,
+ help=f'Text completion request queue (default: {text_completion_request_queue})',
+ )
+
+ parser.add_argument(
+ '--text-completion-response-queue',
+ default=text_completion_response_queue,
+ help=f'Text completion response queue (default: {text_completion_response_queue})',
+ )
parser.add_argument(
'--config-type',
@@ -229,5 +240,5 @@ class Processor(FlowProcessor):
def run():
- Processor.launch(default_ident, __doc__)
+ Processor.launch(module, __doc__)
diff --git a/trustgraph-flow/trustgraph/model/text_completion/azure/llm.py b/trustgraph-flow/trustgraph/model/text_completion/azure/llm.py
index 79118cc8..33840378 100755
--- a/trustgraph-flow/trustgraph/model/text_completion/azure/llm.py
+++ b/trustgraph-flow/trustgraph/model/text_completion/azure/llm.py
@@ -16,7 +16,7 @@ from .... log_level import LogLevel
from .... base import ConsumerProducer
from .... exceptions import TooManyRequests
-module = "text-completion"
+module = ".".join(__name__.split(".")[1:-1])
default_input_queue = text_completion_request_queue
default_output_queue = text_completion_response_queue
diff --git a/trustgraph-flow/trustgraph/model/text_completion/azure_openai/llm.py b/trustgraph-flow/trustgraph/model/text_completion/azure_openai/llm.py
index 734b20c5..252d58ad 100755
--- a/trustgraph-flow/trustgraph/model/text_completion/azure_openai/llm.py
+++ b/trustgraph-flow/trustgraph/model/text_completion/azure_openai/llm.py
@@ -16,7 +16,7 @@ from .... log_level import LogLevel
from .... base import ConsumerProducer
from .... exceptions import TooManyRequests
-module = "text-completion"
+module = ".".join(__name__.split(".")[1:-1])
default_input_queue = text_completion_request_queue
default_output_queue = text_completion_response_queue
diff --git a/trustgraph-flow/trustgraph/model/text_completion/claude/llm.py b/trustgraph-flow/trustgraph/model/text_completion/claude/llm.py
index f60b70d7..195a39e4 100755
--- a/trustgraph-flow/trustgraph/model/text_completion/claude/llm.py
+++ b/trustgraph-flow/trustgraph/model/text_completion/claude/llm.py
@@ -15,7 +15,7 @@ from .... log_level import LogLevel
from .... base import ConsumerProducer
from .... exceptions import TooManyRequests
-module = "text-completion"
+module = ".".join(__name__.split(".")[1:-1])
default_input_queue = text_completion_request_queue
default_output_queue = text_completion_response_queue
diff --git a/trustgraph-flow/trustgraph/model/text_completion/cohere/llm.py b/trustgraph-flow/trustgraph/model/text_completion/cohere/llm.py
index df104ada..d5dab142 100755
--- a/trustgraph-flow/trustgraph/model/text_completion/cohere/llm.py
+++ b/trustgraph-flow/trustgraph/model/text_completion/cohere/llm.py
@@ -15,7 +15,7 @@ from .... log_level import LogLevel
from .... base import ConsumerProducer
from .... exceptions import TooManyRequests
-module = "text-completion"
+module = ".".join(__name__.split(".")[1:-1])
default_input_queue = text_completion_request_queue
default_output_queue = text_completion_response_queue
diff --git a/trustgraph-flow/trustgraph/model/text_completion/googleaistudio/llm.py b/trustgraph-flow/trustgraph/model/text_completion/googleaistudio/llm.py
index 9f382572..98ecaf0e 100644
--- a/trustgraph-flow/trustgraph/model/text_completion/googleaistudio/llm.py
+++ b/trustgraph-flow/trustgraph/model/text_completion/googleaistudio/llm.py
@@ -17,7 +17,7 @@ from .... log_level import LogLevel
from .... base import ConsumerProducer
from .... exceptions import TooManyRequests
-module = "text-completion"
+module = ".".join(__name__.split(".")[1:-1])
default_input_queue = text_completion_request_queue
default_output_queue = text_completion_response_queue
diff --git a/trustgraph-flow/trustgraph/model/text_completion/llamafile/llm.py b/trustgraph-flow/trustgraph/model/text_completion/llamafile/llm.py
index fd473564..483412a2 100755
--- a/trustgraph-flow/trustgraph/model/text_completion/llamafile/llm.py
+++ b/trustgraph-flow/trustgraph/model/text_completion/llamafile/llm.py
@@ -14,7 +14,7 @@ from .... log_level import LogLevel
from .... base import ConsumerProducer
from .... exceptions import TooManyRequests
-module = "text-completion"
+module = ".".join(__name__.split(".")[1:-1])
default_input_queue = text_completion_request_queue
default_output_queue = text_completion_response_queue
diff --git a/trustgraph-flow/trustgraph/model/text_completion/lmstudio/llm.py b/trustgraph-flow/trustgraph/model/text_completion/lmstudio/llm.py
index 05ff18a6..16ff2df4 100755
--- a/trustgraph-flow/trustgraph/model/text_completion/lmstudio/llm.py
+++ b/trustgraph-flow/trustgraph/model/text_completion/lmstudio/llm.py
@@ -15,7 +15,7 @@ from .... log_level import LogLevel
from .... base import ConsumerProducer
from .... exceptions import TooManyRequests
-module = "text-completion"
+module = ".".join(__name__.split(".")[1:-1])
default_input_queue = text_completion_request_queue
default_output_queue = text_completion_response_queue
diff --git a/trustgraph-flow/trustgraph/model/text_completion/mistral/llm.py b/trustgraph-flow/trustgraph/model/text_completion/mistral/llm.py
index 10257cdf..45f1311c 100755
--- a/trustgraph-flow/trustgraph/model/text_completion/mistral/llm.py
+++ b/trustgraph-flow/trustgraph/model/text_completion/mistral/llm.py
@@ -15,7 +15,7 @@ from .... log_level import LogLevel
from .... base import ConsumerProducer
from .... exceptions import TooManyRequests
-module = "text-completion"
+module = ".".join(__name__.split(".")[1:-1])
default_input_queue = text_completion_request_queue
default_output_queue = text_completion_response_queue
diff --git a/trustgraph-flow/trustgraph/model/text_completion/ollama/llm.py b/trustgraph-flow/trustgraph/model/text_completion/ollama/llm.py
index 91e627e3..6d825bac 100755
--- a/trustgraph-flow/trustgraph/model/text_completion/ollama/llm.py
+++ b/trustgraph-flow/trustgraph/model/text_completion/ollama/llm.py
@@ -15,7 +15,7 @@ from .... log_level import LogLevel
from .... base import ConsumerProducer
from .... exceptions import TooManyRequests
-module = "text-completion"
+module = ".".join(__name__.split(".")[1:-1])
default_input_queue = text_completion_request_queue
default_output_queue = text_completion_response_queue
diff --git a/trustgraph-flow/trustgraph/model/text_completion/openai/llm.py b/trustgraph-flow/trustgraph/model/text_completion/openai/llm.py
index 2479034d..590c2e3f 100755
--- a/trustgraph-flow/trustgraph/model/text_completion/openai/llm.py
+++ b/trustgraph-flow/trustgraph/model/text_completion/openai/llm.py
@@ -15,7 +15,7 @@ from .... log_level import LogLevel
from .... base import ConsumerProducer
from .... exceptions import TooManyRequests
-module = "text-completion"
+module = ".".join(__name__.split(".")[1:-1])
default_input_queue = text_completion_request_queue
default_output_queue = text_completion_response_queue
diff --git a/trustgraph-flow/trustgraph/query/doc_embeddings/milvus/service.py b/trustgraph-flow/trustgraph/query/doc_embeddings/milvus/service.py
index 2fb416dd..b16399e9 100755
--- a/trustgraph-flow/trustgraph/query/doc_embeddings/milvus/service.py
+++ b/trustgraph-flow/trustgraph/query/doc_embeddings/milvus/service.py
@@ -11,7 +11,7 @@ from .... schema import document_embeddings_request_queue
from .... schema import document_embeddings_response_queue
from .... base import ConsumerProducer
-module = "de-query"
+module = ".".join(__name__.split(".")[1:-1])
default_input_queue = document_embeddings_request_queue
default_output_queue = document_embeddings_response_queue
diff --git a/trustgraph-flow/trustgraph/query/doc_embeddings/pinecone/service.py b/trustgraph-flow/trustgraph/query/doc_embeddings/pinecone/service.py
index 74c52055..6a88671c 100755
--- a/trustgraph-flow/trustgraph/query/doc_embeddings/pinecone/service.py
+++ b/trustgraph-flow/trustgraph/query/doc_embeddings/pinecone/service.py
@@ -16,7 +16,7 @@ from .... schema import document_embeddings_request_queue
from .... schema import document_embeddings_response_queue
from .... base import ConsumerProducer
-module = "de-query"
+module = ".".join(__name__.split(".")[1:-1])
default_input_queue = document_embeddings_request_queue
default_output_queue = document_embeddings_response_queue
diff --git a/trustgraph-flow/trustgraph/query/doc_embeddings/qdrant/service.py b/trustgraph-flow/trustgraph/query/doc_embeddings/qdrant/service.py
index c5543690..128203ad 100755
--- a/trustgraph-flow/trustgraph/query/doc_embeddings/qdrant/service.py
+++ b/trustgraph-flow/trustgraph/query/doc_embeddings/qdrant/service.py
@@ -7,51 +7,71 @@ of chunks
from qdrant_client import QdrantClient
from qdrant_client.models import PointStruct
from qdrant_client.models import Distance, VectorParams
+import uuid
-from .... schema import DocumentEmbeddingsResponse
+from .... schema import DocumentEmbeddingsRequest, DocumentEmbeddingsResponse
from .... schema import Error, Value
-from .... base import DocumentEmbeddingsQueryService
+from .... schema import document_embeddings_request_queue
+from .... schema import document_embeddings_response_queue
+from .... base import ConsumerProducer
-default_ident = "de-query"
+module = ".".join(__name__.split(".")[1:-1])
+default_input_queue = document_embeddings_request_queue
+default_output_queue = document_embeddings_response_queue
+default_subscriber = module
default_store_uri = 'http://localhost:6333'
-class Processor(DocumentEmbeddingsQueryService):
+class Processor(ConsumerProducer):
def __init__(self, **params):
+ input_queue = params.get("input_queue", default_input_queue)
+ output_queue = params.get("output_queue", default_output_queue)
+ subscriber = params.get("subscriber", default_subscriber)
store_uri = params.get("store_uri", default_store_uri)
-
#optional api key
api_key = params.get("api_key", None)
super(Processor, self).__init__(
**params | {
+ "input_queue": input_queue,
+ "output_queue": output_queue,
+ "subscriber": subscriber,
+ "input_schema": DocumentEmbeddingsRequest,
+ "output_schema": DocumentEmbeddingsResponse,
"store_uri": store_uri,
"api_key": api_key,
}
)
- self.qdrant = QdrantClient(url=store_uri, api_key=api_key)
+ self.client = QdrantClient(url=store_uri, api_key=api_key)
- async def query_document_embeddings(self, msg):
+ async def handle(self, msg):
try:
+ v = msg.value()
+
+ # Sender-produced ID
+ id = msg.properties()["id"]
+
+ print(f"Handling input {id}...", flush=True)
+
chunks = []
- for vec in msg.vectors:
+ for vec in v.vectors:
dim = len(vec)
collection = (
- "d_" + msg.user + "_" + msg.collection + "_" +
+ "d_" + v.user + "_" + v.collection + "_" +
str(dim)
)
- search_result = self.qdrant.query_points(
+ search_result = self.client.query_points(
collection_name=collection,
query=vec,
- limit=msg.limit,
+ limit=v.limit,
with_payload=True,
).points
@@ -59,17 +79,37 @@ class Processor(DocumentEmbeddingsQueryService):
ent = r.payload["doc"]
chunks.append(ent)
- return chunks
+ print("Send response...", flush=True)
+ r = DocumentEmbeddingsResponse(documents=chunks, error=None)
+ await self.send(r, properties={"id": id})
+
+ print("Done.", flush=True)
except Exception as e:
print(f"Exception: {e}")
- raise e
+
+ print("Send error response...", flush=True)
+
+ r = DocumentEmbeddingsResponse(
+ error=Error(
+ type = "llm-error",
+ message = str(e),
+ ),
+ documents=None,
+ )
+
+ await self.send(r, properties={"id": id})
+
+ self.consumer.acknowledge(msg)
@staticmethod
def add_args(parser):
- DocumentEmbeddingsQueryService.add_args(parser)
+ ConsumerProducer.add_args(
+ parser, default_input_queue, default_subscriber,
+ default_output_queue,
+ )
parser.add_argument(
'-t', '--store-uri',
@@ -85,5 +125,5 @@ class Processor(DocumentEmbeddingsQueryService):
def run():
- Processor.launch(default_ident, __doc__)
+ Processor.launch(module, __doc__)
diff --git a/trustgraph-flow/trustgraph/query/graph_embeddings/milvus/service.py b/trustgraph-flow/trustgraph/query/graph_embeddings/milvus/service.py
index d2cec084..8dd8d04d 100755
--- a/trustgraph-flow/trustgraph/query/graph_embeddings/milvus/service.py
+++ b/trustgraph-flow/trustgraph/query/graph_embeddings/milvus/service.py
@@ -11,7 +11,7 @@ from .... schema import graph_embeddings_request_queue
from .... schema import graph_embeddings_response_queue
from .... base import ConsumerProducer
-module = "ge-query"
+module = ".".join(__name__.split(".")[1:-1])
default_input_queue = graph_embeddings_request_queue
default_output_queue = graph_embeddings_response_queue
diff --git a/trustgraph-flow/trustgraph/query/graph_embeddings/pinecone/service.py b/trustgraph-flow/trustgraph/query/graph_embeddings/pinecone/service.py
index 942a1e69..90cfc6de 100755
--- a/trustgraph-flow/trustgraph/query/graph_embeddings/pinecone/service.py
+++ b/trustgraph-flow/trustgraph/query/graph_embeddings/pinecone/service.py
@@ -16,7 +16,7 @@ from .... schema import graph_embeddings_request_queue
from .... schema import graph_embeddings_response_queue
from .... base import ConsumerProducer
-module = "ge-query"
+module = ".".join(__name__.split(".")[1:-1])
default_input_queue = graph_embeddings_request_queue
default_output_queue = graph_embeddings_response_queue
diff --git a/trustgraph-flow/trustgraph/query/graph_embeddings/qdrant/service.py b/trustgraph-flow/trustgraph/query/graph_embeddings/qdrant/service.py
index 32da00e5..dc3e28f3 100755
--- a/trustgraph-flow/trustgraph/query/graph_embeddings/qdrant/service.py
+++ b/trustgraph-flow/trustgraph/query/graph_embeddings/qdrant/service.py
@@ -7,32 +7,44 @@ entities
from qdrant_client import QdrantClient
from qdrant_client.models import PointStruct
from qdrant_client.models import Distance, VectorParams
+import uuid
-from .... schema import GraphEmbeddingsResponse
+from .... schema import GraphEmbeddingsRequest, GraphEmbeddingsResponse
from .... schema import Error, Value
-from .... base import GraphEmbeddingsQueryService
+from .... schema import graph_embeddings_request_queue
+from .... schema import graph_embeddings_response_queue
+from .... base import ConsumerProducer
-default_ident = "ge-query"
+module = ".".join(__name__.split(".")[1:-1])
+default_input_queue = graph_embeddings_request_queue
+default_output_queue = graph_embeddings_response_queue
+default_subscriber = module
default_store_uri = 'http://localhost:6333'
-class Processor(GraphEmbeddingsQueryService):
+class Processor(ConsumerProducer):
def __init__(self, **params):
+ input_queue = params.get("input_queue", default_input_queue)
+ output_queue = params.get("output_queue", default_output_queue)
+ subscriber = params.get("subscriber", default_subscriber)
store_uri = params.get("store_uri", default_store_uri)
-
- #optional api key
api_key = params.get("api_key", None)
super(Processor, self).__init__(
**params | {
+ "input_queue": input_queue,
+ "output_queue": output_queue,
+ "subscriber": subscriber,
+ "input_schema": GraphEmbeddingsRequest,
+ "output_schema": GraphEmbeddingsResponse,
"store_uri": store_uri,
"api_key": api_key,
}
)
- self.qdrant = QdrantClient(url=store_uri, api_key=api_key)
+ self.client = QdrantClient(url=store_uri, api_key=api_key)
def create_value(self, ent):
if ent.startswith("http://") or ent.startswith("https://"):
@@ -40,27 +52,34 @@ class Processor(GraphEmbeddingsQueryService):
else:
return Value(value=ent, is_uri=False)
- async def query_graph_embeddings(self, msg):
+ async def handle(self, msg):
try:
+ v = msg.value()
+
+ # Sender-produced ID
+ id = msg.properties()["id"]
+
+ print(f"Handling input {id}...", flush=True)
+
entity_set = set()
entities = []
- for vec in msg.vectors:
+ for vec in v.vectors:
dim = len(vec)
collection = (
- "t_" + msg.user + "_" + msg.collection + "_" +
+ "t_" + v.user + "_" + v.collection + "_" +
str(dim)
)
# Heuristic hack, get (2*limit), so that we have more chance
# of getting (limit) entities
- search_result = self.qdrant.query_points(
+ search_result = self.client.query_points(
collection_name=collection,
query=vec,
- limit=msg.limit * 2,
+ limit=v.limit * 2,
with_payload=True,
).points
@@ -73,10 +92,10 @@ class Processor(GraphEmbeddingsQueryService):
entities.append(ent)
# Keep adding entities until limit
- if len(entity_set) >= msg.limit: break
+ if len(entity_set) >= v.limit: break
# Keep adding entities until limit
- if len(entity_set) >= msg.limit: break
+ if len(entity_set) >= v.limit: break
ents2 = []
@@ -86,19 +105,36 @@ class Processor(GraphEmbeddingsQueryService):
entities = ents2
print("Send response...", flush=True)
- return entities
+ r = GraphEmbeddingsResponse(entities=entities, error=None)
+ await self.send(r, properties={"id": id})
print("Done.", flush=True)
except Exception as e:
print(f"Exception: {e}")
- raise e
+
+ print("Send error response...", flush=True)
+
+ r = GraphEmbeddingsResponse(
+ error=Error(
+ type = "llm-error",
+ message = str(e),
+ ),
+ entities=None,
+ )
+
+ await self.send(r, properties={"id": id})
+
+ self.consumer.acknowledge(msg)
@staticmethod
def add_args(parser):
- GraphEmbeddingsQueryService.add_args(parser)
+ ConsumerProducer.add_args(
+ parser, default_input_queue, default_subscriber,
+ default_output_queue,
+ )
parser.add_argument(
'-t', '--store-uri',
@@ -114,5 +150,5 @@ class Processor(GraphEmbeddingsQueryService):
def run():
- Processor.launch(default_ident, __doc__)
+ Processor.launch(module, __doc__)
diff --git a/trustgraph-flow/trustgraph/query/triples/cassandra/service.py b/trustgraph-flow/trustgraph/query/triples/cassandra/service.py
index 6fcf4a19..e3687756 100755
--- a/trustgraph-flow/trustgraph/query/triples/cassandra/service.py
+++ b/trustgraph-flow/trustgraph/query/triples/cassandra/service.py
@@ -7,24 +7,38 @@ null. Output is a list of triples.
from .... direct.cassandra import TrustGraph
from .... schema import TriplesQueryRequest, TriplesQueryResponse, Error
from .... schema import Value, Triple
-from .... base import TriplesQueryService
+from .... schema import triples_request_queue
+from .... schema import triples_response_queue
+from .... base import ConsumerProducer
-default_ident = "triples-query"
+module = ".".join(__name__.split(".")[1:-1])
+default_input_queue = triples_request_queue
+default_output_queue = triples_response_queue
+default_subscriber = module
default_graph_host='localhost'
-class Processor(TriplesQueryService):
+class Processor(ConsumerProducer):
def __init__(self, **params):
+ input_queue = params.get("input_queue", default_input_queue)
+ output_queue = params.get("output_queue", default_output_queue)
+ subscriber = params.get("subscriber", default_subscriber)
graph_host = params.get("graph_host", default_graph_host)
graph_username = params.get("graph_username", None)
graph_password = params.get("graph_password", None)
super(Processor, self).__init__(
**params | {
+ "input_queue": input_queue,
+ "output_queue": output_queue,
+ "subscriber": subscriber,
+ "input_schema": TriplesQueryRequest,
+ "output_schema": TriplesQueryResponse,
"graph_host": graph_host,
"graph_username": graph_username,
+ "graph_password": graph_password,
}
)
@@ -39,85 +53,92 @@ class Processor(TriplesQueryService):
else:
return Value(value=ent, is_uri=False)
- async def query_triples(self, query):
+ async def handle(self, msg):
try:
- table = (query.user, query.collection)
+ v = msg.value()
+
+ table = (v.user, v.collection)
if table != self.table:
if self.username and self.password:
self.tg = TrustGraph(
hosts=self.graph_host,
- keyspace=query.user, table=query.collection,
+ keyspace=v.user, table=v.collection,
username=self.username, password=self.password
)
else:
self.tg = TrustGraph(
hosts=self.graph_host,
- keyspace=query.user, table=query.collection,
+ keyspace=v.user, table=v.collection,
)
self.table = table
+ # Sender-produced ID
+ id = msg.properties()["id"]
+
+ print(f"Handling input {id}...", flush=True)
+
triples = []
- if query.s is not None:
- if query.p is not None:
- if query.o is not None:
+ if v.s is not None:
+ if v.p is not None:
+ if v.o is not None:
resp = self.tg.get_spo(
- query.s.value, query.p.value, query.o.value,
- limit=query.limit
+ v.s.value, v.p.value, v.o.value,
+ limit=v.limit
)
- triples.append((query.s.value, query.p.value, query.o.value))
+ triples.append((v.s.value, v.p.value, v.o.value))
else:
resp = self.tg.get_sp(
- query.s.value, query.p.value,
- limit=query.limit
+ v.s.value, v.p.value,
+ limit=v.limit
)
for t in resp:
- triples.append((query.s.value, query.p.value, t.o))
+ triples.append((v.s.value, v.p.value, t.o))
else:
- if query.o is not None:
+ if v.o is not None:
resp = self.tg.get_os(
- query.o.value, query.s.value,
- limit=query.limit
+ v.o.value, v.s.value,
+ limit=v.limit
)
for t in resp:
- triples.append((query.s.value, t.p, query.o.value))
+ triples.append((v.s.value, t.p, v.o.value))
else:
resp = self.tg.get_s(
- query.s.value,
- limit=query.limit
+ v.s.value,
+ limit=v.limit
)
for t in resp:
- triples.append((query.s.value, t.p, t.o))
+ triples.append((v.s.value, t.p, t.o))
else:
- if query.p is not None:
- if query.o is not None:
+ if v.p is not None:
+ if v.o is not None:
resp = self.tg.get_po(
- query.p.value, query.o.value,
- limit=query.limit
+ v.p.value, v.o.value,
+ limit=v.limit
)
for t in resp:
- triples.append((t.s, query.p.value, query.o.value))
+ triples.append((t.s, v.p.value, v.o.value))
else:
resp = self.tg.get_p(
- query.p.value,
- limit=query.limit
+ v.p.value,
+ limit=v.limit
)
for t in resp:
- triples.append((t.s, query.p.value, t.o))
+ triples.append((t.s, v.p.value, t.o))
else:
- if query.o is not None:
+ if v.o is not None:
resp = self.tg.get_o(
- query.o.value,
- limit=query.limit
+ v.o.value,
+ limit=v.limit
)
for t in resp:
- triples.append((t.s, t.p, query.o.value))
+ triples.append((t.s, t.p, v.o.value))
else:
resp = self.tg.get_all(
- limit=query.limit
+ limit=v.limit
)
for t in resp:
triples.append((t.s, t.p, t.o))
@@ -131,17 +152,37 @@ class Processor(TriplesQueryService):
for t in triples
]
- return triples
+ print("Send response...", flush=True)
+ r = TriplesQueryResponse(triples=triples, error=None)
+ await self.send(r, properties={"id": id})
+
+ print("Done.", flush=True)
except Exception as e:
print(f"Exception: {e}")
- raise e
+
+ print("Send error response...", flush=True)
+
+ r = TriplesQueryResponse(
+ error=Error(
+ type = "llm-error",
+ message = str(e),
+ ),
+ response=None,
+ )
+
+ await self.send(r, properties={"id": id})
+
+ self.consumer.acknowledge(msg)
@staticmethod
def add_args(parser):
- TriplesQueryService.add_args(parser)
+ ConsumerProducer.add_args(
+ parser, default_input_queue, default_subscriber,
+ default_output_queue,
+ )
parser.add_argument(
'-g', '--graph-host',
@@ -164,5 +205,5 @@ class Processor(TriplesQueryService):
def run():
- Processor.launch(default_ident, __doc__)
+ Processor.launch(module, __doc__)
diff --git a/trustgraph-flow/trustgraph/query/triples/falkordb/service.py b/trustgraph-flow/trustgraph/query/triples/falkordb/service.py
index c62c28c1..56fed6d3 100755
--- a/trustgraph-flow/trustgraph/query/triples/falkordb/service.py
+++ b/trustgraph-flow/trustgraph/query/triples/falkordb/service.py
@@ -13,7 +13,7 @@ from .... schema import triples_request_queue
from .... schema import triples_response_queue
from .... base import ConsumerProducer
-module = "triples-query"
+module = ".".join(__name__.split(".")[1:-1])
default_input_queue = triples_request_queue
default_output_queue = triples_response_queue
diff --git a/trustgraph-flow/trustgraph/query/triples/memgraph/service.py b/trustgraph-flow/trustgraph/query/triples/memgraph/service.py
index 594c9130..f442c4ef 100755
--- a/trustgraph-flow/trustgraph/query/triples/memgraph/service.py
+++ b/trustgraph-flow/trustgraph/query/triples/memgraph/service.py
@@ -13,7 +13,7 @@ from .... schema import triples_request_queue
from .... schema import triples_response_queue
from .... base import ConsumerProducer
-module = "triples-query"
+module = ".".join(__name__.split(".")[1:-1])
default_input_queue = triples_request_queue
default_output_queue = triples_response_queue
diff --git a/trustgraph-flow/trustgraph/query/triples/neo4j/service.py b/trustgraph-flow/trustgraph/query/triples/neo4j/service.py
index 591361ce..49ba0345 100755
--- a/trustgraph-flow/trustgraph/query/triples/neo4j/service.py
+++ b/trustgraph-flow/trustgraph/query/triples/neo4j/service.py
@@ -13,7 +13,7 @@ from .... schema import triples_request_queue
from .... schema import triples_response_queue
from .... base import ConsumerProducer
-module = "triples-query"
+module = ".".join(__name__.split(".")[1:-1])
default_input_queue = triples_request_queue
default_output_queue = triples_response_queue
diff --git a/trustgraph-flow/trustgraph/retrieval/document_rag/document_rag.py b/trustgraph-flow/trustgraph/retrieval/document_rag/document_rag.py
deleted file mode 100644
index 5e3c9b41..00000000
--- a/trustgraph-flow/trustgraph/retrieval/document_rag/document_rag.py
+++ /dev/null
@@ -1,94 +0,0 @@
-
-import asyncio
-
-LABEL="http://www.w3.org/2000/01/rdf-schema#label"
-
-class Query:
-
- def __init__(
- self, rag, user, collection, verbose,
- doc_limit=20
- ):
- self.rag = rag
- self.user = user
- self.collection = collection
- self.verbose = verbose
- self.doc_limit = doc_limit
-
- async def get_vector(self, query):
-
- if self.verbose:
- print("Compute embeddings...", flush=True)
-
- qembeds = await self.rag.embeddings_client.embed(query)
-
- if self.verbose:
- print("Done.", flush=True)
-
- return qembeds
-
- async def get_docs(self, query):
-
- vectors = await self.get_vector(query)
-
- if self.verbose:
- print("Get docs...", flush=True)
-
- docs = await self.rag.doc_embeddings_client.query(
- vectors, limit=self.doc_limit,
- user=self.user, collection=self.collection,
- )
-
- if self.verbose:
- print("Docs:", flush=True)
- for doc in docs:
- print(doc, flush=True)
-
- return docs
-
-class DocumentRag:
-
- def __init__(
- self, prompt_client, embeddings_client, doc_embeddings_client,
- verbose=False,
- ):
-
- self.verbose = verbose
-
- self.prompt_client = prompt_client
- self.embeddings_client = embeddings_client
- self.doc_embeddings_client = doc_embeddings_client
-
- if self.verbose:
- print("Initialised", flush=True)
-
- async def query(
- self, query, user="trustgraph", collection="default",
- doc_limit=20,
- ):
-
- if self.verbose:
- print("Construct prompt...", flush=True)
-
- q = Query(
- rag=self, user=user, collection=collection, verbose=self.verbose,
- doc_limit=doc_limit
- )
-
- docs = await q.get_docs(query)
-
- if self.verbose:
- print("Invoke LLM...", flush=True)
- print(docs)
- print(query)
-
- resp = await self.prompt_client.document_prompt(
- query = query,
- documents = docs
- )
-
- if self.verbose:
- print("Done", flush=True)
-
- return resp
-
diff --git a/trustgraph-flow/trustgraph/retrieval/document_rag/rag.py b/trustgraph-flow/trustgraph/retrieval/document_rag/rag.py
index 8c478874..bb8b008e 100755
--- a/trustgraph-flow/trustgraph/retrieval/document_rag/rag.py
+++ b/trustgraph-flow/trustgraph/retrieval/document_rag/rag.py
@@ -5,77 +5,88 @@ Input is query, output is response.
"""
from ... schema import DocumentRagQuery, DocumentRagResponse, Error
-from . document_rag import DocumentRag
-from ... base import FlowProcessor, ConsumerSpec, ProducerSpec
-from ... base import PromptClientSpec, EmbeddingsClientSpec
-from ... base import DocumentEmbeddingsClientSpec
+from ... schema import document_rag_request_queue, document_rag_response_queue
+from ... schema import prompt_request_queue
+from ... schema import prompt_response_queue
+from ... schema import embeddings_request_queue
+from ... schema import embeddings_response_queue
+from ... schema import document_embeddings_request_queue
+from ... schema import document_embeddings_response_queue
+from ... log_level import LogLevel
+from ... document_rag import DocumentRag
+from ... base import ConsumerProducer
-default_ident = "document-rag"
+module = ".".join(__name__.split(".")[1:-1])
-class Processor(FlowProcessor):
+default_input_queue = document_rag_request_queue
+default_output_queue = document_rag_response_queue
+default_subscriber = module
+
+class Processor(ConsumerProducer):
def __init__(self, **params):
- id = params.get("id", default_ident)
+ input_queue = params.get("input_queue", default_input_queue)
+ output_queue = params.get("output_queue", default_output_queue)
+ subscriber = params.get("subscriber", default_subscriber)
+ pr_request_queue = params.get(
+ "prompt_request_queue", prompt_request_queue
+ )
+ pr_response_queue = params.get(
+ "prompt_response_queue", prompt_response_queue
+ )
+ emb_request_queue = params.get(
+ "embeddings_request_queue", embeddings_request_queue
+ )
+ emb_response_queue = params.get(
+ "embeddings_response_queue", embeddings_response_queue
+ )
+ de_request_queue = params.get(
+ "document_embeddings_request_queue",
+ document_embeddings_request_queue
+ )
+ de_response_queue = params.get(
+ "document_embeddings_response_queue",
+ document_embeddings_response_queue
+ )
- doc_limit = params.get("doc_limit", 5)
+ doc_limit = params.get("doc_limit", 10)
super(Processor, self).__init__(
**params | {
- "id": id,
- "doc_limit": doc_limit,
+ "input_queue": input_queue,
+ "output_queue": output_queue,
+ "subscriber": subscriber,
+ "input_schema": DocumentRagQuery,
+ "output_schema": DocumentRagResponse,
+ "prompt_request_queue": pr_request_queue,
+ "prompt_response_queue": pr_response_queue,
+ "embeddings_request_queue": emb_request_queue,
+ "embeddings_response_queue": emb_response_queue,
+ "document_embeddings_request_queue": de_request_queue,
+ "document_embeddings_response_queue": de_response_queue,
}
)
+ self.rag = DocumentRag(
+ pulsar_host=self.pulsar_host,
+ pulsar_api_key=self.pulsar_api_key,
+ pr_request_queue=pr_request_queue,
+ pr_response_queue=pr_response_queue,
+ emb_request_queue=emb_request_queue,
+ emb_response_queue=emb_response_queue,
+ de_request_queue=de_request_queue,
+ de_response_queue=de_response_queue,
+ verbose=True,
+ module=module,
+ )
+
self.doc_limit = doc_limit
- self.register_specification(
- ConsumerSpec(
- name = "request",
- schema = DocumentRagQuery,
- handler = self.on_request,
- )
- )
-
- self.register_specification(
- EmbeddingsClientSpec(
- request_name = "embeddings-request",
- response_name = "embeddings-response",
- )
- )
-
- self.register_specification(
- DocumentEmbeddingsClientSpec(
- request_name = "document-embeddings-request",
- response_name = "document-embeddings-response",
- )
- )
-
- self.register_specification(
- PromptClientSpec(
- request_name = "prompt-request",
- response_name = "prompt-response",
- )
- )
-
- self.register_specification(
- ProducerSpec(
- name = "response",
- schema = DocumentRagResponse,
- )
- )
-
- async def on_request(self, msg, consumer, flow):
+ async def handle(self, msg):
try:
- self.rag = DocumentRag(
- embeddings_client = flow("embeddings-request"),
- doc_embeddings_client = flow("document-embeddings-request"),
- prompt_client = flow("prompt-request"),
- verbose=True,
- )
-
v = msg.value()
# Sender-produced ID
@@ -88,15 +99,11 @@ class Processor(FlowProcessor):
else:
doc_limit = self.doc_limit
- response = await self.rag.query(v.query, doc_limit=doc_limit)
+ response = self.rag.query(v.query, doc_limit=doc_limit)
- await flow("response").send(
- DocumentRagResponse(
- response = response,
- error = None
- ),
- properties = {"id": id}
- )
+ print("Send response...", flush=True)
+ r = DocumentRagResponse(response = response, error=None)
+ await self.send(r, properties={"id": id})
print("Done.", flush=True)
@@ -106,21 +113,25 @@ class Processor(FlowProcessor):
print("Send error response...", flush=True)
- await flow("response").send(
- DocumentRagResponse(
- response = None,
- error = Error(
- type = "document-rag-error",
- message = str(e),
- ),
+ r = DocumentRagResponse(
+ error=Error(
+ type = "llm-error",
+ message = str(e),
),
- properties = {"id": id}
+ response=None,
)
+ await self.send(r, properties={"id": id})
+
+ self.consumer.acknowledge(msg)
+
@staticmethod
def add_args(parser):
- FlowProcessor.add_args(parser)
+ ConsumerProducer.add_args(
+ parser, default_input_queue, default_subscriber,
+ default_output_queue,
+ )
parser.add_argument(
'-d', '--doc-limit',
@@ -129,7 +140,43 @@ class Processor(FlowProcessor):
help=f'Default document fetch limit (default: 10)'
)
+ parser.add_argument(
+ '--prompt-request-queue',
+ default=prompt_request_queue,
+ help=f'Prompt request queue (default: {prompt_request_queue})',
+ )
+
+ parser.add_argument(
+ '--prompt-response-queue',
+ default=prompt_response_queue,
+ help=f'Prompt response queue (default: {prompt_response_queue})',
+ )
+
+ parser.add_argument(
+ '--embeddings-request-queue',
+ default=embeddings_request_queue,
+ help=f'Embeddings request queue (default: {embeddings_request_queue})',
+ )
+
+ parser.add_argument(
+ '--embeddings-response-queue',
+ default=embeddings_response_queue,
+ help=f'Embeddings response queue (default: {embeddings_response_queue})',
+ )
+
+ parser.add_argument(
+ '--document-embeddings-request-queue',
+ default=document_embeddings_request_queue,
+ help=f'Document embeddings request queue (default: {document_embeddings_request_queue})',
+ )
+
+ parser.add_argument(
+ '--document-embeddings-response-queue',
+ default=document_embeddings_response_queue,
+ help=f'Document embeddings response queue (default: {document_embeddings_response_queue})',
+ )
+
def run():
- Processor.launch(default_ident, __doc__)
+ Processor.launch(module, __doc__)
diff --git a/trustgraph-flow/trustgraph/retrieval/graph_rag/graph_rag.py b/trustgraph-flow/trustgraph/retrieval/graph_rag/graph_rag.py
deleted file mode 100644
index 6879023a..00000000
--- a/trustgraph-flow/trustgraph/retrieval/graph_rag/graph_rag.py
+++ /dev/null
@@ -1,218 +0,0 @@
-
-import asyncio
-
-LABEL="http://www.w3.org/2000/01/rdf-schema#label"
-
-class Query:
-
- def __init__(
- self, rag, user, collection, verbose,
- entity_limit=50, triple_limit=30, max_subgraph_size=1000,
- max_path_length=2,
- ):
- self.rag = rag
- self.user = user
- self.collection = collection
- self.verbose = verbose
- self.entity_limit = entity_limit
- self.triple_limit = triple_limit
- self.max_subgraph_size = max_subgraph_size
- self.max_path_length = max_path_length
-
- async def get_vector(self, query):
-
- if self.verbose:
- print("Compute embeddings...", flush=True)
-
- qembeds = await self.rag.embeddings_client.embed(query)
-
- if self.verbose:
- print("Done.", flush=True)
-
- return qembeds
-
- async def get_entities(self, query):
-
- vectors = await self.get_vector(query)
-
- if self.verbose:
- print("Get entities...", flush=True)
-
- entities = await self.rag.graph_embeddings_client.query(
- vectors=vectors, limit=self.entity_limit,
- user=self.user, collection=self.collection,
- )
-
- entities = [
- str(e)
- for e in entities
- ]
-
- if self.verbose:
- print("Entities:", flush=True)
- for ent in entities:
- print(" ", ent, flush=True)
-
- return entities
-
- async def maybe_label(self, e):
-
- if e in self.rag.label_cache:
- return self.rag.label_cache[e]
-
- res = await self.rag.triples_client.query(
- s=e, p=LABEL, o=None, limit=1,
- user=self.user, collection=self.collection,
- )
-
- if len(res) == 0:
- self.rag.label_cache[e] = e
- return e
-
- self.rag.label_cache[e] = str(res[0].o)
- return self.rag.label_cache[e]
-
- async def follow_edges(self, ent, subgraph, path_length):
-
- # Not needed?
- if path_length <= 0:
- return
-
- # Stop spanning around if the subgraph is already maxed out
- if len(subgraph) >= self.max_subgraph_size:
- return
-
- res = await self.rag.triples_client.query(
- s=ent, p=None, o=None,
- limit=self.triple_limit,
- user=self.user, collection=self.collection,
- )
-
- for triple in res:
- subgraph.add(
- (str(triple.s), str(triple.p), str(triple.o))
- )
- if path_length > 1:
- await self.follow_edges(str(triple.o), subgraph, path_length-1)
-
- res = await self.rag.triples_client.query(
- s=None, p=ent, o=None,
- limit=self.triple_limit,
- user=self.user, collection=self.collection,
- )
-
- for triple in res:
- subgraph.add(
- (str(triple.s), str(triple.p), str(triple.o))
- )
-
- res = await self.rag.triples_client.query(
- s=None, p=None, o=ent,
- limit=self.triple_limit,
- user=self.user, collection=self.collection,
- )
-
- for triple in res:
- subgraph.add(
- (str(triple.s), str(triple.p), str(triple.o))
- )
- if path_length > 1:
- await self.follow_edges(
- str(triple.s), subgraph, path_length-1
- )
-
- async def get_subgraph(self, query):
-
- entities = await self.get_entities(query)
-
- if self.verbose:
- print("Get subgraph...", flush=True)
-
- subgraph = set()
-
- for ent in entities:
- await self.follow_edges(ent, subgraph, self.max_path_length)
-
- subgraph = list(subgraph)
-
- return subgraph
-
- async def get_labelgraph(self, query):
-
- subgraph = await self.get_subgraph(query)
-
- sg2 = []
-
- for edge in subgraph:
-
- if edge[1] == LABEL:
- continue
-
- s = await self.maybe_label(edge[0])
- p = await self.maybe_label(edge[1])
- o = await self.maybe_label(edge[2])
-
- sg2.append((s, p, o))
-
- sg2 = sg2[0:self.max_subgraph_size]
-
- if self.verbose:
- print("Subgraph:", flush=True)
- for edge in sg2:
- print(" ", str(edge), flush=True)
-
- if self.verbose:
- print("Done.", flush=True)
-
- return sg2
-
-class GraphRag:
-
- def __init__(
- self, prompt_client, embeddings_client, graph_embeddings_client,
- triples_client, verbose=False,
- ):
-
- self.verbose = verbose
-
- self.prompt_client = prompt_client
- self.embeddings_client = embeddings_client
- self.graph_embeddings_client = graph_embeddings_client
- self.triples_client = triples_client
-
- self.label_cache = {}
-
- if self.verbose:
- print("Initialised", flush=True)
-
- async def query(
- self, query, user = "trustgraph", collection = "default",
- entity_limit = 50, triple_limit = 30, max_subgraph_size = 1000,
- max_path_length = 2,
- ):
-
- if self.verbose:
- print("Construct prompt...", flush=True)
-
- q = Query(
- rag = self, user = user, collection = collection,
- verbose = self.verbose, entity_limit = entity_limit,
- triple_limit = triple_limit,
- max_subgraph_size = max_subgraph_size,
- max_path_length = max_path_length,
- )
-
- kg = await q.get_labelgraph(query)
-
- if self.verbose:
- print("Invoke LLM...", flush=True)
- print(kg)
- print(query)
-
- resp = await self.prompt_client.kg_prompt(query, kg)
-
- if self.verbose:
- print("Done", flush=True)
-
- return resp
-
diff --git a/trustgraph-flow/trustgraph/retrieval/graph_rag/rag.py b/trustgraph-flow/trustgraph/retrieval/graph_rag/rag.py
index 5d3cc2f4..2c45ecd4 100755
--- a/trustgraph-flow/trustgraph/retrieval/graph_rag/rag.py
+++ b/trustgraph-flow/trustgraph/retrieval/graph_rag/rag.py
@@ -5,18 +5,57 @@ Input is query, output is response.
"""
from ... schema import GraphRagQuery, GraphRagResponse, Error
-from . graph_rag import GraphRag
-from ... base import FlowProcessor, ConsumerSpec, ProducerSpec
-from ... base import PromptClientSpec, EmbeddingsClientSpec
-from ... base import GraphEmbeddingsClientSpec, TriplesClientSpec
+from ... schema import graph_rag_request_queue, graph_rag_response_queue
+from ... schema import prompt_request_queue
+from ... schema import prompt_response_queue
+from ... schema import embeddings_request_queue
+from ... schema import embeddings_response_queue
+from ... schema import graph_embeddings_request_queue
+from ... schema import graph_embeddings_response_queue
+from ... schema import triples_request_queue
+from ... schema import triples_response_queue
+from ... log_level import LogLevel
+from ... graph_rag import GraphRag
+from ... base import ConsumerProducer
-default_ident = "graph-rag"
+module = ".".join(__name__.split(".")[1:-1])
-class Processor(FlowProcessor):
+default_input_queue = graph_rag_request_queue
+default_output_queue = graph_rag_response_queue
+default_subscriber = module
+
+class Processor(ConsumerProducer):
def __init__(self, **params):
- id = params.get("id", default_ident)
+ input_queue = params.get("input_queue", default_input_queue)
+ output_queue = params.get("output_queue", default_output_queue)
+ subscriber = params.get("subscriber", default_subscriber)
+
+ pr_request_queue = params.get(
+ "prompt_request_queue", prompt_request_queue
+ )
+ pr_response_queue = params.get(
+ "prompt_response_queue", prompt_response_queue
+ )
+ emb_request_queue = params.get(
+ "embeddings_request_queue", embeddings_request_queue
+ )
+ emb_response_queue = params.get(
+ "embeddings_response_queue", embeddings_response_queue
+ )
+ ge_request_queue = params.get(
+ "graph_embeddings_request_queue", graph_embeddings_request_queue
+ )
+ ge_response_queue = params.get(
+ "graph_embeddings_response_queue", graph_embeddings_response_queue
+ )
+ tpl_request_queue = params.get(
+ "triples_request_queue", triples_request_queue
+ )
+ tpl_response_queue = params.get(
+ "triples_response_queue", triples_response_queue
+ )
entity_limit = params.get("entity_limit", 50)
triple_limit = params.get("triple_limit", 30)
@@ -25,74 +64,49 @@ class Processor(FlowProcessor):
super(Processor, self).__init__(
**params | {
- "id": id,
+ "input_queue": input_queue,
+ "output_queue": output_queue,
+ "subscriber": subscriber,
+ "input_schema": GraphRagQuery,
+ "output_schema": GraphRagResponse,
"entity_limit": entity_limit,
"triple_limit": triple_limit,
"max_subgraph_size": max_subgraph_size,
- "max_path_length": max_path_length,
+ "prompt_request_queue": pr_request_queue,
+ "prompt_response_queue": pr_response_queue,
+ "embeddings_request_queue": emb_request_queue,
+ "embeddings_response_queue": emb_response_queue,
+ "graph_embeddings_request_queue": ge_request_queue,
+ "graph_embeddings_response_queue": ge_response_queue,
+ "triples_request_queue": triples_request_queue,
+ "triples_response_queue": triples_response_queue,
}
)
+ self.rag = GraphRag(
+ pulsar_host=self.pulsar_host,
+ pulsar_api_key=self.pulsar_api_key,
+ pr_request_queue=pr_request_queue,
+ pr_response_queue=pr_response_queue,
+ emb_request_queue=emb_request_queue,
+ emb_response_queue=emb_response_queue,
+ ge_request_queue=ge_request_queue,
+ ge_response_queue=ge_response_queue,
+ tpl_request_queue=triples_request_queue,
+ tpl_response_queue=triples_response_queue,
+ verbose=True,
+ module=module,
+ )
+
self.default_entity_limit = entity_limit
self.default_triple_limit = triple_limit
self.default_max_subgraph_size = max_subgraph_size
self.default_max_path_length = max_path_length
- self.register_specification(
- ConsumerSpec(
- name = "request",
- schema = GraphRagQuery,
- handler = self.on_request,
- )
- )
-
- self.register_specification(
- EmbeddingsClientSpec(
- request_name = "embeddings-request",
- response_name = "embeddings-response",
- )
- )
-
- self.register_specification(
- GraphEmbeddingsClientSpec(
- request_name = "graph-embeddings-request",
- response_name = "graph-embeddings-response",
- )
- )
-
- self.register_specification(
- TriplesClientSpec(
- request_name = "triples-request",
- response_name = "triples-response",
- )
- )
-
- self.register_specification(
- PromptClientSpec(
- request_name = "prompt-request",
- response_name = "prompt-response",
- )
- )
-
- self.register_specification(
- ProducerSpec(
- name = "response",
- schema = GraphRagResponse,
- )
- )
-
- async def on_request(self, msg, consumer, flow):
+ async def handle(self, msg):
try:
- self.rag = GraphRag(
- embeddings_client = flow("embeddings-request"),
- graph_embeddings_client = flow("graph-embeddings-request"),
- triples_client = flow("triples-request"),
- prompt_client = flow("prompt-request"),
- verbose=True,
- )
-
v = msg.value()
# Sender-produced ID
@@ -120,20 +134,16 @@ class Processor(FlowProcessor):
else:
max_path_length = self.default_max_path_length
- response = await self.rag.query(
- query = v.query, user = v.user, collection = v.collection,
- entity_limit = entity_limit, triple_limit = triple_limit,
- max_subgraph_size = max_subgraph_size,
- max_path_length = max_path_length,
+ response = self.rag.query(
+ query=v.query, user=v.user, collection=v.collection,
+ entity_limit=entity_limit, triple_limit=triple_limit,
+ max_subgraph_size=max_subgraph_size,
+ max_path_length=max_path_length,
)
- await flow("response").send(
- GraphRagResponse(
- response = response,
- error = None
- ),
- properties = {"id": id}
- )
+ print("Send response...", flush=True)
+ r = GraphRagResponse(response=response, error=None)
+ await self.send(r, properties={"id": id})
print("Done.", flush=True)
@@ -143,21 +153,25 @@ class Processor(FlowProcessor):
print("Send error response...", flush=True)
- await flow("response").send(
- GraphRagResponse(
- response = None,
- error = Error(
- type = "graph-rag-error",
- message = str(e),
- ),
+ r = GraphRagResponse(
+ error=Error(
+ type = "llm-error",
+ message = str(e),
),
- properties = {"id": id}
+ response=None,
)
+ await self.send(r, properties={"id": id})
+
+ self.consumer.acknowledge(msg)
+
@staticmethod
def add_args(parser):
- FlowProcessor.add_args(parser)
+ ConsumerProducer.add_args(
+ parser, default_input_queue, default_subscriber,
+ default_output_queue,
+ )
parser.add_argument(
'-e', '--entity-limit',
@@ -187,7 +201,55 @@ class Processor(FlowProcessor):
help=f'Default max path length (default: 2)'
)
+ parser.add_argument(
+ '--prompt-request-queue',
+ default=prompt_request_queue,
+ help=f'Prompt request queue (default: {prompt_request_queue})',
+ )
+
+ parser.add_argument(
+ '--prompt-response-queue',
+ default=prompt_response_queue,
+ help=f'Prompt response queue (default: {prompt_response_queue})',
+ )
+
+ parser.add_argument(
+ '--embeddings-request-queue',
+ default=embeddings_request_queue,
+ help=f'Embeddings request queue (default: {embeddings_request_queue})',
+ )
+
+ parser.add_argument(
+ '--embeddings-response-queue',
+ default=embeddings_response_queue,
+ help=f'Embeddings response queue (default: {embeddings_response_queue})',
+ )
+
+ parser.add_argument(
+ '--graph-embeddings-request-queue',
+ default=graph_embeddings_request_queue,
+ help=f'Graph embeddings request queue (default: {graph_embeddings_request_queue})',
+ )
+
+ parser.add_argument(
+ '--graph-embeddings-response-queue',
+ default=graph_embeddings_response_queue,
+ help=f'Graph embeddings response queue (default: {graph_embeddings_response_queue})',
+ )
+
+ parser.add_argument(
+ '--triples-request-queue',
+ default=triples_request_queue,
+ help=f'Triples request queue (default: {triples_request_queue})',
+ )
+
+ parser.add_argument(
+ '--triples-response-queue',
+ default=triples_response_queue,
+ help=f'Triples response queue (default: {triples_response_queue})',
+ )
+
def run():
- Processor.launch(default_ident, __doc__)
+ Processor.launch(module, __doc__)
diff --git a/trustgraph-flow/trustgraph/storage/doc_embeddings/milvus/write.py b/trustgraph-flow/trustgraph/storage/doc_embeddings/milvus/write.py
index 2949263a..b4dbc486 100755
--- a/trustgraph-flow/trustgraph/storage/doc_embeddings/milvus/write.py
+++ b/trustgraph-flow/trustgraph/storage/doc_embeddings/milvus/write.py
@@ -10,7 +10,7 @@ from .... schema import document_embeddings_store_queue
from .... log_level import LogLevel
from .... base import Consumer
-module = "de-write"
+module = ".".join(__name__.split(".")[1:-1])
default_input_queue = document_embeddings_store_queue
default_subscriber = module
diff --git a/trustgraph-flow/trustgraph/storage/doc_embeddings/pinecone/write.py b/trustgraph-flow/trustgraph/storage/doc_embeddings/pinecone/write.py
index 128323aa..9e91db9a 100644
--- a/trustgraph-flow/trustgraph/storage/doc_embeddings/pinecone/write.py
+++ b/trustgraph-flow/trustgraph/storage/doc_embeddings/pinecone/write.py
@@ -16,7 +16,7 @@ from .... schema import document_embeddings_store_queue
from .... log_level import LogLevel
from .... base import Consumer
-module = "de-write"
+module = ".".join(__name__.split(".")[1:-1])
default_input_queue = document_embeddings_store_queue
default_subscriber = module
diff --git a/trustgraph-flow/trustgraph/storage/doc_embeddings/qdrant/write.py b/trustgraph-flow/trustgraph/storage/doc_embeddings/qdrant/write.py
index d65a75eb..810c1931 100644
--- a/trustgraph-flow/trustgraph/storage/doc_embeddings/qdrant/write.py
+++ b/trustgraph-flow/trustgraph/storage/doc_embeddings/qdrant/write.py
@@ -8,21 +8,31 @@ from qdrant_client.models import PointStruct
from qdrant_client.models import Distance, VectorParams
import uuid
-from .... base import DocumentEmbeddingsStoreService
+from .... schema import DocumentEmbeddings
+from .... schema import document_embeddings_store_queue
+from .... log_level import LogLevel
+from .... base import Consumer
-default_ident = "de-write"
+module = ".".join(__name__.split(".")[1:-1])
+default_input_queue = document_embeddings_store_queue
+default_subscriber = module
default_store_uri = 'http://localhost:6333'
-class Processor(DocumentEmbeddingsStoreService):
+class Processor(Consumer):
def __init__(self, **params):
+ input_queue = params.get("input_queue", default_input_queue)
+ subscriber = params.get("subscriber", default_subscriber)
store_uri = params.get("store_uri", default_store_uri)
api_key = params.get("api_key", None)
super(Processor, self).__init__(
**params | {
+ "input_queue": input_queue,
+ "subscriber": subscriber,
+ "input_schema": DocumentEmbeddings,
"store_uri": store_uri,
"api_key": api_key,
}
@@ -30,11 +40,13 @@ class Processor(DocumentEmbeddingsStoreService):
self.last_collection = None
- self.qdrant = QdrantClient(url=store_uri, api_key=api_key)
+ self.client = QdrantClient(url=store_uri)
- async def store_document_embeddings(self, message):
+ async def handle(self, msg):
- for emb in message.chunks:
+ v = msg.value()
+
+ for emb in v.chunks:
chunk = emb.chunk.decode("utf-8")
if chunk == "": return
@@ -43,17 +55,16 @@ class Processor(DocumentEmbeddingsStoreService):
dim = len(vec)
collection = (
- "d_" + message.metadata.user + "_" +
- message.metadata.collection + "_" +
+ "d_" + v.metadata.user + "_" + v.metadata.collection + "_" +
str(dim)
)
if collection != self.last_collection:
- if not self.qdrant.collection_exists(collection):
+ if not self.client.collection_exists(collection):
try:
- self.qdrant.create_collection(
+ self.client.create_collection(
collection_name=collection,
vectors_config=VectorParams(
size=dim, distance=Distance.COSINE
@@ -65,7 +76,7 @@ class Processor(DocumentEmbeddingsStoreService):
self.last_collection = collection
- self.qdrant.upsert(
+ self.client.upsert(
collection_name=collection,
points=[
PointStruct(
@@ -81,7 +92,9 @@ class Processor(DocumentEmbeddingsStoreService):
@staticmethod
def add_args(parser):
- DocumentEmbeddingsStoreService.add_args(parser)
+ Consumer.add_args(
+ parser, default_input_queue, default_subscriber,
+ )
parser.add_argument(
'-t', '--store-uri',
@@ -97,5 +110,5 @@ class Processor(DocumentEmbeddingsStoreService):
def run():
- Processor.launch(default_ident, __doc__)
+ Processor.launch(module, __doc__)
diff --git a/trustgraph-flow/trustgraph/storage/graph_embeddings/milvus/write.py b/trustgraph-flow/trustgraph/storage/graph_embeddings/milvus/write.py
index 8d8b68b0..b2d40306 100755
--- a/trustgraph-flow/trustgraph/storage/graph_embeddings/milvus/write.py
+++ b/trustgraph-flow/trustgraph/storage/graph_embeddings/milvus/write.py
@@ -9,7 +9,7 @@ from .... log_level import LogLevel
from .... direct.milvus_graph_embeddings import EntityVectors
from .... base import Consumer
-module = "ge-write"
+module = ".".join(__name__.split(".")[1:-1])
default_input_queue = graph_embeddings_store_queue
default_subscriber = module
diff --git a/trustgraph-flow/trustgraph/storage/graph_embeddings/pinecone/write.py b/trustgraph-flow/trustgraph/storage/graph_embeddings/pinecone/write.py
index 400acf26..83861b54 100755
--- a/trustgraph-flow/trustgraph/storage/graph_embeddings/pinecone/write.py
+++ b/trustgraph-flow/trustgraph/storage/graph_embeddings/pinecone/write.py
@@ -15,7 +15,7 @@ from .... schema import graph_embeddings_store_queue
from .... log_level import LogLevel
from .... base import Consumer
-module = "ge-write"
+module = ".".join(__name__.split(".")[1:-1])
default_input_queue = graph_embeddings_store_queue
default_subscriber = module
diff --git a/trustgraph-flow/trustgraph/storage/graph_embeddings/qdrant/write.py b/trustgraph-flow/trustgraph/storage/graph_embeddings/qdrant/write.py
index ecefee4f..6b0d7371 100755
--- a/trustgraph-flow/trustgraph/storage/graph_embeddings/qdrant/write.py
+++ b/trustgraph-flow/trustgraph/storage/graph_embeddings/qdrant/write.py
@@ -8,21 +8,31 @@ from qdrant_client.models import PointStruct
from qdrant_client.models import Distance, VectorParams
import uuid
-from .... base import GraphEmbeddingsStoreService
+from .... schema import GraphEmbeddings
+from .... schema import graph_embeddings_store_queue
+from .... log_level import LogLevel
+from .... base import Consumer
-default_ident = "ge-write"
+module = ".".join(__name__.split(".")[1:-1])
+default_input_queue = graph_embeddings_store_queue
+default_subscriber = module
default_store_uri = 'http://localhost:6333'
-class Processor(GraphEmbeddingsStoreService):
+class Processor(Consumer):
def __init__(self, **params):
+ input_queue = params.get("input_queue", default_input_queue)
+ subscriber = params.get("subscriber", default_subscriber)
store_uri = params.get("store_uri", default_store_uri)
api_key = params.get("api_key", None)
super(Processor, self).__init__(
**params | {
+ "input_queue": input_queue,
+ "subscriber": subscriber,
+ "input_schema": GraphEmbeddings,
"store_uri": store_uri,
"api_key": api_key,
}
@@ -30,7 +40,7 @@ class Processor(GraphEmbeddingsStoreService):
self.last_collection = None
- self.qdrant = QdrantClient(url=store_uri, api_key=api_key)
+ self.client = QdrantClient(url=store_uri, api_key=api_key)
def get_collection(self, dim, user, collection):
@@ -40,10 +50,10 @@ class Processor(GraphEmbeddingsStoreService):
if cname != self.last_collection:
- if not self.qdrant.collection_exists(cname):
+ if not self.client.collection_exists(cname):
try:
- self.qdrant.create_collection(
+ self.client.create_collection(
collection_name=cname,
vectors_config=VectorParams(
size=dim, distance=Distance.COSINE
@@ -57,9 +67,11 @@ class Processor(GraphEmbeddingsStoreService):
return cname
- async def store_graph_embeddings(self, message):
+ async def handle(self, msg):
- for entity in message.entities:
+ v = msg.value()
+
+ for entity in v.entities:
if entity.entity.value == "" or entity.entity.value is None: return
@@ -68,10 +80,10 @@ class Processor(GraphEmbeddingsStoreService):
dim = len(vec)
collection = self.get_collection(
- dim, message.metadata.user, message.metadata.collection
+ dim, v.metadata.user, v.metadata.collection
)
- self.qdrant.upsert(
+ self.client.upsert(
collection_name=collection,
points=[
PointStruct(
@@ -87,7 +99,9 @@ class Processor(GraphEmbeddingsStoreService):
@staticmethod
def add_args(parser):
- GraphEmbeddingsStoreService.add_args(parser)
+ Consumer.add_args(
+ parser, default_input_queue, default_subscriber,
+ )
parser.add_argument(
'-t', '--store-uri',
@@ -103,5 +117,5 @@ class Processor(GraphEmbeddingsStoreService):
def run():
- Processor.launch(default_ident, __doc__)
+ Processor.launch(module, __doc__)
diff --git a/trustgraph-flow/trustgraph/storage/object_embeddings/milvus/write.py b/trustgraph-flow/trustgraph/storage/object_embeddings/milvus/write.py
index d1ad139a..5490af97 100755
--- a/trustgraph-flow/trustgraph/storage/object_embeddings/milvus/write.py
+++ b/trustgraph-flow/trustgraph/storage/object_embeddings/milvus/write.py
@@ -9,7 +9,7 @@ from .... log_level import LogLevel
from .... direct.milvus_object_embeddings import ObjectVectors
from .... base import Consumer
-module = "oe-write"
+module = ".".join(__name__.split(".")[1:-1])
default_input_queue = object_embeddings_store_queue
default_subscriber = module
diff --git a/trustgraph-flow/trustgraph/storage/rows/cassandra/write.py b/trustgraph-flow/trustgraph/storage/rows/cassandra/write.py
index a84aefde..e6536e6c 100755
--- a/trustgraph-flow/trustgraph/storage/rows/cassandra/write.py
+++ b/trustgraph-flow/trustgraph/storage/rows/cassandra/write.py
@@ -17,7 +17,7 @@ from .... schema import rows_store_queue
from .... log_level import LogLevel
from .... base import Consumer
-module = "rows-write"
+module = ".".join(__name__.split(".")[1:-1])
ssl_context = SSLContext(PROTOCOL_TLSv1_2)
default_input_queue = rows_store_queue
diff --git a/trustgraph-flow/trustgraph/storage/triples/cassandra/write.py b/trustgraph-flow/trustgraph/storage/triples/cassandra/write.py
index f8396692..17b5ae9a 100755
--- a/trustgraph-flow/trustgraph/storage/triples/cassandra/write.py
+++ b/trustgraph-flow/trustgraph/storage/triples/cassandra/write.py
@@ -10,26 +10,35 @@ import argparse
import time
from .... direct.cassandra import TrustGraph
-from .... base import TriplesStoreService
+from .... schema import Triples
+from .... schema import triples_store_queue
+from .... log_level import LogLevel
+from .... base import Consumer
-default_ident = "triples-write"
+module = ".".join(__name__.split(".")[1:-1])
+default_input_queue = triples_store_queue
+default_subscriber = module
default_graph_host='localhost'
-class Processor(TriplesStoreService):
+class Processor(Consumer):
def __init__(self, **params):
- id = params.get("id", default_ident)
-
+ input_queue = params.get("input_queue", default_input_queue)
+ subscriber = params.get("subscriber", default_subscriber)
graph_host = params.get("graph_host", default_graph_host)
graph_username = params.get("graph_username", None)
graph_password = params.get("graph_password", None)
super(Processor, self).__init__(
**params | {
+ "input_queue": input_queue,
+ "subscriber": subscriber,
+ "input_schema": Triples,
"graph_host": graph_host,
- "graph_username": graph_username
+ "graph_username": graph_username,
+ "graph_password": graph_password,
}
)
@@ -38,9 +47,11 @@ class Processor(TriplesStoreService):
self.password = graph_password
self.table = None
- async def store_triples(self, message):
+ async def handle(self, msg):
- table = (message.metadata.user, message.metadata.collection)
+ v = msg.value()
+
+ table = (v.metadata.user, v.metadata.collection)
if self.table is None or self.table != table:
@@ -50,15 +61,13 @@ class Processor(TriplesStoreService):
if self.username and self.password:
self.tg = TrustGraph(
hosts=self.graph_host,
- keyspace=message.metadata.user,
- table=message.metadata.collection,
+ keyspace=v.metadata.user, table=v.metadata.collection,
username=self.username, password=self.password
)
else:
self.tg = TrustGraph(
hosts=self.graph_host,
- keyspace=message.metadata.user,
- table=message.metadata.collection,
+ keyspace=v.metadata.user, table=v.metadata.collection,
)
except Exception as e:
print("Exception", e, flush=True)
@@ -67,7 +76,7 @@ class Processor(TriplesStoreService):
self.table = table
- for t in message.triples:
+ for t in v.triples:
self.tg.insert(
t.s.value,
t.p.value,
@@ -77,7 +86,9 @@ class Processor(TriplesStoreService):
@staticmethod
def add_args(parser):
- TriplesStoreService.add_args(parser)
+ Consumer.add_args(
+ parser, default_input_queue, default_subscriber,
+ )
parser.add_argument(
'-g', '--graph-host',
@@ -99,5 +110,5 @@ class Processor(TriplesStoreService):
def run():
- Processor.launch(default_ident, __doc__)
+ Processor.launch(module, __doc__)
diff --git a/trustgraph-flow/trustgraph/storage/triples/falkordb/write.py b/trustgraph-flow/trustgraph/storage/triples/falkordb/write.py
index b3996b91..2d0ae38a 100755
--- a/trustgraph-flow/trustgraph/storage/triples/falkordb/write.py
+++ b/trustgraph-flow/trustgraph/storage/triples/falkordb/write.py
@@ -16,7 +16,7 @@ from .... schema import triples_store_queue
from .... log_level import LogLevel
from .... base import Consumer
-module = "triples-write"
+module = ".".join(__name__.split(".")[1:-1])
default_input_queue = triples_store_queue
default_subscriber = module
diff --git a/trustgraph-flow/trustgraph/storage/triples/memgraph/write.py b/trustgraph-flow/trustgraph/storage/triples/memgraph/write.py
index 8c88ea8f..620e669e 100755
--- a/trustgraph-flow/trustgraph/storage/triples/memgraph/write.py
+++ b/trustgraph-flow/trustgraph/storage/triples/memgraph/write.py
@@ -16,7 +16,7 @@ from .... schema import triples_store_queue
from .... log_level import LogLevel
from .... base import Consumer
-module = "triples-write"
+module = ".".join(__name__.split(".")[1:-1])
default_input_queue = triples_store_queue
default_subscriber = module
diff --git a/trustgraph-flow/trustgraph/storage/triples/neo4j/write.py b/trustgraph-flow/trustgraph/storage/triples/neo4j/write.py
index 84a4d923..3323f912 100755
--- a/trustgraph-flow/trustgraph/storage/triples/neo4j/write.py
+++ b/trustgraph-flow/trustgraph/storage/triples/neo4j/write.py
@@ -16,7 +16,7 @@ from .... schema import triples_store_queue
from .... log_level import LogLevel
from .... base import Consumer
-module = "triples-write"
+module = ".".join(__name__.split(".")[1:-1])
default_input_queue = triples_store_queue
default_subscriber = module
diff --git a/trustgraph-ocr/trustgraph/decoding/ocr/pdf_decoder.py b/trustgraph-ocr/trustgraph/decoding/ocr/pdf_decoder.py
index 5fa436b8..f8926589 100755
--- a/trustgraph-ocr/trustgraph/decoding/ocr/pdf_decoder.py
+++ b/trustgraph-ocr/trustgraph/decoding/ocr/pdf_decoder.py
@@ -14,7 +14,7 @@ from ... schema import document_ingest_queue, text_ingest_queue
from ... log_level import LogLevel
from ... base import ConsumerProducer
-module = "ocr"
+module = ".".join(__name__.split(".")[1:-1])
default_input_queue = document_ingest_queue
default_output_queue = text_ingest_queue
diff --git a/trustgraph-vertexai/trustgraph/model/text_completion/vertexai/llm.py b/trustgraph-vertexai/trustgraph/model/text_completion/vertexai/llm.py
index 3594b76d..4d38c8c0 100755
--- a/trustgraph-vertexai/trustgraph/model/text_completion/vertexai/llm.py
+++ b/trustgraph-vertexai/trustgraph/model/text_completion/vertexai/llm.py
@@ -4,30 +4,50 @@ Simple LLM service, performs text prompt completion using VertexAI on
Google Cloud. Input is prompt, output is response.
"""
+import vertexai
+import time
+from prometheus_client import Histogram
+import os
+
from google.oauth2 import service_account
import google
-import vertexai
from vertexai.preview.generative_models import (
- Content, FunctionDeclaration, GenerativeModel, GenerationConfig,
- HarmCategory, HarmBlockThreshold, Part, Tool,
+ Content,
+ FunctionDeclaration,
+ GenerativeModel,
+ GenerationConfig,
+ HarmCategory,
+ HarmBlockThreshold,
+ Part,
+ Tool,
)
+from .... schema import TextCompletionRequest, TextCompletionResponse, Error
+from .... schema import text_completion_request_queue
+from .... schema import text_completion_response_queue
+from .... log_level import LogLevel
+from .... base import ConsumerProducer
from .... exceptions import TooManyRequests
-from .... base import LlmService, LlmResult
-default_ident = "text-completion"
+module = ".".join(__name__.split(".")[1:-1])
+default_input_queue = text_completion_request_queue
+default_output_queue = text_completion_response_queue
+default_subscriber = module
default_model = 'gemini-1.0-pro-001'
default_region = 'us-central1'
default_temperature = 0.0
default_max_output = 8192
default_private_key = "private.json"
-class Processor(LlmService):
+class Processor(ConsumerProducer):
def __init__(self, **params):
+ input_queue = params.get("input_queue", default_input_queue)
+ output_queue = params.get("output_queue", default_output_queue)
+ subscriber = params.get("subscriber", default_subscriber)
region = params.get("region", default_region)
model = params.get("model", default_model)
private_key = params.get("private_key", default_private_key)
@@ -37,7 +57,28 @@ class Processor(LlmService):
if private_key is None:
raise RuntimeError("Private key file not specified")
- super(Processor, self).__init__(**params)
+ super(Processor, self).__init__(
+ **params | {
+ "input_queue": input_queue,
+ "output_queue": output_queue,
+ "subscriber": subscriber,
+ "input_schema": TextCompletionRequest,
+ "output_schema": TextCompletionResponse,
+ }
+ )
+
+ if not hasattr(__class__, "text_completion_metric"):
+ __class__.text_completion_metric = Histogram(
+ 'text_completion_duration',
+ 'Text completion duration (seconds)',
+ buckets=[
+ 0.25, 0.5, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0,
+ 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
+ 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0,
+ 30.0, 35.0, 40.0, 45.0, 50.0, 60.0, 80.0, 100.0,
+ 120.0
+ ]
+ )
self.parameters = {
"temperature": temperature,
@@ -69,11 +110,7 @@ class Processor(LlmService):
print("Initialise VertexAI...", flush=True)
if private_key:
- credentials = (
- service_account.Credentials.from_service_account_file(
- private_key
- )
- )
+ credentials = service_account.Credentials.from_service_account_file(private_key)
else:
credentials = None
@@ -94,29 +131,50 @@ class Processor(LlmService):
print("Initialisation complete", flush=True)
- async def generate_content(self, system, prompt):
+ async def handle(self, msg):
try:
- prompt = system + "\n\n" + prompt
+ v = msg.value()
- response = self.llm.generate_content(
- prompt, generation_config=self.generation_config,
- safety_settings=self.safety_settings
- )
+ # Sender-produced ID
- resp = LlmResult()
- resp.text = response.text
- resp.in_token = response.usage_metadata.prompt_token_count
- resp.out_token = response.usage_metadata.candidates_token_count
- resp.model = self.model
+ id = msg.properties()["id"]
- print(f"Input Tokens: {resp.in_token}", flush=True)
- print(f"Output Tokens: {resp.out_token}", flush=True)
+ print(f"Handling prompt {id}...", flush=True)
+
+ prompt = v.system + "\n\n" + v.prompt
+
+ with __class__.text_completion_metric.time():
+
+ response = self.llm.generate_content(
+ prompt, generation_config=self.generation_config,
+ safety_settings=self.safety_settings
+ )
+
+ resp = response.text
+ inputtokens = int(response.usage_metadata.prompt_token_count)
+ outputtokens = int(response.usage_metadata.candidates_token_count)
+ print(resp, flush=True)
+ print(f"Input Tokens: {inputtokens}", flush=True)
+ print(f"Output Tokens: {outputtokens}", flush=True)
print("Send response...", flush=True)
- return resp
+ r = TextCompletionResponse(
+ error=None,
+ response=resp,
+ in_token=inputtokens,
+ out_token=outputtokens,
+ model=self.model
+ )
+
+ await self.send(r, properties={"id": id})
+
+ print("Done.", flush=True)
+
+ # Acknowledge successful processing of the message
+ self.consumer.acknowledge(msg)
except google.api_core.exceptions.ResourceExhausted as e:
@@ -128,19 +186,40 @@ class Processor(LlmService):
except Exception as e:
# Apart from rate limits, treat all exceptions as unrecoverable
+
print(f"Exception: {e}")
- raise e
+
+ print("Send error response...", flush=True)
+
+ r = TextCompletionResponse(
+ error=Error(
+ type = "llm-error",
+ message = str(e),
+ ),
+ response=None,
+ in_token=None,
+ out_token=None,
+ model=None,
+ )
+
+ await self.send(r, properties={"id": id})
+
+ self.consumer.acknowledge(msg)
@staticmethod
def add_args(parser):
- LlmService.add_args(parser)
+ ConsumerProducer.add_args(
+ parser, default_input_queue, default_subscriber,
+ default_output_queue,
+ )
parser.add_argument(
'-m', '--model',
default=default_model,
help=f'LLM model (default: {default_model})'
)
+ # Also: text-bison-32k
parser.add_argument(
'-k', '--private-key',
@@ -168,5 +247,6 @@ class Processor(LlmService):
)
def run():
- Processor.launch(default_ident, __doc__)
+
+ Processor.launch(module, __doc__)