Make embedding & text-completion queues configurable in graph-rag

This commit is contained in:
Cyber MacGeddon 2024-08-05 22:37:11 +01:00
parent 963944f4aa
commit 103251dbf2
2 changed files with 70 additions and 0 deletions

View file

@ -4,6 +4,8 @@ from trustgraph.triple_vectors import TripleVectors
from trustgraph.trustgraph import TrustGraph from trustgraph.trustgraph import TrustGraph
from trustgraph.llm_client import LlmClient from trustgraph.llm_client import LlmClient
from trustgraph.embeddings_client import EmbeddingsClient from trustgraph.embeddings_client import EmbeddingsClient
from . schema import text_completion_request_queue
from . schema import text_completion_response_queue
LABEL="http://www.w3.org/2000/01/rdf-schema#label" LABEL="http://www.w3.org/2000/01/rdf-schema#label"
DEFINITION="http://www.w3.org/2004/02/skos/core#definition" DEFINITION="http://www.w3.org/2004/02/skos/core#definition"
@ -15,6 +17,10 @@ class GraphRag:
graph_hosts=None, graph_hosts=None,
pulsar_host="pulsar://pulsar:6650", pulsar_host="pulsar://pulsar:6650",
vector_store="http://milvus:19530", vector_store="http://milvus:19530",
completion_request_queue=None,
completion_response_queue=None,
emb_request_queue=None,
emb_response_queue=None,
verbose=False, verbose=False,
entity_limit=50, entity_limit=50,
triple_limit=30, triple_limit=30,
@ -24,6 +30,18 @@ class GraphRag:
self.verbose=verbose self.verbose=verbose
if completion_request_queue == None:
completion_request_queue = text_completion_request_queue
if completion_response_queue == None:
completion_response_queue = text_completion_response_queue
if emb_request_queue == None:
emb_request_queue = embeddings_request_queue
if emb_response_queue == None:
emb_response_queue = embeddings_response_queue
if graph_hosts == None: if graph_hosts == None:
graph_hosts = ["cassandra"] graph_hosts = ["cassandra"]
@ -34,6 +52,8 @@ class GraphRag:
self.embeddings = EmbeddingsClient( self.embeddings = EmbeddingsClient(
pulsar_host=pulsar_host, pulsar_host=pulsar_host,
input_queue=emb_request_queue,
output_queue=emb_response_queue,
subscriber=module + "-emb", subscriber=module + "-emb",
) )
@ -47,6 +67,8 @@ class GraphRag:
self.llm = LlmClient( self.llm = LlmClient(
pulsar_host=pulsar_host, pulsar_host=pulsar_host,
input_queue=completion_request_queue,
output_queue=completion_response_queue,
subscriber=module + "-llm", subscriber=module + "-llm",
) )

View file

@ -6,6 +6,10 @@ Input is query, output is response.
from ... schema import GraphRagQuery, GraphRagResponse from ... schema import GraphRagQuery, GraphRagResponse
from ... schema import graph_rag_request_queue, graph_rag_response_queue from ... schema import graph_rag_request_queue, graph_rag_response_queue
from ... schema import text_completion_request_queue
from ... schema import text_completion_response_queue
from ... schema import embeddings_request_queue
from ... schema import embeddings_response_queue
from ... log_level import LogLevel from ... log_level import LogLevel
from ... graph_rag import GraphRag from ... graph_rag import GraphRag
from ... base import ConsumerProducer from ... base import ConsumerProducer
@ -30,6 +34,18 @@ class Processor(ConsumerProducer):
entity_limit = params.get("entity_limit", 50) entity_limit = params.get("entity_limit", 50)
triple_limit = params.get("triple_limit", 30) triple_limit = params.get("triple_limit", 30)
max_subgraph_size = params.get("max_subgraph_size", 3000) max_subgraph_size = params.get("max_subgraph_size", 3000)
tc_request_queue = params.get(
"text_completion_request_queue", text_completion_request_queue
)
tc_response_queue = params.get(
"text_completion_response_queue", text_completion_response_queue
)
emb_request_queue = params.get(
"embeddings_request_queue", embeddings_request_queue
)
emb_response_queue = params.get(
"embeddings_response_queue", embeddings_response_queue
)
super(Processor, self).__init__( super(Processor, self).__init__(
**params | { **params | {
@ -41,12 +57,20 @@ class Processor(ConsumerProducer):
"entity_limit": entity_limit, "entity_limit": entity_limit,
"triple_limit": triple_limit, "triple_limit": triple_limit,
"max_subgraph_size": max_subgraph_size, "max_subgraph_size": max_subgraph_size,
"text_completion_request_queue": tc_request_queue,
"text_completion_response_queue": tc_response_queue,
"embeddings_request_queue": emb_request_queue,
"embeddings_response_queue": emb_response_queue,
} }
) )
self.rag = GraphRag( self.rag = GraphRag(
pulsar_host=self.pulsar_host, pulsar_host=self.pulsar_host,
graph_hosts=graph_hosts.split(","), graph_hosts=graph_hosts.split(","),
completion_request_queue=tc_request_queue,
completion_response_queue=tc_response_queue,
emb_request_queue=emb_request_queue,
emb_response_queue=emb_response_queue,
vector_store=vector_store, vector_store=vector_store,
verbose=True, verbose=True,
entity_limit=entity_limit, entity_limit=entity_limit,
@ -114,6 +138,30 @@ class Processor(ConsumerProducer):
help=f'Max subgraph size (default: 3000)' help=f'Max subgraph size (default: 3000)'
) )
parser.add_argument(
'--text-completion-request-queue',
default=text_completion_request_queue,
help=f'Text completion request queue (default: {text_completion_request_queue})',
)
parser.add_argument(
'--text-completion-response-queue',
default=text_completion_response_queue,
help=f'Text completion response queue (default: {text_completion_response_queue})',
)
parser.add_argument(
'--embeddings-request-queue',
default=embeddings_request_queue,
help=f'Embeddings request queue (default: {embeddings_request_queue})',
)
parser.add_argument(
'--embeddings-response-queue',
default=embeddings_response_queue,
help=f'Embeddings request queue (default: {embeddings_response_queue})',
)
def run(): def run():
Processor.start(module, __doc__) Processor.start(module, __doc__)