Feature/configure flows (#345)

- Keeps processing in different flows separate so that data can go to different stores / collections etc.
- Potentially supports different processing flows
- Tidies the processing API with common base-classes for e.g. LLMs, and automatic configuration of 'clients' to use the right queue names in a flow
This commit is contained in:
cybermaggedon 2025-04-22 20:21:38 +01:00 committed by Cyber MacGeddon
parent dc0ce1041b
commit 31328317fd
125 changed files with 3751 additions and 2628 deletions

View file

@ -60,6 +60,14 @@ container: update-package-versions
${DOCKER} build -f containers/Containerfile.ocr \ ${DOCKER} build -f containers/Containerfile.ocr \
-t ${CONTAINER_BASE}/trustgraph-ocr:${VERSION} . -t ${CONTAINER_BASE}/trustgraph-ocr:${VERSION} .
some-containers:
${DOCKER} build -f containers/Containerfile.base \
-t ${CONTAINER_BASE}/trustgraph-base:${VERSION} .
${DOCKER} build -f containers/Containerfile.flow \
-t ${CONTAINER_BASE}/trustgraph-flow:${VERSION} .
${DOCKER} build -f containers/Containerfile.vertexai \
-t ${CONTAINER_BASE}/trustgraph-vertexai:${VERSION} .
basic-containers: update-package-versions basic-containers: update-package-versions
${DOCKER} build -f containers/Containerfile.base \ ${DOCKER} build -f containers/Containerfile.base \
-t ${CONTAINER_BASE}/trustgraph-base:${VERSION} . -t ${CONTAINER_BASE}/trustgraph-base:${VERSION} .

View file

@ -20,7 +20,11 @@ def output(text, prefix="> ", width=78):
) )
print(out) print(out)
p = AgentClient(pulsar_host="pulsar://localhost:6650") p = AgentClient(
pulsar_host="pulsar://pulsar:6650",
input_queue = "non-persistent://tg/request/agent:0000",
output_queue = "non-persistent://tg/response/agent:0000",
)
q = "How many cats does Mark have? Calculate that number raised to 0.4 power. Is that number lower than the numeric part of the mission identifier of the Space Shuttle Challenger on its last mission? If so, give me an apple pie recipe, otherwise return a poem about cheese." q = "How many cats does Mark have? Calculate that number raised to 0.4 power. Is that number lower than the numeric part of the mission identifier of the Space Shuttle Challenger on its last mission? If so, give me an apple pie recipe, otherwise return a poem about cheese."

View file

@ -3,7 +3,12 @@
import pulsar import pulsar
from trustgraph.clients.document_rag_client import DocumentRagClient from trustgraph.clients.document_rag_client import DocumentRagClient
rag = DocumentRagClient(pulsar_host="pulsar://localhost:6650") rag = DocumentRagClient(
pulsar_host="pulsar://localhost:6650",
subscriber="test1",
input_queue = "non-persistent://tg/request/document-rag:default",
output_queue = "non-persistent://tg/response/document-rag:default",
)
query=""" query="""
What was the cause of the space shuttle disaster?""" What was the cause of the space shuttle disaster?"""

View file

@ -3,7 +3,12 @@
import pulsar import pulsar
from trustgraph.clients.embeddings_client import EmbeddingsClient from trustgraph.clients.embeddings_client import EmbeddingsClient
embed = EmbeddingsClient(pulsar_host="pulsar://localhost:6650") embed = EmbeddingsClient(
pulsar_host="pulsar://pulsar:6650",
input_queue="non-persistent://tg/request/embeddings:default",
output_queue="non-persistent://tg/response/embeddings:default",
subscriber="test1",
)
prompt="Write a funny limerick about a llama" prompt="Write a funny limerick about a llama"
@ -11,5 +16,3 @@ resp = embed.request(prompt)
print(resp) print(resp)

View file

@ -3,11 +3,18 @@
import pulsar import pulsar
from trustgraph.clients.graph_rag_client import GraphRagClient from trustgraph.clients.graph_rag_client import GraphRagClient
rag = GraphRagClient(pulsar_host="pulsar://localhost:6650") rag = GraphRagClient(
pulsar_host="pulsar://localhost:6650",
subscriber="test1",
input_queue = "non-persistent://tg/request/graph-rag:default",
output_queue = "non-persistent://tg/response/graph-rag:default",
)
query=""" #query="""
This knowledge graph describes the Space Shuttle disaster. #This knowledge graph describes the Space Shuttle disaster.
Present 20 facts which are present in the knowledge graph.""" #Present 20 facts which are present in the knowledge graph."""
query = "How many cats does Mark have?"
resp = rag.request(query) resp = rag.request(query)

View file

@ -3,14 +3,17 @@
import pulsar import pulsar
from trustgraph.clients.llm_client import LlmClient from trustgraph.clients.llm_client import LlmClient
llm = LlmClient(pulsar_host="pulsar://localhost:6650") llm = LlmClient(
pulsar_host="pulsar://pulsar:6650",
input_queue="non-persistent://tg/request/text-completion:default",
output_queue="non-persistent://tg/response/text-completion:default",
subscriber="test1",
)
system = "You are a lovely assistant." system = "You are a lovely assistant."
prompt="Write a funny limerick about a llama" prompt="what is 2 + 2 == 5"
resp = llm.request(system, prompt) resp = llm.request(system, prompt)
print(resp) print(resp)

View file

@ -3,7 +3,12 @@
import json import json
from trustgraph.clients.prompt_client import PromptClient from trustgraph.clients.prompt_client import PromptClient
p = PromptClient(pulsar_host="pulsar://localhost:6650") p = PromptClient(
pulsar_host="pulsar://localhost:6650",
input_queue="non-persistent://tg/request/prompt:default",
output_queue="non-persistent://tg/response/prompt:default",
subscriber="test1",
)
chunk=""" chunk="""
The Space Shuttle was a reusable spacecraft that transported astronauts and cargo to and from Earth's orbit. It was designed to launch like a rocket, maneuver in orbit like a spacecraft, and land like an airplane. The Space Shuttle was NASA's space transportation system and was used for many purposes, including: The Space Shuttle was a reusable spacecraft that transported astronauts and cargo to and from Earth's orbit. It was designed to launch like a rocket, maneuver in orbit like a spacecraft, and land like an airplane. The Space Shuttle was NASA's space transportation system and was used for many purposes, including:
@ -31,8 +36,8 @@ The Space Shuttle's last mission was in 2011.
q = "Tell me some facts in the knowledge graph" q = "Tell me some facts in the knowledge graph"
resp = p.request( resp = p.request(
id="extract-definition", id="extract-definitions",
terms = { variables = {
"text": chunk, "text": chunk,
} }
) )
@ -40,7 +45,7 @@ resp = p.request(
print(resp) print(resp)
for fact in resp: for fact in resp:
print(fact["term"], "::") print(fact["entity"], "::")
print(fact["definition"]) print(fact["definition"])
print() print()

View file

@ -3,13 +3,18 @@
import pulsar import pulsar
from trustgraph.clients.prompt_client import PromptClient from trustgraph.clients.prompt_client import PromptClient
p = PromptClient(pulsar_host="pulsar://localhost:6650") p = PromptClient(
pulsar_host="pulsar://localhost:6650",
input_queue="non-persistent://tg/request/prompt:default",
output_queue="non-persistent://tg/response/prompt:default",
subscriber="test1",
)
question = """What is the square root of 16?""" question = """What is the square root of 16?"""
resp = p.request( resp = p.request(
id="question", id="question",
terms = { variables = {
"question": question "question": question
} }
) )

View file

@ -3,7 +3,9 @@
import pulsar import pulsar
from trustgraph.clients.triples_query_client import TriplesQueryClient from trustgraph.clients.triples_query_client import TriplesQueryClient
tq = TriplesQueryClient(pulsar_host="pulsar://localhost:6650") tq = TriplesQueryClient(
pulsar_host="pulsar://localhost:6650",
)
e = "http://trustgraph.ai/e/shuttle" e = "http://trustgraph.ai/e/shuttle"

View file

@ -1,8 +1,31 @@
from . base_processor import BaseProcessor from . pubsub import PulsarClient
from . async_processor import AsyncProcessor
from . consumer import Consumer from . consumer import Consumer
from . producer import Producer from . producer import Producer
from . consumer_producer import ConsumerProducer
from . publisher import Publisher from . publisher import Publisher
from . subscriber import Subscriber from . subscriber import Subscriber
from . metrics import ProcessorMetrics, ConsumerMetrics, ProducerMetrics
from . flow_processor import FlowProcessor
from . consumer_spec import ConsumerSpec
from . setting_spec import SettingSpec
from . producer_spec import ProducerSpec
from . subscriber_spec import SubscriberSpec
from . request_response_spec import RequestResponseSpec
from . llm_service import LlmService, LlmResult
from . embeddings_service import EmbeddingsService
from . embeddings_client import EmbeddingsClientSpec
from . text_completion_client import TextCompletionClientSpec
from . prompt_client import PromptClientSpec
from . triples_store_service import TriplesStoreService
from . graph_embeddings_store_service import GraphEmbeddingsStoreService
from . document_embeddings_store_service import DocumentEmbeddingsStoreService
from . triples_query_service import TriplesQueryService
from . graph_embeddings_query_service import GraphEmbeddingsQueryService
from . document_embeddings_query_service import DocumentEmbeddingsQueryService
from . graph_embeddings_client import GraphEmbeddingsClientSpec
from . triples_client import TriplesClientSpec
from . document_embeddings_client import DocumentEmbeddingsClientSpec
from . agent_service import AgentService
from . graph_rag_client import GraphRagClientSpec

View file

@ -0,0 +1,39 @@
from . request_response_spec import RequestResponse, RequestResponseSpec
from .. schema import AgentRequest, AgentResponse
from .. knowledge import Uri, Literal
class AgentClient(RequestResponse):
async def request(self, recipient, question, plan=None, state=None,
history=[], timeout=300):
resp = await self.request(
AgentRequest(
question = question,
plan = plan,
state = state,
history = history,
),
recipient=recipient,
timeout=timeout,
)
print(resp, flush=True)
if resp.error:
raise RuntimeError(resp.error.message)
return resp
class GraphEmbeddingsClientSpec(RequestResponseSpec):
def __init__(
self, request_name, response_name,
):
super(GraphEmbeddingsClientSpec, self).__init__(
request_name = request_name,
request_schema = GraphEmbeddingsRequest,
response_name = response_name,
response_schema = GraphEmbeddingsResponse,
impl = GraphEmbeddingsClient,
)

View file

@ -0,0 +1,100 @@
"""
Agent manager service completion base class
"""
import time
from prometheus_client import Histogram
from .. schema import AgentRequest, AgentResponse, Error
from .. exceptions import TooManyRequests
from .. base import FlowProcessor, ConsumerSpec, ProducerSpec
default_ident = "agent-manager"
class AgentService(FlowProcessor):
def __init__(self, **params):
id = params.get("id")
super(AgentService, self).__init__(**params | { "id": id })
self.register_specification(
ConsumerSpec(
name = "request",
schema = AgentRequest,
handler = self.on_request
)
)
self.register_specification(
ProducerSpec(
name = "next",
schema = AgentRequest
)
)
self.register_specification(
ProducerSpec(
name = "response",
schema = AgentResponse
)
)
async def on_request(self, msg, consumer, flow):
try:
request = msg.value()
# Sender-produced ID
id = msg.properties()["id"]
async def respond(resp):
await flow("response").send(
resp,
properties={"id": id}
)
async def next(resp):
await flow("next").send(
resp,
properties={"id": id}
)
await self.agent_request(
request = request, respond = respond, next = next,
flow = flow
)
except TooManyRequests as e:
raise e
except Exception as e:
# Apart from rate limits, treat all exceptions as unrecoverable
print(f"on_request Exception: {e}")
print("Send error response...", flush=True)
await flow.producer["response"].send(
AgentResponse(
error=Error(
type = "agent-error",
message = str(e),
),
thought = None,
observation = None,
answer = None,
),
properties={"id": id}
)
@staticmethod
def add_args(parser):
FlowProcessor.add_args(parser)

View file

@ -0,0 +1,254 @@
# Base class for processors. Implements:
# - Pulsar client, subscribe and consume basic
# - the async startup logic
# - Initialising metrics
import asyncio
import argparse
import _pulsar
import time
import uuid
from prometheus_client import start_http_server, Info
from .. schema import ConfigPush, config_push_queue
from .. log_level import LogLevel
from .. exceptions import TooManyRequests
from . pubsub import PulsarClient
from . producer import Producer
from . consumer import Consumer
from . metrics import ProcessorMetrics
default_config_queue = config_push_queue
# Async processor
class AsyncProcessor:
def __init__(self, **params):
# Store the identity
self.id = params.get("id")
# Register a pulsar client
self.pulsar_client = PulsarClient(**params)
# Initialise metrics, records the parameters
ProcessorMetrics(id=self.id).info({
k: str(params[k])
for k in params
if k != "id"
})
# The processor runs all activity in a taskgroup, it's mandatory
# that this is provded
self.taskgroup = params.get("taskgroup")
if self.taskgroup is None:
raise RuntimeError("Essential taskgroup missing")
# Get the configuration topic
self.config_push_queue = params.get(
"config_push_queue", default_config_queue
)
# This records registered configuration handlers
self.config_handlers = []
# Create a random ID for this subscription to the configuration
# service
config_subscriber_id = str(uuid.uuid4())
# Subscribe to config queue
self.config_sub_task = Consumer(
taskgroup = self.taskgroup,
client = self.client,
subscriber = config_subscriber_id,
flow = None,
topic = self.config_push_queue,
schema = ConfigPush,
handler = self.on_config_change,
# This causes new subscriptions to view the entire history of
# configuration
start_of_messages = True
)
self.running = True
# This is called to start dynamic behaviour. An over-ride point for
# extra functionality
async def start(self):
await self.config_sub_task.start()
# This is called to stop all threads. An over-ride point for extra
# functionality
def stop(self):
self.client.close()
self.running = False
# Returns the pulsar host
@property
def pulsar_host(self): return self.client.pulsar_host
# Returns the pulsar client
@property
def client(self): return self.pulsar_client.client
# Register a new event handler for configuration change
def register_config_handler(self, handler):
self.config_handlers.append(handler)
# Called when a new configuration message push occurs
async def on_config_change(self, message, consumer):
# Get configuration data and version number
config = message.value().config
version = message.value().version
# Acknowledge the message
consumer.acknowledge(message)
# Invoke message handlers
print("Config change event", config, version, flush=True)
for ch in self.config_handlers:
await ch(config, version)
# This is the 'main' body of the handler. It is a point to override
# if needed. By default does nothing. Processors are implemented
# by adding consumer/producer functionality so maybe nothing is needed
# in the run() body
async def run(self):
while self.running:
await asyncio.sleep(2)
# Startup fabric. This runs in 'async' mode, creates a taskgroup and
# runs the producer.
@classmethod
async def launch_async(cls, args):
try:
# Create a taskgroup. This seems complicated, when an exception
# occurs, unhandled it looks like it cancels all threads in the
# taskgroup. Needs the exception to be caught in the right
# place.
async with asyncio.TaskGroup() as tg:
# Create a processor instance, and include the taskgroup
# as a paramter. A processor identity ident is used as
# - subscriber name
# - an identifier for flow configuration
p = cls(**args | { "taskgroup": tg })
# Start the processor
await p.start()
# Run the processor
task = tg.create_task(p.run())
# The taskgroup causes everything to wait until
# all threads have stopped
# This is here to output a debug message, shouldn't be needed.
except Exception as e:
print("Exception, closing taskgroup", flush=True)
raise e
# Startup fabric. launch calls launch_async in async mode.
@classmethod
def launch(cls, ident, doc):
# Start assembling CLI arguments
parser = argparse.ArgumentParser(
prog=ident,
description=doc
)
parser.add_argument(
'--id',
default=ident,
help=f'Configuration identity (default: {ident})',
)
# Invoke the class-specific add_args, which manages adding all the
# command-line arguments
cls.add_args(parser)
# Parse arguments
args = parser.parse_args()
args = vars(args)
# Debug
print(args, flush=True)
# Start the Prometheus metrics service if needed
if args["metrics"]:
start_http_server(args["metrics_port"])
# Loop forever, exception handler
while True:
print("Starting...", flush=True)
try:
# Launch the processor in an asyncio handler
asyncio.run(cls.launch_async(
args
))
except KeyboardInterrupt:
print("Keyboard interrupt.", flush=True)
return
except _pulsar.Interrupted:
print("Pulsar Interrupted.", flush=True)
return
# Exceptions from a taskgroup come in as an exception group
except ExceptionGroup as e:
print("Exception group:", flush=True)
for se in e.exceptions:
print(" Type:", type(se), flush=True)
print(f" Exception: {se}", flush=True)
except Exception as e:
print("Type:", type(e), flush=True)
print("Exception:", e, flush=True)
# Retry occurs here
print("Will retry...", flush=True)
time.sleep(4)
print("Retrying...", flush=True)
# The command-line arguments are built using a stack of add_args
# invocations
@staticmethod
def add_args(parser):
PulsarClient.add_args(parser)
parser.add_argument(
'--config-push-queue',
default=default_config_queue,
help=f'Config push queue {default_config_queue}',
)
parser.add_argument(
'--metrics',
action=argparse.BooleanOptionalAction,
default=True,
help=f'Metrics enabled (default: true)',
)
parser.add_argument(
'-P', '--metrics-port',
type=int,
default=8000,
help=f'Pulsar host (default: 8000)',
)

View file

@ -1,210 +0,0 @@
import asyncio
import os
import argparse
import pulsar
from pulsar.schema import JsonSchema
import _pulsar
import time
import uuid
from prometheus_client import start_http_server, Info
from .. schema import ConfigPush, config_push_queue
from .. log_level import LogLevel
default_config_queue = config_push_queue
config_subscriber_id = str(uuid.uuid4())
class BaseProcessor:
default_pulsar_host = os.getenv("PULSAR_HOST", 'pulsar://pulsar:6650')
default_pulsar_api_key = os.getenv("PULSAR_API_KEY", None)
def __init__(self, **params):
self.client = None
if not hasattr(__class__, "params_metric"):
__class__.params_metric = Info(
'params', 'Parameters configuration'
)
# FIXME: Maybe outputs information it should not
__class__.params_metric.info({
k: str(params[k])
for k in params
})
pulsar_host = params.get("pulsar_host", self.default_pulsar_host)
pulsar_listener = params.get("pulsar_listener", None)
pulsar_api_key = params.get("pulsar_api_key", None)
log_level = params.get("log_level", LogLevel.INFO)
self.config_push_queue = params.get(
"config_push_queue",
default_config_queue
)
self.pulsar_host = pulsar_host
self.pulsar_api_key = pulsar_api_key
if pulsar_api_key:
auth = pulsar.AuthenticationToken(pulsar_api_key)
self.client = pulsar.Client(
pulsar_host,
authentication=auth,
logger=pulsar.ConsoleLogger(log_level.to_pulsar())
)
else:
self.client = pulsar.Client(
pulsar_host,
listener_name=pulsar_listener,
logger=pulsar.ConsoleLogger(log_level.to_pulsar())
)
self.pulsar_listener = pulsar_listener
self.config_subscriber = self.client.subscribe(
self.config_push_queue, config_subscriber_id,
consumer_type=pulsar.ConsumerType.Shared,
initial_position=pulsar.InitialPosition.Earliest,
schema=JsonSchema(ConfigPush),
)
def __del__(self):
if hasattr(self, "client"):
if self.client:
self.client.close()
@staticmethod
def add_args(parser):
parser.add_argument(
'-p', '--pulsar-host',
default=__class__.default_pulsar_host,
help=f'Pulsar host (default: {__class__.default_pulsar_host})',
)
parser.add_argument(
'--pulsar-api-key',
default=__class__.default_pulsar_api_key,
help=f'Pulsar API key',
)
parser.add_argument(
'--config-push-queue',
default=default_config_queue,
help=f'Config push queue {default_config_queue}',
)
parser.add_argument(
'--pulsar-listener',
help=f'Pulsar listener (default: none)',
)
parser.add_argument(
'-l', '--log-level',
type=LogLevel,
default=LogLevel.INFO,
choices=list(LogLevel),
help=f'Output queue (default: info)'
)
parser.add_argument(
'--metrics',
action=argparse.BooleanOptionalAction,
default=True,
help=f'Metrics enabled (default: true)',
)
parser.add_argument(
'-P', '--metrics-port',
type=int,
default=8000,
help=f'Pulsar host (default: 8000)',
)
async def start(self):
pass
async def run_config_queue(self):
if self.module == "config.service":
print("I am config-svc, not looking at config queue", flush=True)
return
print("Config thread running", flush=True)
while True:
try:
msg = await asyncio.to_thread(
self.config_subscriber.receive, timeout_millis=2000
)
except pulsar.Timeout:
continue
v = msg.value()
print("Got config version", v.version, flush=True)
await self.on_config(v.version, v.config)
async def on_config(self, version, config):
pass
async def run(self):
raise RuntimeError("Something should have implemented the run method")
@classmethod
async def launch_async(cls, args, prog):
p = cls(**args)
p.module = prog
await p.start()
task1 = asyncio.create_task(p.run_config_queue())
task2 = asyncio.create_task(p.run())
await asyncio.gather(task1, task2)
@classmethod
def launch(cls, prog, doc):
parser = argparse.ArgumentParser(
prog=prog,
description=doc
)
cls.add_args(parser)
args = parser.parse_args()
args = vars(args)
print(args)
if args["metrics"]:
start_http_server(args["metrics_port"])
while True:
try:
asyncio.run(cls.launch_async(args, prog))
except KeyboardInterrupt:
print("Keyboard interrupt.")
return
except _pulsar.Interrupted:
print("Pulsar Interrupted.")
return
except Exception as e:
print(type(e))
print("Exception:", e, flush=True)
print("Will retry...", flush=True)
time.sleep(4)

View file

@ -1,93 +1,136 @@
import asyncio
from pulsar.schema import JsonSchema from pulsar.schema import JsonSchema
import pulsar import pulsar
from prometheus_client import Histogram, Info, Counter, Enum import _pulsar
import asyncio
import time import time
from . base_processor import BaseProcessor
from .. exceptions import TooManyRequests from .. exceptions import TooManyRequests
default_rate_limit_retry = 10 class Consumer:
default_rate_limit_timeout = 7200
class Consumer(BaseProcessor): def __init__(
self, taskgroup, flow, client, topic, subscriber, schema,
handler,
metrics = None,
start_of_messages=False,
rate_limit_retry_time = 10, rate_limit_timeout = 7200,
reconnect_time = 5,
):
def __init__(self, **params): self.taskgroup = taskgroup
self.flow = flow
self.client = client
self.topic = topic
self.subscriber = subscriber
self.schema = schema
self.handler = handler
if not hasattr(__class__, "state_metric"): self.rate_limit_retry_time = rate_limit_retry_time
__class__.state_metric = Enum( self.rate_limit_timeout = rate_limit_timeout
'processor_state', 'Processor state',
states=['starting', 'running', 'stopped']
)
__class__.state_metric.state('starting')
__class__.state_metric.state('starting') self.reconnect_time = 5
super(Consumer, self).__init__(**params) self.start_of_messages = start_of_messages
self.input_queue = params.get("input_queue") self.running = True
self.subscriber = params.get("subscriber") self.task = None
self.input_schema = params.get("input_schema")
self.rate_limit_retry = params.get( self.metrics = metrics
"rate_limit_retry", default_rate_limit_retry
)
self.rate_limit_timeout = params.get(
"rate_limit_timeout", default_rate_limit_timeout
)
if self.input_schema == None: self.consumer = None
raise RuntimeError("input_schema must be specified")
if not hasattr(__class__, "request_metric"): def __del__(self):
__class__.request_metric = Histogram( self.running = False
'request_latency', 'Request latency (seconds)'
)
if not hasattr(__class__, "pubsub_metric"): if hasattr(self, "consumer"):
__class__.pubsub_metric = Info( if self.consumer:
'pubsub', 'Pub/sub configuration' self.consumer.close()
)
if not hasattr(__class__, "processing_metric"): async def stop(self):
__class__.processing_metric = Counter(
'processing_count', 'Processing count', ["status"]
)
if not hasattr(__class__, "rate_limit_metric"): self.running = False
__class__.rate_limit_metric = Counter( await self.task
'rate_limit_count', 'Rate limit event count',
)
__class__.pubsub_metric.info({ async def start(self):
"input_queue": self.input_queue,
"subscriber": self.subscriber,
"input_schema": self.input_schema.__name__,
"rate_limit_retry": str(self.rate_limit_retry),
"rate_limit_timeout": str(self.rate_limit_timeout),
})
self.consumer = self.client.subscribe( self.running = True
self.input_queue, self.subscriber,
consumer_type=pulsar.ConsumerType.Shared,
schema=JsonSchema(self.input_schema),
)
print("Initialised consumer.", flush=True) # Puts it in the stopped state, the run thread should set running
if self.metrics:
self.metrics.state("stopped")
self.task = self.taskgroup.create_task(self.run())
async def run(self): async def run(self):
__class__.state_metric.state('running') while self.running:
while True: if self.metrics:
self.metrics.state("stopped")
msg = await asyncio.to_thread(self.consumer.receive) try:
print(self.topic, "subscribing...", flush=True)
if self.start_of_messages:
pos = pulsar.InitialPosition.Earliest
else:
pos = pulsar.InitialPosition.Latest
self.consumer = await asyncio.to_thread(
self.client.subscribe,
topic = self.topic,
subscription_name = self.subscriber,
schema = JsonSchema(self.schema),
initial_position = pos,
consumer_type = pulsar.ConsumerType.Shared,
)
except Exception as e:
print("consumer subs Exception:", e, flush=True)
await asyncio.sleep(self.reconnect_time)
continue
print(self.topic, "subscribed", flush=True)
if self.metrics:
self.metrics.state("running")
try:
await self.consume()
if self.metrics:
self.metrics.state("stopped")
except Exception as e:
print("consumer loop exception:", e, flush=True)
self.consumer.close()
self.consumer = None
await asyncio.sleep(self.reconnect_time)
continue
async def consume(self):
while self.running:
try:
msg = await asyncio.to_thread(
self.consumer.receive,
timeout_millis=2000
)
except _pulsar.Timeout:
continue
except Exception as e:
raise e
expiry = time.time() + self.rate_limit_timeout expiry = time.time() + self.rate_limit_timeout
# This loop is for retry on rate-limit / resource limits # This loop is for retry on rate-limit / resource limits
while True: while self.running:
if time.time() > expiry: if time.time() > expiry:
@ -97,20 +140,31 @@ class Consumer(BaseProcessor):
# be retried # be retried
self.consumer.negative_acknowledge(msg) self.consumer.negative_acknowledge(msg)
__class__.processing_metric.labels(status="error").inc() if self.metrics:
self.metrics.process("error")
# Break out of retry loop, processes next message # Break out of retry loop, processes next message
break break
try: try:
with __class__.request_metric.time(): print("Handle...", flush=True)
await self.handle(msg)
if self.metrics:
with self.metrics.record_time():
await self.handler(msg, self, self.flow)
else:
await self.handler(msg, self.consumer)
print("Handled.", flush=True)
# Acknowledge successful processing of the message # Acknowledge successful processing of the message
self.consumer.acknowledge(msg) self.consumer.acknowledge(msg)
__class__.processing_metric.labels(status="success").inc() if self.metrics:
self.metrics.process("success")
# Break out of retry loop # Break out of retry loop
break break
@ -119,55 +173,25 @@ class Consumer(BaseProcessor):
print("TooManyRequests: will retry...", flush=True) print("TooManyRequests: will retry...", flush=True)
__class__.rate_limit_metric.inc() if self.metrics:
self.metrics.rate_limit()
# Sleep # Sleep
time.sleep(self.rate_limit_retry) await asyncio.sleep(self.rate_limit_retry_time)
# Contine from retry loop, just causes a reprocessing # Contine from retry loop, just causes a reprocessing
continue continue
except Exception as e: except Exception as e:
print("Exception:", e, flush=True) print("consume exception:", e, flush=True)
# Message failed to be processed, this causes it to # Message failed to be processed, this causes it to
# be retried # be retried
self.consumer.negative_acknowledge(msg) self.consumer.negative_acknowledge(msg)
__class__.processing_metric.labels(status="error").inc() if self.metrics:
self.metrics.process("error")
# Break out of retry loop, processes next message # Break out of retry loop, processes next message
break break
@staticmethod
def add_args(parser, default_input_queue, default_subscriber):
BaseProcessor.add_args(parser)
parser.add_argument(
'-i', '--input-queue',
default=default_input_queue,
help=f'Input queue (default: {default_input_queue})'
)
parser.add_argument(
'-s', '--subscriber',
default=default_subscriber,
help=f'Queue subscriber name (default: {default_subscriber})'
)
parser.add_argument(
'--rate-limit-retry',
type=int,
default=default_rate_limit_retry,
help=f'Rate limit retry (default: {default_rate_limit_retry})'
)
parser.add_argument(
'--rate-limit-timeout',
type=int,
default=default_rate_limit_timeout,
help=f'Rate limit timeout (default: {default_rate_limit_timeout})'
)

View file

@ -1,62 +0,0 @@
from pulsar.schema import JsonSchema
import pulsar
from prometheus_client import Histogram, Info, Counter, Enum
import time
from . consumer import Consumer
from .. exceptions import TooManyRequests
class ConsumerProducer(Consumer):
def __init__(self, **params):
super(ConsumerProducer, self).__init__(**params)
self.output_queue = params.get("output_queue")
self.output_schema = params.get("output_schema")
if not hasattr(__class__, "output_metric"):
__class__.output_metric = Counter(
'output_count', 'Output items created'
)
__class__.pubsub_metric.info({
"input_queue": self.input_queue,
"output_queue": self.output_queue,
"subscriber": self.subscriber,
"input_schema": self.input_schema.__name__,
"output_schema": self.output_schema.__name__,
"rate_limit_retry": str(self.rate_limit_retry),
"rate_limit_timeout": str(self.rate_limit_timeout),
})
if self.output_schema == None:
raise RuntimeError("output_schema must be specified")
self.producer = self.client.create_producer(
topic=self.output_queue,
schema=JsonSchema(self.output_schema),
chunking_enabled=True,
)
print("Initialised consumer/producer.")
async def send(self, msg, properties={}):
self.producer.send(msg, properties)
__class__.output_metric.inc()
@staticmethod
def add_args(
parser, default_input_queue, default_subscriber,
default_output_queue,
):
Consumer.add_args(parser, default_input_queue, default_subscriber)
parser.add_argument(
'-o', '--output-queue',
default=default_output_queue,
help=f'Output queue (default: {default_output_queue})'
)

View file

@ -0,0 +1,36 @@
from . metrics import ConsumerMetrics
from . consumer import Consumer
from . spec import Spec
class ConsumerSpec(Spec):
def __init__(self, name, schema, handler):
self.name = name
self.schema = schema
self.handler = handler
def add(self, flow, processor, definition):
consumer_metrics = ConsumerMetrics(
flow.id, f"{flow.name}-{self.name}"
)
consumer = Consumer(
taskgroup = processor.taskgroup,
flow = flow,
client = processor.client,
topic = definition[self.name],
subscriber = processor.id + "--" + self.name,
schema = self.schema,
handler = self.handler,
metrics = consumer_metrics,
)
# Consumer handle gets access to producers and other
# metadata
consumer.id = flow.id
consumer.name = self.name
consumer.flow = flow
flow.consumer[self.name] = consumer

View file

@ -0,0 +1,38 @@
from . request_response_spec import RequestResponse, RequestResponseSpec
from .. schema import DocumentEmbeddingsRequest, DocumentEmbeddingsResponse
from .. knowledge import Uri, Literal
class DocumentEmbeddingsClient(RequestResponse):
async def query(self, vectors, limit=20, user="trustgraph",
collection="default", timeout=30):
resp = await self.request(
DocumentEmbeddingsRequest(
vectors = vectors,
limit = limit,
user = user,
collection = collection
),
timeout=timeout
)
print(resp, flush=True)
if resp.error:
raise RuntimeError(resp.error.message)
return resp.documents
class DocumentEmbeddingsClientSpec(RequestResponseSpec):
def __init__(
self, request_name, response_name,
):
super(DocumentEmbeddingsClientSpec, self).__init__(
request_name = request_name,
request_schema = DocumentEmbeddingsRequest,
response_name = response_name,
response_schema = DocumentEmbeddingsResponse,
impl = DocumentEmbeddingsClient,
)

View file

@ -0,0 +1,84 @@
"""
Document embeddings query service. Input is vectors. Output is list of
embeddings.
"""
from .. schema import DocumentEmbeddingsRequest, DocumentEmbeddingsResponse
from .. schema import Error, Value
from . flow_processor import FlowProcessor
from . consumer_spec import ConsumerSpec
from . producer_spec import ProducerSpec
default_ident = "ge-query"
class DocumentEmbeddingsQueryService(FlowProcessor):
def __init__(self, **params):
id = params.get("id")
super(DocumentEmbeddingsQueryService, self).__init__(
**params | { "id": id }
)
self.register_specification(
ConsumerSpec(
name = "request",
schema = DocumentEmbeddingsRequest,
handler = self.on_message
)
)
self.register_specification(
ProducerSpec(
name = "response",
schema = DocumentEmbeddingsResponse,
)
)
async def on_message(self, msg, consumer, flow):
try:
request = msg.value()
# Sender-produced ID
id = msg.properties()["id"]
print(f"Handling input {id}...", flush=True)
docs = await self.query_document_embeddings(request)
print("Send response...", flush=True)
r = DocumentEmbeddingsResponse(documents=docs, error=None)
await flow("response").send(r, properties={"id": id})
print("Done.", flush=True)
except Exception as e:
print(f"Exception: {e}")
print("Send error response...", flush=True)
r = DocumentEmbeddingsResponse(
error=Error(
type = "document-embeddings-query-error",
message = str(e),
),
response=None,
)
await flow("response").send(r, properties={"id": id})
@staticmethod
def add_args(parser):
FlowProcessor.add_args(parser)
def run():
Processor.launch(default_ident, __doc__)

View file

@ -0,0 +1,50 @@
"""
Document embeddings store base class
"""
from .. schema import DocumentEmbeddings
from .. base import FlowProcessor, ConsumerSpec
from .. exceptions import TooManyRequests
default_ident = "document-embeddings-write"
class DocumentEmbeddingsStoreService(FlowProcessor):
def __init__(self, **params):
id = params.get("id")
super(DocumentEmbeddingsStoreService, self).__init__(
**params | { "id": id }
)
self.register_specification(
ConsumerSpec(
name = "input",
schema = DocumentEmbeddings,
handler = self.on_message
)
)
async def on_message(self, msg, consumer, flow):
try:
request = msg.value()
await self.store_document_embeddings(request)
except TooManyRequests as e:
raise e
except Exception as e:
print(f"Exception: {e}")
raise e
@staticmethod
def add_args(parser):
FlowProcessor.add_args(parser)

View file

@ -0,0 +1,31 @@
from . request_response_spec import RequestResponse, RequestResponseSpec
from .. schema import EmbeddingsRequest, EmbeddingsResponse
class EmbeddingsClient(RequestResponse):
async def embed(self, text, timeout=30):
resp = await self.request(
EmbeddingsRequest(
text = text
),
timeout=timeout
)
if resp.error:
raise RuntimeError(resp.error.message)
return resp.vectors
class EmbeddingsClientSpec(RequestResponseSpec):
def __init__(
self, request_name, response_name,
):
super(EmbeddingsClientSpec, self).__init__(
request_name = request_name,
request_schema = EmbeddingsRequest,
response_name = response_name,
response_schema = EmbeddingsResponse,
impl = EmbeddingsClient,
)

View file

@ -0,0 +1,90 @@
"""
Embeddings resolution base class
"""
import time
from prometheus_client import Histogram
from .. schema import EmbeddingsRequest, EmbeddingsResponse, Error
from .. exceptions import TooManyRequests
from .. base import FlowProcessor, ConsumerSpec, ProducerSpec
default_ident = "embeddings"
class EmbeddingsService(FlowProcessor):
def __init__(self, **params):
id = params.get("id")
super(EmbeddingsService, self).__init__(**params | { "id": id })
self.register_specification(
ConsumerSpec(
name = "request",
schema = EmbeddingsRequest,
handler = self.on_request
)
)
self.register_specification(
ProducerSpec(
name = "response",
schema = EmbeddingsResponse
)
)
async def on_request(self, msg, consumer, flow):
try:
request = msg.value()
# Sender-produced ID
id = msg.properties()["id"]
print("Handling request", id, "...", flush=True)
vectors = await self.on_embeddings(request.text)
await flow("response").send(
EmbeddingsResponse(
error = None,
vectors = vectors,
),
properties={"id": id}
)
print("Handled.", flush=True)
except TooManyRequests as e:
raise e
except Exception as e:
# Apart from rate limits, treat all exceptions as unrecoverable
print(f"Exception: {e}", flush=True)
print("Send error response...", flush=True)
await flow.producer["response"].send(
EmbeddingsResponse(
error=Error(
type = "embeddings-error",
message = str(e),
),
vectors=None,
),
properties={"id": id}
)
@staticmethod
def add_args(parser):
FlowProcessor.add_args(parser)

View file

@ -0,0 +1,32 @@
import asyncio
class Flow:
def __init__(self, id, flow, processor, defn):
self.id = id
self.name = flow
self.producer = {}
# Consumers and publishers. Is this a bit untidy?
self.consumer = {}
self.setting = {}
for spec in processor.specifications:
spec.add(self, processor, defn)
async def start(self):
for c in self.consumer.values():
await c.start()
async def stop(self):
for c in self.consumer.values():
await c.stop()
def __call__(self, key):
if key in self.producer: return self.producer[key]
if key in self.consumer: return self.consumer[key]
if key in self.setting: return self.setting[key].value
return None

View file

@ -0,0 +1,115 @@
# Base class for processor with management of flows in & out which are managed
# by configuration. This is probably all processor types, except for the
# configuration service which can't manage itself.
import json
from pulsar.schema import JsonSchema
from .. schema import Error
from .. schema import config_request_queue, config_response_queue
from .. schema import config_push_queue
from .. log_level import LogLevel
from . async_processor import AsyncProcessor
from . flow import Flow
# Parent class for configurable processors, configured with flows by
# the config service
class FlowProcessor(AsyncProcessor):
def __init__(self, **params):
# Initialise base class
super(FlowProcessor, self).__init__(**params)
# Register configuration handler
self.register_config_handler(self.on_configure_flows)
# Initialise flow information state
self.flows = {}
# These can be overriden by a derived class:
# Array of specifications: ConsumerSpec, ProducerSpec, SettingSpec
self.specifications = []
print("Service initialised.")
# Register a configuration variable
def register_specification(self, spec):
self.specifications.append(spec)
# Start processing for a new flow
async def start_flow(self, flow, defn):
self.flows[flow] = Flow(self.id, flow, self, defn)
await self.flows[flow].start()
print("Started flow: ", flow)
# Stop processing for a new flow
async def stop_flow(self, flow):
if flow in self.flows:
await self.flows[flow].stop()
del self.flows[flow]
print("Stopped flow: ", flow, flush=True)
# Event handler - called for a configuration change
async def on_configure_flows(self, config, version):
print("Got config version", version, flush=True)
# Skip over invalid data
if "flows-active" not in config: return
# Check there's configuration information for me
if self.id in config["flows-active"]:
# Get my flow config
flow_config = json.loads(config["flows-active"][self.id])
else:
print("No configuration settings for me.", flush=True)
flow_config = {}
# Get list of flows which should be running and are currently
# running
wanted_flows = flow_config.keys()
current_flows = self.flows.keys()
# Start all the flows which arent currently running
for flow in wanted_flows:
if flow not in current_flows:
await self.start_flow(flow, flow_config[flow])
# Stop all the unwanted flows which are due to be stopped
for flow in current_flows:
if flow not in wanted_flows:
await self.stop_flow(flow)
print("Handled config update")
# Start threads, just call parent
async def start(self):
await super(FlowProcessor, self).start()
@staticmethod
def add_args(parser):
AsyncProcessor.add_args(parser)
# parser.add_argument(
# '--rate-limit-retry',
# type=int,
# default=default_rate_limit_retry,
# help=f'Rate limit retry (default: {default_rate_limit_retry})'
# )
# parser.add_argument(
# '--rate-limit-timeout',
# type=int,
# default=default_rate_limit_timeout,
# help=f'Rate limit timeout (default: {default_rate_limit_timeout})'
# )

View file

@ -0,0 +1,45 @@
from . request_response_spec import RequestResponse, RequestResponseSpec
from .. schema import GraphEmbeddingsRequest, GraphEmbeddingsResponse
from .. knowledge import Uri, Literal
def to_value(x):
if x.is_uri: return Uri(x.value)
return Literal(x.value)
class GraphEmbeddingsClient(RequestResponse):
async def query(self, vectors, limit=20, user="trustgraph",
collection="default", timeout=30):
resp = await self.request(
GraphEmbeddingsRequest(
vectors = vectors,
limit = limit,
user = user,
collection = collection
),
timeout=timeout
)
print(resp, flush=True)
if resp.error:
raise RuntimeError(resp.error.message)
return [
to_value(v)
for v in resp.entities
]
class GraphEmbeddingsClientSpec(RequestResponseSpec):
def __init__(
self, request_name, response_name,
):
super(GraphEmbeddingsClientSpec, self).__init__(
request_name = request_name,
request_schema = GraphEmbeddingsRequest,
response_name = response_name,
response_schema = GraphEmbeddingsResponse,
impl = GraphEmbeddingsClient,
)

View file

@ -0,0 +1,84 @@
"""
Graph embeddings query service. Input is vectors. Output is list of
embeddings.
"""
from .. schema import GraphEmbeddingsRequest, GraphEmbeddingsResponse
from .. schema import Error, Value
from . flow_processor import FlowProcessor
from . consumer_spec import ConsumerSpec
from . producer_spec import ProducerSpec
default_ident = "ge-query"
class GraphEmbeddingsQueryService(FlowProcessor):
def __init__(self, **params):
id = params.get("id")
super(GraphEmbeddingsQueryService, self).__init__(
**params | { "id": id }
)
self.register_specification(
ConsumerSpec(
name = "request",
schema = GraphEmbeddingsRequest,
handler = self.on_message
)
)
self.register_specification(
ProducerSpec(
name = "response",
schema = GraphEmbeddingsResponse,
)
)
async def on_message(self, msg, consumer, flow):
try:
request = msg.value()
# Sender-produced ID
id = msg.properties()["id"]
print(f"Handling input {id}...", flush=True)
entities = await self.query_graph_embeddings(request)
print("Send response...", flush=True)
r = GraphEmbeddingsResponse(entities=entities, error=None)
await flow("response").send(r, properties={"id": id})
print("Done.", flush=True)
except Exception as e:
print(f"Exception: {e}")
print("Send error response...", flush=True)
r = GraphEmbeddingsResponse(
error=Error(
type = "graph-embeddings-query-error",
message = str(e),
),
response=None,
)
await flow("response").send(r, properties={"id": id})
@staticmethod
def add_args(parser):
FlowProcessor.add_args(parser)
def run():
Processor.launch(default_ident, __doc__)

View file

@ -0,0 +1,50 @@
"""
Graph embeddings store base class
"""
from .. schema import GraphEmbeddings
from .. base import FlowProcessor, ConsumerSpec
from .. exceptions import TooManyRequests
default_ident = "graph-embeddings-write"
class GraphEmbeddingsStoreService(FlowProcessor):
def __init__(self, **params):
id = params.get("id")
super(GraphEmbeddingsStoreService, self).__init__(
**params | { "id": id }
)
self.register_specification(
ConsumerSpec(
name = "input",
schema = GraphEmbeddings,
handler = self.on_message
)
)
async def on_message(self, msg, consumer, flow):
try:
request = msg.value()
await self.store_graph_embeddings(request)
except TooManyRequests as e:
raise e
except Exception as e:
print(f"Exception: {e}")
raise e
@staticmethod
def add_args(parser):
FlowProcessor.add_args(parser)

View file

@ -0,0 +1,33 @@
from . request_response_spec import RequestResponse, RequestResponseSpec
from .. schema import GraphRagQuery, GraphRagResponse
class GraphRagClient(RequestResponse):
async def rag(self, query, user="trustgraph", collection="default",
timeout=600):
resp = await self.request(
GraphRagQuery(
query = query,
user = user,
collection = collection,
),
timeout=timeout
)
if resp.error:
raise RuntimeError(resp.error.message)
return resp.response
class GraphRagClientSpec(RequestResponseSpec):
def __init__(
self, request_name, response_name,
):
super(GraphRagClientSpec, self).__init__(
request_name = request_name,
request_schema = GraphRagQuery,
response_name = response_name,
response_schema = GraphRagResponse,
impl = GraphRagClient,
)

View file

@ -0,0 +1,114 @@
"""
LLM text completion base class
"""
import time
from prometheus_client import Histogram
from .. schema import TextCompletionRequest, TextCompletionResponse, Error
from .. exceptions import TooManyRequests
from .. base import FlowProcessor, ConsumerSpec, ProducerSpec
default_ident = "text-completion"
class LlmResult:
__slots__ = ["text", "in_token", "out_token", "model"]
class LlmService(FlowProcessor):
def __init__(self, **params):
id = params.get("id")
super(LlmService, self).__init__(**params | { "id": id })
self.register_specification(
ConsumerSpec(
name = "request",
schema = TextCompletionRequest,
handler = self.on_request
)
)
self.register_specification(
ProducerSpec(
name = "response",
schema = TextCompletionResponse
)
)
if not hasattr(__class__, "text_completion_metric"):
__class__.text_completion_metric = Histogram(
'text_completion_duration',
'Text completion duration (seconds)',
["id", "flow"],
buckets=[
0.25, 0.5, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0,
8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0,
30.0, 35.0, 40.0, 45.0, 50.0, 60.0, 80.0, 100.0,
120.0
]
)
async def on_request(self, msg, consumer, flow):
try:
request = msg.value()
# Sender-produced ID
id = msg.properties()["id"]
with __class__.text_completion_metric.labels(
id=self.id,
flow=f"{flow.name}-{consumer.name}",
).time():
response = await self.generate_content(
request.system, request.prompt
)
await flow("response").send(
TextCompletionResponse(
error=None,
response=response.text,
in_token=response.in_token,
out_token=response.out_token,
model=response.model
),
properties={"id": id}
)
except TooManyRequests as e:
raise e
except Exception as e:
# Apart from rate limits, treat all exceptions as unrecoverable
print(f"Exception: {e}")
print("Send error response...", flush=True)
await flow.producer["response"].send(
TextCompletionResponse(
error=Error(
type = "llm-error",
message = str(e),
),
response=None,
in_token=None,
out_token=None,
model=None,
),
properties={"id": id}
)
@staticmethod
def add_args(parser):
FlowProcessor.add_args(parser)

View file

@ -0,0 +1,82 @@
from prometheus_client import start_http_server, Info, Enum, Histogram
from prometheus_client import Counter
class ConsumerMetrics:
def __init__(self, id, flow=None):
self.id = id
self.flow = flow
if not hasattr(__class__, "state_metric"):
__class__.state_metric = Enum(
'consumer_state', 'Consumer state',
["id", "flow"],
states=['stopped', 'running']
)
if not hasattr(__class__, "request_metric"):
__class__.request_metric = Histogram(
'request_latency', 'Request latency (seconds)',
["id", "flow"],
)
if not hasattr(__class__, "processing_metric"):
__class__.processing_metric = Counter(
'processing_count', 'Processing count',
["id", "flow", "status"]
)
if not hasattr(__class__, "rate_limit_metric"):
__class__.rate_limit_metric = Counter(
'rate_limit_count', 'Rate limit event count',
["id", "flow"]
)
def process(self, status):
__class__.processing_metric.labels(
id=self.id, flow=self.flow, status=status
).inc()
def rate_limit(self):
__class__.rate_limit_metric.labels(
id=self.id, flow=self.flow
).inc()
def state(self, state):
__class__.state_metric.labels(
id=self.id, flow=self.flow
).state(state)
def record_time(self):
return __class__.request_metric.labels(
id=self.id, flow=self.flow
).time()
class ProducerMetrics:
def __init__(self, id, flow=None):
self.id = id
self.flow = flow
if not hasattr(__class__, "output_metric"):
__class__.output_metric = Counter(
'output_count', 'Output items created',
["id", "flow"]
)
def inc(self):
__class__.output_metric.labels(id=self.id, flow=self.flow).inc()
class ProcessorMetrics:
def __init__(self, id):
self.id = id
if not hasattr(__class__, "processor_metric"):
__class__.processor_metric = Info(
'processor', 'Processor configuration',
["id"]
)
def info(self, info):
__class__.processor_metric.labels(id=self.id).info(info)

View file

@ -1,56 +1,69 @@
from pulsar.schema import JsonSchema from pulsar.schema import JsonSchema
from prometheus_client import Info, Counter import asyncio
from . base_processor import BaseProcessor class Producer:
class Producer(BaseProcessor): def __init__(self, client, topic, schema, metrics=None):
self.client = client
self.topic = topic
self.schema = schema
def __init__(self, **params): self.metrics = metrics
output_queue = params.get("output_queue") self.running = True
output_schema = params.get("output_schema") self.producer = None
if not hasattr(__class__, "output_metric"): def __del__(self):
__class__.output_metric = Counter(
'output_count', 'Output items created'
)
if not hasattr(__class__, "pubsub_metric"): self.running = False
__class__.pubsub_metric = Info(
'pubsub', 'Pub/sub configuration'
)
__class__.pubsub_metric.info({ if hasattr(self, "producer"):
"output_queue": output_queue, if self.producer:
"output_schema": output_schema.__name__, self.producer.close()
})
super(Producer, self).__init__(**params) async def start(self):
self.running = True
if output_schema == None: async def stop(self):
raise RuntimeError("output_schema must be specified") self.running = False
self.producer = self.client.create_producer(
topic=output_queue,
schema=JsonSchema(output_schema),
chunking_enabled=True,
)
async def send(self, msg, properties={}): async def send(self, msg, properties={}):
self.producer.send(msg, properties)
__class__.output_metric.inc()
@staticmethod if not self.running: return
def add_args(
parser, default_input_queue, default_subscriber,
default_output_queue,
):
BaseProcessor.add_args(parser) while self.running and self.producer is None:
try:
print("Connect publisher to", self.topic, "...", flush=True)
self.producer = self.client.create_producer(
topic = self.topic,
schema = JsonSchema(self.schema)
)
print("Connected to", self.topic, flush=True)
except Exception as e:
print("Exception:", e, flush=True)
await asyncio.sleep(2)
if not self.running: break
while self.running:
try:
await asyncio.to_thread(
self.producer.send,
msg, properties
)
if self.metrics:
self.metrics.inc()
# Delivery success, break out of loop
break
except Exception as e:
print("Exception:", e, flush=True)
self.producer.close()
self.producer = None
parser.add_argument(
'-o', '--output-queue',
default=default_output_queue,
help=f'Output queue (default: {default_output_queue})'
)

View file

@ -0,0 +1,25 @@
from . producer import Producer
from . metrics import ProducerMetrics
from . spec import Spec
class ProducerSpec(Spec):
def __init__(self, name, schema):
self.name = name
self.schema = schema
def add(self, flow, processor, definition):
producer_metrics = ProducerMetrics(
flow.id, f"{flow.name}-{self.name}"
)
producer = Producer(
client = processor.client,
topic = definition[self.name],
schema = self.schema,
metrics = producer_metrics,
)
flow.producer[self.name] = producer

View file

@ -0,0 +1,93 @@
import json
from . request_response_spec import RequestResponse, RequestResponseSpec
from .. schema import PromptRequest, PromptResponse
class PromptClient(RequestResponse):
async def prompt(self, id, variables, timeout=600):
resp = await self.request(
PromptRequest(
id = id,
terms = {
k: json.dumps(v)
for k, v in variables.items()
}
),
timeout=timeout
)
if resp.error:
raise RuntimeError(resp.error.message)
if resp.text: return resp.text
return json.loads(resp.object)
async def extract_definitions(self, text, timeout=600):
return await self.prompt(
id = "extract-definitions",
variables = { "text": text },
timeout = timeout,
)
async def extract_relationships(self, text, timeout=600):
return await self.prompt(
id = "extract-relationships",
variables = { "text": text },
timeout = timeout,
)
async def kg_prompt(self, query, kg, timeout=600):
return await self.prompt(
id = "kg-prompt",
variables = {
"query": query,
"knowledge": [
{ "s": v[0], "p": v[1], "o": v[2] }
for v in kg
]
},
timeout = timeout,
)
async def document_prompt(self, query, documents, timeout=600):
return await self.prompt(
id = "document-prompt",
variables = {
"query": query,
"documents": documents,
},
timeout = timeout,
)
async def agent_react(self, variables, timeout=600):
return await self.prompt(
id = "agent-react",
variables = variables,
timeout = timeout,
)
async def question(self, question, timeout=600):
return await self.prompt(
id = "question",
variables = {
"question": question,
},
timeout = timeout,
)
class PromptClientSpec(RequestResponseSpec):
def __init__(
self, request_name, response_name,
):
super(PromptClientSpec, self).__init__(
request_name = request_name,
request_schema = PromptRequest,
response_name = response_name,
response_schema = PromptResponse,
impl = PromptClient,
)

View file

@ -1,47 +1,52 @@
import queue from pulsar.schema import JsonSchema
import asyncio
import time import time
import pulsar import pulsar
import threading
class Publisher: class Publisher:
def __init__(self, pulsar_client, topic, schema=None, max_size=10, def __init__(self, client, topic, schema=None, max_size=10,
chunking_enabled=True): chunking_enabled=True):
self.client = pulsar_client self.client = client
self.topic = topic self.topic = topic
self.schema = schema self.schema = schema
self.q = queue.Queue(maxsize=max_size) self.q = asyncio.Queue(maxsize=max_size)
self.chunking_enabled = chunking_enabled self.chunking_enabled = chunking_enabled
self.running = True self.running = True
def start(self): async def start(self):
self.task = threading.Thread(target=self.run) self.task = asyncio.create_task(self.run())
self.task.start()
def stop(self): async def stop(self):
self.running = False self.running = False
def join(self): async def join(self):
self.stop() await self.stop()
self.task.join() await self.task
def run(self): async def run(self):
while self.running: while self.running:
try: try:
producer = self.client.create_producer( producer = self.client.create_producer(
topic=self.topic, topic=self.topic,
schema=self.schema, schema=JsonSchema(self.schema),
chunking_enabled=self.chunking_enabled, chunking_enabled=self.chunking_enabled,
) )
while self.running: while self.running:
try: try:
id, item = self.q.get(timeout=0.5) id, item = await asyncio.wait_for(
except queue.Empty: self.q.get(),
timeout=0.5
)
except asyncio.TimeoutError:
continue
except asyncio.QueueEmpty:
continue continue
if id: if id:
@ -55,7 +60,6 @@ class Publisher:
# If handler drops out, sleep a retry # If handler drops out, sleep a retry
time.sleep(2) time.sleep(2)
def send(self, id, msg): async def send(self, id, item):
self.q.put((id, msg)) await self.q.put((id, item))

View file

@ -0,0 +1,80 @@
import os
import pulsar
import uuid
from pulsar.schema import JsonSchema
from .. log_level import LogLevel
class PulsarClient:
default_pulsar_host = os.getenv("PULSAR_HOST", 'pulsar://pulsar:6650')
default_pulsar_api_key = os.getenv("PULSAR_API_KEY", None)
def __init__(self, **params):
self.client = None
pulsar_host = params.get("pulsar_host", self.default_pulsar_host)
pulsar_listener = params.get("pulsar_listener", None)
pulsar_api_key = params.get(
"pulsar_api_key",
self.default_pulsar_api_key
)
log_level = params.get("log_level", LogLevel.INFO)
self.pulsar_host = pulsar_host
self.pulsar_api_key = pulsar_api_key
if pulsar_api_key:
auth = pulsar.AuthenticationToken(pulsar_api_key)
self.client = pulsar.Client(
pulsar_host,
authentication=auth,
logger=pulsar.ConsoleLogger(log_level.to_pulsar())
)
else:
self.client = pulsar.Client(
pulsar_host,
listener_name=pulsar_listener,
logger=pulsar.ConsoleLogger(log_level.to_pulsar())
)
self.pulsar_listener = pulsar_listener
def close(self):
self.client.close()
def __del__(self):
if hasattr(self, "client"):
if self.client:
self.client.close()
@staticmethod
def add_args(parser):
parser.add_argument(
'-p', '--pulsar-host',
default=__class__.default_pulsar_host,
help=f'Pulsar host (default: {__class__.default_pulsar_host})',
)
parser.add_argument(
'--pulsar-api-key',
default=__class__.default_pulsar_api_key,
help=f'Pulsar API key',
)
parser.add_argument(
'--pulsar-listener',
help=f'Pulsar listener (default: none)',
)
parser.add_argument(
'-l', '--log-level',
type=LogLevel,
default=LogLevel.INFO,
choices=list(LogLevel),
help=f'Output queue (default: info)'
)

View file

@ -0,0 +1,136 @@
import uuid
import asyncio
from . subscriber import Subscriber
from . producer import Producer
from . spec import Spec
from . metrics import ConsumerMetrics, ProducerMetrics
class RequestResponse(Subscriber):
def __init__(
self, client, subscription, consumer_name,
request_topic, request_schema,
request_metrics,
response_topic, response_schema,
response_metrics,
):
super(RequestResponse, self).__init__(
client = client,
subscription = subscription,
consumer_name = consumer_name,
topic = response_topic,
schema = response_schema,
)
self.producer = Producer(
client = client,
topic = request_topic,
schema = request_schema,
metrics = request_metrics,
)
async def start(self):
await self.producer.start()
await super(RequestResponse, self).start()
async def stop(self):
await self.producer.stop()
await super(RequestResponse, self).stop()
async def request(self, req, timeout=300, recipient=None):
id = str(uuid.uuid4())
print("Request", id, "...", flush=True)
q = await self.subscribe(id)
try:
await self.producer.send(
req,
properties={"id": id}
)
except Exception as e:
print("Exception:", e)
raise e
try:
while True:
resp = await asyncio.wait_for(
q.get(),
timeout=timeout
)
print("Got response.", flush=True)
if recipient is None:
# If no recipient handler, just return the first
# response we get
return resp
else:
# Recipient handler gets to decide when we're done b
# returning a boolean
fin = await recipient(resp)
# If done, return the last result otherwise loop round for
# next response
if fin:
return resp
else:
continue
except Exception as e:
print("Exception:", e)
raise e
finally:
await self.unsubscribe(id)
# This deals with the request/response case. The caller needs to
# use another service in request/response mode. Uses two topics:
# - we send on the request topic as a producer
# - we receive on the response topic as a subscriber
class RequestResponseSpec(Spec):
def __init__(
self, request_name, request_schema, response_name,
response_schema, impl=RequestResponse
):
self.request_name = request_name
self.request_schema = request_schema
self.response_name = response_name
self.response_schema = response_schema
self.impl = impl
def add(self, flow, processor, definition):
producer_metrics = ProducerMetrics(
flow.id, f"{flow.name}-{self.response_name}"
)
rr = self.impl(
client = processor.client,
subscription = flow.id,
consumer_name = flow.id,
request_topic = definition[self.request_name],
request_schema = self.request_schema,
request_metrics = producer_metrics,
response_topic = definition[self.response_name],
response_schema = self.response_schema,
response_metrics = None,
)
flow.consumer[self.request_name] = rr

View file

@ -0,0 +1,19 @@
from . spec import Spec
class Setting:
def __init__(self, value):
self.value = value
async def start():
pass
async def stop():
pass
class SettingSpec(Spec):
def __init__(self, name):
self.name = name
def add(self, flow, processor, definition):
flow.config[self.name] = Setting(definition[self.name])

View file

@ -0,0 +1,4 @@
class Spec:
pass

View file

@ -1,14 +1,14 @@
import queue from pulsar.schema import JsonSchema
import pulsar import asyncio
import threading import _pulsar
import time import time
class Subscriber: class Subscriber:
def __init__(self, pulsar_client, topic, subscription, consumer_name, def __init__(self, client, topic, subscription, consumer_name,
schema=None, max_size=100): schema=None, max_size=100):
self.client = pulsar_client self.client = client
self.topic = topic self.topic = topic
self.subscription = subscription self.subscription = subscription
self.consumer_name = consumer_name self.consumer_name = consumer_name
@ -16,35 +16,50 @@ class Subscriber:
self.q = {} self.q = {}
self.full = {} self.full = {}
self.max_size = max_size self.max_size = max_size
self.lock = threading.Lock() self.lock = asyncio.Lock()
self.running = True self.running = True
def start(self): async def __del__(self):
self.task = threading.Thread(target=self.run)
self.task.start()
def stop(self):
self.running = False self.running = False
def join(self): async def start(self):
self.task.join() self.task = asyncio.create_task(self.run())
def run(self): async def stop(self):
self.running = False
async def join(self):
await self.stop()
await self.task
async def run(self):
while self.running: while self.running:
try: try:
consumer = self.client.subscribe( consumer = self.client.subscribe(
topic=self.topic, topic = self.topic,
subscription_name=self.subscription, subscription_name = self.subscription,
consumer_name=self.consumer_name, consumer_name = self.consumer_name,
schema=self.schema, schema = JsonSchema(self.schema),
) )
print("Subscriber running...", flush=True)
while self.running: while self.running:
msg = consumer.receive() try:
msg = await asyncio.to_thread(
consumer.receive,
timeout_millis=2000
)
except _pulsar.Timeout:
continue
except Exception as e:
print("Exception:", e, flush=True)
print(type(e))
raise e
# Acknowledge successful reception of the message # Acknowledge successful reception of the message
consumer.acknowledge(msg) consumer.acknowledge(msg)
@ -56,57 +71,68 @@ class Subscriber:
value = msg.value() value = msg.value()
with self.lock: async with self.lock:
# FIXME: Hard-coded timeouts
if id in self.q: if id in self.q:
try: try:
# FIXME: Timeout means data goes missing # FIXME: Timeout means data goes missing
self.q[id].put(value, timeout=0.5) await asyncio.wait_for(
except: self.q[id].put(value),
pass timeout=2
)
except Exception as e:
print("Q Put:", e, flush=True)
for q in self.full.values(): for q in self.full.values():
try: try:
# FIXME: Timeout means data goes missing # FIXME: Timeout means data goes missing
q.put(value, timeout=0.5) await asyncio.wait_for(
except: q.put(value),
pass timeout=2
)
except Exception as e:
print("Q Put:", e, flush=True)
except Exception as e: except Exception as e:
print("Exception:", e, flush=True) print("Subscriber exception:", e, flush=True)
consumer.close()
# If handler drops out, sleep a retry # If handler drops out, sleep a retry
time.sleep(2) time.sleep(2)
def subscribe(self, id): async def subscribe(self, id):
with self.lock: async with self.lock:
q = queue.Queue(maxsize=self.max_size) q = asyncio.Queue(maxsize=self.max_size)
self.q[id] = q self.q[id] = q
return q return q
def unsubscribe(self, id): async def unsubscribe(self, id):
with self.lock: async with self.lock:
if id in self.q: if id in self.q:
# self.q[id].shutdown(immediate=True) # self.q[id].shutdown(immediate=True)
del self.q[id] del self.q[id]
def subscribe_all(self, id): async def subscribe_all(self, id):
with self.lock: async with self.lock:
q = queue.Queue(maxsize=self.max_size) q = asyncio.Queue(maxsize=self.max_size)
self.full[id] = q self.full[id] = q
return q return q
def unsubscribe_all(self, id): async def unsubscribe_all(self, id):
with self.lock: async with self.lock:
if id in self.full: if id in self.full:
# self.full[id].shutdown(immediate=True) # self.full[id].shutdown(immediate=True)

View file

@ -0,0 +1,30 @@
from . metrics import ConsumerMetrics
from . subscriber import Subscriber
from . spec import Spec
class SubscriberSpec(Spec):
def __init__(self, name, schema):
self.name = name
self.schema = schema
def add(self, flow, processor, definition):
# FIXME: Metrics not used
subscriber_metrics = ConsumerMetrics(
flow.id, f"{flow.name}-{self.name}"
)
subscriber = Subscriber(
client = processor.client,
topic = definition[self.name],
subscription = flow.id,
consumer_name = flow.id,
schema = self.schema,
)
# Put it in the consumer map, does that work?
# It means it gets start/stop call.
flow.consumer[self.name] = subscriber

View file

@ -0,0 +1,30 @@
from . request_response_spec import RequestResponse, RequestResponseSpec
from .. schema import TextCompletionRequest, TextCompletionResponse
class TextCompletionClient(RequestResponse):
async def text_completion(self, system, prompt, timeout=600):
resp = await self.request(
TextCompletionRequest(
system = system, prompt = prompt
),
timeout=timeout
)
if resp.error:
raise RuntimeError(resp.error.message)
return resp.response
class TextCompletionClientSpec(RequestResponseSpec):
def __init__(
self, request_name, response_name,
):
super(TextCompletionClientSpec, self).__init__(
request_name = request_name,
request_schema = TextCompletionRequest,
response_name = response_name,
response_schema = TextCompletionResponse,
impl = TextCompletionClient,
)

View file

@ -0,0 +1,61 @@
from . request_response_spec import RequestResponse, RequestResponseSpec
from .. schema import TriplesQueryRequest, TriplesQueryResponse, Value
from .. knowledge import Uri, Literal
class Triple:
def __init__(self, s, p, o):
self.s = s
self.p = p
self.o = o
def to_value(x):
if x.is_uri: return Uri(x.value)
return Literal(x.value)
def from_value(x):
if x is None: return None
if isinstance(x, Uri):
return Value(value=str(x), is_uri=True)
else:
return Value(value=str(x), is_uri=False)
class TriplesClient(RequestResponse):
async def query(self, s=None, p=None, o=None, limit=20,
user="trustgraph", collection="default",
timeout=30):
resp = await self.request(
TriplesQueryRequest(
s = from_value(s),
p = from_value(p),
o = from_value(o),
limit = limit,
user = user,
collection = collection,
),
timeout=timeout
)
if resp.error:
raise RuntimeError(resp.error.message)
triples = [
Triple(to_value(v.s), to_value(v.p), to_value(v.o))
for v in resp.triples
]
return triples
class TriplesClientSpec(RequestResponseSpec):
def __init__(
self, request_name, response_name,
):
super(TriplesClientSpec, self).__init__(
request_name = request_name,
request_schema = TriplesQueryRequest,
response_name = response_name,
response_schema = TriplesQueryResponse,
impl = TriplesClient,
)

View file

@ -0,0 +1,82 @@
"""
Triples query service. Input is a (s, p, o) triple, some values may be
null. Output is a list of triples.
"""
from .. schema import TriplesQueryRequest, TriplesQueryResponse, Error
from .. schema import Value, Triple
from . flow_processor import FlowProcessor
from . consumer_spec import ConsumerSpec
from . producer_spec import ProducerSpec
default_ident = "triples-query"
class TriplesQueryService(FlowProcessor):
def __init__(self, **params):
id = params.get("id")
super(TriplesQueryService, self).__init__(**params | { "id": id })
self.register_specification(
ConsumerSpec(
name = "request",
schema = TriplesQueryRequest,
handler = self.on_message
)
)
self.register_specification(
ProducerSpec(
name = "response",
schema = TriplesQueryResponse,
)
)
async def on_message(self, msg, consumer, flow):
try:
request = msg.value()
# Sender-produced ID
id = msg.properties()["id"]
print(f"Handling input {id}...", flush=True)
triples = await self.query_triples(request)
print("Send response...", flush=True)
r = TriplesQueryResponse(triples=triples, error=None)
await flow("response").send(r, properties={"id": id})
print("Done.", flush=True)
except Exception as e:
print(f"Exception: {e}")
print("Send error response...", flush=True)
r = TriplesQueryResponse(
error = Error(
type = "triples-query-error",
message = str(e),
),
triples = None,
)
await flow("response").send(r, properties={"id": id})
@staticmethod
def add_args(parser):
FlowProcessor.add_args(parser)
def run():
Processor.launch(default_ident, __doc__)

View file

@ -0,0 +1,47 @@
"""
Triples store base class
"""
from .. schema import Triples
from .. base import FlowProcessor, ConsumerSpec
default_ident = "triples-write"
class TriplesStoreService(FlowProcessor):
def __init__(self, **params):
id = params.get("id")
super(TriplesStoreService, self).__init__(**params | { "id": id })
self.register_specification(
ConsumerSpec(
name = "input",
schema = Triples,
handler = self.on_message
)
)
async def on_message(self, msg, consumer, flow):
try:
request = msg.value()
await self.store_triples(request)
except TooManyRequests as e:
raise e
except Exception as e:
print(f"Exception: {e}")
raise e
@staticmethod
def add_args(parser):
FlowProcessor.add_args(parser)

View file

@ -26,12 +26,5 @@ class AgentResponse(Record):
thought = String() thought = String()
observation = String() observation = String()
agent_request_queue = topic(
'agent', kind='non-persistent', namespace='request'
)
agent_response_queue = topic(
'agent', kind='non-persistent', namespace='response'
)
############################################################################ ############################################################################

View file

@ -2,7 +2,7 @@
from pulsar.schema import Record, Bytes, String, Boolean, Array, Map, Integer from pulsar.schema import Record, Bytes, String, Boolean, Array, Map, Integer
from . topic import topic from . topic import topic
from . types import Error, RowSchema from . types import Error
############################################################################ ############################################################################

View file

@ -11,8 +11,6 @@ class Document(Record):
metadata = Metadata() metadata = Metadata()
data = Bytes() data = Bytes()
document_ingest_queue = topic('document-load')
############################################################################ ############################################################################
# Text documents / text from PDF # Text documents / text from PDF
@ -21,8 +19,6 @@ class TextDocument(Record):
metadata = Metadata() metadata = Metadata()
text = Bytes() text = Bytes()
text_ingest_queue = topic('text-document-load')
############################################################################ ############################################################################
# Chunks of text # Chunks of text
@ -31,8 +27,6 @@ class Chunk(Record):
metadata = Metadata() metadata = Metadata()
chunk = Bytes() chunk = Bytes()
chunk_ingest_queue = topic('chunk-load')
############################################################################ ############################################################################
# Document embeddings are embeddings associated with a chunk # Document embeddings are embeddings associated with a chunk
@ -46,8 +40,6 @@ class DocumentEmbeddings(Record):
metadata = Metadata() metadata = Metadata()
chunks = Array(ChunkEmbeddings()) chunks = Array(ChunkEmbeddings())
document_embeddings_store_queue = topic('document-embeddings-store')
############################################################################ ############################################################################
# Doc embeddings query # Doc embeddings query
@ -62,10 +54,3 @@ class DocumentEmbeddingsResponse(Record):
error = Error() error = Error()
documents = Array(Bytes()) documents = Array(Bytes())
document_embeddings_request_queue = topic(
'doc-embeddings', kind='non-persistent', namespace='request'
)
document_embeddings_response_queue = topic(
'doc-embeddings', kind='non-persistent', namespace='response',
)

View file

@ -0,0 +1,66 @@
from pulsar.schema import Record, Bytes, String, Boolean, Array, Map, Integer
from . topic import topic
from . types import Error
############################################################################
# Flow service:
# list_classes() -> (classname[])
# get_class(classname) -> (class)
# put_class(class) -> (class)
# delete_class(classname) -> ()
#
# list_flows() -> (flowid[])
# get_flow(flowid) -> (flow)
# start_flow(flowid, classname) -> ()
# stop_flow(flowid) -> ()
# Prompt services, abstract the prompt generation
class FlowRequest(Record):
operation = String() # list_classes, get_class, put_class, delete_class
# list_flows, get_flow, start_flow, stop_flow
# get_class, put_class, delete_class, start_flow
class_name = String()
# put_class
class = String()
# start_flow
description = String()
# get_flow, start_flow, stop_flow
flow_id = String()
class FlowResponse(Record):
# list_classes
class_names = Array(String())
# list_flows
flow_ids = Array(String())
# get_class
class = String()
# get_flow
flow = String()
# get_flow
description = String()
# Everything
error = Error()
flow_request_queue = topic(
'flow', kind='non-persistent', namespace='request'
)
flow_response_queue = topic(
'flow', kind='non-persistent', namespace='response'
)
############################################################################

View file

@ -17,7 +17,7 @@ from .... log_level import LogLevel
from .... base import ConsumerProducer from .... base import ConsumerProducer
from .... exceptions import TooManyRequests from .... exceptions import TooManyRequests
module = ".".join(__name__.split(".")[1:-1]) module = "text-completion"
default_input_queue = text_completion_request_queue default_input_queue = text_completion_request_queue
default_output_queue = text_completion_response_queue default_output_queue = text_completion_response_queue

View file

@ -4,89 +4,37 @@ Embeddings service, applies an embeddings model selected from HuggingFace.
Input is text, output is embeddings vector. Input is text, output is embeddings vector.
""" """
from ... base import EmbeddingsService
from langchain_huggingface import HuggingFaceEmbeddings from langchain_huggingface import HuggingFaceEmbeddings
from trustgraph.schema import EmbeddingsRequest, EmbeddingsResponse, Error default_ident = "embeddings"
from trustgraph.schema import embeddings_request_queue
from trustgraph.schema import embeddings_response_queue
from trustgraph.log_level import LogLevel
from trustgraph.base import ConsumerProducer
module = ".".join(__name__.split(".")[1:-1])
default_input_queue = embeddings_request_queue
default_output_queue = embeddings_response_queue
default_subscriber = module
default_model="all-MiniLM-L6-v2" default_model="all-MiniLM-L6-v2"
class Processor(ConsumerProducer): class Processor(EmbeddingsService):
def __init__(self, **params): def __init__(self, **params):
input_queue = params.get("input_queue", default_input_queue)
output_queue = params.get("output_queue", default_output_queue)
subscriber = params.get("subscriber", default_subscriber)
model = params.get("model", default_model) model = params.get("model", default_model)
super(Processor, self).__init__( super(Processor, self).__init__(
**params | { **params | { "model": model }
"input_queue": input_queue,
"output_queue": output_queue,
"subscriber": subscriber,
"input_schema": EmbeddingsRequest,
"output_schema": EmbeddingsResponse,
}
) )
print("Get model...", flush=True)
self.embeddings = HuggingFaceEmbeddings(model_name=model) self.embeddings = HuggingFaceEmbeddings(model_name=model)
async def handle(self, msg): async def on_embeddings(self, text):
v = msg.value()
# Sender-produced ID
id = msg.properties()["id"]
print(f"Handling input {id}...", flush=True)
try:
text = v.text
embeds = self.embeddings.embed_documents([text])
print("Send response...", flush=True)
r = EmbeddingsResponse(vectors=embeds, error=None)
await self.send(r, properties={"id": id})
print("Done.", flush=True)
except Exception as e:
print(f"Exception: {e}")
print("Send error response...", flush=True)
r = EmbeddingsResponse(
error=Error(
type = "llm-error",
message = str(e),
),
response=None,
)
await self.send(r, properties={"id": id})
self.consumer.acknowledge(msg)
embeds = self.embeddings.embed_documents([text])
print("Done.", flush=True)
return embeds
@staticmethod @staticmethod
def add_args(parser): def add_args(parser):
ConsumerProducer.add_args( EmbeddingsService.add_args(parser)
parser, default_input_queue, default_subscriber,
default_output_queue,
)
parser.add_argument( parser.add_argument(
'-m', '--model', '-m', '--model',
@ -96,5 +44,5 @@ class Processor(ConsumerProducer):
def run(): def run():
Processor.launch(module, __doc__) Processor.launch(default_ident, __doc__)

View file

@ -8,12 +8,11 @@ logger = logging.getLogger(__name__)
class AgentManager: class AgentManager:
def __init__(self, context, tools, additional_context=None): def __init__(self, tools, additional_context=None):
self.context = context
self.tools = tools self.tools = tools
self.additional_context = additional_context self.additional_context = additional_context
def reason(self, question, history): async def reason(self, question, history, context):
tools = self.tools tools = self.tools
@ -56,10 +55,7 @@ class AgentManager:
logger.info(f"prompt: {variables}") logger.info(f"prompt: {variables}")
obj = self.context.prompt.request( obj = await context("prompt-request").agent_react(variables)
"agent-react",
variables
)
print(json.dumps(obj, indent=4), flush=True) print(json.dumps(obj, indent=4), flush=True)
@ -85,9 +81,13 @@ class AgentManager:
return a return a
async def react(self, question, history, think, observe): async def react(self, question, history, think, observe, context):
act = self.reason(question, history) act = await self.reason(
question = question,
history = history,
context = context,
)
logger.info(f"act: {act}") logger.info(f"act: {act}")
if isinstance(act, Final): if isinstance(act, Final):
@ -104,7 +104,12 @@ class AgentManager:
else: else:
raise RuntimeError(f"No action for {act.name}!") raise RuntimeError(f"No action for {act.name}!")
resp = action.implementation.invoke(**act.arguments) print("TOOL>>>", act)
resp = await action.implementation(context).invoke(
**act.arguments
)
print("RSETUL", resp)
resp = resp.strip() resp = resp.strip()

View file

@ -6,103 +6,68 @@ import json
import re import re
import sys import sys
from pulsar.schema import JsonSchema from ... base import AgentService, TextCompletionClientSpec, PromptClientSpec
from ... base import GraphRagClientSpec
from ... base import ConsumerProducer from ... schema import AgentRequest, AgentResponse, AgentStep, Error
from ... schema import Error
from ... schema import AgentRequest, AgentResponse, AgentStep
from ... schema import agent_request_queue, agent_response_queue
from ... schema import prompt_request_queue as pr_request_queue
from ... schema import prompt_response_queue as pr_response_queue
from ... schema import graph_rag_request_queue as gr_request_queue
from ... schema import graph_rag_response_queue as gr_response_queue
from ... clients.prompt_client import PromptClient
from ... clients.llm_client import LlmClient
from ... clients.graph_rag_client import GraphRagClient
from . tools import KnowledgeQueryImpl, TextCompletionImpl from . tools import KnowledgeQueryImpl, TextCompletionImpl
from . agent_manager import AgentManager from . agent_manager import AgentManager
from . types import Final, Action, Tool, Argument from . types import Final, Action, Tool, Argument
module = ".".join(__name__.split(".")[1:-1]) default_ident = "agent-manager"
default_max_iterations = 10
default_input_queue = agent_request_queue class Processor(AgentService):
default_output_queue = agent_response_queue
default_subscriber = module
default_max_iterations = 15
class Processor(ConsumerProducer):
def __init__(self, **params): def __init__(self, **params):
id = params.get("id")
self.max_iterations = int( self.max_iterations = int(
params.get("max_iterations", default_max_iterations) params.get("max_iterations", default_max_iterations)
) )
tools = {}
input_queue = params.get("input_queue", default_input_queue)
output_queue = params.get("output_queue", default_output_queue)
subscriber = params.get("subscriber", default_subscriber)
prompt_request_queue = params.get(
"prompt_request_queue", pr_request_queue
)
prompt_response_queue = params.get(
"prompt_response_queue", pr_response_queue
)
graph_rag_request_queue = params.get(
"graph_rag_request_queue", gr_request_queue
)
graph_rag_response_queue = params.get(
"graph_rag_response_queue", gr_response_queue
)
self.config_key = params.get("config_type", "agent") self.config_key = params.get("config_type", "agent")
super(Processor, self).__init__( super(Processor, self).__init__(
**params | { **params | {
"input_queue": input_queue, "id": id,
"output_queue": output_queue, "max_iterations": self.max_iterations,
"subscriber": subscriber, "config_type": self.config_key,
"input_schema": AgentRequest,
"output_schema": AgentResponse,
"prompt_request_queue": prompt_request_queue,
"prompt_response_queue": prompt_response_queue,
"graph_rag_request_queue": gr_request_queue,
"graph_rag_response_queue": gr_response_queue,
} }
) )
self.prompt = PromptClient(
subscriber=subscriber,
input_queue=prompt_request_queue,
output_queue=prompt_response_queue,
pulsar_host = self.pulsar_host,
pulsar_api_key=self.pulsar_api_key,
)
self.graph_rag = GraphRagClient(
subscriber=subscriber,
input_queue=graph_rag_request_queue,
output_queue=graph_rag_response_queue,
pulsar_host = self.pulsar_host,
pulsar_api_key=self.pulsar_api_key,
)
# Need to be able to feed requests to myself
self.recursive_input = self.client.create_producer(
topic=input_queue,
schema=JsonSchema(AgentRequest),
)
self.agent = AgentManager( self.agent = AgentManager(
context=self,
tools=[], tools=[],
additional_context="", additional_context="",
) )
async def on_config(self, version, config): self.config_handlers.append(self.on_tools_config)
self.register_specification(
TextCompletionClientSpec(
request_name = "text-completion-request",
response_name = "text-completion-response",
)
)
self.register_specification(
GraphRagClientSpec(
request_name = "graph-rag-request",
response_name = "graph-rag-response",
)
)
self.register_specification(
PromptClientSpec(
request_name = "prompt-request",
response_name = "prompt-response",
)
)
async def on_tools_config(self, config, version):
print("Loading configuration version", version) print("Loading configuration version", version)
@ -138,9 +103,9 @@ class Processor(ConsumerProducer):
impl_id = data.get("type") impl_id = data.get("type")
if impl_id == "knowledge-query": if impl_id == "knowledge-query":
impl = KnowledgeQueryImpl(self) impl = KnowledgeQueryImpl
elif impl_id == "text-completion": elif impl_id == "text-completion":
impl = TextCompletionImpl(self) impl = TextCompletionImpl
else: else:
raise RuntimeError( raise RuntimeError(
f"Tool-kind {impl_id} not known" f"Tool-kind {impl_id} not known"
@ -155,7 +120,6 @@ class Processor(ConsumerProducer):
) )
self.agent = AgentManager( self.agent = AgentManager(
context=self,
tools=tools, tools=tools,
additional_context=additional additional_context=additional
) )
@ -164,19 +128,14 @@ class Processor(ConsumerProducer):
except Exception as e: except Exception as e:
print("Exception:", e, flush=True) print("on_tools_config Exception:", e, flush=True)
print("Configuration reload failed", flush=True) print("Configuration reload failed", flush=True)
async def handle(self, msg): async def agent_request(self, request, respond, next, flow):
try: try:
v = msg.value() if request.history:
# Sender-produced ID
id = msg.properties()["id"]
if v.history:
history = [ history = [
Action( Action(
thought=h.thought, thought=h.thought,
@ -184,12 +143,12 @@ class Processor(ConsumerProducer):
arguments=h.arguments, arguments=h.arguments,
observation=h.observation observation=h.observation
) )
for h in v.history for h in request.history
] ]
else: else:
history = [] history = []
print(f"Question: {v.question}", flush=True) print(f"Question: {request.question}", flush=True)
if len(history) >= self.max_iterations: if len(history) >= self.max_iterations:
raise RuntimeError("Too many agent iterations") raise RuntimeError("Too many agent iterations")
@ -207,7 +166,7 @@ class Processor(ConsumerProducer):
observation=None, observation=None,
) )
await self.send(r, properties={"id": id}) await respond(r)
async def observe(x): async def observe(x):
@ -220,15 +179,21 @@ class Processor(ConsumerProducer):
observation=x, observation=x,
) )
await self.send(r, properties={"id": id}) await respond(r)
act = await self.agent.react(v.question, history, think, observe) act = await self.agent.react(
question = request.question,
history = history,
think = think,
observe = observe,
context = flow,
)
print(f"Action: {act}", flush=True) print(f"Action: {act}", flush=True)
print("Send response...", flush=True) if isinstance(act, Final):
if type(act) == Final: print("Send final response...", flush=True)
r = AgentResponse( r = AgentResponse(
answer=act.final, answer=act.final,
@ -236,18 +201,20 @@ class Processor(ConsumerProducer):
thought=None, thought=None,
) )
await self.send(r, properties={"id": id}) await respond(r)
print("Done.", flush=True) print("Done.", flush=True)
return return
print("Send next...", flush=True)
history.append(act) history.append(act)
r = AgentRequest( r = AgentRequest(
question=v.question, question=request.question,
plan=v.plan, plan=request.plan,
state=v.state, state=request.state,
history=[ history=[
AgentStep( AgentStep(
thought=h.thought, thought=h.thought,
@ -259,7 +226,7 @@ class Processor(ConsumerProducer):
] ]
) )
self.recursive_input.send(r, properties={"id": id}) await next(r)
print("Done.", flush=True) print("Done.", flush=True)
@ -267,7 +234,7 @@ class Processor(ConsumerProducer):
except Exception as e: except Exception as e:
print(f"Exception: {e}") print(f"agent_request Exception: {e}")
print("Send error response...", flush=True) print("Send error response...", flush=True)
@ -279,39 +246,12 @@ class Processor(ConsumerProducer):
response=None, response=None,
) )
await self.send(r, properties={"id": id}) await respond(r)
@staticmethod @staticmethod
def add_args(parser): def add_args(parser):
ConsumerProducer.add_args( AgentService.add_args(parser)
parser, default_input_queue, default_subscriber,
default_output_queue,
)
parser.add_argument(
'--prompt-request-queue',
default=pr_request_queue,
help=f'Prompt request queue (default: {pr_request_queue})',
)
parser.add_argument(
'--prompt-response-queue',
default=pr_response_queue,
help=f'Prompt response queue (default: {pr_response_queue})',
)
parser.add_argument(
'--graph-rag-request-queue',
default=gr_request_queue,
help=f'Graph RAG request queue (default: {gr_request_queue})',
)
parser.add_argument(
'--graph-rag-response-queue',
default=gr_response_queue,
help=f'Graph RAG response queue (default: {gr_response_queue})',
)
parser.add_argument( parser.add_argument(
'--max-iterations', '--max-iterations',
@ -327,5 +267,5 @@ class Processor(ConsumerProducer):
def run(): def run():
Processor.launch(module, __doc__) Processor.launch(default_ident, __doc__)

View file

@ -4,16 +4,22 @@
class KnowledgeQueryImpl: class KnowledgeQueryImpl:
def __init__(self, context): def __init__(self, context):
self.context = context self.context = context
def invoke(self, **arguments): async def invoke(self, **arguments):
return self.context.graph_rag.request(arguments.get("question")) client = self.context("graph-rag-request")
print("Graph RAG question...", flush=True)
return await client.rag(
arguments.get("question")
)
# This tool implementation knows how to do text completion. This uses # This tool implementation knows how to do text completion. This uses
# the prompt service, rather than talking to TextCompletion directly. # the prompt service, rather than talking to TextCompletion directly.
class TextCompletionImpl: class TextCompletionImpl:
def __init__(self, context): def __init__(self, context):
self.context = context self.context = context
def invoke(self, **arguments): async def invoke(self, **arguments):
return self.context.prompt.request( client = self.context("prompt-request")
"question", { "question": arguments.get("question") } print("Prompt question...", flush=True)
return await client.question(
arguments.get("question")
) )

View file

@ -7,40 +7,27 @@ as text as separate output objects.
from langchain_text_splitters import RecursiveCharacterTextSplitter from langchain_text_splitters import RecursiveCharacterTextSplitter
from prometheus_client import Histogram from prometheus_client import Histogram
from ... schema import TextDocument, Chunk, Metadata from ... schema import TextDocument, Chunk
from ... schema import text_ingest_queue, chunk_ingest_queue from ... base import FlowProcessor, ConsumerSpec, ProducerSpec
from ... log_level import LogLevel
from ... base import ConsumerProducer
module = ".".join(__name__.split(".")[1:-1]) default_ident = "chunker"
default_input_queue = text_ingest_queue class Processor(FlowProcessor):
default_output_queue = chunk_ingest_queue
default_subscriber = module
class Processor(ConsumerProducer):
def __init__(self, **params): def __init__(self, **params):
input_queue = params.get("input_queue", default_input_queue) id = params.get("id", default_ident)
output_queue = params.get("output_queue", default_output_queue)
subscriber = params.get("subscriber", default_subscriber)
chunk_size = params.get("chunk_size", 2000) chunk_size = params.get("chunk_size", 2000)
chunk_overlap = params.get("chunk_overlap", 100) chunk_overlap = params.get("chunk_overlap", 100)
super(Processor, self).__init__( super(Processor, self).__init__(
**params | { **params | { "id": id }
"input_queue": input_queue,
"output_queue": output_queue,
"subscriber": subscriber,
"input_schema": TextDocument,
"output_schema": Chunk,
}
) )
if not hasattr(__class__, "chunk_metric"): if not hasattr(__class__, "chunk_metric"):
__class__.chunk_metric = Histogram( __class__.chunk_metric = Histogram(
'chunk_size', 'Chunk size', 'chunk_size', 'Chunk size',
["id", "flow"],
buckets=[100, 160, 250, 400, 650, 1000, 1600, buckets=[100, 160, 250, 400, 650, 1000, 1600,
2500, 4000, 6400, 10000, 16000] 2500, 4000, 6400, 10000, 16000]
) )
@ -52,7 +39,24 @@ class Processor(ConsumerProducer):
is_separator_regex=False, is_separator_regex=False,
) )
async def handle(self, msg): self.register_specification(
ConsumerSpec(
name = "input",
schema = TextDocument,
handler = self.on_message,
)
)
self.register_specification(
ProducerSpec(
name = "output",
schema = Chunk,
)
)
print("Chunker initialised", flush=True)
async def on_message(self, msg, consumer, flow):
v = msg.value() v = msg.value()
print(f"Chunking {v.metadata.id}...", flush=True) print(f"Chunking {v.metadata.id}...", flush=True)
@ -63,24 +67,25 @@ class Processor(ConsumerProducer):
for ix, chunk in enumerate(texts): for ix, chunk in enumerate(texts):
print("Chunk", len(chunk.page_content), flush=True)
r = Chunk( r = Chunk(
metadata=v.metadata, metadata=v.metadata,
chunk=chunk.page_content.encode("utf-8"), chunk=chunk.page_content.encode("utf-8"),
) )
__class__.chunk_metric.observe(len(chunk.page_content)) __class__.chunk_metric.labels(
id=consumer.id, flow=consumer.flow
).observe(len(chunk.page_content))
await self.send(r) await flow("output").send(r)
print("Done.", flush=True) print("Done.", flush=True)
@staticmethod @staticmethod
def add_args(parser): def add_args(parser):
ConsumerProducer.add_args( FlowProcessor.add_args(parser)
parser, default_input_queue, default_subscriber,
default_output_queue,
)
parser.add_argument( parser.add_argument(
'-z', '--chunk-size', '-z', '--chunk-size',
@ -98,5 +103,5 @@ class Processor(ConsumerProducer):
def run(): def run():
Processor.launch(module, __doc__) Processor.launch(default_ident, __doc__)

View file

@ -7,40 +7,27 @@ as text as separate output objects.
from langchain_text_splitters import TokenTextSplitter from langchain_text_splitters import TokenTextSplitter
from prometheus_client import Histogram from prometheus_client import Histogram
from ... schema import TextDocument, Chunk, Metadata from ... schema import TextDocument, Chunk
from ... schema import text_ingest_queue, chunk_ingest_queue from ... base import FlowProcessor
from ... log_level import LogLevel
from ... base import ConsumerProducer
module = ".".join(__name__.split(".")[1:-1]) default_ident = "chunker"
default_input_queue = text_ingest_queue class Processor(FlowProcessor):
default_output_queue = chunk_ingest_queue
default_subscriber = module
class Processor(ConsumerProducer):
def __init__(self, **params): def __init__(self, **params):
input_queue = params.get("input_queue", default_input_queue) id = params.get("id")
output_queue = params.get("output_queue", default_output_queue)
subscriber = params.get("subscriber", default_subscriber)
chunk_size = params.get("chunk_size", 250) chunk_size = params.get("chunk_size", 250)
chunk_overlap = params.get("chunk_overlap", 15) chunk_overlap = params.get("chunk_overlap", 15)
super(Processor, self).__init__( super(Processor, self).__init__(
**params | { **params | { "id": id }
"input_queue": input_queue,
"output_queue": output_queue,
"subscriber": subscriber,
"input_schema": TextDocument,
"output_schema": Chunk,
}
) )
if not hasattr(__class__, "chunk_metric"): if not hasattr(__class__, "chunk_metric"):
__class__.chunk_metric = Histogram( __class__.chunk_metric = Histogram(
'chunk_size', 'Chunk size', 'chunk_size', 'Chunk size',
["id", "flow"],
buckets=[100, 160, 250, 400, 650, 1000, 1600, buckets=[100, 160, 250, 400, 650, 1000, 1600,
2500, 4000, 6400, 10000, 16000] 2500, 4000, 6400, 10000, 16000]
) )
@ -51,7 +38,24 @@ class Processor(ConsumerProducer):
chunk_overlap=chunk_overlap, chunk_overlap=chunk_overlap,
) )
async def handle(self, msg): self.register_specification(
ConsumerSpec(
name = "input",
schema = TextDocument,
handler = self.on_message,
)
)
self.register_specification(
ProducerSpec(
name = "output",
schema = Chunk,
)
)
print("Chunker initialised", flush=True)
async def on_message(self, msg, consumer, flow):
v = msg.value() v = msg.value()
print(f"Chunking {v.metadata.id}...", flush=True) print(f"Chunking {v.metadata.id}...", flush=True)
@ -62,24 +66,25 @@ class Processor(ConsumerProducer):
for ix, chunk in enumerate(texts): for ix, chunk in enumerate(texts):
print("Chunk", len(chunk.page_content), flush=True)
r = Chunk( r = Chunk(
metadata=v.metadata, metadata=v.metadata,
chunk=chunk.page_content.encode("utf-8"), chunk=chunk.page_content.encode("utf-8"),
) )
__class__.chunk_metric.observe(len(chunk.page_content)) __class__.chunk_metric.labels(
id=consumer.id, flow=consumer.flow
).observe(len(chunk.page_content))
await self.send(r) await flow("output").send(r)
print("Done.", flush=True) print("Done.", flush=True)
@staticmethod @staticmethod
def add_args(parser): def add_args(parser):
ConsumerProducer.add_args( FlowProcessor.add_args(parser)
parser, default_input_queue, default_subscriber,
default_output_queue,
)
parser.add_argument( parser.add_argument(
'-z', '--chunk-size', '-z', '--chunk-size',
@ -97,5 +102,5 @@ class Processor(ConsumerProducer):
def run(): def run():
Processor.launch(module, __doc__) Processor.launch(default_ident, __doc__)

View file

@ -0,0 +1,215 @@
from trustgraph.schema import ConfigResponse
from trustgraph.schema import ConfigValue, Error
# This behaves just like a dict, should be easier to add persistent storage
# later
class ConfigurationItems(dict):
pass
class Configuration(dict):
# FIXME: The state is held internally. This only works if there's
# one config service. Should be more than one, and use a
# back-end state store.
def __init__(self, push):
# Version counter
self.version = 0
# External function to respond to update
self.push = push
def __getitem__(self, key):
if key not in self:
self[key] = ConfigurationItems()
return dict.__getitem__(self, key)
async def handle_get(self, v):
for k in v.keys:
if k.type not in self or k.key not in self[k.type]:
return ConfigResponse(
version = None,
values = None,
directory = None,
config = None,
error = Error(
type = "key-error",
message = f"Key error"
)
)
values = [
ConfigValue(
type = k.type,
key = k.key,
value = self[k.type][k.key]
)
for k in v.keys
]
return ConfigResponse(
version = self.version,
values = values,
directory = None,
config = None,
error = None,
)
async def handle_list(self, v):
if v.type not in self:
return ConfigResponse(
version = None,
values = None,
directory = None,
config = None,
error = Error(
type = "key-error",
message = "No such type",
),
)
return ConfigResponse(
version = self.version,
values = None,
directory = list(self[v.type].keys()),
config = None,
error = None,
)
async def handle_getvalues(self, v):
if v.type not in self:
return ConfigResponse(
version = None,
values = None,
directory = None,
config = None,
error = Error(
type = "key-error",
message = f"Key error"
)
)
values = [
ConfigValue(
type = v.type,
key = k,
value = self[v.type][k],
)
for k in self[v.type]
]
return ConfigResponse(
version = self.version,
values = values,
directory = None,
config = None,
error = None,
)
async def handle_delete(self, v):
for k in v.keys:
if k.type not in self or k.key not in self[k.type]:
return ConfigResponse(
version = None,
values = None,
directory = None,
config = None,
error = Error(
type = "key-error",
message = f"Key error"
)
)
for k in v.keys:
del self[k.type][k.key]
self.version += 1
await self.push()
return ConfigResponse(
version = None,
value = None,
directory = None,
values = None,
config = None,
error = None,
)
async def handle_put(self, v):
for k in v.values:
self[k.type][k.key] = k.value
self.version += 1
await self.push()
return ConfigResponse(
version = None,
value = None,
directory = None,
values = None,
error = None,
)
async def handle_config(self, v):
return ConfigResponse(
version = self.version,
value = None,
directory = None,
values = None,
config = self,
error = None,
)
async def handle(self, msg):
print("Handle message ", msg.operation)
if msg.operation == "get":
resp = await self.handle_get(msg)
elif msg.operation == "list":
resp = await self.handle_list(msg)
elif msg.operation == "getvalues":
resp = await self.handle_getvalues(msg)
elif msg.operation == "delete":
resp = await self.handle_delete(msg)
elif msg.operation == "put":
resp = await self.handle_put(msg)
elif msg.operation == "config":
resp = await self.handle_config(msg)
else:
resp = ConfigResponse(
value=None,
directory=None,
values=None,
error=Error(
type = "bad-operation",
message = "Bad operation"
)
)
return resp

View file

@ -1,284 +1,115 @@
""" """
Config service. Fetchs an extract from the Wikipedia page Config service. Manages system global configuration state
using the API.
""" """
from pulsar.schema import JsonSchema from pulsar.schema import JsonSchema
from trustgraph.schema import ConfigRequest, ConfigResponse, ConfigPush from trustgraph.schema import ConfigRequest, ConfigResponse, ConfigPush
from trustgraph.schema import ConfigValue, Error from trustgraph.schema import Error
from trustgraph.schema import config_request_queue, config_response_queue from trustgraph.schema import config_request_queue, config_response_queue
from trustgraph.schema import config_push_queue from trustgraph.schema import config_push_queue
from trustgraph.log_level import LogLevel from trustgraph.log_level import LogLevel
from trustgraph.base import ConsumerProducer from trustgraph.base import AsyncProcessor, Consumer, Producer
module = ".".join(__name__.split(".")[1:-1]) from . config import Configuration
from ... base import ProcessorMetrics, ConsumerMetrics, ProducerMetrics
from ... base import Consumer, Producer
default_input_queue = config_request_queue default_ident = "config-svc"
default_output_queue = config_response_queue
default_request_queue = config_request_queue
default_response_queue = config_response_queue
default_push_queue = config_push_queue default_push_queue = config_push_queue
default_subscriber = module
# This behaves just like a dict, should be easier to add persistent storage class Processor(AsyncProcessor):
# later
class ConfigurationItems(dict):
pass
class Configuration(dict):
def __getitem__(self, key):
if key not in self:
self[key] = ConfigurationItems()
return dict.__getitem__(self, key)
class Processor(ConsumerProducer):
def __init__(self, **params): def __init__(self, **params):
input_queue = params.get("input_queue", default_input_queue) request_queue = params.get("request_queue", default_request_queue)
output_queue = params.get("output_queue", default_output_queue) response_queue = params.get("response_queue", default_response_queue)
push_queue = params.get("push_queue", default_push_queue) push_queue = params.get("push_queue", default_push_queue)
subscriber = params.get("subscriber", default_subscriber) id = params.get("id")
request_schema = ConfigRequest
response_schema = ConfigResponse
push_schema = ConfigResponse
super(Processor, self).__init__( super(Processor, self).__init__(
**params | { **params | {
"input_queue": input_queue, "request_schema": request_schema.__name__,
"output_queue": output_queue, "response_schema": response_schema.__name__,
"push_queue": output_queue, "push_schema": push_schema.__name__,
"subscriber": subscriber,
"input_schema": ConfigRequest,
"output_schema": ConfigResponse,
"push_schema": ConfigPush,
} }
) )
self.push_prod = self.client.create_producer( request_metrics = ConsumerMetrics(id + "-request")
topic=push_queue, response_metrics = ProducerMetrics(id + "-response")
schema=JsonSchema(ConfigPush), push_metrics = ProducerMetrics(id + "-push")
self.push_pub = Producer(
client = self.client,
topic = push_queue,
schema = ConfigPush,
metrics = push_metrics,
) )
# FIXME: The state is held internally. This only works if there's self.response_pub = Producer(
# one config service. Should be more than one, and use a client = self.client,
# back-end state store. topic = response_queue,
self.config = Configuration() schema = ConfigResponse,
metrics = response_metrics,
)
# Version counter self.subs = Consumer(
self.version = 0 taskgroup = self.taskgroup,
client = self.client,
flow = None,
topic = request_queue,
subscriber = id,
schema = request_schema,
handler = self.on_message,
metrics = request_metrics,
)
self.config = Configuration(self.push)
print("Service initialised.")
async def start(self): async def start(self):
await self.push()
async def handle_get(self, v, id):
for k in v.keys:
if k.type not in self.config or k.key not in self.config[k.type]:
return ConfigResponse(
version = None,
values = None,
directory = None,
config = None,
error = Error(
type = "key-error",
message = f"Key error"
)
)
values = [
ConfigValue(
type = k.type,
key = k.key,
value = self.config[k.type][k.key]
)
for k in v.keys
]
return ConfigResponse(
version = self.version,
values = values,
directory = None,
config = None,
error = None,
)
async def handle_list(self, v, id):
if v.type not in self.config:
return ConfigResponse(
version = None,
values = None,
directory = None,
config = None,
error = Error(
type = "key-error",
message = "No such type",
),
)
return ConfigResponse(
version = self.version,
values = None,
directory = list(self.config[v.type].keys()),
config = None,
error = None,
)
async def handle_getvalues(self, v, id):
if v.type not in self.config:
return ConfigResponse(
version = None,
values = None,
directory = None,
config = None,
error = Error(
type = "key-error",
message = f"Key error"
)
)
values = [
ConfigValue(
type = v.type,
key = k,
value = self.config[v.type][k],
)
for k in self.config[v.type]
]
return ConfigResponse(
version = self.version,
values = values,
directory = None,
config = None,
error = None,
)
async def handle_delete(self, v, id):
for k in v.keys:
if k.type not in self.config or k.key not in self.config[k.type]:
return ConfigResponse(
version = None,
values = None,
directory = None,
config = None,
error = Error(
type = "key-error",
message = f"Key error"
)
)
for k in v.keys:
del self.config[k.type][k.key]
self.version += 1
await self.push() await self.push()
await self.subs.start()
return ConfigResponse(
version = None,
value = None,
directory = None,
values = None,
config = None,
error = None,
)
async def handle_put(self, v, id):
for k in v.values:
self.config[k.type][k.key] = k.value
self.version += 1
await self.push()
return ConfigResponse(
version = None,
value = None,
directory = None,
values = None,
error = None,
)
async def handle_config(self, v, id):
return ConfigResponse(
version = self.version,
value = None,
directory = None,
values = None,
config = self.config,
error = None,
)
async def push(self): async def push(self):
resp = ConfigPush( resp = ConfigPush(
version = self.version, version = self.config.version,
value = None, value = None,
directory = None, directory = None,
values = None, values = None,
config = self.config, config = self.config,
error = None, error = None,
) )
self.push_prod.send(resp)
print("Pushed.")
async def handle(self, msg): await self.push_pub.send(resp)
v = msg.value() print("Pushed version ", self.config.version)
# Sender-produced ID async def on_message(self, msg, consumer, flow):
id = msg.properties()["id"]
print(f"Handling {id}...", flush=True)
try: try:
if v.operation == "get": v = msg.value()
resp = await self.handle_get(v, id) # Sender-produced ID
id = msg.properties()["id"]
elif v.operation == "list": print(f"Handling {id}...", flush=True)
resp = await self.handle_list(v, id) resp = await self.config.handle(v)
elif v.operation == "getvalues": await self.response_pub.send(resp, properties={"id": id})
resp = await self.handle_getvalues(v, id)
elif v.operation == "delete":
resp = await self.handle_delete(v, id)
elif v.operation == "put":
resp = await self.handle_put(v, id)
elif v.operation == "config":
resp = await self.handle_config(v, id)
else:
resp = ConfigResponse(
value=None,
directory=None,
values=None,
error=Error(
type = "bad-operation",
message = "Bad operation"
)
)
await self.send(resp, properties={"id": id})
self.consumer.acknowledge(msg)
except Exception as e: except Exception as e:
@ -289,24 +120,33 @@ class Processor(ConsumerProducer):
), ),
text=None, text=None,
) )
await self.send(resp, properties={"id": id})
self.consumer.acknowledge(msg) await self.response_pub.send(resp, properties={"id": id})
@staticmethod @staticmethod
def add_args(parser): def add_args(parser):
ConsumerProducer.add_args( AsyncProcessor.add_args(parser)
parser, default_input_queue, default_subscriber,
default_output_queue, parser.add_argument(
'-q', '--request-queue',
default=default_request_queue,
help=f'Request queue (default: {default_request_queue})'
) )
parser.add_argument( parser.add_argument(
'-q', '--push-queue', '-r', '--response-queue',
default=default_response_queue,
help=f'Response queue {default_response_queue}',
)
parser.add_argument(
'--push-queue',
default=default_push_queue, default=default_push_queue,
help=f'Config push queue (default: {default_push_queue})' help=f'Config push queue (default: {default_push_queue})'
) )
def run(): def run():
Processor.launch(module, __doc__) Processor.launch(default_ident, __doc__)

View file

@ -17,12 +17,10 @@ from mistralai.models import OCRResponse
from ... schema import Document, TextDocument, Metadata from ... schema import Document, TextDocument, Metadata
from ... schema import document_ingest_queue, text_ingest_queue from ... schema import document_ingest_queue, text_ingest_queue
from ... log_level import LogLevel from ... log_level import LogLevel
from ... base import ConsumerProducer from ... base import InputOutputProcessor
module = ".".join(__name__.split(".")[1:-1]) module = "ocr"
default_input_queue = document_ingest_queue
default_output_queue = text_ingest_queue
default_subscriber = module default_subscriber = module
default_api_key = os.getenv("MISTRAL_TOKEN") default_api_key = os.getenv("MISTRAL_TOKEN")
@ -71,19 +69,17 @@ def get_combined_markdown(ocr_response: OCRResponse) -> str:
return "\n\n".join(markdowns) return "\n\n".join(markdowns)
class Processor(ConsumerProducer): class Processor(InputOutputProcessor):
def __init__(self, **params): def __init__(self, **params):
input_queue = params.get("input_queue", default_input_queue) id = params.get("id")
output_queue = params.get("output_queue", default_output_queue)
subscriber = params.get("subscriber", default_subscriber) subscriber = params.get("subscriber", default_subscriber)
api_key = params.get("api_key", default_api_key) api_key = params.get("api_key", default_api_key)
super(Processor, self).__init__( super(Processor, self).__init__(
**params | { **params | {
"input_queue": input_queue, "id": id,
"output_queue": output_queue,
"subscriber": subscriber, "subscriber": subscriber,
"input_schema": Document, "input_schema": Document,
"output_schema": TextDocument, "output_schema": TextDocument,
@ -151,7 +147,7 @@ class Processor(ConsumerProducer):
return markdown return markdown
async def handle(self, msg): async def on_message(self, msg, consumer):
print("PDF message received") print("PDF message received")
@ -166,17 +162,14 @@ class Processor(ConsumerProducer):
text=markdown.encode("utf-8"), text=markdown.encode("utf-8"),
) )
await self.send(r) await consumer.q.output.send(r)
print("Done.", flush=True) print("Done.", flush=True)
@staticmethod @staticmethod
def add_args(parser): def add_args(parser):
ConsumerProducer.add_args( InputOutputProcessor.add_args(parser, default_subscriber)
parser, default_input_queue, default_subscriber,
default_output_queue,
)
parser.add_argument( parser.add_argument(
'-k', '--api-key', '-k', '--api-key',

View file

@ -9,39 +9,43 @@ import base64
from langchain_community.document_loaders import PyPDFLoader from langchain_community.document_loaders import PyPDFLoader
from ... schema import Document, TextDocument, Metadata from ... schema import Document, TextDocument, Metadata
from ... schema import document_ingest_queue, text_ingest_queue
from ... log_level import LogLevel from ... log_level import LogLevel
from ... base import ConsumerProducer from ... base import FlowProcessor, ConsumerSpec, ProducerSpec
module = ".".join(__name__.split(".")[1:-1]) default_ident = "pdf-decoder"
default_input_queue = document_ingest_queue class Processor(FlowProcessor):
default_output_queue = text_ingest_queue
default_subscriber = module
class Processor(ConsumerProducer):
def __init__(self, **params): def __init__(self, **params):
input_queue = params.get("input_queue", default_input_queue) id = params.get("id", default_ident)
output_queue = params.get("output_queue", default_output_queue)
subscriber = params.get("subscriber", default_subscriber)
super(Processor, self).__init__( super(Processor, self).__init__(
**params | { **params | {
"input_queue": input_queue, "id": id,
"output_queue": output_queue,
"subscriber": subscriber,
"input_schema": Document,
"output_schema": TextDocument,
} }
) )
print("PDF inited") self.register_specification(
ConsumerSpec(
name = "input",
schema = Document,
handler = self.on_message,
)
)
async def handle(self, msg): self.register_specification(
ProducerSpec(
name = "output",
schema = TextDocument,
)
)
print("PDF message received") print("PDF inited", flush=True)
async def on_message(self, msg, consumer, flow):
print("PDF message received", flush=True)
v = msg.value() v = msg.value()
@ -59,24 +63,22 @@ class Processor(ConsumerProducer):
for ix, page in enumerate(pages): for ix, page in enumerate(pages):
print("page", ix, flush=True)
r = TextDocument( r = TextDocument(
metadata=v.metadata, metadata=v.metadata,
text=page.page_content.encode("utf-8"), text=page.page_content.encode("utf-8"),
) )
await self.send(r) await flow("output").send(r)
print("Done.", flush=True) print("Done.", flush=True)
@staticmethod @staticmethod
def add_args(parser): def add_args(parser):
FlowProcessor.add_args(parser)
ConsumerProducer.add_args(
parser, default_input_queue, default_subscriber,
default_output_queue,
)
def run(): def run():
Processor.launch(module, __doc__) Processor.launch(default_ident, __doc__)

View file

@ -1,153 +0,0 @@
from . clients.document_embeddings_client import DocumentEmbeddingsClient
from . clients.triples_query_client import TriplesQueryClient
from . clients.embeddings_client import EmbeddingsClient
from . clients.prompt_client import PromptClient
from . schema import DocumentEmbeddingsRequest, DocumentEmbeddingsResponse
from . schema import TriplesQueryRequest, TriplesQueryResponse
from . schema import prompt_request_queue
from . schema import prompt_response_queue
from . schema import embeddings_request_queue
from . schema import embeddings_response_queue
from . schema import document_embeddings_request_queue
from . schema import document_embeddings_response_queue
LABEL="http://www.w3.org/2000/01/rdf-schema#label"
DEFINITION="http://www.w3.org/2004/02/skos/core#definition"
class Query:
def __init__(
self, rag, user, collection, verbose,
doc_limit=20
):
self.rag = rag
self.user = user
self.collection = collection
self.verbose = verbose
self.doc_limit = doc_limit
def get_vector(self, query):
if self.verbose:
print("Compute embeddings...", flush=True)
qembeds = self.rag.embeddings.request(query)
if self.verbose:
print("Done.", flush=True)
return qembeds
def get_docs(self, query):
vectors = self.get_vector(query)
if self.verbose:
print("Get entities...", flush=True)
docs = self.rag.de_client.request(
vectors, limit=self.doc_limit
)
if self.verbose:
print("Docs:", flush=True)
for doc in docs:
print(doc, flush=True)
return docs
class DocumentRag:
def __init__(
self,
pulsar_host="pulsar://pulsar:6650",
pulsar_api_key=None,
pr_request_queue=None,
pr_response_queue=None,
emb_request_queue=None,
emb_response_queue=None,
de_request_queue=None,
de_response_queue=None,
verbose=False,
module="test",
):
self.verbose=verbose
if pr_request_queue is None:
pr_request_queue = prompt_request_queue
if pr_response_queue is None:
pr_response_queue = prompt_response_queue
if emb_request_queue is None:
emb_request_queue = embeddings_request_queue
if emb_response_queue is None:
emb_response_queue = embeddings_response_queue
if de_request_queue is None:
de_request_queue = document_embeddings_request_queue
if de_response_queue is None:
de_response_queue = document_embeddings_response_queue
if self.verbose:
print("Initialising...", flush=True)
self.de_client = DocumentEmbeddingsClient(
pulsar_host=pulsar_host,
subscriber=module + "-de",
input_queue=de_request_queue,
output_queue=de_response_queue,
pulsar_api_key=pulsar_api_key,
)
self.embeddings = EmbeddingsClient(
pulsar_host=pulsar_host,
input_queue=emb_request_queue,
output_queue=emb_response_queue,
subscriber=module + "-emb",
pulsar_api_key=pulsar_api_key,
)
self.lang = PromptClient(
pulsar_host=pulsar_host,
input_queue=pr_request_queue,
output_queue=pr_response_queue,
subscriber=module + "-de-prompt",
pulsar_api_key=pulsar_api_key,
)
if self.verbose:
print("Initialised", flush=True)
def query(
self, query, user="trustgraph", collection="default",
doc_limit=20,
):
if self.verbose:
print("Construct prompt...", flush=True)
q = Query(
rag=self, user=user, collection=collection, verbose=self.verbose,
doc_limit=doc_limit
)
docs = q.get_docs(query)
if self.verbose:
print("Invoke LLM...", flush=True)
print(docs)
print(query)
resp = self.lang.request_document_prompt(query, docs)
if self.verbose:
print("Done", flush=True)
return resp

View file

@ -6,61 +6,63 @@ Output is chunk plus embedding.
""" """
from ... schema import Chunk, ChunkEmbeddings, DocumentEmbeddings from ... schema import Chunk, ChunkEmbeddings, DocumentEmbeddings
from ... schema import chunk_ingest_queue from ... schema import EmbeddingsRequest, EmbeddingsResponse
from ... schema import document_embeddings_store_queue
from ... schema import embeddings_request_queue, embeddings_response_queue
from ... clients.embeddings_client import EmbeddingsClient
from ... log_level import LogLevel
from ... base import ConsumerProducer
module = ".".join(__name__.split(".")[1:-1]) from ... base import FlowProcessor, RequestResponseSpec, ConsumerSpec
from ... base import ProducerSpec
default_input_queue = chunk_ingest_queue default_ident = "document-embeddings"
default_output_queue = document_embeddings_store_queue
default_subscriber = module
class Processor(ConsumerProducer): class Processor(FlowProcessor):
def __init__(self, **params): def __init__(self, **params):
input_queue = params.get("input_queue", default_input_queue) id = params.get("id")
output_queue = params.get("output_queue", default_output_queue)
subscriber = params.get("subscriber", default_subscriber)
emb_request_queue = params.get(
"embeddings_request_queue", embeddings_request_queue
)
emb_response_queue = params.get(
"embeddings_response_queue", embeddings_response_queue
)
super(Processor, self).__init__( super(Processor, self).__init__(
**params | { **params | {
"input_queue": input_queue, "id": id,
"output_queue": output_queue,
"embeddings_request_queue": emb_request_queue,
"embeddings_response_queue": emb_response_queue,
"subscriber": subscriber,
"input_schema": Chunk,
"output_schema": DocumentEmbeddings,
} }
) )
self.embeddings = EmbeddingsClient( self.register_specification(
pulsar_host=self.pulsar_host, ConsumerSpec(
pulsar_api_key=self.pulsar_api_key, name = "input",
input_queue=emb_request_queue, schema = Chunk,
output_queue=emb_response_queue, handler = self.on_message,
subscriber=module + "-emb", )
) )
async def handle(self, msg): self.register_specification(
RequestResponseSpec(
request_name = "embeddings-request",
request_schema = EmbeddingsRequest,
response_name = "embeddings-response",
response_schema = EmbeddingsResponse,
)
)
self.register_specification(
ProducerSpec(
name = "output",
schema = DocumentEmbeddings
)
)
async def on_message(self, msg, consumer, flow):
v = msg.value() v = msg.value()
print(f"Indexing {v.metadata.id}...", flush=True) print(f"Indexing {v.metadata.id}...", flush=True)
try: try:
vectors = self.embeddings.request(v.chunk) resp = await flow("embeddings-request").request(
EmbeddingsRequest(
text = v.chunk
)
)
vectors = resp.vectors
embeds = [ embeds = [
ChunkEmbeddings( ChunkEmbeddings(
@ -74,7 +76,7 @@ class Processor(ConsumerProducer):
chunks=embeds, chunks=embeds,
) )
await self.send(r) await flow("output").send(r)
except Exception as e: except Exception as e:
print("Exception:", e, flush=True) print("Exception:", e, flush=True)
@ -87,24 +89,9 @@ class Processor(ConsumerProducer):
@staticmethod @staticmethod
def add_args(parser): def add_args(parser):
ConsumerProducer.add_args( FlowProcessor.add_args(parser)
parser, default_input_queue, default_subscriber,
default_output_queue,
)
parser.add_argument(
'--embeddings-request-queue',
default=embeddings_request_queue,
help=f'Embeddings request queue (default: {embeddings_request_queue})',
)
parser.add_argument(
'--embeddings-response-queue',
default=embeddings_response_queue,
help=f'Embeddings request queue (default: {embeddings_response_queue})',
)
def run(): def run():
Processor.launch(module, __doc__) Processor.launch(default_ident, __doc__)

View file

@ -1,81 +1,43 @@
""" """
Embeddings service, applies an embeddings model selected from HuggingFace. Embeddings service, applies an embeddings model using fastembed
Input is text, output is embeddings vector. Input is text, output is embeddings vector.
""" """
from ... schema import EmbeddingsRequest, EmbeddingsResponse from ... base import EmbeddingsService
from ... schema import embeddings_request_queue, embeddings_response_queue
from ... log_level import LogLevel
from ... base import ConsumerProducer
from fastembed import TextEmbedding from fastembed import TextEmbedding
import os
module = ".".join(__name__.split(".")[1:-1]) default_ident = "embeddings"
default_input_queue = embeddings_request_queue
default_output_queue = embeddings_response_queue
default_subscriber = module
default_model="sentence-transformers/all-MiniLM-L6-v2" default_model="sentence-transformers/all-MiniLM-L6-v2"
class Processor(ConsumerProducer): class Processor(EmbeddingsService):
def __init__(self, **params): def __init__(self, **params):
input_queue = params.get("input_queue", default_input_queue)
output_queue = params.get("output_queue", default_output_queue)
subscriber = params.get("subscriber", default_subscriber)
model = params.get("model", default_model) model = params.get("model", default_model)
super(Processor, self).__init__( super(Processor, self).__init__(
**params | { **params | { "model": model }
"input_queue": input_queue,
"output_queue": output_queue,
"subscriber": subscriber,
"input_schema": EmbeddingsRequest,
"output_schema": EmbeddingsResponse,
"model": model,
}
) )
print("Get model...", flush=True)
self.embeddings = TextEmbedding(model_name = model) self.embeddings = TextEmbedding(model_name = model)
async def handle(self, msg): async def on_embeddings(self, text):
v = msg.value()
# Sender-produced ID
id = msg.properties()["id"]
print(f"Handling input {id}...", flush=True)
text = v.text
vecs = self.embeddings.embed([text]) vecs = self.embeddings.embed([text])
vecs = [ return [
v.tolist() v.tolist()
for v in vecs for v in vecs
] ]
print("Send response...", flush=True)
r = EmbeddingsResponse(
vectors=list(vecs),
error=None,
)
await self.send(r, properties={"id": id})
print("Done.", flush=True)
@staticmethod @staticmethod
def add_args(parser): def add_args(parser):
ConsumerProducer.add_args( EmbeddingsService.add_args(parser)
parser, default_input_queue, default_subscriber,
default_output_queue,
)
parser.add_argument( parser.add_argument(
'-m', '--model', '-m', '--model',
@ -85,5 +47,5 @@ class Processor(ConsumerProducer):
def run(): def run():
Processor.launch(module, __doc__) Processor.launch(default_ident, __doc__)

View file

@ -6,53 +6,48 @@ Output is entity plus embedding.
""" """
from ... schema import EntityContexts, EntityEmbeddings, GraphEmbeddings from ... schema import EntityContexts, EntityEmbeddings, GraphEmbeddings
from ... schema import entity_contexts_ingest_queue from ... schema import EmbeddingsRequest, EmbeddingsResponse
from ... schema import graph_embeddings_store_queue
from ... schema import embeddings_request_queue, embeddings_response_queue
from ... clients.embeddings_client import EmbeddingsClient
from ... log_level import LogLevel
from ... base import ConsumerProducer
module = ".".join(__name__.split(".")[1:-1]) from ... base import FlowProcessor, EmbeddingsClientSpec, ConsumerSpec
from ... base import ProducerSpec
default_input_queue = entity_contexts_ingest_queue default_ident = "graph-embeddings"
default_output_queue = graph_embeddings_store_queue
default_subscriber = module
class Processor(ConsumerProducer): class Processor(FlowProcessor):
def __init__(self, **params): def __init__(self, **params):
input_queue = params.get("input_queue", default_input_queue) id = params.get("id")
output_queue = params.get("output_queue", default_output_queue)
subscriber = params.get("subscriber", default_subscriber)
emb_request_queue = params.get(
"embeddings_request_queue", embeddings_request_queue
)
emb_response_queue = params.get(
"embeddings_response_queue", embeddings_response_queue
)
super(Processor, self).__init__( super(Processor, self).__init__(
**params | { **params | {
"input_queue": input_queue, "id": id,
"output_queue": output_queue,
"embeddings_request_queue": emb_request_queue,
"embeddings_response_queue": emb_response_queue,
"subscriber": subscriber,
"input_schema": EntityContexts,
"output_schema": GraphEmbeddings,
} }
) )
self.embeddings = EmbeddingsClient( self.register_specification(
pulsar_host=self.pulsar_host, ConsumerSpec(
input_queue=emb_request_queue, name = "input",
output_queue=emb_response_queue, schema = EntityContexts,
subscriber=module + "-emb", handler = self.on_message,
)
) )
async def handle(self, msg): self.register_specification(
EmbeddingsClientSpec(
request_name = "embeddings-request",
response_name = "embeddings-response",
)
)
self.register_specification(
ProducerSpec(
name = "output",
schema = GraphEmbeddings
)
)
async def on_message(self, msg, consumer, flow):
v = msg.value() v = msg.value()
print(f"Indexing {v.metadata.id}...", flush=True) print(f"Indexing {v.metadata.id}...", flush=True)
@ -63,7 +58,9 @@ class Processor(ConsumerProducer):
for entity in v.entities: for entity in v.entities:
vectors = self.embeddings.request(entity.context) vectors = await flow("embeddings-request").embed(
text = entity.context
)
entities.append( entities.append(
EntityEmbeddings( EntityEmbeddings(
@ -77,7 +74,7 @@ class Processor(ConsumerProducer):
entities=entities, entities=entities,
) )
await self.send(r) await flow("output").send(r)
except Exception as e: except Exception as e:
print("Exception:", e, flush=True) print("Exception:", e, flush=True)
@ -90,24 +87,9 @@ class Processor(ConsumerProducer):
@staticmethod @staticmethod
def add_args(parser): def add_args(parser):
ConsumerProducer.add_args( FlowProcessor.add_args(parser)
parser, default_input_queue, default_subscriber,
default_output_queue,
)
parser.add_argument(
'--embeddings-request-queue',
default=embeddings_request_queue,
help=f'Embeddings request queue (default: {embeddings_request_queue})',
)
parser.add_argument(
'--embeddings-response-queue',
default=embeddings_response_queue,
help=f'Embeddings request queue (default: {embeddings_response_queue})',
)
def run(): def run():
Processor.launch(module, __doc__) Processor.launch(default_ident, __doc__)

View file

@ -11,7 +11,7 @@ from ... base import ConsumerProducer
from ollama import Client from ollama import Client
import os import os
module = ".".join(__name__.split(".")[1:-1]) module = "embeddings"
default_input_queue = embeddings_request_queue default_input_queue = embeddings_request_queue
default_output_queue = embeddings_response_queue default_output_queue = embeddings_response_queue

View file

@ -11,7 +11,7 @@ from trustgraph.log_level import LogLevel
from trustgraph.base import ConsumerProducer from trustgraph.base import ConsumerProducer
import requests import requests
module = ".".join(__name__.split(".")[1:-1]) module = "wikipedia"
default_input_queue = encyclopedia_lookup_request_queue default_input_queue = encyclopedia_lookup_request_queue
default_output_queue = encyclopedia_lookup_response_queue default_output_queue = encyclopedia_lookup_response_queue

View file

@ -5,84 +5,62 @@ get entity definitions which are output as graph edges along with
entity/context definitions for embedding. entity/context definitions for embedding.
""" """
import json
import urllib.parse import urllib.parse
from pulsar.schema import JsonSchema
from .... schema import Chunk, Triple, Triples, Metadata, Value from .... schema import Chunk, Triple, Triples, Metadata, Value
from .... schema import EntityContext, EntityContexts from .... schema import EntityContext, EntityContexts
from .... schema import chunk_ingest_queue, triples_store_queue from .... schema import PromptRequest, PromptResponse
from .... schema import entity_contexts_ingest_queue
from .... schema import prompt_request_queue
from .... schema import prompt_response_queue
from .... log_level import LogLevel
from .... clients.prompt_client import PromptClient
from .... rdf import TRUSTGRAPH_ENTITIES, DEFINITION, RDF_LABEL, SUBJECT_OF from .... rdf import TRUSTGRAPH_ENTITIES, DEFINITION, RDF_LABEL, SUBJECT_OF
from .... base import ConsumerProducer
from .... base import FlowProcessor, ConsumerSpec, ProducerSpec
from .... base import PromptClientSpec
DEFINITION_VALUE = Value(value=DEFINITION, is_uri=True) DEFINITION_VALUE = Value(value=DEFINITION, is_uri=True)
RDF_LABEL_VALUE = Value(value=RDF_LABEL, is_uri=True) RDF_LABEL_VALUE = Value(value=RDF_LABEL, is_uri=True)
SUBJECT_OF_VALUE = Value(value=SUBJECT_OF, is_uri=True) SUBJECT_OF_VALUE = Value(value=SUBJECT_OF, is_uri=True)
module = ".".join(__name__.split(".")[1:-1]) default_ident = "kg-extract-definitions"
default_input_queue = chunk_ingest_queue class Processor(FlowProcessor):
default_output_queue = triples_store_queue
default_entity_context_queue = entity_contexts_ingest_queue
default_subscriber = module
class Processor(ConsumerProducer):
def __init__(self, **params): def __init__(self, **params):
input_queue = params.get("input_queue", default_input_queue) id = params.get("id")
output_queue = params.get("output_queue", default_output_queue)
ec_queue = params.get(
"entity_context_queue",
default_entity_context_queue
)
subscriber = params.get("subscriber", default_subscriber)
pr_request_queue = params.get(
"prompt_request_queue", prompt_request_queue
)
pr_response_queue = params.get(
"prompt_response_queue", prompt_response_queue
)
super(Processor, self).__init__( super(Processor, self).__init__(
**params | { **params | {
"input_queue": input_queue, "id": id,
"output_queue": output_queue,
"subscriber": subscriber,
"input_schema": Chunk,
"output_schema": Triples,
"prompt_request_queue": pr_request_queue,
"prompt_response_queue": pr_response_queue,
} }
) )
self.ec_prod = self.client.create_producer( self.register_specification(
topic=ec_queue, ConsumerSpec(
schema=JsonSchema(EntityContexts), name = "input",
schema = Chunk,
handler = self.on_message
)
) )
__class__.pubsub_metric.info({ self.register_specification(
"input_queue": input_queue, PromptClientSpec(
"output_queue": output_queue, request_name = "prompt-request",
"entity_context_queue": ec_queue, response_name = "prompt-response",
"prompt_request_queue": pr_request_queue, )
"prompt_response_queue": pr_response_queue, )
"subscriber": subscriber,
"input_schema": Chunk.__name__,
"output_schema": Triples.__name__,
"vector_schema": EntityContexts.__name__,
})
self.prompt = PromptClient( self.register_specification(
pulsar_host=self.pulsar_host, ProducerSpec(
pulsar_api_key=self.pulsar_api_key, name = "triples",
input_queue=pr_request_queue, schema = Triples
output_queue=pr_response_queue, )
subscriber = module + "-prompt", )
self.register_specification(
ProducerSpec(
name = "entity-contexts",
schema = EntityContexts
)
) )
def to_uri(self, text): def to_uri(self, text):
@ -93,36 +71,47 @@ class Processor(ConsumerProducer):
return uri return uri
def get_definitions(self, chunk): async def emit_triples(self, pub, metadata, triples):
return self.prompt.request_definitions(chunk)
async def emit_edges(self, metadata, triples):
t = Triples( t = Triples(
metadata=metadata, metadata=metadata,
triples=triples, triples=triples,
) )
await self.send(t) await pub.send(t)
async def emit_ecs(self, metadata, entities): async def emit_ecs(self, pub, metadata, entities):
t = EntityContexts( t = EntityContexts(
metadata=metadata, metadata=metadata,
entities=entities, entities=entities,
) )
self.ec_prod.send(t) await pub.send(t)
async def handle(self, msg): async def on_message(self, msg, consumer, flow):
v = msg.value() v = msg.value()
print(f"Indexing {v.metadata.id}...", flush=True) print(f"Indexing {v.metadata.id}...", flush=True)
chunk = v.chunk.decode("utf-8") chunk = v.chunk.decode("utf-8")
print(chunk, flush=True)
try: try:
defs = self.get_definitions(chunk) try:
defs = await flow("prompt-request").extract_definitions(
text = chunk
)
print("Response", defs, flush=True)
if type(defs) != list:
raise RuntimeError("Expecting array in prompt response")
except Exception as e:
print("Prompt exception:", e, flush=True)
raise e
triples = [] triples = []
entities = [] entities = []
@ -134,8 +123,8 @@ class Processor(ConsumerProducer):
for defn in defs: for defn in defs:
s = defn.name s = defn["entity"]
o = defn.definition o = defn["definition"]
if s == "": continue if s == "": continue
if o == "": continue if o == "": continue
@ -166,13 +155,13 @@ class Processor(ConsumerProducer):
ec = EntityContext( ec = EntityContext(
entity=s_value, entity=s_value,
context=defn.definition, context=defn["definition"],
) )
entities.append(ec) entities.append(ec)
await self.emit_triples(
await self.emit_edges( flow("triples"),
Metadata( Metadata(
id=v.metadata.id, id=v.metadata.id,
metadata=[], metadata=[],
@ -183,6 +172,7 @@ class Processor(ConsumerProducer):
) )
await self.emit_ecs( await self.emit_ecs(
flow("entity-contexts"),
Metadata( Metadata(
id=v.metadata.id, id=v.metadata.id,
metadata=[], metadata=[],
@ -200,30 +190,9 @@ class Processor(ConsumerProducer):
@staticmethod @staticmethod
def add_args(parser): def add_args(parser):
ConsumerProducer.add_args( FlowProcessor.add_args(parser)
parser, default_input_queue, default_subscriber,
default_output_queue,
)
parser.add_argument(
'-e', '--entity-context-queue',
default=default_entity_context_queue,
help=f'Entity context queue (default: {default_entity_context_queue})'
)
parser.add_argument(
'--prompt-request-queue',
default=prompt_request_queue,
help=f'Prompt request queue (default: {prompt_request_queue})',
)
parser.add_argument(
'--prompt-completion-response-queue',
default=prompt_response_queue,
help=f'Prompt response queue (default: {prompt_response_queue})',
)
def run(): def run():
Processor.launch(module, __doc__) Processor.launch(default_ident, __doc__)

View file

@ -5,59 +5,54 @@ relationship analysis to get entity relationship edges which are output as
graph edges. graph edges.
""" """
import json
import urllib.parse import urllib.parse
from .... schema import Chunk, Triple, Triples from .... schema import Chunk, Triple, Triples
from .... schema import Metadata, Value from .... schema import Metadata, Value
from .... schema import chunk_ingest_queue, triples_store_queue from .... schema import PromptRequest, PromptResponse
from .... schema import prompt_request_queue
from .... schema import prompt_response_queue
from .... log_level import LogLevel
from .... clients.prompt_client import PromptClient
from .... rdf import RDF_LABEL, TRUSTGRAPH_ENTITIES, SUBJECT_OF from .... rdf import RDF_LABEL, TRUSTGRAPH_ENTITIES, SUBJECT_OF
from .... base import ConsumerProducer
from .... base import FlowProcessor, ConsumerSpec, ProducerSpec
from .... base import PromptClientSpec
RDF_LABEL_VALUE = Value(value=RDF_LABEL, is_uri=True) RDF_LABEL_VALUE = Value(value=RDF_LABEL, is_uri=True)
SUBJECT_OF_VALUE = Value(value=SUBJECT_OF, is_uri=True) SUBJECT_OF_VALUE = Value(value=SUBJECT_OF, is_uri=True)
module = ".".join(__name__.split(".")[1:-1]) default_ident = "kg-extract-relationships"
default_input_queue = chunk_ingest_queue class Processor(FlowProcessor):
default_output_queue = triples_store_queue
default_subscriber = module
class Processor(ConsumerProducer):
def __init__(self, **params): def __init__(self, **params):
input_queue = params.get("input_queue", default_input_queue) id = params.get("id")
output_queue = params.get("output_queue", default_output_queue)
subscriber = params.get("subscriber", default_subscriber)
pr_request_queue = params.get(
"prompt_request_queue", prompt_request_queue
)
pr_response_queue = params.get(
"prompt_response_queue", prompt_response_queue
)
super(Processor, self).__init__( super(Processor, self).__init__(
**params | { **params | {
"input_queue": input_queue, "id": id,
"output_queue": output_queue,
"subscriber": subscriber,
"input_schema": Chunk,
"output_schema": Triples,
"prompt_request_queue": pr_request_queue,
"prompt_response_queue": pr_response_queue,
} }
) )
self.prompt = PromptClient( self.register_specification(
pulsar_host=self.pulsar_host, ConsumerSpec(
pulsar_api_key=self.pulsar_api_key, name = "input",
input_queue=pr_request_queue, schema = Chunk,
output_queue=pr_response_queue, handler = self.on_message
subscriber = module + "-prompt", )
)
self.register_specification(
PromptClientSpec(
request_name = "prompt-request",
response_name = "prompt-response",
)
)
self.register_specification(
ProducerSpec(
name = "triples",
schema = Triples
)
) )
def to_uri(self, text): def to_uri(self, text):
@ -68,28 +63,39 @@ class Processor(ConsumerProducer):
return uri return uri
def get_relationships(self, chunk): async def emit_triples(self, pub, metadata, triples):
return self.prompt.request_relationships(chunk)
async def emit_edges(self, metadata, triples):
t = Triples( t = Triples(
metadata=metadata, metadata=metadata,
triples=triples, triples=triples,
) )
await self.send(t) await pub.send(t)
async def handle(self, msg): async def on_message(self, msg, consumer, flow):
v = msg.value() v = msg.value()
print(f"Indexing {v.metadata.id}...", flush=True) print(f"Indexing {v.metadata.id}...", flush=True)
chunk = v.chunk.decode("utf-8") chunk = v.chunk.decode("utf-8")
print(chunk, flush=True)
try: try:
rels = self.get_relationships(chunk) try:
rels = await flow("prompt-request").extract_relationships(
text = chunk
)
print("Response", rels, flush=True)
if type(rels) != list:
raise RuntimeError("Expecting array in prompt response")
except Exception as e:
print("Prompt exception:", e, flush=True)
raise e
triples = [] triples = []
@ -100,9 +106,9 @@ class Processor(ConsumerProducer):
for rel in rels: for rel in rels:
s = rel.s s = rel["subject"]
p = rel.p p = rel["predicate"]
o = rel.o o = rel["object"]
if s == "": continue if s == "": continue
if p == "": continue if p == "": continue
@ -118,7 +124,7 @@ class Processor(ConsumerProducer):
p_uri = self.to_uri(p) p_uri = self.to_uri(p)
p_value = Value(value=str(p_uri), is_uri=True) p_value = Value(value=str(p_uri), is_uri=True)
if rel.o_entity: if rel["object-entity"]:
o_uri = self.to_uri(o) o_uri = self.to_uri(o)
o_value = Value(value=str(o_uri), is_uri=True) o_value = Value(value=str(o_uri), is_uri=True)
else: else:
@ -144,7 +150,7 @@ class Processor(ConsumerProducer):
o=Value(value=str(p), is_uri=False) o=Value(value=str(p), is_uri=False)
)) ))
if rel.o_entity: if rel["object-entity"]:
# Label for o # Label for o
triples.append(Triple( triples.append(Triple(
s=o_value, s=o_value,
@ -159,7 +165,7 @@ class Processor(ConsumerProducer):
o=Value(value=v.metadata.id, is_uri=True) o=Value(value=v.metadata.id, is_uri=True)
)) ))
if rel.o_entity: if rel["object-entity"]:
# 'Subject of' for o # 'Subject of' for o
triples.append(Triple( triples.append(Triple(
s=o_value, s=o_value,
@ -168,6 +174,7 @@ class Processor(ConsumerProducer):
)) ))
await self.emit_edges( await self.emit_edges(
flow("triples"),
Metadata( Metadata(
id=v.metadata.id, id=v.metadata.id,
metadata=[], metadata=[],
@ -185,24 +192,9 @@ class Processor(ConsumerProducer):
@staticmethod @staticmethod
def add_args(parser): def add_args(parser):
ConsumerProducer.add_args( FlowProcessor.add_args(parser)
parser, default_input_queue, default_subscriber,
default_output_queue,
)
parser.add_argument(
'--prompt-request-queue',
default=prompt_request_queue,
help=f'Prompt request queue (default: {prompt_request_queue})',
)
parser.add_argument(
'--prompt-response-queue',
default=prompt_response_queue,
help=f'Prompt response queue (default: {prompt_response_queue})',
)
def run(): def run():
Processor.launch(module, __doc__) Processor.launch(default_ident, __doc__)

View file

@ -18,7 +18,7 @@ from .... base import ConsumerProducer
DEFINITION_VALUE = Value(value=DEFINITION, is_uri=True) DEFINITION_VALUE = Value(value=DEFINITION, is_uri=True)
module = ".".join(__name__.split(".")[1:-1]) module = "kg-extract-topics"
default_input_queue = chunk_ingest_queue default_input_queue = chunk_ingest_queue
default_output_queue = triples_store_queue default_output_queue = triples_store_queue

View file

@ -39,4 +39,3 @@ class AgentRequestor(ServiceRequestor):
# The 2nd boolean expression indicates whether we're done responding # The 2nd boolean expression indicates whether we're done responding
return resp, (message.answer is not None) return resp, (message.answer is not None)

View file

@ -1,6 +1,5 @@
import asyncio import asyncio
from pulsar.schema import JsonSchema
import uuid import uuid
from aiohttp import WSMsgType from aiohttp import WSMsgType
@ -26,12 +25,12 @@ class DocumentEmbeddingsLoadEndpoint(SocketEndpoint):
self.publisher = Publisher( self.publisher = Publisher(
self.pulsar_client, document_embeddings_store_queue, self.pulsar_client, document_embeddings_store_queue,
schema=JsonSchema(DocumentEmbeddings) schema=DocumentEmbeddings
) )
async def start(self): async def start(self):
self.publisher.start() await self.publisher.start()
async def listener(self, ws, running): async def listener(self, ws, running):
@ -59,6 +58,6 @@ class DocumentEmbeddingsLoadEndpoint(SocketEndpoint):
], ],
) )
self.publisher.send(None, elt) await self.publisher.send(None, elt)
running.stop() running.stop()

View file

@ -1,7 +1,6 @@
import asyncio import asyncio
import queue import queue
from pulsar.schema import JsonSchema
import uuid import uuid
from .. schema import DocumentEmbeddings from .. schema import DocumentEmbeddings
@ -27,7 +26,7 @@ class DocumentEmbeddingsStreamEndpoint(SocketEndpoint):
self.subscriber = Subscriber( self.subscriber = Subscriber(
self.pulsar_client, document_embeddings_store_queue, self.pulsar_client, document_embeddings_store_queue,
"api-gateway", "api-gateway", "api-gateway", "api-gateway",
schema=JsonSchema(DocumentEmbeddings), schema=DocumentEmbeddings,
) )
async def listener(self, ws, running): async def listener(self, ws, running):
@ -44,17 +43,17 @@ class DocumentEmbeddingsStreamEndpoint(SocketEndpoint):
async def start(self): async def start(self):
self.subscriber.start() await self.subscriber.start()
async def async_thread(self, ws, running): async def async_thread(self, ws, running):
id = str(uuid.uuid4()) id = str(uuid.uuid4())
q = self.subscriber.subscribe_all(id) q = await self.subscriber.subscribe_all(id)
while running.get(): while running.get():
try: try:
resp = await asyncio.to_thread(q.get, timeout=0.5) resp = await asyncio.wait_for(q.get(), timeout=0.5)
await ws.send_json(serialize_document_embeddings(resp)) await ws.send_json(serialize_document_embeddings(resp))
except TimeoutError: except TimeoutError:
@ -67,7 +66,7 @@ class DocumentEmbeddingsStreamEndpoint(SocketEndpoint):
print(f"Exception: {str(e)}", flush=True) print(f"Exception: {str(e)}", flush=True)
break break
self.subscriber.unsubscribe_all(id) await self.subscriber.unsubscribe_all(id)
running.stop() running.stop()

View file

@ -1,13 +1,9 @@
import asyncio import asyncio
from pulsar.schema import JsonSchema
from aiohttp import web from aiohttp import web
import uuid import uuid
import logging import logging
from .. base import Publisher
from .. base import Subscriber
logger = logging.getLogger("endpoint") logger = logging.getLogger("endpoint")
logger.setLevel(logging.INFO) logger.setLevel(logging.INFO)

View file

@ -1,6 +1,5 @@
import asyncio import asyncio
from pulsar.schema import JsonSchema
import uuid import uuid
from aiohttp import WSMsgType from aiohttp import WSMsgType
@ -26,12 +25,12 @@ class GraphEmbeddingsLoadEndpoint(SocketEndpoint):
self.publisher = Publisher( self.publisher = Publisher(
self.pulsar_client, graph_embeddings_store_queue, self.pulsar_client, graph_embeddings_store_queue,
schema=JsonSchema(GraphEmbeddings) schema=GraphEmbeddings
) )
async def start(self): async def start(self):
self.publisher.start() await self.publisher.start()
async def listener(self, ws, running): async def listener(self, ws, running):
@ -60,6 +59,6 @@ class GraphEmbeddingsLoadEndpoint(SocketEndpoint):
] ]
) )
self.publisher.send(None, elt) await self.publisher.send(None, elt)
running.stop() running.stop()

View file

@ -1,7 +1,6 @@
import asyncio import asyncio
import queue import queue
from pulsar.schema import JsonSchema
import uuid import uuid
from .. schema import GraphEmbeddings from .. schema import GraphEmbeddings
@ -26,7 +25,7 @@ class GraphEmbeddingsStreamEndpoint(SocketEndpoint):
self.subscriber = Subscriber( self.subscriber = Subscriber(
self.pulsar_client, graph_embeddings_store_queue, self.pulsar_client, graph_embeddings_store_queue,
"api-gateway", "api-gateway", "api-gateway", "api-gateway",
schema=JsonSchema(GraphEmbeddings) schema=GraphEmbeddings
) )
async def listener(self, ws, running): async def listener(self, ws, running):
@ -41,17 +40,17 @@ class GraphEmbeddingsStreamEndpoint(SocketEndpoint):
async def start(self): async def start(self):
self.subscriber.start() await self.subscriber.start()
async def async_thread(self, ws, running): async def async_thread(self, ws, running):
id = str(uuid.uuid4()) id = str(uuid.uuid4())
q = self.subscriber.subscribe_all(id) q = await self.subscriber.subscribe_all(id)
while running.get(): while running.get():
try: try:
resp = await asyncio.to_thread(q.get, timeout=0.5) resp = await asyncio.wait_for(q.get, timeout=0.5)
await ws.send_json(serialize_graph_embeddings(resp)) await ws.send_json(serialize_graph_embeddings(resp))
except TimeoutError: except TimeoutError:
@ -64,7 +63,7 @@ class GraphEmbeddingsStreamEndpoint(SocketEndpoint):
print(f"Exception: {str(e)}", flush=True) print(f"Exception: {str(e)}", flush=True)
break break
self.subscriber.unsubscribe_all(id) await self.subscriber.unsubscribe_all(id)
running.stop() running.stop()

View file

@ -7,7 +7,6 @@
import aiohttp import aiohttp
from aiohttp import web from aiohttp import web
import asyncio import asyncio
from pulsar.schema import JsonSchema
import uuid import uuid
import logging import logging

View file

@ -1,7 +1,6 @@
import asyncio import asyncio
import queue import queue
from pulsar.schema import JsonSchema
import uuid import uuid
from aiohttp import web, WSMsgType from aiohttp import web, WSMsgType

View file

@ -1,6 +1,5 @@
import asyncio import asyncio
from pulsar.schema import JsonSchema
import uuid import uuid
import logging import logging
@ -23,21 +22,21 @@ class ServiceRequestor:
self.pub = Publisher( self.pub = Publisher(
pulsar_client, request_queue, pulsar_client, request_queue,
schema=JsonSchema(request_schema), schema=request_schema,
) )
self.sub = Subscriber( self.sub = Subscriber(
pulsar_client, response_queue, pulsar_client, response_queue,
subscription, consumer_name, subscription, consumer_name,
JsonSchema(response_schema) response_schema
) )
self.timeout = timeout self.timeout = timeout
async def start(self): async def start(self):
self.pub.start() await self.pub.start()
self.sub.start() await self.sub.start()
def to_request(self, request): def to_request(self, request):
raise RuntimeError("Not defined") raise RuntimeError("Not defined")
@ -51,18 +50,15 @@ class ServiceRequestor:
try: try:
q = self.sub.subscribe(id) q = await self.sub.subscribe(id)
await asyncio.to_thread( await self.pub.send(id, self.to_request(request))
self.pub.send, id, self.to_request(request)
)
while True: while True:
try: try:
resp = await asyncio.to_thread( resp = await asyncio.wait_for(
q.get, q.get(), timeout=self.timeout
timeout=self.timeout
) )
except Exception as e: except Exception as e:
raise RuntimeError("Timeout") raise RuntimeError("Timeout")
@ -99,5 +95,5 @@ class ServiceRequestor:
return err return err
finally: finally:
self.sub.unsubscribe(id) await self.sub.unsubscribe(id)

View file

@ -2,7 +2,6 @@
# Like ServiceRequestor, but just fire-and-forget instead of request/response # Like ServiceRequestor, but just fire-and-forget instead of request/response
import asyncio import asyncio
from pulsar.schema import JsonSchema
import uuid import uuid
import logging import logging
@ -21,12 +20,12 @@ class ServiceSender:
self.pub = Publisher( self.pub = Publisher(
pulsar_client, request_queue, pulsar_client, request_queue,
schema=JsonSchema(request_schema), schema=request_schema,
) )
async def start(self): async def start(self):
self.pub.start() await self.pub.start()
def to_request(self, request): def to_request(self, request):
raise RuntimeError("Not defined") raise RuntimeError("Not defined")
@ -35,9 +34,7 @@ class ServiceSender:
try: try:
await asyncio.to_thread( await self.pub.send(None, self.to_request(request))
self.pub.send, None, self.to_request(request)
)
if responder: if responder:
await responder({}, True) await responder({}, True)

View file

@ -3,7 +3,7 @@ API gateway. Offers HTTP services which are translated to interaction on the
Pulsar bus. Pulsar bus.
""" """
module = ".".join(__name__.split(".")[1:-1]) module = "api-gateway"
# FIXME: Subscribes to Pulsar unnecessarily, should only do it when there # FIXME: Subscribes to Pulsar unnecessarily, should only do it when there
# are active listeners # are active listeners
@ -19,7 +19,6 @@ import os
import base64 import base64
import pulsar import pulsar
from pulsar.schema import JsonSchema
from prometheus_client import start_http_server from prometheus_client import start_http_server
from .. log_level import LogLevel from .. log_level import LogLevel

View file

@ -1,6 +1,5 @@
import asyncio import asyncio
from pulsar.schema import JsonSchema
import uuid import uuid
from aiohttp import WSMsgType from aiohttp import WSMsgType
@ -24,12 +23,12 @@ class TriplesLoadEndpoint(SocketEndpoint):
self.publisher = Publisher( self.publisher = Publisher(
self.pulsar_client, triples_store_queue, self.pulsar_client, triples_store_queue,
schema=JsonSchema(Triples) schema=Triples
) )
async def start(self): async def start(self):
self.publisher.start() await self.publisher.start()
async def listener(self, ws, running): async def listener(self, ws, running):
@ -51,7 +50,7 @@ class TriplesLoadEndpoint(SocketEndpoint):
triples=to_subgraph(data["triples"]), triples=to_subgraph(data["triples"]),
) )
self.publisher.send(None, elt) await self.publisher.send(None, elt)
running.stop() running.stop()

View file

@ -1,7 +1,6 @@
import asyncio import asyncio
import queue import queue
from pulsar.schema import JsonSchema
import uuid import uuid
from .. schema import Triples from .. schema import Triples
@ -24,7 +23,7 @@ class TriplesStreamEndpoint(SocketEndpoint):
self.subscriber = Subscriber( self.subscriber = Subscriber(
self.pulsar_client, triples_store_queue, self.pulsar_client, triples_store_queue,
"api-gateway", "api-gateway", "api-gateway", "api-gateway",
schema=JsonSchema(Triples) schema=Triples
) )
async def listener(self, ws, running): async def listener(self, ws, running):
@ -39,7 +38,7 @@ class TriplesStreamEndpoint(SocketEndpoint):
async def start(self): async def start(self):
self.subscriber.start() await self.subscriber.start()
async def async_thread(self, ws, running): async def async_thread(self, ws, running):

View file

@ -1,295 +0,0 @@
from . clients.graph_embeddings_client import GraphEmbeddingsClient
from . clients.triples_query_client import TriplesQueryClient
from . clients.embeddings_client import EmbeddingsClient
from . clients.prompt_client import PromptClient
from . schema import GraphEmbeddingsRequest, GraphEmbeddingsResponse
from . schema import TriplesQueryRequest, TriplesQueryResponse
from . schema import prompt_request_queue
from . schema import prompt_response_queue
from . schema import embeddings_request_queue
from . schema import embeddings_response_queue
from . schema import graph_embeddings_request_queue
from . schema import graph_embeddings_response_queue
from . schema import triples_request_queue
from . schema import triples_response_queue
LABEL="http://www.w3.org/2000/01/rdf-schema#label"
DEFINITION="http://www.w3.org/2004/02/skos/core#definition"
class Query:
def __init__(
self, rag, user, collection, verbose,
entity_limit=50, triple_limit=30, max_subgraph_size=1000,
max_path_length=2,
):
self.rag = rag
self.user = user
self.collection = collection
self.verbose = verbose
self.entity_limit = entity_limit
self.triple_limit = triple_limit
self.max_subgraph_size = max_subgraph_size
self.max_path_length = max_path_length
def get_vector(self, query):
if self.verbose:
print("Compute embeddings...", flush=True)
qembeds = self.rag.embeddings.request(query)
if self.verbose:
print("Done.", flush=True)
return qembeds
def get_entities(self, query):
vectors = self.get_vector(query)
if self.verbose:
print("Get entities...", flush=True)
entities = self.rag.ge_client.request(
user=self.user, collection=self.collection,
vectors=vectors, limit=self.entity_limit,
)
entities = [
e.value
for e in entities
]
if self.verbose:
print("Entities:", flush=True)
for ent in entities:
print(" ", ent, flush=True)
return entities
def maybe_label(self, e):
if e in self.rag.label_cache:
return self.rag.label_cache[e]
res = self.rag.triples_client.request(
user=self.user, collection=self.collection,
s=e, p=LABEL, o=None, limit=1,
)
if len(res) == 0:
self.rag.label_cache[e] = e
return e
self.rag.label_cache[e] = res[0].o.value
return self.rag.label_cache[e]
def follow_edges(self, ent, subgraph, path_length):
# Not needed?
if path_length <= 0:
return
# Stop spanning around if the subgraph is already maxed out
if len(subgraph) >= self.max_subgraph_size:
return
res = self.rag.triples_client.request(
user=self.user, collection=self.collection,
s=ent, p=None, o=None,
limit=self.triple_limit
)
for triple in res:
subgraph.add(
(triple.s.value, triple.p.value, triple.o.value)
)
if path_length > 1:
self.follow_edges(triple.o.value, subgraph, path_length-1)
res = self.rag.triples_client.request(
user=self.user, collection=self.collection,
s=None, p=ent, o=None,
limit=self.triple_limit
)
for triple in res:
subgraph.add(
(triple.s.value, triple.p.value, triple.o.value)
)
res = self.rag.triples_client.request(
user=self.user, collection=self.collection,
s=None, p=None, o=ent,
limit=self.triple_limit,
)
for triple in res:
subgraph.add(
(triple.s.value, triple.p.value, triple.o.value)
)
if path_length > 1:
self.follow_edges(triple.s.value, subgraph, path_length-1)
def get_subgraph(self, query):
entities = self.get_entities(query)
if self.verbose:
print("Get subgraph...", flush=True)
subgraph = set()
for ent in entities:
self.follow_edges(ent, subgraph, self.max_path_length)
subgraph = list(subgraph)
return subgraph
def get_labelgraph(self, query):
subgraph = self.get_subgraph(query)
sg2 = []
for edge in subgraph:
if edge[1] == LABEL:
continue
s = self.maybe_label(edge[0])
p = self.maybe_label(edge[1])
o = self.maybe_label(edge[2])
sg2.append((s, p, o))
sg2 = sg2[0:self.max_subgraph_size]
if self.verbose:
print("Subgraph:", flush=True)
for edge in sg2:
print(" ", str(edge), flush=True)
if self.verbose:
print("Done.", flush=True)
return sg2
class GraphRag:
def __init__(
self,
pulsar_host="pulsar://pulsar:6650",
pulsar_api_key=None,
pr_request_queue=None,
pr_response_queue=None,
emb_request_queue=None,
emb_response_queue=None,
ge_request_queue=None,
ge_response_queue=None,
tpl_request_queue=None,
tpl_response_queue=None,
verbose=False,
module="test",
):
self.verbose=verbose
if pr_request_queue is None:
pr_request_queue = prompt_request_queue
if pr_response_queue is None:
pr_response_queue = prompt_response_queue
if emb_request_queue is None:
emb_request_queue = embeddings_request_queue
if emb_response_queue is None:
emb_response_queue = embeddings_response_queue
if ge_request_queue is None:
ge_request_queue = graph_embeddings_request_queue
if ge_response_queue is None:
ge_response_queue = graph_embeddings_response_queue
if tpl_request_queue is None:
tpl_request_queue = triples_request_queue
if tpl_response_queue is None:
tpl_response_queue = triples_response_queue
if self.verbose:
print("Initialising...", flush=True)
self.ge_client = GraphEmbeddingsClient(
pulsar_host=pulsar_host,
pulsar_api_key=pulsar_api_key,
subscriber=module + "-ge",
input_queue=ge_request_queue,
output_queue=ge_response_queue,
)
self.triples_client = TriplesQueryClient(
pulsar_host=pulsar_host,
pulsar_api_key=pulsar_api_key,
subscriber=module + "-tpl",
input_queue=tpl_request_queue,
output_queue=tpl_response_queue
)
self.embeddings = EmbeddingsClient(
pulsar_host=pulsar_host,
pulsar_api_key=pulsar_api_key,
input_queue=emb_request_queue,
output_queue=emb_response_queue,
subscriber=module + "-emb",
)
self.label_cache = {}
self.prompt = PromptClient(
pulsar_host=pulsar_host,
pulsar_api_key=pulsar_api_key,
input_queue=pr_request_queue,
output_queue=pr_response_queue,
subscriber=module + "-prompt",
)
if self.verbose:
print("Initialised", flush=True)
def query(
self, query, user="trustgraph", collection="default",
entity_limit=50, triple_limit=30, max_subgraph_size=1000,
max_path_length=2,
):
if self.verbose:
print("Construct prompt...", flush=True)
q = Query(
rag=self, user=user, collection=collection, verbose=self.verbose,
entity_limit=entity_limit, triple_limit=triple_limit,
max_subgraph_size=max_subgraph_size,
max_path_length=max_path_length,
)
kg = q.get_labelgraph(query)
if self.verbose:
print("Invoke LLM...", flush=True)
print(kg)
print(query)
resp = self.prompt.request_kg_prompt(query, kg)
if self.verbose:
print("Done", flush=True)
return resp

View file

@ -35,7 +35,7 @@ from .. exceptions import RequestError
from . librarian import Librarian from . librarian import Librarian
module = ".".join(__name__.split(".")[1:-1]) module = "librarian"
default_input_queue = librarian_request_queue default_input_queue = librarian_request_queue
default_output_queue = librarian_response_queue default_output_queue = librarian_response_queue

View file

@ -10,12 +10,11 @@ from .. schema import text_completion_response_queue
from .. log_level import LogLevel from .. log_level import LogLevel
from .. base import Consumer from .. base import Consumer
module = ".".join(__name__.split(".")[1:-1]) module = "metering"
default_input_queue = text_completion_response_queue default_input_queue = text_completion_response_queue
default_subscriber = module default_subscriber = module
class Processor(Consumer): class Processor(Consumer):
def __init__(self, **params): def __init__(self, **params):

View file

@ -27,7 +27,7 @@ from .... clients.llm_client import LlmClient
from . prompts import to_definitions, to_relationships, to_topics from . prompts import to_definitions, to_relationships, to_topics
from . prompts import to_kg_query, to_document_query, to_rows from . prompts import to_kg_query, to_document_query, to_rows
module = ".".join(__name__.split(".")[1:-1]) module = "prompt"
default_input_queue = prompt_request_queue default_input_queue = prompt_request_queue
default_output_queue = prompt_response_queue default_output_queue = prompt_response_queue

View file

@ -4,8 +4,6 @@ import json
from jsonschema import validate from jsonschema import validate
import re import re
from trustgraph.clients.llm_client import LlmClient
class PromptConfiguration: class PromptConfiguration:
def __init__(self, system_template, global_terms={}, prompts={}): def __init__(self, system_template, global_terms={}, prompts={}):
self.system_template = system_template self.system_template = system_template
@ -21,8 +19,7 @@ class Prompt:
class PromptManager: class PromptManager:
def __init__(self, llm, config): def __init__(self, config):
self.llm = llm
self.config = config self.config = config
self.terms = config.global_terms self.terms = config.global_terms
@ -54,7 +51,9 @@ class PromptManager:
return json.loads(json_str) return json.loads(json_str)
def invoke(self, id, input): async def invoke(self, id, input, llm):
print("Invoke...", flush=True)
if id not in self.prompts: if id not in self.prompts:
raise RuntimeError("ID invalid") raise RuntimeError("ID invalid")
@ -68,9 +67,7 @@ class PromptManager:
"prompt": self.templates[id].render(terms) "prompt": self.templates[id].render(terms)
} }
resp = self.llm.request(**prompt) resp = await llm(**prompt)
print(resp, flush=True)
if resp_type == "text": if resp_type == "text":
return resp return resp
@ -81,13 +78,13 @@ class PromptManager:
try: try:
obj = self.parse_json(resp) obj = self.parse_json(resp)
except: except:
print("Parse fail:", resp, flush=True)
raise RuntimeError("JSON parse fail") raise RuntimeError("JSON parse fail")
print(obj, flush=True)
if self.prompts[id].schema: if self.prompts[id].schema:
try: try:
print(self.prompts[id].schema)
validate(instance=obj, schema=self.prompts[id].schema) validate(instance=obj, schema=self.prompts[id].schema)
print("Validated", flush=True)
except Exception as e: except Exception as e:
raise RuntimeError(f"Schema validation fail: {e}") raise RuntimeError(f"Schema validation fail: {e}")

View file

@ -3,6 +3,7 @@
Language service abstracts prompt engineering from LLM. Language service abstracts prompt engineering from LLM.
""" """
import asyncio
import json import json
import re import re
@ -10,74 +11,59 @@ from .... schema import Definition, Relationship, Triple
from .... schema import Topic from .... schema import Topic
from .... schema import PromptRequest, PromptResponse, Error from .... schema import PromptRequest, PromptResponse, Error
from .... schema import TextCompletionRequest, TextCompletionResponse from .... schema import TextCompletionRequest, TextCompletionResponse
from .... schema import text_completion_request_queue
from .... schema import text_completion_response_queue from .... base import FlowProcessor
from .... schema import prompt_request_queue, prompt_response_queue from .... base import ProducerSpec, ConsumerSpec, TextCompletionClientSpec
from .... base import ConsumerProducer
from .... clients.llm_client import LlmClient
from . prompt_manager import PromptConfiguration, Prompt, PromptManager from . prompt_manager import PromptConfiguration, Prompt, PromptManager
module = ".".join(__name__.split(".")[1:-1]) default_ident = "prompt"
default_input_queue = prompt_request_queue class Processor(FlowProcessor):
default_output_queue = prompt_response_queue
default_subscriber = module
class Processor(ConsumerProducer):
def __init__(self, **params): def __init__(self, **params):
input_queue = params.get("input_queue", default_input_queue) id = params.get("id")
output_queue = params.get("output_queue", default_output_queue)
subscriber = params.get("subscriber", default_subscriber)
tc_request_queue = params.get(
"text_completion_request_queue", text_completion_request_queue
)
tc_response_queue = params.get(
"text_completion_response_queue", text_completion_response_queue
)
# Config key for prompts
self.config_key = params.get("config_type", "prompt") self.config_key = params.get("config_type", "prompt")
super(Processor, self).__init__( super(Processor, self).__init__(
**params | { **params | {
"input_queue": input_queue, "id": id,
"output_queue": output_queue,
"subscriber": subscriber,
"input_schema": PromptRequest,
"output_schema": PromptResponse,
"text_completion_request_queue": tc_request_queue,
"text_completion_response_queue": tc_response_queue,
} }
) )
self.llm = LlmClient( self.register_specification(
subscriber=subscriber, ConsumerSpec(
input_queue=tc_request_queue, name = "request",
output_queue=tc_response_queue, schema = PromptRequest,
pulsar_host = self.pulsar_host, handler = self.on_request
pulsar_api_key=self.pulsar_api_key, )
) )
# System prompt hack self.register_specification(
class Llm: TextCompletionClientSpec(
def __init__(self, llm): request_name = "text-completion-request",
self.llm = llm response_name = "text-completion-response",
def request(self, system, prompt): )
print(system) )
print(prompt, flush=True)
return self.llm.request(system, prompt)
self.llm = Llm(self.llm) self.register_specification(
ProducerSpec(
name = "response",
schema = PromptResponse
)
)
self.register_config_handler(self.on_prompt_config)
# Null configuration, should reload quickly # Null configuration, should reload quickly
self.manager = PromptManager( self.manager = PromptManager(
llm = self.llm,
config = PromptConfiguration("", {}, {}) config = PromptConfiguration("", {}, {})
) )
async def on_config(self, version, config): async def on_prompt_config(self, config, version):
print("Loading configuration version", version) print("Loading configuration version", version)
@ -111,7 +97,6 @@ class Processor(ConsumerProducer):
) )
self.manager = PromptManager( self.manager = PromptManager(
self.llm,
PromptConfiguration( PromptConfiguration(
system, system,
{}, {},
@ -126,7 +111,7 @@ class Processor(ConsumerProducer):
print("Exception:", e, flush=True) print("Exception:", e, flush=True)
print("Configuration reload failed", flush=True) print("Configuration reload failed", flush=True)
async def handle(self, msg): async def on_request(self, msg, consumer, flow):
v = msg.value() v = msg.value()
@ -138,7 +123,7 @@ class Processor(ConsumerProducer):
try: try:
print(v.terms) print(v.terms, flush=True)
input = { input = {
k: json.loads(v) k: json.loads(v)
@ -146,14 +131,33 @@ class Processor(ConsumerProducer):
} }
print(f"Handling kind {kind}...", flush=True) print(f"Handling kind {kind}...", flush=True)
print(input, flush=True)
resp = self.manager.invoke(kind, input) async def llm(system, prompt):
print(system, flush=True)
print(prompt, flush=True)
resp = await flow("text-completion-request").text_completion(
system = system, prompt = prompt,
)
try:
return resp
except Exception as e:
print("LLM Exception:", e, flush=True)
return None
try:
resp = await self.manager.invoke(kind, input, llm)
except Exception as e:
print("Invocation exception:", e, flush=True)
raise e
print(resp, flush=True)
if isinstance(resp, str): if isinstance(resp, str):
print("Send text response...", flush=True) print("Send text response...", flush=True)
print(resp, flush=True)
r = PromptResponse( r = PromptResponse(
text=resp, text=resp,
@ -161,7 +165,7 @@ class Processor(ConsumerProducer):
error=None, error=None,
) )
await self.send(r, properties={"id": id}) await flow("response").send(r, properties={"id": id})
return return
@ -176,13 +180,13 @@ class Processor(ConsumerProducer):
error=None, error=None,
) )
await self.send(r, properties={"id": id}) await flow("response").send(r, properties={"id": id})
return return
except Exception as e: except Exception as e:
print(f"Exception: {e}") print(f"Exception: {e}", flush=True)
print("Send error response...", flush=True) print("Send error response...", flush=True)
@ -194,11 +198,11 @@ class Processor(ConsumerProducer):
response=None, response=None,
) )
await self.send(r, properties={"id": id}) await flow("response").send(r, properties={"id": id})
except Exception as e: except Exception as e:
print(f"Exception: {e}") print(f"Exception: {e}", flush=True)
print("Send error response...", flush=True) print("Send error response...", flush=True)
@ -215,22 +219,7 @@ class Processor(ConsumerProducer):
@staticmethod @staticmethod
def add_args(parser): def add_args(parser):
ConsumerProducer.add_args( FlowProcessor.add_args(parser)
parser, default_input_queue, default_subscriber,
default_output_queue,
)
parser.add_argument(
'--text-completion-request-queue',
default=text_completion_request_queue,
help=f'Text completion request queue (default: {text_completion_request_queue})',
)
parser.add_argument(
'--text-completion-response-queue',
default=text_completion_response_queue,
help=f'Text completion response queue (default: {text_completion_response_queue})',
)
parser.add_argument( parser.add_argument(
'--config-type', '--config-type',
@ -240,5 +229,5 @@ class Processor(ConsumerProducer):
def run(): def run():
Processor.launch(module, __doc__) Processor.launch(default_ident, __doc__)

View file

@ -16,7 +16,7 @@ from .... log_level import LogLevel
from .... base import ConsumerProducer from .... base import ConsumerProducer
from .... exceptions import TooManyRequests from .... exceptions import TooManyRequests
module = ".".join(__name__.split(".")[1:-1]) module = "text-completion"
default_input_queue = text_completion_request_queue default_input_queue = text_completion_request_queue
default_output_queue = text_completion_response_queue default_output_queue = text_completion_response_queue

View file

@ -16,7 +16,7 @@ from .... log_level import LogLevel
from .... base import ConsumerProducer from .... base import ConsumerProducer
from .... exceptions import TooManyRequests from .... exceptions import TooManyRequests
module = ".".join(__name__.split(".")[1:-1]) module = "text-completion"
default_input_queue = text_completion_request_queue default_input_queue = text_completion_request_queue
default_output_queue = text_completion_response_queue default_output_queue = text_completion_response_queue

View file

@ -15,7 +15,7 @@ from .... log_level import LogLevel
from .... base import ConsumerProducer from .... base import ConsumerProducer
from .... exceptions import TooManyRequests from .... exceptions import TooManyRequests
module = ".".join(__name__.split(".")[1:-1]) module = "text-completion"
default_input_queue = text_completion_request_queue default_input_queue = text_completion_request_queue
default_output_queue = text_completion_response_queue default_output_queue = text_completion_response_queue

View file

@ -15,7 +15,7 @@ from .... log_level import LogLevel
from .... base import ConsumerProducer from .... base import ConsumerProducer
from .... exceptions import TooManyRequests from .... exceptions import TooManyRequests
module = ".".join(__name__.split(".")[1:-1]) module = "text-completion"
default_input_queue = text_completion_request_queue default_input_queue = text_completion_request_queue
default_output_queue = text_completion_response_queue default_output_queue = text_completion_response_queue

View file

@ -17,7 +17,7 @@ from .... log_level import LogLevel
from .... base import ConsumerProducer from .... base import ConsumerProducer
from .... exceptions import TooManyRequests from .... exceptions import TooManyRequests
module = ".".join(__name__.split(".")[1:-1]) module = "text-completion"
default_input_queue = text_completion_request_queue default_input_queue = text_completion_request_queue
default_output_queue = text_completion_response_queue default_output_queue = text_completion_response_queue

View file

@ -14,7 +14,7 @@ from .... log_level import LogLevel
from .... base import ConsumerProducer from .... base import ConsumerProducer
from .... exceptions import TooManyRequests from .... exceptions import TooManyRequests
module = ".".join(__name__.split(".")[1:-1]) module = "text-completion"
default_input_queue = text_completion_request_queue default_input_queue = text_completion_request_queue
default_output_queue = text_completion_response_queue default_output_queue = text_completion_response_queue

View file

@ -15,7 +15,7 @@ from .... log_level import LogLevel
from .... base import ConsumerProducer from .... base import ConsumerProducer
from .... exceptions import TooManyRequests from .... exceptions import TooManyRequests
module = ".".join(__name__.split(".")[1:-1]) module = "text-completion"
default_input_queue = text_completion_request_queue default_input_queue = text_completion_request_queue
default_output_queue = text_completion_response_queue default_output_queue = text_completion_response_queue

View file

@ -15,7 +15,7 @@ from .... log_level import LogLevel
from .... base import ConsumerProducer from .... base import ConsumerProducer
from .... exceptions import TooManyRequests from .... exceptions import TooManyRequests
module = ".".join(__name__.split(".")[1:-1]) module = "text-completion"
default_input_queue = text_completion_request_queue default_input_queue = text_completion_request_queue
default_output_queue = text_completion_response_queue default_output_queue = text_completion_response_queue

View file

@ -15,7 +15,7 @@ from .... log_level import LogLevel
from .... base import ConsumerProducer from .... base import ConsumerProducer
from .... exceptions import TooManyRequests from .... exceptions import TooManyRequests
module = ".".join(__name__.split(".")[1:-1]) module = "text-completion"
default_input_queue = text_completion_request_queue default_input_queue = text_completion_request_queue
default_output_queue = text_completion_response_queue default_output_queue = text_completion_response_queue

View file

@ -15,7 +15,7 @@ from .... log_level import LogLevel
from .... base import ConsumerProducer from .... base import ConsumerProducer
from .... exceptions import TooManyRequests from .... exceptions import TooManyRequests
module = ".".join(__name__.split(".")[1:-1]) module = "text-completion"
default_input_queue = text_completion_request_queue default_input_queue = text_completion_request_queue
default_output_queue = text_completion_response_queue default_output_queue = text_completion_response_queue

View file

@ -11,7 +11,7 @@ from .... schema import document_embeddings_request_queue
from .... schema import document_embeddings_response_queue from .... schema import document_embeddings_response_queue
from .... base import ConsumerProducer from .... base import ConsumerProducer
module = ".".join(__name__.split(".")[1:-1]) module = "de-query"
default_input_queue = document_embeddings_request_queue default_input_queue = document_embeddings_request_queue
default_output_queue = document_embeddings_response_queue default_output_queue = document_embeddings_response_queue

View file

@ -16,7 +16,7 @@ from .... schema import document_embeddings_request_queue
from .... schema import document_embeddings_response_queue from .... schema import document_embeddings_response_queue
from .... base import ConsumerProducer from .... base import ConsumerProducer
module = ".".join(__name__.split(".")[1:-1]) module = "de-query"
default_input_queue = document_embeddings_request_queue default_input_queue = document_embeddings_request_queue
default_output_queue = document_embeddings_response_queue default_output_queue = document_embeddings_response_queue

View file

@ -7,71 +7,51 @@ of chunks
from qdrant_client import QdrantClient from qdrant_client import QdrantClient
from qdrant_client.models import PointStruct from qdrant_client.models import PointStruct
from qdrant_client.models import Distance, VectorParams from qdrant_client.models import Distance, VectorParams
import uuid
from .... schema import DocumentEmbeddingsRequest, DocumentEmbeddingsResponse from .... schema import DocumentEmbeddingsResponse
from .... schema import Error, Value from .... schema import Error, Value
from .... schema import document_embeddings_request_queue from .... base import DocumentEmbeddingsQueryService
from .... schema import document_embeddings_response_queue
from .... base import ConsumerProducer
module = ".".join(__name__.split(".")[1:-1]) default_ident = "de-query"
default_input_queue = document_embeddings_request_queue
default_output_queue = document_embeddings_response_queue
default_subscriber = module
default_store_uri = 'http://localhost:6333' default_store_uri = 'http://localhost:6333'
class Processor(ConsumerProducer): class Processor(DocumentEmbeddingsQueryService):
def __init__(self, **params): def __init__(self, **params):
input_queue = params.get("input_queue", default_input_queue)
output_queue = params.get("output_queue", default_output_queue)
subscriber = params.get("subscriber", default_subscriber)
store_uri = params.get("store_uri", default_store_uri) store_uri = params.get("store_uri", default_store_uri)
#optional api key #optional api key
api_key = params.get("api_key", None) api_key = params.get("api_key", None)
super(Processor, self).__init__( super(Processor, self).__init__(
**params | { **params | {
"input_queue": input_queue,
"output_queue": output_queue,
"subscriber": subscriber,
"input_schema": DocumentEmbeddingsRequest,
"output_schema": DocumentEmbeddingsResponse,
"store_uri": store_uri, "store_uri": store_uri,
"api_key": api_key, "api_key": api_key,
} }
) )
self.client = QdrantClient(url=store_uri, api_key=api_key) self.qdrant = QdrantClient(url=store_uri, api_key=api_key)
async def handle(self, msg): async def query_document_embeddings(self, msg):
try: try:
v = msg.value()
# Sender-produced ID
id = msg.properties()["id"]
print(f"Handling input {id}...", flush=True)
chunks = [] chunks = []
for vec in v.vectors: for vec in msg.vectors:
dim = len(vec) dim = len(vec)
collection = ( collection = (
"d_" + v.user + "_" + v.collection + "_" + "d_" + msg.user + "_" + msg.collection + "_" +
str(dim) str(dim)
) )
search_result = self.client.query_points( search_result = self.qdrant.query_points(
collection_name=collection, collection_name=collection,
query=vec, query=vec,
limit=v.limit, limit=msg.limit,
with_payload=True, with_payload=True,
).points ).points
@ -79,37 +59,17 @@ class Processor(ConsumerProducer):
ent = r.payload["doc"] ent = r.payload["doc"]
chunks.append(ent) chunks.append(ent)
print("Send response...", flush=True) return chunks
r = DocumentEmbeddingsResponse(documents=chunks, error=None)
await self.send(r, properties={"id": id})
print("Done.", flush=True)
except Exception as e: except Exception as e:
print(f"Exception: {e}") print(f"Exception: {e}")
raise e
print("Send error response...", flush=True)
r = DocumentEmbeddingsResponse(
error=Error(
type = "llm-error",
message = str(e),
),
documents=None,
)
await self.send(r, properties={"id": id})
self.consumer.acknowledge(msg)
@staticmethod @staticmethod
def add_args(parser): def add_args(parser):
ConsumerProducer.add_args( DocumentEmbeddingsQueryService.add_args(parser)
parser, default_input_queue, default_subscriber,
default_output_queue,
)
parser.add_argument( parser.add_argument(
'-t', '--store-uri', '-t', '--store-uri',
@ -125,5 +85,5 @@ class Processor(ConsumerProducer):
def run(): def run():
Processor.launch(module, __doc__) Processor.launch(default_ident, __doc__)

Some files were not shown because too many files have changed in this diff Show more