mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-05-04 12:52:36 +02:00
Feature/rag parameters (#311)
* Change document-rag and graph-rag processing so that the user can specify parameters. Changes in Pulsar services, Pulsar message schemas, gateway and command-line tools. User-visible changes in new parameters on command-line tools. * Fix bugs, graph-rag working * Get subgraph truncation in the right place * Graph RAG and document RAG working and configurable * Multi-hop path traversal GraphRAG * Add safety valve for path_size set too high
This commit is contained in:
parent
f1559c5944
commit
ef845d6c9b
12 changed files with 247 additions and 91 deletions
|
|
@ -18,11 +18,15 @@ DEFINITION="http://www.w3.org/2004/02/skos/core#definition"
|
|||
|
||||
class Query:
|
||||
|
||||
def __init__(self, rag, user, collection, verbose):
|
||||
def __init__(
|
||||
self, rag, user, collection, verbose,
|
||||
doc_limit=20
|
||||
):
|
||||
self.rag = rag
|
||||
self.user = user
|
||||
self.collection = collection
|
||||
self.verbose = verbose
|
||||
self.doc_limit = doc_limit
|
||||
|
||||
def get_vector(self, query):
|
||||
|
||||
|
|
@ -44,7 +48,7 @@ class Query:
|
|||
print("Get entities...", flush=True)
|
||||
|
||||
docs = self.rag.de_client.request(
|
||||
vectors, limit=self.rag.doc_limit
|
||||
vectors, limit=self.doc_limit
|
||||
)
|
||||
|
||||
if self.verbose:
|
||||
|
|
@ -93,9 +97,6 @@ class DocumentRag:
|
|||
if self.verbose:
|
||||
print("Initialising...", flush=True)
|
||||
|
||||
# FIXME: Configurable
|
||||
self.doc_limit = 20
|
||||
|
||||
self.de_client = DocumentEmbeddingsClient(
|
||||
pulsar_host=pulsar_host,
|
||||
subscriber=module + "-de",
|
||||
|
|
@ -123,13 +124,17 @@ class DocumentRag:
|
|||
if self.verbose:
|
||||
print("Initialised", flush=True)
|
||||
|
||||
def query(self, query, user="trustgraph", collection="default"):
|
||||
def query(
|
||||
self, query, user="trustgraph", collection="default",
|
||||
doc_limit=20,
|
||||
):
|
||||
|
||||
if self.verbose:
|
||||
print("Construct prompt...", flush=True)
|
||||
|
||||
q = Query(
|
||||
rag=self, user=user, collection=collection, verbose=self.verbose
|
||||
rag=self, user=user, collection=collection, verbose=self.verbose,
|
||||
doc_limit=doc_limit
|
||||
)
|
||||
|
||||
docs = q.get_docs(query)
|
||||
|
|
|
|||
|
|
@ -23,6 +23,7 @@ class DocumentRagRequestor(ServiceRequestor):
|
|||
query=body["query"],
|
||||
user=body.get("user", "trustgraph"),
|
||||
collection=body.get("collection", "default"),
|
||||
doc_limit=int(body.get("doc-limit", 20)),
|
||||
)
|
||||
|
||||
def from_response(self, message):
|
||||
|
|
|
|||
|
|
@ -23,6 +23,10 @@ class GraphRagRequestor(ServiceRequestor):
|
|||
query=body["query"],
|
||||
user=body.get("user", "trustgraph"),
|
||||
collection=body.get("collection", "default"),
|
||||
entity_limit=int(body.get("entity-limit", 50)),
|
||||
triple_limit=int(body.get("triple-limit", 30)),
|
||||
max_subgraph_size=int(body.get("max-subgraph-size", 1000)),
|
||||
max_path_length=int(body.get("max-path-length", 2)),
|
||||
)
|
||||
|
||||
def from_response(self, message):
|
||||
|
|
|
|||
|
|
@ -20,11 +20,19 @@ DEFINITION="http://www.w3.org/2004/02/skos/core#definition"
|
|||
|
||||
class Query:
|
||||
|
||||
def __init__(self, rag, user, collection, verbose):
|
||||
def __init__(
|
||||
self, rag, user, collection, verbose,
|
||||
entity_limit=50, triple_limit=30, max_subgraph_size=1000,
|
||||
max_path_length=2,
|
||||
):
|
||||
self.rag = rag
|
||||
self.user = user
|
||||
self.collection = collection
|
||||
self.verbose = verbose
|
||||
self.entity_limit = entity_limit
|
||||
self.triple_limit = triple_limit
|
||||
self.max_subgraph_size = max_subgraph_size
|
||||
self.max_path_length = max_path_length
|
||||
|
||||
def get_vector(self, query):
|
||||
|
||||
|
|
@ -47,7 +55,7 @@ class Query:
|
|||
|
||||
entities = self.rag.ge_client.request(
|
||||
user=self.user, collection=self.collection,
|
||||
vectors=vectors, limit=self.rag.entity_limit,
|
||||
vectors=vectors, limit=self.entity_limit,
|
||||
)
|
||||
|
||||
entities = [
|
||||
|
|
@ -79,62 +87,67 @@ class Query:
|
|||
self.rag.label_cache[e] = res[0].o.value
|
||||
return self.rag.label_cache[e]
|
||||
|
||||
def follow_edges(self, ent, subgraph, path_length):
|
||||
|
||||
# Not needed?
|
||||
if path_length <= 0:
|
||||
return
|
||||
|
||||
# Stop spanning around if the subgraph is already maxed out
|
||||
if len(subgraph) >= self.max_subgraph_size:
|
||||
return
|
||||
|
||||
res = self.rag.triples_client.request(
|
||||
user=self.user, collection=self.collection,
|
||||
s=ent, p=None, o=None,
|
||||
limit=self.triple_limit
|
||||
)
|
||||
|
||||
for triple in res:
|
||||
subgraph.add(
|
||||
(triple.s.value, triple.p.value, triple.o.value)
|
||||
)
|
||||
if path_length > 1:
|
||||
self.follow_edges(triple.o.value, subgraph, path_length-1)
|
||||
|
||||
res = self.rag.triples_client.request(
|
||||
user=self.user, collection=self.collection,
|
||||
s=None, p=ent, o=None,
|
||||
limit=self.triple_limit
|
||||
)
|
||||
|
||||
for triple in res:
|
||||
subgraph.add(
|
||||
(triple.s.value, triple.p.value, triple.o.value)
|
||||
)
|
||||
|
||||
res = self.rag.triples_client.request(
|
||||
user=self.user, collection=self.collection,
|
||||
s=None, p=None, o=ent,
|
||||
limit=self.triple_limit,
|
||||
)
|
||||
|
||||
for triple in res:
|
||||
subgraph.add(
|
||||
(triple.s.value, triple.p.value, triple.o.value)
|
||||
)
|
||||
if path_length > 1:
|
||||
self.follow_edges(triple.s.value, subgraph, path_length-1)
|
||||
|
||||
def get_subgraph(self, query):
|
||||
|
||||
entities = self.get_entities(query)
|
||||
|
||||
subgraph = set()
|
||||
|
||||
if self.verbose:
|
||||
print("Get subgraph...", flush=True)
|
||||
|
||||
for e in entities:
|
||||
subgraph = set()
|
||||
|
||||
res = self.rag.triples_client.request(
|
||||
user=self.user, collection=self.collection,
|
||||
s=e, p=None, o=None,
|
||||
limit=self.rag.query_limit
|
||||
)
|
||||
|
||||
for triple in res:
|
||||
subgraph.add(
|
||||
(triple.s.value, triple.p.value, triple.o.value)
|
||||
)
|
||||
|
||||
res = self.rag.triples_client.request(
|
||||
user=self.user, collection=self.collection,
|
||||
s=None, p=e, o=None,
|
||||
limit=self.rag.query_limit
|
||||
)
|
||||
|
||||
for triple in res:
|
||||
subgraph.add(
|
||||
(triple.s.value, triple.p.value, triple.o.value)
|
||||
)
|
||||
|
||||
res = self.rag.triples_client.request(
|
||||
user=self.user, collection=self.collection,
|
||||
s=None, p=None, o=e,
|
||||
limit=self.rag.query_limit,
|
||||
)
|
||||
|
||||
for triple in res:
|
||||
subgraph.add(
|
||||
(triple.s.value, triple.p.value, triple.o.value)
|
||||
)
|
||||
for ent in entities:
|
||||
self.follow_edges(ent, subgraph, self.max_path_length)
|
||||
|
||||
subgraph = list(subgraph)
|
||||
|
||||
subgraph = subgraph[0:self.rag.max_subgraph_size]
|
||||
|
||||
if self.verbose:
|
||||
print("Subgraph:", flush=True)
|
||||
for edge in subgraph:
|
||||
print(" ", str(edge), flush=True)
|
||||
|
||||
if self.verbose:
|
||||
print("Done.", flush=True)
|
||||
|
||||
return subgraph
|
||||
|
||||
def get_labelgraph(self, query):
|
||||
|
|
@ -154,6 +167,16 @@ class Query:
|
|||
|
||||
sg2.append((s, p, o))
|
||||
|
||||
sg2 = sg2[0:self.max_subgraph_size]
|
||||
|
||||
if self.verbose:
|
||||
print("Subgraph:", flush=True)
|
||||
for edge in sg2:
|
||||
print(" ", str(edge), flush=True)
|
||||
|
||||
if self.verbose:
|
||||
print("Done.", flush=True)
|
||||
|
||||
return sg2
|
||||
|
||||
class GraphRag:
|
||||
|
|
@ -171,9 +194,6 @@ class GraphRag:
|
|||
tpl_request_queue=None,
|
||||
tpl_response_queue=None,
|
||||
verbose=False,
|
||||
entity_limit=50,
|
||||
triple_limit=30,
|
||||
max_subgraph_size=3000,
|
||||
module="test",
|
||||
):
|
||||
|
||||
|
|
@ -230,10 +250,6 @@ class GraphRag:
|
|||
subscriber=module + "-emb",
|
||||
)
|
||||
|
||||
self.entity_limit=entity_limit
|
||||
self.query_limit=triple_limit
|
||||
self.max_subgraph_size=max_subgraph_size
|
||||
|
||||
self.label_cache = {}
|
||||
|
||||
self.prompt = PromptClient(
|
||||
|
|
@ -247,13 +263,20 @@ class GraphRag:
|
|||
if self.verbose:
|
||||
print("Initialised", flush=True)
|
||||
|
||||
def query(self, query, user="trustgraph", collection="default"):
|
||||
def query(
|
||||
self, query, user="trustgraph", collection="default",
|
||||
entity_limit=50, triple_limit=30, max_subgraph_size=1000,
|
||||
max_path_length=2,
|
||||
):
|
||||
|
||||
if self.verbose:
|
||||
print("Construct prompt...", flush=True)
|
||||
|
||||
q = Query(
|
||||
rag=self, user=user, collection=collection, verbose=self.verbose
|
||||
rag=self, user=user, collection=collection, verbose=self.verbose,
|
||||
entity_limit=entity_limit, triple_limit=triple_limit,
|
||||
max_subgraph_size=max_subgraph_size,
|
||||
max_path_length=max_path_length,
|
||||
)
|
||||
|
||||
kg = q.get_labelgraph(query)
|
||||
|
|
|
|||
|
|
@ -50,6 +50,8 @@ class Processor(ConsumerProducer):
|
|||
document_embeddings_response_queue
|
||||
)
|
||||
|
||||
doc_limit = params.get("doc_limit", 10)
|
||||
|
||||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
"input_queue": input_queue,
|
||||
|
|
@ -79,6 +81,8 @@ class Processor(ConsumerProducer):
|
|||
module=module,
|
||||
)
|
||||
|
||||
self.doc_limit = doc_limit
|
||||
|
||||
async def handle(self, msg):
|
||||
|
||||
try:
|
||||
|
|
@ -90,7 +94,12 @@ class Processor(ConsumerProducer):
|
|||
|
||||
print(f"Handling input {id}...", flush=True)
|
||||
|
||||
response = self.rag.query(v.query)
|
||||
if v.doc_limit:
|
||||
doc_limit = v.doc_limit
|
||||
else:
|
||||
doc_limit = self.doc_limit
|
||||
|
||||
response = self.rag.query(v.query, doc_limit=doc_limit)
|
||||
|
||||
print("Send response...", flush=True)
|
||||
r = DocumentRagResponse(response = response, error=None)
|
||||
|
|
@ -124,6 +133,13 @@ class Processor(ConsumerProducer):
|
|||
default_output_queue,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-d', '--doc-limit',
|
||||
type=int,
|
||||
default=20,
|
||||
help=f'Default document fetch limit (default: 10)'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--prompt-request-queue',
|
||||
default=prompt_request_queue,
|
||||
|
|
|
|||
|
|
@ -31,9 +31,7 @@ class Processor(ConsumerProducer):
|
|||
input_queue = params.get("input_queue", default_input_queue)
|
||||
output_queue = params.get("output_queue", default_output_queue)
|
||||
subscriber = params.get("subscriber", default_subscriber)
|
||||
entity_limit = params.get("entity_limit", 50)
|
||||
triple_limit = params.get("triple_limit", 30)
|
||||
max_subgraph_size = params.get("max_subgraph_size", 3000)
|
||||
|
||||
pr_request_queue = params.get(
|
||||
"prompt_request_queue", prompt_request_queue
|
||||
)
|
||||
|
|
@ -59,6 +57,11 @@ class Processor(ConsumerProducer):
|
|||
"triples_response_queue", triples_response_queue
|
||||
)
|
||||
|
||||
entity_limit = params.get("entity_limit", 50)
|
||||
triple_limit = params.get("triple_limit", 30)
|
||||
max_subgraph_size = params.get("max_subgraph_size", 150)
|
||||
max_path_length = params.get("max_path_length", 2)
|
||||
|
||||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
"input_queue": input_queue,
|
||||
|
|
@ -92,12 +95,14 @@ class Processor(ConsumerProducer):
|
|||
tpl_request_queue=triples_request_queue,
|
||||
tpl_response_queue=triples_response_queue,
|
||||
verbose=True,
|
||||
entity_limit=entity_limit,
|
||||
triple_limit=triple_limit,
|
||||
max_subgraph_size=max_subgraph_size,
|
||||
module=module,
|
||||
)
|
||||
|
||||
self.default_entity_limit = entity_limit
|
||||
self.default_triple_limit = triple_limit
|
||||
self.default_max_subgraph_size = max_subgraph_size
|
||||
self.default_max_path_length = max_path_length
|
||||
|
||||
async def handle(self, msg):
|
||||
|
||||
try:
|
||||
|
|
@ -106,15 +111,38 @@ class Processor(ConsumerProducer):
|
|||
|
||||
# Sender-produced ID
|
||||
id = msg.properties()["id"]
|
||||
|
||||
|
||||
print(f"Handling input {id}...", flush=True)
|
||||
|
||||
if v.entity_limit:
|
||||
entity_limit = v.entity_limit
|
||||
else:
|
||||
entity_limit = self.default_entity_limit
|
||||
|
||||
if v.triple_limit:
|
||||
triple_limit = v.triple_limit
|
||||
else:
|
||||
triple_limit = self.default_triple_limit
|
||||
|
||||
if v.max_subgraph_size:
|
||||
max_subgraph_size = v.max_subgraph_size
|
||||
else:
|
||||
max_subgraph_size = self.default_max_subgraph_size
|
||||
|
||||
if v.max_path_length:
|
||||
max_path_length = v.max_path_length
|
||||
else:
|
||||
max_path_length = self.default_max_path_length
|
||||
|
||||
response = self.rag.query(
|
||||
query=v.query, user=v.user, collection=v.collection
|
||||
query=v.query, user=v.user, collection=v.collection,
|
||||
entity_limit=entity_limit, triple_limit=triple_limit,
|
||||
max_subgraph_size=max_subgraph_size,
|
||||
max_path_length=max_path_length,
|
||||
)
|
||||
|
||||
print("Send response...", flush=True)
|
||||
r = GraphRagResponse(response = response, error=None)
|
||||
r = GraphRagResponse(response=response, error=None)
|
||||
await self.send(r, properties={"id": id})
|
||||
|
||||
print("Done.", flush=True)
|
||||
|
|
@ -149,21 +177,28 @@ class Processor(ConsumerProducer):
|
|||
'-e', '--entity-limit',
|
||||
type=int,
|
||||
default=50,
|
||||
help=f'Entity vector fetch limit (default: 50)'
|
||||
help=f'Default entity vector fetch limit (default: 50)'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-t', '--triple-limit',
|
||||
type=int,
|
||||
default=30,
|
||||
help=f'Triple query limit, per query (default: 30)'
|
||||
help=f'Default triple query limit, per query (default: 30)'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-u', '--max-subgraph-size',
|
||||
type=int,
|
||||
default=3000,
|
||||
help=f'Max subgraph size (default: 3000)'
|
||||
default=150,
|
||||
help=f'Default max subgraph size (default: 150)'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-a', '--max-path-length',
|
||||
type=int,
|
||||
default=2,
|
||||
help=f'Default max path length (default: 2)'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue