Revert "Feature/configure flows (#345)"

This reverts commit a9197d11ee.
This commit is contained in:
Cyber MacGeddon 2025-04-25 19:02:08 +01:00
parent 3adb3cf59c
commit 1822ca395f
125 changed files with 2628 additions and 3751 deletions

View file

@ -8,11 +8,12 @@ logger = logging.getLogger(__name__)
class AgentManager:
def __init__(self, tools, additional_context=None):
def __init__(self, context, tools, additional_context=None):
self.context = context
self.tools = tools
self.additional_context = additional_context
async def reason(self, question, history, context):
def reason(self, question, history):
tools = self.tools
@ -55,7 +56,10 @@ class AgentManager:
logger.info(f"prompt: {variables}")
obj = await context("prompt-request").agent_react(variables)
obj = self.context.prompt.request(
"agent-react",
variables
)
print(json.dumps(obj, indent=4), flush=True)
@ -81,13 +85,9 @@ class AgentManager:
return a
async def react(self, question, history, think, observe, context):
async def react(self, question, history, think, observe):
act = await self.reason(
question = question,
history = history,
context = context,
)
act = self.reason(question, history)
logger.info(f"act: {act}")
if isinstance(act, Final):
@ -104,12 +104,7 @@ class AgentManager:
else:
raise RuntimeError(f"No action for {act.name}!")
print("TOOL>>>", act)
resp = await action.implementation(context).invoke(
**act.arguments
)
print("RSETUL", resp)
resp = action.implementation.invoke(**act.arguments)
resp = resp.strip()

View file

@ -6,68 +6,103 @@ import json
import re
import sys
from ... base import AgentService, TextCompletionClientSpec, PromptClientSpec
from ... base import GraphRagClientSpec
from pulsar.schema import JsonSchema
from ... schema import AgentRequest, AgentResponse, AgentStep, Error
from ... base import ConsumerProducer
from ... schema import Error
from ... schema import AgentRequest, AgentResponse, AgentStep
from ... schema import agent_request_queue, agent_response_queue
from ... schema import prompt_request_queue as pr_request_queue
from ... schema import prompt_response_queue as pr_response_queue
from ... schema import graph_rag_request_queue as gr_request_queue
from ... schema import graph_rag_response_queue as gr_response_queue
from ... clients.prompt_client import PromptClient
from ... clients.llm_client import LlmClient
from ... clients.graph_rag_client import GraphRagClient
from . tools import KnowledgeQueryImpl, TextCompletionImpl
from . agent_manager import AgentManager
from . types import Final, Action, Tool, Argument
default_ident = "agent-manager"
default_max_iterations = 10
module = ".".join(__name__.split(".")[1:-1])
class Processor(AgentService):
default_input_queue = agent_request_queue
default_output_queue = agent_response_queue
default_subscriber = module
default_max_iterations = 15
class Processor(ConsumerProducer):
def __init__(self, **params):
id = params.get("id")
self.max_iterations = int(
params.get("max_iterations", default_max_iterations)
)
tools = {}
input_queue = params.get("input_queue", default_input_queue)
output_queue = params.get("output_queue", default_output_queue)
subscriber = params.get("subscriber", default_subscriber)
prompt_request_queue = params.get(
"prompt_request_queue", pr_request_queue
)
prompt_response_queue = params.get(
"prompt_response_queue", pr_response_queue
)
graph_rag_request_queue = params.get(
"graph_rag_request_queue", gr_request_queue
)
graph_rag_response_queue = params.get(
"graph_rag_response_queue", gr_response_queue
)
self.config_key = params.get("config_type", "agent")
super(Processor, self).__init__(
**params | {
"id": id,
"max_iterations": self.max_iterations,
"config_type": self.config_key,
"input_queue": input_queue,
"output_queue": output_queue,
"subscriber": subscriber,
"input_schema": AgentRequest,
"output_schema": AgentResponse,
"prompt_request_queue": prompt_request_queue,
"prompt_response_queue": prompt_response_queue,
"graph_rag_request_queue": gr_request_queue,
"graph_rag_response_queue": gr_response_queue,
}
)
self.prompt = PromptClient(
subscriber=subscriber,
input_queue=prompt_request_queue,
output_queue=prompt_response_queue,
pulsar_host = self.pulsar_host,
pulsar_api_key=self.pulsar_api_key,
)
self.graph_rag = GraphRagClient(
subscriber=subscriber,
input_queue=graph_rag_request_queue,
output_queue=graph_rag_response_queue,
pulsar_host = self.pulsar_host,
pulsar_api_key=self.pulsar_api_key,
)
# Need to be able to feed requests to myself
self.recursive_input = self.client.create_producer(
topic=input_queue,
schema=JsonSchema(AgentRequest),
)
self.agent = AgentManager(
context=self,
tools=[],
additional_context="",
)
self.config_handlers.append(self.on_tools_config)
self.register_specification(
TextCompletionClientSpec(
request_name = "text-completion-request",
response_name = "text-completion-response",
)
)
self.register_specification(
GraphRagClientSpec(
request_name = "graph-rag-request",
response_name = "graph-rag-response",
)
)
self.register_specification(
PromptClientSpec(
request_name = "prompt-request",
response_name = "prompt-response",
)
)
async def on_tools_config(self, config, version):
async def on_config(self, version, config):
print("Loading configuration version", version)
@ -103,9 +138,9 @@ class Processor(AgentService):
impl_id = data.get("type")
if impl_id == "knowledge-query":
impl = KnowledgeQueryImpl
impl = KnowledgeQueryImpl(self)
elif impl_id == "text-completion":
impl = TextCompletionImpl
impl = TextCompletionImpl(self)
else:
raise RuntimeError(
f"Tool-kind {impl_id} not known"
@ -120,6 +155,7 @@ class Processor(AgentService):
)
self.agent = AgentManager(
context=self,
tools=tools,
additional_context=additional
)
@ -128,14 +164,19 @@ class Processor(AgentService):
except Exception as e:
print("on_tools_config Exception:", e, flush=True)
print("Exception:", e, flush=True)
print("Configuration reload failed", flush=True)
async def agent_request(self, request, respond, next, flow):
async def handle(self, msg):
try:
if request.history:
v = msg.value()
# Sender-produced ID
id = msg.properties()["id"]
if v.history:
history = [
Action(
thought=h.thought,
@ -143,12 +184,12 @@ class Processor(AgentService):
arguments=h.arguments,
observation=h.observation
)
for h in request.history
for h in v.history
]
else:
history = []
print(f"Question: {request.question}", flush=True)
print(f"Question: {v.question}", flush=True)
if len(history) >= self.max_iterations:
raise RuntimeError("Too many agent iterations")
@ -166,7 +207,7 @@ class Processor(AgentService):
observation=None,
)
await respond(r)
await self.send(r, properties={"id": id})
async def observe(x):
@ -179,21 +220,15 @@ class Processor(AgentService):
observation=x,
)
await respond(r)
await self.send(r, properties={"id": id})
act = await self.agent.react(
question = request.question,
history = history,
think = think,
observe = observe,
context = flow,
)
act = await self.agent.react(v.question, history, think, observe)
print(f"Action: {act}", flush=True)
if isinstance(act, Final):
print("Send response...", flush=True)
print("Send final response...", flush=True)
if type(act) == Final:
r = AgentResponse(
answer=act.final,
@ -201,20 +236,18 @@ class Processor(AgentService):
thought=None,
)
await respond(r)
await self.send(r, properties={"id": id})
print("Done.", flush=True)
return
print("Send next...", flush=True)
history.append(act)
r = AgentRequest(
question=request.question,
plan=request.plan,
state=request.state,
question=v.question,
plan=v.plan,
state=v.state,
history=[
AgentStep(
thought=h.thought,
@ -226,7 +259,7 @@ class Processor(AgentService):
]
)
await next(r)
self.recursive_input.send(r, properties={"id": id})
print("Done.", flush=True)
@ -234,7 +267,7 @@ class Processor(AgentService):
except Exception as e:
print(f"agent_request Exception: {e}")
print(f"Exception: {e}")
print("Send error response...", flush=True)
@ -246,12 +279,39 @@ class Processor(AgentService):
response=None,
)
await respond(r)
await self.send(r, properties={"id": id})
@staticmethod
def add_args(parser):
AgentService.add_args(parser)
ConsumerProducer.add_args(
parser, default_input_queue, default_subscriber,
default_output_queue,
)
parser.add_argument(
'--prompt-request-queue',
default=pr_request_queue,
help=f'Prompt request queue (default: {pr_request_queue})',
)
parser.add_argument(
'--prompt-response-queue',
default=pr_response_queue,
help=f'Prompt response queue (default: {pr_response_queue})',
)
parser.add_argument(
'--graph-rag-request-queue',
default=gr_request_queue,
help=f'Graph RAG request queue (default: {gr_request_queue})',
)
parser.add_argument(
'--graph-rag-response-queue',
default=gr_response_queue,
help=f'Graph RAG response queue (default: {gr_response_queue})',
)
parser.add_argument(
'--max-iterations',
@ -267,5 +327,5 @@ class Processor(AgentService):
def run():
Processor.launch(default_ident, __doc__)
Processor.launch(module, __doc__)

View file

@ -4,22 +4,16 @@
class KnowledgeQueryImpl:
def __init__(self, context):
self.context = context
async def invoke(self, **arguments):
client = self.context("graph-rag-request")
print("Graph RAG question...", flush=True)
return await client.rag(
arguments.get("question")
)
def invoke(self, **arguments):
return self.context.graph_rag.request(arguments.get("question"))
# This tool implementation knows how to do text completion. This uses
# the prompt service, rather than talking to TextCompletion directly.
class TextCompletionImpl:
def __init__(self, context):
self.context = context
async def invoke(self, **arguments):
client = self.context("prompt-request")
print("Prompt question...", flush=True)
return await client.question(
arguments.get("question")
def invoke(self, **arguments):
return self.context.prompt.request(
"question", { "question": arguments.get("question") }
)

View file

@ -7,27 +7,40 @@ as text as separate output objects.
from langchain_text_splitters import RecursiveCharacterTextSplitter
from prometheus_client import Histogram
from ... schema import TextDocument, Chunk
from ... base import FlowProcessor, ConsumerSpec, ProducerSpec
from ... schema import TextDocument, Chunk, Metadata
from ... schema import text_ingest_queue, chunk_ingest_queue
from ... log_level import LogLevel
from ... base import ConsumerProducer
default_ident = "chunker"
module = ".".join(__name__.split(".")[1:-1])
class Processor(FlowProcessor):
default_input_queue = text_ingest_queue
default_output_queue = chunk_ingest_queue
default_subscriber = module
class Processor(ConsumerProducer):
def __init__(self, **params):
id = params.get("id", default_ident)
input_queue = params.get("input_queue", default_input_queue)
output_queue = params.get("output_queue", default_output_queue)
subscriber = params.get("subscriber", default_subscriber)
chunk_size = params.get("chunk_size", 2000)
chunk_overlap = params.get("chunk_overlap", 100)
super(Processor, self).__init__(
**params | { "id": id }
**params | {
"input_queue": input_queue,
"output_queue": output_queue,
"subscriber": subscriber,
"input_schema": TextDocument,
"output_schema": Chunk,
}
)
if not hasattr(__class__, "chunk_metric"):
__class__.chunk_metric = Histogram(
'chunk_size', 'Chunk size',
["id", "flow"],
buckets=[100, 160, 250, 400, 650, 1000, 1600,
2500, 4000, 6400, 10000, 16000]
)
@ -39,24 +52,7 @@ class Processor(FlowProcessor):
is_separator_regex=False,
)
self.register_specification(
ConsumerSpec(
name = "input",
schema = TextDocument,
handler = self.on_message,
)
)
self.register_specification(
ProducerSpec(
name = "output",
schema = Chunk,
)
)
print("Chunker initialised", flush=True)
async def on_message(self, msg, consumer, flow):
async def handle(self, msg):
v = msg.value()
print(f"Chunking {v.metadata.id}...", flush=True)
@ -67,25 +63,24 @@ class Processor(FlowProcessor):
for ix, chunk in enumerate(texts):
print("Chunk", len(chunk.page_content), flush=True)
r = Chunk(
metadata=v.metadata,
chunk=chunk.page_content.encode("utf-8"),
)
__class__.chunk_metric.labels(
id=consumer.id, flow=consumer.flow
).observe(len(chunk.page_content))
__class__.chunk_metric.observe(len(chunk.page_content))
await flow("output").send(r)
await self.send(r)
print("Done.", flush=True)
@staticmethod
def add_args(parser):
FlowProcessor.add_args(parser)
ConsumerProducer.add_args(
parser, default_input_queue, default_subscriber,
default_output_queue,
)
parser.add_argument(
'-z', '--chunk-size',
@ -103,5 +98,5 @@ class Processor(FlowProcessor):
def run():
Processor.launch(default_ident, __doc__)
Processor.launch(module, __doc__)

View file

@ -7,27 +7,40 @@ as text as separate output objects.
from langchain_text_splitters import TokenTextSplitter
from prometheus_client import Histogram
from ... schema import TextDocument, Chunk
from ... base import FlowProcessor
from ... schema import TextDocument, Chunk, Metadata
from ... schema import text_ingest_queue, chunk_ingest_queue
from ... log_level import LogLevel
from ... base import ConsumerProducer
default_ident = "chunker"
module = ".".join(__name__.split(".")[1:-1])
class Processor(FlowProcessor):
default_input_queue = text_ingest_queue
default_output_queue = chunk_ingest_queue
default_subscriber = module
class Processor(ConsumerProducer):
def __init__(self, **params):
id = params.get("id")
input_queue = params.get("input_queue", default_input_queue)
output_queue = params.get("output_queue", default_output_queue)
subscriber = params.get("subscriber", default_subscriber)
chunk_size = params.get("chunk_size", 250)
chunk_overlap = params.get("chunk_overlap", 15)
super(Processor, self).__init__(
**params | { "id": id }
**params | {
"input_queue": input_queue,
"output_queue": output_queue,
"subscriber": subscriber,
"input_schema": TextDocument,
"output_schema": Chunk,
}
)
if not hasattr(__class__, "chunk_metric"):
__class__.chunk_metric = Histogram(
'chunk_size', 'Chunk size',
["id", "flow"],
buckets=[100, 160, 250, 400, 650, 1000, 1600,
2500, 4000, 6400, 10000, 16000]
)
@ -38,24 +51,7 @@ class Processor(FlowProcessor):
chunk_overlap=chunk_overlap,
)
self.register_specification(
ConsumerSpec(
name = "input",
schema = TextDocument,
handler = self.on_message,
)
)
self.register_specification(
ProducerSpec(
name = "output",
schema = Chunk,
)
)
print("Chunker initialised", flush=True)
async def on_message(self, msg, consumer, flow):
async def handle(self, msg):
v = msg.value()
print(f"Chunking {v.metadata.id}...", flush=True)
@ -66,25 +62,24 @@ class Processor(FlowProcessor):
for ix, chunk in enumerate(texts):
print("Chunk", len(chunk.page_content), flush=True)
r = Chunk(
metadata=v.metadata,
chunk=chunk.page_content.encode("utf-8"),
)
__class__.chunk_metric.labels(
id=consumer.id, flow=consumer.flow
).observe(len(chunk.page_content))
__class__.chunk_metric.observe(len(chunk.page_content))
await flow("output").send(r)
await self.send(r)
print("Done.", flush=True)
@staticmethod
def add_args(parser):
FlowProcessor.add_args(parser)
ConsumerProducer.add_args(
parser, default_input_queue, default_subscriber,
default_output_queue,
)
parser.add_argument(
'-z', '--chunk-size',
@ -102,5 +97,5 @@ class Processor(FlowProcessor):
def run():
Processor.launch(default_ident, __doc__)
Processor.launch(module, __doc__)

View file

@ -1,215 +0,0 @@
from trustgraph.schema import ConfigResponse
from trustgraph.schema import ConfigValue, Error
# This behaves just like a dict, should be easier to add persistent storage
# later
class ConfigurationItems(dict):
pass
class Configuration(dict):
# FIXME: The state is held internally. This only works if there's
# one config service. Should be more than one, and use a
# back-end state store.
def __init__(self, push):
# Version counter
self.version = 0
# External function to respond to update
self.push = push
def __getitem__(self, key):
if key not in self:
self[key] = ConfigurationItems()
return dict.__getitem__(self, key)
async def handle_get(self, v):
for k in v.keys:
if k.type not in self or k.key not in self[k.type]:
return ConfigResponse(
version = None,
values = None,
directory = None,
config = None,
error = Error(
type = "key-error",
message = f"Key error"
)
)
values = [
ConfigValue(
type = k.type,
key = k.key,
value = self[k.type][k.key]
)
for k in v.keys
]
return ConfigResponse(
version = self.version,
values = values,
directory = None,
config = None,
error = None,
)
async def handle_list(self, v):
if v.type not in self:
return ConfigResponse(
version = None,
values = None,
directory = None,
config = None,
error = Error(
type = "key-error",
message = "No such type",
),
)
return ConfigResponse(
version = self.version,
values = None,
directory = list(self[v.type].keys()),
config = None,
error = None,
)
async def handle_getvalues(self, v):
if v.type not in self:
return ConfigResponse(
version = None,
values = None,
directory = None,
config = None,
error = Error(
type = "key-error",
message = f"Key error"
)
)
values = [
ConfigValue(
type = v.type,
key = k,
value = self[v.type][k],
)
for k in self[v.type]
]
return ConfigResponse(
version = self.version,
values = values,
directory = None,
config = None,
error = None,
)
async def handle_delete(self, v):
for k in v.keys:
if k.type not in self or k.key not in self[k.type]:
return ConfigResponse(
version = None,
values = None,
directory = None,
config = None,
error = Error(
type = "key-error",
message = f"Key error"
)
)
for k in v.keys:
del self[k.type][k.key]
self.version += 1
await self.push()
return ConfigResponse(
version = None,
value = None,
directory = None,
values = None,
config = None,
error = None,
)
async def handle_put(self, v):
for k in v.values:
self[k.type][k.key] = k.value
self.version += 1
await self.push()
return ConfigResponse(
version = None,
value = None,
directory = None,
values = None,
error = None,
)
async def handle_config(self, v):
return ConfigResponse(
version = self.version,
value = None,
directory = None,
values = None,
config = self,
error = None,
)
async def handle(self, msg):
print("Handle message ", msg.operation)
if msg.operation == "get":
resp = await self.handle_get(msg)
elif msg.operation == "list":
resp = await self.handle_list(msg)
elif msg.operation == "getvalues":
resp = await self.handle_getvalues(msg)
elif msg.operation == "delete":
resp = await self.handle_delete(msg)
elif msg.operation == "put":
resp = await self.handle_put(msg)
elif msg.operation == "config":
resp = await self.handle_config(msg)
else:
resp = ConfigResponse(
value=None,
directory=None,
values=None,
error=Error(
type = "bad-operation",
message = "Bad operation"
)
)
return resp

View file

@ -1,90 +1,214 @@
"""
Config service. Manages system global configuration state
Config service. Fetchs an extract from the Wikipedia page
using the API.
"""
from pulsar.schema import JsonSchema
from trustgraph.schema import ConfigRequest, ConfigResponse, ConfigPush
from trustgraph.schema import Error
from trustgraph.schema import ConfigValue, Error
from trustgraph.schema import config_request_queue, config_response_queue
from trustgraph.schema import config_push_queue
from trustgraph.log_level import LogLevel
from trustgraph.base import AsyncProcessor, Consumer, Producer
from trustgraph.base import ConsumerProducer
from . config import Configuration
from ... base import ProcessorMetrics, ConsumerMetrics, ProducerMetrics
from ... base import Consumer, Producer
module = ".".join(__name__.split(".")[1:-1])
default_ident = "config-svc"
default_request_queue = config_request_queue
default_response_queue = config_response_queue
default_input_queue = config_request_queue
default_output_queue = config_response_queue
default_push_queue = config_push_queue
default_subscriber = module
class Processor(AsyncProcessor):
# This behaves just like a dict, should be easier to add persistent storage
# later
class ConfigurationItems(dict):
pass
class Configuration(dict):
def __getitem__(self, key):
if key not in self:
self[key] = ConfigurationItems()
return dict.__getitem__(self, key)
class Processor(ConsumerProducer):
def __init__(self, **params):
request_queue = params.get("request_queue", default_request_queue)
response_queue = params.get("response_queue", default_response_queue)
push_queue = params.get("push_queue", default_push_queue)
id = params.get("id")
request_schema = ConfigRequest
response_schema = ConfigResponse
push_schema = ConfigResponse
input_queue = params.get("input_queue", default_input_queue)
output_queue = params.get("output_queue", default_output_queue)
push_queue = params.get("push_queue", default_push_queue)
subscriber = params.get("subscriber", default_subscriber)
super(Processor, self).__init__(
**params | {
"request_schema": request_schema.__name__,
"response_schema": response_schema.__name__,
"push_schema": push_schema.__name__,
"input_queue": input_queue,
"output_queue": output_queue,
"push_queue": output_queue,
"subscriber": subscriber,
"input_schema": ConfigRequest,
"output_schema": ConfigResponse,
"push_schema": ConfigPush,
}
)
request_metrics = ConsumerMetrics(id + "-request")
response_metrics = ProducerMetrics(id + "-response")
push_metrics = ProducerMetrics(id + "-push")
self.push_pub = Producer(
client = self.client,
topic = push_queue,
schema = ConfigPush,
metrics = push_metrics,
self.push_prod = self.client.create_producer(
topic=push_queue,
schema=JsonSchema(ConfigPush),
)
self.response_pub = Producer(
client = self.client,
topic = response_queue,
schema = ConfigResponse,
metrics = response_metrics,
)
# FIXME: The state is held internally. This only works if there's
# one config service. Should be more than one, and use a
# back-end state store.
self.config = Configuration()
self.subs = Consumer(
taskgroup = self.taskgroup,
client = self.client,
flow = None,
topic = request_queue,
subscriber = id,
schema = request_schema,
handler = self.on_message,
metrics = request_metrics,
)
self.config = Configuration(self.push)
print("Service initialised.")
# Version counter
self.version = 0
async def start(self):
await self.push()
async def handle_get(self, v, id):
for k in v.keys:
if k.type not in self.config or k.key not in self.config[k.type]:
return ConfigResponse(
version = None,
values = None,
directory = None,
config = None,
error = Error(
type = "key-error",
message = f"Key error"
)
)
values = [
ConfigValue(
type = k.type,
key = k.key,
value = self.config[k.type][k.key]
)
for k in v.keys
]
return ConfigResponse(
version = self.version,
values = values,
directory = None,
config = None,
error = None,
)
async def handle_list(self, v, id):
if v.type not in self.config:
return ConfigResponse(
version = None,
values = None,
directory = None,
config = None,
error = Error(
type = "key-error",
message = "No such type",
),
)
return ConfigResponse(
version = self.version,
values = None,
directory = list(self.config[v.type].keys()),
config = None,
error = None,
)
async def handle_getvalues(self, v, id):
if v.type not in self.config:
return ConfigResponse(
version = None,
values = None,
directory = None,
config = None,
error = Error(
type = "key-error",
message = f"Key error"
)
)
values = [
ConfigValue(
type = v.type,
key = k,
value = self.config[v.type][k],
)
for k in self.config[v.type]
]
return ConfigResponse(
version = self.version,
values = values,
directory = None,
config = None,
error = None,
)
async def handle_delete(self, v, id):
for k in v.keys:
if k.type not in self.config or k.key not in self.config[k.type]:
return ConfigResponse(
version = None,
values = None,
directory = None,
config = None,
error = Error(
type = "key-error",
message = f"Key error"
)
)
for k in v.keys:
del self.config[k.type][k.key]
self.version += 1
await self.push()
await self.subs.start()
async def push(self):
resp = ConfigPush(
version = self.config.version,
return ConfigResponse(
version = None,
value = None,
directory = None,
values = None,
config = None,
error = None,
)
async def handle_put(self, v, id):
for k in v.values:
self.config[k.type][k.key] = k.value
self.version += 1
await self.push()
return ConfigResponse(
version = None,
value = None,
directory = None,
values = None,
error = None,
)
async def handle_config(self, v, id):
return ConfigResponse(
version = self.version,
value = None,
directory = None,
values = None,
@ -92,27 +216,72 @@ class Processor(AsyncProcessor):
error = None,
)
await self.push_pub.send(resp)
async def push(self):
print("Pushed version ", self.config.version)
resp = ConfigPush(
version = self.version,
value = None,
directory = None,
values = None,
config = self.config,
error = None,
)
self.push_prod.send(resp)
print("Pushed.")
async def on_message(self, msg, consumer, flow):
async def handle(self, msg):
v = msg.value()
# Sender-produced ID
id = msg.properties()["id"]
print(f"Handling {id}...", flush=True)
try:
v = msg.value()
if v.operation == "get":
# Sender-produced ID
id = msg.properties()["id"]
resp = await self.handle_get(v, id)
print(f"Handling {id}...", flush=True)
elif v.operation == "list":
resp = await self.config.handle(v)
resp = await self.handle_list(v, id)
await self.response_pub.send(resp, properties={"id": id})
elif v.operation == "getvalues":
resp = await self.handle_getvalues(v, id)
elif v.operation == "delete":
resp = await self.handle_delete(v, id)
elif v.operation == "put":
resp = await self.handle_put(v, id)
elif v.operation == "config":
resp = await self.handle_config(v, id)
else:
resp = ConfigResponse(
value=None,
directory=None,
values=None,
error=Error(
type = "bad-operation",
message = "Bad operation"
)
)
await self.send(resp, properties={"id": id})
self.consumer.acknowledge(msg)
except Exception as e:
resp = ConfigResponse(
error=Error(
type = "unexpected-error",
@ -120,33 +289,24 @@ class Processor(AsyncProcessor):
),
text=None,
)
await self.response_pub.send(resp, properties={"id": id})
await self.send(resp, properties={"id": id})
self.consumer.acknowledge(msg)
@staticmethod
def add_args(parser):
AsyncProcessor.add_args(parser)
parser.add_argument(
'-q', '--request-queue',
default=default_request_queue,
help=f'Request queue (default: {default_request_queue})'
ConsumerProducer.add_args(
parser, default_input_queue, default_subscriber,
default_output_queue,
)
parser.add_argument(
'-r', '--response-queue',
default=default_response_queue,
help=f'Response queue {default_response_queue}',
)
parser.add_argument(
'--push-queue',
'-q', '--push-queue',
default=default_push_queue,
help=f'Config push queue (default: {default_push_queue})'
)
def run():
Processor.launch(default_ident, __doc__)
Processor.launch(module, __doc__)

View file

@ -17,10 +17,12 @@ from mistralai.models import OCRResponse
from ... schema import Document, TextDocument, Metadata
from ... schema import document_ingest_queue, text_ingest_queue
from ... log_level import LogLevel
from ... base import InputOutputProcessor
from ... base import ConsumerProducer
module = "ocr"
module = ".".join(__name__.split(".")[1:-1])
default_input_queue = document_ingest_queue
default_output_queue = text_ingest_queue
default_subscriber = module
default_api_key = os.getenv("MISTRAL_TOKEN")
@ -69,17 +71,19 @@ def get_combined_markdown(ocr_response: OCRResponse) -> str:
return "\n\n".join(markdowns)
class Processor(InputOutputProcessor):
class Processor(ConsumerProducer):
def __init__(self, **params):
id = params.get("id")
input_queue = params.get("input_queue", default_input_queue)
output_queue = params.get("output_queue", default_output_queue)
subscriber = params.get("subscriber", default_subscriber)
api_key = params.get("api_key", default_api_key)
super(Processor, self).__init__(
**params | {
"id": id,
"input_queue": input_queue,
"output_queue": output_queue,
"subscriber": subscriber,
"input_schema": Document,
"output_schema": TextDocument,
@ -147,7 +151,7 @@ class Processor(InputOutputProcessor):
return markdown
async def on_message(self, msg, consumer):
async def handle(self, msg):
print("PDF message received")
@ -162,14 +166,17 @@ class Processor(InputOutputProcessor):
text=markdown.encode("utf-8"),
)
await consumer.q.output.send(r)
await self.send(r)
print("Done.", flush=True)
@staticmethod
def add_args(parser):
InputOutputProcessor.add_args(parser, default_subscriber)
ConsumerProducer.add_args(
parser, default_input_queue, default_subscriber,
default_output_queue,
)
parser.add_argument(
'-k', '--api-key',

View file

@ -9,43 +9,39 @@ import base64
from langchain_community.document_loaders import PyPDFLoader
from ... schema import Document, TextDocument, Metadata
from ... schema import document_ingest_queue, text_ingest_queue
from ... log_level import LogLevel
from ... base import FlowProcessor, ConsumerSpec, ProducerSpec
from ... base import ConsumerProducer
default_ident = "pdf-decoder"
module = ".".join(__name__.split(".")[1:-1])
class Processor(FlowProcessor):
default_input_queue = document_ingest_queue
default_output_queue = text_ingest_queue
default_subscriber = module
class Processor(ConsumerProducer):
def __init__(self, **params):
id = params.get("id", default_ident)
input_queue = params.get("input_queue", default_input_queue)
output_queue = params.get("output_queue", default_output_queue)
subscriber = params.get("subscriber", default_subscriber)
super(Processor, self).__init__(
**params | {
"id": id,
"input_queue": input_queue,
"output_queue": output_queue,
"subscriber": subscriber,
"input_schema": Document,
"output_schema": TextDocument,
}
)
self.register_specification(
ConsumerSpec(
name = "input",
schema = Document,
handler = self.on_message,
)
)
print("PDF inited")
self.register_specification(
ProducerSpec(
name = "output",
schema = TextDocument,
)
)
async def handle(self, msg):
print("PDF inited", flush=True)
async def on_message(self, msg, consumer, flow):
print("PDF message received", flush=True)
print("PDF message received")
v = msg.value()
@ -63,22 +59,24 @@ class Processor(FlowProcessor):
for ix, page in enumerate(pages):
print("page", ix, flush=True)
r = TextDocument(
metadata=v.metadata,
text=page.page_content.encode("utf-8"),
)
await flow("output").send(r)
await self.send(r)
print("Done.", flush=True)
@staticmethod
def add_args(parser):
FlowProcessor.add_args(parser)
ConsumerProducer.add_args(
parser, default_input_queue, default_subscriber,
default_output_queue,
)
def run():
Processor.launch(default_ident, __doc__)
Processor.launch(module, __doc__)

View file

@ -0,0 +1,153 @@
from . clients.document_embeddings_client import DocumentEmbeddingsClient
from . clients.triples_query_client import TriplesQueryClient
from . clients.embeddings_client import EmbeddingsClient
from . clients.prompt_client import PromptClient
from . schema import DocumentEmbeddingsRequest, DocumentEmbeddingsResponse
from . schema import TriplesQueryRequest, TriplesQueryResponse
from . schema import prompt_request_queue
from . schema import prompt_response_queue
from . schema import embeddings_request_queue
from . schema import embeddings_response_queue
from . schema import document_embeddings_request_queue
from . schema import document_embeddings_response_queue
LABEL="http://www.w3.org/2000/01/rdf-schema#label"
DEFINITION="http://www.w3.org/2004/02/skos/core#definition"
class Query:
def __init__(
self, rag, user, collection, verbose,
doc_limit=20
):
self.rag = rag
self.user = user
self.collection = collection
self.verbose = verbose
self.doc_limit = doc_limit
def get_vector(self, query):
if self.verbose:
print("Compute embeddings...", flush=True)
qembeds = self.rag.embeddings.request(query)
if self.verbose:
print("Done.", flush=True)
return qembeds
def get_docs(self, query):
vectors = self.get_vector(query)
if self.verbose:
print("Get entities...", flush=True)
docs = self.rag.de_client.request(
vectors, limit=self.doc_limit
)
if self.verbose:
print("Docs:", flush=True)
for doc in docs:
print(doc, flush=True)
return docs
class DocumentRag:
def __init__(
self,
pulsar_host="pulsar://pulsar:6650",
pulsar_api_key=None,
pr_request_queue=None,
pr_response_queue=None,
emb_request_queue=None,
emb_response_queue=None,
de_request_queue=None,
de_response_queue=None,
verbose=False,
module="test",
):
self.verbose=verbose
if pr_request_queue is None:
pr_request_queue = prompt_request_queue
if pr_response_queue is None:
pr_response_queue = prompt_response_queue
if emb_request_queue is None:
emb_request_queue = embeddings_request_queue
if emb_response_queue is None:
emb_response_queue = embeddings_response_queue
if de_request_queue is None:
de_request_queue = document_embeddings_request_queue
if de_response_queue is None:
de_response_queue = document_embeddings_response_queue
if self.verbose:
print("Initialising...", flush=True)
self.de_client = DocumentEmbeddingsClient(
pulsar_host=pulsar_host,
subscriber=module + "-de",
input_queue=de_request_queue,
output_queue=de_response_queue,
pulsar_api_key=pulsar_api_key,
)
self.embeddings = EmbeddingsClient(
pulsar_host=pulsar_host,
input_queue=emb_request_queue,
output_queue=emb_response_queue,
subscriber=module + "-emb",
pulsar_api_key=pulsar_api_key,
)
self.lang = PromptClient(
pulsar_host=pulsar_host,
input_queue=pr_request_queue,
output_queue=pr_response_queue,
subscriber=module + "-de-prompt",
pulsar_api_key=pulsar_api_key,
)
if self.verbose:
print("Initialised", flush=True)
def query(
self, query, user="trustgraph", collection="default",
doc_limit=20,
):
if self.verbose:
print("Construct prompt...", flush=True)
q = Query(
rag=self, user=user, collection=collection, verbose=self.verbose,
doc_limit=doc_limit
)
docs = q.get_docs(query)
if self.verbose:
print("Invoke LLM...", flush=True)
print(docs)
print(query)
resp = self.lang.request_document_prompt(query, docs)
if self.verbose:
print("Done", flush=True)
return resp

View file

@ -6,63 +6,61 @@ Output is chunk plus embedding.
"""
from ... schema import Chunk, ChunkEmbeddings, DocumentEmbeddings
from ... schema import EmbeddingsRequest, EmbeddingsResponse
from ... schema import chunk_ingest_queue
from ... schema import document_embeddings_store_queue
from ... schema import embeddings_request_queue, embeddings_response_queue
from ... clients.embeddings_client import EmbeddingsClient
from ... log_level import LogLevel
from ... base import ConsumerProducer
from ... base import FlowProcessor, RequestResponseSpec, ConsumerSpec
from ... base import ProducerSpec
module = ".".join(__name__.split(".")[1:-1])
default_ident = "document-embeddings"
default_input_queue = chunk_ingest_queue
default_output_queue = document_embeddings_store_queue
default_subscriber = module
class Processor(FlowProcessor):
class Processor(ConsumerProducer):
def __init__(self, **params):
id = params.get("id")
input_queue = params.get("input_queue", default_input_queue)
output_queue = params.get("output_queue", default_output_queue)
subscriber = params.get("subscriber", default_subscriber)
emb_request_queue = params.get(
"embeddings_request_queue", embeddings_request_queue
)
emb_response_queue = params.get(
"embeddings_response_queue", embeddings_response_queue
)
super(Processor, self).__init__(
**params | {
"id": id,
"input_queue": input_queue,
"output_queue": output_queue,
"embeddings_request_queue": emb_request_queue,
"embeddings_response_queue": emb_response_queue,
"subscriber": subscriber,
"input_schema": Chunk,
"output_schema": DocumentEmbeddings,
}
)
self.register_specification(
ConsumerSpec(
name = "input",
schema = Chunk,
handler = self.on_message,
)
self.embeddings = EmbeddingsClient(
pulsar_host=self.pulsar_host,
pulsar_api_key=self.pulsar_api_key,
input_queue=emb_request_queue,
output_queue=emb_response_queue,
subscriber=module + "-emb",
)
self.register_specification(
RequestResponseSpec(
request_name = "embeddings-request",
request_schema = EmbeddingsRequest,
response_name = "embeddings-response",
response_schema = EmbeddingsResponse,
)
)
self.register_specification(
ProducerSpec(
name = "output",
schema = DocumentEmbeddings
)
)
async def on_message(self, msg, consumer, flow):
async def handle(self, msg):
v = msg.value()
print(f"Indexing {v.metadata.id}...", flush=True)
try:
resp = await flow("embeddings-request").request(
EmbeddingsRequest(
text = v.chunk
)
)
vectors = resp.vectors
vectors = self.embeddings.request(v.chunk)
embeds = [
ChunkEmbeddings(
@ -76,7 +74,7 @@ class Processor(FlowProcessor):
chunks=embeds,
)
await flow("output").send(r)
await self.send(r)
except Exception as e:
print("Exception:", e, flush=True)
@ -89,9 +87,24 @@ class Processor(FlowProcessor):
@staticmethod
def add_args(parser):
FlowProcessor.add_args(parser)
ConsumerProducer.add_args(
parser, default_input_queue, default_subscriber,
default_output_queue,
)
parser.add_argument(
'--embeddings-request-queue',
default=embeddings_request_queue,
help=f'Embeddings request queue (default: {embeddings_request_queue})',
)
parser.add_argument(
'--embeddings-response-queue',
default=embeddings_response_queue,
help=f'Embeddings request queue (default: {embeddings_response_queue})',
)
def run():
Processor.launch(default_ident, __doc__)
Processor.launch(module, __doc__)

View file

@ -1,43 +1,81 @@
"""
Embeddings service, applies an embeddings model using fastembed
Embeddings service, applies an embeddings model selected from HuggingFace.
Input is text, output is embeddings vector.
"""
from ... base import EmbeddingsService
from ... schema import EmbeddingsRequest, EmbeddingsResponse
from ... schema import embeddings_request_queue, embeddings_response_queue
from ... log_level import LogLevel
from ... base import ConsumerProducer
from fastembed import TextEmbedding
import os
default_ident = "embeddings"
module = ".".join(__name__.split(".")[1:-1])
default_input_queue = embeddings_request_queue
default_output_queue = embeddings_response_queue
default_subscriber = module
default_model="sentence-transformers/all-MiniLM-L6-v2"
class Processor(EmbeddingsService):
class Processor(ConsumerProducer):
def __init__(self, **params):
input_queue = params.get("input_queue", default_input_queue)
output_queue = params.get("output_queue", default_output_queue)
subscriber = params.get("subscriber", default_subscriber)
model = params.get("model", default_model)
super(Processor, self).__init__(
**params | { "model": model }
**params | {
"input_queue": input_queue,
"output_queue": output_queue,
"subscriber": subscriber,
"input_schema": EmbeddingsRequest,
"output_schema": EmbeddingsResponse,
"model": model,
}
)
print("Get model...", flush=True)
self.embeddings = TextEmbedding(model_name = model)
async def on_embeddings(self, text):
async def handle(self, msg):
v = msg.value()
# Sender-produced ID
id = msg.properties()["id"]
print(f"Handling input {id}...", flush=True)
text = v.text
vecs = self.embeddings.embed([text])
return [
vecs = [
v.tolist()
for v in vecs
]
print("Send response...", flush=True)
r = EmbeddingsResponse(
vectors=list(vecs),
error=None,
)
await self.send(r, properties={"id": id})
print("Done.", flush=True)
@staticmethod
def add_args(parser):
EmbeddingsService.add_args(parser)
ConsumerProducer.add_args(
parser, default_input_queue, default_subscriber,
default_output_queue,
)
parser.add_argument(
'-m', '--model',
@ -47,5 +85,5 @@ class Processor(EmbeddingsService):
def run():
Processor.launch(default_ident, __doc__)
Processor.launch(module, __doc__)

View file

@ -6,48 +6,53 @@ Output is entity plus embedding.
"""
from ... schema import EntityContexts, EntityEmbeddings, GraphEmbeddings
from ... schema import EmbeddingsRequest, EmbeddingsResponse
from ... schema import entity_contexts_ingest_queue
from ... schema import graph_embeddings_store_queue
from ... schema import embeddings_request_queue, embeddings_response_queue
from ... clients.embeddings_client import EmbeddingsClient
from ... log_level import LogLevel
from ... base import ConsumerProducer
from ... base import FlowProcessor, EmbeddingsClientSpec, ConsumerSpec
from ... base import ProducerSpec
module = ".".join(__name__.split(".")[1:-1])
default_ident = "graph-embeddings"
default_input_queue = entity_contexts_ingest_queue
default_output_queue = graph_embeddings_store_queue
default_subscriber = module
class Processor(FlowProcessor):
class Processor(ConsumerProducer):
def __init__(self, **params):
id = params.get("id")
input_queue = params.get("input_queue", default_input_queue)
output_queue = params.get("output_queue", default_output_queue)
subscriber = params.get("subscriber", default_subscriber)
emb_request_queue = params.get(
"embeddings_request_queue", embeddings_request_queue
)
emb_response_queue = params.get(
"embeddings_response_queue", embeddings_response_queue
)
super(Processor, self).__init__(
**params | {
"id": id,
"input_queue": input_queue,
"output_queue": output_queue,
"embeddings_request_queue": emb_request_queue,
"embeddings_response_queue": emb_response_queue,
"subscriber": subscriber,
"input_schema": EntityContexts,
"output_schema": GraphEmbeddings,
}
)
self.register_specification(
ConsumerSpec(
name = "input",
schema = EntityContexts,
handler = self.on_message,
)
self.embeddings = EmbeddingsClient(
pulsar_host=self.pulsar_host,
input_queue=emb_request_queue,
output_queue=emb_response_queue,
subscriber=module + "-emb",
)
self.register_specification(
EmbeddingsClientSpec(
request_name = "embeddings-request",
response_name = "embeddings-response",
)
)
self.register_specification(
ProducerSpec(
name = "output",
schema = GraphEmbeddings
)
)
async def on_message(self, msg, consumer, flow):
async def handle(self, msg):
v = msg.value()
print(f"Indexing {v.metadata.id}...", flush=True)
@ -58,9 +63,7 @@ class Processor(FlowProcessor):
for entity in v.entities:
vectors = await flow("embeddings-request").embed(
text = entity.context
)
vectors = self.embeddings.request(entity.context)
entities.append(
EntityEmbeddings(
@ -74,7 +77,7 @@ class Processor(FlowProcessor):
entities=entities,
)
await flow("output").send(r)
await self.send(r)
except Exception as e:
print("Exception:", e, flush=True)
@ -87,9 +90,24 @@ class Processor(FlowProcessor):
@staticmethod
def add_args(parser):
FlowProcessor.add_args(parser)
ConsumerProducer.add_args(
parser, default_input_queue, default_subscriber,
default_output_queue,
)
parser.add_argument(
'--embeddings-request-queue',
default=embeddings_request_queue,
help=f'Embeddings request queue (default: {embeddings_request_queue})',
)
parser.add_argument(
'--embeddings-response-queue',
default=embeddings_response_queue,
help=f'Embeddings request queue (default: {embeddings_response_queue})',
)
def run():
Processor.launch(default_ident, __doc__)
Processor.launch(module, __doc__)

View file

@ -11,7 +11,7 @@ from ... base import ConsumerProducer
from ollama import Client
import os
module = "embeddings"
module = ".".join(__name__.split(".")[1:-1])
default_input_queue = embeddings_request_queue
default_output_queue = embeddings_response_queue

View file

@ -11,7 +11,7 @@ from trustgraph.log_level import LogLevel
from trustgraph.base import ConsumerProducer
import requests
module = "wikipedia"
module = ".".join(__name__.split(".")[1:-1])
default_input_queue = encyclopedia_lookup_request_queue
default_output_queue = encyclopedia_lookup_response_queue

View file

@ -5,62 +5,84 @@ get entity definitions which are output as graph edges along with
entity/context definitions for embedding.
"""
import json
import urllib.parse
from pulsar.schema import JsonSchema
from .... schema import Chunk, Triple, Triples, Metadata, Value
from .... schema import EntityContext, EntityContexts
from .... schema import PromptRequest, PromptResponse
from .... schema import chunk_ingest_queue, triples_store_queue
from .... schema import entity_contexts_ingest_queue
from .... schema import prompt_request_queue
from .... schema import prompt_response_queue
from .... log_level import LogLevel
from .... clients.prompt_client import PromptClient
from .... rdf import TRUSTGRAPH_ENTITIES, DEFINITION, RDF_LABEL, SUBJECT_OF
from .... base import FlowProcessor, ConsumerSpec, ProducerSpec
from .... base import PromptClientSpec
from .... base import ConsumerProducer
DEFINITION_VALUE = Value(value=DEFINITION, is_uri=True)
RDF_LABEL_VALUE = Value(value=RDF_LABEL, is_uri=True)
SUBJECT_OF_VALUE = Value(value=SUBJECT_OF, is_uri=True)
default_ident = "kg-extract-definitions"
module = ".".join(__name__.split(".")[1:-1])
class Processor(FlowProcessor):
default_input_queue = chunk_ingest_queue
default_output_queue = triples_store_queue
default_entity_context_queue = entity_contexts_ingest_queue
default_subscriber = module
class Processor(ConsumerProducer):
def __init__(self, **params):
id = params.get("id")
input_queue = params.get("input_queue", default_input_queue)
output_queue = params.get("output_queue", default_output_queue)
ec_queue = params.get(
"entity_context_queue",
default_entity_context_queue
)
subscriber = params.get("subscriber", default_subscriber)
pr_request_queue = params.get(
"prompt_request_queue", prompt_request_queue
)
pr_response_queue = params.get(
"prompt_response_queue", prompt_response_queue
)
super(Processor, self).__init__(
**params | {
"id": id,
"input_queue": input_queue,
"output_queue": output_queue,
"subscriber": subscriber,
"input_schema": Chunk,
"output_schema": Triples,
"prompt_request_queue": pr_request_queue,
"prompt_response_queue": pr_response_queue,
}
)
self.register_specification(
ConsumerSpec(
name = "input",
schema = Chunk,
handler = self.on_message
)
self.ec_prod = self.client.create_producer(
topic=ec_queue,
schema=JsonSchema(EntityContexts),
)
self.register_specification(
PromptClientSpec(
request_name = "prompt-request",
response_name = "prompt-response",
)
)
__class__.pubsub_metric.info({
"input_queue": input_queue,
"output_queue": output_queue,
"entity_context_queue": ec_queue,
"prompt_request_queue": pr_request_queue,
"prompt_response_queue": pr_response_queue,
"subscriber": subscriber,
"input_schema": Chunk.__name__,
"output_schema": Triples.__name__,
"vector_schema": EntityContexts.__name__,
})
self.register_specification(
ProducerSpec(
name = "triples",
schema = Triples
)
)
self.register_specification(
ProducerSpec(
name = "entity-contexts",
schema = EntityContexts
)
self.prompt = PromptClient(
pulsar_host=self.pulsar_host,
pulsar_api_key=self.pulsar_api_key,
input_queue=pr_request_queue,
output_queue=pr_response_queue,
subscriber = module + "-prompt",
)
def to_uri(self, text):
@ -71,47 +93,36 @@ class Processor(FlowProcessor):
return uri
async def emit_triples(self, pub, metadata, triples):
def get_definitions(self, chunk):
return self.prompt.request_definitions(chunk)
async def emit_edges(self, metadata, triples):
t = Triples(
metadata=metadata,
triples=triples,
)
await pub.send(t)
await self.send(t)
async def emit_ecs(self, pub, metadata, entities):
async def emit_ecs(self, metadata, entities):
t = EntityContexts(
metadata=metadata,
entities=entities,
)
await pub.send(t)
self.ec_prod.send(t)
async def on_message(self, msg, consumer, flow):
async def handle(self, msg):
v = msg.value()
print(f"Indexing {v.metadata.id}...", flush=True)
chunk = v.chunk.decode("utf-8")
print(chunk, flush=True)
try:
try:
defs = await flow("prompt-request").extract_definitions(
text = chunk
)
print("Response", defs, flush=True)
if type(defs) != list:
raise RuntimeError("Expecting array in prompt response")
except Exception as e:
print("Prompt exception:", e, flush=True)
raise e
defs = self.get_definitions(chunk)
triples = []
entities = []
@ -123,8 +134,8 @@ class Processor(FlowProcessor):
for defn in defs:
s = defn["entity"]
o = defn["definition"]
s = defn.name
o = defn.definition
if s == "": continue
if o == "": continue
@ -155,13 +166,13 @@ class Processor(FlowProcessor):
ec = EntityContext(
entity=s_value,
context=defn["definition"],
context=defn.definition,
)
entities.append(ec)
await self.emit_triples(
flow("triples"),
await self.emit_edges(
Metadata(
id=v.metadata.id,
metadata=[],
@ -172,7 +183,6 @@ class Processor(FlowProcessor):
)
await self.emit_ecs(
flow("entity-contexts"),
Metadata(
id=v.metadata.id,
metadata=[],
@ -190,9 +200,30 @@ class Processor(FlowProcessor):
@staticmethod
def add_args(parser):
FlowProcessor.add_args(parser)
ConsumerProducer.add_args(
parser, default_input_queue, default_subscriber,
default_output_queue,
)
parser.add_argument(
'-e', '--entity-context-queue',
default=default_entity_context_queue,
help=f'Entity context queue (default: {default_entity_context_queue})'
)
parser.add_argument(
'--prompt-request-queue',
default=prompt_request_queue,
help=f'Prompt request queue (default: {prompt_request_queue})',
)
parser.add_argument(
'--prompt-completion-response-queue',
default=prompt_response_queue,
help=f'Prompt response queue (default: {prompt_response_queue})',
)
def run():
Processor.launch(default_ident, __doc__)
Processor.launch(module, __doc__)

View file

@ -5,54 +5,59 @@ relationship analysis to get entity relationship edges which are output as
graph edges.
"""
import json
import urllib.parse
from .... schema import Chunk, Triple, Triples
from .... schema import Metadata, Value
from .... schema import PromptRequest, PromptResponse
from .... schema import chunk_ingest_queue, triples_store_queue
from .... schema import prompt_request_queue
from .... schema import prompt_response_queue
from .... log_level import LogLevel
from .... clients.prompt_client import PromptClient
from .... rdf import RDF_LABEL, TRUSTGRAPH_ENTITIES, SUBJECT_OF
from .... base import FlowProcessor, ConsumerSpec, ProducerSpec
from .... base import PromptClientSpec
from .... base import ConsumerProducer
RDF_LABEL_VALUE = Value(value=RDF_LABEL, is_uri=True)
SUBJECT_OF_VALUE = Value(value=SUBJECT_OF, is_uri=True)
default_ident = "kg-extract-relationships"
module = ".".join(__name__.split(".")[1:-1])
class Processor(FlowProcessor):
default_input_queue = chunk_ingest_queue
default_output_queue = triples_store_queue
default_subscriber = module
class Processor(ConsumerProducer):
def __init__(self, **params):
id = params.get("id")
input_queue = params.get("input_queue", default_input_queue)
output_queue = params.get("output_queue", default_output_queue)
subscriber = params.get("subscriber", default_subscriber)
pr_request_queue = params.get(
"prompt_request_queue", prompt_request_queue
)
pr_response_queue = params.get(
"prompt_response_queue", prompt_response_queue
)
super(Processor, self).__init__(
**params | {
"id": id,
"input_queue": input_queue,
"output_queue": output_queue,
"subscriber": subscriber,
"input_schema": Chunk,
"output_schema": Triples,
"prompt_request_queue": pr_request_queue,
"prompt_response_queue": pr_response_queue,
}
)
self.register_specification(
ConsumerSpec(
name = "input",
schema = Chunk,
handler = self.on_message
)
)
self.register_specification(
PromptClientSpec(
request_name = "prompt-request",
response_name = "prompt-response",
)
)
self.register_specification(
ProducerSpec(
name = "triples",
schema = Triples
)
self.prompt = PromptClient(
pulsar_host=self.pulsar_host,
pulsar_api_key=self.pulsar_api_key,
input_queue=pr_request_queue,
output_queue=pr_response_queue,
subscriber = module + "-prompt",
)
def to_uri(self, text):
@ -63,39 +68,28 @@ class Processor(FlowProcessor):
return uri
async def emit_triples(self, pub, metadata, triples):
def get_relationships(self, chunk):
return self.prompt.request_relationships(chunk)
async def emit_edges(self, metadata, triples):
t = Triples(
metadata=metadata,
triples=triples,
)
await pub.send(t)
await self.send(t)
async def on_message(self, msg, consumer, flow):
async def handle(self, msg):
v = msg.value()
print(f"Indexing {v.metadata.id}...", flush=True)
chunk = v.chunk.decode("utf-8")
print(chunk, flush=True)
try:
try:
rels = await flow("prompt-request").extract_relationships(
text = chunk
)
print("Response", rels, flush=True)
if type(rels) != list:
raise RuntimeError("Expecting array in prompt response")
except Exception as e:
print("Prompt exception:", e, flush=True)
raise e
rels = self.get_relationships(chunk)
triples = []
@ -106,9 +100,9 @@ class Processor(FlowProcessor):
for rel in rels:
s = rel["subject"]
p = rel["predicate"]
o = rel["object"]
s = rel.s
p = rel.p
o = rel.o
if s == "": continue
if p == "": continue
@ -124,7 +118,7 @@ class Processor(FlowProcessor):
p_uri = self.to_uri(p)
p_value = Value(value=str(p_uri), is_uri=True)
if rel["object-entity"]:
if rel.o_entity:
o_uri = self.to_uri(o)
o_value = Value(value=str(o_uri), is_uri=True)
else:
@ -150,7 +144,7 @@ class Processor(FlowProcessor):
o=Value(value=str(p), is_uri=False)
))
if rel["object-entity"]:
if rel.o_entity:
# Label for o
triples.append(Triple(
s=o_value,
@ -165,7 +159,7 @@ class Processor(FlowProcessor):
o=Value(value=v.metadata.id, is_uri=True)
))
if rel["object-entity"]:
if rel.o_entity:
# 'Subject of' for o
triples.append(Triple(
s=o_value,
@ -174,7 +168,6 @@ class Processor(FlowProcessor):
))
await self.emit_edges(
flow("triples"),
Metadata(
id=v.metadata.id,
metadata=[],
@ -192,9 +185,24 @@ class Processor(FlowProcessor):
@staticmethod
def add_args(parser):
FlowProcessor.add_args(parser)
ConsumerProducer.add_args(
parser, default_input_queue, default_subscriber,
default_output_queue,
)
parser.add_argument(
'--prompt-request-queue',
default=prompt_request_queue,
help=f'Prompt request queue (default: {prompt_request_queue})',
)
parser.add_argument(
'--prompt-response-queue',
default=prompt_response_queue,
help=f'Prompt response queue (default: {prompt_response_queue})',
)
def run():
Processor.launch(default_ident, __doc__)
Processor.launch(module, __doc__)

View file

@ -18,7 +18,7 @@ from .... base import ConsumerProducer
DEFINITION_VALUE = Value(value=DEFINITION, is_uri=True)
module = "kg-extract-topics"
module = ".".join(__name__.split(".")[1:-1])
default_input_queue = chunk_ingest_queue
default_output_queue = triples_store_queue

View file

@ -39,3 +39,4 @@ class AgentRequestor(ServiceRequestor):
# The 2nd boolean expression indicates whether we're done responding
return resp, (message.answer is not None)

View file

@ -1,5 +1,6 @@
import asyncio
from pulsar.schema import JsonSchema
import uuid
from aiohttp import WSMsgType
@ -25,12 +26,12 @@ class DocumentEmbeddingsLoadEndpoint(SocketEndpoint):
self.publisher = Publisher(
self.pulsar_client, document_embeddings_store_queue,
schema=DocumentEmbeddings
schema=JsonSchema(DocumentEmbeddings)
)
async def start(self):
await self.publisher.start()
self.publisher.start()
async def listener(self, ws, running):
@ -58,6 +59,6 @@ class DocumentEmbeddingsLoadEndpoint(SocketEndpoint):
],
)
await self.publisher.send(None, elt)
self.publisher.send(None, elt)
running.stop()

View file

@ -1,6 +1,7 @@
import asyncio
import queue
from pulsar.schema import JsonSchema
import uuid
from .. schema import DocumentEmbeddings
@ -26,7 +27,7 @@ class DocumentEmbeddingsStreamEndpoint(SocketEndpoint):
self.subscriber = Subscriber(
self.pulsar_client, document_embeddings_store_queue,
"api-gateway", "api-gateway",
schema=DocumentEmbeddings,
schema=JsonSchema(DocumentEmbeddings),
)
async def listener(self, ws, running):
@ -43,17 +44,17 @@ class DocumentEmbeddingsStreamEndpoint(SocketEndpoint):
async def start(self):
await self.subscriber.start()
self.subscriber.start()
async def async_thread(self, ws, running):
id = str(uuid.uuid4())
q = await self.subscriber.subscribe_all(id)
q = self.subscriber.subscribe_all(id)
while running.get():
try:
resp = await asyncio.wait_for(q.get(), timeout=0.5)
resp = await asyncio.to_thread(q.get, timeout=0.5)
await ws.send_json(serialize_document_embeddings(resp))
except TimeoutError:
@ -66,7 +67,7 @@ class DocumentEmbeddingsStreamEndpoint(SocketEndpoint):
print(f"Exception: {str(e)}", flush=True)
break
await self.subscriber.unsubscribe_all(id)
self.subscriber.unsubscribe_all(id)
running.stop()

View file

@ -1,9 +1,13 @@
import asyncio
from pulsar.schema import JsonSchema
from aiohttp import web
import uuid
import logging
from .. base import Publisher
from .. base import Subscriber
logger = logging.getLogger("endpoint")
logger.setLevel(logging.INFO)

View file

@ -1,5 +1,6 @@
import asyncio
from pulsar.schema import JsonSchema
import uuid
from aiohttp import WSMsgType
@ -25,12 +26,12 @@ class GraphEmbeddingsLoadEndpoint(SocketEndpoint):
self.publisher = Publisher(
self.pulsar_client, graph_embeddings_store_queue,
schema=GraphEmbeddings
schema=JsonSchema(GraphEmbeddings)
)
async def start(self):
await self.publisher.start()
self.publisher.start()
async def listener(self, ws, running):
@ -59,6 +60,6 @@ class GraphEmbeddingsLoadEndpoint(SocketEndpoint):
]
)
await self.publisher.send(None, elt)
self.publisher.send(None, elt)
running.stop()

View file

@ -1,6 +1,7 @@
import asyncio
import queue
from pulsar.schema import JsonSchema
import uuid
from .. schema import GraphEmbeddings
@ -25,7 +26,7 @@ class GraphEmbeddingsStreamEndpoint(SocketEndpoint):
self.subscriber = Subscriber(
self.pulsar_client, graph_embeddings_store_queue,
"api-gateway", "api-gateway",
schema=GraphEmbeddings
schema=JsonSchema(GraphEmbeddings)
)
async def listener(self, ws, running):
@ -40,17 +41,17 @@ class GraphEmbeddingsStreamEndpoint(SocketEndpoint):
async def start(self):
await self.subscriber.start()
self.subscriber.start()
async def async_thread(self, ws, running):
id = str(uuid.uuid4())
q = await self.subscriber.subscribe_all(id)
q = self.subscriber.subscribe_all(id)
while running.get():
try:
resp = await asyncio.wait_for(q.get, timeout=0.5)
resp = await asyncio.to_thread(q.get, timeout=0.5)
await ws.send_json(serialize_graph_embeddings(resp))
except TimeoutError:
@ -63,7 +64,7 @@ class GraphEmbeddingsStreamEndpoint(SocketEndpoint):
print(f"Exception: {str(e)}", flush=True)
break
await self.subscriber.unsubscribe_all(id)
self.subscriber.unsubscribe_all(id)
running.stop()

View file

@ -7,6 +7,7 @@
import aiohttp
from aiohttp import web
import asyncio
from pulsar.schema import JsonSchema
import uuid
import logging

View file

@ -1,6 +1,7 @@
import asyncio
import queue
from pulsar.schema import JsonSchema
import uuid
from aiohttp import web, WSMsgType

View file

@ -1,5 +1,6 @@
import asyncio
from pulsar.schema import JsonSchema
import uuid
import logging
@ -22,21 +23,21 @@ class ServiceRequestor:
self.pub = Publisher(
pulsar_client, request_queue,
schema=request_schema,
schema=JsonSchema(request_schema),
)
self.sub = Subscriber(
pulsar_client, response_queue,
subscription, consumer_name,
response_schema
JsonSchema(response_schema)
)
self.timeout = timeout
async def start(self):
await self.pub.start()
await self.sub.start()
self.pub.start()
self.sub.start()
def to_request(self, request):
raise RuntimeError("Not defined")
@ -50,15 +51,18 @@ class ServiceRequestor:
try:
q = await self.sub.subscribe(id)
q = self.sub.subscribe(id)
await self.pub.send(id, self.to_request(request))
await asyncio.to_thread(
self.pub.send, id, self.to_request(request)
)
while True:
try:
resp = await asyncio.wait_for(
q.get(), timeout=self.timeout
resp = await asyncio.to_thread(
q.get,
timeout=self.timeout
)
except Exception as e:
raise RuntimeError("Timeout")
@ -95,5 +99,5 @@ class ServiceRequestor:
return err
finally:
await self.sub.unsubscribe(id)
self.sub.unsubscribe(id)

View file

@ -2,6 +2,7 @@
# Like ServiceRequestor, but just fire-and-forget instead of request/response
import asyncio
from pulsar.schema import JsonSchema
import uuid
import logging
@ -20,12 +21,12 @@ class ServiceSender:
self.pub = Publisher(
pulsar_client, request_queue,
schema=request_schema,
schema=JsonSchema(request_schema),
)
async def start(self):
await self.pub.start()
self.pub.start()
def to_request(self, request):
raise RuntimeError("Not defined")
@ -34,7 +35,9 @@ class ServiceSender:
try:
await self.pub.send(None, self.to_request(request))
await asyncio.to_thread(
self.pub.send, None, self.to_request(request)
)
if responder:
await responder({}, True)

View file

@ -3,7 +3,7 @@ API gateway. Offers HTTP services which are translated to interaction on the
Pulsar bus.
"""
module = "api-gateway"
module = ".".join(__name__.split(".")[1:-1])
# FIXME: Subscribes to Pulsar unnecessarily, should only do it when there
# are active listeners
@ -19,6 +19,7 @@ import os
import base64
import pulsar
from pulsar.schema import JsonSchema
from prometheus_client import start_http_server
from .. log_level import LogLevel

View file

@ -1,5 +1,6 @@
import asyncio
from pulsar.schema import JsonSchema
import uuid
from aiohttp import WSMsgType
@ -23,12 +24,12 @@ class TriplesLoadEndpoint(SocketEndpoint):
self.publisher = Publisher(
self.pulsar_client, triples_store_queue,
schema=Triples
schema=JsonSchema(Triples)
)
async def start(self):
await self.publisher.start()
self.publisher.start()
async def listener(self, ws, running):
@ -50,7 +51,7 @@ class TriplesLoadEndpoint(SocketEndpoint):
triples=to_subgraph(data["triples"]),
)
await self.publisher.send(None, elt)
self.publisher.send(None, elt)
running.stop()

View file

@ -1,6 +1,7 @@
import asyncio
import queue
from pulsar.schema import JsonSchema
import uuid
from .. schema import Triples
@ -23,7 +24,7 @@ class TriplesStreamEndpoint(SocketEndpoint):
self.subscriber = Subscriber(
self.pulsar_client, triples_store_queue,
"api-gateway", "api-gateway",
schema=Triples
schema=JsonSchema(Triples)
)
async def listener(self, ws, running):
@ -38,7 +39,7 @@ class TriplesStreamEndpoint(SocketEndpoint):
async def start(self):
await self.subscriber.start()
self.subscriber.start()
async def async_thread(self, ws, running):

View file

@ -0,0 +1,295 @@
from . clients.graph_embeddings_client import GraphEmbeddingsClient
from . clients.triples_query_client import TriplesQueryClient
from . clients.embeddings_client import EmbeddingsClient
from . clients.prompt_client import PromptClient
from . schema import GraphEmbeddingsRequest, GraphEmbeddingsResponse
from . schema import TriplesQueryRequest, TriplesQueryResponse
from . schema import prompt_request_queue
from . schema import prompt_response_queue
from . schema import embeddings_request_queue
from . schema import embeddings_response_queue
from . schema import graph_embeddings_request_queue
from . schema import graph_embeddings_response_queue
from . schema import triples_request_queue
from . schema import triples_response_queue
LABEL="http://www.w3.org/2000/01/rdf-schema#label"
DEFINITION="http://www.w3.org/2004/02/skos/core#definition"
class Query:
def __init__(
self, rag, user, collection, verbose,
entity_limit=50, triple_limit=30, max_subgraph_size=1000,
max_path_length=2,
):
self.rag = rag
self.user = user
self.collection = collection
self.verbose = verbose
self.entity_limit = entity_limit
self.triple_limit = triple_limit
self.max_subgraph_size = max_subgraph_size
self.max_path_length = max_path_length
def get_vector(self, query):
if self.verbose:
print("Compute embeddings...", flush=True)
qembeds = self.rag.embeddings.request(query)
if self.verbose:
print("Done.", flush=True)
return qembeds
def get_entities(self, query):
vectors = self.get_vector(query)
if self.verbose:
print("Get entities...", flush=True)
entities = self.rag.ge_client.request(
user=self.user, collection=self.collection,
vectors=vectors, limit=self.entity_limit,
)
entities = [
e.value
for e in entities
]
if self.verbose:
print("Entities:", flush=True)
for ent in entities:
print(" ", ent, flush=True)
return entities
def maybe_label(self, e):
if e in self.rag.label_cache:
return self.rag.label_cache[e]
res = self.rag.triples_client.request(
user=self.user, collection=self.collection,
s=e, p=LABEL, o=None, limit=1,
)
if len(res) == 0:
self.rag.label_cache[e] = e
return e
self.rag.label_cache[e] = res[0].o.value
return self.rag.label_cache[e]
def follow_edges(self, ent, subgraph, path_length):
# Not needed?
if path_length <= 0:
return
# Stop spanning around if the subgraph is already maxed out
if len(subgraph) >= self.max_subgraph_size:
return
res = self.rag.triples_client.request(
user=self.user, collection=self.collection,
s=ent, p=None, o=None,
limit=self.triple_limit
)
for triple in res:
subgraph.add(
(triple.s.value, triple.p.value, triple.o.value)
)
if path_length > 1:
self.follow_edges(triple.o.value, subgraph, path_length-1)
res = self.rag.triples_client.request(
user=self.user, collection=self.collection,
s=None, p=ent, o=None,
limit=self.triple_limit
)
for triple in res:
subgraph.add(
(triple.s.value, triple.p.value, triple.o.value)
)
res = self.rag.triples_client.request(
user=self.user, collection=self.collection,
s=None, p=None, o=ent,
limit=self.triple_limit,
)
for triple in res:
subgraph.add(
(triple.s.value, triple.p.value, triple.o.value)
)
if path_length > 1:
self.follow_edges(triple.s.value, subgraph, path_length-1)
def get_subgraph(self, query):
entities = self.get_entities(query)
if self.verbose:
print("Get subgraph...", flush=True)
subgraph = set()
for ent in entities:
self.follow_edges(ent, subgraph, self.max_path_length)
subgraph = list(subgraph)
return subgraph
def get_labelgraph(self, query):
subgraph = self.get_subgraph(query)
sg2 = []
for edge in subgraph:
if edge[1] == LABEL:
continue
s = self.maybe_label(edge[0])
p = self.maybe_label(edge[1])
o = self.maybe_label(edge[2])
sg2.append((s, p, o))
sg2 = sg2[0:self.max_subgraph_size]
if self.verbose:
print("Subgraph:", flush=True)
for edge in sg2:
print(" ", str(edge), flush=True)
if self.verbose:
print("Done.", flush=True)
return sg2
class GraphRag:
def __init__(
self,
pulsar_host="pulsar://pulsar:6650",
pulsar_api_key=None,
pr_request_queue=None,
pr_response_queue=None,
emb_request_queue=None,
emb_response_queue=None,
ge_request_queue=None,
ge_response_queue=None,
tpl_request_queue=None,
tpl_response_queue=None,
verbose=False,
module="test",
):
self.verbose=verbose
if pr_request_queue is None:
pr_request_queue = prompt_request_queue
if pr_response_queue is None:
pr_response_queue = prompt_response_queue
if emb_request_queue is None:
emb_request_queue = embeddings_request_queue
if emb_response_queue is None:
emb_response_queue = embeddings_response_queue
if ge_request_queue is None:
ge_request_queue = graph_embeddings_request_queue
if ge_response_queue is None:
ge_response_queue = graph_embeddings_response_queue
if tpl_request_queue is None:
tpl_request_queue = triples_request_queue
if tpl_response_queue is None:
tpl_response_queue = triples_response_queue
if self.verbose:
print("Initialising...", flush=True)
self.ge_client = GraphEmbeddingsClient(
pulsar_host=pulsar_host,
pulsar_api_key=pulsar_api_key,
subscriber=module + "-ge",
input_queue=ge_request_queue,
output_queue=ge_response_queue,
)
self.triples_client = TriplesQueryClient(
pulsar_host=pulsar_host,
pulsar_api_key=pulsar_api_key,
subscriber=module + "-tpl",
input_queue=tpl_request_queue,
output_queue=tpl_response_queue
)
self.embeddings = EmbeddingsClient(
pulsar_host=pulsar_host,
pulsar_api_key=pulsar_api_key,
input_queue=emb_request_queue,
output_queue=emb_response_queue,
subscriber=module + "-emb",
)
self.label_cache = {}
self.prompt = PromptClient(
pulsar_host=pulsar_host,
pulsar_api_key=pulsar_api_key,
input_queue=pr_request_queue,
output_queue=pr_response_queue,
subscriber=module + "-prompt",
)
if self.verbose:
print("Initialised", flush=True)
def query(
self, query, user="trustgraph", collection="default",
entity_limit=50, triple_limit=30, max_subgraph_size=1000,
max_path_length=2,
):
if self.verbose:
print("Construct prompt...", flush=True)
q = Query(
rag=self, user=user, collection=collection, verbose=self.verbose,
entity_limit=entity_limit, triple_limit=triple_limit,
max_subgraph_size=max_subgraph_size,
max_path_length=max_path_length,
)
kg = q.get_labelgraph(query)
if self.verbose:
print("Invoke LLM...", flush=True)
print(kg)
print(query)
resp = self.prompt.request_kg_prompt(query, kg)
if self.verbose:
print("Done", flush=True)
return resp

View file

@ -35,7 +35,7 @@ from .. exceptions import RequestError
from . librarian import Librarian
module = "librarian"
module = ".".join(__name__.split(".")[1:-1])
default_input_queue = librarian_request_queue
default_output_queue = librarian_response_queue

View file

@ -10,11 +10,12 @@ from .. schema import text_completion_response_queue
from .. log_level import LogLevel
from .. base import Consumer
module = "metering"
module = ".".join(__name__.split(".")[1:-1])
default_input_queue = text_completion_response_queue
default_subscriber = module
class Processor(Consumer):
def __init__(self, **params):

View file

@ -27,7 +27,7 @@ from .... clients.llm_client import LlmClient
from . prompts import to_definitions, to_relationships, to_topics
from . prompts import to_kg_query, to_document_query, to_rows
module = "prompt"
module = ".".join(__name__.split(".")[1:-1])
default_input_queue = prompt_request_queue
default_output_queue = prompt_response_queue

View file

@ -4,6 +4,8 @@ import json
from jsonschema import validate
import re
from trustgraph.clients.llm_client import LlmClient
class PromptConfiguration:
def __init__(self, system_template, global_terms={}, prompts={}):
self.system_template = system_template
@ -19,7 +21,8 @@ class Prompt:
class PromptManager:
def __init__(self, config):
def __init__(self, llm, config):
self.llm = llm
self.config = config
self.terms = config.global_terms
@ -51,9 +54,7 @@ class PromptManager:
return json.loads(json_str)
async def invoke(self, id, input, llm):
print("Invoke...", flush=True)
def invoke(self, id, input):
if id not in self.prompts:
raise RuntimeError("ID invalid")
@ -67,7 +68,9 @@ class PromptManager:
"prompt": self.templates[id].render(terms)
}
resp = await llm(**prompt)
resp = self.llm.request(**prompt)
print(resp, flush=True)
if resp_type == "text":
return resp
@ -78,13 +81,13 @@ class PromptManager:
try:
obj = self.parse_json(resp)
except:
print("Parse fail:", resp, flush=True)
raise RuntimeError("JSON parse fail")
print(obj, flush=True)
if self.prompts[id].schema:
try:
print(self.prompts[id].schema)
validate(instance=obj, schema=self.prompts[id].schema)
print("Validated", flush=True)
except Exception as e:
raise RuntimeError(f"Schema validation fail: {e}")

View file

@ -3,7 +3,6 @@
Language service abstracts prompt engineering from LLM.
"""
import asyncio
import json
import re
@ -11,59 +10,74 @@ from .... schema import Definition, Relationship, Triple
from .... schema import Topic
from .... schema import PromptRequest, PromptResponse, Error
from .... schema import TextCompletionRequest, TextCompletionResponse
from .... base import FlowProcessor
from .... base import ProducerSpec, ConsumerSpec, TextCompletionClientSpec
from .... schema import text_completion_request_queue
from .... schema import text_completion_response_queue
from .... schema import prompt_request_queue, prompt_response_queue
from .... base import ConsumerProducer
from .... clients.llm_client import LlmClient
from . prompt_manager import PromptConfiguration, Prompt, PromptManager
default_ident = "prompt"
module = ".".join(__name__.split(".")[1:-1])
class Processor(FlowProcessor):
default_input_queue = prompt_request_queue
default_output_queue = prompt_response_queue
default_subscriber = module
class Processor(ConsumerProducer):
def __init__(self, **params):
id = params.get("id")
input_queue = params.get("input_queue", default_input_queue)
output_queue = params.get("output_queue", default_output_queue)
subscriber = params.get("subscriber", default_subscriber)
tc_request_queue = params.get(
"text_completion_request_queue", text_completion_request_queue
)
tc_response_queue = params.get(
"text_completion_response_queue", text_completion_response_queue
)
# Config key for prompts
self.config_key = params.get("config_type", "prompt")
super(Processor, self).__init__(
**params | {
"id": id,
"input_queue": input_queue,
"output_queue": output_queue,
"subscriber": subscriber,
"input_schema": PromptRequest,
"output_schema": PromptResponse,
"text_completion_request_queue": tc_request_queue,
"text_completion_response_queue": tc_response_queue,
}
)
self.register_specification(
ConsumerSpec(
name = "request",
schema = PromptRequest,
handler = self.on_request
)
self.llm = LlmClient(
subscriber=subscriber,
input_queue=tc_request_queue,
output_queue=tc_response_queue,
pulsar_host = self.pulsar_host,
pulsar_api_key=self.pulsar_api_key,
)
self.register_specification(
TextCompletionClientSpec(
request_name = "text-completion-request",
response_name = "text-completion-response",
)
)
# System prompt hack
class Llm:
def __init__(self, llm):
self.llm = llm
def request(self, system, prompt):
print(system)
print(prompt, flush=True)
return self.llm.request(system, prompt)
self.register_specification(
ProducerSpec(
name = "response",
schema = PromptResponse
)
)
self.register_config_handler(self.on_prompt_config)
self.llm = Llm(self.llm)
# Null configuration, should reload quickly
self.manager = PromptManager(
llm = self.llm,
config = PromptConfiguration("", {}, {})
)
async def on_prompt_config(self, config, version):
async def on_config(self, version, config):
print("Loading configuration version", version)
@ -97,6 +111,7 @@ class Processor(FlowProcessor):
)
self.manager = PromptManager(
self.llm,
PromptConfiguration(
system,
{},
@ -111,7 +126,7 @@ class Processor(FlowProcessor):
print("Exception:", e, flush=True)
print("Configuration reload failed", flush=True)
async def on_request(self, msg, consumer, flow):
async def handle(self, msg):
v = msg.value()
@ -123,7 +138,7 @@ class Processor(FlowProcessor):
try:
print(v.terms, flush=True)
print(v.terms)
input = {
k: json.loads(v)
@ -131,33 +146,14 @@ class Processor(FlowProcessor):
}
print(f"Handling kind {kind}...", flush=True)
print(input, flush=True)
async def llm(system, prompt):
print(system, flush=True)
print(prompt, flush=True)
resp = await flow("text-completion-request").text_completion(
system = system, prompt = prompt,
)
try:
return resp
except Exception as e:
print("LLM Exception:", e, flush=True)
return None
try:
resp = await self.manager.invoke(kind, input, llm)
except Exception as e:
print("Invocation exception:", e, flush=True)
raise e
print(resp, flush=True)
resp = self.manager.invoke(kind, input)
if isinstance(resp, str):
print("Send text response...", flush=True)
print(resp, flush=True)
r = PromptResponse(
text=resp,
@ -165,7 +161,7 @@ class Processor(FlowProcessor):
error=None,
)
await flow("response").send(r, properties={"id": id})
await self.send(r, properties={"id": id})
return
@ -180,13 +176,13 @@ class Processor(FlowProcessor):
error=None,
)
await flow("response").send(r, properties={"id": id})
await self.send(r, properties={"id": id})
return
except Exception as e:
print(f"Exception: {e}", flush=True)
print(f"Exception: {e}")
print("Send error response...", flush=True)
@ -198,11 +194,11 @@ class Processor(FlowProcessor):
response=None,
)
await flow("response").send(r, properties={"id": id})
await self.send(r, properties={"id": id})
except Exception as e:
print(f"Exception: {e}", flush=True)
print(f"Exception: {e}")
print("Send error response...", flush=True)
@ -219,7 +215,22 @@ class Processor(FlowProcessor):
@staticmethod
def add_args(parser):
FlowProcessor.add_args(parser)
ConsumerProducer.add_args(
parser, default_input_queue, default_subscriber,
default_output_queue,
)
parser.add_argument(
'--text-completion-request-queue',
default=text_completion_request_queue,
help=f'Text completion request queue (default: {text_completion_request_queue})',
)
parser.add_argument(
'--text-completion-response-queue',
default=text_completion_response_queue,
help=f'Text completion response queue (default: {text_completion_response_queue})',
)
parser.add_argument(
'--config-type',
@ -229,5 +240,5 @@ class Processor(FlowProcessor):
def run():
Processor.launch(default_ident, __doc__)
Processor.launch(module, __doc__)

View file

@ -16,7 +16,7 @@ from .... log_level import LogLevel
from .... base import ConsumerProducer
from .... exceptions import TooManyRequests
module = "text-completion"
module = ".".join(__name__.split(".")[1:-1])
default_input_queue = text_completion_request_queue
default_output_queue = text_completion_response_queue

View file

@ -16,7 +16,7 @@ from .... log_level import LogLevel
from .... base import ConsumerProducer
from .... exceptions import TooManyRequests
module = "text-completion"
module = ".".join(__name__.split(".")[1:-1])
default_input_queue = text_completion_request_queue
default_output_queue = text_completion_response_queue

View file

@ -15,7 +15,7 @@ from .... log_level import LogLevel
from .... base import ConsumerProducer
from .... exceptions import TooManyRequests
module = "text-completion"
module = ".".join(__name__.split(".")[1:-1])
default_input_queue = text_completion_request_queue
default_output_queue = text_completion_response_queue

View file

@ -15,7 +15,7 @@ from .... log_level import LogLevel
from .... base import ConsumerProducer
from .... exceptions import TooManyRequests
module = "text-completion"
module = ".".join(__name__.split(".")[1:-1])
default_input_queue = text_completion_request_queue
default_output_queue = text_completion_response_queue

View file

@ -17,7 +17,7 @@ from .... log_level import LogLevel
from .... base import ConsumerProducer
from .... exceptions import TooManyRequests
module = "text-completion"
module = ".".join(__name__.split(".")[1:-1])
default_input_queue = text_completion_request_queue
default_output_queue = text_completion_response_queue

View file

@ -14,7 +14,7 @@ from .... log_level import LogLevel
from .... base import ConsumerProducer
from .... exceptions import TooManyRequests
module = "text-completion"
module = ".".join(__name__.split(".")[1:-1])
default_input_queue = text_completion_request_queue
default_output_queue = text_completion_response_queue

View file

@ -15,7 +15,7 @@ from .... log_level import LogLevel
from .... base import ConsumerProducer
from .... exceptions import TooManyRequests
module = "text-completion"
module = ".".join(__name__.split(".")[1:-1])
default_input_queue = text_completion_request_queue
default_output_queue = text_completion_response_queue

View file

@ -15,7 +15,7 @@ from .... log_level import LogLevel
from .... base import ConsumerProducer
from .... exceptions import TooManyRequests
module = "text-completion"
module = ".".join(__name__.split(".")[1:-1])
default_input_queue = text_completion_request_queue
default_output_queue = text_completion_response_queue

View file

@ -15,7 +15,7 @@ from .... log_level import LogLevel
from .... base import ConsumerProducer
from .... exceptions import TooManyRequests
module = "text-completion"
module = ".".join(__name__.split(".")[1:-1])
default_input_queue = text_completion_request_queue
default_output_queue = text_completion_response_queue

View file

@ -15,7 +15,7 @@ from .... log_level import LogLevel
from .... base import ConsumerProducer
from .... exceptions import TooManyRequests
module = "text-completion"
module = ".".join(__name__.split(".")[1:-1])
default_input_queue = text_completion_request_queue
default_output_queue = text_completion_response_queue

View file

@ -11,7 +11,7 @@ from .... schema import document_embeddings_request_queue
from .... schema import document_embeddings_response_queue
from .... base import ConsumerProducer
module = "de-query"
module = ".".join(__name__.split(".")[1:-1])
default_input_queue = document_embeddings_request_queue
default_output_queue = document_embeddings_response_queue

View file

@ -16,7 +16,7 @@ from .... schema import document_embeddings_request_queue
from .... schema import document_embeddings_response_queue
from .... base import ConsumerProducer
module = "de-query"
module = ".".join(__name__.split(".")[1:-1])
default_input_queue = document_embeddings_request_queue
default_output_queue = document_embeddings_response_queue

View file

@ -7,51 +7,71 @@ of chunks
from qdrant_client import QdrantClient
from qdrant_client.models import PointStruct
from qdrant_client.models import Distance, VectorParams
import uuid
from .... schema import DocumentEmbeddingsResponse
from .... schema import DocumentEmbeddingsRequest, DocumentEmbeddingsResponse
from .... schema import Error, Value
from .... base import DocumentEmbeddingsQueryService
from .... schema import document_embeddings_request_queue
from .... schema import document_embeddings_response_queue
from .... base import ConsumerProducer
default_ident = "de-query"
module = ".".join(__name__.split(".")[1:-1])
default_input_queue = document_embeddings_request_queue
default_output_queue = document_embeddings_response_queue
default_subscriber = module
default_store_uri = 'http://localhost:6333'
class Processor(DocumentEmbeddingsQueryService):
class Processor(ConsumerProducer):
def __init__(self, **params):
input_queue = params.get("input_queue", default_input_queue)
output_queue = params.get("output_queue", default_output_queue)
subscriber = params.get("subscriber", default_subscriber)
store_uri = params.get("store_uri", default_store_uri)
#optional api key
api_key = params.get("api_key", None)
super(Processor, self).__init__(
**params | {
"input_queue": input_queue,
"output_queue": output_queue,
"subscriber": subscriber,
"input_schema": DocumentEmbeddingsRequest,
"output_schema": DocumentEmbeddingsResponse,
"store_uri": store_uri,
"api_key": api_key,
}
)
self.qdrant = QdrantClient(url=store_uri, api_key=api_key)
self.client = QdrantClient(url=store_uri, api_key=api_key)
async def query_document_embeddings(self, msg):
async def handle(self, msg):
try:
v = msg.value()
# Sender-produced ID
id = msg.properties()["id"]
print(f"Handling input {id}...", flush=True)
chunks = []
for vec in msg.vectors:
for vec in v.vectors:
dim = len(vec)
collection = (
"d_" + msg.user + "_" + msg.collection + "_" +
"d_" + v.user + "_" + v.collection + "_" +
str(dim)
)
search_result = self.qdrant.query_points(
search_result = self.client.query_points(
collection_name=collection,
query=vec,
limit=msg.limit,
limit=v.limit,
with_payload=True,
).points
@ -59,17 +79,37 @@ class Processor(DocumentEmbeddingsQueryService):
ent = r.payload["doc"]
chunks.append(ent)
return chunks
print("Send response...", flush=True)
r = DocumentEmbeddingsResponse(documents=chunks, error=None)
await self.send(r, properties={"id": id})
print("Done.", flush=True)
except Exception as e:
print(f"Exception: {e}")
raise e
print("Send error response...", flush=True)
r = DocumentEmbeddingsResponse(
error=Error(
type = "llm-error",
message = str(e),
),
documents=None,
)
await self.send(r, properties={"id": id})
self.consumer.acknowledge(msg)
@staticmethod
def add_args(parser):
DocumentEmbeddingsQueryService.add_args(parser)
ConsumerProducer.add_args(
parser, default_input_queue, default_subscriber,
default_output_queue,
)
parser.add_argument(
'-t', '--store-uri',
@ -85,5 +125,5 @@ class Processor(DocumentEmbeddingsQueryService):
def run():
Processor.launch(default_ident, __doc__)
Processor.launch(module, __doc__)

View file

@ -11,7 +11,7 @@ from .... schema import graph_embeddings_request_queue
from .... schema import graph_embeddings_response_queue
from .... base import ConsumerProducer
module = "ge-query"
module = ".".join(__name__.split(".")[1:-1])
default_input_queue = graph_embeddings_request_queue
default_output_queue = graph_embeddings_response_queue

View file

@ -16,7 +16,7 @@ from .... schema import graph_embeddings_request_queue
from .... schema import graph_embeddings_response_queue
from .... base import ConsumerProducer
module = "ge-query"
module = ".".join(__name__.split(".")[1:-1])
default_input_queue = graph_embeddings_request_queue
default_output_queue = graph_embeddings_response_queue

View file

@ -7,32 +7,44 @@ entities
from qdrant_client import QdrantClient
from qdrant_client.models import PointStruct
from qdrant_client.models import Distance, VectorParams
import uuid
from .... schema import GraphEmbeddingsResponse
from .... schema import GraphEmbeddingsRequest, GraphEmbeddingsResponse
from .... schema import Error, Value
from .... base import GraphEmbeddingsQueryService
from .... schema import graph_embeddings_request_queue
from .... schema import graph_embeddings_response_queue
from .... base import ConsumerProducer
default_ident = "ge-query"
module = ".".join(__name__.split(".")[1:-1])
default_input_queue = graph_embeddings_request_queue
default_output_queue = graph_embeddings_response_queue
default_subscriber = module
default_store_uri = 'http://localhost:6333'
class Processor(GraphEmbeddingsQueryService):
class Processor(ConsumerProducer):
def __init__(self, **params):
input_queue = params.get("input_queue", default_input_queue)
output_queue = params.get("output_queue", default_output_queue)
subscriber = params.get("subscriber", default_subscriber)
store_uri = params.get("store_uri", default_store_uri)
#optional api key
api_key = params.get("api_key", None)
super(Processor, self).__init__(
**params | {
"input_queue": input_queue,
"output_queue": output_queue,
"subscriber": subscriber,
"input_schema": GraphEmbeddingsRequest,
"output_schema": GraphEmbeddingsResponse,
"store_uri": store_uri,
"api_key": api_key,
}
)
self.qdrant = QdrantClient(url=store_uri, api_key=api_key)
self.client = QdrantClient(url=store_uri, api_key=api_key)
def create_value(self, ent):
if ent.startswith("http://") or ent.startswith("https://"):
@ -40,27 +52,34 @@ class Processor(GraphEmbeddingsQueryService):
else:
return Value(value=ent, is_uri=False)
async def query_graph_embeddings(self, msg):
async def handle(self, msg):
try:
v = msg.value()
# Sender-produced ID
id = msg.properties()["id"]
print(f"Handling input {id}...", flush=True)
entity_set = set()
entities = []
for vec in msg.vectors:
for vec in v.vectors:
dim = len(vec)
collection = (
"t_" + msg.user + "_" + msg.collection + "_" +
"t_" + v.user + "_" + v.collection + "_" +
str(dim)
)
# Heuristic hack, get (2*limit), so that we have more chance
# of getting (limit) entities
search_result = self.qdrant.query_points(
search_result = self.client.query_points(
collection_name=collection,
query=vec,
limit=msg.limit * 2,
limit=v.limit * 2,
with_payload=True,
).points
@ -73,10 +92,10 @@ class Processor(GraphEmbeddingsQueryService):
entities.append(ent)
# Keep adding entities until limit
if len(entity_set) >= msg.limit: break
if len(entity_set) >= v.limit: break
# Keep adding entities until limit
if len(entity_set) >= msg.limit: break
if len(entity_set) >= v.limit: break
ents2 = []
@ -86,19 +105,36 @@ class Processor(GraphEmbeddingsQueryService):
entities = ents2
print("Send response...", flush=True)
return entities
r = GraphEmbeddingsResponse(entities=entities, error=None)
await self.send(r, properties={"id": id})
print("Done.", flush=True)
except Exception as e:
print(f"Exception: {e}")
raise e
print("Send error response...", flush=True)
r = GraphEmbeddingsResponse(
error=Error(
type = "llm-error",
message = str(e),
),
entities=None,
)
await self.send(r, properties={"id": id})
self.consumer.acknowledge(msg)
@staticmethod
def add_args(parser):
GraphEmbeddingsQueryService.add_args(parser)
ConsumerProducer.add_args(
parser, default_input_queue, default_subscriber,
default_output_queue,
)
parser.add_argument(
'-t', '--store-uri',
@ -114,5 +150,5 @@ class Processor(GraphEmbeddingsQueryService):
def run():
Processor.launch(default_ident, __doc__)
Processor.launch(module, __doc__)

View file

@ -7,24 +7,38 @@ null. Output is a list of triples.
from .... direct.cassandra import TrustGraph
from .... schema import TriplesQueryRequest, TriplesQueryResponse, Error
from .... schema import Value, Triple
from .... base import TriplesQueryService
from .... schema import triples_request_queue
from .... schema import triples_response_queue
from .... base import ConsumerProducer
default_ident = "triples-query"
module = ".".join(__name__.split(".")[1:-1])
default_input_queue = triples_request_queue
default_output_queue = triples_response_queue
default_subscriber = module
default_graph_host='localhost'
class Processor(TriplesQueryService):
class Processor(ConsumerProducer):
def __init__(self, **params):
input_queue = params.get("input_queue", default_input_queue)
output_queue = params.get("output_queue", default_output_queue)
subscriber = params.get("subscriber", default_subscriber)
graph_host = params.get("graph_host", default_graph_host)
graph_username = params.get("graph_username", None)
graph_password = params.get("graph_password", None)
super(Processor, self).__init__(
**params | {
"input_queue": input_queue,
"output_queue": output_queue,
"subscriber": subscriber,
"input_schema": TriplesQueryRequest,
"output_schema": TriplesQueryResponse,
"graph_host": graph_host,
"graph_username": graph_username,
"graph_password": graph_password,
}
)
@ -39,85 +53,92 @@ class Processor(TriplesQueryService):
else:
return Value(value=ent, is_uri=False)
async def query_triples(self, query):
async def handle(self, msg):
try:
table = (query.user, query.collection)
v = msg.value()
table = (v.user, v.collection)
if table != self.table:
if self.username and self.password:
self.tg = TrustGraph(
hosts=self.graph_host,
keyspace=query.user, table=query.collection,
keyspace=v.user, table=v.collection,
username=self.username, password=self.password
)
else:
self.tg = TrustGraph(
hosts=self.graph_host,
keyspace=query.user, table=query.collection,
keyspace=v.user, table=v.collection,
)
self.table = table
# Sender-produced ID
id = msg.properties()["id"]
print(f"Handling input {id}...", flush=True)
triples = []
if query.s is not None:
if query.p is not None:
if query.o is not None:
if v.s is not None:
if v.p is not None:
if v.o is not None:
resp = self.tg.get_spo(
query.s.value, query.p.value, query.o.value,
limit=query.limit
v.s.value, v.p.value, v.o.value,
limit=v.limit
)
triples.append((query.s.value, query.p.value, query.o.value))
triples.append((v.s.value, v.p.value, v.o.value))
else:
resp = self.tg.get_sp(
query.s.value, query.p.value,
limit=query.limit
v.s.value, v.p.value,
limit=v.limit
)
for t in resp:
triples.append((query.s.value, query.p.value, t.o))
triples.append((v.s.value, v.p.value, t.o))
else:
if query.o is not None:
if v.o is not None:
resp = self.tg.get_os(
query.o.value, query.s.value,
limit=query.limit
v.o.value, v.s.value,
limit=v.limit
)
for t in resp:
triples.append((query.s.value, t.p, query.o.value))
triples.append((v.s.value, t.p, v.o.value))
else:
resp = self.tg.get_s(
query.s.value,
limit=query.limit
v.s.value,
limit=v.limit
)
for t in resp:
triples.append((query.s.value, t.p, t.o))
triples.append((v.s.value, t.p, t.o))
else:
if query.p is not None:
if query.o is not None:
if v.p is not None:
if v.o is not None:
resp = self.tg.get_po(
query.p.value, query.o.value,
limit=query.limit
v.p.value, v.o.value,
limit=v.limit
)
for t in resp:
triples.append((t.s, query.p.value, query.o.value))
triples.append((t.s, v.p.value, v.o.value))
else:
resp = self.tg.get_p(
query.p.value,
limit=query.limit
v.p.value,
limit=v.limit
)
for t in resp:
triples.append((t.s, query.p.value, t.o))
triples.append((t.s, v.p.value, t.o))
else:
if query.o is not None:
if v.o is not None:
resp = self.tg.get_o(
query.o.value,
limit=query.limit
v.o.value,
limit=v.limit
)
for t in resp:
triples.append((t.s, t.p, query.o.value))
triples.append((t.s, t.p, v.o.value))
else:
resp = self.tg.get_all(
limit=query.limit
limit=v.limit
)
for t in resp:
triples.append((t.s, t.p, t.o))
@ -131,17 +152,37 @@ class Processor(TriplesQueryService):
for t in triples
]
return triples
print("Send response...", flush=True)
r = TriplesQueryResponse(triples=triples, error=None)
await self.send(r, properties={"id": id})
print("Done.", flush=True)
except Exception as e:
print(f"Exception: {e}")
raise e
print("Send error response...", flush=True)
r = TriplesQueryResponse(
error=Error(
type = "llm-error",
message = str(e),
),
response=None,
)
await self.send(r, properties={"id": id})
self.consumer.acknowledge(msg)
@staticmethod
def add_args(parser):
TriplesQueryService.add_args(parser)
ConsumerProducer.add_args(
parser, default_input_queue, default_subscriber,
default_output_queue,
)
parser.add_argument(
'-g', '--graph-host',
@ -164,5 +205,5 @@ class Processor(TriplesQueryService):
def run():
Processor.launch(default_ident, __doc__)
Processor.launch(module, __doc__)

View file

@ -13,7 +13,7 @@ from .... schema import triples_request_queue
from .... schema import triples_response_queue
from .... base import ConsumerProducer
module = "triples-query"
module = ".".join(__name__.split(".")[1:-1])
default_input_queue = triples_request_queue
default_output_queue = triples_response_queue

View file

@ -13,7 +13,7 @@ from .... schema import triples_request_queue
from .... schema import triples_response_queue
from .... base import ConsumerProducer
module = "triples-query"
module = ".".join(__name__.split(".")[1:-1])
default_input_queue = triples_request_queue
default_output_queue = triples_response_queue

View file

@ -13,7 +13,7 @@ from .... schema import triples_request_queue
from .... schema import triples_response_queue
from .... base import ConsumerProducer
module = "triples-query"
module = ".".join(__name__.split(".")[1:-1])
default_input_queue = triples_request_queue
default_output_queue = triples_response_queue

View file

@ -1,94 +0,0 @@
import asyncio
LABEL="http://www.w3.org/2000/01/rdf-schema#label"
class Query:
def __init__(
self, rag, user, collection, verbose,
doc_limit=20
):
self.rag = rag
self.user = user
self.collection = collection
self.verbose = verbose
self.doc_limit = doc_limit
async def get_vector(self, query):
if self.verbose:
print("Compute embeddings...", flush=True)
qembeds = await self.rag.embeddings_client.embed(query)
if self.verbose:
print("Done.", flush=True)
return qembeds
async def get_docs(self, query):
vectors = await self.get_vector(query)
if self.verbose:
print("Get docs...", flush=True)
docs = await self.rag.doc_embeddings_client.query(
vectors, limit=self.doc_limit,
user=self.user, collection=self.collection,
)
if self.verbose:
print("Docs:", flush=True)
for doc in docs:
print(doc, flush=True)
return docs
class DocumentRag:
def __init__(
self, prompt_client, embeddings_client, doc_embeddings_client,
verbose=False,
):
self.verbose = verbose
self.prompt_client = prompt_client
self.embeddings_client = embeddings_client
self.doc_embeddings_client = doc_embeddings_client
if self.verbose:
print("Initialised", flush=True)
async def query(
self, query, user="trustgraph", collection="default",
doc_limit=20,
):
if self.verbose:
print("Construct prompt...", flush=True)
q = Query(
rag=self, user=user, collection=collection, verbose=self.verbose,
doc_limit=doc_limit
)
docs = await q.get_docs(query)
if self.verbose:
print("Invoke LLM...", flush=True)
print(docs)
print(query)
resp = await self.prompt_client.document_prompt(
query = query,
documents = docs
)
if self.verbose:
print("Done", flush=True)
return resp

View file

@ -5,77 +5,88 @@ Input is query, output is response.
"""
from ... schema import DocumentRagQuery, DocumentRagResponse, Error
from . document_rag import DocumentRag
from ... base import FlowProcessor, ConsumerSpec, ProducerSpec
from ... base import PromptClientSpec, EmbeddingsClientSpec
from ... base import DocumentEmbeddingsClientSpec
from ... schema import document_rag_request_queue, document_rag_response_queue
from ... schema import prompt_request_queue
from ... schema import prompt_response_queue
from ... schema import embeddings_request_queue
from ... schema import embeddings_response_queue
from ... schema import document_embeddings_request_queue
from ... schema import document_embeddings_response_queue
from ... log_level import LogLevel
from ... document_rag import DocumentRag
from ... base import ConsumerProducer
default_ident = "document-rag"
module = ".".join(__name__.split(".")[1:-1])
class Processor(FlowProcessor):
default_input_queue = document_rag_request_queue
default_output_queue = document_rag_response_queue
default_subscriber = module
class Processor(ConsumerProducer):
def __init__(self, **params):
id = params.get("id", default_ident)
input_queue = params.get("input_queue", default_input_queue)
output_queue = params.get("output_queue", default_output_queue)
subscriber = params.get("subscriber", default_subscriber)
pr_request_queue = params.get(
"prompt_request_queue", prompt_request_queue
)
pr_response_queue = params.get(
"prompt_response_queue", prompt_response_queue
)
emb_request_queue = params.get(
"embeddings_request_queue", embeddings_request_queue
)
emb_response_queue = params.get(
"embeddings_response_queue", embeddings_response_queue
)
de_request_queue = params.get(
"document_embeddings_request_queue",
document_embeddings_request_queue
)
de_response_queue = params.get(
"document_embeddings_response_queue",
document_embeddings_response_queue
)
doc_limit = params.get("doc_limit", 5)
doc_limit = params.get("doc_limit", 10)
super(Processor, self).__init__(
**params | {
"id": id,
"doc_limit": doc_limit,
"input_queue": input_queue,
"output_queue": output_queue,
"subscriber": subscriber,
"input_schema": DocumentRagQuery,
"output_schema": DocumentRagResponse,
"prompt_request_queue": pr_request_queue,
"prompt_response_queue": pr_response_queue,
"embeddings_request_queue": emb_request_queue,
"embeddings_response_queue": emb_response_queue,
"document_embeddings_request_queue": de_request_queue,
"document_embeddings_response_queue": de_response_queue,
}
)
self.rag = DocumentRag(
pulsar_host=self.pulsar_host,
pulsar_api_key=self.pulsar_api_key,
pr_request_queue=pr_request_queue,
pr_response_queue=pr_response_queue,
emb_request_queue=emb_request_queue,
emb_response_queue=emb_response_queue,
de_request_queue=de_request_queue,
de_response_queue=de_response_queue,
verbose=True,
module=module,
)
self.doc_limit = doc_limit
self.register_specification(
ConsumerSpec(
name = "request",
schema = DocumentRagQuery,
handler = self.on_request,
)
)
self.register_specification(
EmbeddingsClientSpec(
request_name = "embeddings-request",
response_name = "embeddings-response",
)
)
self.register_specification(
DocumentEmbeddingsClientSpec(
request_name = "document-embeddings-request",
response_name = "document-embeddings-response",
)
)
self.register_specification(
PromptClientSpec(
request_name = "prompt-request",
response_name = "prompt-response",
)
)
self.register_specification(
ProducerSpec(
name = "response",
schema = DocumentRagResponse,
)
)
async def on_request(self, msg, consumer, flow):
async def handle(self, msg):
try:
self.rag = DocumentRag(
embeddings_client = flow("embeddings-request"),
doc_embeddings_client = flow("document-embeddings-request"),
prompt_client = flow("prompt-request"),
verbose=True,
)
v = msg.value()
# Sender-produced ID
@ -88,15 +99,11 @@ class Processor(FlowProcessor):
else:
doc_limit = self.doc_limit
response = await self.rag.query(v.query, doc_limit=doc_limit)
response = self.rag.query(v.query, doc_limit=doc_limit)
await flow("response").send(
DocumentRagResponse(
response = response,
error = None
),
properties = {"id": id}
)
print("Send response...", flush=True)
r = DocumentRagResponse(response = response, error=None)
await self.send(r, properties={"id": id})
print("Done.", flush=True)
@ -106,21 +113,25 @@ class Processor(FlowProcessor):
print("Send error response...", flush=True)
await flow("response").send(
DocumentRagResponse(
response = None,
error = Error(
type = "document-rag-error",
message = str(e),
),
r = DocumentRagResponse(
error=Error(
type = "llm-error",
message = str(e),
),
properties = {"id": id}
response=None,
)
await self.send(r, properties={"id": id})
self.consumer.acknowledge(msg)
@staticmethod
def add_args(parser):
FlowProcessor.add_args(parser)
ConsumerProducer.add_args(
parser, default_input_queue, default_subscriber,
default_output_queue,
)
parser.add_argument(
'-d', '--doc-limit',
@ -129,7 +140,43 @@ class Processor(FlowProcessor):
help=f'Default document fetch limit (default: 10)'
)
parser.add_argument(
'--prompt-request-queue',
default=prompt_request_queue,
help=f'Prompt request queue (default: {prompt_request_queue})',
)
parser.add_argument(
'--prompt-response-queue',
default=prompt_response_queue,
help=f'Prompt response queue (default: {prompt_response_queue})',
)
parser.add_argument(
'--embeddings-request-queue',
default=embeddings_request_queue,
help=f'Embeddings request queue (default: {embeddings_request_queue})',
)
parser.add_argument(
'--embeddings-response-queue',
default=embeddings_response_queue,
help=f'Embeddings response queue (default: {embeddings_response_queue})',
)
parser.add_argument(
'--document-embeddings-request-queue',
default=document_embeddings_request_queue,
help=f'Document embeddings request queue (default: {document_embeddings_request_queue})',
)
parser.add_argument(
'--document-embeddings-response-queue',
default=document_embeddings_response_queue,
help=f'Document embeddings response queue (default: {document_embeddings_response_queue})',
)
def run():
Processor.launch(default_ident, __doc__)
Processor.launch(module, __doc__)

View file

@ -1,218 +0,0 @@
import asyncio
LABEL="http://www.w3.org/2000/01/rdf-schema#label"
class Query:
def __init__(
self, rag, user, collection, verbose,
entity_limit=50, triple_limit=30, max_subgraph_size=1000,
max_path_length=2,
):
self.rag = rag
self.user = user
self.collection = collection
self.verbose = verbose
self.entity_limit = entity_limit
self.triple_limit = triple_limit
self.max_subgraph_size = max_subgraph_size
self.max_path_length = max_path_length
async def get_vector(self, query):
if self.verbose:
print("Compute embeddings...", flush=True)
qembeds = await self.rag.embeddings_client.embed(query)
if self.verbose:
print("Done.", flush=True)
return qembeds
async def get_entities(self, query):
vectors = await self.get_vector(query)
if self.verbose:
print("Get entities...", flush=True)
entities = await self.rag.graph_embeddings_client.query(
vectors=vectors, limit=self.entity_limit,
user=self.user, collection=self.collection,
)
entities = [
str(e)
for e in entities
]
if self.verbose:
print("Entities:", flush=True)
for ent in entities:
print(" ", ent, flush=True)
return entities
async def maybe_label(self, e):
if e in self.rag.label_cache:
return self.rag.label_cache[e]
res = await self.rag.triples_client.query(
s=e, p=LABEL, o=None, limit=1,
user=self.user, collection=self.collection,
)
if len(res) == 0:
self.rag.label_cache[e] = e
return e
self.rag.label_cache[e] = str(res[0].o)
return self.rag.label_cache[e]
async def follow_edges(self, ent, subgraph, path_length):
# Not needed?
if path_length <= 0:
return
# Stop spanning around if the subgraph is already maxed out
if len(subgraph) >= self.max_subgraph_size:
return
res = await self.rag.triples_client.query(
s=ent, p=None, o=None,
limit=self.triple_limit,
user=self.user, collection=self.collection,
)
for triple in res:
subgraph.add(
(str(triple.s), str(triple.p), str(triple.o))
)
if path_length > 1:
await self.follow_edges(str(triple.o), subgraph, path_length-1)
res = await self.rag.triples_client.query(
s=None, p=ent, o=None,
limit=self.triple_limit,
user=self.user, collection=self.collection,
)
for triple in res:
subgraph.add(
(str(triple.s), str(triple.p), str(triple.o))
)
res = await self.rag.triples_client.query(
s=None, p=None, o=ent,
limit=self.triple_limit,
user=self.user, collection=self.collection,
)
for triple in res:
subgraph.add(
(str(triple.s), str(triple.p), str(triple.o))
)
if path_length > 1:
await self.follow_edges(
str(triple.s), subgraph, path_length-1
)
async def get_subgraph(self, query):
entities = await self.get_entities(query)
if self.verbose:
print("Get subgraph...", flush=True)
subgraph = set()
for ent in entities:
await self.follow_edges(ent, subgraph, self.max_path_length)
subgraph = list(subgraph)
return subgraph
async def get_labelgraph(self, query):
subgraph = await self.get_subgraph(query)
sg2 = []
for edge in subgraph:
if edge[1] == LABEL:
continue
s = await self.maybe_label(edge[0])
p = await self.maybe_label(edge[1])
o = await self.maybe_label(edge[2])
sg2.append((s, p, o))
sg2 = sg2[0:self.max_subgraph_size]
if self.verbose:
print("Subgraph:", flush=True)
for edge in sg2:
print(" ", str(edge), flush=True)
if self.verbose:
print("Done.", flush=True)
return sg2
class GraphRag:
def __init__(
self, prompt_client, embeddings_client, graph_embeddings_client,
triples_client, verbose=False,
):
self.verbose = verbose
self.prompt_client = prompt_client
self.embeddings_client = embeddings_client
self.graph_embeddings_client = graph_embeddings_client
self.triples_client = triples_client
self.label_cache = {}
if self.verbose:
print("Initialised", flush=True)
async def query(
self, query, user = "trustgraph", collection = "default",
entity_limit = 50, triple_limit = 30, max_subgraph_size = 1000,
max_path_length = 2,
):
if self.verbose:
print("Construct prompt...", flush=True)
q = Query(
rag = self, user = user, collection = collection,
verbose = self.verbose, entity_limit = entity_limit,
triple_limit = triple_limit,
max_subgraph_size = max_subgraph_size,
max_path_length = max_path_length,
)
kg = await q.get_labelgraph(query)
if self.verbose:
print("Invoke LLM...", flush=True)
print(kg)
print(query)
resp = await self.prompt_client.kg_prompt(query, kg)
if self.verbose:
print("Done", flush=True)
return resp

View file

@ -5,18 +5,57 @@ Input is query, output is response.
"""
from ... schema import GraphRagQuery, GraphRagResponse, Error
from . graph_rag import GraphRag
from ... base import FlowProcessor, ConsumerSpec, ProducerSpec
from ... base import PromptClientSpec, EmbeddingsClientSpec
from ... base import GraphEmbeddingsClientSpec, TriplesClientSpec
from ... schema import graph_rag_request_queue, graph_rag_response_queue
from ... schema import prompt_request_queue
from ... schema import prompt_response_queue
from ... schema import embeddings_request_queue
from ... schema import embeddings_response_queue
from ... schema import graph_embeddings_request_queue
from ... schema import graph_embeddings_response_queue
from ... schema import triples_request_queue
from ... schema import triples_response_queue
from ... log_level import LogLevel
from ... graph_rag import GraphRag
from ... base import ConsumerProducer
default_ident = "graph-rag"
module = ".".join(__name__.split(".")[1:-1])
class Processor(FlowProcessor):
default_input_queue = graph_rag_request_queue
default_output_queue = graph_rag_response_queue
default_subscriber = module
class Processor(ConsumerProducer):
def __init__(self, **params):
id = params.get("id", default_ident)
input_queue = params.get("input_queue", default_input_queue)
output_queue = params.get("output_queue", default_output_queue)
subscriber = params.get("subscriber", default_subscriber)
pr_request_queue = params.get(
"prompt_request_queue", prompt_request_queue
)
pr_response_queue = params.get(
"prompt_response_queue", prompt_response_queue
)
emb_request_queue = params.get(
"embeddings_request_queue", embeddings_request_queue
)
emb_response_queue = params.get(
"embeddings_response_queue", embeddings_response_queue
)
ge_request_queue = params.get(
"graph_embeddings_request_queue", graph_embeddings_request_queue
)
ge_response_queue = params.get(
"graph_embeddings_response_queue", graph_embeddings_response_queue
)
tpl_request_queue = params.get(
"triples_request_queue", triples_request_queue
)
tpl_response_queue = params.get(
"triples_response_queue", triples_response_queue
)
entity_limit = params.get("entity_limit", 50)
triple_limit = params.get("triple_limit", 30)
@ -25,74 +64,49 @@ class Processor(FlowProcessor):
super(Processor, self).__init__(
**params | {
"id": id,
"input_queue": input_queue,
"output_queue": output_queue,
"subscriber": subscriber,
"input_schema": GraphRagQuery,
"output_schema": GraphRagResponse,
"entity_limit": entity_limit,
"triple_limit": triple_limit,
"max_subgraph_size": max_subgraph_size,
"max_path_length": max_path_length,
"prompt_request_queue": pr_request_queue,
"prompt_response_queue": pr_response_queue,
"embeddings_request_queue": emb_request_queue,
"embeddings_response_queue": emb_response_queue,
"graph_embeddings_request_queue": ge_request_queue,
"graph_embeddings_response_queue": ge_response_queue,
"triples_request_queue": triples_request_queue,
"triples_response_queue": triples_response_queue,
}
)
self.rag = GraphRag(
pulsar_host=self.pulsar_host,
pulsar_api_key=self.pulsar_api_key,
pr_request_queue=pr_request_queue,
pr_response_queue=pr_response_queue,
emb_request_queue=emb_request_queue,
emb_response_queue=emb_response_queue,
ge_request_queue=ge_request_queue,
ge_response_queue=ge_response_queue,
tpl_request_queue=triples_request_queue,
tpl_response_queue=triples_response_queue,
verbose=True,
module=module,
)
self.default_entity_limit = entity_limit
self.default_triple_limit = triple_limit
self.default_max_subgraph_size = max_subgraph_size
self.default_max_path_length = max_path_length
self.register_specification(
ConsumerSpec(
name = "request",
schema = GraphRagQuery,
handler = self.on_request,
)
)
self.register_specification(
EmbeddingsClientSpec(
request_name = "embeddings-request",
response_name = "embeddings-response",
)
)
self.register_specification(
GraphEmbeddingsClientSpec(
request_name = "graph-embeddings-request",
response_name = "graph-embeddings-response",
)
)
self.register_specification(
TriplesClientSpec(
request_name = "triples-request",
response_name = "triples-response",
)
)
self.register_specification(
PromptClientSpec(
request_name = "prompt-request",
response_name = "prompt-response",
)
)
self.register_specification(
ProducerSpec(
name = "response",
schema = GraphRagResponse,
)
)
async def on_request(self, msg, consumer, flow):
async def handle(self, msg):
try:
self.rag = GraphRag(
embeddings_client = flow("embeddings-request"),
graph_embeddings_client = flow("graph-embeddings-request"),
triples_client = flow("triples-request"),
prompt_client = flow("prompt-request"),
verbose=True,
)
v = msg.value()
# Sender-produced ID
@ -120,20 +134,16 @@ class Processor(FlowProcessor):
else:
max_path_length = self.default_max_path_length
response = await self.rag.query(
query = v.query, user = v.user, collection = v.collection,
entity_limit = entity_limit, triple_limit = triple_limit,
max_subgraph_size = max_subgraph_size,
max_path_length = max_path_length,
response = self.rag.query(
query=v.query, user=v.user, collection=v.collection,
entity_limit=entity_limit, triple_limit=triple_limit,
max_subgraph_size=max_subgraph_size,
max_path_length=max_path_length,
)
await flow("response").send(
GraphRagResponse(
response = response,
error = None
),
properties = {"id": id}
)
print("Send response...", flush=True)
r = GraphRagResponse(response=response, error=None)
await self.send(r, properties={"id": id})
print("Done.", flush=True)
@ -143,21 +153,25 @@ class Processor(FlowProcessor):
print("Send error response...", flush=True)
await flow("response").send(
GraphRagResponse(
response = None,
error = Error(
type = "graph-rag-error",
message = str(e),
),
r = GraphRagResponse(
error=Error(
type = "llm-error",
message = str(e),
),
properties = {"id": id}
response=None,
)
await self.send(r, properties={"id": id})
self.consumer.acknowledge(msg)
@staticmethod
def add_args(parser):
FlowProcessor.add_args(parser)
ConsumerProducer.add_args(
parser, default_input_queue, default_subscriber,
default_output_queue,
)
parser.add_argument(
'-e', '--entity-limit',
@ -187,7 +201,55 @@ class Processor(FlowProcessor):
help=f'Default max path length (default: 2)'
)
parser.add_argument(
'--prompt-request-queue',
default=prompt_request_queue,
help=f'Prompt request queue (default: {prompt_request_queue})',
)
parser.add_argument(
'--prompt-response-queue',
default=prompt_response_queue,
help=f'Prompt response queue (default: {prompt_response_queue})',
)
parser.add_argument(
'--embeddings-request-queue',
default=embeddings_request_queue,
help=f'Embeddings request queue (default: {embeddings_request_queue})',
)
parser.add_argument(
'--embeddings-response-queue',
default=embeddings_response_queue,
help=f'Embeddings response queue (default: {embeddings_response_queue})',
)
parser.add_argument(
'--graph-embeddings-request-queue',
default=graph_embeddings_request_queue,
help=f'Graph embeddings request queue (default: {graph_embeddings_request_queue})',
)
parser.add_argument(
'--graph-embeddings-response-queue',
default=graph_embeddings_response_queue,
help=f'Graph embeddings response queue (default: {graph_embeddings_response_queue})',
)
parser.add_argument(
'--triples-request-queue',
default=triples_request_queue,
help=f'Triples request queue (default: {triples_request_queue})',
)
parser.add_argument(
'--triples-response-queue',
default=triples_response_queue,
help=f'Triples response queue (default: {triples_response_queue})',
)
def run():
Processor.launch(default_ident, __doc__)
Processor.launch(module, __doc__)

View file

@ -10,7 +10,7 @@ from .... schema import document_embeddings_store_queue
from .... log_level import LogLevel
from .... base import Consumer
module = "de-write"
module = ".".join(__name__.split(".")[1:-1])
default_input_queue = document_embeddings_store_queue
default_subscriber = module

View file

@ -16,7 +16,7 @@ from .... schema import document_embeddings_store_queue
from .... log_level import LogLevel
from .... base import Consumer
module = "de-write"
module = ".".join(__name__.split(".")[1:-1])
default_input_queue = document_embeddings_store_queue
default_subscriber = module

View file

@ -8,21 +8,31 @@ from qdrant_client.models import PointStruct
from qdrant_client.models import Distance, VectorParams
import uuid
from .... base import DocumentEmbeddingsStoreService
from .... schema import DocumentEmbeddings
from .... schema import document_embeddings_store_queue
from .... log_level import LogLevel
from .... base import Consumer
default_ident = "de-write"
module = ".".join(__name__.split(".")[1:-1])
default_input_queue = document_embeddings_store_queue
default_subscriber = module
default_store_uri = 'http://localhost:6333'
class Processor(DocumentEmbeddingsStoreService):
class Processor(Consumer):
def __init__(self, **params):
input_queue = params.get("input_queue", default_input_queue)
subscriber = params.get("subscriber", default_subscriber)
store_uri = params.get("store_uri", default_store_uri)
api_key = params.get("api_key", None)
super(Processor, self).__init__(
**params | {
"input_queue": input_queue,
"subscriber": subscriber,
"input_schema": DocumentEmbeddings,
"store_uri": store_uri,
"api_key": api_key,
}
@ -30,11 +40,13 @@ class Processor(DocumentEmbeddingsStoreService):
self.last_collection = None
self.qdrant = QdrantClient(url=store_uri, api_key=api_key)
self.client = QdrantClient(url=store_uri)
async def store_document_embeddings(self, message):
async def handle(self, msg):
for emb in message.chunks:
v = msg.value()
for emb in v.chunks:
chunk = emb.chunk.decode("utf-8")
if chunk == "": return
@ -43,17 +55,16 @@ class Processor(DocumentEmbeddingsStoreService):
dim = len(vec)
collection = (
"d_" + message.metadata.user + "_" +
message.metadata.collection + "_" +
"d_" + v.metadata.user + "_" + v.metadata.collection + "_" +
str(dim)
)
if collection != self.last_collection:
if not self.qdrant.collection_exists(collection):
if not self.client.collection_exists(collection):
try:
self.qdrant.create_collection(
self.client.create_collection(
collection_name=collection,
vectors_config=VectorParams(
size=dim, distance=Distance.COSINE
@ -65,7 +76,7 @@ class Processor(DocumentEmbeddingsStoreService):
self.last_collection = collection
self.qdrant.upsert(
self.client.upsert(
collection_name=collection,
points=[
PointStruct(
@ -81,7 +92,9 @@ class Processor(DocumentEmbeddingsStoreService):
@staticmethod
def add_args(parser):
DocumentEmbeddingsStoreService.add_args(parser)
Consumer.add_args(
parser, default_input_queue, default_subscriber,
)
parser.add_argument(
'-t', '--store-uri',
@ -97,5 +110,5 @@ class Processor(DocumentEmbeddingsStoreService):
def run():
Processor.launch(default_ident, __doc__)
Processor.launch(module, __doc__)

View file

@ -9,7 +9,7 @@ from .... log_level import LogLevel
from .... direct.milvus_graph_embeddings import EntityVectors
from .... base import Consumer
module = "ge-write"
module = ".".join(__name__.split(".")[1:-1])
default_input_queue = graph_embeddings_store_queue
default_subscriber = module

View file

@ -15,7 +15,7 @@ from .... schema import graph_embeddings_store_queue
from .... log_level import LogLevel
from .... base import Consumer
module = "ge-write"
module = ".".join(__name__.split(".")[1:-1])
default_input_queue = graph_embeddings_store_queue
default_subscriber = module

View file

@ -8,21 +8,31 @@ from qdrant_client.models import PointStruct
from qdrant_client.models import Distance, VectorParams
import uuid
from .... base import GraphEmbeddingsStoreService
from .... schema import GraphEmbeddings
from .... schema import graph_embeddings_store_queue
from .... log_level import LogLevel
from .... base import Consumer
default_ident = "ge-write"
module = ".".join(__name__.split(".")[1:-1])
default_input_queue = graph_embeddings_store_queue
default_subscriber = module
default_store_uri = 'http://localhost:6333'
class Processor(GraphEmbeddingsStoreService):
class Processor(Consumer):
def __init__(self, **params):
input_queue = params.get("input_queue", default_input_queue)
subscriber = params.get("subscriber", default_subscriber)
store_uri = params.get("store_uri", default_store_uri)
api_key = params.get("api_key", None)
super(Processor, self).__init__(
**params | {
"input_queue": input_queue,
"subscriber": subscriber,
"input_schema": GraphEmbeddings,
"store_uri": store_uri,
"api_key": api_key,
}
@ -30,7 +40,7 @@ class Processor(GraphEmbeddingsStoreService):
self.last_collection = None
self.qdrant = QdrantClient(url=store_uri, api_key=api_key)
self.client = QdrantClient(url=store_uri, api_key=api_key)
def get_collection(self, dim, user, collection):
@ -40,10 +50,10 @@ class Processor(GraphEmbeddingsStoreService):
if cname != self.last_collection:
if not self.qdrant.collection_exists(cname):
if not self.client.collection_exists(cname):
try:
self.qdrant.create_collection(
self.client.create_collection(
collection_name=cname,
vectors_config=VectorParams(
size=dim, distance=Distance.COSINE
@ -57,9 +67,11 @@ class Processor(GraphEmbeddingsStoreService):
return cname
async def store_graph_embeddings(self, message):
async def handle(self, msg):
for entity in message.entities:
v = msg.value()
for entity in v.entities:
if entity.entity.value == "" or entity.entity.value is None: return
@ -68,10 +80,10 @@ class Processor(GraphEmbeddingsStoreService):
dim = len(vec)
collection = self.get_collection(
dim, message.metadata.user, message.metadata.collection
dim, v.metadata.user, v.metadata.collection
)
self.qdrant.upsert(
self.client.upsert(
collection_name=collection,
points=[
PointStruct(
@ -87,7 +99,9 @@ class Processor(GraphEmbeddingsStoreService):
@staticmethod
def add_args(parser):
GraphEmbeddingsStoreService.add_args(parser)
Consumer.add_args(
parser, default_input_queue, default_subscriber,
)
parser.add_argument(
'-t', '--store-uri',
@ -103,5 +117,5 @@ class Processor(GraphEmbeddingsStoreService):
def run():
Processor.launch(default_ident, __doc__)
Processor.launch(module, __doc__)

View file

@ -9,7 +9,7 @@ from .... log_level import LogLevel
from .... direct.milvus_object_embeddings import ObjectVectors
from .... base import Consumer
module = "oe-write"
module = ".".join(__name__.split(".")[1:-1])
default_input_queue = object_embeddings_store_queue
default_subscriber = module

View file

@ -17,7 +17,7 @@ from .... schema import rows_store_queue
from .... log_level import LogLevel
from .... base import Consumer
module = "rows-write"
module = ".".join(__name__.split(".")[1:-1])
ssl_context = SSLContext(PROTOCOL_TLSv1_2)
default_input_queue = rows_store_queue

View file

@ -10,26 +10,35 @@ import argparse
import time
from .... direct.cassandra import TrustGraph
from .... base import TriplesStoreService
from .... schema import Triples
from .... schema import triples_store_queue
from .... log_level import LogLevel
from .... base import Consumer
default_ident = "triples-write"
module = ".".join(__name__.split(".")[1:-1])
default_input_queue = triples_store_queue
default_subscriber = module
default_graph_host='localhost'
class Processor(TriplesStoreService):
class Processor(Consumer):
def __init__(self, **params):
id = params.get("id", default_ident)
input_queue = params.get("input_queue", default_input_queue)
subscriber = params.get("subscriber", default_subscriber)
graph_host = params.get("graph_host", default_graph_host)
graph_username = params.get("graph_username", None)
graph_password = params.get("graph_password", None)
super(Processor, self).__init__(
**params | {
"input_queue": input_queue,
"subscriber": subscriber,
"input_schema": Triples,
"graph_host": graph_host,
"graph_username": graph_username
"graph_username": graph_username,
"graph_password": graph_password,
}
)
@ -38,9 +47,11 @@ class Processor(TriplesStoreService):
self.password = graph_password
self.table = None
async def store_triples(self, message):
async def handle(self, msg):
table = (message.metadata.user, message.metadata.collection)
v = msg.value()
table = (v.metadata.user, v.metadata.collection)
if self.table is None or self.table != table:
@ -50,15 +61,13 @@ class Processor(TriplesStoreService):
if self.username and self.password:
self.tg = TrustGraph(
hosts=self.graph_host,
keyspace=message.metadata.user,
table=message.metadata.collection,
keyspace=v.metadata.user, table=v.metadata.collection,
username=self.username, password=self.password
)
else:
self.tg = TrustGraph(
hosts=self.graph_host,
keyspace=message.metadata.user,
table=message.metadata.collection,
keyspace=v.metadata.user, table=v.metadata.collection,
)
except Exception as e:
print("Exception", e, flush=True)
@ -67,7 +76,7 @@ class Processor(TriplesStoreService):
self.table = table
for t in message.triples:
for t in v.triples:
self.tg.insert(
t.s.value,
t.p.value,
@ -77,7 +86,9 @@ class Processor(TriplesStoreService):
@staticmethod
def add_args(parser):
TriplesStoreService.add_args(parser)
Consumer.add_args(
parser, default_input_queue, default_subscriber,
)
parser.add_argument(
'-g', '--graph-host',
@ -99,5 +110,5 @@ class Processor(TriplesStoreService):
def run():
Processor.launch(default_ident, __doc__)
Processor.launch(module, __doc__)

View file

@ -16,7 +16,7 @@ from .... schema import triples_store_queue
from .... log_level import LogLevel
from .... base import Consumer
module = "triples-write"
module = ".".join(__name__.split(".")[1:-1])
default_input_queue = triples_store_queue
default_subscriber = module

View file

@ -16,7 +16,7 @@ from .... schema import triples_store_queue
from .... log_level import LogLevel
from .... base import Consumer
module = "triples-write"
module = ".".join(__name__.split(".")[1:-1])
default_input_queue = triples_store_queue
default_subscriber = module

View file

@ -16,7 +16,7 @@ from .... schema import triples_store_queue
from .... log_level import LogLevel
from .... base import Consumer
module = "triples-write"
module = ".".join(__name__.split(".")[1:-1])
default_input_queue = triples_store_queue
default_subscriber = module