Feature/configure flows (#345)

- Keeps processing in different flows separate so that data can go to different stores / collections etc.
- Potentially supports different processing flows
- Tidies the processing API with common base-classes for e.g. LLMs, and automatic configuration of 'clients' to use the right queue names in a flow
This commit is contained in:
cybermaggedon 2025-04-22 20:21:38 +01:00 committed by GitHub
parent a06a814a41
commit a9197d11ee
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
125 changed files with 3751 additions and 2628 deletions

View file

@ -0,0 +1,218 @@
import asyncio
LABEL="http://www.w3.org/2000/01/rdf-schema#label"
class Query:
def __init__(
self, rag, user, collection, verbose,
entity_limit=50, triple_limit=30, max_subgraph_size=1000,
max_path_length=2,
):
self.rag = rag
self.user = user
self.collection = collection
self.verbose = verbose
self.entity_limit = entity_limit
self.triple_limit = triple_limit
self.max_subgraph_size = max_subgraph_size
self.max_path_length = max_path_length
async def get_vector(self, query):
if self.verbose:
print("Compute embeddings...", flush=True)
qembeds = await self.rag.embeddings_client.embed(query)
if self.verbose:
print("Done.", flush=True)
return qembeds
async def get_entities(self, query):
vectors = await self.get_vector(query)
if self.verbose:
print("Get entities...", flush=True)
entities = await self.rag.graph_embeddings_client.query(
vectors=vectors, limit=self.entity_limit,
user=self.user, collection=self.collection,
)
entities = [
str(e)
for e in entities
]
if self.verbose:
print("Entities:", flush=True)
for ent in entities:
print(" ", ent, flush=True)
return entities
async def maybe_label(self, e):
if e in self.rag.label_cache:
return self.rag.label_cache[e]
res = await self.rag.triples_client.query(
s=e, p=LABEL, o=None, limit=1,
user=self.user, collection=self.collection,
)
if len(res) == 0:
self.rag.label_cache[e] = e
return e
self.rag.label_cache[e] = str(res[0].o)
return self.rag.label_cache[e]
async def follow_edges(self, ent, subgraph, path_length):
# Not needed?
if path_length <= 0:
return
# Stop spanning around if the subgraph is already maxed out
if len(subgraph) >= self.max_subgraph_size:
return
res = await self.rag.triples_client.query(
s=ent, p=None, o=None,
limit=self.triple_limit,
user=self.user, collection=self.collection,
)
for triple in res:
subgraph.add(
(str(triple.s), str(triple.p), str(triple.o))
)
if path_length > 1:
await self.follow_edges(str(triple.o), subgraph, path_length-1)
res = await self.rag.triples_client.query(
s=None, p=ent, o=None,
limit=self.triple_limit,
user=self.user, collection=self.collection,
)
for triple in res:
subgraph.add(
(str(triple.s), str(triple.p), str(triple.o))
)
res = await self.rag.triples_client.query(
s=None, p=None, o=ent,
limit=self.triple_limit,
user=self.user, collection=self.collection,
)
for triple in res:
subgraph.add(
(str(triple.s), str(triple.p), str(triple.o))
)
if path_length > 1:
await self.follow_edges(
str(triple.s), subgraph, path_length-1
)
async def get_subgraph(self, query):
entities = await self.get_entities(query)
if self.verbose:
print("Get subgraph...", flush=True)
subgraph = set()
for ent in entities:
await self.follow_edges(ent, subgraph, self.max_path_length)
subgraph = list(subgraph)
return subgraph
async def get_labelgraph(self, query):
subgraph = await self.get_subgraph(query)
sg2 = []
for edge in subgraph:
if edge[1] == LABEL:
continue
s = await self.maybe_label(edge[0])
p = await self.maybe_label(edge[1])
o = await self.maybe_label(edge[2])
sg2.append((s, p, o))
sg2 = sg2[0:self.max_subgraph_size]
if self.verbose:
print("Subgraph:", flush=True)
for edge in sg2:
print(" ", str(edge), flush=True)
if self.verbose:
print("Done.", flush=True)
return sg2
class GraphRag:
def __init__(
self, prompt_client, embeddings_client, graph_embeddings_client,
triples_client, verbose=False,
):
self.verbose = verbose
self.prompt_client = prompt_client
self.embeddings_client = embeddings_client
self.graph_embeddings_client = graph_embeddings_client
self.triples_client = triples_client
self.label_cache = {}
if self.verbose:
print("Initialised", flush=True)
async def query(
self, query, user = "trustgraph", collection = "default",
entity_limit = 50, triple_limit = 30, max_subgraph_size = 1000,
max_path_length = 2,
):
if self.verbose:
print("Construct prompt...", flush=True)
q = Query(
rag = self, user = user, collection = collection,
verbose = self.verbose, entity_limit = entity_limit,
triple_limit = triple_limit,
max_subgraph_size = max_subgraph_size,
max_path_length = max_path_length,
)
kg = await q.get_labelgraph(query)
if self.verbose:
print("Invoke LLM...", flush=True)
print(kg)
print(query)
resp = await self.prompt_client.kg_prompt(query, kg)
if self.verbose:
print("Done", flush=True)
return resp

View file

@ -5,57 +5,18 @@ Input is query, output is response.
"""
from ... schema import GraphRagQuery, GraphRagResponse, Error
from ... schema import graph_rag_request_queue, graph_rag_response_queue
from ... schema import prompt_request_queue
from ... schema import prompt_response_queue
from ... schema import embeddings_request_queue
from ... schema import embeddings_response_queue
from ... schema import graph_embeddings_request_queue
from ... schema import graph_embeddings_response_queue
from ... schema import triples_request_queue
from ... schema import triples_response_queue
from ... log_level import LogLevel
from ... graph_rag import GraphRag
from ... base import ConsumerProducer
from . graph_rag import GraphRag
from ... base import FlowProcessor, ConsumerSpec, ProducerSpec
from ... base import PromptClientSpec, EmbeddingsClientSpec
from ... base import GraphEmbeddingsClientSpec, TriplesClientSpec
module = ".".join(__name__.split(".")[1:-1])
default_ident = "graph-rag"
default_input_queue = graph_rag_request_queue
default_output_queue = graph_rag_response_queue
default_subscriber = module
class Processor(ConsumerProducer):
class Processor(FlowProcessor):
def __init__(self, **params):
input_queue = params.get("input_queue", default_input_queue)
output_queue = params.get("output_queue", default_output_queue)
subscriber = params.get("subscriber", default_subscriber)
pr_request_queue = params.get(
"prompt_request_queue", prompt_request_queue
)
pr_response_queue = params.get(
"prompt_response_queue", prompt_response_queue
)
emb_request_queue = params.get(
"embeddings_request_queue", embeddings_request_queue
)
emb_response_queue = params.get(
"embeddings_response_queue", embeddings_response_queue
)
ge_request_queue = params.get(
"graph_embeddings_request_queue", graph_embeddings_request_queue
)
ge_response_queue = params.get(
"graph_embeddings_response_queue", graph_embeddings_response_queue
)
tpl_request_queue = params.get(
"triples_request_queue", triples_request_queue
)
tpl_response_queue = params.get(
"triples_response_queue", triples_response_queue
)
id = params.get("id", default_ident)
entity_limit = params.get("entity_limit", 50)
triple_limit = params.get("triple_limit", 30)
@ -64,49 +25,74 @@ class Processor(ConsumerProducer):
super(Processor, self).__init__(
**params | {
"input_queue": input_queue,
"output_queue": output_queue,
"subscriber": subscriber,
"input_schema": GraphRagQuery,
"output_schema": GraphRagResponse,
"id": id,
"entity_limit": entity_limit,
"triple_limit": triple_limit,
"max_subgraph_size": max_subgraph_size,
"prompt_request_queue": pr_request_queue,
"prompt_response_queue": pr_response_queue,
"embeddings_request_queue": emb_request_queue,
"embeddings_response_queue": emb_response_queue,
"graph_embeddings_request_queue": ge_request_queue,
"graph_embeddings_response_queue": ge_response_queue,
"triples_request_queue": triples_request_queue,
"triples_response_queue": triples_response_queue,
"max_path_length": max_path_length,
}
)
self.rag = GraphRag(
pulsar_host=self.pulsar_host,
pulsar_api_key=self.pulsar_api_key,
pr_request_queue=pr_request_queue,
pr_response_queue=pr_response_queue,
emb_request_queue=emb_request_queue,
emb_response_queue=emb_response_queue,
ge_request_queue=ge_request_queue,
ge_response_queue=ge_response_queue,
tpl_request_queue=triples_request_queue,
tpl_response_queue=triples_response_queue,
verbose=True,
module=module,
)
self.default_entity_limit = entity_limit
self.default_triple_limit = triple_limit
self.default_max_subgraph_size = max_subgraph_size
self.default_max_path_length = max_path_length
async def handle(self, msg):
self.register_specification(
ConsumerSpec(
name = "request",
schema = GraphRagQuery,
handler = self.on_request,
)
)
self.register_specification(
EmbeddingsClientSpec(
request_name = "embeddings-request",
response_name = "embeddings-response",
)
)
self.register_specification(
GraphEmbeddingsClientSpec(
request_name = "graph-embeddings-request",
response_name = "graph-embeddings-response",
)
)
self.register_specification(
TriplesClientSpec(
request_name = "triples-request",
response_name = "triples-response",
)
)
self.register_specification(
PromptClientSpec(
request_name = "prompt-request",
response_name = "prompt-response",
)
)
self.register_specification(
ProducerSpec(
name = "response",
schema = GraphRagResponse,
)
)
async def on_request(self, msg, consumer, flow):
try:
self.rag = GraphRag(
embeddings_client = flow("embeddings-request"),
graph_embeddings_client = flow("graph-embeddings-request"),
triples_client = flow("triples-request"),
prompt_client = flow("prompt-request"),
verbose=True,
)
v = msg.value()
# Sender-produced ID
@ -134,16 +120,20 @@ class Processor(ConsumerProducer):
else:
max_path_length = self.default_max_path_length
response = self.rag.query(
query=v.query, user=v.user, collection=v.collection,
entity_limit=entity_limit, triple_limit=triple_limit,
max_subgraph_size=max_subgraph_size,
max_path_length=max_path_length,
response = await self.rag.query(
query = v.query, user = v.user, collection = v.collection,
entity_limit = entity_limit, triple_limit = triple_limit,
max_subgraph_size = max_subgraph_size,
max_path_length = max_path_length,
)
print("Send response...", flush=True)
r = GraphRagResponse(response=response, error=None)
await self.send(r, properties={"id": id})
await flow("response").send(
GraphRagResponse(
response = response,
error = None
),
properties = {"id": id}
)
print("Done.", flush=True)
@ -153,25 +143,21 @@ class Processor(ConsumerProducer):
print("Send error response...", flush=True)
r = GraphRagResponse(
error=Error(
type = "llm-error",
message = str(e),
await flow("response").send(
GraphRagResponse(
response = None,
error = Error(
type = "graph-rag-error",
message = str(e),
),
),
response=None,
properties = {"id": id}
)
await self.send(r, properties={"id": id})
self.consumer.acknowledge(msg)
@staticmethod
def add_args(parser):
ConsumerProducer.add_args(
parser, default_input_queue, default_subscriber,
default_output_queue,
)
FlowProcessor.add_args(parser)
parser.add_argument(
'-e', '--entity-limit',
@ -201,55 +187,7 @@ class Processor(ConsumerProducer):
help=f'Default max path length (default: 2)'
)
parser.add_argument(
'--prompt-request-queue',
default=prompt_request_queue,
help=f'Prompt request queue (default: {prompt_request_queue})',
)
parser.add_argument(
'--prompt-response-queue',
default=prompt_response_queue,
help=f'Prompt response queue (default: {prompt_response_queue})',
)
parser.add_argument(
'--embeddings-request-queue',
default=embeddings_request_queue,
help=f'Embeddings request queue (default: {embeddings_request_queue})',
)
parser.add_argument(
'--embeddings-response-queue',
default=embeddings_response_queue,
help=f'Embeddings response queue (default: {embeddings_response_queue})',
)
parser.add_argument(
'--graph-embeddings-request-queue',
default=graph_embeddings_request_queue,
help=f'Graph embeddings request queue (default: {graph_embeddings_request_queue})',
)
parser.add_argument(
'--graph-embeddings-response-queue',
default=graph_embeddings_response_queue,
help=f'Graph embeddings response queue (default: {graph_embeddings_response_queue})',
)
parser.add_argument(
'--triples-request-queue',
default=triples_request_queue,
help=f'Triples request queue (default: {triples_request_queue})',
)
parser.add_argument(
'--triples-response-queue',
default=triples_response_queue,
help=f'Triples response queue (default: {triples_response_queue})',
)
def run():
Processor.launch(module, __doc__)
Processor.launch(default_ident, __doc__)