Feature/configure flows (#345)

- Keeps processing in different flows separate so that data can go to different stores / collections etc.
- Potentially supports different processing flows
- Tidies the processing API with common base-classes for e.g. LLMs, and automatic configuration of 'clients' to use the right queue names in a flow
This commit is contained in:
cybermaggedon 2025-04-22 20:21:38 +01:00 committed by GitHub
parent a06a814a41
commit a9197d11ee
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
125 changed files with 3751 additions and 2628 deletions

View file

@ -5,57 +5,18 @@ Input is query, output is response.
"""
from ... schema import GraphRagQuery, GraphRagResponse, Error
from ... schema import graph_rag_request_queue, graph_rag_response_queue
from ... schema import prompt_request_queue
from ... schema import prompt_response_queue
from ... schema import embeddings_request_queue
from ... schema import embeddings_response_queue
from ... schema import graph_embeddings_request_queue
from ... schema import graph_embeddings_response_queue
from ... schema import triples_request_queue
from ... schema import triples_response_queue
from ... log_level import LogLevel
from ... graph_rag import GraphRag
from ... base import ConsumerProducer
from . graph_rag import GraphRag
from ... base import FlowProcessor, ConsumerSpec, ProducerSpec
from ... base import PromptClientSpec, EmbeddingsClientSpec
from ... base import GraphEmbeddingsClientSpec, TriplesClientSpec
module = ".".join(__name__.split(".")[1:-1])
default_ident = "graph-rag"
default_input_queue = graph_rag_request_queue
default_output_queue = graph_rag_response_queue
default_subscriber = module
class Processor(ConsumerProducer):
class Processor(FlowProcessor):
def __init__(self, **params):
input_queue = params.get("input_queue", default_input_queue)
output_queue = params.get("output_queue", default_output_queue)
subscriber = params.get("subscriber", default_subscriber)
pr_request_queue = params.get(
"prompt_request_queue", prompt_request_queue
)
pr_response_queue = params.get(
"prompt_response_queue", prompt_response_queue
)
emb_request_queue = params.get(
"embeddings_request_queue", embeddings_request_queue
)
emb_response_queue = params.get(
"embeddings_response_queue", embeddings_response_queue
)
ge_request_queue = params.get(
"graph_embeddings_request_queue", graph_embeddings_request_queue
)
ge_response_queue = params.get(
"graph_embeddings_response_queue", graph_embeddings_response_queue
)
tpl_request_queue = params.get(
"triples_request_queue", triples_request_queue
)
tpl_response_queue = params.get(
"triples_response_queue", triples_response_queue
)
id = params.get("id", default_ident)
entity_limit = params.get("entity_limit", 50)
triple_limit = params.get("triple_limit", 30)
@ -64,49 +25,74 @@ class Processor(ConsumerProducer):
super(Processor, self).__init__(
**params | {
"input_queue": input_queue,
"output_queue": output_queue,
"subscriber": subscriber,
"input_schema": GraphRagQuery,
"output_schema": GraphRagResponse,
"id": id,
"entity_limit": entity_limit,
"triple_limit": triple_limit,
"max_subgraph_size": max_subgraph_size,
"prompt_request_queue": pr_request_queue,
"prompt_response_queue": pr_response_queue,
"embeddings_request_queue": emb_request_queue,
"embeddings_response_queue": emb_response_queue,
"graph_embeddings_request_queue": ge_request_queue,
"graph_embeddings_response_queue": ge_response_queue,
"triples_request_queue": triples_request_queue,
"triples_response_queue": triples_response_queue,
"max_path_length": max_path_length,
}
)
self.rag = GraphRag(
pulsar_host=self.pulsar_host,
pulsar_api_key=self.pulsar_api_key,
pr_request_queue=pr_request_queue,
pr_response_queue=pr_response_queue,
emb_request_queue=emb_request_queue,
emb_response_queue=emb_response_queue,
ge_request_queue=ge_request_queue,
ge_response_queue=ge_response_queue,
tpl_request_queue=triples_request_queue,
tpl_response_queue=triples_response_queue,
verbose=True,
module=module,
)
self.default_entity_limit = entity_limit
self.default_triple_limit = triple_limit
self.default_max_subgraph_size = max_subgraph_size
self.default_max_path_length = max_path_length
async def handle(self, msg):
self.register_specification(
ConsumerSpec(
name = "request",
schema = GraphRagQuery,
handler = self.on_request,
)
)
self.register_specification(
EmbeddingsClientSpec(
request_name = "embeddings-request",
response_name = "embeddings-response",
)
)
self.register_specification(
GraphEmbeddingsClientSpec(
request_name = "graph-embeddings-request",
response_name = "graph-embeddings-response",
)
)
self.register_specification(
TriplesClientSpec(
request_name = "triples-request",
response_name = "triples-response",
)
)
self.register_specification(
PromptClientSpec(
request_name = "prompt-request",
response_name = "prompt-response",
)
)
self.register_specification(
ProducerSpec(
name = "response",
schema = GraphRagResponse,
)
)
async def on_request(self, msg, consumer, flow):
try:
self.rag = GraphRag(
embeddings_client = flow("embeddings-request"),
graph_embeddings_client = flow("graph-embeddings-request"),
triples_client = flow("triples-request"),
prompt_client = flow("prompt-request"),
verbose=True,
)
v = msg.value()
# Sender-produced ID
@ -134,16 +120,20 @@ class Processor(ConsumerProducer):
else:
max_path_length = self.default_max_path_length
response = self.rag.query(
query=v.query, user=v.user, collection=v.collection,
entity_limit=entity_limit, triple_limit=triple_limit,
max_subgraph_size=max_subgraph_size,
max_path_length=max_path_length,
response = await self.rag.query(
query = v.query, user = v.user, collection = v.collection,
entity_limit = entity_limit, triple_limit = triple_limit,
max_subgraph_size = max_subgraph_size,
max_path_length = max_path_length,
)
print("Send response...", flush=True)
r = GraphRagResponse(response=response, error=None)
await self.send(r, properties={"id": id})
await flow("response").send(
GraphRagResponse(
response = response,
error = None
),
properties = {"id": id}
)
print("Done.", flush=True)
@ -153,25 +143,21 @@ class Processor(ConsumerProducer):
print("Send error response...", flush=True)
r = GraphRagResponse(
error=Error(
type = "llm-error",
message = str(e),
await flow("response").send(
GraphRagResponse(
response = None,
error = Error(
type = "graph-rag-error",
message = str(e),
),
),
response=None,
properties = {"id": id}
)
await self.send(r, properties={"id": id})
self.consumer.acknowledge(msg)
@staticmethod
def add_args(parser):
ConsumerProducer.add_args(
parser, default_input_queue, default_subscriber,
default_output_queue,
)
FlowProcessor.add_args(parser)
parser.add_argument(
'-e', '--entity-limit',
@ -201,55 +187,7 @@ class Processor(ConsumerProducer):
help=f'Default max path length (default: 2)'
)
parser.add_argument(
'--prompt-request-queue',
default=prompt_request_queue,
help=f'Prompt request queue (default: {prompt_request_queue})',
)
parser.add_argument(
'--prompt-response-queue',
default=prompt_response_queue,
help=f'Prompt response queue (default: {prompt_response_queue})',
)
parser.add_argument(
'--embeddings-request-queue',
default=embeddings_request_queue,
help=f'Embeddings request queue (default: {embeddings_request_queue})',
)
parser.add_argument(
'--embeddings-response-queue',
default=embeddings_response_queue,
help=f'Embeddings response queue (default: {embeddings_response_queue})',
)
parser.add_argument(
'--graph-embeddings-request-queue',
default=graph_embeddings_request_queue,
help=f'Graph embeddings request queue (default: {graph_embeddings_request_queue})',
)
parser.add_argument(
'--graph-embeddings-response-queue',
default=graph_embeddings_response_queue,
help=f'Graph embeddings response queue (default: {graph_embeddings_response_queue})',
)
parser.add_argument(
'--triples-request-queue',
default=triples_request_queue,
help=f'Triples request queue (default: {triples_request_queue})',
)
parser.add_argument(
'--triples-response-queue',
default=triples_response_queue,
help=f'Triples response queue (default: {triples_response_queue})',
)
def run():
Processor.launch(module, __doc__)
Processor.launch(default_ident, __doc__)