mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-07-05 19:32:11 +02:00
Plumbed into generic
This commit is contained in:
parent
a143e4ea4d
commit
428a869a93
11 changed files with 37 additions and 68 deletions
|
|
@ -8,8 +8,8 @@ query = " ".join(sys.argv[1:])
|
|||
gr = GraphRag(
|
||||
verbose=True,
|
||||
pulsar_host="pulsar://localhost:6650",
|
||||
completion_request_queue="non-persistent://tg/request/text-completion-rag",
|
||||
completion_response_queue="non-persistent://tg/response/text-completion-rag-response",
|
||||
pr_request_queue="non-persistent://tg/request/prompt",
|
||||
pr_response_queue="non-persistent://tg/response/prompt-response",
|
||||
)
|
||||
|
||||
if query == "":
|
||||
|
|
|
|||
3
setup.py
3
setup.py
|
|
@ -69,15 +69,16 @@ setuptools.setup(
|
|||
"scripts/load-triples",
|
||||
"scripts/loader",
|
||||
"scripts/pdf-decoder",
|
||||
"scripts/prompt-generic",
|
||||
"scripts/query",
|
||||
"scripts/run-processing",
|
||||
"scripts/text-completion-azure",
|
||||
"scripts/text-completion-bedrock",
|
||||
"scripts/text-completion-claude",
|
||||
"scripts/text-completion-cohere",
|
||||
"scripts/text-completion-ollama",
|
||||
"scripts/text-completion-openai",
|
||||
"scripts/text-completion-vertexai",
|
||||
"scripts/text-completion-cohere",
|
||||
"scripts/triples-dump-parquet",
|
||||
"scripts/triples-query-cassandra",
|
||||
"scripts/triples-write-cassandra",
|
||||
|
|
|
|||
|
|
@ -1,13 +1,13 @@
|
|||
|
||||
from . graph_embeddings_client import GraphEmbeddingsClient
|
||||
from . triples_query_client import TriplesQueryClient
|
||||
from . llm_client import LlmClient
|
||||
from . embeddings_client import EmbeddingsClient
|
||||
from . prompt_client import PromptClient
|
||||
|
||||
from . schema import GraphEmbeddingsRequest, GraphEmbeddingsResponse
|
||||
from . schema import TriplesQueryRequest, TriplesQueryResponse
|
||||
from . schema import text_completion_request_queue
|
||||
from . schema import text_completion_response_queue
|
||||
from . schema import prompt_request_queue
|
||||
from . schema import prompt_response_queue
|
||||
from . schema import embeddings_request_queue
|
||||
from . schema import embeddings_response_queue
|
||||
from . schema import graph_embeddings_request_queue
|
||||
|
|
@ -23,8 +23,8 @@ class GraphRag:
|
|||
def __init__(
|
||||
self,
|
||||
pulsar_host="pulsar://pulsar:6650",
|
||||
completion_request_queue=None,
|
||||
completion_response_queue=None,
|
||||
pr_request_queue=None,
|
||||
pr_response_queue=None,
|
||||
emb_request_queue=None,
|
||||
emb_response_queue=None,
|
||||
ge_request_queue=None,
|
||||
|
|
@ -40,11 +40,11 @@ class GraphRag:
|
|||
|
||||
self.verbose=verbose
|
||||
|
||||
if completion_request_queue is None:
|
||||
completion_request_queue = text_completion_request_queue
|
||||
if pr_request_queue is None:
|
||||
pr_request_queue = prompt_request_queue
|
||||
|
||||
if completion_response_queue is None:
|
||||
completion_response_queue = text_completion_response_queue
|
||||
if pr_response_queue is None:
|
||||
pr_response_queue = prompt_response_queue
|
||||
|
||||
if emb_request_queue is None:
|
||||
emb_request_queue = embeddings_request_queue
|
||||
|
|
@ -94,11 +94,11 @@ class GraphRag:
|
|||
|
||||
self.label_cache = {}
|
||||
|
||||
self.llm = LlmClient(
|
||||
self.lang = PromptClient(
|
||||
pulsar_host=pulsar_host,
|
||||
input_queue=completion_request_queue,
|
||||
output_queue=completion_response_queue,
|
||||
subscriber=module + "-llm",
|
||||
input_queue=prompt_request_queue,
|
||||
output_queue=prompt_response_queue,
|
||||
subscriber=module + "-prompt",
|
||||
)
|
||||
|
||||
if self.verbose:
|
||||
|
|
@ -229,51 +229,19 @@ class GraphRag:
|
|||
|
||||
return sg2
|
||||
|
||||
def get_cypher(self, query):
|
||||
|
||||
sg = self.get_labelgraph(query)
|
||||
|
||||
sg2 = []
|
||||
|
||||
for s, p, o in sg:
|
||||
|
||||
sg2.append(f"({s})-[{p}]->({o})")
|
||||
|
||||
kg = "\n".join(sg2)
|
||||
kg = kg.replace("\\", "-")
|
||||
|
||||
if self.verbose:
|
||||
print(kg)
|
||||
|
||||
return kg
|
||||
|
||||
def get_graph_prompt(self, query):
|
||||
|
||||
kg = self.get_cypher(query)
|
||||
|
||||
prompt=f"""Study the following set of knowledge statements. The statements are written in Cypher format that has been extracted from a knowledge graph. Use only the provided set of knowledge statements in your response. Do not speculate if the answer is not found in the provided set of knowledge statements.
|
||||
|
||||
Here's the knowledge statements:
|
||||
{kg}
|
||||
|
||||
Use only the provided knowledge statements to respond to the following:
|
||||
{query}
|
||||
"""
|
||||
|
||||
return prompt
|
||||
|
||||
def query(self, query):
|
||||
|
||||
if self.verbose:
|
||||
print("Construct prompt...", flush=True)
|
||||
|
||||
prompt = self.get_graph_prompt(query)
|
||||
kg = self.get_labelgraph(query)
|
||||
|
||||
if self.verbose:
|
||||
print("Invoke LLM...", flush=True)
|
||||
print(prompt)
|
||||
print(kg)
|
||||
print(query)
|
||||
|
||||
resp = self.llm.request(prompt)
|
||||
resp = self.lang.request_kg_prompt(query, kg)
|
||||
|
||||
if self.verbose:
|
||||
print("Done", flush=True)
|
||||
|
|
|
|||
|
|
@ -6,8 +6,8 @@ Input is query, output is response.
|
|||
|
||||
from ... schema import GraphRagQuery, GraphRagResponse
|
||||
from ... schema import graph_rag_request_queue, graph_rag_response_queue
|
||||
from ... schema import text_completion_request_queue
|
||||
from ... schema import text_completion_response_queue
|
||||
from ... schema import prompt_request_queue
|
||||
from ... schema import prompt_response_queue
|
||||
from ... schema import embeddings_request_queue
|
||||
from ... schema import embeddings_response_queue
|
||||
from ... schema import graph_embeddings_request_queue
|
||||
|
|
@ -34,11 +34,11 @@ class Processor(ConsumerProducer):
|
|||
entity_limit = params.get("entity_limit", 50)
|
||||
triple_limit = params.get("triple_limit", 30)
|
||||
max_subgraph_size = params.get("max_subgraph_size", 3000)
|
||||
tc_request_queue = params.get(
|
||||
"text_completion_request_queue", text_completion_request_queue
|
||||
pr_request_queue = params.get(
|
||||
"prompt_request_queue", prompt_request_queue
|
||||
)
|
||||
tc_response_queue = params.get(
|
||||
"text_completion_response_queue", text_completion_response_queue
|
||||
pr_response_queue = params.get(
|
||||
"prompt_response_queue", prompt_response_queue
|
||||
)
|
||||
emb_request_queue = params.get(
|
||||
"embeddings_request_queue", embeddings_request_queue
|
||||
|
|
@ -69,8 +69,8 @@ class Processor(ConsumerProducer):
|
|||
"entity_limit": entity_limit,
|
||||
"triple_limit": triple_limit,
|
||||
"max_subgraph_size": max_subgraph_size,
|
||||
"text_completion_request_queue": tc_request_queue,
|
||||
"text_completion_response_queue": tc_response_queue,
|
||||
"prompt_request_queue": pr_request_queue,
|
||||
"prompt_response_queue": pr_response_queue,
|
||||
"embeddings_request_queue": emb_request_queue,
|
||||
"embeddings_response_queue": emb_response_queue,
|
||||
"graph_embeddings_request_queue": ge_request_queue,
|
||||
|
|
@ -82,8 +82,8 @@ class Processor(ConsumerProducer):
|
|||
|
||||
self.rag = GraphRag(
|
||||
pulsar_host=self.pulsar_host,
|
||||
completion_request_queue=tc_request_queue,
|
||||
completion_response_queue=tc_response_queue,
|
||||
pr_request_queue=pr_request_queue,
|
||||
pr_response_queue=pr_response_queue,
|
||||
emb_request_queue=emb_request_queue,
|
||||
emb_response_queue=emb_response_queue,
|
||||
ge_request_queue=ge_request_queue,
|
||||
|
|
@ -157,15 +157,15 @@ class Processor(ConsumerProducer):
|
|||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--text-completion-request-queue',
|
||||
default=text_completion_request_queue,
|
||||
help=f'Text completion request queue (default: {text_completion_request_queue})',
|
||||
'--prompt-request-queue',
|
||||
default=prompt_request_queue,
|
||||
help=f'Prompt request queue (default: {prompt_request_queue})',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--text-completion-response-queue',
|
||||
default=text_completion_response_queue,
|
||||
help=f'Text completion response queue (default: {text_completion_response_queue})',
|
||||
'--prompt-response-queue',
|
||||
default=prompt_response_queue,
|
||||
help=f'Prompt response queue (default: {prompt_response_queue})',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue