diff --git a/templates/components/document-rag.jsonnet b/templates/components/document-rag.jsonnet index ec125ed5..177e23e0 100644 --- a/templates/components/document-rag.jsonnet +++ b/templates/components/document-rag.jsonnet @@ -5,6 +5,8 @@ local prompts = import "prompts/mixtral.jsonnet"; { + "document-rag-doc-limit":: 20, + "document-rag" +: { create:: function(engine) diff --git a/templates/components/graph-rag.jsonnet b/templates/components/graph-rag.jsonnet index 27035b35..10a8b4ce 100644 --- a/templates/components/graph-rag.jsonnet +++ b/templates/components/graph-rag.jsonnet @@ -4,9 +4,9 @@ local url = import "values/url.jsonnet"; { - "graph-rag-entity-limit":: 50, - "graph-rag-triple-limit":: 30, - "graph-rag-max-subgraph-size":: 3000, + "graph-rag-entity-limit":: 20, + "graph-rag-triple-limit":: 10, + "graph-rag-max-subgraph-size":: 1000, "kg-extract-definitions" +: { diff --git a/trustgraph-base/trustgraph/api/api.py b/trustgraph-base/trustgraph/api/api.py index 24207f32..4c72d3ca 100644 --- a/trustgraph-base/trustgraph/api/api.py +++ b/trustgraph-base/trustgraph/api/api.py @@ -102,11 +102,21 @@ class Api: except: raise ProtocolException(f"Response not formatted correctly") - def graph_rag(self, question): + def graph_rag( + self, question, user="trustgraph", collection="default", + entity_limit=50, triple_limit=30, max_subgraph_size=150, + max_path_length=2, + ): # The input consists of a question input = { - "query": question + "query": question, + "user": user, + "collection": collection, + "entity-limit": entity_limit, + "triple-limit": triple_limit, + "max-subgraph-size": max_subgraph_size, + "max-path-length": max_path_length, } url = f"{self.url}graph-rag" @@ -131,11 +141,17 @@ class Api: except: raise ProtocolException(f"Response not formatted correctly") - def document_rag(self, question): + def document_rag( + self, question, user="trustgraph", collection="default", + doc_limit=10, + ): # The input consists of a question input = { - "query": question + "query": question, + "user": user, + "collection": collection, + "doc-limit": doc_limit, } url = f"{self.url}document-rag" diff --git a/trustgraph-base/trustgraph/schema/retrieval.py b/trustgraph-base/trustgraph/schema/retrieval.py index 9c4361a1..caeb8e67 100644 --- a/trustgraph-base/trustgraph/schema/retrieval.py +++ b/trustgraph-base/trustgraph/schema/retrieval.py @@ -11,6 +11,10 @@ class GraphRagQuery(Record): query = String() user = String() collection = String() + entity_limit = Integer() + triple_limit = Integer() + max_subgraph_size = Integer() + max_path_length = Integer() class GraphRagResponse(Record): error = Error() @@ -31,6 +35,7 @@ class DocumentRagQuery(Record): query = String() user = String() collection = String() + doc_limit = Integer() class DocumentRagResponse(Record): error = Error() diff --git a/trustgraph-cli/scripts/tg-invoke-document-rag b/trustgraph-cli/scripts/tg-invoke-document-rag index a3fc1958..759d4200 100755 --- a/trustgraph-cli/scripts/tg-invoke-document-rag +++ b/trustgraph-cli/scripts/tg-invoke-document-rag @@ -11,13 +11,16 @@ from trustgraph.api import Api default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_user = 'trustgraph' default_collection = 'default' +default_doc_limit = 10 -def question(url, question, user, collection): +def question(url, question, user, collection, doc_limit): rag = Api(url) -# user=user, collection=collection, - resp = rag.document_rag(question=question) + resp = rag.document_rag( + question=question, user=user, collection=collection, + doc_limit=doc_limit, + ) print(resp) @@ -41,7 +44,7 @@ def main(): # ) parser.add_argument( - '-q', '--query', + '-q', '--question', required=True, help=f'Question to answer', ) @@ -58,6 +61,12 @@ def main(): help=f'Collection ID (default: {default_collection})' ) + parser.add_argument( + '-d', '--doc-limit', + default=default_doc_limit, + help=f'Document limit (default: {default_doc_limit})' + ) + args = parser.parse_args() try: @@ -67,6 +76,7 @@ def main(): question=args.question, user=args.user, collection=args.collection, + doc_limit=args.doc_limit, ) except Exception as e: diff --git a/trustgraph-cli/scripts/tg-invoke-graph-rag b/trustgraph-cli/scripts/tg-invoke-graph-rag index 50de5b74..5bbe5f59 100755 --- a/trustgraph-cli/scripts/tg-invoke-graph-rag +++ b/trustgraph-cli/scripts/tg-invoke-graph-rag @@ -11,13 +11,24 @@ from trustgraph.api import Api default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_user = 'trustgraph' default_collection = 'default' +default_entity_limit = 50 +default_triple_limit = 30 +default_max_subgraph_size = 150 +default_max_path_length = 2 -def question(url, question, user, collection): +def question( + url, question, user, collection, entity_limit, triple_limit, + max_subgraph_size, max_path_length +): rag = Api(url) -# user=user, collection=collection, - resp = rag.graph_rag(question=question) + resp = rag.graph_rag( + question=question, user=user, collection=collection, + entity_limit=entity_limit, triple_limit=triple_limit, + max_subgraph_size=max_subgraph_size, + max_path_length=max_path_length + ) print(resp) @@ -52,6 +63,30 @@ def main(): help=f'Collection ID (default: {default_collection})' ) + parser.add_argument( + '-e', '--entity-limit', + default=default_entity_limit, + help=f'Entity limit (default: {default_entity_limit})' + ) + + parser.add_argument( + '-t', '--triple-limit', + default=default_triple_limit, + help=f'Triple limit (default: {default_triple_limit})' + ) + + parser.add_argument( + '-s', '--max-subgraph-size', + default=default_max_subgraph_size, + help=f'Max subgraph size (default: {default_max_subgraph_size})' + ) + + parser.add_argument( + '-p', '--max-path-length', + default=default_max_path_length, + help=f'Max path length (default: {default_max_path_length})' + ) + args = parser.parse_args() try: @@ -61,6 +96,10 @@ def main(): question=args.question, user=args.user, collection=args.collection, + entity_limit=args.entity_limit, + triple_limit=args.triple_limit, + max_subgraph_size=args.max_subgraph_size, + max_path_length=args.max_path_length, ) except Exception as e: diff --git a/trustgraph-flow/trustgraph/document_rag.py b/trustgraph-flow/trustgraph/document_rag.py index f4676b15..4fc4850a 100644 --- a/trustgraph-flow/trustgraph/document_rag.py +++ b/trustgraph-flow/trustgraph/document_rag.py @@ -18,11 +18,15 @@ DEFINITION="http://www.w3.org/2004/02/skos/core#definition" class Query: - def __init__(self, rag, user, collection, verbose): + def __init__( + self, rag, user, collection, verbose, + doc_limit=20 + ): self.rag = rag self.user = user self.collection = collection self.verbose = verbose + self.doc_limit = doc_limit def get_vector(self, query): @@ -44,7 +48,7 @@ class Query: print("Get entities...", flush=True) docs = self.rag.de_client.request( - vectors, limit=self.rag.doc_limit + vectors, limit=self.doc_limit ) if self.verbose: @@ -93,9 +97,6 @@ class DocumentRag: if self.verbose: print("Initialising...", flush=True) - # FIXME: Configurable - self.doc_limit = 20 - self.de_client = DocumentEmbeddingsClient( pulsar_host=pulsar_host, subscriber=module + "-de", @@ -123,13 +124,17 @@ class DocumentRag: if self.verbose: print("Initialised", flush=True) - def query(self, query, user="trustgraph", collection="default"): + def query( + self, query, user="trustgraph", collection="default", + doc_limit=20, + ): if self.verbose: print("Construct prompt...", flush=True) q = Query( - rag=self, user=user, collection=collection, verbose=self.verbose + rag=self, user=user, collection=collection, verbose=self.verbose, + doc_limit=doc_limit ) docs = q.get_docs(query) diff --git a/trustgraph-flow/trustgraph/gateway/document_rag.py b/trustgraph-flow/trustgraph/gateway/document_rag.py index e5749197..94d8f788 100644 --- a/trustgraph-flow/trustgraph/gateway/document_rag.py +++ b/trustgraph-flow/trustgraph/gateway/document_rag.py @@ -23,6 +23,7 @@ class DocumentRagRequestor(ServiceRequestor): query=body["query"], user=body.get("user", "trustgraph"), collection=body.get("collection", "default"), + doc_limit=int(body.get("doc-limit", 20)), ) def from_response(self, message): diff --git a/trustgraph-flow/trustgraph/gateway/graph_rag.py b/trustgraph-flow/trustgraph/gateway/graph_rag.py index 59a4fb90..b2b69758 100644 --- a/trustgraph-flow/trustgraph/gateway/graph_rag.py +++ b/trustgraph-flow/trustgraph/gateway/graph_rag.py @@ -23,6 +23,10 @@ class GraphRagRequestor(ServiceRequestor): query=body["query"], user=body.get("user", "trustgraph"), collection=body.get("collection", "default"), + entity_limit=int(body.get("entity-limit", 50)), + triple_limit=int(body.get("triple-limit", 30)), + max_subgraph_size=int(body.get("max-subgraph-size", 1000)), + max_path_length=int(body.get("max-path-length", 2)), ) def from_response(self, message): diff --git a/trustgraph-flow/trustgraph/graph_rag.py b/trustgraph-flow/trustgraph/graph_rag.py index a1f9909e..6a4e11c5 100644 --- a/trustgraph-flow/trustgraph/graph_rag.py +++ b/trustgraph-flow/trustgraph/graph_rag.py @@ -20,11 +20,19 @@ DEFINITION="http://www.w3.org/2004/02/skos/core#definition" class Query: - def __init__(self, rag, user, collection, verbose): + def __init__( + self, rag, user, collection, verbose, + entity_limit=50, triple_limit=30, max_subgraph_size=1000, + max_path_length=2, + ): self.rag = rag self.user = user self.collection = collection self.verbose = verbose + self.entity_limit = entity_limit + self.triple_limit = triple_limit + self.max_subgraph_size = max_subgraph_size + self.max_path_length = max_path_length def get_vector(self, query): @@ -47,7 +55,7 @@ class Query: entities = self.rag.ge_client.request( user=self.user, collection=self.collection, - vectors=vectors, limit=self.rag.entity_limit, + vectors=vectors, limit=self.entity_limit, ) entities = [ @@ -79,62 +87,67 @@ class Query: self.rag.label_cache[e] = res[0].o.value return self.rag.label_cache[e] + def follow_edges(self, ent, subgraph, path_length): + + # Not needed? + if path_length <= 0: + return + + # Stop spanning around if the subgraph is already maxed out + if len(subgraph) >= self.max_subgraph_size: + return + + res = self.rag.triples_client.request( + user=self.user, collection=self.collection, + s=ent, p=None, o=None, + limit=self.triple_limit + ) + + for triple in res: + subgraph.add( + (triple.s.value, triple.p.value, triple.o.value) + ) + if path_length > 1: + self.follow_edges(triple.o.value, subgraph, path_length-1) + + res = self.rag.triples_client.request( + user=self.user, collection=self.collection, + s=None, p=ent, o=None, + limit=self.triple_limit + ) + + for triple in res: + subgraph.add( + (triple.s.value, triple.p.value, triple.o.value) + ) + + res = self.rag.triples_client.request( + user=self.user, collection=self.collection, + s=None, p=None, o=ent, + limit=self.triple_limit, + ) + + for triple in res: + subgraph.add( + (triple.s.value, triple.p.value, triple.o.value) + ) + if path_length > 1: + self.follow_edges(triple.s.value, subgraph, path_length-1) + def get_subgraph(self, query): entities = self.get_entities(query) - subgraph = set() - if self.verbose: print("Get subgraph...", flush=True) - for e in entities: + subgraph = set() - res = self.rag.triples_client.request( - user=self.user, collection=self.collection, - s=e, p=None, o=None, - limit=self.rag.query_limit - ) - - for triple in res: - subgraph.add( - (triple.s.value, triple.p.value, triple.o.value) - ) - - res = self.rag.triples_client.request( - user=self.user, collection=self.collection, - s=None, p=e, o=None, - limit=self.rag.query_limit - ) - - for triple in res: - subgraph.add( - (triple.s.value, triple.p.value, triple.o.value) - ) - - res = self.rag.triples_client.request( - user=self.user, collection=self.collection, - s=None, p=None, o=e, - limit=self.rag.query_limit, - ) - - for triple in res: - subgraph.add( - (triple.s.value, triple.p.value, triple.o.value) - ) + for ent in entities: + self.follow_edges(ent, subgraph, self.max_path_length) subgraph = list(subgraph) - subgraph = subgraph[0:self.rag.max_subgraph_size] - - if self.verbose: - print("Subgraph:", flush=True) - for edge in subgraph: - print(" ", str(edge), flush=True) - - if self.verbose: - print("Done.", flush=True) - return subgraph def get_labelgraph(self, query): @@ -154,6 +167,16 @@ class Query: sg2.append((s, p, o)) + sg2 = sg2[0:self.max_subgraph_size] + + if self.verbose: + print("Subgraph:", flush=True) + for edge in sg2: + print(" ", str(edge), flush=True) + + if self.verbose: + print("Done.", flush=True) + return sg2 class GraphRag: @@ -171,9 +194,6 @@ class GraphRag: tpl_request_queue=None, tpl_response_queue=None, verbose=False, - entity_limit=50, - triple_limit=30, - max_subgraph_size=3000, module="test", ): @@ -230,10 +250,6 @@ class GraphRag: subscriber=module + "-emb", ) - self.entity_limit=entity_limit - self.query_limit=triple_limit - self.max_subgraph_size=max_subgraph_size - self.label_cache = {} self.prompt = PromptClient( @@ -247,13 +263,20 @@ class GraphRag: if self.verbose: print("Initialised", flush=True) - def query(self, query, user="trustgraph", collection="default"): + def query( + self, query, user="trustgraph", collection="default", + entity_limit=50, triple_limit=30, max_subgraph_size=1000, + max_path_length=2, + ): if self.verbose: print("Construct prompt...", flush=True) q = Query( - rag=self, user=user, collection=collection, verbose=self.verbose + rag=self, user=user, collection=collection, verbose=self.verbose, + entity_limit=entity_limit, triple_limit=triple_limit, + max_subgraph_size=max_subgraph_size, + max_path_length=max_path_length, ) kg = q.get_labelgraph(query) diff --git a/trustgraph-flow/trustgraph/retrieval/document_rag/rag.py b/trustgraph-flow/trustgraph/retrieval/document_rag/rag.py index 23b46129..bb8b008e 100755 --- a/trustgraph-flow/trustgraph/retrieval/document_rag/rag.py +++ b/trustgraph-flow/trustgraph/retrieval/document_rag/rag.py @@ -50,6 +50,8 @@ class Processor(ConsumerProducer): document_embeddings_response_queue ) + doc_limit = params.get("doc_limit", 10) + super(Processor, self).__init__( **params | { "input_queue": input_queue, @@ -79,6 +81,8 @@ class Processor(ConsumerProducer): module=module, ) + self.doc_limit = doc_limit + async def handle(self, msg): try: @@ -90,7 +94,12 @@ class Processor(ConsumerProducer): print(f"Handling input {id}...", flush=True) - response = self.rag.query(v.query) + if v.doc_limit: + doc_limit = v.doc_limit + else: + doc_limit = self.doc_limit + + response = self.rag.query(v.query, doc_limit=doc_limit) print("Send response...", flush=True) r = DocumentRagResponse(response = response, error=None) @@ -124,6 +133,13 @@ class Processor(ConsumerProducer): default_output_queue, ) + parser.add_argument( + '-d', '--doc-limit', + type=int, + default=20, + help=f'Default document fetch limit (default: 10)' + ) + parser.add_argument( '--prompt-request-queue', default=prompt_request_queue, diff --git a/trustgraph-flow/trustgraph/retrieval/graph_rag/rag.py b/trustgraph-flow/trustgraph/retrieval/graph_rag/rag.py index 4095d0c3..2c45ecd4 100755 --- a/trustgraph-flow/trustgraph/retrieval/graph_rag/rag.py +++ b/trustgraph-flow/trustgraph/retrieval/graph_rag/rag.py @@ -31,9 +31,7 @@ class Processor(ConsumerProducer): input_queue = params.get("input_queue", default_input_queue) output_queue = params.get("output_queue", default_output_queue) subscriber = params.get("subscriber", default_subscriber) - entity_limit = params.get("entity_limit", 50) - triple_limit = params.get("triple_limit", 30) - max_subgraph_size = params.get("max_subgraph_size", 3000) + pr_request_queue = params.get( "prompt_request_queue", prompt_request_queue ) @@ -59,6 +57,11 @@ class Processor(ConsumerProducer): "triples_response_queue", triples_response_queue ) + entity_limit = params.get("entity_limit", 50) + triple_limit = params.get("triple_limit", 30) + max_subgraph_size = params.get("max_subgraph_size", 150) + max_path_length = params.get("max_path_length", 2) + super(Processor, self).__init__( **params | { "input_queue": input_queue, @@ -92,12 +95,14 @@ class Processor(ConsumerProducer): tpl_request_queue=triples_request_queue, tpl_response_queue=triples_response_queue, verbose=True, - entity_limit=entity_limit, - triple_limit=triple_limit, - max_subgraph_size=max_subgraph_size, module=module, ) + self.default_entity_limit = entity_limit + self.default_triple_limit = triple_limit + self.default_max_subgraph_size = max_subgraph_size + self.default_max_path_length = max_path_length + async def handle(self, msg): try: @@ -106,15 +111,38 @@ class Processor(ConsumerProducer): # Sender-produced ID id = msg.properties()["id"] - + print(f"Handling input {id}...", flush=True) + if v.entity_limit: + entity_limit = v.entity_limit + else: + entity_limit = self.default_entity_limit + + if v.triple_limit: + triple_limit = v.triple_limit + else: + triple_limit = self.default_triple_limit + + if v.max_subgraph_size: + max_subgraph_size = v.max_subgraph_size + else: + max_subgraph_size = self.default_max_subgraph_size + + if v.max_path_length: + max_path_length = v.max_path_length + else: + max_path_length = self.default_max_path_length + response = self.rag.query( - query=v.query, user=v.user, collection=v.collection + query=v.query, user=v.user, collection=v.collection, + entity_limit=entity_limit, triple_limit=triple_limit, + max_subgraph_size=max_subgraph_size, + max_path_length=max_path_length, ) print("Send response...", flush=True) - r = GraphRagResponse(response = response, error=None) + r = GraphRagResponse(response=response, error=None) await self.send(r, properties={"id": id}) print("Done.", flush=True) @@ -149,21 +177,28 @@ class Processor(ConsumerProducer): '-e', '--entity-limit', type=int, default=50, - help=f'Entity vector fetch limit (default: 50)' + help=f'Default entity vector fetch limit (default: 50)' ) parser.add_argument( '-t', '--triple-limit', type=int, default=30, - help=f'Triple query limit, per query (default: 30)' + help=f'Default triple query limit, per query (default: 30)' ) parser.add_argument( '-u', '--max-subgraph-size', type=int, - default=3000, - help=f'Max subgraph size (default: 3000)' + default=150, + help=f'Default max subgraph size (default: 150)' + ) + + parser.add_argument( + '-a', '--max-path-length', + type=int, + default=2, + help=f'Default max path length (default: 2)' ) parser.add_argument(