Feature/rag parameters (#311)

* Change document-rag and graph-rag processing so that the user can
specify parameters.  Changes in Pulsar services, Pulsar message
schemas, gateway and command-line tools.  User-visible changes in
new parameters on command-line tools.

* Fix bugs, graph-rag working

* Get subgraph truncation in the right place

* Graph RAG and document RAG working and configurable

* Multi-hop path traversal GraphRAG

* Add safety valve for path_size set too high
This commit is contained in:
cybermaggedon 2025-03-13 00:38:18 +00:00 committed by GitHub
parent f1559c5944
commit ef845d6c9b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 247 additions and 91 deletions

View file

@ -18,11 +18,15 @@ DEFINITION="http://www.w3.org/2004/02/skos/core#definition"
class Query:
def __init__(self, rag, user, collection, verbose):
def __init__(
self, rag, user, collection, verbose,
doc_limit=20
):
self.rag = rag
self.user = user
self.collection = collection
self.verbose = verbose
self.doc_limit = doc_limit
def get_vector(self, query):
@ -44,7 +48,7 @@ class Query:
print("Get entities...", flush=True)
docs = self.rag.de_client.request(
vectors, limit=self.rag.doc_limit
vectors, limit=self.doc_limit
)
if self.verbose:
@ -93,9 +97,6 @@ class DocumentRag:
if self.verbose:
print("Initialising...", flush=True)
# FIXME: Configurable
self.doc_limit = 20
self.de_client = DocumentEmbeddingsClient(
pulsar_host=pulsar_host,
subscriber=module + "-de",
@ -123,13 +124,17 @@ class DocumentRag:
if self.verbose:
print("Initialised", flush=True)
def query(self, query, user="trustgraph", collection="default"):
def query(
self, query, user="trustgraph", collection="default",
doc_limit=20,
):
if self.verbose:
print("Construct prompt...", flush=True)
q = Query(
rag=self, user=user, collection=collection, verbose=self.verbose
rag=self, user=user, collection=collection, verbose=self.verbose,
doc_limit=doc_limit
)
docs = q.get_docs(query)

View file

@ -23,6 +23,7 @@ class DocumentRagRequestor(ServiceRequestor):
query=body["query"],
user=body.get("user", "trustgraph"),
collection=body.get("collection", "default"),
doc_limit=int(body.get("doc-limit", 20)),
)
def from_response(self, message):

View file

@ -23,6 +23,10 @@ class GraphRagRequestor(ServiceRequestor):
query=body["query"],
user=body.get("user", "trustgraph"),
collection=body.get("collection", "default"),
entity_limit=int(body.get("entity-limit", 50)),
triple_limit=int(body.get("triple-limit", 30)),
max_subgraph_size=int(body.get("max-subgraph-size", 1000)),
max_path_length=int(body.get("max-path-length", 2)),
)
def from_response(self, message):

View file

@ -20,11 +20,19 @@ DEFINITION="http://www.w3.org/2004/02/skos/core#definition"
class Query:
def __init__(self, rag, user, collection, verbose):
def __init__(
self, rag, user, collection, verbose,
entity_limit=50, triple_limit=30, max_subgraph_size=1000,
max_path_length=2,
):
self.rag = rag
self.user = user
self.collection = collection
self.verbose = verbose
self.entity_limit = entity_limit
self.triple_limit = triple_limit
self.max_subgraph_size = max_subgraph_size
self.max_path_length = max_path_length
def get_vector(self, query):
@ -47,7 +55,7 @@ class Query:
entities = self.rag.ge_client.request(
user=self.user, collection=self.collection,
vectors=vectors, limit=self.rag.entity_limit,
vectors=vectors, limit=self.entity_limit,
)
entities = [
@ -79,62 +87,67 @@ class Query:
self.rag.label_cache[e] = res[0].o.value
return self.rag.label_cache[e]
def follow_edges(self, ent, subgraph, path_length):
# Not needed?
if path_length <= 0:
return
# Stop spanning around if the subgraph is already maxed out
if len(subgraph) >= self.max_subgraph_size:
return
res = self.rag.triples_client.request(
user=self.user, collection=self.collection,
s=ent, p=None, o=None,
limit=self.triple_limit
)
for triple in res:
subgraph.add(
(triple.s.value, triple.p.value, triple.o.value)
)
if path_length > 1:
self.follow_edges(triple.o.value, subgraph, path_length-1)
res = self.rag.triples_client.request(
user=self.user, collection=self.collection,
s=None, p=ent, o=None,
limit=self.triple_limit
)
for triple in res:
subgraph.add(
(triple.s.value, triple.p.value, triple.o.value)
)
res = self.rag.triples_client.request(
user=self.user, collection=self.collection,
s=None, p=None, o=ent,
limit=self.triple_limit,
)
for triple in res:
subgraph.add(
(triple.s.value, triple.p.value, triple.o.value)
)
if path_length > 1:
self.follow_edges(triple.s.value, subgraph, path_length-1)
def get_subgraph(self, query):
entities = self.get_entities(query)
subgraph = set()
if self.verbose:
print("Get subgraph...", flush=True)
for e in entities:
subgraph = set()
res = self.rag.triples_client.request(
user=self.user, collection=self.collection,
s=e, p=None, o=None,
limit=self.rag.query_limit
)
for triple in res:
subgraph.add(
(triple.s.value, triple.p.value, triple.o.value)
)
res = self.rag.triples_client.request(
user=self.user, collection=self.collection,
s=None, p=e, o=None,
limit=self.rag.query_limit
)
for triple in res:
subgraph.add(
(triple.s.value, triple.p.value, triple.o.value)
)
res = self.rag.triples_client.request(
user=self.user, collection=self.collection,
s=None, p=None, o=e,
limit=self.rag.query_limit,
)
for triple in res:
subgraph.add(
(triple.s.value, triple.p.value, triple.o.value)
)
for ent in entities:
self.follow_edges(ent, subgraph, self.max_path_length)
subgraph = list(subgraph)
subgraph = subgraph[0:self.rag.max_subgraph_size]
if self.verbose:
print("Subgraph:", flush=True)
for edge in subgraph:
print(" ", str(edge), flush=True)
if self.verbose:
print("Done.", flush=True)
return subgraph
def get_labelgraph(self, query):
@ -154,6 +167,16 @@ class Query:
sg2.append((s, p, o))
sg2 = sg2[0:self.max_subgraph_size]
if self.verbose:
print("Subgraph:", flush=True)
for edge in sg2:
print(" ", str(edge), flush=True)
if self.verbose:
print("Done.", flush=True)
return sg2
class GraphRag:
@ -171,9 +194,6 @@ class GraphRag:
tpl_request_queue=None,
tpl_response_queue=None,
verbose=False,
entity_limit=50,
triple_limit=30,
max_subgraph_size=3000,
module="test",
):
@ -230,10 +250,6 @@ class GraphRag:
subscriber=module + "-emb",
)
self.entity_limit=entity_limit
self.query_limit=triple_limit
self.max_subgraph_size=max_subgraph_size
self.label_cache = {}
self.prompt = PromptClient(
@ -247,13 +263,20 @@ class GraphRag:
if self.verbose:
print("Initialised", flush=True)
def query(self, query, user="trustgraph", collection="default"):
def query(
self, query, user="trustgraph", collection="default",
entity_limit=50, triple_limit=30, max_subgraph_size=1000,
max_path_length=2,
):
if self.verbose:
print("Construct prompt...", flush=True)
q = Query(
rag=self, user=user, collection=collection, verbose=self.verbose
rag=self, user=user, collection=collection, verbose=self.verbose,
entity_limit=entity_limit, triple_limit=triple_limit,
max_subgraph_size=max_subgraph_size,
max_path_length=max_path_length,
)
kg = q.get_labelgraph(query)

View file

@ -50,6 +50,8 @@ class Processor(ConsumerProducer):
document_embeddings_response_queue
)
doc_limit = params.get("doc_limit", 10)
super(Processor, self).__init__(
**params | {
"input_queue": input_queue,
@ -79,6 +81,8 @@ class Processor(ConsumerProducer):
module=module,
)
self.doc_limit = doc_limit
async def handle(self, msg):
try:
@ -90,7 +94,12 @@ class Processor(ConsumerProducer):
print(f"Handling input {id}...", flush=True)
response = self.rag.query(v.query)
if v.doc_limit:
doc_limit = v.doc_limit
else:
doc_limit = self.doc_limit
response = self.rag.query(v.query, doc_limit=doc_limit)
print("Send response...", flush=True)
r = DocumentRagResponse(response = response, error=None)
@ -124,6 +133,13 @@ class Processor(ConsumerProducer):
default_output_queue,
)
parser.add_argument(
'-d', '--doc-limit',
type=int,
default=20,
help=f'Default document fetch limit (default: 10)'
)
parser.add_argument(
'--prompt-request-queue',
default=prompt_request_queue,

View file

@ -31,9 +31,7 @@ class Processor(ConsumerProducer):
input_queue = params.get("input_queue", default_input_queue)
output_queue = params.get("output_queue", default_output_queue)
subscriber = params.get("subscriber", default_subscriber)
entity_limit = params.get("entity_limit", 50)
triple_limit = params.get("triple_limit", 30)
max_subgraph_size = params.get("max_subgraph_size", 3000)
pr_request_queue = params.get(
"prompt_request_queue", prompt_request_queue
)
@ -59,6 +57,11 @@ class Processor(ConsumerProducer):
"triples_response_queue", triples_response_queue
)
entity_limit = params.get("entity_limit", 50)
triple_limit = params.get("triple_limit", 30)
max_subgraph_size = params.get("max_subgraph_size", 150)
max_path_length = params.get("max_path_length", 2)
super(Processor, self).__init__(
**params | {
"input_queue": input_queue,
@ -92,12 +95,14 @@ class Processor(ConsumerProducer):
tpl_request_queue=triples_request_queue,
tpl_response_queue=triples_response_queue,
verbose=True,
entity_limit=entity_limit,
triple_limit=triple_limit,
max_subgraph_size=max_subgraph_size,
module=module,
)
self.default_entity_limit = entity_limit
self.default_triple_limit = triple_limit
self.default_max_subgraph_size = max_subgraph_size
self.default_max_path_length = max_path_length
async def handle(self, msg):
try:
@ -106,15 +111,38 @@ class Processor(ConsumerProducer):
# Sender-produced ID
id = msg.properties()["id"]
print(f"Handling input {id}...", flush=True)
if v.entity_limit:
entity_limit = v.entity_limit
else:
entity_limit = self.default_entity_limit
if v.triple_limit:
triple_limit = v.triple_limit
else:
triple_limit = self.default_triple_limit
if v.max_subgraph_size:
max_subgraph_size = v.max_subgraph_size
else:
max_subgraph_size = self.default_max_subgraph_size
if v.max_path_length:
max_path_length = v.max_path_length
else:
max_path_length = self.default_max_path_length
response = self.rag.query(
query=v.query, user=v.user, collection=v.collection
query=v.query, user=v.user, collection=v.collection,
entity_limit=entity_limit, triple_limit=triple_limit,
max_subgraph_size=max_subgraph_size,
max_path_length=max_path_length,
)
print("Send response...", flush=True)
r = GraphRagResponse(response = response, error=None)
r = GraphRagResponse(response=response, error=None)
await self.send(r, properties={"id": id})
print("Done.", flush=True)
@ -149,21 +177,28 @@ class Processor(ConsumerProducer):
'-e', '--entity-limit',
type=int,
default=50,
help=f'Entity vector fetch limit (default: 50)'
help=f'Default entity vector fetch limit (default: 50)'
)
parser.add_argument(
'-t', '--triple-limit',
type=int,
default=30,
help=f'Triple query limit, per query (default: 30)'
help=f'Default triple query limit, per query (default: 30)'
)
parser.add_argument(
'-u', '--max-subgraph-size',
type=int,
default=3000,
help=f'Max subgraph size (default: 3000)'
default=150,
help=f'Default max subgraph size (default: 150)'
)
parser.add_argument(
'-a', '--max-path-length',
type=int,
default=2,
help=f'Default max path length (default: 2)'
)
parser.add_argument(