Feature/rag parameters (#311)

* Change document-rag and graph-rag processing so that the user can
specify parameters.  Changes in Pulsar services, Pulsar message
schemas, gateway and command-line tools.  User-visible changes in
new parameters on command-line tools.

* Fix bugs, graph-rag working

* Get subgraph truncation in the right place

* Graph RAG and document RAG working and configurable

* Multi-hop path traversal GraphRAG

* Add safety valve for path_size set too high
This commit is contained in:
cybermaggedon 2025-03-13 00:38:18 +00:00 committed by GitHub
parent f1559c5944
commit ef845d6c9b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 247 additions and 91 deletions

View file

@ -5,6 +5,8 @@ local prompts = import "prompts/mixtral.jsonnet";
{
"document-rag-doc-limit":: 20,
"document-rag" +: {
create:: function(engine)

View file

@ -4,9 +4,9 @@ local url = import "values/url.jsonnet";
{
"graph-rag-entity-limit":: 50,
"graph-rag-triple-limit":: 30,
"graph-rag-max-subgraph-size":: 3000,
"graph-rag-entity-limit":: 20,
"graph-rag-triple-limit":: 10,
"graph-rag-max-subgraph-size":: 1000,
"kg-extract-definitions" +: {

View file

@ -102,11 +102,21 @@ class Api:
except:
raise ProtocolException(f"Response not formatted correctly")
def graph_rag(self, question):
def graph_rag(
self, question, user="trustgraph", collection="default",
entity_limit=50, triple_limit=30, max_subgraph_size=150,
max_path_length=2,
):
# The input consists of a question
input = {
"query": question
"query": question,
"user": user,
"collection": collection,
"entity-limit": entity_limit,
"triple-limit": triple_limit,
"max-subgraph-size": max_subgraph_size,
"max-path-length": max_path_length,
}
url = f"{self.url}graph-rag"
@ -131,11 +141,17 @@ class Api:
except:
raise ProtocolException(f"Response not formatted correctly")
def document_rag(self, question):
def document_rag(
self, question, user="trustgraph", collection="default",
doc_limit=10,
):
# The input consists of a question
input = {
"query": question
"query": question,
"user": user,
"collection": collection,
"doc-limit": doc_limit,
}
url = f"{self.url}document-rag"

View file

@ -11,6 +11,10 @@ class GraphRagQuery(Record):
query = String()
user = String()
collection = String()
entity_limit = Integer()
triple_limit = Integer()
max_subgraph_size = Integer()
max_path_length = Integer()
class GraphRagResponse(Record):
error = Error()
@ -31,6 +35,7 @@ class DocumentRagQuery(Record):
query = String()
user = String()
collection = String()
doc_limit = Integer()
class DocumentRagResponse(Record):
error = Error()

View file

@ -11,13 +11,16 @@ from trustgraph.api import Api
default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/')
default_user = 'trustgraph'
default_collection = 'default'
default_doc_limit = 10
def question(url, question, user, collection):
def question(url, question, user, collection, doc_limit):
rag = Api(url)
# user=user, collection=collection,
resp = rag.document_rag(question=question)
resp = rag.document_rag(
question=question, user=user, collection=collection,
doc_limit=doc_limit,
)
print(resp)
@ -41,7 +44,7 @@ def main():
# )
parser.add_argument(
'-q', '--query',
'-q', '--question',
required=True,
help=f'Question to answer',
)
@ -58,6 +61,12 @@ def main():
help=f'Collection ID (default: {default_collection})'
)
parser.add_argument(
'-d', '--doc-limit',
default=default_doc_limit,
help=f'Document limit (default: {default_doc_limit})'
)
args = parser.parse_args()
try:
@ -67,6 +76,7 @@ def main():
question=args.question,
user=args.user,
collection=args.collection,
doc_limit=args.doc_limit,
)
except Exception as e:

View file

@ -11,13 +11,24 @@ from trustgraph.api import Api
default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/')
default_user = 'trustgraph'
default_collection = 'default'
default_entity_limit = 50
default_triple_limit = 30
default_max_subgraph_size = 150
default_max_path_length = 2
def question(url, question, user, collection):
def question(
url, question, user, collection, entity_limit, triple_limit,
max_subgraph_size, max_path_length
):
rag = Api(url)
# user=user, collection=collection,
resp = rag.graph_rag(question=question)
resp = rag.graph_rag(
question=question, user=user, collection=collection,
entity_limit=entity_limit, triple_limit=triple_limit,
max_subgraph_size=max_subgraph_size,
max_path_length=max_path_length
)
print(resp)
@ -52,6 +63,30 @@ def main():
help=f'Collection ID (default: {default_collection})'
)
parser.add_argument(
'-e', '--entity-limit',
default=default_entity_limit,
help=f'Entity limit (default: {default_entity_limit})'
)
parser.add_argument(
'-t', '--triple-limit',
default=default_triple_limit,
help=f'Triple limit (default: {default_triple_limit})'
)
parser.add_argument(
'-s', '--max-subgraph-size',
default=default_max_subgraph_size,
help=f'Max subgraph size (default: {default_max_subgraph_size})'
)
parser.add_argument(
'-p', '--max-path-length',
default=default_max_path_length,
help=f'Max path length (default: {default_max_path_length})'
)
args = parser.parse_args()
try:
@ -61,6 +96,10 @@ def main():
question=args.question,
user=args.user,
collection=args.collection,
entity_limit=args.entity_limit,
triple_limit=args.triple_limit,
max_subgraph_size=args.max_subgraph_size,
max_path_length=args.max_path_length,
)
except Exception as e:

View file

@ -18,11 +18,15 @@ DEFINITION="http://www.w3.org/2004/02/skos/core#definition"
class Query:
def __init__(self, rag, user, collection, verbose):
def __init__(
self, rag, user, collection, verbose,
doc_limit=20
):
self.rag = rag
self.user = user
self.collection = collection
self.verbose = verbose
self.doc_limit = doc_limit
def get_vector(self, query):
@ -44,7 +48,7 @@ class Query:
print("Get entities...", flush=True)
docs = self.rag.de_client.request(
vectors, limit=self.rag.doc_limit
vectors, limit=self.doc_limit
)
if self.verbose:
@ -93,9 +97,6 @@ class DocumentRag:
if self.verbose:
print("Initialising...", flush=True)
# FIXME: Configurable
self.doc_limit = 20
self.de_client = DocumentEmbeddingsClient(
pulsar_host=pulsar_host,
subscriber=module + "-de",
@ -123,13 +124,17 @@ class DocumentRag:
if self.verbose:
print("Initialised", flush=True)
def query(self, query, user="trustgraph", collection="default"):
def query(
self, query, user="trustgraph", collection="default",
doc_limit=20,
):
if self.verbose:
print("Construct prompt...", flush=True)
q = Query(
rag=self, user=user, collection=collection, verbose=self.verbose
rag=self, user=user, collection=collection, verbose=self.verbose,
doc_limit=doc_limit
)
docs = q.get_docs(query)

View file

@ -23,6 +23,7 @@ class DocumentRagRequestor(ServiceRequestor):
query=body["query"],
user=body.get("user", "trustgraph"),
collection=body.get("collection", "default"),
doc_limit=int(body.get("doc-limit", 20)),
)
def from_response(self, message):

View file

@ -23,6 +23,10 @@ class GraphRagRequestor(ServiceRequestor):
query=body["query"],
user=body.get("user", "trustgraph"),
collection=body.get("collection", "default"),
entity_limit=int(body.get("entity-limit", 50)),
triple_limit=int(body.get("triple-limit", 30)),
max_subgraph_size=int(body.get("max-subgraph-size", 1000)),
max_path_length=int(body.get("max-path-length", 2)),
)
def from_response(self, message):

View file

@ -20,11 +20,19 @@ DEFINITION="http://www.w3.org/2004/02/skos/core#definition"
class Query:
def __init__(self, rag, user, collection, verbose):
def __init__(
self, rag, user, collection, verbose,
entity_limit=50, triple_limit=30, max_subgraph_size=1000,
max_path_length=2,
):
self.rag = rag
self.user = user
self.collection = collection
self.verbose = verbose
self.entity_limit = entity_limit
self.triple_limit = triple_limit
self.max_subgraph_size = max_subgraph_size
self.max_path_length = max_path_length
def get_vector(self, query):
@ -47,7 +55,7 @@ class Query:
entities = self.rag.ge_client.request(
user=self.user, collection=self.collection,
vectors=vectors, limit=self.rag.entity_limit,
vectors=vectors, limit=self.entity_limit,
)
entities = [
@ -79,62 +87,67 @@ class Query:
self.rag.label_cache[e] = res[0].o.value
return self.rag.label_cache[e]
def follow_edges(self, ent, subgraph, path_length):
# Not needed?
if path_length <= 0:
return
# Stop spanning around if the subgraph is already maxed out
if len(subgraph) >= self.max_subgraph_size:
return
res = self.rag.triples_client.request(
user=self.user, collection=self.collection,
s=ent, p=None, o=None,
limit=self.triple_limit
)
for triple in res:
subgraph.add(
(triple.s.value, triple.p.value, triple.o.value)
)
if path_length > 1:
self.follow_edges(triple.o.value, subgraph, path_length-1)
res = self.rag.triples_client.request(
user=self.user, collection=self.collection,
s=None, p=ent, o=None,
limit=self.triple_limit
)
for triple in res:
subgraph.add(
(triple.s.value, triple.p.value, triple.o.value)
)
res = self.rag.triples_client.request(
user=self.user, collection=self.collection,
s=None, p=None, o=ent,
limit=self.triple_limit,
)
for triple in res:
subgraph.add(
(triple.s.value, triple.p.value, triple.o.value)
)
if path_length > 1:
self.follow_edges(triple.s.value, subgraph, path_length-1)
def get_subgraph(self, query):
entities = self.get_entities(query)
subgraph = set()
if self.verbose:
print("Get subgraph...", flush=True)
for e in entities:
subgraph = set()
res = self.rag.triples_client.request(
user=self.user, collection=self.collection,
s=e, p=None, o=None,
limit=self.rag.query_limit
)
for triple in res:
subgraph.add(
(triple.s.value, triple.p.value, triple.o.value)
)
res = self.rag.triples_client.request(
user=self.user, collection=self.collection,
s=None, p=e, o=None,
limit=self.rag.query_limit
)
for triple in res:
subgraph.add(
(triple.s.value, triple.p.value, triple.o.value)
)
res = self.rag.triples_client.request(
user=self.user, collection=self.collection,
s=None, p=None, o=e,
limit=self.rag.query_limit,
)
for triple in res:
subgraph.add(
(triple.s.value, triple.p.value, triple.o.value)
)
for ent in entities:
self.follow_edges(ent, subgraph, self.max_path_length)
subgraph = list(subgraph)
subgraph = subgraph[0:self.rag.max_subgraph_size]
if self.verbose:
print("Subgraph:", flush=True)
for edge in subgraph:
print(" ", str(edge), flush=True)
if self.verbose:
print("Done.", flush=True)
return subgraph
def get_labelgraph(self, query):
@ -154,6 +167,16 @@ class Query:
sg2.append((s, p, o))
sg2 = sg2[0:self.max_subgraph_size]
if self.verbose:
print("Subgraph:", flush=True)
for edge in sg2:
print(" ", str(edge), flush=True)
if self.verbose:
print("Done.", flush=True)
return sg2
class GraphRag:
@ -171,9 +194,6 @@ class GraphRag:
tpl_request_queue=None,
tpl_response_queue=None,
verbose=False,
entity_limit=50,
triple_limit=30,
max_subgraph_size=3000,
module="test",
):
@ -230,10 +250,6 @@ class GraphRag:
subscriber=module + "-emb",
)
self.entity_limit=entity_limit
self.query_limit=triple_limit
self.max_subgraph_size=max_subgraph_size
self.label_cache = {}
self.prompt = PromptClient(
@ -247,13 +263,20 @@ class GraphRag:
if self.verbose:
print("Initialised", flush=True)
def query(self, query, user="trustgraph", collection="default"):
def query(
self, query, user="trustgraph", collection="default",
entity_limit=50, triple_limit=30, max_subgraph_size=1000,
max_path_length=2,
):
if self.verbose:
print("Construct prompt...", flush=True)
q = Query(
rag=self, user=user, collection=collection, verbose=self.verbose
rag=self, user=user, collection=collection, verbose=self.verbose,
entity_limit=entity_limit, triple_limit=triple_limit,
max_subgraph_size=max_subgraph_size,
max_path_length=max_path_length,
)
kg = q.get_labelgraph(query)

View file

@ -50,6 +50,8 @@ class Processor(ConsumerProducer):
document_embeddings_response_queue
)
doc_limit = params.get("doc_limit", 10)
super(Processor, self).__init__(
**params | {
"input_queue": input_queue,
@ -79,6 +81,8 @@ class Processor(ConsumerProducer):
module=module,
)
self.doc_limit = doc_limit
async def handle(self, msg):
try:
@ -90,7 +94,12 @@ class Processor(ConsumerProducer):
print(f"Handling input {id}...", flush=True)
response = self.rag.query(v.query)
if v.doc_limit:
doc_limit = v.doc_limit
else:
doc_limit = self.doc_limit
response = self.rag.query(v.query, doc_limit=doc_limit)
print("Send response...", flush=True)
r = DocumentRagResponse(response = response, error=None)
@ -124,6 +133,13 @@ class Processor(ConsumerProducer):
default_output_queue,
)
parser.add_argument(
'-d', '--doc-limit',
type=int,
default=20,
help=f'Default document fetch limit (default: 10)'
)
parser.add_argument(
'--prompt-request-queue',
default=prompt_request_queue,

View file

@ -31,9 +31,7 @@ class Processor(ConsumerProducer):
input_queue = params.get("input_queue", default_input_queue)
output_queue = params.get("output_queue", default_output_queue)
subscriber = params.get("subscriber", default_subscriber)
entity_limit = params.get("entity_limit", 50)
triple_limit = params.get("triple_limit", 30)
max_subgraph_size = params.get("max_subgraph_size", 3000)
pr_request_queue = params.get(
"prompt_request_queue", prompt_request_queue
)
@ -59,6 +57,11 @@ class Processor(ConsumerProducer):
"triples_response_queue", triples_response_queue
)
entity_limit = params.get("entity_limit", 50)
triple_limit = params.get("triple_limit", 30)
max_subgraph_size = params.get("max_subgraph_size", 150)
max_path_length = params.get("max_path_length", 2)
super(Processor, self).__init__(
**params | {
"input_queue": input_queue,
@ -92,12 +95,14 @@ class Processor(ConsumerProducer):
tpl_request_queue=triples_request_queue,
tpl_response_queue=triples_response_queue,
verbose=True,
entity_limit=entity_limit,
triple_limit=triple_limit,
max_subgraph_size=max_subgraph_size,
module=module,
)
self.default_entity_limit = entity_limit
self.default_triple_limit = triple_limit
self.default_max_subgraph_size = max_subgraph_size
self.default_max_path_length = max_path_length
async def handle(self, msg):
try:
@ -106,15 +111,38 @@ class Processor(ConsumerProducer):
# Sender-produced ID
id = msg.properties()["id"]
print(f"Handling input {id}...", flush=True)
if v.entity_limit:
entity_limit = v.entity_limit
else:
entity_limit = self.default_entity_limit
if v.triple_limit:
triple_limit = v.triple_limit
else:
triple_limit = self.default_triple_limit
if v.max_subgraph_size:
max_subgraph_size = v.max_subgraph_size
else:
max_subgraph_size = self.default_max_subgraph_size
if v.max_path_length:
max_path_length = v.max_path_length
else:
max_path_length = self.default_max_path_length
response = self.rag.query(
query=v.query, user=v.user, collection=v.collection
query=v.query, user=v.user, collection=v.collection,
entity_limit=entity_limit, triple_limit=triple_limit,
max_subgraph_size=max_subgraph_size,
max_path_length=max_path_length,
)
print("Send response...", flush=True)
r = GraphRagResponse(response = response, error=None)
r = GraphRagResponse(response=response, error=None)
await self.send(r, properties={"id": id})
print("Done.", flush=True)
@ -149,21 +177,28 @@ class Processor(ConsumerProducer):
'-e', '--entity-limit',
type=int,
default=50,
help=f'Entity vector fetch limit (default: 50)'
help=f'Default entity vector fetch limit (default: 50)'
)
parser.add_argument(
'-t', '--triple-limit',
type=int,
default=30,
help=f'Triple query limit, per query (default: 30)'
help=f'Default triple query limit, per query (default: 30)'
)
parser.add_argument(
'-u', '--max-subgraph-size',
type=int,
default=3000,
help=f'Max subgraph size (default: 3000)'
default=150,
help=f'Default max subgraph size (default: 150)'
)
parser.add_argument(
'-a', '--max-path-length',
type=int,
default=2,
help=f'Default max path length (default: 2)'
)
parser.add_argument(