Feature/configure flows (#345)

- Keeps processing in different flows separate so that data can go to different stores / collections etc.
- Potentially supports different processing flows
- Tidies the processing API with common base-classes for e.g. LLMs, and automatic configuration of 'clients' to use the right queue names in a flow
This commit is contained in:
cybermaggedon 2025-04-22 20:21:38 +01:00 committed by GitHub
parent a06a814a41
commit a9197d11ee
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
125 changed files with 3751 additions and 2628 deletions

View file

@ -60,6 +60,14 @@ container: update-package-versions
${DOCKER} build -f containers/Containerfile.ocr \
-t ${CONTAINER_BASE}/trustgraph-ocr:${VERSION} .
some-containers:
${DOCKER} build -f containers/Containerfile.base \
-t ${CONTAINER_BASE}/trustgraph-base:${VERSION} .
${DOCKER} build -f containers/Containerfile.flow \
-t ${CONTAINER_BASE}/trustgraph-flow:${VERSION} .
${DOCKER} build -f containers/Containerfile.vertexai \
-t ${CONTAINER_BASE}/trustgraph-vertexai:${VERSION} .
basic-containers: update-package-versions
${DOCKER} build -f containers/Containerfile.base \
-t ${CONTAINER_BASE}/trustgraph-base:${VERSION} .

View file

@ -20,7 +20,11 @@ def output(text, prefix="> ", width=78):
)
print(out)
p = AgentClient(pulsar_host="pulsar://localhost:6650")
p = AgentClient(
pulsar_host="pulsar://pulsar:6650",
input_queue = "non-persistent://tg/request/agent:0000",
output_queue = "non-persistent://tg/response/agent:0000",
)
q = "How many cats does Mark have? Calculate that number raised to 0.4 power. Is that number lower than the numeric part of the mission identifier of the Space Shuttle Challenger on its last mission? If so, give me an apple pie recipe, otherwise return a poem about cheese."

View file

@ -3,7 +3,12 @@
import pulsar
from trustgraph.clients.document_rag_client import DocumentRagClient
rag = DocumentRagClient(pulsar_host="pulsar://localhost:6650")
rag = DocumentRagClient(
pulsar_host="pulsar://localhost:6650",
subscriber="test1",
input_queue = "non-persistent://tg/request/document-rag:default",
output_queue = "non-persistent://tg/response/document-rag:default",
)
query="""
What was the cause of the space shuttle disaster?"""

View file

@ -3,7 +3,12 @@
import pulsar
from trustgraph.clients.embeddings_client import EmbeddingsClient
embed = EmbeddingsClient(pulsar_host="pulsar://localhost:6650")
embed = EmbeddingsClient(
pulsar_host="pulsar://pulsar:6650",
input_queue="non-persistent://tg/request/embeddings:default",
output_queue="non-persistent://tg/response/embeddings:default",
subscriber="test1",
)
prompt="Write a funny limerick about a llama"
@ -11,5 +16,3 @@ resp = embed.request(prompt)
print(resp)

View file

@ -3,11 +3,18 @@
import pulsar
from trustgraph.clients.graph_rag_client import GraphRagClient
rag = GraphRagClient(pulsar_host="pulsar://localhost:6650")
rag = GraphRagClient(
pulsar_host="pulsar://localhost:6650",
subscriber="test1",
input_queue = "non-persistent://tg/request/graph-rag:default",
output_queue = "non-persistent://tg/response/graph-rag:default",
)
query="""
This knowledge graph describes the Space Shuttle disaster.
Present 20 facts which are present in the knowledge graph."""
#query="""
#This knowledge graph describes the Space Shuttle disaster.
#Present 20 facts which are present in the knowledge graph."""
query = "How many cats does Mark have?"
resp = rag.request(query)

View file

@ -3,14 +3,17 @@
import pulsar
from trustgraph.clients.llm_client import LlmClient
llm = LlmClient(pulsar_host="pulsar://localhost:6650")
llm = LlmClient(
pulsar_host="pulsar://pulsar:6650",
input_queue="non-persistent://tg/request/text-completion:default",
output_queue="non-persistent://tg/response/text-completion:default",
subscriber="test1",
)
system = "You are a lovely assistant."
prompt="Write a funny limerick about a llama"
prompt="what is 2 + 2 == 5"
resp = llm.request(system, prompt)
print(resp)

View file

@ -3,7 +3,12 @@
import json
from trustgraph.clients.prompt_client import PromptClient
p = PromptClient(pulsar_host="pulsar://localhost:6650")
p = PromptClient(
pulsar_host="pulsar://localhost:6650",
input_queue="non-persistent://tg/request/prompt:default",
output_queue="non-persistent://tg/response/prompt:default",
subscriber="test1",
)
chunk="""
The Space Shuttle was a reusable spacecraft that transported astronauts and cargo to and from Earth's orbit. It was designed to launch like a rocket, maneuver in orbit like a spacecraft, and land like an airplane. The Space Shuttle was NASA's space transportation system and was used for many purposes, including:
@ -31,8 +36,8 @@ The Space Shuttle's last mission was in 2011.
q = "Tell me some facts in the knowledge graph"
resp = p.request(
id="extract-definition",
terms = {
id="extract-definitions",
variables = {
"text": chunk,
}
)
@ -40,7 +45,7 @@ resp = p.request(
print(resp)
for fact in resp:
print(fact["term"], "::")
print(fact["entity"], "::")
print(fact["definition"])
print()

View file

@ -3,13 +3,18 @@
import pulsar
from trustgraph.clients.prompt_client import PromptClient
p = PromptClient(pulsar_host="pulsar://localhost:6650")
p = PromptClient(
pulsar_host="pulsar://localhost:6650",
input_queue="non-persistent://tg/request/prompt:default",
output_queue="non-persistent://tg/response/prompt:default",
subscriber="test1",
)
question = """What is the square root of 16?"""
resp = p.request(
id="question",
terms = {
variables = {
"question": question
}
)

View file

@ -3,7 +3,9 @@
import pulsar
from trustgraph.clients.triples_query_client import TriplesQueryClient
tq = TriplesQueryClient(pulsar_host="pulsar://localhost:6650")
tq = TriplesQueryClient(
pulsar_host="pulsar://localhost:6650",
)
e = "http://trustgraph.ai/e/shuttle"

View file

@ -1,8 +1,31 @@
from . base_processor import BaseProcessor
from . pubsub import PulsarClient
from . async_processor import AsyncProcessor
from . consumer import Consumer
from . producer import Producer
from . consumer_producer import ConsumerProducer
from . publisher import Publisher
from . subscriber import Subscriber
from . metrics import ProcessorMetrics, ConsumerMetrics, ProducerMetrics
from . flow_processor import FlowProcessor
from . consumer_spec import ConsumerSpec
from . setting_spec import SettingSpec
from . producer_spec import ProducerSpec
from . subscriber_spec import SubscriberSpec
from . request_response_spec import RequestResponseSpec
from . llm_service import LlmService, LlmResult
from . embeddings_service import EmbeddingsService
from . embeddings_client import EmbeddingsClientSpec
from . text_completion_client import TextCompletionClientSpec
from . prompt_client import PromptClientSpec
from . triples_store_service import TriplesStoreService
from . graph_embeddings_store_service import GraphEmbeddingsStoreService
from . document_embeddings_store_service import DocumentEmbeddingsStoreService
from . triples_query_service import TriplesQueryService
from . graph_embeddings_query_service import GraphEmbeddingsQueryService
from . document_embeddings_query_service import DocumentEmbeddingsQueryService
from . graph_embeddings_client import GraphEmbeddingsClientSpec
from . triples_client import TriplesClientSpec
from . document_embeddings_client import DocumentEmbeddingsClientSpec
from . agent_service import AgentService
from . graph_rag_client import GraphRagClientSpec

View file

@ -0,0 +1,39 @@
from . request_response_spec import RequestResponse, RequestResponseSpec
from .. schema import AgentRequest, AgentResponse
from .. knowledge import Uri, Literal
class AgentClient(RequestResponse):
async def request(self, recipient, question, plan=None, state=None,
history=[], timeout=300):
resp = await self.request(
AgentRequest(
question = question,
plan = plan,
state = state,
history = history,
),
recipient=recipient,
timeout=timeout,
)
print(resp, flush=True)
if resp.error:
raise RuntimeError(resp.error.message)
return resp
class GraphEmbeddingsClientSpec(RequestResponseSpec):
def __init__(
self, request_name, response_name,
):
super(GraphEmbeddingsClientSpec, self).__init__(
request_name = request_name,
request_schema = GraphEmbeddingsRequest,
response_name = response_name,
response_schema = GraphEmbeddingsResponse,
impl = GraphEmbeddingsClient,
)

View file

@ -0,0 +1,100 @@
"""
Agent manager service completion base class
"""
import time
from prometheus_client import Histogram
from .. schema import AgentRequest, AgentResponse, Error
from .. exceptions import TooManyRequests
from .. base import FlowProcessor, ConsumerSpec, ProducerSpec
default_ident = "agent-manager"
class AgentService(FlowProcessor):
def __init__(self, **params):
id = params.get("id")
super(AgentService, self).__init__(**params | { "id": id })
self.register_specification(
ConsumerSpec(
name = "request",
schema = AgentRequest,
handler = self.on_request
)
)
self.register_specification(
ProducerSpec(
name = "next",
schema = AgentRequest
)
)
self.register_specification(
ProducerSpec(
name = "response",
schema = AgentResponse
)
)
async def on_request(self, msg, consumer, flow):
try:
request = msg.value()
# Sender-produced ID
id = msg.properties()["id"]
async def respond(resp):
await flow("response").send(
resp,
properties={"id": id}
)
async def next(resp):
await flow("next").send(
resp,
properties={"id": id}
)
await self.agent_request(
request = request, respond = respond, next = next,
flow = flow
)
except TooManyRequests as e:
raise e
except Exception as e:
# Apart from rate limits, treat all exceptions as unrecoverable
print(f"on_request Exception: {e}")
print("Send error response...", flush=True)
await flow.producer["response"].send(
AgentResponse(
error=Error(
type = "agent-error",
message = str(e),
),
thought = None,
observation = None,
answer = None,
),
properties={"id": id}
)
@staticmethod
def add_args(parser):
FlowProcessor.add_args(parser)

View file

@ -0,0 +1,254 @@
# Base class for processors. Implements:
# - Pulsar client, subscribe and consume basic
# - the async startup logic
# - Initialising metrics
import asyncio
import argparse
import _pulsar
import time
import uuid
from prometheus_client import start_http_server, Info
from .. schema import ConfigPush, config_push_queue
from .. log_level import LogLevel
from .. exceptions import TooManyRequests
from . pubsub import PulsarClient
from . producer import Producer
from . consumer import Consumer
from . metrics import ProcessorMetrics
default_config_queue = config_push_queue
# Async processor
class AsyncProcessor:
def __init__(self, **params):
# Store the identity
self.id = params.get("id")
# Register a pulsar client
self.pulsar_client = PulsarClient(**params)
# Initialise metrics, records the parameters
ProcessorMetrics(id=self.id).info({
k: str(params[k])
for k in params
if k != "id"
})
# The processor runs all activity in a taskgroup, it's mandatory
# that this is provded
self.taskgroup = params.get("taskgroup")
if self.taskgroup is None:
raise RuntimeError("Essential taskgroup missing")
# Get the configuration topic
self.config_push_queue = params.get(
"config_push_queue", default_config_queue
)
# This records registered configuration handlers
self.config_handlers = []
# Create a random ID for this subscription to the configuration
# service
config_subscriber_id = str(uuid.uuid4())
# Subscribe to config queue
self.config_sub_task = Consumer(
taskgroup = self.taskgroup,
client = self.client,
subscriber = config_subscriber_id,
flow = None,
topic = self.config_push_queue,
schema = ConfigPush,
handler = self.on_config_change,
# This causes new subscriptions to view the entire history of
# configuration
start_of_messages = True
)
self.running = True
# This is called to start dynamic behaviour. An over-ride point for
# extra functionality
async def start(self):
await self.config_sub_task.start()
# This is called to stop all threads. An over-ride point for extra
# functionality
def stop(self):
self.client.close()
self.running = False
# Returns the pulsar host
@property
def pulsar_host(self): return self.client.pulsar_host
# Returns the pulsar client
@property
def client(self): return self.pulsar_client.client
# Register a new event handler for configuration change
def register_config_handler(self, handler):
self.config_handlers.append(handler)
# Called when a new configuration message push occurs
async def on_config_change(self, message, consumer):
# Get configuration data and version number
config = message.value().config
version = message.value().version
# Acknowledge the message
consumer.acknowledge(message)
# Invoke message handlers
print("Config change event", config, version, flush=True)
for ch in self.config_handlers:
await ch(config, version)
# This is the 'main' body of the handler. It is a point to override
# if needed. By default does nothing. Processors are implemented
# by adding consumer/producer functionality so maybe nothing is needed
# in the run() body
async def run(self):
while self.running:
await asyncio.sleep(2)
# Startup fabric. This runs in 'async' mode, creates a taskgroup and
# runs the producer.
@classmethod
async def launch_async(cls, args):
try:
# Create a taskgroup. This seems complicated, when an exception
# occurs, unhandled it looks like it cancels all threads in the
# taskgroup. Needs the exception to be caught in the right
# place.
async with asyncio.TaskGroup() as tg:
# Create a processor instance, and include the taskgroup
# as a paramter. A processor identity ident is used as
# - subscriber name
# - an identifier for flow configuration
p = cls(**args | { "taskgroup": tg })
# Start the processor
await p.start()
# Run the processor
task = tg.create_task(p.run())
# The taskgroup causes everything to wait until
# all threads have stopped
# This is here to output a debug message, shouldn't be needed.
except Exception as e:
print("Exception, closing taskgroup", flush=True)
raise e
# Startup fabric. launch calls launch_async in async mode.
@classmethod
def launch(cls, ident, doc):
# Start assembling CLI arguments
parser = argparse.ArgumentParser(
prog=ident,
description=doc
)
parser.add_argument(
'--id',
default=ident,
help=f'Configuration identity (default: {ident})',
)
# Invoke the class-specific add_args, which manages adding all the
# command-line arguments
cls.add_args(parser)
# Parse arguments
args = parser.parse_args()
args = vars(args)
# Debug
print(args, flush=True)
# Start the Prometheus metrics service if needed
if args["metrics"]:
start_http_server(args["metrics_port"])
# Loop forever, exception handler
while True:
print("Starting...", flush=True)
try:
# Launch the processor in an asyncio handler
asyncio.run(cls.launch_async(
args
))
except KeyboardInterrupt:
print("Keyboard interrupt.", flush=True)
return
except _pulsar.Interrupted:
print("Pulsar Interrupted.", flush=True)
return
# Exceptions from a taskgroup come in as an exception group
except ExceptionGroup as e:
print("Exception group:", flush=True)
for se in e.exceptions:
print(" Type:", type(se), flush=True)
print(f" Exception: {se}", flush=True)
except Exception as e:
print("Type:", type(e), flush=True)
print("Exception:", e, flush=True)
# Retry occurs here
print("Will retry...", flush=True)
time.sleep(4)
print("Retrying...", flush=True)
# The command-line arguments are built using a stack of add_args
# invocations
@staticmethod
def add_args(parser):
PulsarClient.add_args(parser)
parser.add_argument(
'--config-push-queue',
default=default_config_queue,
help=f'Config push queue {default_config_queue}',
)
parser.add_argument(
'--metrics',
action=argparse.BooleanOptionalAction,
default=True,
help=f'Metrics enabled (default: true)',
)
parser.add_argument(
'-P', '--metrics-port',
type=int,
default=8000,
help=f'Pulsar host (default: 8000)',
)

View file

@ -1,210 +0,0 @@
import asyncio
import os
import argparse
import pulsar
from pulsar.schema import JsonSchema
import _pulsar
import time
import uuid
from prometheus_client import start_http_server, Info
from .. schema import ConfigPush, config_push_queue
from .. log_level import LogLevel
default_config_queue = config_push_queue
config_subscriber_id = str(uuid.uuid4())
class BaseProcessor:
default_pulsar_host = os.getenv("PULSAR_HOST", 'pulsar://pulsar:6650')
default_pulsar_api_key = os.getenv("PULSAR_API_KEY", None)
def __init__(self, **params):
self.client = None
if not hasattr(__class__, "params_metric"):
__class__.params_metric = Info(
'params', 'Parameters configuration'
)
# FIXME: Maybe outputs information it should not
__class__.params_metric.info({
k: str(params[k])
for k in params
})
pulsar_host = params.get("pulsar_host", self.default_pulsar_host)
pulsar_listener = params.get("pulsar_listener", None)
pulsar_api_key = params.get("pulsar_api_key", None)
log_level = params.get("log_level", LogLevel.INFO)
self.config_push_queue = params.get(
"config_push_queue",
default_config_queue
)
self.pulsar_host = pulsar_host
self.pulsar_api_key = pulsar_api_key
if pulsar_api_key:
auth = pulsar.AuthenticationToken(pulsar_api_key)
self.client = pulsar.Client(
pulsar_host,
authentication=auth,
logger=pulsar.ConsoleLogger(log_level.to_pulsar())
)
else:
self.client = pulsar.Client(
pulsar_host,
listener_name=pulsar_listener,
logger=pulsar.ConsoleLogger(log_level.to_pulsar())
)
self.pulsar_listener = pulsar_listener
self.config_subscriber = self.client.subscribe(
self.config_push_queue, config_subscriber_id,
consumer_type=pulsar.ConsumerType.Shared,
initial_position=pulsar.InitialPosition.Earliest,
schema=JsonSchema(ConfigPush),
)
def __del__(self):
if hasattr(self, "client"):
if self.client:
self.client.close()
@staticmethod
def add_args(parser):
parser.add_argument(
'-p', '--pulsar-host',
default=__class__.default_pulsar_host,
help=f'Pulsar host (default: {__class__.default_pulsar_host})',
)
parser.add_argument(
'--pulsar-api-key',
default=__class__.default_pulsar_api_key,
help=f'Pulsar API key',
)
parser.add_argument(
'--config-push-queue',
default=default_config_queue,
help=f'Config push queue {default_config_queue}',
)
parser.add_argument(
'--pulsar-listener',
help=f'Pulsar listener (default: none)',
)
parser.add_argument(
'-l', '--log-level',
type=LogLevel,
default=LogLevel.INFO,
choices=list(LogLevel),
help=f'Output queue (default: info)'
)
parser.add_argument(
'--metrics',
action=argparse.BooleanOptionalAction,
default=True,
help=f'Metrics enabled (default: true)',
)
parser.add_argument(
'-P', '--metrics-port',
type=int,
default=8000,
help=f'Pulsar host (default: 8000)',
)
async def start(self):
pass
async def run_config_queue(self):
if self.module == "config.service":
print("I am config-svc, not looking at config queue", flush=True)
return
print("Config thread running", flush=True)
while True:
try:
msg = await asyncio.to_thread(
self.config_subscriber.receive, timeout_millis=2000
)
except pulsar.Timeout:
continue
v = msg.value()
print("Got config version", v.version, flush=True)
await self.on_config(v.version, v.config)
async def on_config(self, version, config):
pass
async def run(self):
raise RuntimeError("Something should have implemented the run method")
@classmethod
async def launch_async(cls, args, prog):
p = cls(**args)
p.module = prog
await p.start()
task1 = asyncio.create_task(p.run_config_queue())
task2 = asyncio.create_task(p.run())
await asyncio.gather(task1, task2)
@classmethod
def launch(cls, prog, doc):
parser = argparse.ArgumentParser(
prog=prog,
description=doc
)
cls.add_args(parser)
args = parser.parse_args()
args = vars(args)
print(args)
if args["metrics"]:
start_http_server(args["metrics_port"])
while True:
try:
asyncio.run(cls.launch_async(args, prog))
except KeyboardInterrupt:
print("Keyboard interrupt.")
return
except _pulsar.Interrupted:
print("Pulsar Interrupted.")
return
except Exception as e:
print(type(e))
print("Exception:", e, flush=True)
print("Will retry...", flush=True)
time.sleep(4)

View file

@ -1,93 +1,136 @@
import asyncio
from pulsar.schema import JsonSchema
import pulsar
from prometheus_client import Histogram, Info, Counter, Enum
import _pulsar
import asyncio
import time
from . base_processor import BaseProcessor
from .. exceptions import TooManyRequests
default_rate_limit_retry = 10
default_rate_limit_timeout = 7200
class Consumer:
class Consumer(BaseProcessor):
def __init__(
self, taskgroup, flow, client, topic, subscriber, schema,
handler,
metrics = None,
start_of_messages=False,
rate_limit_retry_time = 10, rate_limit_timeout = 7200,
reconnect_time = 5,
):
def __init__(self, **params):
self.taskgroup = taskgroup
self.flow = flow
self.client = client
self.topic = topic
self.subscriber = subscriber
self.schema = schema
self.handler = handler
if not hasattr(__class__, "state_metric"):
__class__.state_metric = Enum(
'processor_state', 'Processor state',
states=['starting', 'running', 'stopped']
)
__class__.state_metric.state('starting')
self.rate_limit_retry_time = rate_limit_retry_time
self.rate_limit_timeout = rate_limit_timeout
__class__.state_metric.state('starting')
self.reconnect_time = 5
super(Consumer, self).__init__(**params)
self.start_of_messages = start_of_messages
self.input_queue = params.get("input_queue")
self.subscriber = params.get("subscriber")
self.input_schema = params.get("input_schema")
self.running = True
self.task = None
self.rate_limit_retry = params.get(
"rate_limit_retry", default_rate_limit_retry
)
self.rate_limit_timeout = params.get(
"rate_limit_timeout", default_rate_limit_timeout
)
self.metrics = metrics
if self.input_schema == None:
raise RuntimeError("input_schema must be specified")
self.consumer = None
if not hasattr(__class__, "request_metric"):
__class__.request_metric = Histogram(
'request_latency', 'Request latency (seconds)'
)
def __del__(self):
self.running = False
if not hasattr(__class__, "pubsub_metric"):
__class__.pubsub_metric = Info(
'pubsub', 'Pub/sub configuration'
)
if hasattr(self, "consumer"):
if self.consumer:
self.consumer.close()
if not hasattr(__class__, "processing_metric"):
__class__.processing_metric = Counter(
'processing_count', 'Processing count', ["status"]
)
async def stop(self):
if not hasattr(__class__, "rate_limit_metric"):
__class__.rate_limit_metric = Counter(
'rate_limit_count', 'Rate limit event count',
)
self.running = False
await self.task
__class__.pubsub_metric.info({
"input_queue": self.input_queue,
"subscriber": self.subscriber,
"input_schema": self.input_schema.__name__,
"rate_limit_retry": str(self.rate_limit_retry),
"rate_limit_timeout": str(self.rate_limit_timeout),
})
async def start(self):
self.consumer = self.client.subscribe(
self.input_queue, self.subscriber,
consumer_type=pulsar.ConsumerType.Shared,
schema=JsonSchema(self.input_schema),
)
self.running = True
print("Initialised consumer.", flush=True)
# Puts it in the stopped state, the run thread should set running
if self.metrics:
self.metrics.state("stopped")
self.task = self.taskgroup.create_task(self.run())
async def run(self):
__class__.state_metric.state('running')
while self.running:
while True:
if self.metrics:
self.metrics.state("stopped")
msg = await asyncio.to_thread(self.consumer.receive)
try:
print(self.topic, "subscribing...", flush=True)
if self.start_of_messages:
pos = pulsar.InitialPosition.Earliest
else:
pos = pulsar.InitialPosition.Latest
self.consumer = await asyncio.to_thread(
self.client.subscribe,
topic = self.topic,
subscription_name = self.subscriber,
schema = JsonSchema(self.schema),
initial_position = pos,
consumer_type = pulsar.ConsumerType.Shared,
)
except Exception as e:
print("consumer subs Exception:", e, flush=True)
await asyncio.sleep(self.reconnect_time)
continue
print(self.topic, "subscribed", flush=True)
if self.metrics:
self.metrics.state("running")
try:
await self.consume()
if self.metrics:
self.metrics.state("stopped")
except Exception as e:
print("consumer loop exception:", e, flush=True)
self.consumer.close()
self.consumer = None
await asyncio.sleep(self.reconnect_time)
continue
async def consume(self):
while self.running:
try:
msg = await asyncio.to_thread(
self.consumer.receive,
timeout_millis=2000
)
except _pulsar.Timeout:
continue
except Exception as e:
raise e
expiry = time.time() + self.rate_limit_timeout
# This loop is for retry on rate-limit / resource limits
while True:
while self.running:
if time.time() > expiry:
@ -97,20 +140,31 @@ class Consumer(BaseProcessor):
# be retried
self.consumer.negative_acknowledge(msg)
__class__.processing_metric.labels(status="error").inc()
if self.metrics:
self.metrics.process("error")
# Break out of retry loop, processes next message
break
try:
with __class__.request_metric.time():
await self.handle(msg)
print("Handle...", flush=True)
if self.metrics:
with self.metrics.record_time():
await self.handler(msg, self, self.flow)
else:
await self.handler(msg, self.consumer)
print("Handled.", flush=True)
# Acknowledge successful processing of the message
self.consumer.acknowledge(msg)
__class__.processing_metric.labels(status="success").inc()
if self.metrics:
self.metrics.process("success")
# Break out of retry loop
break
@ -119,55 +173,25 @@ class Consumer(BaseProcessor):
print("TooManyRequests: will retry...", flush=True)
__class__.rate_limit_metric.inc()
if self.metrics:
self.metrics.rate_limit()
# Sleep
time.sleep(self.rate_limit_retry)
await asyncio.sleep(self.rate_limit_retry_time)
# Contine from retry loop, just causes a reprocessing
continue
except Exception as e:
print("Exception:", e, flush=True)
print("consume exception:", e, flush=True)
# Message failed to be processed, this causes it to
# be retried
self.consumer.negative_acknowledge(msg)
__class__.processing_metric.labels(status="error").inc()
if self.metrics:
self.metrics.process("error")
# Break out of retry loop, processes next message
break
@staticmethod
def add_args(parser, default_input_queue, default_subscriber):
BaseProcessor.add_args(parser)
parser.add_argument(
'-i', '--input-queue',
default=default_input_queue,
help=f'Input queue (default: {default_input_queue})'
)
parser.add_argument(
'-s', '--subscriber',
default=default_subscriber,
help=f'Queue subscriber name (default: {default_subscriber})'
)
parser.add_argument(
'--rate-limit-retry',
type=int,
default=default_rate_limit_retry,
help=f'Rate limit retry (default: {default_rate_limit_retry})'
)
parser.add_argument(
'--rate-limit-timeout',
type=int,
default=default_rate_limit_timeout,
help=f'Rate limit timeout (default: {default_rate_limit_timeout})'
)

View file

@ -1,62 +0,0 @@
from pulsar.schema import JsonSchema
import pulsar
from prometheus_client import Histogram, Info, Counter, Enum
import time
from . consumer import Consumer
from .. exceptions import TooManyRequests
class ConsumerProducer(Consumer):
def __init__(self, **params):
super(ConsumerProducer, self).__init__(**params)
self.output_queue = params.get("output_queue")
self.output_schema = params.get("output_schema")
if not hasattr(__class__, "output_metric"):
__class__.output_metric = Counter(
'output_count', 'Output items created'
)
__class__.pubsub_metric.info({
"input_queue": self.input_queue,
"output_queue": self.output_queue,
"subscriber": self.subscriber,
"input_schema": self.input_schema.__name__,
"output_schema": self.output_schema.__name__,
"rate_limit_retry": str(self.rate_limit_retry),
"rate_limit_timeout": str(self.rate_limit_timeout),
})
if self.output_schema == None:
raise RuntimeError("output_schema must be specified")
self.producer = self.client.create_producer(
topic=self.output_queue,
schema=JsonSchema(self.output_schema),
chunking_enabled=True,
)
print("Initialised consumer/producer.")
async def send(self, msg, properties={}):
self.producer.send(msg, properties)
__class__.output_metric.inc()
@staticmethod
def add_args(
parser, default_input_queue, default_subscriber,
default_output_queue,
):
Consumer.add_args(parser, default_input_queue, default_subscriber)
parser.add_argument(
'-o', '--output-queue',
default=default_output_queue,
help=f'Output queue (default: {default_output_queue})'
)

View file

@ -0,0 +1,36 @@
from . metrics import ConsumerMetrics
from . consumer import Consumer
from . spec import Spec
class ConsumerSpec(Spec):
def __init__(self, name, schema, handler):
self.name = name
self.schema = schema
self.handler = handler
def add(self, flow, processor, definition):
consumer_metrics = ConsumerMetrics(
flow.id, f"{flow.name}-{self.name}"
)
consumer = Consumer(
taskgroup = processor.taskgroup,
flow = flow,
client = processor.client,
topic = definition[self.name],
subscriber = processor.id + "--" + self.name,
schema = self.schema,
handler = self.handler,
metrics = consumer_metrics,
)
# Consumer handle gets access to producers and other
# metadata
consumer.id = flow.id
consumer.name = self.name
consumer.flow = flow
flow.consumer[self.name] = consumer

View file

@ -0,0 +1,38 @@
from . request_response_spec import RequestResponse, RequestResponseSpec
from .. schema import DocumentEmbeddingsRequest, DocumentEmbeddingsResponse
from .. knowledge import Uri, Literal
class DocumentEmbeddingsClient(RequestResponse):
async def query(self, vectors, limit=20, user="trustgraph",
collection="default", timeout=30):
resp = await self.request(
DocumentEmbeddingsRequest(
vectors = vectors,
limit = limit,
user = user,
collection = collection
),
timeout=timeout
)
print(resp, flush=True)
if resp.error:
raise RuntimeError(resp.error.message)
return resp.documents
class DocumentEmbeddingsClientSpec(RequestResponseSpec):
def __init__(
self, request_name, response_name,
):
super(DocumentEmbeddingsClientSpec, self).__init__(
request_name = request_name,
request_schema = DocumentEmbeddingsRequest,
response_name = response_name,
response_schema = DocumentEmbeddingsResponse,
impl = DocumentEmbeddingsClient,
)

View file

@ -0,0 +1,84 @@
"""
Document embeddings query service. Input is vectors. Output is list of
embeddings.
"""
from .. schema import DocumentEmbeddingsRequest, DocumentEmbeddingsResponse
from .. schema import Error, Value
from . flow_processor import FlowProcessor
from . consumer_spec import ConsumerSpec
from . producer_spec import ProducerSpec
default_ident = "ge-query"
class DocumentEmbeddingsQueryService(FlowProcessor):
def __init__(self, **params):
id = params.get("id")
super(DocumentEmbeddingsQueryService, self).__init__(
**params | { "id": id }
)
self.register_specification(
ConsumerSpec(
name = "request",
schema = DocumentEmbeddingsRequest,
handler = self.on_message
)
)
self.register_specification(
ProducerSpec(
name = "response",
schema = DocumentEmbeddingsResponse,
)
)
async def on_message(self, msg, consumer, flow):
try:
request = msg.value()
# Sender-produced ID
id = msg.properties()["id"]
print(f"Handling input {id}...", flush=True)
docs = await self.query_document_embeddings(request)
print("Send response...", flush=True)
r = DocumentEmbeddingsResponse(documents=docs, error=None)
await flow("response").send(r, properties={"id": id})
print("Done.", flush=True)
except Exception as e:
print(f"Exception: {e}")
print("Send error response...", flush=True)
r = DocumentEmbeddingsResponse(
error=Error(
type = "document-embeddings-query-error",
message = str(e),
),
response=None,
)
await flow("response").send(r, properties={"id": id})
@staticmethod
def add_args(parser):
FlowProcessor.add_args(parser)
def run():
Processor.launch(default_ident, __doc__)

View file

@ -0,0 +1,50 @@
"""
Document embeddings store base class
"""
from .. schema import DocumentEmbeddings
from .. base import FlowProcessor, ConsumerSpec
from .. exceptions import TooManyRequests
default_ident = "document-embeddings-write"
class DocumentEmbeddingsStoreService(FlowProcessor):
def __init__(self, **params):
id = params.get("id")
super(DocumentEmbeddingsStoreService, self).__init__(
**params | { "id": id }
)
self.register_specification(
ConsumerSpec(
name = "input",
schema = DocumentEmbeddings,
handler = self.on_message
)
)
async def on_message(self, msg, consumer, flow):
try:
request = msg.value()
await self.store_document_embeddings(request)
except TooManyRequests as e:
raise e
except Exception as e:
print(f"Exception: {e}")
raise e
@staticmethod
def add_args(parser):
FlowProcessor.add_args(parser)

View file

@ -0,0 +1,31 @@
from . request_response_spec import RequestResponse, RequestResponseSpec
from .. schema import EmbeddingsRequest, EmbeddingsResponse
class EmbeddingsClient(RequestResponse):
async def embed(self, text, timeout=30):
resp = await self.request(
EmbeddingsRequest(
text = text
),
timeout=timeout
)
if resp.error:
raise RuntimeError(resp.error.message)
return resp.vectors
class EmbeddingsClientSpec(RequestResponseSpec):
def __init__(
self, request_name, response_name,
):
super(EmbeddingsClientSpec, self).__init__(
request_name = request_name,
request_schema = EmbeddingsRequest,
response_name = response_name,
response_schema = EmbeddingsResponse,
impl = EmbeddingsClient,
)

View file

@ -0,0 +1,90 @@
"""
Embeddings resolution base class
"""
import time
from prometheus_client import Histogram
from .. schema import EmbeddingsRequest, EmbeddingsResponse, Error
from .. exceptions import TooManyRequests
from .. base import FlowProcessor, ConsumerSpec, ProducerSpec
default_ident = "embeddings"
class EmbeddingsService(FlowProcessor):
def __init__(self, **params):
id = params.get("id")
super(EmbeddingsService, self).__init__(**params | { "id": id })
self.register_specification(
ConsumerSpec(
name = "request",
schema = EmbeddingsRequest,
handler = self.on_request
)
)
self.register_specification(
ProducerSpec(
name = "response",
schema = EmbeddingsResponse
)
)
async def on_request(self, msg, consumer, flow):
try:
request = msg.value()
# Sender-produced ID
id = msg.properties()["id"]
print("Handling request", id, "...", flush=True)
vectors = await self.on_embeddings(request.text)
await flow("response").send(
EmbeddingsResponse(
error = None,
vectors = vectors,
),
properties={"id": id}
)
print("Handled.", flush=True)
except TooManyRequests as e:
raise e
except Exception as e:
# Apart from rate limits, treat all exceptions as unrecoverable
print(f"Exception: {e}", flush=True)
print("Send error response...", flush=True)
await flow.producer["response"].send(
EmbeddingsResponse(
error=Error(
type = "embeddings-error",
message = str(e),
),
vectors=None,
),
properties={"id": id}
)
@staticmethod
def add_args(parser):
FlowProcessor.add_args(parser)

View file

@ -0,0 +1,32 @@
import asyncio
class Flow:
def __init__(self, id, flow, processor, defn):
self.id = id
self.name = flow
self.producer = {}
# Consumers and publishers. Is this a bit untidy?
self.consumer = {}
self.setting = {}
for spec in processor.specifications:
spec.add(self, processor, defn)
async def start(self):
for c in self.consumer.values():
await c.start()
async def stop(self):
for c in self.consumer.values():
await c.stop()
def __call__(self, key):
if key in self.producer: return self.producer[key]
if key in self.consumer: return self.consumer[key]
if key in self.setting: return self.setting[key].value
return None

View file

@ -0,0 +1,115 @@
# Base class for processor with management of flows in & out which are managed
# by configuration. This is probably all processor types, except for the
# configuration service which can't manage itself.
import json
from pulsar.schema import JsonSchema
from .. schema import Error
from .. schema import config_request_queue, config_response_queue
from .. schema import config_push_queue
from .. log_level import LogLevel
from . async_processor import AsyncProcessor
from . flow import Flow
# Parent class for configurable processors, configured with flows by
# the config service
class FlowProcessor(AsyncProcessor):
def __init__(self, **params):
# Initialise base class
super(FlowProcessor, self).__init__(**params)
# Register configuration handler
self.register_config_handler(self.on_configure_flows)
# Initialise flow information state
self.flows = {}
# These can be overriden by a derived class:
# Array of specifications: ConsumerSpec, ProducerSpec, SettingSpec
self.specifications = []
print("Service initialised.")
# Register a configuration variable
def register_specification(self, spec):
self.specifications.append(spec)
# Start processing for a new flow
async def start_flow(self, flow, defn):
self.flows[flow] = Flow(self.id, flow, self, defn)
await self.flows[flow].start()
print("Started flow: ", flow)
# Stop processing for a new flow
async def stop_flow(self, flow):
if flow in self.flows:
await self.flows[flow].stop()
del self.flows[flow]
print("Stopped flow: ", flow, flush=True)
# Event handler - called for a configuration change
async def on_configure_flows(self, config, version):
print("Got config version", version, flush=True)
# Skip over invalid data
if "flows-active" not in config: return
# Check there's configuration information for me
if self.id in config["flows-active"]:
# Get my flow config
flow_config = json.loads(config["flows-active"][self.id])
else:
print("No configuration settings for me.", flush=True)
flow_config = {}
# Get list of flows which should be running and are currently
# running
wanted_flows = flow_config.keys()
current_flows = self.flows.keys()
# Start all the flows which arent currently running
for flow in wanted_flows:
if flow not in current_flows:
await self.start_flow(flow, flow_config[flow])
# Stop all the unwanted flows which are due to be stopped
for flow in current_flows:
if flow not in wanted_flows:
await self.stop_flow(flow)
print("Handled config update")
# Start threads, just call parent
async def start(self):
await super(FlowProcessor, self).start()
@staticmethod
def add_args(parser):
AsyncProcessor.add_args(parser)
# parser.add_argument(
# '--rate-limit-retry',
# type=int,
# default=default_rate_limit_retry,
# help=f'Rate limit retry (default: {default_rate_limit_retry})'
# )
# parser.add_argument(
# '--rate-limit-timeout',
# type=int,
# default=default_rate_limit_timeout,
# help=f'Rate limit timeout (default: {default_rate_limit_timeout})'
# )

View file

@ -0,0 +1,45 @@
from . request_response_spec import RequestResponse, RequestResponseSpec
from .. schema import GraphEmbeddingsRequest, GraphEmbeddingsResponse
from .. knowledge import Uri, Literal
def to_value(x):
if x.is_uri: return Uri(x.value)
return Literal(x.value)
class GraphEmbeddingsClient(RequestResponse):
async def query(self, vectors, limit=20, user="trustgraph",
collection="default", timeout=30):
resp = await self.request(
GraphEmbeddingsRequest(
vectors = vectors,
limit = limit,
user = user,
collection = collection
),
timeout=timeout
)
print(resp, flush=True)
if resp.error:
raise RuntimeError(resp.error.message)
return [
to_value(v)
for v in resp.entities
]
class GraphEmbeddingsClientSpec(RequestResponseSpec):
def __init__(
self, request_name, response_name,
):
super(GraphEmbeddingsClientSpec, self).__init__(
request_name = request_name,
request_schema = GraphEmbeddingsRequest,
response_name = response_name,
response_schema = GraphEmbeddingsResponse,
impl = GraphEmbeddingsClient,
)

View file

@ -0,0 +1,84 @@
"""
Graph embeddings query service. Input is vectors. Output is list of
embeddings.
"""
from .. schema import GraphEmbeddingsRequest, GraphEmbeddingsResponse
from .. schema import Error, Value
from . flow_processor import FlowProcessor
from . consumer_spec import ConsumerSpec
from . producer_spec import ProducerSpec
default_ident = "ge-query"
class GraphEmbeddingsQueryService(FlowProcessor):
def __init__(self, **params):
id = params.get("id")
super(GraphEmbeddingsQueryService, self).__init__(
**params | { "id": id }
)
self.register_specification(
ConsumerSpec(
name = "request",
schema = GraphEmbeddingsRequest,
handler = self.on_message
)
)
self.register_specification(
ProducerSpec(
name = "response",
schema = GraphEmbeddingsResponse,
)
)
async def on_message(self, msg, consumer, flow):
try:
request = msg.value()
# Sender-produced ID
id = msg.properties()["id"]
print(f"Handling input {id}...", flush=True)
entities = await self.query_graph_embeddings(request)
print("Send response...", flush=True)
r = GraphEmbeddingsResponse(entities=entities, error=None)
await flow("response").send(r, properties={"id": id})
print("Done.", flush=True)
except Exception as e:
print(f"Exception: {e}")
print("Send error response...", flush=True)
r = GraphEmbeddingsResponse(
error=Error(
type = "graph-embeddings-query-error",
message = str(e),
),
response=None,
)
await flow("response").send(r, properties={"id": id})
@staticmethod
def add_args(parser):
FlowProcessor.add_args(parser)
def run():
Processor.launch(default_ident, __doc__)

View file

@ -0,0 +1,50 @@
"""
Graph embeddings store base class
"""
from .. schema import GraphEmbeddings
from .. base import FlowProcessor, ConsumerSpec
from .. exceptions import TooManyRequests
default_ident = "graph-embeddings-write"
class GraphEmbeddingsStoreService(FlowProcessor):
def __init__(self, **params):
id = params.get("id")
super(GraphEmbeddingsStoreService, self).__init__(
**params | { "id": id }
)
self.register_specification(
ConsumerSpec(
name = "input",
schema = GraphEmbeddings,
handler = self.on_message
)
)
async def on_message(self, msg, consumer, flow):
try:
request = msg.value()
await self.store_graph_embeddings(request)
except TooManyRequests as e:
raise e
except Exception as e:
print(f"Exception: {e}")
raise e
@staticmethod
def add_args(parser):
FlowProcessor.add_args(parser)

View file

@ -0,0 +1,33 @@
from . request_response_spec import RequestResponse, RequestResponseSpec
from .. schema import GraphRagQuery, GraphRagResponse
class GraphRagClient(RequestResponse):
async def rag(self, query, user="trustgraph", collection="default",
timeout=600):
resp = await self.request(
GraphRagQuery(
query = query,
user = user,
collection = collection,
),
timeout=timeout
)
if resp.error:
raise RuntimeError(resp.error.message)
return resp.response
class GraphRagClientSpec(RequestResponseSpec):
def __init__(
self, request_name, response_name,
):
super(GraphRagClientSpec, self).__init__(
request_name = request_name,
request_schema = GraphRagQuery,
response_name = response_name,
response_schema = GraphRagResponse,
impl = GraphRagClient,
)

View file

@ -0,0 +1,114 @@
"""
LLM text completion base class
"""
import time
from prometheus_client import Histogram
from .. schema import TextCompletionRequest, TextCompletionResponse, Error
from .. exceptions import TooManyRequests
from .. base import FlowProcessor, ConsumerSpec, ProducerSpec
default_ident = "text-completion"
class LlmResult:
__slots__ = ["text", "in_token", "out_token", "model"]
class LlmService(FlowProcessor):
def __init__(self, **params):
id = params.get("id")
super(LlmService, self).__init__(**params | { "id": id })
self.register_specification(
ConsumerSpec(
name = "request",
schema = TextCompletionRequest,
handler = self.on_request
)
)
self.register_specification(
ProducerSpec(
name = "response",
schema = TextCompletionResponse
)
)
if not hasattr(__class__, "text_completion_metric"):
__class__.text_completion_metric = Histogram(
'text_completion_duration',
'Text completion duration (seconds)',
["id", "flow"],
buckets=[
0.25, 0.5, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0,
8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0,
30.0, 35.0, 40.0, 45.0, 50.0, 60.0, 80.0, 100.0,
120.0
]
)
async def on_request(self, msg, consumer, flow):
try:
request = msg.value()
# Sender-produced ID
id = msg.properties()["id"]
with __class__.text_completion_metric.labels(
id=self.id,
flow=f"{flow.name}-{consumer.name}",
).time():
response = await self.generate_content(
request.system, request.prompt
)
await flow("response").send(
TextCompletionResponse(
error=None,
response=response.text,
in_token=response.in_token,
out_token=response.out_token,
model=response.model
),
properties={"id": id}
)
except TooManyRequests as e:
raise e
except Exception as e:
# Apart from rate limits, treat all exceptions as unrecoverable
print(f"Exception: {e}")
print("Send error response...", flush=True)
await flow.producer["response"].send(
TextCompletionResponse(
error=Error(
type = "llm-error",
message = str(e),
),
response=None,
in_token=None,
out_token=None,
model=None,
),
properties={"id": id}
)
@staticmethod
def add_args(parser):
FlowProcessor.add_args(parser)

View file

@ -0,0 +1,82 @@
from prometheus_client import start_http_server, Info, Enum, Histogram
from prometheus_client import Counter
class ConsumerMetrics:
def __init__(self, id, flow=None):
self.id = id
self.flow = flow
if not hasattr(__class__, "state_metric"):
__class__.state_metric = Enum(
'consumer_state', 'Consumer state',
["id", "flow"],
states=['stopped', 'running']
)
if not hasattr(__class__, "request_metric"):
__class__.request_metric = Histogram(
'request_latency', 'Request latency (seconds)',
["id", "flow"],
)
if not hasattr(__class__, "processing_metric"):
__class__.processing_metric = Counter(
'processing_count', 'Processing count',
["id", "flow", "status"]
)
if not hasattr(__class__, "rate_limit_metric"):
__class__.rate_limit_metric = Counter(
'rate_limit_count', 'Rate limit event count',
["id", "flow"]
)
def process(self, status):
__class__.processing_metric.labels(
id=self.id, flow=self.flow, status=status
).inc()
def rate_limit(self):
__class__.rate_limit_metric.labels(
id=self.id, flow=self.flow
).inc()
def state(self, state):
__class__.state_metric.labels(
id=self.id, flow=self.flow
).state(state)
def record_time(self):
return __class__.request_metric.labels(
id=self.id, flow=self.flow
).time()
class ProducerMetrics:
def __init__(self, id, flow=None):
self.id = id
self.flow = flow
if not hasattr(__class__, "output_metric"):
__class__.output_metric = Counter(
'output_count', 'Output items created',
["id", "flow"]
)
def inc(self):
__class__.output_metric.labels(id=self.id, flow=self.flow).inc()
class ProcessorMetrics:
def __init__(self, id):
self.id = id
if not hasattr(__class__, "processor_metric"):
__class__.processor_metric = Info(
'processor', 'Processor configuration',
["id"]
)
def info(self, info):
__class__.processor_metric.labels(id=self.id).info(info)

View file

@ -1,56 +1,69 @@
from pulsar.schema import JsonSchema
from prometheus_client import Info, Counter
import asyncio
from . base_processor import BaseProcessor
class Producer:
class Producer(BaseProcessor):
def __init__(self, client, topic, schema, metrics=None):
self.client = client
self.topic = topic
self.schema = schema
def __init__(self, **params):
self.metrics = metrics
output_queue = params.get("output_queue")
output_schema = params.get("output_schema")
self.running = True
self.producer = None
if not hasattr(__class__, "output_metric"):
__class__.output_metric = Counter(
'output_count', 'Output items created'
)
def __del__(self):
if not hasattr(__class__, "pubsub_metric"):
__class__.pubsub_metric = Info(
'pubsub', 'Pub/sub configuration'
)
self.running = False
__class__.pubsub_metric.info({
"output_queue": output_queue,
"output_schema": output_schema.__name__,
})
if hasattr(self, "producer"):
if self.producer:
self.producer.close()
super(Producer, self).__init__(**params)
async def start(self):
self.running = True
if output_schema == None:
raise RuntimeError("output_schema must be specified")
self.producer = self.client.create_producer(
topic=output_queue,
schema=JsonSchema(output_schema),
chunking_enabled=True,
)
async def stop(self):
self.running = False
async def send(self, msg, properties={}):
self.producer.send(msg, properties)
__class__.output_metric.inc()
@staticmethod
def add_args(
parser, default_input_queue, default_subscriber,
default_output_queue,
):
if not self.running: return
BaseProcessor.add_args(parser)
while self.running and self.producer is None:
try:
print("Connect publisher to", self.topic, "...", flush=True)
self.producer = self.client.create_producer(
topic = self.topic,
schema = JsonSchema(self.schema)
)
print("Connected to", self.topic, flush=True)
except Exception as e:
print("Exception:", e, flush=True)
await asyncio.sleep(2)
if not self.running: break
while self.running:
try:
await asyncio.to_thread(
self.producer.send,
msg, properties
)
if self.metrics:
self.metrics.inc()
# Delivery success, break out of loop
break
except Exception as e:
print("Exception:", e, flush=True)
self.producer.close()
self.producer = None
parser.add_argument(
'-o', '--output-queue',
default=default_output_queue,
help=f'Output queue (default: {default_output_queue})'
)

View file

@ -0,0 +1,25 @@
from . producer import Producer
from . metrics import ProducerMetrics
from . spec import Spec
class ProducerSpec(Spec):
def __init__(self, name, schema):
self.name = name
self.schema = schema
def add(self, flow, processor, definition):
producer_metrics = ProducerMetrics(
flow.id, f"{flow.name}-{self.name}"
)
producer = Producer(
client = processor.client,
topic = definition[self.name],
schema = self.schema,
metrics = producer_metrics,
)
flow.producer[self.name] = producer

View file

@ -0,0 +1,93 @@
import json
from . request_response_spec import RequestResponse, RequestResponseSpec
from .. schema import PromptRequest, PromptResponse
class PromptClient(RequestResponse):
async def prompt(self, id, variables, timeout=600):
resp = await self.request(
PromptRequest(
id = id,
terms = {
k: json.dumps(v)
for k, v in variables.items()
}
),
timeout=timeout
)
if resp.error:
raise RuntimeError(resp.error.message)
if resp.text: return resp.text
return json.loads(resp.object)
async def extract_definitions(self, text, timeout=600):
return await self.prompt(
id = "extract-definitions",
variables = { "text": text },
timeout = timeout,
)
async def extract_relationships(self, text, timeout=600):
return await self.prompt(
id = "extract-relationships",
variables = { "text": text },
timeout = timeout,
)
async def kg_prompt(self, query, kg, timeout=600):
return await self.prompt(
id = "kg-prompt",
variables = {
"query": query,
"knowledge": [
{ "s": v[0], "p": v[1], "o": v[2] }
for v in kg
]
},
timeout = timeout,
)
async def document_prompt(self, query, documents, timeout=600):
return await self.prompt(
id = "document-prompt",
variables = {
"query": query,
"documents": documents,
},
timeout = timeout,
)
async def agent_react(self, variables, timeout=600):
return await self.prompt(
id = "agent-react",
variables = variables,
timeout = timeout,
)
async def question(self, question, timeout=600):
return await self.prompt(
id = "question",
variables = {
"question": question,
},
timeout = timeout,
)
class PromptClientSpec(RequestResponseSpec):
def __init__(
self, request_name, response_name,
):
super(PromptClientSpec, self).__init__(
request_name = request_name,
request_schema = PromptRequest,
response_name = response_name,
response_schema = PromptResponse,
impl = PromptClient,
)

View file

@ -1,47 +1,52 @@
import queue
from pulsar.schema import JsonSchema
import asyncio
import time
import pulsar
import threading
class Publisher:
def __init__(self, pulsar_client, topic, schema=None, max_size=10,
def __init__(self, client, topic, schema=None, max_size=10,
chunking_enabled=True):
self.client = pulsar_client
self.client = client
self.topic = topic
self.schema = schema
self.q = queue.Queue(maxsize=max_size)
self.q = asyncio.Queue(maxsize=max_size)
self.chunking_enabled = chunking_enabled
self.running = True
def start(self):
self.task = threading.Thread(target=self.run)
self.task.start()
async def start(self):
self.task = asyncio.create_task(self.run())
def stop(self):
async def stop(self):
self.running = False
def join(self):
self.stop()
self.task.join()
async def join(self):
await self.stop()
await self.task
def run(self):
async def run(self):
while self.running:
try:
producer = self.client.create_producer(
topic=self.topic,
schema=self.schema,
schema=JsonSchema(self.schema),
chunking_enabled=self.chunking_enabled,
)
while self.running:
try:
id, item = self.q.get(timeout=0.5)
except queue.Empty:
id, item = await asyncio.wait_for(
self.q.get(),
timeout=0.5
)
except asyncio.TimeoutError:
continue
except asyncio.QueueEmpty:
continue
if id:
@ -55,7 +60,6 @@ class Publisher:
# If handler drops out, sleep a retry
time.sleep(2)
def send(self, id, msg):
self.q.put((id, msg))
async def send(self, id, item):
await self.q.put((id, item))

View file

@ -0,0 +1,80 @@
import os
import pulsar
import uuid
from pulsar.schema import JsonSchema
from .. log_level import LogLevel
class PulsarClient:
default_pulsar_host = os.getenv("PULSAR_HOST", 'pulsar://pulsar:6650')
default_pulsar_api_key = os.getenv("PULSAR_API_KEY", None)
def __init__(self, **params):
self.client = None
pulsar_host = params.get("pulsar_host", self.default_pulsar_host)
pulsar_listener = params.get("pulsar_listener", None)
pulsar_api_key = params.get(
"pulsar_api_key",
self.default_pulsar_api_key
)
log_level = params.get("log_level", LogLevel.INFO)
self.pulsar_host = pulsar_host
self.pulsar_api_key = pulsar_api_key
if pulsar_api_key:
auth = pulsar.AuthenticationToken(pulsar_api_key)
self.client = pulsar.Client(
pulsar_host,
authentication=auth,
logger=pulsar.ConsoleLogger(log_level.to_pulsar())
)
else:
self.client = pulsar.Client(
pulsar_host,
listener_name=pulsar_listener,
logger=pulsar.ConsoleLogger(log_level.to_pulsar())
)
self.pulsar_listener = pulsar_listener
def close(self):
self.client.close()
def __del__(self):
if hasattr(self, "client"):
if self.client:
self.client.close()
@staticmethod
def add_args(parser):
parser.add_argument(
'-p', '--pulsar-host',
default=__class__.default_pulsar_host,
help=f'Pulsar host (default: {__class__.default_pulsar_host})',
)
parser.add_argument(
'--pulsar-api-key',
default=__class__.default_pulsar_api_key,
help=f'Pulsar API key',
)
parser.add_argument(
'--pulsar-listener',
help=f'Pulsar listener (default: none)',
)
parser.add_argument(
'-l', '--log-level',
type=LogLevel,
default=LogLevel.INFO,
choices=list(LogLevel),
help=f'Output queue (default: info)'
)

View file

@ -0,0 +1,136 @@
import uuid
import asyncio
from . subscriber import Subscriber
from . producer import Producer
from . spec import Spec
from . metrics import ConsumerMetrics, ProducerMetrics
class RequestResponse(Subscriber):
def __init__(
self, client, subscription, consumer_name,
request_topic, request_schema,
request_metrics,
response_topic, response_schema,
response_metrics,
):
super(RequestResponse, self).__init__(
client = client,
subscription = subscription,
consumer_name = consumer_name,
topic = response_topic,
schema = response_schema,
)
self.producer = Producer(
client = client,
topic = request_topic,
schema = request_schema,
metrics = request_metrics,
)
async def start(self):
await self.producer.start()
await super(RequestResponse, self).start()
async def stop(self):
await self.producer.stop()
await super(RequestResponse, self).stop()
async def request(self, req, timeout=300, recipient=None):
id = str(uuid.uuid4())
print("Request", id, "...", flush=True)
q = await self.subscribe(id)
try:
await self.producer.send(
req,
properties={"id": id}
)
except Exception as e:
print("Exception:", e)
raise e
try:
while True:
resp = await asyncio.wait_for(
q.get(),
timeout=timeout
)
print("Got response.", flush=True)
if recipient is None:
# If no recipient handler, just return the first
# response we get
return resp
else:
# Recipient handler gets to decide when we're done b
# returning a boolean
fin = await recipient(resp)
# If done, return the last result otherwise loop round for
# next response
if fin:
return resp
else:
continue
except Exception as e:
print("Exception:", e)
raise e
finally:
await self.unsubscribe(id)
# This deals with the request/response case. The caller needs to
# use another service in request/response mode. Uses two topics:
# - we send on the request topic as a producer
# - we receive on the response topic as a subscriber
class RequestResponseSpec(Spec):
def __init__(
self, request_name, request_schema, response_name,
response_schema, impl=RequestResponse
):
self.request_name = request_name
self.request_schema = request_schema
self.response_name = response_name
self.response_schema = response_schema
self.impl = impl
def add(self, flow, processor, definition):
producer_metrics = ProducerMetrics(
flow.id, f"{flow.name}-{self.response_name}"
)
rr = self.impl(
client = processor.client,
subscription = flow.id,
consumer_name = flow.id,
request_topic = definition[self.request_name],
request_schema = self.request_schema,
request_metrics = producer_metrics,
response_topic = definition[self.response_name],
response_schema = self.response_schema,
response_metrics = None,
)
flow.consumer[self.request_name] = rr

View file

@ -0,0 +1,19 @@
from . spec import Spec
class Setting:
def __init__(self, value):
self.value = value
async def start():
pass
async def stop():
pass
class SettingSpec(Spec):
def __init__(self, name):
self.name = name
def add(self, flow, processor, definition):
flow.config[self.name] = Setting(definition[self.name])

View file

@ -0,0 +1,4 @@
class Spec:
pass

View file

@ -1,14 +1,14 @@
import queue
import pulsar
import threading
from pulsar.schema import JsonSchema
import asyncio
import _pulsar
import time
class Subscriber:
def __init__(self, pulsar_client, topic, subscription, consumer_name,
def __init__(self, client, topic, subscription, consumer_name,
schema=None, max_size=100):
self.client = pulsar_client
self.client = client
self.topic = topic
self.subscription = subscription
self.consumer_name = consumer_name
@ -16,35 +16,50 @@ class Subscriber:
self.q = {}
self.full = {}
self.max_size = max_size
self.lock = threading.Lock()
self.lock = asyncio.Lock()
self.running = True
def start(self):
self.task = threading.Thread(target=self.run)
self.task.start()
def stop(self):
async def __del__(self):
self.running = False
def join(self):
self.task.join()
async def start(self):
self.task = asyncio.create_task(self.run())
def run(self):
async def stop(self):
self.running = False
async def join(self):
await self.stop()
await self.task
async def run(self):
while self.running:
try:
consumer = self.client.subscribe(
topic=self.topic,
subscription_name=self.subscription,
consumer_name=self.consumer_name,
schema=self.schema,
topic = self.topic,
subscription_name = self.subscription,
consumer_name = self.consumer_name,
schema = JsonSchema(self.schema),
)
print("Subscriber running...", flush=True)
while self.running:
msg = consumer.receive()
try:
msg = await asyncio.to_thread(
consumer.receive,
timeout_millis=2000
)
except _pulsar.Timeout:
continue
except Exception as e:
print("Exception:", e, flush=True)
print(type(e))
raise e
# Acknowledge successful reception of the message
consumer.acknowledge(msg)
@ -56,57 +71,68 @@ class Subscriber:
value = msg.value()
with self.lock:
async with self.lock:
# FIXME: Hard-coded timeouts
if id in self.q:
try:
# FIXME: Timeout means data goes missing
self.q[id].put(value, timeout=0.5)
except:
pass
await asyncio.wait_for(
self.q[id].put(value),
timeout=2
)
except Exception as e:
print("Q Put:", e, flush=True)
for q in self.full.values():
try:
# FIXME: Timeout means data goes missing
q.put(value, timeout=0.5)
except:
pass
await asyncio.wait_for(
q.put(value),
timeout=2
)
except Exception as e:
print("Q Put:", e, flush=True)
except Exception as e:
print("Exception:", e, flush=True)
print("Subscriber exception:", e, flush=True)
consumer.close()
# If handler drops out, sleep a retry
time.sleep(2)
def subscribe(self, id):
async def subscribe(self, id):
with self.lock:
async with self.lock:
q = queue.Queue(maxsize=self.max_size)
q = asyncio.Queue(maxsize=self.max_size)
self.q[id] = q
return q
def unsubscribe(self, id):
async def unsubscribe(self, id):
with self.lock:
async with self.lock:
if id in self.q:
# self.q[id].shutdown(immediate=True)
del self.q[id]
def subscribe_all(self, id):
async def subscribe_all(self, id):
with self.lock:
async with self.lock:
q = queue.Queue(maxsize=self.max_size)
q = asyncio.Queue(maxsize=self.max_size)
self.full[id] = q
return q
def unsubscribe_all(self, id):
async def unsubscribe_all(self, id):
with self.lock:
async with self.lock:
if id in self.full:
# self.full[id].shutdown(immediate=True)

View file

@ -0,0 +1,30 @@
from . metrics import ConsumerMetrics
from . subscriber import Subscriber
from . spec import Spec
class SubscriberSpec(Spec):
def __init__(self, name, schema):
self.name = name
self.schema = schema
def add(self, flow, processor, definition):
# FIXME: Metrics not used
subscriber_metrics = ConsumerMetrics(
flow.id, f"{flow.name}-{self.name}"
)
subscriber = Subscriber(
client = processor.client,
topic = definition[self.name],
subscription = flow.id,
consumer_name = flow.id,
schema = self.schema,
)
# Put it in the consumer map, does that work?
# It means it gets start/stop call.
flow.consumer[self.name] = subscriber

View file

@ -0,0 +1,30 @@
from . request_response_spec import RequestResponse, RequestResponseSpec
from .. schema import TextCompletionRequest, TextCompletionResponse
class TextCompletionClient(RequestResponse):
async def text_completion(self, system, prompt, timeout=600):
resp = await self.request(
TextCompletionRequest(
system = system, prompt = prompt
),
timeout=timeout
)
if resp.error:
raise RuntimeError(resp.error.message)
return resp.response
class TextCompletionClientSpec(RequestResponseSpec):
def __init__(
self, request_name, response_name,
):
super(TextCompletionClientSpec, self).__init__(
request_name = request_name,
request_schema = TextCompletionRequest,
response_name = response_name,
response_schema = TextCompletionResponse,
impl = TextCompletionClient,
)

View file

@ -0,0 +1,61 @@
from . request_response_spec import RequestResponse, RequestResponseSpec
from .. schema import TriplesQueryRequest, TriplesQueryResponse, Value
from .. knowledge import Uri, Literal
class Triple:
def __init__(self, s, p, o):
self.s = s
self.p = p
self.o = o
def to_value(x):
if x.is_uri: return Uri(x.value)
return Literal(x.value)
def from_value(x):
if x is None: return None
if isinstance(x, Uri):
return Value(value=str(x), is_uri=True)
else:
return Value(value=str(x), is_uri=False)
class TriplesClient(RequestResponse):
async def query(self, s=None, p=None, o=None, limit=20,
user="trustgraph", collection="default",
timeout=30):
resp = await self.request(
TriplesQueryRequest(
s = from_value(s),
p = from_value(p),
o = from_value(o),
limit = limit,
user = user,
collection = collection,
),
timeout=timeout
)
if resp.error:
raise RuntimeError(resp.error.message)
triples = [
Triple(to_value(v.s), to_value(v.p), to_value(v.o))
for v in resp.triples
]
return triples
class TriplesClientSpec(RequestResponseSpec):
def __init__(
self, request_name, response_name,
):
super(TriplesClientSpec, self).__init__(
request_name = request_name,
request_schema = TriplesQueryRequest,
response_name = response_name,
response_schema = TriplesQueryResponse,
impl = TriplesClient,
)

View file

@ -0,0 +1,82 @@
"""
Triples query service. Input is a (s, p, o) triple, some values may be
null. Output is a list of triples.
"""
from .. schema import TriplesQueryRequest, TriplesQueryResponse, Error
from .. schema import Value, Triple
from . flow_processor import FlowProcessor
from . consumer_spec import ConsumerSpec
from . producer_spec import ProducerSpec
default_ident = "triples-query"
class TriplesQueryService(FlowProcessor):
def __init__(self, **params):
id = params.get("id")
super(TriplesQueryService, self).__init__(**params | { "id": id })
self.register_specification(
ConsumerSpec(
name = "request",
schema = TriplesQueryRequest,
handler = self.on_message
)
)
self.register_specification(
ProducerSpec(
name = "response",
schema = TriplesQueryResponse,
)
)
async def on_message(self, msg, consumer, flow):
try:
request = msg.value()
# Sender-produced ID
id = msg.properties()["id"]
print(f"Handling input {id}...", flush=True)
triples = await self.query_triples(request)
print("Send response...", flush=True)
r = TriplesQueryResponse(triples=triples, error=None)
await flow("response").send(r, properties={"id": id})
print("Done.", flush=True)
except Exception as e:
print(f"Exception: {e}")
print("Send error response...", flush=True)
r = TriplesQueryResponse(
error = Error(
type = "triples-query-error",
message = str(e),
),
triples = None,
)
await flow("response").send(r, properties={"id": id})
@staticmethod
def add_args(parser):
FlowProcessor.add_args(parser)
def run():
Processor.launch(default_ident, __doc__)

View file

@ -0,0 +1,47 @@
"""
Triples store base class
"""
from .. schema import Triples
from .. base import FlowProcessor, ConsumerSpec
default_ident = "triples-write"
class TriplesStoreService(FlowProcessor):
def __init__(self, **params):
id = params.get("id")
super(TriplesStoreService, self).__init__(**params | { "id": id })
self.register_specification(
ConsumerSpec(
name = "input",
schema = Triples,
handler = self.on_message
)
)
async def on_message(self, msg, consumer, flow):
try:
request = msg.value()
await self.store_triples(request)
except TooManyRequests as e:
raise e
except Exception as e:
print(f"Exception: {e}")
raise e
@staticmethod
def add_args(parser):
FlowProcessor.add_args(parser)

View file

@ -26,12 +26,5 @@ class AgentResponse(Record):
thought = String()
observation = String()
agent_request_queue = topic(
'agent', kind='non-persistent', namespace='request'
)
agent_response_queue = topic(
'agent', kind='non-persistent', namespace='response'
)
############################################################################

View file

@ -2,7 +2,7 @@
from pulsar.schema import Record, Bytes, String, Boolean, Array, Map, Integer
from . topic import topic
from . types import Error, RowSchema
from . types import Error
############################################################################

View file

@ -11,8 +11,6 @@ class Document(Record):
metadata = Metadata()
data = Bytes()
document_ingest_queue = topic('document-load')
############################################################################
# Text documents / text from PDF
@ -21,8 +19,6 @@ class TextDocument(Record):
metadata = Metadata()
text = Bytes()
text_ingest_queue = topic('text-document-load')
############################################################################
# Chunks of text
@ -31,8 +27,6 @@ class Chunk(Record):
metadata = Metadata()
chunk = Bytes()
chunk_ingest_queue = topic('chunk-load')
############################################################################
# Document embeddings are embeddings associated with a chunk
@ -46,8 +40,6 @@ class DocumentEmbeddings(Record):
metadata = Metadata()
chunks = Array(ChunkEmbeddings())
document_embeddings_store_queue = topic('document-embeddings-store')
############################################################################
# Doc embeddings query
@ -62,10 +54,3 @@ class DocumentEmbeddingsResponse(Record):
error = Error()
documents = Array(Bytes())
document_embeddings_request_queue = topic(
'doc-embeddings', kind='non-persistent', namespace='request'
)
document_embeddings_response_queue = topic(
'doc-embeddings', kind='non-persistent', namespace='response',
)

View file

@ -0,0 +1,66 @@
from pulsar.schema import Record, Bytes, String, Boolean, Array, Map, Integer
from . topic import topic
from . types import Error
############################################################################
# Flow service:
# list_classes() -> (classname[])
# get_class(classname) -> (class)
# put_class(class) -> (class)
# delete_class(classname) -> ()
#
# list_flows() -> (flowid[])
# get_flow(flowid) -> (flow)
# start_flow(flowid, classname) -> ()
# stop_flow(flowid) -> ()
# Prompt services, abstract the prompt generation
class FlowRequest(Record):
operation = String() # list_classes, get_class, put_class, delete_class
# list_flows, get_flow, start_flow, stop_flow
# get_class, put_class, delete_class, start_flow
class_name = String()
# put_class
class = String()
# start_flow
description = String()
# get_flow, start_flow, stop_flow
flow_id = String()
class FlowResponse(Record):
# list_classes
class_names = Array(String())
# list_flows
flow_ids = Array(String())
# get_class
class = String()
# get_flow
flow = String()
# get_flow
description = String()
# Everything
error = Error()
flow_request_queue = topic(
'flow', kind='non-persistent', namespace='request'
)
flow_response_queue = topic(
'flow', kind='non-persistent', namespace='response'
)
############################################################################

View file

@ -17,7 +17,7 @@ from .... log_level import LogLevel
from .... base import ConsumerProducer
from .... exceptions import TooManyRequests
module = ".".join(__name__.split(".")[1:-1])
module = "text-completion"
default_input_queue = text_completion_request_queue
default_output_queue = text_completion_response_queue

View file

@ -4,89 +4,37 @@ Embeddings service, applies an embeddings model selected from HuggingFace.
Input is text, output is embeddings vector.
"""
from ... base import EmbeddingsService
from langchain_huggingface import HuggingFaceEmbeddings
from trustgraph.schema import EmbeddingsRequest, EmbeddingsResponse, Error
from trustgraph.schema import embeddings_request_queue
from trustgraph.schema import embeddings_response_queue
from trustgraph.log_level import LogLevel
from trustgraph.base import ConsumerProducer
default_ident = "embeddings"
module = ".".join(__name__.split(".")[1:-1])
default_input_queue = embeddings_request_queue
default_output_queue = embeddings_response_queue
default_subscriber = module
default_model="all-MiniLM-L6-v2"
class Processor(ConsumerProducer):
class Processor(EmbeddingsService):
def __init__(self, **params):
input_queue = params.get("input_queue", default_input_queue)
output_queue = params.get("output_queue", default_output_queue)
subscriber = params.get("subscriber", default_subscriber)
model = params.get("model", default_model)
super(Processor, self).__init__(
**params | {
"input_queue": input_queue,
"output_queue": output_queue,
"subscriber": subscriber,
"input_schema": EmbeddingsRequest,
"output_schema": EmbeddingsResponse,
}
**params | { "model": model }
)
print("Get model...", flush=True)
self.embeddings = HuggingFaceEmbeddings(model_name=model)
async def handle(self, msg):
async def on_embeddings(self, text):
v = msg.value()
# Sender-produced ID
id = msg.properties()["id"]
print(f"Handling input {id}...", flush=True)
try:
text = v.text
embeds = self.embeddings.embed_documents([text])
print("Send response...", flush=True)
r = EmbeddingsResponse(vectors=embeds, error=None)
await self.send(r, properties={"id": id})
print("Done.", flush=True)
except Exception as e:
print(f"Exception: {e}")
print("Send error response...", flush=True)
r = EmbeddingsResponse(
error=Error(
type = "llm-error",
message = str(e),
),
response=None,
)
await self.send(r, properties={"id": id})
self.consumer.acknowledge(msg)
embeds = self.embeddings.embed_documents([text])
print("Done.", flush=True)
return embeds
@staticmethod
def add_args(parser):
ConsumerProducer.add_args(
parser, default_input_queue, default_subscriber,
default_output_queue,
)
EmbeddingsService.add_args(parser)
parser.add_argument(
'-m', '--model',
@ -96,5 +44,5 @@ class Processor(ConsumerProducer):
def run():
Processor.launch(module, __doc__)
Processor.launch(default_ident, __doc__)

View file

@ -8,12 +8,11 @@ logger = logging.getLogger(__name__)
class AgentManager:
def __init__(self, context, tools, additional_context=None):
self.context = context
def __init__(self, tools, additional_context=None):
self.tools = tools
self.additional_context = additional_context
def reason(self, question, history):
async def reason(self, question, history, context):
tools = self.tools
@ -56,10 +55,7 @@ class AgentManager:
logger.info(f"prompt: {variables}")
obj = self.context.prompt.request(
"agent-react",
variables
)
obj = await context("prompt-request").agent_react(variables)
print(json.dumps(obj, indent=4), flush=True)
@ -85,9 +81,13 @@ class AgentManager:
return a
async def react(self, question, history, think, observe):
async def react(self, question, history, think, observe, context):
act = self.reason(question, history)
act = await self.reason(
question = question,
history = history,
context = context,
)
logger.info(f"act: {act}")
if isinstance(act, Final):
@ -104,7 +104,12 @@ class AgentManager:
else:
raise RuntimeError(f"No action for {act.name}!")
resp = action.implementation.invoke(**act.arguments)
print("TOOL>>>", act)
resp = await action.implementation(context).invoke(
**act.arguments
)
print("RSETUL", resp)
resp = resp.strip()

View file

@ -6,103 +6,68 @@ import json
import re
import sys
from pulsar.schema import JsonSchema
from ... base import AgentService, TextCompletionClientSpec, PromptClientSpec
from ... base import GraphRagClientSpec
from ... base import ConsumerProducer
from ... schema import Error
from ... schema import AgentRequest, AgentResponse, AgentStep
from ... schema import agent_request_queue, agent_response_queue
from ... schema import prompt_request_queue as pr_request_queue
from ... schema import prompt_response_queue as pr_response_queue
from ... schema import graph_rag_request_queue as gr_request_queue
from ... schema import graph_rag_response_queue as gr_response_queue
from ... clients.prompt_client import PromptClient
from ... clients.llm_client import LlmClient
from ... clients.graph_rag_client import GraphRagClient
from ... schema import AgentRequest, AgentResponse, AgentStep, Error
from . tools import KnowledgeQueryImpl, TextCompletionImpl
from . agent_manager import AgentManager
from . types import Final, Action, Tool, Argument
module = ".".join(__name__.split(".")[1:-1])
default_ident = "agent-manager"
default_max_iterations = 10
default_input_queue = agent_request_queue
default_output_queue = agent_response_queue
default_subscriber = module
default_max_iterations = 15
class Processor(ConsumerProducer):
class Processor(AgentService):
def __init__(self, **params):
id = params.get("id")
self.max_iterations = int(
params.get("max_iterations", default_max_iterations)
)
tools = {}
input_queue = params.get("input_queue", default_input_queue)
output_queue = params.get("output_queue", default_output_queue)
subscriber = params.get("subscriber", default_subscriber)
prompt_request_queue = params.get(
"prompt_request_queue", pr_request_queue
)
prompt_response_queue = params.get(
"prompt_response_queue", pr_response_queue
)
graph_rag_request_queue = params.get(
"graph_rag_request_queue", gr_request_queue
)
graph_rag_response_queue = params.get(
"graph_rag_response_queue", gr_response_queue
)
self.config_key = params.get("config_type", "agent")
super(Processor, self).__init__(
**params | {
"input_queue": input_queue,
"output_queue": output_queue,
"subscriber": subscriber,
"input_schema": AgentRequest,
"output_schema": AgentResponse,
"prompt_request_queue": prompt_request_queue,
"prompt_response_queue": prompt_response_queue,
"graph_rag_request_queue": gr_request_queue,
"graph_rag_response_queue": gr_response_queue,
"id": id,
"max_iterations": self.max_iterations,
"config_type": self.config_key,
}
)
self.prompt = PromptClient(
subscriber=subscriber,
input_queue=prompt_request_queue,
output_queue=prompt_response_queue,
pulsar_host = self.pulsar_host,
pulsar_api_key=self.pulsar_api_key,
)
self.graph_rag = GraphRagClient(
subscriber=subscriber,
input_queue=graph_rag_request_queue,
output_queue=graph_rag_response_queue,
pulsar_host = self.pulsar_host,
pulsar_api_key=self.pulsar_api_key,
)
# Need to be able to feed requests to myself
self.recursive_input = self.client.create_producer(
topic=input_queue,
schema=JsonSchema(AgentRequest),
)
self.agent = AgentManager(
context=self,
tools=[],
additional_context="",
)
async def on_config(self, version, config):
self.config_handlers.append(self.on_tools_config)
self.register_specification(
TextCompletionClientSpec(
request_name = "text-completion-request",
response_name = "text-completion-response",
)
)
self.register_specification(
GraphRagClientSpec(
request_name = "graph-rag-request",
response_name = "graph-rag-response",
)
)
self.register_specification(
PromptClientSpec(
request_name = "prompt-request",
response_name = "prompt-response",
)
)
async def on_tools_config(self, config, version):
print("Loading configuration version", version)
@ -138,9 +103,9 @@ class Processor(ConsumerProducer):
impl_id = data.get("type")
if impl_id == "knowledge-query":
impl = KnowledgeQueryImpl(self)
impl = KnowledgeQueryImpl
elif impl_id == "text-completion":
impl = TextCompletionImpl(self)
impl = TextCompletionImpl
else:
raise RuntimeError(
f"Tool-kind {impl_id} not known"
@ -155,7 +120,6 @@ class Processor(ConsumerProducer):
)
self.agent = AgentManager(
context=self,
tools=tools,
additional_context=additional
)
@ -164,19 +128,14 @@ class Processor(ConsumerProducer):
except Exception as e:
print("Exception:", e, flush=True)
print("on_tools_config Exception:", e, flush=True)
print("Configuration reload failed", flush=True)
async def handle(self, msg):
async def agent_request(self, request, respond, next, flow):
try:
v = msg.value()
# Sender-produced ID
id = msg.properties()["id"]
if v.history:
if request.history:
history = [
Action(
thought=h.thought,
@ -184,12 +143,12 @@ class Processor(ConsumerProducer):
arguments=h.arguments,
observation=h.observation
)
for h in v.history
for h in request.history
]
else:
history = []
print(f"Question: {v.question}", flush=True)
print(f"Question: {request.question}", flush=True)
if len(history) >= self.max_iterations:
raise RuntimeError("Too many agent iterations")
@ -207,7 +166,7 @@ class Processor(ConsumerProducer):
observation=None,
)
await self.send(r, properties={"id": id})
await respond(r)
async def observe(x):
@ -220,15 +179,21 @@ class Processor(ConsumerProducer):
observation=x,
)
await self.send(r, properties={"id": id})
await respond(r)
act = await self.agent.react(v.question, history, think, observe)
act = await self.agent.react(
question = request.question,
history = history,
think = think,
observe = observe,
context = flow,
)
print(f"Action: {act}", flush=True)
print("Send response...", flush=True)
if isinstance(act, Final):
if type(act) == Final:
print("Send final response...", flush=True)
r = AgentResponse(
answer=act.final,
@ -236,18 +201,20 @@ class Processor(ConsumerProducer):
thought=None,
)
await self.send(r, properties={"id": id})
await respond(r)
print("Done.", flush=True)
return
print("Send next...", flush=True)
history.append(act)
r = AgentRequest(
question=v.question,
plan=v.plan,
state=v.state,
question=request.question,
plan=request.plan,
state=request.state,
history=[
AgentStep(
thought=h.thought,
@ -259,7 +226,7 @@ class Processor(ConsumerProducer):
]
)
self.recursive_input.send(r, properties={"id": id})
await next(r)
print("Done.", flush=True)
@ -267,7 +234,7 @@ class Processor(ConsumerProducer):
except Exception as e:
print(f"Exception: {e}")
print(f"agent_request Exception: {e}")
print("Send error response...", flush=True)
@ -279,39 +246,12 @@ class Processor(ConsumerProducer):
response=None,
)
await self.send(r, properties={"id": id})
await respond(r)
@staticmethod
def add_args(parser):
ConsumerProducer.add_args(
parser, default_input_queue, default_subscriber,
default_output_queue,
)
parser.add_argument(
'--prompt-request-queue',
default=pr_request_queue,
help=f'Prompt request queue (default: {pr_request_queue})',
)
parser.add_argument(
'--prompt-response-queue',
default=pr_response_queue,
help=f'Prompt response queue (default: {pr_response_queue})',
)
parser.add_argument(
'--graph-rag-request-queue',
default=gr_request_queue,
help=f'Graph RAG request queue (default: {gr_request_queue})',
)
parser.add_argument(
'--graph-rag-response-queue',
default=gr_response_queue,
help=f'Graph RAG response queue (default: {gr_response_queue})',
)
AgentService.add_args(parser)
parser.add_argument(
'--max-iterations',
@ -327,5 +267,5 @@ class Processor(ConsumerProducer):
def run():
Processor.launch(module, __doc__)
Processor.launch(default_ident, __doc__)

View file

@ -4,16 +4,22 @@
class KnowledgeQueryImpl:
def __init__(self, context):
self.context = context
def invoke(self, **arguments):
return self.context.graph_rag.request(arguments.get("question"))
async def invoke(self, **arguments):
client = self.context("graph-rag-request")
print("Graph RAG question...", flush=True)
return await client.rag(
arguments.get("question")
)
# This tool implementation knows how to do text completion. This uses
# the prompt service, rather than talking to TextCompletion directly.
class TextCompletionImpl:
def __init__(self, context):
self.context = context
def invoke(self, **arguments):
return self.context.prompt.request(
"question", { "question": arguments.get("question") }
async def invoke(self, **arguments):
client = self.context("prompt-request")
print("Prompt question...", flush=True)
return await client.question(
arguments.get("question")
)

View file

@ -7,40 +7,27 @@ as text as separate output objects.
from langchain_text_splitters import RecursiveCharacterTextSplitter
from prometheus_client import Histogram
from ... schema import TextDocument, Chunk, Metadata
from ... schema import text_ingest_queue, chunk_ingest_queue
from ... log_level import LogLevel
from ... base import ConsumerProducer
from ... schema import TextDocument, Chunk
from ... base import FlowProcessor, ConsumerSpec, ProducerSpec
module = ".".join(__name__.split(".")[1:-1])
default_ident = "chunker"
default_input_queue = text_ingest_queue
default_output_queue = chunk_ingest_queue
default_subscriber = module
class Processor(ConsumerProducer):
class Processor(FlowProcessor):
def __init__(self, **params):
input_queue = params.get("input_queue", default_input_queue)
output_queue = params.get("output_queue", default_output_queue)
subscriber = params.get("subscriber", default_subscriber)
id = params.get("id", default_ident)
chunk_size = params.get("chunk_size", 2000)
chunk_overlap = params.get("chunk_overlap", 100)
super(Processor, self).__init__(
**params | {
"input_queue": input_queue,
"output_queue": output_queue,
"subscriber": subscriber,
"input_schema": TextDocument,
"output_schema": Chunk,
}
**params | { "id": id }
)
if not hasattr(__class__, "chunk_metric"):
__class__.chunk_metric = Histogram(
'chunk_size', 'Chunk size',
["id", "flow"],
buckets=[100, 160, 250, 400, 650, 1000, 1600,
2500, 4000, 6400, 10000, 16000]
)
@ -52,7 +39,24 @@ class Processor(ConsumerProducer):
is_separator_regex=False,
)
async def handle(self, msg):
self.register_specification(
ConsumerSpec(
name = "input",
schema = TextDocument,
handler = self.on_message,
)
)
self.register_specification(
ProducerSpec(
name = "output",
schema = Chunk,
)
)
print("Chunker initialised", flush=True)
async def on_message(self, msg, consumer, flow):
v = msg.value()
print(f"Chunking {v.metadata.id}...", flush=True)
@ -63,24 +67,25 @@ class Processor(ConsumerProducer):
for ix, chunk in enumerate(texts):
print("Chunk", len(chunk.page_content), flush=True)
r = Chunk(
metadata=v.metadata,
chunk=chunk.page_content.encode("utf-8"),
)
__class__.chunk_metric.observe(len(chunk.page_content))
__class__.chunk_metric.labels(
id=consumer.id, flow=consumer.flow
).observe(len(chunk.page_content))
await self.send(r)
await flow("output").send(r)
print("Done.", flush=True)
@staticmethod
def add_args(parser):
ConsumerProducer.add_args(
parser, default_input_queue, default_subscriber,
default_output_queue,
)
FlowProcessor.add_args(parser)
parser.add_argument(
'-z', '--chunk-size',
@ -98,5 +103,5 @@ class Processor(ConsumerProducer):
def run():
Processor.launch(module, __doc__)
Processor.launch(default_ident, __doc__)

View file

@ -7,40 +7,27 @@ as text as separate output objects.
from langchain_text_splitters import TokenTextSplitter
from prometheus_client import Histogram
from ... schema import TextDocument, Chunk, Metadata
from ... schema import text_ingest_queue, chunk_ingest_queue
from ... log_level import LogLevel
from ... base import ConsumerProducer
from ... schema import TextDocument, Chunk
from ... base import FlowProcessor
module = ".".join(__name__.split(".")[1:-1])
default_ident = "chunker"
default_input_queue = text_ingest_queue
default_output_queue = chunk_ingest_queue
default_subscriber = module
class Processor(ConsumerProducer):
class Processor(FlowProcessor):
def __init__(self, **params):
input_queue = params.get("input_queue", default_input_queue)
output_queue = params.get("output_queue", default_output_queue)
subscriber = params.get("subscriber", default_subscriber)
id = params.get("id")
chunk_size = params.get("chunk_size", 250)
chunk_overlap = params.get("chunk_overlap", 15)
super(Processor, self).__init__(
**params | {
"input_queue": input_queue,
"output_queue": output_queue,
"subscriber": subscriber,
"input_schema": TextDocument,
"output_schema": Chunk,
}
**params | { "id": id }
)
if not hasattr(__class__, "chunk_metric"):
__class__.chunk_metric = Histogram(
'chunk_size', 'Chunk size',
["id", "flow"],
buckets=[100, 160, 250, 400, 650, 1000, 1600,
2500, 4000, 6400, 10000, 16000]
)
@ -51,7 +38,24 @@ class Processor(ConsumerProducer):
chunk_overlap=chunk_overlap,
)
async def handle(self, msg):
self.register_specification(
ConsumerSpec(
name = "input",
schema = TextDocument,
handler = self.on_message,
)
)
self.register_specification(
ProducerSpec(
name = "output",
schema = Chunk,
)
)
print("Chunker initialised", flush=True)
async def on_message(self, msg, consumer, flow):
v = msg.value()
print(f"Chunking {v.metadata.id}...", flush=True)
@ -62,24 +66,25 @@ class Processor(ConsumerProducer):
for ix, chunk in enumerate(texts):
print("Chunk", len(chunk.page_content), flush=True)
r = Chunk(
metadata=v.metadata,
chunk=chunk.page_content.encode("utf-8"),
)
__class__.chunk_metric.observe(len(chunk.page_content))
__class__.chunk_metric.labels(
id=consumer.id, flow=consumer.flow
).observe(len(chunk.page_content))
await self.send(r)
await flow("output").send(r)
print("Done.", flush=True)
@staticmethod
def add_args(parser):
ConsumerProducer.add_args(
parser, default_input_queue, default_subscriber,
default_output_queue,
)
FlowProcessor.add_args(parser)
parser.add_argument(
'-z', '--chunk-size',
@ -97,5 +102,5 @@ class Processor(ConsumerProducer):
def run():
Processor.launch(module, __doc__)
Processor.launch(default_ident, __doc__)

View file

@ -0,0 +1,215 @@
from trustgraph.schema import ConfigResponse
from trustgraph.schema import ConfigValue, Error
# This behaves just like a dict, should be easier to add persistent storage
# later
class ConfigurationItems(dict):
pass
class Configuration(dict):
# FIXME: The state is held internally. This only works if there's
# one config service. Should be more than one, and use a
# back-end state store.
def __init__(self, push):
# Version counter
self.version = 0
# External function to respond to update
self.push = push
def __getitem__(self, key):
if key not in self:
self[key] = ConfigurationItems()
return dict.__getitem__(self, key)
async def handle_get(self, v):
for k in v.keys:
if k.type not in self or k.key not in self[k.type]:
return ConfigResponse(
version = None,
values = None,
directory = None,
config = None,
error = Error(
type = "key-error",
message = f"Key error"
)
)
values = [
ConfigValue(
type = k.type,
key = k.key,
value = self[k.type][k.key]
)
for k in v.keys
]
return ConfigResponse(
version = self.version,
values = values,
directory = None,
config = None,
error = None,
)
async def handle_list(self, v):
if v.type not in self:
return ConfigResponse(
version = None,
values = None,
directory = None,
config = None,
error = Error(
type = "key-error",
message = "No such type",
),
)
return ConfigResponse(
version = self.version,
values = None,
directory = list(self[v.type].keys()),
config = None,
error = None,
)
async def handle_getvalues(self, v):
if v.type not in self:
return ConfigResponse(
version = None,
values = None,
directory = None,
config = None,
error = Error(
type = "key-error",
message = f"Key error"
)
)
values = [
ConfigValue(
type = v.type,
key = k,
value = self[v.type][k],
)
for k in self[v.type]
]
return ConfigResponse(
version = self.version,
values = values,
directory = None,
config = None,
error = None,
)
async def handle_delete(self, v):
for k in v.keys:
if k.type not in self or k.key not in self[k.type]:
return ConfigResponse(
version = None,
values = None,
directory = None,
config = None,
error = Error(
type = "key-error",
message = f"Key error"
)
)
for k in v.keys:
del self[k.type][k.key]
self.version += 1
await self.push()
return ConfigResponse(
version = None,
value = None,
directory = None,
values = None,
config = None,
error = None,
)
async def handle_put(self, v):
for k in v.values:
self[k.type][k.key] = k.value
self.version += 1
await self.push()
return ConfigResponse(
version = None,
value = None,
directory = None,
values = None,
error = None,
)
async def handle_config(self, v):
return ConfigResponse(
version = self.version,
value = None,
directory = None,
values = None,
config = self,
error = None,
)
async def handle(self, msg):
print("Handle message ", msg.operation)
if msg.operation == "get":
resp = await self.handle_get(msg)
elif msg.operation == "list":
resp = await self.handle_list(msg)
elif msg.operation == "getvalues":
resp = await self.handle_getvalues(msg)
elif msg.operation == "delete":
resp = await self.handle_delete(msg)
elif msg.operation == "put":
resp = await self.handle_put(msg)
elif msg.operation == "config":
resp = await self.handle_config(msg)
else:
resp = ConfigResponse(
value=None,
directory=None,
values=None,
error=Error(
type = "bad-operation",
message = "Bad operation"
)
)
return resp

View file

@ -1,287 +1,118 @@
"""
Config service. Fetchs an extract from the Wikipedia page
using the API.
Config service. Manages system global configuration state
"""
from pulsar.schema import JsonSchema
from trustgraph.schema import ConfigRequest, ConfigResponse, ConfigPush
from trustgraph.schema import ConfigValue, Error
from trustgraph.schema import Error
from trustgraph.schema import config_request_queue, config_response_queue
from trustgraph.schema import config_push_queue
from trustgraph.log_level import LogLevel
from trustgraph.base import ConsumerProducer
from trustgraph.base import AsyncProcessor, Consumer, Producer
module = ".".join(__name__.split(".")[1:-1])
from . config import Configuration
from ... base import ProcessorMetrics, ConsumerMetrics, ProducerMetrics
from ... base import Consumer, Producer
default_input_queue = config_request_queue
default_output_queue = config_response_queue
default_ident = "config-svc"
default_request_queue = config_request_queue
default_response_queue = config_response_queue
default_push_queue = config_push_queue
default_subscriber = module
# This behaves just like a dict, should be easier to add persistent storage
# later
class ConfigurationItems(dict):
pass
class Configuration(dict):
def __getitem__(self, key):
if key not in self:
self[key] = ConfigurationItems()
return dict.__getitem__(self, key)
class Processor(ConsumerProducer):
class Processor(AsyncProcessor):
def __init__(self, **params):
input_queue = params.get("input_queue", default_input_queue)
output_queue = params.get("output_queue", default_output_queue)
request_queue = params.get("request_queue", default_request_queue)
response_queue = params.get("response_queue", default_response_queue)
push_queue = params.get("push_queue", default_push_queue)
subscriber = params.get("subscriber", default_subscriber)
id = params.get("id")
request_schema = ConfigRequest
response_schema = ConfigResponse
push_schema = ConfigResponse
super(Processor, self).__init__(
**params | {
"input_queue": input_queue,
"output_queue": output_queue,
"push_queue": output_queue,
"subscriber": subscriber,
"input_schema": ConfigRequest,
"output_schema": ConfigResponse,
"push_schema": ConfigPush,
"request_schema": request_schema.__name__,
"response_schema": response_schema.__name__,
"push_schema": push_schema.__name__,
}
)
self.push_prod = self.client.create_producer(
topic=push_queue,
schema=JsonSchema(ConfigPush),
request_metrics = ConsumerMetrics(id + "-request")
response_metrics = ProducerMetrics(id + "-response")
push_metrics = ProducerMetrics(id + "-push")
self.push_pub = Producer(
client = self.client,
topic = push_queue,
schema = ConfigPush,
metrics = push_metrics,
)
# FIXME: The state is held internally. This only works if there's
# one config service. Should be more than one, and use a
# back-end state store.
self.config = Configuration()
self.response_pub = Producer(
client = self.client,
topic = response_queue,
schema = ConfigResponse,
metrics = response_metrics,
)
# Version counter
self.version = 0
self.subs = Consumer(
taskgroup = self.taskgroup,
client = self.client,
flow = None,
topic = request_queue,
subscriber = id,
schema = request_schema,
handler = self.on_message,
metrics = request_metrics,
)
self.config = Configuration(self.push)
print("Service initialised.")
async def start(self):
await self.push()
await self.subs.start()
async def handle_get(self, v, id):
for k in v.keys:
if k.type not in self.config or k.key not in self.config[k.type]:
return ConfigResponse(
version = None,
values = None,
directory = None,
config = None,
error = Error(
type = "key-error",
message = f"Key error"
)
)
values = [
ConfigValue(
type = k.type,
key = k.key,
value = self.config[k.type][k.key]
)
for k in v.keys
]
return ConfigResponse(
version = self.version,
values = values,
directory = None,
config = None,
error = None,
)
async def handle_list(self, v, id):
if v.type not in self.config:
return ConfigResponse(
version = None,
values = None,
directory = None,
config = None,
error = Error(
type = "key-error",
message = "No such type",
),
)
return ConfigResponse(
version = self.version,
values = None,
directory = list(self.config[v.type].keys()),
config = None,
error = None,
)
async def handle_getvalues(self, v, id):
if v.type not in self.config:
return ConfigResponse(
version = None,
values = None,
directory = None,
config = None,
error = Error(
type = "key-error",
message = f"Key error"
)
)
values = [
ConfigValue(
type = v.type,
key = k,
value = self.config[v.type][k],
)
for k in self.config[v.type]
]
return ConfigResponse(
version = self.version,
values = values,
directory = None,
config = None,
error = None,
)
async def handle_delete(self, v, id):
for k in v.keys:
if k.type not in self.config or k.key not in self.config[k.type]:
return ConfigResponse(
version = None,
values = None,
directory = None,
config = None,
error = Error(
type = "key-error",
message = f"Key error"
)
)
for k in v.keys:
del self.config[k.type][k.key]
self.version += 1
await self.push()
return ConfigResponse(
version = None,
value = None,
directory = None,
values = None,
config = None,
error = None,
)
async def handle_put(self, v, id):
for k in v.values:
self.config[k.type][k.key] = k.value
self.version += 1
await self.push()
return ConfigResponse(
version = None,
value = None,
directory = None,
values = None,
error = None,
)
async def handle_config(self, v, id):
return ConfigResponse(
version = self.version,
value = None,
directory = None,
values = None,
config = self.config,
error = None,
)
async def push(self):
resp = ConfigPush(
version = self.version,
version = self.config.version,
value = None,
directory = None,
values = None,
config = self.config,
error = None,
)
self.push_prod.send(resp)
print("Pushed.")
await self.push_pub.send(resp)
print("Pushed version ", self.config.version)
async def handle(self, msg):
v = msg.value()
# Sender-produced ID
id = msg.properties()["id"]
print(f"Handling {id}...", flush=True)
async def on_message(self, msg, consumer, flow):
try:
if v.operation == "get":
v = msg.value()
resp = await self.handle_get(v, id)
# Sender-produced ID
id = msg.properties()["id"]
elif v.operation == "list":
print(f"Handling {id}...", flush=True)
resp = await self.handle_list(v, id)
resp = await self.config.handle(v)
elif v.operation == "getvalues":
resp = await self.handle_getvalues(v, id)
elif v.operation == "delete":
resp = await self.handle_delete(v, id)
elif v.operation == "put":
resp = await self.handle_put(v, id)
elif v.operation == "config":
resp = await self.handle_config(v, id)
else:
resp = ConfigResponse(
value=None,
directory=None,
values=None,
error=Error(
type = "bad-operation",
message = "Bad operation"
)
)
await self.send(resp, properties={"id": id})
self.consumer.acknowledge(msg)
await self.response_pub.send(resp, properties={"id": id})
except Exception as e:
resp = ConfigResponse(
error=Error(
type = "unexpected-error",
@ -289,24 +120,33 @@ class Processor(ConsumerProducer):
),
text=None,
)
await self.send(resp, properties={"id": id})
self.consumer.acknowledge(msg)
await self.response_pub.send(resp, properties={"id": id})
@staticmethod
def add_args(parser):
ConsumerProducer.add_args(
parser, default_input_queue, default_subscriber,
default_output_queue,
AsyncProcessor.add_args(parser)
parser.add_argument(
'-q', '--request-queue',
default=default_request_queue,
help=f'Request queue (default: {default_request_queue})'
)
parser.add_argument(
'-q', '--push-queue',
'-r', '--response-queue',
default=default_response_queue,
help=f'Response queue {default_response_queue}',
)
parser.add_argument(
'--push-queue',
default=default_push_queue,
help=f'Config push queue (default: {default_push_queue})'
)
def run():
Processor.launch(module, __doc__)
Processor.launch(default_ident, __doc__)

View file

@ -17,12 +17,10 @@ from mistralai.models import OCRResponse
from ... schema import Document, TextDocument, Metadata
from ... schema import document_ingest_queue, text_ingest_queue
from ... log_level import LogLevel
from ... base import ConsumerProducer
from ... base import InputOutputProcessor
module = ".".join(__name__.split(".")[1:-1])
module = "ocr"
default_input_queue = document_ingest_queue
default_output_queue = text_ingest_queue
default_subscriber = module
default_api_key = os.getenv("MISTRAL_TOKEN")
@ -71,19 +69,17 @@ def get_combined_markdown(ocr_response: OCRResponse) -> str:
return "\n\n".join(markdowns)
class Processor(ConsumerProducer):
class Processor(InputOutputProcessor):
def __init__(self, **params):
input_queue = params.get("input_queue", default_input_queue)
output_queue = params.get("output_queue", default_output_queue)
id = params.get("id")
subscriber = params.get("subscriber", default_subscriber)
api_key = params.get("api_key", default_api_key)
super(Processor, self).__init__(
**params | {
"input_queue": input_queue,
"output_queue": output_queue,
"id": id,
"subscriber": subscriber,
"input_schema": Document,
"output_schema": TextDocument,
@ -151,7 +147,7 @@ class Processor(ConsumerProducer):
return markdown
async def handle(self, msg):
async def on_message(self, msg, consumer):
print("PDF message received")
@ -166,17 +162,14 @@ class Processor(ConsumerProducer):
text=markdown.encode("utf-8"),
)
await self.send(r)
await consumer.q.output.send(r)
print("Done.", flush=True)
@staticmethod
def add_args(parser):
ConsumerProducer.add_args(
parser, default_input_queue, default_subscriber,
default_output_queue,
)
InputOutputProcessor.add_args(parser, default_subscriber)
parser.add_argument(
'-k', '--api-key',

View file

@ -9,39 +9,43 @@ import base64
from langchain_community.document_loaders import PyPDFLoader
from ... schema import Document, TextDocument, Metadata
from ... schema import document_ingest_queue, text_ingest_queue
from ... log_level import LogLevel
from ... base import ConsumerProducer
from ... base import FlowProcessor, ConsumerSpec, ProducerSpec
module = ".".join(__name__.split(".")[1:-1])
default_ident = "pdf-decoder"
default_input_queue = document_ingest_queue
default_output_queue = text_ingest_queue
default_subscriber = module
class Processor(ConsumerProducer):
class Processor(FlowProcessor):
def __init__(self, **params):
input_queue = params.get("input_queue", default_input_queue)
output_queue = params.get("output_queue", default_output_queue)
subscriber = params.get("subscriber", default_subscriber)
id = params.get("id", default_ident)
super(Processor, self).__init__(
**params | {
"input_queue": input_queue,
"output_queue": output_queue,
"subscriber": subscriber,
"input_schema": Document,
"output_schema": TextDocument,
"id": id,
}
)
print("PDF inited")
self.register_specification(
ConsumerSpec(
name = "input",
schema = Document,
handler = self.on_message,
)
)
async def handle(self, msg):
self.register_specification(
ProducerSpec(
name = "output",
schema = TextDocument,
)
)
print("PDF message received")
print("PDF inited", flush=True)
async def on_message(self, msg, consumer, flow):
print("PDF message received", flush=True)
v = msg.value()
@ -59,24 +63,22 @@ class Processor(ConsumerProducer):
for ix, page in enumerate(pages):
print("page", ix, flush=True)
r = TextDocument(
metadata=v.metadata,
text=page.page_content.encode("utf-8"),
)
await self.send(r)
await flow("output").send(r)
print("Done.", flush=True)
@staticmethod
def add_args(parser):
ConsumerProducer.add_args(
parser, default_input_queue, default_subscriber,
default_output_queue,
)
FlowProcessor.add_args(parser)
def run():
Processor.launch(module, __doc__)
Processor.launch(default_ident, __doc__)

View file

@ -1,153 +0,0 @@
from . clients.document_embeddings_client import DocumentEmbeddingsClient
from . clients.triples_query_client import TriplesQueryClient
from . clients.embeddings_client import EmbeddingsClient
from . clients.prompt_client import PromptClient
from . schema import DocumentEmbeddingsRequest, DocumentEmbeddingsResponse
from . schema import TriplesQueryRequest, TriplesQueryResponse
from . schema import prompt_request_queue
from . schema import prompt_response_queue
from . schema import embeddings_request_queue
from . schema import embeddings_response_queue
from . schema import document_embeddings_request_queue
from . schema import document_embeddings_response_queue
LABEL="http://www.w3.org/2000/01/rdf-schema#label"
DEFINITION="http://www.w3.org/2004/02/skos/core#definition"
class Query:
def __init__(
self, rag, user, collection, verbose,
doc_limit=20
):
self.rag = rag
self.user = user
self.collection = collection
self.verbose = verbose
self.doc_limit = doc_limit
def get_vector(self, query):
if self.verbose:
print("Compute embeddings...", flush=True)
qembeds = self.rag.embeddings.request(query)
if self.verbose:
print("Done.", flush=True)
return qembeds
def get_docs(self, query):
vectors = self.get_vector(query)
if self.verbose:
print("Get entities...", flush=True)
docs = self.rag.de_client.request(
vectors, limit=self.doc_limit
)
if self.verbose:
print("Docs:", flush=True)
for doc in docs:
print(doc, flush=True)
return docs
class DocumentRag:
def __init__(
self,
pulsar_host="pulsar://pulsar:6650",
pulsar_api_key=None,
pr_request_queue=None,
pr_response_queue=None,
emb_request_queue=None,
emb_response_queue=None,
de_request_queue=None,
de_response_queue=None,
verbose=False,
module="test",
):
self.verbose=verbose
if pr_request_queue is None:
pr_request_queue = prompt_request_queue
if pr_response_queue is None:
pr_response_queue = prompt_response_queue
if emb_request_queue is None:
emb_request_queue = embeddings_request_queue
if emb_response_queue is None:
emb_response_queue = embeddings_response_queue
if de_request_queue is None:
de_request_queue = document_embeddings_request_queue
if de_response_queue is None:
de_response_queue = document_embeddings_response_queue
if self.verbose:
print("Initialising...", flush=True)
self.de_client = DocumentEmbeddingsClient(
pulsar_host=pulsar_host,
subscriber=module + "-de",
input_queue=de_request_queue,
output_queue=de_response_queue,
pulsar_api_key=pulsar_api_key,
)
self.embeddings = EmbeddingsClient(
pulsar_host=pulsar_host,
input_queue=emb_request_queue,
output_queue=emb_response_queue,
subscriber=module + "-emb",
pulsar_api_key=pulsar_api_key,
)
self.lang = PromptClient(
pulsar_host=pulsar_host,
input_queue=pr_request_queue,
output_queue=pr_response_queue,
subscriber=module + "-de-prompt",
pulsar_api_key=pulsar_api_key,
)
if self.verbose:
print("Initialised", flush=True)
def query(
self, query, user="trustgraph", collection="default",
doc_limit=20,
):
if self.verbose:
print("Construct prompt...", flush=True)
q = Query(
rag=self, user=user, collection=collection, verbose=self.verbose,
doc_limit=doc_limit
)
docs = q.get_docs(query)
if self.verbose:
print("Invoke LLM...", flush=True)
print(docs)
print(query)
resp = self.lang.request_document_prompt(query, docs)
if self.verbose:
print("Done", flush=True)
return resp

View file

@ -6,61 +6,63 @@ Output is chunk plus embedding.
"""
from ... schema import Chunk, ChunkEmbeddings, DocumentEmbeddings
from ... schema import chunk_ingest_queue
from ... schema import document_embeddings_store_queue
from ... schema import embeddings_request_queue, embeddings_response_queue
from ... clients.embeddings_client import EmbeddingsClient
from ... log_level import LogLevel
from ... base import ConsumerProducer
from ... schema import EmbeddingsRequest, EmbeddingsResponse
module = ".".join(__name__.split(".")[1:-1])
from ... base import FlowProcessor, RequestResponseSpec, ConsumerSpec
from ... base import ProducerSpec
default_input_queue = chunk_ingest_queue
default_output_queue = document_embeddings_store_queue
default_subscriber = module
default_ident = "document-embeddings"
class Processor(ConsumerProducer):
class Processor(FlowProcessor):
def __init__(self, **params):
input_queue = params.get("input_queue", default_input_queue)
output_queue = params.get("output_queue", default_output_queue)
subscriber = params.get("subscriber", default_subscriber)
emb_request_queue = params.get(
"embeddings_request_queue", embeddings_request_queue
)
emb_response_queue = params.get(
"embeddings_response_queue", embeddings_response_queue
)
id = params.get("id")
super(Processor, self).__init__(
**params | {
"input_queue": input_queue,
"output_queue": output_queue,
"embeddings_request_queue": emb_request_queue,
"embeddings_response_queue": emb_response_queue,
"subscriber": subscriber,
"input_schema": Chunk,
"output_schema": DocumentEmbeddings,
"id": id,
}
)
self.embeddings = EmbeddingsClient(
pulsar_host=self.pulsar_host,
pulsar_api_key=self.pulsar_api_key,
input_queue=emb_request_queue,
output_queue=emb_response_queue,
subscriber=module + "-emb",
self.register_specification(
ConsumerSpec(
name = "input",
schema = Chunk,
handler = self.on_message,
)
)
async def handle(self, msg):
self.register_specification(
RequestResponseSpec(
request_name = "embeddings-request",
request_schema = EmbeddingsRequest,
response_name = "embeddings-response",
response_schema = EmbeddingsResponse,
)
)
self.register_specification(
ProducerSpec(
name = "output",
schema = DocumentEmbeddings
)
)
async def on_message(self, msg, consumer, flow):
v = msg.value()
print(f"Indexing {v.metadata.id}...", flush=True)
try:
vectors = self.embeddings.request(v.chunk)
resp = await flow("embeddings-request").request(
EmbeddingsRequest(
text = v.chunk
)
)
vectors = resp.vectors
embeds = [
ChunkEmbeddings(
@ -74,7 +76,7 @@ class Processor(ConsumerProducer):
chunks=embeds,
)
await self.send(r)
await flow("output").send(r)
except Exception as e:
print("Exception:", e, flush=True)
@ -87,24 +89,9 @@ class Processor(ConsumerProducer):
@staticmethod
def add_args(parser):
ConsumerProducer.add_args(
parser, default_input_queue, default_subscriber,
default_output_queue,
)
parser.add_argument(
'--embeddings-request-queue',
default=embeddings_request_queue,
help=f'Embeddings request queue (default: {embeddings_request_queue})',
)
parser.add_argument(
'--embeddings-response-queue',
default=embeddings_response_queue,
help=f'Embeddings request queue (default: {embeddings_response_queue})',
)
FlowProcessor.add_args(parser)
def run():
Processor.launch(module, __doc__)
Processor.launch(default_ident, __doc__)

View file

@ -1,81 +1,43 @@
"""
Embeddings service, applies an embeddings model selected from HuggingFace.
Embeddings service, applies an embeddings model using fastembed
Input is text, output is embeddings vector.
"""
from ... schema import EmbeddingsRequest, EmbeddingsResponse
from ... schema import embeddings_request_queue, embeddings_response_queue
from ... log_level import LogLevel
from ... base import ConsumerProducer
from ... base import EmbeddingsService
from fastembed import TextEmbedding
import os
module = ".".join(__name__.split(".")[1:-1])
default_ident = "embeddings"
default_input_queue = embeddings_request_queue
default_output_queue = embeddings_response_queue
default_subscriber = module
default_model="sentence-transformers/all-MiniLM-L6-v2"
class Processor(ConsumerProducer):
class Processor(EmbeddingsService):
def __init__(self, **params):
input_queue = params.get("input_queue", default_input_queue)
output_queue = params.get("output_queue", default_output_queue)
subscriber = params.get("subscriber", default_subscriber)
model = params.get("model", default_model)
super(Processor, self).__init__(
**params | {
"input_queue": input_queue,
"output_queue": output_queue,
"subscriber": subscriber,
"input_schema": EmbeddingsRequest,
"output_schema": EmbeddingsResponse,
"model": model,
}
**params | { "model": model }
)
print("Get model...", flush=True)
self.embeddings = TextEmbedding(model_name = model)
async def handle(self, msg):
async def on_embeddings(self, text):
v = msg.value()
# Sender-produced ID
id = msg.properties()["id"]
print(f"Handling input {id}...", flush=True)
text = v.text
vecs = self.embeddings.embed([text])
vecs = [
return [
v.tolist()
for v in vecs
]
print("Send response...", flush=True)
r = EmbeddingsResponse(
vectors=list(vecs),
error=None,
)
await self.send(r, properties={"id": id})
print("Done.", flush=True)
@staticmethod
def add_args(parser):
ConsumerProducer.add_args(
parser, default_input_queue, default_subscriber,
default_output_queue,
)
EmbeddingsService.add_args(parser)
parser.add_argument(
'-m', '--model',
@ -85,5 +47,5 @@ class Processor(ConsumerProducer):
def run():
Processor.launch(module, __doc__)
Processor.launch(default_ident, __doc__)

View file

@ -6,53 +6,48 @@ Output is entity plus embedding.
"""
from ... schema import EntityContexts, EntityEmbeddings, GraphEmbeddings
from ... schema import entity_contexts_ingest_queue
from ... schema import graph_embeddings_store_queue
from ... schema import embeddings_request_queue, embeddings_response_queue
from ... clients.embeddings_client import EmbeddingsClient
from ... log_level import LogLevel
from ... base import ConsumerProducer
from ... schema import EmbeddingsRequest, EmbeddingsResponse
module = ".".join(__name__.split(".")[1:-1])
from ... base import FlowProcessor, EmbeddingsClientSpec, ConsumerSpec
from ... base import ProducerSpec
default_input_queue = entity_contexts_ingest_queue
default_output_queue = graph_embeddings_store_queue
default_subscriber = module
default_ident = "graph-embeddings"
class Processor(ConsumerProducer):
class Processor(FlowProcessor):
def __init__(self, **params):
input_queue = params.get("input_queue", default_input_queue)
output_queue = params.get("output_queue", default_output_queue)
subscriber = params.get("subscriber", default_subscriber)
emb_request_queue = params.get(
"embeddings_request_queue", embeddings_request_queue
)
emb_response_queue = params.get(
"embeddings_response_queue", embeddings_response_queue
)
id = params.get("id")
super(Processor, self).__init__(
**params | {
"input_queue": input_queue,
"output_queue": output_queue,
"embeddings_request_queue": emb_request_queue,
"embeddings_response_queue": emb_response_queue,
"subscriber": subscriber,
"input_schema": EntityContexts,
"output_schema": GraphEmbeddings,
"id": id,
}
)
self.embeddings = EmbeddingsClient(
pulsar_host=self.pulsar_host,
input_queue=emb_request_queue,
output_queue=emb_response_queue,
subscriber=module + "-emb",
self.register_specification(
ConsumerSpec(
name = "input",
schema = EntityContexts,
handler = self.on_message,
)
)
async def handle(self, msg):
self.register_specification(
EmbeddingsClientSpec(
request_name = "embeddings-request",
response_name = "embeddings-response",
)
)
self.register_specification(
ProducerSpec(
name = "output",
schema = GraphEmbeddings
)
)
async def on_message(self, msg, consumer, flow):
v = msg.value()
print(f"Indexing {v.metadata.id}...", flush=True)
@ -63,7 +58,9 @@ class Processor(ConsumerProducer):
for entity in v.entities:
vectors = self.embeddings.request(entity.context)
vectors = await flow("embeddings-request").embed(
text = entity.context
)
entities.append(
EntityEmbeddings(
@ -77,7 +74,7 @@ class Processor(ConsumerProducer):
entities=entities,
)
await self.send(r)
await flow("output").send(r)
except Exception as e:
print("Exception:", e, flush=True)
@ -90,24 +87,9 @@ class Processor(ConsumerProducer):
@staticmethod
def add_args(parser):
ConsumerProducer.add_args(
parser, default_input_queue, default_subscriber,
default_output_queue,
)
parser.add_argument(
'--embeddings-request-queue',
default=embeddings_request_queue,
help=f'Embeddings request queue (default: {embeddings_request_queue})',
)
parser.add_argument(
'--embeddings-response-queue',
default=embeddings_response_queue,
help=f'Embeddings request queue (default: {embeddings_response_queue})',
)
FlowProcessor.add_args(parser)
def run():
Processor.launch(module, __doc__)
Processor.launch(default_ident, __doc__)

View file

@ -11,7 +11,7 @@ from ... base import ConsumerProducer
from ollama import Client
import os
module = ".".join(__name__.split(".")[1:-1])
module = "embeddings"
default_input_queue = embeddings_request_queue
default_output_queue = embeddings_response_queue

View file

@ -11,7 +11,7 @@ from trustgraph.log_level import LogLevel
from trustgraph.base import ConsumerProducer
import requests
module = ".".join(__name__.split(".")[1:-1])
module = "wikipedia"
default_input_queue = encyclopedia_lookup_request_queue
default_output_queue = encyclopedia_lookup_response_queue

View file

@ -5,84 +5,62 @@ get entity definitions which are output as graph edges along with
entity/context definitions for embedding.
"""
import json
import urllib.parse
from pulsar.schema import JsonSchema
from .... schema import Chunk, Triple, Triples, Metadata, Value
from .... schema import EntityContext, EntityContexts
from .... schema import chunk_ingest_queue, triples_store_queue
from .... schema import entity_contexts_ingest_queue
from .... schema import prompt_request_queue
from .... schema import prompt_response_queue
from .... log_level import LogLevel
from .... clients.prompt_client import PromptClient
from .... schema import PromptRequest, PromptResponse
from .... rdf import TRUSTGRAPH_ENTITIES, DEFINITION, RDF_LABEL, SUBJECT_OF
from .... base import ConsumerProducer
from .... base import FlowProcessor, ConsumerSpec, ProducerSpec
from .... base import PromptClientSpec
DEFINITION_VALUE = Value(value=DEFINITION, is_uri=True)
RDF_LABEL_VALUE = Value(value=RDF_LABEL, is_uri=True)
SUBJECT_OF_VALUE = Value(value=SUBJECT_OF, is_uri=True)
module = ".".join(__name__.split(".")[1:-1])
default_ident = "kg-extract-definitions"
default_input_queue = chunk_ingest_queue
default_output_queue = triples_store_queue
default_entity_context_queue = entity_contexts_ingest_queue
default_subscriber = module
class Processor(ConsumerProducer):
class Processor(FlowProcessor):
def __init__(self, **params):
input_queue = params.get("input_queue", default_input_queue)
output_queue = params.get("output_queue", default_output_queue)
ec_queue = params.get(
"entity_context_queue",
default_entity_context_queue
)
subscriber = params.get("subscriber", default_subscriber)
pr_request_queue = params.get(
"prompt_request_queue", prompt_request_queue
)
pr_response_queue = params.get(
"prompt_response_queue", prompt_response_queue
)
id = params.get("id")
super(Processor, self).__init__(
**params | {
"input_queue": input_queue,
"output_queue": output_queue,
"subscriber": subscriber,
"input_schema": Chunk,
"output_schema": Triples,
"prompt_request_queue": pr_request_queue,
"prompt_response_queue": pr_response_queue,
"id": id,
}
)
self.ec_prod = self.client.create_producer(
topic=ec_queue,
schema=JsonSchema(EntityContexts),
self.register_specification(
ConsumerSpec(
name = "input",
schema = Chunk,
handler = self.on_message
)
)
__class__.pubsub_metric.info({
"input_queue": input_queue,
"output_queue": output_queue,
"entity_context_queue": ec_queue,
"prompt_request_queue": pr_request_queue,
"prompt_response_queue": pr_response_queue,
"subscriber": subscriber,
"input_schema": Chunk.__name__,
"output_schema": Triples.__name__,
"vector_schema": EntityContexts.__name__,
})
self.register_specification(
PromptClientSpec(
request_name = "prompt-request",
response_name = "prompt-response",
)
)
self.prompt = PromptClient(
pulsar_host=self.pulsar_host,
pulsar_api_key=self.pulsar_api_key,
input_queue=pr_request_queue,
output_queue=pr_response_queue,
subscriber = module + "-prompt",
self.register_specification(
ProducerSpec(
name = "triples",
schema = Triples
)
)
self.register_specification(
ProducerSpec(
name = "entity-contexts",
schema = EntityContexts
)
)
def to_uri(self, text):
@ -93,36 +71,47 @@ class Processor(ConsumerProducer):
return uri
def get_definitions(self, chunk):
return self.prompt.request_definitions(chunk)
async def emit_edges(self, metadata, triples):
async def emit_triples(self, pub, metadata, triples):
t = Triples(
metadata=metadata,
triples=triples,
)
await self.send(t)
await pub.send(t)
async def emit_ecs(self, metadata, entities):
async def emit_ecs(self, pub, metadata, entities):
t = EntityContexts(
metadata=metadata,
entities=entities,
)
self.ec_prod.send(t)
await pub.send(t)
async def handle(self, msg):
async def on_message(self, msg, consumer, flow):
v = msg.value()
print(f"Indexing {v.metadata.id}...", flush=True)
chunk = v.chunk.decode("utf-8")
print(chunk, flush=True)
try:
defs = self.get_definitions(chunk)
try:
defs = await flow("prompt-request").extract_definitions(
text = chunk
)
print("Response", defs, flush=True)
if type(defs) != list:
raise RuntimeError("Expecting array in prompt response")
except Exception as e:
print("Prompt exception:", e, flush=True)
raise e
triples = []
entities = []
@ -134,8 +123,8 @@ class Processor(ConsumerProducer):
for defn in defs:
s = defn.name
o = defn.definition
s = defn["entity"]
o = defn["definition"]
if s == "": continue
if o == "": continue
@ -166,13 +155,13 @@ class Processor(ConsumerProducer):
ec = EntityContext(
entity=s_value,
context=defn.definition,
context=defn["definition"],
)
entities.append(ec)
await self.emit_edges(
await self.emit_triples(
flow("triples"),
Metadata(
id=v.metadata.id,
metadata=[],
@ -183,6 +172,7 @@ class Processor(ConsumerProducer):
)
await self.emit_ecs(
flow("entity-contexts"),
Metadata(
id=v.metadata.id,
metadata=[],
@ -200,30 +190,9 @@ class Processor(ConsumerProducer):
@staticmethod
def add_args(parser):
ConsumerProducer.add_args(
parser, default_input_queue, default_subscriber,
default_output_queue,
)
parser.add_argument(
'-e', '--entity-context-queue',
default=default_entity_context_queue,
help=f'Entity context queue (default: {default_entity_context_queue})'
)
parser.add_argument(
'--prompt-request-queue',
default=prompt_request_queue,
help=f'Prompt request queue (default: {prompt_request_queue})',
)
parser.add_argument(
'--prompt-completion-response-queue',
default=prompt_response_queue,
help=f'Prompt response queue (default: {prompt_response_queue})',
)
FlowProcessor.add_args(parser)
def run():
Processor.launch(module, __doc__)
Processor.launch(default_ident, __doc__)

View file

@ -5,59 +5,54 @@ relationship analysis to get entity relationship edges which are output as
graph edges.
"""
import json
import urllib.parse
from .... schema import Chunk, Triple, Triples
from .... schema import Metadata, Value
from .... schema import chunk_ingest_queue, triples_store_queue
from .... schema import prompt_request_queue
from .... schema import prompt_response_queue
from .... log_level import LogLevel
from .... clients.prompt_client import PromptClient
from .... schema import PromptRequest, PromptResponse
from .... rdf import RDF_LABEL, TRUSTGRAPH_ENTITIES, SUBJECT_OF
from .... base import ConsumerProducer
from .... base import FlowProcessor, ConsumerSpec, ProducerSpec
from .... base import PromptClientSpec
RDF_LABEL_VALUE = Value(value=RDF_LABEL, is_uri=True)
SUBJECT_OF_VALUE = Value(value=SUBJECT_OF, is_uri=True)
module = ".".join(__name__.split(".")[1:-1])
default_ident = "kg-extract-relationships"
default_input_queue = chunk_ingest_queue
default_output_queue = triples_store_queue
default_subscriber = module
class Processor(ConsumerProducer):
class Processor(FlowProcessor):
def __init__(self, **params):
input_queue = params.get("input_queue", default_input_queue)
output_queue = params.get("output_queue", default_output_queue)
subscriber = params.get("subscriber", default_subscriber)
pr_request_queue = params.get(
"prompt_request_queue", prompt_request_queue
)
pr_response_queue = params.get(
"prompt_response_queue", prompt_response_queue
)
id = params.get("id")
super(Processor, self).__init__(
**params | {
"input_queue": input_queue,
"output_queue": output_queue,
"subscriber": subscriber,
"input_schema": Chunk,
"output_schema": Triples,
"prompt_request_queue": pr_request_queue,
"prompt_response_queue": pr_response_queue,
"id": id,
}
)
self.prompt = PromptClient(
pulsar_host=self.pulsar_host,
pulsar_api_key=self.pulsar_api_key,
input_queue=pr_request_queue,
output_queue=pr_response_queue,
subscriber = module + "-prompt",
self.register_specification(
ConsumerSpec(
name = "input",
schema = Chunk,
handler = self.on_message
)
)
self.register_specification(
PromptClientSpec(
request_name = "prompt-request",
response_name = "prompt-response",
)
)
self.register_specification(
ProducerSpec(
name = "triples",
schema = Triples
)
)
def to_uri(self, text):
@ -68,28 +63,39 @@ class Processor(ConsumerProducer):
return uri
def get_relationships(self, chunk):
return self.prompt.request_relationships(chunk)
async def emit_edges(self, metadata, triples):
async def emit_triples(self, pub, metadata, triples):
t = Triples(
metadata=metadata,
triples=triples,
)
await self.send(t)
await pub.send(t)
async def handle(self, msg):
async def on_message(self, msg, consumer, flow):
v = msg.value()
print(f"Indexing {v.metadata.id}...", flush=True)
chunk = v.chunk.decode("utf-8")
print(chunk, flush=True)
try:
rels = self.get_relationships(chunk)
try:
rels = await flow("prompt-request").extract_relationships(
text = chunk
)
print("Response", rels, flush=True)
if type(rels) != list:
raise RuntimeError("Expecting array in prompt response")
except Exception as e:
print("Prompt exception:", e, flush=True)
raise e
triples = []
@ -100,9 +106,9 @@ class Processor(ConsumerProducer):
for rel in rels:
s = rel.s
p = rel.p
o = rel.o
s = rel["subject"]
p = rel["predicate"]
o = rel["object"]
if s == "": continue
if p == "": continue
@ -118,7 +124,7 @@ class Processor(ConsumerProducer):
p_uri = self.to_uri(p)
p_value = Value(value=str(p_uri), is_uri=True)
if rel.o_entity:
if rel["object-entity"]:
o_uri = self.to_uri(o)
o_value = Value(value=str(o_uri), is_uri=True)
else:
@ -144,7 +150,7 @@ class Processor(ConsumerProducer):
o=Value(value=str(p), is_uri=False)
))
if rel.o_entity:
if rel["object-entity"]:
# Label for o
triples.append(Triple(
s=o_value,
@ -159,7 +165,7 @@ class Processor(ConsumerProducer):
o=Value(value=v.metadata.id, is_uri=True)
))
if rel.o_entity:
if rel["object-entity"]:
# 'Subject of' for o
triples.append(Triple(
s=o_value,
@ -168,6 +174,7 @@ class Processor(ConsumerProducer):
))
await self.emit_edges(
flow("triples"),
Metadata(
id=v.metadata.id,
metadata=[],
@ -185,24 +192,9 @@ class Processor(ConsumerProducer):
@staticmethod
def add_args(parser):
ConsumerProducer.add_args(
parser, default_input_queue, default_subscriber,
default_output_queue,
)
parser.add_argument(
'--prompt-request-queue',
default=prompt_request_queue,
help=f'Prompt request queue (default: {prompt_request_queue})',
)
parser.add_argument(
'--prompt-response-queue',
default=prompt_response_queue,
help=f'Prompt response queue (default: {prompt_response_queue})',
)
FlowProcessor.add_args(parser)
def run():
Processor.launch(module, __doc__)
Processor.launch(default_ident, __doc__)

View file

@ -18,7 +18,7 @@ from .... base import ConsumerProducer
DEFINITION_VALUE = Value(value=DEFINITION, is_uri=True)
module = ".".join(__name__.split(".")[1:-1])
module = "kg-extract-topics"
default_input_queue = chunk_ingest_queue
default_output_queue = triples_store_queue

View file

@ -39,4 +39,3 @@ class AgentRequestor(ServiceRequestor):
# The 2nd boolean expression indicates whether we're done responding
return resp, (message.answer is not None)

View file

@ -1,6 +1,5 @@
import asyncio
from pulsar.schema import JsonSchema
import uuid
from aiohttp import WSMsgType
@ -26,12 +25,12 @@ class DocumentEmbeddingsLoadEndpoint(SocketEndpoint):
self.publisher = Publisher(
self.pulsar_client, document_embeddings_store_queue,
schema=JsonSchema(DocumentEmbeddings)
schema=DocumentEmbeddings
)
async def start(self):
self.publisher.start()
await self.publisher.start()
async def listener(self, ws, running):
@ -59,6 +58,6 @@ class DocumentEmbeddingsLoadEndpoint(SocketEndpoint):
],
)
self.publisher.send(None, elt)
await self.publisher.send(None, elt)
running.stop()

View file

@ -1,7 +1,6 @@
import asyncio
import queue
from pulsar.schema import JsonSchema
import uuid
from .. schema import DocumentEmbeddings
@ -27,7 +26,7 @@ class DocumentEmbeddingsStreamEndpoint(SocketEndpoint):
self.subscriber = Subscriber(
self.pulsar_client, document_embeddings_store_queue,
"api-gateway", "api-gateway",
schema=JsonSchema(DocumentEmbeddings),
schema=DocumentEmbeddings,
)
async def listener(self, ws, running):
@ -44,17 +43,17 @@ class DocumentEmbeddingsStreamEndpoint(SocketEndpoint):
async def start(self):
self.subscriber.start()
await self.subscriber.start()
async def async_thread(self, ws, running):
id = str(uuid.uuid4())
q = self.subscriber.subscribe_all(id)
q = await self.subscriber.subscribe_all(id)
while running.get():
try:
resp = await asyncio.to_thread(q.get, timeout=0.5)
resp = await asyncio.wait_for(q.get(), timeout=0.5)
await ws.send_json(serialize_document_embeddings(resp))
except TimeoutError:
@ -67,7 +66,7 @@ class DocumentEmbeddingsStreamEndpoint(SocketEndpoint):
print(f"Exception: {str(e)}", flush=True)
break
self.subscriber.unsubscribe_all(id)
await self.subscriber.unsubscribe_all(id)
running.stop()

View file

@ -1,13 +1,9 @@
import asyncio
from pulsar.schema import JsonSchema
from aiohttp import web
import uuid
import logging
from .. base import Publisher
from .. base import Subscriber
logger = logging.getLogger("endpoint")
logger.setLevel(logging.INFO)

View file

@ -1,6 +1,5 @@
import asyncio
from pulsar.schema import JsonSchema
import uuid
from aiohttp import WSMsgType
@ -26,12 +25,12 @@ class GraphEmbeddingsLoadEndpoint(SocketEndpoint):
self.publisher = Publisher(
self.pulsar_client, graph_embeddings_store_queue,
schema=JsonSchema(GraphEmbeddings)
schema=GraphEmbeddings
)
async def start(self):
self.publisher.start()
await self.publisher.start()
async def listener(self, ws, running):
@ -60,6 +59,6 @@ class GraphEmbeddingsLoadEndpoint(SocketEndpoint):
]
)
self.publisher.send(None, elt)
await self.publisher.send(None, elt)
running.stop()

View file

@ -1,7 +1,6 @@
import asyncio
import queue
from pulsar.schema import JsonSchema
import uuid
from .. schema import GraphEmbeddings
@ -26,7 +25,7 @@ class GraphEmbeddingsStreamEndpoint(SocketEndpoint):
self.subscriber = Subscriber(
self.pulsar_client, graph_embeddings_store_queue,
"api-gateway", "api-gateway",
schema=JsonSchema(GraphEmbeddings)
schema=GraphEmbeddings
)
async def listener(self, ws, running):
@ -41,17 +40,17 @@ class GraphEmbeddingsStreamEndpoint(SocketEndpoint):
async def start(self):
self.subscriber.start()
await self.subscriber.start()
async def async_thread(self, ws, running):
id = str(uuid.uuid4())
q = self.subscriber.subscribe_all(id)
q = await self.subscriber.subscribe_all(id)
while running.get():
try:
resp = await asyncio.to_thread(q.get, timeout=0.5)
resp = await asyncio.wait_for(q.get, timeout=0.5)
await ws.send_json(serialize_graph_embeddings(resp))
except TimeoutError:
@ -64,7 +63,7 @@ class GraphEmbeddingsStreamEndpoint(SocketEndpoint):
print(f"Exception: {str(e)}", flush=True)
break
self.subscriber.unsubscribe_all(id)
await self.subscriber.unsubscribe_all(id)
running.stop()

View file

@ -7,7 +7,6 @@
import aiohttp
from aiohttp import web
import asyncio
from pulsar.schema import JsonSchema
import uuid
import logging

View file

@ -1,7 +1,6 @@
import asyncio
import queue
from pulsar.schema import JsonSchema
import uuid
from aiohttp import web, WSMsgType

View file

@ -1,6 +1,5 @@
import asyncio
from pulsar.schema import JsonSchema
import uuid
import logging
@ -23,21 +22,21 @@ class ServiceRequestor:
self.pub = Publisher(
pulsar_client, request_queue,
schema=JsonSchema(request_schema),
schema=request_schema,
)
self.sub = Subscriber(
pulsar_client, response_queue,
subscription, consumer_name,
JsonSchema(response_schema)
response_schema
)
self.timeout = timeout
async def start(self):
self.pub.start()
self.sub.start()
await self.pub.start()
await self.sub.start()
def to_request(self, request):
raise RuntimeError("Not defined")
@ -51,18 +50,15 @@ class ServiceRequestor:
try:
q = self.sub.subscribe(id)
q = await self.sub.subscribe(id)
await asyncio.to_thread(
self.pub.send, id, self.to_request(request)
)
await self.pub.send(id, self.to_request(request))
while True:
try:
resp = await asyncio.to_thread(
q.get,
timeout=self.timeout
resp = await asyncio.wait_for(
q.get(), timeout=self.timeout
)
except Exception as e:
raise RuntimeError("Timeout")
@ -99,5 +95,5 @@ class ServiceRequestor:
return err
finally:
self.sub.unsubscribe(id)
await self.sub.unsubscribe(id)

View file

@ -2,7 +2,6 @@
# Like ServiceRequestor, but just fire-and-forget instead of request/response
import asyncio
from pulsar.schema import JsonSchema
import uuid
import logging
@ -21,12 +20,12 @@ class ServiceSender:
self.pub = Publisher(
pulsar_client, request_queue,
schema=JsonSchema(request_schema),
schema=request_schema,
)
async def start(self):
self.pub.start()
await self.pub.start()
def to_request(self, request):
raise RuntimeError("Not defined")
@ -35,9 +34,7 @@ class ServiceSender:
try:
await asyncio.to_thread(
self.pub.send, None, self.to_request(request)
)
await self.pub.send(None, self.to_request(request))
if responder:
await responder({}, True)

View file

@ -3,7 +3,7 @@ API gateway. Offers HTTP services which are translated to interaction on the
Pulsar bus.
"""
module = ".".join(__name__.split(".")[1:-1])
module = "api-gateway"
# FIXME: Subscribes to Pulsar unnecessarily, should only do it when there
# are active listeners
@ -19,7 +19,6 @@ import os
import base64
import pulsar
from pulsar.schema import JsonSchema
from prometheus_client import start_http_server
from .. log_level import LogLevel

View file

@ -1,6 +1,5 @@
import asyncio
from pulsar.schema import JsonSchema
import uuid
from aiohttp import WSMsgType
@ -24,12 +23,12 @@ class TriplesLoadEndpoint(SocketEndpoint):
self.publisher = Publisher(
self.pulsar_client, triples_store_queue,
schema=JsonSchema(Triples)
schema=Triples
)
async def start(self):
self.publisher.start()
await self.publisher.start()
async def listener(self, ws, running):
@ -51,7 +50,7 @@ class TriplesLoadEndpoint(SocketEndpoint):
triples=to_subgraph(data["triples"]),
)
self.publisher.send(None, elt)
await self.publisher.send(None, elt)
running.stop()

View file

@ -1,7 +1,6 @@
import asyncio
import queue
from pulsar.schema import JsonSchema
import uuid
from .. schema import Triples
@ -24,7 +23,7 @@ class TriplesStreamEndpoint(SocketEndpoint):
self.subscriber = Subscriber(
self.pulsar_client, triples_store_queue,
"api-gateway", "api-gateway",
schema=JsonSchema(Triples)
schema=Triples
)
async def listener(self, ws, running):
@ -39,7 +38,7 @@ class TriplesStreamEndpoint(SocketEndpoint):
async def start(self):
self.subscriber.start()
await self.subscriber.start()
async def async_thread(self, ws, running):

View file

@ -1,295 +0,0 @@
from . clients.graph_embeddings_client import GraphEmbeddingsClient
from . clients.triples_query_client import TriplesQueryClient
from . clients.embeddings_client import EmbeddingsClient
from . clients.prompt_client import PromptClient
from . schema import GraphEmbeddingsRequest, GraphEmbeddingsResponse
from . schema import TriplesQueryRequest, TriplesQueryResponse
from . schema import prompt_request_queue
from . schema import prompt_response_queue
from . schema import embeddings_request_queue
from . schema import embeddings_response_queue
from . schema import graph_embeddings_request_queue
from . schema import graph_embeddings_response_queue
from . schema import triples_request_queue
from . schema import triples_response_queue
LABEL="http://www.w3.org/2000/01/rdf-schema#label"
DEFINITION="http://www.w3.org/2004/02/skos/core#definition"
class Query:
def __init__(
self, rag, user, collection, verbose,
entity_limit=50, triple_limit=30, max_subgraph_size=1000,
max_path_length=2,
):
self.rag = rag
self.user = user
self.collection = collection
self.verbose = verbose
self.entity_limit = entity_limit
self.triple_limit = triple_limit
self.max_subgraph_size = max_subgraph_size
self.max_path_length = max_path_length
def get_vector(self, query):
if self.verbose:
print("Compute embeddings...", flush=True)
qembeds = self.rag.embeddings.request(query)
if self.verbose:
print("Done.", flush=True)
return qembeds
def get_entities(self, query):
vectors = self.get_vector(query)
if self.verbose:
print("Get entities...", flush=True)
entities = self.rag.ge_client.request(
user=self.user, collection=self.collection,
vectors=vectors, limit=self.entity_limit,
)
entities = [
e.value
for e in entities
]
if self.verbose:
print("Entities:", flush=True)
for ent in entities:
print(" ", ent, flush=True)
return entities
def maybe_label(self, e):
if e in self.rag.label_cache:
return self.rag.label_cache[e]
res = self.rag.triples_client.request(
user=self.user, collection=self.collection,
s=e, p=LABEL, o=None, limit=1,
)
if len(res) == 0:
self.rag.label_cache[e] = e
return e
self.rag.label_cache[e] = res[0].o.value
return self.rag.label_cache[e]
def follow_edges(self, ent, subgraph, path_length):
# Not needed?
if path_length <= 0:
return
# Stop spanning around if the subgraph is already maxed out
if len(subgraph) >= self.max_subgraph_size:
return
res = self.rag.triples_client.request(
user=self.user, collection=self.collection,
s=ent, p=None, o=None,
limit=self.triple_limit
)
for triple in res:
subgraph.add(
(triple.s.value, triple.p.value, triple.o.value)
)
if path_length > 1:
self.follow_edges(triple.o.value, subgraph, path_length-1)
res = self.rag.triples_client.request(
user=self.user, collection=self.collection,
s=None, p=ent, o=None,
limit=self.triple_limit
)
for triple in res:
subgraph.add(
(triple.s.value, triple.p.value, triple.o.value)
)
res = self.rag.triples_client.request(
user=self.user, collection=self.collection,
s=None, p=None, o=ent,
limit=self.triple_limit,
)
for triple in res:
subgraph.add(
(triple.s.value, triple.p.value, triple.o.value)
)
if path_length > 1:
self.follow_edges(triple.s.value, subgraph, path_length-1)
def get_subgraph(self, query):
entities = self.get_entities(query)
if self.verbose:
print("Get subgraph...", flush=True)
subgraph = set()
for ent in entities:
self.follow_edges(ent, subgraph, self.max_path_length)
subgraph = list(subgraph)
return subgraph
def get_labelgraph(self, query):
subgraph = self.get_subgraph(query)
sg2 = []
for edge in subgraph:
if edge[1] == LABEL:
continue
s = self.maybe_label(edge[0])
p = self.maybe_label(edge[1])
o = self.maybe_label(edge[2])
sg2.append((s, p, o))
sg2 = sg2[0:self.max_subgraph_size]
if self.verbose:
print("Subgraph:", flush=True)
for edge in sg2:
print(" ", str(edge), flush=True)
if self.verbose:
print("Done.", flush=True)
return sg2
class GraphRag:
def __init__(
self,
pulsar_host="pulsar://pulsar:6650",
pulsar_api_key=None,
pr_request_queue=None,
pr_response_queue=None,
emb_request_queue=None,
emb_response_queue=None,
ge_request_queue=None,
ge_response_queue=None,
tpl_request_queue=None,
tpl_response_queue=None,
verbose=False,
module="test",
):
self.verbose=verbose
if pr_request_queue is None:
pr_request_queue = prompt_request_queue
if pr_response_queue is None:
pr_response_queue = prompt_response_queue
if emb_request_queue is None:
emb_request_queue = embeddings_request_queue
if emb_response_queue is None:
emb_response_queue = embeddings_response_queue
if ge_request_queue is None:
ge_request_queue = graph_embeddings_request_queue
if ge_response_queue is None:
ge_response_queue = graph_embeddings_response_queue
if tpl_request_queue is None:
tpl_request_queue = triples_request_queue
if tpl_response_queue is None:
tpl_response_queue = triples_response_queue
if self.verbose:
print("Initialising...", flush=True)
self.ge_client = GraphEmbeddingsClient(
pulsar_host=pulsar_host,
pulsar_api_key=pulsar_api_key,
subscriber=module + "-ge",
input_queue=ge_request_queue,
output_queue=ge_response_queue,
)
self.triples_client = TriplesQueryClient(
pulsar_host=pulsar_host,
pulsar_api_key=pulsar_api_key,
subscriber=module + "-tpl",
input_queue=tpl_request_queue,
output_queue=tpl_response_queue
)
self.embeddings = EmbeddingsClient(
pulsar_host=pulsar_host,
pulsar_api_key=pulsar_api_key,
input_queue=emb_request_queue,
output_queue=emb_response_queue,
subscriber=module + "-emb",
)
self.label_cache = {}
self.prompt = PromptClient(
pulsar_host=pulsar_host,
pulsar_api_key=pulsar_api_key,
input_queue=pr_request_queue,
output_queue=pr_response_queue,
subscriber=module + "-prompt",
)
if self.verbose:
print("Initialised", flush=True)
def query(
self, query, user="trustgraph", collection="default",
entity_limit=50, triple_limit=30, max_subgraph_size=1000,
max_path_length=2,
):
if self.verbose:
print("Construct prompt...", flush=True)
q = Query(
rag=self, user=user, collection=collection, verbose=self.verbose,
entity_limit=entity_limit, triple_limit=triple_limit,
max_subgraph_size=max_subgraph_size,
max_path_length=max_path_length,
)
kg = q.get_labelgraph(query)
if self.verbose:
print("Invoke LLM...", flush=True)
print(kg)
print(query)
resp = self.prompt.request_kg_prompt(query, kg)
if self.verbose:
print("Done", flush=True)
return resp

View file

@ -35,7 +35,7 @@ from .. exceptions import RequestError
from . librarian import Librarian
module = ".".join(__name__.split(".")[1:-1])
module = "librarian"
default_input_queue = librarian_request_queue
default_output_queue = librarian_response_queue

View file

@ -10,12 +10,11 @@ from .. schema import text_completion_response_queue
from .. log_level import LogLevel
from .. base import Consumer
module = ".".join(__name__.split(".")[1:-1])
module = "metering"
default_input_queue = text_completion_response_queue
default_subscriber = module
class Processor(Consumer):
def __init__(self, **params):

View file

@ -27,7 +27,7 @@ from .... clients.llm_client import LlmClient
from . prompts import to_definitions, to_relationships, to_topics
from . prompts import to_kg_query, to_document_query, to_rows
module = ".".join(__name__.split(".")[1:-1])
module = "prompt"
default_input_queue = prompt_request_queue
default_output_queue = prompt_response_queue

View file

@ -4,8 +4,6 @@ import json
from jsonschema import validate
import re
from trustgraph.clients.llm_client import LlmClient
class PromptConfiguration:
def __init__(self, system_template, global_terms={}, prompts={}):
self.system_template = system_template
@ -21,8 +19,7 @@ class Prompt:
class PromptManager:
def __init__(self, llm, config):
self.llm = llm
def __init__(self, config):
self.config = config
self.terms = config.global_terms
@ -54,7 +51,9 @@ class PromptManager:
return json.loads(json_str)
def invoke(self, id, input):
async def invoke(self, id, input, llm):
print("Invoke...", flush=True)
if id not in self.prompts:
raise RuntimeError("ID invalid")
@ -68,9 +67,7 @@ class PromptManager:
"prompt": self.templates[id].render(terms)
}
resp = self.llm.request(**prompt)
print(resp, flush=True)
resp = await llm(**prompt)
if resp_type == "text":
return resp
@ -81,13 +78,13 @@ class PromptManager:
try:
obj = self.parse_json(resp)
except:
print("Parse fail:", resp, flush=True)
raise RuntimeError("JSON parse fail")
print(obj, flush=True)
if self.prompts[id].schema:
try:
print(self.prompts[id].schema)
validate(instance=obj, schema=self.prompts[id].schema)
print("Validated", flush=True)
except Exception as e:
raise RuntimeError(f"Schema validation fail: {e}")

View file

@ -3,6 +3,7 @@
Language service abstracts prompt engineering from LLM.
"""
import asyncio
import json
import re
@ -10,74 +11,59 @@ from .... schema import Definition, Relationship, Triple
from .... schema import Topic
from .... schema import PromptRequest, PromptResponse, Error
from .... schema import TextCompletionRequest, TextCompletionResponse
from .... schema import text_completion_request_queue
from .... schema import text_completion_response_queue
from .... schema import prompt_request_queue, prompt_response_queue
from .... base import ConsumerProducer
from .... clients.llm_client import LlmClient
from .... base import FlowProcessor
from .... base import ProducerSpec, ConsumerSpec, TextCompletionClientSpec
from . prompt_manager import PromptConfiguration, Prompt, PromptManager
module = ".".join(__name__.split(".")[1:-1])
default_ident = "prompt"
default_input_queue = prompt_request_queue
default_output_queue = prompt_response_queue
default_subscriber = module
class Processor(ConsumerProducer):
class Processor(FlowProcessor):
def __init__(self, **params):
input_queue = params.get("input_queue", default_input_queue)
output_queue = params.get("output_queue", default_output_queue)
subscriber = params.get("subscriber", default_subscriber)
tc_request_queue = params.get(
"text_completion_request_queue", text_completion_request_queue
)
tc_response_queue = params.get(
"text_completion_response_queue", text_completion_response_queue
)
id = params.get("id")
# Config key for prompts
self.config_key = params.get("config_type", "prompt")
super(Processor, self).__init__(
**params | {
"input_queue": input_queue,
"output_queue": output_queue,
"subscriber": subscriber,
"input_schema": PromptRequest,
"output_schema": PromptResponse,
"text_completion_request_queue": tc_request_queue,
"text_completion_response_queue": tc_response_queue,
"id": id,
}
)
self.llm = LlmClient(
subscriber=subscriber,
input_queue=tc_request_queue,
output_queue=tc_response_queue,
pulsar_host = self.pulsar_host,
pulsar_api_key=self.pulsar_api_key,
self.register_specification(
ConsumerSpec(
name = "request",
schema = PromptRequest,
handler = self.on_request
)
)
# System prompt hack
class Llm:
def __init__(self, llm):
self.llm = llm
def request(self, system, prompt):
print(system)
print(prompt, flush=True)
return self.llm.request(system, prompt)
self.register_specification(
TextCompletionClientSpec(
request_name = "text-completion-request",
response_name = "text-completion-response",
)
)
self.llm = Llm(self.llm)
self.register_specification(
ProducerSpec(
name = "response",
schema = PromptResponse
)
)
self.register_config_handler(self.on_prompt_config)
# Null configuration, should reload quickly
self.manager = PromptManager(
llm = self.llm,
config = PromptConfiguration("", {}, {})
)
async def on_config(self, version, config):
async def on_prompt_config(self, config, version):
print("Loading configuration version", version)
@ -111,7 +97,6 @@ class Processor(ConsumerProducer):
)
self.manager = PromptManager(
self.llm,
PromptConfiguration(
system,
{},
@ -126,7 +111,7 @@ class Processor(ConsumerProducer):
print("Exception:", e, flush=True)
print("Configuration reload failed", flush=True)
async def handle(self, msg):
async def on_request(self, msg, consumer, flow):
v = msg.value()
@ -138,7 +123,7 @@ class Processor(ConsumerProducer):
try:
print(v.terms)
print(v.terms, flush=True)
input = {
k: json.loads(v)
@ -146,14 +131,33 @@ class Processor(ConsumerProducer):
}
print(f"Handling kind {kind}...", flush=True)
print(input, flush=True)
resp = self.manager.invoke(kind, input)
async def llm(system, prompt):
print(system, flush=True)
print(prompt, flush=True)
resp = await flow("text-completion-request").text_completion(
system = system, prompt = prompt,
)
try:
return resp
except Exception as e:
print("LLM Exception:", e, flush=True)
return None
try:
resp = await self.manager.invoke(kind, input, llm)
except Exception as e:
print("Invocation exception:", e, flush=True)
raise e
print(resp, flush=True)
if isinstance(resp, str):
print("Send text response...", flush=True)
print(resp, flush=True)
r = PromptResponse(
text=resp,
@ -161,7 +165,7 @@ class Processor(ConsumerProducer):
error=None,
)
await self.send(r, properties={"id": id})
await flow("response").send(r, properties={"id": id})
return
@ -176,13 +180,13 @@ class Processor(ConsumerProducer):
error=None,
)
await self.send(r, properties={"id": id})
await flow("response").send(r, properties={"id": id})
return
except Exception as e:
print(f"Exception: {e}")
print(f"Exception: {e}", flush=True)
print("Send error response...", flush=True)
@ -194,11 +198,11 @@ class Processor(ConsumerProducer):
response=None,
)
await self.send(r, properties={"id": id})
await flow("response").send(r, properties={"id": id})
except Exception as e:
print(f"Exception: {e}")
print(f"Exception: {e}", flush=True)
print("Send error response...", flush=True)
@ -215,22 +219,7 @@ class Processor(ConsumerProducer):
@staticmethod
def add_args(parser):
ConsumerProducer.add_args(
parser, default_input_queue, default_subscriber,
default_output_queue,
)
parser.add_argument(
'--text-completion-request-queue',
default=text_completion_request_queue,
help=f'Text completion request queue (default: {text_completion_request_queue})',
)
parser.add_argument(
'--text-completion-response-queue',
default=text_completion_response_queue,
help=f'Text completion response queue (default: {text_completion_response_queue})',
)
FlowProcessor.add_args(parser)
parser.add_argument(
'--config-type',
@ -240,5 +229,5 @@ class Processor(ConsumerProducer):
def run():
Processor.launch(module, __doc__)
Processor.launch(default_ident, __doc__)

View file

@ -16,7 +16,7 @@ from .... log_level import LogLevel
from .... base import ConsumerProducer
from .... exceptions import TooManyRequests
module = ".".join(__name__.split(".")[1:-1])
module = "text-completion"
default_input_queue = text_completion_request_queue
default_output_queue = text_completion_response_queue

View file

@ -16,7 +16,7 @@ from .... log_level import LogLevel
from .... base import ConsumerProducer
from .... exceptions import TooManyRequests
module = ".".join(__name__.split(".")[1:-1])
module = "text-completion"
default_input_queue = text_completion_request_queue
default_output_queue = text_completion_response_queue

View file

@ -15,7 +15,7 @@ from .... log_level import LogLevel
from .... base import ConsumerProducer
from .... exceptions import TooManyRequests
module = ".".join(__name__.split(".")[1:-1])
module = "text-completion"
default_input_queue = text_completion_request_queue
default_output_queue = text_completion_response_queue

View file

@ -15,7 +15,7 @@ from .... log_level import LogLevel
from .... base import ConsumerProducer
from .... exceptions import TooManyRequests
module = ".".join(__name__.split(".")[1:-1])
module = "text-completion"
default_input_queue = text_completion_request_queue
default_output_queue = text_completion_response_queue

View file

@ -17,7 +17,7 @@ from .... log_level import LogLevel
from .... base import ConsumerProducer
from .... exceptions import TooManyRequests
module = ".".join(__name__.split(".")[1:-1])
module = "text-completion"
default_input_queue = text_completion_request_queue
default_output_queue = text_completion_response_queue

View file

@ -14,7 +14,7 @@ from .... log_level import LogLevel
from .... base import ConsumerProducer
from .... exceptions import TooManyRequests
module = ".".join(__name__.split(".")[1:-1])
module = "text-completion"
default_input_queue = text_completion_request_queue
default_output_queue = text_completion_response_queue

View file

@ -15,7 +15,7 @@ from .... log_level import LogLevel
from .... base import ConsumerProducer
from .... exceptions import TooManyRequests
module = ".".join(__name__.split(".")[1:-1])
module = "text-completion"
default_input_queue = text_completion_request_queue
default_output_queue = text_completion_response_queue

View file

@ -15,7 +15,7 @@ from .... log_level import LogLevel
from .... base import ConsumerProducer
from .... exceptions import TooManyRequests
module = ".".join(__name__.split(".")[1:-1])
module = "text-completion"
default_input_queue = text_completion_request_queue
default_output_queue = text_completion_response_queue

View file

@ -15,7 +15,7 @@ from .... log_level import LogLevel
from .... base import ConsumerProducer
from .... exceptions import TooManyRequests
module = ".".join(__name__.split(".")[1:-1])
module = "text-completion"
default_input_queue = text_completion_request_queue
default_output_queue = text_completion_response_queue

View file

@ -15,7 +15,7 @@ from .... log_level import LogLevel
from .... base import ConsumerProducer
from .... exceptions import TooManyRequests
module = ".".join(__name__.split(".")[1:-1])
module = "text-completion"
default_input_queue = text_completion_request_queue
default_output_queue = text_completion_response_queue

View file

@ -11,7 +11,7 @@ from .... schema import document_embeddings_request_queue
from .... schema import document_embeddings_response_queue
from .... base import ConsumerProducer
module = ".".join(__name__.split(".")[1:-1])
module = "de-query"
default_input_queue = document_embeddings_request_queue
default_output_queue = document_embeddings_response_queue

View file

@ -16,7 +16,7 @@ from .... schema import document_embeddings_request_queue
from .... schema import document_embeddings_response_queue
from .... base import ConsumerProducer
module = ".".join(__name__.split(".")[1:-1])
module = "de-query"
default_input_queue = document_embeddings_request_queue
default_output_queue = document_embeddings_response_queue

View file

@ -7,71 +7,51 @@ of chunks
from qdrant_client import QdrantClient
from qdrant_client.models import PointStruct
from qdrant_client.models import Distance, VectorParams
import uuid
from .... schema import DocumentEmbeddingsRequest, DocumentEmbeddingsResponse
from .... schema import DocumentEmbeddingsResponse
from .... schema import Error, Value
from .... schema import document_embeddings_request_queue
from .... schema import document_embeddings_response_queue
from .... base import ConsumerProducer
from .... base import DocumentEmbeddingsQueryService
module = ".".join(__name__.split(".")[1:-1])
default_ident = "de-query"
default_input_queue = document_embeddings_request_queue
default_output_queue = document_embeddings_response_queue
default_subscriber = module
default_store_uri = 'http://localhost:6333'
class Processor(ConsumerProducer):
class Processor(DocumentEmbeddingsQueryService):
def __init__(self, **params):
input_queue = params.get("input_queue", default_input_queue)
output_queue = params.get("output_queue", default_output_queue)
subscriber = params.get("subscriber", default_subscriber)
store_uri = params.get("store_uri", default_store_uri)
#optional api key
api_key = params.get("api_key", None)
super(Processor, self).__init__(
**params | {
"input_queue": input_queue,
"output_queue": output_queue,
"subscriber": subscriber,
"input_schema": DocumentEmbeddingsRequest,
"output_schema": DocumentEmbeddingsResponse,
"store_uri": store_uri,
"api_key": api_key,
}
)
self.client = QdrantClient(url=store_uri, api_key=api_key)
self.qdrant = QdrantClient(url=store_uri, api_key=api_key)
async def handle(self, msg):
async def query_document_embeddings(self, msg):
try:
v = msg.value()
# Sender-produced ID
id = msg.properties()["id"]
print(f"Handling input {id}...", flush=True)
chunks = []
for vec in v.vectors:
for vec in msg.vectors:
dim = len(vec)
collection = (
"d_" + v.user + "_" + v.collection + "_" +
"d_" + msg.user + "_" + msg.collection + "_" +
str(dim)
)
search_result = self.client.query_points(
search_result = self.qdrant.query_points(
collection_name=collection,
query=vec,
limit=v.limit,
limit=msg.limit,
with_payload=True,
).points
@ -79,37 +59,17 @@ class Processor(ConsumerProducer):
ent = r.payload["doc"]
chunks.append(ent)
print("Send response...", flush=True)
r = DocumentEmbeddingsResponse(documents=chunks, error=None)
await self.send(r, properties={"id": id})
print("Done.", flush=True)
return chunks
except Exception as e:
print(f"Exception: {e}")
print("Send error response...", flush=True)
r = DocumentEmbeddingsResponse(
error=Error(
type = "llm-error",
message = str(e),
),
documents=None,
)
await self.send(r, properties={"id": id})
self.consumer.acknowledge(msg)
raise e
@staticmethod
def add_args(parser):
ConsumerProducer.add_args(
parser, default_input_queue, default_subscriber,
default_output_queue,
)
DocumentEmbeddingsQueryService.add_args(parser)
parser.add_argument(
'-t', '--store-uri',
@ -125,5 +85,5 @@ class Processor(ConsumerProducer):
def run():
Processor.launch(module, __doc__)
Processor.launch(default_ident, __doc__)

Some files were not shown because too many files have changed in this diff Show more