diff --git a/trustgraph/graph_rag.py b/trustgraph/graph_rag.py index 0b7ccbdc..967bd68a 100644 --- a/trustgraph/graph_rag.py +++ b/trustgraph/graph_rag.py @@ -4,6 +4,8 @@ from trustgraph.triple_vectors import TripleVectors from trustgraph.trustgraph import TrustGraph from trustgraph.llm_client import LlmClient from trustgraph.embeddings_client import EmbeddingsClient +from . schema import text_completion_request_queue +from . schema import text_completion_response_queue LABEL="http://www.w3.org/2000/01/rdf-schema#label" DEFINITION="http://www.w3.org/2004/02/skos/core#definition" @@ -15,6 +17,10 @@ class GraphRag: graph_hosts=None, pulsar_host="pulsar://pulsar:6650", vector_store="http://milvus:19530", + completion_request_queue=None, + completion_response_queue=None, + emb_request_queue=None, + emb_response_queue=None, verbose=False, entity_limit=50, triple_limit=30, @@ -24,6 +30,18 @@ class GraphRag: self.verbose=verbose + if completion_request_queue == None: + completion_request_queue = text_completion_request_queue + + if completion_response_queue == None: + completion_response_queue = text_completion_response_queue + + if emb_request_queue == None: + emb_request_queue = embeddings_request_queue + + if emb_response_queue == None: + emb_response_queue = embeddings_response_queue + if graph_hosts == None: graph_hosts = ["cassandra"] @@ -34,6 +52,8 @@ class GraphRag: self.embeddings = EmbeddingsClient( pulsar_host=pulsar_host, + input_queue=emb_request_queue, + output_queue=emb_response_queue, subscriber=module + "-emb", ) @@ -47,6 +67,8 @@ class GraphRag: self.llm = LlmClient( pulsar_host=pulsar_host, + input_queue=completion_request_queue, + output_queue=completion_response_queue, subscriber=module + "-llm", ) diff --git a/trustgraph/retrieval/graph_rag/rag.py b/trustgraph/retrieval/graph_rag/rag.py index 80cbfafb..e43aa0d8 100755 --- a/trustgraph/retrieval/graph_rag/rag.py +++ b/trustgraph/retrieval/graph_rag/rag.py @@ -6,6 +6,10 @@ Input is query, output is response. from ... schema import GraphRagQuery, GraphRagResponse from ... schema import graph_rag_request_queue, graph_rag_response_queue +from ... schema import text_completion_request_queue +from ... schema import text_completion_response_queue +from ... schema import embeddings_request_queue +from ... schema import embeddings_response_queue from ... log_level import LogLevel from ... graph_rag import GraphRag from ... base import ConsumerProducer @@ -30,6 +34,18 @@ class Processor(ConsumerProducer): entity_limit = params.get("entity_limit", 50) triple_limit = params.get("triple_limit", 30) max_subgraph_size = params.get("max_subgraph_size", 3000) + tc_request_queue = params.get( + "text_completion_request_queue", text_completion_request_queue + ) + tc_response_queue = params.get( + "text_completion_response_queue", text_completion_response_queue + ) + emb_request_queue = params.get( + "embeddings_request_queue", embeddings_request_queue + ) + emb_response_queue = params.get( + "embeddings_response_queue", embeddings_response_queue + ) super(Processor, self).__init__( **params | { @@ -41,12 +57,20 @@ class Processor(ConsumerProducer): "entity_limit": entity_limit, "triple_limit": triple_limit, "max_subgraph_size": max_subgraph_size, + "text_completion_request_queue": tc_request_queue, + "text_completion_response_queue": tc_response_queue, + "embeddings_request_queue": emb_request_queue, + "embeddings_response_queue": emb_response_queue, } ) self.rag = GraphRag( pulsar_host=self.pulsar_host, graph_hosts=graph_hosts.split(","), + completion_request_queue=tc_request_queue, + completion_response_queue=tc_response_queue, + emb_request_queue=emb_request_queue, + emb_response_queue=emb_response_queue, vector_store=vector_store, verbose=True, entity_limit=entity_limit, @@ -114,6 +138,30 @@ class Processor(ConsumerProducer): help=f'Max subgraph size (default: 3000)' ) + parser.add_argument( + '--text-completion-request-queue', + default=text_completion_request_queue, + help=f'Text completion request queue (default: {text_completion_request_queue})', + ) + + parser.add_argument( + '--text-completion-response-queue', + default=text_completion_response_queue, + help=f'Text completion response queue (default: {text_completion_response_queue})', + ) + + parser.add_argument( + '--embeddings-request-queue', + default=embeddings_request_queue, + help=f'Embeddings request queue (default: {embeddings_request_queue})', + ) + + parser.add_argument( + '--embeddings-response-queue', + default=embeddings_response_queue, + help=f'Embeddings request queue (default: {embeddings_response_queue})', + ) + def run(): Processor.start(module, __doc__)