Plumbed into generic

This commit is contained in:
Cyber MacGeddon 2024-08-13 16:58:16 +01:00
parent a143e4ea4d
commit 428a869a93
11 changed files with 37 additions and 68 deletions

View file

@ -8,8 +8,8 @@ query = " ".join(sys.argv[1:])
gr = GraphRag(
verbose=True,
pulsar_host="pulsar://localhost:6650",
completion_request_queue="non-persistent://tg/request/text-completion-rag",
completion_response_queue="non-persistent://tg/response/text-completion-rag-response",
pr_request_queue="non-persistent://tg/request/prompt",
pr_response_queue="non-persistent://tg/response/prompt-response",
)
if query == "":

View file

@ -69,15 +69,16 @@ setuptools.setup(
"scripts/load-triples",
"scripts/loader",
"scripts/pdf-decoder",
"scripts/prompt-generic",
"scripts/query",
"scripts/run-processing",
"scripts/text-completion-azure",
"scripts/text-completion-bedrock",
"scripts/text-completion-claude",
"scripts/text-completion-cohere",
"scripts/text-completion-ollama",
"scripts/text-completion-openai",
"scripts/text-completion-vertexai",
"scripts/text-completion-cohere",
"scripts/triples-dump-parquet",
"scripts/triples-query-cassandra",
"scripts/triples-write-cassandra",

View file

@ -1,13 +1,13 @@
from . graph_embeddings_client import GraphEmbeddingsClient
from . triples_query_client import TriplesQueryClient
from . llm_client import LlmClient
from . embeddings_client import EmbeddingsClient
from . prompt_client import PromptClient
from . schema import GraphEmbeddingsRequest, GraphEmbeddingsResponse
from . schema import TriplesQueryRequest, TriplesQueryResponse
from . schema import text_completion_request_queue
from . schema import text_completion_response_queue
from . schema import prompt_request_queue
from . schema import prompt_response_queue
from . schema import embeddings_request_queue
from . schema import embeddings_response_queue
from . schema import graph_embeddings_request_queue
@ -23,8 +23,8 @@ class GraphRag:
def __init__(
self,
pulsar_host="pulsar://pulsar:6650",
completion_request_queue=None,
completion_response_queue=None,
pr_request_queue=None,
pr_response_queue=None,
emb_request_queue=None,
emb_response_queue=None,
ge_request_queue=None,
@ -40,11 +40,11 @@ class GraphRag:
self.verbose=verbose
if completion_request_queue is None:
completion_request_queue = text_completion_request_queue
if pr_request_queue is None:
pr_request_queue = prompt_request_queue
if completion_response_queue is None:
completion_response_queue = text_completion_response_queue
if pr_response_queue is None:
pr_response_queue = prompt_response_queue
if emb_request_queue is None:
emb_request_queue = embeddings_request_queue
@ -94,11 +94,11 @@ class GraphRag:
self.label_cache = {}
self.llm = LlmClient(
self.lang = PromptClient(
pulsar_host=pulsar_host,
input_queue=completion_request_queue,
output_queue=completion_response_queue,
subscriber=module + "-llm",
input_queue=prompt_request_queue,
output_queue=prompt_response_queue,
subscriber=module + "-prompt",
)
if self.verbose:
@ -229,51 +229,19 @@ class GraphRag:
return sg2
def get_cypher(self, query):
sg = self.get_labelgraph(query)
sg2 = []
for s, p, o in sg:
sg2.append(f"({s})-[{p}]->({o})")
kg = "\n".join(sg2)
kg = kg.replace("\\", "-")
if self.verbose:
print(kg)
return kg
def get_graph_prompt(self, query):
kg = self.get_cypher(query)
prompt=f"""Study the following set of knowledge statements. The statements are written in Cypher format that has been extracted from a knowledge graph. Use only the provided set of knowledge statements in your response. Do not speculate if the answer is not found in the provided set of knowledge statements.
Here's the knowledge statements:
{kg}
Use only the provided knowledge statements to respond to the following:
{query}
"""
return prompt
def query(self, query):
if self.verbose:
print("Construct prompt...", flush=True)
prompt = self.get_graph_prompt(query)
kg = self.get_labelgraph(query)
if self.verbose:
print("Invoke LLM...", flush=True)
print(prompt)
print(kg)
print(query)
resp = self.llm.request(prompt)
resp = self.lang.request_kg_prompt(query, kg)
if self.verbose:
print("Done", flush=True)

View file

@ -6,8 +6,8 @@ Input is query, output is response.
from ... schema import GraphRagQuery, GraphRagResponse
from ... schema import graph_rag_request_queue, graph_rag_response_queue
from ... schema import text_completion_request_queue
from ... schema import text_completion_response_queue
from ... schema import prompt_request_queue
from ... schema import prompt_response_queue
from ... schema import embeddings_request_queue
from ... schema import embeddings_response_queue
from ... schema import graph_embeddings_request_queue
@ -34,11 +34,11 @@ class Processor(ConsumerProducer):
entity_limit = params.get("entity_limit", 50)
triple_limit = params.get("triple_limit", 30)
max_subgraph_size = params.get("max_subgraph_size", 3000)
tc_request_queue = params.get(
"text_completion_request_queue", text_completion_request_queue
pr_request_queue = params.get(
"prompt_request_queue", prompt_request_queue
)
tc_response_queue = params.get(
"text_completion_response_queue", text_completion_response_queue
pr_response_queue = params.get(
"prompt_response_queue", prompt_response_queue
)
emb_request_queue = params.get(
"embeddings_request_queue", embeddings_request_queue
@ -69,8 +69,8 @@ class Processor(ConsumerProducer):
"entity_limit": entity_limit,
"triple_limit": triple_limit,
"max_subgraph_size": max_subgraph_size,
"text_completion_request_queue": tc_request_queue,
"text_completion_response_queue": tc_response_queue,
"prompt_request_queue": pr_request_queue,
"prompt_response_queue": pr_response_queue,
"embeddings_request_queue": emb_request_queue,
"embeddings_response_queue": emb_response_queue,
"graph_embeddings_request_queue": ge_request_queue,
@ -82,8 +82,8 @@ class Processor(ConsumerProducer):
self.rag = GraphRag(
pulsar_host=self.pulsar_host,
completion_request_queue=tc_request_queue,
completion_response_queue=tc_response_queue,
pr_request_queue=pr_request_queue,
pr_response_queue=pr_response_queue,
emb_request_queue=emb_request_queue,
emb_response_queue=emb_response_queue,
ge_request_queue=ge_request_queue,
@ -157,15 +157,15 @@ class Processor(ConsumerProducer):
)
parser.add_argument(
'--text-completion-request-queue',
default=text_completion_request_queue,
help=f'Text completion request queue (default: {text_completion_request_queue})',
'--prompt-request-queue',
default=prompt_request_queue,
help=f'Prompt request queue (default: {prompt_request_queue})',
)
parser.add_argument(
'--text-completion-response-queue',
default=text_completion_response_queue,
help=f'Text completion response queue (default: {text_completion_response_queue})',
'--prompt-response-queue',
default=prompt_response_queue,
help=f'Prompt response queue (default: {prompt_response_queue})',
)
parser.add_argument(