mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-28 18:06:21 +02:00
Revert "Feature/configure flows (#345)"
This reverts commit a9197d11ee.
This commit is contained in:
parent
3adb3cf59c
commit
1822ca395f
125 changed files with 2628 additions and 3751 deletions
|
|
@ -8,11 +8,12 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
class AgentManager:
|
||||
|
||||
def __init__(self, tools, additional_context=None):
|
||||
def __init__(self, context, tools, additional_context=None):
|
||||
self.context = context
|
||||
self.tools = tools
|
||||
self.additional_context = additional_context
|
||||
|
||||
async def reason(self, question, history, context):
|
||||
def reason(self, question, history):
|
||||
|
||||
tools = self.tools
|
||||
|
||||
|
|
@ -55,7 +56,10 @@ class AgentManager:
|
|||
|
||||
logger.info(f"prompt: {variables}")
|
||||
|
||||
obj = await context("prompt-request").agent_react(variables)
|
||||
obj = self.context.prompt.request(
|
||||
"agent-react",
|
||||
variables
|
||||
)
|
||||
|
||||
print(json.dumps(obj, indent=4), flush=True)
|
||||
|
||||
|
|
@ -81,13 +85,9 @@ class AgentManager:
|
|||
|
||||
return a
|
||||
|
||||
async def react(self, question, history, think, observe, context):
|
||||
async def react(self, question, history, think, observe):
|
||||
|
||||
act = await self.reason(
|
||||
question = question,
|
||||
history = history,
|
||||
context = context,
|
||||
)
|
||||
act = self.reason(question, history)
|
||||
logger.info(f"act: {act}")
|
||||
|
||||
if isinstance(act, Final):
|
||||
|
|
@ -104,12 +104,7 @@ class AgentManager:
|
|||
else:
|
||||
raise RuntimeError(f"No action for {act.name}!")
|
||||
|
||||
print("TOOL>>>", act)
|
||||
resp = await action.implementation(context).invoke(
|
||||
**act.arguments
|
||||
)
|
||||
|
||||
print("RSETUL", resp)
|
||||
resp = action.implementation.invoke(**act.arguments)
|
||||
|
||||
resp = resp.strip()
|
||||
|
||||
|
|
|
|||
|
|
@ -6,68 +6,103 @@ import json
|
|||
import re
|
||||
import sys
|
||||
|
||||
from ... base import AgentService, TextCompletionClientSpec, PromptClientSpec
|
||||
from ... base import GraphRagClientSpec
|
||||
from pulsar.schema import JsonSchema
|
||||
|
||||
from ... schema import AgentRequest, AgentResponse, AgentStep, Error
|
||||
from ... base import ConsumerProducer
|
||||
from ... schema import Error
|
||||
from ... schema import AgentRequest, AgentResponse, AgentStep
|
||||
from ... schema import agent_request_queue, agent_response_queue
|
||||
from ... schema import prompt_request_queue as pr_request_queue
|
||||
from ... schema import prompt_response_queue as pr_response_queue
|
||||
from ... schema import graph_rag_request_queue as gr_request_queue
|
||||
from ... schema import graph_rag_response_queue as gr_response_queue
|
||||
from ... clients.prompt_client import PromptClient
|
||||
from ... clients.llm_client import LlmClient
|
||||
from ... clients.graph_rag_client import GraphRagClient
|
||||
|
||||
from . tools import KnowledgeQueryImpl, TextCompletionImpl
|
||||
from . agent_manager import AgentManager
|
||||
|
||||
from . types import Final, Action, Tool, Argument
|
||||
|
||||
default_ident = "agent-manager"
|
||||
default_max_iterations = 10
|
||||
module = ".".join(__name__.split(".")[1:-1])
|
||||
|
||||
class Processor(AgentService):
|
||||
default_input_queue = agent_request_queue
|
||||
default_output_queue = agent_response_queue
|
||||
default_subscriber = module
|
||||
default_max_iterations = 15
|
||||
|
||||
class Processor(ConsumerProducer):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
||||
id = params.get("id")
|
||||
|
||||
self.max_iterations = int(
|
||||
params.get("max_iterations", default_max_iterations)
|
||||
)
|
||||
|
||||
tools = {}
|
||||
|
||||
input_queue = params.get("input_queue", default_input_queue)
|
||||
output_queue = params.get("output_queue", default_output_queue)
|
||||
subscriber = params.get("subscriber", default_subscriber)
|
||||
prompt_request_queue = params.get(
|
||||
"prompt_request_queue", pr_request_queue
|
||||
)
|
||||
prompt_response_queue = params.get(
|
||||
"prompt_response_queue", pr_response_queue
|
||||
)
|
||||
graph_rag_request_queue = params.get(
|
||||
"graph_rag_request_queue", gr_request_queue
|
||||
)
|
||||
graph_rag_response_queue = params.get(
|
||||
"graph_rag_response_queue", gr_response_queue
|
||||
)
|
||||
|
||||
self.config_key = params.get("config_type", "agent")
|
||||
|
||||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
"id": id,
|
||||
"max_iterations": self.max_iterations,
|
||||
"config_type": self.config_key,
|
||||
"input_queue": input_queue,
|
||||
"output_queue": output_queue,
|
||||
"subscriber": subscriber,
|
||||
"input_schema": AgentRequest,
|
||||
"output_schema": AgentResponse,
|
||||
"prompt_request_queue": prompt_request_queue,
|
||||
"prompt_response_queue": prompt_response_queue,
|
||||
"graph_rag_request_queue": gr_request_queue,
|
||||
"graph_rag_response_queue": gr_response_queue,
|
||||
}
|
||||
)
|
||||
|
||||
self.prompt = PromptClient(
|
||||
subscriber=subscriber,
|
||||
input_queue=prompt_request_queue,
|
||||
output_queue=prompt_response_queue,
|
||||
pulsar_host = self.pulsar_host,
|
||||
pulsar_api_key=self.pulsar_api_key,
|
||||
)
|
||||
|
||||
self.graph_rag = GraphRagClient(
|
||||
subscriber=subscriber,
|
||||
input_queue=graph_rag_request_queue,
|
||||
output_queue=graph_rag_response_queue,
|
||||
pulsar_host = self.pulsar_host,
|
||||
pulsar_api_key=self.pulsar_api_key,
|
||||
)
|
||||
|
||||
# Need to be able to feed requests to myself
|
||||
self.recursive_input = self.client.create_producer(
|
||||
topic=input_queue,
|
||||
schema=JsonSchema(AgentRequest),
|
||||
)
|
||||
|
||||
self.agent = AgentManager(
|
||||
context=self,
|
||||
tools=[],
|
||||
additional_context="",
|
||||
)
|
||||
|
||||
self.config_handlers.append(self.on_tools_config)
|
||||
|
||||
self.register_specification(
|
||||
TextCompletionClientSpec(
|
||||
request_name = "text-completion-request",
|
||||
response_name = "text-completion-response",
|
||||
)
|
||||
)
|
||||
|
||||
self.register_specification(
|
||||
GraphRagClientSpec(
|
||||
request_name = "graph-rag-request",
|
||||
response_name = "graph-rag-response",
|
||||
)
|
||||
)
|
||||
|
||||
self.register_specification(
|
||||
PromptClientSpec(
|
||||
request_name = "prompt-request",
|
||||
response_name = "prompt-response",
|
||||
)
|
||||
)
|
||||
|
||||
async def on_tools_config(self, config, version):
|
||||
async def on_config(self, version, config):
|
||||
|
||||
print("Loading configuration version", version)
|
||||
|
||||
|
|
@ -103,9 +138,9 @@ class Processor(AgentService):
|
|||
impl_id = data.get("type")
|
||||
|
||||
if impl_id == "knowledge-query":
|
||||
impl = KnowledgeQueryImpl
|
||||
impl = KnowledgeQueryImpl(self)
|
||||
elif impl_id == "text-completion":
|
||||
impl = TextCompletionImpl
|
||||
impl = TextCompletionImpl(self)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Tool-kind {impl_id} not known"
|
||||
|
|
@ -120,6 +155,7 @@ class Processor(AgentService):
|
|||
)
|
||||
|
||||
self.agent = AgentManager(
|
||||
context=self,
|
||||
tools=tools,
|
||||
additional_context=additional
|
||||
)
|
||||
|
|
@ -128,14 +164,19 @@ class Processor(AgentService):
|
|||
|
||||
except Exception as e:
|
||||
|
||||
print("on_tools_config Exception:", e, flush=True)
|
||||
print("Exception:", e, flush=True)
|
||||
print("Configuration reload failed", flush=True)
|
||||
|
||||
async def agent_request(self, request, respond, next, flow):
|
||||
async def handle(self, msg):
|
||||
|
||||
try:
|
||||
|
||||
if request.history:
|
||||
v = msg.value()
|
||||
|
||||
# Sender-produced ID
|
||||
id = msg.properties()["id"]
|
||||
|
||||
if v.history:
|
||||
history = [
|
||||
Action(
|
||||
thought=h.thought,
|
||||
|
|
@ -143,12 +184,12 @@ class Processor(AgentService):
|
|||
arguments=h.arguments,
|
||||
observation=h.observation
|
||||
)
|
||||
for h in request.history
|
||||
for h in v.history
|
||||
]
|
||||
else:
|
||||
history = []
|
||||
|
||||
print(f"Question: {request.question}", flush=True)
|
||||
print(f"Question: {v.question}", flush=True)
|
||||
|
||||
if len(history) >= self.max_iterations:
|
||||
raise RuntimeError("Too many agent iterations")
|
||||
|
|
@ -166,7 +207,7 @@ class Processor(AgentService):
|
|||
observation=None,
|
||||
)
|
||||
|
||||
await respond(r)
|
||||
await self.send(r, properties={"id": id})
|
||||
|
||||
async def observe(x):
|
||||
|
||||
|
|
@ -179,21 +220,15 @@ class Processor(AgentService):
|
|||
observation=x,
|
||||
)
|
||||
|
||||
await respond(r)
|
||||
await self.send(r, properties={"id": id})
|
||||
|
||||
act = await self.agent.react(
|
||||
question = request.question,
|
||||
history = history,
|
||||
think = think,
|
||||
observe = observe,
|
||||
context = flow,
|
||||
)
|
||||
act = await self.agent.react(v.question, history, think, observe)
|
||||
|
||||
print(f"Action: {act}", flush=True)
|
||||
|
||||
if isinstance(act, Final):
|
||||
print("Send response...", flush=True)
|
||||
|
||||
print("Send final response...", flush=True)
|
||||
if type(act) == Final:
|
||||
|
||||
r = AgentResponse(
|
||||
answer=act.final,
|
||||
|
|
@ -201,20 +236,18 @@ class Processor(AgentService):
|
|||
thought=None,
|
||||
)
|
||||
|
||||
await respond(r)
|
||||
await self.send(r, properties={"id": id})
|
||||
|
||||
print("Done.", flush=True)
|
||||
|
||||
return
|
||||
|
||||
print("Send next...", flush=True)
|
||||
|
||||
history.append(act)
|
||||
|
||||
r = AgentRequest(
|
||||
question=request.question,
|
||||
plan=request.plan,
|
||||
state=request.state,
|
||||
question=v.question,
|
||||
plan=v.plan,
|
||||
state=v.state,
|
||||
history=[
|
||||
AgentStep(
|
||||
thought=h.thought,
|
||||
|
|
@ -226,7 +259,7 @@ class Processor(AgentService):
|
|||
]
|
||||
)
|
||||
|
||||
await next(r)
|
||||
self.recursive_input.send(r, properties={"id": id})
|
||||
|
||||
print("Done.", flush=True)
|
||||
|
||||
|
|
@ -234,7 +267,7 @@ class Processor(AgentService):
|
|||
|
||||
except Exception as e:
|
||||
|
||||
print(f"agent_request Exception: {e}")
|
||||
print(f"Exception: {e}")
|
||||
|
||||
print("Send error response...", flush=True)
|
||||
|
||||
|
|
@ -246,12 +279,39 @@ class Processor(AgentService):
|
|||
response=None,
|
||||
)
|
||||
|
||||
await respond(r)
|
||||
await self.send(r, properties={"id": id})
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
||||
AgentService.add_args(parser)
|
||||
ConsumerProducer.add_args(
|
||||
parser, default_input_queue, default_subscriber,
|
||||
default_output_queue,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--prompt-request-queue',
|
||||
default=pr_request_queue,
|
||||
help=f'Prompt request queue (default: {pr_request_queue})',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--prompt-response-queue',
|
||||
default=pr_response_queue,
|
||||
help=f'Prompt response queue (default: {pr_response_queue})',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--graph-rag-request-queue',
|
||||
default=gr_request_queue,
|
||||
help=f'Graph RAG request queue (default: {gr_request_queue})',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--graph-rag-response-queue',
|
||||
default=gr_response_queue,
|
||||
help=f'Graph RAG response queue (default: {gr_response_queue})',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--max-iterations',
|
||||
|
|
@ -267,5 +327,5 @@ class Processor(AgentService):
|
|||
|
||||
def run():
|
||||
|
||||
Processor.launch(default_ident, __doc__)
|
||||
Processor.launch(module, __doc__)
|
||||
|
||||
|
|
|
|||
|
|
@ -4,22 +4,16 @@
|
|||
class KnowledgeQueryImpl:
|
||||
def __init__(self, context):
|
||||
self.context = context
|
||||
async def invoke(self, **arguments):
|
||||
client = self.context("graph-rag-request")
|
||||
print("Graph RAG question...", flush=True)
|
||||
return await client.rag(
|
||||
arguments.get("question")
|
||||
)
|
||||
def invoke(self, **arguments):
|
||||
return self.context.graph_rag.request(arguments.get("question"))
|
||||
|
||||
# This tool implementation knows how to do text completion. This uses
|
||||
# the prompt service, rather than talking to TextCompletion directly.
|
||||
class TextCompletionImpl:
|
||||
def __init__(self, context):
|
||||
self.context = context
|
||||
async def invoke(self, **arguments):
|
||||
client = self.context("prompt-request")
|
||||
print("Prompt question...", flush=True)
|
||||
return await client.question(
|
||||
arguments.get("question")
|
||||
def invoke(self, **arguments):
|
||||
return self.context.prompt.request(
|
||||
"question", { "question": arguments.get("question") }
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -7,27 +7,40 @@ as text as separate output objects.
|
|||
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
||||
from prometheus_client import Histogram
|
||||
|
||||
from ... schema import TextDocument, Chunk
|
||||
from ... base import FlowProcessor, ConsumerSpec, ProducerSpec
|
||||
from ... schema import TextDocument, Chunk, Metadata
|
||||
from ... schema import text_ingest_queue, chunk_ingest_queue
|
||||
from ... log_level import LogLevel
|
||||
from ... base import ConsumerProducer
|
||||
|
||||
default_ident = "chunker"
|
||||
module = ".".join(__name__.split(".")[1:-1])
|
||||
|
||||
class Processor(FlowProcessor):
|
||||
default_input_queue = text_ingest_queue
|
||||
default_output_queue = chunk_ingest_queue
|
||||
default_subscriber = module
|
||||
|
||||
class Processor(ConsumerProducer):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
||||
id = params.get("id", default_ident)
|
||||
input_queue = params.get("input_queue", default_input_queue)
|
||||
output_queue = params.get("output_queue", default_output_queue)
|
||||
subscriber = params.get("subscriber", default_subscriber)
|
||||
chunk_size = params.get("chunk_size", 2000)
|
||||
chunk_overlap = params.get("chunk_overlap", 100)
|
||||
|
||||
super(Processor, self).__init__(
|
||||
**params | { "id": id }
|
||||
**params | {
|
||||
"input_queue": input_queue,
|
||||
"output_queue": output_queue,
|
||||
"subscriber": subscriber,
|
||||
"input_schema": TextDocument,
|
||||
"output_schema": Chunk,
|
||||
}
|
||||
)
|
||||
|
||||
if not hasattr(__class__, "chunk_metric"):
|
||||
__class__.chunk_metric = Histogram(
|
||||
'chunk_size', 'Chunk size',
|
||||
["id", "flow"],
|
||||
buckets=[100, 160, 250, 400, 650, 1000, 1600,
|
||||
2500, 4000, 6400, 10000, 16000]
|
||||
)
|
||||
|
|
@ -39,24 +52,7 @@ class Processor(FlowProcessor):
|
|||
is_separator_regex=False,
|
||||
)
|
||||
|
||||
self.register_specification(
|
||||
ConsumerSpec(
|
||||
name = "input",
|
||||
schema = TextDocument,
|
||||
handler = self.on_message,
|
||||
)
|
||||
)
|
||||
|
||||
self.register_specification(
|
||||
ProducerSpec(
|
||||
name = "output",
|
||||
schema = Chunk,
|
||||
)
|
||||
)
|
||||
|
||||
print("Chunker initialised", flush=True)
|
||||
|
||||
async def on_message(self, msg, consumer, flow):
|
||||
async def handle(self, msg):
|
||||
|
||||
v = msg.value()
|
||||
print(f"Chunking {v.metadata.id}...", flush=True)
|
||||
|
|
@ -67,25 +63,24 @@ class Processor(FlowProcessor):
|
|||
|
||||
for ix, chunk in enumerate(texts):
|
||||
|
||||
print("Chunk", len(chunk.page_content), flush=True)
|
||||
|
||||
r = Chunk(
|
||||
metadata=v.metadata,
|
||||
chunk=chunk.page_content.encode("utf-8"),
|
||||
)
|
||||
|
||||
__class__.chunk_metric.labels(
|
||||
id=consumer.id, flow=consumer.flow
|
||||
).observe(len(chunk.page_content))
|
||||
__class__.chunk_metric.observe(len(chunk.page_content))
|
||||
|
||||
await flow("output").send(r)
|
||||
await self.send(r)
|
||||
|
||||
print("Done.", flush=True)
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
||||
FlowProcessor.add_args(parser)
|
||||
ConsumerProducer.add_args(
|
||||
parser, default_input_queue, default_subscriber,
|
||||
default_output_queue,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-z', '--chunk-size',
|
||||
|
|
@ -103,5 +98,5 @@ class Processor(FlowProcessor):
|
|||
|
||||
def run():
|
||||
|
||||
Processor.launch(default_ident, __doc__)
|
||||
Processor.launch(module, __doc__)
|
||||
|
||||
|
|
|
|||
|
|
@ -7,27 +7,40 @@ as text as separate output objects.
|
|||
from langchain_text_splitters import TokenTextSplitter
|
||||
from prometheus_client import Histogram
|
||||
|
||||
from ... schema import TextDocument, Chunk
|
||||
from ... base import FlowProcessor
|
||||
from ... schema import TextDocument, Chunk, Metadata
|
||||
from ... schema import text_ingest_queue, chunk_ingest_queue
|
||||
from ... log_level import LogLevel
|
||||
from ... base import ConsumerProducer
|
||||
|
||||
default_ident = "chunker"
|
||||
module = ".".join(__name__.split(".")[1:-1])
|
||||
|
||||
class Processor(FlowProcessor):
|
||||
default_input_queue = text_ingest_queue
|
||||
default_output_queue = chunk_ingest_queue
|
||||
default_subscriber = module
|
||||
|
||||
class Processor(ConsumerProducer):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
||||
id = params.get("id")
|
||||
input_queue = params.get("input_queue", default_input_queue)
|
||||
output_queue = params.get("output_queue", default_output_queue)
|
||||
subscriber = params.get("subscriber", default_subscriber)
|
||||
chunk_size = params.get("chunk_size", 250)
|
||||
chunk_overlap = params.get("chunk_overlap", 15)
|
||||
|
||||
super(Processor, self).__init__(
|
||||
**params | { "id": id }
|
||||
**params | {
|
||||
"input_queue": input_queue,
|
||||
"output_queue": output_queue,
|
||||
"subscriber": subscriber,
|
||||
"input_schema": TextDocument,
|
||||
"output_schema": Chunk,
|
||||
}
|
||||
)
|
||||
|
||||
if not hasattr(__class__, "chunk_metric"):
|
||||
__class__.chunk_metric = Histogram(
|
||||
'chunk_size', 'Chunk size',
|
||||
["id", "flow"],
|
||||
buckets=[100, 160, 250, 400, 650, 1000, 1600,
|
||||
2500, 4000, 6400, 10000, 16000]
|
||||
)
|
||||
|
|
@ -38,24 +51,7 @@ class Processor(FlowProcessor):
|
|||
chunk_overlap=chunk_overlap,
|
||||
)
|
||||
|
||||
self.register_specification(
|
||||
ConsumerSpec(
|
||||
name = "input",
|
||||
schema = TextDocument,
|
||||
handler = self.on_message,
|
||||
)
|
||||
)
|
||||
|
||||
self.register_specification(
|
||||
ProducerSpec(
|
||||
name = "output",
|
||||
schema = Chunk,
|
||||
)
|
||||
)
|
||||
|
||||
print("Chunker initialised", flush=True)
|
||||
|
||||
async def on_message(self, msg, consumer, flow):
|
||||
async def handle(self, msg):
|
||||
|
||||
v = msg.value()
|
||||
print(f"Chunking {v.metadata.id}...", flush=True)
|
||||
|
|
@ -66,25 +62,24 @@ class Processor(FlowProcessor):
|
|||
|
||||
for ix, chunk in enumerate(texts):
|
||||
|
||||
print("Chunk", len(chunk.page_content), flush=True)
|
||||
|
||||
r = Chunk(
|
||||
metadata=v.metadata,
|
||||
chunk=chunk.page_content.encode("utf-8"),
|
||||
)
|
||||
|
||||
__class__.chunk_metric.labels(
|
||||
id=consumer.id, flow=consumer.flow
|
||||
).observe(len(chunk.page_content))
|
||||
__class__.chunk_metric.observe(len(chunk.page_content))
|
||||
|
||||
await flow("output").send(r)
|
||||
await self.send(r)
|
||||
|
||||
print("Done.", flush=True)
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
||||
FlowProcessor.add_args(parser)
|
||||
ConsumerProducer.add_args(
|
||||
parser, default_input_queue, default_subscriber,
|
||||
default_output_queue,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-z', '--chunk-size',
|
||||
|
|
@ -102,5 +97,5 @@ class Processor(FlowProcessor):
|
|||
|
||||
def run():
|
||||
|
||||
Processor.launch(default_ident, __doc__)
|
||||
Processor.launch(module, __doc__)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,215 +0,0 @@
|
|||
|
||||
from trustgraph.schema import ConfigResponse
|
||||
from trustgraph.schema import ConfigValue, Error
|
||||
|
||||
# This behaves just like a dict, should be easier to add persistent storage
|
||||
# later
|
||||
class ConfigurationItems(dict):
|
||||
pass
|
||||
|
||||
class Configuration(dict):
|
||||
|
||||
# FIXME: The state is held internally. This only works if there's
|
||||
# one config service. Should be more than one, and use a
|
||||
# back-end state store.
|
||||
|
||||
def __init__(self, push):
|
||||
|
||||
# Version counter
|
||||
self.version = 0
|
||||
|
||||
# External function to respond to update
|
||||
self.push = push
|
||||
|
||||
def __getitem__(self, key):
|
||||
if key not in self:
|
||||
self[key] = ConfigurationItems()
|
||||
return dict.__getitem__(self, key)
|
||||
|
||||
async def handle_get(self, v):
|
||||
|
||||
for k in v.keys:
|
||||
if k.type not in self or k.key not in self[k.type]:
|
||||
return ConfigResponse(
|
||||
version = None,
|
||||
values = None,
|
||||
directory = None,
|
||||
config = None,
|
||||
error = Error(
|
||||
type = "key-error",
|
||||
message = f"Key error"
|
||||
)
|
||||
)
|
||||
|
||||
values = [
|
||||
ConfigValue(
|
||||
type = k.type,
|
||||
key = k.key,
|
||||
value = self[k.type][k.key]
|
||||
)
|
||||
for k in v.keys
|
||||
]
|
||||
|
||||
return ConfigResponse(
|
||||
version = self.version,
|
||||
values = values,
|
||||
directory = None,
|
||||
config = None,
|
||||
error = None,
|
||||
)
|
||||
|
||||
async def handle_list(self, v):
|
||||
|
||||
if v.type not in self:
|
||||
|
||||
return ConfigResponse(
|
||||
version = None,
|
||||
values = None,
|
||||
directory = None,
|
||||
config = None,
|
||||
error = Error(
|
||||
type = "key-error",
|
||||
message = "No such type",
|
||||
),
|
||||
)
|
||||
|
||||
return ConfigResponse(
|
||||
version = self.version,
|
||||
values = None,
|
||||
directory = list(self[v.type].keys()),
|
||||
config = None,
|
||||
error = None,
|
||||
)
|
||||
|
||||
async def handle_getvalues(self, v):
|
||||
|
||||
if v.type not in self:
|
||||
|
||||
return ConfigResponse(
|
||||
version = None,
|
||||
values = None,
|
||||
directory = None,
|
||||
config = None,
|
||||
error = Error(
|
||||
type = "key-error",
|
||||
message = f"Key error"
|
||||
)
|
||||
)
|
||||
|
||||
values = [
|
||||
ConfigValue(
|
||||
type = v.type,
|
||||
key = k,
|
||||
value = self[v.type][k],
|
||||
)
|
||||
for k in self[v.type]
|
||||
]
|
||||
|
||||
return ConfigResponse(
|
||||
version = self.version,
|
||||
values = values,
|
||||
directory = None,
|
||||
config = None,
|
||||
error = None,
|
||||
)
|
||||
|
||||
async def handle_delete(self, v):
|
||||
|
||||
for k in v.keys:
|
||||
if k.type not in self or k.key not in self[k.type]:
|
||||
return ConfigResponse(
|
||||
version = None,
|
||||
values = None,
|
||||
directory = None,
|
||||
config = None,
|
||||
error = Error(
|
||||
type = "key-error",
|
||||
message = f"Key error"
|
||||
)
|
||||
)
|
||||
|
||||
for k in v.keys:
|
||||
del self[k.type][k.key]
|
||||
|
||||
self.version += 1
|
||||
|
||||
await self.push()
|
||||
|
||||
return ConfigResponse(
|
||||
version = None,
|
||||
value = None,
|
||||
directory = None,
|
||||
values = None,
|
||||
config = None,
|
||||
error = None,
|
||||
)
|
||||
|
||||
async def handle_put(self, v):
|
||||
|
||||
for k in v.values:
|
||||
self[k.type][k.key] = k.value
|
||||
|
||||
self.version += 1
|
||||
|
||||
await self.push()
|
||||
|
||||
return ConfigResponse(
|
||||
version = None,
|
||||
value = None,
|
||||
directory = None,
|
||||
values = None,
|
||||
error = None,
|
||||
)
|
||||
|
||||
async def handle_config(self, v):
|
||||
|
||||
return ConfigResponse(
|
||||
version = self.version,
|
||||
value = None,
|
||||
directory = None,
|
||||
values = None,
|
||||
config = self,
|
||||
error = None,
|
||||
)
|
||||
|
||||
async def handle(self, msg):
|
||||
|
||||
print("Handle message ", msg.operation)
|
||||
|
||||
if msg.operation == "get":
|
||||
|
||||
resp = await self.handle_get(msg)
|
||||
|
||||
elif msg.operation == "list":
|
||||
|
||||
resp = await self.handle_list(msg)
|
||||
|
||||
elif msg.operation == "getvalues":
|
||||
|
||||
resp = await self.handle_getvalues(msg)
|
||||
|
||||
elif msg.operation == "delete":
|
||||
|
||||
resp = await self.handle_delete(msg)
|
||||
|
||||
elif msg.operation == "put":
|
||||
|
||||
resp = await self.handle_put(msg)
|
||||
|
||||
elif msg.operation == "config":
|
||||
|
||||
resp = await self.handle_config(msg)
|
||||
|
||||
else:
|
||||
|
||||
resp = ConfigResponse(
|
||||
value=None,
|
||||
directory=None,
|
||||
values=None,
|
||||
error=Error(
|
||||
type = "bad-operation",
|
||||
message = "Bad operation"
|
||||
)
|
||||
)
|
||||
|
||||
return resp
|
||||
|
|
@ -1,90 +1,214 @@
|
|||
|
||||
"""
|
||||
Config service. Manages system global configuration state
|
||||
Config service. Fetchs an extract from the Wikipedia page
|
||||
using the API.
|
||||
"""
|
||||
|
||||
from pulsar.schema import JsonSchema
|
||||
|
||||
from trustgraph.schema import ConfigRequest, ConfigResponse, ConfigPush
|
||||
from trustgraph.schema import Error
|
||||
from trustgraph.schema import ConfigValue, Error
|
||||
from trustgraph.schema import config_request_queue, config_response_queue
|
||||
from trustgraph.schema import config_push_queue
|
||||
from trustgraph.log_level import LogLevel
|
||||
from trustgraph.base import AsyncProcessor, Consumer, Producer
|
||||
from trustgraph.base import ConsumerProducer
|
||||
|
||||
from . config import Configuration
|
||||
from ... base import ProcessorMetrics, ConsumerMetrics, ProducerMetrics
|
||||
from ... base import Consumer, Producer
|
||||
module = ".".join(__name__.split(".")[1:-1])
|
||||
|
||||
default_ident = "config-svc"
|
||||
|
||||
default_request_queue = config_request_queue
|
||||
default_response_queue = config_response_queue
|
||||
default_input_queue = config_request_queue
|
||||
default_output_queue = config_response_queue
|
||||
default_push_queue = config_push_queue
|
||||
default_subscriber = module
|
||||
|
||||
class Processor(AsyncProcessor):
|
||||
# This behaves just like a dict, should be easier to add persistent storage
|
||||
# later
|
||||
|
||||
class ConfigurationItems(dict):
|
||||
pass
|
||||
|
||||
class Configuration(dict):
|
||||
|
||||
def __getitem__(self, key):
|
||||
if key not in self:
|
||||
self[key] = ConfigurationItems()
|
||||
return dict.__getitem__(self, key)
|
||||
|
||||
class Processor(ConsumerProducer):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
||||
request_queue = params.get("request_queue", default_request_queue)
|
||||
response_queue = params.get("response_queue", default_response_queue)
|
||||
push_queue = params.get("push_queue", default_push_queue)
|
||||
id = params.get("id")
|
||||
|
||||
request_schema = ConfigRequest
|
||||
response_schema = ConfigResponse
|
||||
push_schema = ConfigResponse
|
||||
input_queue = params.get("input_queue", default_input_queue)
|
||||
output_queue = params.get("output_queue", default_output_queue)
|
||||
push_queue = params.get("push_queue", default_push_queue)
|
||||
subscriber = params.get("subscriber", default_subscriber)
|
||||
|
||||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
"request_schema": request_schema.__name__,
|
||||
"response_schema": response_schema.__name__,
|
||||
"push_schema": push_schema.__name__,
|
||||
"input_queue": input_queue,
|
||||
"output_queue": output_queue,
|
||||
"push_queue": output_queue,
|
||||
"subscriber": subscriber,
|
||||
"input_schema": ConfigRequest,
|
||||
"output_schema": ConfigResponse,
|
||||
"push_schema": ConfigPush,
|
||||
}
|
||||
)
|
||||
|
||||
request_metrics = ConsumerMetrics(id + "-request")
|
||||
response_metrics = ProducerMetrics(id + "-response")
|
||||
push_metrics = ProducerMetrics(id + "-push")
|
||||
|
||||
self.push_pub = Producer(
|
||||
client = self.client,
|
||||
topic = push_queue,
|
||||
schema = ConfigPush,
|
||||
metrics = push_metrics,
|
||||
self.push_prod = self.client.create_producer(
|
||||
topic=push_queue,
|
||||
schema=JsonSchema(ConfigPush),
|
||||
)
|
||||
|
||||
self.response_pub = Producer(
|
||||
client = self.client,
|
||||
topic = response_queue,
|
||||
schema = ConfigResponse,
|
||||
metrics = response_metrics,
|
||||
)
|
||||
# FIXME: The state is held internally. This only works if there's
|
||||
# one config service. Should be more than one, and use a
|
||||
# back-end state store.
|
||||
self.config = Configuration()
|
||||
|
||||
self.subs = Consumer(
|
||||
taskgroup = self.taskgroup,
|
||||
client = self.client,
|
||||
flow = None,
|
||||
topic = request_queue,
|
||||
subscriber = id,
|
||||
schema = request_schema,
|
||||
handler = self.on_message,
|
||||
metrics = request_metrics,
|
||||
)
|
||||
|
||||
self.config = Configuration(self.push)
|
||||
|
||||
print("Service initialised.")
|
||||
# Version counter
|
||||
self.version = 0
|
||||
|
||||
async def start(self):
|
||||
await self.push()
|
||||
|
||||
async def handle_get(self, v, id):
|
||||
|
||||
for k in v.keys:
|
||||
if k.type not in self.config or k.key not in self.config[k.type]:
|
||||
return ConfigResponse(
|
||||
version = None,
|
||||
values = None,
|
||||
directory = None,
|
||||
config = None,
|
||||
error = Error(
|
||||
type = "key-error",
|
||||
message = f"Key error"
|
||||
)
|
||||
)
|
||||
|
||||
values = [
|
||||
ConfigValue(
|
||||
type = k.type,
|
||||
key = k.key,
|
||||
value = self.config[k.type][k.key]
|
||||
)
|
||||
for k in v.keys
|
||||
]
|
||||
|
||||
return ConfigResponse(
|
||||
version = self.version,
|
||||
values = values,
|
||||
directory = None,
|
||||
config = None,
|
||||
error = None,
|
||||
)
|
||||
|
||||
async def handle_list(self, v, id):
|
||||
|
||||
if v.type not in self.config:
|
||||
|
||||
return ConfigResponse(
|
||||
version = None,
|
||||
values = None,
|
||||
directory = None,
|
||||
config = None,
|
||||
error = Error(
|
||||
type = "key-error",
|
||||
message = "No such type",
|
||||
),
|
||||
)
|
||||
|
||||
return ConfigResponse(
|
||||
version = self.version,
|
||||
values = None,
|
||||
directory = list(self.config[v.type].keys()),
|
||||
config = None,
|
||||
error = None,
|
||||
)
|
||||
|
||||
async def handle_getvalues(self, v, id):
|
||||
|
||||
if v.type not in self.config:
|
||||
|
||||
return ConfigResponse(
|
||||
version = None,
|
||||
values = None,
|
||||
directory = None,
|
||||
config = None,
|
||||
error = Error(
|
||||
type = "key-error",
|
||||
message = f"Key error"
|
||||
)
|
||||
)
|
||||
|
||||
values = [
|
||||
ConfigValue(
|
||||
type = v.type,
|
||||
key = k,
|
||||
value = self.config[v.type][k],
|
||||
)
|
||||
for k in self.config[v.type]
|
||||
]
|
||||
|
||||
return ConfigResponse(
|
||||
version = self.version,
|
||||
values = values,
|
||||
directory = None,
|
||||
config = None,
|
||||
error = None,
|
||||
)
|
||||
|
||||
async def handle_delete(self, v, id):
|
||||
|
||||
for k in v.keys:
|
||||
if k.type not in self.config or k.key not in self.config[k.type]:
|
||||
return ConfigResponse(
|
||||
version = None,
|
||||
values = None,
|
||||
directory = None,
|
||||
config = None,
|
||||
error = Error(
|
||||
type = "key-error",
|
||||
message = f"Key error"
|
||||
)
|
||||
)
|
||||
|
||||
for k in v.keys:
|
||||
del self.config[k.type][k.key]
|
||||
|
||||
self.version += 1
|
||||
|
||||
await self.push()
|
||||
await self.subs.start()
|
||||
|
||||
async def push(self):
|
||||
|
||||
resp = ConfigPush(
|
||||
version = self.config.version,
|
||||
return ConfigResponse(
|
||||
version = None,
|
||||
value = None,
|
||||
directory = None,
|
||||
values = None,
|
||||
config = None,
|
||||
error = None,
|
||||
)
|
||||
|
||||
async def handle_put(self, v, id):
|
||||
|
||||
for k in v.values:
|
||||
self.config[k.type][k.key] = k.value
|
||||
|
||||
self.version += 1
|
||||
|
||||
await self.push()
|
||||
|
||||
return ConfigResponse(
|
||||
version = None,
|
||||
value = None,
|
||||
directory = None,
|
||||
values = None,
|
||||
error = None,
|
||||
)
|
||||
|
||||
async def handle_config(self, v, id):
|
||||
|
||||
return ConfigResponse(
|
||||
version = self.version,
|
||||
value = None,
|
||||
directory = None,
|
||||
values = None,
|
||||
|
|
@ -92,27 +216,72 @@ class Processor(AsyncProcessor):
|
|||
error = None,
|
||||
)
|
||||
|
||||
await self.push_pub.send(resp)
|
||||
async def push(self):
|
||||
|
||||
print("Pushed version ", self.config.version)
|
||||
resp = ConfigPush(
|
||||
version = self.version,
|
||||
value = None,
|
||||
directory = None,
|
||||
values = None,
|
||||
config = self.config,
|
||||
error = None,
|
||||
)
|
||||
self.push_prod.send(resp)
|
||||
print("Pushed.")
|
||||
|
||||
async def on_message(self, msg, consumer, flow):
|
||||
async def handle(self, msg):
|
||||
|
||||
v = msg.value()
|
||||
|
||||
# Sender-produced ID
|
||||
id = msg.properties()["id"]
|
||||
|
||||
print(f"Handling {id}...", flush=True)
|
||||
|
||||
try:
|
||||
|
||||
v = msg.value()
|
||||
if v.operation == "get":
|
||||
|
||||
# Sender-produced ID
|
||||
id = msg.properties()["id"]
|
||||
resp = await self.handle_get(v, id)
|
||||
|
||||
print(f"Handling {id}...", flush=True)
|
||||
elif v.operation == "list":
|
||||
|
||||
resp = await self.config.handle(v)
|
||||
resp = await self.handle_list(v, id)
|
||||
|
||||
await self.response_pub.send(resp, properties={"id": id})
|
||||
elif v.operation == "getvalues":
|
||||
|
||||
resp = await self.handle_getvalues(v, id)
|
||||
|
||||
elif v.operation == "delete":
|
||||
|
||||
resp = await self.handle_delete(v, id)
|
||||
|
||||
elif v.operation == "put":
|
||||
|
||||
resp = await self.handle_put(v, id)
|
||||
|
||||
elif v.operation == "config":
|
||||
|
||||
resp = await self.handle_config(v, id)
|
||||
|
||||
else:
|
||||
|
||||
resp = ConfigResponse(
|
||||
value=None,
|
||||
directory=None,
|
||||
values=None,
|
||||
error=Error(
|
||||
type = "bad-operation",
|
||||
message = "Bad operation"
|
||||
)
|
||||
)
|
||||
|
||||
await self.send(resp, properties={"id": id})
|
||||
|
||||
self.consumer.acknowledge(msg)
|
||||
|
||||
except Exception as e:
|
||||
|
||||
|
||||
resp = ConfigResponse(
|
||||
error=Error(
|
||||
type = "unexpected-error",
|
||||
|
|
@ -120,33 +289,24 @@ class Processor(AsyncProcessor):
|
|||
),
|
||||
text=None,
|
||||
)
|
||||
|
||||
await self.response_pub.send(resp, properties={"id": id})
|
||||
await self.send(resp, properties={"id": id})
|
||||
self.consumer.acknowledge(msg)
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
||||
AsyncProcessor.add_args(parser)
|
||||
|
||||
parser.add_argument(
|
||||
'-q', '--request-queue',
|
||||
default=default_request_queue,
|
||||
help=f'Request queue (default: {default_request_queue})'
|
||||
ConsumerProducer.add_args(
|
||||
parser, default_input_queue, default_subscriber,
|
||||
default_output_queue,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-r', '--response-queue',
|
||||
default=default_response_queue,
|
||||
help=f'Response queue {default_response_queue}',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--push-queue',
|
||||
'-q', '--push-queue',
|
||||
default=default_push_queue,
|
||||
help=f'Config push queue (default: {default_push_queue})'
|
||||
)
|
||||
|
||||
def run():
|
||||
|
||||
Processor.launch(default_ident, __doc__)
|
||||
Processor.launch(module, __doc__)
|
||||
|
||||
|
|
|
|||
|
|
@ -17,10 +17,12 @@ from mistralai.models import OCRResponse
|
|||
from ... schema import Document, TextDocument, Metadata
|
||||
from ... schema import document_ingest_queue, text_ingest_queue
|
||||
from ... log_level import LogLevel
|
||||
from ... base import InputOutputProcessor
|
||||
from ... base import ConsumerProducer
|
||||
|
||||
module = "ocr"
|
||||
module = ".".join(__name__.split(".")[1:-1])
|
||||
|
||||
default_input_queue = document_ingest_queue
|
||||
default_output_queue = text_ingest_queue
|
||||
default_subscriber = module
|
||||
default_api_key = os.getenv("MISTRAL_TOKEN")
|
||||
|
||||
|
|
@ -69,17 +71,19 @@ def get_combined_markdown(ocr_response: OCRResponse) -> str:
|
|||
|
||||
return "\n\n".join(markdowns)
|
||||
|
||||
class Processor(InputOutputProcessor):
|
||||
class Processor(ConsumerProducer):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
||||
id = params.get("id")
|
||||
input_queue = params.get("input_queue", default_input_queue)
|
||||
output_queue = params.get("output_queue", default_output_queue)
|
||||
subscriber = params.get("subscriber", default_subscriber)
|
||||
api_key = params.get("api_key", default_api_key)
|
||||
|
||||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
"id": id,
|
||||
"input_queue": input_queue,
|
||||
"output_queue": output_queue,
|
||||
"subscriber": subscriber,
|
||||
"input_schema": Document,
|
||||
"output_schema": TextDocument,
|
||||
|
|
@ -147,7 +151,7 @@ class Processor(InputOutputProcessor):
|
|||
|
||||
return markdown
|
||||
|
||||
async def on_message(self, msg, consumer):
|
||||
async def handle(self, msg):
|
||||
|
||||
print("PDF message received")
|
||||
|
||||
|
|
@ -162,14 +166,17 @@ class Processor(InputOutputProcessor):
|
|||
text=markdown.encode("utf-8"),
|
||||
)
|
||||
|
||||
await consumer.q.output.send(r)
|
||||
await self.send(r)
|
||||
|
||||
print("Done.", flush=True)
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
||||
InputOutputProcessor.add_args(parser, default_subscriber)
|
||||
ConsumerProducer.add_args(
|
||||
parser, default_input_queue, default_subscriber,
|
||||
default_output_queue,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-k', '--api-key',
|
||||
|
|
|
|||
|
|
@ -9,43 +9,39 @@ import base64
|
|||
from langchain_community.document_loaders import PyPDFLoader
|
||||
|
||||
from ... schema import Document, TextDocument, Metadata
|
||||
from ... schema import document_ingest_queue, text_ingest_queue
|
||||
from ... log_level import LogLevel
|
||||
from ... base import FlowProcessor, ConsumerSpec, ProducerSpec
|
||||
from ... base import ConsumerProducer
|
||||
|
||||
default_ident = "pdf-decoder"
|
||||
module = ".".join(__name__.split(".")[1:-1])
|
||||
|
||||
class Processor(FlowProcessor):
|
||||
default_input_queue = document_ingest_queue
|
||||
default_output_queue = text_ingest_queue
|
||||
default_subscriber = module
|
||||
|
||||
class Processor(ConsumerProducer):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
||||
id = params.get("id", default_ident)
|
||||
input_queue = params.get("input_queue", default_input_queue)
|
||||
output_queue = params.get("output_queue", default_output_queue)
|
||||
subscriber = params.get("subscriber", default_subscriber)
|
||||
|
||||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
"id": id,
|
||||
"input_queue": input_queue,
|
||||
"output_queue": output_queue,
|
||||
"subscriber": subscriber,
|
||||
"input_schema": Document,
|
||||
"output_schema": TextDocument,
|
||||
}
|
||||
)
|
||||
|
||||
self.register_specification(
|
||||
ConsumerSpec(
|
||||
name = "input",
|
||||
schema = Document,
|
||||
handler = self.on_message,
|
||||
)
|
||||
)
|
||||
print("PDF inited")
|
||||
|
||||
self.register_specification(
|
||||
ProducerSpec(
|
||||
name = "output",
|
||||
schema = TextDocument,
|
||||
)
|
||||
)
|
||||
async def handle(self, msg):
|
||||
|
||||
print("PDF inited", flush=True)
|
||||
|
||||
async def on_message(self, msg, consumer, flow):
|
||||
|
||||
print("PDF message received", flush=True)
|
||||
print("PDF message received")
|
||||
|
||||
v = msg.value()
|
||||
|
||||
|
|
@ -63,22 +59,24 @@ class Processor(FlowProcessor):
|
|||
|
||||
for ix, page in enumerate(pages):
|
||||
|
||||
print("page", ix, flush=True)
|
||||
|
||||
r = TextDocument(
|
||||
metadata=v.metadata,
|
||||
text=page.page_content.encode("utf-8"),
|
||||
)
|
||||
|
||||
await flow("output").send(r)
|
||||
await self.send(r)
|
||||
|
||||
print("Done.", flush=True)
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
FlowProcessor.add_args(parser)
|
||||
|
||||
ConsumerProducer.add_args(
|
||||
parser, default_input_queue, default_subscriber,
|
||||
default_output_queue,
|
||||
)
|
||||
|
||||
def run():
|
||||
|
||||
Processor.launch(default_ident, __doc__)
|
||||
Processor.launch(module, __doc__)
|
||||
|
||||
|
|
|
|||
153
trustgraph-flow/trustgraph/document_rag.py
Normal file
153
trustgraph-flow/trustgraph/document_rag.py
Normal file
|
|
@ -0,0 +1,153 @@
|
|||
|
||||
from . clients.document_embeddings_client import DocumentEmbeddingsClient
|
||||
from . clients.triples_query_client import TriplesQueryClient
|
||||
from . clients.embeddings_client import EmbeddingsClient
|
||||
from . clients.prompt_client import PromptClient
|
||||
|
||||
from . schema import DocumentEmbeddingsRequest, DocumentEmbeddingsResponse
|
||||
from . schema import TriplesQueryRequest, TriplesQueryResponse
|
||||
from . schema import prompt_request_queue
|
||||
from . schema import prompt_response_queue
|
||||
from . schema import embeddings_request_queue
|
||||
from . schema import embeddings_response_queue
|
||||
from . schema import document_embeddings_request_queue
|
||||
from . schema import document_embeddings_response_queue
|
||||
|
||||
LABEL="http://www.w3.org/2000/01/rdf-schema#label"
|
||||
DEFINITION="http://www.w3.org/2004/02/skos/core#definition"
|
||||
|
||||
class Query:
|
||||
|
||||
def __init__(
|
||||
self, rag, user, collection, verbose,
|
||||
doc_limit=20
|
||||
):
|
||||
self.rag = rag
|
||||
self.user = user
|
||||
self.collection = collection
|
||||
self.verbose = verbose
|
||||
self.doc_limit = doc_limit
|
||||
|
||||
def get_vector(self, query):
|
||||
|
||||
if self.verbose:
|
||||
print("Compute embeddings...", flush=True)
|
||||
|
||||
qembeds = self.rag.embeddings.request(query)
|
||||
|
||||
if self.verbose:
|
||||
print("Done.", flush=True)
|
||||
|
||||
return qembeds
|
||||
|
||||
def get_docs(self, query):
|
||||
|
||||
vectors = self.get_vector(query)
|
||||
|
||||
if self.verbose:
|
||||
print("Get entities...", flush=True)
|
||||
|
||||
docs = self.rag.de_client.request(
|
||||
vectors, limit=self.doc_limit
|
||||
)
|
||||
|
||||
if self.verbose:
|
||||
print("Docs:", flush=True)
|
||||
for doc in docs:
|
||||
print(doc, flush=True)
|
||||
|
||||
return docs
|
||||
|
||||
class DocumentRag:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
pulsar_host="pulsar://pulsar:6650",
|
||||
pulsar_api_key=None,
|
||||
pr_request_queue=None,
|
||||
pr_response_queue=None,
|
||||
emb_request_queue=None,
|
||||
emb_response_queue=None,
|
||||
de_request_queue=None,
|
||||
de_response_queue=None,
|
||||
verbose=False,
|
||||
module="test",
|
||||
):
|
||||
|
||||
self.verbose=verbose
|
||||
|
||||
if pr_request_queue is None:
|
||||
pr_request_queue = prompt_request_queue
|
||||
|
||||
if pr_response_queue is None:
|
||||
pr_response_queue = prompt_response_queue
|
||||
|
||||
if emb_request_queue is None:
|
||||
emb_request_queue = embeddings_request_queue
|
||||
|
||||
if emb_response_queue is None:
|
||||
emb_response_queue = embeddings_response_queue
|
||||
|
||||
if de_request_queue is None:
|
||||
de_request_queue = document_embeddings_request_queue
|
||||
|
||||
if de_response_queue is None:
|
||||
de_response_queue = document_embeddings_response_queue
|
||||
|
||||
if self.verbose:
|
||||
print("Initialising...", flush=True)
|
||||
|
||||
self.de_client = DocumentEmbeddingsClient(
|
||||
pulsar_host=pulsar_host,
|
||||
subscriber=module + "-de",
|
||||
input_queue=de_request_queue,
|
||||
output_queue=de_response_queue,
|
||||
pulsar_api_key=pulsar_api_key,
|
||||
)
|
||||
|
||||
self.embeddings = EmbeddingsClient(
|
||||
pulsar_host=pulsar_host,
|
||||
input_queue=emb_request_queue,
|
||||
output_queue=emb_response_queue,
|
||||
subscriber=module + "-emb",
|
||||
pulsar_api_key=pulsar_api_key,
|
||||
)
|
||||
|
||||
self.lang = PromptClient(
|
||||
pulsar_host=pulsar_host,
|
||||
input_queue=pr_request_queue,
|
||||
output_queue=pr_response_queue,
|
||||
subscriber=module + "-de-prompt",
|
||||
pulsar_api_key=pulsar_api_key,
|
||||
)
|
||||
|
||||
if self.verbose:
|
||||
print("Initialised", flush=True)
|
||||
|
||||
def query(
|
||||
self, query, user="trustgraph", collection="default",
|
||||
doc_limit=20,
|
||||
):
|
||||
|
||||
if self.verbose:
|
||||
print("Construct prompt...", flush=True)
|
||||
|
||||
q = Query(
|
||||
rag=self, user=user, collection=collection, verbose=self.verbose,
|
||||
doc_limit=doc_limit
|
||||
)
|
||||
|
||||
docs = q.get_docs(query)
|
||||
|
||||
if self.verbose:
|
||||
print("Invoke LLM...", flush=True)
|
||||
print(docs)
|
||||
print(query)
|
||||
|
||||
resp = self.lang.request_document_prompt(query, docs)
|
||||
|
||||
if self.verbose:
|
||||
print("Done", flush=True)
|
||||
|
||||
return resp
|
||||
|
||||
|
|
@ -6,63 +6,61 @@ Output is chunk plus embedding.
|
|||
"""
|
||||
|
||||
from ... schema import Chunk, ChunkEmbeddings, DocumentEmbeddings
|
||||
from ... schema import EmbeddingsRequest, EmbeddingsResponse
|
||||
from ... schema import chunk_ingest_queue
|
||||
from ... schema import document_embeddings_store_queue
|
||||
from ... schema import embeddings_request_queue, embeddings_response_queue
|
||||
from ... clients.embeddings_client import EmbeddingsClient
|
||||
from ... log_level import LogLevel
|
||||
from ... base import ConsumerProducer
|
||||
|
||||
from ... base import FlowProcessor, RequestResponseSpec, ConsumerSpec
|
||||
from ... base import ProducerSpec
|
||||
module = ".".join(__name__.split(".")[1:-1])
|
||||
|
||||
default_ident = "document-embeddings"
|
||||
default_input_queue = chunk_ingest_queue
|
||||
default_output_queue = document_embeddings_store_queue
|
||||
default_subscriber = module
|
||||
|
||||
class Processor(FlowProcessor):
|
||||
class Processor(ConsumerProducer):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
||||
id = params.get("id")
|
||||
input_queue = params.get("input_queue", default_input_queue)
|
||||
output_queue = params.get("output_queue", default_output_queue)
|
||||
subscriber = params.get("subscriber", default_subscriber)
|
||||
emb_request_queue = params.get(
|
||||
"embeddings_request_queue", embeddings_request_queue
|
||||
)
|
||||
emb_response_queue = params.get(
|
||||
"embeddings_response_queue", embeddings_response_queue
|
||||
)
|
||||
|
||||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
"id": id,
|
||||
"input_queue": input_queue,
|
||||
"output_queue": output_queue,
|
||||
"embeddings_request_queue": emb_request_queue,
|
||||
"embeddings_response_queue": emb_response_queue,
|
||||
"subscriber": subscriber,
|
||||
"input_schema": Chunk,
|
||||
"output_schema": DocumentEmbeddings,
|
||||
}
|
||||
)
|
||||
|
||||
self.register_specification(
|
||||
ConsumerSpec(
|
||||
name = "input",
|
||||
schema = Chunk,
|
||||
handler = self.on_message,
|
||||
)
|
||||
self.embeddings = EmbeddingsClient(
|
||||
pulsar_host=self.pulsar_host,
|
||||
pulsar_api_key=self.pulsar_api_key,
|
||||
input_queue=emb_request_queue,
|
||||
output_queue=emb_response_queue,
|
||||
subscriber=module + "-emb",
|
||||
)
|
||||
|
||||
self.register_specification(
|
||||
RequestResponseSpec(
|
||||
request_name = "embeddings-request",
|
||||
request_schema = EmbeddingsRequest,
|
||||
response_name = "embeddings-response",
|
||||
response_schema = EmbeddingsResponse,
|
||||
)
|
||||
)
|
||||
|
||||
self.register_specification(
|
||||
ProducerSpec(
|
||||
name = "output",
|
||||
schema = DocumentEmbeddings
|
||||
)
|
||||
)
|
||||
|
||||
async def on_message(self, msg, consumer, flow):
|
||||
async def handle(self, msg):
|
||||
|
||||
v = msg.value()
|
||||
print(f"Indexing {v.metadata.id}...", flush=True)
|
||||
|
||||
try:
|
||||
|
||||
resp = await flow("embeddings-request").request(
|
||||
EmbeddingsRequest(
|
||||
text = v.chunk
|
||||
)
|
||||
)
|
||||
|
||||
vectors = resp.vectors
|
||||
vectors = self.embeddings.request(v.chunk)
|
||||
|
||||
embeds = [
|
||||
ChunkEmbeddings(
|
||||
|
|
@ -76,7 +74,7 @@ class Processor(FlowProcessor):
|
|||
chunks=embeds,
|
||||
)
|
||||
|
||||
await flow("output").send(r)
|
||||
await self.send(r)
|
||||
|
||||
except Exception as e:
|
||||
print("Exception:", e, flush=True)
|
||||
|
|
@ -89,9 +87,24 @@ class Processor(FlowProcessor):
|
|||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
||||
FlowProcessor.add_args(parser)
|
||||
ConsumerProducer.add_args(
|
||||
parser, default_input_queue, default_subscriber,
|
||||
default_output_queue,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--embeddings-request-queue',
|
||||
default=embeddings_request_queue,
|
||||
help=f'Embeddings request queue (default: {embeddings_request_queue})',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--embeddings-response-queue',
|
||||
default=embeddings_response_queue,
|
||||
help=f'Embeddings request queue (default: {embeddings_response_queue})',
|
||||
)
|
||||
|
||||
def run():
|
||||
|
||||
Processor.launch(default_ident, __doc__)
|
||||
Processor.launch(module, __doc__)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,43 +1,81 @@
|
|||
|
||||
"""
|
||||
Embeddings service, applies an embeddings model using fastembed
|
||||
Embeddings service, applies an embeddings model selected from HuggingFace.
|
||||
Input is text, output is embeddings vector.
|
||||
"""
|
||||
|
||||
from ... base import EmbeddingsService
|
||||
|
||||
from ... schema import EmbeddingsRequest, EmbeddingsResponse
|
||||
from ... schema import embeddings_request_queue, embeddings_response_queue
|
||||
from ... log_level import LogLevel
|
||||
from ... base import ConsumerProducer
|
||||
from fastembed import TextEmbedding
|
||||
import os
|
||||
|
||||
default_ident = "embeddings"
|
||||
module = ".".join(__name__.split(".")[1:-1])
|
||||
|
||||
default_input_queue = embeddings_request_queue
|
||||
default_output_queue = embeddings_response_queue
|
||||
default_subscriber = module
|
||||
default_model="sentence-transformers/all-MiniLM-L6-v2"
|
||||
|
||||
class Processor(EmbeddingsService):
|
||||
class Processor(ConsumerProducer):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
||||
input_queue = params.get("input_queue", default_input_queue)
|
||||
output_queue = params.get("output_queue", default_output_queue)
|
||||
subscriber = params.get("subscriber", default_subscriber)
|
||||
|
||||
model = params.get("model", default_model)
|
||||
|
||||
super(Processor, self).__init__(
|
||||
**params | { "model": model }
|
||||
**params | {
|
||||
"input_queue": input_queue,
|
||||
"output_queue": output_queue,
|
||||
"subscriber": subscriber,
|
||||
"input_schema": EmbeddingsRequest,
|
||||
"output_schema": EmbeddingsResponse,
|
||||
"model": model,
|
||||
}
|
||||
)
|
||||
|
||||
print("Get model...", flush=True)
|
||||
self.embeddings = TextEmbedding(model_name = model)
|
||||
|
||||
async def on_embeddings(self, text):
|
||||
async def handle(self, msg):
|
||||
|
||||
v = msg.value()
|
||||
|
||||
# Sender-produced ID
|
||||
|
||||
id = msg.properties()["id"]
|
||||
|
||||
print(f"Handling input {id}...", flush=True)
|
||||
|
||||
text = v.text
|
||||
vecs = self.embeddings.embed([text])
|
||||
|
||||
return [
|
||||
vecs = [
|
||||
v.tolist()
|
||||
for v in vecs
|
||||
]
|
||||
|
||||
print("Send response...", flush=True)
|
||||
r = EmbeddingsResponse(
|
||||
vectors=list(vecs),
|
||||
error=None,
|
||||
)
|
||||
|
||||
await self.send(r, properties={"id": id})
|
||||
|
||||
print("Done.", flush=True)
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
||||
EmbeddingsService.add_args(parser)
|
||||
ConsumerProducer.add_args(
|
||||
parser, default_input_queue, default_subscriber,
|
||||
default_output_queue,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-m', '--model',
|
||||
|
|
@ -47,5 +85,5 @@ class Processor(EmbeddingsService):
|
|||
|
||||
def run():
|
||||
|
||||
Processor.launch(default_ident, __doc__)
|
||||
Processor.launch(module, __doc__)
|
||||
|
||||
|
|
|
|||
|
|
@ -6,48 +6,53 @@ Output is entity plus embedding.
|
|||
"""
|
||||
|
||||
from ... schema import EntityContexts, EntityEmbeddings, GraphEmbeddings
|
||||
from ... schema import EmbeddingsRequest, EmbeddingsResponse
|
||||
from ... schema import entity_contexts_ingest_queue
|
||||
from ... schema import graph_embeddings_store_queue
|
||||
from ... schema import embeddings_request_queue, embeddings_response_queue
|
||||
from ... clients.embeddings_client import EmbeddingsClient
|
||||
from ... log_level import LogLevel
|
||||
from ... base import ConsumerProducer
|
||||
|
||||
from ... base import FlowProcessor, EmbeddingsClientSpec, ConsumerSpec
|
||||
from ... base import ProducerSpec
|
||||
module = ".".join(__name__.split(".")[1:-1])
|
||||
|
||||
default_ident = "graph-embeddings"
|
||||
default_input_queue = entity_contexts_ingest_queue
|
||||
default_output_queue = graph_embeddings_store_queue
|
||||
default_subscriber = module
|
||||
|
||||
class Processor(FlowProcessor):
|
||||
class Processor(ConsumerProducer):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
||||
id = params.get("id")
|
||||
input_queue = params.get("input_queue", default_input_queue)
|
||||
output_queue = params.get("output_queue", default_output_queue)
|
||||
subscriber = params.get("subscriber", default_subscriber)
|
||||
emb_request_queue = params.get(
|
||||
"embeddings_request_queue", embeddings_request_queue
|
||||
)
|
||||
emb_response_queue = params.get(
|
||||
"embeddings_response_queue", embeddings_response_queue
|
||||
)
|
||||
|
||||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
"id": id,
|
||||
"input_queue": input_queue,
|
||||
"output_queue": output_queue,
|
||||
"embeddings_request_queue": emb_request_queue,
|
||||
"embeddings_response_queue": emb_response_queue,
|
||||
"subscriber": subscriber,
|
||||
"input_schema": EntityContexts,
|
||||
"output_schema": GraphEmbeddings,
|
||||
}
|
||||
)
|
||||
|
||||
self.register_specification(
|
||||
ConsumerSpec(
|
||||
name = "input",
|
||||
schema = EntityContexts,
|
||||
handler = self.on_message,
|
||||
)
|
||||
self.embeddings = EmbeddingsClient(
|
||||
pulsar_host=self.pulsar_host,
|
||||
input_queue=emb_request_queue,
|
||||
output_queue=emb_response_queue,
|
||||
subscriber=module + "-emb",
|
||||
)
|
||||
|
||||
self.register_specification(
|
||||
EmbeddingsClientSpec(
|
||||
request_name = "embeddings-request",
|
||||
response_name = "embeddings-response",
|
||||
)
|
||||
)
|
||||
|
||||
self.register_specification(
|
||||
ProducerSpec(
|
||||
name = "output",
|
||||
schema = GraphEmbeddings
|
||||
)
|
||||
)
|
||||
|
||||
async def on_message(self, msg, consumer, flow):
|
||||
async def handle(self, msg):
|
||||
|
||||
v = msg.value()
|
||||
print(f"Indexing {v.metadata.id}...", flush=True)
|
||||
|
|
@ -58,9 +63,7 @@ class Processor(FlowProcessor):
|
|||
|
||||
for entity in v.entities:
|
||||
|
||||
vectors = await flow("embeddings-request").embed(
|
||||
text = entity.context
|
||||
)
|
||||
vectors = self.embeddings.request(entity.context)
|
||||
|
||||
entities.append(
|
||||
EntityEmbeddings(
|
||||
|
|
@ -74,7 +77,7 @@ class Processor(FlowProcessor):
|
|||
entities=entities,
|
||||
)
|
||||
|
||||
await flow("output").send(r)
|
||||
await self.send(r)
|
||||
|
||||
except Exception as e:
|
||||
print("Exception:", e, flush=True)
|
||||
|
|
@ -87,9 +90,24 @@ class Processor(FlowProcessor):
|
|||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
||||
FlowProcessor.add_args(parser)
|
||||
ConsumerProducer.add_args(
|
||||
parser, default_input_queue, default_subscriber,
|
||||
default_output_queue,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--embeddings-request-queue',
|
||||
default=embeddings_request_queue,
|
||||
help=f'Embeddings request queue (default: {embeddings_request_queue})',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--embeddings-response-queue',
|
||||
default=embeddings_response_queue,
|
||||
help=f'Embeddings request queue (default: {embeddings_response_queue})',
|
||||
)
|
||||
|
||||
def run():
|
||||
|
||||
Processor.launch(default_ident, __doc__)
|
||||
Processor.launch(module, __doc__)
|
||||
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ from ... base import ConsumerProducer
|
|||
from ollama import Client
|
||||
import os
|
||||
|
||||
module = "embeddings"
|
||||
module = ".".join(__name__.split(".")[1:-1])
|
||||
|
||||
default_input_queue = embeddings_request_queue
|
||||
default_output_queue = embeddings_response_queue
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ from trustgraph.log_level import LogLevel
|
|||
from trustgraph.base import ConsumerProducer
|
||||
import requests
|
||||
|
||||
module = "wikipedia"
|
||||
module = ".".join(__name__.split(".")[1:-1])
|
||||
|
||||
default_input_queue = encyclopedia_lookup_request_queue
|
||||
default_output_queue = encyclopedia_lookup_response_queue
|
||||
|
|
|
|||
|
|
@ -5,62 +5,84 @@ get entity definitions which are output as graph edges along with
|
|||
entity/context definitions for embedding.
|
||||
"""
|
||||
|
||||
import json
|
||||
import urllib.parse
|
||||
from pulsar.schema import JsonSchema
|
||||
|
||||
from .... schema import Chunk, Triple, Triples, Metadata, Value
|
||||
from .... schema import EntityContext, EntityContexts
|
||||
from .... schema import PromptRequest, PromptResponse
|
||||
from .... schema import chunk_ingest_queue, triples_store_queue
|
||||
from .... schema import entity_contexts_ingest_queue
|
||||
from .... schema import prompt_request_queue
|
||||
from .... schema import prompt_response_queue
|
||||
from .... log_level import LogLevel
|
||||
from .... clients.prompt_client import PromptClient
|
||||
from .... rdf import TRUSTGRAPH_ENTITIES, DEFINITION, RDF_LABEL, SUBJECT_OF
|
||||
|
||||
from .... base import FlowProcessor, ConsumerSpec, ProducerSpec
|
||||
from .... base import PromptClientSpec
|
||||
from .... base import ConsumerProducer
|
||||
|
||||
DEFINITION_VALUE = Value(value=DEFINITION, is_uri=True)
|
||||
RDF_LABEL_VALUE = Value(value=RDF_LABEL, is_uri=True)
|
||||
SUBJECT_OF_VALUE = Value(value=SUBJECT_OF, is_uri=True)
|
||||
|
||||
default_ident = "kg-extract-definitions"
|
||||
module = ".".join(__name__.split(".")[1:-1])
|
||||
|
||||
class Processor(FlowProcessor):
|
||||
default_input_queue = chunk_ingest_queue
|
||||
default_output_queue = triples_store_queue
|
||||
default_entity_context_queue = entity_contexts_ingest_queue
|
||||
default_subscriber = module
|
||||
|
||||
class Processor(ConsumerProducer):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
||||
id = params.get("id")
|
||||
input_queue = params.get("input_queue", default_input_queue)
|
||||
output_queue = params.get("output_queue", default_output_queue)
|
||||
ec_queue = params.get(
|
||||
"entity_context_queue",
|
||||
default_entity_context_queue
|
||||
)
|
||||
subscriber = params.get("subscriber", default_subscriber)
|
||||
pr_request_queue = params.get(
|
||||
"prompt_request_queue", prompt_request_queue
|
||||
)
|
||||
pr_response_queue = params.get(
|
||||
"prompt_response_queue", prompt_response_queue
|
||||
)
|
||||
|
||||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
"id": id,
|
||||
"input_queue": input_queue,
|
||||
"output_queue": output_queue,
|
||||
"subscriber": subscriber,
|
||||
"input_schema": Chunk,
|
||||
"output_schema": Triples,
|
||||
"prompt_request_queue": pr_request_queue,
|
||||
"prompt_response_queue": pr_response_queue,
|
||||
}
|
||||
)
|
||||
|
||||
self.register_specification(
|
||||
ConsumerSpec(
|
||||
name = "input",
|
||||
schema = Chunk,
|
||||
handler = self.on_message
|
||||
)
|
||||
self.ec_prod = self.client.create_producer(
|
||||
topic=ec_queue,
|
||||
schema=JsonSchema(EntityContexts),
|
||||
)
|
||||
|
||||
self.register_specification(
|
||||
PromptClientSpec(
|
||||
request_name = "prompt-request",
|
||||
response_name = "prompt-response",
|
||||
)
|
||||
)
|
||||
__class__.pubsub_metric.info({
|
||||
"input_queue": input_queue,
|
||||
"output_queue": output_queue,
|
||||
"entity_context_queue": ec_queue,
|
||||
"prompt_request_queue": pr_request_queue,
|
||||
"prompt_response_queue": pr_response_queue,
|
||||
"subscriber": subscriber,
|
||||
"input_schema": Chunk.__name__,
|
||||
"output_schema": Triples.__name__,
|
||||
"vector_schema": EntityContexts.__name__,
|
||||
})
|
||||
|
||||
self.register_specification(
|
||||
ProducerSpec(
|
||||
name = "triples",
|
||||
schema = Triples
|
||||
)
|
||||
)
|
||||
|
||||
self.register_specification(
|
||||
ProducerSpec(
|
||||
name = "entity-contexts",
|
||||
schema = EntityContexts
|
||||
)
|
||||
self.prompt = PromptClient(
|
||||
pulsar_host=self.pulsar_host,
|
||||
pulsar_api_key=self.pulsar_api_key,
|
||||
input_queue=pr_request_queue,
|
||||
output_queue=pr_response_queue,
|
||||
subscriber = module + "-prompt",
|
||||
)
|
||||
|
||||
def to_uri(self, text):
|
||||
|
|
@ -71,47 +93,36 @@ class Processor(FlowProcessor):
|
|||
|
||||
return uri
|
||||
|
||||
async def emit_triples(self, pub, metadata, triples):
|
||||
def get_definitions(self, chunk):
|
||||
|
||||
return self.prompt.request_definitions(chunk)
|
||||
|
||||
async def emit_edges(self, metadata, triples):
|
||||
|
||||
t = Triples(
|
||||
metadata=metadata,
|
||||
triples=triples,
|
||||
)
|
||||
await pub.send(t)
|
||||
await self.send(t)
|
||||
|
||||
async def emit_ecs(self, pub, metadata, entities):
|
||||
async def emit_ecs(self, metadata, entities):
|
||||
|
||||
t = EntityContexts(
|
||||
metadata=metadata,
|
||||
entities=entities,
|
||||
)
|
||||
await pub.send(t)
|
||||
self.ec_prod.send(t)
|
||||
|
||||
async def on_message(self, msg, consumer, flow):
|
||||
async def handle(self, msg):
|
||||
|
||||
v = msg.value()
|
||||
print(f"Indexing {v.metadata.id}...", flush=True)
|
||||
|
||||
chunk = v.chunk.decode("utf-8")
|
||||
|
||||
print(chunk, flush=True)
|
||||
|
||||
try:
|
||||
|
||||
try:
|
||||
|
||||
defs = await flow("prompt-request").extract_definitions(
|
||||
text = chunk
|
||||
)
|
||||
|
||||
print("Response", defs, flush=True)
|
||||
|
||||
if type(defs) != list:
|
||||
raise RuntimeError("Expecting array in prompt response")
|
||||
|
||||
except Exception as e:
|
||||
print("Prompt exception:", e, flush=True)
|
||||
raise e
|
||||
defs = self.get_definitions(chunk)
|
||||
|
||||
triples = []
|
||||
entities = []
|
||||
|
|
@ -123,8 +134,8 @@ class Processor(FlowProcessor):
|
|||
|
||||
for defn in defs:
|
||||
|
||||
s = defn["entity"]
|
||||
o = defn["definition"]
|
||||
s = defn.name
|
||||
o = defn.definition
|
||||
|
||||
if s == "": continue
|
||||
if o == "": continue
|
||||
|
|
@ -155,13 +166,13 @@ class Processor(FlowProcessor):
|
|||
|
||||
ec = EntityContext(
|
||||
entity=s_value,
|
||||
context=defn["definition"],
|
||||
context=defn.definition,
|
||||
)
|
||||
|
||||
entities.append(ec)
|
||||
|
||||
|
||||
await self.emit_triples(
|
||||
flow("triples"),
|
||||
await self.emit_edges(
|
||||
Metadata(
|
||||
id=v.metadata.id,
|
||||
metadata=[],
|
||||
|
|
@ -172,7 +183,6 @@ class Processor(FlowProcessor):
|
|||
)
|
||||
|
||||
await self.emit_ecs(
|
||||
flow("entity-contexts"),
|
||||
Metadata(
|
||||
id=v.metadata.id,
|
||||
metadata=[],
|
||||
|
|
@ -190,9 +200,30 @@ class Processor(FlowProcessor):
|
|||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
||||
FlowProcessor.add_args(parser)
|
||||
ConsumerProducer.add_args(
|
||||
parser, default_input_queue, default_subscriber,
|
||||
default_output_queue,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-e', '--entity-context-queue',
|
||||
default=default_entity_context_queue,
|
||||
help=f'Entity context queue (default: {default_entity_context_queue})'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--prompt-request-queue',
|
||||
default=prompt_request_queue,
|
||||
help=f'Prompt request queue (default: {prompt_request_queue})',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--prompt-completion-response-queue',
|
||||
default=prompt_response_queue,
|
||||
help=f'Prompt response queue (default: {prompt_response_queue})',
|
||||
)
|
||||
|
||||
def run():
|
||||
|
||||
Processor.launch(default_ident, __doc__)
|
||||
Processor.launch(module, __doc__)
|
||||
|
||||
|
|
|
|||
|
|
@ -5,54 +5,59 @@ relationship analysis to get entity relationship edges which are output as
|
|||
graph edges.
|
||||
"""
|
||||
|
||||
import json
|
||||
import urllib.parse
|
||||
|
||||
from .... schema import Chunk, Triple, Triples
|
||||
from .... schema import Metadata, Value
|
||||
from .... schema import PromptRequest, PromptResponse
|
||||
from .... schema import chunk_ingest_queue, triples_store_queue
|
||||
from .... schema import prompt_request_queue
|
||||
from .... schema import prompt_response_queue
|
||||
from .... log_level import LogLevel
|
||||
from .... clients.prompt_client import PromptClient
|
||||
from .... rdf import RDF_LABEL, TRUSTGRAPH_ENTITIES, SUBJECT_OF
|
||||
|
||||
from .... base import FlowProcessor, ConsumerSpec, ProducerSpec
|
||||
from .... base import PromptClientSpec
|
||||
from .... base import ConsumerProducer
|
||||
|
||||
RDF_LABEL_VALUE = Value(value=RDF_LABEL, is_uri=True)
|
||||
SUBJECT_OF_VALUE = Value(value=SUBJECT_OF, is_uri=True)
|
||||
|
||||
default_ident = "kg-extract-relationships"
|
||||
module = ".".join(__name__.split(".")[1:-1])
|
||||
|
||||
class Processor(FlowProcessor):
|
||||
default_input_queue = chunk_ingest_queue
|
||||
default_output_queue = triples_store_queue
|
||||
default_subscriber = module
|
||||
|
||||
class Processor(ConsumerProducer):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
||||
id = params.get("id")
|
||||
input_queue = params.get("input_queue", default_input_queue)
|
||||
output_queue = params.get("output_queue", default_output_queue)
|
||||
subscriber = params.get("subscriber", default_subscriber)
|
||||
pr_request_queue = params.get(
|
||||
"prompt_request_queue", prompt_request_queue
|
||||
)
|
||||
pr_response_queue = params.get(
|
||||
"prompt_response_queue", prompt_response_queue
|
||||
)
|
||||
|
||||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
"id": id,
|
||||
"input_queue": input_queue,
|
||||
"output_queue": output_queue,
|
||||
"subscriber": subscriber,
|
||||
"input_schema": Chunk,
|
||||
"output_schema": Triples,
|
||||
"prompt_request_queue": pr_request_queue,
|
||||
"prompt_response_queue": pr_response_queue,
|
||||
}
|
||||
)
|
||||
|
||||
self.register_specification(
|
||||
ConsumerSpec(
|
||||
name = "input",
|
||||
schema = Chunk,
|
||||
handler = self.on_message
|
||||
)
|
||||
)
|
||||
|
||||
self.register_specification(
|
||||
PromptClientSpec(
|
||||
request_name = "prompt-request",
|
||||
response_name = "prompt-response",
|
||||
)
|
||||
)
|
||||
|
||||
self.register_specification(
|
||||
ProducerSpec(
|
||||
name = "triples",
|
||||
schema = Triples
|
||||
)
|
||||
self.prompt = PromptClient(
|
||||
pulsar_host=self.pulsar_host,
|
||||
pulsar_api_key=self.pulsar_api_key,
|
||||
input_queue=pr_request_queue,
|
||||
output_queue=pr_response_queue,
|
||||
subscriber = module + "-prompt",
|
||||
)
|
||||
|
||||
def to_uri(self, text):
|
||||
|
|
@ -63,39 +68,28 @@ class Processor(FlowProcessor):
|
|||
|
||||
return uri
|
||||
|
||||
async def emit_triples(self, pub, metadata, triples):
|
||||
def get_relationships(self, chunk):
|
||||
|
||||
return self.prompt.request_relationships(chunk)
|
||||
|
||||
async def emit_edges(self, metadata, triples):
|
||||
|
||||
t = Triples(
|
||||
metadata=metadata,
|
||||
triples=triples,
|
||||
)
|
||||
await pub.send(t)
|
||||
await self.send(t)
|
||||
|
||||
async def on_message(self, msg, consumer, flow):
|
||||
async def handle(self, msg):
|
||||
|
||||
v = msg.value()
|
||||
print(f"Indexing {v.metadata.id}...", flush=True)
|
||||
|
||||
chunk = v.chunk.decode("utf-8")
|
||||
|
||||
print(chunk, flush=True)
|
||||
|
||||
try:
|
||||
|
||||
try:
|
||||
|
||||
rels = await flow("prompt-request").extract_relationships(
|
||||
text = chunk
|
||||
)
|
||||
|
||||
print("Response", rels, flush=True)
|
||||
|
||||
if type(rels) != list:
|
||||
raise RuntimeError("Expecting array in prompt response")
|
||||
|
||||
except Exception as e:
|
||||
print("Prompt exception:", e, flush=True)
|
||||
raise e
|
||||
rels = self.get_relationships(chunk)
|
||||
|
||||
triples = []
|
||||
|
||||
|
|
@ -106,9 +100,9 @@ class Processor(FlowProcessor):
|
|||
|
||||
for rel in rels:
|
||||
|
||||
s = rel["subject"]
|
||||
p = rel["predicate"]
|
||||
o = rel["object"]
|
||||
s = rel.s
|
||||
p = rel.p
|
||||
o = rel.o
|
||||
|
||||
if s == "": continue
|
||||
if p == "": continue
|
||||
|
|
@ -124,7 +118,7 @@ class Processor(FlowProcessor):
|
|||
p_uri = self.to_uri(p)
|
||||
p_value = Value(value=str(p_uri), is_uri=True)
|
||||
|
||||
if rel["object-entity"]:
|
||||
if rel.o_entity:
|
||||
o_uri = self.to_uri(o)
|
||||
o_value = Value(value=str(o_uri), is_uri=True)
|
||||
else:
|
||||
|
|
@ -150,7 +144,7 @@ class Processor(FlowProcessor):
|
|||
o=Value(value=str(p), is_uri=False)
|
||||
))
|
||||
|
||||
if rel["object-entity"]:
|
||||
if rel.o_entity:
|
||||
# Label for o
|
||||
triples.append(Triple(
|
||||
s=o_value,
|
||||
|
|
@ -165,7 +159,7 @@ class Processor(FlowProcessor):
|
|||
o=Value(value=v.metadata.id, is_uri=True)
|
||||
))
|
||||
|
||||
if rel["object-entity"]:
|
||||
if rel.o_entity:
|
||||
# 'Subject of' for o
|
||||
triples.append(Triple(
|
||||
s=o_value,
|
||||
|
|
@ -174,7 +168,6 @@ class Processor(FlowProcessor):
|
|||
))
|
||||
|
||||
await self.emit_edges(
|
||||
flow("triples"),
|
||||
Metadata(
|
||||
id=v.metadata.id,
|
||||
metadata=[],
|
||||
|
|
@ -192,9 +185,24 @@ class Processor(FlowProcessor):
|
|||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
||||
FlowProcessor.add_args(parser)
|
||||
ConsumerProducer.add_args(
|
||||
parser, default_input_queue, default_subscriber,
|
||||
default_output_queue,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--prompt-request-queue',
|
||||
default=prompt_request_queue,
|
||||
help=f'Prompt request queue (default: {prompt_request_queue})',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--prompt-response-queue',
|
||||
default=prompt_response_queue,
|
||||
help=f'Prompt response queue (default: {prompt_response_queue})',
|
||||
)
|
||||
|
||||
def run():
|
||||
|
||||
Processor.launch(default_ident, __doc__)
|
||||
Processor.launch(module, __doc__)
|
||||
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ from .... base import ConsumerProducer
|
|||
|
||||
DEFINITION_VALUE = Value(value=DEFINITION, is_uri=True)
|
||||
|
||||
module = "kg-extract-topics"
|
||||
module = ".".join(__name__.split(".")[1:-1])
|
||||
|
||||
default_input_queue = chunk_ingest_queue
|
||||
default_output_queue = triples_store_queue
|
||||
|
|
|
|||
|
|
@ -39,3 +39,4 @@ class AgentRequestor(ServiceRequestor):
|
|||
# The 2nd boolean expression indicates whether we're done responding
|
||||
return resp, (message.answer is not None)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
|
||||
import asyncio
|
||||
from pulsar.schema import JsonSchema
|
||||
import uuid
|
||||
from aiohttp import WSMsgType
|
||||
|
||||
|
|
@ -25,12 +26,12 @@ class DocumentEmbeddingsLoadEndpoint(SocketEndpoint):
|
|||
|
||||
self.publisher = Publisher(
|
||||
self.pulsar_client, document_embeddings_store_queue,
|
||||
schema=DocumentEmbeddings
|
||||
schema=JsonSchema(DocumentEmbeddings)
|
||||
)
|
||||
|
||||
async def start(self):
|
||||
|
||||
await self.publisher.start()
|
||||
self.publisher.start()
|
||||
|
||||
async def listener(self, ws, running):
|
||||
|
||||
|
|
@ -58,6 +59,6 @@ class DocumentEmbeddingsLoadEndpoint(SocketEndpoint):
|
|||
],
|
||||
)
|
||||
|
||||
await self.publisher.send(None, elt)
|
||||
self.publisher.send(None, elt)
|
||||
|
||||
running.stop()
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
|
||||
import asyncio
|
||||
import queue
|
||||
from pulsar.schema import JsonSchema
|
||||
import uuid
|
||||
|
||||
from .. schema import DocumentEmbeddings
|
||||
|
|
@ -26,7 +27,7 @@ class DocumentEmbeddingsStreamEndpoint(SocketEndpoint):
|
|||
self.subscriber = Subscriber(
|
||||
self.pulsar_client, document_embeddings_store_queue,
|
||||
"api-gateway", "api-gateway",
|
||||
schema=DocumentEmbeddings,
|
||||
schema=JsonSchema(DocumentEmbeddings),
|
||||
)
|
||||
|
||||
async def listener(self, ws, running):
|
||||
|
|
@ -43,17 +44,17 @@ class DocumentEmbeddingsStreamEndpoint(SocketEndpoint):
|
|||
|
||||
async def start(self):
|
||||
|
||||
await self.subscriber.start()
|
||||
self.subscriber.start()
|
||||
|
||||
async def async_thread(self, ws, running):
|
||||
|
||||
id = str(uuid.uuid4())
|
||||
|
||||
q = await self.subscriber.subscribe_all(id)
|
||||
q = self.subscriber.subscribe_all(id)
|
||||
|
||||
while running.get():
|
||||
try:
|
||||
resp = await asyncio.wait_for(q.get(), timeout=0.5)
|
||||
resp = await asyncio.to_thread(q.get, timeout=0.5)
|
||||
await ws.send_json(serialize_document_embeddings(resp))
|
||||
|
||||
except TimeoutError:
|
||||
|
|
@ -66,7 +67,7 @@ class DocumentEmbeddingsStreamEndpoint(SocketEndpoint):
|
|||
print(f"Exception: {str(e)}", flush=True)
|
||||
break
|
||||
|
||||
await self.subscriber.unsubscribe_all(id)
|
||||
self.subscriber.unsubscribe_all(id)
|
||||
|
||||
running.stop()
|
||||
|
||||
|
|
|
|||
|
|
@ -1,9 +1,13 @@
|
|||
|
||||
import asyncio
|
||||
from pulsar.schema import JsonSchema
|
||||
from aiohttp import web
|
||||
import uuid
|
||||
import logging
|
||||
|
||||
from .. base import Publisher
|
||||
from .. base import Subscriber
|
||||
|
||||
logger = logging.getLogger("endpoint")
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
|
||||
import asyncio
|
||||
from pulsar.schema import JsonSchema
|
||||
import uuid
|
||||
from aiohttp import WSMsgType
|
||||
|
||||
|
|
@ -25,12 +26,12 @@ class GraphEmbeddingsLoadEndpoint(SocketEndpoint):
|
|||
|
||||
self.publisher = Publisher(
|
||||
self.pulsar_client, graph_embeddings_store_queue,
|
||||
schema=GraphEmbeddings
|
||||
schema=JsonSchema(GraphEmbeddings)
|
||||
)
|
||||
|
||||
async def start(self):
|
||||
|
||||
await self.publisher.start()
|
||||
self.publisher.start()
|
||||
|
||||
async def listener(self, ws, running):
|
||||
|
||||
|
|
@ -59,6 +60,6 @@ class GraphEmbeddingsLoadEndpoint(SocketEndpoint):
|
|||
]
|
||||
)
|
||||
|
||||
await self.publisher.send(None, elt)
|
||||
self.publisher.send(None, elt)
|
||||
|
||||
running.stop()
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
|
||||
import asyncio
|
||||
import queue
|
||||
from pulsar.schema import JsonSchema
|
||||
import uuid
|
||||
|
||||
from .. schema import GraphEmbeddings
|
||||
|
|
@ -25,7 +26,7 @@ class GraphEmbeddingsStreamEndpoint(SocketEndpoint):
|
|||
self.subscriber = Subscriber(
|
||||
self.pulsar_client, graph_embeddings_store_queue,
|
||||
"api-gateway", "api-gateway",
|
||||
schema=GraphEmbeddings
|
||||
schema=JsonSchema(GraphEmbeddings)
|
||||
)
|
||||
|
||||
async def listener(self, ws, running):
|
||||
|
|
@ -40,17 +41,17 @@ class GraphEmbeddingsStreamEndpoint(SocketEndpoint):
|
|||
|
||||
async def start(self):
|
||||
|
||||
await self.subscriber.start()
|
||||
self.subscriber.start()
|
||||
|
||||
async def async_thread(self, ws, running):
|
||||
|
||||
id = str(uuid.uuid4())
|
||||
|
||||
q = await self.subscriber.subscribe_all(id)
|
||||
q = self.subscriber.subscribe_all(id)
|
||||
|
||||
while running.get():
|
||||
try:
|
||||
resp = await asyncio.wait_for(q.get, timeout=0.5)
|
||||
resp = await asyncio.to_thread(q.get, timeout=0.5)
|
||||
await ws.send_json(serialize_graph_embeddings(resp))
|
||||
|
||||
except TimeoutError:
|
||||
|
|
@ -63,7 +64,7 @@ class GraphEmbeddingsStreamEndpoint(SocketEndpoint):
|
|||
print(f"Exception: {str(e)}", flush=True)
|
||||
break
|
||||
|
||||
await self.subscriber.unsubscribe_all(id)
|
||||
self.subscriber.unsubscribe_all(id)
|
||||
|
||||
running.stop()
|
||||
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@
|
|||
import aiohttp
|
||||
from aiohttp import web
|
||||
import asyncio
|
||||
from pulsar.schema import JsonSchema
|
||||
import uuid
|
||||
import logging
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
|
||||
import asyncio
|
||||
import queue
|
||||
from pulsar.schema import JsonSchema
|
||||
import uuid
|
||||
from aiohttp import web, WSMsgType
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
|
||||
import asyncio
|
||||
from pulsar.schema import JsonSchema
|
||||
import uuid
|
||||
import logging
|
||||
|
||||
|
|
@ -22,21 +23,21 @@ class ServiceRequestor:
|
|||
|
||||
self.pub = Publisher(
|
||||
pulsar_client, request_queue,
|
||||
schema=request_schema,
|
||||
schema=JsonSchema(request_schema),
|
||||
)
|
||||
|
||||
self.sub = Subscriber(
|
||||
pulsar_client, response_queue,
|
||||
subscription, consumer_name,
|
||||
response_schema
|
||||
JsonSchema(response_schema)
|
||||
)
|
||||
|
||||
self.timeout = timeout
|
||||
|
||||
async def start(self):
|
||||
|
||||
await self.pub.start()
|
||||
await self.sub.start()
|
||||
self.pub.start()
|
||||
self.sub.start()
|
||||
|
||||
def to_request(self, request):
|
||||
raise RuntimeError("Not defined")
|
||||
|
|
@ -50,15 +51,18 @@ class ServiceRequestor:
|
|||
|
||||
try:
|
||||
|
||||
q = await self.sub.subscribe(id)
|
||||
q = self.sub.subscribe(id)
|
||||
|
||||
await self.pub.send(id, self.to_request(request))
|
||||
await asyncio.to_thread(
|
||||
self.pub.send, id, self.to_request(request)
|
||||
)
|
||||
|
||||
while True:
|
||||
|
||||
try:
|
||||
resp = await asyncio.wait_for(
|
||||
q.get(), timeout=self.timeout
|
||||
resp = await asyncio.to_thread(
|
||||
q.get,
|
||||
timeout=self.timeout
|
||||
)
|
||||
except Exception as e:
|
||||
raise RuntimeError("Timeout")
|
||||
|
|
@ -95,5 +99,5 @@ class ServiceRequestor:
|
|||
return err
|
||||
|
||||
finally:
|
||||
await self.sub.unsubscribe(id)
|
||||
self.sub.unsubscribe(id)
|
||||
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@
|
|||
# Like ServiceRequestor, but just fire-and-forget instead of request/response
|
||||
|
||||
import asyncio
|
||||
from pulsar.schema import JsonSchema
|
||||
import uuid
|
||||
import logging
|
||||
|
||||
|
|
@ -20,12 +21,12 @@ class ServiceSender:
|
|||
|
||||
self.pub = Publisher(
|
||||
pulsar_client, request_queue,
|
||||
schema=request_schema,
|
||||
schema=JsonSchema(request_schema),
|
||||
)
|
||||
|
||||
async def start(self):
|
||||
|
||||
await self.pub.start()
|
||||
self.pub.start()
|
||||
|
||||
def to_request(self, request):
|
||||
raise RuntimeError("Not defined")
|
||||
|
|
@ -34,7 +35,9 @@ class ServiceSender:
|
|||
|
||||
try:
|
||||
|
||||
await self.pub.send(None, self.to_request(request))
|
||||
await asyncio.to_thread(
|
||||
self.pub.send, None, self.to_request(request)
|
||||
)
|
||||
|
||||
if responder:
|
||||
await responder({}, True)
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ API gateway. Offers HTTP services which are translated to interaction on the
|
|||
Pulsar bus.
|
||||
"""
|
||||
|
||||
module = "api-gateway"
|
||||
module = ".".join(__name__.split(".")[1:-1])
|
||||
|
||||
# FIXME: Subscribes to Pulsar unnecessarily, should only do it when there
|
||||
# are active listeners
|
||||
|
|
@ -19,6 +19,7 @@ import os
|
|||
import base64
|
||||
|
||||
import pulsar
|
||||
from pulsar.schema import JsonSchema
|
||||
from prometheus_client import start_http_server
|
||||
|
||||
from .. log_level import LogLevel
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
|
||||
import asyncio
|
||||
from pulsar.schema import JsonSchema
|
||||
import uuid
|
||||
from aiohttp import WSMsgType
|
||||
|
||||
|
|
@ -23,12 +24,12 @@ class TriplesLoadEndpoint(SocketEndpoint):
|
|||
|
||||
self.publisher = Publisher(
|
||||
self.pulsar_client, triples_store_queue,
|
||||
schema=Triples
|
||||
schema=JsonSchema(Triples)
|
||||
)
|
||||
|
||||
async def start(self):
|
||||
|
||||
await self.publisher.start()
|
||||
self.publisher.start()
|
||||
|
||||
async def listener(self, ws, running):
|
||||
|
||||
|
|
@ -50,7 +51,7 @@ class TriplesLoadEndpoint(SocketEndpoint):
|
|||
triples=to_subgraph(data["triples"]),
|
||||
)
|
||||
|
||||
await self.publisher.send(None, elt)
|
||||
self.publisher.send(None, elt)
|
||||
|
||||
|
||||
running.stop()
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
|
||||
import asyncio
|
||||
import queue
|
||||
from pulsar.schema import JsonSchema
|
||||
import uuid
|
||||
|
||||
from .. schema import Triples
|
||||
|
|
@ -23,7 +24,7 @@ class TriplesStreamEndpoint(SocketEndpoint):
|
|||
self.subscriber = Subscriber(
|
||||
self.pulsar_client, triples_store_queue,
|
||||
"api-gateway", "api-gateway",
|
||||
schema=Triples
|
||||
schema=JsonSchema(Triples)
|
||||
)
|
||||
|
||||
async def listener(self, ws, running):
|
||||
|
|
@ -38,7 +39,7 @@ class TriplesStreamEndpoint(SocketEndpoint):
|
|||
|
||||
async def start(self):
|
||||
|
||||
await self.subscriber.start()
|
||||
self.subscriber.start()
|
||||
|
||||
async def async_thread(self, ws, running):
|
||||
|
||||
|
|
|
|||
295
trustgraph-flow/trustgraph/graph_rag.py
Normal file
295
trustgraph-flow/trustgraph/graph_rag.py
Normal file
|
|
@ -0,0 +1,295 @@
|
|||
|
||||
from . clients.graph_embeddings_client import GraphEmbeddingsClient
|
||||
from . clients.triples_query_client import TriplesQueryClient
|
||||
from . clients.embeddings_client import EmbeddingsClient
|
||||
from . clients.prompt_client import PromptClient
|
||||
|
||||
from . schema import GraphEmbeddingsRequest, GraphEmbeddingsResponse
|
||||
from . schema import TriplesQueryRequest, TriplesQueryResponse
|
||||
from . schema import prompt_request_queue
|
||||
from . schema import prompt_response_queue
|
||||
from . schema import embeddings_request_queue
|
||||
from . schema import embeddings_response_queue
|
||||
from . schema import graph_embeddings_request_queue
|
||||
from . schema import graph_embeddings_response_queue
|
||||
from . schema import triples_request_queue
|
||||
from . schema import triples_response_queue
|
||||
|
||||
LABEL="http://www.w3.org/2000/01/rdf-schema#label"
|
||||
DEFINITION="http://www.w3.org/2004/02/skos/core#definition"
|
||||
|
||||
class Query:
|
||||
|
||||
def __init__(
|
||||
self, rag, user, collection, verbose,
|
||||
entity_limit=50, triple_limit=30, max_subgraph_size=1000,
|
||||
max_path_length=2,
|
||||
):
|
||||
self.rag = rag
|
||||
self.user = user
|
||||
self.collection = collection
|
||||
self.verbose = verbose
|
||||
self.entity_limit = entity_limit
|
||||
self.triple_limit = triple_limit
|
||||
self.max_subgraph_size = max_subgraph_size
|
||||
self.max_path_length = max_path_length
|
||||
|
||||
def get_vector(self, query):
|
||||
|
||||
if self.verbose:
|
||||
print("Compute embeddings...", flush=True)
|
||||
|
||||
qembeds = self.rag.embeddings.request(query)
|
||||
|
||||
if self.verbose:
|
||||
print("Done.", flush=True)
|
||||
|
||||
return qembeds
|
||||
|
||||
def get_entities(self, query):
|
||||
|
||||
vectors = self.get_vector(query)
|
||||
|
||||
if self.verbose:
|
||||
print("Get entities...", flush=True)
|
||||
|
||||
entities = self.rag.ge_client.request(
|
||||
user=self.user, collection=self.collection,
|
||||
vectors=vectors, limit=self.entity_limit,
|
||||
)
|
||||
|
||||
entities = [
|
||||
e.value
|
||||
for e in entities
|
||||
]
|
||||
|
||||
if self.verbose:
|
||||
print("Entities:", flush=True)
|
||||
for ent in entities:
|
||||
print(" ", ent, flush=True)
|
||||
|
||||
return entities
|
||||
|
||||
def maybe_label(self, e):
|
||||
|
||||
if e in self.rag.label_cache:
|
||||
return self.rag.label_cache[e]
|
||||
|
||||
res = self.rag.triples_client.request(
|
||||
user=self.user, collection=self.collection,
|
||||
s=e, p=LABEL, o=None, limit=1,
|
||||
)
|
||||
|
||||
if len(res) == 0:
|
||||
self.rag.label_cache[e] = e
|
||||
return e
|
||||
|
||||
self.rag.label_cache[e] = res[0].o.value
|
||||
return self.rag.label_cache[e]
|
||||
|
||||
def follow_edges(self, ent, subgraph, path_length):
|
||||
|
||||
# Not needed?
|
||||
if path_length <= 0:
|
||||
return
|
||||
|
||||
# Stop spanning around if the subgraph is already maxed out
|
||||
if len(subgraph) >= self.max_subgraph_size:
|
||||
return
|
||||
|
||||
res = self.rag.triples_client.request(
|
||||
user=self.user, collection=self.collection,
|
||||
s=ent, p=None, o=None,
|
||||
limit=self.triple_limit
|
||||
)
|
||||
|
||||
for triple in res:
|
||||
subgraph.add(
|
||||
(triple.s.value, triple.p.value, triple.o.value)
|
||||
)
|
||||
if path_length > 1:
|
||||
self.follow_edges(triple.o.value, subgraph, path_length-1)
|
||||
|
||||
res = self.rag.triples_client.request(
|
||||
user=self.user, collection=self.collection,
|
||||
s=None, p=ent, o=None,
|
||||
limit=self.triple_limit
|
||||
)
|
||||
|
||||
for triple in res:
|
||||
subgraph.add(
|
||||
(triple.s.value, triple.p.value, triple.o.value)
|
||||
)
|
||||
|
||||
res = self.rag.triples_client.request(
|
||||
user=self.user, collection=self.collection,
|
||||
s=None, p=None, o=ent,
|
||||
limit=self.triple_limit,
|
||||
)
|
||||
|
||||
for triple in res:
|
||||
subgraph.add(
|
||||
(triple.s.value, triple.p.value, triple.o.value)
|
||||
)
|
||||
if path_length > 1:
|
||||
self.follow_edges(triple.s.value, subgraph, path_length-1)
|
||||
|
||||
def get_subgraph(self, query):
|
||||
|
||||
entities = self.get_entities(query)
|
||||
|
||||
if self.verbose:
|
||||
print("Get subgraph...", flush=True)
|
||||
|
||||
subgraph = set()
|
||||
|
||||
for ent in entities:
|
||||
self.follow_edges(ent, subgraph, self.max_path_length)
|
||||
|
||||
subgraph = list(subgraph)
|
||||
|
||||
return subgraph
|
||||
|
||||
def get_labelgraph(self, query):
|
||||
|
||||
subgraph = self.get_subgraph(query)
|
||||
|
||||
sg2 = []
|
||||
|
||||
for edge in subgraph:
|
||||
|
||||
if edge[1] == LABEL:
|
||||
continue
|
||||
|
||||
s = self.maybe_label(edge[0])
|
||||
p = self.maybe_label(edge[1])
|
||||
o = self.maybe_label(edge[2])
|
||||
|
||||
sg2.append((s, p, o))
|
||||
|
||||
sg2 = sg2[0:self.max_subgraph_size]
|
||||
|
||||
if self.verbose:
|
||||
print("Subgraph:", flush=True)
|
||||
for edge in sg2:
|
||||
print(" ", str(edge), flush=True)
|
||||
|
||||
if self.verbose:
|
||||
print("Done.", flush=True)
|
||||
|
||||
return sg2
|
||||
|
||||
class GraphRag:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
pulsar_host="pulsar://pulsar:6650",
|
||||
pulsar_api_key=None,
|
||||
pr_request_queue=None,
|
||||
pr_response_queue=None,
|
||||
emb_request_queue=None,
|
||||
emb_response_queue=None,
|
||||
ge_request_queue=None,
|
||||
ge_response_queue=None,
|
||||
tpl_request_queue=None,
|
||||
tpl_response_queue=None,
|
||||
verbose=False,
|
||||
module="test",
|
||||
):
|
||||
|
||||
self.verbose=verbose
|
||||
|
||||
if pr_request_queue is None:
|
||||
pr_request_queue = prompt_request_queue
|
||||
|
||||
if pr_response_queue is None:
|
||||
pr_response_queue = prompt_response_queue
|
||||
|
||||
if emb_request_queue is None:
|
||||
emb_request_queue = embeddings_request_queue
|
||||
|
||||
if emb_response_queue is None:
|
||||
emb_response_queue = embeddings_response_queue
|
||||
|
||||
if ge_request_queue is None:
|
||||
ge_request_queue = graph_embeddings_request_queue
|
||||
|
||||
if ge_response_queue is None:
|
||||
ge_response_queue = graph_embeddings_response_queue
|
||||
|
||||
if tpl_request_queue is None:
|
||||
tpl_request_queue = triples_request_queue
|
||||
|
||||
if tpl_response_queue is None:
|
||||
tpl_response_queue = triples_response_queue
|
||||
|
||||
if self.verbose:
|
||||
print("Initialising...", flush=True)
|
||||
|
||||
self.ge_client = GraphEmbeddingsClient(
|
||||
pulsar_host=pulsar_host,
|
||||
pulsar_api_key=pulsar_api_key,
|
||||
subscriber=module + "-ge",
|
||||
input_queue=ge_request_queue,
|
||||
output_queue=ge_response_queue,
|
||||
)
|
||||
|
||||
self.triples_client = TriplesQueryClient(
|
||||
pulsar_host=pulsar_host,
|
||||
pulsar_api_key=pulsar_api_key,
|
||||
subscriber=module + "-tpl",
|
||||
input_queue=tpl_request_queue,
|
||||
output_queue=tpl_response_queue
|
||||
)
|
||||
|
||||
self.embeddings = EmbeddingsClient(
|
||||
pulsar_host=pulsar_host,
|
||||
pulsar_api_key=pulsar_api_key,
|
||||
input_queue=emb_request_queue,
|
||||
output_queue=emb_response_queue,
|
||||
subscriber=module + "-emb",
|
||||
)
|
||||
|
||||
self.label_cache = {}
|
||||
|
||||
self.prompt = PromptClient(
|
||||
pulsar_host=pulsar_host,
|
||||
pulsar_api_key=pulsar_api_key,
|
||||
input_queue=pr_request_queue,
|
||||
output_queue=pr_response_queue,
|
||||
subscriber=module + "-prompt",
|
||||
)
|
||||
|
||||
if self.verbose:
|
||||
print("Initialised", flush=True)
|
||||
|
||||
def query(
|
||||
self, query, user="trustgraph", collection="default",
|
||||
entity_limit=50, triple_limit=30, max_subgraph_size=1000,
|
||||
max_path_length=2,
|
||||
):
|
||||
|
||||
if self.verbose:
|
||||
print("Construct prompt...", flush=True)
|
||||
|
||||
q = Query(
|
||||
rag=self, user=user, collection=collection, verbose=self.verbose,
|
||||
entity_limit=entity_limit, triple_limit=triple_limit,
|
||||
max_subgraph_size=max_subgraph_size,
|
||||
max_path_length=max_path_length,
|
||||
)
|
||||
|
||||
kg = q.get_labelgraph(query)
|
||||
|
||||
if self.verbose:
|
||||
print("Invoke LLM...", flush=True)
|
||||
print(kg)
|
||||
print(query)
|
||||
|
||||
resp = self.prompt.request_kg_prompt(query, kg)
|
||||
|
||||
if self.verbose:
|
||||
print("Done", flush=True)
|
||||
|
||||
return resp
|
||||
|
||||
|
|
@ -35,7 +35,7 @@ from .. exceptions import RequestError
|
|||
|
||||
from . librarian import Librarian
|
||||
|
||||
module = "librarian"
|
||||
module = ".".join(__name__.split(".")[1:-1])
|
||||
|
||||
default_input_queue = librarian_request_queue
|
||||
default_output_queue = librarian_response_queue
|
||||
|
|
|
|||
|
|
@ -10,11 +10,12 @@ from .. schema import text_completion_response_queue
|
|||
from .. log_level import LogLevel
|
||||
from .. base import Consumer
|
||||
|
||||
module = "metering"
|
||||
module = ".".join(__name__.split(".")[1:-1])
|
||||
|
||||
default_input_queue = text_completion_response_queue
|
||||
default_subscriber = module
|
||||
|
||||
|
||||
class Processor(Consumer):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
|
|
|||
|
|
@ -27,7 +27,7 @@ from .... clients.llm_client import LlmClient
|
|||
from . prompts import to_definitions, to_relationships, to_topics
|
||||
from . prompts import to_kg_query, to_document_query, to_rows
|
||||
|
||||
module = "prompt"
|
||||
module = ".".join(__name__.split(".")[1:-1])
|
||||
|
||||
default_input_queue = prompt_request_queue
|
||||
default_output_queue = prompt_response_queue
|
||||
|
|
|
|||
|
|
@ -4,6 +4,8 @@ import json
|
|||
from jsonschema import validate
|
||||
import re
|
||||
|
||||
from trustgraph.clients.llm_client import LlmClient
|
||||
|
||||
class PromptConfiguration:
|
||||
def __init__(self, system_template, global_terms={}, prompts={}):
|
||||
self.system_template = system_template
|
||||
|
|
@ -19,7 +21,8 @@ class Prompt:
|
|||
|
||||
class PromptManager:
|
||||
|
||||
def __init__(self, config):
|
||||
def __init__(self, llm, config):
|
||||
self.llm = llm
|
||||
self.config = config
|
||||
self.terms = config.global_terms
|
||||
|
||||
|
|
@ -51,9 +54,7 @@ class PromptManager:
|
|||
|
||||
return json.loads(json_str)
|
||||
|
||||
async def invoke(self, id, input, llm):
|
||||
|
||||
print("Invoke...", flush=True)
|
||||
def invoke(self, id, input):
|
||||
|
||||
if id not in self.prompts:
|
||||
raise RuntimeError("ID invalid")
|
||||
|
|
@ -67,7 +68,9 @@ class PromptManager:
|
|||
"prompt": self.templates[id].render(terms)
|
||||
}
|
||||
|
||||
resp = await llm(**prompt)
|
||||
resp = self.llm.request(**prompt)
|
||||
|
||||
print(resp, flush=True)
|
||||
|
||||
if resp_type == "text":
|
||||
return resp
|
||||
|
|
@ -78,13 +81,13 @@ class PromptManager:
|
|||
try:
|
||||
obj = self.parse_json(resp)
|
||||
except:
|
||||
print("Parse fail:", resp, flush=True)
|
||||
raise RuntimeError("JSON parse fail")
|
||||
|
||||
print(obj, flush=True)
|
||||
if self.prompts[id].schema:
|
||||
try:
|
||||
print(self.prompts[id].schema)
|
||||
validate(instance=obj, schema=self.prompts[id].schema)
|
||||
print("Validated", flush=True)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Schema validation fail: {e}")
|
||||
|
||||
|
|
|
|||
|
|
@ -3,7 +3,6 @@
|
|||
Language service abstracts prompt engineering from LLM.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import re
|
||||
|
||||
|
|
@ -11,59 +10,74 @@ from .... schema import Definition, Relationship, Triple
|
|||
from .... schema import Topic
|
||||
from .... schema import PromptRequest, PromptResponse, Error
|
||||
from .... schema import TextCompletionRequest, TextCompletionResponse
|
||||
|
||||
from .... base import FlowProcessor
|
||||
from .... base import ProducerSpec, ConsumerSpec, TextCompletionClientSpec
|
||||
from .... schema import text_completion_request_queue
|
||||
from .... schema import text_completion_response_queue
|
||||
from .... schema import prompt_request_queue, prompt_response_queue
|
||||
from .... base import ConsumerProducer
|
||||
from .... clients.llm_client import LlmClient
|
||||
|
||||
from . prompt_manager import PromptConfiguration, Prompt, PromptManager
|
||||
|
||||
default_ident = "prompt"
|
||||
module = ".".join(__name__.split(".")[1:-1])
|
||||
|
||||
class Processor(FlowProcessor):
|
||||
default_input_queue = prompt_request_queue
|
||||
default_output_queue = prompt_response_queue
|
||||
default_subscriber = module
|
||||
|
||||
class Processor(ConsumerProducer):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
||||
id = params.get("id")
|
||||
input_queue = params.get("input_queue", default_input_queue)
|
||||
output_queue = params.get("output_queue", default_output_queue)
|
||||
subscriber = params.get("subscriber", default_subscriber)
|
||||
tc_request_queue = params.get(
|
||||
"text_completion_request_queue", text_completion_request_queue
|
||||
)
|
||||
tc_response_queue = params.get(
|
||||
"text_completion_response_queue", text_completion_response_queue
|
||||
)
|
||||
|
||||
# Config key for prompts
|
||||
self.config_key = params.get("config_type", "prompt")
|
||||
|
||||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
"id": id,
|
||||
"input_queue": input_queue,
|
||||
"output_queue": output_queue,
|
||||
"subscriber": subscriber,
|
||||
"input_schema": PromptRequest,
|
||||
"output_schema": PromptResponse,
|
||||
"text_completion_request_queue": tc_request_queue,
|
||||
"text_completion_response_queue": tc_response_queue,
|
||||
}
|
||||
)
|
||||
|
||||
self.register_specification(
|
||||
ConsumerSpec(
|
||||
name = "request",
|
||||
schema = PromptRequest,
|
||||
handler = self.on_request
|
||||
)
|
||||
self.llm = LlmClient(
|
||||
subscriber=subscriber,
|
||||
input_queue=tc_request_queue,
|
||||
output_queue=tc_response_queue,
|
||||
pulsar_host = self.pulsar_host,
|
||||
pulsar_api_key=self.pulsar_api_key,
|
||||
)
|
||||
|
||||
self.register_specification(
|
||||
TextCompletionClientSpec(
|
||||
request_name = "text-completion-request",
|
||||
response_name = "text-completion-response",
|
||||
)
|
||||
)
|
||||
# System prompt hack
|
||||
class Llm:
|
||||
def __init__(self, llm):
|
||||
self.llm = llm
|
||||
def request(self, system, prompt):
|
||||
print(system)
|
||||
print(prompt, flush=True)
|
||||
return self.llm.request(system, prompt)
|
||||
|
||||
self.register_specification(
|
||||
ProducerSpec(
|
||||
name = "response",
|
||||
schema = PromptResponse
|
||||
)
|
||||
)
|
||||
|
||||
self.register_config_handler(self.on_prompt_config)
|
||||
self.llm = Llm(self.llm)
|
||||
|
||||
# Null configuration, should reload quickly
|
||||
self.manager = PromptManager(
|
||||
llm = self.llm,
|
||||
config = PromptConfiguration("", {}, {})
|
||||
)
|
||||
|
||||
async def on_prompt_config(self, config, version):
|
||||
async def on_config(self, version, config):
|
||||
|
||||
print("Loading configuration version", version)
|
||||
|
||||
|
|
@ -97,6 +111,7 @@ class Processor(FlowProcessor):
|
|||
)
|
||||
|
||||
self.manager = PromptManager(
|
||||
self.llm,
|
||||
PromptConfiguration(
|
||||
system,
|
||||
{},
|
||||
|
|
@ -111,7 +126,7 @@ class Processor(FlowProcessor):
|
|||
print("Exception:", e, flush=True)
|
||||
print("Configuration reload failed", flush=True)
|
||||
|
||||
async def on_request(self, msg, consumer, flow):
|
||||
async def handle(self, msg):
|
||||
|
||||
v = msg.value()
|
||||
|
||||
|
|
@ -123,7 +138,7 @@ class Processor(FlowProcessor):
|
|||
|
||||
try:
|
||||
|
||||
print(v.terms, flush=True)
|
||||
print(v.terms)
|
||||
|
||||
input = {
|
||||
k: json.loads(v)
|
||||
|
|
@ -131,33 +146,14 @@ class Processor(FlowProcessor):
|
|||
}
|
||||
|
||||
print(f"Handling kind {kind}...", flush=True)
|
||||
print(input, flush=True)
|
||||
|
||||
async def llm(system, prompt):
|
||||
|
||||
print(system, flush=True)
|
||||
print(prompt, flush=True)
|
||||
|
||||
resp = await flow("text-completion-request").text_completion(
|
||||
system = system, prompt = prompt,
|
||||
)
|
||||
|
||||
try:
|
||||
return resp
|
||||
except Exception as e:
|
||||
print("LLM Exception:", e, flush=True)
|
||||
return None
|
||||
|
||||
try:
|
||||
resp = await self.manager.invoke(kind, input, llm)
|
||||
except Exception as e:
|
||||
print("Invocation exception:", e, flush=True)
|
||||
raise e
|
||||
|
||||
print(resp, flush=True)
|
||||
resp = self.manager.invoke(kind, input)
|
||||
|
||||
if isinstance(resp, str):
|
||||
|
||||
print("Send text response...", flush=True)
|
||||
print(resp, flush=True)
|
||||
|
||||
r = PromptResponse(
|
||||
text=resp,
|
||||
|
|
@ -165,7 +161,7 @@ class Processor(FlowProcessor):
|
|||
error=None,
|
||||
)
|
||||
|
||||
await flow("response").send(r, properties={"id": id})
|
||||
await self.send(r, properties={"id": id})
|
||||
|
||||
return
|
||||
|
||||
|
|
@ -180,13 +176,13 @@ class Processor(FlowProcessor):
|
|||
error=None,
|
||||
)
|
||||
|
||||
await flow("response").send(r, properties={"id": id})
|
||||
await self.send(r, properties={"id": id})
|
||||
|
||||
return
|
||||
|
||||
except Exception as e:
|
||||
|
||||
print(f"Exception: {e}", flush=True)
|
||||
print(f"Exception: {e}")
|
||||
|
||||
print("Send error response...", flush=True)
|
||||
|
||||
|
|
@ -198,11 +194,11 @@ class Processor(FlowProcessor):
|
|||
response=None,
|
||||
)
|
||||
|
||||
await flow("response").send(r, properties={"id": id})
|
||||
await self.send(r, properties={"id": id})
|
||||
|
||||
except Exception as e:
|
||||
|
||||
print(f"Exception: {e}", flush=True)
|
||||
print(f"Exception: {e}")
|
||||
|
||||
print("Send error response...", flush=True)
|
||||
|
||||
|
|
@ -219,7 +215,22 @@ class Processor(FlowProcessor):
|
|||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
||||
FlowProcessor.add_args(parser)
|
||||
ConsumerProducer.add_args(
|
||||
parser, default_input_queue, default_subscriber,
|
||||
default_output_queue,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--text-completion-request-queue',
|
||||
default=text_completion_request_queue,
|
||||
help=f'Text completion request queue (default: {text_completion_request_queue})',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--text-completion-response-queue',
|
||||
default=text_completion_response_queue,
|
||||
help=f'Text completion response queue (default: {text_completion_response_queue})',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--config-type',
|
||||
|
|
@ -229,5 +240,5 @@ class Processor(FlowProcessor):
|
|||
|
||||
def run():
|
||||
|
||||
Processor.launch(default_ident, __doc__)
|
||||
Processor.launch(module, __doc__)
|
||||
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ from .... log_level import LogLevel
|
|||
from .... base import ConsumerProducer
|
||||
from .... exceptions import TooManyRequests
|
||||
|
||||
module = "text-completion"
|
||||
module = ".".join(__name__.split(".")[1:-1])
|
||||
|
||||
default_input_queue = text_completion_request_queue
|
||||
default_output_queue = text_completion_response_queue
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ from .... log_level import LogLevel
|
|||
from .... base import ConsumerProducer
|
||||
from .... exceptions import TooManyRequests
|
||||
|
||||
module = "text-completion"
|
||||
module = ".".join(__name__.split(".")[1:-1])
|
||||
|
||||
default_input_queue = text_completion_request_queue
|
||||
default_output_queue = text_completion_response_queue
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ from .... log_level import LogLevel
|
|||
from .... base import ConsumerProducer
|
||||
from .... exceptions import TooManyRequests
|
||||
|
||||
module = "text-completion"
|
||||
module = ".".join(__name__.split(".")[1:-1])
|
||||
|
||||
default_input_queue = text_completion_request_queue
|
||||
default_output_queue = text_completion_response_queue
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ from .... log_level import LogLevel
|
|||
from .... base import ConsumerProducer
|
||||
from .... exceptions import TooManyRequests
|
||||
|
||||
module = "text-completion"
|
||||
module = ".".join(__name__.split(".")[1:-1])
|
||||
|
||||
default_input_queue = text_completion_request_queue
|
||||
default_output_queue = text_completion_response_queue
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ from .... log_level import LogLevel
|
|||
from .... base import ConsumerProducer
|
||||
from .... exceptions import TooManyRequests
|
||||
|
||||
module = "text-completion"
|
||||
module = ".".join(__name__.split(".")[1:-1])
|
||||
|
||||
default_input_queue = text_completion_request_queue
|
||||
default_output_queue = text_completion_response_queue
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ from .... log_level import LogLevel
|
|||
from .... base import ConsumerProducer
|
||||
from .... exceptions import TooManyRequests
|
||||
|
||||
module = "text-completion"
|
||||
module = ".".join(__name__.split(".")[1:-1])
|
||||
|
||||
default_input_queue = text_completion_request_queue
|
||||
default_output_queue = text_completion_response_queue
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ from .... log_level import LogLevel
|
|||
from .... base import ConsumerProducer
|
||||
from .... exceptions import TooManyRequests
|
||||
|
||||
module = "text-completion"
|
||||
module = ".".join(__name__.split(".")[1:-1])
|
||||
|
||||
default_input_queue = text_completion_request_queue
|
||||
default_output_queue = text_completion_response_queue
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ from .... log_level import LogLevel
|
|||
from .... base import ConsumerProducer
|
||||
from .... exceptions import TooManyRequests
|
||||
|
||||
module = "text-completion"
|
||||
module = ".".join(__name__.split(".")[1:-1])
|
||||
|
||||
default_input_queue = text_completion_request_queue
|
||||
default_output_queue = text_completion_response_queue
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ from .... log_level import LogLevel
|
|||
from .... base import ConsumerProducer
|
||||
from .... exceptions import TooManyRequests
|
||||
|
||||
module = "text-completion"
|
||||
module = ".".join(__name__.split(".")[1:-1])
|
||||
|
||||
default_input_queue = text_completion_request_queue
|
||||
default_output_queue = text_completion_response_queue
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ from .... log_level import LogLevel
|
|||
from .... base import ConsumerProducer
|
||||
from .... exceptions import TooManyRequests
|
||||
|
||||
module = "text-completion"
|
||||
module = ".".join(__name__.split(".")[1:-1])
|
||||
|
||||
default_input_queue = text_completion_request_queue
|
||||
default_output_queue = text_completion_response_queue
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ from .... schema import document_embeddings_request_queue
|
|||
from .... schema import document_embeddings_response_queue
|
||||
from .... base import ConsumerProducer
|
||||
|
||||
module = "de-query"
|
||||
module = ".".join(__name__.split(".")[1:-1])
|
||||
|
||||
default_input_queue = document_embeddings_request_queue
|
||||
default_output_queue = document_embeddings_response_queue
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ from .... schema import document_embeddings_request_queue
|
|||
from .... schema import document_embeddings_response_queue
|
||||
from .... base import ConsumerProducer
|
||||
|
||||
module = "de-query"
|
||||
module = ".".join(__name__.split(".")[1:-1])
|
||||
|
||||
default_input_queue = document_embeddings_request_queue
|
||||
default_output_queue = document_embeddings_response_queue
|
||||
|
|
|
|||
|
|
@ -7,51 +7,71 @@ of chunks
|
|||
from qdrant_client import QdrantClient
|
||||
from qdrant_client.models import PointStruct
|
||||
from qdrant_client.models import Distance, VectorParams
|
||||
import uuid
|
||||
|
||||
from .... schema import DocumentEmbeddingsResponse
|
||||
from .... schema import DocumentEmbeddingsRequest, DocumentEmbeddingsResponse
|
||||
from .... schema import Error, Value
|
||||
from .... base import DocumentEmbeddingsQueryService
|
||||
from .... schema import document_embeddings_request_queue
|
||||
from .... schema import document_embeddings_response_queue
|
||||
from .... base import ConsumerProducer
|
||||
|
||||
default_ident = "de-query"
|
||||
module = ".".join(__name__.split(".")[1:-1])
|
||||
|
||||
default_input_queue = document_embeddings_request_queue
|
||||
default_output_queue = document_embeddings_response_queue
|
||||
default_subscriber = module
|
||||
default_store_uri = 'http://localhost:6333'
|
||||
|
||||
class Processor(DocumentEmbeddingsQueryService):
|
||||
class Processor(ConsumerProducer):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
||||
input_queue = params.get("input_queue", default_input_queue)
|
||||
output_queue = params.get("output_queue", default_output_queue)
|
||||
subscriber = params.get("subscriber", default_subscriber)
|
||||
store_uri = params.get("store_uri", default_store_uri)
|
||||
|
||||
#optional api key
|
||||
api_key = params.get("api_key", None)
|
||||
|
||||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
"input_queue": input_queue,
|
||||
"output_queue": output_queue,
|
||||
"subscriber": subscriber,
|
||||
"input_schema": DocumentEmbeddingsRequest,
|
||||
"output_schema": DocumentEmbeddingsResponse,
|
||||
"store_uri": store_uri,
|
||||
"api_key": api_key,
|
||||
}
|
||||
)
|
||||
|
||||
self.qdrant = QdrantClient(url=store_uri, api_key=api_key)
|
||||
self.client = QdrantClient(url=store_uri, api_key=api_key)
|
||||
|
||||
async def query_document_embeddings(self, msg):
|
||||
async def handle(self, msg):
|
||||
|
||||
try:
|
||||
|
||||
v = msg.value()
|
||||
|
||||
# Sender-produced ID
|
||||
id = msg.properties()["id"]
|
||||
|
||||
print(f"Handling input {id}...", flush=True)
|
||||
|
||||
chunks = []
|
||||
|
||||
for vec in msg.vectors:
|
||||
for vec in v.vectors:
|
||||
|
||||
dim = len(vec)
|
||||
collection = (
|
||||
"d_" + msg.user + "_" + msg.collection + "_" +
|
||||
"d_" + v.user + "_" + v.collection + "_" +
|
||||
str(dim)
|
||||
)
|
||||
|
||||
search_result = self.qdrant.query_points(
|
||||
search_result = self.client.query_points(
|
||||
collection_name=collection,
|
||||
query=vec,
|
||||
limit=msg.limit,
|
||||
limit=v.limit,
|
||||
with_payload=True,
|
||||
).points
|
||||
|
||||
|
|
@ -59,17 +79,37 @@ class Processor(DocumentEmbeddingsQueryService):
|
|||
ent = r.payload["doc"]
|
||||
chunks.append(ent)
|
||||
|
||||
return chunks
|
||||
print("Send response...", flush=True)
|
||||
r = DocumentEmbeddingsResponse(documents=chunks, error=None)
|
||||
await self.send(r, properties={"id": id})
|
||||
|
||||
print("Done.", flush=True)
|
||||
|
||||
except Exception as e:
|
||||
|
||||
print(f"Exception: {e}")
|
||||
raise e
|
||||
|
||||
print("Send error response...", flush=True)
|
||||
|
||||
r = DocumentEmbeddingsResponse(
|
||||
error=Error(
|
||||
type = "llm-error",
|
||||
message = str(e),
|
||||
),
|
||||
documents=None,
|
||||
)
|
||||
|
||||
await self.send(r, properties={"id": id})
|
||||
|
||||
self.consumer.acknowledge(msg)
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
||||
DocumentEmbeddingsQueryService.add_args(parser)
|
||||
ConsumerProducer.add_args(
|
||||
parser, default_input_queue, default_subscriber,
|
||||
default_output_queue,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-t', '--store-uri',
|
||||
|
|
@ -85,5 +125,5 @@ class Processor(DocumentEmbeddingsQueryService):
|
|||
|
||||
def run():
|
||||
|
||||
Processor.launch(default_ident, __doc__)
|
||||
Processor.launch(module, __doc__)
|
||||
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ from .... schema import graph_embeddings_request_queue
|
|||
from .... schema import graph_embeddings_response_queue
|
||||
from .... base import ConsumerProducer
|
||||
|
||||
module = "ge-query"
|
||||
module = ".".join(__name__.split(".")[1:-1])
|
||||
|
||||
default_input_queue = graph_embeddings_request_queue
|
||||
default_output_queue = graph_embeddings_response_queue
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ from .... schema import graph_embeddings_request_queue
|
|||
from .... schema import graph_embeddings_response_queue
|
||||
from .... base import ConsumerProducer
|
||||
|
||||
module = "ge-query"
|
||||
module = ".".join(__name__.split(".")[1:-1])
|
||||
|
||||
default_input_queue = graph_embeddings_request_queue
|
||||
default_output_queue = graph_embeddings_response_queue
|
||||
|
|
|
|||
|
|
@ -7,32 +7,44 @@ entities
|
|||
from qdrant_client import QdrantClient
|
||||
from qdrant_client.models import PointStruct
|
||||
from qdrant_client.models import Distance, VectorParams
|
||||
import uuid
|
||||
|
||||
from .... schema import GraphEmbeddingsResponse
|
||||
from .... schema import GraphEmbeddingsRequest, GraphEmbeddingsResponse
|
||||
from .... schema import Error, Value
|
||||
from .... base import GraphEmbeddingsQueryService
|
||||
from .... schema import graph_embeddings_request_queue
|
||||
from .... schema import graph_embeddings_response_queue
|
||||
from .... base import ConsumerProducer
|
||||
|
||||
default_ident = "ge-query"
|
||||
module = ".".join(__name__.split(".")[1:-1])
|
||||
|
||||
default_input_queue = graph_embeddings_request_queue
|
||||
default_output_queue = graph_embeddings_response_queue
|
||||
default_subscriber = module
|
||||
default_store_uri = 'http://localhost:6333'
|
||||
|
||||
class Processor(GraphEmbeddingsQueryService):
|
||||
class Processor(ConsumerProducer):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
||||
input_queue = params.get("input_queue", default_input_queue)
|
||||
output_queue = params.get("output_queue", default_output_queue)
|
||||
subscriber = params.get("subscriber", default_subscriber)
|
||||
store_uri = params.get("store_uri", default_store_uri)
|
||||
|
||||
#optional api key
|
||||
api_key = params.get("api_key", None)
|
||||
|
||||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
"input_queue": input_queue,
|
||||
"output_queue": output_queue,
|
||||
"subscriber": subscriber,
|
||||
"input_schema": GraphEmbeddingsRequest,
|
||||
"output_schema": GraphEmbeddingsResponse,
|
||||
"store_uri": store_uri,
|
||||
"api_key": api_key,
|
||||
}
|
||||
)
|
||||
|
||||
self.qdrant = QdrantClient(url=store_uri, api_key=api_key)
|
||||
self.client = QdrantClient(url=store_uri, api_key=api_key)
|
||||
|
||||
def create_value(self, ent):
|
||||
if ent.startswith("http://") or ent.startswith("https://"):
|
||||
|
|
@ -40,27 +52,34 @@ class Processor(GraphEmbeddingsQueryService):
|
|||
else:
|
||||
return Value(value=ent, is_uri=False)
|
||||
|
||||
async def query_graph_embeddings(self, msg):
|
||||
async def handle(self, msg):
|
||||
|
||||
try:
|
||||
|
||||
v = msg.value()
|
||||
|
||||
# Sender-produced ID
|
||||
id = msg.properties()["id"]
|
||||
|
||||
print(f"Handling input {id}...", flush=True)
|
||||
|
||||
entity_set = set()
|
||||
entities = []
|
||||
|
||||
for vec in msg.vectors:
|
||||
for vec in v.vectors:
|
||||
|
||||
dim = len(vec)
|
||||
collection = (
|
||||
"t_" + msg.user + "_" + msg.collection + "_" +
|
||||
"t_" + v.user + "_" + v.collection + "_" +
|
||||
str(dim)
|
||||
)
|
||||
|
||||
# Heuristic hack, get (2*limit), so that we have more chance
|
||||
# of getting (limit) entities
|
||||
search_result = self.qdrant.query_points(
|
||||
search_result = self.client.query_points(
|
||||
collection_name=collection,
|
||||
query=vec,
|
||||
limit=msg.limit * 2,
|
||||
limit=v.limit * 2,
|
||||
with_payload=True,
|
||||
).points
|
||||
|
||||
|
|
@ -73,10 +92,10 @@ class Processor(GraphEmbeddingsQueryService):
|
|||
entities.append(ent)
|
||||
|
||||
# Keep adding entities until limit
|
||||
if len(entity_set) >= msg.limit: break
|
||||
if len(entity_set) >= v.limit: break
|
||||
|
||||
# Keep adding entities until limit
|
||||
if len(entity_set) >= msg.limit: break
|
||||
if len(entity_set) >= v.limit: break
|
||||
|
||||
ents2 = []
|
||||
|
||||
|
|
@ -86,19 +105,36 @@ class Processor(GraphEmbeddingsQueryService):
|
|||
entities = ents2
|
||||
|
||||
print("Send response...", flush=True)
|
||||
return entities
|
||||
r = GraphEmbeddingsResponse(entities=entities, error=None)
|
||||
await self.send(r, properties={"id": id})
|
||||
|
||||
print("Done.", flush=True)
|
||||
|
||||
except Exception as e:
|
||||
|
||||
print(f"Exception: {e}")
|
||||
raise e
|
||||
|
||||
print("Send error response...", flush=True)
|
||||
|
||||
r = GraphEmbeddingsResponse(
|
||||
error=Error(
|
||||
type = "llm-error",
|
||||
message = str(e),
|
||||
),
|
||||
entities=None,
|
||||
)
|
||||
|
||||
await self.send(r, properties={"id": id})
|
||||
|
||||
self.consumer.acknowledge(msg)
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
||||
GraphEmbeddingsQueryService.add_args(parser)
|
||||
ConsumerProducer.add_args(
|
||||
parser, default_input_queue, default_subscriber,
|
||||
default_output_queue,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-t', '--store-uri',
|
||||
|
|
@ -114,5 +150,5 @@ class Processor(GraphEmbeddingsQueryService):
|
|||
|
||||
def run():
|
||||
|
||||
Processor.launch(default_ident, __doc__)
|
||||
Processor.launch(module, __doc__)
|
||||
|
||||
|
|
|
|||
|
|
@ -7,24 +7,38 @@ null. Output is a list of triples.
|
|||
from .... direct.cassandra import TrustGraph
|
||||
from .... schema import TriplesQueryRequest, TriplesQueryResponse, Error
|
||||
from .... schema import Value, Triple
|
||||
from .... base import TriplesQueryService
|
||||
from .... schema import triples_request_queue
|
||||
from .... schema import triples_response_queue
|
||||
from .... base import ConsumerProducer
|
||||
|
||||
default_ident = "triples-query"
|
||||
module = ".".join(__name__.split(".")[1:-1])
|
||||
|
||||
default_input_queue = triples_request_queue
|
||||
default_output_queue = triples_response_queue
|
||||
default_subscriber = module
|
||||
default_graph_host='localhost'
|
||||
|
||||
class Processor(TriplesQueryService):
|
||||
class Processor(ConsumerProducer):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
||||
input_queue = params.get("input_queue", default_input_queue)
|
||||
output_queue = params.get("output_queue", default_output_queue)
|
||||
subscriber = params.get("subscriber", default_subscriber)
|
||||
graph_host = params.get("graph_host", default_graph_host)
|
||||
graph_username = params.get("graph_username", None)
|
||||
graph_password = params.get("graph_password", None)
|
||||
|
||||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
"input_queue": input_queue,
|
||||
"output_queue": output_queue,
|
||||
"subscriber": subscriber,
|
||||
"input_schema": TriplesQueryRequest,
|
||||
"output_schema": TriplesQueryResponse,
|
||||
"graph_host": graph_host,
|
||||
"graph_username": graph_username,
|
||||
"graph_password": graph_password,
|
||||
}
|
||||
)
|
||||
|
||||
|
|
@ -39,85 +53,92 @@ class Processor(TriplesQueryService):
|
|||
else:
|
||||
return Value(value=ent, is_uri=False)
|
||||
|
||||
async def query_triples(self, query):
|
||||
async def handle(self, msg):
|
||||
|
||||
try:
|
||||
|
||||
table = (query.user, query.collection)
|
||||
v = msg.value()
|
||||
|
||||
table = (v.user, v.collection)
|
||||
|
||||
if table != self.table:
|
||||
if self.username and self.password:
|
||||
self.tg = TrustGraph(
|
||||
hosts=self.graph_host,
|
||||
keyspace=query.user, table=query.collection,
|
||||
keyspace=v.user, table=v.collection,
|
||||
username=self.username, password=self.password
|
||||
)
|
||||
else:
|
||||
self.tg = TrustGraph(
|
||||
hosts=self.graph_host,
|
||||
keyspace=query.user, table=query.collection,
|
||||
keyspace=v.user, table=v.collection,
|
||||
)
|
||||
self.table = table
|
||||
|
||||
# Sender-produced ID
|
||||
id = msg.properties()["id"]
|
||||
|
||||
print(f"Handling input {id}...", flush=True)
|
||||
|
||||
triples = []
|
||||
|
||||
if query.s is not None:
|
||||
if query.p is not None:
|
||||
if query.o is not None:
|
||||
if v.s is not None:
|
||||
if v.p is not None:
|
||||
if v.o is not None:
|
||||
resp = self.tg.get_spo(
|
||||
query.s.value, query.p.value, query.o.value,
|
||||
limit=query.limit
|
||||
v.s.value, v.p.value, v.o.value,
|
||||
limit=v.limit
|
||||
)
|
||||
triples.append((query.s.value, query.p.value, query.o.value))
|
||||
triples.append((v.s.value, v.p.value, v.o.value))
|
||||
else:
|
||||
resp = self.tg.get_sp(
|
||||
query.s.value, query.p.value,
|
||||
limit=query.limit
|
||||
v.s.value, v.p.value,
|
||||
limit=v.limit
|
||||
)
|
||||
for t in resp:
|
||||
triples.append((query.s.value, query.p.value, t.o))
|
||||
triples.append((v.s.value, v.p.value, t.o))
|
||||
else:
|
||||
if query.o is not None:
|
||||
if v.o is not None:
|
||||
resp = self.tg.get_os(
|
||||
query.o.value, query.s.value,
|
||||
limit=query.limit
|
||||
v.o.value, v.s.value,
|
||||
limit=v.limit
|
||||
)
|
||||
for t in resp:
|
||||
triples.append((query.s.value, t.p, query.o.value))
|
||||
triples.append((v.s.value, t.p, v.o.value))
|
||||
else:
|
||||
resp = self.tg.get_s(
|
||||
query.s.value,
|
||||
limit=query.limit
|
||||
v.s.value,
|
||||
limit=v.limit
|
||||
)
|
||||
for t in resp:
|
||||
triples.append((query.s.value, t.p, t.o))
|
||||
triples.append((v.s.value, t.p, t.o))
|
||||
else:
|
||||
if query.p is not None:
|
||||
if query.o is not None:
|
||||
if v.p is not None:
|
||||
if v.o is not None:
|
||||
resp = self.tg.get_po(
|
||||
query.p.value, query.o.value,
|
||||
limit=query.limit
|
||||
v.p.value, v.o.value,
|
||||
limit=v.limit
|
||||
)
|
||||
for t in resp:
|
||||
triples.append((t.s, query.p.value, query.o.value))
|
||||
triples.append((t.s, v.p.value, v.o.value))
|
||||
else:
|
||||
resp = self.tg.get_p(
|
||||
query.p.value,
|
||||
limit=query.limit
|
||||
v.p.value,
|
||||
limit=v.limit
|
||||
)
|
||||
for t in resp:
|
||||
triples.append((t.s, query.p.value, t.o))
|
||||
triples.append((t.s, v.p.value, t.o))
|
||||
else:
|
||||
if query.o is not None:
|
||||
if v.o is not None:
|
||||
resp = self.tg.get_o(
|
||||
query.o.value,
|
||||
limit=query.limit
|
||||
v.o.value,
|
||||
limit=v.limit
|
||||
)
|
||||
for t in resp:
|
||||
triples.append((t.s, t.p, query.o.value))
|
||||
triples.append((t.s, t.p, v.o.value))
|
||||
else:
|
||||
resp = self.tg.get_all(
|
||||
limit=query.limit
|
||||
limit=v.limit
|
||||
)
|
||||
for t in resp:
|
||||
triples.append((t.s, t.p, t.o))
|
||||
|
|
@ -131,17 +152,37 @@ class Processor(TriplesQueryService):
|
|||
for t in triples
|
||||
]
|
||||
|
||||
return triples
|
||||
print("Send response...", flush=True)
|
||||
r = TriplesQueryResponse(triples=triples, error=None)
|
||||
await self.send(r, properties={"id": id})
|
||||
|
||||
print("Done.", flush=True)
|
||||
|
||||
except Exception as e:
|
||||
|
||||
print(f"Exception: {e}")
|
||||
raise e
|
||||
|
||||
print("Send error response...", flush=True)
|
||||
|
||||
r = TriplesQueryResponse(
|
||||
error=Error(
|
||||
type = "llm-error",
|
||||
message = str(e),
|
||||
),
|
||||
response=None,
|
||||
)
|
||||
|
||||
await self.send(r, properties={"id": id})
|
||||
|
||||
self.consumer.acknowledge(msg)
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
||||
TriplesQueryService.add_args(parser)
|
||||
ConsumerProducer.add_args(
|
||||
parser, default_input_queue, default_subscriber,
|
||||
default_output_queue,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-g', '--graph-host',
|
||||
|
|
@ -164,5 +205,5 @@ class Processor(TriplesQueryService):
|
|||
|
||||
def run():
|
||||
|
||||
Processor.launch(default_ident, __doc__)
|
||||
Processor.launch(module, __doc__)
|
||||
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ from .... schema import triples_request_queue
|
|||
from .... schema import triples_response_queue
|
||||
from .... base import ConsumerProducer
|
||||
|
||||
module = "triples-query"
|
||||
module = ".".join(__name__.split(".")[1:-1])
|
||||
|
||||
default_input_queue = triples_request_queue
|
||||
default_output_queue = triples_response_queue
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ from .... schema import triples_request_queue
|
|||
from .... schema import triples_response_queue
|
||||
from .... base import ConsumerProducer
|
||||
|
||||
module = "triples-query"
|
||||
module = ".".join(__name__.split(".")[1:-1])
|
||||
|
||||
default_input_queue = triples_request_queue
|
||||
default_output_queue = triples_response_queue
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ from .... schema import triples_request_queue
|
|||
from .... schema import triples_response_queue
|
||||
from .... base import ConsumerProducer
|
||||
|
||||
module = "triples-query"
|
||||
module = ".".join(__name__.split(".")[1:-1])
|
||||
|
||||
default_input_queue = triples_request_queue
|
||||
default_output_queue = triples_response_queue
|
||||
|
|
|
|||
|
|
@ -1,94 +0,0 @@
|
|||
|
||||
import asyncio
|
||||
|
||||
LABEL="http://www.w3.org/2000/01/rdf-schema#label"
|
||||
|
||||
class Query:
|
||||
|
||||
def __init__(
|
||||
self, rag, user, collection, verbose,
|
||||
doc_limit=20
|
||||
):
|
||||
self.rag = rag
|
||||
self.user = user
|
||||
self.collection = collection
|
||||
self.verbose = verbose
|
||||
self.doc_limit = doc_limit
|
||||
|
||||
async def get_vector(self, query):
|
||||
|
||||
if self.verbose:
|
||||
print("Compute embeddings...", flush=True)
|
||||
|
||||
qembeds = await self.rag.embeddings_client.embed(query)
|
||||
|
||||
if self.verbose:
|
||||
print("Done.", flush=True)
|
||||
|
||||
return qembeds
|
||||
|
||||
async def get_docs(self, query):
|
||||
|
||||
vectors = await self.get_vector(query)
|
||||
|
||||
if self.verbose:
|
||||
print("Get docs...", flush=True)
|
||||
|
||||
docs = await self.rag.doc_embeddings_client.query(
|
||||
vectors, limit=self.doc_limit,
|
||||
user=self.user, collection=self.collection,
|
||||
)
|
||||
|
||||
if self.verbose:
|
||||
print("Docs:", flush=True)
|
||||
for doc in docs:
|
||||
print(doc, flush=True)
|
||||
|
||||
return docs
|
||||
|
||||
class DocumentRag:
|
||||
|
||||
def __init__(
|
||||
self, prompt_client, embeddings_client, doc_embeddings_client,
|
||||
verbose=False,
|
||||
):
|
||||
|
||||
self.verbose = verbose
|
||||
|
||||
self.prompt_client = prompt_client
|
||||
self.embeddings_client = embeddings_client
|
||||
self.doc_embeddings_client = doc_embeddings_client
|
||||
|
||||
if self.verbose:
|
||||
print("Initialised", flush=True)
|
||||
|
||||
async def query(
|
||||
self, query, user="trustgraph", collection="default",
|
||||
doc_limit=20,
|
||||
):
|
||||
|
||||
if self.verbose:
|
||||
print("Construct prompt...", flush=True)
|
||||
|
||||
q = Query(
|
||||
rag=self, user=user, collection=collection, verbose=self.verbose,
|
||||
doc_limit=doc_limit
|
||||
)
|
||||
|
||||
docs = await q.get_docs(query)
|
||||
|
||||
if self.verbose:
|
||||
print("Invoke LLM...", flush=True)
|
||||
print(docs)
|
||||
print(query)
|
||||
|
||||
resp = await self.prompt_client.document_prompt(
|
||||
query = query,
|
||||
documents = docs
|
||||
)
|
||||
|
||||
if self.verbose:
|
||||
print("Done", flush=True)
|
||||
|
||||
return resp
|
||||
|
||||
|
|
@ -5,77 +5,88 @@ Input is query, output is response.
|
|||
"""
|
||||
|
||||
from ... schema import DocumentRagQuery, DocumentRagResponse, Error
|
||||
from . document_rag import DocumentRag
|
||||
from ... base import FlowProcessor, ConsumerSpec, ProducerSpec
|
||||
from ... base import PromptClientSpec, EmbeddingsClientSpec
|
||||
from ... base import DocumentEmbeddingsClientSpec
|
||||
from ... schema import document_rag_request_queue, document_rag_response_queue
|
||||
from ... schema import prompt_request_queue
|
||||
from ... schema import prompt_response_queue
|
||||
from ... schema import embeddings_request_queue
|
||||
from ... schema import embeddings_response_queue
|
||||
from ... schema import document_embeddings_request_queue
|
||||
from ... schema import document_embeddings_response_queue
|
||||
from ... log_level import LogLevel
|
||||
from ... document_rag import DocumentRag
|
||||
from ... base import ConsumerProducer
|
||||
|
||||
default_ident = "document-rag"
|
||||
module = ".".join(__name__.split(".")[1:-1])
|
||||
|
||||
class Processor(FlowProcessor):
|
||||
default_input_queue = document_rag_request_queue
|
||||
default_output_queue = document_rag_response_queue
|
||||
default_subscriber = module
|
||||
|
||||
class Processor(ConsumerProducer):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
||||
id = params.get("id", default_ident)
|
||||
input_queue = params.get("input_queue", default_input_queue)
|
||||
output_queue = params.get("output_queue", default_output_queue)
|
||||
subscriber = params.get("subscriber", default_subscriber)
|
||||
pr_request_queue = params.get(
|
||||
"prompt_request_queue", prompt_request_queue
|
||||
)
|
||||
pr_response_queue = params.get(
|
||||
"prompt_response_queue", prompt_response_queue
|
||||
)
|
||||
emb_request_queue = params.get(
|
||||
"embeddings_request_queue", embeddings_request_queue
|
||||
)
|
||||
emb_response_queue = params.get(
|
||||
"embeddings_response_queue", embeddings_response_queue
|
||||
)
|
||||
de_request_queue = params.get(
|
||||
"document_embeddings_request_queue",
|
||||
document_embeddings_request_queue
|
||||
)
|
||||
de_response_queue = params.get(
|
||||
"document_embeddings_response_queue",
|
||||
document_embeddings_response_queue
|
||||
)
|
||||
|
||||
doc_limit = params.get("doc_limit", 5)
|
||||
doc_limit = params.get("doc_limit", 10)
|
||||
|
||||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
"id": id,
|
||||
"doc_limit": doc_limit,
|
||||
"input_queue": input_queue,
|
||||
"output_queue": output_queue,
|
||||
"subscriber": subscriber,
|
||||
"input_schema": DocumentRagQuery,
|
||||
"output_schema": DocumentRagResponse,
|
||||
"prompt_request_queue": pr_request_queue,
|
||||
"prompt_response_queue": pr_response_queue,
|
||||
"embeddings_request_queue": emb_request_queue,
|
||||
"embeddings_response_queue": emb_response_queue,
|
||||
"document_embeddings_request_queue": de_request_queue,
|
||||
"document_embeddings_response_queue": de_response_queue,
|
||||
}
|
||||
)
|
||||
|
||||
self.rag = DocumentRag(
|
||||
pulsar_host=self.pulsar_host,
|
||||
pulsar_api_key=self.pulsar_api_key,
|
||||
pr_request_queue=pr_request_queue,
|
||||
pr_response_queue=pr_response_queue,
|
||||
emb_request_queue=emb_request_queue,
|
||||
emb_response_queue=emb_response_queue,
|
||||
de_request_queue=de_request_queue,
|
||||
de_response_queue=de_response_queue,
|
||||
verbose=True,
|
||||
module=module,
|
||||
)
|
||||
|
||||
self.doc_limit = doc_limit
|
||||
|
||||
self.register_specification(
|
||||
ConsumerSpec(
|
||||
name = "request",
|
||||
schema = DocumentRagQuery,
|
||||
handler = self.on_request,
|
||||
)
|
||||
)
|
||||
|
||||
self.register_specification(
|
||||
EmbeddingsClientSpec(
|
||||
request_name = "embeddings-request",
|
||||
response_name = "embeddings-response",
|
||||
)
|
||||
)
|
||||
|
||||
self.register_specification(
|
||||
DocumentEmbeddingsClientSpec(
|
||||
request_name = "document-embeddings-request",
|
||||
response_name = "document-embeddings-response",
|
||||
)
|
||||
)
|
||||
|
||||
self.register_specification(
|
||||
PromptClientSpec(
|
||||
request_name = "prompt-request",
|
||||
response_name = "prompt-response",
|
||||
)
|
||||
)
|
||||
|
||||
self.register_specification(
|
||||
ProducerSpec(
|
||||
name = "response",
|
||||
schema = DocumentRagResponse,
|
||||
)
|
||||
)
|
||||
|
||||
async def on_request(self, msg, consumer, flow):
|
||||
async def handle(self, msg):
|
||||
|
||||
try:
|
||||
|
||||
self.rag = DocumentRag(
|
||||
embeddings_client = flow("embeddings-request"),
|
||||
doc_embeddings_client = flow("document-embeddings-request"),
|
||||
prompt_client = flow("prompt-request"),
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
v = msg.value()
|
||||
|
||||
# Sender-produced ID
|
||||
|
|
@ -88,15 +99,11 @@ class Processor(FlowProcessor):
|
|||
else:
|
||||
doc_limit = self.doc_limit
|
||||
|
||||
response = await self.rag.query(v.query, doc_limit=doc_limit)
|
||||
response = self.rag.query(v.query, doc_limit=doc_limit)
|
||||
|
||||
await flow("response").send(
|
||||
DocumentRagResponse(
|
||||
response = response,
|
||||
error = None
|
||||
),
|
||||
properties = {"id": id}
|
||||
)
|
||||
print("Send response...", flush=True)
|
||||
r = DocumentRagResponse(response = response, error=None)
|
||||
await self.send(r, properties={"id": id})
|
||||
|
||||
print("Done.", flush=True)
|
||||
|
||||
|
|
@ -106,21 +113,25 @@ class Processor(FlowProcessor):
|
|||
|
||||
print("Send error response...", flush=True)
|
||||
|
||||
await flow("response").send(
|
||||
DocumentRagResponse(
|
||||
response = None,
|
||||
error = Error(
|
||||
type = "document-rag-error",
|
||||
message = str(e),
|
||||
),
|
||||
r = DocumentRagResponse(
|
||||
error=Error(
|
||||
type = "llm-error",
|
||||
message = str(e),
|
||||
),
|
||||
properties = {"id": id}
|
||||
response=None,
|
||||
)
|
||||
|
||||
await self.send(r, properties={"id": id})
|
||||
|
||||
self.consumer.acknowledge(msg)
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
||||
FlowProcessor.add_args(parser)
|
||||
ConsumerProducer.add_args(
|
||||
parser, default_input_queue, default_subscriber,
|
||||
default_output_queue,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-d', '--doc-limit',
|
||||
|
|
@ -129,7 +140,43 @@ class Processor(FlowProcessor):
|
|||
help=f'Default document fetch limit (default: 10)'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--prompt-request-queue',
|
||||
default=prompt_request_queue,
|
||||
help=f'Prompt request queue (default: {prompt_request_queue})',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--prompt-response-queue',
|
||||
default=prompt_response_queue,
|
||||
help=f'Prompt response queue (default: {prompt_response_queue})',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--embeddings-request-queue',
|
||||
default=embeddings_request_queue,
|
||||
help=f'Embeddings request queue (default: {embeddings_request_queue})',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--embeddings-response-queue',
|
||||
default=embeddings_response_queue,
|
||||
help=f'Embeddings response queue (default: {embeddings_response_queue})',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--document-embeddings-request-queue',
|
||||
default=document_embeddings_request_queue,
|
||||
help=f'Document embeddings request queue (default: {document_embeddings_request_queue})',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--document-embeddings-response-queue',
|
||||
default=document_embeddings_response_queue,
|
||||
help=f'Document embeddings response queue (default: {document_embeddings_response_queue})',
|
||||
)
|
||||
|
||||
def run():
|
||||
|
||||
Processor.launch(default_ident, __doc__)
|
||||
Processor.launch(module, __doc__)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,218 +0,0 @@
|
|||
|
||||
import asyncio
|
||||
|
||||
LABEL="http://www.w3.org/2000/01/rdf-schema#label"
|
||||
|
||||
class Query:
|
||||
|
||||
def __init__(
|
||||
self, rag, user, collection, verbose,
|
||||
entity_limit=50, triple_limit=30, max_subgraph_size=1000,
|
||||
max_path_length=2,
|
||||
):
|
||||
self.rag = rag
|
||||
self.user = user
|
||||
self.collection = collection
|
||||
self.verbose = verbose
|
||||
self.entity_limit = entity_limit
|
||||
self.triple_limit = triple_limit
|
||||
self.max_subgraph_size = max_subgraph_size
|
||||
self.max_path_length = max_path_length
|
||||
|
||||
async def get_vector(self, query):
|
||||
|
||||
if self.verbose:
|
||||
print("Compute embeddings...", flush=True)
|
||||
|
||||
qembeds = await self.rag.embeddings_client.embed(query)
|
||||
|
||||
if self.verbose:
|
||||
print("Done.", flush=True)
|
||||
|
||||
return qembeds
|
||||
|
||||
async def get_entities(self, query):
|
||||
|
||||
vectors = await self.get_vector(query)
|
||||
|
||||
if self.verbose:
|
||||
print("Get entities...", flush=True)
|
||||
|
||||
entities = await self.rag.graph_embeddings_client.query(
|
||||
vectors=vectors, limit=self.entity_limit,
|
||||
user=self.user, collection=self.collection,
|
||||
)
|
||||
|
||||
entities = [
|
||||
str(e)
|
||||
for e in entities
|
||||
]
|
||||
|
||||
if self.verbose:
|
||||
print("Entities:", flush=True)
|
||||
for ent in entities:
|
||||
print(" ", ent, flush=True)
|
||||
|
||||
return entities
|
||||
|
||||
async def maybe_label(self, e):
|
||||
|
||||
if e in self.rag.label_cache:
|
||||
return self.rag.label_cache[e]
|
||||
|
||||
res = await self.rag.triples_client.query(
|
||||
s=e, p=LABEL, o=None, limit=1,
|
||||
user=self.user, collection=self.collection,
|
||||
)
|
||||
|
||||
if len(res) == 0:
|
||||
self.rag.label_cache[e] = e
|
||||
return e
|
||||
|
||||
self.rag.label_cache[e] = str(res[0].o)
|
||||
return self.rag.label_cache[e]
|
||||
|
||||
async def follow_edges(self, ent, subgraph, path_length):
|
||||
|
||||
# Not needed?
|
||||
if path_length <= 0:
|
||||
return
|
||||
|
||||
# Stop spanning around if the subgraph is already maxed out
|
||||
if len(subgraph) >= self.max_subgraph_size:
|
||||
return
|
||||
|
||||
res = await self.rag.triples_client.query(
|
||||
s=ent, p=None, o=None,
|
||||
limit=self.triple_limit,
|
||||
user=self.user, collection=self.collection,
|
||||
)
|
||||
|
||||
for triple in res:
|
||||
subgraph.add(
|
||||
(str(triple.s), str(triple.p), str(triple.o))
|
||||
)
|
||||
if path_length > 1:
|
||||
await self.follow_edges(str(triple.o), subgraph, path_length-1)
|
||||
|
||||
res = await self.rag.triples_client.query(
|
||||
s=None, p=ent, o=None,
|
||||
limit=self.triple_limit,
|
||||
user=self.user, collection=self.collection,
|
||||
)
|
||||
|
||||
for triple in res:
|
||||
subgraph.add(
|
||||
(str(triple.s), str(triple.p), str(triple.o))
|
||||
)
|
||||
|
||||
res = await self.rag.triples_client.query(
|
||||
s=None, p=None, o=ent,
|
||||
limit=self.triple_limit,
|
||||
user=self.user, collection=self.collection,
|
||||
)
|
||||
|
||||
for triple in res:
|
||||
subgraph.add(
|
||||
(str(triple.s), str(triple.p), str(triple.o))
|
||||
)
|
||||
if path_length > 1:
|
||||
await self.follow_edges(
|
||||
str(triple.s), subgraph, path_length-1
|
||||
)
|
||||
|
||||
async def get_subgraph(self, query):
|
||||
|
||||
entities = await self.get_entities(query)
|
||||
|
||||
if self.verbose:
|
||||
print("Get subgraph...", flush=True)
|
||||
|
||||
subgraph = set()
|
||||
|
||||
for ent in entities:
|
||||
await self.follow_edges(ent, subgraph, self.max_path_length)
|
||||
|
||||
subgraph = list(subgraph)
|
||||
|
||||
return subgraph
|
||||
|
||||
async def get_labelgraph(self, query):
|
||||
|
||||
subgraph = await self.get_subgraph(query)
|
||||
|
||||
sg2 = []
|
||||
|
||||
for edge in subgraph:
|
||||
|
||||
if edge[1] == LABEL:
|
||||
continue
|
||||
|
||||
s = await self.maybe_label(edge[0])
|
||||
p = await self.maybe_label(edge[1])
|
||||
o = await self.maybe_label(edge[2])
|
||||
|
||||
sg2.append((s, p, o))
|
||||
|
||||
sg2 = sg2[0:self.max_subgraph_size]
|
||||
|
||||
if self.verbose:
|
||||
print("Subgraph:", flush=True)
|
||||
for edge in sg2:
|
||||
print(" ", str(edge), flush=True)
|
||||
|
||||
if self.verbose:
|
||||
print("Done.", flush=True)
|
||||
|
||||
return sg2
|
||||
|
||||
class GraphRag:
|
||||
|
||||
def __init__(
|
||||
self, prompt_client, embeddings_client, graph_embeddings_client,
|
||||
triples_client, verbose=False,
|
||||
):
|
||||
|
||||
self.verbose = verbose
|
||||
|
||||
self.prompt_client = prompt_client
|
||||
self.embeddings_client = embeddings_client
|
||||
self.graph_embeddings_client = graph_embeddings_client
|
||||
self.triples_client = triples_client
|
||||
|
||||
self.label_cache = {}
|
||||
|
||||
if self.verbose:
|
||||
print("Initialised", flush=True)
|
||||
|
||||
async def query(
|
||||
self, query, user = "trustgraph", collection = "default",
|
||||
entity_limit = 50, triple_limit = 30, max_subgraph_size = 1000,
|
||||
max_path_length = 2,
|
||||
):
|
||||
|
||||
if self.verbose:
|
||||
print("Construct prompt...", flush=True)
|
||||
|
||||
q = Query(
|
||||
rag = self, user = user, collection = collection,
|
||||
verbose = self.verbose, entity_limit = entity_limit,
|
||||
triple_limit = triple_limit,
|
||||
max_subgraph_size = max_subgraph_size,
|
||||
max_path_length = max_path_length,
|
||||
)
|
||||
|
||||
kg = await q.get_labelgraph(query)
|
||||
|
||||
if self.verbose:
|
||||
print("Invoke LLM...", flush=True)
|
||||
print(kg)
|
||||
print(query)
|
||||
|
||||
resp = await self.prompt_client.kg_prompt(query, kg)
|
||||
|
||||
if self.verbose:
|
||||
print("Done", flush=True)
|
||||
|
||||
return resp
|
||||
|
||||
|
|
@ -5,18 +5,57 @@ Input is query, output is response.
|
|||
"""
|
||||
|
||||
from ... schema import GraphRagQuery, GraphRagResponse, Error
|
||||
from . graph_rag import GraphRag
|
||||
from ... base import FlowProcessor, ConsumerSpec, ProducerSpec
|
||||
from ... base import PromptClientSpec, EmbeddingsClientSpec
|
||||
from ... base import GraphEmbeddingsClientSpec, TriplesClientSpec
|
||||
from ... schema import graph_rag_request_queue, graph_rag_response_queue
|
||||
from ... schema import prompt_request_queue
|
||||
from ... schema import prompt_response_queue
|
||||
from ... schema import embeddings_request_queue
|
||||
from ... schema import embeddings_response_queue
|
||||
from ... schema import graph_embeddings_request_queue
|
||||
from ... schema import graph_embeddings_response_queue
|
||||
from ... schema import triples_request_queue
|
||||
from ... schema import triples_response_queue
|
||||
from ... log_level import LogLevel
|
||||
from ... graph_rag import GraphRag
|
||||
from ... base import ConsumerProducer
|
||||
|
||||
default_ident = "graph-rag"
|
||||
module = ".".join(__name__.split(".")[1:-1])
|
||||
|
||||
class Processor(FlowProcessor):
|
||||
default_input_queue = graph_rag_request_queue
|
||||
default_output_queue = graph_rag_response_queue
|
||||
default_subscriber = module
|
||||
|
||||
class Processor(ConsumerProducer):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
||||
id = params.get("id", default_ident)
|
||||
input_queue = params.get("input_queue", default_input_queue)
|
||||
output_queue = params.get("output_queue", default_output_queue)
|
||||
subscriber = params.get("subscriber", default_subscriber)
|
||||
|
||||
pr_request_queue = params.get(
|
||||
"prompt_request_queue", prompt_request_queue
|
||||
)
|
||||
pr_response_queue = params.get(
|
||||
"prompt_response_queue", prompt_response_queue
|
||||
)
|
||||
emb_request_queue = params.get(
|
||||
"embeddings_request_queue", embeddings_request_queue
|
||||
)
|
||||
emb_response_queue = params.get(
|
||||
"embeddings_response_queue", embeddings_response_queue
|
||||
)
|
||||
ge_request_queue = params.get(
|
||||
"graph_embeddings_request_queue", graph_embeddings_request_queue
|
||||
)
|
||||
ge_response_queue = params.get(
|
||||
"graph_embeddings_response_queue", graph_embeddings_response_queue
|
||||
)
|
||||
tpl_request_queue = params.get(
|
||||
"triples_request_queue", triples_request_queue
|
||||
)
|
||||
tpl_response_queue = params.get(
|
||||
"triples_response_queue", triples_response_queue
|
||||
)
|
||||
|
||||
entity_limit = params.get("entity_limit", 50)
|
||||
triple_limit = params.get("triple_limit", 30)
|
||||
|
|
@ -25,74 +64,49 @@ class Processor(FlowProcessor):
|
|||
|
||||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
"id": id,
|
||||
"input_queue": input_queue,
|
||||
"output_queue": output_queue,
|
||||
"subscriber": subscriber,
|
||||
"input_schema": GraphRagQuery,
|
||||
"output_schema": GraphRagResponse,
|
||||
"entity_limit": entity_limit,
|
||||
"triple_limit": triple_limit,
|
||||
"max_subgraph_size": max_subgraph_size,
|
||||
"max_path_length": max_path_length,
|
||||
"prompt_request_queue": pr_request_queue,
|
||||
"prompt_response_queue": pr_response_queue,
|
||||
"embeddings_request_queue": emb_request_queue,
|
||||
"embeddings_response_queue": emb_response_queue,
|
||||
"graph_embeddings_request_queue": ge_request_queue,
|
||||
"graph_embeddings_response_queue": ge_response_queue,
|
||||
"triples_request_queue": triples_request_queue,
|
||||
"triples_response_queue": triples_response_queue,
|
||||
}
|
||||
)
|
||||
|
||||
self.rag = GraphRag(
|
||||
pulsar_host=self.pulsar_host,
|
||||
pulsar_api_key=self.pulsar_api_key,
|
||||
pr_request_queue=pr_request_queue,
|
||||
pr_response_queue=pr_response_queue,
|
||||
emb_request_queue=emb_request_queue,
|
||||
emb_response_queue=emb_response_queue,
|
||||
ge_request_queue=ge_request_queue,
|
||||
ge_response_queue=ge_response_queue,
|
||||
tpl_request_queue=triples_request_queue,
|
||||
tpl_response_queue=triples_response_queue,
|
||||
verbose=True,
|
||||
module=module,
|
||||
)
|
||||
|
||||
self.default_entity_limit = entity_limit
|
||||
self.default_triple_limit = triple_limit
|
||||
self.default_max_subgraph_size = max_subgraph_size
|
||||
self.default_max_path_length = max_path_length
|
||||
|
||||
self.register_specification(
|
||||
ConsumerSpec(
|
||||
name = "request",
|
||||
schema = GraphRagQuery,
|
||||
handler = self.on_request,
|
||||
)
|
||||
)
|
||||
|
||||
self.register_specification(
|
||||
EmbeddingsClientSpec(
|
||||
request_name = "embeddings-request",
|
||||
response_name = "embeddings-response",
|
||||
)
|
||||
)
|
||||
|
||||
self.register_specification(
|
||||
GraphEmbeddingsClientSpec(
|
||||
request_name = "graph-embeddings-request",
|
||||
response_name = "graph-embeddings-response",
|
||||
)
|
||||
)
|
||||
|
||||
self.register_specification(
|
||||
TriplesClientSpec(
|
||||
request_name = "triples-request",
|
||||
response_name = "triples-response",
|
||||
)
|
||||
)
|
||||
|
||||
self.register_specification(
|
||||
PromptClientSpec(
|
||||
request_name = "prompt-request",
|
||||
response_name = "prompt-response",
|
||||
)
|
||||
)
|
||||
|
||||
self.register_specification(
|
||||
ProducerSpec(
|
||||
name = "response",
|
||||
schema = GraphRagResponse,
|
||||
)
|
||||
)
|
||||
|
||||
async def on_request(self, msg, consumer, flow):
|
||||
async def handle(self, msg):
|
||||
|
||||
try:
|
||||
|
||||
self.rag = GraphRag(
|
||||
embeddings_client = flow("embeddings-request"),
|
||||
graph_embeddings_client = flow("graph-embeddings-request"),
|
||||
triples_client = flow("triples-request"),
|
||||
prompt_client = flow("prompt-request"),
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
v = msg.value()
|
||||
|
||||
# Sender-produced ID
|
||||
|
|
@ -120,20 +134,16 @@ class Processor(FlowProcessor):
|
|||
else:
|
||||
max_path_length = self.default_max_path_length
|
||||
|
||||
response = await self.rag.query(
|
||||
query = v.query, user = v.user, collection = v.collection,
|
||||
entity_limit = entity_limit, triple_limit = triple_limit,
|
||||
max_subgraph_size = max_subgraph_size,
|
||||
max_path_length = max_path_length,
|
||||
response = self.rag.query(
|
||||
query=v.query, user=v.user, collection=v.collection,
|
||||
entity_limit=entity_limit, triple_limit=triple_limit,
|
||||
max_subgraph_size=max_subgraph_size,
|
||||
max_path_length=max_path_length,
|
||||
)
|
||||
|
||||
await flow("response").send(
|
||||
GraphRagResponse(
|
||||
response = response,
|
||||
error = None
|
||||
),
|
||||
properties = {"id": id}
|
||||
)
|
||||
print("Send response...", flush=True)
|
||||
r = GraphRagResponse(response=response, error=None)
|
||||
await self.send(r, properties={"id": id})
|
||||
|
||||
print("Done.", flush=True)
|
||||
|
||||
|
|
@ -143,21 +153,25 @@ class Processor(FlowProcessor):
|
|||
|
||||
print("Send error response...", flush=True)
|
||||
|
||||
await flow("response").send(
|
||||
GraphRagResponse(
|
||||
response = None,
|
||||
error = Error(
|
||||
type = "graph-rag-error",
|
||||
message = str(e),
|
||||
),
|
||||
r = GraphRagResponse(
|
||||
error=Error(
|
||||
type = "llm-error",
|
||||
message = str(e),
|
||||
),
|
||||
properties = {"id": id}
|
||||
response=None,
|
||||
)
|
||||
|
||||
await self.send(r, properties={"id": id})
|
||||
|
||||
self.consumer.acknowledge(msg)
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
||||
FlowProcessor.add_args(parser)
|
||||
ConsumerProducer.add_args(
|
||||
parser, default_input_queue, default_subscriber,
|
||||
default_output_queue,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-e', '--entity-limit',
|
||||
|
|
@ -187,7 +201,55 @@ class Processor(FlowProcessor):
|
|||
help=f'Default max path length (default: 2)'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--prompt-request-queue',
|
||||
default=prompt_request_queue,
|
||||
help=f'Prompt request queue (default: {prompt_request_queue})',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--prompt-response-queue',
|
||||
default=prompt_response_queue,
|
||||
help=f'Prompt response queue (default: {prompt_response_queue})',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--embeddings-request-queue',
|
||||
default=embeddings_request_queue,
|
||||
help=f'Embeddings request queue (default: {embeddings_request_queue})',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--embeddings-response-queue',
|
||||
default=embeddings_response_queue,
|
||||
help=f'Embeddings response queue (default: {embeddings_response_queue})',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--graph-embeddings-request-queue',
|
||||
default=graph_embeddings_request_queue,
|
||||
help=f'Graph embeddings request queue (default: {graph_embeddings_request_queue})',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--graph-embeddings-response-queue',
|
||||
default=graph_embeddings_response_queue,
|
||||
help=f'Graph embeddings response queue (default: {graph_embeddings_response_queue})',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--triples-request-queue',
|
||||
default=triples_request_queue,
|
||||
help=f'Triples request queue (default: {triples_request_queue})',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--triples-response-queue',
|
||||
default=triples_response_queue,
|
||||
help=f'Triples response queue (default: {triples_response_queue})',
|
||||
)
|
||||
|
||||
def run():
|
||||
|
||||
Processor.launch(default_ident, __doc__)
|
||||
Processor.launch(module, __doc__)
|
||||
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ from .... schema import document_embeddings_store_queue
|
|||
from .... log_level import LogLevel
|
||||
from .... base import Consumer
|
||||
|
||||
module = "de-write"
|
||||
module = ".".join(__name__.split(".")[1:-1])
|
||||
|
||||
default_input_queue = document_embeddings_store_queue
|
||||
default_subscriber = module
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ from .... schema import document_embeddings_store_queue
|
|||
from .... log_level import LogLevel
|
||||
from .... base import Consumer
|
||||
|
||||
module = "de-write"
|
||||
module = ".".join(__name__.split(".")[1:-1])
|
||||
|
||||
default_input_queue = document_embeddings_store_queue
|
||||
default_subscriber = module
|
||||
|
|
|
|||
|
|
@ -8,21 +8,31 @@ from qdrant_client.models import PointStruct
|
|||
from qdrant_client.models import Distance, VectorParams
|
||||
import uuid
|
||||
|
||||
from .... base import DocumentEmbeddingsStoreService
|
||||
from .... schema import DocumentEmbeddings
|
||||
from .... schema import document_embeddings_store_queue
|
||||
from .... log_level import LogLevel
|
||||
from .... base import Consumer
|
||||
|
||||
default_ident = "de-write"
|
||||
module = ".".join(__name__.split(".")[1:-1])
|
||||
|
||||
default_input_queue = document_embeddings_store_queue
|
||||
default_subscriber = module
|
||||
default_store_uri = 'http://localhost:6333'
|
||||
|
||||
class Processor(DocumentEmbeddingsStoreService):
|
||||
class Processor(Consumer):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
||||
input_queue = params.get("input_queue", default_input_queue)
|
||||
subscriber = params.get("subscriber", default_subscriber)
|
||||
store_uri = params.get("store_uri", default_store_uri)
|
||||
api_key = params.get("api_key", None)
|
||||
|
||||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
"input_queue": input_queue,
|
||||
"subscriber": subscriber,
|
||||
"input_schema": DocumentEmbeddings,
|
||||
"store_uri": store_uri,
|
||||
"api_key": api_key,
|
||||
}
|
||||
|
|
@ -30,11 +40,13 @@ class Processor(DocumentEmbeddingsStoreService):
|
|||
|
||||
self.last_collection = None
|
||||
|
||||
self.qdrant = QdrantClient(url=store_uri, api_key=api_key)
|
||||
self.client = QdrantClient(url=store_uri)
|
||||
|
||||
async def store_document_embeddings(self, message):
|
||||
async def handle(self, msg):
|
||||
|
||||
for emb in message.chunks:
|
||||
v = msg.value()
|
||||
|
||||
for emb in v.chunks:
|
||||
|
||||
chunk = emb.chunk.decode("utf-8")
|
||||
if chunk == "": return
|
||||
|
|
@ -43,17 +55,16 @@ class Processor(DocumentEmbeddingsStoreService):
|
|||
|
||||
dim = len(vec)
|
||||
collection = (
|
||||
"d_" + message.metadata.user + "_" +
|
||||
message.metadata.collection + "_" +
|
||||
"d_" + v.metadata.user + "_" + v.metadata.collection + "_" +
|
||||
str(dim)
|
||||
)
|
||||
|
||||
if collection != self.last_collection:
|
||||
|
||||
if not self.qdrant.collection_exists(collection):
|
||||
if not self.client.collection_exists(collection):
|
||||
|
||||
try:
|
||||
self.qdrant.create_collection(
|
||||
self.client.create_collection(
|
||||
collection_name=collection,
|
||||
vectors_config=VectorParams(
|
||||
size=dim, distance=Distance.COSINE
|
||||
|
|
@ -65,7 +76,7 @@ class Processor(DocumentEmbeddingsStoreService):
|
|||
|
||||
self.last_collection = collection
|
||||
|
||||
self.qdrant.upsert(
|
||||
self.client.upsert(
|
||||
collection_name=collection,
|
||||
points=[
|
||||
PointStruct(
|
||||
|
|
@ -81,7 +92,9 @@ class Processor(DocumentEmbeddingsStoreService):
|
|||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
||||
DocumentEmbeddingsStoreService.add_args(parser)
|
||||
Consumer.add_args(
|
||||
parser, default_input_queue, default_subscriber,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-t', '--store-uri',
|
||||
|
|
@ -97,5 +110,5 @@ class Processor(DocumentEmbeddingsStoreService):
|
|||
|
||||
def run():
|
||||
|
||||
Processor.launch(default_ident, __doc__)
|
||||
Processor.launch(module, __doc__)
|
||||
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ from .... log_level import LogLevel
|
|||
from .... direct.milvus_graph_embeddings import EntityVectors
|
||||
from .... base import Consumer
|
||||
|
||||
module = "ge-write"
|
||||
module = ".".join(__name__.split(".")[1:-1])
|
||||
|
||||
default_input_queue = graph_embeddings_store_queue
|
||||
default_subscriber = module
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ from .... schema import graph_embeddings_store_queue
|
|||
from .... log_level import LogLevel
|
||||
from .... base import Consumer
|
||||
|
||||
module = "ge-write"
|
||||
module = ".".join(__name__.split(".")[1:-1])
|
||||
|
||||
default_input_queue = graph_embeddings_store_queue
|
||||
default_subscriber = module
|
||||
|
|
|
|||
|
|
@ -8,21 +8,31 @@ from qdrant_client.models import PointStruct
|
|||
from qdrant_client.models import Distance, VectorParams
|
||||
import uuid
|
||||
|
||||
from .... base import GraphEmbeddingsStoreService
|
||||
from .... schema import GraphEmbeddings
|
||||
from .... schema import graph_embeddings_store_queue
|
||||
from .... log_level import LogLevel
|
||||
from .... base import Consumer
|
||||
|
||||
default_ident = "ge-write"
|
||||
module = ".".join(__name__.split(".")[1:-1])
|
||||
|
||||
default_input_queue = graph_embeddings_store_queue
|
||||
default_subscriber = module
|
||||
default_store_uri = 'http://localhost:6333'
|
||||
|
||||
class Processor(GraphEmbeddingsStoreService):
|
||||
class Processor(Consumer):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
||||
input_queue = params.get("input_queue", default_input_queue)
|
||||
subscriber = params.get("subscriber", default_subscriber)
|
||||
store_uri = params.get("store_uri", default_store_uri)
|
||||
api_key = params.get("api_key", None)
|
||||
|
||||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
"input_queue": input_queue,
|
||||
"subscriber": subscriber,
|
||||
"input_schema": GraphEmbeddings,
|
||||
"store_uri": store_uri,
|
||||
"api_key": api_key,
|
||||
}
|
||||
|
|
@ -30,7 +40,7 @@ class Processor(GraphEmbeddingsStoreService):
|
|||
|
||||
self.last_collection = None
|
||||
|
||||
self.qdrant = QdrantClient(url=store_uri, api_key=api_key)
|
||||
self.client = QdrantClient(url=store_uri, api_key=api_key)
|
||||
|
||||
def get_collection(self, dim, user, collection):
|
||||
|
||||
|
|
@ -40,10 +50,10 @@ class Processor(GraphEmbeddingsStoreService):
|
|||
|
||||
if cname != self.last_collection:
|
||||
|
||||
if not self.qdrant.collection_exists(cname):
|
||||
if not self.client.collection_exists(cname):
|
||||
|
||||
try:
|
||||
self.qdrant.create_collection(
|
||||
self.client.create_collection(
|
||||
collection_name=cname,
|
||||
vectors_config=VectorParams(
|
||||
size=dim, distance=Distance.COSINE
|
||||
|
|
@ -57,9 +67,11 @@ class Processor(GraphEmbeddingsStoreService):
|
|||
|
||||
return cname
|
||||
|
||||
async def store_graph_embeddings(self, message):
|
||||
async def handle(self, msg):
|
||||
|
||||
for entity in message.entities:
|
||||
v = msg.value()
|
||||
|
||||
for entity in v.entities:
|
||||
|
||||
if entity.entity.value == "" or entity.entity.value is None: return
|
||||
|
||||
|
|
@ -68,10 +80,10 @@ class Processor(GraphEmbeddingsStoreService):
|
|||
dim = len(vec)
|
||||
|
||||
collection = self.get_collection(
|
||||
dim, message.metadata.user, message.metadata.collection
|
||||
dim, v.metadata.user, v.metadata.collection
|
||||
)
|
||||
|
||||
self.qdrant.upsert(
|
||||
self.client.upsert(
|
||||
collection_name=collection,
|
||||
points=[
|
||||
PointStruct(
|
||||
|
|
@ -87,7 +99,9 @@ class Processor(GraphEmbeddingsStoreService):
|
|||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
||||
GraphEmbeddingsStoreService.add_args(parser)
|
||||
Consumer.add_args(
|
||||
parser, default_input_queue, default_subscriber,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-t', '--store-uri',
|
||||
|
|
@ -103,5 +117,5 @@ class Processor(GraphEmbeddingsStoreService):
|
|||
|
||||
def run():
|
||||
|
||||
Processor.launch(default_ident, __doc__)
|
||||
Processor.launch(module, __doc__)
|
||||
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ from .... log_level import LogLevel
|
|||
from .... direct.milvus_object_embeddings import ObjectVectors
|
||||
from .... base import Consumer
|
||||
|
||||
module = "oe-write"
|
||||
module = ".".join(__name__.split(".")[1:-1])
|
||||
|
||||
default_input_queue = object_embeddings_store_queue
|
||||
default_subscriber = module
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ from .... schema import rows_store_queue
|
|||
from .... log_level import LogLevel
|
||||
from .... base import Consumer
|
||||
|
||||
module = "rows-write"
|
||||
module = ".".join(__name__.split(".")[1:-1])
|
||||
ssl_context = SSLContext(PROTOCOL_TLSv1_2)
|
||||
|
||||
default_input_queue = rows_store_queue
|
||||
|
|
|
|||
|
|
@ -10,26 +10,35 @@ import argparse
|
|||
import time
|
||||
|
||||
from .... direct.cassandra import TrustGraph
|
||||
from .... base import TriplesStoreService
|
||||
from .... schema import Triples
|
||||
from .... schema import triples_store_queue
|
||||
from .... log_level import LogLevel
|
||||
from .... base import Consumer
|
||||
|
||||
default_ident = "triples-write"
|
||||
module = ".".join(__name__.split(".")[1:-1])
|
||||
|
||||
default_input_queue = triples_store_queue
|
||||
default_subscriber = module
|
||||
default_graph_host='localhost'
|
||||
|
||||
class Processor(TriplesStoreService):
|
||||
class Processor(Consumer):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
||||
id = params.get("id", default_ident)
|
||||
|
||||
input_queue = params.get("input_queue", default_input_queue)
|
||||
subscriber = params.get("subscriber", default_subscriber)
|
||||
graph_host = params.get("graph_host", default_graph_host)
|
||||
graph_username = params.get("graph_username", None)
|
||||
graph_password = params.get("graph_password", None)
|
||||
|
||||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
"input_queue": input_queue,
|
||||
"subscriber": subscriber,
|
||||
"input_schema": Triples,
|
||||
"graph_host": graph_host,
|
||||
"graph_username": graph_username
|
||||
"graph_username": graph_username,
|
||||
"graph_password": graph_password,
|
||||
}
|
||||
)
|
||||
|
||||
|
|
@ -38,9 +47,11 @@ class Processor(TriplesStoreService):
|
|||
self.password = graph_password
|
||||
self.table = None
|
||||
|
||||
async def store_triples(self, message):
|
||||
async def handle(self, msg):
|
||||
|
||||
table = (message.metadata.user, message.metadata.collection)
|
||||
v = msg.value()
|
||||
|
||||
table = (v.metadata.user, v.metadata.collection)
|
||||
|
||||
if self.table is None or self.table != table:
|
||||
|
||||
|
|
@ -50,15 +61,13 @@ class Processor(TriplesStoreService):
|
|||
if self.username and self.password:
|
||||
self.tg = TrustGraph(
|
||||
hosts=self.graph_host,
|
||||
keyspace=message.metadata.user,
|
||||
table=message.metadata.collection,
|
||||
keyspace=v.metadata.user, table=v.metadata.collection,
|
||||
username=self.username, password=self.password
|
||||
)
|
||||
else:
|
||||
self.tg = TrustGraph(
|
||||
hosts=self.graph_host,
|
||||
keyspace=message.metadata.user,
|
||||
table=message.metadata.collection,
|
||||
keyspace=v.metadata.user, table=v.metadata.collection,
|
||||
)
|
||||
except Exception as e:
|
||||
print("Exception", e, flush=True)
|
||||
|
|
@ -67,7 +76,7 @@ class Processor(TriplesStoreService):
|
|||
|
||||
self.table = table
|
||||
|
||||
for t in message.triples:
|
||||
for t in v.triples:
|
||||
self.tg.insert(
|
||||
t.s.value,
|
||||
t.p.value,
|
||||
|
|
@ -77,7 +86,9 @@ class Processor(TriplesStoreService):
|
|||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
||||
TriplesStoreService.add_args(parser)
|
||||
Consumer.add_args(
|
||||
parser, default_input_queue, default_subscriber,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-g', '--graph-host',
|
||||
|
|
@ -99,5 +110,5 @@ class Processor(TriplesStoreService):
|
|||
|
||||
def run():
|
||||
|
||||
Processor.launch(default_ident, __doc__)
|
||||
Processor.launch(module, __doc__)
|
||||
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ from .... schema import triples_store_queue
|
|||
from .... log_level import LogLevel
|
||||
from .... base import Consumer
|
||||
|
||||
module = "triples-write"
|
||||
module = ".".join(__name__.split(".")[1:-1])
|
||||
|
||||
default_input_queue = triples_store_queue
|
||||
default_subscriber = module
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ from .... schema import triples_store_queue
|
|||
from .... log_level import LogLevel
|
||||
from .... base import Consumer
|
||||
|
||||
module = "triples-write"
|
||||
module = ".".join(__name__.split(".")[1:-1])
|
||||
|
||||
default_input_queue = triples_store_queue
|
||||
default_subscriber = module
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ from .... schema import triples_store_queue
|
|||
from .... log_level import LogLevel
|
||||
from .... base import Consumer
|
||||
|
||||
module = "triples-write"
|
||||
module = ".".join(__name__.split(".")[1:-1])
|
||||
|
||||
default_input_queue = triples_store_queue
|
||||
default_subscriber = module
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue